routing support streaming
This commit is contained in:
@@ -17,6 +17,7 @@ from lang_agent.base import GraphBase
|
|||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||||
|
from langchain_core.messages.base import BaseMessageChunk
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph, START, END
|
||||||
@@ -80,32 +81,32 @@ class RoutingGraph(GraphBase):
|
|||||||
self.workflow = self._build_graph()
|
self.workflow = self._build_graph()
|
||||||
|
|
||||||
|
|
||||||
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs)->str:
|
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||||
self._validate_input(*nargs, **kwargs)
|
self._validate_input(*nargs, **kwargs)
|
||||||
|
|
||||||
if as_stream:
|
if as_stream:
|
||||||
# TODO: this doesn't stream the entire process, we are blind
|
# Stream messages from the workflow
|
||||||
for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs):
|
for chunk, metadata in self.workflow.stream({"inp": nargs}, stream_mode="messages", **kwargs):
|
||||||
last_el = jax.tree.leaves(step)[-1]
|
node = metadata.get("langgraph_node")
|
||||||
if isinstance(last_el, str):
|
if node != "model":
|
||||||
logger.info(last_el)
|
continue # skip router or other intermediate nodes
|
||||||
elif isinstance(last_el, BaseMessage):
|
|
||||||
last_el.pretty_print()
|
# Yield only the final message content chunks
|
||||||
|
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||||
state = step
|
yield chunk.content
|
||||||
else:
|
else:
|
||||||
state = self.workflow.invoke({"inp": nargs})
|
state = self.workflow.invoke({"inp": nargs})
|
||||||
|
|
||||||
msg_list = jax.tree.leaves(state)
|
msg_list = jax.tree.leaves(state)
|
||||||
|
|
||||||
for e in msg_list:
|
for e in msg_list:
|
||||||
if isinstance(e, BaseMessage):
|
if isinstance(e, BaseMessage):
|
||||||
e.pretty_print()
|
e.pretty_print()
|
||||||
|
|
||||||
if as_raw:
|
if as_raw:
|
||||||
return msg_list
|
return msg_list
|
||||||
|
|
||||||
return msg_list[-1].content
|
return msg_list[-1].content
|
||||||
|
|
||||||
def _validate_input(self, *nargs, **kwargs):
|
def _validate_input(self, *nargs, **kwargs):
|
||||||
print("\033[93m====================INPUT MESSAGES=============================\033[0m")
|
print("\033[93m====================INPUT MESSAGES=============================\033[0m")
|
||||||
@@ -256,5 +257,25 @@ class RoutingGraph(GraphBase):
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
route = RoutingConfig().setup()
|
from dotenv import load_dotenv
|
||||||
route.show_graph()
|
from langchain.messages import SystemMessage, HumanMessage
|
||||||
|
from langchain_core.messages.base import BaseMessageChunk
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
route:RoutingGraph = RoutingConfig().setup()
|
||||||
|
graph = route.workflow
|
||||||
|
|
||||||
|
nargs = {
|
||||||
|
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
||||||
|
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
|
||||||
|
},{"configurable": {"thread_id": "3"}}
|
||||||
|
|
||||||
|
for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
||||||
|
node = metadata.get("langgraph_node")
|
||||||
|
if node not in ("model"):
|
||||||
|
continue # skip router or other intermediate nodes
|
||||||
|
|
||||||
|
# Print only the final message content
|
||||||
|
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||||
|
print(chunk.content, end="", flush=True)
|
||||||
|
|
||||||
Reference in New Issue
Block a user