update the device id correctly

This commit is contained in:
2025-12-31 14:41:05 +08:00
parent e84146e549
commit 8e423dc8fe

View File

@@ -3,21 +3,59 @@ from typing import Type, Any, Optional
import tyro import tyro
import commentjson import commentjson
import asyncio import asyncio
import json
import os.path as osp import os.path as osp
from loguru import logger from loguru import logger
from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_core.tools import BaseTool, StructuredTool from langchain_core.tools import BaseTool, StructuredTool
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langchain_core.messages import ToolMessage
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from lang_agent.config import InstantiateConfig from lang_agent.config import InstantiateConfig
def _format_tool_result(result: Any, input: dict) -> str:
"""
Format the tool result to match the expected output format.
MCP tools return a tuple (result, error), which needs to be converted
to a JSON array string for consistency with StructuredTool.invoke() behavior.
"""
content = json.dumps(list(result))
return ToolMessage(content=content,
name=input.get("name"),
tool_call_id=input.get("id"))
def _is_tool_call(input: Any) -> bool:
"""Check if input is a ToolCall dict (has 'id' and 'args' keys)."""
return isinstance(input, dict) and "id" in input and "args" in input
def _extract_tool_args(input: Any) -> tuple[dict, dict | None]:
"""
Extract tool arguments from input.
Returns:
(tool_args, tool_call_info) where tool_call_info contains id/name if it was a ToolCall
"""
if _is_tool_call(input):
# Input is a ToolCall: {"id": "...", "name": "...", "args": {...}}
tool_call_info = {"id": input.get("id"), "name": input.get("name")}
return input["args"].copy(), tool_call_info
else:
# Input is already the args dict
return input if isinstance(input, dict) else {}, None
class DeviceIdInjectedTool(StructuredTool): class DeviceIdInjectedTool(StructuredTool):
""" """
A StructuredTool subclass that injects device_id from RunnableConfig A StructuredTool subclass that injects device_id from RunnableConfig
at the invoke/ainvoke level, before any argument parsing. at the invoke/ainvoke level, before any argument parsing.
NOTE: We bypass the parent's invoke/ainvoke to avoid Pydantic schema validation
which would strip the device_id field (since it's not in the new args_schema).
""" """
def invoke( def invoke(
@@ -26,15 +64,31 @@ class DeviceIdInjectedTool(StructuredTool):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs, **kwargs,
): ):
# Inject device_id from config into the input dict # Extract actual args from ToolCall if needed
tool_args, tool_call_info = _extract_tool_args(input)
# Inject device_id from config into the tool args
if config and "configurable" in config: if config and "configurable" in config:
device_id = config["configurable"].get("device_id") device_id = config["configurable"].get("device_id")
logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}") logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}")
# Add device_id to input if it's valid (not None and not "0") # Add device_id to args if it's valid (not None and not "0")
if isinstance(input, dict) and device_id is not None and device_id != "0": if device_id is not None and device_id != "0":
input = {**input, "device_id": device_id} tool_args = {**tool_args, "device_id": device_id}
logger.info(f"DeviceIdInjectedTool.invoke - calling with args: {list(tool_args.keys())}")
# Call the underlying func directly to bypass schema validation
# which would strip the device_id field not in args_schema
if self.func is not None:
result = self.func(**tool_args)
return _format_tool_result(result, tool_call_info or {})
elif self.coroutine is not None:
# Run async function synchronously
result = asyncio.run(self.coroutine(**tool_args))
return _format_tool_result(result, tool_call_info or {})
else:
# Fallback to parent implementation
return super().invoke(input, config, **kwargs) return super().invoke(input, config, **kwargs)
async def ainvoke( async def ainvoke(
@@ -44,15 +98,33 @@ class DeviceIdInjectedTool(StructuredTool):
**kwargs, **kwargs,
): ):
logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========") logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========")
# Inject device_id from config into the input dict
# Extract actual args from ToolCall if needed
tool_args, tool_call_info = _extract_tool_args(input)
# Inject device_id from config into the tool args
if config and "configurable" in config: if config and "configurable" in config:
device_id = config["configurable"].get("device_id") device_id = config["configurable"].get("device_id")
logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}") logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}")
# Add device_id to input if it's valid (not None and not "0") # Add device_id to args if it's valid (not None and not "0")
if isinstance(input, dict) and device_id is not None and device_id != "0": if device_id is not None and device_id != "0":
input = {**input, "device_id": device_id} tool_args = {**tool_args, "device_id": device_id}
logger.info(f"DeviceIdInjectedTool.ainvoke - calling with args: {list(tool_args.keys())}")
# Call the underlying coroutine/func directly to bypass schema validation
# which would strip the device_id field not in args_schema
if self.coroutine is not None:
result = await self.coroutine(**tool_args)
return _format_tool_result(result, tool_call_info or {})
elif self.func is not None:
# Run sync function in thread pool to avoid blocking
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, lambda: self.func(**tool_args))
return _format_tool_result(result, tool_call_info or {})
else:
# Fallback to parent implementation
return await super().ainvoke(input, config, **kwargs) return await super().ainvoke(input, config, **kwargs)
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)