as raw
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, List, Callable, Any
|
||||
import tyro
|
||||
import jax
|
||||
|
||||
from lang_agent.config import KeyConfig
|
||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||
@@ -49,7 +50,7 @@ class ReactGraph(GraphBase):
|
||||
tools = self.tool_manager.get_langchain_tools()
|
||||
self.agent = create_agent(self.llm, tools, checkpointer=memory)
|
||||
|
||||
def invoke(self, *nargs, as_stream:bool=False, **kwargs):
|
||||
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||
"""
|
||||
as_stream (bool): for debug only, gets the agent to print its thoughts
|
||||
"""
|
||||
@@ -61,4 +62,7 @@ class ReactGraph(GraphBase):
|
||||
else:
|
||||
out = self.agent.invoke(*nargs, **kwargs)
|
||||
|
||||
return out
|
||||
if as_raw:
|
||||
return out
|
||||
else:
|
||||
return jax.tree.leaves(out)[-1].content
|
||||
|
||||
@@ -61,7 +61,7 @@ class RoutingGraph(GraphBase):
|
||||
self.workflow = self._build_graph()
|
||||
|
||||
|
||||
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs)->str:
|
||||
self._validate_input(*nargs, **kwargs)
|
||||
|
||||
if as_stream:
|
||||
@@ -81,7 +81,7 @@ class RoutingGraph(GraphBase):
|
||||
if as_raw:
|
||||
return msg_list
|
||||
|
||||
return msg_list[-1].content
|
||||
return msg_list[-1].content
|
||||
|
||||
def _validate_input(self, *nargs, **kwargs):
|
||||
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
|
||||
|
||||
Reference in New Issue
Block a user