configure system prompt with json
This commit is contained in:
8
configs/route_sys_prompts.json
Normal file
8
configs/route_sys_prompts.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{ // Prompt for router; manditory to say return as json for route_prompt; if not, there will be bugs
|
||||
"route_prompt" : "Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input",
|
||||
|
||||
// Prompt for tool branch
|
||||
"tool_prompt" : "You must use tool to complete the possible task"
|
||||
|
||||
// Optionally set chat_prompt to overwrite the system prompt from xiaozhi
|
||||
}
|
||||
@@ -7,6 +7,8 @@ from PIL import Image
|
||||
from io import BytesIO
|
||||
import matplotlib.pyplot as plt
|
||||
import jax
|
||||
import os.path as osp
|
||||
import commentjson
|
||||
|
||||
from lang_agent.config import KeyConfig
|
||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||
@@ -34,9 +36,20 @@ class RoutingConfig(KeyConfig):
|
||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
||||
|
||||
sys_promp_json: str = None
|
||||
"path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.sys_promp_json is None:
|
||||
self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json")
|
||||
logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}")
|
||||
|
||||
assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist."
|
||||
|
||||
|
||||
class Route(BaseModel):
|
||||
step: Literal["chat", "order"] = Field(
|
||||
@@ -99,12 +112,16 @@ class RoutingGraph(GraphBase):
|
||||
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory)
|
||||
self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory)
|
||||
|
||||
with open(self.config.sys_promp_json , "r") as f:
|
||||
self.prompt_dict = commentjson.load(f)
|
||||
|
||||
|
||||
def _router_call(self, state:State):
|
||||
decision:Route = self.router.invoke(
|
||||
[
|
||||
SystemMessage(
|
||||
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
|
||||
# content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
|
||||
content=self.prompt_dict.get("route_prompt")
|
||||
),
|
||||
self._get_human_msg(state)
|
||||
]
|
||||
@@ -138,6 +155,16 @@ class RoutingGraph(GraphBase):
|
||||
else:
|
||||
inp = state["inp"]
|
||||
|
||||
if self.prompt_dict.get("chat_prompt") is not None:
|
||||
inp = {"messages":[
|
||||
SystemMessage(
|
||||
# "You must use tool to complete the possible task"
|
||||
self.prompt_dict["chat_prompt"]
|
||||
),
|
||||
*state["inp"][0]["messages"][1:]
|
||||
]}, state["inp"][1]
|
||||
|
||||
|
||||
out = self.chat_model.invoke(*inp)
|
||||
return {"messages": out}
|
||||
|
||||
@@ -145,9 +172,9 @@ class RoutingGraph(GraphBase):
|
||||
def _tool_model_call(self, state:State):
|
||||
inp = {"messages":[
|
||||
SystemMessage(
|
||||
"You must use tool to complete the possible task"
|
||||
# "You must use tool to complete the possible task"
|
||||
self.prompt_dict["tool_prompt"]
|
||||
),
|
||||
# self._get_human_msg(state)
|
||||
*state["inp"][0]["messages"][1:]
|
||||
]}, state["inp"][1]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user