configure system prompt with json
This commit is contained in:
8
configs/route_sys_prompts.json
Normal file
8
configs/route_sys_prompts.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{ // Prompt for router; manditory to say return as json for route_prompt; if not, there will be bugs
|
||||||
|
"route_prompt" : "Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input",
|
||||||
|
|
||||||
|
// Prompt for tool branch
|
||||||
|
"tool_prompt" : "You must use tool to complete the possible task"
|
||||||
|
|
||||||
|
// Optionally set chat_prompt to overwrite the system prompt from xiaozhi
|
||||||
|
}
|
||||||
@@ -7,6 +7,8 @@ from PIL import Image
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import jax
|
import jax
|
||||||
|
import os.path as osp
|
||||||
|
import commentjson
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||||
@@ -34,9 +36,20 @@ class RoutingConfig(KeyConfig):
|
|||||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
"""base url; could be used to overwrite the baseurl in llm provider"""
|
||||||
|
|
||||||
|
sys_promp_json: str = None
|
||||||
|
"path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"
|
||||||
|
|
||||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
if self.sys_promp_json is None:
|
||||||
|
self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json")
|
||||||
|
logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}")
|
||||||
|
|
||||||
|
assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist."
|
||||||
|
|
||||||
|
|
||||||
class Route(BaseModel):
|
class Route(BaseModel):
|
||||||
step: Literal["chat", "order"] = Field(
|
step: Literal["chat", "order"] = Field(
|
||||||
@@ -89,9 +102,9 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
def _build_modules(self):
|
def _build_modules(self):
|
||||||
self.llm = init_chat_model(model=self.config.llm_name,
|
self.llm = init_chat_model(model=self.config.llm_name,
|
||||||
model_provider=self.config.llm_provider,
|
model_provider=self.config.llm_provider,
|
||||||
api_key=self.config.api_key,
|
api_key=self.config.api_key,
|
||||||
base_url=self.config.base_url)
|
base_url=self.config.base_url)
|
||||||
self.memory = MemorySaver()
|
self.memory = MemorySaver()
|
||||||
self.router = self.llm.with_structured_output(Route)
|
self.router = self.llm.with_structured_output(Route)
|
||||||
|
|
||||||
@@ -99,12 +112,16 @@ class RoutingGraph(GraphBase):
|
|||||||
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory)
|
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory)
|
||||||
self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory)
|
self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory)
|
||||||
|
|
||||||
|
with open(self.config.sys_promp_json , "r") as f:
|
||||||
|
self.prompt_dict = commentjson.load(f)
|
||||||
|
|
||||||
|
|
||||||
def _router_call(self, state:State):
|
def _router_call(self, state:State):
|
||||||
decision:Route = self.router.invoke(
|
decision:Route = self.router.invoke(
|
||||||
[
|
[
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
|
# content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
|
||||||
|
content=self.prompt_dict.get("route_prompt")
|
||||||
),
|
),
|
||||||
self._get_human_msg(state)
|
self._get_human_msg(state)
|
||||||
]
|
]
|
||||||
@@ -138,6 +155,16 @@ class RoutingGraph(GraphBase):
|
|||||||
else:
|
else:
|
||||||
inp = state["inp"]
|
inp = state["inp"]
|
||||||
|
|
||||||
|
if self.prompt_dict.get("chat_prompt") is not None:
|
||||||
|
inp = {"messages":[
|
||||||
|
SystemMessage(
|
||||||
|
# "You must use tool to complete the possible task"
|
||||||
|
self.prompt_dict["chat_prompt"]
|
||||||
|
),
|
||||||
|
*state["inp"][0]["messages"][1:]
|
||||||
|
]}, state["inp"][1]
|
||||||
|
|
||||||
|
|
||||||
out = self.chat_model.invoke(*inp)
|
out = self.chat_model.invoke(*inp)
|
||||||
return {"messages": out}
|
return {"messages": out}
|
||||||
|
|
||||||
@@ -145,9 +172,9 @@ class RoutingGraph(GraphBase):
|
|||||||
def _tool_model_call(self, state:State):
|
def _tool_model_call(self, state:State):
|
||||||
inp = {"messages":[
|
inp = {"messages":[
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
"You must use tool to complete the possible task"
|
# "You must use tool to complete the possible task"
|
||||||
|
self.prompt_dict["tool_prompt"]
|
||||||
),
|
),
|
||||||
# self._get_human_msg(state)
|
|
||||||
*state["inp"][0]["messages"][1:]
|
*state["inp"][0]["messages"][1:]
|
||||||
]}, state["inp"][1]
|
]}, state["inp"][1]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user