diff --git a/configs/route_sys_prompts.json b/configs/route_sys_prompts.json new file mode 100644 index 0000000..0e35c59 --- /dev/null +++ b/configs/route_sys_prompts.json @@ -0,0 +1,8 @@ +{ // Prompt for router; manditory to say return as json for route_prompt; if not, there will be bugs + "route_prompt" : "Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input", + + // Prompt for tool branch + "tool_prompt" : "You must use tool to complete the possible task" + + // Optionally set chat_prompt to overwrite the system prompt from xiaozhi +} \ No newline at end of file diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 85948c1..592a5ba 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -7,6 +7,8 @@ from PIL import Image from io import BytesIO import matplotlib.pyplot as plt import jax +import os.path as osp +import commentjson from lang_agent.config import KeyConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig @@ -34,9 +36,20 @@ class RoutingConfig(KeyConfig): base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" """base url; could be used to overwrite the baseurl in llm provider""" + sys_promp_json: str = None + "path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided" + tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) + def __post_init__(self): + super().__post_init__() + if self.sys_promp_json is None: + self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json") + logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}") + + assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist." + class Route(BaseModel): step: Literal["chat", "order"] = Field( @@ -89,9 +102,9 @@ class RoutingGraph(GraphBase): def _build_modules(self): self.llm = init_chat_model(model=self.config.llm_name, - model_provider=self.config.llm_provider, - api_key=self.config.api_key, - base_url=self.config.base_url) + model_provider=self.config.llm_provider, + api_key=self.config.api_key, + base_url=self.config.base_url) self.memory = MemorySaver() self.router = self.llm.with_structured_output(Route) @@ -99,12 +112,16 @@ class RoutingGraph(GraphBase): self.chat_model = create_agent(self.llm, [], checkpointer=self.memory) self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory) + with open(self.config.sys_promp_json , "r") as f: + self.prompt_dict = commentjson.load(f) + def _router_call(self, state:State): decision:Route = self.router.invoke( [ SystemMessage( - content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input" + # content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input" + content=self.prompt_dict.get("route_prompt") ), self._get_human_msg(state) ] @@ -137,6 +154,16 @@ class RoutingGraph(GraphBase): inp = state["messages"], state["inp"][1] else: inp = state["inp"] + + if self.prompt_dict.get("chat_prompt") is not None: + inp = {"messages":[ + SystemMessage( + # "You must use tool to complete the possible task" + self.prompt_dict["chat_prompt"] + ), + *state["inp"][0]["messages"][1:] + ]}, state["inp"][1] + out = self.chat_model.invoke(*inp) return {"messages": out} @@ -145,9 +172,9 @@ class RoutingGraph(GraphBase): def _tool_model_call(self, state:State): inp = {"messages":[ SystemMessage( - "You must use tool to complete the possible task" + # "You must use tool to complete the possible task" + self.prompt_dict["tool_prompt"] ), - # self._get_human_msg(state) *state["inp"][0]["messages"][1:] ]}, state["inp"][1]