update the device id correctly
This commit is contained in:
@@ -3,21 +3,59 @@ from typing import Type, Any, Optional
|
||||
import tyro
|
||||
import commentjson
|
||||
import asyncio
|
||||
import json
|
||||
import os.path as osp
|
||||
from loguru import logger
|
||||
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.messages import ToolMessage
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from lang_agent.config import InstantiateConfig
|
||||
|
||||
|
||||
def _format_tool_result(result: Any, input: dict) -> str:
|
||||
"""
|
||||
Format the tool result to match the expected output format.
|
||||
MCP tools return a tuple (result, error), which needs to be converted
|
||||
to a JSON array string for consistency with StructuredTool.invoke() behavior.
|
||||
"""
|
||||
content = json.dumps(list(result))
|
||||
return ToolMessage(content=content,
|
||||
name=input.get("name"),
|
||||
tool_call_id=input.get("id"))
|
||||
|
||||
|
||||
def _is_tool_call(input: Any) -> bool:
|
||||
"""Check if input is a ToolCall dict (has 'id' and 'args' keys)."""
|
||||
return isinstance(input, dict) and "id" in input and "args" in input
|
||||
|
||||
|
||||
def _extract_tool_args(input: Any) -> tuple[dict, dict | None]:
|
||||
"""
|
||||
Extract tool arguments from input.
|
||||
|
||||
Returns:
|
||||
(tool_args, tool_call_info) where tool_call_info contains id/name if it was a ToolCall
|
||||
"""
|
||||
if _is_tool_call(input):
|
||||
# Input is a ToolCall: {"id": "...", "name": "...", "args": {...}}
|
||||
tool_call_info = {"id": input.get("id"), "name": input.get("name")}
|
||||
return input["args"].copy(), tool_call_info
|
||||
else:
|
||||
# Input is already the args dict
|
||||
return input if isinstance(input, dict) else {}, None
|
||||
|
||||
|
||||
class DeviceIdInjectedTool(StructuredTool):
|
||||
"""
|
||||
A StructuredTool subclass that injects device_id from RunnableConfig
|
||||
at the invoke/ainvoke level, before any argument parsing.
|
||||
|
||||
NOTE: We bypass the parent's invoke/ainvoke to avoid Pydantic schema validation
|
||||
which would strip the device_id field (since it's not in the new args_schema).
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
@@ -26,16 +64,32 @@ class DeviceIdInjectedTool(StructuredTool):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Inject device_id from config into the input dict
|
||||
# Extract actual args from ToolCall if needed
|
||||
tool_args, tool_call_info = _extract_tool_args(input)
|
||||
|
||||
# Inject device_id from config into the tool args
|
||||
if config and "configurable" in config:
|
||||
device_id = config["configurable"].get("device_id")
|
||||
logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}")
|
||||
|
||||
# Add device_id to input if it's valid (not None and not "0")
|
||||
if isinstance(input, dict) and device_id is not None and device_id != "0":
|
||||
input = {**input, "device_id": device_id}
|
||||
# Add device_id to args if it's valid (not None and not "0")
|
||||
if device_id is not None and device_id != "0":
|
||||
tool_args = {**tool_args, "device_id": device_id}
|
||||
|
||||
return super().invoke(input, config, **kwargs)
|
||||
logger.info(f"DeviceIdInjectedTool.invoke - calling with args: {list(tool_args.keys())}")
|
||||
|
||||
# Call the underlying func directly to bypass schema validation
|
||||
# which would strip the device_id field not in args_schema
|
||||
if self.func is not None:
|
||||
result = self.func(**tool_args)
|
||||
return _format_tool_result(result, tool_call_info or {})
|
||||
elif self.coroutine is not None:
|
||||
# Run async function synchronously
|
||||
result = asyncio.run(self.coroutine(**tool_args))
|
||||
return _format_tool_result(result, tool_call_info or {})
|
||||
else:
|
||||
# Fallback to parent implementation
|
||||
return super().invoke(input, config, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@@ -44,16 +98,34 @@ class DeviceIdInjectedTool(StructuredTool):
|
||||
**kwargs,
|
||||
):
|
||||
logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========")
|
||||
# Inject device_id from config into the input dict
|
||||
|
||||
# Extract actual args from ToolCall if needed
|
||||
tool_args, tool_call_info = _extract_tool_args(input)
|
||||
|
||||
# Inject device_id from config into the tool args
|
||||
if config and "configurable" in config:
|
||||
device_id = config["configurable"].get("device_id")
|
||||
logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}")
|
||||
|
||||
# Add device_id to input if it's valid (not None and not "0")
|
||||
if isinstance(input, dict) and device_id is not None and device_id != "0":
|
||||
input = {**input, "device_id": device_id}
|
||||
# Add device_id to args if it's valid (not None and not "0")
|
||||
if device_id is not None and device_id != "0":
|
||||
tool_args = {**tool_args, "device_id": device_id}
|
||||
|
||||
return await super().ainvoke(input, config, **kwargs)
|
||||
logger.info(f"DeviceIdInjectedTool.ainvoke - calling with args: {list(tool_args.keys())}")
|
||||
|
||||
# Call the underlying coroutine/func directly to bypass schema validation
|
||||
# which would strip the device_id field not in args_schema
|
||||
if self.coroutine is not None:
|
||||
result = await self.coroutine(**tool_args)
|
||||
return _format_tool_result(result, tool_call_info or {})
|
||||
elif self.func is not None:
|
||||
# Run sync function in thread pool to avoid blocking
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, lambda: self.func(**tool_args))
|
||||
return _format_tool_result(result, tool_call_info or {})
|
||||
else:
|
||||
# Fallback to parent implementation
|
||||
return await super().ainvoke(input, config, **kwargs)
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user