diff --git a/lang_agent/base.py b/lang_agent/base.py index 8b8a155..1a2fcdc 100644 --- a/lang_agent/base.py +++ b/lang_agent/base.py @@ -6,13 +6,15 @@ from PIL import Image from io import BytesIO import matplotlib.pyplot as plt from loguru import logger -from lang_agent.utils import tree_leaves from langgraph.graph.state import CompiledStateGraph +from langgraph.checkpoint.memory import MemorySaver + from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langchain_core.messages.base import BaseMessageChunk from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser +from lang_agent.utils import tree_leaves if TYPE_CHECKING: from lang_agent.graphs.graph_states import State @@ -173,6 +175,21 @@ class GraphBase(ABC): plt.show() else: return img + + def clear_memory(self): + # Clear the agent's (LangChain) memory if available + if hasattr(self, "memory") and self.memory is not None: + if isinstance(self.memory, MemorySaver): + for thread_id in self.memory.storage: + self.memory.delete_thread(thread_id) + + async def aclear_memory(self): + # Clear the agent's (LangChain) memory if available (async version) + if hasattr(self, "memory") and self.memory is not None: + if isinstance(self.memory, MemorySaver): + for thread_id in self.memory.storage: + await self.memory.adelete_thread(thread_id) + class ToolNodeBase(GraphBase):