replace jax.tree.leaves with drop in
This commit is contained in:
@@ -6,7 +6,7 @@ from PIL import Image
|
||||
from io import BytesIO
|
||||
import matplotlib.pyplot as plt
|
||||
from loguru import logger
|
||||
import jax
|
||||
from lang_agent.utils import tree_leaves
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
@@ -67,7 +67,7 @@ class GraphBase(ABC):
|
||||
else:
|
||||
state = self.workflow.invoke({"inp": nargs})
|
||||
|
||||
msg_list = jax.tree.leaves(state)
|
||||
msg_list = tree_leaves(state)
|
||||
|
||||
for e in msg_list:
|
||||
if isinstance(e, BaseMessage):
|
||||
@@ -90,7 +90,7 @@ class GraphBase(ABC):
|
||||
else:
|
||||
state = await self.workflow.ainvoke({"inp": nargs})
|
||||
|
||||
msg_list = jax.tree.leaves(state)
|
||||
msg_list = tree_leaves(state)
|
||||
|
||||
for e in msg_list:
|
||||
if isinstance(e, BaseMessage):
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, List, Callable, Any, AsyncIterator
|
||||
import tyro
|
||||
import jax
|
||||
|
||||
from lang_agent.config import KeyConfig
|
||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||
from lang_agent.base import GraphBase
|
||||
from lang_agent.utils import tree_leaves
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||
@@ -63,7 +63,7 @@ class ReactGraph(GraphBase):
|
||||
else:
|
||||
out = self.agent.invoke(*nargs, **kwargs)
|
||||
|
||||
msgs_list = jax.tree.leaves(out)
|
||||
msgs_list = tree_leaves(out)
|
||||
|
||||
for e in msgs_list:
|
||||
if isinstance(e, BaseMessage):
|
||||
@@ -87,7 +87,7 @@ class ReactGraph(GraphBase):
|
||||
else:
|
||||
out = await self.agent.ainvoke(*nargs, **kwargs)
|
||||
|
||||
msgs_list = jax.tree.leaves(out)
|
||||
msgs_list = tree_leaves(out)
|
||||
|
||||
for e in msgs_list:
|
||||
if isinstance(e, BaseMessage):
|
||||
@@ -112,14 +112,17 @@ if __name__ == "__main__":
|
||||
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
||||
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
|
||||
},{"configurable": {"thread_id": "3"}}
|
||||
|
||||
for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
||||
node = metadata.get("langgraph_node")
|
||||
if node not in ("model"):
|
||||
print(node)
|
||||
continue # skip router or other intermediate nodes
|
||||
|
||||
# Print only the final message content
|
||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||
print(chunk.content, end="", flush=True)
|
||||
out = route.invoke(*nargs)
|
||||
assert 0
|
||||
|
||||
# for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
||||
# node = metadata.get("langgraph_node")
|
||||
# if node not in ("model"):
|
||||
# print(node)
|
||||
# continue # skip router or other intermediate nodes
|
||||
|
||||
# # Print only the final message content
|
||||
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||
# print(chunk.content, end="", flush=True)
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterat
|
||||
import tyro
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
import jax
|
||||
import os.path as osp
|
||||
import commentjson
|
||||
import glob
|
||||
|
||||
@@ -11,7 +11,7 @@ from lang_agent.components.tool_manager import ToolManager
|
||||
from lang_agent.components.reit_llm import ReitLLM
|
||||
from lang_agent.base import ToolNodeBase
|
||||
from lang_agent.graphs.graph_states import State, ChattyToolState
|
||||
from lang_agent.utils import make_llm, words_only
|
||||
from lang_agent.utils import make_llm, words_only, tree_leaves
|
||||
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
||||
from langchain.agents import create_agent
|
||||
@@ -269,7 +269,6 @@ def check_mcp_conn():
|
||||
logger.warning(f"MCP server at {mcp_url} check failed: {e}")
|
||||
|
||||
def debug_tool_node():
|
||||
import jax
|
||||
import httpx
|
||||
from langchain_core.messages.base import BaseMessageChunk
|
||||
from lang_agent.components.tool_manager import ToolManagerConfig
|
||||
@@ -302,7 +301,7 @@ def debug_tool_node():
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
for chunk in graph.stream(*input_data, stream_mode="updates"):
|
||||
el = jax.tree.leaves(chunk)[-1]
|
||||
el = tree_leaves(chunk)[-1]
|
||||
el.pretty_print()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
Reference in New Issue
Block a user