better handling
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Type
|
||||
import tyro
|
||||
import asyncio
|
||||
import websockets
|
||||
from websockets.asyncio.server import ServerConnection
|
||||
from loguru import logger
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
@@ -42,6 +43,8 @@ class PipelineConfig(InstantiateConfig):
|
||||
class Pipeline:
|
||||
def __init__(self, config:PipelineConfig):
|
||||
self.config = config
|
||||
|
||||
self.populate_module()
|
||||
|
||||
def populate_module(self):
|
||||
self.llm = init_chat_model(model=self.config.llm_name,
|
||||
@@ -49,11 +52,23 @@ class Pipeline:
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url)
|
||||
|
||||
self.agent = self.llm ## NOTE: placeholder for now
|
||||
self.agent = self.llm # NOTE: placeholder for now, add graph later
|
||||
|
||||
def respond(self, msg:str):
|
||||
return self.agent.invoke(msg)
|
||||
|
||||
async def handle_connection(self, inp:str):
|
||||
return "hello"
|
||||
async def handle_connection(self, websocket:ServerConnection):
|
||||
try:
|
||||
async for message in websocket:
|
||||
if isinstance(message, bytes):
|
||||
#NOTE: For binary, echo back.
|
||||
await websocket.send(message)
|
||||
else:
|
||||
# TODO: handle this better, will have system/user prompt send here
|
||||
response = self.respond(message)
|
||||
await websocket.send(response)
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
|
||||
|
||||
async def start_server(self):
|
||||
|
||||
Reference in New Issue
Block a user