record pipeline id in conv_store
This commit is contained in:
@@ -6,10 +6,18 @@ import os
|
||||
from loguru import logger
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
|
||||
from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
ToolMessage,
|
||||
SystemMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""Enum for message types in the conversation store."""
|
||||
|
||||
HUMAN = "human"
|
||||
AI = "ai"
|
||||
TOOL = "tool"
|
||||
@@ -17,9 +25,12 @@ class MessageType(str, Enum):
|
||||
|
||||
class BaseConvStore(ABC):
|
||||
@abstractmethod
|
||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
|
||||
def record_message_list(
|
||||
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationStore(BaseConvStore):
|
||||
def __init__(self):
|
||||
conn_str = os.environ.get("CONN_STR")
|
||||
@@ -32,18 +43,17 @@ class ConversationStore(BaseConvStore):
|
||||
conversation_id: str,
|
||||
msg_type: MessageType,
|
||||
content: str,
|
||||
sequence: int, # the conversation number
|
||||
sequence: int,
|
||||
pipeline_id: str = None,
|
||||
):
|
||||
with psycopg.connect(self.conn_str) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# DB schema only supports these columns:
|
||||
# (conversation_id, message_type, content, sequence_number)
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO messages (conversation_id, message_type, content, sequence_number)
|
||||
VALUES (%s, %s, %s, %s)
|
||||
INSERT INTO messages (conversation_id, pipeline_id, message_type, content, sequence_number)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""",
|
||||
(conversation_id, msg_type.value, content, sequence),
|
||||
(conversation_id, pipeline_id, msg_type.value, content, sequence),
|
||||
)
|
||||
|
||||
def get_conv_number(self, conversation_id: str) -> int:
|
||||
@@ -53,25 +63,33 @@ class ConversationStore(BaseConvStore):
|
||||
"""
|
||||
with psycopg.connect(self.conn_str) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM messages
|
||||
WHERE conversation_id = %s
|
||||
""", (conversation_id,))
|
||||
""",
|
||||
(conversation_id,),
|
||||
)
|
||||
return int(cur.fetchone()[0])
|
||||
|
||||
def get_conversation(self, conversation_id: str) -> List[Dict]:
|
||||
with psycopg.connect(self.conn_str) as conn:
|
||||
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
||||
cur.execute("""
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT message_type, content, sequence_number, created_at
|
||||
FROM messages
|
||||
WHERE conversation_id = %s
|
||||
ORDER BY sequence_number ASC
|
||||
""", (conversation_id,))
|
||||
""",
|
||||
(conversation_id,),
|
||||
)
|
||||
return cur.fetchall()
|
||||
|
||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
|
||||
def record_message_list(
|
||||
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||
):
|
||||
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
||||
curr_len = self.get_conv_number(conv_id)
|
||||
to_add_msg = inp[curr_len:]
|
||||
@@ -80,11 +98,12 @@ class ConversationStore(BaseConvStore):
|
||||
# Serialize dict/list content to JSON string
|
||||
if not isinstance(content, str):
|
||||
content = json.dumps(content, ensure_ascii=False, indent=4)
|
||||
self.add_message(conv_id, self._get_type(msg), content, curr_len + 1)
|
||||
self.add_message(
|
||||
conv_id, self._get_type(msg), content, curr_len + 1, pipeline_id
|
||||
)
|
||||
curr_len += 1
|
||||
return curr_len
|
||||
|
||||
|
||||
def _get_type(self, msg: BaseMessage) -> MessageType:
|
||||
if isinstance(msg, HumanMessage):
|
||||
return MessageType.HUMAN
|
||||
@@ -100,7 +119,9 @@ class ConversationPrinter(BaseConvStore):
|
||||
def __init__(self):
|
||||
self.id_dic = {}
|
||||
|
||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
|
||||
def record_message_list(
|
||||
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||
):
|
||||
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
||||
curr_len = self.id_dic.get(conv_id, 0)
|
||||
to_print_msg = inp[curr_len:]
|
||||
@@ -113,9 +134,11 @@ class ConversationPrinter(BaseConvStore):
|
||||
else:
|
||||
self.id_dic[conv_id] += len(to_print_msg)
|
||||
|
||||
|
||||
CONV_STORE = ConversationStore()
|
||||
# CONV_STORE = ConversationPrinter()
|
||||
|
||||
|
||||
def use_printer():
|
||||
global CONV_STORE
|
||||
CONV_STORE = ConversationPrinter()
|
||||
|
||||
Reference in New Issue
Block a user