From 355b87102a2e3178f12e31ae496472d4958b7170 Mon Sep 17 00:00:00 2001 From: goulustis Date: Fri, 21 Nov 2025 16:23:57 +0800 Subject: [PATCH] bug fixes + test --- lang_agent/graphs/tool_nodes.py | 41 +++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index 06b8907..0cb940e 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -92,7 +92,7 @@ class ChattyToolNode(GraphBase): base_url=self.config.base_url, temperature=0) - self.chatty_agent = create_agent(self.chatty_agent, [], checkpointer=self.mem) + self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem) self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem) with open(self.config.chatty_sys_prompt_f, "r") as f: @@ -102,6 +102,8 @@ class ChattyToolNode(GraphBase): self.tool_sys_prompt = f.read() def invoke(self, state:State): + self.tool_done = False + inp = {"inp": state["inp"]} out = self.workflow.invoke(inp) chat_msgs = out.get("chatty_message") @@ -119,6 +121,7 @@ class ChattyToolNode(GraphBase): out = self.tool_agent.invoke(*inp) + self.tool_done = True return {"tool_messages": out} @@ -155,7 +158,7 @@ class ChattyToolNode(GraphBase): builder.add_edge(START, "chatty_tool_call") builder.add_edge(START, "chatty_chat_call") builder.add_edge("chatty_chat_call", "chatty_handoff_node") - builder.add_edge("chatty_node_call", "chatty_handoff_node") + builder.add_edge("chatty_tool_call", "chatty_handoff_node") builder.add_edge("chatty_handoff_node", END) self.workflow = builder.compile() @@ -171,5 +174,35 @@ tool_node_union = tyro.extras.subcommand_type_from_defaults(tool_node_dict, pref AnnotatedToolNode = tyro.conf.OmitSubcommandPrefixes[tyro.conf.SuppressFixed[tool_node_union]] if __name__ == "__main__": - tyro.cli(ToolNodeConfig) - + from langchain.chat_models import init_chat_model + from langchain_core.messages.base import BaseMessageChunk + from langchain_core.messages import BaseMessage + + from lang_agent.tool_manager import ToolManagerConfig + + from dotenv import load_dotenv + import os + load_dotenv() + + llm = init_chat_model(model="qwen-flash", + model_provider="openai", + api_key=os.environ.get("ALI_API_KEY"), + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", + temperature=0) + mem = MemorySaver() + tool_manager = ToolManagerConfig().setup() + chatty_node:ChattyToolNode = ChattyToolNodeConfig().setup(tool_manager=tool_manager, + llm=llm, + memory=mem) + + query = "use calculator to calculate 33*42" + input = {"inp" : ({"messages":[SystemMessage("you are a kind helper"), HumanMessage(query)]}, + {"configurable": {"thread_id": '3'}})} + inp = input["inp"] + + graph = chatty_node.workflow + + for chunk, metadata in graph.stream(*inp, stream_mode="messages"): + if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + print(chunk.content, end="", flush=True) +