diff --git a/sam3/train/data/collator.py b/sam3/train/data/collator.py index 4a0f2e8..38d031d 100644 --- a/sam3/train/data/collator.py +++ b/sam3/train/data/collator.py @@ -194,7 +194,7 @@ def collate_fn_api( offset_img_id = 0 offset_query_id = [0 for _ in range(num_stages)] - for i, data in enumerate(batch): + for data in batch: img_batch.extend([img.data for img in data.images]) if data.raw_images is not None: @@ -209,7 +209,7 @@ def collate_fn_api( datapoint_query_id_2_stage_query_id.append(offset_query_id[stage_id]) offset_query_id[stage_id] += 1 - for j, q in enumerate(data.find_queries): + for q in data.find_queries: stage_id = q.query_processing_order stages[stage_id].img_ids.append(q.image_id + offset_img_id) if q.query_text not in text_batch: