wrap the self_* tools to specify tool use
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type
|
||||
from typing import Type, Any, Optional
|
||||
import tyro
|
||||
import commentjson
|
||||
import asyncio
|
||||
@@ -7,9 +7,58 @@ import os.path as osp
|
||||
from loguru import logger
|
||||
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from lang_agent.config import InstantiateConfig
|
||||
|
||||
|
||||
class DeviceIdInjectedTool(StructuredTool):
|
||||
"""
|
||||
A StructuredTool subclass that injects device_id from RunnableConfig
|
||||
at the invoke/ainvoke level, before any argument parsing.
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: dict,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs,
|
||||
):
|
||||
logger.info("================================================CONFIG========================")
|
||||
logger.info(config)
|
||||
# Inject device_id from config into the input dict
|
||||
if config and "configurable" in config:
|
||||
device_id = config["configurable"].get("device_id")
|
||||
logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}")
|
||||
|
||||
# Add device_id to input if it's valid (not None and not "0")
|
||||
if isinstance(input, dict) and device_id is not None and device_id != "0":
|
||||
input = {**input, "device_id": device_id}
|
||||
|
||||
return super().invoke(input, config, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: dict,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs,
|
||||
):
|
||||
logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========")
|
||||
logger.info(f"input: {input}")
|
||||
logger.info(f"config: {config}")
|
||||
# Inject device_id from config into the input dict
|
||||
if config and "configurable" in config:
|
||||
device_id = config["configurable"].get("device_id")
|
||||
logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}")
|
||||
|
||||
# Add device_id to input if it's valid (not None and not "0")
|
||||
if isinstance(input, dict) and device_id is not None and device_id != "0":
|
||||
input = {**input, "device_id": device_id}
|
||||
|
||||
return await super().ainvoke(input, config, **kwargs)
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class ClientToolManagerConfig(InstantiateConfig):
|
||||
@@ -56,11 +105,105 @@ class ClientToolManager:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
tools = future.result()
|
||||
return tools
|
||||
return self._wrap_tools_with_injected_device_id(tools)
|
||||
except RuntimeError:
|
||||
# No event loop running, safe to use asyncio.run()
|
||||
tools = asyncio.run(self.aget_tools())
|
||||
return tools
|
||||
return self._wrap_tools_with_injected_device_id(tools)
|
||||
|
||||
def _wrap_tools_with_injected_device_id(self, tools: list) -> list:
|
||||
"""
|
||||
Wrap tools that have 'device_id' parameter to inject it from RunnableConfig.
|
||||
This removes the burden from the LLM to pass device_id explicitly.
|
||||
"""
|
||||
wrapped_tools = []
|
||||
for tool in tools:
|
||||
wrapped_tools.append(wrap_tool_with_injected_device_id(tool))
|
||||
return wrapped_tools
|
||||
|
||||
|
||||
def wrap_tool_with_injected_device_id(tool: BaseTool) -> BaseTool:
|
||||
"""
|
||||
Wrap a tool to inject 'device_id' from RunnableConfig instead of requiring LLM to pass it.
|
||||
If the tool doesn't have a device_id parameter, returns the tool unchanged.
|
||||
|
||||
Uses DeviceIdInjectedTool which overrides invoke/ainvoke to inject device_id
|
||||
directly from config before argument parsing.
|
||||
"""
|
||||
# Check if tool has device_id in its schema
|
||||
tool_schema = None
|
||||
if hasattr(tool, "args_schema") and tool.args_schema is not None:
|
||||
if isinstance(tool.args_schema, dict):
|
||||
tool_schema = tool.args_schema
|
||||
elif hasattr(tool.args_schema, "model_json_schema"):
|
||||
tool_schema = tool.args_schema.model_json_schema()
|
||||
elif hasattr(tool.args_schema, "schema"):
|
||||
tool_schema = tool.args_schema.schema()
|
||||
elif hasattr(tool, "args") and tool.args is not None:
|
||||
tool_schema = {"properties": tool.args}
|
||||
|
||||
if tool_schema is None:
|
||||
return tool
|
||||
|
||||
properties = tool_schema.get("properties", {})
|
||||
if "device_id" not in properties:
|
||||
return tool
|
||||
|
||||
# Build a new args_schema WITHOUT device_id visible to LLM
|
||||
# device_id will be injected at invoke/ainvoke level from config
|
||||
new_fields = {}
|
||||
required_fields = tool_schema.get("required", [])
|
||||
|
||||
for field_name, field_info in properties.items():
|
||||
if field_name == "device_id":
|
||||
# Skip device_id - it will be injected from config, not shown to LLM
|
||||
continue
|
||||
else:
|
||||
# Preserve other fields
|
||||
field_type = _get_python_type_from_schema(field_info)
|
||||
is_required = field_name in required_fields
|
||||
if is_required:
|
||||
new_fields[field_name] = (field_type, Field(description=field_info.get("description", "")))
|
||||
else:
|
||||
new_fields[field_name] = (
|
||||
Optional[field_type],
|
||||
Field(default=field_info.get("default"), description=field_info.get("description", ""))
|
||||
)
|
||||
|
||||
# Create the new Pydantic model (without device_id)
|
||||
NewArgsSchema = create_model(f"{tool.name}Args", **new_fields)
|
||||
|
||||
# Get original functions
|
||||
original_func = tool.func if hasattr(tool, 'func') else None
|
||||
original_coroutine = tool.coroutine if hasattr(tool, 'coroutine') else None
|
||||
|
||||
# Create the new wrapped tool using DeviceIdInjectedTool
|
||||
# which injects device_id at invoke/ainvoke level
|
||||
wrapped_tool = DeviceIdInjectedTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
args_schema=NewArgsSchema,
|
||||
func=original_func,
|
||||
coroutine=original_coroutine,
|
||||
return_direct=getattr(tool, "return_direct", False),
|
||||
)
|
||||
|
||||
logger.info(f"Wrapped tool '{tool.name}' - type: {type(wrapped_tool).__name__}, has ainvoke: {hasattr(wrapped_tool, 'ainvoke')}")
|
||||
return wrapped_tool
|
||||
|
||||
|
||||
def _get_python_type_from_schema(field_info: dict) -> type:
|
||||
"""Convert JSON schema type to Python type."""
|
||||
json_type = field_info.get("type", "string")
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
return type_mapping.get(json_type, Any)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE: Simple test
|
||||
|
||||
Reference in New Issue
Block a user