Spaces:
Runtime error
Runtime error
import os | |
from typing import Dict, List, Optional, Tuple, Union | |
import gradio as gr | |
from common.call_llm import chat, chat_stream_generator | |
from plugin_task.model import Plugin, ReActStep | |
from plugin_task.plugins import PLUGIN_JSON_SCHEMA, PLUGINS | |
from plugin_task.prompt import ( | |
FILLING_SLOT_PROMPT, | |
FINAL_PROMPT, | |
INTENT_RECOGNITION_PROMPT, | |
) | |
from plugin_task.util import ( | |
build_prompt_plugin_variables, | |
parse_reAct_step, | |
plugin_parameter_validator, | |
) | |
PLUGIN_ENDPOINT = os.environ.get("PLUGIN_ENDPOINT") | |
def api_plugin_chat( | |
session: Dict, | |
message: str, | |
chat_history: List[List[str]], | |
*radio_plugins, | |
): | |
"""调用插件""" | |
if not check_in_plugin_session(session): | |
plugins = prepare_plugins(radio_plugins) | |
if not plugins: | |
gr.Warning("没有启用插件") | |
return | |
intention, reAct_step = intent_recognition(message, plugins) | |
if intention in ("ask_user_for_required_params", "plugin"): | |
session["origin_message"] = message | |
session["choice_plugin"] = reAct_step.thought["tool_to_use_for_user"] | |
session["reAct_step"] = [reAct_step] | |
else: | |
intention, reAct_step = filling_slot_with_loop(session, message) | |
print( | |
f"[API_PLUGIN_CHAT]. message: {message},\n intention: {intention},\n session: {session}\n" | |
+ "=" * 25 | |
+ "END" | |
+ "=" * 25 | |
) | |
if intention == "fail": | |
chat_history[-1][1] = reAct_step | |
session.clear() | |
yield session, None, chat_history | |
return | |
if intention == "ask_user_for_required_params": | |
chat_history[-1][1] = reAct_step.action_input.get("question", "") | |
yield session, None, chat_history | |
if intention == "plugin": | |
yield from call_final_answer(session, reAct_step, chat_history) | |
if intention == "chat": | |
yield from call_chat(session, message, chat_history) | |
if intention == "end": | |
session.clear() | |
chat_history[-1][1] = "[系统消息]:当前插件对话结束" | |
yield session, None, chat_history | |
return | |
return | |
def filling_slot_with_loop( | |
session: Dict, message: str, retry: int = 3 | |
) -> Tuple[str, Optional[Union[ReActStep, str]]]: | |
"""处理填槽""" | |
plugin = PLUGINS[session["choice_plugin"]] | |
while True: | |
lastest_reAct_step = session["reAct_step"][-1] | |
if not lastest_reAct_step.observation: | |
lastest_reAct_step.observation = {"user_answer": message} | |
reAct_step_str = "\n".join(step.to_str() for step in session["reAct_step"]) | |
ask_content = FILLING_SLOT_PROMPT.format( | |
plugin_name=plugin.unique_name_for_model, | |
description_for_human=plugin.description_for_human, | |
parameter_schema=plugin.parameter_schema, | |
question=session["origin_message"], | |
reAct_step_str=reAct_step_str, | |
) | |
model_response = chat( | |
[{"content": ask_content, "role": "user"}], | |
stop="Observation", | |
endpoint=PLUGIN_ENDPOINT, | |
) | |
print( | |
f"[FILLING_SLOT_WITH_LOOP] message: {message} ask_content: {ask_content}\n model_response: {model_response}\n" | |
+ "=" * 25 | |
+ "END" | |
+ "=" * 25 | |
) | |
reAct_step = parse_reAct_step(model_response) | |
if not reAct_step: | |
if (retry := retry - 1) < 0: | |
return "fail", model_response | |
continue | |
tool_to_use_for_user = reAct_step.thought.get("tool_to_use_for_user") | |
known_parameter = reAct_step.thought.get("known_params", {}) | |
if ( | |
reAct_step.action == "end_conversation" | |
or tool_to_use_for_user == "end_conversation" | |
): | |
return "end", reAct_step | |
if ( | |
reAct_step.action == "ASK_USER_FOR_REQUIRED_PARAMS" | |
and tool_to_use_for_user == plugin.unique_name_for_model | |
): | |
passed, _ = plugin_parameter_validator( | |
known_parameter, | |
tool_to_use_for_user, | |
) | |
if passed: | |
reAct_step.action = tool_to_use_for_user | |
action = "plugin" | |
else: | |
action = "ask_user_for_required_params" | |
session["reAct_step"].append(reAct_step) | |
return action, reAct_step | |
if ( | |
reAct_step.action == plugin.unique_name_for_model | |
and tool_to_use_for_user == plugin.unique_name_for_model | |
): | |
passed, invalid_info = plugin_parameter_validator( | |
known_parameter, | |
tool_to_use_for_user, | |
) | |
if not passed: | |
reAct_step.observation = {"tool_parameters_verification": invalid_info} | |
session["reAct_step"].append(reAct_step) | |
continue | |
session["reAct_step"].append(reAct_step) | |
return "plugin", reAct_step | |
def call_chat(session: Dict, message: str, chat_history: List[List[str]]): | |
from chat_task.chat import generate_chat | |
for chunk in generate_chat(message, chat_history, PLUGIN_ENDPOINT): | |
yield session, *chunk | |
def check_in_plugin_session(session: Dict) -> bool: | |
"""检查是否在插件会话中""" | |
return bool(session) | |
def prepare_plugins( | |
radio_plugins: List[str], | |
) -> List[Plugin]: | |
return [ | |
PLUGINS[PLUGIN_JSON_SCHEMA[plugin_idx]["unique_name_for_model"]] | |
for plugin_idx, plugin_status in enumerate(radio_plugins) | |
if plugin_status == "开启" | |
] | |
def intent_recognition( | |
message: str, choice_plugins: List[Plugin] | |
) -> Tuple[str, Union[ReActStep, str]]: | |
"""意图识别""" | |
plugins, plugin_names = build_prompt_plugin_variables(choice_plugins) | |
ask_content = INTENT_RECOGNITION_PROMPT.format( | |
plugins=plugins, plugin_names=plugin_names, question=message | |
) | |
print( | |
f"[INTENT_RECOGNITION] message:{message} ask_content: {ask_content}" | |
+ "=" * 25 | |
+ "END" | |
+ "=" * 25 | |
) | |
retry = 3 | |
while retry != 0: | |
model_response = chat( | |
[{"content": ask_content, "role": "user"}], | |
stop="Observation", | |
endpoint=PLUGIN_ENDPOINT, | |
) | |
reAct_step = parse_reAct_step(model_response) | |
if reAct_step: | |
break | |
retry -= 1 | |
if not reAct_step: | |
print(f"[INTENT_RECOGNITION] model fail: {model_response}") | |
return "fail", model_response | |
tool_to_use_for_user = reAct_step.thought.get("tool_to_use_for_user") | |
known_params = reAct_step.thought.get("known_params", {}) | |
if reAct_step.action == "TOOL_OTHER": | |
return "chat", reAct_step | |
elif ( | |
reAct_step.action == "end_conversation" | |
and tool_to_use_for_user == "end_conversation" | |
): | |
return "end", reAct_step | |
elif tool_to_use_for_user in plugin_names.split(","): | |
if reAct_step.action in ("ASK_USER_FOR_INTENT", "ASK_USER_FOR_REQUIRED_PARAMS"): | |
passed, _ = plugin_parameter_validator( | |
known_params, | |
tool_to_use_for_user, | |
) | |
if passed: | |
reAct_step.action = tool_to_use_for_user | |
return "plugin", reAct_step | |
return "ask_user_for_required_params", reAct_step | |
if reAct_step.action in plugin_names.split(","): | |
return "plugin", reAct_step | |
return "chat", reAct_step | |
def call_final_answer(session: Dict, reAct_step: ReActStep, history: List[List[str]]): | |
"""调用最终回答""" | |
plugin_result = PLUGINS[reAct_step.action].run(**reAct_step.action_input) | |
lastest_reAct_step = session["reAct_step"][-1] | |
lastest_reAct_step.observation = {"tool_response": plugin_result} | |
reAct_step_str = "\n".join(step.to_str() for step in session["reAct_step"]) | |
final_prompt = FINAL_PROMPT.format( | |
question=session["origin_message"], | |
reAct_step_str=reAct_step_str, | |
) | |
print( | |
f"[CALL_FINAL_ANSWER] final_prompt: {final_prompt}\n" | |
+ "=" * 25 | |
+ "END" | |
+ "=" * 25 | |
) | |
stream_response = chat_stream_generator( | |
[{"content": final_prompt, "role": "user"}], | |
endpoint=PLUGIN_ENDPOINT, | |
) | |
for character in stream_response: | |
history[-1][1] += character | |
yield session, None, history | |
session.clear() | |