Initial commit

fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
facebook-github-bot
2025-11-18 23:07:42 -08:00
commit a13e358df4
504 changed files with 122758 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

View File

@@ -0,0 +1,465 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import json
from collections import defaultdict
from typing import Dict, List, Tuple
import torch
from pycocotools import mask as mask_util
# ============================================================================
# Utility Functions
# ============================================================================
def convert_boxlist_to_normalized_tensor(box_list, image_width, image_height):
"""
Converts a list of bounding boxes to a normalized PyTorch tensor.
Args:
box_list (list of list or tuples): Each box is [x_min, y_min, x_max, y_max].
image_width (int or float): Width of the image.
image_height (int or float): Height of the image.
Returns:
torch.Tensor: Normalized tensor of shape (N, 4), values in [0, 1].
"""
boxes = torch.tensor(box_list, dtype=torch.float32)
boxes[:, [0, 2]] /= image_width # x_min, x_max
boxes[:, [1, 3]] /= image_height # y_min, y_max
boxes = boxes.clamp(0, 1)
return boxes
def load_coco_and_group_by_image(json_path: str) -> Tuple[List[Dict], Dict[int, str]]:
"""
Load COCO JSON file and group annotations by image.
Args:
json_path (str): Path to COCO JSON file.
Returns:
Tuple containing:
- List of dicts with 'image' and 'annotations' keys
- Dict mapping category IDs to category names
"""
with open(json_path, "r") as f:
coco = json.load(f)
images = {img["id"]: img for img in coco["images"]}
anns_by_image = defaultdict(list)
for ann in coco["annotations"]:
anns_by_image[ann["image_id"]].append(ann)
sorted_image_ids = sorted(images.keys())
grouped = []
for image_id in sorted_image_ids:
image_info = images[image_id]
grouped.append(
{"image": image_info, "annotations": anns_by_image.get(image_id, [])}
)
cat_id_to_name = {cat["id"]: cat["name"] for cat in coco["categories"]}
return grouped, cat_id_to_name
def ann_to_rle(segm, im_info: Dict) -> Dict:
"""
Convert annotation which can be polygons or uncompressed RLE to RLE.
Args:
segm: Segmentation data (polygon list or RLE dict)
im_info (dict): Image info containing 'height' and 'width'
Returns:
RLE encoded segmentation
"""
h, w = im_info["height"], im_info["width"]
if isinstance(segm, list):
# Polygon - merge all parts into one mask RLE code
rles = mask_util.frPyObjects(segm, h, w)
rle = mask_util.merge(rles)
elif isinstance(segm["counts"], list):
# Uncompressed RLE
rle = mask_util.frPyObjects(segm, h, w)
else:
# Already RLE
rle = segm
return rle
# ============================================================================
# COCO Training API
# ============================================================================
class COCO_FROM_JSON:
"""
COCO training API for loading box-only annotations from JSON.
Groups all annotations per image and creates queries per category.
"""
def __init__(
self,
annotation_file,
prompts=None,
include_negatives=True,
category_chunk_size=None,
):
"""
Initialize the COCO training API.
Args:
annotation_file (str): Path to COCO JSON annotation file
prompts: Optional custom prompts for categories
include_negatives (bool): Whether to include negative examples (categories with no instances)
"""
self._raw_data, self._cat_idx_to_text = load_coco_and_group_by_image(
annotation_file
)
self._sorted_cat_ids = sorted(list(self._cat_idx_to_text.keys()))
self.prompts = None
self.include_negatives = include_negatives
self.category_chunk_size = (
category_chunk_size
if category_chunk_size is not None
else len(self._sorted_cat_ids)
)
self.category_chunks = [
self._sorted_cat_ids[i : i + self.category_chunk_size]
for i in range(0, len(self._sorted_cat_ids), self.category_chunk_size)
]
if prompts is not None:
prompts = eval(prompts)
self.prompts = {}
for loc_dict in prompts:
self.prompts[int(loc_dict["id"])] = loc_dict["name"]
assert len(self.prompts) == len(
self._sorted_cat_ids
), "Number of prompts must match number of categories"
def getDatapointIds(self):
"""Return all datapoint indices for training."""
return list(range(len(self._raw_data) * len(self.category_chunks)))
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
"""
Load queries and annotations for a specific datapoint.
Args:
idx (int): Datapoint index
Returns:
Tuple of (queries, annotations) lists
"""
img_idx = idx // len(self.category_chunks)
chunk_idx = idx % len(self.category_chunks)
cat_chunk = self.category_chunks[chunk_idx]
queries = []
annotations = []
query_template = {
"id": None,
"original_cat_id": None,
"object_ids_output": None,
"query_text": None,
"query_processing_order": 0,
"ptr_x_query_id": None,
"ptr_y_query_id": None,
"image_id": 0, # Single image per datapoint
"input_box": None,
"input_box_label": None,
"input_points": None,
"is_exhaustive": True,
}
annot_template = {
"image_id": 0,
"bbox": None, # Normalized bbox in xywh
"area": None, # Unnormalized area
"segmentation": None, # RLE encoded
"object_id": None,
"is_crowd": None,
"id": None,
}
raw_annotations = self._raw_data[img_idx]["annotations"]
image_info = self._raw_data[img_idx]["image"]
width, height = image_info["width"], image_info["height"]
# Group annotations by category
cat_id_to_anns = defaultdict(list)
for ann in raw_annotations:
cat_id_to_anns[ann["category_id"]].append(ann)
annotations_by_cat_sorted = [
(cat_id, cat_id_to_anns[cat_id]) for cat_id in cat_chunk
]
for cat_id, anns in annotations_by_cat_sorted:
if len(anns) == 0 and not self.include_negatives:
continue
cur_ann_ids = []
# Create annotations for this category
for ann in anns:
annotation = annot_template.copy()
annotation["id"] = len(annotations)
annotation["object_id"] = annotation["id"]
annotation["is_crowd"] = ann["iscrowd"]
normalized_boxes = convert_boxlist_to_normalized_tensor(
[ann["bbox"]], width, height
)
bbox = normalized_boxes[0]
annotation["area"] = (bbox[2] * bbox[3]).item()
annotation["bbox"] = bbox
if (
"segmentation" in ann
and ann["segmentation"] is not None
and ann["segmentation"] != []
):
annotation["segmentation"] = ann_to_rle(
ann["segmentation"], im_info=image_info
)
annotations.append(annotation)
cur_ann_ids.append(annotation["id"])
# Create query for this category
query = query_template.copy()
query["id"] = len(queries)
query["original_cat_id"] = cat_id
query["query_text"] = (
self._cat_idx_to_text[cat_id]
if self.prompts is None
else self.prompts[cat_id]
)
query["object_ids_output"] = cur_ann_ids
queries.append(query)
return queries, annotations
def loadImagesFromDatapoint(self, idx):
"""
Load image information for a specific datapoint.
Args:
idx (int): Datapoint index
Returns:
List containing image info dict
"""
img_idx = idx // len(self.category_chunks)
img_data = self._raw_data[img_idx]["image"]
images = [
{
"id": 0,
"file_name": img_data["file_name"],
"original_img_id": img_data["id"],
"coco_img_id": img_data["id"],
}
]
return images
# ============================================================================
# SAM3 Evaluation APIs
# ============================================================================
class SAM3_EVAL_API_FROM_JSON_NP:
"""
SAM3 evaluation API for loading noun phrase queries from JSON.
"""
def __init__(self, annotation_file):
"""
Initialize the SAM3 evaluation API.
Args:
annotation_file (str): Path to SAM3 JSON annotation file
"""
with open(annotation_file, "r") as f:
data = json.load(f)
self._image_data = data["images"]
def getDatapointIds(self):
"""Return all datapoint indices."""
return list(range(len(self._image_data)))
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
"""
Load queries and annotations for a specific datapoint.
Args:
idx (int): Datapoint index
Returns:
Tuple of (queries, annotations) lists
"""
cur_img_data = self._image_data[idx]
queries = []
annotations = []
query_template = {
"id": None,
"original_cat_id": None,
"object_ids_output": None,
"query_text": None,
"query_processing_order": 0,
"ptr_x_query_id": None,
"ptr_y_query_id": None,
"image_id": 0,
"input_box": None,
"input_box_label": None,
"input_points": None,
"is_exhaustive": True,
}
# Create query
query = query_template.copy()
query["id"] = len(queries)
query["original_cat_id"] = int(cur_img_data["queried_category"])
query["query_text"] = cur_img_data["text_input"]
query["object_ids_output"] = []
queries.append(query)
return queries, annotations
def loadImagesFromDatapoint(self, idx):
"""
Load image information for a specific datapoint.
Args:
idx (int): Datapoint index
Returns:
List containing image info dict
"""
img_data = self._image_data[idx]
images = [
{
"id": 0,
"file_name": img_data["file_name"],
"original_img_id": img_data["id"],
"coco_img_id": img_data["id"],
}
]
return images
class SAM3_VEVAL_API_FROM_JSON_NP:
"""
SAM3 video evaluation API for loading noun phrase queries from JSON.
"""
def __init__(self, annotation_file):
"""
Initialize the SAM3 video evaluation API.
Args:
annotation_file (str): Path to SAM3 video JSON annotation file
"""
with open(annotation_file, "r") as f:
data = json.load(f)
assert "video_np_pairs" in data, "Incorrect data format"
self._video_data = data["videos"]
self._video_id_to_np_ids = defaultdict(list)
self._cat_id_to_np = {}
for cat_dict in data["categories"]:
self._cat_id_to_np[cat_dict["id"]] = cat_dict["name"]
for video_np_dict in data["video_np_pairs"]:
self._video_id_to_np_ids[video_np_dict["video_id"]].append(
video_np_dict["category_id"]
)
assert (
self._cat_id_to_np[video_np_dict["category_id"]]
== video_np_dict["noun_phrase"]
), "Category name does not match text input"
def getDatapointIds(self):
"""Return all datapoint indices."""
return list(range(len(self._video_data)))
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
"""
Load queries and annotations for a specific video datapoint.
Args:
idx (int): Datapoint index
Returns:
Tuple of (queries, annotations) lists
"""
cur_vid_data = self._video_data[idx]
queries = []
annotations = []
query_template = {
"id": None,
"original_cat_id": None,
"object_ids_output": None,
"query_text": None,
"query_processing_order": 0,
"ptr_x_query_id": None,
"ptr_y_query_id": None,
"image_id": 0,
"input_box": None,
"input_box_label": None,
"input_points": None,
"is_exhaustive": True,
}
all_np_ids = self._video_id_to_np_ids[cur_vid_data["id"]]
for np_id in all_np_ids:
text_input = self._cat_id_to_np[np_id]
for i, image_path in enumerate(cur_vid_data["file_names"]):
query = query_template.copy()
query["id"] = len(queries)
query["original_cat_id"] = np_id
query["query_text"] = text_input
query["image_id"] = i
query["query_processing_order"] = i
query["object_ids_output"] = []
queries.append(query)
return queries, annotations
def loadImagesFromDatapoint(self, idx):
"""
Load image information for a specific video datapoint.
Args:
idx (int): Datapoint index
Returns:
List containing image info dicts for all frames
"""
video_data = self._video_data[idx]
images = [
{
"id": i,
"file_name": file_name,
"original_img_id": video_data["id"],
"coco_img_id": video_data["id"],
}
for i, file_name in enumerate(video_data["file_names"])
]
return images

360
sam3/train/data/collator.py Normal file
View File

@@ -0,0 +1,360 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
from typing import Any, get_args, get_origin, List, Union
import torch
from sam3.model.data_misc import (
BatchedDatapoint,
BatchedFindTarget,
BatchedInferenceMetadata,
FindStage,
)
from .sam3_image_dataset import Datapoint
MyTensor = Union[torch.Tensor, List[Any]]
def convert_my_tensors(obj):
def is_optional_field(field) -> bool:
return get_origin(field) is Union and type(None) in get_args(field)
for field in fields(obj):
if is_dataclass(getattr(obj, field.name)):
convert_my_tensors(getattr(obj, field.name))
continue
field_type = field.type
if is_optional_field(field.type):
field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type
if field_type != MyTensor or getattr(obj, field.name) is None:
continue
elif len(getattr(obj, field.name)) and isinstance(
getattr(obj, field.name)[0], torch.Tensor
):
stack_dim = 0
if field.name in [
"input_boxes",
"input_boxes_label",
]:
stack_dim = 1
setattr(
obj,
field.name,
torch.stack(getattr(obj, field.name), dim=stack_dim).to(
getattr(obj, field.name + "__type")
),
)
else:
setattr(
obj,
field.name,
torch.as_tensor(
getattr(obj, field.name), dtype=getattr(obj, field.name + "__type")
),
)
return obj
def packed_to_padded_naive(boxes_packed, num_boxes, fill_value=0):
"""
Convert a packed tensor of bounding boxes to a padded tensor of bounding
boxes. Naive implementation using a loop.
Inputs:
- boxes_packed: Tensor of shape (N_1 + ... + N_B, 4)
- num_boxes: Tensor of shape (B,) where num_boxes[i] = N_i
Returns:
- boxes_padded: Tensor of shape (B, N_max, 4) where N_max = max_i N_i
"""
B = num_boxes.shape[0]
Ns = num_boxes.tolist()
boxes_padded = boxes_packed.new_zeros(B, max(Ns), *boxes_packed.shape[1:])
if fill_value != 0:
boxes_padded[...] = fill_value
prev_idx = 0
for i in range(B):
next_idx = prev_idx + Ns[i]
boxes_padded[i, : Ns[i]] = boxes_packed[prev_idx:next_idx]
prev_idx = next_idx
return boxes_padded
def pad_tensor_list_to_longest(
tensors: List[torch.Tensor], dim=0, pad_val=0
) -> List[torch.Tensor]:
# Edits the list in-place
if not tensors:
return tensors
pad_len = max(t.shape[dim] for t in tensors)
for i in range(len(tensors)):
n_dims = len(tensors[i].shape)
n_right_dims = (n_dims - 1) - (n_dims + dim) % n_dims
n_pad = pad_len - tensors[i].shape[dim]
pad_tuple = tuple([0] * 2 * n_right_dims + [0, n_pad])
tensors[i] = torch.nn.functional.pad(tensors[i], pad_tuple, value=pad_val)
return tensors
def collate_fn_api_with_chunking(
batch,
num_chunks,
dict_key,
with_seg_masks=False,
input_points_embedding_dim=257,
repeats: int = 0,
load_image_in_fp16: bool = False,
):
assert num_chunks >= 1, "num_chunks must be >= 1"
# split the batch into num_chunks chunks
batch_chunks = [batch[i::num_chunks] for i in range(num_chunks)]
# collate each chunk
collated_chunks = [
collate_fn_api(
chunk,
dict_key,
with_seg_masks,
input_points_embedding_dim,
repeats,
# ptr_behaviour,
load_image_in_fp16,
)
for chunk in batch_chunks
]
return collated_chunks
def collate_fn_api(
batch: List[Datapoint],
dict_key,
with_seg_masks=False,
input_points_embedding_dim=257,
repeats: int = 0,
load_image_in_fp16: bool = False,
):
# img_batch = torch.stack(sum([[img.data for img in v.images] for v in batch], []))
img_batch = []
text_batch = []
raw_images = None
num_stages = (
max(q.query_processing_order for data in batch for q in data.find_queries) + 1
)
stages = [
FindStage(
img_ids=[],
text_ids=[],
input_boxes=[],
input_boxes_label=[],
input_boxes_mask=[],
input_points=[],
input_points_mask=[],
object_ids=[],
)
for _ in range(num_stages)
]
find_targets = [
BatchedFindTarget(
num_boxes=[],
boxes=[],
boxes_padded=[],
is_exhaustive=[],
segments=[],
semantic_segments=[],
is_valid_segment=[],
repeated_boxes=[],
object_ids=[],
object_ids_padded=[],
)
for _ in range(num_stages)
]
find_metadatas = [
BatchedInferenceMetadata(
coco_image_id=[],
original_size=[],
object_id=[],
frame_index=[],
original_image_id=[],
original_category_id=[],
is_conditioning_only=[],
)
for _ in range(num_stages)
]
offset_img_id = 0
offset_query_id = [0 for _ in range(num_stages)]
for i, data in enumerate(batch):
img_batch.extend([img.data for img in data.images])
if data.raw_images is not None:
if raw_images is None:
raw_images = []
raw_images.extend(data.raw_images)
# Conversion of query_ids indexing in a datapoint to query_ids indexing in a stage
datapoint_query_id_2_stage_query_id = []
for q in data.find_queries:
stage_id = q.query_processing_order
datapoint_query_id_2_stage_query_id.append(offset_query_id[stage_id])
offset_query_id[stage_id] += 1
for j, q in enumerate(data.find_queries):
stage_id = q.query_processing_order
stages[stage_id].img_ids.append(q.image_id + offset_img_id)
if q.query_text not in text_batch:
text_batch.append(q.query_text)
stages[stage_id].text_ids.append(text_batch.index(q.query_text))
assert (
q.inference_metadata is not None
), "inference_metadata must be provided when FindQueryLoaded is created."
for f in fields(q.inference_metadata):
getattr(find_metadatas[stage_id], f.name).append(
getattr(q.inference_metadata, f.name)
)
if q.input_bbox is not None:
assert q.input_bbox.numel() % 4 == 0
assert q.input_bbox_label is not None
nb_boxes = q.input_bbox.numel() // 4
assert len(q.input_bbox_label) == nb_boxes
stages[stage_id].input_boxes.append(q.input_bbox.view(nb_boxes, 4))
stages[stage_id].input_boxes_label.append(
q.input_bbox_label.view(nb_boxes)
)
stages[stage_id].input_boxes_mask.append(
torch.zeros(nb_boxes, dtype=torch.bool)
)
else:
stages[stage_id].input_boxes.append(torch.zeros(0, 4))
stages[stage_id].input_boxes_label.append(
torch.zeros(0, dtype=torch.bool)
)
stages[stage_id].input_boxes_mask.append(
torch.ones(0, dtype=torch.bool)
)
if q.input_points is not None:
stages[stage_id].input_points.append(
q.input_points.squeeze(0) # Strip a trivial batch index
)
# All masks will be padded up to the longest length
# with 1s before final conversion to batchd tensors
stages[stage_id].input_points_mask.append(
torch.zeros(q.input_points.shape[1])
)
else:
stages[stage_id].input_points.append(
torch.empty(0, input_points_embedding_dim)
)
stages[stage_id].input_points_mask.append(torch.empty(0))
current_out_boxes = []
current_out_object_ids = []
# Set the object ids referred to by this query
stages[stage_id].object_ids.append(q.object_ids_output)
for object_id in q.object_ids_output:
current_out_boxes.append(
data.images[q.image_id].objects[object_id].bbox
)
current_out_object_ids.append(object_id)
find_targets[stage_id].boxes.extend(current_out_boxes)
find_targets[stage_id].object_ids.extend(current_out_object_ids)
if repeats > 0:
for _ in range(repeats):
find_targets[stage_id].repeated_boxes.extend(current_out_boxes)
find_targets[stage_id].num_boxes.append(len(current_out_boxes))
find_targets[stage_id].is_exhaustive.append(q.is_exhaustive)
if with_seg_masks:
current_seg_mask = []
current_is_valid_segment = []
for object_id in q.object_ids_output:
seg_mask = data.images[q.image_id].objects[object_id].segment
if seg_mask is not None:
current_seg_mask.append(seg_mask)
current_is_valid_segment.append(1)
else:
dummy_mask = torch.zeros(
data.images[q.image_id].data.shape[-2:], dtype=torch.bool
)
current_seg_mask.append(dummy_mask)
current_is_valid_segment.append(0)
find_targets[stage_id].segments.extend(current_seg_mask)
find_targets[stage_id].is_valid_segment.extend(current_is_valid_segment)
else:
# We are not loading segmentation masks
find_targets[stage_id].segments = None
find_targets[stage_id].is_valid_segment = None
if q.semantic_target is not None:
find_targets[stage_id].semantic_segments.append(q.semantic_target)
offset_img_id += len(data.images)
# Pad input points to equal sequence lengths
for i in range(len(stages)):
stages[i].input_points = pad_tensor_list_to_longest(
stages[i].input_points, dim=0, pad_val=0
)
# Masked-out regions indicated by 1s.
stages[i].input_points_mask = pad_tensor_list_to_longest(
stages[i].input_points_mask, dim=0, pad_val=1
)
# Pad input boxes to equal sequence lengths
for i in range(len(stages)):
stages[i].input_boxes = pad_tensor_list_to_longest(
stages[i].input_boxes, dim=0, pad_val=0
)
stages[i].input_boxes_label = pad_tensor_list_to_longest(
stages[i].input_boxes_label, dim=0, pad_val=0
)
# Masked-out regions indicated by 1s.
stages[i].input_boxes_mask = pad_tensor_list_to_longest(
stages[i].input_boxes_mask, dim=0, pad_val=1
)
# Convert to tensors
for i in range(len(stages)):
stages[i] = convert_my_tensors(stages[i])
find_targets[i] = convert_my_tensors(find_targets[i])
find_metadatas[i] = convert_my_tensors(find_metadatas[i])
# get padded representation for the boxes
find_targets[i].boxes_padded = packed_to_padded_naive(
find_targets[i].boxes.view(-1, 4), find_targets[i].num_boxes
)
find_targets[i].object_ids_padded = packed_to_padded_naive(
find_targets[i].object_ids, find_targets[i].num_boxes, fill_value=-1
)
# Finalize the image batch
# check sizes
for img in img_batch[1:]:
assert img.shape == img_batch[0].shape, "All images must have the same size"
image_batch = torch.stack(img_batch)
if load_image_in_fp16:
# Optionally, cast the image tensors to fp16, which helps save GPU memory on
# long videos with thousands of frames (where image tensors could be several GBs)
image_batch = image_batch.half()
return {
dict_key: BatchedDatapoint(
img_batch=image_batch,
find_text_batch=text_batch,
find_inputs=stages,
find_targets=find_targets,
find_metadatas=find_metadatas,
raw_images=raw_images,
)
}

View File

@@ -0,0 +1,528 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Dataset class for modulated detection"""
import json
import os
import random
import sys
import traceback
from collections import Counter
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.data
import torchvision
from decord import cpu, VideoReader
from iopath.common.file_io import g_pathmgr
from PIL import Image as PILImage
from PIL.Image import DecompressionBombError
from sam3.model.box_ops import box_xywh_to_xyxy
from torchvision.datasets.vision import VisionDataset
from .coco_json_loaders import COCO_FROM_JSON
@dataclass
class InferenceMetadata:
"""Metadata required for postprocessing"""
# Coco id that corresponds to the "image" for evaluation by the coco evaluator
# This is used for our own "class agnostic" evaluation
coco_image_id: int
# id in the original dataset, such that we can use the original evaluator
original_image_id: int
# Original category id (if we want to use the original evaluator)
original_category_id: int
# Size of the raw image (height, width)
original_size: Tuple[int, int]
# Id of the object in the media
object_id: int
# Index of the frame in the media (0 if single image)
frame_index: int
# Whether it is for conditioning only, e.g., 0-th frame in TA is for conditioning
# as we assume GT available in frame 0.
is_conditioning_only: Optional[bool] = False
@dataclass
class FindQuery:
query_text: str
image_id: int
# In case of a find query, the list of object ids that have to be predicted
object_ids_output: List[int]
# This is "instance exhaustivity".
# true iff all instances are separable and annotated
# See below the slightly different "pixel exhaustivity"
is_exhaustive: bool
# The order in which the queries are processed (only meaningful for video)
query_processing_order: int = 0
# Input geometry, initially in denormalized XYXY format. Then
# 1. converted to normalized CxCyWH by the Normalize transform
input_bbox: Optional[torch.Tensor] = None
input_bbox_label: Optional[torch.Tensor] = None
# Only for the PVS task
input_points: Optional[torch.Tensor] = None
semantic_target: Optional[torch.Tensor] = None
# pixel exhaustivity: true iff the union of all segments (including crowds)
# covers every pixel belonging to the target class
# Note that instance_exhaustive implies pixel_exhaustive
is_pixel_exhaustive: Optional[bool] = None
@dataclass
class FindQueryLoaded(FindQuery):
# Must have default value since FindQuery has entries with default values
inference_metadata: Optional[InferenceMetadata] = None
@dataclass
class Object:
# Initially in denormalized XYXY format, gets converted to normalized CxCyWH by the Normalize transform
bbox: torch.Tensor
area: float
# Id of the object in the media
object_id: Optional[int] = -1
# Index of the frame in the media (0 if single image)
frame_index: Optional[int] = -1
segment: Optional[Union[torch.Tensor, dict]] = None # RLE dict or binary mask
is_crowd: bool = False
source: Optional[str] = None
@dataclass
class Image:
data: Union[torch.Tensor, PILImage.Image]
objects: List[Object]
size: Tuple[int, int] # (height, width)
# For blurring augmentation
blurring_mask: Optional[Dict[str, Any]] = None
@dataclass
class Datapoint:
"""Refers to an image/video and all its annotations"""
find_queries: List[FindQueryLoaded]
images: List[Image]
raw_images: Optional[List[PILImage.Image]] = None
class CustomCocoDetectionAPI(VisionDataset):
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
def __init__(
self,
root: str,
annFile: str,
load_segmentation: bool,
fix_fname: bool = False,
training: bool = True,
blurring_masks_path: Optional[str] = None,
use_caching: bool = True,
zstd_dict_path=None,
filter_query=None,
coco_json_loader: Callable = COCO_FROM_JSON,
limit_ids: int = None,
) -> None:
super().__init__(root)
self.annFile = annFile
self.use_caching = use_caching
self.zstd_dict_path = zstd_dict_path
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
self.load_segmentation = load_segmentation
self.fix_fname = fix_fname
self.filter_query = filter_query
self.coco = None
self.coco_json_loader = coco_json_loader
self.limit_ids = limit_ids
self.set_sharded_annotation_file(0)
self.training = training
self.blurring_masks_path = blurring_masks_path
def _load_images(
self, datapoint_id: int, img_ids_to_load: Optional[Set[int]] = None
) -> Tuple[List[Tuple[int, PILImage.Image]], List[Dict[str, Any]]]:
all_images = []
all_img_metadata = []
for current_meta in self.coco.loadImagesFromDatapoint(datapoint_id):
img_id = current_meta["id"]
if img_ids_to_load is not None and img_id not in img_ids_to_load:
continue
if self.fix_fname:
current_meta["file_name"] = current_meta["file_name"].split("/")[-1]
path = current_meta["file_name"]
if self.blurring_masks_path is not None:
mask_fname = os.path.basename(path).replace(".jpg", "-mask.json")
mask_path = os.path.join(self.blurring_masks_path, mask_fname)
if os.path.exists(mask_path):
with open(mask_path, "r") as fopen:
current_meta["blurring_mask"] = json.load(fopen)
all_img_metadata.append(current_meta)
path = os.path.join(self.root, path)
try:
if ".mp4" in path and path[-4:] == ".mp4":
# Going to load a video frame
video_path, frame = path.split("@")
video = VideoReader(video_path, ctx=cpu(0))
# Convert to PIL image
all_images.append(
(
img_id,
torchvision.transforms.ToPILImage()(
video[int(frame)].asnumpy()
),
)
)
else:
with g_pathmgr.open(path, "rb") as fopen:
all_images.append((img_id, PILImage.open(fopen).convert("RGB")))
except FileNotFoundError as e:
print(f"File not found: {path} from dataset: {self.annFile}")
raise e
return all_images, all_img_metadata
def set_curr_epoch(self, epoch: int):
self.curr_epoch = epoch
def set_epoch(self, epoch: int):
pass
def set_sharded_annotation_file(self, data_epoch: int):
if self.coco is not None:
return
assert g_pathmgr.isfile(
self.annFile
), f"please provide valid annotation file. Missing: {self.annFile}"
annFile = g_pathmgr.get_local_path(self.annFile)
if self.coco is not None:
del self.coco
self.coco = self.coco_json_loader(annFile)
# Use a torch tensor here to optimize memory usage when using several dataloaders
ids_list = list(sorted(self.coco.getDatapointIds()))
if self.limit_ids is not None:
local_random = random.Random(len(ids_list))
local_random.shuffle(ids_list)
ids_list = ids_list[: self.limit_ids]
self.ids = torch.as_tensor(ids_list, dtype=torch.long)
def __getitem__(self, index: int) -> Datapoint:
return self._load_datapoint(index)
def _load_datapoint(self, index: int) -> Datapoint:
"""A separate method for easy overriding in subclasses."""
id = self.ids[index].item()
pil_images, img_metadata = self._load_images(id)
queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id)
return self.load_queries(pil_images, annotations, queries, img_metadata)
def load_queries(self, pil_images, annotations, queries, img_metadata):
"""Transform the raw image and queries into a Datapoint sample."""
images: List[Image] = []
id2index_img = {}
id2index_obj = {}
id2index_find_query = {}
id2imsize = {}
assert len(pil_images) == len(img_metadata)
for i in range(len(pil_images)):
w, h = pil_images[i][1].size
blurring_mask = None
if "blurring_mask" in img_metadata[i]:
blurring_mask = img_metadata[i]["blurring_mask"]
images.append(
Image(
data=pil_images[i][1],
objects=[],
size=(h, w),
blurring_mask=blurring_mask,
)
)
id2index_img[pil_images[i][0]] = i
id2imsize[pil_images[i][0]] = (h, w)
for annotation in annotations:
image_id = id2index_img[annotation["image_id"]]
bbox = box_xywh_to_xyxy(torch.as_tensor(annotation["bbox"])).view(1, 4)
h, w = id2imsize[annotation["image_id"]]
bbox[:, 0::2].mul_(w).clamp_(min=0, max=w)
bbox[:, 1::2].mul_(h).clamp_(min=0, max=h)
segment = None
if self.load_segmentation and "segmentation" in annotation:
# We're not decoding the RLE here, a transform will do it lazily later
segment = annotation["segmentation"]
images[image_id].objects.append(
Object(
bbox=bbox[0],
area=annotation["area"],
object_id=(
annotation["object_id"] if "object_id" in annotation else -1
),
frame_index=(
annotation["frame_index"] if "frame_index" in annotation else -1
),
segment=segment,
is_crowd=(
annotation["is_crowd"] if "is_crowd" in annotation else None
),
source=annotation["source"] if "source" in annotation else "",
)
)
id2index_obj[annotation["id"]] = len(images[image_id].objects) - 1
find_queries = []
stage2num_queries = Counter()
for i, query in enumerate(queries):
stage2num_queries[query["query_processing_order"]] += 1
id2index_find_query[query["id"]] = i
# Sanity check: all the stages should have the same number of queries
if len(stage2num_queries) == 0:
num_queries_per_stage = 0
else:
num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
for stage, num_queries in stage2num_queries.items():
assert (
num_queries == num_queries_per_stage
), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
for query_id, query in enumerate(queries):
h, w = id2imsize[query["image_id"]]
if (
"input_box" in query
and query["input_box"] is not None
and len(query["input_box"]) > 0
):
bbox = box_xywh_to_xyxy(torch.as_tensor(query["input_box"])).view(-1, 4)
bbox[:, 0::2].mul_(w).clamp_(min=0, max=w)
bbox[:, 1::2].mul_(h).clamp_(min=0, max=h)
if "input_box_label" in query and query["input_box_label"] is not None:
bbox_label = torch.as_tensor(
query["input_box_label"], dtype=torch.long
).view(-1)
assert len(bbox_label) == len(bbox)
else:
# assume the boxes are positives
bbox_label = torch.ones(len(bbox), dtype=torch.long)
else:
bbox = None
bbox_label = None
if "input_points" in query and query["input_points"] is not None:
points = torch.as_tensor(query["input_points"]).view(1, -1, 3)
points[:, :, 0:1].mul_(w).clamp_(min=0, max=w)
points[:, :, 1:2].mul_(h).clamp_(min=0, max=h)
else:
points = None
try:
original_image_id = int(
img_metadata[id2index_img[query["image_id"]]]["original_img_id"]
)
except ValueError:
original_image_id = -1
try:
img_metadata_query = img_metadata[id2index_img[query["image_id"]]]
coco_image_id = (
int(img_metadata_query["coco_img_id"])
if "coco_img_id" in img_metadata_query
else query["id"]
)
except KeyError:
coco_image_id = -1
try:
original_category_id = int(query["original_cat_id"])
except (ValueError, KeyError):
original_category_id = -1
# For evaluation, we associate the ids of the object to be tracked to the query
if query["object_ids_output"]:
obj_id = query["object_ids_output"][0]
obj_idx = id2index_obj[obj_id]
image_idx = id2index_img[query["image_id"]]
object_id = images[image_idx].objects[obj_idx].object_id
frame_index = images[image_idx].objects[obj_idx].frame_index
else:
object_id = -1
frame_index = -1
find_queries.append(
FindQueryLoaded(
# id=query["id"],
# query_type=qtype,
query_text=(
query["query_text"] if query["query_text"] is not None else ""
),
image_id=id2index_img[query["image_id"]],
input_bbox=bbox,
input_bbox_label=bbox_label,
input_points=points,
object_ids_output=[
id2index_obj[obj_id] for obj_id in query["object_ids_output"]
],
is_exhaustive=query["is_exhaustive"],
is_pixel_exhaustive=(
query["is_pixel_exhaustive"]
if "is_pixel_exhaustive" in query
else (
query["is_exhaustive"] if query["is_exhaustive"] else None
)
),
query_processing_order=query["query_processing_order"],
inference_metadata=InferenceMetadata(
coco_image_id=-1 if self.training else coco_image_id,
original_image_id=(-1 if self.training else original_image_id),
frame_index=frame_index,
original_category_id=original_category_id,
original_size=(h, w),
object_id=object_id,
),
)
)
return Datapoint(
find_queries=find_queries,
images=images,
raw_images=[p[1] for p in pil_images],
)
def __len__(self) -> int:
return len(self.ids)
class Sam3ImageDataset(CustomCocoDetectionAPI):
def __init__(
self,
img_folder,
ann_file,
transforms,
max_ann_per_img: int,
multiplier: int,
training: bool,
load_segmentation: bool = False,
max_train_queries: int = 81,
max_val_queries: int = 300,
fix_fname: bool = False,
is_sharded_annotation_dir: bool = False,
blurring_masks_path: Optional[str] = None,
use_caching: bool = True,
zstd_dict_path=None,
filter_query=None,
coco_json_loader: Callable = COCO_FROM_JSON,
limit_ids: int = None,
):
super(Sam3ImageDataset, self).__init__(
img_folder,
ann_file,
fix_fname=fix_fname,
load_segmentation=load_segmentation,
training=training,
blurring_masks_path=blurring_masks_path,
use_caching=use_caching,
zstd_dict_path=zstd_dict_path,
filter_query=filter_query,
coco_json_loader=coco_json_loader,
limit_ids=limit_ids,
)
self._transforms = transforms
self.training = training
self.max_ann_per_img = max_ann_per_img
self.max_train_queries = max_train_queries
self.max_val_queries = max_val_queries
self.repeat_factors = torch.ones(len(self.ids), dtype=torch.float32)
self.repeat_factors *= multiplier
print(f"Raw dataset length = {len(self.ids)}")
self._MAX_RETRIES = 100
def __getitem__(self, idx):
return self.__orig_getitem__(idx)
def __orig_getitem__(self, idx):
for _ in range(self._MAX_RETRIES):
try:
datapoint = super(Sam3ImageDataset, self).__getitem__(idx)
# This can be done better by filtering the offending find queries
# However, this requires care:
# - Delete any find/get query that may depend on the deleted one
# - Re-compute the indexes in the pointers to account for the deleted finds
for q in datapoint.find_queries:
if len(q.object_ids_output) > self.max_ann_per_img:
raise DecompressionBombError(
f"Too many outputs ({len(q.object_ids_output)})"
)
max_queries = (
self.max_train_queries if self.training else self.max_val_queries
)
if len(datapoint.find_queries) > max_queries:
raise DecompressionBombError(
f"Too many find queries ({len(datapoint.find_queries)})"
)
if len(datapoint.find_queries) == 0:
raise DecompressionBombError("No find queries")
for transform in self._transforms:
datapoint = transform(datapoint, epoch=self.curr_epoch)
break
except (DecompressionBombError, OSError, ValueError) as error:
sys.stderr.write(f"ERROR: got loading error on datapoint {idx}\n")
sys.stderr.write(f"Exception: {error}\n")
sys.stderr.write(traceback.format_exc())
idx = (idx + 1) % len(self)
else:
raise RuntimeError(
f"Failed {self._MAX_RETRIES} times trying to load an image."
)
return datapoint

View File

@@ -0,0 +1,327 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import copy
import io
import json
import logging
import math
import os
import pickle
import random
import sys
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
import torchvision
# from decord import cpu, VideoReader
from iopath.common.file_io import PathManager
from PIL import Image as PILImage
from .sam3_image_dataset import Datapoint, Sam3ImageDataset
SEED = 42
class VideoGroundingDataset(Sam3ImageDataset):
def __init__(
self,
num_stages_sample: int = 4,
stage_stride_min: int = 1,
stage_stride_max: int = 5,
random_reverse_time_axis: bool = True,
is_tiling_single_image: bool = False,
# By default, we remove find those queries with geometric inputs (input_box or input_points)
# when creating synthetic videos from frames (since they are not *video-level* text prompts).
# If we need them later, we can sample them on-the-fly via transforms or inside the model.
tile_img_keep_find_queries_with_geo_inputs: bool = False,
tile_img_keep_get_queries: bool = False,
# the maximum number of find queries (for each frame) to keep in a video; if the datapoint
# contains more queries per frame than this limit, we subsample them to avoid OOM errors
max_query_num: int = -1, # the default -1 means no limit
# whether to override the "is_exhaustive" flag of the loaded find queries to True
# (by default, our video datasets are ingested with is_exhaustive=False, since the YTVIS format
# annotations doesn't involve an "is_exhaustive" flag; this means that those unmatched (negative)
# detection queries or tracking queries do not receive a classification loss given that we have
# weak_loss=True in IABCEMdetr -- this could lead to false positives for both image detection
# and video association.)
override_query_is_exhaustive_to_true: bool = False,
# the maximum number of masklets in a video; if the datapoint contains more masklets
# than this limit, we skip the datapoint to avoid OOM errors (this is useful for
# training with large videos that contain many objects)
max_masklet_num_in_video: int = 300, # 300 masklets is usually OK to avoid OOM
**kwargs,
):
"""
Loading video grounding data
Video frame sampling parameters (for training only):
- num_stages_sample: number of frames to sample from the video during training
- stage_stride_min: minimum stride between sampled frames during training
- stage_stride_max: maximum stride between sampled frames during training (if it's
greater than stage_stride_min, the actual stride is sampled uniformly between min
and max; during inference, we always use all frames in the video with stride=1)
- random_reverse_time_axis: whether to randomly invert the video's temporal axis
(i.e. playing it backwards) during training
"""
super().__init__(**kwargs)
assert num_stages_sample >= 1
assert stage_stride_min >= 1
assert stage_stride_max >= stage_stride_min
self.num_stages_sample = num_stages_sample
self.stage_stride_min = stage_stride_min
self.stage_stride_max = stage_stride_max
self.random_reverse_time_axis = random_reverse_time_axis
self.is_tiling_single_image = is_tiling_single_image
self.tile_img_keep_find_queries_with_geo_inputs = (
tile_img_keep_find_queries_with_geo_inputs
)
self.tile_img_keep_get_queries = tile_img_keep_get_queries
self.max_query_num = max_query_num
self.override_query_is_exhaustive_to_true = override_query_is_exhaustive_to_true
self.max_masklet_num_in_video = max_masklet_num_in_video
self.rng = random.Random()
self.set_curr_epoch(0)
def set_curr_epoch(self, epoch: int):
super().set_curr_epoch(epoch)
self.rng.seed(SEED + epoch)
def _load_datapoint(self, index: int) -> Datapoint:
id = self.ids[index].item()
queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id)
# we subsample the video frames during training
if self.training and not self.is_tiling_single_image:
# pick a random stride for sampling query stages (`randint` includes both ends)
stage_stride = self.rng.randint(
self.stage_stride_min, self.stage_stride_max
)
stage_ids_to_keep = self._sample_stage_ids(
queries, self.num_stages_sample, stage_stride
)
# filter the queries and annotations to keep only the selected stages
# (also remap the stage ids so that they are contiguous and start from 0)
reverse_time_axis = (
self.rng.random() < 0.5 if self.random_reverse_time_axis else False
)
queries, annotations, kept_img_ids = self._filter_query_and_anns(
queries,
annotations,
stage_ids_to_keep,
remap_stage_id=True,
reverse_time_axis=reverse_time_axis,
)
pil_images, img_metadata = self._load_images(id, kept_img_ids)
if reverse_time_axis:
# reverse the temporal ordering of the images and their metadata
# so that the image order matches the query order
pil_images = pil_images[::-1]
img_metadata = img_metadata[::-1]
else:
pil_images, img_metadata = self._load_images(id)
# check that all the images have the same image size (they are expected
# to have the same image size since they are frames from the same video)
assert all(p.size == pil_images[0][1].size for _, p in pil_images)
queries.sort(key=lambda q: q["query_processing_order"])
if self.override_query_is_exhaustive_to_true:
for query in queries:
query["is_exhaustive"] = True
datapoint = self.load_queries(pil_images, annotations, queries, img_metadata)
# skip datapoints with too many masklets to avoid OOM errors
num_masklets_in_video = len(datapoint.images[0].objects)
if num_masklets_in_video > self.max_masklet_num_in_video > 0:
logging.warning(
f"Datapoint {id} has ({num_masklets_in_video=}), exceeding "
f"the maximum allowed ({self.max_masklet_num_in_video}). "
"Skipping this datapoint."
)
next_index = (index + 1) % len(self)
return self._load_datapoint(next_index) # move to the next datapoint
if self.is_tiling_single_image:
datapoint = self._tile_single_image_data(datapoint, self.num_stages_sample)
if self.max_query_num > 0:
datapoint = self._subsample_queries(datapoint, self.max_query_num)
# ensure that all find queries have the same processing order as their image id
for query in datapoint.find_queries:
assert query.image_id == query.query_processing_order, (
f"find query has inconsistent image_id and "
f"query_processing_order: {query.image_id=} vs "
f"{query.query_processing_order=}"
)
return datapoint
def _sample_stage_ids(self, queries, num_stages_sample, stage_stride):
"""Sample a subset of stage ids from all queries."""
# Later we can perhaps turn it into a Sampler class to be more flexible.
all_stage_ids = sorted(set(q["query_processing_order"] for q in queries))
num_stages_total = len(all_stage_ids)
if num_stages_total < num_stages_sample:
raise ValueError("Not enough stages to sample")
# the difference in index between the first and the last sampled stage ids
b_e_gap = (num_stages_sample - 1) * stage_stride
if b_e_gap > num_stages_total - 1:
# In this case, it's not possible to sample with the provide stride,
# so we use the maximum possible stride.
prev_stage_stride = stage_stride
stage_stride = math.floor((num_stages_total - 1) / (num_stages_sample - 1))
logging.info(
f"lowering stride from {prev_stage_stride} to {stage_stride} to "
f"sample {num_stages_sample} stages (from {num_stages_total} total)"
)
b_e_gap = (num_stages_sample - 1) * stage_stride
# randomly select a starting stage id (`randint` includes both ends)
b_max = len(all_stage_ids) - 1 - b_e_gap
b = self.rng.randint(0, b_max)
e = b + b_e_gap
stage_ids_to_keep = all_stage_ids[b : e + 1 : stage_stride]
return stage_ids_to_keep
def _filter_query_and_anns(
self, queries, annotations, stage_ids_to_keep, remap_stage_id, reverse_time_axis
):
"""Filter queries and annotations to only keep those in `stage_ids_to_keep`."""
stage_ids_to_keep = set(stage_ids_to_keep)
kept_img_ids = set()
kept_stage_ids = set()
# Filter queries -- keep those queries with stage_id in `stage_ids_to_keep`
filtered_queries = []
for query in queries:
input_box = query.get("input_box", None)
input_points = query.get("input_points", None)
has_geo_input = input_box is not None or input_points is not None
if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs:
continue
stage_id = query["query_processing_order"]
if stage_id in stage_ids_to_keep:
kept_img_ids.add(query["image_id"])
kept_stage_ids.add(stage_id)
filtered_queries.append(query)
# Check that all frames in `stage_ids_to_keep` are present after filtering
all_frame_present = kept_stage_ids == stage_ids_to_keep
assert all_frame_present, f"{kept_stage_ids=} vs {stage_ids_to_keep=}"
if remap_stage_id:
# Remap those kept stage ids to be contiguous and starting from 0
old_stage_ids = sorted(kept_stage_ids, reverse=reverse_time_axis)
stage_id_old2new = {old: new for new, old in enumerate(old_stage_ids)}
for query in filtered_queries:
ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
assert (
ptr_x_is_empty and ptr_y_is_empty
), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
query["query_processing_order"] = stage_id_old2new[
query["query_processing_order"]
]
# Filter annotations -- keep those annotations with image_id in `kept_img_ids`
filtered_annotations = [
ann for ann in annotations if ann["image_id"] in kept_img_ids
]
return filtered_queries, filtered_annotations, kept_img_ids
def _tile_single_image_data(self, datapoint: Datapoint, num_stages_sample: int):
"""
Tile a single image and its queries to simulate video frames. The output is a
datapoint with *identical video frames* (i.e. the same static image) and needs
further transforms (e.g. affine) to get video frames with different content.
"""
# tile `images: List[Image]`
assert len(datapoint.images) == 1, "Expected only one single image"
tiled_images = [
copy.deepcopy(datapoint.images[0]) for _ in range(num_stages_sample)
]
for stage_id, img in enumerate(tiled_images):
for obj in img.objects:
obj.frame_index = stage_id
# tile `raw_images: Optional[List[PILImage.Image]] = None`
tiled_raw_images = None
if datapoint.raw_images is not None:
assert len(datapoint.raw_images) == 1, "Expected only one single image"
tiled_raw_images = [
datapoint.raw_images[0].copy() for _ in range(num_stages_sample)
]
# tile `find_queries: List[FindQueryLoaded]`
tiled_find_queries_per_stage = [[] for _ in range(num_stages_sample)]
for query in datapoint.find_queries:
assert query.image_id == 0
assert query.query_processing_order == 0
# check and make sure that a query doesn't contain pointers or references
# to other queries (that cannot be tiled)
assert query.ptr_x is None and query.ptr_y is None
assert query.ptr_mem is None
# assert query.wkdata_qid is None
# assert query.other_positive_qids is None
# assert query.negative_qids is None
has_geo_input = (
query.input_bbox is not None or query.input_points is not None
)
if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs:
continue
for stage_id in range(num_stages_sample):
# copy the query and update the image_id
new_query = copy.deepcopy(query)
new_query.image_id = stage_id
new_query.query_processing_order = stage_id
if new_query.inference_metadata is not None:
new_query.inference_metadata.frame_index = stage_id
tiled_find_queries_per_stage[stage_id].append(new_query)
tiled_find_queries = sum(tiled_find_queries_per_stage, [])
# tile `get_queries: List[GetQuery]` -- we skip them for now (since they involve
# a pointer to a find query that is complicated to tile, and there is not an
# imminent use case for them in the video grounding task in the near future)
if self.tile_img_keep_get_queries:
raise NotImplementedError("Tiling get queries is not implemented yet")
else:
tiled_get_queries = []
return Datapoint(
images=tiled_images,
raw_images=tiled_raw_images,
find_queries=tiled_find_queries,
get_queries=tiled_get_queries,
)
def _subsample_queries(self, datapoint: Datapoint, max_query_num: int):
"""Subsample to keep at most `max_query_num` queries per frame in a datapoint."""
# aggregate the find queries per stage
num_frames = max(q.query_processing_order for q in datapoint.find_queries) + 1
find_queries_per_stage = [[] for _ in range(num_frames)]
for query in datapoint.find_queries:
find_queries_per_stage[query.query_processing_order].append(query)
# verify that all the stages have the same number of queries
num_queries_per_stage = len(find_queries_per_stage[0])
for queries in find_queries_per_stage:
assert len(queries) == num_queries_per_stage
if max_query_num <= 0 or num_queries_per_stage <= max_query_num:
return datapoint
# subsample the queries to keep only `max_query_num` queries
sampled_inds = self.rng.sample(range(num_queries_per_stage), max_query_num)
sampled_find_queries_per_stage = [
[queries[idx] for idx in sampled_inds] for queries in find_queries_per_stage
]
sampled_find_queries = sum(sampled_find_queries_per_stage, [])
return Datapoint(
images=datapoint.images,
raw_images=datapoint.raw_images,
find_queries=sampled_find_queries,
get_queries=datapoint.get_queries,
)

View File

@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Callable, Iterable, Optional
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset
class TorchDataset:
def __init__(
self,
dataset: Dataset,
batch_size: int,
num_workers: int,
shuffle: bool,
pin_memory: bool,
drop_last: bool,
collate_fn: Optional[Callable] = None,
worker_init_fn: Optional[Callable] = None,
enable_distributed_sampler=True,
) -> None:
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
self.collate_fn = collate_fn
self.worker_init_fn = worker_init_fn
assert not isinstance(self.dataset, IterableDataset), "Not supported yet"
if enable_distributed_sampler:
self.sampler = DistributedSampler(self.dataset, shuffle=self.shuffle)
else:
self.sampler = None
def get_loader(self, epoch) -> Iterable:
if self.sampler:
self.sampler.set_epoch(epoch)
if hasattr(self.dataset, "epoch"):
self.dataset.epoch = epoch
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
return DataLoader(
self.dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
drop_last=self.drop_last,
sampler=self.sampler,
collate_fn=self.collate_fn,
worker_init_fn=self.worker_init_fn,
)