Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
1
sam3/train/data/__init__.py
Normal file
1
sam3/train/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
465
sam3/train/data/coco_json_loaders.py
Normal file
465
sam3/train/data/coco_json_loaders.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from pycocotools import mask as mask_util
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Utility Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def convert_boxlist_to_normalized_tensor(box_list, image_width, image_height):
|
||||
"""
|
||||
Converts a list of bounding boxes to a normalized PyTorch tensor.
|
||||
|
||||
Args:
|
||||
box_list (list of list or tuples): Each box is [x_min, y_min, x_max, y_max].
|
||||
image_width (int or float): Width of the image.
|
||||
image_height (int or float): Height of the image.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Normalized tensor of shape (N, 4), values in [0, 1].
|
||||
"""
|
||||
boxes = torch.tensor(box_list, dtype=torch.float32)
|
||||
boxes[:, [0, 2]] /= image_width # x_min, x_max
|
||||
boxes[:, [1, 3]] /= image_height # y_min, y_max
|
||||
boxes = boxes.clamp(0, 1)
|
||||
return boxes
|
||||
|
||||
|
||||
def load_coco_and_group_by_image(json_path: str) -> Tuple[List[Dict], Dict[int, str]]:
|
||||
"""
|
||||
Load COCO JSON file and group annotations by image.
|
||||
|
||||
Args:
|
||||
json_path (str): Path to COCO JSON file.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of dicts with 'image' and 'annotations' keys
|
||||
- Dict mapping category IDs to category names
|
||||
"""
|
||||
with open(json_path, "r") as f:
|
||||
coco = json.load(f)
|
||||
|
||||
images = {img["id"]: img for img in coco["images"]}
|
||||
|
||||
anns_by_image = defaultdict(list)
|
||||
for ann in coco["annotations"]:
|
||||
anns_by_image[ann["image_id"]].append(ann)
|
||||
|
||||
sorted_image_ids = sorted(images.keys())
|
||||
|
||||
grouped = []
|
||||
for image_id in sorted_image_ids:
|
||||
image_info = images[image_id]
|
||||
grouped.append(
|
||||
{"image": image_info, "annotations": anns_by_image.get(image_id, [])}
|
||||
)
|
||||
|
||||
cat_id_to_name = {cat["id"]: cat["name"] for cat in coco["categories"]}
|
||||
|
||||
return grouped, cat_id_to_name
|
||||
|
||||
|
||||
def ann_to_rle(segm, im_info: Dict) -> Dict:
|
||||
"""
|
||||
Convert annotation which can be polygons or uncompressed RLE to RLE.
|
||||
|
||||
Args:
|
||||
segm: Segmentation data (polygon list or RLE dict)
|
||||
im_info (dict): Image info containing 'height' and 'width'
|
||||
|
||||
Returns:
|
||||
RLE encoded segmentation
|
||||
"""
|
||||
h, w = im_info["height"], im_info["width"]
|
||||
|
||||
if isinstance(segm, list):
|
||||
# Polygon - merge all parts into one mask RLE code
|
||||
rles = mask_util.frPyObjects(segm, h, w)
|
||||
rle = mask_util.merge(rles)
|
||||
elif isinstance(segm["counts"], list):
|
||||
# Uncompressed RLE
|
||||
rle = mask_util.frPyObjects(segm, h, w)
|
||||
else:
|
||||
# Already RLE
|
||||
rle = segm
|
||||
|
||||
return rle
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# COCO Training API
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class COCO_FROM_JSON:
|
||||
"""
|
||||
COCO training API for loading box-only annotations from JSON.
|
||||
Groups all annotations per image and creates queries per category.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
annotation_file,
|
||||
prompts=None,
|
||||
include_negatives=True,
|
||||
category_chunk_size=None,
|
||||
):
|
||||
"""
|
||||
Initialize the COCO training API.
|
||||
|
||||
Args:
|
||||
annotation_file (str): Path to COCO JSON annotation file
|
||||
prompts: Optional custom prompts for categories
|
||||
include_negatives (bool): Whether to include negative examples (categories with no instances)
|
||||
"""
|
||||
self._raw_data, self._cat_idx_to_text = load_coco_and_group_by_image(
|
||||
annotation_file
|
||||
)
|
||||
self._sorted_cat_ids = sorted(list(self._cat_idx_to_text.keys()))
|
||||
self.prompts = None
|
||||
self.include_negatives = include_negatives
|
||||
self.category_chunk_size = (
|
||||
category_chunk_size
|
||||
if category_chunk_size is not None
|
||||
else len(self._sorted_cat_ids)
|
||||
)
|
||||
self.category_chunks = [
|
||||
self._sorted_cat_ids[i : i + self.category_chunk_size]
|
||||
for i in range(0, len(self._sorted_cat_ids), self.category_chunk_size)
|
||||
]
|
||||
if prompts is not None:
|
||||
prompts = eval(prompts)
|
||||
self.prompts = {}
|
||||
for loc_dict in prompts:
|
||||
self.prompts[int(loc_dict["id"])] = loc_dict["name"]
|
||||
assert len(self.prompts) == len(
|
||||
self._sorted_cat_ids
|
||||
), "Number of prompts must match number of categories"
|
||||
|
||||
def getDatapointIds(self):
|
||||
"""Return all datapoint indices for training."""
|
||||
return list(range(len(self._raw_data) * len(self.category_chunks)))
|
||||
|
||||
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
|
||||
"""
|
||||
Load queries and annotations for a specific datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
Tuple of (queries, annotations) lists
|
||||
"""
|
||||
img_idx = idx // len(self.category_chunks)
|
||||
chunk_idx = idx % len(self.category_chunks)
|
||||
cat_chunk = self.category_chunks[chunk_idx]
|
||||
|
||||
queries = []
|
||||
annotations = []
|
||||
|
||||
query_template = {
|
||||
"id": None,
|
||||
"original_cat_id": None,
|
||||
"object_ids_output": None,
|
||||
"query_text": None,
|
||||
"query_processing_order": 0,
|
||||
"ptr_x_query_id": None,
|
||||
"ptr_y_query_id": None,
|
||||
"image_id": 0, # Single image per datapoint
|
||||
"input_box": None,
|
||||
"input_box_label": None,
|
||||
"input_points": None,
|
||||
"is_exhaustive": True,
|
||||
}
|
||||
|
||||
annot_template = {
|
||||
"image_id": 0,
|
||||
"bbox": None, # Normalized bbox in xywh
|
||||
"area": None, # Unnormalized area
|
||||
"segmentation": None, # RLE encoded
|
||||
"object_id": None,
|
||||
"is_crowd": None,
|
||||
"id": None,
|
||||
}
|
||||
|
||||
raw_annotations = self._raw_data[img_idx]["annotations"]
|
||||
image_info = self._raw_data[img_idx]["image"]
|
||||
width, height = image_info["width"], image_info["height"]
|
||||
|
||||
# Group annotations by category
|
||||
cat_id_to_anns = defaultdict(list)
|
||||
for ann in raw_annotations:
|
||||
cat_id_to_anns[ann["category_id"]].append(ann)
|
||||
|
||||
annotations_by_cat_sorted = [
|
||||
(cat_id, cat_id_to_anns[cat_id]) for cat_id in cat_chunk
|
||||
]
|
||||
|
||||
for cat_id, anns in annotations_by_cat_sorted:
|
||||
if len(anns) == 0 and not self.include_negatives:
|
||||
continue
|
||||
|
||||
cur_ann_ids = []
|
||||
|
||||
# Create annotations for this category
|
||||
for ann in anns:
|
||||
annotation = annot_template.copy()
|
||||
annotation["id"] = len(annotations)
|
||||
annotation["object_id"] = annotation["id"]
|
||||
annotation["is_crowd"] = ann["iscrowd"]
|
||||
|
||||
normalized_boxes = convert_boxlist_to_normalized_tensor(
|
||||
[ann["bbox"]], width, height
|
||||
)
|
||||
bbox = normalized_boxes[0]
|
||||
|
||||
annotation["area"] = (bbox[2] * bbox[3]).item()
|
||||
annotation["bbox"] = bbox
|
||||
|
||||
if (
|
||||
"segmentation" in ann
|
||||
and ann["segmentation"] is not None
|
||||
and ann["segmentation"] != []
|
||||
):
|
||||
annotation["segmentation"] = ann_to_rle(
|
||||
ann["segmentation"], im_info=image_info
|
||||
)
|
||||
|
||||
annotations.append(annotation)
|
||||
cur_ann_ids.append(annotation["id"])
|
||||
|
||||
# Create query for this category
|
||||
query = query_template.copy()
|
||||
query["id"] = len(queries)
|
||||
query["original_cat_id"] = cat_id
|
||||
query["query_text"] = (
|
||||
self._cat_idx_to_text[cat_id]
|
||||
if self.prompts is None
|
||||
else self.prompts[cat_id]
|
||||
)
|
||||
query["object_ids_output"] = cur_ann_ids
|
||||
queries.append(query)
|
||||
|
||||
return queries, annotations
|
||||
|
||||
def loadImagesFromDatapoint(self, idx):
|
||||
"""
|
||||
Load image information for a specific datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
List containing image info dict
|
||||
"""
|
||||
img_idx = idx // len(self.category_chunks)
|
||||
img_data = self._raw_data[img_idx]["image"]
|
||||
images = [
|
||||
{
|
||||
"id": 0,
|
||||
"file_name": img_data["file_name"],
|
||||
"original_img_id": img_data["id"],
|
||||
"coco_img_id": img_data["id"],
|
||||
}
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SAM3 Evaluation APIs
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SAM3_EVAL_API_FROM_JSON_NP:
|
||||
"""
|
||||
SAM3 evaluation API for loading noun phrase queries from JSON.
|
||||
"""
|
||||
|
||||
def __init__(self, annotation_file):
|
||||
"""
|
||||
Initialize the SAM3 evaluation API.
|
||||
|
||||
Args:
|
||||
annotation_file (str): Path to SAM3 JSON annotation file
|
||||
"""
|
||||
with open(annotation_file, "r") as f:
|
||||
data = json.load(f)
|
||||
self._image_data = data["images"]
|
||||
|
||||
def getDatapointIds(self):
|
||||
"""Return all datapoint indices."""
|
||||
return list(range(len(self._image_data)))
|
||||
|
||||
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
|
||||
"""
|
||||
Load queries and annotations for a specific datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
Tuple of (queries, annotations) lists
|
||||
"""
|
||||
cur_img_data = self._image_data[idx]
|
||||
queries = []
|
||||
annotations = []
|
||||
|
||||
query_template = {
|
||||
"id": None,
|
||||
"original_cat_id": None,
|
||||
"object_ids_output": None,
|
||||
"query_text": None,
|
||||
"query_processing_order": 0,
|
||||
"ptr_x_query_id": None,
|
||||
"ptr_y_query_id": None,
|
||||
"image_id": 0,
|
||||
"input_box": None,
|
||||
"input_box_label": None,
|
||||
"input_points": None,
|
||||
"is_exhaustive": True,
|
||||
}
|
||||
|
||||
# Create query
|
||||
query = query_template.copy()
|
||||
query["id"] = len(queries)
|
||||
query["original_cat_id"] = int(cur_img_data["queried_category"])
|
||||
query["query_text"] = cur_img_data["text_input"]
|
||||
query["object_ids_output"] = []
|
||||
queries.append(query)
|
||||
|
||||
return queries, annotations
|
||||
|
||||
def loadImagesFromDatapoint(self, idx):
|
||||
"""
|
||||
Load image information for a specific datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
List containing image info dict
|
||||
"""
|
||||
img_data = self._image_data[idx]
|
||||
images = [
|
||||
{
|
||||
"id": 0,
|
||||
"file_name": img_data["file_name"],
|
||||
"original_img_id": img_data["id"],
|
||||
"coco_img_id": img_data["id"],
|
||||
}
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
class SAM3_VEVAL_API_FROM_JSON_NP:
|
||||
"""
|
||||
SAM3 video evaluation API for loading noun phrase queries from JSON.
|
||||
"""
|
||||
|
||||
def __init__(self, annotation_file):
|
||||
"""
|
||||
Initialize the SAM3 video evaluation API.
|
||||
|
||||
Args:
|
||||
annotation_file (str): Path to SAM3 video JSON annotation file
|
||||
"""
|
||||
with open(annotation_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert "video_np_pairs" in data, "Incorrect data format"
|
||||
|
||||
self._video_data = data["videos"]
|
||||
self._video_id_to_np_ids = defaultdict(list)
|
||||
self._cat_id_to_np = {}
|
||||
|
||||
for cat_dict in data["categories"]:
|
||||
self._cat_id_to_np[cat_dict["id"]] = cat_dict["name"]
|
||||
|
||||
for video_np_dict in data["video_np_pairs"]:
|
||||
self._video_id_to_np_ids[video_np_dict["video_id"]].append(
|
||||
video_np_dict["category_id"]
|
||||
)
|
||||
assert (
|
||||
self._cat_id_to_np[video_np_dict["category_id"]]
|
||||
== video_np_dict["noun_phrase"]
|
||||
), "Category name does not match text input"
|
||||
|
||||
def getDatapointIds(self):
|
||||
"""Return all datapoint indices."""
|
||||
return list(range(len(self._video_data)))
|
||||
|
||||
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
|
||||
"""
|
||||
Load queries and annotations for a specific video datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
Tuple of (queries, annotations) lists
|
||||
"""
|
||||
cur_vid_data = self._video_data[idx]
|
||||
queries = []
|
||||
annotations = []
|
||||
|
||||
query_template = {
|
||||
"id": None,
|
||||
"original_cat_id": None,
|
||||
"object_ids_output": None,
|
||||
"query_text": None,
|
||||
"query_processing_order": 0,
|
||||
"ptr_x_query_id": None,
|
||||
"ptr_y_query_id": None,
|
||||
"image_id": 0,
|
||||
"input_box": None,
|
||||
"input_box_label": None,
|
||||
"input_points": None,
|
||||
"is_exhaustive": True,
|
||||
}
|
||||
|
||||
all_np_ids = self._video_id_to_np_ids[cur_vid_data["id"]]
|
||||
|
||||
for np_id in all_np_ids:
|
||||
text_input = self._cat_id_to_np[np_id]
|
||||
|
||||
for i, image_path in enumerate(cur_vid_data["file_names"]):
|
||||
query = query_template.copy()
|
||||
query["id"] = len(queries)
|
||||
query["original_cat_id"] = np_id
|
||||
query["query_text"] = text_input
|
||||
query["image_id"] = i
|
||||
query["query_processing_order"] = i
|
||||
query["object_ids_output"] = []
|
||||
queries.append(query)
|
||||
|
||||
return queries, annotations
|
||||
|
||||
def loadImagesFromDatapoint(self, idx):
|
||||
"""
|
||||
Load image information for a specific video datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
List containing image info dicts for all frames
|
||||
"""
|
||||
video_data = self._video_data[idx]
|
||||
images = [
|
||||
{
|
||||
"id": i,
|
||||
"file_name": file_name,
|
||||
"original_img_id": video_data["id"],
|
||||
"coco_img_id": video_data["id"],
|
||||
}
|
||||
for i, file_name in enumerate(video_data["file_names"])
|
||||
]
|
||||
return images
|
||||
360
sam3/train/data/collator.py
Normal file
360
sam3/train/data/collator.py
Normal file
@@ -0,0 +1,360 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
|
||||
from typing import Any, get_args, get_origin, List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.data_misc import (
|
||||
BatchedDatapoint,
|
||||
BatchedFindTarget,
|
||||
BatchedInferenceMetadata,
|
||||
FindStage,
|
||||
)
|
||||
|
||||
from .sam3_image_dataset import Datapoint
|
||||
|
||||
|
||||
MyTensor = Union[torch.Tensor, List[Any]]
|
||||
|
||||
|
||||
def convert_my_tensors(obj):
|
||||
def is_optional_field(field) -> bool:
|
||||
return get_origin(field) is Union and type(None) in get_args(field)
|
||||
|
||||
for field in fields(obj):
|
||||
if is_dataclass(getattr(obj, field.name)):
|
||||
convert_my_tensors(getattr(obj, field.name))
|
||||
continue
|
||||
|
||||
field_type = field.type
|
||||
if is_optional_field(field.type):
|
||||
field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type
|
||||
|
||||
if field_type != MyTensor or getattr(obj, field.name) is None:
|
||||
continue
|
||||
|
||||
elif len(getattr(obj, field.name)) and isinstance(
|
||||
getattr(obj, field.name)[0], torch.Tensor
|
||||
):
|
||||
stack_dim = 0
|
||||
if field.name in [
|
||||
"input_boxes",
|
||||
"input_boxes_label",
|
||||
]:
|
||||
stack_dim = 1
|
||||
setattr(
|
||||
obj,
|
||||
field.name,
|
||||
torch.stack(getattr(obj, field.name), dim=stack_dim).to(
|
||||
getattr(obj, field.name + "__type")
|
||||
),
|
||||
)
|
||||
else:
|
||||
setattr(
|
||||
obj,
|
||||
field.name,
|
||||
torch.as_tensor(
|
||||
getattr(obj, field.name), dtype=getattr(obj, field.name + "__type")
|
||||
),
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
def packed_to_padded_naive(boxes_packed, num_boxes, fill_value=0):
|
||||
"""
|
||||
Convert a packed tensor of bounding boxes to a padded tensor of bounding
|
||||
boxes. Naive implementation using a loop.
|
||||
|
||||
Inputs:
|
||||
- boxes_packed: Tensor of shape (N_1 + ... + N_B, 4)
|
||||
- num_boxes: Tensor of shape (B,) where num_boxes[i] = N_i
|
||||
|
||||
Returns:
|
||||
- boxes_padded: Tensor of shape (B, N_max, 4) where N_max = max_i N_i
|
||||
"""
|
||||
B = num_boxes.shape[0]
|
||||
Ns = num_boxes.tolist()
|
||||
|
||||
boxes_padded = boxes_packed.new_zeros(B, max(Ns), *boxes_packed.shape[1:])
|
||||
if fill_value != 0:
|
||||
boxes_padded[...] = fill_value
|
||||
prev_idx = 0
|
||||
for i in range(B):
|
||||
next_idx = prev_idx + Ns[i]
|
||||
boxes_padded[i, : Ns[i]] = boxes_packed[prev_idx:next_idx]
|
||||
prev_idx = next_idx
|
||||
return boxes_padded
|
||||
|
||||
|
||||
def pad_tensor_list_to_longest(
|
||||
tensors: List[torch.Tensor], dim=0, pad_val=0
|
||||
) -> List[torch.Tensor]:
|
||||
# Edits the list in-place
|
||||
if not tensors:
|
||||
return tensors
|
||||
pad_len = max(t.shape[dim] for t in tensors)
|
||||
for i in range(len(tensors)):
|
||||
n_dims = len(tensors[i].shape)
|
||||
n_right_dims = (n_dims - 1) - (n_dims + dim) % n_dims
|
||||
n_pad = pad_len - tensors[i].shape[dim]
|
||||
pad_tuple = tuple([0] * 2 * n_right_dims + [0, n_pad])
|
||||
tensors[i] = torch.nn.functional.pad(tensors[i], pad_tuple, value=pad_val)
|
||||
return tensors
|
||||
|
||||
|
||||
def collate_fn_api_with_chunking(
|
||||
batch,
|
||||
num_chunks,
|
||||
dict_key,
|
||||
with_seg_masks=False,
|
||||
input_points_embedding_dim=257,
|
||||
repeats: int = 0,
|
||||
load_image_in_fp16: bool = False,
|
||||
):
|
||||
assert num_chunks >= 1, "num_chunks must be >= 1"
|
||||
|
||||
# split the batch into num_chunks chunks
|
||||
batch_chunks = [batch[i::num_chunks] for i in range(num_chunks)]
|
||||
|
||||
# collate each chunk
|
||||
collated_chunks = [
|
||||
collate_fn_api(
|
||||
chunk,
|
||||
dict_key,
|
||||
with_seg_masks,
|
||||
input_points_embedding_dim,
|
||||
repeats,
|
||||
# ptr_behaviour,
|
||||
load_image_in_fp16,
|
||||
)
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
return collated_chunks
|
||||
|
||||
|
||||
def collate_fn_api(
|
||||
batch: List[Datapoint],
|
||||
dict_key,
|
||||
with_seg_masks=False,
|
||||
input_points_embedding_dim=257,
|
||||
repeats: int = 0,
|
||||
load_image_in_fp16: bool = False,
|
||||
):
|
||||
# img_batch = torch.stack(sum([[img.data for img in v.images] for v in batch], []))
|
||||
img_batch = []
|
||||
text_batch = []
|
||||
raw_images = None
|
||||
|
||||
num_stages = (
|
||||
max(q.query_processing_order for data in batch for q in data.find_queries) + 1
|
||||
)
|
||||
|
||||
stages = [
|
||||
FindStage(
|
||||
img_ids=[],
|
||||
text_ids=[],
|
||||
input_boxes=[],
|
||||
input_boxes_label=[],
|
||||
input_boxes_mask=[],
|
||||
input_points=[],
|
||||
input_points_mask=[],
|
||||
object_ids=[],
|
||||
)
|
||||
for _ in range(num_stages)
|
||||
]
|
||||
find_targets = [
|
||||
BatchedFindTarget(
|
||||
num_boxes=[],
|
||||
boxes=[],
|
||||
boxes_padded=[],
|
||||
is_exhaustive=[],
|
||||
segments=[],
|
||||
semantic_segments=[],
|
||||
is_valid_segment=[],
|
||||
repeated_boxes=[],
|
||||
object_ids=[],
|
||||
object_ids_padded=[],
|
||||
)
|
||||
for _ in range(num_stages)
|
||||
]
|
||||
find_metadatas = [
|
||||
BatchedInferenceMetadata(
|
||||
coco_image_id=[],
|
||||
original_size=[],
|
||||
object_id=[],
|
||||
frame_index=[],
|
||||
original_image_id=[],
|
||||
original_category_id=[],
|
||||
is_conditioning_only=[],
|
||||
)
|
||||
for _ in range(num_stages)
|
||||
]
|
||||
|
||||
offset_img_id = 0
|
||||
offset_query_id = [0 for _ in range(num_stages)]
|
||||
for i, data in enumerate(batch):
|
||||
img_batch.extend([img.data for img in data.images])
|
||||
|
||||
if data.raw_images is not None:
|
||||
if raw_images is None:
|
||||
raw_images = []
|
||||
raw_images.extend(data.raw_images)
|
||||
|
||||
# Conversion of query_ids indexing in a datapoint to query_ids indexing in a stage
|
||||
datapoint_query_id_2_stage_query_id = []
|
||||
for q in data.find_queries:
|
||||
stage_id = q.query_processing_order
|
||||
datapoint_query_id_2_stage_query_id.append(offset_query_id[stage_id])
|
||||
offset_query_id[stage_id] += 1
|
||||
|
||||
for j, q in enumerate(data.find_queries):
|
||||
stage_id = q.query_processing_order
|
||||
stages[stage_id].img_ids.append(q.image_id + offset_img_id)
|
||||
if q.query_text not in text_batch:
|
||||
text_batch.append(q.query_text)
|
||||
stages[stage_id].text_ids.append(text_batch.index(q.query_text))
|
||||
|
||||
assert (
|
||||
q.inference_metadata is not None
|
||||
), "inference_metadata must be provided when FindQueryLoaded is created."
|
||||
for f in fields(q.inference_metadata):
|
||||
getattr(find_metadatas[stage_id], f.name).append(
|
||||
getattr(q.inference_metadata, f.name)
|
||||
)
|
||||
|
||||
if q.input_bbox is not None:
|
||||
assert q.input_bbox.numel() % 4 == 0
|
||||
assert q.input_bbox_label is not None
|
||||
nb_boxes = q.input_bbox.numel() // 4
|
||||
assert len(q.input_bbox_label) == nb_boxes
|
||||
stages[stage_id].input_boxes.append(q.input_bbox.view(nb_boxes, 4))
|
||||
stages[stage_id].input_boxes_label.append(
|
||||
q.input_bbox_label.view(nb_boxes)
|
||||
)
|
||||
stages[stage_id].input_boxes_mask.append(
|
||||
torch.zeros(nb_boxes, dtype=torch.bool)
|
||||
)
|
||||
else:
|
||||
stages[stage_id].input_boxes.append(torch.zeros(0, 4))
|
||||
stages[stage_id].input_boxes_label.append(
|
||||
torch.zeros(0, dtype=torch.bool)
|
||||
)
|
||||
stages[stage_id].input_boxes_mask.append(
|
||||
torch.ones(0, dtype=torch.bool)
|
||||
)
|
||||
|
||||
if q.input_points is not None:
|
||||
stages[stage_id].input_points.append(
|
||||
q.input_points.squeeze(0) # Strip a trivial batch index
|
||||
)
|
||||
# All masks will be padded up to the longest length
|
||||
# with 1s before final conversion to batchd tensors
|
||||
stages[stage_id].input_points_mask.append(
|
||||
torch.zeros(q.input_points.shape[1])
|
||||
)
|
||||
else:
|
||||
stages[stage_id].input_points.append(
|
||||
torch.empty(0, input_points_embedding_dim)
|
||||
)
|
||||
stages[stage_id].input_points_mask.append(torch.empty(0))
|
||||
|
||||
current_out_boxes = []
|
||||
current_out_object_ids = []
|
||||
# Set the object ids referred to by this query
|
||||
stages[stage_id].object_ids.append(q.object_ids_output)
|
||||
for object_id in q.object_ids_output:
|
||||
current_out_boxes.append(
|
||||
data.images[q.image_id].objects[object_id].bbox
|
||||
)
|
||||
current_out_object_ids.append(object_id)
|
||||
find_targets[stage_id].boxes.extend(current_out_boxes)
|
||||
find_targets[stage_id].object_ids.extend(current_out_object_ids)
|
||||
if repeats > 0:
|
||||
for _ in range(repeats):
|
||||
find_targets[stage_id].repeated_boxes.extend(current_out_boxes)
|
||||
find_targets[stage_id].num_boxes.append(len(current_out_boxes))
|
||||
find_targets[stage_id].is_exhaustive.append(q.is_exhaustive)
|
||||
|
||||
if with_seg_masks:
|
||||
current_seg_mask = []
|
||||
current_is_valid_segment = []
|
||||
for object_id in q.object_ids_output:
|
||||
seg_mask = data.images[q.image_id].objects[object_id].segment
|
||||
if seg_mask is not None:
|
||||
current_seg_mask.append(seg_mask)
|
||||
current_is_valid_segment.append(1)
|
||||
else:
|
||||
dummy_mask = torch.zeros(
|
||||
data.images[q.image_id].data.shape[-2:], dtype=torch.bool
|
||||
)
|
||||
current_seg_mask.append(dummy_mask)
|
||||
current_is_valid_segment.append(0)
|
||||
find_targets[stage_id].segments.extend(current_seg_mask)
|
||||
find_targets[stage_id].is_valid_segment.extend(current_is_valid_segment)
|
||||
else:
|
||||
# We are not loading segmentation masks
|
||||
find_targets[stage_id].segments = None
|
||||
find_targets[stage_id].is_valid_segment = None
|
||||
|
||||
if q.semantic_target is not None:
|
||||
find_targets[stage_id].semantic_segments.append(q.semantic_target)
|
||||
|
||||
offset_img_id += len(data.images)
|
||||
|
||||
# Pad input points to equal sequence lengths
|
||||
for i in range(len(stages)):
|
||||
stages[i].input_points = pad_tensor_list_to_longest(
|
||||
stages[i].input_points, dim=0, pad_val=0
|
||||
)
|
||||
# Masked-out regions indicated by 1s.
|
||||
stages[i].input_points_mask = pad_tensor_list_to_longest(
|
||||
stages[i].input_points_mask, dim=0, pad_val=1
|
||||
)
|
||||
|
||||
# Pad input boxes to equal sequence lengths
|
||||
for i in range(len(stages)):
|
||||
stages[i].input_boxes = pad_tensor_list_to_longest(
|
||||
stages[i].input_boxes, dim=0, pad_val=0
|
||||
)
|
||||
stages[i].input_boxes_label = pad_tensor_list_to_longest(
|
||||
stages[i].input_boxes_label, dim=0, pad_val=0
|
||||
)
|
||||
# Masked-out regions indicated by 1s.
|
||||
stages[i].input_boxes_mask = pad_tensor_list_to_longest(
|
||||
stages[i].input_boxes_mask, dim=0, pad_val=1
|
||||
)
|
||||
|
||||
# Convert to tensors
|
||||
for i in range(len(stages)):
|
||||
stages[i] = convert_my_tensors(stages[i])
|
||||
find_targets[i] = convert_my_tensors(find_targets[i])
|
||||
find_metadatas[i] = convert_my_tensors(find_metadatas[i])
|
||||
# get padded representation for the boxes
|
||||
find_targets[i].boxes_padded = packed_to_padded_naive(
|
||||
find_targets[i].boxes.view(-1, 4), find_targets[i].num_boxes
|
||||
)
|
||||
find_targets[i].object_ids_padded = packed_to_padded_naive(
|
||||
find_targets[i].object_ids, find_targets[i].num_boxes, fill_value=-1
|
||||
)
|
||||
|
||||
# Finalize the image batch
|
||||
# check sizes
|
||||
for img in img_batch[1:]:
|
||||
assert img.shape == img_batch[0].shape, "All images must have the same size"
|
||||
image_batch = torch.stack(img_batch)
|
||||
if load_image_in_fp16:
|
||||
# Optionally, cast the image tensors to fp16, which helps save GPU memory on
|
||||
# long videos with thousands of frames (where image tensors could be several GBs)
|
||||
image_batch = image_batch.half()
|
||||
|
||||
return {
|
||||
dict_key: BatchedDatapoint(
|
||||
img_batch=image_batch,
|
||||
find_text_batch=text_batch,
|
||||
find_inputs=stages,
|
||||
find_targets=find_targets,
|
||||
find_metadatas=find_metadatas,
|
||||
raw_images=raw_images,
|
||||
)
|
||||
}
|
||||
528
sam3/train/data/sam3_image_dataset.py
Normal file
528
sam3/train/data/sam3_image_dataset.py
Normal file
@@ -0,0 +1,528 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""Dataset class for modulated detection"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision
|
||||
from decord import cpu, VideoReader
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from PIL.Image import DecompressionBombError
|
||||
|
||||
from sam3.model.box_ops import box_xywh_to_xyxy
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
from .coco_json_loaders import COCO_FROM_JSON
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceMetadata:
|
||||
"""Metadata required for postprocessing"""
|
||||
|
||||
# Coco id that corresponds to the "image" for evaluation by the coco evaluator
|
||||
# This is used for our own "class agnostic" evaluation
|
||||
coco_image_id: int
|
||||
|
||||
# id in the original dataset, such that we can use the original evaluator
|
||||
original_image_id: int
|
||||
|
||||
# Original category id (if we want to use the original evaluator)
|
||||
original_category_id: int
|
||||
|
||||
# Size of the raw image (height, width)
|
||||
original_size: Tuple[int, int]
|
||||
|
||||
# Id of the object in the media
|
||||
object_id: int
|
||||
|
||||
# Index of the frame in the media (0 if single image)
|
||||
frame_index: int
|
||||
|
||||
# Whether it is for conditioning only, e.g., 0-th frame in TA is for conditioning
|
||||
# as we assume GT available in frame 0.
|
||||
is_conditioning_only: Optional[bool] = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FindQuery:
|
||||
query_text: str
|
||||
|
||||
image_id: int
|
||||
|
||||
# In case of a find query, the list of object ids that have to be predicted
|
||||
object_ids_output: List[int]
|
||||
|
||||
# This is "instance exhaustivity".
|
||||
# true iff all instances are separable and annotated
|
||||
# See below the slightly different "pixel exhaustivity"
|
||||
is_exhaustive: bool
|
||||
|
||||
# The order in which the queries are processed (only meaningful for video)
|
||||
query_processing_order: int = 0
|
||||
|
||||
# Input geometry, initially in denormalized XYXY format. Then
|
||||
# 1. converted to normalized CxCyWH by the Normalize transform
|
||||
input_bbox: Optional[torch.Tensor] = None
|
||||
input_bbox_label: Optional[torch.Tensor] = None
|
||||
|
||||
# Only for the PVS task
|
||||
input_points: Optional[torch.Tensor] = None
|
||||
|
||||
semantic_target: Optional[torch.Tensor] = None
|
||||
|
||||
# pixel exhaustivity: true iff the union of all segments (including crowds)
|
||||
# covers every pixel belonging to the target class
|
||||
# Note that instance_exhaustive implies pixel_exhaustive
|
||||
is_pixel_exhaustive: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FindQueryLoaded(FindQuery):
|
||||
# Must have default value since FindQuery has entries with default values
|
||||
inference_metadata: Optional[InferenceMetadata] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Object:
|
||||
# Initially in denormalized XYXY format, gets converted to normalized CxCyWH by the Normalize transform
|
||||
bbox: torch.Tensor
|
||||
area: float
|
||||
|
||||
# Id of the object in the media
|
||||
object_id: Optional[int] = -1
|
||||
|
||||
# Index of the frame in the media (0 if single image)
|
||||
frame_index: Optional[int] = -1
|
||||
|
||||
segment: Optional[Union[torch.Tensor, dict]] = None # RLE dict or binary mask
|
||||
|
||||
is_crowd: bool = False
|
||||
|
||||
source: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Image:
|
||||
data: Union[torch.Tensor, PILImage.Image]
|
||||
objects: List[Object]
|
||||
size: Tuple[int, int] # (height, width)
|
||||
|
||||
# For blurring augmentation
|
||||
blurring_mask: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Datapoint:
|
||||
"""Refers to an image/video and all its annotations"""
|
||||
|
||||
find_queries: List[FindQueryLoaded]
|
||||
images: List[Image]
|
||||
raw_images: Optional[List[PILImage.Image]] = None
|
||||
|
||||
|
||||
class CustomCocoDetectionAPI(VisionDataset):
|
||||
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory where images are downloaded to.
|
||||
annFile (string): Path to json annotation file.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.ToTensor``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
||||
and returns a transformed version.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
annFile: str,
|
||||
load_segmentation: bool,
|
||||
fix_fname: bool = False,
|
||||
training: bool = True,
|
||||
blurring_masks_path: Optional[str] = None,
|
||||
use_caching: bool = True,
|
||||
zstd_dict_path=None,
|
||||
filter_query=None,
|
||||
coco_json_loader: Callable = COCO_FROM_JSON,
|
||||
limit_ids: int = None,
|
||||
) -> None:
|
||||
super().__init__(root)
|
||||
|
||||
self.annFile = annFile
|
||||
self.use_caching = use_caching
|
||||
self.zstd_dict_path = zstd_dict_path
|
||||
|
||||
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
|
||||
self.load_segmentation = load_segmentation
|
||||
self.fix_fname = fix_fname
|
||||
self.filter_query = filter_query
|
||||
|
||||
self.coco = None
|
||||
self.coco_json_loader = coco_json_loader
|
||||
self.limit_ids = limit_ids
|
||||
self.set_sharded_annotation_file(0)
|
||||
self.training = training
|
||||
self.blurring_masks_path = blurring_masks_path
|
||||
|
||||
def _load_images(
|
||||
self, datapoint_id: int, img_ids_to_load: Optional[Set[int]] = None
|
||||
) -> Tuple[List[Tuple[int, PILImage.Image]], List[Dict[str, Any]]]:
|
||||
all_images = []
|
||||
all_img_metadata = []
|
||||
for current_meta in self.coco.loadImagesFromDatapoint(datapoint_id):
|
||||
img_id = current_meta["id"]
|
||||
if img_ids_to_load is not None and img_id not in img_ids_to_load:
|
||||
continue
|
||||
if self.fix_fname:
|
||||
current_meta["file_name"] = current_meta["file_name"].split("/")[-1]
|
||||
path = current_meta["file_name"]
|
||||
if self.blurring_masks_path is not None:
|
||||
mask_fname = os.path.basename(path).replace(".jpg", "-mask.json")
|
||||
mask_path = os.path.join(self.blurring_masks_path, mask_fname)
|
||||
if os.path.exists(mask_path):
|
||||
with open(mask_path, "r") as fopen:
|
||||
current_meta["blurring_mask"] = json.load(fopen)
|
||||
|
||||
all_img_metadata.append(current_meta)
|
||||
path = os.path.join(self.root, path)
|
||||
try:
|
||||
if ".mp4" in path and path[-4:] == ".mp4":
|
||||
# Going to load a video frame
|
||||
video_path, frame = path.split("@")
|
||||
video = VideoReader(video_path, ctx=cpu(0))
|
||||
# Convert to PIL image
|
||||
all_images.append(
|
||||
(
|
||||
img_id,
|
||||
torchvision.transforms.ToPILImage()(
|
||||
video[int(frame)].asnumpy()
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
with g_pathmgr.open(path, "rb") as fopen:
|
||||
all_images.append((img_id, PILImage.open(fopen).convert("RGB")))
|
||||
except FileNotFoundError as e:
|
||||
print(f"File not found: {path} from dataset: {self.annFile}")
|
||||
raise e
|
||||
|
||||
return all_images, all_img_metadata
|
||||
|
||||
def set_curr_epoch(self, epoch: int):
|
||||
self.curr_epoch = epoch
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
pass
|
||||
|
||||
def set_sharded_annotation_file(self, data_epoch: int):
|
||||
if self.coco is not None:
|
||||
return
|
||||
|
||||
assert g_pathmgr.isfile(
|
||||
self.annFile
|
||||
), f"please provide valid annotation file. Missing: {self.annFile}"
|
||||
annFile = g_pathmgr.get_local_path(self.annFile)
|
||||
|
||||
if self.coco is not None:
|
||||
del self.coco
|
||||
|
||||
self.coco = self.coco_json_loader(annFile)
|
||||
# Use a torch tensor here to optimize memory usage when using several dataloaders
|
||||
ids_list = list(sorted(self.coco.getDatapointIds()))
|
||||
if self.limit_ids is not None:
|
||||
local_random = random.Random(len(ids_list))
|
||||
local_random.shuffle(ids_list)
|
||||
ids_list = ids_list[: self.limit_ids]
|
||||
self.ids = torch.as_tensor(ids_list, dtype=torch.long)
|
||||
|
||||
def __getitem__(self, index: int) -> Datapoint:
|
||||
return self._load_datapoint(index)
|
||||
|
||||
def _load_datapoint(self, index: int) -> Datapoint:
|
||||
"""A separate method for easy overriding in subclasses."""
|
||||
id = self.ids[index].item()
|
||||
pil_images, img_metadata = self._load_images(id)
|
||||
queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id)
|
||||
return self.load_queries(pil_images, annotations, queries, img_metadata)
|
||||
|
||||
def load_queries(self, pil_images, annotations, queries, img_metadata):
|
||||
"""Transform the raw image and queries into a Datapoint sample."""
|
||||
images: List[Image] = []
|
||||
id2index_img = {}
|
||||
id2index_obj = {}
|
||||
id2index_find_query = {}
|
||||
id2imsize = {}
|
||||
assert len(pil_images) == len(img_metadata)
|
||||
for i in range(len(pil_images)):
|
||||
w, h = pil_images[i][1].size
|
||||
blurring_mask = None
|
||||
if "blurring_mask" in img_metadata[i]:
|
||||
blurring_mask = img_metadata[i]["blurring_mask"]
|
||||
images.append(
|
||||
Image(
|
||||
data=pil_images[i][1],
|
||||
objects=[],
|
||||
size=(h, w),
|
||||
blurring_mask=blurring_mask,
|
||||
)
|
||||
)
|
||||
id2index_img[pil_images[i][0]] = i
|
||||
id2imsize[pil_images[i][0]] = (h, w)
|
||||
|
||||
for annotation in annotations:
|
||||
image_id = id2index_img[annotation["image_id"]]
|
||||
bbox = box_xywh_to_xyxy(torch.as_tensor(annotation["bbox"])).view(1, 4)
|
||||
h, w = id2imsize[annotation["image_id"]]
|
||||
bbox[:, 0::2].mul_(w).clamp_(min=0, max=w)
|
||||
bbox[:, 1::2].mul_(h).clamp_(min=0, max=h)
|
||||
segment = None
|
||||
if self.load_segmentation and "segmentation" in annotation:
|
||||
# We're not decoding the RLE here, a transform will do it lazily later
|
||||
segment = annotation["segmentation"]
|
||||
images[image_id].objects.append(
|
||||
Object(
|
||||
bbox=bbox[0],
|
||||
area=annotation["area"],
|
||||
object_id=(
|
||||
annotation["object_id"] if "object_id" in annotation else -1
|
||||
),
|
||||
frame_index=(
|
||||
annotation["frame_index"] if "frame_index" in annotation else -1
|
||||
),
|
||||
segment=segment,
|
||||
is_crowd=(
|
||||
annotation["is_crowd"] if "is_crowd" in annotation else None
|
||||
),
|
||||
source=annotation["source"] if "source" in annotation else "",
|
||||
)
|
||||
)
|
||||
id2index_obj[annotation["id"]] = len(images[image_id].objects) - 1
|
||||
|
||||
find_queries = []
|
||||
stage2num_queries = Counter()
|
||||
for i, query in enumerate(queries):
|
||||
stage2num_queries[query["query_processing_order"]] += 1
|
||||
id2index_find_query[query["id"]] = i
|
||||
|
||||
# Sanity check: all the stages should have the same number of queries
|
||||
if len(stage2num_queries) == 0:
|
||||
num_queries_per_stage = 0
|
||||
else:
|
||||
num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
|
||||
for stage, num_queries in stage2num_queries.items():
|
||||
assert (
|
||||
num_queries == num_queries_per_stage
|
||||
), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
|
||||
|
||||
for query_id, query in enumerate(queries):
|
||||
h, w = id2imsize[query["image_id"]]
|
||||
if (
|
||||
"input_box" in query
|
||||
and query["input_box"] is not None
|
||||
and len(query["input_box"]) > 0
|
||||
):
|
||||
bbox = box_xywh_to_xyxy(torch.as_tensor(query["input_box"])).view(-1, 4)
|
||||
bbox[:, 0::2].mul_(w).clamp_(min=0, max=w)
|
||||
bbox[:, 1::2].mul_(h).clamp_(min=0, max=h)
|
||||
if "input_box_label" in query and query["input_box_label"] is not None:
|
||||
bbox_label = torch.as_tensor(
|
||||
query["input_box_label"], dtype=torch.long
|
||||
).view(-1)
|
||||
assert len(bbox_label) == len(bbox)
|
||||
else:
|
||||
# assume the boxes are positives
|
||||
bbox_label = torch.ones(len(bbox), dtype=torch.long)
|
||||
else:
|
||||
bbox = None
|
||||
bbox_label = None
|
||||
|
||||
if "input_points" in query and query["input_points"] is not None:
|
||||
points = torch.as_tensor(query["input_points"]).view(1, -1, 3)
|
||||
points[:, :, 0:1].mul_(w).clamp_(min=0, max=w)
|
||||
points[:, :, 1:2].mul_(h).clamp_(min=0, max=h)
|
||||
else:
|
||||
points = None
|
||||
|
||||
try:
|
||||
original_image_id = int(
|
||||
img_metadata[id2index_img[query["image_id"]]]["original_img_id"]
|
||||
)
|
||||
except ValueError:
|
||||
original_image_id = -1
|
||||
|
||||
try:
|
||||
img_metadata_query = img_metadata[id2index_img[query["image_id"]]]
|
||||
coco_image_id = (
|
||||
int(img_metadata_query["coco_img_id"])
|
||||
if "coco_img_id" in img_metadata_query
|
||||
else query["id"]
|
||||
)
|
||||
except KeyError:
|
||||
coco_image_id = -1
|
||||
|
||||
try:
|
||||
original_category_id = int(query["original_cat_id"])
|
||||
except (ValueError, KeyError):
|
||||
original_category_id = -1
|
||||
|
||||
# For evaluation, we associate the ids of the object to be tracked to the query
|
||||
if query["object_ids_output"]:
|
||||
obj_id = query["object_ids_output"][0]
|
||||
obj_idx = id2index_obj[obj_id]
|
||||
image_idx = id2index_img[query["image_id"]]
|
||||
object_id = images[image_idx].objects[obj_idx].object_id
|
||||
frame_index = images[image_idx].objects[obj_idx].frame_index
|
||||
else:
|
||||
object_id = -1
|
||||
frame_index = -1
|
||||
|
||||
find_queries.append(
|
||||
FindQueryLoaded(
|
||||
# id=query["id"],
|
||||
# query_type=qtype,
|
||||
query_text=(
|
||||
query["query_text"] if query["query_text"] is not None else ""
|
||||
),
|
||||
image_id=id2index_img[query["image_id"]],
|
||||
input_bbox=bbox,
|
||||
input_bbox_label=bbox_label,
|
||||
input_points=points,
|
||||
object_ids_output=[
|
||||
id2index_obj[obj_id] for obj_id in query["object_ids_output"]
|
||||
],
|
||||
is_exhaustive=query["is_exhaustive"],
|
||||
is_pixel_exhaustive=(
|
||||
query["is_pixel_exhaustive"]
|
||||
if "is_pixel_exhaustive" in query
|
||||
else (
|
||||
query["is_exhaustive"] if query["is_exhaustive"] else None
|
||||
)
|
||||
),
|
||||
query_processing_order=query["query_processing_order"],
|
||||
inference_metadata=InferenceMetadata(
|
||||
coco_image_id=-1 if self.training else coco_image_id,
|
||||
original_image_id=(-1 if self.training else original_image_id),
|
||||
frame_index=frame_index,
|
||||
original_category_id=original_category_id,
|
||||
original_size=(h, w),
|
||||
object_id=object_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return Datapoint(
|
||||
find_queries=find_queries,
|
||||
images=images,
|
||||
raw_images=[p[1] for p in pil_images],
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.ids)
|
||||
|
||||
|
||||
class Sam3ImageDataset(CustomCocoDetectionAPI):
|
||||
def __init__(
|
||||
self,
|
||||
img_folder,
|
||||
ann_file,
|
||||
transforms,
|
||||
max_ann_per_img: int,
|
||||
multiplier: int,
|
||||
training: bool,
|
||||
load_segmentation: bool = False,
|
||||
max_train_queries: int = 81,
|
||||
max_val_queries: int = 300,
|
||||
fix_fname: bool = False,
|
||||
is_sharded_annotation_dir: bool = False,
|
||||
blurring_masks_path: Optional[str] = None,
|
||||
use_caching: bool = True,
|
||||
zstd_dict_path=None,
|
||||
filter_query=None,
|
||||
coco_json_loader: Callable = COCO_FROM_JSON,
|
||||
limit_ids: int = None,
|
||||
):
|
||||
super(Sam3ImageDataset, self).__init__(
|
||||
img_folder,
|
||||
ann_file,
|
||||
fix_fname=fix_fname,
|
||||
load_segmentation=load_segmentation,
|
||||
training=training,
|
||||
blurring_masks_path=blurring_masks_path,
|
||||
use_caching=use_caching,
|
||||
zstd_dict_path=zstd_dict_path,
|
||||
filter_query=filter_query,
|
||||
coco_json_loader=coco_json_loader,
|
||||
limit_ids=limit_ids,
|
||||
)
|
||||
|
||||
self._transforms = transforms
|
||||
self.training = training
|
||||
self.max_ann_per_img = max_ann_per_img
|
||||
self.max_train_queries = max_train_queries
|
||||
self.max_val_queries = max_val_queries
|
||||
|
||||
self.repeat_factors = torch.ones(len(self.ids), dtype=torch.float32)
|
||||
|
||||
self.repeat_factors *= multiplier
|
||||
print(f"Raw dataset length = {len(self.ids)}")
|
||||
|
||||
self._MAX_RETRIES = 100
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.__orig_getitem__(idx)
|
||||
|
||||
def __orig_getitem__(self, idx):
|
||||
for _ in range(self._MAX_RETRIES):
|
||||
try:
|
||||
datapoint = super(Sam3ImageDataset, self).__getitem__(idx)
|
||||
|
||||
# This can be done better by filtering the offending find queries
|
||||
# However, this requires care:
|
||||
# - Delete any find/get query that may depend on the deleted one
|
||||
# - Re-compute the indexes in the pointers to account for the deleted finds
|
||||
for q in datapoint.find_queries:
|
||||
if len(q.object_ids_output) > self.max_ann_per_img:
|
||||
raise DecompressionBombError(
|
||||
f"Too many outputs ({len(q.object_ids_output)})"
|
||||
)
|
||||
|
||||
max_queries = (
|
||||
self.max_train_queries if self.training else self.max_val_queries
|
||||
)
|
||||
|
||||
if len(datapoint.find_queries) > max_queries:
|
||||
raise DecompressionBombError(
|
||||
f"Too many find queries ({len(datapoint.find_queries)})"
|
||||
)
|
||||
|
||||
if len(datapoint.find_queries) == 0:
|
||||
raise DecompressionBombError("No find queries")
|
||||
for transform in self._transforms:
|
||||
datapoint = transform(datapoint, epoch=self.curr_epoch)
|
||||
|
||||
break
|
||||
except (DecompressionBombError, OSError, ValueError) as error:
|
||||
sys.stderr.write(f"ERROR: got loading error on datapoint {idx}\n")
|
||||
sys.stderr.write(f"Exception: {error}\n")
|
||||
sys.stderr.write(traceback.format_exc())
|
||||
idx = (idx + 1) % len(self)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Failed {self._MAX_RETRIES} times trying to load an image."
|
||||
)
|
||||
|
||||
return datapoint
|
||||
327
sam3/train/data/sam3_video_dataset.py
Normal file
327
sam3/train/data/sam3_video_dataset.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import copy
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
# from decord import cpu, VideoReader
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from .sam3_image_dataset import Datapoint, Sam3ImageDataset
|
||||
|
||||
|
||||
SEED = 42
|
||||
|
||||
|
||||
class VideoGroundingDataset(Sam3ImageDataset):
|
||||
def __init__(
|
||||
self,
|
||||
num_stages_sample: int = 4,
|
||||
stage_stride_min: int = 1,
|
||||
stage_stride_max: int = 5,
|
||||
random_reverse_time_axis: bool = True,
|
||||
is_tiling_single_image: bool = False,
|
||||
# By default, we remove find those queries with geometric inputs (input_box or input_points)
|
||||
# when creating synthetic videos from frames (since they are not *video-level* text prompts).
|
||||
# If we need them later, we can sample them on-the-fly via transforms or inside the model.
|
||||
tile_img_keep_find_queries_with_geo_inputs: bool = False,
|
||||
tile_img_keep_get_queries: bool = False,
|
||||
# the maximum number of find queries (for each frame) to keep in a video; if the datapoint
|
||||
# contains more queries per frame than this limit, we subsample them to avoid OOM errors
|
||||
max_query_num: int = -1, # the default -1 means no limit
|
||||
# whether to override the "is_exhaustive" flag of the loaded find queries to True
|
||||
# (by default, our video datasets are ingested with is_exhaustive=False, since the YTVIS format
|
||||
# annotations doesn't involve an "is_exhaustive" flag; this means that those unmatched (negative)
|
||||
# detection queries or tracking queries do not receive a classification loss given that we have
|
||||
# weak_loss=True in IABCEMdetr -- this could lead to false positives for both image detection
|
||||
# and video association.)
|
||||
override_query_is_exhaustive_to_true: bool = False,
|
||||
# the maximum number of masklets in a video; if the datapoint contains more masklets
|
||||
# than this limit, we skip the datapoint to avoid OOM errors (this is useful for
|
||||
# training with large videos that contain many objects)
|
||||
max_masklet_num_in_video: int = 300, # 300 masklets is usually OK to avoid OOM
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loading video grounding data
|
||||
|
||||
Video frame sampling parameters (for training only):
|
||||
- num_stages_sample: number of frames to sample from the video during training
|
||||
- stage_stride_min: minimum stride between sampled frames during training
|
||||
- stage_stride_max: maximum stride between sampled frames during training (if it's
|
||||
greater than stage_stride_min, the actual stride is sampled uniformly between min
|
||||
and max; during inference, we always use all frames in the video with stride=1)
|
||||
- random_reverse_time_axis: whether to randomly invert the video's temporal axis
|
||||
(i.e. playing it backwards) during training
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
assert num_stages_sample >= 1
|
||||
assert stage_stride_min >= 1
|
||||
assert stage_stride_max >= stage_stride_min
|
||||
self.num_stages_sample = num_stages_sample
|
||||
self.stage_stride_min = stage_stride_min
|
||||
self.stage_stride_max = stage_stride_max
|
||||
self.random_reverse_time_axis = random_reverse_time_axis
|
||||
self.is_tiling_single_image = is_tiling_single_image
|
||||
self.tile_img_keep_find_queries_with_geo_inputs = (
|
||||
tile_img_keep_find_queries_with_geo_inputs
|
||||
)
|
||||
self.tile_img_keep_get_queries = tile_img_keep_get_queries
|
||||
self.max_query_num = max_query_num
|
||||
self.override_query_is_exhaustive_to_true = override_query_is_exhaustive_to_true
|
||||
self.max_masklet_num_in_video = max_masklet_num_in_video
|
||||
self.rng = random.Random()
|
||||
self.set_curr_epoch(0)
|
||||
|
||||
def set_curr_epoch(self, epoch: int):
|
||||
super().set_curr_epoch(epoch)
|
||||
self.rng.seed(SEED + epoch)
|
||||
|
||||
def _load_datapoint(self, index: int) -> Datapoint:
|
||||
id = self.ids[index].item()
|
||||
queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id)
|
||||
|
||||
# we subsample the video frames during training
|
||||
if self.training and not self.is_tiling_single_image:
|
||||
# pick a random stride for sampling query stages (`randint` includes both ends)
|
||||
stage_stride = self.rng.randint(
|
||||
self.stage_stride_min, self.stage_stride_max
|
||||
)
|
||||
stage_ids_to_keep = self._sample_stage_ids(
|
||||
queries, self.num_stages_sample, stage_stride
|
||||
)
|
||||
# filter the queries and annotations to keep only the selected stages
|
||||
# (also remap the stage ids so that they are contiguous and start from 0)
|
||||
reverse_time_axis = (
|
||||
self.rng.random() < 0.5 if self.random_reverse_time_axis else False
|
||||
)
|
||||
queries, annotations, kept_img_ids = self._filter_query_and_anns(
|
||||
queries,
|
||||
annotations,
|
||||
stage_ids_to_keep,
|
||||
remap_stage_id=True,
|
||||
reverse_time_axis=reverse_time_axis,
|
||||
)
|
||||
pil_images, img_metadata = self._load_images(id, kept_img_ids)
|
||||
if reverse_time_axis:
|
||||
# reverse the temporal ordering of the images and their metadata
|
||||
# so that the image order matches the query order
|
||||
pil_images = pil_images[::-1]
|
||||
img_metadata = img_metadata[::-1]
|
||||
else:
|
||||
pil_images, img_metadata = self._load_images(id)
|
||||
|
||||
# check that all the images have the same image size (they are expected
|
||||
# to have the same image size since they are frames from the same video)
|
||||
assert all(p.size == pil_images[0][1].size for _, p in pil_images)
|
||||
|
||||
queries.sort(key=lambda q: q["query_processing_order"])
|
||||
if self.override_query_is_exhaustive_to_true:
|
||||
for query in queries:
|
||||
query["is_exhaustive"] = True
|
||||
datapoint = self.load_queries(pil_images, annotations, queries, img_metadata)
|
||||
|
||||
# skip datapoints with too many masklets to avoid OOM errors
|
||||
num_masklets_in_video = len(datapoint.images[0].objects)
|
||||
if num_masklets_in_video > self.max_masklet_num_in_video > 0:
|
||||
logging.warning(
|
||||
f"Datapoint {id} has ({num_masklets_in_video=}), exceeding "
|
||||
f"the maximum allowed ({self.max_masklet_num_in_video}). "
|
||||
"Skipping this datapoint."
|
||||
)
|
||||
next_index = (index + 1) % len(self)
|
||||
return self._load_datapoint(next_index) # move to the next datapoint
|
||||
|
||||
if self.is_tiling_single_image:
|
||||
datapoint = self._tile_single_image_data(datapoint, self.num_stages_sample)
|
||||
if self.max_query_num > 0:
|
||||
datapoint = self._subsample_queries(datapoint, self.max_query_num)
|
||||
|
||||
# ensure that all find queries have the same processing order as their image id
|
||||
for query in datapoint.find_queries:
|
||||
assert query.image_id == query.query_processing_order, (
|
||||
f"find query has inconsistent image_id and "
|
||||
f"query_processing_order: {query.image_id=} vs "
|
||||
f"{query.query_processing_order=}"
|
||||
)
|
||||
return datapoint
|
||||
|
||||
def _sample_stage_ids(self, queries, num_stages_sample, stage_stride):
|
||||
"""Sample a subset of stage ids from all queries."""
|
||||
# Later we can perhaps turn it into a Sampler class to be more flexible.
|
||||
all_stage_ids = sorted(set(q["query_processing_order"] for q in queries))
|
||||
num_stages_total = len(all_stage_ids)
|
||||
if num_stages_total < num_stages_sample:
|
||||
raise ValueError("Not enough stages to sample")
|
||||
|
||||
# the difference in index between the first and the last sampled stage ids
|
||||
b_e_gap = (num_stages_sample - 1) * stage_stride
|
||||
if b_e_gap > num_stages_total - 1:
|
||||
# In this case, it's not possible to sample with the provide stride,
|
||||
# so we use the maximum possible stride.
|
||||
prev_stage_stride = stage_stride
|
||||
stage_stride = math.floor((num_stages_total - 1) / (num_stages_sample - 1))
|
||||
logging.info(
|
||||
f"lowering stride from {prev_stage_stride} to {stage_stride} to "
|
||||
f"sample {num_stages_sample} stages (from {num_stages_total} total)"
|
||||
)
|
||||
b_e_gap = (num_stages_sample - 1) * stage_stride
|
||||
|
||||
# randomly select a starting stage id (`randint` includes both ends)
|
||||
b_max = len(all_stage_ids) - 1 - b_e_gap
|
||||
b = self.rng.randint(0, b_max)
|
||||
e = b + b_e_gap
|
||||
stage_ids_to_keep = all_stage_ids[b : e + 1 : stage_stride]
|
||||
return stage_ids_to_keep
|
||||
|
||||
def _filter_query_and_anns(
|
||||
self, queries, annotations, stage_ids_to_keep, remap_stage_id, reverse_time_axis
|
||||
):
|
||||
"""Filter queries and annotations to only keep those in `stage_ids_to_keep`."""
|
||||
stage_ids_to_keep = set(stage_ids_to_keep)
|
||||
kept_img_ids = set()
|
||||
kept_stage_ids = set()
|
||||
|
||||
# Filter queries -- keep those queries with stage_id in `stage_ids_to_keep`
|
||||
filtered_queries = []
|
||||
for query in queries:
|
||||
input_box = query.get("input_box", None)
|
||||
input_points = query.get("input_points", None)
|
||||
has_geo_input = input_box is not None or input_points is not None
|
||||
if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs:
|
||||
continue
|
||||
stage_id = query["query_processing_order"]
|
||||
if stage_id in stage_ids_to_keep:
|
||||
kept_img_ids.add(query["image_id"])
|
||||
kept_stage_ids.add(stage_id)
|
||||
filtered_queries.append(query)
|
||||
# Check that all frames in `stage_ids_to_keep` are present after filtering
|
||||
all_frame_present = kept_stage_ids == stage_ids_to_keep
|
||||
assert all_frame_present, f"{kept_stage_ids=} vs {stage_ids_to_keep=}"
|
||||
if remap_stage_id:
|
||||
# Remap those kept stage ids to be contiguous and starting from 0
|
||||
old_stage_ids = sorted(kept_stage_ids, reverse=reverse_time_axis)
|
||||
stage_id_old2new = {old: new for new, old in enumerate(old_stage_ids)}
|
||||
for query in filtered_queries:
|
||||
ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
|
||||
ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
|
||||
assert (
|
||||
ptr_x_is_empty and ptr_y_is_empty
|
||||
), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
|
||||
query["query_processing_order"] = stage_id_old2new[
|
||||
query["query_processing_order"]
|
||||
]
|
||||
|
||||
# Filter annotations -- keep those annotations with image_id in `kept_img_ids`
|
||||
filtered_annotations = [
|
||||
ann for ann in annotations if ann["image_id"] in kept_img_ids
|
||||
]
|
||||
|
||||
return filtered_queries, filtered_annotations, kept_img_ids
|
||||
|
||||
def _tile_single_image_data(self, datapoint: Datapoint, num_stages_sample: int):
|
||||
"""
|
||||
Tile a single image and its queries to simulate video frames. The output is a
|
||||
datapoint with *identical video frames* (i.e. the same static image) and needs
|
||||
further transforms (e.g. affine) to get video frames with different content.
|
||||
"""
|
||||
# tile `images: List[Image]`
|
||||
assert len(datapoint.images) == 1, "Expected only one single image"
|
||||
tiled_images = [
|
||||
copy.deepcopy(datapoint.images[0]) for _ in range(num_stages_sample)
|
||||
]
|
||||
for stage_id, img in enumerate(tiled_images):
|
||||
for obj in img.objects:
|
||||
obj.frame_index = stage_id
|
||||
|
||||
# tile `raw_images: Optional[List[PILImage.Image]] = None`
|
||||
tiled_raw_images = None
|
||||
if datapoint.raw_images is not None:
|
||||
assert len(datapoint.raw_images) == 1, "Expected only one single image"
|
||||
tiled_raw_images = [
|
||||
datapoint.raw_images[0].copy() for _ in range(num_stages_sample)
|
||||
]
|
||||
|
||||
# tile `find_queries: List[FindQueryLoaded]`
|
||||
tiled_find_queries_per_stage = [[] for _ in range(num_stages_sample)]
|
||||
for query in datapoint.find_queries:
|
||||
assert query.image_id == 0
|
||||
assert query.query_processing_order == 0
|
||||
# check and make sure that a query doesn't contain pointers or references
|
||||
# to other queries (that cannot be tiled)
|
||||
assert query.ptr_x is None and query.ptr_y is None
|
||||
assert query.ptr_mem is None
|
||||
# assert query.wkdata_qid is None
|
||||
# assert query.other_positive_qids is None
|
||||
# assert query.negative_qids is None
|
||||
has_geo_input = (
|
||||
query.input_bbox is not None or query.input_points is not None
|
||||
)
|
||||
if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs:
|
||||
continue
|
||||
for stage_id in range(num_stages_sample):
|
||||
# copy the query and update the image_id
|
||||
new_query = copy.deepcopy(query)
|
||||
new_query.image_id = stage_id
|
||||
new_query.query_processing_order = stage_id
|
||||
if new_query.inference_metadata is not None:
|
||||
new_query.inference_metadata.frame_index = stage_id
|
||||
tiled_find_queries_per_stage[stage_id].append(new_query)
|
||||
|
||||
tiled_find_queries = sum(tiled_find_queries_per_stage, [])
|
||||
|
||||
# tile `get_queries: List[GetQuery]` -- we skip them for now (since they involve
|
||||
# a pointer to a find query that is complicated to tile, and there is not an
|
||||
# imminent use case for them in the video grounding task in the near future)
|
||||
if self.tile_img_keep_get_queries:
|
||||
raise NotImplementedError("Tiling get queries is not implemented yet")
|
||||
else:
|
||||
tiled_get_queries = []
|
||||
|
||||
return Datapoint(
|
||||
images=tiled_images,
|
||||
raw_images=tiled_raw_images,
|
||||
find_queries=tiled_find_queries,
|
||||
get_queries=tiled_get_queries,
|
||||
)
|
||||
|
||||
def _subsample_queries(self, datapoint: Datapoint, max_query_num: int):
|
||||
"""Subsample to keep at most `max_query_num` queries per frame in a datapoint."""
|
||||
# aggregate the find queries per stage
|
||||
num_frames = max(q.query_processing_order for q in datapoint.find_queries) + 1
|
||||
find_queries_per_stage = [[] for _ in range(num_frames)]
|
||||
for query in datapoint.find_queries:
|
||||
find_queries_per_stage[query.query_processing_order].append(query)
|
||||
|
||||
# verify that all the stages have the same number of queries
|
||||
num_queries_per_stage = len(find_queries_per_stage[0])
|
||||
for queries in find_queries_per_stage:
|
||||
assert len(queries) == num_queries_per_stage
|
||||
if max_query_num <= 0 or num_queries_per_stage <= max_query_num:
|
||||
return datapoint
|
||||
|
||||
# subsample the queries to keep only `max_query_num` queries
|
||||
sampled_inds = self.rng.sample(range(num_queries_per_stage), max_query_num)
|
||||
sampled_find_queries_per_stage = [
|
||||
[queries[idx] for idx in sampled_inds] for queries in find_queries_per_stage
|
||||
]
|
||||
sampled_find_queries = sum(sampled_find_queries_per_stage, [])
|
||||
return Datapoint(
|
||||
images=datapoint.images,
|
||||
raw_images=datapoint.raw_images,
|
||||
find_queries=sampled_find_queries,
|
||||
get_queries=datapoint.get_queries,
|
||||
)
|
||||
52
sam3/train/data/torch_dataset.py
Normal file
52
sam3/train/data/torch_dataset.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset
|
||||
|
||||
|
||||
class TorchDataset:
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
shuffle: bool,
|
||||
pin_memory: bool,
|
||||
drop_last: bool,
|
||||
collate_fn: Optional[Callable] = None,
|
||||
worker_init_fn: Optional[Callable] = None,
|
||||
enable_distributed_sampler=True,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.shuffle = shuffle
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
self.collate_fn = collate_fn
|
||||
self.worker_init_fn = worker_init_fn
|
||||
assert not isinstance(self.dataset, IterableDataset), "Not supported yet"
|
||||
if enable_distributed_sampler:
|
||||
self.sampler = DistributedSampler(self.dataset, shuffle=self.shuffle)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
def get_loader(self, epoch) -> Iterable:
|
||||
if self.sampler:
|
||||
self.sampler.set_epoch(epoch)
|
||||
if hasattr(self.dataset, "epoch"):
|
||||
self.dataset.epoch = epoch
|
||||
if hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
return DataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
drop_last=self.drop_last,
|
||||
sampler=self.sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
worker_init_fn=self.worker_init_fn,
|
||||
)
|
||||
Reference in New Issue
Block a user