add config class
This commit is contained in:
115
lang_agent/config.py
Normal file
115
lang_agent/config.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from dataclasses import dataclass, is_dataclass, fields, MISSING
|
||||
from typing import Any, Tuple, Type
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
from loguru import logger
|
||||
|
||||
## base classes taken from nerfstudio
|
||||
# Pretty printing class
|
||||
class PrintableConfig:
|
||||
"""Printable Config defining str function"""
|
||||
|
||||
def __str__(self):
|
||||
lines = [self.__class__.__name__ + ":"]
|
||||
for key, val in vars(self).items():
|
||||
|
||||
if key.endswith("_secret") or ("key" in key):
|
||||
val = "****"
|
||||
|
||||
if isinstance(val, Tuple):
|
||||
flattened_val = "["
|
||||
for item in val:
|
||||
flattened_val += str(item) + "\n"
|
||||
flattened_val = flattened_val.rstrip("\n")
|
||||
val = flattened_val + "]"
|
||||
lines += f"{key}: {str(val)}".split("\n")
|
||||
return "\n" + "\n ".join(lines)
|
||||
|
||||
|
||||
# Base instantiate configs
|
||||
@dataclass
|
||||
class InstantiateConfig(PrintableConfig):
|
||||
"""Config class for instantiating an the class specified in the _target attribute."""
|
||||
|
||||
_target: Type
|
||||
|
||||
def setup(self, **kwargs) -> Any:
|
||||
"""Returns the instantiated object using the config."""
|
||||
return self._target(self, **kwargs)
|
||||
|
||||
def save_config(self, filename: str) -> None:
|
||||
"""Save the config to a YAML file."""
|
||||
with open(filename, 'w') as f:
|
||||
yaml.dump(self, f)
|
||||
logger.info(f"[yellow]config saved to: {filename}[/yellow]")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiveConfig(InstantiateConfig):
|
||||
key_id: str = None
|
||||
"""alpaca key id"""
|
||||
|
||||
key_secret: str = None
|
||||
"""alpaca secret"""
|
||||
|
||||
paper: bool = True
|
||||
"""is paper trading or not"""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.key_id is None:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
self.key_id = os.getenv("ALPACA_KEY_ID")
|
||||
self.key_secret = os.getenv("ALPACA_KEY_SECRET")
|
||||
|
||||
assert self.key_id is not None, "alpaca key id required!"
|
||||
assert self.key_secret is not None, "alpaca key secret required!"
|
||||
|
||||
|
||||
|
||||
def load_config(filename: str, inp_conf = None) -> InstantiateConfig:
|
||||
"""load and overwrite config from file"""
|
||||
config = yaml.load(Path(filename).read_text(), Loader=yaml.Loader)
|
||||
|
||||
config = ovewrite_config(config, inp_conf) if inp_conf is not None else config
|
||||
return config
|
||||
|
||||
def is_default(instance, field_):
|
||||
"""
|
||||
Check if the value of a field in a dataclass instance is the default value.
|
||||
"""
|
||||
value = getattr(instance, field_.name)
|
||||
|
||||
if field_.default is not MISSING:
|
||||
# Compare with default value
|
||||
return value == field_.default
|
||||
elif field_.default_factory is not MISSING:
|
||||
# Compare with value generated by the default factory
|
||||
return value == field_.default_factory()
|
||||
else:
|
||||
# No default value specified
|
||||
return False
|
||||
|
||||
def ovewrite_config(loaded_conf, inp_conf):
|
||||
"""for non-default values in inp_conf, overwrite the corresponding values in loaded_conf"""
|
||||
if not (is_dataclass(loaded_conf) and is_dataclass(inp_conf)):
|
||||
return loaded_conf
|
||||
|
||||
for field_ in fields(loaded_conf):
|
||||
field_name = field_.name
|
||||
# if field_name in inp_conf:
|
||||
current_value = getattr(inp_conf, field_name)
|
||||
new_value = getattr(inp_conf, field_name) #inp_conf[field_name]
|
||||
|
||||
if is_dataclass(current_value):
|
||||
# Recurse for nested dataclasses
|
||||
merged_value = ovewrite_config(current_value, new_value)
|
||||
setattr(loaded_conf, field_name, merged_value)
|
||||
elif not is_default(inp_conf, field_):
|
||||
# Overwrite only if the current value is not default
|
||||
setattr(loaded_conf, field_name, new_value)
|
||||
|
||||
return loaded_conf
|
||||
Reference in New Issue
Block a user