diff --git a/lang_agent/components/conv_store.py b/lang_agent/components/conv_store.py index 486e5ba..e17a20a 100644 --- a/lang_agent/components/conv_store.py +++ b/lang_agent/components/conv_store.py @@ -4,6 +4,7 @@ from typing import List, Dict, Union from enum import Enum import os from loguru import logger +from abc import ABC, abstractmethod from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage @@ -13,7 +14,13 @@ class MessageType(str, Enum): AI = "ai" TOOL = "tool" -class ConversationStore: + +class BaseConvStore(ABC): + @abstractmethod + def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None): + pass + +class ConversationStore(BaseConvStore): def __init__(self): conn_str = os.environ.get("CONN_STR") if conn_str is None: @@ -64,7 +71,7 @@ class ConversationStore: """, (conversation_id,)) return cur.fetchall() - def record_message_list(self, conv_id:str, inp:List[BaseMessage]): + def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None): inp = [e for e in inp if not isinstance(e, SystemMessage)] curr_len = self.get_conv_number(conv_id) to_add_msg = inp[curr_len:] @@ -89,11 +96,11 @@ class ConversationStore: raise ValueError(f"Unknown message type: {type(msg)}") -class ConversationPrinter: +class ConversationPrinter(BaseConvStore): def __init__(self): self.id_dic = {} - def record_message_list(self, conv_id:str, inp:List[BaseMessage]): + def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None): inp = [e for e in inp if not isinstance(e, SystemMessage)] curr_len = self.id_dic.get(conv_id, 0) to_print_msg = inp[curr_len:]