Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
195
sam3/eval/hota_eval_toolkit/trackeval/utils.py
Normal file
195
sam3/eval/hota_eval_toolkit/trackeval/utils.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# flake8: noqa
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def init_config(config, default_config, name=None):
|
||||
"""Initialise non-given config values with defaults"""
|
||||
if config is None:
|
||||
config = default_config
|
||||
else:
|
||||
for k in default_config.keys():
|
||||
if k not in config.keys():
|
||||
config[k] = default_config[k]
|
||||
if name and config["PRINT_CONFIG"]:
|
||||
print("\n%s Config:" % name)
|
||||
for c in config.keys():
|
||||
print("%-20s : %-30s" % (c, config[c]))
|
||||
return config
|
||||
|
||||
|
||||
def update_config(config):
|
||||
"""
|
||||
Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
|
||||
:param config: the config to update
|
||||
:return: the updated config
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
for setting in config.keys():
|
||||
if type(config[setting]) == list or type(config[setting]) == type(None):
|
||||
parser.add_argument("--" + setting, nargs="+")
|
||||
else:
|
||||
parser.add_argument("--" + setting)
|
||||
args = parser.parse_args().__dict__
|
||||
for setting in args.keys():
|
||||
if args[setting] is not None:
|
||||
if type(config[setting]) == type(True):
|
||||
if args[setting] == "True":
|
||||
x = True
|
||||
elif args[setting] == "False":
|
||||
x = False
|
||||
else:
|
||||
raise Exception(
|
||||
"Command line parameter " + setting + "must be True or False"
|
||||
)
|
||||
elif type(config[setting]) == type(1):
|
||||
x = int(args[setting])
|
||||
elif type(args[setting]) == type(None):
|
||||
x = None
|
||||
else:
|
||||
x = args[setting]
|
||||
config[setting] = x
|
||||
return config
|
||||
|
||||
|
||||
def get_code_path():
|
||||
"""Get base path where code is"""
|
||||
return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
|
||||
def validate_metrics_list(metrics_list):
|
||||
"""Get names of metric class and ensures they are unique, further checks that the fields within each metric class
|
||||
do not have overlapping names.
|
||||
"""
|
||||
metric_names = [metric.get_name() for metric in metrics_list]
|
||||
# check metric names are unique
|
||||
if len(metric_names) != len(set(metric_names)):
|
||||
raise TrackEvalException(
|
||||
"Code being run with multiple metrics of the same name"
|
||||
)
|
||||
fields = []
|
||||
for m in metrics_list:
|
||||
fields += m.fields
|
||||
# check metric fields are unique
|
||||
if len(fields) != len(set(fields)):
|
||||
raise TrackEvalException(
|
||||
"Code being run with multiple metrics with fields of the same name"
|
||||
)
|
||||
return metric_names
|
||||
|
||||
|
||||
def write_summary_results(summaries, cls, output_folder):
|
||||
"""Write summary results to file"""
|
||||
|
||||
fields = sum([list(s.keys()) for s in summaries], [])
|
||||
values = sum([list(s.values()) for s in summaries], [])
|
||||
|
||||
# In order to remain consistent upon new fields being adding, for each of the following fields if they are present
|
||||
# they will be output in the summary first in the order below. Any further fields will be output in the order each
|
||||
# metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or
|
||||
# randomly (python < 3.6).
|
||||
default_order = [
|
||||
"HOTA",
|
||||
"DetA",
|
||||
"AssA",
|
||||
"DetRe",
|
||||
"DetPr",
|
||||
"AssRe",
|
||||
"AssPr",
|
||||
"LocA",
|
||||
"OWTA",
|
||||
"HOTA(0)",
|
||||
"LocA(0)",
|
||||
"HOTALocA(0)",
|
||||
"MOTA",
|
||||
"MOTP",
|
||||
"MODA",
|
||||
"CLR_Re",
|
||||
"CLR_Pr",
|
||||
"MTR",
|
||||
"PTR",
|
||||
"MLR",
|
||||
"CLR_TP",
|
||||
"CLR_FN",
|
||||
"CLR_FP",
|
||||
"IDSW",
|
||||
"MT",
|
||||
"PT",
|
||||
"ML",
|
||||
"Frag",
|
||||
"sMOTA",
|
||||
"IDF1",
|
||||
"IDR",
|
||||
"IDP",
|
||||
"IDTP",
|
||||
"IDFN",
|
||||
"IDFP",
|
||||
"Dets",
|
||||
"GT_Dets",
|
||||
"IDs",
|
||||
"GT_IDs",
|
||||
]
|
||||
default_ordered_dict = OrderedDict(
|
||||
zip(default_order, [None for _ in default_order])
|
||||
)
|
||||
for f, v in zip(fields, values):
|
||||
default_ordered_dict[f] = v
|
||||
for df in default_order:
|
||||
if default_ordered_dict[df] is None:
|
||||
del default_ordered_dict[df]
|
||||
fields = list(default_ordered_dict.keys())
|
||||
values = list(default_ordered_dict.values())
|
||||
|
||||
out_file = os.path.join(output_folder, cls + "_summary.txt")
|
||||
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
||||
with open(out_file, "w", newline="") as f:
|
||||
writer = csv.writer(f, delimiter=" ")
|
||||
writer.writerow(fields)
|
||||
writer.writerow(values)
|
||||
|
||||
|
||||
def write_detailed_results(details, cls, output_folder):
|
||||
"""Write detailed results to file"""
|
||||
sequences = details[0].keys()
|
||||
fields = ["seq"] + sum([list(s["COMBINED_SEQ"].keys()) for s in details], [])
|
||||
out_file = os.path.join(output_folder, cls + "_detailed.csv")
|
||||
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
||||
with open(out_file, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(fields)
|
||||
for seq in sorted(sequences):
|
||||
if seq == "COMBINED_SEQ":
|
||||
continue
|
||||
writer.writerow([seq] + sum([list(s[seq].values()) for s in details], []))
|
||||
writer.writerow(
|
||||
["COMBINED"] + sum([list(s["COMBINED_SEQ"].values()) for s in details], [])
|
||||
)
|
||||
|
||||
|
||||
def load_detail(file):
|
||||
"""Loads detailed data for a tracker."""
|
||||
data = {}
|
||||
with open(file) as f:
|
||||
for i, row_text in enumerate(f):
|
||||
row = row_text.replace("\r", "").replace("\n", "").split(",")
|
||||
if i == 0:
|
||||
keys = row[1:]
|
||||
continue
|
||||
current_values = row[1:]
|
||||
seq = row[0]
|
||||
if seq == "COMBINED":
|
||||
seq = "COMBINED_SEQ"
|
||||
if (len(current_values) == len(keys)) and seq != "":
|
||||
data[seq] = {}
|
||||
for key, value in zip(keys, current_values):
|
||||
data[seq][key] = float(value)
|
||||
return data
|
||||
|
||||
|
||||
class TrackEvalException(Exception):
|
||||
"""Custom exception for catching expected errors."""
|
||||
|
||||
...
|
||||
Reference in New Issue
Block a user