choose tool set

This commit is contained in:
2026-01-26 16:22:30 +08:00
parent d885178ebf
commit 13a0316729

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass, field, is_dataclass
from typing import Type, Any, Optional from typing import Type, Any, Optional, List
import tyro import tyro
import commentjson import commentjson
import asyncio import asyncio
@@ -235,6 +235,9 @@ class ClientToolManagerConfig(InstantiateConfig):
mcp_config_f: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "mcp_config.json") mcp_config_f: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "mcp_config.json")
"""path to all mcp configurations; expect json file""" """path to all mcp configurations; expect json file"""
tool_keys: List = None
"""tool configs to use; the keys inside mcp_config; if None, use everything"""
def __post_init__(self): def __post_init__(self):
assert osp.exists(self.mcp_config_f), f"config_f {self.mcp_config_f} does not exist." assert osp.exists(self.mcp_config_f), f"config_f {self.mcp_config_f} does not exist."
@@ -247,16 +250,29 @@ class ClientToolManager:
def populate_module(self): def populate_module(self):
with open(self.config.mcp_config_f, "r") as f: with open(self.config.mcp_config_f, "r") as f:
self.mcp_configs = commentjson.load(f) self.mcp_configs:dict = commentjson.load(f)
async def aget_tools(self): async def aget_tools(self):
""" """
Get tools from all configured MCP servers. Get tools from all configured MCP servers.
Handles connection failures gracefully by logging warnings and continuing. Handles connection failures gracefully by logging warnings and continuing.
""" """
all_tools = []
for server_name, server_config in self.mcp_configs.items(): def get_to_load_configs() -> dict:
if self.config.tool_keys is None:
to_load_config = self.mcp_configs
else:
to_load_config = {}
for key in self.config.tool_keys:
val = self.mcp_configs.get(key)
if val is None:
logger.warning(f"{key} is not in mcp tools")
else:
to_load_config[key] = val
to_load_config = get_to_load_configs()
all_tools = []
for server_name, server_config in to_load_config.items():
try: try:
# Create a client for this single server # Create a client for this single server
single_server_config = {server_name: server_config} single_server_config = {server_name: server_config}