record pipeline id in conv_store
This commit is contained in:
@@ -6,10 +6,18 @@ import os
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
|
from langchain_core.messages import (
|
||||||
|
HumanMessage,
|
||||||
|
AIMessage,
|
||||||
|
ToolMessage,
|
||||||
|
SystemMessage,
|
||||||
|
BaseMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MessageType(str, Enum):
|
class MessageType(str, Enum):
|
||||||
"""Enum for message types in the conversation store."""
|
"""Enum for message types in the conversation store."""
|
||||||
|
|
||||||
HUMAN = "human"
|
HUMAN = "human"
|
||||||
AI = "ai"
|
AI = "ai"
|
||||||
TOOL = "tool"
|
TOOL = "tool"
|
||||||
@@ -17,9 +25,12 @@ class MessageType(str, Enum):
|
|||||||
|
|
||||||
class BaseConvStore(ABC):
|
class BaseConvStore(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
|
def record_message_list(
|
||||||
|
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ConversationStore(BaseConvStore):
|
class ConversationStore(BaseConvStore):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
conn_str = os.environ.get("CONN_STR")
|
conn_str = os.environ.get("CONN_STR")
|
||||||
@@ -32,46 +43,53 @@ class ConversationStore(BaseConvStore):
|
|||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
msg_type: MessageType,
|
msg_type: MessageType,
|
||||||
content: str,
|
content: str,
|
||||||
sequence: int, # the conversation number
|
sequence: int,
|
||||||
|
pipeline_id: str = None,
|
||||||
):
|
):
|
||||||
with psycopg.connect(self.conn_str) as conn:
|
with psycopg.connect(self.conn_str) as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
# DB schema only supports these columns:
|
|
||||||
# (conversation_id, message_type, content, sequence_number)
|
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO messages (conversation_id, message_type, content, sequence_number)
|
INSERT INTO messages (conversation_id, pipeline_id, message_type, content, sequence_number)
|
||||||
VALUES (%s, %s, %s, %s)
|
VALUES (%s, %s, %s, %s, %s)
|
||||||
""",
|
""",
|
||||||
(conversation_id, msg_type.value, content, sequence),
|
(conversation_id, pipeline_id, msg_type.value, content, sequence),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_conv_number(self, conversation_id: str) -> int:
|
def get_conv_number(self, conversation_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
if the conversation_id does not exist, return 0
|
if the conversation_id does not exist, return 0
|
||||||
if len(conversation) = 3, it will return 3
|
if len(conversation) = 3, it will return 3
|
||||||
"""
|
"""
|
||||||
with psycopg.connect(self.conn_str) as conn:
|
with psycopg.connect(self.conn_str) as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute("""
|
cur.execute(
|
||||||
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE conversation_id = %s
|
WHERE conversation_id = %s
|
||||||
""", (conversation_id,))
|
""",
|
||||||
|
(conversation_id,),
|
||||||
|
)
|
||||||
return int(cur.fetchone()[0])
|
return int(cur.fetchone()[0])
|
||||||
|
|
||||||
def get_conversation(self, conversation_id: str) -> List[Dict]:
|
def get_conversation(self, conversation_id: str) -> List[Dict]:
|
||||||
with psycopg.connect(self.conn_str) as conn:
|
with psycopg.connect(self.conn_str) as conn:
|
||||||
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
||||||
cur.execute("""
|
cur.execute(
|
||||||
|
"""
|
||||||
SELECT message_type, content, sequence_number, created_at
|
SELECT message_type, content, sequence_number, created_at
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE conversation_id = %s
|
WHERE conversation_id = %s
|
||||||
ORDER BY sequence_number ASC
|
ORDER BY sequence_number ASC
|
||||||
""", (conversation_id,))
|
""",
|
||||||
|
(conversation_id,),
|
||||||
|
)
|
||||||
return cur.fetchall()
|
return cur.fetchall()
|
||||||
|
|
||||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
|
def record_message_list(
|
||||||
|
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||||
|
):
|
||||||
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
||||||
curr_len = self.get_conv_number(conv_id)
|
curr_len = self.get_conv_number(conv_id)
|
||||||
to_add_msg = inp[curr_len:]
|
to_add_msg = inp[curr_len:]
|
||||||
@@ -80,12 +98,13 @@ class ConversationStore(BaseConvStore):
|
|||||||
# Serialize dict/list content to JSON string
|
# Serialize dict/list content to JSON string
|
||||||
if not isinstance(content, str):
|
if not isinstance(content, str):
|
||||||
content = json.dumps(content, ensure_ascii=False, indent=4)
|
content = json.dumps(content, ensure_ascii=False, indent=4)
|
||||||
self.add_message(conv_id, self._get_type(msg), content, curr_len + 1)
|
self.add_message(
|
||||||
|
conv_id, self._get_type(msg), content, curr_len + 1, pipeline_id
|
||||||
|
)
|
||||||
curr_len += 1
|
curr_len += 1
|
||||||
return curr_len
|
return curr_len
|
||||||
|
|
||||||
|
def _get_type(self, msg: BaseMessage) -> MessageType:
|
||||||
def _get_type(self, msg:BaseMessage) -> MessageType:
|
|
||||||
if isinstance(msg, HumanMessage):
|
if isinstance(msg, HumanMessage):
|
||||||
return MessageType.HUMAN
|
return MessageType.HUMAN
|
||||||
elif isinstance(msg, AIMessage):
|
elif isinstance(msg, AIMessage):
|
||||||
@@ -100,7 +119,9 @@ class ConversationPrinter(BaseConvStore):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.id_dic = {}
|
self.id_dic = {}
|
||||||
|
|
||||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
|
def record_message_list(
|
||||||
|
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||||
|
):
|
||||||
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
||||||
curr_len = self.id_dic.get(conv_id, 0)
|
curr_len = self.id_dic.get(conv_id, 0)
|
||||||
to_print_msg = inp[curr_len:]
|
to_print_msg = inp[curr_len:]
|
||||||
@@ -113,9 +134,11 @@ class ConversationPrinter(BaseConvStore):
|
|||||||
else:
|
else:
|
||||||
self.id_dic[conv_id] += len(to_print_msg)
|
self.id_dic[conv_id] += len(to_print_msg)
|
||||||
|
|
||||||
|
|
||||||
CONV_STORE = ConversationStore()
|
CONV_STORE = ConversationStore()
|
||||||
# CONV_STORE = ConversationPrinter()
|
# CONV_STORE = ConversationPrinter()
|
||||||
|
|
||||||
|
|
||||||
def use_printer():
|
def use_printer():
|
||||||
global CONV_STORE
|
global CONV_STORE
|
||||||
CONV_STORE = ConversationPrinter()
|
CONV_STORE = ConversationPrinter()
|
||||||
|
|||||||
Reference in New Issue
Block a user