diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index 51de1d9..d1a978d 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -65,6 +65,6 @@ class ReactGraph(GraphBase): msgs_list = jax.tree.leaves(out) if as_raw: - return out + return msgs_list else: return msgs_list[-1].content