diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 36c5686..8ab9783 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -82,6 +82,7 @@ class PipelineConfig(KeyConfig): class Pipeline: def __init__(self, config:PipelineConfig): self.config = config + self.thread_id_cache = {} self.populate_module() @@ -149,11 +150,15 @@ class Pipeline: as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage] """ + rm_id = self.get_remove_id(thread_id) + if rm_id: + self.graph.clear_memory(rm_id) + device_id = "0" spl_ls = thread_id.split("_") assert len(spl_ls) <= 2, "something wrong!" if len(spl_ls) == 2: - thread_id, device_id = spl_ls + _, device_id = spl_ls inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id, "device_id":device_id}} @@ -165,6 +170,31 @@ class Pipeline: return self._stream_res(out, thread_id) else: return out + + def get_remove_id(self, thread_id:str) -> bool: + """ + returns a id to remove if a new conversation has starte + """ + parts = thread_id.split("_") + if len(parts) < 2: + return None + + assert len(parts) == 2, "should have exactly two parts" + + thread_id, device_id = parts + c_th_id = self.thread_id_cache.get(device_id) + + if c_th_id is None: + self.thread_id_cache[device_id] = thread_id + return None + elif c_th_id == thread_id: + return None + elif c_th_id != thread_id: + self.thread_id_cache[device_id] = thread_id + return f"{c_th_id}_{device_id}" + else: + assert 0, "BUG SHOULD NOT BE HERE" + async def ainvoke(self, *nargs, **kwargs): """Async version of invoke using LangGraph's native async support.""" @@ -196,6 +226,10 @@ class Pipeline: as_stream (bool): if true, enable the thing to be streamable as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage] """ + rm_id = self.get_remove_id(thread_id) + if rm_id: + await self.graph.aclear_memory(rm_id) + # NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph u = DEFAULT_PROMPT @@ -203,7 +237,7 @@ class Pipeline: spl_ls = thread_id.split("_") assert len(spl_ls) <= 2, "something wrong!" if len(spl_ls) == 2: - thread_id, device_id = spl_ls + _, device_id = spl_ls print(f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m") inp_data = {"messages":[SystemMessage(u),