make react a one node workflow
This commit is contained in:
@@ -8,11 +8,13 @@ from lang_agent.config import KeyConfig
|
|||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
from lang_agent.utils import tree_leaves
|
from lang_agent.utils import tree_leaves
|
||||||
|
from lang_agent.graphs.graph_states import State
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
from langgraph.graph import StateGraph, START, END
|
||||||
|
|
||||||
# NOTE: maybe make this into a base_graph_config?
|
# NOTE: maybe make this into a base_graph_config?
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@@ -46,6 +48,9 @@ class ReactGraph(GraphBase):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.populate_modules()
|
self.populate_modules()
|
||||||
|
self.workflow = self._build_graph()
|
||||||
|
|
||||||
|
self.streamable_tags = [["main_llm"]]
|
||||||
|
|
||||||
def populate_modules(self):
|
def populate_modules(self):
|
||||||
self.llm = init_chat_model(model=self.config.llm_name,
|
self.llm = init_chat_model(model=self.config.llm_name,
|
||||||
@@ -58,78 +63,38 @@ class ReactGraph(GraphBase):
|
|||||||
self.tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
self.tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
||||||
self.memory = MemorySaver()
|
self.memory = MemorySaver()
|
||||||
tools = self.tool_manager.get_langchain_tools()
|
tools = self.tool_manager.get_langchain_tools()
|
||||||
self.workflow = create_agent(self.llm, tools, checkpointer=self.memory)
|
self.agent = create_agent(self.llm, tools, checkpointer=self.memory)
|
||||||
|
|
||||||
with open(self.config.sys_prompt_f, "r") as f:
|
with open(self.config.sys_prompt_f, "r") as f:
|
||||||
self.sys_prompt = f.read()
|
self.sys_prompt = f.read()
|
||||||
|
|
||||||
def _get_human_msg(self, *nargs):
|
|
||||||
msgs = nargs[0]["messages"]
|
|
||||||
|
|
||||||
candidate_hum_msg = None
|
|
||||||
for msg in msgs:
|
|
||||||
if isinstance(msg, HumanMessage):
|
|
||||||
candidate_hum_msg = msg
|
|
||||||
break
|
|
||||||
|
|
||||||
assert isinstance(candidate_hum_msg, HumanMessage), "not a human message"
|
|
||||||
|
|
||||||
return candidate_hum_msg
|
|
||||||
|
|
||||||
def _prep_inp(self, *nargs):
|
def _agent_call(self, state:State):
|
||||||
assert len(nargs) == 2, "should have 2 arguements"
|
if state.get("messages") is not None:
|
||||||
|
inp = state["messages"], state["inp"][1]
|
||||||
human_msg = self._get_human_msg(*nargs)
|
|
||||||
conf = nargs[1]
|
|
||||||
return {"messages":[SystemMessage(self.sys_prompt), human_msg]}, conf
|
|
||||||
|
|
||||||
|
|
||||||
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
|
||||||
"""
|
|
||||||
as_stream (bool): for debug only, gets the agent to print its thoughts
|
|
||||||
"""
|
|
||||||
nargs = self._prep_inp(*nargs)
|
|
||||||
if as_stream:
|
|
||||||
for step in self.workflow.stream(*nargs, stream_mode="values", **kwargs):
|
|
||||||
step["messages"][-1].pretty_print()
|
|
||||||
out = step
|
|
||||||
else:
|
else:
|
||||||
out = self.workflow.invoke(*nargs, **kwargs)
|
inp = state["inp"]
|
||||||
|
|
||||||
|
inp = {"messages":[
|
||||||
|
SystemMessage(
|
||||||
|
self.sys_prompt
|
||||||
|
),
|
||||||
|
*self._get_inp_msgs(state)
|
||||||
|
]}, state["inp"][1]
|
||||||
|
|
||||||
msgs_list = tree_leaves(out)
|
|
||||||
|
|
||||||
for e in msgs_list:
|
out = self.agent.invoke(*inp)
|
||||||
if isinstance(e, BaseMessage):
|
return {"messages": out["messages"]}
|
||||||
e.pretty_print()
|
|
||||||
|
|
||||||
if as_raw:
|
|
||||||
return msgs_list
|
|
||||||
else:
|
|
||||||
return msgs_list[-1].content
|
|
||||||
|
|
||||||
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
def _build_graph(self):
|
||||||
"""
|
builder = StateGraph(State)
|
||||||
Async version of invoke using LangGraph's native async support.
|
|
||||||
as_stream (bool): for debug only, gets the agent to print its thoughts
|
|
||||||
"""
|
|
||||||
nargs = self._prep_inp(*nargs)
|
|
||||||
if as_stream:
|
|
||||||
async for step in self.workflow.astream(*nargs, stream_mode="values", **kwargs):
|
|
||||||
step["messages"][-1].pretty_print()
|
|
||||||
out = step
|
|
||||||
else:
|
|
||||||
out = await self.workflow.ainvoke(*nargs, **kwargs)
|
|
||||||
|
|
||||||
msgs_list = tree_leaves(out)
|
builder.add_node("agent_call", self._agent_call)
|
||||||
|
|
||||||
|
builder.add_edge(START, "agent_call")
|
||||||
|
builder.add_edge("agent_call", END)
|
||||||
|
|
||||||
for e in msgs_list:
|
return builder.compile()
|
||||||
if isinstance(e, BaseMessage):
|
|
||||||
e.pretty_print()
|
|
||||||
|
|
||||||
if as_raw:
|
|
||||||
return msgs_list
|
|
||||||
else:
|
|
||||||
return msgs_list[-1].content
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -139,22 +104,25 @@ if __name__ == "__main__":
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
route:ReactGraph = ReactGraphConfig().setup()
|
route:ReactGraph = ReactGraphConfig().setup()
|
||||||
graph = route.workflow
|
graph = route.agent
|
||||||
|
|
||||||
nargs = {
|
nargs = {
|
||||||
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
||||||
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
|
HumanMessage("say something cool")]
|
||||||
},{"configurable": {"thread_id": "3"}}
|
},{"configurable": {"thread_id": "3"}}
|
||||||
|
|
||||||
|
for out in route.invoke(*nargs, as_stream=True):
|
||||||
|
print(out)
|
||||||
|
|
||||||
# out = route.invoke(*nargs)
|
# out = route.invoke(*nargs)
|
||||||
# assert 0
|
# assert 0
|
||||||
|
|
||||||
# for mode, data in graph.stream(*nargs, stream_mode=["messages", "values"]):
|
# for mode, data in graph.stream(*nargs, stream_mode=["messages", "values"]):
|
||||||
# print(data)
|
# print(data)
|
||||||
|
|
||||||
for _, mode, out in graph.stream(*nargs, subgraphs=True,
|
# for _, mode, out in graph.stream(*nargs, subgraphs=True,
|
||||||
stream_mode=["messages", "values"]):
|
# stream_mode=["messages", "values"]):
|
||||||
if mode == "values":
|
# if mode == "values":
|
||||||
msgs = out.get("messages")
|
# msgs = out.get("messages")
|
||||||
l = len(msgs) if msgs is not None else -1
|
# l = len(msgs) if msgs is not None else -1
|
||||||
print(type(out), out.keys(), l)
|
# print(type(out), out.keys(), l)
|
||||||
Reference in New Issue
Block a user