record conversations

This commit is contained in:
2026-01-30 10:42:29 +08:00
parent 59031c9919
commit 781b3f450e

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Type from typing import Type, List
import tyro import tyro
import asyncio import asyncio
import websockets import websockets
@@ -8,7 +8,7 @@ from loguru import logger
import os import os
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
from langchain.agents import create_agent from langchain.agents import create_agent
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
@@ -16,7 +16,7 @@ from langgraph.checkpoint.memory import MemorySaver
from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.config import InstantiateConfig, KeyConfig
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
from lang_agent.base import GraphBase from lang_agent.base import GraphBase
from lang_agent.components.conv_store import CONV_STORE
DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员擅长倾听、共情且主动回应。聊天时语气自然亲切像朋友般轻松交流不使用生硬术语。能接住各种话题对疑问耐心解答对情绪及时回应避免冷场。保持积极正向不传播负面信息语言简洁易懂让对话流畅舒适。与用户User交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求] DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员擅长倾听、共情且主动回应。聊天时语气自然亲切像朋友般轻松交流不使用生硬术语。能接住各种话题对疑问耐心解答对情绪及时回应避免冷场。保持积极正向不传播负面信息语言简洁易懂让对话流畅舒适。与用户User交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求]
[角色设定] [角色设定]
@@ -108,9 +108,9 @@ class Pipeline:
def invoke(self, *nargs, **kwargs)->str: def invoke(self, *nargs, **kwargs)->str:
out = self.graph.invoke(*nargs, **kwargs) out = self.graph.invoke(*nargs, **kwargs)
# If streaming, yield chunks from the generator # If streaming, return the raw generator (let caller handle wrapping)
if kwargs.get("as_stream"): if kwargs.get("as_stream"):
return self._stream_res(out) return out
# Non-streaming path # Non-streaming path
if kwargs.get("as_raw"): if kwargs.get("as_raw"):
@@ -128,9 +128,13 @@ class Pipeline:
assert 0, "something is wrong" assert 0, "something is wrong"
def _stream_res(self, out:list): def _stream_res(self, out:List[str | List[BaseMessage]], conv_id:str=None):
for chunk in out: for chunk in out:
yield chunk if isinstance(chunk, str):
yield chunk
else:
logger.info("logged message")
CONV_STORE.record_messages(conv_id, chunk)
async def _astream_res(self, out): async def _astream_res(self, out):
"""Async version of _stream_res for async generators.""" """Async version of _stream_res for async generators."""
@@ -142,8 +146,6 @@ class Pipeline:
as_stream (bool): if true, enable the thing to be streamable as_stream (bool): if true, enable the thing to be streamable
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage] as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
""" """
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
u = DEFAULT_PROMPT
device_id = "0" device_id = "0"
spl_ls = thread_id.split("_") spl_ls = thread_id.split("_")
@@ -151,10 +153,6 @@ class Pipeline:
if len(spl_ls) == 2: if len(spl_ls) == 2:
thread_id, device_id = spl_ls thread_id, device_id = spl_ls
# inp = {"messages":[SystemMessage(u),
# HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
# "device_id":device_id}}
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id, inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
"device_id":device_id}} "device_id":device_id}}
@@ -162,7 +160,7 @@ class Pipeline:
if as_stream: if as_stream:
# Yield chunks from the generator # Yield chunks from the generator
return self._stream_res(out) return self._stream_res(out, thread_id)
else: else:
return out return out