Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
153
sam3/eval/teta_eval_toolkit/config.py
Normal file
153
sam3/eval/teta_eval_toolkit/config.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
"""Config."""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def parse_configs():
|
||||
"""Parse command line."""
|
||||
default_eval_config = get_default_eval_config()
|
||||
default_eval_config["DISPLAY_LESS_PROGRESS"] = True
|
||||
default_dataset_config = get_default_dataset_config()
|
||||
default_metrics_config = {"METRICS": ["TETA"]}
|
||||
config = {
|
||||
**default_eval_config,
|
||||
**default_dataset_config,
|
||||
**default_metrics_config,
|
||||
}
|
||||
parser = argparse.ArgumentParser()
|
||||
for setting in config.keys():
|
||||
if type(config[setting]) == list or type(config[setting]) == type(None):
|
||||
parser.add_argument("--" + setting, nargs="+")
|
||||
else:
|
||||
parser.add_argument("--" + setting)
|
||||
args = parser.parse_args().__dict__
|
||||
for setting in args.keys():
|
||||
if args[setting] is not None:
|
||||
if type(config[setting]) == type(True):
|
||||
if args[setting] == "True":
|
||||
x = True
|
||||
elif args[setting] == "False":
|
||||
x = False
|
||||
else:
|
||||
raise Exception(
|
||||
f"Command line parameter {setting} must be True/False"
|
||||
)
|
||||
elif type(config[setting]) == type(1):
|
||||
x = int(args[setting])
|
||||
elif type(args[setting]) == type(None):
|
||||
x = None
|
||||
else:
|
||||
x = args[setting]
|
||||
config[setting] = x
|
||||
eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()}
|
||||
dataset_config = {
|
||||
k: v for k, v in config.items() if k in default_dataset_config.keys()
|
||||
}
|
||||
metrics_config = {
|
||||
k: v for k, v in config.items() if k in default_metrics_config.keys()
|
||||
}
|
||||
|
||||
return eval_config, dataset_config, metrics_config
|
||||
|
||||
|
||||
def get_default_eval_config():
|
||||
"""Returns the default config values for evaluation."""
|
||||
code_path = get_code_path()
|
||||
default_config = {
|
||||
"USE_PARALLEL": True,
|
||||
"NUM_PARALLEL_CORES": 8,
|
||||
"BREAK_ON_ERROR": True,
|
||||
"RETURN_ON_ERROR": False,
|
||||
"LOG_ON_ERROR": os.path.join(code_path, "error_log.txt"),
|
||||
"PRINT_RESULTS": True,
|
||||
"PRINT_ONLY_COMBINED": True,
|
||||
"PRINT_CONFIG": True,
|
||||
"TIME_PROGRESS": True,
|
||||
"DISPLAY_LESS_PROGRESS": True,
|
||||
"OUTPUT_SUMMARY": True,
|
||||
"OUTPUT_EMPTY_CLASSES": True,
|
||||
"OUTPUT_TEM_RAW_DATA": True,
|
||||
"OUTPUT_PER_SEQ_RES": True,
|
||||
}
|
||||
return default_config
|
||||
|
||||
|
||||
def get_default_dataset_config():
|
||||
"""Default class config values"""
|
||||
code_path = get_code_path()
|
||||
default_config = {
|
||||
"GT_FOLDER": os.path.join(
|
||||
code_path, "data/gt/tao/tao_training"
|
||||
), # Location of GT data
|
||||
"TRACKERS_FOLDER": os.path.join(
|
||||
code_path, "data/trackers/tao/tao_training"
|
||||
), # Trackers location
|
||||
"OUTPUT_FOLDER": None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
|
||||
"TRACKERS_TO_EVAL": ['TETer'], # Filenames of trackers to eval (if None, all in folder)
|
||||
"CLASSES_TO_EVAL": None, # Classes to eval (if None, all classes)
|
||||
"SPLIT_TO_EVAL": "training", # Valid: 'training', 'val'
|
||||
"PRINT_CONFIG": True, # Whether to print current config
|
||||
"TRACKER_SUB_FOLDER": "data", # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
|
||||
"OUTPUT_SUB_FOLDER": "", # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
|
||||
"TRACKER_DISPLAY_NAMES": None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
|
||||
"MAX_DETECTIONS": 0, # Number of maximal allowed detections per image (0 for unlimited)
|
||||
"USE_MASK": False, # Whether to use mask data for evaluation
|
||||
}
|
||||
return default_config
|
||||
|
||||
|
||||
def init_config(config, default_config, name=None):
|
||||
"""Initialize non-given config values with defaults."""
|
||||
if config is None:
|
||||
config = default_config
|
||||
else:
|
||||
for k in default_config.keys():
|
||||
if k not in config.keys():
|
||||
config[k] = default_config[k]
|
||||
if name and config["PRINT_CONFIG"]:
|
||||
print("\n%s Config:" % name)
|
||||
for c in config.keys():
|
||||
print("%-20s : %-30s" % (c, config[c]))
|
||||
return config
|
||||
|
||||
|
||||
def update_config(config):
|
||||
"""
|
||||
Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
|
||||
:param config: the config to update
|
||||
:return: the updated config
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
for setting in config.keys():
|
||||
if type(config[setting]) == list or type(config[setting]) == type(None):
|
||||
parser.add_argument("--" + setting, nargs="+")
|
||||
else:
|
||||
parser.add_argument("--" + setting)
|
||||
args = parser.parse_args().__dict__
|
||||
for setting in args.keys():
|
||||
if args[setting] is not None:
|
||||
if type(config[setting]) == type(True):
|
||||
if args[setting] == "True":
|
||||
x = True
|
||||
elif args[setting] == "False":
|
||||
x = False
|
||||
else:
|
||||
raise Exception(
|
||||
"Command line parameter " + setting + "must be True or False"
|
||||
)
|
||||
elif type(config[setting]) == type(1):
|
||||
x = int(args[setting])
|
||||
elif type(args[setting]) == type(None):
|
||||
x = None
|
||||
else:
|
||||
x = args[setting]
|
||||
config[setting] = x
|
||||
return config
|
||||
|
||||
|
||||
def get_code_path():
|
||||
"""Get base path where code is"""
|
||||
return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
Reference in New Issue
Block a user