choose tool set
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, Any, Optional
|
||||
from typing import Type, Any, Optional, List
|
||||
import tyro
|
||||
import commentjson
|
||||
import asyncio
|
||||
@@ -235,6 +235,9 @@ class ClientToolManagerConfig(InstantiateConfig):
|
||||
mcp_config_f: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "mcp_config.json")
|
||||
"""path to all mcp configurations; expect json file"""
|
||||
|
||||
tool_keys: List = None
|
||||
"""tool configs to use; the keys inside mcp_config; if None, use everything"""
|
||||
|
||||
def __post_init__(self):
|
||||
assert osp.exists(self.mcp_config_f), f"config_f {self.mcp_config_f} does not exist."
|
||||
|
||||
@@ -247,16 +250,29 @@ class ClientToolManager:
|
||||
|
||||
def populate_module(self):
|
||||
with open(self.config.mcp_config_f, "r") as f:
|
||||
self.mcp_configs = commentjson.load(f)
|
||||
self.mcp_configs:dict = commentjson.load(f)
|
||||
|
||||
async def aget_tools(self):
|
||||
"""
|
||||
Get tools from all configured MCP servers.
|
||||
Handles connection failures gracefully by logging warnings and continuing.
|
||||
"""
|
||||
all_tools = []
|
||||
|
||||
for server_name, server_config in self.mcp_configs.items():
|
||||
def get_to_load_configs() -> dict:
|
||||
if self.config.tool_keys is None:
|
||||
to_load_config = self.mcp_configs
|
||||
else:
|
||||
to_load_config = {}
|
||||
for key in self.config.tool_keys:
|
||||
val = self.mcp_configs.get(key)
|
||||
if val is None:
|
||||
logger.warning(f"{key} is not in mcp tools")
|
||||
else:
|
||||
to_load_config[key] = val
|
||||
|
||||
to_load_config = get_to_load_configs()
|
||||
all_tools = []
|
||||
for server_name, server_config in to_load_config.items():
|
||||
try:
|
||||
# Create a client for this single server
|
||||
single_server_config = {server_name: server_config}
|
||||
|
||||
Reference in New Issue
Block a user