Differential Revision: D90237984 fbshipit-source-id: 526fd760f303bf31be4f743bdcd77760496de0de
51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
# flake8: noqa
|
|
|
|
# pyre-unsafe
|
|
|
|
from .. import _timing
|
|
from ._base_metric import _BaseMetric
|
|
|
|
|
|
class Count(_BaseMetric):
|
|
"""Class which simply counts the number of tracker and gt detections and ids."""
|
|
|
|
def __init__(self, config=None):
|
|
super().__init__()
|
|
self.integer_fields = ["Dets", "GT_Dets", "IDs", "GT_IDs"]
|
|
self.fields = self.integer_fields
|
|
self.summary_fields = self.fields
|
|
|
|
@_timing.time
|
|
def eval_sequence(self, data):
|
|
"""Returns counts for one sequence"""
|
|
# Get results
|
|
res = {
|
|
"Dets": data["num_tracker_dets"],
|
|
"GT_Dets": data["num_gt_dets"],
|
|
"IDs": data["num_tracker_ids"],
|
|
"GT_IDs": data["num_gt_ids"],
|
|
"Frames": data["num_timesteps"],
|
|
}
|
|
return res
|
|
|
|
def combine_sequences(self, all_res):
|
|
"""Combines metrics across all sequences"""
|
|
res = {}
|
|
for field in self.integer_fields:
|
|
res[field] = self._combine_sum(all_res, field)
|
|
return res
|
|
|
|
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=None):
|
|
"""Combines metrics across all classes by averaging over the class values"""
|
|
res = {}
|
|
for field in self.integer_fields:
|
|
res[field] = self._combine_sum(all_res, field)
|
|
return res
|
|
|
|
def combine_classes_det_averaged(self, all_res):
|
|
"""Combines metrics across all classes by averaging over the detection values"""
|
|
res = {}
|
|
for field in self.integer_fields:
|
|
res[field] = self._combine_sum(all_res, field)
|
|
return res
|