update react with ainvoke
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, List, Callable, Any
|
||||
from typing import Type, List, Callable, Any, AsyncIterator
|
||||
import tyro
|
||||
import jax
|
||||
|
||||
@@ -73,6 +73,30 @@ class ReactGraph(GraphBase):
|
||||
else:
|
||||
return msgs_list[-1].content
|
||||
|
||||
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||
"""
|
||||
Async version of invoke using LangGraph's native async support.
|
||||
as_stream (bool): for debug only, gets the agent to print its thoughts
|
||||
"""
|
||||
|
||||
if as_stream:
|
||||
async for step in self.agent.astream(*nargs, stream_mode="values", **kwargs):
|
||||
step["messages"][-1].pretty_print()
|
||||
out = step
|
||||
else:
|
||||
out = await self.agent.ainvoke(*nargs, **kwargs)
|
||||
|
||||
msgs_list = jax.tree.leaves(out)
|
||||
|
||||
for e in msgs_list:
|
||||
if isinstance(e, BaseMessage):
|
||||
e.pretty_print()
|
||||
|
||||
if as_raw:
|
||||
return msgs_list
|
||||
else:
|
||||
return msgs_list[-1].content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dotenv import load_dotenv
|
||||
|
||||
Reference in New Issue
Block a user