update conv store

This commit is contained in:
2026-01-30 09:13:00 +08:00
parent c497916bc2
commit 609e31c9ad

View File

@@ -1,8 +1,17 @@
import psycopg import psycopg
from uuid import UUID from uuid import UUID
from typing import List, Dict, Literal, Union from typing import List, Dict, Union
from enum import Enum
import os import os
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage
class MessageType(str, Enum):
"""Enum for message types in the conversation store."""
HUMAN = "human"
AI = "ai"
TOOL = "tool"
class ConversationStore: class ConversationStore:
def __init__(self): def __init__(self):
conn_str = os.environ.get("CONN_STR") conn_str = os.environ.get("CONN_STR")
@@ -21,7 +30,7 @@ class ConversationStore:
def add_message( def add_message(
self, self,
conversation_id: Union[str, UUID], conversation_id: Union[str, UUID],
msg_type: Literal["human", "ai", "tool"], msg_type: MessageType,
content: str, content: str,
sequence: int, # the conversation number sequence: int, # the conversation number
): ):
@@ -35,15 +44,19 @@ class ConversationStore:
INSERT INTO messages (conversation_id, message_type, content, sequence_number) INSERT INTO messages (conversation_id, message_type, content, sequence_number)
VALUES (%s, %s, %s, %s) VALUES (%s, %s, %s, %s)
""", """,
(conversation_id, msg_type, content, sequence), (conversation_id, msg_type.value, content, sequence),
) )
def get_conv_number(self, conversation_id: Union[str, UUID]) -> int: def get_conv_number(self, conversation_id: Union[str, UUID]) -> int:
"""
if the conversation_id does not exist, return 0
if len(conversation) = 3, it will return 3
"""
conversation_id = self._coerce_conversation_id(conversation_id) conversation_id = self._coerce_conversation_id(conversation_id)
with psycopg.connect(self.conn_str) as conn: with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(""" cur.execute("""
SELECT COALESCE(MAX(sequence_number), -1) SELECT COUNT(*)
FROM messages FROM messages
WHERE conversation_id = %s WHERE conversation_id = %s
""", (conversation_id,)) """, (conversation_id,))
@@ -61,5 +74,23 @@ class ConversationStore:
""", (conversation_id,)) """, (conversation_id,))
return cur.fetchall() return cur.fetchall()
def record_messages(self, conv_id:str, inp:List[BaseMessage]):
curr_len = self.get_conv_number(conv_id)
to_add_msg = inp[curr_len:]
for msg in to_add_msg:
self.add_message(conv_id, self._get_type(msg), msg.content, curr_len + 1)
curr_len += 1
return curr_len
def _get_type(self, msg:BaseMessage) -> MessageType:
if isinstance(msg, HumanMessage):
return MessageType.HUMAN
elif isinstance(msg, AIMessage):
return MessageType.AI
elif isinstance(msg, ToolMessage):
return MessageType.TOOL
else:
raise ValueError(f"Unknown message type: {type(msg)}")
CONV_STORE = ConversationStore() CONV_STORE = ConversationStore()