From 609e31c9ad75c6f9db2bf3edd2829afb2d3d3b3c Mon Sep 17 00:00:00 2001 From: goulustis Date: Fri, 30 Jan 2026 09:13:00 +0800 Subject: [PATCH] update conv store --- lang_agent/components/conv_store.py | 41 +++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/lang_agent/components/conv_store.py b/lang_agent/components/conv_store.py index 87f781b..f4b7e6d 100644 --- a/lang_agent/components/conv_store.py +++ b/lang_agent/components/conv_store.py @@ -1,8 +1,17 @@ import psycopg from uuid import UUID -from typing import List, Dict, Literal, Union +from typing import List, Dict, Union +from enum import Enum import os +from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage + +class MessageType(str, Enum): + """Enum for message types in the conversation store.""" + HUMAN = "human" + AI = "ai" + TOOL = "tool" + class ConversationStore: def __init__(self): conn_str = os.environ.get("CONN_STR") @@ -21,7 +30,7 @@ class ConversationStore: def add_message( self, conversation_id: Union[str, UUID], - msg_type: Literal["human", "ai", "tool"], + msg_type: MessageType, content: str, sequence: int, # the conversation number ): @@ -35,15 +44,19 @@ class ConversationStore: INSERT INTO messages (conversation_id, message_type, content, sequence_number) VALUES (%s, %s, %s, %s) """, - (conversation_id, msg_type, content, sequence), + (conversation_id, msg_type.value, content, sequence), ) def get_conv_number(self, conversation_id: Union[str, UUID]) -> int: + """ + if the conversation_id does not exist, return 0 + if len(conversation) = 3, it will return 3 + """ conversation_id = self._coerce_conversation_id(conversation_id) with psycopg.connect(self.conn_str) as conn: with conn.cursor() as cur: cur.execute(""" - SELECT COALESCE(MAX(sequence_number), -1) + SELECT COUNT(*) FROM messages WHERE conversation_id = %s """, (conversation_id,)) @@ -60,6 +73,24 @@ class ConversationStore: ORDER BY sequence_number ASC """, (conversation_id,)) return cur.fetchall() - + + def record_messages(self, conv_id:str, inp:List[BaseMessage]): + curr_len = self.get_conv_number(conv_id) + to_add_msg = inp[curr_len:] + for msg in to_add_msg: + self.add_message(conv_id, self._get_type(msg), msg.content, curr_len + 1) + curr_len += 1 + return curr_len + + + def _get_type(self, msg:BaseMessage) -> MessageType: + if isinstance(msg, HumanMessage): + return MessageType.HUMAN + elif isinstance(msg, AIMessage): + return MessageType.AI + elif isinstance(msg, ToolMessage): + return MessageType.TOOL + else: + raise ValueError(f"Unknown message type: {type(msg)}") CONV_STORE = ConversationStore() \ No newline at end of file