diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index 6282325..ff46214 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -62,7 +62,9 @@ class ReactGraph(GraphBase): else: out = self.agent.invoke(*nargs, **kwargs) + msgs_list = jax.tree.leaves(out) + if as_raw: return out else: - return jax.tree.leaves(out)[-1].content + return msgs_list[-1].content