record pipeline id in conv_store

This commit is contained in:
2026-03-04 15:37:30 +08:00
parent cf1cae51f7
commit 2f40f1c526

View File

@@ -6,10 +6,18 @@ import os
from loguru import logger
from abc import ABC, abstractmethod
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
from langchain_core.messages import (
HumanMessage,
AIMessage,
ToolMessage,
SystemMessage,
BaseMessage,
)
class MessageType(str, Enum):
"""Enum for message types in the conversation store."""
HUMAN = "human"
AI = "ai"
TOOL = "tool"
@@ -17,9 +25,12 @@ class MessageType(str, Enum):
class BaseConvStore(ABC):
@abstractmethod
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
def record_message_list(
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
):
pass
class ConversationStore(BaseConvStore):
def __init__(self):
conn_str = os.environ.get("CONN_STR")
@@ -32,46 +43,53 @@ class ConversationStore(BaseConvStore):
conversation_id: str,
msg_type: MessageType,
content: str,
sequence: int, # the conversation number
sequence: int,
pipeline_id: str = None,
):
with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur:
# DB schema only supports these columns:
# (conversation_id, message_type, content, sequence_number)
cur.execute(
"""
INSERT INTO messages (conversation_id, message_type, content, sequence_number)
VALUES (%s, %s, %s, %s)
INSERT INTO messages (conversation_id, pipeline_id, message_type, content, sequence_number)
VALUES (%s, %s, %s, %s, %s)
""",
(conversation_id, msg_type.value, content, sequence),
(conversation_id, pipeline_id, msg_type.value, content, sequence),
)
def get_conv_number(self, conversation_id: str) -> int:
"""
if the conversation_id does not exist, return 0
if len(conversation) = 3, it will return 3
if the conversation_id does not exist, return 0
if len(conversation) = 3, it will return 3
"""
with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur:
cur.execute("""
cur.execute(
"""
SELECT COUNT(*)
FROM messages
WHERE conversation_id = %s
""", (conversation_id,))
""",
(conversation_id,),
)
return int(cur.fetchone()[0])
def get_conversation(self, conversation_id: str) -> List[Dict]:
with psycopg.connect(self.conn_str) as conn:
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
cur.execute("""
cur.execute(
"""
SELECT message_type, content, sequence_number, created_at
FROM messages
WHERE conversation_id = %s
ORDER BY sequence_number ASC
""", (conversation_id,))
""",
(conversation_id,),
)
return cur.fetchall()
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
def record_message_list(
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
):
inp = [e for e in inp if not isinstance(e, SystemMessage)]
curr_len = self.get_conv_number(conv_id)
to_add_msg = inp[curr_len:]
@@ -80,12 +98,13 @@ class ConversationStore(BaseConvStore):
# Serialize dict/list content to JSON string
if not isinstance(content, str):
content = json.dumps(content, ensure_ascii=False, indent=4)
self.add_message(conv_id, self._get_type(msg), content, curr_len + 1)
self.add_message(
conv_id, self._get_type(msg), content, curr_len + 1, pipeline_id
)
curr_len += 1
return curr_len
def _get_type(self, msg:BaseMessage) -> MessageType:
def _get_type(self, msg: BaseMessage) -> MessageType:
if isinstance(msg, HumanMessage):
return MessageType.HUMAN
elif isinstance(msg, AIMessage):
@@ -99,23 +118,27 @@ class ConversationStore(BaseConvStore):
class ConversationPrinter(BaseConvStore):
def __init__(self):
self.id_dic = {}
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
def record_message_list(
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
):
inp = [e for e in inp if not isinstance(e, SystemMessage)]
curr_len = self.id_dic.get(conv_id, 0)
to_print_msg = inp[curr_len:]
print("\n")
for msg in to_print_msg:
msg.pretty_print()
if curr_len == 0:
self.id_dic[conv_id] = len(inp)
else:
self.id_dic[conv_id] += len(to_print_msg)
CONV_STORE = ConversationStore()
# CONV_STORE = ConversationPrinter()
def use_printer():
global CONV_STORE
CONV_STORE = ConversationPrinter()