diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index 0279ca1..64c9b15 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field, is_dataclass -from typing import Type, List, Callable, Any +from typing import Type, List, Callable, Any, AsyncIterator import tyro import jax @@ -73,6 +73,30 @@ class ReactGraph(GraphBase): else: return msgs_list[-1].content + async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): + """ + Async version of invoke using LangGraph's native async support. + as_stream (bool): for debug only, gets the agent to print its thoughts + """ + + if as_stream: + async for step in self.agent.astream(*nargs, stream_mode="values", **kwargs): + step["messages"][-1].pretty_print() + out = step + else: + out = await self.agent.ainvoke(*nargs, **kwargs) + + msgs_list = jax.tree.leaves(out) + + for e in msgs_list: + if isinstance(e, BaseMessage): + e.pretty_print() + + if as_raw: + return msgs_list + else: + return msgs_list[-1].content + if __name__ == "__main__": from dotenv import load_dotenv