diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 93b0ce0..1924b39 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -75,8 +75,16 @@ class Pipeline: self.agent = create_react_agent(self.llm, tools, checkpointer=memory) - def invoke(self, *nargs, **kwargs): - return self.agent.invoke(*nargs, **kwargs) + def invoke(self, *nargs, as_stream:bool=False, **kwargs): + + if as_stream: + for step in self.agent.stream(*nargs, stream_mode="values", **kwargs): + step["messages"][-1].pretty_print() + out = step + else: + out = self.agent.invoke(*nargs, **kwargs) + + return out async def handle_connection(self, websocket:ServerConnection): try: @@ -106,18 +114,14 @@ class Pipeline: def get_ws_url(self): return f"ws://{self.config.host}:{self.config.port}" + def chat(self, inp:str, as_stream:bool=False): """ as_stream (bool): for debug only, gets the agent to print its thoughts """ inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": 3}} - if as_stream: - for step in self.agent.stream(*inp, stream_mode="values"): - step["messages"][-1].pretty_print() - out = step - else: - out = self.invoke(*inp) + out = self.invoke(*inp, as_stream=as_stream) return out['messages'][-1].content @@ -125,6 +129,4 @@ class Pipeline: if __name__ == "__main__": pipeline:Pipeline = PipelineConfig().setup() - u = pipeline.chat("use the calculator tool to calculate what is 900 * 321", as_stream=True) - print("================out") - print(u) \ No newline at end of file + u = pipeline.chat("查查光与尘这杯茶的特点", as_stream=True) \ No newline at end of file