diff --git a/sam3/train/nms_helper.py b/sam3/train/nms_helper.py index cd5b6dc..d6378bc 100644 --- a/sam3/train/nms_helper.py +++ b/sam3/train/nms_helper.py @@ -33,7 +33,7 @@ def convert_bbox_format(bbox: list) -> List[float]: # -------------------- Track-level NMS -------------------- def process_track_level_nms(video_groups: Dict, nms_threshold: float) -> Dict: """Apply track-level NMS to all videos""" - for video_id, tracks in video_groups.items(): + for tracks in video_groups.values(): track_detections = [] # Process tracks @@ -76,7 +76,7 @@ def process_track_level_nms(video_groups: Dict, nms_threshold: float) -> Dict: # -------------------- Frame-level NMS -------------------- def process_frame_level_nms(video_groups: Dict, nms_threshold: float) -> Dict: """Apply frame-level NMS to all videos""" - for video_id, tracks in video_groups.items(): + for tracks in video_groups.values(): if not tracks: continue