update conv store

This commit is contained in:
2026-01-30 09:13:00 +08:00
parent c497916bc2
commit 609e31c9ad

View File

@@ -1,8 +1,17 @@
import psycopg
from uuid import UUID
from typing import List, Dict, Literal, Union
from typing import List, Dict, Union
from enum import Enum
import os
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage
class MessageType(str, Enum):
"""Enum for message types in the conversation store."""
HUMAN = "human"
AI = "ai"
TOOL = "tool"
class ConversationStore:
def __init__(self):
conn_str = os.environ.get("CONN_STR")
@@ -21,7 +30,7 @@ class ConversationStore:
def add_message(
self,
conversation_id: Union[str, UUID],
msg_type: Literal["human", "ai", "tool"],
msg_type: MessageType,
content: str,
sequence: int, # the conversation number
):
@@ -35,15 +44,19 @@ class ConversationStore:
INSERT INTO messages (conversation_id, message_type, content, sequence_number)
VALUES (%s, %s, %s, %s)
""",
(conversation_id, msg_type, content, sequence),
(conversation_id, msg_type.value, content, sequence),
)
def get_conv_number(self, conversation_id: Union[str, UUID]) -> int:
"""
if the conversation_id does not exist, return 0
if len(conversation) = 3, it will return 3
"""
conversation_id = self._coerce_conversation_id(conversation_id)
with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur:
cur.execute("""
SELECT COALESCE(MAX(sequence_number), -1)
SELECT COUNT(*)
FROM messages
WHERE conversation_id = %s
""", (conversation_id,))
@@ -60,6 +73,24 @@ class ConversationStore:
ORDER BY sequence_number ASC
""", (conversation_id,))
return cur.fetchall()
def record_messages(self, conv_id:str, inp:List[BaseMessage]):
curr_len = self.get_conv_number(conv_id)
to_add_msg = inp[curr_len:]
for msg in to_add_msg:
self.add_message(conv_id, self._get_type(msg), msg.content, curr_len + 1)
curr_len += 1
return curr_len
def _get_type(self, msg:BaseMessage) -> MessageType:
if isinstance(msg, HumanMessage):
return MessageType.HUMAN
elif isinstance(msg, AIMessage):
return MessageType.AI
elif isinstance(msg, ToolMessage):
return MessageType.TOOL
else:
raise ValueError(f"Unknown message type: {type(msg)}")
CONV_STORE = ConversationStore()