Differential Revision: D90237984 fbshipit-source-id: 526fd760f303bf31be4f743bdcd77760496de0de
49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
# fmt: off
|
|
# flake8: noqa
|
|
|
|
# pyre-unsafe
|
|
|
|
import csv
|
|
import os
|
|
from collections import OrderedDict
|
|
|
|
|
|
def validate_metrics_list(metrics_list):
|
|
"""Get names of metric class and ensures they are unique, further checks that the fields within each metric class
|
|
do not have overlapping names.
|
|
"""
|
|
metric_names = [metric.get_name() for metric in metrics_list]
|
|
# check metric names are unique
|
|
if len(metric_names) != len(set(metric_names)):
|
|
raise TrackEvalException(
|
|
"Code being run with multiple metrics of the same name"
|
|
)
|
|
fields = []
|
|
for m in metrics_list:
|
|
fields += m.fields
|
|
# check metric fields are unique
|
|
if len(fields) != len(set(fields)):
|
|
raise TrackEvalException(
|
|
"Code being run with multiple metrics with fields of the same name"
|
|
)
|
|
return metric_names
|
|
|
|
|
|
def get_track_id_str(ann):
|
|
"""Get name of track ID in annotation."""
|
|
if "track_id" in ann:
|
|
tk_str = "track_id"
|
|
elif "instance_id" in ann:
|
|
tk_str = "instance_id"
|
|
elif "scalabel_id" in ann:
|
|
tk_str = "scalabel_id"
|
|
else:
|
|
assert False, "No track/instance ID."
|
|
return tk_str
|
|
|
|
|
|
class TrackEvalException(Exception):
|
|
"""Custom exception for catching expected errors."""
|
|
|
|
...
|