store id as string instead
This commit is contained in:
@@ -1,10 +1,9 @@
|
|||||||
import psycopg
|
import psycopg
|
||||||
from uuid import UUID
|
|
||||||
from typing import List, Dict, Union
|
from typing import List, Dict, Union
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage
|
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
|
||||||
|
|
||||||
class MessageType(str, Enum):
|
class MessageType(str, Enum):
|
||||||
"""Enum for message types in the conversation store."""
|
"""Enum for message types in the conversation store."""
|
||||||
@@ -19,22 +18,13 @@ class ConversationStore:
|
|||||||
raise ValueError("CONN_STR is not set")
|
raise ValueError("CONN_STR is not set")
|
||||||
self.conn_str = conn_str
|
self.conn_str = conn_str
|
||||||
|
|
||||||
def _coerce_conversation_id(self, conversation_id: Union[str, UUID]) -> UUID:
|
|
||||||
if isinstance(conversation_id, UUID):
|
|
||||||
return conversation_id
|
|
||||||
try:
|
|
||||||
return UUID(conversation_id)
|
|
||||||
except (TypeError, ValueError) as e:
|
|
||||||
raise ValueError("conversation_id must be a UUID (or UUID string)") from e
|
|
||||||
|
|
||||||
def add_message(
|
def add_message(
|
||||||
self,
|
self,
|
||||||
conversation_id: Union[str, UUID],
|
conversation_id: str,
|
||||||
msg_type: MessageType,
|
msg_type: MessageType,
|
||||||
content: str,
|
content: str,
|
||||||
sequence: int, # the conversation number
|
sequence: int, # the conversation number
|
||||||
):
|
):
|
||||||
conversation_id = self._coerce_conversation_id(conversation_id)
|
|
||||||
with psycopg.connect(self.conn_str) as conn:
|
with psycopg.connect(self.conn_str) as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
# DB schema only supports these columns:
|
# DB schema only supports these columns:
|
||||||
@@ -47,12 +37,11 @@ class ConversationStore:
|
|||||||
(conversation_id, msg_type.value, content, sequence),
|
(conversation_id, msg_type.value, content, sequence),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_conv_number(self, conversation_id: Union[str, UUID]) -> int:
|
def get_conv_number(self, conversation_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
if the conversation_id does not exist, return 0
|
if the conversation_id does not exist, return 0
|
||||||
if len(conversation) = 3, it will return 3
|
if len(conversation) = 3, it will return 3
|
||||||
"""
|
"""
|
||||||
conversation_id = self._coerce_conversation_id(conversation_id)
|
|
||||||
with psycopg.connect(self.conn_str) as conn:
|
with psycopg.connect(self.conn_str) as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute("""
|
cur.execute("""
|
||||||
@@ -62,8 +51,7 @@ class ConversationStore:
|
|||||||
""", (conversation_id,))
|
""", (conversation_id,))
|
||||||
return int(cur.fetchone()[0])
|
return int(cur.fetchone()[0])
|
||||||
|
|
||||||
def get_conversation(self, conversation_id: Union[str, UUID]) -> List[Dict]:
|
def get_conversation(self, conversation_id: str) -> List[Dict]:
|
||||||
conversation_id = self._coerce_conversation_id(conversation_id)
|
|
||||||
with psycopg.connect(self.conn_str) as conn:
|
with psycopg.connect(self.conn_str) as conn:
|
||||||
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
||||||
cur.execute("""
|
cur.execute("""
|
||||||
@@ -75,6 +63,7 @@ class ConversationStore:
|
|||||||
return cur.fetchall()
|
return cur.fetchall()
|
||||||
|
|
||||||
def record_messages(self, conv_id:str, inp:List[BaseMessage]):
|
def record_messages(self, conv_id:str, inp:List[BaseMessage]):
|
||||||
|
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
||||||
curr_len = self.get_conv_number(conv_id)
|
curr_len = self.get_conv_number(conv_id)
|
||||||
to_add_msg = inp[curr_len:]
|
to_add_msg = inp[curr_len:]
|
||||||
for msg in to_add_msg:
|
for msg in to_add_msg:
|
||||||
|
|||||||
Reference in New Issue
Block a user