more types
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Type
|
from typing import Type, List
|
||||||
import tyro
|
import tyro
|
||||||
import asyncio
|
import asyncio
|
||||||
import websockets
|
import websockets
|
||||||
@@ -8,6 +8,7 @@ from loguru import logger
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
from langchain_core.messages import SystemMessage, HumanMessage
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
@@ -72,7 +73,7 @@ class Pipeline:
|
|||||||
tools = self.tool_manager.get_tools()
|
tools = self.tool_manager.get_tools()
|
||||||
self.agent = create_react_agent(self.llm, tools, checkpointer=memory)
|
self.agent = create_react_agent(self.llm, tools, checkpointer=memory)
|
||||||
|
|
||||||
def respond(self, msg:str):
|
def respond(self, msg:str | List[SystemMessage, HumanMessage]):
|
||||||
return self.agent.invoke(msg)
|
return self.agent.invoke(msg)
|
||||||
|
|
||||||
async def handle_connection(self, websocket:ServerConnection):
|
async def handle_connection(self, websocket:ServerConnection):
|
||||||
|
|||||||
Reference in New Issue
Block a user