update react with ainvoke
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass, field, is_dataclass
|
from dataclasses import dataclass, field, is_dataclass
|
||||||
from typing import Type, List, Callable, Any
|
from typing import Type, List, Callable, Any, AsyncIterator
|
||||||
import tyro
|
import tyro
|
||||||
import jax
|
import jax
|
||||||
|
|
||||||
@@ -73,6 +73,30 @@ class ReactGraph(GraphBase):
|
|||||||
else:
|
else:
|
||||||
return msgs_list[-1].content
|
return msgs_list[-1].content
|
||||||
|
|
||||||
|
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Async version of invoke using LangGraph's native async support.
|
||||||
|
as_stream (bool): for debug only, gets the agent to print its thoughts
|
||||||
|
"""
|
||||||
|
|
||||||
|
if as_stream:
|
||||||
|
async for step in self.agent.astream(*nargs, stream_mode="values", **kwargs):
|
||||||
|
step["messages"][-1].pretty_print()
|
||||||
|
out = step
|
||||||
|
else:
|
||||||
|
out = await self.agent.ainvoke(*nargs, **kwargs)
|
||||||
|
|
||||||
|
msgs_list = jax.tree.leaves(out)
|
||||||
|
|
||||||
|
for e in msgs_list:
|
||||||
|
if isinstance(e, BaseMessage):
|
||||||
|
e.pretty_print()
|
||||||
|
|
||||||
|
if as_raw:
|
||||||
|
return msgs_list
|
||||||
|
else:
|
||||||
|
return msgs_list[-1].content
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|||||||
Reference in New Issue
Block a user