let chatty_node use tool node

This commit is contained in:
2025-12-12 16:50:48 +08:00
parent 069d15d254
commit 4dbdf4d14b

View File

@@ -75,6 +75,8 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
"""path to chatty system prompt"""
tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig)
class ChattyToolNode(ToolNodeBase):
def __init__(self, config:ChattyToolNodeConfig,
@@ -109,7 +111,9 @@ class ChattyToolNode(ToolNodeBase):
self.reit_llm = ReitLLM(tags=["reit_llm"])
self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem)
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
# self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
self.tool_agent = self.config.tool_node_conf.setup(tool_manager=self.tool_manager,
memory=self.mem)
with open(self.config.chatty_sys_prompt_f, "r") as f:
self.chatty_sys_prompt = f.read()
@@ -132,14 +136,14 @@ class ChattyToolNode(ToolNodeBase):
return {"messages": state_msgs + chat_msgs + tool_msgs}
def _tool_node_call(self, state:ChattyToolState):
inp = {"messages":[
SystemMessage(
self.tool_sys_prompt
),
*state["inp"][0]["messages"][1:]
]}, state["inp"][1]
# inp = {"messages":[
# SystemMessage(
# self.tool_sys_prompt
# ),
# *state["inp"][0]["messages"][1:]
# ]}, state["inp"][1]
out = self.tool_agent.invoke(*inp)
out = self.tool_agent.invoke(state)
self.tool_done = True
return {"tool_messages": out["messages"]}