diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 0fabac4..4e4f004 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Type +from typing import Type, List import tyro import asyncio import websockets @@ -8,6 +8,7 @@ from loguru import logger import os from langchain.chat_models import init_chat_model +from langchain_core.messages import SystemMessage, HumanMessage from langgraph.prebuilt import create_react_agent from langgraph.checkpoint.memory import MemorySaver @@ -72,7 +73,7 @@ class Pipeline: tools = self.tool_manager.get_tools() self.agent = create_react_agent(self.llm, tools, checkpointer=memory) - def respond(self, msg:str): + def respond(self, msg:str | List[SystemMessage, HumanMessage]): return self.agent.invoke(msg) async def handle_connection(self, websocket:ServerConnection):