From 522e1a551741b9c6289c44eade91d77036a56264 Mon Sep 17 00:00:00 2001 From: goulustis Date: Fri, 10 Oct 2025 16:28:50 +0800 Subject: [PATCH] add config class --- lang_agent/config.py | 115 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 lang_agent/config.py diff --git a/lang_agent/config.py b/lang_agent/config.py new file mode 100644 index 0000000..2fda88b --- /dev/null +++ b/lang_agent/config.py @@ -0,0 +1,115 @@ +from dataclasses import dataclass, is_dataclass, fields, MISSING +from typing import Any, Tuple, Type +import yaml +from pathlib import Path +import os + +from loguru import logger + +## base classes taken from nerfstudio +# Pretty printing class +class PrintableConfig: + """Printable Config defining str function""" + + def __str__(self): + lines = [self.__class__.__name__ + ":"] + for key, val in vars(self).items(): + + if key.endswith("_secret") or ("key" in key): + val = "****" + + if isinstance(val, Tuple): + flattened_val = "[" + for item in val: + flattened_val += str(item) + "\n" + flattened_val = flattened_val.rstrip("\n") + val = flattened_val + "]" + lines += f"{key}: {str(val)}".split("\n") + return "\n" + "\n ".join(lines) + + +# Base instantiate configs +@dataclass +class InstantiateConfig(PrintableConfig): + """Config class for instantiating an the class specified in the _target attribute.""" + + _target: Type + + def setup(self, **kwargs) -> Any: + """Returns the instantiated object using the config.""" + return self._target(self, **kwargs) + + def save_config(self, filename: str) -> None: + """Save the config to a YAML file.""" + with open(filename, 'w') as f: + yaml.dump(self, f) + logger.info(f"[yellow]config saved to: {filename}[/yellow]") + + +@dataclass +class LiveConfig(InstantiateConfig): + key_id: str = None + """alpaca key id""" + + key_secret: str = None + """alpaca secret""" + + paper: bool = True + """is paper trading or not""" + + def __post_init__(self): + if self.key_id is None: + from dotenv import load_dotenv + load_dotenv() + + self.key_id = os.getenv("ALPACA_KEY_ID") + self.key_secret = os.getenv("ALPACA_KEY_SECRET") + + assert self.key_id is not None, "alpaca key id required!" + assert self.key_secret is not None, "alpaca key secret required!" + + + +def load_config(filename: str, inp_conf = None) -> InstantiateConfig: + """load and overwrite config from file""" + config = yaml.load(Path(filename).read_text(), Loader=yaml.Loader) + + config = ovewrite_config(config, inp_conf) if inp_conf is not None else config + return config + +def is_default(instance, field_): + """ + Check if the value of a field in a dataclass instance is the default value. + """ + value = getattr(instance, field_.name) + + if field_.default is not MISSING: + # Compare with default value + return value == field_.default + elif field_.default_factory is not MISSING: + # Compare with value generated by the default factory + return value == field_.default_factory() + else: + # No default value specified + return False + +def ovewrite_config(loaded_conf, inp_conf): + """for non-default values in inp_conf, overwrite the corresponding values in loaded_conf""" + if not (is_dataclass(loaded_conf) and is_dataclass(inp_conf)): + return loaded_conf + + for field_ in fields(loaded_conf): + field_name = field_.name + # if field_name in inp_conf: + current_value = getattr(inp_conf, field_name) + new_value = getattr(inp_conf, field_name) #inp_conf[field_name] + + if is_dataclass(current_value): + # Recurse for nested dataclasses + merged_value = ovewrite_config(current_value, new_value) + setattr(loaded_conf, field_name, merged_value) + elif not is_default(inp_conf, field_): + # Overwrite only if the current value is not default + setattr(loaded_conf, field_name, new_value) + + return loaded_conf