record conversations
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Type
|
from typing import Type, List
|
||||||
import tyro
|
import tyro
|
||||||
import asyncio
|
import asyncio
|
||||||
import websockets
|
import websockets
|
||||||
@@ -8,7 +8,7 @@ from loguru import logger
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
@@ -16,7 +16,7 @@ from langgraph.checkpoint.memory import MemorySaver
|
|||||||
from lang_agent.config import InstantiateConfig, KeyConfig
|
from lang_agent.config import InstantiateConfig, KeyConfig
|
||||||
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
|
from lang_agent.components.conv_store import CONV_STORE
|
||||||
|
|
||||||
DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员,擅长倾听、共情且主动回应。聊天时语气自然亲切,像朋友般轻松交流,不使用生硬术语。能接住各种话题,对疑问耐心解答,对情绪及时回应,避免冷场。保持积极正向,不传播负面信息,语言简洁易懂,让对话流畅舒适。与用户(User)交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求]
|
DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员,擅长倾听、共情且主动回应。聊天时语气自然亲切,像朋友般轻松交流,不使用生硬术语。能接住各种话题,对疑问耐心解答,对情绪及时回应,避免冷场。保持积极正向,不传播负面信息,语言简洁易懂,让对话流畅舒适。与用户(User)交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求]
|
||||||
[角色设定]
|
[角色设定]
|
||||||
@@ -108,9 +108,9 @@ class Pipeline:
|
|||||||
def invoke(self, *nargs, **kwargs)->str:
|
def invoke(self, *nargs, **kwargs)->str:
|
||||||
out = self.graph.invoke(*nargs, **kwargs)
|
out = self.graph.invoke(*nargs, **kwargs)
|
||||||
|
|
||||||
# If streaming, yield chunks from the generator
|
# If streaming, return the raw generator (let caller handle wrapping)
|
||||||
if kwargs.get("as_stream"):
|
if kwargs.get("as_stream"):
|
||||||
return self._stream_res(out)
|
return out
|
||||||
|
|
||||||
# Non-streaming path
|
# Non-streaming path
|
||||||
if kwargs.get("as_raw"):
|
if kwargs.get("as_raw"):
|
||||||
@@ -128,9 +128,13 @@ class Pipeline:
|
|||||||
assert 0, "something is wrong"
|
assert 0, "something is wrong"
|
||||||
|
|
||||||
|
|
||||||
def _stream_res(self, out:list):
|
def _stream_res(self, out:List[str | List[BaseMessage]], conv_id:str=None):
|
||||||
for chunk in out:
|
for chunk in out:
|
||||||
|
if isinstance(chunk, str):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
else:
|
||||||
|
logger.info("logged message")
|
||||||
|
CONV_STORE.record_messages(conv_id, chunk)
|
||||||
|
|
||||||
async def _astream_res(self, out):
|
async def _astream_res(self, out):
|
||||||
"""Async version of _stream_res for async generators."""
|
"""Async version of _stream_res for async generators."""
|
||||||
@@ -142,8 +146,6 @@ class Pipeline:
|
|||||||
as_stream (bool): if true, enable the thing to be streamable
|
as_stream (bool): if true, enable the thing to be streamable
|
||||||
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
||||||
"""
|
"""
|
||||||
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
|
|
||||||
u = DEFAULT_PROMPT
|
|
||||||
|
|
||||||
device_id = "0"
|
device_id = "0"
|
||||||
spl_ls = thread_id.split("_")
|
spl_ls = thread_id.split("_")
|
||||||
@@ -151,10 +153,6 @@ class Pipeline:
|
|||||||
if len(spl_ls) == 2:
|
if len(spl_ls) == 2:
|
||||||
thread_id, device_id = spl_ls
|
thread_id, device_id = spl_ls
|
||||||
|
|
||||||
# inp = {"messages":[SystemMessage(u),
|
|
||||||
# HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
|
||||||
# "device_id":device_id}}
|
|
||||||
|
|
||||||
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
||||||
"device_id":device_id}}
|
"device_id":device_id}}
|
||||||
|
|
||||||
@@ -162,7 +160,7 @@ class Pipeline:
|
|||||||
|
|
||||||
if as_stream:
|
if as_stream:
|
||||||
# Yield chunks from the generator
|
# Yield chunks from the generator
|
||||||
return self._stream_res(out)
|
return self._stream_res(out, thread_id)
|
||||||
else:
|
else:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user