diff --git a/lang_agent/base.py b/lang_agent/base.py index 7e33f35..0e32e6d 100644 --- a/lang_agent/base.py +++ b/lang_agent/base.py @@ -6,7 +6,7 @@ from PIL import Image from io import BytesIO import matplotlib.pyplot as plt from loguru import logger -import jax +from lang_agent.utils import tree_leaves from langgraph.graph.state import CompiledStateGraph from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage @@ -67,7 +67,7 @@ class GraphBase(ABC): else: state = self.workflow.invoke({"inp": nargs}) - msg_list = jax.tree.leaves(state) + msg_list = tree_leaves(state) for e in msg_list: if isinstance(e, BaseMessage): @@ -90,7 +90,7 @@ class GraphBase(ABC): else: state = await self.workflow.ainvoke({"inp": nargs}) - msg_list = jax.tree.leaves(state) + msg_list = tree_leaves(state) for e in msg_list: if isinstance(e, BaseMessage): diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index 9960b5a..a488336 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -1,11 +1,11 @@ from dataclasses import dataclass, field, is_dataclass from typing import Type, List, Callable, Any, AsyncIterator import tyro -import jax from lang_agent.config import KeyConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.base import GraphBase +from lang_agent.utils import tree_leaves from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage @@ -63,7 +63,7 @@ class ReactGraph(GraphBase): else: out = self.agent.invoke(*nargs, **kwargs) - msgs_list = jax.tree.leaves(out) + msgs_list = tree_leaves(out) for e in msgs_list: if isinstance(e, BaseMessage): @@ -87,7 +87,7 @@ class ReactGraph(GraphBase): else: out = await self.agent.ainvoke(*nargs, **kwargs) - msgs_list = jax.tree.leaves(out) + msgs_list = tree_leaves(out) for e in msgs_list: if isinstance(e, BaseMessage): @@ -112,14 +112,17 @@ if __name__ == "__main__": "messages": [SystemMessage("you are a helpful bot named jarvis"), HumanMessage("use the calculator tool to calculate 92*55 and say the answer")] },{"configurable": {"thread_id": "3"}} - - for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"): - node = metadata.get("langgraph_node") - if node not in ("model"): - print(node) - continue # skip router or other intermediate nodes - # Print only the final message content - if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): - print(chunk.content, end="", flush=True) + out = route.invoke(*nargs) + assert 0 + + # for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"): + # node = metadata.get("langgraph_node") + # if node not in ("model"): + # print(node) + # continue # skip router or other intermediate nodes + + # # Print only the final message content + # if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None): + # print(chunk.content, end="", flush=True) \ No newline at end of file diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 29ab0a5..b135134 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -3,7 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterat import tyro from pydantic import BaseModel, Field from loguru import logger -import jax import os.path as osp import commentjson import glob diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index b59fb4d..3923f2d 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -11,7 +11,7 @@ from lang_agent.components.tool_manager import ToolManager from lang_agent.components.reit_llm import ReitLLM from lang_agent.base import ToolNodeBase from lang_agent.graphs.graph_states import State, ChattyToolState -from lang_agent.utils import make_llm, words_only +from lang_agent.utils import make_llm, words_only, tree_leaves from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage from langchain.agents import create_agent @@ -269,7 +269,6 @@ def check_mcp_conn(): logger.warning(f"MCP server at {mcp_url} check failed: {e}") def debug_tool_node(): - import jax import httpx from langchain_core.messages.base import BaseMessageChunk from lang_agent.components.tool_manager import ToolManagerConfig @@ -302,7 +301,7 @@ def debug_tool_node(): print("Assistant: ", end="", flush=True) for chunk in graph.stream(*input_data, stream_mode="updates"): - el = jax.tree.leaves(chunk)[-1] + el = tree_leaves(chunk)[-1] el.pretty_print() except Exception as e: print(e)