diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index 50ef0a9..d1e3660 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -75,6 +75,8 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig): chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt") """path to chatty system prompt""" + tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig) + class ChattyToolNode(ToolNodeBase): def __init__(self, config:ChattyToolNodeConfig, @@ -109,7 +111,9 @@ class ChattyToolNode(ToolNodeBase): self.reit_llm = ReitLLM(tags=["reit_llm"]) self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem) - self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem) + # self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem) + self.tool_agent = self.config.tool_node_conf.setup(tool_manager=self.tool_manager, + memory=self.mem) with open(self.config.chatty_sys_prompt_f, "r") as f: self.chatty_sys_prompt = f.read() @@ -132,14 +136,14 @@ class ChattyToolNode(ToolNodeBase): return {"messages": state_msgs + chat_msgs + tool_msgs} def _tool_node_call(self, state:ChattyToolState): - inp = {"messages":[ - SystemMessage( - self.tool_sys_prompt - ), - *state["inp"][0]["messages"][1:] - ]}, state["inp"][1] + # inp = {"messages":[ + # SystemMessage( + # self.tool_sys_prompt + # ), + # *state["inp"][0]["messages"][1:] + # ]}, state["inp"][1] - out = self.tool_agent.invoke(*inp) + out = self.tool_agent.invoke(state) self.tool_done = True return {"tool_messages": out["messages"]}