From 067803ee7af0adde654772a3055684ef4e472cc1 Mon Sep 17 00:00:00 2001 From: goulustis Date: Fri, 7 Nov 2025 14:51:59 +0800 Subject: [PATCH] routing support streaming --- lang_agent/graphs/routing.py | 63 ++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 5e1c0d0..a806600 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -17,6 +17,7 @@ from lang_agent.base import GraphBase from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage +from langchain_core.messages.base import BaseMessageChunk from langchain.agents import create_agent from langgraph.graph import StateGraph, START, END @@ -80,32 +81,32 @@ class RoutingGraph(GraphBase): self.workflow = self._build_graph() - def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs)->str: + def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): self._validate_input(*nargs, **kwargs) if as_stream: - # TODO: this doesn't stream the entire process, we are blind - for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs): - last_el = jax.tree.leaves(step)[-1] - if isinstance(last_el, str): - logger.info(last_el) - elif isinstance(last_el, BaseMessage): - last_el.pretty_print() - - state = step + # Stream messages from the workflow + for chunk, metadata in self.workflow.stream({"inp": nargs}, stream_mode="messages", **kwargs): + node = metadata.get("langgraph_node") + if node != "model": + continue # skip router or other intermediate nodes + + # Yield only the final message content chunks + if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + yield chunk.content else: state = self.workflow.invoke({"inp": nargs}) - - msg_list = jax.tree.leaves(state) + + msg_list = jax.tree.leaves(state) - for e in msg_list: - if isinstance(e, BaseMessage): - e.pretty_print() - - if as_raw: - return msg_list + for e in msg_list: + if isinstance(e, BaseMessage): + e.pretty_print() + + if as_raw: + return msg_list - return msg_list[-1].content + return msg_list[-1].content def _validate_input(self, *nargs, **kwargs): print("\033[93m====================INPUT MESSAGES=============================\033[0m") @@ -256,5 +257,25 @@ class RoutingGraph(GraphBase): plt.show() if __name__ == "__main__": - route = RoutingConfig().setup() - route.show_graph() \ No newline at end of file + from dotenv import load_dotenv + from langchain.messages import SystemMessage, HumanMessage + from langchain_core.messages.base import BaseMessageChunk + load_dotenv() + + route:RoutingGraph = RoutingConfig().setup() + graph = route.workflow + + nargs = { + "messages": [SystemMessage("you are a helpful bot named jarvis"), + HumanMessage("use the calculator tool to calculate 92*55 and say the answer")] + },{"configurable": {"thread_id": "3"}} + + for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"): + node = metadata.get("langgraph_node") + if node not in ("model"): + continue # skip router or other intermediate nodes + + # Print only the final message content + if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + print(chunk.content, end="", flush=True) + \ No newline at end of file