diff --git a/lang_agent/components/conv_store.py b/lang_agent/components/conv_store.py index e17a20a..873e5d4 100644 --- a/lang_agent/components/conv_store.py +++ b/lang_agent/components/conv_store.py @@ -6,10 +6,18 @@ import os from loguru import logger from abc import ABC, abstractmethod -from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage +from langchain_core.messages import ( + HumanMessage, + AIMessage, + ToolMessage, + SystemMessage, + BaseMessage, +) + class MessageType(str, Enum): """Enum for message types in the conversation store.""" + HUMAN = "human" AI = "ai" TOOL = "tool" @@ -17,9 +25,12 @@ class MessageType(str, Enum): class BaseConvStore(ABC): @abstractmethod - def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None): + def record_message_list( + self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None + ): pass + class ConversationStore(BaseConvStore): def __init__(self): conn_str = os.environ.get("CONN_STR") @@ -32,46 +43,53 @@ class ConversationStore(BaseConvStore): conversation_id: str, msg_type: MessageType, content: str, - sequence: int, # the conversation number + sequence: int, + pipeline_id: str = None, ): with psycopg.connect(self.conn_str) as conn: with conn.cursor() as cur: - # DB schema only supports these columns: - # (conversation_id, message_type, content, sequence_number) cur.execute( """ - INSERT INTO messages (conversation_id, message_type, content, sequence_number) - VALUES (%s, %s, %s, %s) + INSERT INTO messages (conversation_id, pipeline_id, message_type, content, sequence_number) + VALUES (%s, %s, %s, %s, %s) """, - (conversation_id, msg_type.value, content, sequence), + (conversation_id, pipeline_id, msg_type.value, content, sequence), ) - + def get_conv_number(self, conversation_id: str) -> int: """ - if the conversation_id does not exist, return 0 - if len(conversation) = 3, it will return 3 + if the conversation_id does not exist, return 0 + if len(conversation) = 3, it will return 3 """ with psycopg.connect(self.conn_str) as conn: with conn.cursor() as cur: - cur.execute(""" + cur.execute( + """ SELECT COUNT(*) FROM messages WHERE conversation_id = %s - """, (conversation_id,)) + """, + (conversation_id,), + ) return int(cur.fetchone()[0]) - + def get_conversation(self, conversation_id: str) -> List[Dict]: with psycopg.connect(self.conn_str) as conn: with conn.cursor(row_factory=psycopg.rows.dict_row) as cur: - cur.execute(""" + cur.execute( + """ SELECT message_type, content, sequence_number, created_at FROM messages WHERE conversation_id = %s ORDER BY sequence_number ASC - """, (conversation_id,)) + """, + (conversation_id,), + ) return cur.fetchall() - - def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None): + + def record_message_list( + self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None + ): inp = [e for e in inp if not isinstance(e, SystemMessage)] curr_len = self.get_conv_number(conv_id) to_add_msg = inp[curr_len:] @@ -80,12 +98,13 @@ class ConversationStore(BaseConvStore): # Serialize dict/list content to JSON string if not isinstance(content, str): content = json.dumps(content, ensure_ascii=False, indent=4) - self.add_message(conv_id, self._get_type(msg), content, curr_len + 1) + self.add_message( + conv_id, self._get_type(msg), content, curr_len + 1, pipeline_id + ) curr_len += 1 return curr_len - - - def _get_type(self, msg:BaseMessage) -> MessageType: + + def _get_type(self, msg: BaseMessage) -> MessageType: if isinstance(msg, HumanMessage): return MessageType.HUMAN elif isinstance(msg, AIMessage): @@ -99,23 +118,27 @@ class ConversationStore(BaseConvStore): class ConversationPrinter(BaseConvStore): def __init__(self): self.id_dic = {} - - def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None): + + def record_message_list( + self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None + ): inp = [e for e in inp if not isinstance(e, SystemMessage)] curr_len = self.id_dic.get(conv_id, 0) to_print_msg = inp[curr_len:] print("\n") for msg in to_print_msg: msg.pretty_print() - + if curr_len == 0: self.id_dic[conv_id] = len(inp) else: self.id_dic[conv_id] += len(to_print_msg) - + + CONV_STORE = ConversationStore() # CONV_STORE = ConversationPrinter() + def use_printer(): global CONV_STORE CONV_STORE = ConversationPrinter()