replace jax.tree.leaves with drop in
This commit is contained in:
@@ -6,7 +6,7 @@ from PIL import Image
|
||||
from io import BytesIO
|
||||
import matplotlib.pyplot as plt
|
||||
from loguru import logger
|
||||
import jax
|
||||
from lang_agent.utils import tree_leaves
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
@@ -67,7 +67,7 @@ class GraphBase(ABC):
|
||||
else:
|
||||
state = self.workflow.invoke({"inp": nargs})
|
||||
|
||||
msg_list = jax.tree.leaves(state)
|
||||
msg_list = tree_leaves(state)
|
||||
|
||||
for e in msg_list:
|
||||
if isinstance(e, BaseMessage):
|
||||
@@ -90,7 +90,7 @@ class GraphBase(ABC):
|
||||
else:
|
||||
state = await self.workflow.ainvoke({"inp": nargs})
|
||||
|
||||
msg_list = jax.tree.leaves(state)
|
||||
msg_list = tree_leaves(state)
|
||||
|
||||
for e in msg_list:
|
||||
if isinstance(e, BaseMessage):
|
||||
|
||||
Reference in New Issue
Block a user