From 4db1b87c811cc18ecb6ea51ae73e6c4c0d7f8c27 Mon Sep 17 00:00:00 2001 From: goulustis Date: Sat, 22 Nov 2025 19:42:26 +0800 Subject: [PATCH] support subgraph streaming --- lang_agent/graphs/routing.py | 49 +++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 52b001c..1b973eb 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -10,7 +10,7 @@ import glob from lang_agent.config import KeyConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig -from lang_agent.base import GraphBase +from lang_agent.base import GraphBase, ToolNodeBase from lang_agent.graphs.graph_states import State from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig @@ -66,10 +66,19 @@ class RoutingGraph(GraphBase): def _stream_result(self, *nargs, **kwargs): - for chunk, metadata in self.workflow.stream({"inp": nargs}, stream_mode="messages", **kwargs): - node = metadata.get("langgraph_node") - if node != "model": - continue # skip router or other intermediate nodes + streamable_tags = self.tool_node.get_streamable_tags() + [["route_chat_llm"]] + + for chunk, metadata in self.workflow.stream({"inp": nargs}, + stream_mode="messages", + subgraphs=True, + **kwargs): + + if isinstance(metadata, tuple): + chunk, metadata = metadata + + tags = metadata.get("tags") + if not (tags in streamable_tags): + continue # Yield only the final message content chunks if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): @@ -130,8 +139,8 @@ class RoutingGraph(GraphBase): tool_manager:ToolManager = self.config.tool_manager_config.setup() self.chat_model = create_agent(self.chat_llm, self._get_chat_tools(tool_manager), checkpointer=self.memory) - self.tool_node:GraphBase = self.config.tool_node_config.setup(tool_manager=tool_manager, - memory=self.memory) + self.tool_node:ToolNodeBase = self.config.tool_node_config.setup(tool_manager=tool_manager, + memory=self.memory) self._load_sys_prompts() @@ -244,22 +253,26 @@ if __name__ == "__main__": from dotenv import load_dotenv from langchain.messages import SystemMessage, HumanMessage from langchain_core.messages.base import BaseMessageChunk + from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig, ChattyToolNodeConfig load_dotenv() - route:RoutingGraph = RoutingConfig().setup() + route:RoutingGraph = RoutingConfig(tool_node_config=ChattyToolNodeConfig()).setup() graph = route.workflow nargs = { "messages": [SystemMessage("you are a helpful bot named jarvis"), - HumanMessage("what is your name")] + HumanMessage("use calculator to calculate 926*84")] },{"configurable": {"thread_id": "3"}} - - for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"): - node = metadata.get("langgraph_node") - if node not in ("model"): - continue # skip router or other intermediate nodes - # Print only the final message content - if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): - print(chunk.content, end="", flush=True) - \ No newline at end of file + for chunk in route.invoke(*nargs, as_stream=True): + # pass + print(chunk, end="", flush=True) + + # for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"): + # node = metadata.get("langgraph_node") + # if node not in ("model"): + # continue # skip router or other intermediate nodes + + # # Print only the final message content + # if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + # print(chunk.content, end="", flush=True) \ No newline at end of file