more types

This commit is contained in:
2025-10-15 13:30:25 +08:00
parent bc898e526f
commit 2b22cbc765

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Type from typing import Type, List
import tyro import tyro
import asyncio import asyncio
import websockets import websockets
@@ -8,6 +8,7 @@ from loguru import logger
import os import os
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
@@ -72,7 +73,7 @@ class Pipeline:
tools = self.tool_manager.get_tools() tools = self.tool_manager.get_tools()
self.agent = create_react_agent(self.llm, tools, checkpointer=memory) self.agent = create_react_agent(self.llm, tools, checkpointer=memory)
def respond(self, msg:str): def respond(self, msg:str | List[SystemMessage, HumanMessage]):
return self.agent.invoke(msg) return self.agent.invoke(msg)
async def handle_connection(self, websocket:ServerConnection): async def handle_connection(self, websocket:ServerConnection):