Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
5
sam3/eval/teta_eval_toolkit/__init__.py
Normal file
5
sam3/eval/teta_eval_toolkit/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
from . import config, datasets, metrics, utils
|
||||
from .eval import Evaluator
|
||||
69
sam3/eval/teta_eval_toolkit/_timing.py
Normal file
69
sam3/eval/teta_eval_toolkit/_timing.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from time import perf_counter
|
||||
|
||||
DO_TIMING = False
|
||||
DISPLAY_LESS_PROGRESS = False
|
||||
timer_dict = {}
|
||||
counter = 0
|
||||
|
||||
|
||||
def time(f):
|
||||
@wraps(f)
|
||||
def wrap(*args, **kw):
|
||||
if DO_TIMING:
|
||||
# Run function with timing
|
||||
ts = perf_counter()
|
||||
result = f(*args, **kw)
|
||||
te = perf_counter()
|
||||
tt = te - ts
|
||||
|
||||
# Get function name
|
||||
arg_names = inspect.getfullargspec(f)[0]
|
||||
if arg_names[0] == "self" and DISPLAY_LESS_PROGRESS:
|
||||
return result
|
||||
elif arg_names[0] == "self":
|
||||
method_name = type(args[0]).__name__ + "." + f.__name__
|
||||
else:
|
||||
method_name = f.__name__
|
||||
|
||||
# Record accumulative time in each function for analysis
|
||||
if method_name in timer_dict.keys():
|
||||
timer_dict[method_name] += tt
|
||||
else:
|
||||
timer_dict[method_name] = tt
|
||||
|
||||
# If code is finished, display timing summary
|
||||
if method_name == "Evaluator.evaluate":
|
||||
print("")
|
||||
print("Timing analysis:")
|
||||
for key, value in timer_dict.items():
|
||||
print("%-70s %2.4f sec" % (key, value))
|
||||
else:
|
||||
# Get function argument values for printing special arguments of interest
|
||||
arg_titles = ["tracker", "seq", "cls"]
|
||||
arg_vals = []
|
||||
for i, a in enumerate(arg_names):
|
||||
if a in arg_titles:
|
||||
arg_vals.append(args[i])
|
||||
arg_text = "(" + ", ".join(arg_vals) + ")"
|
||||
|
||||
# Display methods and functions with different indentation.
|
||||
if arg_names[0] == "self":
|
||||
print("%-74s %2.4f sec" % (" " * 4 + method_name + arg_text, tt))
|
||||
elif arg_names[0] == "test":
|
||||
pass
|
||||
else:
|
||||
global counter
|
||||
counter += 1
|
||||
print("%i %-70s %2.4f sec" % (counter, method_name + arg_text, tt))
|
||||
|
||||
return result
|
||||
else:
|
||||
# If config["TIME_PROGRESS"] is false, or config["USE_PARALLEL"] is true, run functions normally without timing.
|
||||
return f(*args, **kw)
|
||||
|
||||
return wrap
|
||||
153
sam3/eval/teta_eval_toolkit/config.py
Normal file
153
sam3/eval/teta_eval_toolkit/config.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
"""Config."""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def parse_configs():
|
||||
"""Parse command line."""
|
||||
default_eval_config = get_default_eval_config()
|
||||
default_eval_config["DISPLAY_LESS_PROGRESS"] = True
|
||||
default_dataset_config = get_default_dataset_config()
|
||||
default_metrics_config = {"METRICS": ["TETA"]}
|
||||
config = {
|
||||
**default_eval_config,
|
||||
**default_dataset_config,
|
||||
**default_metrics_config,
|
||||
}
|
||||
parser = argparse.ArgumentParser()
|
||||
for setting in config.keys():
|
||||
if type(config[setting]) == list or type(config[setting]) == type(None):
|
||||
parser.add_argument("--" + setting, nargs="+")
|
||||
else:
|
||||
parser.add_argument("--" + setting)
|
||||
args = parser.parse_args().__dict__
|
||||
for setting in args.keys():
|
||||
if args[setting] is not None:
|
||||
if type(config[setting]) == type(True):
|
||||
if args[setting] == "True":
|
||||
x = True
|
||||
elif args[setting] == "False":
|
||||
x = False
|
||||
else:
|
||||
raise Exception(
|
||||
f"Command line parameter {setting} must be True/False"
|
||||
)
|
||||
elif type(config[setting]) == type(1):
|
||||
x = int(args[setting])
|
||||
elif type(args[setting]) == type(None):
|
||||
x = None
|
||||
else:
|
||||
x = args[setting]
|
||||
config[setting] = x
|
||||
eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()}
|
||||
dataset_config = {
|
||||
k: v for k, v in config.items() if k in default_dataset_config.keys()
|
||||
}
|
||||
metrics_config = {
|
||||
k: v for k, v in config.items() if k in default_metrics_config.keys()
|
||||
}
|
||||
|
||||
return eval_config, dataset_config, metrics_config
|
||||
|
||||
|
||||
def get_default_eval_config():
|
||||
"""Returns the default config values for evaluation."""
|
||||
code_path = get_code_path()
|
||||
default_config = {
|
||||
"USE_PARALLEL": True,
|
||||
"NUM_PARALLEL_CORES": 8,
|
||||
"BREAK_ON_ERROR": True,
|
||||
"RETURN_ON_ERROR": False,
|
||||
"LOG_ON_ERROR": os.path.join(code_path, "error_log.txt"),
|
||||
"PRINT_RESULTS": True,
|
||||
"PRINT_ONLY_COMBINED": True,
|
||||
"PRINT_CONFIG": True,
|
||||
"TIME_PROGRESS": True,
|
||||
"DISPLAY_LESS_PROGRESS": True,
|
||||
"OUTPUT_SUMMARY": True,
|
||||
"OUTPUT_EMPTY_CLASSES": True,
|
||||
"OUTPUT_TEM_RAW_DATA": True,
|
||||
"OUTPUT_PER_SEQ_RES": True,
|
||||
}
|
||||
return default_config
|
||||
|
||||
|
||||
def get_default_dataset_config():
|
||||
"""Default class config values"""
|
||||
code_path = get_code_path()
|
||||
default_config = {
|
||||
"GT_FOLDER": os.path.join(
|
||||
code_path, "data/gt/tao/tao_training"
|
||||
), # Location of GT data
|
||||
"TRACKERS_FOLDER": os.path.join(
|
||||
code_path, "data/trackers/tao/tao_training"
|
||||
), # Trackers location
|
||||
"OUTPUT_FOLDER": None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
|
||||
"TRACKERS_TO_EVAL": ['TETer'], # Filenames of trackers to eval (if None, all in folder)
|
||||
"CLASSES_TO_EVAL": None, # Classes to eval (if None, all classes)
|
||||
"SPLIT_TO_EVAL": "training", # Valid: 'training', 'val'
|
||||
"PRINT_CONFIG": True, # Whether to print current config
|
||||
"TRACKER_SUB_FOLDER": "data", # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
|
||||
"OUTPUT_SUB_FOLDER": "", # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
|
||||
"TRACKER_DISPLAY_NAMES": None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
|
||||
"MAX_DETECTIONS": 0, # Number of maximal allowed detections per image (0 for unlimited)
|
||||
"USE_MASK": False, # Whether to use mask data for evaluation
|
||||
}
|
||||
return default_config
|
||||
|
||||
|
||||
def init_config(config, default_config, name=None):
|
||||
"""Initialize non-given config values with defaults."""
|
||||
if config is None:
|
||||
config = default_config
|
||||
else:
|
||||
for k in default_config.keys():
|
||||
if k not in config.keys():
|
||||
config[k] = default_config[k]
|
||||
if name and config["PRINT_CONFIG"]:
|
||||
print("\n%s Config:" % name)
|
||||
for c in config.keys():
|
||||
print("%-20s : %-30s" % (c, config[c]))
|
||||
return config
|
||||
|
||||
|
||||
def update_config(config):
|
||||
"""
|
||||
Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
|
||||
:param config: the config to update
|
||||
:return: the updated config
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
for setting in config.keys():
|
||||
if type(config[setting]) == list or type(config[setting]) == type(None):
|
||||
parser.add_argument("--" + setting, nargs="+")
|
||||
else:
|
||||
parser.add_argument("--" + setting)
|
||||
args = parser.parse_args().__dict__
|
||||
for setting in args.keys():
|
||||
if args[setting] is not None:
|
||||
if type(config[setting]) == type(True):
|
||||
if args[setting] == "True":
|
||||
x = True
|
||||
elif args[setting] == "False":
|
||||
x = False
|
||||
else:
|
||||
raise Exception(
|
||||
"Command line parameter " + setting + "must be True or False"
|
||||
)
|
||||
elif type(config[setting]) == type(1):
|
||||
x = int(args[setting])
|
||||
elif type(args[setting]) == type(None):
|
||||
x = None
|
||||
else:
|
||||
x = args[setting]
|
||||
config[setting] = x
|
||||
return config
|
||||
|
||||
|
||||
def get_code_path():
|
||||
"""Get base path where code is"""
|
||||
return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
5
sam3/eval/teta_eval_toolkit/datasets/__init__.py
Normal file
5
sam3/eval/teta_eval_toolkit/datasets/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
"""Datasets."""
|
||||
from .coco import COCO
|
||||
from .tao import TAO
|
||||
379
sam3/eval/teta_eval_toolkit/datasets/_base_dataset.py
Normal file
379
sam3/eval/teta_eval_toolkit/datasets/_base_dataset.py
Normal file
@@ -0,0 +1,379 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
import traceback
|
||||
import zipfile
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import _timing
|
||||
from ..utils import TrackEvalException
|
||||
|
||||
|
||||
class _BaseDataset(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
self.tracker_list = None
|
||||
self.seq_list = None
|
||||
self.class_list = None
|
||||
self.output_fol = None
|
||||
self.output_sub_fol = None
|
||||
self.should_classes_combine = True
|
||||
self.use_super_categories = False
|
||||
|
||||
# Functions to implement:
|
||||
|
||||
@abstractmethod
|
||||
def _load_raw_file(self, tracker, seq, is_gt):
|
||||
...
|
||||
|
||||
@_timing.time
|
||||
@abstractmethod
|
||||
def get_preprocessed_seq_data(self, raw_data, cls):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
|
||||
...
|
||||
|
||||
# Helper functions for all datasets:
|
||||
|
||||
@classmethod
|
||||
def get_class_name(cls):
|
||||
return cls.__name__
|
||||
|
||||
def get_name(self):
|
||||
return self.get_class_name()
|
||||
|
||||
def get_output_fol(self, tracker):
|
||||
return os.path.join(self.output_fol, tracker, self.output_sub_fol)
|
||||
|
||||
def get_display_name(self, tracker):
|
||||
"""Can be overwritten if the trackers name (in files) is different to how it should be displayed.
|
||||
By default this method just returns the trackers name as is.
|
||||
"""
|
||||
return tracker
|
||||
|
||||
def get_eval_info(self):
|
||||
"""Return info about the dataset needed for the Evaluator"""
|
||||
return self.tracker_list, self.seq_list, self.class_list
|
||||
|
||||
@_timing.time
|
||||
def get_raw_seq_data(self, tracker, seq):
|
||||
"""Loads raw data (tracker and ground-truth) for a single tracker on a single sequence.
|
||||
Raw data includes all of the information needed for both preprocessing and evaluation, for all classes.
|
||||
A later function (get_processed_seq_data) will perform such preprocessing and extract relevant information for
|
||||
the evaluation of each class.
|
||||
|
||||
This returns a dict which contains the fields:
|
||||
[num_timesteps]: integer
|
||||
[gt_ids, tracker_ids, gt_classes, tracker_classes, tracker_confidences]:
|
||||
list (for each timestep) of 1D NDArrays (for each det).
|
||||
[gt_dets, tracker_dets, gt_crowd_ignore_regions]: list (for each timestep) of lists of detections.
|
||||
[similarity_scores]: list (for each timestep) of 2D NDArrays.
|
||||
[gt_extras]: dict (for each extra) of lists (for each timestep) of 1D NDArrays (for each det).
|
||||
|
||||
gt_extras contains dataset specific information used for preprocessing such as occlusion and truncation levels.
|
||||
|
||||
Note that similarities are extracted as part of the dataset and not the metric, because almost all metrics are
|
||||
independent of the exact method of calculating the similarity. However datasets are not (e.g. segmentation
|
||||
masks vs 2D boxes vs 3D boxes).
|
||||
We calculate the similarity before preprocessing because often both preprocessing and evaluation require it and
|
||||
we don't wish to calculate this twice.
|
||||
We calculate similarity between all gt and tracker classes (not just each class individually) to allow for
|
||||
calculation of metrics such as class confusion matrices. Typically the impact of this on performance is low.
|
||||
"""
|
||||
# Load raw data.
|
||||
raw_gt_data = self._load_raw_file(tracker, seq, is_gt=True)
|
||||
raw_tracker_data = self._load_raw_file(tracker, seq, is_gt=False)
|
||||
raw_data = {**raw_tracker_data, **raw_gt_data} # Merges dictionaries
|
||||
|
||||
# Calculate similarities for each timestep.
|
||||
similarity_scores = []
|
||||
for _, (gt_dets_t, tracker_dets_t) in enumerate(
|
||||
zip(raw_data["gt_dets"], raw_data["tk_dets"])
|
||||
):
|
||||
ious = self._calculate_similarities(gt_dets_t, tracker_dets_t)
|
||||
similarity_scores.append(ious)
|
||||
raw_data["similarity_scores"] = similarity_scores
|
||||
return raw_data
|
||||
|
||||
@staticmethod
|
||||
def _load_simple_text_file(
|
||||
file,
|
||||
time_col=0,
|
||||
id_col=None,
|
||||
remove_negative_ids=False,
|
||||
valid_filter=None,
|
||||
crowd_ignore_filter=None,
|
||||
convert_filter=None,
|
||||
is_zipped=False,
|
||||
zip_file=None,
|
||||
force_delimiters=None,
|
||||
):
|
||||
"""Function that loads data which is in a commonly used text file format.
|
||||
Assumes each det is given by one row of a text file.
|
||||
There is no limit to the number or meaning of each column,
|
||||
however one column needs to give the timestep of each det (time_col) which is default col 0.
|
||||
|
||||
The file dialect (deliminator, num cols, etc) is determined automatically.
|
||||
This function automatically separates dets by timestep,
|
||||
and is much faster than alternatives such as np.loadtext or pandas.
|
||||
|
||||
If remove_negative_ids is True and id_col is not None, dets with negative values in id_col are excluded.
|
||||
These are not excluded from ignore data.
|
||||
|
||||
valid_filter can be used to only include certain classes.
|
||||
It is a dict with ints as keys, and lists as values,
|
||||
such that a row is included if "row[key].lower() is in value" for all key/value pairs in the dict.
|
||||
If None, all classes are included.
|
||||
|
||||
crowd_ignore_filter can be used to read crowd_ignore regions separately. It has the same format as valid filter.
|
||||
|
||||
convert_filter can be used to convert value read to another format.
|
||||
This is used most commonly to convert classes given as string to a class id.
|
||||
This is a dict such that the key is the column to convert, and the value is another dict giving the mapping.
|
||||
|
||||
Optionally, input files could be a zip of multiple text files for storage efficiency.
|
||||
|
||||
Returns read_data and ignore_data.
|
||||
Each is a dict (with keys as timesteps as strings) of lists (over dets) of lists (over column values).
|
||||
Note that all data is returned as strings, and must be converted to float/int later if needed.
|
||||
Note that timesteps will not be present in the returned dict keys if there are no dets for them
|
||||
"""
|
||||
|
||||
if remove_negative_ids and id_col is None:
|
||||
raise TrackEvalException(
|
||||
"remove_negative_ids is True, but id_col is not given."
|
||||
)
|
||||
if crowd_ignore_filter is None:
|
||||
crowd_ignore_filter = {}
|
||||
if convert_filter is None:
|
||||
convert_filter = {}
|
||||
try:
|
||||
if is_zipped: # Either open file directly or within a zip.
|
||||
if zip_file is None:
|
||||
raise TrackEvalException(
|
||||
"is_zipped set to True, but no zip_file is given."
|
||||
)
|
||||
archive = zipfile.ZipFile(os.path.join(zip_file), "r")
|
||||
fp = io.TextIOWrapper(archive.open(file, "r"))
|
||||
else:
|
||||
fp = open(file)
|
||||
read_data = {}
|
||||
crowd_ignore_data = {}
|
||||
fp.seek(0, os.SEEK_END)
|
||||
# check if file is empty
|
||||
if fp.tell():
|
||||
fp.seek(0)
|
||||
dialect = csv.Sniffer().sniff(
|
||||
fp.readline(), delimiters=force_delimiters
|
||||
) # Auto determine structure.
|
||||
dialect.skipinitialspace = (
|
||||
True # Deal with extra spaces between columns
|
||||
)
|
||||
fp.seek(0)
|
||||
reader = csv.reader(fp, dialect)
|
||||
for row in reader:
|
||||
try:
|
||||
# Deal with extra trailing spaces at the end of rows
|
||||
if row[-1] in "":
|
||||
row = row[:-1]
|
||||
timestep = str(int(float(row[time_col])))
|
||||
# Read ignore regions separately.
|
||||
is_ignored = False
|
||||
for ignore_key, ignore_value in crowd_ignore_filter.items():
|
||||
if row[ignore_key].lower() in ignore_value:
|
||||
# Convert values in one column (e.g. string to id)
|
||||
for (
|
||||
convert_key,
|
||||
convert_value,
|
||||
) in convert_filter.items():
|
||||
row[convert_key] = convert_value[
|
||||
row[convert_key].lower()
|
||||
]
|
||||
# Save data separated by timestep.
|
||||
if timestep in crowd_ignore_data.keys():
|
||||
crowd_ignore_data[timestep].append(row)
|
||||
else:
|
||||
crowd_ignore_data[timestep] = [row]
|
||||
is_ignored = True
|
||||
if (
|
||||
is_ignored
|
||||
): # if det is an ignore region, it cannot be a normal det.
|
||||
continue
|
||||
# Exclude some dets if not valid.
|
||||
if valid_filter is not None:
|
||||
for key, value in valid_filter.items():
|
||||
if row[key].lower() not in value:
|
||||
continue
|
||||
if remove_negative_ids:
|
||||
if int(float(row[id_col])) < 0:
|
||||
continue
|
||||
# Convert values in one column (e.g. string to id)
|
||||
for convert_key, convert_value in convert_filter.items():
|
||||
row[convert_key] = convert_value[row[convert_key].lower()]
|
||||
# Save data separated by timestep.
|
||||
if timestep in read_data.keys():
|
||||
read_data[timestep].append(row)
|
||||
else:
|
||||
read_data[timestep] = [row]
|
||||
except Exception:
|
||||
exc_str_init = (
|
||||
"In file %s the following line cannot be read correctly: \n"
|
||||
% os.path.basename(file)
|
||||
)
|
||||
exc_str = " ".join([exc_str_init] + row)
|
||||
raise TrackEvalException(exc_str)
|
||||
fp.close()
|
||||
except Exception:
|
||||
print("Error loading file: %s, printing traceback." % file)
|
||||
traceback.print_exc()
|
||||
raise TrackEvalException(
|
||||
"File %s cannot be read because it is either not present or invalidly formatted"
|
||||
% os.path.basename(file)
|
||||
)
|
||||
return read_data, crowd_ignore_data
|
||||
|
||||
@staticmethod
|
||||
def _calculate_mask_ious(masks1, masks2, is_encoded=False, do_ioa=False):
|
||||
"""Calculates the IOU (intersection over union) between two arrays of segmentation masks.
|
||||
If is_encoded a run length encoding with pycocotools is assumed as input format, otherwise an input of numpy
|
||||
arrays of the shape (num_masks, height, width) is assumed and the encoding is performed.
|
||||
If do_ioa (intersection over area) , then calculates the intersection over the area of masks1 - this is commonly
|
||||
used to determine if detections are within crowd ignore region.
|
||||
:param masks1: first set of masks (numpy array of shape (num_masks, height, width) if not encoded,
|
||||
else pycocotools rle encoded format)
|
||||
:param masks2: second set of masks (numpy array of shape (num_masks, height, width) if not encoded,
|
||||
else pycocotools rle encoded format)
|
||||
:param is_encoded: whether the input is in pycocotools rle encoded format
|
||||
:param do_ioa: whether to perform IoA computation
|
||||
:return: the IoU/IoA scores
|
||||
"""
|
||||
|
||||
# Only loaded when run to reduce minimum requirements
|
||||
from pycocotools import mask as mask_utils
|
||||
|
||||
# use pycocotools for run length encoding of masks
|
||||
if not is_encoded:
|
||||
masks1 = mask_utils.encode(
|
||||
np.array(np.transpose(masks1, (1, 2, 0)), order="F")
|
||||
)
|
||||
masks2 = mask_utils.encode(
|
||||
np.array(np.transpose(masks2, (1, 2, 0)), order="F")
|
||||
)
|
||||
|
||||
# use pycocotools for iou computation of rle encoded masks
|
||||
ious = mask_utils.iou(masks1, masks2, [do_ioa] * len(masks2))
|
||||
if len(masks1) == 0 or len(masks2) == 0:
|
||||
ious = np.asarray(ious).reshape(len(masks1), len(masks2))
|
||||
assert (ious >= 0 - np.finfo("float").eps).all()
|
||||
assert (ious <= 1 + np.finfo("float").eps).all()
|
||||
|
||||
return ious
|
||||
|
||||
@staticmethod
|
||||
def _calculate_box_ious(bboxes1, bboxes2, box_format="xywh", do_ioa=False):
|
||||
"""Calculates the IOU (intersection over union) between two arrays of boxes.
|
||||
Allows variable box formats ('xywh' and 'x0y0x1y1').
|
||||
If do_ioa (intersection over area) , then calculates the intersection over the area of boxes1 - this is commonly
|
||||
used to determine if detections are within crowd ignore region.
|
||||
"""
|
||||
if box_format in "xywh":
|
||||
# layout: (x0, y0, w, h)
|
||||
bboxes1 = deepcopy(bboxes1)
|
||||
bboxes2 = deepcopy(bboxes2)
|
||||
|
||||
bboxes1[:, 2] = bboxes1[:, 0] + bboxes1[:, 2]
|
||||
bboxes1[:, 3] = bboxes1[:, 1] + bboxes1[:, 3]
|
||||
bboxes2[:, 2] = bboxes2[:, 0] + bboxes2[:, 2]
|
||||
bboxes2[:, 3] = bboxes2[:, 1] + bboxes2[:, 3]
|
||||
elif box_format not in "x0y0x1y1":
|
||||
raise (TrackEvalException("box_format %s is not implemented" % box_format))
|
||||
|
||||
# layout: (x0, y0, x1, y1)
|
||||
min_ = np.minimum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
|
||||
max_ = np.maximum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
|
||||
intersection = np.maximum(min_[..., 2] - max_[..., 0], 0) * np.maximum(
|
||||
min_[..., 3] - max_[..., 1], 0
|
||||
)
|
||||
area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
|
||||
bboxes1[..., 3] - bboxes1[..., 1]
|
||||
)
|
||||
|
||||
if do_ioa:
|
||||
ioas = np.zeros_like(intersection)
|
||||
valid_mask = area1 > 0 + np.finfo("float").eps
|
||||
ioas[valid_mask, :] = (
|
||||
intersection[valid_mask, :] / area1[valid_mask][:, np.newaxis]
|
||||
)
|
||||
|
||||
return ioas
|
||||
else:
|
||||
area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
|
||||
bboxes2[..., 3] - bboxes2[..., 1]
|
||||
)
|
||||
union = area1[:, np.newaxis] + area2[np.newaxis, :] - intersection
|
||||
intersection[area1 <= 0 + np.finfo("float").eps, :] = 0
|
||||
intersection[:, area2 <= 0 + np.finfo("float").eps] = 0
|
||||
intersection[union <= 0 + np.finfo("float").eps] = 0
|
||||
union[union <= 0 + np.finfo("float").eps] = 1
|
||||
ious = intersection / union
|
||||
return ious
|
||||
|
||||
@staticmethod
|
||||
def _calculate_euclidean_similarity(dets1, dets2, zero_distance=2.0):
|
||||
"""Calculates the euclidean distance between two sets of detections, and then converts this into a similarity
|
||||
measure with values between 0 and 1 using the following formula: sim = max(0, 1 - dist/zero_distance).
|
||||
The default zero_distance of 2.0, corresponds to the default used in MOT15_3D, such that a 0.5 similarity
|
||||
threshold corresponds to a 1m distance threshold for TPs.
|
||||
"""
|
||||
dist = np.linalg.norm(dets1[:, np.newaxis] - dets2[np.newaxis, :], axis=2)
|
||||
sim = np.maximum(0, 1 - dist / zero_distance)
|
||||
return sim
|
||||
|
||||
@staticmethod
|
||||
def _check_unique_ids(data, after_preproc=False):
|
||||
"""Check the requirement that the tracker_ids and gt_ids are unique per timestep"""
|
||||
gt_ids = data["gt_ids"]
|
||||
tracker_ids = data["tk_ids"]
|
||||
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(gt_ids, tracker_ids)):
|
||||
if len(tracker_ids_t) > 0:
|
||||
unique_ids, counts = np.unique(tracker_ids_t, return_counts=True)
|
||||
if np.max(counts) != 1:
|
||||
duplicate_ids = unique_ids[counts > 1]
|
||||
exc_str_init = (
|
||||
"Tracker predicts the same ID more than once in a single timestep "
|
||||
"(seq: %s, frame: %i, ids:" % (data["seq"], t + 1)
|
||||
)
|
||||
exc_str = (
|
||||
" ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")"
|
||||
)
|
||||
if after_preproc:
|
||||
exc_str_init += (
|
||||
"\n Note that this error occurred after preprocessing (but not before), "
|
||||
"so ids may not be as in file, and something seems wrong with preproc."
|
||||
)
|
||||
raise TrackEvalException(exc_str)
|
||||
if len(gt_ids_t) > 0:
|
||||
unique_ids, counts = np.unique(gt_ids_t, return_counts=True)
|
||||
if np.max(counts) != 1:
|
||||
duplicate_ids = unique_ids[counts > 1]
|
||||
exc_str_init = (
|
||||
"Ground-truth has the same ID more than once in a single timestep "
|
||||
"(seq: %s, frame: %i, ids:" % (data["seq"], t + 1)
|
||||
)
|
||||
exc_str = (
|
||||
" ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")"
|
||||
)
|
||||
if after_preproc:
|
||||
exc_str_init += (
|
||||
"\n Note that this error occurred after preprocessing (but not before), "
|
||||
"so ids may not be as in file, and something seems wrong with preproc."
|
||||
)
|
||||
raise TrackEvalException(exc_str)
|
||||
637
sam3/eval/teta_eval_toolkit/datasets/coco.py
Normal file
637
sam3/eval/teta_eval_toolkit/datasets/coco.py
Normal file
@@ -0,0 +1,637 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
"""COCO Dataset."""
|
||||
import copy
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from .. import _timing, utils
|
||||
from ..config import get_default_dataset_config, init_config
|
||||
from ..utils import TrackEvalException
|
||||
from ._base_dataset import _BaseDataset
|
||||
|
||||
|
||||
class COCO(_BaseDataset):
|
||||
"""Tracking datasets in COCO format."""
|
||||
|
||||
def __init__(self, config=None):
|
||||
"""Initialize dataset, checking that all required files are present."""
|
||||
super().__init__()
|
||||
# Fill non-given config values with defaults
|
||||
self.config = init_config(config, get_default_dataset_config(), self.get_name())
|
||||
self.gt_fol = self.config["GT_FOLDER"]
|
||||
self.tracker_fol = self.config["TRACKERS_FOLDER"]
|
||||
self.should_classes_combine = True
|
||||
self.use_super_categories = False
|
||||
self.use_mask = self.config["USE_MASK"]
|
||||
|
||||
self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"]
|
||||
self.output_fol = self.config["OUTPUT_FOLDER"]
|
||||
if self.output_fol is None:
|
||||
self.output_fol = self.tracker_fol
|
||||
self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"]
|
||||
|
||||
if self.gt_fol.endswith(".json"):
|
||||
self.gt_data = json.load(open(self.gt_fol, "r"))
|
||||
else:
|
||||
gt_dir_files = [
|
||||
file for file in os.listdir(self.gt_fol) if file.endswith(".json")
|
||||
]
|
||||
if len(gt_dir_files) != 1:
|
||||
raise TrackEvalException(
|
||||
f"{self.gt_fol} does not contain exactly one json file."
|
||||
)
|
||||
|
||||
with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f:
|
||||
self.gt_data = json.load(f)
|
||||
|
||||
# fill missing video ids
|
||||
self._fill_video_ids_inplace(self.gt_data["annotations"])
|
||||
|
||||
# get sequences to eval and sequence information
|
||||
self.seq_list = [
|
||||
vid["name"].replace("/", "-") for vid in self.gt_data["videos"]
|
||||
]
|
||||
self.seq_name2seqid = {
|
||||
vid["name"].replace("/", "-"): vid["id"] for vid in self.gt_data["videos"]
|
||||
}
|
||||
# compute mappings from videos to annotation data
|
||||
self.video2gt_track, self.video2gt_image = self._compute_vid_mappings(
|
||||
self.gt_data["annotations"]
|
||||
)
|
||||
# compute sequence lengths
|
||||
self.seq_lengths = {vid["id"]: 0 for vid in self.gt_data["videos"]}
|
||||
for img in self.gt_data["images"]:
|
||||
self.seq_lengths[img["video_id"]] += 1
|
||||
self.seq2images2timestep = self._compute_image_to_timestep_mappings()
|
||||
self.seq2cls = {
|
||||
vid["id"]: {
|
||||
"pos_cat_ids": list(
|
||||
{track["category_id"] for track in self.video2gt_track[vid["id"]]}
|
||||
),
|
||||
}
|
||||
for vid in self.gt_data["videos"]
|
||||
}
|
||||
|
||||
# Get classes to eval
|
||||
considered_vid_ids = [self.seq_name2seqid[vid] for vid in self.seq_list]
|
||||
seen_cats = set(
|
||||
[
|
||||
cat_id
|
||||
for vid_id in considered_vid_ids
|
||||
for cat_id in self.seq2cls[vid_id]["pos_cat_ids"]
|
||||
]
|
||||
)
|
||||
# only classes with ground truth are evaluated in TAO
|
||||
self.valid_classes = [
|
||||
cls["name"] for cls in self.gt_data["categories"] if cls["id"] in seen_cats
|
||||
]
|
||||
cls_name2clsid_map = {
|
||||
cls["name"]: cls["id"] for cls in self.gt_data["categories"]
|
||||
}
|
||||
|
||||
if self.config["CLASSES_TO_EVAL"]:
|
||||
self.class_list = [
|
||||
cls.lower() if cls.lower() in self.valid_classes else None
|
||||
for cls in self.config["CLASSES_TO_EVAL"]
|
||||
]
|
||||
if not all(self.class_list):
|
||||
valid_cls = ", ".join(self.valid_classes)
|
||||
raise TrackEvalException(
|
||||
"Attempted to evaluate an invalid class. Only classes "
|
||||
f"{valid_cls} are valid (classes present in ground truth"
|
||||
" data)."
|
||||
)
|
||||
else:
|
||||
self.class_list = [cls for cls in self.valid_classes]
|
||||
self.cls_name2clsid = {
|
||||
k: v for k, v in cls_name2clsid_map.items() if k in self.class_list
|
||||
}
|
||||
self.clsid2cls_name = {
|
||||
v: k for k, v in cls_name2clsid_map.items() if k in self.class_list
|
||||
}
|
||||
# get trackers to eval
|
||||
if self.config["TRACKERS_TO_EVAL"] is None:
|
||||
self.tracker_list = os.listdir(self.tracker_fol)
|
||||
else:
|
||||
self.tracker_list = self.config["TRACKERS_TO_EVAL"]
|
||||
|
||||
if self.config["TRACKER_DISPLAY_NAMES"] is None:
|
||||
self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list))
|
||||
elif (self.config["TRACKERS_TO_EVAL"] is not None) and (
|
||||
len(self.config["TK_DISPLAY_NAMES"]) == len(self.tracker_list)
|
||||
):
|
||||
self.tracker_to_disp = dict(
|
||||
zip(self.tracker_list, self.config["TK_DISPLAY_NAMES"])
|
||||
)
|
||||
else:
|
||||
raise TrackEvalException(
|
||||
"List of tracker files and tracker display names do not match."
|
||||
)
|
||||
|
||||
self.tracker_data = {tracker: dict() for tracker in self.tracker_list}
|
||||
|
||||
for tracker in self.tracker_list:
|
||||
if self.tracker_sub_fol.endswith(".json"):
|
||||
with open(os.path.join(self.tracker_sub_fol)) as f:
|
||||
curr_data = json.load(f)
|
||||
else:
|
||||
tr_dir = os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol)
|
||||
tr_dir_files = [
|
||||
file for file in os.listdir(tr_dir) if file.endswith(".json")
|
||||
]
|
||||
if len(tr_dir_files) != 1:
|
||||
raise TrackEvalException(
|
||||
f"{tr_dir} does not contain exactly one json file."
|
||||
)
|
||||
with open(os.path.join(tr_dir, tr_dir_files[0])) as f:
|
||||
curr_data = json.load(f)
|
||||
|
||||
# limit detections if MAX_DETECTIONS > 0
|
||||
if self.config["MAX_DETECTIONS"]:
|
||||
curr_data = self._limit_dets_per_image(curr_data)
|
||||
|
||||
# fill missing video ids
|
||||
self._fill_video_ids_inplace(curr_data)
|
||||
|
||||
# make track ids unique over whole evaluation set
|
||||
self._make_tk_ids_unique(curr_data)
|
||||
|
||||
# get tracker sequence information
|
||||
curr_vids2tracks, curr_vids2images = self._compute_vid_mappings(curr_data)
|
||||
self.tracker_data[tracker]["vids_to_tracks"] = curr_vids2tracks
|
||||
self.tracker_data[tracker]["vids_to_images"] = curr_vids2images
|
||||
|
||||
def get_display_name(self, tracker):
|
||||
return self.tracker_to_disp[tracker]
|
||||
|
||||
def _load_raw_file(self, tracker, seq, is_gt):
|
||||
"""Load a file (gt or tracker) in the TAO format
|
||||
|
||||
If is_gt, this returns a dict which contains the fields:
|
||||
[gt_ids, gt_classes]:
|
||||
list (for each timestep) of 1D NDArrays (for each det).
|
||||
[gt_dets]: list (for each timestep) of lists of detections.
|
||||
|
||||
if not is_gt, this returns a dict which contains the fields:
|
||||
[tk_ids, tk_classes]:
|
||||
list (for each timestep) of 1D NDArrays (for each det).
|
||||
[tk_dets]: list (for each timestep) of lists of detections.
|
||||
"""
|
||||
seq_id = self.seq_name2seqid[seq]
|
||||
# file location
|
||||
if is_gt:
|
||||
imgs = self.video2gt_image[seq_id]
|
||||
else:
|
||||
imgs = self.tracker_data[tracker]["vids_to_images"][seq_id]
|
||||
|
||||
# convert data to required format
|
||||
num_timesteps = self.seq_lengths[seq_id]
|
||||
img_to_timestep = self.seq2images2timestep[seq_id]
|
||||
data_keys = ["ids", "classes", "dets"]
|
||||
# if not is_gt:
|
||||
# data_keys += ["tk_confidences"]
|
||||
raw_data = {key: [None] * num_timesteps for key in data_keys}
|
||||
for img in imgs:
|
||||
# some tracker data contains images without any ground truth info,
|
||||
# these are ignored
|
||||
if img["id"] not in img_to_timestep:
|
||||
continue
|
||||
t = img_to_timestep[img["id"]]
|
||||
anns = img["annotations"]
|
||||
tk_str = utils.get_track_id_str(anns[0])
|
||||
|
||||
if self.use_mask:
|
||||
# When using mask, extract segmentation data
|
||||
raw_data["dets"][t] = [ann.get("segmentation") for ann in anns]
|
||||
else:
|
||||
# When using bbox, extract bbox data
|
||||
raw_data["dets"][t] = np.atleast_2d([ann["bbox"] for ann in anns]).astype(
|
||||
float
|
||||
)
|
||||
raw_data["ids"][t] = np.atleast_1d([ann[tk_str] for ann in anns]).astype(
|
||||
int
|
||||
)
|
||||
raw_data["classes"][t] = np.atleast_1d(
|
||||
[ann["category_id"] for ann in anns]
|
||||
).astype(int)
|
||||
# if not is_gt:
|
||||
# raw_data["tk_confidences"][t] = np.atleast_1d(
|
||||
# [ann["score"] for ann in anns]
|
||||
# ).astype(float)
|
||||
|
||||
for t, d in enumerate(raw_data["dets"]):
|
||||
if d is None:
|
||||
raw_data["dets"][t] = np.empty((0, 4)).astype(float)
|
||||
raw_data["ids"][t] = np.empty(0).astype(int)
|
||||
raw_data["classes"][t] = np.empty(0).astype(int)
|
||||
# if not is_gt:
|
||||
# raw_data["tk_confidences"][t] = np.empty(0)
|
||||
|
||||
if is_gt:
|
||||
key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"}
|
||||
else:
|
||||
key_map = {"ids": "tk_ids", "classes": "tk_classes", "dets": "tk_dets"}
|
||||
for k, v in key_map.items():
|
||||
raw_data[v] = raw_data.pop(k)
|
||||
|
||||
raw_data["num_timesteps"] = num_timesteps
|
||||
raw_data["seq"] = seq
|
||||
return raw_data
|
||||
|
||||
def get_preprocessed_seq_data_thr(self, raw_data, cls, assignment=None):
|
||||
"""Preprocess data for a single sequence for a single class.
|
||||
|
||||
Inputs:
|
||||
raw_data: dict containing the data for the sequence already
|
||||
read in by get_raw_seq_data().
|
||||
cls: class to be evaluated.
|
||||
Outputs:
|
||||
gt_ids:
|
||||
list (for each timestep) of ids of GT tracks
|
||||
tk_ids:
|
||||
list (for each timestep) of ids of predicted tracks (all for TP
|
||||
matching (Det + AssocA))
|
||||
tk_overlap_ids:
|
||||
list (for each timestep) of ids of predicted tracks that overlap
|
||||
with GTs
|
||||
tk_dets:
|
||||
list (for each timestep) of lists of detections that
|
||||
corresponding to the tk_ids
|
||||
tk_classes:
|
||||
list (for each timestep) of lists of classes that corresponding
|
||||
to the tk_ids
|
||||
tk_confidences:
|
||||
list (for each timestep) of lists of classes that corresponding
|
||||
to the tk_ids
|
||||
sim_scores:
|
||||
similarity score between gt_ids and tk_ids.
|
||||
"""
|
||||
if cls != "all":
|
||||
cls_id = self.cls_name2clsid[cls]
|
||||
|
||||
data_keys = [
|
||||
"gt_ids",
|
||||
"tk_ids",
|
||||
"gt_id_map",
|
||||
"tk_id_map",
|
||||
"gt_dets",
|
||||
"gt_classes",
|
||||
"gt_class_name",
|
||||
"tk_overlap_classes",
|
||||
"tk_overlap_ids",
|
||||
"tk_class_eval_tk_ids",
|
||||
"tk_dets",
|
||||
"tk_classes",
|
||||
# "tk_confidences",
|
||||
"tk_exh_ids",
|
||||
"sim_scores",
|
||||
]
|
||||
data = {key: [None] * raw_data["num_timesteps"] for key in data_keys}
|
||||
unique_gt_ids = []
|
||||
unique_tk_ids = []
|
||||
num_gt_dets = 0
|
||||
num_tk_cls_dets = 0
|
||||
num_tk_overlap_dets = 0
|
||||
overlap_ious_thr = 0.5
|
||||
loc_and_asso_tk_ids = []
|
||||
exh_class_tk_ids = []
|
||||
|
||||
for t in range(raw_data["num_timesteps"]):
|
||||
# only extract relevant dets for this class for preproc and eval
|
||||
if cls == "all":
|
||||
gt_class_mask = np.ones_like(raw_data["gt_classes"][t]).astype(bool)
|
||||
else:
|
||||
gt_class_mask = np.atleast_1d(
|
||||
raw_data["gt_classes"][t] == cls_id
|
||||
).astype(bool)
|
||||
|
||||
# select GT that is not in the evaluating classes
|
||||
if assignment is not None and assignment:
|
||||
all_gt_ids = list(assignment[t].keys())
|
||||
gt_ids_in = raw_data["gt_ids"][t][gt_class_mask]
|
||||
gt_ids_out = set(all_gt_ids) - set(gt_ids_in)
|
||||
tk_ids_out = set([assignment[t][key] for key in list(gt_ids_out)])
|
||||
|
||||
# compute overlapped tracks and add their ids to overlap_tk_ids
|
||||
sim_scores = raw_data["similarity_scores"]
|
||||
overlap_ids_masks = (sim_scores[t][gt_class_mask] >= overlap_ious_thr).any(
|
||||
axis=0
|
||||
)
|
||||
overlap_tk_ids_t = raw_data["tk_ids"][t][overlap_ids_masks]
|
||||
if assignment is not None and assignment:
|
||||
data["tk_overlap_ids"][t] = list(set(overlap_tk_ids_t) - tk_ids_out)
|
||||
else:
|
||||
data["tk_overlap_ids"][t] = list(set(overlap_tk_ids_t))
|
||||
|
||||
loc_and_asso_tk_ids += data["tk_overlap_ids"][t]
|
||||
|
||||
data["tk_exh_ids"][t] = []
|
||||
if cls == "all":
|
||||
continue
|
||||
|
||||
# add the track ids of exclusive annotated class to exh_class_tk_ids
|
||||
tk_exh_mask = np.atleast_1d(raw_data["tk_classes"][t] == cls_id)
|
||||
tk_exh_mask = tk_exh_mask.astype(bool)
|
||||
exh_class_tk_ids_t = raw_data["tk_ids"][t][tk_exh_mask]
|
||||
exh_class_tk_ids.append(exh_class_tk_ids_t)
|
||||
data["tk_exh_ids"][t] = exh_class_tk_ids_t
|
||||
|
||||
# remove tk_ids that has been assigned to GT belongs to other classes.
|
||||
loc_and_asso_tk_ids = list(set(loc_and_asso_tk_ids))
|
||||
|
||||
# remove all unwanted unmatched tracker detections
|
||||
for t in range(raw_data["num_timesteps"]):
|
||||
# add gt to the data
|
||||
if cls == "all":
|
||||
gt_class_mask = np.ones_like(raw_data["gt_classes"][t]).astype(bool)
|
||||
else:
|
||||
gt_class_mask = np.atleast_1d(
|
||||
raw_data["gt_classes"][t] == cls_id
|
||||
).astype(bool)
|
||||
data["gt_classes"][t] = cls_id
|
||||
data["gt_class_name"][t] = cls
|
||||
|
||||
gt_ids = raw_data["gt_ids"][t][gt_class_mask]
|
||||
if self.use_mask:
|
||||
gt_dets = [raw_data['gt_dets'][t][ind] for ind in range(len(gt_class_mask)) if gt_class_mask[ind]]
|
||||
else:
|
||||
gt_dets = raw_data["gt_dets"][t][gt_class_mask]
|
||||
data["gt_ids"][t] = gt_ids
|
||||
data["gt_dets"][t] = gt_dets
|
||||
|
||||
# filter pred and only keep those that highly overlap with GTs
|
||||
tk_mask = np.isin(
|
||||
raw_data["tk_ids"][t], np.array(loc_and_asso_tk_ids), assume_unique=True
|
||||
)
|
||||
tk_overlap_mask = np.isin(
|
||||
raw_data["tk_ids"][t],
|
||||
np.array(data["tk_overlap_ids"][t]),
|
||||
assume_unique=True,
|
||||
)
|
||||
|
||||
tk_ids = raw_data["tk_ids"][t][tk_mask]
|
||||
if self.use_mask:
|
||||
tk_dets = [raw_data['tk_dets'][t][ind] for ind in range(len(tk_mask)) if
|
||||
tk_mask[ind]]
|
||||
else:
|
||||
tk_dets = raw_data["tk_dets"][t][tk_mask]
|
||||
|
||||
tracker_classes = raw_data["tk_classes"][t][tk_mask]
|
||||
|
||||
# add overlap classes for computing the FP for Cls term
|
||||
tracker_overlap_classes = raw_data["tk_classes"][t][tk_overlap_mask]
|
||||
# tracker_confidences = raw_data["tk_confidences"][t][tk_mask]
|
||||
sim_scores_masked = sim_scores[t][gt_class_mask, :][:, tk_mask]
|
||||
|
||||
# add filtered prediction to the data
|
||||
data["tk_classes"][t] = tracker_classes
|
||||
data["tk_overlap_classes"][t] = tracker_overlap_classes
|
||||
data["tk_ids"][t] = tk_ids
|
||||
data["tk_dets"][t] = tk_dets
|
||||
# data["tk_confidences"][t] = tracker_confidences
|
||||
data["sim_scores"][t] = sim_scores_masked
|
||||
data["tk_class_eval_tk_ids"][t] = set(
|
||||
list(data["tk_overlap_ids"][t]) + list(data["tk_exh_ids"][t])
|
||||
)
|
||||
|
||||
# count total number of detections
|
||||
unique_gt_ids += list(np.unique(data["gt_ids"][t]))
|
||||
# the unique track ids are for association.
|
||||
unique_tk_ids += list(np.unique(data["tk_ids"][t]))
|
||||
|
||||
num_tk_overlap_dets += len(data["tk_overlap_ids"][t])
|
||||
num_tk_cls_dets += len(data["tk_class_eval_tk_ids"][t])
|
||||
num_gt_dets += len(data["gt_ids"][t])
|
||||
|
||||
# re-label IDs such that there are no empty IDs
|
||||
if len(unique_gt_ids) > 0:
|
||||
unique_gt_ids = np.unique(unique_gt_ids)
|
||||
gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
|
||||
gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
|
||||
data["gt_id_map"] = {}
|
||||
for gt_id in unique_gt_ids:
|
||||
new_gt_id = gt_id_map[gt_id].astype(int)
|
||||
data["gt_id_map"][new_gt_id] = gt_id
|
||||
|
||||
for t in range(raw_data["num_timesteps"]):
|
||||
if len(data["gt_ids"][t]) > 0:
|
||||
data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int)
|
||||
|
||||
if len(unique_tk_ids) > 0:
|
||||
unique_tk_ids = np.unique(unique_tk_ids)
|
||||
tk_id_map = np.nan * np.ones((np.max(unique_tk_ids) + 1))
|
||||
tk_id_map[unique_tk_ids] = np.arange(len(unique_tk_ids))
|
||||
|
||||
data["tk_id_map"] = {}
|
||||
for track_id in unique_tk_ids:
|
||||
new_track_id = tk_id_map[track_id].astype(int)
|
||||
data["tk_id_map"][new_track_id] = track_id
|
||||
|
||||
for t in range(raw_data["num_timesteps"]):
|
||||
if len(data["tk_ids"][t]) > 0:
|
||||
data["tk_ids"][t] = tk_id_map[data["tk_ids"][t]].astype(int)
|
||||
if len(data["tk_overlap_ids"][t]) > 0:
|
||||
data["tk_overlap_ids"][t] = tk_id_map[
|
||||
data["tk_overlap_ids"][t]
|
||||
].astype(int)
|
||||
|
||||
# record overview statistics.
|
||||
data["num_tk_cls_dets"] = num_tk_cls_dets
|
||||
data["num_tk_overlap_dets"] = num_tk_overlap_dets
|
||||
data["num_gt_dets"] = num_gt_dets
|
||||
data["num_tk_ids"] = len(unique_tk_ids)
|
||||
data["num_gt_ids"] = len(unique_gt_ids)
|
||||
data["num_timesteps"] = raw_data["num_timesteps"]
|
||||
data["seq"] = raw_data["seq"]
|
||||
|
||||
self._check_unique_ids(data)
|
||||
|
||||
return data
|
||||
|
||||
@_timing.time
|
||||
def get_preprocessed_seq_data(
|
||||
self, raw_data, cls, assignment=None, thresholds=[50, 75]
|
||||
):
|
||||
"""Preprocess data for a single sequence for a single class."""
|
||||
data = {}
|
||||
if thresholds is None:
|
||||
thresholds = [50, 75]
|
||||
elif isinstance(thresholds, int):
|
||||
thresholds = [thresholds]
|
||||
|
||||
for thr in thresholds:
|
||||
assignment_thr = None
|
||||
if assignment is not None:
|
||||
assignment_thr = assignment[thr]
|
||||
data[thr] = self.get_preprocessed_seq_data_thr(
|
||||
raw_data, cls, assignment_thr
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def _calculate_similarities(self, gt_dets_t, tk_dets_t):
|
||||
"""Compute similarity scores."""
|
||||
if self.use_mask:
|
||||
similarity_scores = self._calculate_mask_ious(gt_dets_t, tk_dets_t, is_encoded=True, do_ioa=False)
|
||||
else:
|
||||
similarity_scores = self._calculate_box_ious(gt_dets_t, tk_dets_t)
|
||||
return similarity_scores
|
||||
|
||||
def _compute_vid_mappings(self, annotations):
|
||||
"""Computes mappings from videos to corresponding tracks and images."""
|
||||
vids_to_tracks = {}
|
||||
vids_to_imgs = {}
|
||||
vid_ids = [vid["id"] for vid in self.gt_data["videos"]]
|
||||
|
||||
# compute an mapping from image IDs to images
|
||||
images = {}
|
||||
for image in self.gt_data["images"]:
|
||||
images[image["id"]] = image
|
||||
|
||||
tk_str = utils.get_track_id_str(annotations[0])
|
||||
for ann in annotations:
|
||||
ann["area"] = ann["bbox"][2] * ann["bbox"][3]
|
||||
|
||||
vid = ann["video_id"]
|
||||
if ann["video_id"] not in vids_to_tracks.keys():
|
||||
vids_to_tracks[ann["video_id"]] = list()
|
||||
if ann["video_id"] not in vids_to_imgs.keys():
|
||||
vids_to_imgs[ann["video_id"]] = list()
|
||||
|
||||
# fill in vids_to_tracks
|
||||
tid = ann[tk_str]
|
||||
exist_tids = [track["id"] for track in vids_to_tracks[vid]]
|
||||
try:
|
||||
index1 = exist_tids.index(tid)
|
||||
except ValueError:
|
||||
index1 = -1
|
||||
if tid not in exist_tids:
|
||||
curr_track = {
|
||||
"id": tid,
|
||||
"category_id": ann["category_id"],
|
||||
"video_id": vid,
|
||||
"annotations": [ann],
|
||||
}
|
||||
vids_to_tracks[vid].append(curr_track)
|
||||
else:
|
||||
vids_to_tracks[vid][index1]["annotations"].append(ann)
|
||||
|
||||
# fill in vids_to_imgs
|
||||
img_id = ann["image_id"]
|
||||
exist_img_ids = [img["id"] for img in vids_to_imgs[vid]]
|
||||
try:
|
||||
index2 = exist_img_ids.index(img_id)
|
||||
except ValueError:
|
||||
index2 = -1
|
||||
if index2 == -1:
|
||||
curr_img = {"id": img_id, "annotations": [ann]}
|
||||
vids_to_imgs[vid].append(curr_img)
|
||||
else:
|
||||
vids_to_imgs[vid][index2]["annotations"].append(ann)
|
||||
|
||||
# sort annotations by frame index and compute track area
|
||||
for vid, tracks in vids_to_tracks.items():
|
||||
for track in tracks:
|
||||
track["annotations"] = sorted(
|
||||
track["annotations"],
|
||||
key=lambda x: images[x["image_id"]]["frame_id"],
|
||||
)
|
||||
# compute average area
|
||||
track["area"] = sum(x["area"] for x in track["annotations"]) / len(
|
||||
track["annotations"]
|
||||
)
|
||||
|
||||
# ensure all videos are present
|
||||
for vid_id in vid_ids:
|
||||
if vid_id not in vids_to_tracks.keys():
|
||||
vids_to_tracks[vid_id] = []
|
||||
if vid_id not in vids_to_imgs.keys():
|
||||
vids_to_imgs[vid_id] = []
|
||||
|
||||
return vids_to_tracks, vids_to_imgs
|
||||
|
||||
def _compute_image_to_timestep_mappings(self):
|
||||
"""Computes a mapping from images to timestep in sequence."""
|
||||
images = {}
|
||||
for image in self.gt_data["images"]:
|
||||
images[image["id"]] = image
|
||||
|
||||
seq_to_imgs_to_timestep = {vid["id"]: dict() for vid in self.gt_data["videos"]}
|
||||
for vid in seq_to_imgs_to_timestep:
|
||||
curr_imgs = [img["id"] for img in self.video2gt_image[vid]]
|
||||
curr_imgs = sorted(curr_imgs, key=lambda x: images[x]["frame_id"])
|
||||
seq_to_imgs_to_timestep[vid] = {
|
||||
curr_imgs[i]: i for i in range(len(curr_imgs))
|
||||
}
|
||||
|
||||
return seq_to_imgs_to_timestep
|
||||
|
||||
def _limit_dets_per_image(self, annotations):
|
||||
"""Limits the number of detections for each image.
|
||||
|
||||
Adapted from https://github.com/TAO-Dataset/.
|
||||
"""
|
||||
max_dets = self.config["MAX_DETECTIONS"]
|
||||
img_ann = defaultdict(list)
|
||||
for ann in annotations:
|
||||
img_ann[ann["image_id"]].append(ann)
|
||||
|
||||
for img_id, _anns in img_ann.items():
|
||||
if len(_anns) <= max_dets:
|
||||
continue
|
||||
_anns = sorted(_anns, key=lambda x: x["score"], reverse=True)
|
||||
img_ann[img_id] = _anns[:max_dets]
|
||||
|
||||
return [ann for anns in img_ann.values() for ann in anns]
|
||||
|
||||
def _fill_video_ids_inplace(self, annotations):
|
||||
"""Fills in missing video IDs inplace.
|
||||
|
||||
Adapted from https://github.com/TAO-Dataset/.
|
||||
"""
|
||||
missing_video_id = [x for x in annotations if "video_id" not in x]
|
||||
if missing_video_id:
|
||||
image_id_to_video_id = {
|
||||
x["id"]: x["video_id"] for x in self.gt_data["images"]
|
||||
}
|
||||
for x in missing_video_id:
|
||||
x["video_id"] = image_id_to_video_id[x["image_id"]]
|
||||
|
||||
@staticmethod
|
||||
def _make_tk_ids_unique(annotations):
|
||||
"""Makes track IDs unqiue over the whole annotation set.
|
||||
|
||||
Adapted from https://github.com/TAO-Dataset/.
|
||||
"""
|
||||
track_id_videos = {}
|
||||
track_ids_to_update = set()
|
||||
max_track_id = 0
|
||||
|
||||
tk_str = utils.get_track_id_str(annotations[0])
|
||||
for ann in annotations:
|
||||
t = int(ann[tk_str])
|
||||
if t not in track_id_videos:
|
||||
track_id_videos[t] = ann["video_id"]
|
||||
|
||||
if ann["video_id"] != track_id_videos[t]:
|
||||
# track id is assigned to multiple videos
|
||||
track_ids_to_update.add(t)
|
||||
max_track_id = max(max_track_id, t)
|
||||
|
||||
if track_ids_to_update:
|
||||
print("true")
|
||||
next_id = itertools.count(max_track_id + 1)
|
||||
new_tk_ids = defaultdict(lambda: next(next_id))
|
||||
for ann in annotations:
|
||||
t = ann[tk_str]
|
||||
v = ann["video_id"]
|
||||
if t in track_ids_to_update:
|
||||
ann[tk_str] = new_tk_ids[t, v]
|
||||
return len(track_ids_to_update)
|
||||
659
sam3/eval/teta_eval_toolkit/datasets/tao.py
Normal file
659
sam3/eval/teta_eval_toolkit/datasets/tao.py
Normal file
@@ -0,0 +1,659 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
"""TAO Dataset."""
|
||||
import copy
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import _timing
|
||||
from ..config import get_default_dataset_config, init_config
|
||||
from ..utils import TrackEvalException
|
||||
from ._base_dataset import _BaseDataset
|
||||
|
||||
|
||||
class TAO(_BaseDataset):
|
||||
"""Dataset class for TAO tracking"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
"""Initialize dataset, checking that all required files are present."""
|
||||
super().__init__()
|
||||
# Fill non-given config values with defaults
|
||||
self.config = init_config(config, get_default_dataset_config(), self.get_name())
|
||||
self.gt_fol = self.config["GT_FOLDER"]
|
||||
self.tracker_fol = self.config["TRACKERS_FOLDER"]
|
||||
self.should_classes_combine = True
|
||||
self.use_super_categories = False
|
||||
self.use_mask = self.config["USE_MASK"]
|
||||
|
||||
|
||||
self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"]
|
||||
self.output_fol = self.config["OUTPUT_FOLDER"]
|
||||
if self.output_fol is None:
|
||||
self.output_fol = self.tracker_fol
|
||||
self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"]
|
||||
|
||||
if self.gt_fol.endswith(".json"):
|
||||
self.gt_data = json.load(open(self.gt_fol, "r"))
|
||||
else:
|
||||
gt_dir_files = [
|
||||
file for file in os.listdir(self.gt_fol) if file.endswith(".json")
|
||||
]
|
||||
if len(gt_dir_files) != 1:
|
||||
raise TrackEvalException(
|
||||
f"{self.gt_fol} does not contain exactly one json file."
|
||||
)
|
||||
|
||||
with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f:
|
||||
self.gt_data = json.load(f)
|
||||
|
||||
# merge categories marked with a merged tag in TAO dataset
|
||||
self._merge_categories(self.gt_data["annotations"] + self.gt_data["tracks"])
|
||||
|
||||
# get sequences to eval and sequence information
|
||||
self.seq_list = [
|
||||
vid["name"].replace("/", "-") for vid in self.gt_data["videos"]
|
||||
]
|
||||
self.seq_name2seqid = {
|
||||
vid["name"].replace("/", "-"): vid["id"] for vid in self.gt_data["videos"]
|
||||
}
|
||||
# compute mappings from videos to annotation data
|
||||
self.video2gt_track, self.video2gt_image = self._compute_vid_mappings(
|
||||
self.gt_data["annotations"]
|
||||
)
|
||||
# compute sequence lengths
|
||||
self.seq_lengths = {vid["id"]: 0 for vid in self.gt_data["videos"]}
|
||||
for img in self.gt_data["images"]:
|
||||
self.seq_lengths[img["video_id"]] += 1
|
||||
self.seq2images2timestep = self._compute_image_to_timestep_mappings()
|
||||
self.seq2cls = {
|
||||
vid["id"]: {
|
||||
"pos_cat_ids": list(
|
||||
{track["category_id"] for track in self.video2gt_track[vid["id"]]}
|
||||
),
|
||||
"neg_cat_ids": vid["neg_category_ids"],
|
||||
"not_exh_labeled_cat_ids": vid["not_exhaustive_category_ids"],
|
||||
}
|
||||
for vid in self.gt_data["videos"]
|
||||
}
|
||||
|
||||
# Get classes to eval
|
||||
considered_vid_ids = [self.seq_name2seqid[vid] for vid in self.seq_list]
|
||||
seen_cats = set(
|
||||
[
|
||||
cat_id
|
||||
for vid_id in considered_vid_ids
|
||||
for cat_id in self.seq2cls[vid_id]["pos_cat_ids"]
|
||||
]
|
||||
)
|
||||
# only classes with ground truth are evaluated in TAO
|
||||
self.valid_classes = [
|
||||
cls["name"] for cls in self.gt_data["categories"] if cls["id"] in seen_cats
|
||||
]
|
||||
cls_name2clsid_map = {
|
||||
cls["name"]: cls["id"] for cls in self.gt_data["categories"]
|
||||
}
|
||||
|
||||
if self.config["CLASSES_TO_EVAL"]:
|
||||
self.class_list = [
|
||||
cls.lower() if cls.lower() in self.valid_classes else None
|
||||
for cls in self.config["CLASSES_TO_EVAL"]
|
||||
]
|
||||
if not all(self.class_list):
|
||||
valid_cls = ", ".join(self.valid_classes)
|
||||
raise TrackEvalException(
|
||||
"Attempted to evaluate an invalid class. Only classes "
|
||||
f"{valid_cls} are valid (classes present in ground truth"
|
||||
" data)."
|
||||
)
|
||||
else:
|
||||
self.class_list = [cls for cls in self.valid_classes]
|
||||
self.cls_name2clsid = {
|
||||
k: v for k, v in cls_name2clsid_map.items() if k in self.class_list
|
||||
}
|
||||
self.clsid2cls_name = {
|
||||
v: k for k, v in cls_name2clsid_map.items() if k in self.class_list
|
||||
}
|
||||
# get trackers to eval
|
||||
print(self.config["TRACKERS_TO_EVAL"] )
|
||||
if self.config["TRACKERS_TO_EVAL"] is None:
|
||||
self.tracker_list = os.listdir(self.tracker_fol)
|
||||
else:
|
||||
self.tracker_list = self.config["TRACKERS_TO_EVAL"]
|
||||
|
||||
if self.config["TRACKER_DISPLAY_NAMES"] is None:
|
||||
self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list))
|
||||
elif (self.config["TRACKERS_TO_EVAL"] is not None) and (
|
||||
len(self.config["TK_DISPLAY_NAMES"]) == len(self.tracker_list)
|
||||
):
|
||||
self.tracker_to_disp = dict(
|
||||
zip(self.tracker_list, self.config["TK_DISPLAY_NAMES"])
|
||||
)
|
||||
else:
|
||||
raise TrackEvalException(
|
||||
"List of tracker files and tracker display names do not match."
|
||||
)
|
||||
|
||||
self.tracker_data = {tracker: dict() for tracker in self.tracker_list}
|
||||
|
||||
for tracker in self.tracker_list:
|
||||
if self.tracker_sub_fol.endswith(".json"):
|
||||
with open(os.path.join(self.tracker_sub_fol)) as f:
|
||||
curr_data = json.load(f)
|
||||
else:
|
||||
tr_dir = os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol)
|
||||
tr_dir_files = [
|
||||
file for file in os.listdir(tr_dir) if file.endswith(".json")
|
||||
]
|
||||
if len(tr_dir_files) != 1:
|
||||
raise TrackEvalException(
|
||||
f"{tr_dir} does not contain exactly one json file."
|
||||
)
|
||||
with open(os.path.join(tr_dir, tr_dir_files[0])) as f:
|
||||
curr_data = json.load(f)
|
||||
|
||||
# limit detections if MAX_DETECTIONS > 0
|
||||
if self.config["MAX_DETECTIONS"]:
|
||||
curr_data = self._limit_dets_per_image(curr_data)
|
||||
|
||||
# fill missing video ids
|
||||
self._fill_video_ids_inplace(curr_data)
|
||||
|
||||
# make track ids unique over whole evaluation set
|
||||
self._make_tk_ids_unique(curr_data)
|
||||
|
||||
# merge categories marked with a merged tag in TAO dataset
|
||||
self._merge_categories(curr_data)
|
||||
|
||||
# get tracker sequence information
|
||||
curr_vids2tracks, curr_vids2images = self._compute_vid_mappings(curr_data)
|
||||
self.tracker_data[tracker]["vids_to_tracks"] = curr_vids2tracks
|
||||
self.tracker_data[tracker]["vids_to_images"] = curr_vids2images
|
||||
|
||||
def get_display_name(self, tracker):
|
||||
return self.tracker_to_disp[tracker]
|
||||
|
||||
def _load_raw_file(self, tracker, seq, is_gt):
|
||||
"""Load a file (gt or tracker) in the TAO format
|
||||
|
||||
If is_gt, this returns a dict which contains the fields:
|
||||
[gt_ids, gt_classes]:
|
||||
list (for each timestep) of 1D NDArrays (for each det).
|
||||
[gt_dets]: list (for each timestep) of lists of detections.
|
||||
|
||||
if not is_gt, this returns a dict which contains the fields:
|
||||
[tk_ids, tk_classes, tk_confidences]:
|
||||
list (for each timestep) of 1D NDArrays (for each det).
|
||||
[tk_dets]: list (for each timestep) of lists of detections.
|
||||
"""
|
||||
seq_id = self.seq_name2seqid[seq]
|
||||
# file location
|
||||
if is_gt:
|
||||
imgs = self.video2gt_image[seq_id]
|
||||
else:
|
||||
imgs = self.tracker_data[tracker]["vids_to_images"][seq_id]
|
||||
|
||||
# convert data to required format
|
||||
num_timesteps = self.seq_lengths[seq_id]
|
||||
img_to_timestep = self.seq2images2timestep[seq_id]
|
||||
data_keys = ["ids", "classes", "dets"]
|
||||
if not is_gt:
|
||||
data_keys += ["tk_confidences"]
|
||||
raw_data = {key: [None] * num_timesteps for key in data_keys}
|
||||
for img in imgs:
|
||||
# some tracker data contains images without any ground truth info,
|
||||
# these are ignored
|
||||
if img["id"] not in img_to_timestep:
|
||||
continue
|
||||
t = img_to_timestep[img["id"]]
|
||||
anns = img["annotations"]
|
||||
if self.use_mask:
|
||||
# When using mask, extract segmentation data
|
||||
raw_data["dets"][t] = [ann.get("segmentation") for ann in anns]
|
||||
else:
|
||||
# When using bbox, extract bbox data
|
||||
raw_data["dets"][t] = np.atleast_2d([ann["bbox"] for ann in anns]).astype(
|
||||
float
|
||||
)
|
||||
raw_data["ids"][t] = np.atleast_1d(
|
||||
[ann["track_id"] for ann in anns]
|
||||
).astype(int)
|
||||
raw_data["classes"][t] = np.atleast_1d(
|
||||
[ann["category_id"] for ann in anns]
|
||||
).astype(int)
|
||||
if not is_gt:
|
||||
raw_data["tk_confidences"][t] = np.atleast_1d(
|
||||
[ann["score"] for ann in anns]
|
||||
).astype(float)
|
||||
|
||||
for t, d in enumerate(raw_data["dets"]):
|
||||
if d is None:
|
||||
raw_data["dets"][t] = np.empty((0, 4)).astype(float)
|
||||
raw_data["ids"][t] = np.empty(0).astype(int)
|
||||
raw_data["classes"][t] = np.empty(0).astype(int)
|
||||
if not is_gt:
|
||||
raw_data["tk_confidences"][t] = np.empty(0)
|
||||
|
||||
if is_gt:
|
||||
key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"}
|
||||
else:
|
||||
key_map = {"ids": "tk_ids", "classes": "tk_classes", "dets": "tk_dets"}
|
||||
for k, v in key_map.items():
|
||||
raw_data[v] = raw_data.pop(k)
|
||||
|
||||
raw_data["num_timesteps"] = num_timesteps
|
||||
raw_data["neg_cat_ids"] = self.seq2cls[seq_id]["neg_cat_ids"]
|
||||
raw_data["not_exh_labeled_cls"] = self.seq2cls[seq_id][
|
||||
"not_exh_labeled_cat_ids"
|
||||
]
|
||||
raw_data["seq"] = seq
|
||||
return raw_data
|
||||
|
||||
def get_preprocessed_seq_data_thr(self, raw_data, cls, assignment=None):
|
||||
"""Preprocess data for a single sequence for a single class.
|
||||
|
||||
Inputs:
|
||||
raw_data: dict containing the data for the sequence already
|
||||
read in by get_raw_seq_data().
|
||||
cls: class to be evaluated.
|
||||
Outputs:
|
||||
gt_ids:
|
||||
list (for each timestep) of ids of GT tracks
|
||||
tk_ids:
|
||||
list (for each timestep) of ids of predicted tracks (all for TP
|
||||
matching (Det + AssocA))
|
||||
tk_overlap_ids:
|
||||
list (for each timestep) of ids of predicted tracks that overlap
|
||||
with GTs
|
||||
tk_neg_ids:
|
||||
list (for each timestep) of ids of predicted tracks that with
|
||||
the class id on the negative list for the current sequence.
|
||||
tk_exh_ids:
|
||||
list (for each timestep) of ids of predicted tracks that do not
|
||||
overlap with existing GTs but have the class id on the
|
||||
exhaustive annotated class list for the current sequence.
|
||||
tk_dets:
|
||||
list (for each timestep) of lists of detections that
|
||||
corresponding to the tk_ids
|
||||
tk_classes:
|
||||
list (for each timestep) of lists of classes that corresponding
|
||||
to the tk_ids
|
||||
tk_confidences:
|
||||
list (for each timestep) of lists of classes that corresponding
|
||||
to the tk_ids
|
||||
sim_scores:
|
||||
similarity score between gt_ids and tk_ids.
|
||||
"""
|
||||
if cls != "all":
|
||||
cls_id = self.cls_name2clsid[cls]
|
||||
|
||||
data_keys = [
|
||||
"gt_ids",
|
||||
"tk_ids",
|
||||
"gt_id_map",
|
||||
"tk_id_map",
|
||||
"gt_dets",
|
||||
"gt_classes",
|
||||
"gt_class_name",
|
||||
"tk_overlap_classes",
|
||||
"tk_overlap_ids",
|
||||
"tk_neg_ids",
|
||||
"tk_exh_ids",
|
||||
"tk_class_eval_tk_ids",
|
||||
"tk_dets",
|
||||
"tk_classes",
|
||||
"tk_confidences",
|
||||
"sim_scores",
|
||||
]
|
||||
data = {key: [None] * raw_data["num_timesteps"] for key in data_keys}
|
||||
unique_gt_ids = []
|
||||
unique_tk_ids = []
|
||||
num_gt_dets = 0
|
||||
num_tk_cls_dets = 0
|
||||
num_tk_overlap_dets = 0
|
||||
overlap_ious_thr = 0.5
|
||||
loc_and_asso_tk_ids = []
|
||||
|
||||
for t in range(raw_data["num_timesteps"]):
|
||||
# only extract relevant dets for this class for preproc and eval
|
||||
if cls == "all":
|
||||
gt_class_mask = np.ones_like(raw_data["gt_classes"][t]).astype(bool)
|
||||
else:
|
||||
gt_class_mask = np.atleast_1d(
|
||||
raw_data["gt_classes"][t] == cls_id
|
||||
).astype(bool)
|
||||
|
||||
# select GT that is not in the evaluating classes
|
||||
if assignment is not None and assignment:
|
||||
all_gt_ids = list(assignment[t].keys())
|
||||
gt_ids_in = raw_data["gt_ids"][t][gt_class_mask]
|
||||
gt_ids_out = set(all_gt_ids) - set(gt_ids_in)
|
||||
tk_ids_out = set([assignment[t][key] for key in list(gt_ids_out)])
|
||||
|
||||
# compute overlapped tracks and add their ids to overlap_tk_ids
|
||||
sim_scores = raw_data["similarity_scores"]
|
||||
overlap_ids_masks = (sim_scores[t][gt_class_mask] >= overlap_ious_thr).any(
|
||||
axis=0
|
||||
)
|
||||
overlap_tk_ids_t = raw_data["tk_ids"][t][overlap_ids_masks]
|
||||
if assignment is not None and assignment:
|
||||
data["tk_overlap_ids"][t] = list(set(overlap_tk_ids_t) - tk_ids_out)
|
||||
else:
|
||||
data["tk_overlap_ids"][t] = list(set(overlap_tk_ids_t))
|
||||
|
||||
loc_and_asso_tk_ids += data["tk_overlap_ids"][t]
|
||||
|
||||
data["tk_exh_ids"][t] = []
|
||||
data["tk_neg_ids"][t] = []
|
||||
|
||||
if cls == "all":
|
||||
continue
|
||||
|
||||
# remove tk_ids that has been assigned to GT belongs to other classes.
|
||||
loc_and_asso_tk_ids = list(set(loc_and_asso_tk_ids))
|
||||
|
||||
# remove all unwanted unmatched tracker detections
|
||||
for t in range(raw_data["num_timesteps"]):
|
||||
# add gt to the data
|
||||
if cls == "all":
|
||||
gt_class_mask = np.ones_like(raw_data["gt_classes"][t]).astype(bool)
|
||||
else:
|
||||
gt_class_mask = np.atleast_1d(
|
||||
raw_data["gt_classes"][t] == cls_id
|
||||
).astype(bool)
|
||||
data["gt_classes"][t] = cls_id
|
||||
data["gt_class_name"][t] = cls
|
||||
|
||||
gt_ids = raw_data["gt_ids"][t][gt_class_mask]
|
||||
if self.use_mask:
|
||||
gt_dets = [raw_data['gt_dets'][t][ind] for ind in range(len(gt_class_mask)) if gt_class_mask[ind]]
|
||||
else:
|
||||
gt_dets = raw_data["gt_dets"][t][gt_class_mask]
|
||||
data["gt_ids"][t] = gt_ids
|
||||
data["gt_dets"][t] = gt_dets
|
||||
|
||||
# filter pred and only keep those that highly overlap with GTs
|
||||
tk_mask = np.isin(
|
||||
raw_data["tk_ids"][t], np.array(loc_and_asso_tk_ids), assume_unique=True
|
||||
)
|
||||
tk_overlap_mask = np.isin(
|
||||
raw_data["tk_ids"][t],
|
||||
np.array(data["tk_overlap_ids"][t]),
|
||||
assume_unique=True,
|
||||
)
|
||||
|
||||
tk_ids = raw_data["tk_ids"][t][tk_mask]
|
||||
if self.use_mask:
|
||||
tk_dets = [raw_data['tk_dets'][t][ind] for ind in range(len(tk_mask)) if
|
||||
tk_mask[ind]]
|
||||
else:
|
||||
tk_dets = raw_data["tk_dets"][t][tk_mask]
|
||||
tracker_classes = raw_data["tk_classes"][t][tk_mask]
|
||||
|
||||
# add overlap classes for computing the FP for Cls term
|
||||
tracker_overlap_classes = raw_data["tk_classes"][t][tk_overlap_mask]
|
||||
tracker_confidences = raw_data["tk_confidences"][t][tk_mask]
|
||||
sim_scores_masked = sim_scores[t][gt_class_mask, :][:, tk_mask]
|
||||
|
||||
# add filtered prediction to the data
|
||||
data["tk_classes"][t] = tracker_classes
|
||||
data["tk_overlap_classes"][t] = tracker_overlap_classes
|
||||
data["tk_ids"][t] = tk_ids
|
||||
data["tk_dets"][t] = tk_dets
|
||||
data["tk_confidences"][t] = tracker_confidences
|
||||
data["sim_scores"][t] = sim_scores_masked
|
||||
data["tk_class_eval_tk_ids"][t] = set(
|
||||
list(data["tk_overlap_ids"][t])
|
||||
+ list(data["tk_neg_ids"][t])
|
||||
+ list(data["tk_exh_ids"][t])
|
||||
)
|
||||
|
||||
# count total number of detections
|
||||
unique_gt_ids += list(np.unique(data["gt_ids"][t]))
|
||||
# the unique track ids are for association.
|
||||
unique_tk_ids += list(np.unique(data["tk_ids"][t]))
|
||||
|
||||
num_tk_overlap_dets += len(data["tk_overlap_ids"][t])
|
||||
num_tk_cls_dets += len(data["tk_class_eval_tk_ids"][t])
|
||||
num_gt_dets += len(data["gt_ids"][t])
|
||||
|
||||
# re-label IDs such that there are no empty IDs
|
||||
if len(unique_gt_ids) > 0:
|
||||
unique_gt_ids = np.unique(unique_gt_ids)
|
||||
gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
|
||||
gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
|
||||
data["gt_id_map"] = {}
|
||||
for gt_id in unique_gt_ids:
|
||||
new_gt_id = gt_id_map[gt_id].astype(int)
|
||||
data["gt_id_map"][new_gt_id] = gt_id
|
||||
|
||||
for t in range(raw_data["num_timesteps"]):
|
||||
if len(data["gt_ids"][t]) > 0:
|
||||
data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int)
|
||||
|
||||
if len(unique_tk_ids) > 0:
|
||||
unique_tk_ids = np.unique(unique_tk_ids)
|
||||
tk_id_map = np.nan * np.ones((np.max(unique_tk_ids) + 1))
|
||||
tk_id_map[unique_tk_ids] = np.arange(len(unique_tk_ids))
|
||||
|
||||
data["tk_id_map"] = {}
|
||||
for track_id in unique_tk_ids:
|
||||
new_track_id = tk_id_map[track_id].astype(int)
|
||||
data["tk_id_map"][new_track_id] = track_id
|
||||
|
||||
for t in range(raw_data["num_timesteps"]):
|
||||
if len(data["tk_ids"][t]) > 0:
|
||||
data["tk_ids"][t] = tk_id_map[data["tk_ids"][t]].astype(int)
|
||||
if len(data["tk_overlap_ids"][t]) > 0:
|
||||
data["tk_overlap_ids"][t] = tk_id_map[
|
||||
data["tk_overlap_ids"][t]
|
||||
].astype(int)
|
||||
|
||||
# record overview statistics.
|
||||
data["num_tk_cls_dets"] = num_tk_cls_dets
|
||||
data["num_tk_overlap_dets"] = num_tk_overlap_dets
|
||||
data["num_gt_dets"] = num_gt_dets
|
||||
data["num_tk_ids"] = len(unique_tk_ids)
|
||||
data["num_gt_ids"] = len(unique_gt_ids)
|
||||
data["num_timesteps"] = raw_data["num_timesteps"]
|
||||
data["seq"] = raw_data["seq"]
|
||||
|
||||
self._check_unique_ids(data)
|
||||
|
||||
return data
|
||||
|
||||
@_timing.time
|
||||
def get_preprocessed_seq_data(
|
||||
self, raw_data, cls, assignment=None, thresholds=[50, 75]
|
||||
):
|
||||
"""Preprocess data for a single sequence for a single class."""
|
||||
data = {}
|
||||
if thresholds is None:
|
||||
thresholds = [50]
|
||||
elif isinstance(thresholds, int):
|
||||
thresholds = [thresholds]
|
||||
|
||||
for thr in thresholds:
|
||||
assignment_thr = None
|
||||
if assignment is not None:
|
||||
assignment_thr = assignment[thr]
|
||||
data[thr] = self.get_preprocessed_seq_data_thr(
|
||||
raw_data, cls, assignment_thr
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def _calculate_similarities(self, gt_dets_t, tk_dets_t):
|
||||
"""Compute similarity scores."""
|
||||
if self.use_mask:
|
||||
similarity_scores = self._calculate_mask_ious(gt_dets_t, tk_dets_t, is_encoded=True, do_ioa=False)
|
||||
else:
|
||||
similarity_scores = self._calculate_box_ious(gt_dets_t, tk_dets_t)
|
||||
return similarity_scores
|
||||
|
||||
def _merge_categories(self, annotations):
|
||||
"""Merges categories with a merged tag.
|
||||
|
||||
Adapted from https://github.com/TAO-Dataset.
|
||||
"""
|
||||
merge_map = {}
|
||||
for category in self.gt_data["categories"]:
|
||||
if "merged" in category:
|
||||
for to_merge in category["merged"]:
|
||||
merge_map[to_merge["id"]] = category["id"]
|
||||
|
||||
for ann in annotations:
|
||||
ann["category_id"] = merge_map.get(ann["category_id"], ann["category_id"])
|
||||
|
||||
def _compute_vid_mappings(self, annotations):
|
||||
"""Computes mappings from videos to corresponding tracks and images."""
|
||||
vids_to_tracks = {}
|
||||
vids_to_imgs = {}
|
||||
vid_ids = [vid["id"] for vid in self.gt_data["videos"]]
|
||||
|
||||
# compute an mapping from image IDs to images
|
||||
images = {}
|
||||
for image in self.gt_data["images"]:
|
||||
images[image["id"]] = image
|
||||
|
||||
for ann in annotations:
|
||||
ann["area"] = ann["bbox"][2] * ann["bbox"][3]
|
||||
|
||||
vid = ann["video_id"]
|
||||
if ann["video_id"] not in vids_to_tracks.keys():
|
||||
vids_to_tracks[ann["video_id"]] = list()
|
||||
if ann["video_id"] not in vids_to_imgs.keys():
|
||||
vids_to_imgs[ann["video_id"]] = list()
|
||||
|
||||
# fill in vids_to_tracks
|
||||
tid = ann["track_id"]
|
||||
exist_tids = [track["id"] for track in vids_to_tracks[vid]]
|
||||
try:
|
||||
index1 = exist_tids.index(tid)
|
||||
except ValueError:
|
||||
index1 = -1
|
||||
if tid not in exist_tids:
|
||||
curr_track = {
|
||||
"id": tid,
|
||||
"category_id": ann["category_id"],
|
||||
"video_id": vid,
|
||||
"annotations": [ann],
|
||||
}
|
||||
vids_to_tracks[vid].append(curr_track)
|
||||
else:
|
||||
vids_to_tracks[vid][index1]["annotations"].append(ann)
|
||||
|
||||
# fill in vids_to_imgs
|
||||
img_id = ann["image_id"]
|
||||
exist_img_ids = [img["id"] for img in vids_to_imgs[vid]]
|
||||
try:
|
||||
index2 = exist_img_ids.index(img_id)
|
||||
except ValueError:
|
||||
index2 = -1
|
||||
if index2 == -1:
|
||||
curr_img = {"id": img_id, "annotations": [ann]}
|
||||
vids_to_imgs[vid].append(curr_img)
|
||||
else:
|
||||
vids_to_imgs[vid][index2]["annotations"].append(ann)
|
||||
|
||||
# sort annotations by frame index and compute track area
|
||||
for vid, tracks in vids_to_tracks.items():
|
||||
for track in tracks:
|
||||
track["annotations"] = sorted(
|
||||
track["annotations"],
|
||||
key=lambda x: images[x["image_id"]]["frame_index"],
|
||||
)
|
||||
# compute average area
|
||||
track["area"] = sum(x["area"] for x in track["annotations"]) / len(
|
||||
track["annotations"]
|
||||
)
|
||||
|
||||
# ensure all videos are present
|
||||
for vid_id in vid_ids:
|
||||
if vid_id not in vids_to_tracks.keys():
|
||||
vids_to_tracks[vid_id] = []
|
||||
if vid_id not in vids_to_imgs.keys():
|
||||
vids_to_imgs[vid_id] = []
|
||||
|
||||
return vids_to_tracks, vids_to_imgs
|
||||
|
||||
def _compute_image_to_timestep_mappings(self):
|
||||
"""Computes a mapping from images to timestep in sequence."""
|
||||
images = {}
|
||||
for image in self.gt_data["images"]:
|
||||
images[image["id"]] = image
|
||||
|
||||
seq_to_imgs_to_timestep = {vid["id"]: dict() for vid in self.gt_data["videos"]}
|
||||
for vid in seq_to_imgs_to_timestep:
|
||||
curr_imgs = [img["id"] for img in self.video2gt_image[vid]]
|
||||
curr_imgs = sorted(curr_imgs, key=lambda x: images[x]["frame_index"])
|
||||
seq_to_imgs_to_timestep[vid] = {
|
||||
curr_imgs[i]: i for i in range(len(curr_imgs))
|
||||
}
|
||||
|
||||
return seq_to_imgs_to_timestep
|
||||
|
||||
def _limit_dets_per_image(self, annotations):
|
||||
"""Limits the number of detections for each image.
|
||||
|
||||
Adapted from https://github.com/TAO-Dataset/.
|
||||
"""
|
||||
max_dets = self.config["MAX_DETECTIONS"]
|
||||
img_ann = defaultdict(list)
|
||||
for ann in annotations:
|
||||
img_ann[ann["image_id"]].append(ann)
|
||||
|
||||
for img_id, _anns in img_ann.items():
|
||||
if len(_anns) <= max_dets:
|
||||
continue
|
||||
_anns = sorted(_anns, key=lambda x: x["score"], reverse=True)
|
||||
img_ann[img_id] = _anns[:max_dets]
|
||||
|
||||
return [ann for anns in img_ann.values() for ann in anns]
|
||||
|
||||
def _fill_video_ids_inplace(self, annotations):
|
||||
"""Fills in missing video IDs inplace.
|
||||
|
||||
Adapted from https://github.com/TAO-Dataset/.
|
||||
"""
|
||||
missing_video_id = [x for x in annotations if "video_id" not in x]
|
||||
if missing_video_id:
|
||||
image_id_to_video_id = {
|
||||
x["id"]: x["video_id"] for x in self.gt_data["images"]
|
||||
}
|
||||
for x in missing_video_id:
|
||||
x["video_id"] = image_id_to_video_id[x["image_id"]]
|
||||
|
||||
@staticmethod
|
||||
def _make_tk_ids_unique(annotations):
|
||||
"""Makes track IDs unqiue over the whole annotation set.
|
||||
|
||||
Adapted from https://github.com/TAO-Dataset/.
|
||||
"""
|
||||
track_id_videos = {}
|
||||
track_ids_to_update = set()
|
||||
max_track_id = 0
|
||||
for ann in annotations:
|
||||
t = ann["track_id"]
|
||||
if t not in track_id_videos:
|
||||
track_id_videos[t] = ann["video_id"]
|
||||
|
||||
if ann["video_id"] != track_id_videos[t]:
|
||||
# track id is assigned to multiple videos
|
||||
track_ids_to_update.add(t)
|
||||
max_track_id = max(max_track_id, t)
|
||||
|
||||
if track_ids_to_update:
|
||||
print("true")
|
||||
next_id = itertools.count(max_track_id + 1)
|
||||
new_tk_ids = defaultdict(lambda: next(next_id))
|
||||
for ann in annotations:
|
||||
t = ann["track_id"]
|
||||
v = ann["video_id"]
|
||||
if t in track_ids_to_update:
|
||||
ann["track_id"] = new_tk_ids[t, v]
|
||||
return len(track_ids_to_update)
|
||||
275
sam3/eval/teta_eval_toolkit/eval.py
Normal file
275
sam3/eval/teta_eval_toolkit/eval.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
import copy
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
from multiprocessing.pool import Pool
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import _timing, utils
|
||||
from .config import get_default_eval_config, init_config
|
||||
from .utils import TrackEvalException
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""Evaluator class for evaluating different metrics for each datasets."""
|
||||
|
||||
def __init__(self, config=None):
|
||||
"""Initialize the evaluator with a config file."""
|
||||
self.config = init_config(config, get_default_eval_config(), "Eval")
|
||||
# Only run timing analysis if not run in parallel.
|
||||
if self.config["TIME_PROGRESS"] and not self.config["USE_PARALLEL"]:
|
||||
_timing.DO_TIMING = True
|
||||
if self.config["DISPLAY_LESS_PROGRESS"]:
|
||||
_timing.DISPLAY_LESS_PROGRESS = True
|
||||
|
||||
@_timing.time
|
||||
def evaluate(self, dataset_list, metrics_list):
|
||||
"""Evaluate a set of metrics on a set of datasets."""
|
||||
config = self.config
|
||||
metrics_list = metrics_list
|
||||
metric_names = utils.validate_metrics_list(metrics_list)
|
||||
dataset_names = [dataset.get_name() for dataset in dataset_list]
|
||||
output_res = {}
|
||||
output_msg = {}
|
||||
|
||||
for dataset, dname in zip(dataset_list, dataset_names):
|
||||
# Get dataset info about what to evaluate
|
||||
output_res[dname] = {}
|
||||
output_msg[dname] = {}
|
||||
tracker_list, seq_list, class_list = dataset.get_eval_info()
|
||||
print(
|
||||
f"\nEvaluating {len(tracker_list)} tracker(s) on "
|
||||
f"{len(seq_list)} sequence(s) for {len(class_list)} class(es)"
|
||||
f" on {dname} dataset using the following "
|
||||
f'metrics: {", ".join(metric_names)}\n'
|
||||
)
|
||||
|
||||
# Evaluate each tracker
|
||||
for tracker in tracker_list:
|
||||
try:
|
||||
output_res, output_msg = self.evaluate_tracker(
|
||||
tracker,
|
||||
dataset,
|
||||
dname,
|
||||
class_list,
|
||||
metrics_list,
|
||||
metric_names,
|
||||
seq_list,
|
||||
output_res,
|
||||
output_msg,
|
||||
)
|
||||
except Exception as err:
|
||||
output_res[dname][tracker] = None
|
||||
if type(err) == TrackEvalException:
|
||||
output_msg[dname][tracker] = str(err)
|
||||
else:
|
||||
output_msg[dname][tracker] = "Unknown error occurred."
|
||||
print("Tracker %s was unable to be evaluated." % tracker)
|
||||
print(err)
|
||||
traceback.print_exc()
|
||||
if config["LOG_ON_ERROR"] is not None:
|
||||
with open(config["LOG_ON_ERROR"], "a") as f:
|
||||
print(dname, file=f)
|
||||
print(tracker, file=f)
|
||||
print(traceback.format_exc(), file=f)
|
||||
print("\n\n\n", file=f)
|
||||
if config["BREAK_ON_ERROR"]:
|
||||
raise err
|
||||
elif config["RETURN_ON_ERROR"]:
|
||||
return output_res, output_msg
|
||||
|
||||
return output_res, output_msg
|
||||
|
||||
def evaluate_tracker(
|
||||
self,
|
||||
tracker,
|
||||
dataset,
|
||||
dname,
|
||||
class_list,
|
||||
metrics_list,
|
||||
metric_names,
|
||||
seq_list,
|
||||
output_res,
|
||||
output_msg,
|
||||
):
|
||||
"""Evaluate each sequence in parallel or in series."""
|
||||
print("\nEvaluating %s\n" % tracker)
|
||||
time_start = time.time()
|
||||
config = self.config
|
||||
if config["USE_PARALLEL"]:
|
||||
with Pool(config["NUM_PARALLEL_CORES"]) as pool:
|
||||
_eval_sequence = partial(
|
||||
eval_sequence,
|
||||
dataset=dataset,
|
||||
tracker=tracker,
|
||||
class_list=class_list,
|
||||
metrics_list=metrics_list,
|
||||
metric_names=metric_names,
|
||||
)
|
||||
results = pool.map(_eval_sequence, seq_list)
|
||||
res = dict(zip(seq_list, results))
|
||||
else:
|
||||
res = {}
|
||||
for curr_seq in sorted(seq_list):
|
||||
res[curr_seq] = eval_sequence(
|
||||
curr_seq, dataset, tracker, class_list, metrics_list, metric_names
|
||||
)
|
||||
|
||||
|
||||
# collecting combined cls keys (cls averaged, det averaged, super classes)
|
||||
cls_keys = []
|
||||
res["COMBINED_SEQ"] = {}
|
||||
# combine sequences for each class
|
||||
for c_cls in class_list:
|
||||
res["COMBINED_SEQ"][c_cls] = {}
|
||||
for metric, mname in zip(metrics_list, metric_names):
|
||||
curr_res = {
|
||||
seq_key: seq_value[c_cls][mname]
|
||||
for seq_key, seq_value in res.items()
|
||||
if seq_key != "COMBINED_SEQ"
|
||||
}
|
||||
# combine results over all sequences and then over all classes
|
||||
res["COMBINED_SEQ"][c_cls][mname] = metric.combine_sequences(curr_res)
|
||||
|
||||
# combine classes
|
||||
if dataset.should_classes_combine:
|
||||
if config["OUTPUT_PER_SEQ_RES"]:
|
||||
video_keys = res.keys()
|
||||
else:
|
||||
video_keys = ["COMBINED_SEQ"]
|
||||
for v_key in video_keys:
|
||||
cls_keys += ["average"]
|
||||
res[v_key]["average"] = {}
|
||||
for metric, mname in zip(metrics_list, metric_names):
|
||||
cls_res = {
|
||||
cls_key: cls_value[mname]
|
||||
for cls_key, cls_value in res[v_key].items()
|
||||
if cls_key not in cls_keys
|
||||
}
|
||||
res[v_key]["average"][
|
||||
mname
|
||||
] = metric.combine_classes_class_averaged(
|
||||
cls_res, ignore_empty=True
|
||||
)
|
||||
|
||||
# combine classes to super classes
|
||||
if dataset.use_super_categories:
|
||||
for cat, sub_cats in dataset.super_categories.items():
|
||||
cls_keys.append(cat)
|
||||
res["COMBINED_SEQ"][cat] = {}
|
||||
for metric, mname in zip(metrics_list, metric_names):
|
||||
cat_res = {
|
||||
cls_key: cls_value[mname]
|
||||
for cls_key, cls_value in res["COMBINED_SEQ"].items()
|
||||
if cls_key in sub_cats
|
||||
}
|
||||
res["COMBINED_SEQ"][cat][
|
||||
mname
|
||||
] = metric.combine_classes_det_averaged(cat_res)
|
||||
# Print and output results in various formats
|
||||
if config["TIME_PROGRESS"]:
|
||||
print(
|
||||
f"\nAll sequences for {tracker} finished in"
|
||||
f" {time.time() - time_start} seconds"
|
||||
)
|
||||
output_fol = dataset.get_output_fol(tracker)
|
||||
os.makedirs(output_fol, exist_ok=True)
|
||||
|
||||
# take a mean of each field of each thr
|
||||
if config["OUTPUT_PER_SEQ_RES"]:
|
||||
all_res = copy.deepcopy(res)
|
||||
summary_keys = res.keys()
|
||||
else:
|
||||
all_res = copy.deepcopy(res["COMBINED_SEQ"])
|
||||
summary_keys = ["COMBINED_SEQ"]
|
||||
thr_key_list = [50]
|
||||
for s_key in summary_keys:
|
||||
for metric, mname in zip(metrics_list, metric_names):
|
||||
if mname != "TETA":
|
||||
if s_key == "COMBINED_SEQ":
|
||||
metric.print_table(
|
||||
{"COMBINED_SEQ": res["COMBINED_SEQ"][cls_keys[0]][mname]},
|
||||
tracker,
|
||||
cls_keys[0],
|
||||
)
|
||||
continue
|
||||
|
||||
for c_cls in res[s_key].keys():
|
||||
for thr in thr_key_list:
|
||||
all_res[s_key][c_cls][mname][thr] = metric._summary_row(
|
||||
res[s_key][c_cls][mname][thr]
|
||||
)
|
||||
x = (
|
||||
np.array(list(all_res[s_key][c_cls]["TETA"].values()))
|
||||
.astype("float")
|
||||
.mean(axis=0)
|
||||
)
|
||||
all_res_summary = list(x.round(decimals=2).astype("str"))
|
||||
all_res[s_key][c_cls][mname]["ALL"] = all_res_summary
|
||||
if config["OUTPUT_SUMMARY"] and s_key == "COMBINED_SEQ":
|
||||
for t in thr_key_list:
|
||||
metric.print_summary_table(
|
||||
all_res[s_key][cls_keys[0]][mname][t],
|
||||
t,
|
||||
tracker,
|
||||
cls_keys[0],
|
||||
)
|
||||
|
||||
if config["OUTPUT_TEM_RAW_DATA"]:
|
||||
out_file = os.path.join(output_fol, "teta_summary_results.pth")
|
||||
pickle.dump(all_res, open(out_file, "wb"))
|
||||
print("Saved the TETA summary results.")
|
||||
|
||||
# output
|
||||
output_res[dname][mname] = all_res[s_key][cls_keys[0]][mname][t]
|
||||
output_msg[dname][tracker] = "Success"
|
||||
|
||||
return output_res, output_msg
|
||||
|
||||
|
||||
@_timing.time
|
||||
def eval_sequence(seq, dataset, tracker, class_list, metrics_list, metric_names):
|
||||
"""Function for evaluating a single sequence."""
|
||||
raw_data = dataset.get_raw_seq_data(tracker, seq)
|
||||
seq_res = {}
|
||||
|
||||
if "TETA" in metric_names:
|
||||
thresholds = [50]
|
||||
data_all_class = dataset.get_preprocessed_seq_data(
|
||||
raw_data, "all", thresholds=thresholds
|
||||
)
|
||||
teta = metrics_list[metric_names.index("TETA")]
|
||||
assignment = teta.compute_global_assignment(data_all_class)
|
||||
|
||||
# create a dict to save Cls_FP for each class in different thr.
|
||||
cls_fp = {
|
||||
key: {
|
||||
cls: np.zeros((len(np.arange(0.5, 0.99, 0.05)))) for cls in class_list
|
||||
}
|
||||
for key in thresholds
|
||||
}
|
||||
|
||||
for cls in class_list:
|
||||
seq_res[cls] = {}
|
||||
data = dataset.get_preprocessed_seq_data(raw_data, cls, assignment, thresholds)
|
||||
|
||||
for metric, mname in zip(metrics_list, metric_names):
|
||||
if mname == "TETA":
|
||||
seq_res[cls][mname], cls_fp, _ = metric.eval_sequence(
|
||||
data, cls, dataset.clsid2cls_name, cls_fp
|
||||
)
|
||||
else:
|
||||
seq_res[cls][mname] = metric.eval_sequence(data)
|
||||
|
||||
if "TETA" in metric_names:
|
||||
for thr in thresholds:
|
||||
for cls in class_list:
|
||||
seq_res[cls]["TETA"][thr]["Cls_FP"] += cls_fp[thr][cls]
|
||||
|
||||
return seq_res
|
||||
4
sam3/eval/teta_eval_toolkit/metrics/__init__.py
Normal file
4
sam3/eval/teta_eval_toolkit/metrics/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
from .teta import TETA
|
||||
148
sam3/eval/teta_eval_toolkit/metrics/_base_metric.py
Normal file
148
sam3/eval/teta_eval_toolkit/metrics/_base_metric.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import _timing
|
||||
from ..utils import TrackEvalException
|
||||
|
||||
|
||||
class _BaseMetric(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
self.plottable = False
|
||||
self.integer_fields = []
|
||||
self.float_fields = []
|
||||
self.array_labels = []
|
||||
self.integer_array_fields = []
|
||||
self.float_array_fields = []
|
||||
self.fields = []
|
||||
self.summary_fields = []
|
||||
self.registered = False
|
||||
|
||||
#####################################################################
|
||||
# Abstract functions for subclasses to implement
|
||||
|
||||
@_timing.time
|
||||
@abstractmethod
|
||||
def eval_sequence(self, data):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def combine_sequences(self, all_res):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def combine_classes_class_averaged(self, all_res, ignore_empty=False):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def combine_classes_det_averaged(self, all_res):
|
||||
...
|
||||
|
||||
def plot_single_tracker_results(self, all_res, tracker, output_folder, cls):
|
||||
"""Plot results, only valid for metrics with self.plottable."""
|
||||
if self.plottable:
|
||||
raise NotImplementedError(
|
||||
f"plot_results is not implemented for metric {self.get_name()}"
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
#####################################################################
|
||||
# Helper functions which are useful for all metrics:
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
return cls.__name__
|
||||
|
||||
@staticmethod
|
||||
def _combine_sum(all_res, field):
|
||||
"""Combine sequence results via sum"""
|
||||
return sum([all_res[k][field] for k in all_res.keys()])
|
||||
|
||||
@staticmethod
|
||||
def _combine_weighted_av(all_res, field, comb_res, weight_field):
|
||||
"""Combine sequence results via weighted average."""
|
||||
return sum(
|
||||
[all_res[k][field] * all_res[k][weight_field] for k in all_res.keys()]
|
||||
) / np.maximum(1.0, comb_res[weight_field])
|
||||
|
||||
def print_table(self, table_res, tracker, cls):
|
||||
"""Print table of results for all sequences."""
|
||||
print("")
|
||||
metric_name = self.get_name()
|
||||
self._row_print(
|
||||
[metric_name + ": " + tracker + "-" + cls] + self.summary_fields
|
||||
)
|
||||
for seq, results in sorted(table_res.items()):
|
||||
if seq == "COMBINED_SEQ":
|
||||
continue
|
||||
summary_res = self._summary_row(results)
|
||||
self._row_print([seq] + summary_res)
|
||||
summary_res = self._summary_row(table_res["COMBINED_SEQ"])
|
||||
self._row_print(["COMBINED"] + summary_res)
|
||||
|
||||
def _summary_row(self, results_):
|
||||
vals = []
|
||||
for h in self.summary_fields:
|
||||
if h in self.float_array_fields:
|
||||
vals.append("{0:1.5g}".format(100 * np.mean(results_[h])))
|
||||
elif h in self.float_fields:
|
||||
vals.append("{0:1.5g}".format(100 * float(results_[h])))
|
||||
elif h in self.integer_fields:
|
||||
vals.append("{0:d}".format(int(results_[h])))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Summary function not implemented for this field type."
|
||||
)
|
||||
return vals
|
||||
|
||||
@staticmethod
|
||||
def _row_print(*argv):
|
||||
"""Print results in evenly spaced rows, with more space in first row."""
|
||||
if len(argv) == 1:
|
||||
argv = argv[0]
|
||||
to_print = "%-35s" % argv[0]
|
||||
for v in argv[1:]:
|
||||
to_print += "%-10s" % str(v)
|
||||
print(to_print)
|
||||
|
||||
def summary_results(self, table_res):
|
||||
"""Return a simple summary of final results for a tracker."""
|
||||
return dict(
|
||||
zip(self.summary_fields, self._summary_row(table_res["COMBINED_SEQ"]),)
|
||||
)
|
||||
|
||||
def detailed_results(self, table_res):
|
||||
"""Return detailed final results for a tracker."""
|
||||
# Get detailed field information
|
||||
detailed_fields = self.float_fields + self.integer_fields
|
||||
for h in self.float_array_fields + self.integer_array_fields:
|
||||
for alpha in [int(100 * x) for x in self.array_labels]:
|
||||
detailed_fields.append(h + "___" + str(alpha))
|
||||
detailed_fields.append(h + "___AUC")
|
||||
|
||||
# Get detailed results
|
||||
detailed_results = {}
|
||||
for seq, res in table_res.items():
|
||||
detailed_row = self._detailed_row(res)
|
||||
if len(detailed_row) != len(detailed_fields):
|
||||
raise TrackEvalException(
|
||||
f"Field names and data have different sizes "
|
||||
f"({len(detailed_row)} and {len(detailed_fields)})"
|
||||
)
|
||||
detailed_results[seq] = dict(zip(detailed_fields, detailed_row))
|
||||
return detailed_results
|
||||
|
||||
def _detailed_row(self, res):
|
||||
detailed_row = []
|
||||
for h in self.float_fields + self.integer_fields:
|
||||
detailed_row.append(res[h])
|
||||
for h in self.float_array_fields + self.integer_array_fields:
|
||||
for i, _ in enumerate([int(100 * x) for x in self.array_labels]):
|
||||
detailed_row.append(res[h][i])
|
||||
detailed_row.append(np.mean(res[h]))
|
||||
return detailed_row
|
||||
399
sam3/eval/teta_eval_toolkit/metrics/teta.py
Normal file
399
sam3/eval/teta_eval_toolkit/metrics/teta.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
"""Track Every Thing Accuracy metric."""
|
||||
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from .. import _timing
|
||||
from ._base_metric import _BaseMetric
|
||||
|
||||
EPS = np.finfo("float").eps # epsilon
|
||||
|
||||
|
||||
class TETA(_BaseMetric):
|
||||
"""TETA metric."""
|
||||
|
||||
def __init__(self, exhaustive=False, config=None):
|
||||
"""Initialize metric."""
|
||||
super().__init__()
|
||||
self.plottable = True
|
||||
self.array_labels = np.arange(0.0, 0.99, 0.05)
|
||||
self.cls_array_labels = np.arange(0.5, 0.99, 0.05)
|
||||
|
||||
self.integer_array_fields = [
|
||||
"Loc_TP",
|
||||
"Loc_FN",
|
||||
"Loc_FP",
|
||||
"Cls_TP",
|
||||
"Cls_FN",
|
||||
"Cls_FP",
|
||||
]
|
||||
self.float_array_fields = (
|
||||
["TETA", "LocA", "AssocA", "ClsA"]
|
||||
+ ["LocRe", "LocPr"]
|
||||
+ ["AssocRe", "AssocPr"]
|
||||
+ ["ClsRe", "ClsPr"]
|
||||
)
|
||||
self.fields = self.float_array_fields + self.integer_array_fields
|
||||
self.summary_fields = self.float_array_fields
|
||||
self.exhaustive = exhaustive
|
||||
|
||||
def compute_global_assignment(self, data_thr, alpha=0.5):
|
||||
"""Compute global assignment of TP."""
|
||||
res = {
|
||||
thr: {t: {} for t in range(data_thr[thr]["num_timesteps"])}
|
||||
for thr in data_thr
|
||||
}
|
||||
|
||||
for thr in data_thr:
|
||||
data = data_thr[thr]
|
||||
# return empty result if tracker or gt sequence is empty
|
||||
if data["num_tk_overlap_dets"] == 0 or data["num_gt_dets"] == 0:
|
||||
return res
|
||||
|
||||
# global alignment score
|
||||
ga_score, _, _ = self.compute_global_alignment_score(data)
|
||||
|
||||
# calculate scores for each timestep
|
||||
for t, (gt_ids_t, tk_ids_t) in enumerate(
|
||||
zip(data["gt_ids"], data["tk_ids"])
|
||||
):
|
||||
# get matches optimizing for TETA
|
||||
amatch_rows, amatch_cols = self.compute_matches(
|
||||
data, t, ga_score, gt_ids_t, tk_ids_t, alpha=alpha
|
||||
)
|
||||
gt_ids = [data["gt_id_map"][tid] for tid in gt_ids_t[amatch_rows[0]]]
|
||||
matched_ids = [
|
||||
data["tk_id_map"][tid] for tid in tk_ids_t[amatch_cols[0]]
|
||||
]
|
||||
res[thr][t] = dict(zip(gt_ids, matched_ids))
|
||||
|
||||
return res
|
||||
|
||||
def eval_sequence_single_thr(self, data, cls, cid2clsname, cls_fp_thr, thr):
|
||||
"""Computes TETA metric for one threshold for one sequence."""
|
||||
res = {}
|
||||
class_info_list = []
|
||||
for field in self.float_array_fields + self.integer_array_fields:
|
||||
if field.startswith("Cls"):
|
||||
res[field] = np.zeros(len(self.cls_array_labels), dtype=float)
|
||||
else:
|
||||
res[field] = np.zeros((len(self.array_labels)), dtype=float)
|
||||
|
||||
# return empty result if tracker or gt sequence is empty
|
||||
if data["num_tk_overlap_dets"] == 0:
|
||||
res["Loc_FN"] = data["num_gt_dets"] * np.ones(
|
||||
(len(self.array_labels)), dtype=float
|
||||
)
|
||||
if self.exhaustive:
|
||||
cls_fp_thr[cls] = data["num_tk_cls_dets"] * np.ones(
|
||||
(len(self.cls_array_labels)), dtype=float
|
||||
)
|
||||
res = self._compute_final_fields(res)
|
||||
return res, cls_fp_thr, class_info_list
|
||||
|
||||
if data["num_gt_dets"] == 0:
|
||||
if self.exhaustive:
|
||||
cls_fp_thr[cls] = data["num_tk_cls_dets"] * np.ones(
|
||||
(len(self.cls_array_labels)), dtype=float
|
||||
)
|
||||
res = self._compute_final_fields(res)
|
||||
return res, cls_fp_thr, class_info_list
|
||||
|
||||
# global alignment score
|
||||
ga_score, gt_id_count, tk_id_count = self.compute_global_alignment_score(data)
|
||||
matches_counts = [np.zeros_like(ga_score) for _ in self.array_labels]
|
||||
|
||||
# calculate scores for each timestep
|
||||
for t, (gt_ids_t, tk_ids_t, tk_overlap_ids_t, tk_cls_ids_t) in enumerate(
|
||||
zip(
|
||||
data["gt_ids"],
|
||||
data["tk_ids"],
|
||||
data["tk_overlap_ids"],
|
||||
data["tk_class_eval_tk_ids"],
|
||||
)
|
||||
):
|
||||
# deal with the case that there are no gt_det/tk_det in a timestep
|
||||
if len(gt_ids_t) == 0:
|
||||
if self.exhaustive:
|
||||
cls_fp_thr[cls] += len(tk_cls_ids_t)
|
||||
continue
|
||||
|
||||
# get matches optimizing for TETA
|
||||
amatch_rows, amatch_cols = self.compute_matches(
|
||||
data, t, ga_score, gt_ids_t, tk_ids_t, list(self.array_labels)
|
||||
)
|
||||
|
||||
# map overlap_ids to original ids.
|
||||
if len(tk_overlap_ids_t) != 0:
|
||||
sorter = np.argsort(tk_ids_t)
|
||||
indexes = sorter[
|
||||
np.searchsorted(tk_ids_t, tk_overlap_ids_t, sorter=sorter)
|
||||
]
|
||||
sim_t = data["sim_scores"][t][:, indexes]
|
||||
fpl_candidates = tk_overlap_ids_t[(sim_t >= (thr / 100)).any(axis=0)]
|
||||
fpl_candidates_ori_ids_t = np.array(
|
||||
[data["tk_id_map"][tid] for tid in fpl_candidates]
|
||||
)
|
||||
else:
|
||||
fpl_candidates_ori_ids_t = []
|
||||
|
||||
if self.exhaustive:
|
||||
cls_fp_thr[cls] += len(tk_cls_ids_t) - len(tk_overlap_ids_t)
|
||||
|
||||
# calculate and accumulate basic statistics
|
||||
for a, alpha in enumerate(self.array_labels):
|
||||
match_row, match_col = amatch_rows[a], amatch_cols[a]
|
||||
num_matches = len(match_row)
|
||||
matched_ori_ids = set(
|
||||
[data["tk_id_map"][tid] for tid in tk_ids_t[match_col]]
|
||||
)
|
||||
match_tk_cls = data["tk_classes"][t][match_col]
|
||||
wrong_tk_cls = match_tk_cls[match_tk_cls != data["gt_classes"][t]]
|
||||
|
||||
num_class_and_det_matches = np.sum(
|
||||
match_tk_cls == data["gt_classes"][t]
|
||||
)
|
||||
|
||||
if alpha >= 0.5:
|
||||
for cid in wrong_tk_cls:
|
||||
if cid in cid2clsname:
|
||||
cname = cid2clsname[cid]
|
||||
cls_fp_thr[cname][a - 10] += 1
|
||||
res["Cls_TP"][a - 10] += num_class_and_det_matches
|
||||
res["Cls_FN"][a - 10] += num_matches - num_class_and_det_matches
|
||||
|
||||
res["Loc_TP"][a] += num_matches
|
||||
res["Loc_FN"][a] += len(gt_ids_t) - num_matches
|
||||
res["Loc_FP"][a] += len(set(fpl_candidates_ori_ids_t) - matched_ori_ids)
|
||||
|
||||
if num_matches > 0:
|
||||
matches_counts[a][gt_ids_t[match_row], tk_ids_t[match_col]] += 1
|
||||
|
||||
# calculate AssocA, AssocRe, AssocPr
|
||||
self.compute_association_scores(res, matches_counts, gt_id_count, tk_id_count)
|
||||
|
||||
# calculate final scores
|
||||
res = self._compute_final_fields(res)
|
||||
return res, cls_fp_thr, class_info_list
|
||||
|
||||
def compute_global_alignment_score(self, data):
|
||||
"""Computes global alignment score."""
|
||||
num_matches = np.zeros((data["num_gt_ids"], data["num_tk_ids"]))
|
||||
gt_id_count = np.zeros((data["num_gt_ids"], 1))
|
||||
tk_id_count = np.zeros((1, data["num_tk_ids"]))
|
||||
|
||||
# loop through each timestep and accumulate global track info.
|
||||
for t, (gt_ids_t, tk_ids_t) in enumerate(zip(data["gt_ids"], data["tk_ids"])):
|
||||
# count potential matches between ids in each time step
|
||||
# these are normalized, weighted by match similarity
|
||||
sim = data["sim_scores"][t]
|
||||
sim_iou_denom = sim.sum(0, keepdims=True) + sim.sum(1, keepdims=True) - sim
|
||||
sim_iou = np.zeros_like(sim)
|
||||
mask = sim_iou_denom > (0 + EPS)
|
||||
sim_iou[mask] = sim[mask] / sim_iou_denom[mask]
|
||||
num_matches[gt_ids_t[:, None], tk_ids_t[None, :]] += sim_iou
|
||||
|
||||
# calculate total number of dets for each gt_id and tk_id.
|
||||
gt_id_count[gt_ids_t] += 1
|
||||
tk_id_count[0, tk_ids_t] += 1
|
||||
|
||||
# Calculate overall Jaccard alignment score between IDs
|
||||
ga_score = num_matches / (gt_id_count + tk_id_count - num_matches)
|
||||
return ga_score, gt_id_count, tk_id_count
|
||||
|
||||
def compute_matches(self, data, t, ga_score, gt_ids, tk_ids, alpha):
|
||||
"""Compute matches based on alignment score."""
|
||||
sim = data["sim_scores"][t]
|
||||
score_mat = ga_score[gt_ids[:, None], tk_ids[None, :]] * sim
|
||||
# Hungarian algorithm to find best matches
|
||||
match_rows, match_cols = linear_sum_assignment(-score_mat)
|
||||
|
||||
if not isinstance(alpha, list):
|
||||
alpha = [alpha]
|
||||
alpha_match_rows, alpha_match_cols = [], []
|
||||
for a in alpha:
|
||||
matched_mask = sim[match_rows, match_cols] >= a - EPS
|
||||
alpha_match_rows.append(match_rows[matched_mask])
|
||||
alpha_match_cols.append(match_cols[matched_mask])
|
||||
return alpha_match_rows, alpha_match_cols
|
||||
|
||||
def compute_association_scores(self, res, matches_counts, gt_id_count, tk_id_count):
|
||||
"""Calculate association scores for each alpha.
|
||||
|
||||
First calculate scores per gt_id/tk_id combo,
|
||||
and then average over the number of detections.
|
||||
"""
|
||||
for a, _ in enumerate(self.array_labels):
|
||||
matches_count = matches_counts[a]
|
||||
ass_a = matches_count / np.maximum(
|
||||
1, gt_id_count + tk_id_count - matches_count
|
||||
)
|
||||
res["AssocA"][a] = np.sum(matches_count * ass_a) / np.maximum(
|
||||
1, res["Loc_TP"][a]
|
||||
)
|
||||
ass_re = matches_count / np.maximum(1, gt_id_count)
|
||||
res["AssocRe"][a] = np.sum(matches_count * ass_re) / np.maximum(
|
||||
1, res["Loc_TP"][a]
|
||||
)
|
||||
ass_pr = matches_count / np.maximum(1, tk_id_count)
|
||||
res["AssocPr"][a] = np.sum(matches_count * ass_pr) / np.maximum(
|
||||
1, res["Loc_TP"][a]
|
||||
)
|
||||
|
||||
@_timing.time
|
||||
def eval_sequence(self, data, cls, cls_id_name_mapping, cls_fp):
|
||||
"""Evaluate a single sequence across all thresholds."""
|
||||
res = {}
|
||||
class_info_dict = {}
|
||||
|
||||
for thr in data:
|
||||
res[thr], cls_fp[thr], cls_info = self.eval_sequence_single_thr(
|
||||
data[thr], cls, cls_id_name_mapping, cls_fp[thr], thr
|
||||
)
|
||||
class_info_dict[thr] = cls_info
|
||||
|
||||
return res, cls_fp, class_info_dict
|
||||
|
||||
def combine_sequences(self, all_res):
|
||||
"""Combines metrics across all sequences."""
|
||||
data = {}
|
||||
res = {}
|
||||
|
||||
if all_res:
|
||||
thresholds = list(list(all_res.values())[0].keys())
|
||||
else:
|
||||
thresholds = [50]
|
||||
for thr in thresholds:
|
||||
data[thr] = {}
|
||||
for seq_key in all_res:
|
||||
data[thr][seq_key] = all_res[seq_key][thr]
|
||||
for thr in thresholds:
|
||||
res[thr] = self._combine_sequences_thr(data[thr])
|
||||
|
||||
return res
|
||||
|
||||
def _combine_sequences_thr(self, all_res):
|
||||
"""Combines sequences over each threshold."""
|
||||
res = {}
|
||||
for field in self.integer_array_fields:
|
||||
res[field] = self._combine_sum(all_res, field)
|
||||
for field in ["AssocRe", "AssocPr", "AssocA"]:
|
||||
res[field] = self._combine_weighted_av(
|
||||
all_res, field, res, weight_field="Loc_TP"
|
||||
)
|
||||
res = self._compute_final_fields(res)
|
||||
return res
|
||||
|
||||
def combine_classes_class_averaged(self, all_res, ignore_empty=False):
|
||||
"""Combines metrics across all classes by averaging over classes.
|
||||
|
||||
If 'ignore_empty' is True, then it only sums over classes
|
||||
with at least one gt or predicted detection.
|
||||
"""
|
||||
data = {}
|
||||
res = {}
|
||||
if all_res:
|
||||
thresholds = list(list(all_res.values())[0].keys())
|
||||
else:
|
||||
thresholds = [50]
|
||||
for thr in thresholds:
|
||||
data[thr] = {}
|
||||
for cls_key in all_res:
|
||||
data[thr][cls_key] = all_res[cls_key][thr]
|
||||
for thr in data:
|
||||
res[thr] = self._combine_classes_class_averaged_thr(
|
||||
data[thr], ignore_empty=ignore_empty
|
||||
)
|
||||
return res
|
||||
|
||||
def _combine_classes_class_averaged_thr(self, all_res, ignore_empty=False):
|
||||
"""Combines classes over each threshold."""
|
||||
res = {}
|
||||
|
||||
def check_empty(val):
|
||||
"""Returns True if empty."""
|
||||
return not (val["Loc_TP"] + val["Loc_FN"] + val["Loc_FP"] > 0 + EPS).any()
|
||||
|
||||
for field in self.integer_array_fields:
|
||||
if ignore_empty:
|
||||
res_field = {k: v for k, v in all_res.items() if not check_empty(v)}
|
||||
else:
|
||||
res_field = {k: v for k, v in all_res.items()}
|
||||
res[field] = self._combine_sum(res_field, field)
|
||||
|
||||
for field in self.float_array_fields:
|
||||
if ignore_empty:
|
||||
res_field = [v[field] for v in all_res.values() if not check_empty(v)]
|
||||
else:
|
||||
res_field = [v[field] for v in all_res.values()]
|
||||
res[field] = np.mean(res_field, axis=0)
|
||||
return res
|
||||
|
||||
def combine_classes_det_averaged(self, all_res):
|
||||
"""Combines metrics across all classes by averaging over detections."""
|
||||
data = {}
|
||||
res = {}
|
||||
if all_res:
|
||||
thresholds = list(list(all_res.values())[0].keys())
|
||||
else:
|
||||
thresholds = [50]
|
||||
for thr in thresholds:
|
||||
data[thr] = {}
|
||||
for cls_key in all_res:
|
||||
data[thr][cls_key] = all_res[cls_key][thr]
|
||||
for thr in data:
|
||||
res[thr] = self._combine_classes_det_averaged_thr(data[thr])
|
||||
return res
|
||||
|
||||
def _combine_classes_det_averaged_thr(self, all_res):
|
||||
"""Combines detections over each threshold."""
|
||||
res = {}
|
||||
for field in self.integer_array_fields:
|
||||
res[field] = self._combine_sum(all_res, field)
|
||||
for field in ["AssocRe", "AssocPr", "AssocA"]:
|
||||
res[field] = self._combine_weighted_av(
|
||||
all_res, field, res, weight_field="Loc_TP"
|
||||
)
|
||||
res = self._compute_final_fields(res)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _compute_final_fields(res):
|
||||
"""Calculate final metric values.
|
||||
|
||||
This function is used both for both per-sequence calculation,
|
||||
and in combining values across sequences.
|
||||
"""
|
||||
# LocA
|
||||
res["LocRe"] = res["Loc_TP"] / np.maximum(1, res["Loc_TP"] + res["Loc_FN"])
|
||||
res["LocPr"] = res["Loc_TP"] / np.maximum(1, res["Loc_TP"] + res["Loc_FP"])
|
||||
res["LocA"] = res["Loc_TP"] / np.maximum(
|
||||
1, res["Loc_TP"] + res["Loc_FN"] + res["Loc_FP"]
|
||||
)
|
||||
|
||||
# ClsA
|
||||
res["ClsRe"] = res["Cls_TP"] / np.maximum(1, res["Cls_TP"] + res["Cls_FN"])
|
||||
res["ClsPr"] = res["Cls_TP"] / np.maximum(1, res["Cls_TP"] + res["Cls_FP"])
|
||||
res["ClsA"] = res["Cls_TP"] / np.maximum(
|
||||
1, res["Cls_TP"] + res["Cls_FN"] + res["Cls_FP"]
|
||||
)
|
||||
|
||||
res["ClsRe"] = np.mean(res["ClsRe"])
|
||||
res["ClsPr"] = np.mean(res["ClsPr"])
|
||||
res["ClsA"] = np.mean(res["ClsA"])
|
||||
|
||||
res["TETA"] = (res["LocA"] + res["AssocA"] + res["ClsA"]) / 3
|
||||
|
||||
return res
|
||||
|
||||
def print_summary_table(self, thr_res, thr, tracker, cls):
|
||||
"""Prints summary table of results."""
|
||||
print("")
|
||||
metric_name = self.get_name()
|
||||
self._row_print(
|
||||
[f"{metric_name}{str(thr)}: {tracker}-{cls}"] + self.summary_fields
|
||||
)
|
||||
self._row_print(["COMBINED"] + thr_res)
|
||||
46
sam3/eval/teta_eval_toolkit/utils.py
Normal file
46
sam3/eval/teta_eval_toolkit/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
import csv
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def validate_metrics_list(metrics_list):
|
||||
"""Get names of metric class and ensures they are unique, further checks that the fields within each metric class
|
||||
do not have overlapping names.
|
||||
"""
|
||||
metric_names = [metric.get_name() for metric in metrics_list]
|
||||
# check metric names are unique
|
||||
if len(metric_names) != len(set(metric_names)):
|
||||
raise TrackEvalException(
|
||||
"Code being run with multiple metrics of the same name"
|
||||
)
|
||||
fields = []
|
||||
for m in metrics_list:
|
||||
fields += m.fields
|
||||
# check metric fields are unique
|
||||
if len(fields) != len(set(fields)):
|
||||
raise TrackEvalException(
|
||||
"Code being run with multiple metrics with fields of the same name"
|
||||
)
|
||||
return metric_names
|
||||
|
||||
|
||||
def get_track_id_str(ann):
|
||||
"""Get name of track ID in annotation."""
|
||||
if "track_id" in ann:
|
||||
tk_str = "track_id"
|
||||
elif "instance_id" in ann:
|
||||
tk_str = "instance_id"
|
||||
elif "scalabel_id" in ann:
|
||||
tk_str = "scalabel_id"
|
||||
else:
|
||||
assert False, "No track/instance ID."
|
||||
return tk_str
|
||||
|
||||
|
||||
class TrackEvalException(Exception):
|
||||
"""Custom exception for catching expected errors."""
|
||||
|
||||
...
|
||||
Reference in New Issue
Block a user