record pipeline id in conv_store

This commit is contained in:
2026-03-04 15:37:30 +08:00
parent cf1cae51f7
commit 2f40f1c526

View File

@@ -6,10 +6,18 @@ import os
from loguru import logger from loguru import logger
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage from langchain_core.messages import (
HumanMessage,
AIMessage,
ToolMessage,
SystemMessage,
BaseMessage,
)
class MessageType(str, Enum): class MessageType(str, Enum):
"""Enum for message types in the conversation store.""" """Enum for message types in the conversation store."""
HUMAN = "human" HUMAN = "human"
AI = "ai" AI = "ai"
TOOL = "tool" TOOL = "tool"
@@ -17,9 +25,12 @@ class MessageType(str, Enum):
class BaseConvStore(ABC): class BaseConvStore(ABC):
@abstractmethod @abstractmethod
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None): def record_message_list(
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
):
pass pass
class ConversationStore(BaseConvStore): class ConversationStore(BaseConvStore):
def __init__(self): def __init__(self):
conn_str = os.environ.get("CONN_STR") conn_str = os.environ.get("CONN_STR")
@@ -32,46 +43,53 @@ class ConversationStore(BaseConvStore):
conversation_id: str, conversation_id: str,
msg_type: MessageType, msg_type: MessageType,
content: str, content: str,
sequence: int, # the conversation number sequence: int,
pipeline_id: str = None,
): ):
with psycopg.connect(self.conn_str) as conn: with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
# DB schema only supports these columns:
# (conversation_id, message_type, content, sequence_number)
cur.execute( cur.execute(
""" """
INSERT INTO messages (conversation_id, message_type, content, sequence_number) INSERT INTO messages (conversation_id, pipeline_id, message_type, content, sequence_number)
VALUES (%s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s)
""", """,
(conversation_id, msg_type.value, content, sequence), (conversation_id, pipeline_id, msg_type.value, content, sequence),
) )
def get_conv_number(self, conversation_id: str) -> int: def get_conv_number(self, conversation_id: str) -> int:
""" """
if the conversation_id does not exist, return 0 if the conversation_id does not exist, return 0
if len(conversation) = 3, it will return 3 if len(conversation) = 3, it will return 3
""" """
with psycopg.connect(self.conn_str) as conn: with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(""" cur.execute(
"""
SELECT COUNT(*) SELECT COUNT(*)
FROM messages FROM messages
WHERE conversation_id = %s WHERE conversation_id = %s
""", (conversation_id,)) """,
(conversation_id,),
)
return int(cur.fetchone()[0]) return int(cur.fetchone()[0])
def get_conversation(self, conversation_id: str) -> List[Dict]: def get_conversation(self, conversation_id: str) -> List[Dict]:
with psycopg.connect(self.conn_str) as conn: with psycopg.connect(self.conn_str) as conn:
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur: with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
cur.execute(""" cur.execute(
"""
SELECT message_type, content, sequence_number, created_at SELECT message_type, content, sequence_number, created_at
FROM messages FROM messages
WHERE conversation_id = %s WHERE conversation_id = %s
ORDER BY sequence_number ASC ORDER BY sequence_number ASC
""", (conversation_id,)) """,
(conversation_id,),
)
return cur.fetchall() return cur.fetchall()
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None): def record_message_list(
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
):
inp = [e for e in inp if not isinstance(e, SystemMessage)] inp = [e for e in inp if not isinstance(e, SystemMessage)]
curr_len = self.get_conv_number(conv_id) curr_len = self.get_conv_number(conv_id)
to_add_msg = inp[curr_len:] to_add_msg = inp[curr_len:]
@@ -80,12 +98,13 @@ class ConversationStore(BaseConvStore):
# Serialize dict/list content to JSON string # Serialize dict/list content to JSON string
if not isinstance(content, str): if not isinstance(content, str):
content = json.dumps(content, ensure_ascii=False, indent=4) content = json.dumps(content, ensure_ascii=False, indent=4)
self.add_message(conv_id, self._get_type(msg), content, curr_len + 1) self.add_message(
conv_id, self._get_type(msg), content, curr_len + 1, pipeline_id
)
curr_len += 1 curr_len += 1
return curr_len return curr_len
def _get_type(self, msg: BaseMessage) -> MessageType:
def _get_type(self, msg:BaseMessage) -> MessageType:
if isinstance(msg, HumanMessage): if isinstance(msg, HumanMessage):
return MessageType.HUMAN return MessageType.HUMAN
elif isinstance(msg, AIMessage): elif isinstance(msg, AIMessage):
@@ -100,7 +119,9 @@ class ConversationPrinter(BaseConvStore):
def __init__(self): def __init__(self):
self.id_dic = {} self.id_dic = {}
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None): def record_message_list(
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
):
inp = [e for e in inp if not isinstance(e, SystemMessage)] inp = [e for e in inp if not isinstance(e, SystemMessage)]
curr_len = self.id_dic.get(conv_id, 0) curr_len = self.id_dic.get(conv_id, 0)
to_print_msg = inp[curr_len:] to_print_msg = inp[curr_len:]
@@ -113,9 +134,11 @@ class ConversationPrinter(BaseConvStore):
else: else:
self.id_dic[conv_id] += len(to_print_msg) self.id_dic[conv_id] += len(to_print_msg)
CONV_STORE = ConversationStore() CONV_STORE = ConversationStore()
# CONV_STORE = ConversationPrinter() # CONV_STORE = ConversationPrinter()
def use_printer(): def use_printer():
global CONV_STORE global CONV_STORE
CONV_STORE = ConversationPrinter() CONV_STORE = ConversationPrinter()