add pipeline id to conv_store

This commit is contained in:
2026-03-04 14:42:14 +08:00
parent 9d1eeaeec5
commit bb6d98c9f4

View File

@@ -4,6 +4,7 @@ from typing import List, Dict, Union
from enum import Enum from enum import Enum
import os import os
from loguru import logger from loguru import logger
from abc import ABC, abstractmethod
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
@@ -13,7 +14,13 @@ class MessageType(str, Enum):
AI = "ai" AI = "ai"
TOOL = "tool" TOOL = "tool"
class ConversationStore:
class BaseConvStore(ABC):
@abstractmethod
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
pass
class ConversationStore(BaseConvStore):
def __init__(self): def __init__(self):
conn_str = os.environ.get("CONN_STR") conn_str = os.environ.get("CONN_STR")
if conn_str is None: if conn_str is None:
@@ -64,7 +71,7 @@ class ConversationStore:
""", (conversation_id,)) """, (conversation_id,))
return cur.fetchall() return cur.fetchall()
def record_message_list(self, conv_id:str, inp:List[BaseMessage]): def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
inp = [e for e in inp if not isinstance(e, SystemMessage)] inp = [e for e in inp if not isinstance(e, SystemMessage)]
curr_len = self.get_conv_number(conv_id) curr_len = self.get_conv_number(conv_id)
to_add_msg = inp[curr_len:] to_add_msg = inp[curr_len:]
@@ -89,11 +96,11 @@ class ConversationStore:
raise ValueError(f"Unknown message type: {type(msg)}") raise ValueError(f"Unknown message type: {type(msg)}")
class ConversationPrinter: class ConversationPrinter(BaseConvStore):
def __init__(self): def __init__(self):
self.id_dic = {} self.id_dic = {}
def record_message_list(self, conv_id:str, inp:List[BaseMessage]): def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
inp = [e for e in inp if not isinstance(e, SystemMessage)] inp = [e for e in inp if not isinstance(e, SystemMessage)]
curr_len = self.id_dic.get(conv_id, 0) curr_len = self.id_dic.get(conv_id, 0)
to_print_msg = inp[curr_len:] to_print_msg = inp[curr_len:]