change default sys_promp_path
This commit is contained in:
@@ -37,21 +37,12 @@ class RoutingConfig(KeyConfig):
|
|||||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
"""base url; could be used to overwrite the baseurl in llm provider"""
|
||||||
|
|
||||||
sys_promp_json: str = None
|
sys_promp_dir: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts")
|
||||||
"path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"
|
"""path to directory or json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"""
|
||||||
|
|
||||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||||
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__post_init__()
|
|
||||||
if self.sys_promp_json is None:
|
|
||||||
# self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json")
|
|
||||||
self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts")
|
|
||||||
logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}")
|
|
||||||
|
|
||||||
assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist."
|
|
||||||
|
|
||||||
|
|
||||||
class Route(BaseModel):
|
class Route(BaseModel):
|
||||||
step: Literal["chat", "order"] = Field(
|
step: Literal["chat", "order"] = Field(
|
||||||
@@ -140,14 +131,14 @@ class RoutingGraph(GraphBase):
|
|||||||
self._load_sys_prompts()
|
self._load_sys_prompts()
|
||||||
|
|
||||||
def _load_sys_prompts(self):
|
def _load_sys_prompts(self):
|
||||||
if "json" in self.config.sys_promp_json[-5:]:
|
if "json" in self.config.sys_promp_dir[-5:]:
|
||||||
logger.info("loading sys prompt from json")
|
logger.info("loading sys prompt from json")
|
||||||
with open(self.config.sys_promp_json , "r") as f:
|
with open(self.config.sys_promp_dir , "r") as f:
|
||||||
self.prompt_dict:Dict[str, str] = commentjson.load(f)
|
self.prompt_dict:Dict[str, str] = commentjson.load(f)
|
||||||
|
|
||||||
elif osp.isdir(self.config.sys_promp_json):
|
elif osp.isdir(self.config.sys_promp_dir):
|
||||||
logger.info("loading sys_prompt from txt")
|
logger.info("loading sys_prompt from txt")
|
||||||
sys_fs = glob.glob(osp.join(self.config.sys_promp_json, "*.txt"))
|
sys_fs = glob.glob(osp.join(self.config.sys_promp_dir, "*.txt"))
|
||||||
sys_fs = sorted([e for e in sys_fs if not ("optional" in e)])
|
sys_fs = sorted([e for e in sys_fs if not ("optional" in e)])
|
||||||
self.prompt_dict = {}
|
self.prompt_dict = {}
|
||||||
for sys_f in sys_fs:
|
for sys_f in sys_fs:
|
||||||
@@ -155,7 +146,7 @@ class RoutingGraph(GraphBase):
|
|||||||
with open(sys_f, "r") as f:
|
with open(sys_f, "r") as f:
|
||||||
self.prompt_dict[key] = f.read()
|
self.prompt_dict[key] = f.read()
|
||||||
else:
|
else:
|
||||||
err_msg = f"{self.config.sys_promp_json} is not supported"
|
err_msg = f"{self.config.sys_promp_dir} is not supported"
|
||||||
assert 0, err_msg
|
assert 0, err_msg
|
||||||
|
|
||||||
for k, _ in self.prompt_dict.items():
|
for k, _ in self.prompt_dict.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user