routing support streaming
This commit is contained in:
@@ -17,6 +17,7 @@ from lang_agent.base import GraphBase
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||
from langchain_core.messages.base import BaseMessageChunk
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
@@ -80,32 +81,32 @@ class RoutingGraph(GraphBase):
|
||||
self.workflow = self._build_graph()
|
||||
|
||||
|
||||
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs)->str:
|
||||
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||
self._validate_input(*nargs, **kwargs)
|
||||
|
||||
if as_stream:
|
||||
# TODO: this doesn't stream the entire process, we are blind
|
||||
for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs):
|
||||
last_el = jax.tree.leaves(step)[-1]
|
||||
if isinstance(last_el, str):
|
||||
logger.info(last_el)
|
||||
elif isinstance(last_el, BaseMessage):
|
||||
last_el.pretty_print()
|
||||
# Stream messages from the workflow
|
||||
for chunk, metadata in self.workflow.stream({"inp": nargs}, stream_mode="messages", **kwargs):
|
||||
node = metadata.get("langgraph_node")
|
||||
if node != "model":
|
||||
continue # skip router or other intermediate nodes
|
||||
|
||||
state = step
|
||||
# Yield only the final message content chunks
|
||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||
yield chunk.content
|
||||
else:
|
||||
state = self.workflow.invoke({"inp": nargs})
|
||||
|
||||
msg_list = jax.tree.leaves(state)
|
||||
msg_list = jax.tree.leaves(state)
|
||||
|
||||
for e in msg_list:
|
||||
if isinstance(e, BaseMessage):
|
||||
e.pretty_print()
|
||||
for e in msg_list:
|
||||
if isinstance(e, BaseMessage):
|
||||
e.pretty_print()
|
||||
|
||||
if as_raw:
|
||||
return msg_list
|
||||
if as_raw:
|
||||
return msg_list
|
||||
|
||||
return msg_list[-1].content
|
||||
return msg_list[-1].content
|
||||
|
||||
def _validate_input(self, *nargs, **kwargs):
|
||||
print("\033[93m====================INPUT MESSAGES=============================\033[0m")
|
||||
@@ -256,5 +257,25 @@ class RoutingGraph(GraphBase):
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
route = RoutingConfig().setup()
|
||||
route.show_graph()
|
||||
from dotenv import load_dotenv
|
||||
from langchain.messages import SystemMessage, HumanMessage
|
||||
from langchain_core.messages.base import BaseMessageChunk
|
||||
load_dotenv()
|
||||
|
||||
route:RoutingGraph = RoutingConfig().setup()
|
||||
graph = route.workflow
|
||||
|
||||
nargs = {
|
||||
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
||||
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
|
||||
},{"configurable": {"thread_id": "3"}}
|
||||
|
||||
for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
||||
node = metadata.get("langgraph_node")
|
||||
if node not in ("model"):
|
||||
continue # skip router or other intermediate nodes
|
||||
|
||||
# Print only the final message content
|
||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||
print(chunk.content, end="", flush=True)
|
||||
|
||||
Reference in New Issue
Block a user