diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index 3cdf125..6282325 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field, is_dataclass from typing import Type, List, Callable, Any import tyro +import jax from lang_agent.config import KeyConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig @@ -49,7 +50,7 @@ class ReactGraph(GraphBase): tools = self.tool_manager.get_langchain_tools() self.agent = create_agent(self.llm, tools, checkpointer=memory) - def invoke(self, *nargs, as_stream:bool=False, **kwargs): + def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): """ as_stream (bool): for debug only, gets the agent to print its thoughts """ @@ -61,4 +62,7 @@ class ReactGraph(GraphBase): else: out = self.agent.invoke(*nargs, **kwargs) - return out + if as_raw: + return out + else: + return jax.tree.leaves(out)[-1].content diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index e303879..85948c1 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -61,7 +61,7 @@ class RoutingGraph(GraphBase): self.workflow = self._build_graph() - def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): + def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs)->str: self._validate_input(*nargs, **kwargs) if as_stream: @@ -81,7 +81,7 @@ class RoutingGraph(GraphBase): if as_raw: return msg_list - return msg_list[-1].content + return msg_list[-1].content def _validate_input(self, *nargs, **kwargs): assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 67de96b..12c72bc 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -117,7 +117,7 @@ class Pipeline: return f"ws://{self.config.host}:{self.config.port}" - def chat(self, inp:str, as_stream:bool=False)->str: + def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False)->str: u = """ 你叫小盏,是一个点餐助手,你的回复要简洁明了,不需要给用户提供选择。对话过程中不要出现提示用户下一步的操作,用可爱的语气进行交流,根据用户的语言使用对应的语言回答 @@ -142,7 +142,7 @@ class Pipeline: inp = {"messages":[SystemMessage(u), HumanMessage(inp)]}, {"configurable": {"thread_id": 3}} - out = self.invoke(*inp, as_stream=as_stream) + out = self.invoke(*inp, as_stream=as_stream, as_raw=as_raw) # return out['messages'][-1].content return out \ No newline at end of file