support load sys prompt from txt
This commit is contained in:
1
configs/route_sys_prompts/optional_chat_prompt.txt
Normal file
1
configs/route_sys_prompts/optional_chat_prompt.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
create a chat_prompt.txt to overwrite the system prompt from xiaozhi
|
||||||
1
configs/route_sys_prompts/route_prompt.txt
Normal file
1
configs/route_sys_prompts/route_prompt.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input
|
||||||
1
configs/route_sys_prompts/tool_prompt.txt
Normal file
1
configs/route_sys_prompts/tool_prompt.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
You must use tool to complete the possible task
|
||||||
@@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
|
|||||||
import jax
|
import jax
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import commentjson
|
import commentjson
|
||||||
|
import glob
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||||
@@ -45,7 +46,8 @@ class RoutingConfig(KeyConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
if self.sys_promp_json is None:
|
if self.sys_promp_json is None:
|
||||||
self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json")
|
# self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json")
|
||||||
|
self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts")
|
||||||
logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}")
|
logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}")
|
||||||
|
|
||||||
assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist."
|
assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist."
|
||||||
@@ -119,8 +121,31 @@ class RoutingGraph(GraphBase):
|
|||||||
self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
|
self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
|
||||||
self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory)
|
self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory)
|
||||||
|
|
||||||
with open(self.config.sys_promp_json , "r") as f:
|
self._load_sys_prompts()
|
||||||
self.prompt_dict:Dict[str, str] = commentjson.load(f)
|
|
||||||
|
def _load_sys_prompts(self):
|
||||||
|
if "json" in self.config.sys_promp_json[-5:]:
|
||||||
|
logger.info("loading sys prompt from json")
|
||||||
|
with open(self.config.sys_promp_json , "r") as f:
|
||||||
|
self.prompt_dict:Dict[str, str] = commentjson.load(f)
|
||||||
|
|
||||||
|
elif osp.isdir(self.config.sys_promp_json):
|
||||||
|
logger.info("loading sys_prompt from txt")
|
||||||
|
sys_fs = glob.glob(osp.join(self.config.sys_promp_json, "*.txt"))
|
||||||
|
sys_fs = sorted([e for e in sys_fs if not ("optional" in e)])
|
||||||
|
assert len(sys_fs) <= 3, "AT MOST 3 PROMPT!"
|
||||||
|
self.prompt_dict = {}
|
||||||
|
for sys_f in sys_fs:
|
||||||
|
key = osp.basename(sys_f).split(".")[0]
|
||||||
|
with open(sys_f, "r") as f:
|
||||||
|
self.prompt_dict[key] = f.read()
|
||||||
|
else:
|
||||||
|
err_msg = f"{self.config.sys_promp_json} is not supported"
|
||||||
|
assert 0, err_msg
|
||||||
|
|
||||||
|
for k, _ in self.prompt_dict.items():
|
||||||
|
logger.info(f"loaded {k} system prompt")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _router_call(self, state:State):
|
def _router_call(self, state:State):
|
||||||
|
|||||||
Reference in New Issue
Block a user