From 8e423dc8fe85233baaedeb5ffc0cdbdb6e9887be Mon Sep 17 00:00:00 2001 From: goulustis Date: Wed, 31 Dec 2025 14:41:05 +0800 Subject: [PATCH] update the device id correctly --- lang_agent/components/client_tool_manager.py | 92 +++++++++++++++++--- 1 file changed, 82 insertions(+), 10 deletions(-) diff --git a/lang_agent/components/client_tool_manager.py b/lang_agent/components/client_tool_manager.py index f16b3dc..809d0cf 100644 --- a/lang_agent/components/client_tool_manager.py +++ b/lang_agent/components/client_tool_manager.py @@ -3,21 +3,59 @@ from typing import Type, Any, Optional import tyro import commentjson import asyncio +import json import os.path as osp from loguru import logger from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_core.tools import BaseTool, StructuredTool from langchain_core.runnables import RunnableConfig +from langchain_core.messages import ToolMessage from pydantic import BaseModel, Field, create_model from lang_agent.config import InstantiateConfig +def _format_tool_result(result: Any, input: dict) -> str: + """ + Format the tool result to match the expected output format. + MCP tools return a tuple (result, error), which needs to be converted + to a JSON array string for consistency with StructuredTool.invoke() behavior. + """ + content = json.dumps(list(result)) + return ToolMessage(content=content, + name=input.get("name"), + tool_call_id=input.get("id")) + + +def _is_tool_call(input: Any) -> bool: + """Check if input is a ToolCall dict (has 'id' and 'args' keys).""" + return isinstance(input, dict) and "id" in input and "args" in input + + +def _extract_tool_args(input: Any) -> tuple[dict, dict | None]: + """ + Extract tool arguments from input. + + Returns: + (tool_args, tool_call_info) where tool_call_info contains id/name if it was a ToolCall + """ + if _is_tool_call(input): + # Input is a ToolCall: {"id": "...", "name": "...", "args": {...}} + tool_call_info = {"id": input.get("id"), "name": input.get("name")} + return input["args"].copy(), tool_call_info + else: + # Input is already the args dict + return input if isinstance(input, dict) else {}, None + + class DeviceIdInjectedTool(StructuredTool): """ A StructuredTool subclass that injects device_id from RunnableConfig at the invoke/ainvoke level, before any argument parsing. + + NOTE: We bypass the parent's invoke/ainvoke to avoid Pydantic schema validation + which would strip the device_id field (since it's not in the new args_schema). """ def invoke( @@ -26,16 +64,32 @@ class DeviceIdInjectedTool(StructuredTool): config: Optional[RunnableConfig] = None, **kwargs, ): - # Inject device_id from config into the input dict + # Extract actual args from ToolCall if needed + tool_args, tool_call_info = _extract_tool_args(input) + + # Inject device_id from config into the tool args if config and "configurable" in config: device_id = config["configurable"].get("device_id") logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}") - # Add device_id to input if it's valid (not None and not "0") - if isinstance(input, dict) and device_id is not None and device_id != "0": - input = {**input, "device_id": device_id} + # Add device_id to args if it's valid (not None and not "0") + if device_id is not None and device_id != "0": + tool_args = {**tool_args, "device_id": device_id} - return super().invoke(input, config, **kwargs) + logger.info(f"DeviceIdInjectedTool.invoke - calling with args: {list(tool_args.keys())}") + + # Call the underlying func directly to bypass schema validation + # which would strip the device_id field not in args_schema + if self.func is not None: + result = self.func(**tool_args) + return _format_tool_result(result, tool_call_info or {}) + elif self.coroutine is not None: + # Run async function synchronously + result = asyncio.run(self.coroutine(**tool_args)) + return _format_tool_result(result, tool_call_info or {}) + else: + # Fallback to parent implementation + return super().invoke(input, config, **kwargs) async def ainvoke( self, @@ -44,16 +98,34 @@ class DeviceIdInjectedTool(StructuredTool): **kwargs, ): logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========") - # Inject device_id from config into the input dict + + # Extract actual args from ToolCall if needed + tool_args, tool_call_info = _extract_tool_args(input) + + # Inject device_id from config into the tool args if config and "configurable" in config: device_id = config["configurable"].get("device_id") logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}") - # Add device_id to input if it's valid (not None and not "0") - if isinstance(input, dict) and device_id is not None and device_id != "0": - input = {**input, "device_id": device_id} + # Add device_id to args if it's valid (not None and not "0") + if device_id is not None and device_id != "0": + tool_args = {**tool_args, "device_id": device_id} - return await super().ainvoke(input, config, **kwargs) + logger.info(f"DeviceIdInjectedTool.ainvoke - calling with args: {list(tool_args.keys())}") + + # Call the underlying coroutine/func directly to bypass schema validation + # which would strip the device_id field not in args_schema + if self.coroutine is not None: + result = await self.coroutine(**tool_args) + return _format_tool_result(result, tool_call_info or {}) + elif self.func is not None: + # Run sync function in thread pool to avoid blocking + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, lambda: self.func(**tool_args)) + return _format_tool_result(result, tool_call_info or {}) + else: + # Fallback to parent implementation + return await super().ainvoke(input, config, **kwargs) @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass