wrap the self_* tools to specify tool use

This commit is contained in:
2025-12-30 22:44:51 +08:00
parent 69713b3977
commit 21f8fe93de
2 changed files with 150 additions and 13 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Type from typing import Type, Any, Optional
import tyro import tyro
import commentjson import commentjson
import asyncio import asyncio
@@ -7,9 +7,58 @@ import os.path as osp
from loguru import logger from loguru import logger
from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_core.tools import BaseTool, StructuredTool
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field, create_model
from lang_agent.config import InstantiateConfig from lang_agent.config import InstantiateConfig
class DeviceIdInjectedTool(StructuredTool):
"""
A StructuredTool subclass that injects device_id from RunnableConfig
at the invoke/ainvoke level, before any argument parsing.
"""
def invoke(
self,
input: dict,
config: Optional[RunnableConfig] = None,
**kwargs,
):
logger.info("================================================CONFIG========================")
logger.info(config)
# Inject device_id from config into the input dict
if config and "configurable" in config:
device_id = config["configurable"].get("device_id")
logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}")
# Add device_id to input if it's valid (not None and not "0")
if isinstance(input, dict) and device_id is not None and device_id != "0":
input = {**input, "device_id": device_id}
return super().invoke(input, config, **kwargs)
async def ainvoke(
self,
input: dict,
config: Optional[RunnableConfig] = None,
**kwargs,
):
logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========")
logger.info(f"input: {input}")
logger.info(f"config: {config}")
# Inject device_id from config into the input dict
if config and "configurable" in config:
device_id = config["configurable"].get("device_id")
logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}")
# Add device_id to input if it's valid (not None and not "0")
if isinstance(input, dict) and device_id is not None and device_id != "0":
input = {**input, "device_id": device_id}
return await super().ainvoke(input, config, **kwargs)
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class ClientToolManagerConfig(InstantiateConfig): class ClientToolManagerConfig(InstantiateConfig):
@@ -56,11 +105,105 @@ class ClientToolManager:
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread) future = executor.submit(run_in_thread)
tools = future.result() tools = future.result()
return tools return self._wrap_tools_with_injected_device_id(tools)
except RuntimeError: except RuntimeError:
# No event loop running, safe to use asyncio.run() # No event loop running, safe to use asyncio.run()
tools = asyncio.run(self.aget_tools()) tools = asyncio.run(self.aget_tools())
return tools return self._wrap_tools_with_injected_device_id(tools)
def _wrap_tools_with_injected_device_id(self, tools: list) -> list:
"""
Wrap tools that have 'device_id' parameter to inject it from RunnableConfig.
This removes the burden from the LLM to pass device_id explicitly.
"""
wrapped_tools = []
for tool in tools:
wrapped_tools.append(wrap_tool_with_injected_device_id(tool))
return wrapped_tools
def wrap_tool_with_injected_device_id(tool: BaseTool) -> BaseTool:
"""
Wrap a tool to inject 'device_id' from RunnableConfig instead of requiring LLM to pass it.
If the tool doesn't have a device_id parameter, returns the tool unchanged.
Uses DeviceIdInjectedTool which overrides invoke/ainvoke to inject device_id
directly from config before argument parsing.
"""
# Check if tool has device_id in its schema
tool_schema = None
if hasattr(tool, "args_schema") and tool.args_schema is not None:
if isinstance(tool.args_schema, dict):
tool_schema = tool.args_schema
elif hasattr(tool.args_schema, "model_json_schema"):
tool_schema = tool.args_schema.model_json_schema()
elif hasattr(tool.args_schema, "schema"):
tool_schema = tool.args_schema.schema()
elif hasattr(tool, "args") and tool.args is not None:
tool_schema = {"properties": tool.args}
if tool_schema is None:
return tool
properties = tool_schema.get("properties", {})
if "device_id" not in properties:
return tool
# Build a new args_schema WITHOUT device_id visible to LLM
# device_id will be injected at invoke/ainvoke level from config
new_fields = {}
required_fields = tool_schema.get("required", [])
for field_name, field_info in properties.items():
if field_name == "device_id":
# Skip device_id - it will be injected from config, not shown to LLM
continue
else:
# Preserve other fields
field_type = _get_python_type_from_schema(field_info)
is_required = field_name in required_fields
if is_required:
new_fields[field_name] = (field_type, Field(description=field_info.get("description", "")))
else:
new_fields[field_name] = (
Optional[field_type],
Field(default=field_info.get("default"), description=field_info.get("description", ""))
)
# Create the new Pydantic model (without device_id)
NewArgsSchema = create_model(f"{tool.name}Args", **new_fields)
# Get original functions
original_func = tool.func if hasattr(tool, 'func') else None
original_coroutine = tool.coroutine if hasattr(tool, 'coroutine') else None
# Create the new wrapped tool using DeviceIdInjectedTool
# which injects device_id at invoke/ainvoke level
wrapped_tool = DeviceIdInjectedTool(
name=tool.name,
description=tool.description,
args_schema=NewArgsSchema,
func=original_func,
coroutine=original_coroutine,
return_direct=getattr(tool, "return_direct", False),
)
logger.info(f"Wrapped tool '{tool.name}' - type: {type(wrapped_tool).__name__}, has ainvoke: {hasattr(wrapped_tool, 'ainvoke')}")
return wrapped_tool
def _get_python_type_from_schema(field_info: dict) -> type:
"""Convert JSON schema type to Python type."""
json_type = field_info.get("type", "string")
type_mapping = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
}
return type_mapping.get(json_type, Any)
if __name__ == "__main__": if __name__ == "__main__":
# NOTE: Simple test # NOTE: Simple test

View File

@@ -155,16 +155,10 @@ class ToolManager:
@functools.wraps(func.coroutine) @functools.wraps(func.coroutine)
def sync_func(*args, _wrapper=sync_wrapper, **kwargs): def sync_func(*args, _wrapper=sync_wrapper, **kwargs):
return _wrapper(*args, **kwargs) return _wrapper(*args, **kwargs)
new_tool = StructuredTool( # Preserve the original tool's class (e.g., DeviceIdInjectedTool)
name=func.name, # by setting func directly instead of creating a new StructuredTool
description=func.description, func.func = sync_func
args_schema=func.args_schema, self.langchain_tools.append(func)
func=sync_func,
coroutine=func.coroutine,
metadata=func.metadata if hasattr(func, 'metadata') else None,
return_direct=func.return_direct if hasattr(func, 'return_direct') else False,
)
self.langchain_tools.append(new_tool)
else: else:
self.langchain_tools.append(func) self.langchain_tools.append(func)
else: else: