shft debug to invoke
This commit is contained in:
@@ -75,8 +75,16 @@ class Pipeline:
|
|||||||
self.agent = create_react_agent(self.llm, tools, checkpointer=memory)
|
self.agent = create_react_agent(self.llm, tools, checkpointer=memory)
|
||||||
|
|
||||||
|
|
||||||
def invoke(self, *nargs, **kwargs):
|
def invoke(self, *nargs, as_stream:bool=False, **kwargs):
|
||||||
return self.agent.invoke(*nargs, **kwargs)
|
|
||||||
|
if as_stream:
|
||||||
|
for step in self.agent.stream(*nargs, stream_mode="values", **kwargs):
|
||||||
|
step["messages"][-1].pretty_print()
|
||||||
|
out = step
|
||||||
|
else:
|
||||||
|
out = self.agent.invoke(*nargs, **kwargs)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
async def handle_connection(self, websocket:ServerConnection):
|
async def handle_connection(self, websocket:ServerConnection):
|
||||||
try:
|
try:
|
||||||
@@ -106,18 +114,14 @@ class Pipeline:
|
|||||||
def get_ws_url(self):
|
def get_ws_url(self):
|
||||||
return f"ws://{self.config.host}:{self.config.port}"
|
return f"ws://{self.config.host}:{self.config.port}"
|
||||||
|
|
||||||
|
|
||||||
def chat(self, inp:str, as_stream:bool=False):
|
def chat(self, inp:str, as_stream:bool=False):
|
||||||
"""
|
"""
|
||||||
as_stream (bool): for debug only, gets the agent to print its thoughts
|
as_stream (bool): for debug only, gets the agent to print its thoughts
|
||||||
"""
|
"""
|
||||||
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": 3}}
|
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": 3}}
|
||||||
|
|
||||||
if as_stream:
|
out = self.invoke(*inp, as_stream=as_stream)
|
||||||
for step in self.agent.stream(*inp, stream_mode="values"):
|
|
||||||
step["messages"][-1].pretty_print()
|
|
||||||
out = step
|
|
||||||
else:
|
|
||||||
out = self.invoke(*inp)
|
|
||||||
|
|
||||||
return out['messages'][-1].content
|
return out['messages'][-1].content
|
||||||
|
|
||||||
@@ -125,6 +129,4 @@ class Pipeline:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pipeline:Pipeline = PipelineConfig().setup()
|
pipeline:Pipeline = PipelineConfig().setup()
|
||||||
|
|
||||||
u = pipeline.chat("use the calculator tool to calculate what is 900 * 321", as_stream=True)
|
u = pipeline.chat("查查光与尘这杯茶的特点", as_stream=True)
|
||||||
print("================out")
|
|
||||||
print(u)
|
|
||||||
Reference in New Issue
Block a user