diff --git a/lang_agent/base.py b/lang_agent/base.py index 7e33f35..0e32e6d 100644 --- a/lang_agent/base.py +++ b/lang_agent/base.py @@ -6,7 +6,7 @@ from PIL import Image from io import BytesIO import matplotlib.pyplot as plt from loguru import logger -import jax +from lang_agent.utils import tree_leaves from langgraph.graph.state import CompiledStateGraph from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage @@ -67,7 +67,7 @@ class GraphBase(ABC): else: state = self.workflow.invoke({"inp": nargs}) - msg_list = jax.tree.leaves(state) + msg_list = tree_leaves(state) for e in msg_list: if isinstance(e, BaseMessage): @@ -90,7 +90,7 @@ class GraphBase(ABC): else: state = await self.workflow.ainvoke({"inp": nargs}) - msg_list = jax.tree.leaves(state) + msg_list = tree_leaves(state) for e in msg_list: if isinstance(e, BaseMessage): diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index 9960b5a..a488336 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -1,11 +1,11 @@ from dataclasses import dataclass, field, is_dataclass from typing import Type, List, Callable, Any, AsyncIterator import tyro -import jax from lang_agent.config import KeyConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.base import GraphBase +from lang_agent.utils import tree_leaves from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage @@ -63,7 +63,7 @@ class ReactGraph(GraphBase): else: out = self.agent.invoke(*nargs, **kwargs) - msgs_list = jax.tree.leaves(out) + msgs_list = tree_leaves(out) for e in msgs_list: if isinstance(e, BaseMessage): @@ -87,7 +87,7 @@ class ReactGraph(GraphBase): else: out = await self.agent.ainvoke(*nargs, **kwargs) - msgs_list = jax.tree.leaves(out) + msgs_list = tree_leaves(out) for e in msgs_list: if isinstance(e, BaseMessage): @@ -112,14 +112,17 @@ if __name__ == "__main__": "messages": [SystemMessage("you are a helpful bot named jarvis"), HumanMessage("use the calculator tool to calculate 92*55 and say the answer")] },{"configurable": {"thread_id": "3"}} - - for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"): - node = metadata.get("langgraph_node") - if node not in ("model"): - print(node) - continue # skip router or other intermediate nodes - # Print only the final message content - if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): - print(chunk.content, end="", flush=True) + out = route.invoke(*nargs) + assert 0 + + # for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"): + # node = metadata.get("langgraph_node") + # if node not in ("model"): + # print(node) + # continue # skip router or other intermediate nodes + + # # Print only the final message content + # if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + # print(chunk.content, end="", flush=True) \ No newline at end of file diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 29ab0a5..b135134 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -3,7 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterat import tyro from pydantic import BaseModel, Field from loguru import logger -import jax import os.path as osp import commentjson import glob diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index b59fb4d..3923f2d 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -11,7 +11,7 @@ from lang_agent.components.tool_manager import ToolManager from lang_agent.components.reit_llm import ReitLLM from lang_agent.base import ToolNodeBase from lang_agent.graphs.graph_states import State, ChattyToolState -from lang_agent.utils import make_llm, words_only +from lang_agent.utils import make_llm, words_only, tree_leaves from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage from langchain.agents import create_agent @@ -269,7 +269,6 @@ def check_mcp_conn(): logger.warning(f"MCP server at {mcp_url} check failed: {e}") def debug_tool_node(): - import jax import httpx from langchain_core.messages.base import BaseMessageChunk from lang_agent.components.tool_manager import ToolManagerConfig @@ -302,7 +301,7 @@ def debug_tool_node(): print("Assistant: ", end="", flush=True) for chunk in graph.stream(*input_data, stream_mode="updates"): - el = jax.tree.leaves(chunk)[-1] + el = tree_leaves(chunk)[-1] el.pretty_print() except Exception as e: print(e) diff --git a/lang_agent/utils.py b/lang_agent/utils.py index d5ce27e..87f65f8 100644 --- a/lang_agent/utils.py +++ b/lang_agent/utils.py @@ -21,6 +21,28 @@ def make_llm(model="qwen-plus", return llm +def tree_leaves(tree): + """ + Extracts all leaf values from a nested structure (dict, list, tuple). + Drop-in replacement for jax.tree.leaves. + """ + leaves = [] + stack = [tree] + + while stack: + node = stack.pop() + if isinstance(node, dict): + # JAX sorts dict keys alphabetically + sorted_values = [node[k] for k in sorted(node.keys())] + stack.extend(reversed(sorted_values)) + elif isinstance(node, (list, tuple)): + stack.extend(reversed(node)) + else: + leaves.append(node) + + return leaves + + NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]') def words_only(text): """ diff --git a/pyproject.toml b/pyproject.toml index 9068613..2e0b2c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ dependencies = [ "fastapi", "matplotlib", "Pillow", - "jax", "commentjson", "pandas", "asgiref" diff --git a/scripts/make_eval_dataset.py b/scripts/make_eval_dataset.py index 1ffafad..595ed24 100644 --- a/scripts/make_eval_dataset.py +++ b/scripts/make_eval_dataset.py @@ -135,6 +135,12 @@ examples = [ "answer": "我一直在呢,随时陪你聊聊天、喝杯茶", } }, + { + "inputs": {"text": "介绍一下你自己"}, + "outputs": { + "answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶", + } + }, ] cli = Client()