diff --git a/lang_agent/components/conv_store.py b/lang_agent/components/conv_store.py index f4b7e6d..9986684 100644 --- a/lang_agent/components/conv_store.py +++ b/lang_agent/components/conv_store.py @@ -1,10 +1,9 @@ import psycopg -from uuid import UUID from typing import List, Dict, Union from enum import Enum import os -from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage +from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage class MessageType(str, Enum): """Enum for message types in the conversation store.""" @@ -19,22 +18,13 @@ class ConversationStore: raise ValueError("CONN_STR is not set") self.conn_str = conn_str - def _coerce_conversation_id(self, conversation_id: Union[str, UUID]) -> UUID: - if isinstance(conversation_id, UUID): - return conversation_id - try: - return UUID(conversation_id) - except (TypeError, ValueError) as e: - raise ValueError("conversation_id must be a UUID (or UUID string)") from e - def add_message( self, - conversation_id: Union[str, UUID], + conversation_id: str, msg_type: MessageType, content: str, sequence: int, # the conversation number ): - conversation_id = self._coerce_conversation_id(conversation_id) with psycopg.connect(self.conn_str) as conn: with conn.cursor() as cur: # DB schema only supports these columns: @@ -47,12 +37,11 @@ class ConversationStore: (conversation_id, msg_type.value, content, sequence), ) - def get_conv_number(self, conversation_id: Union[str, UUID]) -> int: + def get_conv_number(self, conversation_id: str) -> int: """ if the conversation_id does not exist, return 0 if len(conversation) = 3, it will return 3 """ - conversation_id = self._coerce_conversation_id(conversation_id) with psycopg.connect(self.conn_str) as conn: with conn.cursor() as cur: cur.execute(""" @@ -62,8 +51,7 @@ class ConversationStore: """, (conversation_id,)) return int(cur.fetchone()[0]) - def get_conversation(self, conversation_id: Union[str, UUID]) -> List[Dict]: - conversation_id = self._coerce_conversation_id(conversation_id) + def get_conversation(self, conversation_id: str) -> List[Dict]: with psycopg.connect(self.conn_str) as conn: with conn.cursor(row_factory=psycopg.rows.dict_row) as cur: cur.execute(""" @@ -75,6 +63,7 @@ class ConversationStore: return cur.fetchall() def record_messages(self, conv_id:str, inp:List[BaseMessage]): + inp = [e for e in inp if not isinstance(e, SystemMessage)] curr_len = self.get_conv_number(conv_id) to_add_msg = inp[curr_len:] for msg in to_add_msg: