From 40ee9b742012e45feb96b705c188142b7f503f1e Mon Sep 17 00:00:00 2001 From: goulustis Date: Fri, 30 Jan 2026 18:47:01 +0800 Subject: [PATCH] make react a one node workflow --- lang_agent/graphs/react.py | 106 +++++++++++++------------------------ 1 file changed, 37 insertions(+), 69 deletions(-) diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index 360e177..c4e2a6b 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -8,11 +8,13 @@ from lang_agent.config import KeyConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.base import GraphBase from lang_agent.utils import tree_leaves +from lang_agent.graphs.graph_states import State from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage from langchain.agents import create_agent from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import StateGraph, START, END # NOTE: maybe make this into a base_graph_config? @tyro.conf.configure(tyro.conf.SuppressFixed) @@ -46,6 +48,9 @@ class ReactGraph(GraphBase): self.config = config self.populate_modules() + self.workflow = self._build_graph() + + self.streamable_tags = [["main_llm"]] def populate_modules(self): self.llm = init_chat_model(model=self.config.llm_name, @@ -58,78 +63,38 @@ class ReactGraph(GraphBase): self.tool_manager:ToolManager = self.config.tool_manager_config.setup() self.memory = MemorySaver() tools = self.tool_manager.get_langchain_tools() - self.workflow = create_agent(self.llm, tools, checkpointer=self.memory) + self.agent = create_agent(self.llm, tools, checkpointer=self.memory) with open(self.config.sys_prompt_f, "r") as f: self.sys_prompt = f.read() - - def _get_human_msg(self, *nargs): - msgs = nargs[0]["messages"] - - candidate_hum_msg = None - for msg in msgs: - if isinstance(msg, HumanMessage): - candidate_hum_msg = msg - break - - assert isinstance(candidate_hum_msg, HumanMessage), "not a human message" - - return candidate_hum_msg - def _prep_inp(self, *nargs): - assert len(nargs) == 2, "should have 2 arguements" - - human_msg = self._get_human_msg(*nargs) - conf = nargs[1] - return {"messages":[SystemMessage(self.sys_prompt), human_msg]}, conf - - - def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): - """ - as_stream (bool): for debug only, gets the agent to print its thoughts - """ - nargs = self._prep_inp(*nargs) - if as_stream: - for step in self.workflow.stream(*nargs, stream_mode="values", **kwargs): - step["messages"][-1].pretty_print() - out = step + def _agent_call(self, state:State): + if state.get("messages") is not None: + inp = state["messages"], state["inp"][1] else: - out = self.workflow.invoke(*nargs, **kwargs) + inp = state["inp"] + + inp = {"messages":[ + SystemMessage( + self.sys_prompt + ), + *self._get_inp_msgs(state) + ]}, state["inp"][1] - msgs_list = tree_leaves(out) - for e in msgs_list: - if isinstance(e, BaseMessage): - e.pretty_print() + out = self.agent.invoke(*inp) + return {"messages": out["messages"]} - if as_raw: - return msgs_list - else: - return msgs_list[-1].content - async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): - """ - Async version of invoke using LangGraph's native async support. - as_stream (bool): for debug only, gets the agent to print its thoughts - """ - nargs = self._prep_inp(*nargs) - if as_stream: - async for step in self.workflow.astream(*nargs, stream_mode="values", **kwargs): - step["messages"][-1].pretty_print() - out = step - else: - out = await self.workflow.ainvoke(*nargs, **kwargs) + def _build_graph(self): + builder = StateGraph(State) - msgs_list = tree_leaves(out) + builder.add_node("agent_call", self._agent_call) + + builder.add_edge(START, "agent_call") + builder.add_edge("agent_call", END) - for e in msgs_list: - if isinstance(e, BaseMessage): - e.pretty_print() - - if as_raw: - return msgs_list - else: - return msgs_list[-1].content + return builder.compile() if __name__ == "__main__": @@ -139,22 +104,25 @@ if __name__ == "__main__": load_dotenv() route:ReactGraph = ReactGraphConfig().setup() - graph = route.workflow + graph = route.agent nargs = { "messages": [SystemMessage("you are a helpful bot named jarvis"), - HumanMessage("use the calculator tool to calculate 92*55 and say the answer")] + HumanMessage("say something cool")] },{"configurable": {"thread_id": "3"}} + for out in route.invoke(*nargs, as_stream=True): + print(out) + # out = route.invoke(*nargs) # assert 0 # for mode, data in graph.stream(*nargs, stream_mode=["messages", "values"]): # print(data) - for _, mode, out in graph.stream(*nargs, subgraphs=True, - stream_mode=["messages", "values"]): - if mode == "values": - msgs = out.get("messages") - l = len(msgs) if msgs is not None else -1 - print(type(out), out.keys(), l) \ No newline at end of file + # for _, mode, out in graph.stream(*nargs, subgraphs=True, + # stream_mode=["messages", "values"]): + # if mode == "values": + # msgs = out.get("messages") + # l = len(msgs) if msgs is not None else -1 + # print(type(out), out.keys(), l) \ No newline at end of file