record conversations
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type
|
||||
from typing import Type, List
|
||||
import tyro
|
||||
import asyncio
|
||||
import websockets
|
||||
@@ -8,7 +8,7 @@ from loguru import logger
|
||||
import os
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
@@ -16,7 +16,7 @@ from langgraph.checkpoint.memory import MemorySaver
|
||||
from lang_agent.config import InstantiateConfig, KeyConfig
|
||||
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
||||
from lang_agent.base import GraphBase
|
||||
|
||||
from lang_agent.components.conv_store import CONV_STORE
|
||||
|
||||
DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员,擅长倾听、共情且主动回应。聊天时语气自然亲切,像朋友般轻松交流,不使用生硬术语。能接住各种话题,对疑问耐心解答,对情绪及时回应,避免冷场。保持积极正向,不传播负面信息,语言简洁易懂,让对话流畅舒适。与用户(User)交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求]
|
||||
[角色设定]
|
||||
@@ -108,9 +108,9 @@ class Pipeline:
|
||||
def invoke(self, *nargs, **kwargs)->str:
|
||||
out = self.graph.invoke(*nargs, **kwargs)
|
||||
|
||||
# If streaming, yield chunks from the generator
|
||||
# If streaming, return the raw generator (let caller handle wrapping)
|
||||
if kwargs.get("as_stream"):
|
||||
return self._stream_res(out)
|
||||
return out
|
||||
|
||||
# Non-streaming path
|
||||
if kwargs.get("as_raw"):
|
||||
@@ -128,9 +128,13 @@ class Pipeline:
|
||||
assert 0, "something is wrong"
|
||||
|
||||
|
||||
def _stream_res(self, out:list):
|
||||
def _stream_res(self, out:List[str | List[BaseMessage]], conv_id:str=None):
|
||||
for chunk in out:
|
||||
yield chunk
|
||||
if isinstance(chunk, str):
|
||||
yield chunk
|
||||
else:
|
||||
logger.info("logged message")
|
||||
CONV_STORE.record_messages(conv_id, chunk)
|
||||
|
||||
async def _astream_res(self, out):
|
||||
"""Async version of _stream_res for async generators."""
|
||||
@@ -142,8 +146,6 @@ class Pipeline:
|
||||
as_stream (bool): if true, enable the thing to be streamable
|
||||
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
||||
"""
|
||||
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
|
||||
u = DEFAULT_PROMPT
|
||||
|
||||
device_id = "0"
|
||||
spl_ls = thread_id.split("_")
|
||||
@@ -151,10 +153,6 @@ class Pipeline:
|
||||
if len(spl_ls) == 2:
|
||||
thread_id, device_id = spl_ls
|
||||
|
||||
# inp = {"messages":[SystemMessage(u),
|
||||
# HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
||||
# "device_id":device_id}}
|
||||
|
||||
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
||||
"device_id":device_id}}
|
||||
|
||||
@@ -162,7 +160,7 @@ class Pipeline:
|
||||
|
||||
if as_stream:
|
||||
# Yield chunks from the generator
|
||||
return self._stream_res(out)
|
||||
return self._stream_res(out, thread_id)
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user