Merge branch 'main' of http://6.6.6.86:3321/Quant-Speed-AI/lang-agent
This commit is contained in:
@@ -6,7 +6,7 @@ from PIL import Image
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import jax
|
from lang_agent.utils import tree_leaves
|
||||||
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
@@ -67,7 +67,7 @@ class GraphBase(ABC):
|
|||||||
else:
|
else:
|
||||||
state = self.workflow.invoke({"inp": nargs})
|
state = self.workflow.invoke({"inp": nargs})
|
||||||
|
|
||||||
msg_list = jax.tree.leaves(state)
|
msg_list = tree_leaves(state)
|
||||||
|
|
||||||
for e in msg_list:
|
for e in msg_list:
|
||||||
if isinstance(e, BaseMessage):
|
if isinstance(e, BaseMessage):
|
||||||
@@ -90,7 +90,7 @@ class GraphBase(ABC):
|
|||||||
else:
|
else:
|
||||||
state = await self.workflow.ainvoke({"inp": nargs})
|
state = await self.workflow.ainvoke({"inp": nargs})
|
||||||
|
|
||||||
msg_list = jax.tree.leaves(state)
|
msg_list = tree_leaves(state)
|
||||||
|
|
||||||
for e in msg_list:
|
for e in msg_list:
|
||||||
if isinstance(e, BaseMessage):
|
if isinstance(e, BaseMessage):
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from dataclasses import dataclass, field, is_dataclass
|
from dataclasses import dataclass, field, is_dataclass
|
||||||
from typing import Type, List, Callable, Any, AsyncIterator
|
from typing import Type, List, Callable, Any, AsyncIterator
|
||||||
import tyro
|
import tyro
|
||||||
import jax
|
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
|
from lang_agent.utils import tree_leaves
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||||
@@ -63,7 +63,7 @@ class ReactGraph(GraphBase):
|
|||||||
else:
|
else:
|
||||||
out = self.agent.invoke(*nargs, **kwargs)
|
out = self.agent.invoke(*nargs, **kwargs)
|
||||||
|
|
||||||
msgs_list = jax.tree.leaves(out)
|
msgs_list = tree_leaves(out)
|
||||||
|
|
||||||
for e in msgs_list:
|
for e in msgs_list:
|
||||||
if isinstance(e, BaseMessage):
|
if isinstance(e, BaseMessage):
|
||||||
@@ -87,7 +87,7 @@ class ReactGraph(GraphBase):
|
|||||||
else:
|
else:
|
||||||
out = await self.agent.ainvoke(*nargs, **kwargs)
|
out = await self.agent.ainvoke(*nargs, **kwargs)
|
||||||
|
|
||||||
msgs_list = jax.tree.leaves(out)
|
msgs_list = tree_leaves(out)
|
||||||
|
|
||||||
for e in msgs_list:
|
for e in msgs_list:
|
||||||
if isinstance(e, BaseMessage):
|
if isinstance(e, BaseMessage):
|
||||||
@@ -112,14 +112,17 @@ if __name__ == "__main__":
|
|||||||
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
||||||
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
|
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
|
||||||
},{"configurable": {"thread_id": "3"}}
|
},{"configurable": {"thread_id": "3"}}
|
||||||
|
|
||||||
for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
|
||||||
node = metadata.get("langgraph_node")
|
|
||||||
if node not in ("model"):
|
|
||||||
print(node)
|
|
||||||
continue # skip router or other intermediate nodes
|
|
||||||
|
|
||||||
# Print only the final message content
|
out = route.invoke(*nargs)
|
||||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
assert 0
|
||||||
print(chunk.content, end="", flush=True)
|
|
||||||
|
# for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
||||||
|
# node = metadata.get("langgraph_node")
|
||||||
|
# if node not in ("model"):
|
||||||
|
# print(node)
|
||||||
|
# continue # skip router or other intermediate nodes
|
||||||
|
|
||||||
|
# # Print only the final message content
|
||||||
|
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||||
|
# print(chunk.content, end="", flush=True)
|
||||||
|
|
||||||
@@ -3,7 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterat
|
|||||||
import tyro
|
import tyro
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import jax
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import commentjson
|
import commentjson
|
||||||
import glob
|
import glob
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from lang_agent.components.tool_manager import ToolManager
|
|||||||
from lang_agent.components.reit_llm import ReitLLM
|
from lang_agent.components.reit_llm import ReitLLM
|
||||||
from lang_agent.base import ToolNodeBase
|
from lang_agent.base import ToolNodeBase
|
||||||
from lang_agent.graphs.graph_states import State, ChattyToolState
|
from lang_agent.graphs.graph_states import State, ChattyToolState
|
||||||
from lang_agent.utils import make_llm, words_only
|
from lang_agent.utils import make_llm, words_only, tree_leaves
|
||||||
|
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
@@ -269,7 +269,6 @@ def check_mcp_conn():
|
|||||||
logger.warning(f"MCP server at {mcp_url} check failed: {e}")
|
logger.warning(f"MCP server at {mcp_url} check failed: {e}")
|
||||||
|
|
||||||
def debug_tool_node():
|
def debug_tool_node():
|
||||||
import jax
|
|
||||||
import httpx
|
import httpx
|
||||||
from langchain_core.messages.base import BaseMessageChunk
|
from langchain_core.messages.base import BaseMessageChunk
|
||||||
from lang_agent.components.tool_manager import ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManagerConfig
|
||||||
@@ -302,7 +301,7 @@ def debug_tool_node():
|
|||||||
|
|
||||||
print("Assistant: ", end="", flush=True)
|
print("Assistant: ", end="", flush=True)
|
||||||
for chunk in graph.stream(*input_data, stream_mode="updates"):
|
for chunk in graph.stream(*input_data, stream_mode="updates"):
|
||||||
el = jax.tree.leaves(chunk)[-1]
|
el = tree_leaves(chunk)[-1]
|
||||||
el.pretty_print()
|
el.pretty_print()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|||||||
@@ -21,6 +21,28 @@ def make_llm(model="qwen-plus",
|
|||||||
|
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
def tree_leaves(tree):
|
||||||
|
"""
|
||||||
|
Extracts all leaf values from a nested structure (dict, list, tuple).
|
||||||
|
Drop-in replacement for jax.tree.leaves.
|
||||||
|
"""
|
||||||
|
leaves = []
|
||||||
|
stack = [tree]
|
||||||
|
|
||||||
|
while stack:
|
||||||
|
node = stack.pop()
|
||||||
|
if isinstance(node, dict):
|
||||||
|
# JAX sorts dict keys alphabetically
|
||||||
|
sorted_values = [node[k] for k in sorted(node.keys())]
|
||||||
|
stack.extend(reversed(sorted_values))
|
||||||
|
elif isinstance(node, (list, tuple)):
|
||||||
|
stack.extend(reversed(node))
|
||||||
|
else:
|
||||||
|
leaves.append(node)
|
||||||
|
|
||||||
|
return leaves
|
||||||
|
|
||||||
|
|
||||||
NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]')
|
NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]')
|
||||||
def words_only(text):
|
def words_only(text):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ dependencies = [
|
|||||||
"fastapi",
|
"fastapi",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"Pillow",
|
"Pillow",
|
||||||
"jax",
|
|
||||||
"commentjson",
|
"commentjson",
|
||||||
"pandas",
|
"pandas",
|
||||||
"asgiref"
|
"asgiref"
|
||||||
|
|||||||
@@ -135,6 +135,12 @@ examples = [
|
|||||||
"answer": "我一直在呢,随时陪你聊聊天、喝杯茶",
|
"answer": "我一直在呢,随时陪你聊聊天、喝杯茶",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "介绍一下你自己"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶",
|
||||||
|
}
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
cli = Client()
|
cli = Client()
|
||||||
|
|||||||
Reference in New Issue
Block a user