diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 2dc791a..ed4c7f9 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -37,21 +37,12 @@ class RoutingConfig(KeyConfig): base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" """base url; could be used to overwrite the baseurl in llm provider""" - sys_promp_json: str = None - "path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided" + sys_promp_dir: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts") + """path to directory or json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided""" tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) - def __post_init__(self): - super().__post_init__() - if self.sys_promp_json is None: - # self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json") - self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts") - logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}") - - assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist." - class Route(BaseModel): step: Literal["chat", "order"] = Field( @@ -140,14 +131,14 @@ class RoutingGraph(GraphBase): self._load_sys_prompts() def _load_sys_prompts(self): - if "json" in self.config.sys_promp_json[-5:]: + if "json" in self.config.sys_promp_dir[-5:]: logger.info("loading sys prompt from json") - with open(self.config.sys_promp_json , "r") as f: + with open(self.config.sys_promp_dir , "r") as f: self.prompt_dict:Dict[str, str] = commentjson.load(f) - elif osp.isdir(self.config.sys_promp_json): + elif osp.isdir(self.config.sys_promp_dir): logger.info("loading sys_prompt from txt") - sys_fs = glob.glob(osp.join(self.config.sys_promp_json, "*.txt")) + sys_fs = glob.glob(osp.join(self.config.sys_promp_dir, "*.txt")) sys_fs = sorted([e for e in sys_fs if not ("optional" in e)]) self.prompt_dict = {} for sys_f in sys_fs: @@ -155,7 +146,7 @@ class RoutingGraph(GraphBase): with open(sys_f, "r") as f: self.prompt_dict[key] = f.read() else: - err_msg = f"{self.config.sys_promp_json} is not supported" + err_msg = f"{self.config.sys_promp_dir} is not supported" assert 0, err_msg for k, _ in self.prompt_dict.items():