routing support streaming

This commit is contained in:
2025-11-07 14:51:59 +08:00
parent daeb0ca251
commit 067803ee7a

View File

@@ -17,6 +17,7 @@ from lang_agent.base import GraphBase
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
from langchain_core.messages.base import BaseMessageChunk
from langchain.agents import create_agent
from langgraph.graph import StateGraph, START, END
@@ -80,19 +81,19 @@ class RoutingGraph(GraphBase):
self.workflow = self._build_graph()
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs)->str:
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
self._validate_input(*nargs, **kwargs)
if as_stream:
# TODO this doesn't stream the entire process, we are blind
for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs):
last_el = jax.tree.leaves(step)[-1]
if isinstance(last_el, str):
logger.info(last_el)
elif isinstance(last_el, BaseMessage):
last_el.pretty_print()
# Stream messages from the workflow
for chunk, metadata in self.workflow.stream({"inp": nargs}, stream_mode="messages", **kwargs):
node = metadata.get("langgraph_node")
if node != "model":
continue # skip router or other intermediate nodes
state = step
# Yield only the final message content chunks
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
yield chunk.content
else:
state = self.workflow.invoke({"inp": nargs})
@@ -256,5 +257,25 @@ class RoutingGraph(GraphBase):
plt.show()
if __name__ == "__main__":
route = RoutingConfig().setup()
route.show_graph()
from dotenv import load_dotenv
from langchain.messages import SystemMessage, HumanMessage
from langchain_core.messages.base import BaseMessageChunk
load_dotenv()
route:RoutingGraph = RoutingConfig().setup()
graph = route.workflow
nargs = {
"messages": [SystemMessage("you are a helpful bot named jarvis"),
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
},{"configurable": {"thread_id": "3"}}
for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
node = metadata.get("langgraph_node")
if node not in ("model"):
continue # skip router or other intermediate nodes
# Print only the final message content
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
print(chunk.content, end="", flush=True)