better handling

This commit is contained in:
2025-10-10 17:20:28 +08:00
parent 3e750bd98a
commit 076503ee3b

View File

@@ -3,6 +3,7 @@ from typing import Type
import tyro import tyro
import asyncio import asyncio
import websockets import websockets
from websockets.asyncio.server import ServerConnection
from loguru import logger from loguru import logger
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
@@ -42,6 +43,8 @@ class PipelineConfig(InstantiateConfig):
class Pipeline: class Pipeline:
def __init__(self, config:PipelineConfig): def __init__(self, config:PipelineConfig):
self.config = config self.config = config
self.populate_module()
def populate_module(self): def populate_module(self):
self.llm = init_chat_model(model=self.config.llm_name, self.llm = init_chat_model(model=self.config.llm_name,
@@ -49,11 +52,23 @@ class Pipeline:
api_key=self.config.api_key, api_key=self.config.api_key,
base_url=self.config.base_url) base_url=self.config.base_url)
self.agent = self.llm ## NOTE: placeholder for now self.agent = self.llm # NOTE: placeholder for now, add graph later
def respond(self, msg:str):
return self.agent.invoke(msg)
async def handle_connection(self, inp:str): async def handle_connection(self, websocket:ServerConnection):
return "hello" try:
async for message in websocket:
if isinstance(message, bytes):
#NOTE: For binary, echo back.
await websocket.send(message)
else:
# TODO: handle this better, will have system/user prompt send here
response = self.respond(message)
await websocket.send(response)
except websockets.ConnectionClosed:
pass
async def start_server(self): async def start_server(self):