remove memory if new conv in device
This commit is contained in:
@@ -82,6 +82,7 @@ class PipelineConfig(KeyConfig):
|
|||||||
class Pipeline:
|
class Pipeline:
|
||||||
def __init__(self, config:PipelineConfig):
|
def __init__(self, config:PipelineConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.thread_id_cache = {}
|
||||||
|
|
||||||
self.populate_module()
|
self.populate_module()
|
||||||
|
|
||||||
@@ -149,11 +150,15 @@ class Pipeline:
|
|||||||
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
rm_id = self.get_remove_id(thread_id)
|
||||||
|
if rm_id:
|
||||||
|
self.graph.clear_memory(rm_id)
|
||||||
|
|
||||||
device_id = "0"
|
device_id = "0"
|
||||||
spl_ls = thread_id.split("_")
|
spl_ls = thread_id.split("_")
|
||||||
assert len(spl_ls) <= 2, "something wrong!"
|
assert len(spl_ls) <= 2, "something wrong!"
|
||||||
if len(spl_ls) == 2:
|
if len(spl_ls) == 2:
|
||||||
thread_id, device_id = spl_ls
|
_, device_id = spl_ls
|
||||||
|
|
||||||
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
||||||
"device_id":device_id}}
|
"device_id":device_id}}
|
||||||
@@ -166,6 +171,31 @@ class Pipeline:
|
|||||||
else:
|
else:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def get_remove_id(self, thread_id:str) -> bool:
|
||||||
|
"""
|
||||||
|
returns a id to remove if a new conversation has starte
|
||||||
|
"""
|
||||||
|
parts = thread_id.split("_")
|
||||||
|
if len(parts) < 2:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert len(parts) == 2, "should have exactly two parts"
|
||||||
|
|
||||||
|
thread_id, device_id = parts
|
||||||
|
c_th_id = self.thread_id_cache.get(device_id)
|
||||||
|
|
||||||
|
if c_th_id is None:
|
||||||
|
self.thread_id_cache[device_id] = thread_id
|
||||||
|
return None
|
||||||
|
elif c_th_id == thread_id:
|
||||||
|
return None
|
||||||
|
elif c_th_id != thread_id:
|
||||||
|
self.thread_id_cache[device_id] = thread_id
|
||||||
|
return f"{c_th_id}_{device_id}"
|
||||||
|
else:
|
||||||
|
assert 0, "BUG SHOULD NOT BE HERE"
|
||||||
|
|
||||||
|
|
||||||
async def ainvoke(self, *nargs, **kwargs):
|
async def ainvoke(self, *nargs, **kwargs):
|
||||||
"""Async version of invoke using LangGraph's native async support."""
|
"""Async version of invoke using LangGraph's native async support."""
|
||||||
out = await self.graph.ainvoke(*nargs, **kwargs)
|
out = await self.graph.ainvoke(*nargs, **kwargs)
|
||||||
@@ -196,6 +226,10 @@ class Pipeline:
|
|||||||
as_stream (bool): if true, enable the thing to be streamable
|
as_stream (bool): if true, enable the thing to be streamable
|
||||||
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
||||||
"""
|
"""
|
||||||
|
rm_id = self.get_remove_id(thread_id)
|
||||||
|
if rm_id:
|
||||||
|
await self.graph.aclear_memory(rm_id)
|
||||||
|
|
||||||
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
|
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
|
||||||
u = DEFAULT_PROMPT
|
u = DEFAULT_PROMPT
|
||||||
|
|
||||||
@@ -203,7 +237,7 @@ class Pipeline:
|
|||||||
spl_ls = thread_id.split("_")
|
spl_ls = thread_id.split("_")
|
||||||
assert len(spl_ls) <= 2, "something wrong!"
|
assert len(spl_ls) <= 2, "something wrong!"
|
||||||
if len(spl_ls) == 2:
|
if len(spl_ls) == 2:
|
||||||
thread_id, device_id = spl_ls
|
_, device_id = spl_ls
|
||||||
print(f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m")
|
print(f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m")
|
||||||
|
|
||||||
inp_data = {"messages":[SystemMessage(u),
|
inp_data = {"messages":[SystemMessage(u),
|
||||||
|
|||||||
Reference in New Issue
Block a user