store id as string instead

This commit is contained in:
2026-01-30 10:38:58 +08:00
parent fde8c85488
commit 59031c9919

View File

@@ -1,10 +1,9 @@
import psycopg import psycopg
from uuid import UUID
from typing import List, Dict, Union from typing import List, Dict, Union
from enum import Enum from enum import Enum
import os import os
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
class MessageType(str, Enum): class MessageType(str, Enum):
"""Enum for message types in the conversation store.""" """Enum for message types in the conversation store."""
@@ -19,22 +18,13 @@ class ConversationStore:
raise ValueError("CONN_STR is not set") raise ValueError("CONN_STR is not set")
self.conn_str = conn_str self.conn_str = conn_str
def _coerce_conversation_id(self, conversation_id: Union[str, UUID]) -> UUID:
if isinstance(conversation_id, UUID):
return conversation_id
try:
return UUID(conversation_id)
except (TypeError, ValueError) as e:
raise ValueError("conversation_id must be a UUID (or UUID string)") from e
def add_message( def add_message(
self, self,
conversation_id: Union[str, UUID], conversation_id: str,
msg_type: MessageType, msg_type: MessageType,
content: str, content: str,
sequence: int, # the conversation number sequence: int, # the conversation number
): ):
conversation_id = self._coerce_conversation_id(conversation_id)
with psycopg.connect(self.conn_str) as conn: with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
# DB schema only supports these columns: # DB schema only supports these columns:
@@ -47,12 +37,11 @@ class ConversationStore:
(conversation_id, msg_type.value, content, sequence), (conversation_id, msg_type.value, content, sequence),
) )
def get_conv_number(self, conversation_id: Union[str, UUID]) -> int: def get_conv_number(self, conversation_id: str) -> int:
""" """
if the conversation_id does not exist, return 0 if the conversation_id does not exist, return 0
if len(conversation) = 3, it will return 3 if len(conversation) = 3, it will return 3
""" """
conversation_id = self._coerce_conversation_id(conversation_id)
with psycopg.connect(self.conn_str) as conn: with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(""" cur.execute("""
@@ -62,8 +51,7 @@ class ConversationStore:
""", (conversation_id,)) """, (conversation_id,))
return int(cur.fetchone()[0]) return int(cur.fetchone()[0])
def get_conversation(self, conversation_id: Union[str, UUID]) -> List[Dict]: def get_conversation(self, conversation_id: str) -> List[Dict]:
conversation_id = self._coerce_conversation_id(conversation_id)
with psycopg.connect(self.conn_str) as conn: with psycopg.connect(self.conn_str) as conn:
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur: with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
cur.execute(""" cur.execute("""
@@ -75,6 +63,7 @@ class ConversationStore:
return cur.fetchall() return cur.fetchall()
def record_messages(self, conv_id:str, inp:List[BaseMessage]): def record_messages(self, conv_id:str, inp:List[BaseMessage]):
inp = [e for e in inp if not isinstance(e, SystemMessage)]
curr_len = self.get_conv_number(conv_id) curr_len = self.get_conv_number(conv_id)
to_add_msg = inp[curr_len:] to_add_msg = inp[curr_len:]
for msg in to_add_msg: for msg in to_add_msg: