remove memory if new conv in device
This commit is contained in:
@@ -82,6 +82,7 @@ class PipelineConfig(KeyConfig):
|
||||
class Pipeline:
|
||||
def __init__(self, config:PipelineConfig):
|
||||
self.config = config
|
||||
self.thread_id_cache = {}
|
||||
|
||||
self.populate_module()
|
||||
|
||||
@@ -149,11 +150,15 @@ class Pipeline:
|
||||
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
||||
"""
|
||||
|
||||
rm_id = self.get_remove_id(thread_id)
|
||||
if rm_id:
|
||||
self.graph.clear_memory(rm_id)
|
||||
|
||||
device_id = "0"
|
||||
spl_ls = thread_id.split("_")
|
||||
assert len(spl_ls) <= 2, "something wrong!"
|
||||
if len(spl_ls) == 2:
|
||||
thread_id, device_id = spl_ls
|
||||
_, device_id = spl_ls
|
||||
|
||||
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
||||
"device_id":device_id}}
|
||||
@@ -166,6 +171,31 @@ class Pipeline:
|
||||
else:
|
||||
return out
|
||||
|
||||
def get_remove_id(self, thread_id:str) -> bool:
|
||||
"""
|
||||
returns a id to remove if a new conversation has starte
|
||||
"""
|
||||
parts = thread_id.split("_")
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
|
||||
assert len(parts) == 2, "should have exactly two parts"
|
||||
|
||||
thread_id, device_id = parts
|
||||
c_th_id = self.thread_id_cache.get(device_id)
|
||||
|
||||
if c_th_id is None:
|
||||
self.thread_id_cache[device_id] = thread_id
|
||||
return None
|
||||
elif c_th_id == thread_id:
|
||||
return None
|
||||
elif c_th_id != thread_id:
|
||||
self.thread_id_cache[device_id] = thread_id
|
||||
return f"{c_th_id}_{device_id}"
|
||||
else:
|
||||
assert 0, "BUG SHOULD NOT BE HERE"
|
||||
|
||||
|
||||
async def ainvoke(self, *nargs, **kwargs):
|
||||
"""Async version of invoke using LangGraph's native async support."""
|
||||
out = await self.graph.ainvoke(*nargs, **kwargs)
|
||||
@@ -196,6 +226,10 @@ class Pipeline:
|
||||
as_stream (bool): if true, enable the thing to be streamable
|
||||
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
||||
"""
|
||||
rm_id = self.get_remove_id(thread_id)
|
||||
if rm_id:
|
||||
await self.graph.aclear_memory(rm_id)
|
||||
|
||||
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
|
||||
u = DEFAULT_PROMPT
|
||||
|
||||
@@ -203,7 +237,7 @@ class Pipeline:
|
||||
spl_ls = thread_id.split("_")
|
||||
assert len(spl_ls) <= 2, "something wrong!"
|
||||
if len(spl_ls) == 2:
|
||||
thread_id, device_id = spl_ls
|
||||
_, device_id = spl_ls
|
||||
print(f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m")
|
||||
|
||||
inp_data = {"messages":[SystemMessage(u),
|
||||
|
||||
Reference in New Issue
Block a user