replace jax.tree.leaves with drop in

This commit is contained in:
2026-01-20 17:33:02 +08:00
parent 4290ce6756
commit ac43eb6f27
4 changed files with 20 additions and 19 deletions

View File

@@ -6,7 +6,7 @@ from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
from loguru import logger
import jax
from lang_agent.utils import tree_leaves
from langgraph.graph.state import CompiledStateGraph
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
@@ -67,7 +67,7 @@ class GraphBase(ABC):
else:
state = self.workflow.invoke({"inp": nargs})
msg_list = jax.tree.leaves(state)
msg_list = tree_leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):
@@ -90,7 +90,7 @@ class GraphBase(ABC):
else:
state = await self.workflow.ainvoke({"inp": nargs})
msg_list = jax.tree.leaves(state)
msg_list = tree_leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):