diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index d195f50..cb8b62a 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -3,6 +3,7 @@ from typing import Type import tyro import asyncio import websockets +from websockets.asyncio.server import ServerConnection from loguru import logger from langchain.chat_models import init_chat_model @@ -42,6 +43,8 @@ class PipelineConfig(InstantiateConfig): class Pipeline: def __init__(self, config:PipelineConfig): self.config = config + + self.populate_module() def populate_module(self): self.llm = init_chat_model(model=self.config.llm_name, @@ -49,11 +52,23 @@ class Pipeline: api_key=self.config.api_key, base_url=self.config.base_url) - self.agent = self.llm ## NOTE: placeholder for now + self.agent = self.llm # NOTE: placeholder for now, add graph later + def respond(self, msg:str): + return self.agent.invoke(msg) - async def handle_connection(self, inp:str): - return "hello" + async def handle_connection(self, websocket:ServerConnection): + try: + async for message in websocket: + if isinstance(message, bytes): + #NOTE: For binary, echo back. + await websocket.send(message) + else: + # TODO: handle this better, will have system/user prompt send here + response = self.respond(message) + await websocket.send(response) + except websockets.ConnectionClosed: + pass async def start_server(self):