This commit is contained in:
2025-10-23 20:38:34 +08:00
parent 4f14017364
commit da2f1575bd
3 changed files with 10 additions and 6 deletions

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass, field, is_dataclass
from typing import Type, List, Callable, Any from typing import Type, List, Callable, Any
import tyro import tyro
import jax
from lang_agent.config import KeyConfig from lang_agent.config import KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig
@@ -49,7 +50,7 @@ class ReactGraph(GraphBase):
tools = self.tool_manager.get_langchain_tools() tools = self.tool_manager.get_langchain_tools()
self.agent = create_agent(self.llm, tools, checkpointer=memory) self.agent = create_agent(self.llm, tools, checkpointer=memory)
def invoke(self, *nargs, as_stream:bool=False, **kwargs): def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
""" """
as_stream (bool): for debug only, gets the agent to print its thoughts as_stream (bool): for debug only, gets the agent to print its thoughts
""" """
@@ -61,4 +62,7 @@ class ReactGraph(GraphBase):
else: else:
out = self.agent.invoke(*nargs, **kwargs) out = self.agent.invoke(*nargs, **kwargs)
return out if as_raw:
return out
else:
return jax.tree.leaves(out)[-1].content

View File

@@ -61,7 +61,7 @@ class RoutingGraph(GraphBase):
self.workflow = self._build_graph() self.workflow = self._build_graph()
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs)->str:
self._validate_input(*nargs, **kwargs) self._validate_input(*nargs, **kwargs)
if as_stream: if as_stream:
@@ -81,7 +81,7 @@ class RoutingGraph(GraphBase):
if as_raw: if as_raw:
return msg_list return msg_list
return msg_list[-1].content return msg_list[-1].content
def _validate_input(self, *nargs, **kwargs): def _validate_input(self, *nargs, **kwargs):
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"

View File

@@ -117,7 +117,7 @@ class Pipeline:
return f"ws://{self.config.host}:{self.config.port}" return f"ws://{self.config.host}:{self.config.port}"
def chat(self, inp:str, as_stream:bool=False)->str: def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False)->str:
u = """ u = """
你叫小盏,是一个点餐助手,你的回复要简洁明了,不需要给用户提供选择。对话过程中不要出现提示用户下一步的操作,用可爱的语气进行交流,根据用户的语言使用对应的语言回答 你叫小盏,是一个点餐助手,你的回复要简洁明了,不需要给用户提供选择。对话过程中不要出现提示用户下一步的操作,用可爱的语气进行交流,根据用户的语言使用对应的语言回答
@@ -142,7 +142,7 @@ class Pipeline:
inp = {"messages":[SystemMessage(u), inp = {"messages":[SystemMessage(u),
HumanMessage(inp)]}, {"configurable": {"thread_id": 3}} HumanMessage(inp)]}, {"configurable": {"thread_id": 3}}
out = self.invoke(*inp, as_stream=as_stream) out = self.invoke(*inp, as_stream=as_stream, as_raw=as_raw)
# return out['messages'][-1].content # return out['messages'][-1].content
return out return out