add pipeline id to conv_store
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import List, Dict, Union
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
import os
|
import os
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
|
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
|
||||||
|
|
||||||
@@ -13,7 +14,13 @@ class MessageType(str, Enum):
|
|||||||
AI = "ai"
|
AI = "ai"
|
||||||
TOOL = "tool"
|
TOOL = "tool"
|
||||||
|
|
||||||
class ConversationStore:
|
|
||||||
|
class BaseConvStore(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ConversationStore(BaseConvStore):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
conn_str = os.environ.get("CONN_STR")
|
conn_str = os.environ.get("CONN_STR")
|
||||||
if conn_str is None:
|
if conn_str is None:
|
||||||
@@ -64,7 +71,7 @@ class ConversationStore:
|
|||||||
""", (conversation_id,))
|
""", (conversation_id,))
|
||||||
return cur.fetchall()
|
return cur.fetchall()
|
||||||
|
|
||||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage]):
|
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
|
||||||
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
||||||
curr_len = self.get_conv_number(conv_id)
|
curr_len = self.get_conv_number(conv_id)
|
||||||
to_add_msg = inp[curr_len:]
|
to_add_msg = inp[curr_len:]
|
||||||
@@ -89,11 +96,11 @@ class ConversationStore:
|
|||||||
raise ValueError(f"Unknown message type: {type(msg)}")
|
raise ValueError(f"Unknown message type: {type(msg)}")
|
||||||
|
|
||||||
|
|
||||||
class ConversationPrinter:
|
class ConversationPrinter(BaseConvStore):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.id_dic = {}
|
self.id_dic = {}
|
||||||
|
|
||||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage]):
|
def record_message_list(self, conv_id:str, inp:List[BaseMessage], pipeline_id:str=None):
|
||||||
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
||||||
curr_len = self.id_dic.get(conv_id, 0)
|
curr_len = self.id_dic.get(conv_id, 0)
|
||||||
to_print_msg = inp[curr_len:]
|
to_print_msg = inp[curr_len:]
|
||||||
|
|||||||
Reference in New Issue
Block a user