replace jax.tree.leaves with drop in

This commit is contained in:
2026-01-20 17:33:02 +08:00
parent 4290ce6756
commit ac43eb6f27
4 changed files with 20 additions and 19 deletions

View File

@@ -1,11 +1,11 @@
from dataclasses import dataclass, field, is_dataclass
from typing import Type, List, Callable, Any, AsyncIterator
import tyro
import jax
from lang_agent.config import KeyConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.base import GraphBase
from lang_agent.utils import tree_leaves
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
@@ -63,7 +63,7 @@ class ReactGraph(GraphBase):
else:
out = self.agent.invoke(*nargs, **kwargs)
msgs_list = jax.tree.leaves(out)
msgs_list = tree_leaves(out)
for e in msgs_list:
if isinstance(e, BaseMessage):
@@ -87,7 +87,7 @@ class ReactGraph(GraphBase):
else:
out = await self.agent.ainvoke(*nargs, **kwargs)
msgs_list = jax.tree.leaves(out)
msgs_list = tree_leaves(out)
for e in msgs_list:
if isinstance(e, BaseMessage):
@@ -112,14 +112,17 @@ if __name__ == "__main__":
"messages": [SystemMessage("you are a helpful bot named jarvis"),
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
},{"configurable": {"thread_id": "3"}}
for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
node = metadata.get("langgraph_node")
if node not in ("model"):
print(node)
continue # skip router or other intermediate nodes
# Print only the final message content
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
print(chunk.content, end="", flush=True)
out = route.invoke(*nargs)
assert 0
# for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
# node = metadata.get("langgraph_node")
# if node not in ("model"):
# print(node)
# continue # skip router or other intermediate nodes
# # Print only the final message content
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
# print(chunk.content, end="", flush=True)