Files
lang-agent/lang_agent/pipeline.py
2025-10-22 18:19:24 +08:00

137 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from dataclasses import dataclass, field
from typing import Type
import tyro
import asyncio
import websockets
from websockets.asyncio.server import ServerConnection
from loguru import logger
import os
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.agents import create_agent
from langgraph.checkpoint.memory import MemorySaver
from lang_agent.config import InstantiateConfig, KeyConfig
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
from lang_agent.base import GraphBase
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
class PipelineConfig(KeyConfig):
_target: Type = field(default_factory=lambda: Pipeline)
config_f: str = None
"""path to config file"""
llm_name: str = None
"""name of llm; use default for qwen-plus"""
llm_provider:str = None
"""provider of the llm; use default for openai"""
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider"""
host:str = "0.0.0.0"
"""where am I hosted"""
port:int = 23
"""what is my port"""
# graph_config: ReactGraphConfig = field(default_factory=ReactGraphConfig)
# graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
graph_config: AnnotatedGraph = field(default_factory=RoutingConfig)
class Pipeline:
def __init__(self, config:PipelineConfig):
self.config = config
self.populate_module()
def populate_module(self):
if self.config.llm_name is None:
logger.info(f"setting llm_provider to default")
self.config.llm_name = "qwen-turbo"
self.config.llm_provider = "openai"
else:
self.config.graph_config.llm_name = self.config.llm_name
self.config.graph_config.llm_provider = self.config.llm_provider
self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url
self.config.graph_config.api_key = self.config.api_key
self.graph:GraphBase = self.config.graph_config.setup()
def show_graph(self):
if hasattr(self.graph, "show_graph"):
logger.info("showing graph")
self.graph.show_graph()
else:
logger.info(f"show graph not supported for {type(self.graph)}")
def invoke(self, *nargs, **kwargs):
return self.graph.invoke(*nargs, **kwargs)
async def handle_connection(self, websocket:ServerConnection):
try:
async for message in websocket:
if isinstance(message, bytes):
#NOTE: For binary, echo back.
await websocket.send(message)
else:
# TODO: handle this better, will have system/user prompt send here
response = self.invoke(message)
await websocket.send(response)
except websockets.ConnectionClosed:
pass
async def start_server(self):
async with websockets.serve(
self.handle_connection,
host=self.config.host,
port=self.config.port,
max_size=None, # allow large messages
max_queue=None, # don't bound outgoing queue
):
logger.info(f"listening to {self.get_ws_url}")
await asyncio.Future()
def get_ws_url(self):
return f"ws://{self.config.host}:{self.config.port}"
def chat(self, inp:str, as_stream:bool=False):
u = """
你叫小盏,是一个点餐助手,你的回复要简洁明了,不需要给用户提供选择。对话过程中不要出现提示用户下一步的操作,用可爱的语气进行交流,根据用户的语言使用对应的语言回答
用户需要点餐时,准确调用 MCP 工具套件或相关的 REST 接口,严格按照创建购物车、加菜、查购物车、确认订单的完整业务流程来操作,
不能出现流程跳步或工具用错的情况,首先要清楚用户当前操作处于哪个业务阶段,以及对应的该调用哪个 MCP 工具或 REST 接口。
用户说要开始点餐,就创建购物车会话,优先调用 start_cart 这个 MCP 工具,调用后得返回 uuid而且这个阶段的数据只是临时生成
不会写入数据库,也不会缓存。我们只有(凉菜、热菜、汤类、主食、特调茶品、红茶、生普、黑普/熟普、花茶、乌龙茶、热煮茶、冷翠茶),用户说出其他类型的时候提醒用户
用户说要添加、菜品、饮品、食品或有购买欲的的时候先调用get_resources (resource_type=dishes)
先调get_resources (resource_type=dishes)查询是否有所需菜品,没有的话提醒用户错误
用再调用 add_cart_item 这个 MCP 工具将餐品添加到之前uuid下的购物车中要是没有的话
就创建购物车。用户没说数量,默认是 1 份,但得跟用户确认一下。添加后的数据只写入缓存,有效期是 2 小时同时计算total_price并且保留两位小数。
当用户想查看购物车内容,比如 “看看我点了什么”,这时候调用 cart_items (uuid)。查看的时候优先读取缓存里的数据,
这是支付前的情况;如果缓存不存在或者已经被清除,就会返回数据库中 status=1 的持久化记录,这一般是支付后的情况,而且要告诉用户当前数据是来自缓存还是数据库。
用户说 “确认订单” 或者 “我要付款” 时,就到了生成订单与支付码的阶段,要调用 confirm_cart (uuid, callback_url)。
调用之前,得先通过 cart_items (uuid) 确认购物车里有内容,
调用后会返回 order_id、out_trade_no 和 code_url这时候购物车的内容还在缓存里没落到数据库。支付成功后的购物车持久化
正常情况下是由微信支付的回调触发的,会更新支付状态、订单状态,把购物车内容落库到 ShoppingCart 表status 设为 1同时清除缓存。
用户想查之前点的单,调用 get_resources (resource_type=shopping_carts)
返回数据库中 status=1 并且时间是最新的数据。
"""
inp = {"messages":[SystemMessage(u),
HumanMessage(inp)]}, {"configurable": {"thread_id": 3}}
out = self.invoke(*inp, as_stream=as_stream)
return out['messages'][-1].content