configure system prompt with json

This commit is contained in:
2025-10-27 17:34:49 +08:00
parent 5822c4a572
commit 4bab1bad0b
2 changed files with 41 additions and 6 deletions

View File

@@ -0,0 +1,8 @@
{ // Prompt for router; manditory to say return as json for route_prompt; if not, there will be bugs
"route_prompt" : "Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input",
// Prompt for tool branch
"tool_prompt" : "You must use tool to complete the possible task"
// Optionally set chat_prompt to overwrite the system prompt from xiaozhi
}

View File

@@ -7,6 +7,8 @@ from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import jax
import os.path as osp
import commentjson
from lang_agent.config import KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
@@ -34,9 +36,20 @@ class RoutingConfig(KeyConfig):
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider"""
sys_promp_json: str = None
"path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
def __post_init__(self):
super().__post_init__()
if self.sys_promp_json is None:
self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json")
logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}")
assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist."
class Route(BaseModel):
step: Literal["chat", "order"] = Field(
@@ -89,9 +102,9 @@ class RoutingGraph(GraphBase):
def _build_modules(self):
self.llm = init_chat_model(model=self.config.llm_name,
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url)
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url)
self.memory = MemorySaver()
self.router = self.llm.with_structured_output(Route)
@@ -99,12 +112,16 @@ class RoutingGraph(GraphBase):
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory)
self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory)
with open(self.config.sys_promp_json , "r") as f:
self.prompt_dict = commentjson.load(f)
def _router_call(self, state:State):
decision:Route = self.router.invoke(
[
SystemMessage(
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
# content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
content=self.prompt_dict.get("route_prompt")
),
self._get_human_msg(state)
]
@@ -137,6 +154,16 @@ class RoutingGraph(GraphBase):
inp = state["messages"], state["inp"][1]
else:
inp = state["inp"]
if self.prompt_dict.get("chat_prompt") is not None:
inp = {"messages":[
SystemMessage(
# "You must use tool to complete the possible task"
self.prompt_dict["chat_prompt"]
),
*state["inp"][0]["messages"][1:]
]}, state["inp"][1]
out = self.chat_model.invoke(*inp)
return {"messages": out}
@@ -145,9 +172,9 @@ class RoutingGraph(GraphBase):
def _tool_model_call(self, state:State):
inp = {"messages":[
SystemMessage(
"You must use tool to complete the possible task"
# "You must use tool to complete the possible task"
self.prompt_dict["tool_prompt"]
),
# self._get_human_msg(state)
*state["inp"][0]["messages"][1:]
]}, state["inp"][1]