Files
sam3_local/sam3/eval/hota_eval_toolkit/trackeval/eval.py
Bowie Chen 11dec2936d apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: itamaro

Differential Revision: D90476315

fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
2026-01-11 23:16:49 -08:00

399 lines
17 KiB
Python

# flake8: noqa
# pyre-unsafe
import os
import time
import traceback
from functools import partial
from multiprocessing.pool import Pool
import numpy as np
from . import _timing, utils
from .metrics import Count
from .utils import TrackEvalException
try:
import tqdm
TQDM_IMPORTED = True
except ImportError as _:
TQDM_IMPORTED = False
class Evaluator:
"""Evaluator class for evaluating different metrics for different datasets"""
@staticmethod
def get_default_eval_config():
"""Returns the default config values for evaluation"""
code_path = utils.get_code_path()
default_config = {
"USE_PARALLEL": False,
"NUM_PARALLEL_CORES": 8,
"BREAK_ON_ERROR": True, # Raises exception and exits with error
"RETURN_ON_ERROR": False, # if not BREAK_ON_ERROR, then returns from function on error
"LOG_ON_ERROR": os.path.join(
code_path, "error_log.txt"
), # if not None, save any errors into a log file.
"PRINT_RESULTS": True,
"PRINT_ONLY_COMBINED": False,
"PRINT_CONFIG": True,
"TIME_PROGRESS": True,
"DISPLAY_LESS_PROGRESS": True,
"OUTPUT_SUMMARY": True,
"OUTPUT_EMPTY_CLASSES": True, # If False, summary files are not output for classes with no detections
"OUTPUT_DETAILED": True,
"PLOT_CURVES": True,
}
return default_config
def __init__(self, config=None):
"""Initialise the evaluator with a config file"""
self.config = utils.init_config(config, self.get_default_eval_config(), "Eval")
# Only run timing analysis if not run in parallel.
if self.config["TIME_PROGRESS"] and not self.config["USE_PARALLEL"]:
_timing.DO_TIMING = True
if self.config["DISPLAY_LESS_PROGRESS"]:
_timing.DISPLAY_LESS_PROGRESS = True
def _combine_results(
self,
res,
metrics_list,
metric_names,
dataset,
res_field="COMBINED_SEQ",
target_tag=None,
):
assert res_field.startswith("COMBINED_SEQ")
# collecting combined cls keys (cls averaged, det averaged, super classes)
tracker_list, seq_list, class_list = dataset.get_eval_info()
combined_cls_keys = []
res[res_field] = {}
# narrow the target for evaluation
if target_tag is not None:
target_video_ids = [
annot["video_id"]
for annot in dataset.gt_data["annotations"]
if target_tag in annot["tags"]
]
vid2name = {
video["id"]: video["file_names"][0].split("/")[0]
for video in dataset.gt_data["videos"]
}
target_video_ids = set(target_video_ids)
target_video = [vid2name[video_id] for video_id in target_video_ids]
if len(target_video) == 0:
raise TrackEvalException(
"No sequences found with the tag %s" % target_tag
)
target_annotations = [
annot
for annot in dataset.gt_data["annotations"]
if annot["video_id"] in target_video_ids
]
assert all(target_tag in annot["tags"] for annot in target_annotations), (
f"Not all annotations in the target sequences have the target tag {target_tag}. "
"We currently only support a target tag at the sequence level, not at the annotation level."
)
else:
target_video = seq_list
# combine sequences for each class
for c_cls in class_list:
res[res_field][c_cls] = {}
for metric, metric_name in zip(metrics_list, metric_names):
curr_res = {
seq_key: seq_value[c_cls][metric_name]
for seq_key, seq_value in res.items()
if not seq_key.startswith("COMBINED_SEQ")
and seq_key in target_video
}
res[res_field][c_cls][metric_name] = metric.combine_sequences(curr_res)
# combine classes
if dataset.should_classes_combine:
combined_cls_keys += [
"cls_comb_cls_av",
"cls_comb_det_av",
"all",
]
res[res_field]["cls_comb_cls_av"] = {}
res[res_field]["cls_comb_det_av"] = {}
for metric, metric_name in zip(metrics_list, metric_names):
cls_res = {
cls_key: cls_value[metric_name]
for cls_key, cls_value in res[res_field].items()
if cls_key not in combined_cls_keys
}
res[res_field]["cls_comb_cls_av"][metric_name] = (
metric.combine_classes_class_averaged(cls_res)
)
res[res_field]["cls_comb_det_av"][metric_name] = (
metric.combine_classes_det_averaged(cls_res)
)
# combine classes to super classes
if dataset.use_super_categories:
for cat, sub_cats in dataset.super_categories.items():
combined_cls_keys.append(cat)
res[res_field][cat] = {}
for metric, metric_name in zip(metrics_list, metric_names):
cat_res = {
cls_key: cls_value[metric_name]
for cls_key, cls_value in res[res_field].items()
if cls_key in sub_cats
}
res[res_field][cat][metric_name] = (
metric.combine_classes_det_averaged(cat_res)
)
return res, combined_cls_keys
def _summarize_results(
self,
res,
tracker,
metrics_list,
metric_names,
dataset,
res_field,
combined_cls_keys,
):
config = self.config
output_fol = dataset.get_output_fol(tracker)
tracker_display_name = dataset.get_display_name(tracker)
for c_cls in res[
res_field
].keys(): # class_list + combined classes if calculated
summaries = []
details = []
num_dets = res[res_field][c_cls]["Count"]["Dets"]
if config["OUTPUT_EMPTY_CLASSES"] or num_dets > 0:
for metric, metric_name in zip(metrics_list, metric_names):
# for combined classes there is no per sequence evaluation
if c_cls in combined_cls_keys:
table_res = {res_field: res[res_field][c_cls][metric_name]}
else:
table_res = {
seq_key: seq_value[c_cls][metric_name]
for seq_key, seq_value in res.items()
}
if config["PRINT_RESULTS"] and config["PRINT_ONLY_COMBINED"]:
dont_print = (
dataset.should_classes_combine
and c_cls not in combined_cls_keys
)
if not dont_print:
metric.print_table(
{res_field: table_res[res_field]},
tracker_display_name,
c_cls,
res_field,
res_field,
)
elif config["PRINT_RESULTS"]:
metric.print_table(
table_res, tracker_display_name, c_cls, res_field, res_field
)
if config["OUTPUT_SUMMARY"]:
summaries.append(metric.summary_results(table_res))
if config["OUTPUT_DETAILED"]:
details.append(metric.detailed_results(table_res))
if config["PLOT_CURVES"]:
metric.plot_single_tracker_results(
table_res,
tracker_display_name,
c_cls,
output_fol,
)
if config["OUTPUT_SUMMARY"]:
utils.write_summary_results(summaries, c_cls, output_fol)
if config["OUTPUT_DETAILED"]:
utils.write_detailed_results(details, c_cls, output_fol)
@_timing.time
def evaluate(self, dataset_list, metrics_list, show_progressbar=False):
"""Evaluate a set of metrics on a set of datasets"""
config = self.config
metrics_list = metrics_list + [Count()] # Count metrics are always run
metric_names = utils.validate_metrics_list(metrics_list)
dataset_names = [dataset.get_name() for dataset in dataset_list]
output_res = {}
output_msg = {}
for dataset, dataset_name in zip(dataset_list, dataset_names):
# Get dataset info about what to evaluate
output_res[dataset_name] = {}
output_msg[dataset_name] = {}
tracker_list, seq_list, class_list = dataset.get_eval_info()
print(
"\nEvaluating %i tracker(s) on %i sequence(s) for %i class(es) on %s dataset using the following "
"metrics: %s\n"
% (
len(tracker_list),
len(seq_list),
len(class_list),
dataset_name,
", ".join(metric_names),
)
)
# Evaluate each tracker
for tracker in tracker_list:
# if not config['BREAK_ON_ERROR'] then go to next tracker without breaking
try:
# Evaluate each sequence in parallel or in series.
# returns a nested dict (res), indexed like: res[seq][class][metric_name][sub_metric field]
# e.g. res[seq_0001][pedestrian][hota][DetA]
print("\nEvaluating %s\n" % tracker)
time_start = time.time()
if config["USE_PARALLEL"]:
if show_progressbar and TQDM_IMPORTED:
seq_list_sorted = sorted(seq_list)
with (
Pool(config["NUM_PARALLEL_CORES"]) as pool,
tqdm.tqdm(total=len(seq_list)) as pbar,
):
_eval_sequence = partial(
eval_sequence,
dataset=dataset,
tracker=tracker,
class_list=class_list,
metrics_list=metrics_list,
metric_names=metric_names,
)
results = []
for r in pool.imap(
_eval_sequence, seq_list_sorted, chunksize=20
):
results.append(r)
pbar.update()
res = dict(zip(seq_list_sorted, results))
else:
with Pool(config["NUM_PARALLEL_CORES"]) as pool:
_eval_sequence = partial(
eval_sequence,
dataset=dataset,
tracker=tracker,
class_list=class_list,
metrics_list=metrics_list,
metric_names=metric_names,
)
results = pool.map(_eval_sequence, seq_list)
res = dict(zip(seq_list, results))
else:
res = {}
if show_progressbar and TQDM_IMPORTED:
seq_list_sorted = sorted(seq_list)
for curr_seq in tqdm.tqdm(seq_list_sorted):
res[curr_seq] = eval_sequence(
curr_seq,
dataset,
tracker,
class_list,
metrics_list,
metric_names,
)
else:
for curr_seq in sorted(seq_list):
res[curr_seq] = eval_sequence(
curr_seq,
dataset,
tracker,
class_list,
metrics_list,
metric_names,
)
# Combine results over all sequences and then over all classes
res, combined_cls_keys = self._combine_results(
res, metrics_list, metric_names, dataset, "COMBINED_SEQ"
)
if np.all(
["tags" in annot for annot in dataset.gt_data["annotations"]]
):
# Combine results over the challenging sequences and then over all classes
# currently only support "tracking_challenging_pair"
res, _ = self._combine_results(
res,
metrics_list,
metric_names,
dataset,
"COMBINED_SEQ_CHALLENGING",
"tracking_challenging_pair",
)
# Print and output results in various formats
if config["TIME_PROGRESS"]:
print(
"\nAll sequences for %s finished in %.2f seconds"
% (tracker, time.time() - time_start)
)
self._summarize_results(
res,
tracker,
metrics_list,
metric_names,
dataset,
"COMBINED_SEQ",
combined_cls_keys,
)
if "COMBINED_SEQ_CHALLENGING" in res:
self._summarize_results(
res,
tracker,
metrics_list,
metric_names,
dataset,
"COMBINED_SEQ_CHALLENGING",
combined_cls_keys,
)
# Output for returning from function
output_res[dataset_name][tracker] = res
output_msg[dataset_name][tracker] = "Success"
except Exception as err:
output_res[dataset_name][tracker] = None
if type(err) == TrackEvalException:
output_msg[dataset_name][tracker] = str(err)
else:
output_msg[dataset_name][tracker] = "Unknown error occurred."
print("Tracker %s was unable to be evaluated." % tracker)
print(err)
traceback.print_exc()
if config["LOG_ON_ERROR"] is not None:
with open(config["LOG_ON_ERROR"], "a") as f:
print(dataset_name, file=f)
print(tracker, file=f)
print(traceback.format_exc(), file=f)
print("\n\n\n", file=f)
if config["BREAK_ON_ERROR"]:
raise err
elif config["RETURN_ON_ERROR"]:
return output_res, output_msg
return output_res, output_msg
@_timing.time
def eval_sequence(seq, dataset, tracker, class_list, metrics_list, metric_names):
"""Function for evaluating a single sequence"""
raw_data = dataset.get_raw_seq_data(tracker, seq)
seq_res = {}
for cls in class_list:
seq_res[cls] = {}
data = dataset.get_preprocessed_seq_data(raw_data, cls)
for metric, met_name in zip(metrics_list, metric_names):
seq_res[cls][met_name] = metric.eval_sequence(data)
return seq_res