diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index aabdc3f..221048b 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -98,7 +98,6 @@ class ChattyToolNode(ToolNodeBase): self.config = config self.tool_manager = tool_manager self.mem = memory - self.tool_done = False self.chat_key = "[CHATTY_OUT]" self.tool_key = "[TOOL_OUT]" @@ -138,9 +137,7 @@ class ChattyToolNode(ToolNodeBase): return [["chatty_llm"], ["reit_llm"]] def invoke(self, state:State): - self.tool_done = False - - inp = {"inp": state["inp"]} + inp = {"inp": state["inp"], "tool_done": False} out = self.workflow.invoke(inp) chat_msgs = out.get("chatty_messages")["messages"] tool_msgs = out.get("tool_messages")["messages"] @@ -150,9 +147,7 @@ class ChattyToolNode(ToolNodeBase): async def ainvoke(self, state:State): """Async version of invoke using LangGraph's native async support.""" - self.tool_done = False - - inp = {"inp": state["inp"]} + inp = {"inp": state["inp"], "tool_done": False} out = await self.workflow.ainvoke(inp) chat_msgs = out.get("chatty_messages")["messages"] tool_msgs = out.get("tool_messages")["messages"] @@ -170,14 +165,13 @@ class ChattyToolNode(ToolNodeBase): out = self.tool_agent.invoke(state) - self.tool_done = True - return {"tool_messages": out["messages"]} + return {"tool_messages": out["messages"], "tool_done": True} def _chat_node_call(self, state:ChattyToolState): outs:List[BaseMessage] = [] - while not self.tool_done: + while not state.get("tool_done", False): inp = {"messages":[ SystemMessage( f"回复的最开始应该是{self.chat_key}\n"+self.chatty_sys_prompt