update conv store
This commit is contained in:
@@ -1,8 +1,17 @@
|
|||||||
import psycopg
|
import psycopg
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List, Dict, Literal, Union
|
from typing import List, Dict, Union
|
||||||
|
from enum import Enum
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage
|
||||||
|
|
||||||
|
class MessageType(str, Enum):
|
||||||
|
"""Enum for message types in the conversation store."""
|
||||||
|
HUMAN = "human"
|
||||||
|
AI = "ai"
|
||||||
|
TOOL = "tool"
|
||||||
|
|
||||||
class ConversationStore:
|
class ConversationStore:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
conn_str = os.environ.get("CONN_STR")
|
conn_str = os.environ.get("CONN_STR")
|
||||||
@@ -21,7 +30,7 @@ class ConversationStore:
|
|||||||
def add_message(
|
def add_message(
|
||||||
self,
|
self,
|
||||||
conversation_id: Union[str, UUID],
|
conversation_id: Union[str, UUID],
|
||||||
msg_type: Literal["human", "ai", "tool"],
|
msg_type: MessageType,
|
||||||
content: str,
|
content: str,
|
||||||
sequence: int, # the conversation number
|
sequence: int, # the conversation number
|
||||||
):
|
):
|
||||||
@@ -35,15 +44,19 @@ class ConversationStore:
|
|||||||
INSERT INTO messages (conversation_id, message_type, content, sequence_number)
|
INSERT INTO messages (conversation_id, message_type, content, sequence_number)
|
||||||
VALUES (%s, %s, %s, %s)
|
VALUES (%s, %s, %s, %s)
|
||||||
""",
|
""",
|
||||||
(conversation_id, msg_type, content, sequence),
|
(conversation_id, msg_type.value, content, sequence),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_conv_number(self, conversation_id: Union[str, UUID]) -> int:
|
def get_conv_number(self, conversation_id: Union[str, UUID]) -> int:
|
||||||
|
"""
|
||||||
|
if the conversation_id does not exist, return 0
|
||||||
|
if len(conversation) = 3, it will return 3
|
||||||
|
"""
|
||||||
conversation_id = self._coerce_conversation_id(conversation_id)
|
conversation_id = self._coerce_conversation_id(conversation_id)
|
||||||
with psycopg.connect(self.conn_str) as conn:
|
with psycopg.connect(self.conn_str) as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute("""
|
cur.execute("""
|
||||||
SELECT COALESCE(MAX(sequence_number), -1)
|
SELECT COUNT(*)
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE conversation_id = %s
|
WHERE conversation_id = %s
|
||||||
""", (conversation_id,))
|
""", (conversation_id,))
|
||||||
@@ -61,5 +74,23 @@ class ConversationStore:
|
|||||||
""", (conversation_id,))
|
""", (conversation_id,))
|
||||||
return cur.fetchall()
|
return cur.fetchall()
|
||||||
|
|
||||||
|
def record_messages(self, conv_id:str, inp:List[BaseMessage]):
|
||||||
|
curr_len = self.get_conv_number(conv_id)
|
||||||
|
to_add_msg = inp[curr_len:]
|
||||||
|
for msg in to_add_msg:
|
||||||
|
self.add_message(conv_id, self._get_type(msg), msg.content, curr_len + 1)
|
||||||
|
curr_len += 1
|
||||||
|
return curr_len
|
||||||
|
|
||||||
|
|
||||||
|
def _get_type(self, msg:BaseMessage) -> MessageType:
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
return MessageType.HUMAN
|
||||||
|
elif isinstance(msg, AIMessage):
|
||||||
|
return MessageType.AI
|
||||||
|
elif isinstance(msg, ToolMessage):
|
||||||
|
return MessageType.TOOL
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown message type: {type(msg)}")
|
||||||
|
|
||||||
CONV_STORE = ConversationStore()
|
CONV_STORE = ConversationStore()
|
||||||
Reference in New Issue
Block a user