replace jax.tree.leaves with drop in

This commit is contained in:
2026-01-20 17:33:02 +08:00
parent 4290ce6756
commit ac43eb6f27
4 changed files with 20 additions and 19 deletions

View File

@@ -11,7 +11,7 @@ from lang_agent.components.tool_manager import ToolManager
from lang_agent.components.reit_llm import ReitLLM
from lang_agent.base import ToolNodeBase
from lang_agent.graphs.graph_states import State, ChattyToolState
from lang_agent.utils import make_llm, words_only
from lang_agent.utils import make_llm, words_only, tree_leaves
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain.agents import create_agent
@@ -269,7 +269,6 @@ def check_mcp_conn():
logger.warning(f"MCP server at {mcp_url} check failed: {e}")
def debug_tool_node():
import jax
import httpx
from langchain_core.messages.base import BaseMessageChunk
from lang_agent.components.tool_manager import ToolManagerConfig
@@ -302,7 +301,7 @@ def debug_tool_node():
print("Assistant: ", end="", flush=True)
for chunk in graph.stream(*input_data, stream_mode="updates"):
el = jax.tree.leaves(chunk)[-1]
el = tree_leaves(chunk)[-1]
el.pretty_print()
except Exception as e:
print(e)