diff --git a/lang_agent/config.py b/lang_agent/config.py index 1419bd5..bcf29a4 100644 --- a/lang_agent/config.py +++ b/lang_agent/config.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, is_dataclass, fields, MISSING from typing import Any, Tuple, Type import yaml from pathlib import Path -import os +from typing import Dict from loguru import logger @@ -56,7 +56,7 @@ class InstantiateConfig(PrintableConfig): -def load_config(filename: str, inp_conf = None) -> InstantiateConfig: +def load_tyro_conf(filename: str, inp_conf = None) -> InstantiateConfig: """load and overwrite config from file""" config = yaml.load(Path(filename).read_text(), Loader=yaml.Loader) @@ -99,3 +99,21 @@ def ovewrite_config(loaded_conf, inp_conf): setattr(loaded_conf, field_name, new_value) return loaded_conf + + +def mcp_langchain_to_ws_config(conf:Dict[str, Dict[str, str]]): + serv_conf = {} + + for k, v in conf.items(): + + if v["transport"] == "stdio": + serv_conf[k] = { + "type" : v["transport"], + "command": v["command"], + "args": ["-m"] + v["args"] if v["command"] == "python" else v["args"], + } + else: + logger.warning(f"Unsupported transport {v['transport']} for MCP {k}. Skipping...") + continue + + return {"mcpServers":serv_conf}