working chatty tool node

This commit is contained in:
2025-11-21 21:38:02 +08:00
parent b7cf0bc983
commit 60968e3be6

View File

@@ -8,6 +8,7 @@ from lang_agent.config import InstantiateConfig, KeyConfig
from lang_agent.tool_manager import ToolManager from lang_agent.tool_manager import ToolManager
from lang_agent.base import GraphBase from lang_agent.base import GraphBase
from lang_agent.graphs.graph_states import State, ChattyToolState from lang_agent.graphs.graph_states import State, ChattyToolState
from lang_agent.utils import make_llm
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
@@ -28,16 +29,16 @@ class ToolNodeConfig(InstantiateConfig):
class ToolNode(GraphBase): class ToolNode(GraphBase):
def __init__(self, config: ToolNodeConfig, def __init__(self, config: ToolNodeConfig,
tool_manager:ToolManager, tool_manager:ToolManager,
llm:BaseChatModel,
memory:MemorySaver): memory:MemorySaver):
self.config = config self.config = config
self.tool_manager = tool_manager self.tool_manager = tool_manager
self.llm = llm
self.mem = memory self.mem = memory
self.populate_modules() self.populate_modules()
def populate_modules(self): def populate_modules(self):
self.llm = make_llm(tags=["tool_llm"])
self.tool_agent = create_agent(self.llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem) self.tool_agent = create_agent(self.llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
with open(self.config.tool_prompt_f, "r") as f: with open(self.config.tool_prompt_f, "r") as f:
self.sys_prompt = f.read() self.sys_prompt = f.read()
@@ -73,11 +74,9 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
class ChattyToolNode(GraphBase): class ChattyToolNode(GraphBase):
def __init__(self, config:ChattyToolNodeConfig, def __init__(self, config:ChattyToolNodeConfig,
tool_manager:ToolManager, tool_manager:ToolManager,
llm:BaseChatModel,
memory:MemorySaver): memory:MemorySaver):
self.config = config self.config = config
self.tool_manager = tool_manager self.tool_manager = tool_manager
self.tool_llm = llm
self.mem = memory self.mem = memory
self.tool_done = False self.tool_done = False
@@ -90,7 +89,15 @@ class ChattyToolNode(GraphBase):
model_provider=self.config.llm_provider, model_provider=self.config.llm_provider,
api_key=self.config.api_key, api_key=self.config.api_key,
base_url=self.config.base_url, base_url=self.config.base_url,
temperature=0) temperature=0,
tags=["chatty_llm"])
self.tool_llm = init_chat_model(model=self.config.llm_name,
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url,
temperature=0,
tags=["tool_llm"])
self.reit_llm = make_llm(model="qwen-flash", tags=["reit_llm"])
self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem) self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem)
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem) self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
@@ -141,11 +148,15 @@ class ChattyToolNode(GraphBase):
def _handoff_node(self, state:ChattyToolState): def _handoff_node(self, state:ChattyToolState):
# NOTE: this exist to have both results # NOTE: This exists just to stream the thing correctly
# chat_msgs = state.get("chatty_message") tool_msgs = state.get("tool_messages")["messages"]
# tool_msgs = state.get("tool_message") inp = [
SystemMessage(
# return {"messages": chat_msgs + tool_msgs} "do nothing and repeat the last message"
),
tool_msgs[-1].content
]
self.reit_llm.invoke(inp)
return {} return {}
@@ -174,36 +185,29 @@ tool_node_union = tyro.extras.subcommand_type_from_defaults(tool_node_dict, pref
AnnotatedToolNode = tyro.conf.OmitSubcommandPrefixes[tyro.conf.SuppressFixed[tool_node_union]] AnnotatedToolNode = tyro.conf.OmitSubcommandPrefixes[tyro.conf.SuppressFixed[tool_node_union]]
if __name__ == "__main__": if __name__ == "__main__":
from langchain.chat_models import init_chat_model
from langchain_core.messages.base import BaseMessageChunk from langchain_core.messages.base import BaseMessageChunk
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from lang_agent.tool_manager import ToolManagerConfig from lang_agent.tool_manager import ToolManagerConfig
from dotenv import load_dotenv from dotenv import load_dotenv
import os
load_dotenv() load_dotenv()
llm = init_chat_model(model="qwen-flash",
model_provider="openai",
api_key=os.environ.get("ALI_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
temperature=0)
mem = MemorySaver() mem = MemorySaver()
tool_manager = ToolManagerConfig().setup() tool_manager = ToolManagerConfig().setup()
chatty_node:ChattyToolNode = ChattyToolNodeConfig().setup(tool_manager=tool_manager, chatty_node:ChattyToolNode = ChattyToolNodeConfig().setup(tool_manager=tool_manager,
llm=llm,
memory=mem) memory=mem)
query = "use calculator to calculate 33*42" query = "use calculator to calculate 33*42"
input = {"inp" : ({"messages":[SystemMessage("you are a kind helper"), HumanMessage(query)]}, input = {"inp" : ({"messages":[SystemMessage("you are a kind helper"), HumanMessage(query)]},
{"configurable": {"thread_id": '3'}})} {"configurable": {"thread_id": '3'}})}
inp = input inp = input
graph = chatty_node.workflow graph = chatty_node.workflow
# for chunk, metadata in graph.stream(inp, stream_mode="messages"): for chunk, metadata in graph.stream(inp, stream_mode="messages"):
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): tags = metadata.get("tags")
# print(chunk.content, end="", flush=True) if not (tags in [["chatty_llm"], ["reit_llm"]]):
out = graph.invoke(inp) continue
assert 0
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
print(chunk.content, end="", flush=True)