From 21f8fe93de05a1f946f4e34c0103d762a418c4d5 Mon Sep 17 00:00:00 2001 From: goulustis Date: Tue, 30 Dec 2025 22:44:51 +0800 Subject: [PATCH] wrap the self_* tools to specify tool use --- lang_agent/components/client_tool_manager.py | 149 ++++++++++++++++++- lang_agent/components/tool_manager.py | 14 +- 2 files changed, 150 insertions(+), 13 deletions(-) diff --git a/lang_agent/components/client_tool_manager.py b/lang_agent/components/client_tool_manager.py index 5a51d69..699dfad 100644 --- a/lang_agent/components/client_tool_manager.py +++ b/lang_agent/components/client_tool_manager.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Type +from typing import Type, Any, Optional import tyro import commentjson import asyncio @@ -7,9 +7,58 @@ import os.path as osp from loguru import logger from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_core.tools import BaseTool, StructuredTool +from langchain_core.runnables import RunnableConfig +from pydantic import BaseModel, Field, create_model from lang_agent.config import InstantiateConfig + +class DeviceIdInjectedTool(StructuredTool): + """ + A StructuredTool subclass that injects device_id from RunnableConfig + at the invoke/ainvoke level, before any argument parsing. + """ + + def invoke( + self, + input: dict, + config: Optional[RunnableConfig] = None, + **kwargs, + ): + logger.info("================================================CONFIG========================") + logger.info(config) + # Inject device_id from config into the input dict + if config and "configurable" in config: + device_id = config["configurable"].get("device_id") + logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}") + + # Add device_id to input if it's valid (not None and not "0") + if isinstance(input, dict) and device_id is not None and device_id != "0": + input = {**input, "device_id": device_id} + + return super().invoke(input, config, **kwargs) + + async def ainvoke( + self, + input: dict, + config: Optional[RunnableConfig] = None, + **kwargs, + ): + logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========") + logger.info(f"input: {input}") + logger.info(f"config: {config}") + # Inject device_id from config into the input dict + if config and "configurable" in config: + device_id = config["configurable"].get("device_id") + logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}") + + # Add device_id to input if it's valid (not None and not "0") + if isinstance(input, dict) and device_id is not None and device_id != "0": + input = {**input, "device_id": device_id} + + return await super().ainvoke(input, config, **kwargs) + @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class ClientToolManagerConfig(InstantiateConfig): @@ -56,11 +105,105 @@ class ClientToolManager: with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(run_in_thread) tools = future.result() - return tools + return self._wrap_tools_with_injected_device_id(tools) except RuntimeError: # No event loop running, safe to use asyncio.run() tools = asyncio.run(self.aget_tools()) - return tools + return self._wrap_tools_with_injected_device_id(tools) + + def _wrap_tools_with_injected_device_id(self, tools: list) -> list: + """ + Wrap tools that have 'device_id' parameter to inject it from RunnableConfig. + This removes the burden from the LLM to pass device_id explicitly. + """ + wrapped_tools = [] + for tool in tools: + wrapped_tools.append(wrap_tool_with_injected_device_id(tool)) + return wrapped_tools + + +def wrap_tool_with_injected_device_id(tool: BaseTool) -> BaseTool: + """ + Wrap a tool to inject 'device_id' from RunnableConfig instead of requiring LLM to pass it. + If the tool doesn't have a device_id parameter, returns the tool unchanged. + + Uses DeviceIdInjectedTool which overrides invoke/ainvoke to inject device_id + directly from config before argument parsing. + """ + # Check if tool has device_id in its schema + tool_schema = None + if hasattr(tool, "args_schema") and tool.args_schema is not None: + if isinstance(tool.args_schema, dict): + tool_schema = tool.args_schema + elif hasattr(tool.args_schema, "model_json_schema"): + tool_schema = tool.args_schema.model_json_schema() + elif hasattr(tool.args_schema, "schema"): + tool_schema = tool.args_schema.schema() + elif hasattr(tool, "args") and tool.args is not None: + tool_schema = {"properties": tool.args} + + if tool_schema is None: + return tool + + properties = tool_schema.get("properties", {}) + if "device_id" not in properties: + return tool + + # Build a new args_schema WITHOUT device_id visible to LLM + # device_id will be injected at invoke/ainvoke level from config + new_fields = {} + required_fields = tool_schema.get("required", []) + + for field_name, field_info in properties.items(): + if field_name == "device_id": + # Skip device_id - it will be injected from config, not shown to LLM + continue + else: + # Preserve other fields + field_type = _get_python_type_from_schema(field_info) + is_required = field_name in required_fields + if is_required: + new_fields[field_name] = (field_type, Field(description=field_info.get("description", ""))) + else: + new_fields[field_name] = ( + Optional[field_type], + Field(default=field_info.get("default"), description=field_info.get("description", "")) + ) + + # Create the new Pydantic model (without device_id) + NewArgsSchema = create_model(f"{tool.name}Args", **new_fields) + + # Get original functions + original_func = tool.func if hasattr(tool, 'func') else None + original_coroutine = tool.coroutine if hasattr(tool, 'coroutine') else None + + # Create the new wrapped tool using DeviceIdInjectedTool + # which injects device_id at invoke/ainvoke level + wrapped_tool = DeviceIdInjectedTool( + name=tool.name, + description=tool.description, + args_schema=NewArgsSchema, + func=original_func, + coroutine=original_coroutine, + return_direct=getattr(tool, "return_direct", False), + ) + + logger.info(f"Wrapped tool '{tool.name}' - type: {type(wrapped_tool).__name__}, has ainvoke: {hasattr(wrapped_tool, 'ainvoke')}") + return wrapped_tool + + +def _get_python_type_from_schema(field_info: dict) -> type: + """Convert JSON schema type to Python type.""" + json_type = field_info.get("type", "string") + type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "array": list, + "object": dict, + } + return type_mapping.get(json_type, Any) if __name__ == "__main__": # NOTE: Simple test diff --git a/lang_agent/components/tool_manager.py b/lang_agent/components/tool_manager.py index 5d6c5c7..2f5c052 100644 --- a/lang_agent/components/tool_manager.py +++ b/lang_agent/components/tool_manager.py @@ -155,16 +155,10 @@ class ToolManager: @functools.wraps(func.coroutine) def sync_func(*args, _wrapper=sync_wrapper, **kwargs): return _wrapper(*args, **kwargs) - new_tool = StructuredTool( - name=func.name, - description=func.description, - args_schema=func.args_schema, - func=sync_func, - coroutine=func.coroutine, - metadata=func.metadata if hasattr(func, 'metadata') else None, - return_direct=func.return_direct if hasattr(func, 'return_direct') else False, - ) - self.langchain_tools.append(new_tool) + # Preserve the original tool's class (e.g., DeviceIdInjectedTool) + # by setting func directly instead of creating a new StructuredTool + func.func = sync_func + self.langchain_tools.append(func) else: self.langchain_tools.append(func) else: