Files
lang-agent/lang_agent/config.py
2025-11-03 23:36:23 +08:00

181 lines
6.0 KiB
Python

from dataclasses import dataclass, is_dataclass, fields, MISSING
from typing import Any, Tuple, Type
import yaml
from pathlib import Path
from typing import Dict
import os
from loguru import logger
from dotenv import load_dotenv
load_dotenv()
## NOTE: base classes taken from nerfstudio
class PrintableConfig:
"""Printable Config defining str function"""
def __str__(self):
lines = [self.__class__.__name__ + ":"]
for key, val in vars(self).items():
if self.is_secrete(key):
val = str(val)
val = val[:3] + "*"*(len(val) - 6) + val[-3:]
if isinstance(val, Tuple):
flattened_val = "["
for item in val:
flattened_val += str(item) + "\n"
flattened_val = flattened_val.rstrip("\n")
val = flattened_val + "]"
lines += f"{key}: {str(val)}".split("\n")
return "\n" + "\n ".join(lines)
def is_secrete(self, inp:str):
sec_list = ["secret", "api_key"]
for sec in sec_list:
if sec in inp:
return True
return False
# Base instantiate configs
@dataclass
class InstantiateConfig(PrintableConfig):
"""Config class for instantiating an the class specified in the _target attribute."""
_target: Type
def setup(self, **kwargs) -> Any:
"""Returns the instantiated object using the config."""
return self._target(self, **kwargs)
def save_config(self, filename: str) -> None:
"""Save the config to a YAML file."""
def mask_value(key, value):
# Apply masking if key is secret-like
if isinstance(value, str) and self.is_secrete(key):
sval = str(value)
return sval[:3] + "*" * (len(sval) - 6) + sval[-3:]
return value
def to_masked_serializable(obj):
# Recursively convert dataclasses and containers to serializable with masked secrets
if is_dataclass(obj):
out = {}
for k, v in vars(obj).items():
if is_dataclass(v) or isinstance(v, (dict, list, tuple)):
out[k] = to_masked_serializable(v)
else:
out[k] = mask_value(k, v)
return out
if isinstance(obj, dict):
out = {}
for k, v in obj.items():
if is_dataclass(v) or isinstance(v, (dict, list, tuple)):
out[k] = to_masked_serializable(v)
else:
# k might be a non-string; convert to str for is_secrete check consistency
key_str = str(k)
out[k] = mask_value(key_str, v)
return out
if isinstance(obj, list):
return [to_masked_serializable(v) for v in obj]
if isinstance(obj, tuple):
return tuple(to_masked_serializable(v) for v in obj)
return obj
masked = to_masked_serializable(self)
with open(filename, 'w') as f:
yaml.dump(masked, f)
logger.info(f"[yellow]config saved to: {filename}[/yellow]")
def get_name(self):
return self.__class__.__name__
@dataclass
class KeyConfig(InstantiateConfig):
api_key:str = None
"""api key for llm"""
def __post_init__(self):
if self.api_key == "wrong-key" or self.api_key is None:
self.api_key = os.environ.get("ALI_API_KEY")
if self.api_key is None:
logger.error(f"no ALI_API_KEY provided for embedding")
else:
logger.info("ALI_API_KEY loaded from environ")
@dataclass
class ToolConfig(InstantiateConfig):
use_tool:bool = True
"""specify to use tool or not"""
def load_tyro_conf(filename: str, inp_conf = None) -> InstantiateConfig:
"""load and overwrite config from file"""
config = yaml.load(Path(filename).read_text(), Loader=yaml.Loader)
config = ovewrite_config(config, inp_conf) if inp_conf is not None else config
return config
def is_default(instance, field_):
"""
Check if the value of a field in a dataclass instance is the default value.
"""
value = getattr(instance, field_.name)
if field_.default is not MISSING:
# Compare with default value
return value == field_.default
elif field_.default_factory is not MISSING:
# Compare with value generated by the default factory
return value == field_.default_factory()
else:
# No default value specified
return False
def ovewrite_config(loaded_conf, inp_conf):
"""for non-default values in inp_conf, overwrite the corresponding values in loaded_conf"""
if not (is_dataclass(loaded_conf) and is_dataclass(inp_conf)):
return loaded_conf
for field_ in fields(loaded_conf):
field_name = field_.name
# if field_name in inp_conf:
current_value = getattr(inp_conf, field_name)
new_value = getattr(inp_conf, field_name) #inp_conf[field_name]
if is_dataclass(current_value):
# Recurse for nested dataclasses
merged_value = ovewrite_config(current_value, new_value)
setattr(loaded_conf, field_name, merged_value)
elif not is_default(inp_conf, field_):
# Overwrite only if the current value is not default
setattr(loaded_conf, field_name, new_value)
return loaded_conf
def mcp_langchain_to_ws_config(conf:Dict[str, Dict[str, str]]):
serv_conf = {}
for k, v in conf.items():
if v["transport"] == "stdio":
serv_conf[k] = {
"type" : v["transport"],
"command": v["command"],
"args": v["args"],
}
else:
logger.warning(f"Unsupported transport {v['transport']} for MCP {k}. Skipping...")
continue
return {"mcpServers":serv_conf}