diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 4e4f004..8ef48a3 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -73,8 +73,11 @@ class Pipeline: tools = self.tool_manager.get_tools() self.agent = create_react_agent(self.llm, tools, checkpointer=memory) - def respond(self, msg:str | List[SystemMessage, HumanMessage]): - return self.agent.invoke(msg) + # def respond(self, msg:str | List[SystemMessage, HumanMessage]): + # return self.agent.invoke(msg) + + def invoke(self, *nargs, **kwargs): + return self.agent.invoke(*nargs, **kwargs) async def handle_connection(self, websocket:ServerConnection): try: @@ -84,7 +87,7 @@ class Pipeline: await websocket.send(message) else: # TODO: handle this better, will have system/user prompt send here - response = self.respond(message) + response = self.invoke(message) await websocket.send(response) except websockets.ConnectionClosed: pass