This commit is contained in:
2026-01-22 16:02:18 +08:00
7 changed files with 48 additions and 20 deletions

View File

@@ -6,7 +6,7 @@ from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
from loguru import logger
import jax
from lang_agent.utils import tree_leaves
from langgraph.graph.state import CompiledStateGraph
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
@@ -67,7 +67,7 @@ class GraphBase(ABC):
else:
state = self.workflow.invoke({"inp": nargs})
msg_list = jax.tree.leaves(state)
msg_list = tree_leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):
@@ -90,7 +90,7 @@ class GraphBase(ABC):
else:
state = await self.workflow.ainvoke({"inp": nargs})
msg_list = jax.tree.leaves(state)
msg_list = tree_leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):

View File

@@ -1,11 +1,11 @@
from dataclasses import dataclass, field, is_dataclass
from typing import Type, List, Callable, Any, AsyncIterator
import tyro
import jax
from lang_agent.config import KeyConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.base import GraphBase
from lang_agent.utils import tree_leaves
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
@@ -63,7 +63,7 @@ class ReactGraph(GraphBase):
else:
out = self.agent.invoke(*nargs, **kwargs)
msgs_list = jax.tree.leaves(out)
msgs_list = tree_leaves(out)
for e in msgs_list:
if isinstance(e, BaseMessage):
@@ -87,7 +87,7 @@ class ReactGraph(GraphBase):
else:
out = await self.agent.ainvoke(*nargs, **kwargs)
msgs_list = jax.tree.leaves(out)
msgs_list = tree_leaves(out)
for e in msgs_list:
if isinstance(e, BaseMessage):
@@ -112,14 +112,17 @@ if __name__ == "__main__":
"messages": [SystemMessage("you are a helpful bot named jarvis"),
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
},{"configurable": {"thread_id": "3"}}
for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
node = metadata.get("langgraph_node")
if node not in ("model"):
print(node)
continue # skip router or other intermediate nodes
# Print only the final message content
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
print(chunk.content, end="", flush=True)
out = route.invoke(*nargs)
assert 0
# for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
# node = metadata.get("langgraph_node")
# if node not in ("model"):
# print(node)
# continue # skip router or other intermediate nodes
# # Print only the final message content
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
# print(chunk.content, end="", flush=True)

View File

@@ -3,7 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterat
import tyro
from pydantic import BaseModel, Field
from loguru import logger
import jax
import os.path as osp
import commentjson
import glob

View File

@@ -11,7 +11,7 @@ from lang_agent.components.tool_manager import ToolManager
from lang_agent.components.reit_llm import ReitLLM
from lang_agent.base import ToolNodeBase
from lang_agent.graphs.graph_states import State, ChattyToolState
from lang_agent.utils import make_llm, words_only
from lang_agent.utils import make_llm, words_only, tree_leaves
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain.agents import create_agent
@@ -269,7 +269,6 @@ def check_mcp_conn():
logger.warning(f"MCP server at {mcp_url} check failed: {e}")
def debug_tool_node():
import jax
import httpx
from langchain_core.messages.base import BaseMessageChunk
from lang_agent.components.tool_manager import ToolManagerConfig
@@ -302,7 +301,7 @@ def debug_tool_node():
print("Assistant: ", end="", flush=True)
for chunk in graph.stream(*input_data, stream_mode="updates"):
el = jax.tree.leaves(chunk)[-1]
el = tree_leaves(chunk)[-1]
el.pretty_print()
except Exception as e:
print(e)

View File

@@ -21,6 +21,28 @@ def make_llm(model="qwen-plus",
return llm
def tree_leaves(tree):
"""
Extracts all leaf values from a nested structure (dict, list, tuple).
Drop-in replacement for jax.tree.leaves.
"""
leaves = []
stack = [tree]
while stack:
node = stack.pop()
if isinstance(node, dict):
# JAX sorts dict keys alphabetically
sorted_values = [node[k] for k in sorted(node.keys())]
stack.extend(reversed(sorted_values))
elif isinstance(node, (list, tuple)):
stack.extend(reversed(node))
else:
leaves.append(node)
return leaves
NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]')
def words_only(text):
"""

View File

@@ -21,7 +21,6 @@ dependencies = [
"fastapi",
"matplotlib",
"Pillow",
"jax",
"commentjson",
"pandas",
"asgiref"

View File

@@ -135,6 +135,12 @@ examples = [
"answer": "我一直在呢,随时陪你聊聊天、喝杯茶",
}
},
{
"inputs": {"text": "介绍一下你自己"},
"outputs": {
"answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶",
}
},
]
cli = Client()