replace jax.tree.leaves with drop in
This commit is contained in:
@@ -6,7 +6,7 @@ from PIL import Image
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import jax
|
from lang_agent.utils import tree_leaves
|
||||||
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
@@ -67,7 +67,7 @@ class GraphBase(ABC):
|
|||||||
else:
|
else:
|
||||||
state = self.workflow.invoke({"inp": nargs})
|
state = self.workflow.invoke({"inp": nargs})
|
||||||
|
|
||||||
msg_list = jax.tree.leaves(state)
|
msg_list = tree_leaves(state)
|
||||||
|
|
||||||
for e in msg_list:
|
for e in msg_list:
|
||||||
if isinstance(e, BaseMessage):
|
if isinstance(e, BaseMessage):
|
||||||
@@ -90,7 +90,7 @@ class GraphBase(ABC):
|
|||||||
else:
|
else:
|
||||||
state = await self.workflow.ainvoke({"inp": nargs})
|
state = await self.workflow.ainvoke({"inp": nargs})
|
||||||
|
|
||||||
msg_list = jax.tree.leaves(state)
|
msg_list = tree_leaves(state)
|
||||||
|
|
||||||
for e in msg_list:
|
for e in msg_list:
|
||||||
if isinstance(e, BaseMessage):
|
if isinstance(e, BaseMessage):
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from dataclasses import dataclass, field, is_dataclass
|
from dataclasses import dataclass, field, is_dataclass
|
||||||
from typing import Type, List, Callable, Any, AsyncIterator
|
from typing import Type, List, Callable, Any, AsyncIterator
|
||||||
import tyro
|
import tyro
|
||||||
import jax
|
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
|
from lang_agent.utils import tree_leaves
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||||
@@ -63,7 +63,7 @@ class ReactGraph(GraphBase):
|
|||||||
else:
|
else:
|
||||||
out = self.agent.invoke(*nargs, **kwargs)
|
out = self.agent.invoke(*nargs, **kwargs)
|
||||||
|
|
||||||
msgs_list = jax.tree.leaves(out)
|
msgs_list = tree_leaves(out)
|
||||||
|
|
||||||
for e in msgs_list:
|
for e in msgs_list:
|
||||||
if isinstance(e, BaseMessage):
|
if isinstance(e, BaseMessage):
|
||||||
@@ -87,7 +87,7 @@ class ReactGraph(GraphBase):
|
|||||||
else:
|
else:
|
||||||
out = await self.agent.ainvoke(*nargs, **kwargs)
|
out = await self.agent.ainvoke(*nargs, **kwargs)
|
||||||
|
|
||||||
msgs_list = jax.tree.leaves(out)
|
msgs_list = tree_leaves(out)
|
||||||
|
|
||||||
for e in msgs_list:
|
for e in msgs_list:
|
||||||
if isinstance(e, BaseMessage):
|
if isinstance(e, BaseMessage):
|
||||||
@@ -112,14 +112,17 @@ if __name__ == "__main__":
|
|||||||
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
||||||
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
|
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
|
||||||
},{"configurable": {"thread_id": "3"}}
|
},{"configurable": {"thread_id": "3"}}
|
||||||
|
|
||||||
for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
|
||||||
node = metadata.get("langgraph_node")
|
|
||||||
if node not in ("model"):
|
|
||||||
print(node)
|
|
||||||
continue # skip router or other intermediate nodes
|
|
||||||
|
|
||||||
# Print only the final message content
|
out = route.invoke(*nargs)
|
||||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
assert 0
|
||||||
print(chunk.content, end="", flush=True)
|
|
||||||
|
# for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
||||||
|
# node = metadata.get("langgraph_node")
|
||||||
|
# if node not in ("model"):
|
||||||
|
# print(node)
|
||||||
|
# continue # skip router or other intermediate nodes
|
||||||
|
|
||||||
|
# # Print only the final message content
|
||||||
|
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||||
|
# print(chunk.content, end="", flush=True)
|
||||||
|
|
||||||
@@ -3,7 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterat
|
|||||||
import tyro
|
import tyro
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import jax
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import commentjson
|
import commentjson
|
||||||
import glob
|
import glob
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from lang_agent.components.tool_manager import ToolManager
|
|||||||
from lang_agent.components.reit_llm import ReitLLM
|
from lang_agent.components.reit_llm import ReitLLM
|
||||||
from lang_agent.base import ToolNodeBase
|
from lang_agent.base import ToolNodeBase
|
||||||
from lang_agent.graphs.graph_states import State, ChattyToolState
|
from lang_agent.graphs.graph_states import State, ChattyToolState
|
||||||
from lang_agent.utils import make_llm, words_only
|
from lang_agent.utils import make_llm, words_only, tree_leaves
|
||||||
|
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
@@ -269,7 +269,6 @@ def check_mcp_conn():
|
|||||||
logger.warning(f"MCP server at {mcp_url} check failed: {e}")
|
logger.warning(f"MCP server at {mcp_url} check failed: {e}")
|
||||||
|
|
||||||
def debug_tool_node():
|
def debug_tool_node():
|
||||||
import jax
|
|
||||||
import httpx
|
import httpx
|
||||||
from langchain_core.messages.base import BaseMessageChunk
|
from langchain_core.messages.base import BaseMessageChunk
|
||||||
from lang_agent.components.tool_manager import ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManagerConfig
|
||||||
@@ -302,7 +301,7 @@ def debug_tool_node():
|
|||||||
|
|
||||||
print("Assistant: ", end="", flush=True)
|
print("Assistant: ", end="", flush=True)
|
||||||
for chunk in graph.stream(*input_data, stream_mode="updates"):
|
for chunk in graph.stream(*input_data, stream_mode="updates"):
|
||||||
el = jax.tree.leaves(chunk)[-1]
|
el = tree_leaves(chunk)[-1]
|
||||||
el.pretty_print()
|
el.pretty_print()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|||||||
Reference in New Issue
Block a user