store id as string instead

This commit is contained in:
2026-01-30 10:38:58 +08:00
parent fde8c85488
commit 59031c9919

View File

@@ -1,10 +1,9 @@
import psycopg
from uuid import UUID
from typing import List, Dict, Union
from enum import Enum
import os
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
class MessageType(str, Enum):
"""Enum for message types in the conversation store."""
@@ -19,22 +18,13 @@ class ConversationStore:
raise ValueError("CONN_STR is not set")
self.conn_str = conn_str
def _coerce_conversation_id(self, conversation_id: Union[str, UUID]) -> UUID:
if isinstance(conversation_id, UUID):
return conversation_id
try:
return UUID(conversation_id)
except (TypeError, ValueError) as e:
raise ValueError("conversation_id must be a UUID (or UUID string)") from e
def add_message(
self,
conversation_id: Union[str, UUID],
conversation_id: str,
msg_type: MessageType,
content: str,
sequence: int, # the conversation number
):
conversation_id = self._coerce_conversation_id(conversation_id)
with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur:
# DB schema only supports these columns:
@@ -47,12 +37,11 @@ class ConversationStore:
(conversation_id, msg_type.value, content, sequence),
)
def get_conv_number(self, conversation_id: Union[str, UUID]) -> int:
def get_conv_number(self, conversation_id: str) -> int:
"""
if the conversation_id does not exist, return 0
if len(conversation) = 3, it will return 3
"""
conversation_id = self._coerce_conversation_id(conversation_id)
with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur:
cur.execute("""
@@ -62,8 +51,7 @@ class ConversationStore:
""", (conversation_id,))
return int(cur.fetchone()[0])
def get_conversation(self, conversation_id: Union[str, UUID]) -> List[Dict]:
conversation_id = self._coerce_conversation_id(conversation_id)
def get_conversation(self, conversation_id: str) -> List[Dict]:
with psycopg.connect(self.conn_str) as conn:
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
cur.execute("""
@@ -75,6 +63,7 @@ class ConversationStore:
return cur.fetchall()
def record_messages(self, conv_id:str, inp:List[BaseMessage]):
inp = [e for e in inp if not isinstance(e, SystemMessage)]
curr_len = self.get_conv_number(conv_id)
to_add_msg = inp[curr_len:]
for msg in to_add_msg: