let chatty_node use tool node
This commit is contained in:
@@ -75,6 +75,8 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
|
||||
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
||||
"""path to chatty system prompt"""
|
||||
|
||||
tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig)
|
||||
|
||||
|
||||
class ChattyToolNode(ToolNodeBase):
|
||||
def __init__(self, config:ChattyToolNodeConfig,
|
||||
@@ -109,7 +111,9 @@ class ChattyToolNode(ToolNodeBase):
|
||||
self.reit_llm = ReitLLM(tags=["reit_llm"])
|
||||
|
||||
self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem)
|
||||
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
|
||||
# self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
|
||||
self.tool_agent = self.config.tool_node_conf.setup(tool_manager=self.tool_manager,
|
||||
memory=self.mem)
|
||||
|
||||
with open(self.config.chatty_sys_prompt_f, "r") as f:
|
||||
self.chatty_sys_prompt = f.read()
|
||||
@@ -132,14 +136,14 @@ class ChattyToolNode(ToolNodeBase):
|
||||
return {"messages": state_msgs + chat_msgs + tool_msgs}
|
||||
|
||||
def _tool_node_call(self, state:ChattyToolState):
|
||||
inp = {"messages":[
|
||||
SystemMessage(
|
||||
self.tool_sys_prompt
|
||||
),
|
||||
*state["inp"][0]["messages"][1:]
|
||||
]}, state["inp"][1]
|
||||
# inp = {"messages":[
|
||||
# SystemMessage(
|
||||
# self.tool_sys_prompt
|
||||
# ),
|
||||
# *state["inp"][0]["messages"][1:]
|
||||
# ]}, state["inp"][1]
|
||||
|
||||
out = self.tool_agent.invoke(*inp)
|
||||
out = self.tool_agent.invoke(state)
|
||||
|
||||
self.tool_done = True
|
||||
return {"tool_messages": out["messages"]}
|
||||
|
||||
Reference in New Issue
Block a user