update the device id correctly

This commit is contained in:
2025-12-31 14:41:05 +08:00
parent e84146e549
commit 8e423dc8fe

View File

@@ -3,21 +3,59 @@ from typing import Type, Any, Optional
import tyro
import commentjson
import asyncio
import json
import os.path as osp
from loguru import logger
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_core.tools import BaseTool, StructuredTool
from langchain_core.runnables import RunnableConfig
from langchain_core.messages import ToolMessage
from pydantic import BaseModel, Field, create_model
from lang_agent.config import InstantiateConfig
def _format_tool_result(result: Any, input: dict) -> str:
"""
Format the tool result to match the expected output format.
MCP tools return a tuple (result, error), which needs to be converted
to a JSON array string for consistency with StructuredTool.invoke() behavior.
"""
content = json.dumps(list(result))
return ToolMessage(content=content,
name=input.get("name"),
tool_call_id=input.get("id"))
def _is_tool_call(input: Any) -> bool:
"""Check if input is a ToolCall dict (has 'id' and 'args' keys)."""
return isinstance(input, dict) and "id" in input and "args" in input
def _extract_tool_args(input: Any) -> tuple[dict, dict | None]:
"""
Extract tool arguments from input.
Returns:
(tool_args, tool_call_info) where tool_call_info contains id/name if it was a ToolCall
"""
if _is_tool_call(input):
# Input is a ToolCall: {"id": "...", "name": "...", "args": {...}}
tool_call_info = {"id": input.get("id"), "name": input.get("name")}
return input["args"].copy(), tool_call_info
else:
# Input is already the args dict
return input if isinstance(input, dict) else {}, None
class DeviceIdInjectedTool(StructuredTool):
"""
A StructuredTool subclass that injects device_id from RunnableConfig
at the invoke/ainvoke level, before any argument parsing.
NOTE: We bypass the parent's invoke/ainvoke to avoid Pydantic schema validation
which would strip the device_id field (since it's not in the new args_schema).
"""
def invoke(
@@ -26,16 +64,32 @@ class DeviceIdInjectedTool(StructuredTool):
config: Optional[RunnableConfig] = None,
**kwargs,
):
# Inject device_id from config into the input dict
# Extract actual args from ToolCall if needed
tool_args, tool_call_info = _extract_tool_args(input)
# Inject device_id from config into the tool args
if config and "configurable" in config:
device_id = config["configurable"].get("device_id")
logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}")
# Add device_id to input if it's valid (not None and not "0")
if isinstance(input, dict) and device_id is not None and device_id != "0":
input = {**input, "device_id": device_id}
# Add device_id to args if it's valid (not None and not "0")
if device_id is not None and device_id != "0":
tool_args = {**tool_args, "device_id": device_id}
return super().invoke(input, config, **kwargs)
logger.info(f"DeviceIdInjectedTool.invoke - calling with args: {list(tool_args.keys())}")
# Call the underlying func directly to bypass schema validation
# which would strip the device_id field not in args_schema
if self.func is not None:
result = self.func(**tool_args)
return _format_tool_result(result, tool_call_info or {})
elif self.coroutine is not None:
# Run async function synchronously
result = asyncio.run(self.coroutine(**tool_args))
return _format_tool_result(result, tool_call_info or {})
else:
# Fallback to parent implementation
return super().invoke(input, config, **kwargs)
async def ainvoke(
self,
@@ -44,16 +98,34 @@ class DeviceIdInjectedTool(StructuredTool):
**kwargs,
):
logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========")
# Inject device_id from config into the input dict
# Extract actual args from ToolCall if needed
tool_args, tool_call_info = _extract_tool_args(input)
# Inject device_id from config into the tool args
if config and "configurable" in config:
device_id = config["configurable"].get("device_id")
logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}")
# Add device_id to input if it's valid (not None and not "0")
if isinstance(input, dict) and device_id is not None and device_id != "0":
input = {**input, "device_id": device_id}
# Add device_id to args if it's valid (not None and not "0")
if device_id is not None and device_id != "0":
tool_args = {**tool_args, "device_id": device_id}
return await super().ainvoke(input, config, **kwargs)
logger.info(f"DeviceIdInjectedTool.ainvoke - calling with args: {list(tool_args.keys())}")
# Call the underlying coroutine/func directly to bypass schema validation
# which would strip the device_id field not in args_schema
if self.coroutine is not None:
result = await self.coroutine(**tool_args)
return _format_tool_result(result, tool_call_info or {})
elif self.func is not None:
# Run sync function in thread pool to avoid blocking
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, lambda: self.func(**tool_args))
return _format_tool_result(result, tool_call_info or {})
else:
# Fallback to parent implementation
return await super().ainvoke(input, config, **kwargs)
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass