wrap the self_* tools to specify tool use
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Type
|
from typing import Type, Any, Optional
|
||||||
import tyro
|
import tyro
|
||||||
import commentjson
|
import commentjson
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -7,9 +7,58 @@ import os.path as osp
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||||
|
from langchain_core.tools import BaseTool, StructuredTool
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
|
||||||
from lang_agent.config import InstantiateConfig
|
from lang_agent.config import InstantiateConfig
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceIdInjectedTool(StructuredTool):
|
||||||
|
"""
|
||||||
|
A StructuredTool subclass that injects device_id from RunnableConfig
|
||||||
|
at the invoke/ainvoke level, before any argument parsing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self,
|
||||||
|
input: dict,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
logger.info("================================================CONFIG========================")
|
||||||
|
logger.info(config)
|
||||||
|
# Inject device_id from config into the input dict
|
||||||
|
if config and "configurable" in config:
|
||||||
|
device_id = config["configurable"].get("device_id")
|
||||||
|
logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}")
|
||||||
|
|
||||||
|
# Add device_id to input if it's valid (not None and not "0")
|
||||||
|
if isinstance(input, dict) and device_id is not None and device_id != "0":
|
||||||
|
input = {**input, "device_id": device_id}
|
||||||
|
|
||||||
|
return super().invoke(input, config, **kwargs)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self,
|
||||||
|
input: dict,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========")
|
||||||
|
logger.info(f"input: {input}")
|
||||||
|
logger.info(f"config: {config}")
|
||||||
|
# Inject device_id from config into the input dict
|
||||||
|
if config and "configurable" in config:
|
||||||
|
device_id = config["configurable"].get("device_id")
|
||||||
|
logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}")
|
||||||
|
|
||||||
|
# Add device_id to input if it's valid (not None and not "0")
|
||||||
|
if isinstance(input, dict) and device_id is not None and device_id != "0":
|
||||||
|
input = {**input, "device_id": device_id}
|
||||||
|
|
||||||
|
return await super().ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClientToolManagerConfig(InstantiateConfig):
|
class ClientToolManagerConfig(InstantiateConfig):
|
||||||
@@ -56,11 +105,105 @@ class ClientToolManager:
|
|||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
future = executor.submit(run_in_thread)
|
future = executor.submit(run_in_thread)
|
||||||
tools = future.result()
|
tools = future.result()
|
||||||
return tools
|
return self._wrap_tools_with_injected_device_id(tools)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# No event loop running, safe to use asyncio.run()
|
# No event loop running, safe to use asyncio.run()
|
||||||
tools = asyncio.run(self.aget_tools())
|
tools = asyncio.run(self.aget_tools())
|
||||||
return tools
|
return self._wrap_tools_with_injected_device_id(tools)
|
||||||
|
|
||||||
|
def _wrap_tools_with_injected_device_id(self, tools: list) -> list:
|
||||||
|
"""
|
||||||
|
Wrap tools that have 'device_id' parameter to inject it from RunnableConfig.
|
||||||
|
This removes the burden from the LLM to pass device_id explicitly.
|
||||||
|
"""
|
||||||
|
wrapped_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
wrapped_tools.append(wrap_tool_with_injected_device_id(tool))
|
||||||
|
return wrapped_tools
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_tool_with_injected_device_id(tool: BaseTool) -> BaseTool:
|
||||||
|
"""
|
||||||
|
Wrap a tool to inject 'device_id' from RunnableConfig instead of requiring LLM to pass it.
|
||||||
|
If the tool doesn't have a device_id parameter, returns the tool unchanged.
|
||||||
|
|
||||||
|
Uses DeviceIdInjectedTool which overrides invoke/ainvoke to inject device_id
|
||||||
|
directly from config before argument parsing.
|
||||||
|
"""
|
||||||
|
# Check if tool has device_id in its schema
|
||||||
|
tool_schema = None
|
||||||
|
if hasattr(tool, "args_schema") and tool.args_schema is not None:
|
||||||
|
if isinstance(tool.args_schema, dict):
|
||||||
|
tool_schema = tool.args_schema
|
||||||
|
elif hasattr(tool.args_schema, "model_json_schema"):
|
||||||
|
tool_schema = tool.args_schema.model_json_schema()
|
||||||
|
elif hasattr(tool.args_schema, "schema"):
|
||||||
|
tool_schema = tool.args_schema.schema()
|
||||||
|
elif hasattr(tool, "args") and tool.args is not None:
|
||||||
|
tool_schema = {"properties": tool.args}
|
||||||
|
|
||||||
|
if tool_schema is None:
|
||||||
|
return tool
|
||||||
|
|
||||||
|
properties = tool_schema.get("properties", {})
|
||||||
|
if "device_id" not in properties:
|
||||||
|
return tool
|
||||||
|
|
||||||
|
# Build a new args_schema WITHOUT device_id visible to LLM
|
||||||
|
# device_id will be injected at invoke/ainvoke level from config
|
||||||
|
new_fields = {}
|
||||||
|
required_fields = tool_schema.get("required", [])
|
||||||
|
|
||||||
|
for field_name, field_info in properties.items():
|
||||||
|
if field_name == "device_id":
|
||||||
|
# Skip device_id - it will be injected from config, not shown to LLM
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Preserve other fields
|
||||||
|
field_type = _get_python_type_from_schema(field_info)
|
||||||
|
is_required = field_name in required_fields
|
||||||
|
if is_required:
|
||||||
|
new_fields[field_name] = (field_type, Field(description=field_info.get("description", "")))
|
||||||
|
else:
|
||||||
|
new_fields[field_name] = (
|
||||||
|
Optional[field_type],
|
||||||
|
Field(default=field_info.get("default"), description=field_info.get("description", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the new Pydantic model (without device_id)
|
||||||
|
NewArgsSchema = create_model(f"{tool.name}Args", **new_fields)
|
||||||
|
|
||||||
|
# Get original functions
|
||||||
|
original_func = tool.func if hasattr(tool, 'func') else None
|
||||||
|
original_coroutine = tool.coroutine if hasattr(tool, 'coroutine') else None
|
||||||
|
|
||||||
|
# Create the new wrapped tool using DeviceIdInjectedTool
|
||||||
|
# which injects device_id at invoke/ainvoke level
|
||||||
|
wrapped_tool = DeviceIdInjectedTool(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
args_schema=NewArgsSchema,
|
||||||
|
func=original_func,
|
||||||
|
coroutine=original_coroutine,
|
||||||
|
return_direct=getattr(tool, "return_direct", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Wrapped tool '{tool.name}' - type: {type(wrapped_tool).__name__}, has ainvoke: {hasattr(wrapped_tool, 'ainvoke')}")
|
||||||
|
return wrapped_tool
|
||||||
|
|
||||||
|
|
||||||
|
def _get_python_type_from_schema(field_info: dict) -> type:
|
||||||
|
"""Convert JSON schema type to Python type."""
|
||||||
|
json_type = field_info.get("type", "string")
|
||||||
|
type_mapping = {
|
||||||
|
"string": str,
|
||||||
|
"integer": int,
|
||||||
|
"number": float,
|
||||||
|
"boolean": bool,
|
||||||
|
"array": list,
|
||||||
|
"object": dict,
|
||||||
|
}
|
||||||
|
return type_mapping.get(json_type, Any)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# NOTE: Simple test
|
# NOTE: Simple test
|
||||||
|
|||||||
@@ -155,16 +155,10 @@ class ToolManager:
|
|||||||
@functools.wraps(func.coroutine)
|
@functools.wraps(func.coroutine)
|
||||||
def sync_func(*args, _wrapper=sync_wrapper, **kwargs):
|
def sync_func(*args, _wrapper=sync_wrapper, **kwargs):
|
||||||
return _wrapper(*args, **kwargs)
|
return _wrapper(*args, **kwargs)
|
||||||
new_tool = StructuredTool(
|
# Preserve the original tool's class (e.g., DeviceIdInjectedTool)
|
||||||
name=func.name,
|
# by setting func directly instead of creating a new StructuredTool
|
||||||
description=func.description,
|
func.func = sync_func
|
||||||
args_schema=func.args_schema,
|
self.langchain_tools.append(func)
|
||||||
func=sync_func,
|
|
||||||
coroutine=func.coroutine,
|
|
||||||
metadata=func.metadata if hasattr(func, 'metadata') else None,
|
|
||||||
return_direct=func.return_direct if hasattr(func, 'return_direct') else False,
|
|
||||||
)
|
|
||||||
self.langchain_tools.append(new_tool)
|
|
||||||
else:
|
else:
|
||||||
self.langchain_tools.append(func)
|
self.langchain_tools.append(func)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user