Files
lang-agent/lang_agent/pipeline.py
2025-10-10 22:20:21 +08:00

95 lines
2.9 KiB
Python

from dataclasses import dataclass, field
from typing import Type
import tyro
import asyncio
import websockets
from websockets.asyncio.server import ServerConnection
from loguru import logger
from langchain.chat_models import init_chat_model
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from lang_agent.config import InstantiateConfig
from lang_agent.rag.simple import SimpleRagConfig, SimpleRag
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
class PipelineConfig(InstantiateConfig):
_target: Type = field(default_factory=lambda: Pipeline)
config_f: str = None
"""path to config file"""
llm_name: str = "qwen-turbo"
"""name of llm"""
llm_provider:str = "openai"
"""provider of the llm"""
api_key:str = None
"""api key for llm"""
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider"""
host:str = "0.0.0.0"
"""where am I hosted"""
port:int = 23
"""what is my port"""
# NOTE: For reference
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
class Pipeline:
def __init__(self, config:PipelineConfig):
self.config = config
self.populate_module()
def populate_module(self):
self.llm = init_chat_model(model=self.config.llm_name,
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url)
# NOTE: placeholder for now, add graph later
self.rag:SimpleRag = self.config.rag_config.setup()
memory = MemorySaver()
tools = []
self.agent = create_react_agent(self.llm, tools, checkpointer=memory)
def respond(self, msg:str):
return self.agent.invoke(msg)
async def handle_connection(self, websocket:ServerConnection):
try:
async for message in websocket:
if isinstance(message, bytes):
#NOTE: For binary, echo back.
await websocket.send(message)
else:
# TODO: handle this better, will have system/user prompt send here
response = self.respond(message)
await websocket.send(response)
except websockets.ConnectionClosed:
pass
async def start_server(self):
async with websockets.serve(
self.handle_connection,
host=self.config.host,
port=self.config.port,
max_size=None, # allow large messages
max_queue=None, # don't bound outgoing queue
):
# print("WebSocket server listening on ws://0.0.0.0:8765")
logger.info(f"listening to ws://{self.config.host}:{self.config.port}")
await asyncio.Future()