choose tool set
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass, field, is_dataclass
|
from dataclasses import dataclass, field, is_dataclass
|
||||||
from typing import Type, Any, Optional
|
from typing import Type, Any, Optional, List
|
||||||
import tyro
|
import tyro
|
||||||
import commentjson
|
import commentjson
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -235,6 +235,9 @@ class ClientToolManagerConfig(InstantiateConfig):
|
|||||||
mcp_config_f: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "mcp_config.json")
|
mcp_config_f: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "mcp_config.json")
|
||||||
"""path to all mcp configurations; expect json file"""
|
"""path to all mcp configurations; expect json file"""
|
||||||
|
|
||||||
|
tool_keys: List = None
|
||||||
|
"""tool configs to use; the keys inside mcp_config; if None, use everything"""
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert osp.exists(self.mcp_config_f), f"config_f {self.mcp_config_f} does not exist."
|
assert osp.exists(self.mcp_config_f), f"config_f {self.mcp_config_f} does not exist."
|
||||||
|
|
||||||
@@ -247,16 +250,29 @@ class ClientToolManager:
|
|||||||
|
|
||||||
def populate_module(self):
|
def populate_module(self):
|
||||||
with open(self.config.mcp_config_f, "r") as f:
|
with open(self.config.mcp_config_f, "r") as f:
|
||||||
self.mcp_configs = commentjson.load(f)
|
self.mcp_configs:dict = commentjson.load(f)
|
||||||
|
|
||||||
async def aget_tools(self):
|
async def aget_tools(self):
|
||||||
"""
|
"""
|
||||||
Get tools from all configured MCP servers.
|
Get tools from all configured MCP servers.
|
||||||
Handles connection failures gracefully by logging warnings and continuing.
|
Handles connection failures gracefully by logging warnings and continuing.
|
||||||
"""
|
"""
|
||||||
all_tools = []
|
|
||||||
|
|
||||||
for server_name, server_config in self.mcp_configs.items():
|
def get_to_load_configs() -> dict:
|
||||||
|
if self.config.tool_keys is None:
|
||||||
|
to_load_config = self.mcp_configs
|
||||||
|
else:
|
||||||
|
to_load_config = {}
|
||||||
|
for key in self.config.tool_keys:
|
||||||
|
val = self.mcp_configs.get(key)
|
||||||
|
if val is None:
|
||||||
|
logger.warning(f"{key} is not in mcp tools")
|
||||||
|
else:
|
||||||
|
to_load_config[key] = val
|
||||||
|
|
||||||
|
to_load_config = get_to_load_configs()
|
||||||
|
all_tools = []
|
||||||
|
for server_name, server_config in to_load_config.items():
|
||||||
try:
|
try:
|
||||||
# Create a client for this single server
|
# Create a client for this single server
|
||||||
single_server_config = {server_name: server_config}
|
single_server_config = {server_name: server_config}
|
||||||
|
|||||||
Reference in New Issue
Block a user