diff --git a/lang_agent/components/conv_store.py b/lang_agent/components/conv_store.py index 9986684..1645761 100644 --- a/lang_agent/components/conv_store.py +++ b/lang_agent/components/conv_store.py @@ -62,7 +62,7 @@ class ConversationStore: """, (conversation_id,)) return cur.fetchall() - def record_messages(self, conv_id:str, inp:List[BaseMessage]): + def record_message_list(self, conv_id:str, inp:List[BaseMessage]): inp = [e for e in inp if not isinstance(e, SystemMessage)] curr_len = self.get_conv_number(conv_id) to_add_msg = inp[curr_len:] @@ -82,4 +82,11 @@ class ConversationStore: else: raise ValueError(f"Unknown message type: {type(msg)}") + +class ConversationPrinter: + def __init__(self): + self.id_dic = {} + + + CONV_STORE = ConversationStore() \ No newline at end of file diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 28cca47..995856e 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -134,7 +134,7 @@ class Pipeline: yield chunk else: logger.info("logged message") - CONV_STORE.record_messages(conv_id, chunk) + CONV_STORE.record_message_list(conv_id, chunk) async def _astream_res(self, out): """Async version of _stream_res for async generators."""