188 lines
6.3 KiB
Python
188 lines
6.3 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import Type
|
|
import tyro
|
|
import asyncio
|
|
import websockets
|
|
from websockets.asyncio.server import ServerConnection
|
|
from loguru import logger
|
|
import os
|
|
|
|
from langchain.chat_models import init_chat_model
|
|
from langchain_core.messages import SystemMessage, HumanMessage
|
|
|
|
from langchain.agents import create_agent
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
from lang_agent.config import InstantiateConfig, KeyConfig
|
|
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
|
from lang_agent.base import GraphBase
|
|
|
|
|
|
DEFAULT_PROMPT="""you are a helpful helper
|
|
|
|
"""
|
|
|
|
|
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
|
@dataclass
|
|
class PipelineConfig(KeyConfig):
|
|
_target: Type = field(default_factory=lambda: Pipeline)
|
|
|
|
config_f: str = None
|
|
"""path to config file"""
|
|
|
|
llm_name: str = "qwen-plus"
|
|
"""name of llm; use default for qwen-plus"""
|
|
|
|
llm_provider:str = "openai"
|
|
"""provider of the llm; use default for openai"""
|
|
|
|
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
"""base url; could be used to overwrite the baseurl in llm provider"""
|
|
|
|
host:str = "0.0.0.0"
|
|
"""where am I hosted"""
|
|
|
|
port:int = 23
|
|
"""what is my port"""
|
|
|
|
# graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
|
|
graph_config: AnnotatedGraph = field(default_factory=RoutingConfig)
|
|
|
|
|
|
|
|
|
|
class Pipeline:
|
|
def __init__(self, config:PipelineConfig):
|
|
self.config = config
|
|
|
|
self.populate_module()
|
|
|
|
def populate_module(self):
|
|
if self.config.llm_name is None:
|
|
logger.info(f"setting llm_provider to default")
|
|
self.config.llm_name = "qwen-turbo"
|
|
self.config.llm_provider = "openai"
|
|
else:
|
|
self.config.graph_config.llm_name = self.config.llm_name
|
|
self.config.graph_config.llm_provider = self.config.llm_provider
|
|
self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url
|
|
self.config.graph_config.api_key = self.config.api_key
|
|
|
|
self.graph:GraphBase = self.config.graph_config.setup()
|
|
|
|
def show_graph(self):
|
|
if hasattr(self.graph, "show_graph"):
|
|
logger.info("showing graph")
|
|
self.graph.show_graph()
|
|
else:
|
|
logger.info(f"show graph not supported for {type(self.graph)}")
|
|
|
|
def invoke(self, *nargs, **kwargs):
|
|
out = self.graph.invoke(*nargs, **kwargs)
|
|
|
|
# If streaming, yield chunks from the generator
|
|
if kwargs.get("as_stream"):
|
|
return self._stream_res(out)
|
|
|
|
# Non-streaming path
|
|
if kwargs.get("as_raw"):
|
|
return out
|
|
|
|
if isinstance(out, SystemMessage) or isinstance(out, HumanMessage):
|
|
return out.content
|
|
|
|
if isinstance(out, list):
|
|
return out[-1].content
|
|
|
|
if isinstance(out, str):
|
|
return out
|
|
|
|
assert 0, "something is wrong"
|
|
|
|
|
|
def _stream_res(self, out:list):
|
|
for chunk in out:
|
|
yield chunk
|
|
|
|
async def _astream_res(self, out):
|
|
"""Async version of _stream_res for async generators."""
|
|
async for chunk in out:
|
|
yield chunk
|
|
|
|
def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'):
|
|
"""
|
|
as_stream (bool): if true, enable the thing to be streamable
|
|
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
|
"""
|
|
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
|
|
u = DEFAULT_PROMPT
|
|
|
|
device_id = "0"
|
|
spl_ls = thread_id.split("_")
|
|
assert len(spl_ls) <= 2, "something wrong!"
|
|
if len(spl_ls) == 2:
|
|
thread_id, device_id = spl_ls
|
|
|
|
inp = {"messages":[SystemMessage(u),
|
|
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
|
"device_id":device_id}}
|
|
|
|
out = self.invoke(*inp, as_stream=as_stream, as_raw=as_raw)
|
|
|
|
if as_stream:
|
|
# Yield chunks from the generator
|
|
return self._stream_res(out)
|
|
else:
|
|
return out
|
|
|
|
async def ainvoke(self, *nargs, **kwargs):
|
|
"""Async version of invoke using LangGraph's native async support."""
|
|
out = await self.graph.ainvoke(*nargs, **kwargs)
|
|
|
|
# If streaming, return async generator
|
|
if kwargs.get("as_stream"):
|
|
return self._astream_res(out)
|
|
|
|
# Non-streaming path
|
|
if kwargs.get("as_raw"):
|
|
return out
|
|
|
|
if isinstance(out, SystemMessage) or isinstance(out, HumanMessage):
|
|
return out.content
|
|
|
|
if isinstance(out, list):
|
|
return out[-1].content
|
|
|
|
if isinstance(out, str):
|
|
return out
|
|
|
|
assert 0, "something is wrong"
|
|
|
|
async def achat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'):
|
|
"""
|
|
Async version of chat using LangGraph's native async support.
|
|
|
|
as_stream (bool): if true, enable the thing to be streamable
|
|
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
|
"""
|
|
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
|
|
u = DEFAULT_PROMPT
|
|
|
|
device_id = "0"
|
|
spl_ls = thread_id.split("_")
|
|
assert len(spl_ls) <= 2, "something wrong!"
|
|
if len(spl_ls) == 2:
|
|
thread_id, device_id = spl_ls
|
|
print(f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m")
|
|
|
|
inp_data = {"messages":[SystemMessage(u),
|
|
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
|
"device_id":device_id}}
|
|
|
|
if as_stream:
|
|
# Return async generator for streaming
|
|
out = await self.ainvoke(*inp_data, as_stream=True, as_raw=as_raw)
|
|
return self._astream_res(out)
|
|
else:
|
|
return await self.ainvoke(*inp_data, as_stream=False, as_raw=as_raw) |