From 60968e3be6097458041a45ab86efae8c16bf3191 Mon Sep 17 00:00:00 2001 From: goulustis Date: Fri, 21 Nov 2025 21:38:02 +0800 Subject: [PATCH] working chatty tool node --- lang_agent/graphs/tool_nodes.py | 52 ++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index 51dd04e..1daab72 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -8,6 +8,7 @@ from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.tool_manager import ToolManager from lang_agent.base import GraphBase from lang_agent.graphs.graph_states import State, ChattyToolState +from lang_agent.utils import make_llm from langchain_core.language_models import BaseChatModel from langchain_core.messages import SystemMessage, HumanMessage, AIMessage @@ -28,16 +29,16 @@ class ToolNodeConfig(InstantiateConfig): class ToolNode(GraphBase): def __init__(self, config: ToolNodeConfig, tool_manager:ToolManager, - llm:BaseChatModel, memory:MemorySaver): self.config = config self.tool_manager = tool_manager - self.llm = llm self.mem = memory self.populate_modules() def populate_modules(self): + self.llm = make_llm(tags=["tool_llm"]) + self.tool_agent = create_agent(self.llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem) with open(self.config.tool_prompt_f, "r") as f: self.sys_prompt = f.read() @@ -73,11 +74,9 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig): class ChattyToolNode(GraphBase): def __init__(self, config:ChattyToolNodeConfig, tool_manager:ToolManager, - llm:BaseChatModel, memory:MemorySaver): self.config = config self.tool_manager = tool_manager - self.tool_llm = llm self.mem = memory self.tool_done = False @@ -90,7 +89,15 @@ class ChattyToolNode(GraphBase): model_provider=self.config.llm_provider, api_key=self.config.api_key, base_url=self.config.base_url, - temperature=0) + temperature=0, + tags=["chatty_llm"]) + self.tool_llm = init_chat_model(model=self.config.llm_name, + model_provider=self.config.llm_provider, + api_key=self.config.api_key, + base_url=self.config.base_url, + temperature=0, + tags=["tool_llm"]) + self.reit_llm = make_llm(model="qwen-flash", tags=["reit_llm"]) self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem) self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem) @@ -141,11 +148,15 @@ class ChattyToolNode(GraphBase): def _handoff_node(self, state:ChattyToolState): - # NOTE: this exist to have both results - # chat_msgs = state.get("chatty_message") - # tool_msgs = state.get("tool_message") - - # return {"messages": chat_msgs + tool_msgs} + # NOTE: This exists just to stream the thing correctly + tool_msgs = state.get("tool_messages")["messages"] + inp = [ + SystemMessage( + "do nothing and repeat the last message" + ), + tool_msgs[-1].content + ] + self.reit_llm.invoke(inp) return {} @@ -174,36 +185,29 @@ tool_node_union = tyro.extras.subcommand_type_from_defaults(tool_node_dict, pref AnnotatedToolNode = tyro.conf.OmitSubcommandPrefixes[tyro.conf.SuppressFixed[tool_node_union]] if __name__ == "__main__": - from langchain.chat_models import init_chat_model from langchain_core.messages.base import BaseMessageChunk from langchain_core.messages import BaseMessage from lang_agent.tool_manager import ToolManagerConfig from dotenv import load_dotenv - import os load_dotenv() - llm = init_chat_model(model="qwen-flash", - model_provider="openai", - api_key=os.environ.get("ALI_API_KEY"), - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - temperature=0) mem = MemorySaver() tool_manager = ToolManagerConfig().setup() chatty_node:ChattyToolNode = ChattyToolNodeConfig().setup(tool_manager=tool_manager, - llm=llm, memory=mem) query = "use calculator to calculate 33*42" input = {"inp" : ({"messages":[SystemMessage("you are a kind helper"), HumanMessage(query)]}, {"configurable": {"thread_id": '3'}})} inp = input - graph = chatty_node.workflow - # for chunk, metadata in graph.stream(inp, stream_mode="messages"): - # if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): - # print(chunk.content, end="", flush=True) - out = graph.invoke(inp) - assert 0 + for chunk, metadata in graph.stream(inp, stream_mode="messages"): + tags = metadata.get("tags") + if not (tags in [["chatty_llm"], ["reit_llm"]]): + continue + + if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + print(chunk.content, end="", flush=True)