choose tool set

This commit is contained in:
2026-01-26 16:22:30 +08:00
parent d885178ebf
commit 13a0316729

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field, is_dataclass
from typing import Type, Any, Optional
from typing import Type, Any, Optional, List
import tyro
import commentjson
import asyncio
@@ -235,6 +235,9 @@ class ClientToolManagerConfig(InstantiateConfig):
mcp_config_f: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "mcp_config.json")
"""path to all mcp configurations; expect json file"""
tool_keys: List = None
"""tool configs to use; the keys inside mcp_config; if None, use everything"""
def __post_init__(self):
assert osp.exists(self.mcp_config_f), f"config_f {self.mcp_config_f} does not exist."
@@ -247,16 +250,29 @@ class ClientToolManager:
def populate_module(self):
with open(self.config.mcp_config_f, "r") as f:
self.mcp_configs = commentjson.load(f)
self.mcp_configs:dict = commentjson.load(f)
async def aget_tools(self):
"""
Get tools from all configured MCP servers.
Handles connection failures gracefully by logging warnings and continuing.
"""
all_tools = []
for server_name, server_config in self.mcp_configs.items():
def get_to_load_configs() -> dict:
if self.config.tool_keys is None:
to_load_config = self.mcp_configs
else:
to_load_config = {}
for key in self.config.tool_keys:
val = self.mcp_configs.get(key)
if val is None:
logger.warning(f"{key} is not in mcp tools")
else:
to_load_config[key] = val
to_load_config = get_to_load_configs()
all_tools = []
for server_name, server_config in to_load_config.items():
try:
# Create a client for this single server
single_server_config = {server_name: server_config}