as raw
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
from dataclasses import dataclass, field, is_dataclass
|
from dataclasses import dataclass, field, is_dataclass
|
||||||
from typing import Type, List, Callable, Any
|
from typing import Type, List, Callable, Any
|
||||||
import tyro
|
import tyro
|
||||||
|
import jax
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||||
@@ -49,7 +50,7 @@ class ReactGraph(GraphBase):
|
|||||||
tools = self.tool_manager.get_langchain_tools()
|
tools = self.tool_manager.get_langchain_tools()
|
||||||
self.agent = create_agent(self.llm, tools, checkpointer=memory)
|
self.agent = create_agent(self.llm, tools, checkpointer=memory)
|
||||||
|
|
||||||
def invoke(self, *nargs, as_stream:bool=False, **kwargs):
|
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
as_stream (bool): for debug only, gets the agent to print its thoughts
|
as_stream (bool): for debug only, gets the agent to print its thoughts
|
||||||
"""
|
"""
|
||||||
@@ -61,4 +62,7 @@ class ReactGraph(GraphBase):
|
|||||||
else:
|
else:
|
||||||
out = self.agent.invoke(*nargs, **kwargs)
|
out = self.agent.invoke(*nargs, **kwargs)
|
||||||
|
|
||||||
|
if as_raw:
|
||||||
return out
|
return out
|
||||||
|
else:
|
||||||
|
return jax.tree.leaves(out)[-1].content
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class RoutingGraph(GraphBase):
|
|||||||
self.workflow = self._build_graph()
|
self.workflow = self._build_graph()
|
||||||
|
|
||||||
|
|
||||||
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs)->str:
|
||||||
self._validate_input(*nargs, **kwargs)
|
self._validate_input(*nargs, **kwargs)
|
||||||
|
|
||||||
if as_stream:
|
if as_stream:
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ class Pipeline:
|
|||||||
return f"ws://{self.config.host}:{self.config.port}"
|
return f"ws://{self.config.host}:{self.config.port}"
|
||||||
|
|
||||||
|
|
||||||
def chat(self, inp:str, as_stream:bool=False)->str:
|
def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False)->str:
|
||||||
u = """
|
u = """
|
||||||
你叫小盏,是一个点餐助手,你的回复要简洁明了,不需要给用户提供选择。对话过程中不要出现提示用户下一步的操作,用可爱的语气进行交流,根据用户的语言使用对应的语言回答
|
你叫小盏,是一个点餐助手,你的回复要简洁明了,不需要给用户提供选择。对话过程中不要出现提示用户下一步的操作,用可爱的语气进行交流,根据用户的语言使用对应的语言回答
|
||||||
|
|
||||||
@@ -142,7 +142,7 @@ class Pipeline:
|
|||||||
inp = {"messages":[SystemMessage(u),
|
inp = {"messages":[SystemMessage(u),
|
||||||
HumanMessage(inp)]}, {"configurable": {"thread_id": 3}}
|
HumanMessage(inp)]}, {"configurable": {"thread_id": 3}}
|
||||||
|
|
||||||
out = self.invoke(*inp, as_stream=as_stream)
|
out = self.invoke(*inp, as_stream=as_stream, as_raw=as_raw)
|
||||||
|
|
||||||
# return out['messages'][-1].content
|
# return out['messages'][-1].content
|
||||||
return out
|
return out
|
||||||
Reference in New Issue
Block a user