Files
lang-agent/lang_agent/components/conv_store.py
2026-01-29 09:44:43 +08:00

65 lines
2.6 KiB
Python

import psycopg
from uuid import UUID
from typing import List, Dict, Literal, Union
import os
class ConversationStore:
def __init__(self):
conn_str = os.environ.get("CONN_STR")
if conn_str is None:
raise ValueError("CONN_STR is not set")
self.conn_str = conn_str
def _coerce_conversation_id(self, conversation_id: Union[str, UUID]) -> UUID:
if isinstance(conversation_id, UUID):
return conversation_id
try:
return UUID(conversation_id)
except (TypeError, ValueError) as e:
raise ValueError("conversation_id must be a UUID (or UUID string)") from e
def add_message(
self,
conversation_id: Union[str, UUID],
msg_type: Literal["human", "ai", "tool"],
content: str,
sequence: int, # the conversation number
):
conversation_id = self._coerce_conversation_id(conversation_id)
with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur:
# DB schema only supports these columns:
# (conversation_id, message_type, content, sequence_number)
cur.execute(
"""
INSERT INTO messages (conversation_id, message_type, content, sequence_number)
VALUES (%s, %s, %s, %s)
""",
(conversation_id, msg_type, content, sequence),
)
def get_conv_number(self, conversation_id: Union[str, UUID]) -> int:
conversation_id = self._coerce_conversation_id(conversation_id)
with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur:
cur.execute("""
SELECT COALESCE(MAX(sequence_number), -1)
FROM messages
WHERE conversation_id = %s
""", (conversation_id,))
return int(cur.fetchone()[0])
def get_conversation(self, conversation_id: Union[str, UUID]) -> List[Dict]:
conversation_id = self._coerce_conversation_id(conversation_id)
with psycopg.connect(self.conn_str) as conn:
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
cur.execute("""
SELECT message_type, content, sequence_number, created_at
FROM messages
WHERE conversation_id = %s
ORDER BY sequence_number ASC
""", (conversation_id,))
return cur.fetchall()
CONV_STORE = ConversationStore()