remove memory if new conv in device

This commit is contained in:
2026-02-02 13:37:42 +08:00
parent 776be9ee22
commit e99da5093e

View File

@@ -82,6 +82,7 @@ class PipelineConfig(KeyConfig):
class Pipeline:
def __init__(self, config:PipelineConfig):
self.config = config
self.thread_id_cache = {}
self.populate_module()
@@ -149,11 +150,15 @@ class Pipeline:
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
"""
rm_id = self.get_remove_id(thread_id)
if rm_id:
self.graph.clear_memory(rm_id)
device_id = "0"
spl_ls = thread_id.split("_")
assert len(spl_ls) <= 2, "something wrong!"
if len(spl_ls) == 2:
thread_id, device_id = spl_ls
_, device_id = spl_ls
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
"device_id":device_id}}
@@ -165,6 +170,31 @@ class Pipeline:
return self._stream_res(out, thread_id)
else:
return out
def get_remove_id(self, thread_id:str) -> bool:
"""
returns a id to remove if a new conversation has starte
"""
parts = thread_id.split("_")
if len(parts) < 2:
return None
assert len(parts) == 2, "should have exactly two parts"
thread_id, device_id = parts
c_th_id = self.thread_id_cache.get(device_id)
if c_th_id is None:
self.thread_id_cache[device_id] = thread_id
return None
elif c_th_id == thread_id:
return None
elif c_th_id != thread_id:
self.thread_id_cache[device_id] = thread_id
return f"{c_th_id}_{device_id}"
else:
assert 0, "BUG SHOULD NOT BE HERE"
async def ainvoke(self, *nargs, **kwargs):
"""Async version of invoke using LangGraph's native async support."""
@@ -196,6 +226,10 @@ class Pipeline:
as_stream (bool): if true, enable the thing to be streamable
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
"""
rm_id = self.get_remove_id(thread_id)
if rm_id:
await self.graph.aclear_memory(rm_id)
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
u = DEFAULT_PROMPT
@@ -203,7 +237,7 @@ class Pipeline:
spl_ls = thread_id.split("_")
assert len(spl_ls) <= 2, "something wrong!"
if len(spl_ls) == 2:
thread_id, device_id = spl_ls
_, device_id = spl_ls
print(f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m")
inp_data = {"messages":[SystemMessage(u),