From ac43eb6f27dfe957668711332dfa8579e9d73e35 Mon Sep 17 00:00:00 2001 From: goulustis Date: Tue, 20 Jan 2026 17:33:02 +0800 Subject: [PATCH 1/5] replace jax.tree.leaves with drop in --- lang_agent/base.py | 6 +++--- lang_agent/graphs/react.py | 27 +++++++++++++++------------ lang_agent/graphs/routing.py | 1 - lang_agent/graphs/tool_nodes.py | 5 ++--- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/lang_agent/base.py b/lang_agent/base.py index 7e33f35..0e32e6d 100644 --- a/lang_agent/base.py +++ b/lang_agent/base.py @@ -6,7 +6,7 @@ from PIL import Image from io import BytesIO import matplotlib.pyplot as plt from loguru import logger -import jax +from lang_agent.utils import tree_leaves from langgraph.graph.state import CompiledStateGraph from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage @@ -67,7 +67,7 @@ class GraphBase(ABC): else: state = self.workflow.invoke({"inp": nargs}) - msg_list = jax.tree.leaves(state) + msg_list = tree_leaves(state) for e in msg_list: if isinstance(e, BaseMessage): @@ -90,7 +90,7 @@ class GraphBase(ABC): else: state = await self.workflow.ainvoke({"inp": nargs}) - msg_list = jax.tree.leaves(state) + msg_list = tree_leaves(state) for e in msg_list: if isinstance(e, BaseMessage): diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index 9960b5a..a488336 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -1,11 +1,11 @@ from dataclasses import dataclass, field, is_dataclass from typing import Type, List, Callable, Any, AsyncIterator import tyro -import jax from lang_agent.config import KeyConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.base import GraphBase +from lang_agent.utils import tree_leaves from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage @@ -63,7 +63,7 @@ class ReactGraph(GraphBase): else: out = self.agent.invoke(*nargs, **kwargs) - msgs_list = jax.tree.leaves(out) + msgs_list = tree_leaves(out) for e in msgs_list: if isinstance(e, BaseMessage): @@ -87,7 +87,7 @@ class ReactGraph(GraphBase): else: out = await self.agent.ainvoke(*nargs, **kwargs) - msgs_list = jax.tree.leaves(out) + msgs_list = tree_leaves(out) for e in msgs_list: if isinstance(e, BaseMessage): @@ -112,14 +112,17 @@ if __name__ == "__main__": "messages": [SystemMessage("you are a helpful bot named jarvis"), HumanMessage("use the calculator tool to calculate 92*55 and say the answer")] },{"configurable": {"thread_id": "3"}} - - for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"): - node = metadata.get("langgraph_node") - if node not in ("model"): - print(node) - continue # skip router or other intermediate nodes - # Print only the final message content - if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): - print(chunk.content, end="", flush=True) + out = route.invoke(*nargs) + assert 0 + + # for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"): + # node = metadata.get("langgraph_node") + # if node not in ("model"): + # print(node) + # continue # skip router or other intermediate nodes + + # # Print only the final message content + # if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + # print(chunk.content, end="", flush=True) \ No newline at end of file diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 29ab0a5..b135134 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -3,7 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterat import tyro from pydantic import BaseModel, Field from loguru import logger -import jax import os.path as osp import commentjson import glob diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index b59fb4d..3923f2d 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -11,7 +11,7 @@ from lang_agent.components.tool_manager import ToolManager from lang_agent.components.reit_llm import ReitLLM from lang_agent.base import ToolNodeBase from lang_agent.graphs.graph_states import State, ChattyToolState -from lang_agent.utils import make_llm, words_only +from lang_agent.utils import make_llm, words_only, tree_leaves from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage from langchain.agents import create_agent @@ -269,7 +269,6 @@ def check_mcp_conn(): logger.warning(f"MCP server at {mcp_url} check failed: {e}") def debug_tool_node(): - import jax import httpx from langchain_core.messages.base import BaseMessageChunk from lang_agent.components.tool_manager import ToolManagerConfig @@ -302,7 +301,7 @@ def debug_tool_node(): print("Assistant: ", end="", flush=True) for chunk in graph.stream(*input_data, stream_mode="updates"): - el = jax.tree.leaves(chunk)[-1] + el = tree_leaves(chunk)[-1] el.pretty_print() except Exception as e: print(e) From b7fce5d973a970b50a45d5e74b6d5574b5699499 Mon Sep 17 00:00:00 2001 From: goulustis Date: Tue, 20 Jan 2026 17:33:13 +0800 Subject: [PATCH 2/5] jax drop in replacement --- lang_agent/utils.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/lang_agent/utils.py b/lang_agent/utils.py index d5ce27e..3f4cce0 100644 --- a/lang_agent/utils.py +++ b/lang_agent/utils.py @@ -21,6 +21,26 @@ def make_llm(model="qwen-plus", return llm +def tree_leaves(tree): + """ + Extracts all leaf values from a nested structure (dict, list, tuple). + Drop-in replacement for jax.tree.leaves. + """ + leaves = [] + stack = [tree] + + while stack: + node = stack.pop() + if isinstance(node, dict): + stack.extend(reversed(node.values())) + elif isinstance(node, (list, tuple)): + stack.extend(reversed(node)) + else: + leaves.append(node) + + return leaves + + NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]') def words_only(text): """ From 0839c9bcee98915c1480d7dcbe204c71d8412124 Mon Sep 17 00:00:00 2001 From: goulustis Date: Tue, 20 Jan 2026 17:33:28 +0800 Subject: [PATCH 3/5] remove jax requirement --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9068613..2e0b2c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ dependencies = [ "fastapi", "matplotlib", "Pillow", - "jax", "commentjson", "pandas", "asgiref" From afe94fd9f6adf68e9ee9f5b0f4cec42fb2fc64af Mon Sep 17 00:00:00 2001 From: goulustis Date: Tue, 20 Jan 2026 17:53:32 +0800 Subject: [PATCH 4/5] bug fix --- lang_agent/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lang_agent/utils.py b/lang_agent/utils.py index 3f4cce0..87f65f8 100644 --- a/lang_agent/utils.py +++ b/lang_agent/utils.py @@ -32,7 +32,9 @@ def tree_leaves(tree): while stack: node = stack.pop() if isinstance(node, dict): - stack.extend(reversed(node.values())) + # JAX sorts dict keys alphabetically + sorted_values = [node[k] for k in sorted(node.keys())] + stack.extend(reversed(sorted_values)) elif isinstance(node, (list, tuple)): stack.extend(reversed(node)) else: From 5514f73aa550b86076ce82ce2a81ae9453705d6f Mon Sep 17 00:00:00 2001 From: goulustis Date: Tue, 20 Jan 2026 17:54:03 +0800 Subject: [PATCH 5/5] update example in dataset --- scripts/make_eval_dataset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scripts/make_eval_dataset.py b/scripts/make_eval_dataset.py index 1ffafad..595ed24 100644 --- a/scripts/make_eval_dataset.py +++ b/scripts/make_eval_dataset.py @@ -135,6 +135,12 @@ examples = [ "answer": "我一直在呢,随时陪你聊聊天、喝杯茶", } }, + { + "inputs": {"text": "介绍一下你自己"}, + "outputs": { + "answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶", + } + }, ] cli = Client()