diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index e55f861..10da0e8 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -8,7 +8,7 @@ from lang_agent.tool_manager import ToolManager, ToolManagerConfig from lang_agent.base import GraphBase from langchain.chat_models import init_chat_model -from langchain_core.messages import SystemMessage, HumanMessage +from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage from langchain.agents import create_agent from langgraph.checkpoint.memory import MemorySaver @@ -64,7 +64,37 @@ class ReactGraph(GraphBase): msgs_list = jax.tree.leaves(out) + for e in msgs_list: + if isinstance(e, BaseMessage): + e.pretty_print() + if as_raw: return msgs_list else: return msgs_list[-1].content + + +if __name__ == "__main__": + from dotenv import load_dotenv + from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage + from langchain_core.messages.base import BaseMessageChunk + load_dotenv() + + route:ReactGraph = ReactGraphConfig().setup() + graph = route.agent + + nargs = { + "messages": [SystemMessage("you are a helpful bot named jarvis"), + HumanMessage("use the calculator tool to calculate 92*55 and say the answer")] + },{"configurable": {"thread_id": "3"}} + + for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"): + node = metadata.get("langgraph_node") + if node not in ("model"): + print(node) + continue # skip router or other intermediate nodes + + # Print only the final message content + if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + print(chunk.content, end="", flush=True) + \ No newline at end of file