Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
360
sam3/train/data/collator.py
Normal file
360
sam3/train/data/collator.py
Normal file
@@ -0,0 +1,360 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
|
||||
from typing import Any, get_args, get_origin, List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.data_misc import (
|
||||
BatchedDatapoint,
|
||||
BatchedFindTarget,
|
||||
BatchedInferenceMetadata,
|
||||
FindStage,
|
||||
)
|
||||
|
||||
from .sam3_image_dataset import Datapoint
|
||||
|
||||
|
||||
MyTensor = Union[torch.Tensor, List[Any]]
|
||||
|
||||
|
||||
def convert_my_tensors(obj):
|
||||
def is_optional_field(field) -> bool:
|
||||
return get_origin(field) is Union and type(None) in get_args(field)
|
||||
|
||||
for field in fields(obj):
|
||||
if is_dataclass(getattr(obj, field.name)):
|
||||
convert_my_tensors(getattr(obj, field.name))
|
||||
continue
|
||||
|
||||
field_type = field.type
|
||||
if is_optional_field(field.type):
|
||||
field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type
|
||||
|
||||
if field_type != MyTensor or getattr(obj, field.name) is None:
|
||||
continue
|
||||
|
||||
elif len(getattr(obj, field.name)) and isinstance(
|
||||
getattr(obj, field.name)[0], torch.Tensor
|
||||
):
|
||||
stack_dim = 0
|
||||
if field.name in [
|
||||
"input_boxes",
|
||||
"input_boxes_label",
|
||||
]:
|
||||
stack_dim = 1
|
||||
setattr(
|
||||
obj,
|
||||
field.name,
|
||||
torch.stack(getattr(obj, field.name), dim=stack_dim).to(
|
||||
getattr(obj, field.name + "__type")
|
||||
),
|
||||
)
|
||||
else:
|
||||
setattr(
|
||||
obj,
|
||||
field.name,
|
||||
torch.as_tensor(
|
||||
getattr(obj, field.name), dtype=getattr(obj, field.name + "__type")
|
||||
),
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
def packed_to_padded_naive(boxes_packed, num_boxes, fill_value=0):
|
||||
"""
|
||||
Convert a packed tensor of bounding boxes to a padded tensor of bounding
|
||||
boxes. Naive implementation using a loop.
|
||||
|
||||
Inputs:
|
||||
- boxes_packed: Tensor of shape (N_1 + ... + N_B, 4)
|
||||
- num_boxes: Tensor of shape (B,) where num_boxes[i] = N_i
|
||||
|
||||
Returns:
|
||||
- boxes_padded: Tensor of shape (B, N_max, 4) where N_max = max_i N_i
|
||||
"""
|
||||
B = num_boxes.shape[0]
|
||||
Ns = num_boxes.tolist()
|
||||
|
||||
boxes_padded = boxes_packed.new_zeros(B, max(Ns), *boxes_packed.shape[1:])
|
||||
if fill_value != 0:
|
||||
boxes_padded[...] = fill_value
|
||||
prev_idx = 0
|
||||
for i in range(B):
|
||||
next_idx = prev_idx + Ns[i]
|
||||
boxes_padded[i, : Ns[i]] = boxes_packed[prev_idx:next_idx]
|
||||
prev_idx = next_idx
|
||||
return boxes_padded
|
||||
|
||||
|
||||
def pad_tensor_list_to_longest(
|
||||
tensors: List[torch.Tensor], dim=0, pad_val=0
|
||||
) -> List[torch.Tensor]:
|
||||
# Edits the list in-place
|
||||
if not tensors:
|
||||
return tensors
|
||||
pad_len = max(t.shape[dim] for t in tensors)
|
||||
for i in range(len(tensors)):
|
||||
n_dims = len(tensors[i].shape)
|
||||
n_right_dims = (n_dims - 1) - (n_dims + dim) % n_dims
|
||||
n_pad = pad_len - tensors[i].shape[dim]
|
||||
pad_tuple = tuple([0] * 2 * n_right_dims + [0, n_pad])
|
||||
tensors[i] = torch.nn.functional.pad(tensors[i], pad_tuple, value=pad_val)
|
||||
return tensors
|
||||
|
||||
|
||||
def collate_fn_api_with_chunking(
|
||||
batch,
|
||||
num_chunks,
|
||||
dict_key,
|
||||
with_seg_masks=False,
|
||||
input_points_embedding_dim=257,
|
||||
repeats: int = 0,
|
||||
load_image_in_fp16: bool = False,
|
||||
):
|
||||
assert num_chunks >= 1, "num_chunks must be >= 1"
|
||||
|
||||
# split the batch into num_chunks chunks
|
||||
batch_chunks = [batch[i::num_chunks] for i in range(num_chunks)]
|
||||
|
||||
# collate each chunk
|
||||
collated_chunks = [
|
||||
collate_fn_api(
|
||||
chunk,
|
||||
dict_key,
|
||||
with_seg_masks,
|
||||
input_points_embedding_dim,
|
||||
repeats,
|
||||
# ptr_behaviour,
|
||||
load_image_in_fp16,
|
||||
)
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
return collated_chunks
|
||||
|
||||
|
||||
def collate_fn_api(
|
||||
batch: List[Datapoint],
|
||||
dict_key,
|
||||
with_seg_masks=False,
|
||||
input_points_embedding_dim=257,
|
||||
repeats: int = 0,
|
||||
load_image_in_fp16: bool = False,
|
||||
):
|
||||
# img_batch = torch.stack(sum([[img.data for img in v.images] for v in batch], []))
|
||||
img_batch = []
|
||||
text_batch = []
|
||||
raw_images = None
|
||||
|
||||
num_stages = (
|
||||
max(q.query_processing_order for data in batch for q in data.find_queries) + 1
|
||||
)
|
||||
|
||||
stages = [
|
||||
FindStage(
|
||||
img_ids=[],
|
||||
text_ids=[],
|
||||
input_boxes=[],
|
||||
input_boxes_label=[],
|
||||
input_boxes_mask=[],
|
||||
input_points=[],
|
||||
input_points_mask=[],
|
||||
object_ids=[],
|
||||
)
|
||||
for _ in range(num_stages)
|
||||
]
|
||||
find_targets = [
|
||||
BatchedFindTarget(
|
||||
num_boxes=[],
|
||||
boxes=[],
|
||||
boxes_padded=[],
|
||||
is_exhaustive=[],
|
||||
segments=[],
|
||||
semantic_segments=[],
|
||||
is_valid_segment=[],
|
||||
repeated_boxes=[],
|
||||
object_ids=[],
|
||||
object_ids_padded=[],
|
||||
)
|
||||
for _ in range(num_stages)
|
||||
]
|
||||
find_metadatas = [
|
||||
BatchedInferenceMetadata(
|
||||
coco_image_id=[],
|
||||
original_size=[],
|
||||
object_id=[],
|
||||
frame_index=[],
|
||||
original_image_id=[],
|
||||
original_category_id=[],
|
||||
is_conditioning_only=[],
|
||||
)
|
||||
for _ in range(num_stages)
|
||||
]
|
||||
|
||||
offset_img_id = 0
|
||||
offset_query_id = [0 for _ in range(num_stages)]
|
||||
for i, data in enumerate(batch):
|
||||
img_batch.extend([img.data for img in data.images])
|
||||
|
||||
if data.raw_images is not None:
|
||||
if raw_images is None:
|
||||
raw_images = []
|
||||
raw_images.extend(data.raw_images)
|
||||
|
||||
# Conversion of query_ids indexing in a datapoint to query_ids indexing in a stage
|
||||
datapoint_query_id_2_stage_query_id = []
|
||||
for q in data.find_queries:
|
||||
stage_id = q.query_processing_order
|
||||
datapoint_query_id_2_stage_query_id.append(offset_query_id[stage_id])
|
||||
offset_query_id[stage_id] += 1
|
||||
|
||||
for j, q in enumerate(data.find_queries):
|
||||
stage_id = q.query_processing_order
|
||||
stages[stage_id].img_ids.append(q.image_id + offset_img_id)
|
||||
if q.query_text not in text_batch:
|
||||
text_batch.append(q.query_text)
|
||||
stages[stage_id].text_ids.append(text_batch.index(q.query_text))
|
||||
|
||||
assert (
|
||||
q.inference_metadata is not None
|
||||
), "inference_metadata must be provided when FindQueryLoaded is created."
|
||||
for f in fields(q.inference_metadata):
|
||||
getattr(find_metadatas[stage_id], f.name).append(
|
||||
getattr(q.inference_metadata, f.name)
|
||||
)
|
||||
|
||||
if q.input_bbox is not None:
|
||||
assert q.input_bbox.numel() % 4 == 0
|
||||
assert q.input_bbox_label is not None
|
||||
nb_boxes = q.input_bbox.numel() // 4
|
||||
assert len(q.input_bbox_label) == nb_boxes
|
||||
stages[stage_id].input_boxes.append(q.input_bbox.view(nb_boxes, 4))
|
||||
stages[stage_id].input_boxes_label.append(
|
||||
q.input_bbox_label.view(nb_boxes)
|
||||
)
|
||||
stages[stage_id].input_boxes_mask.append(
|
||||
torch.zeros(nb_boxes, dtype=torch.bool)
|
||||
)
|
||||
else:
|
||||
stages[stage_id].input_boxes.append(torch.zeros(0, 4))
|
||||
stages[stage_id].input_boxes_label.append(
|
||||
torch.zeros(0, dtype=torch.bool)
|
||||
)
|
||||
stages[stage_id].input_boxes_mask.append(
|
||||
torch.ones(0, dtype=torch.bool)
|
||||
)
|
||||
|
||||
if q.input_points is not None:
|
||||
stages[stage_id].input_points.append(
|
||||
q.input_points.squeeze(0) # Strip a trivial batch index
|
||||
)
|
||||
# All masks will be padded up to the longest length
|
||||
# with 1s before final conversion to batchd tensors
|
||||
stages[stage_id].input_points_mask.append(
|
||||
torch.zeros(q.input_points.shape[1])
|
||||
)
|
||||
else:
|
||||
stages[stage_id].input_points.append(
|
||||
torch.empty(0, input_points_embedding_dim)
|
||||
)
|
||||
stages[stage_id].input_points_mask.append(torch.empty(0))
|
||||
|
||||
current_out_boxes = []
|
||||
current_out_object_ids = []
|
||||
# Set the object ids referred to by this query
|
||||
stages[stage_id].object_ids.append(q.object_ids_output)
|
||||
for object_id in q.object_ids_output:
|
||||
current_out_boxes.append(
|
||||
data.images[q.image_id].objects[object_id].bbox
|
||||
)
|
||||
current_out_object_ids.append(object_id)
|
||||
find_targets[stage_id].boxes.extend(current_out_boxes)
|
||||
find_targets[stage_id].object_ids.extend(current_out_object_ids)
|
||||
if repeats > 0:
|
||||
for _ in range(repeats):
|
||||
find_targets[stage_id].repeated_boxes.extend(current_out_boxes)
|
||||
find_targets[stage_id].num_boxes.append(len(current_out_boxes))
|
||||
find_targets[stage_id].is_exhaustive.append(q.is_exhaustive)
|
||||
|
||||
if with_seg_masks:
|
||||
current_seg_mask = []
|
||||
current_is_valid_segment = []
|
||||
for object_id in q.object_ids_output:
|
||||
seg_mask = data.images[q.image_id].objects[object_id].segment
|
||||
if seg_mask is not None:
|
||||
current_seg_mask.append(seg_mask)
|
||||
current_is_valid_segment.append(1)
|
||||
else:
|
||||
dummy_mask = torch.zeros(
|
||||
data.images[q.image_id].data.shape[-2:], dtype=torch.bool
|
||||
)
|
||||
current_seg_mask.append(dummy_mask)
|
||||
current_is_valid_segment.append(0)
|
||||
find_targets[stage_id].segments.extend(current_seg_mask)
|
||||
find_targets[stage_id].is_valid_segment.extend(current_is_valid_segment)
|
||||
else:
|
||||
# We are not loading segmentation masks
|
||||
find_targets[stage_id].segments = None
|
||||
find_targets[stage_id].is_valid_segment = None
|
||||
|
||||
if q.semantic_target is not None:
|
||||
find_targets[stage_id].semantic_segments.append(q.semantic_target)
|
||||
|
||||
offset_img_id += len(data.images)
|
||||
|
||||
# Pad input points to equal sequence lengths
|
||||
for i in range(len(stages)):
|
||||
stages[i].input_points = pad_tensor_list_to_longest(
|
||||
stages[i].input_points, dim=0, pad_val=0
|
||||
)
|
||||
# Masked-out regions indicated by 1s.
|
||||
stages[i].input_points_mask = pad_tensor_list_to_longest(
|
||||
stages[i].input_points_mask, dim=0, pad_val=1
|
||||
)
|
||||
|
||||
# Pad input boxes to equal sequence lengths
|
||||
for i in range(len(stages)):
|
||||
stages[i].input_boxes = pad_tensor_list_to_longest(
|
||||
stages[i].input_boxes, dim=0, pad_val=0
|
||||
)
|
||||
stages[i].input_boxes_label = pad_tensor_list_to_longest(
|
||||
stages[i].input_boxes_label, dim=0, pad_val=0
|
||||
)
|
||||
# Masked-out regions indicated by 1s.
|
||||
stages[i].input_boxes_mask = pad_tensor_list_to_longest(
|
||||
stages[i].input_boxes_mask, dim=0, pad_val=1
|
||||
)
|
||||
|
||||
# Convert to tensors
|
||||
for i in range(len(stages)):
|
||||
stages[i] = convert_my_tensors(stages[i])
|
||||
find_targets[i] = convert_my_tensors(find_targets[i])
|
||||
find_metadatas[i] = convert_my_tensors(find_metadatas[i])
|
||||
# get padded representation for the boxes
|
||||
find_targets[i].boxes_padded = packed_to_padded_naive(
|
||||
find_targets[i].boxes.view(-1, 4), find_targets[i].num_boxes
|
||||
)
|
||||
find_targets[i].object_ids_padded = packed_to_padded_naive(
|
||||
find_targets[i].object_ids, find_targets[i].num_boxes, fill_value=-1
|
||||
)
|
||||
|
||||
# Finalize the image batch
|
||||
# check sizes
|
||||
for img in img_batch[1:]:
|
||||
assert img.shape == img_batch[0].shape, "All images must have the same size"
|
||||
image_batch = torch.stack(img_batch)
|
||||
if load_image_in_fp16:
|
||||
# Optionally, cast the image tensors to fp16, which helps save GPU memory on
|
||||
# long videos with thousands of frames (where image tensors could be several GBs)
|
||||
image_batch = image_batch.half()
|
||||
|
||||
return {
|
||||
dict_key: BatchedDatapoint(
|
||||
img_batch=image_batch,
|
||||
find_text_batch=text_batch,
|
||||
find_inputs=stages,
|
||||
find_targets=find_targets,
|
||||
find_metadatas=find_metadatas,
|
||||
raw_images=raw_images,
|
||||
)
|
||||
}
|
||||
Reference in New Issue
Block a user