configure system prompt with json

This commit is contained in:
2025-10-27 17:34:49 +08:00
parent 5822c4a572
commit 4bab1bad0b
2 changed files with 41 additions and 6 deletions

View File

@@ -0,0 +1,8 @@
{ // Prompt for router; manditory to say return as json for route_prompt; if not, there will be bugs
"route_prompt" : "Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input",
// Prompt for tool branch
"tool_prompt" : "You must use tool to complete the possible task"
// Optionally set chat_prompt to overwrite the system prompt from xiaozhi
}

View File

@@ -7,6 +7,8 @@ from PIL import Image
from io import BytesIO from io import BytesIO
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import jax import jax
import os.path as osp
import commentjson
from lang_agent.config import KeyConfig from lang_agent.config import KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig
@@ -34,9 +36,20 @@ class RoutingConfig(KeyConfig):
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider""" """base url; could be used to overwrite the baseurl in llm provider"""
sys_promp_json: str = None
"path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
def __post_init__(self):
super().__post_init__()
if self.sys_promp_json is None:
self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json")
logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}")
assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist."
class Route(BaseModel): class Route(BaseModel):
step: Literal["chat", "order"] = Field( step: Literal["chat", "order"] = Field(
@@ -89,9 +102,9 @@ class RoutingGraph(GraphBase):
def _build_modules(self): def _build_modules(self):
self.llm = init_chat_model(model=self.config.llm_name, self.llm = init_chat_model(model=self.config.llm_name,
model_provider=self.config.llm_provider, model_provider=self.config.llm_provider,
api_key=self.config.api_key, api_key=self.config.api_key,
base_url=self.config.base_url) base_url=self.config.base_url)
self.memory = MemorySaver() self.memory = MemorySaver()
self.router = self.llm.with_structured_output(Route) self.router = self.llm.with_structured_output(Route)
@@ -99,12 +112,16 @@ class RoutingGraph(GraphBase):
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory) self.chat_model = create_agent(self.llm, [], checkpointer=self.memory)
self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory) self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory)
with open(self.config.sys_promp_json , "r") as f:
self.prompt_dict = commentjson.load(f)
def _router_call(self, state:State): def _router_call(self, state:State):
decision:Route = self.router.invoke( decision:Route = self.router.invoke(
[ [
SystemMessage( SystemMessage(
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input" # content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
content=self.prompt_dict.get("route_prompt")
), ),
self._get_human_msg(state) self._get_human_msg(state)
] ]
@@ -137,6 +154,16 @@ class RoutingGraph(GraphBase):
inp = state["messages"], state["inp"][1] inp = state["messages"], state["inp"][1]
else: else:
inp = state["inp"] inp = state["inp"]
if self.prompt_dict.get("chat_prompt") is not None:
inp = {"messages":[
SystemMessage(
# "You must use tool to complete the possible task"
self.prompt_dict["chat_prompt"]
),
*state["inp"][0]["messages"][1:]
]}, state["inp"][1]
out = self.chat_model.invoke(*inp) out = self.chat_model.invoke(*inp)
return {"messages": out} return {"messages": out}
@@ -145,9 +172,9 @@ class RoutingGraph(GraphBase):
def _tool_model_call(self, state:State): def _tool_model_call(self, state:State):
inp = {"messages":[ inp = {"messages":[
SystemMessage( SystemMessage(
"You must use tool to complete the possible task" # "You must use tool to complete the possible task"
self.prompt_dict["tool_prompt"]
), ),
# self._get_human_msg(state)
*state["inp"][0]["messages"][1:] *state["inp"][0]["messages"][1:]
]}, state["inp"][1] ]}, state["inp"][1]