add pipeline id to conv_store

This commit is contained in:
2026-03-04 14:42:14 +08:00
parent 9d1eeaeec5
commit bb6d98c9f4

View File

@@ -4,6 +4,7 @@ from typing import List, Dict, Union
from enum import Enum
import os
from loguru import logger
from abc import ABC, abstractmethod
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
@@ -13,7 +14,13 @@ class MessageType(str, Enum):
AI = "ai"
TOOL = "tool"
class ConversationStore:
class BaseConvStore(ABC):
@abstractmethod
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
pass
class ConversationStore(BaseConvStore):
def __init__(self):
conn_str = os.environ.get("CONN_STR")
if conn_str is None:
@@ -64,7 +71,7 @@ class ConversationStore:
""", (conversation_id,))
return cur.fetchall()
def record_message_list(self, conv_id:str, inp:List[BaseMessage]):
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
inp = [e for e in inp if not isinstance(e, SystemMessage)]
curr_len = self.get_conv_number(conv_id)
to_add_msg = inp[curr_len:]
@@ -89,11 +96,11 @@ class ConversationStore:
raise ValueError(f"Unknown message type: {type(msg)}")
class ConversationPrinter:
class ConversationPrinter(BaseConvStore):
def __init__(self):
self.id_dic = {}
def record_message_list(self, conv_id:str, inp:List[BaseMessage]):
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
inp = [e for e in inp if not isinstance(e, SystemMessage)]
curr_len = self.id_dic.get(conv_id, 0)
to_print_msg = inp[curr_len:]