diff --git a/configs/route_sys_prompts/optional_chat_prompt.txt b/configs/route_sys_prompts/optional_chat_prompt.txt new file mode 100644 index 0000000..ede1ef0 --- /dev/null +++ b/configs/route_sys_prompts/optional_chat_prompt.txt @@ -0,0 +1 @@ +create a chat_prompt.txt to overwrite the system prompt from xiaozhi \ No newline at end of file diff --git a/configs/route_sys_prompts/route_prompt.txt b/configs/route_sys_prompts/route_prompt.txt new file mode 100644 index 0000000..c46bffb --- /dev/null +++ b/configs/route_sys_prompts/route_prompt.txt @@ -0,0 +1 @@ +Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input diff --git a/configs/route_sys_prompts/tool_prompt.txt b/configs/route_sys_prompts/tool_prompt.txt new file mode 100644 index 0000000..a2f8359 --- /dev/null +++ b/configs/route_sys_prompts/tool_prompt.txt @@ -0,0 +1 @@ +You must use tool to complete the possible task diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 7dfaed9..da09a5d 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -9,6 +9,7 @@ import matplotlib.pyplot as plt import jax import os.path as osp import commentjson +import glob from lang_agent.config import KeyConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig @@ -45,7 +46,8 @@ class RoutingConfig(KeyConfig): def __post_init__(self): super().__post_init__() if self.sys_promp_json is None: - self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json") + # self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json") + self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts") logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}") assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist." @@ -119,8 +121,31 @@ class RoutingGraph(GraphBase): self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory) self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory) - with open(self.config.sys_promp_json , "r") as f: - self.prompt_dict:Dict[str, str] = commentjson.load(f) + self._load_sys_prompts() + + def _load_sys_prompts(self): + if "json" in self.config.sys_promp_json[-5:]: + logger.info("loading sys prompt from json") + with open(self.config.sys_promp_json , "r") as f: + self.prompt_dict:Dict[str, str] = commentjson.load(f) + + elif osp.isdir(self.config.sys_promp_json): + logger.info("loading sys_prompt from txt") + sys_fs = glob.glob(osp.join(self.config.sys_promp_json, "*.txt")) + sys_fs = sorted([e for e in sys_fs if not ("optional" in e)]) + assert len(sys_fs) <= 3, "AT MOST 3 PROMPT!" + self.prompt_dict = {} + for sys_f in sys_fs: + key = osp.basename(sys_f).split(".")[0] + with open(sys_f, "r") as f: + self.prompt_dict[key] = f.read() + else: + err_msg = f"{self.config.sys_promp_json} is not supported" + assert 0, err_msg + + for k, _ in self.prompt_dict.items(): + logger.info(f"loaded {k} system prompt") + def _router_call(self, state:State):