record conversations

This commit is contained in:
2026-01-30 10:42:29 +08:00
parent 59031c9919
commit 781b3f450e

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Type
from typing import Type, List
import tyro
import asyncio
import websockets
@@ -8,7 +8,7 @@ from loguru import logger
import os
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
from langchain.agents import create_agent
from langgraph.checkpoint.memory import MemorySaver
@@ -16,7 +16,7 @@ from langgraph.checkpoint.memory import MemorySaver
from lang_agent.config import InstantiateConfig, KeyConfig
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
from lang_agent.base import GraphBase
from lang_agent.components.conv_store import CONV_STORE
DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员擅长倾听、共情且主动回应。聊天时语气自然亲切像朋友般轻松交流不使用生硬术语。能接住各种话题对疑问耐心解答对情绪及时回应避免冷场。保持积极正向不传播负面信息语言简洁易懂让对话流畅舒适。与用户User交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求]
[角色设定]
@@ -108,9 +108,9 @@ class Pipeline:
def invoke(self, *nargs, **kwargs)->str:
out = self.graph.invoke(*nargs, **kwargs)
# If streaming, yield chunks from the generator
# If streaming, return the raw generator (let caller handle wrapping)
if kwargs.get("as_stream"):
return self._stream_res(out)
return out
# Non-streaming path
if kwargs.get("as_raw"):
@@ -128,9 +128,13 @@ class Pipeline:
assert 0, "something is wrong"
def _stream_res(self, out:list):
def _stream_res(self, out:List[str | List[BaseMessage]], conv_id:str=None):
for chunk in out:
if isinstance(chunk, str):
yield chunk
else:
logger.info("logged message")
CONV_STORE.record_messages(conv_id, chunk)
async def _astream_res(self, out):
"""Async version of _stream_res for async generators."""
@@ -142,8 +146,6 @@ class Pipeline:
as_stream (bool): if true, enable the thing to be streamable
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
"""
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
u = DEFAULT_PROMPT
device_id = "0"
spl_ls = thread_id.split("_")
@@ -151,10 +153,6 @@ class Pipeline:
if len(spl_ls) == 2:
thread_id, device_id = spl_ls
# inp = {"messages":[SystemMessage(u),
# HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
# "device_id":device_id}}
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
"device_id":device_id}}
@@ -162,7 +160,7 @@ class Pipeline:
if as_stream:
# Yield chunks from the generator
return self._stream_res(out)
return self._stream_res(out, thread_id)
else:
return out