Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
46
sam3/eval/teta_eval_toolkit/utils.py
Normal file
46
sam3/eval/teta_eval_toolkit/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
|
||||
import csv
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def validate_metrics_list(metrics_list):
|
||||
"""Get names of metric class and ensures they are unique, further checks that the fields within each metric class
|
||||
do not have overlapping names.
|
||||
"""
|
||||
metric_names = [metric.get_name() for metric in metrics_list]
|
||||
# check metric names are unique
|
||||
if len(metric_names) != len(set(metric_names)):
|
||||
raise TrackEvalException(
|
||||
"Code being run with multiple metrics of the same name"
|
||||
)
|
||||
fields = []
|
||||
for m in metrics_list:
|
||||
fields += m.fields
|
||||
# check metric fields are unique
|
||||
if len(fields) != len(set(fields)):
|
||||
raise TrackEvalException(
|
||||
"Code being run with multiple metrics with fields of the same name"
|
||||
)
|
||||
return metric_names
|
||||
|
||||
|
||||
def get_track_id_str(ann):
|
||||
"""Get name of track ID in annotation."""
|
||||
if "track_id" in ann:
|
||||
tk_str = "track_id"
|
||||
elif "instance_id" in ann:
|
||||
tk_str = "instance_id"
|
||||
elif "scalabel_id" in ann:
|
||||
tk_str = "scalabel_id"
|
||||
else:
|
||||
assert False, "No track/instance ID."
|
||||
return tk_str
|
||||
|
||||
|
||||
class TrackEvalException(Exception):
|
||||
"""Custom exception for catching expected errors."""
|
||||
|
||||
...
|
||||
Reference in New Issue
Block a user