wrap the self_* tools to specify tool use

This commit is contained in:
2025-12-30 22:44:51 +08:00
parent 69713b3977
commit 21f8fe93de
2 changed files with 150 additions and 13 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Type
from typing import Type, Any, Optional
import tyro
import commentjson
import asyncio
@@ -7,9 +7,58 @@ import os.path as osp
from loguru import logger
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_core.tools import BaseTool, StructuredTool
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field, create_model
from lang_agent.config import InstantiateConfig
class DeviceIdInjectedTool(StructuredTool):
"""
A StructuredTool subclass that injects device_id from RunnableConfig
at the invoke/ainvoke level, before any argument parsing.
"""
def invoke(
self,
input: dict,
config: Optional[RunnableConfig] = None,
**kwargs,
):
logger.info("================================================CONFIG========================")
logger.info(config)
# Inject device_id from config into the input dict
if config and "configurable" in config:
device_id = config["configurable"].get("device_id")
logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}")
# Add device_id to input if it's valid (not None and not "0")
if isinstance(input, dict) and device_id is not None and device_id != "0":
input = {**input, "device_id": device_id}
return super().invoke(input, config, **kwargs)
async def ainvoke(
self,
input: dict,
config: Optional[RunnableConfig] = None,
**kwargs,
):
logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========")
logger.info(f"input: {input}")
logger.info(f"config: {config}")
# Inject device_id from config into the input dict
if config and "configurable" in config:
device_id = config["configurable"].get("device_id")
logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}")
# Add device_id to input if it's valid (not None and not "0")
if isinstance(input, dict) and device_id is not None and device_id != "0":
input = {**input, "device_id": device_id}
return await super().ainvoke(input, config, **kwargs)
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
class ClientToolManagerConfig(InstantiateConfig):
@@ -56,11 +105,105 @@ class ClientToolManager:
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
tools = future.result()
return tools
return self._wrap_tools_with_injected_device_id(tools)
except RuntimeError:
# No event loop running, safe to use asyncio.run()
tools = asyncio.run(self.aget_tools())
return tools
return self._wrap_tools_with_injected_device_id(tools)
def _wrap_tools_with_injected_device_id(self, tools: list) -> list:
"""
Wrap tools that have 'device_id' parameter to inject it from RunnableConfig.
This removes the burden from the LLM to pass device_id explicitly.
"""
wrapped_tools = []
for tool in tools:
wrapped_tools.append(wrap_tool_with_injected_device_id(tool))
return wrapped_tools
def wrap_tool_with_injected_device_id(tool: BaseTool) -> BaseTool:
"""
Wrap a tool to inject 'device_id' from RunnableConfig instead of requiring LLM to pass it.
If the tool doesn't have a device_id parameter, returns the tool unchanged.
Uses DeviceIdInjectedTool which overrides invoke/ainvoke to inject device_id
directly from config before argument parsing.
"""
# Check if tool has device_id in its schema
tool_schema = None
if hasattr(tool, "args_schema") and tool.args_schema is not None:
if isinstance(tool.args_schema, dict):
tool_schema = tool.args_schema
elif hasattr(tool.args_schema, "model_json_schema"):
tool_schema = tool.args_schema.model_json_schema()
elif hasattr(tool.args_schema, "schema"):
tool_schema = tool.args_schema.schema()
elif hasattr(tool, "args") and tool.args is not None:
tool_schema = {"properties": tool.args}
if tool_schema is None:
return tool
properties = tool_schema.get("properties", {})
if "device_id" not in properties:
return tool
# Build a new args_schema WITHOUT device_id visible to LLM
# device_id will be injected at invoke/ainvoke level from config
new_fields = {}
required_fields = tool_schema.get("required", [])
for field_name, field_info in properties.items():
if field_name == "device_id":
# Skip device_id - it will be injected from config, not shown to LLM
continue
else:
# Preserve other fields
field_type = _get_python_type_from_schema(field_info)
is_required = field_name in required_fields
if is_required:
new_fields[field_name] = (field_type, Field(description=field_info.get("description", "")))
else:
new_fields[field_name] = (
Optional[field_type],
Field(default=field_info.get("default"), description=field_info.get("description", ""))
)
# Create the new Pydantic model (without device_id)
NewArgsSchema = create_model(f"{tool.name}Args", **new_fields)
# Get original functions
original_func = tool.func if hasattr(tool, 'func') else None
original_coroutine = tool.coroutine if hasattr(tool, 'coroutine') else None
# Create the new wrapped tool using DeviceIdInjectedTool
# which injects device_id at invoke/ainvoke level
wrapped_tool = DeviceIdInjectedTool(
name=tool.name,
description=tool.description,
args_schema=NewArgsSchema,
func=original_func,
coroutine=original_coroutine,
return_direct=getattr(tool, "return_direct", False),
)
logger.info(f"Wrapped tool '{tool.name}' - type: {type(wrapped_tool).__name__}, has ainvoke: {hasattr(wrapped_tool, 'ainvoke')}")
return wrapped_tool
def _get_python_type_from_schema(field_info: dict) -> type:
"""Convert JSON schema type to Python type."""
json_type = field_info.get("type", "string")
type_mapping = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
}
return type_mapping.get(json_type, Any)
if __name__ == "__main__":
# NOTE: Simple test