change save location
This commit is contained in:
@@ -72,56 +72,12 @@ class InstantiateConfig(PrintableConfig):
|
|||||||
将配置保存到 YAML 文件
|
将配置保存到 YAML 文件
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# Persist the full config object (including type tags) so it can be
|
||||||
def mask_value(key, value):
|
# deserialized back into config instances with methods like .setup().
|
||||||
"""
|
# Secret masking is intentionally handled by __str__ for printing/logging,
|
||||||
Apply masking if key is secret-like
|
# not when writing to disk.
|
||||||
如果键是敏感的,应用掩码
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
yaml.dump(self, f)
|
||||||
检查键是否敏感(如包含 "secret" 或 "api_key"),如果是,则对值进行掩码处理
|
|
||||||
"""
|
|
||||||
if isinstance(value, str) and self.is_secrete(str(key)):
|
|
||||||
sval = str(value)
|
|
||||||
return sval[:3] + "*" * (len(sval) - 6) + sval[-3:]
|
|
||||||
return value
|
|
||||||
|
|
||||||
def to_serializable(obj, apply_mask: bool):
|
|
||||||
"""
|
|
||||||
Recursively convert dataclasses and containers to serializable format,
|
|
||||||
optionally masking secrets.
|
|
||||||
|
|
||||||
递归地将数据类和容器转换为可序列化的格式,可选地对敏感信息进行掩码处理
|
|
||||||
"""
|
|
||||||
if is_dataclass(obj):
|
|
||||||
out = {}
|
|
||||||
for k, v in vars(obj).items():
|
|
||||||
if is_dataclass(v) or isinstance(v, (dict, list, tuple)):
|
|
||||||
out[k] = to_serializable(v, apply_mask)
|
|
||||||
else:
|
|
||||||
out[k] = mask_value(k, v) if apply_mask else v
|
|
||||||
return out
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
out = {}
|
|
||||||
for k, v in obj.items():
|
|
||||||
if is_dataclass(v) or isinstance(v, (dict, list, tuple)):
|
|
||||||
out[k] = to_serializable(v, apply_mask)
|
|
||||||
else:
|
|
||||||
key_str = str(k)
|
|
||||||
out[k] = mask_value(key_str, v) if apply_mask else v
|
|
||||||
return out
|
|
||||||
if isinstance(obj, list):
|
|
||||||
return [to_serializable(v, apply_mask) for v in obj]
|
|
||||||
if isinstance(obj, tuple):
|
|
||||||
return tuple(to_serializable(v, apply_mask) for v in obj)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
# NOTE: we intentionally do NOT mask secrets when saving to disk so that
|
|
||||||
# configs can be reloaded with real values. Masking is handled in __str__
|
|
||||||
# for safe logging/printing. If you need a redacted copy, call
|
|
||||||
# to_serializable(self, apply_mask=True) manually and dump it yourself.
|
|
||||||
serializable = to_serializable(self, apply_mask=False)
|
|
||||||
with open(filename, 'w') as f:
|
|
||||||
yaml.dump(serializable, f)
|
|
||||||
logger.info(f"[yellow]config saved to: {filename}[/yellow]")
|
logger.info(f"[yellow]config saved to: {filename}[/yellow]")
|
||||||
|
|
||||||
def get_name(self):
|
def get_name(self):
|
||||||
@@ -182,7 +138,7 @@ def load_tyro_conf(filename: str, inp_conf = None) -> InstantiateConfig:
|
|||||||
"""
|
"""
|
||||||
config = yaml.load(Path(filename).read_text(), Loader=yaml.Loader)
|
config = yaml.load(Path(filename).read_text(), Loader=yaml.Loader)
|
||||||
|
|
||||||
config = ovewrite_config(config, inp_conf) if inp_conf is not None else config
|
# config = ovewrite_config(config, inp_conf) if inp_conf is not None else config
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def is_default(instance, field_):
|
def is_default(instance, field_):
|
||||||
|
|||||||
Reference in New Issue
Block a user