update the device id correctly
This commit is contained in:
@@ -3,21 +3,59 @@ from typing import Type, Any, Optional
|
|||||||
import tyro
|
import tyro
|
||||||
import commentjson
|
import commentjson
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||||
from langchain_core.tools import BaseTool, StructuredTool
|
from langchain_core.tools import BaseTool, StructuredTool
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
|
||||||
from lang_agent.config import InstantiateConfig
|
from lang_agent.config import InstantiateConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _format_tool_result(result: Any, input: dict) -> str:
|
||||||
|
"""
|
||||||
|
Format the tool result to match the expected output format.
|
||||||
|
MCP tools return a tuple (result, error), which needs to be converted
|
||||||
|
to a JSON array string for consistency with StructuredTool.invoke() behavior.
|
||||||
|
"""
|
||||||
|
content = json.dumps(list(result))
|
||||||
|
return ToolMessage(content=content,
|
||||||
|
name=input.get("name"),
|
||||||
|
tool_call_id=input.get("id"))
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tool_call(input: Any) -> bool:
|
||||||
|
"""Check if input is a ToolCall dict (has 'id' and 'args' keys)."""
|
||||||
|
return isinstance(input, dict) and "id" in input and "args" in input
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_args(input: Any) -> tuple[dict, dict | None]:
|
||||||
|
"""
|
||||||
|
Extract tool arguments from input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(tool_args, tool_call_info) where tool_call_info contains id/name if it was a ToolCall
|
||||||
|
"""
|
||||||
|
if _is_tool_call(input):
|
||||||
|
# Input is a ToolCall: {"id": "...", "name": "...", "args": {...}}
|
||||||
|
tool_call_info = {"id": input.get("id"), "name": input.get("name")}
|
||||||
|
return input["args"].copy(), tool_call_info
|
||||||
|
else:
|
||||||
|
# Input is already the args dict
|
||||||
|
return input if isinstance(input, dict) else {}, None
|
||||||
|
|
||||||
|
|
||||||
class DeviceIdInjectedTool(StructuredTool):
|
class DeviceIdInjectedTool(StructuredTool):
|
||||||
"""
|
"""
|
||||||
A StructuredTool subclass that injects device_id from RunnableConfig
|
A StructuredTool subclass that injects device_id from RunnableConfig
|
||||||
at the invoke/ainvoke level, before any argument parsing.
|
at the invoke/ainvoke level, before any argument parsing.
|
||||||
|
|
||||||
|
NOTE: We bypass the parent's invoke/ainvoke to avoid Pydantic schema validation
|
||||||
|
which would strip the device_id field (since it's not in the new args_schema).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
@@ -26,15 +64,31 @@ class DeviceIdInjectedTool(StructuredTool):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# Inject device_id from config into the input dict
|
# Extract actual args from ToolCall if needed
|
||||||
|
tool_args, tool_call_info = _extract_tool_args(input)
|
||||||
|
|
||||||
|
# Inject device_id from config into the tool args
|
||||||
if config and "configurable" in config:
|
if config and "configurable" in config:
|
||||||
device_id = config["configurable"].get("device_id")
|
device_id = config["configurable"].get("device_id")
|
||||||
logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}")
|
logger.info(f"DeviceIdInjectedTool.invoke - device_id from config: {device_id}")
|
||||||
|
|
||||||
# Add device_id to input if it's valid (not None and not "0")
|
# Add device_id to args if it's valid (not None and not "0")
|
||||||
if isinstance(input, dict) and device_id is not None and device_id != "0":
|
if device_id is not None and device_id != "0":
|
||||||
input = {**input, "device_id": device_id}
|
tool_args = {**tool_args, "device_id": device_id}
|
||||||
|
|
||||||
|
logger.info(f"DeviceIdInjectedTool.invoke - calling with args: {list(tool_args.keys())}")
|
||||||
|
|
||||||
|
# Call the underlying func directly to bypass schema validation
|
||||||
|
# which would strip the device_id field not in args_schema
|
||||||
|
if self.func is not None:
|
||||||
|
result = self.func(**tool_args)
|
||||||
|
return _format_tool_result(result, tool_call_info or {})
|
||||||
|
elif self.coroutine is not None:
|
||||||
|
# Run async function synchronously
|
||||||
|
result = asyncio.run(self.coroutine(**tool_args))
|
||||||
|
return _format_tool_result(result, tool_call_info or {})
|
||||||
|
else:
|
||||||
|
# Fallback to parent implementation
|
||||||
return super().invoke(input, config, **kwargs)
|
return super().invoke(input, config, **kwargs)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
@@ -44,15 +98,33 @@ class DeviceIdInjectedTool(StructuredTool):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========")
|
logger.info(f"========== DeviceIdInjectedTool.ainvoke CALLED ==========")
|
||||||
# Inject device_id from config into the input dict
|
|
||||||
|
# Extract actual args from ToolCall if needed
|
||||||
|
tool_args, tool_call_info = _extract_tool_args(input)
|
||||||
|
|
||||||
|
# Inject device_id from config into the tool args
|
||||||
if config and "configurable" in config:
|
if config and "configurable" in config:
|
||||||
device_id = config["configurable"].get("device_id")
|
device_id = config["configurable"].get("device_id")
|
||||||
logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}")
|
logger.info(f"DeviceIdInjectedTool.ainvoke - device_id from config: {device_id}")
|
||||||
|
|
||||||
# Add device_id to input if it's valid (not None and not "0")
|
# Add device_id to args if it's valid (not None and not "0")
|
||||||
if isinstance(input, dict) and device_id is not None and device_id != "0":
|
if device_id is not None and device_id != "0":
|
||||||
input = {**input, "device_id": device_id}
|
tool_args = {**tool_args, "device_id": device_id}
|
||||||
|
|
||||||
|
logger.info(f"DeviceIdInjectedTool.ainvoke - calling with args: {list(tool_args.keys())}")
|
||||||
|
|
||||||
|
# Call the underlying coroutine/func directly to bypass schema validation
|
||||||
|
# which would strip the device_id field not in args_schema
|
||||||
|
if self.coroutine is not None:
|
||||||
|
result = await self.coroutine(**tool_args)
|
||||||
|
return _format_tool_result(result, tool_call_info or {})
|
||||||
|
elif self.func is not None:
|
||||||
|
# Run sync function in thread pool to avoid blocking
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
result = await loop.run_in_executor(None, lambda: self.func(**tool_args))
|
||||||
|
return _format_tool_result(result, tool_call_info or {})
|
||||||
|
else:
|
||||||
|
# Fallback to parent implementation
|
||||||
return await super().ainvoke(input, config, **kwargs)
|
return await super().ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
|
|||||||
Reference in New Issue
Block a user