make react a one node workflow

This commit is contained in:
2026-01-30 18:47:01 +08:00
parent e3e432550f
commit 40ee9b7420

View File

@@ -8,11 +8,13 @@ from lang_agent.config import KeyConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.base import GraphBase from lang_agent.base import GraphBase
from lang_agent.utils import tree_leaves from lang_agent.utils import tree_leaves
from lang_agent.graphs.graph_states import State
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
from langchain.agents import create_agent from langchain.agents import create_agent
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START, END
# NOTE: maybe make this into a base_graph_config? # NOTE: maybe make this into a base_graph_config?
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@@ -46,6 +48,9 @@ class ReactGraph(GraphBase):
self.config = config self.config = config
self.populate_modules() self.populate_modules()
self.workflow = self._build_graph()
self.streamable_tags = [["main_llm"]]
def populate_modules(self): def populate_modules(self):
self.llm = init_chat_model(model=self.config.llm_name, self.llm = init_chat_model(model=self.config.llm_name,
@@ -58,78 +63,38 @@ class ReactGraph(GraphBase):
self.tool_manager:ToolManager = self.config.tool_manager_config.setup() self.tool_manager:ToolManager = self.config.tool_manager_config.setup()
self.memory = MemorySaver() self.memory = MemorySaver()
tools = self.tool_manager.get_langchain_tools() tools = self.tool_manager.get_langchain_tools()
self.workflow = create_agent(self.llm, tools, checkpointer=self.memory) self.agent = create_agent(self.llm, tools, checkpointer=self.memory)
with open(self.config.sys_prompt_f, "r") as f: with open(self.config.sys_prompt_f, "r") as f:
self.sys_prompt = f.read() self.sys_prompt = f.read()
def _get_human_msg(self, *nargs):
msgs = nargs[0]["messages"]
candidate_hum_msg = None
for msg in msgs:
if isinstance(msg, HumanMessage):
candidate_hum_msg = msg
break
assert isinstance(candidate_hum_msg, HumanMessage), "not a human message"
return candidate_hum_msg
def _prep_inp(self, *nargs): def _agent_call(self, state:State):
assert len(nargs) == 2, "should have 2 arguements" if state.get("messages") is not None:
inp = state["messages"], state["inp"][1]
human_msg = self._get_human_msg(*nargs)
conf = nargs[1]
return {"messages":[SystemMessage(self.sys_prompt), human_msg]}, conf
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
"""
as_stream (bool): for debug only, gets the agent to print its thoughts
"""
nargs = self._prep_inp(*nargs)
if as_stream:
for step in self.workflow.stream(*nargs, stream_mode="values", **kwargs):
step["messages"][-1].pretty_print()
out = step
else: else:
out = self.workflow.invoke(*nargs, **kwargs) inp = state["inp"]
inp = {"messages":[
SystemMessage(
self.sys_prompt
),
*self._get_inp_msgs(state)
]}, state["inp"][1]
msgs_list = tree_leaves(out)
for e in msgs_list: out = self.agent.invoke(*inp)
if isinstance(e, BaseMessage): return {"messages": out["messages"]}
e.pretty_print()
if as_raw:
return msgs_list
else:
return msgs_list[-1].content
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): def _build_graph(self):
""" builder = StateGraph(State)
Async version of invoke using LangGraph's native async support.
as_stream (bool): for debug only, gets the agent to print its thoughts
"""
nargs = self._prep_inp(*nargs)
if as_stream:
async for step in self.workflow.astream(*nargs, stream_mode="values", **kwargs):
step["messages"][-1].pretty_print()
out = step
else:
out = await self.workflow.ainvoke(*nargs, **kwargs)
msgs_list = tree_leaves(out) builder.add_node("agent_call", self._agent_call)
builder.add_edge(START, "agent_call")
builder.add_edge("agent_call", END)
for e in msgs_list: return builder.compile()
if isinstance(e, BaseMessage):
e.pretty_print()
if as_raw:
return msgs_list
else:
return msgs_list[-1].content
if __name__ == "__main__": if __name__ == "__main__":
@@ -139,22 +104,25 @@ if __name__ == "__main__":
load_dotenv() load_dotenv()
route:ReactGraph = ReactGraphConfig().setup() route:ReactGraph = ReactGraphConfig().setup()
graph = route.workflow graph = route.agent
nargs = { nargs = {
"messages": [SystemMessage("you are a helpful bot named jarvis"), "messages": [SystemMessage("you are a helpful bot named jarvis"),
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")] HumanMessage("say something cool")]
},{"configurable": {"thread_id": "3"}} },{"configurable": {"thread_id": "3"}}
for out in route.invoke(*nargs, as_stream=True):
print(out)
# out = route.invoke(*nargs) # out = route.invoke(*nargs)
# assert 0 # assert 0
# for mode, data in graph.stream(*nargs, stream_mode=["messages", "values"]): # for mode, data in graph.stream(*nargs, stream_mode=["messages", "values"]):
# print(data) # print(data)
for _, mode, out in graph.stream(*nargs, subgraphs=True, # for _, mode, out in graph.stream(*nargs, subgraphs=True,
stream_mode=["messages", "values"]): # stream_mode=["messages", "values"]):
if mode == "values": # if mode == "values":
msgs = out.get("messages") # msgs = out.get("messages")
l = len(msgs) if msgs is not None else -1 # l = len(msgs) if msgs is not None else -1
print(type(out), out.keys(), l) # print(type(out), out.keys(), l)