support subgraph streaming
This commit is contained in:
@@ -10,7 +10,7 @@ import glob
|
|||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase, ToolNodeBase
|
||||||
from lang_agent.graphs.graph_states import State
|
from lang_agent.graphs.graph_states import State
|
||||||
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig
|
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig
|
||||||
|
|
||||||
@@ -66,10 +66,19 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
|
|
||||||
def _stream_result(self, *nargs, **kwargs):
|
def _stream_result(self, *nargs, **kwargs):
|
||||||
for chunk, metadata in self.workflow.stream({"inp": nargs}, stream_mode="messages", **kwargs):
|
streamable_tags = self.tool_node.get_streamable_tags() + [["route_chat_llm"]]
|
||||||
node = metadata.get("langgraph_node")
|
|
||||||
if node != "model":
|
for chunk, metadata in self.workflow.stream({"inp": nargs},
|
||||||
continue # skip router or other intermediate nodes
|
stream_mode="messages",
|
||||||
|
subgraphs=True,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
if isinstance(metadata, tuple):
|
||||||
|
chunk, metadata = metadata
|
||||||
|
|
||||||
|
tags = metadata.get("tags")
|
||||||
|
if not (tags in streamable_tags):
|
||||||
|
continue
|
||||||
|
|
||||||
# Yield only the final message content chunks
|
# Yield only the final message content chunks
|
||||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||||
@@ -130,7 +139,7 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
||||||
self.chat_model = create_agent(self.chat_llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
|
self.chat_model = create_agent(self.chat_llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
|
||||||
self.tool_node:GraphBase = self.config.tool_node_config.setup(tool_manager=tool_manager,
|
self.tool_node:ToolNodeBase = self.config.tool_node_config.setup(tool_manager=tool_manager,
|
||||||
memory=self.memory)
|
memory=self.memory)
|
||||||
|
|
||||||
self._load_sys_prompts()
|
self._load_sys_prompts()
|
||||||
@@ -244,22 +253,26 @@ if __name__ == "__main__":
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain.messages import SystemMessage, HumanMessage
|
from langchain.messages import SystemMessage, HumanMessage
|
||||||
from langchain_core.messages.base import BaseMessageChunk
|
from langchain_core.messages.base import BaseMessageChunk
|
||||||
|
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig, ChattyToolNodeConfig
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
route:RoutingGraph = RoutingConfig().setup()
|
route:RoutingGraph = RoutingConfig(tool_node_config=ChattyToolNodeConfig()).setup()
|
||||||
graph = route.workflow
|
graph = route.workflow
|
||||||
|
|
||||||
nargs = {
|
nargs = {
|
||||||
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
||||||
HumanMessage("what is your name")]
|
HumanMessage("use calculator to calculate 926*84")]
|
||||||
},{"configurable": {"thread_id": "3"}}
|
},{"configurable": {"thread_id": "3"}}
|
||||||
|
|
||||||
for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
for chunk in route.invoke(*nargs, as_stream=True):
|
||||||
node = metadata.get("langgraph_node")
|
# pass
|
||||||
if node not in ("model"):
|
print(chunk, end="", flush=True)
|
||||||
continue # skip router or other intermediate nodes
|
|
||||||
|
|
||||||
# Print only the final message content
|
# for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
||||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
# node = metadata.get("langgraph_node")
|
||||||
print(chunk.content, end="", flush=True)
|
# if node not in ("model"):
|
||||||
|
# continue # skip router or other intermediate nodes
|
||||||
|
|
||||||
|
# # Print only the final message content
|
||||||
|
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||||
|
# print(chunk.content, end="", flush=True)
|
||||||
Reference in New Issue
Block a user