Initial commit

fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
facebook-github-bot
2025-11-18 23:07:42 -08:00
commit a13e358df4
504 changed files with 122758 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

View File

@@ -0,0 +1,455 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Transforms and data augmentation for both image + bbox.
"""
import math
import random
from typing import Iterable
import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from sam3.model.box_ops import box_xyxy_to_cxcywh
from sam3.model.data_misc import interpolate
def crop(image, target, region):
cropped_image = F.crop(image, *region)
target = target.copy()
i, j, h, w = region
# should we do something wrt the original size?
target["size"] = torch.tensor([h, w])
fields = ["labels", "area", "iscrowd", "positive_map"]
if "boxes" in target:
boxes = target["boxes"]
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
target["boxes"] = cropped_boxes.reshape(-1, 4)
target["area"] = area
fields.append("boxes")
if "input_boxes" in target:
boxes = target["input_boxes"]
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
target["input_boxes"] = cropped_boxes.reshape(-1, 4)
if "masks" in target:
# FIXME should we update the area here if there are no boxes?
target["masks"] = target["masks"][:, i : i + h, j : j + w]
fields.append("masks")
# remove elements for which the boxes or masks that have zero area
if "boxes" in target or "masks" in target:
# favor boxes selection when defining which elements to keep
# this is compatible with previous implementation
if "boxes" in target:
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
else:
keep = target["masks"].flatten(1).any(1)
for field in fields:
if field in target:
target[field] = target[field][keep]
return cropped_image, target
def hflip(image, target):
flipped_image = F.hflip(image)
w, h = image.size
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
[-1, 1, -1, 1]
) + torch.as_tensor([w, 0, w, 0])
target["boxes"] = boxes
if "input_boxes" in target:
boxes = target["input_boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
[-1, 1, -1, 1]
) + torch.as_tensor([w, 0, w, 0])
target["input_boxes"] = boxes
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
if "text_input" in target:
text_input = (
target["text_input"]
.replace("left", "[TMP]")
.replace("right", "left")
.replace("[TMP]", "right")
)
target["text_input"] = text_input
return flipped_image, target
def resize(image, target, size, max_size=None, square=False):
# size can be min_size (scalar) or (w, h) tuple
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
if square:
size = size, size
else:
size = get_size(image.size, size, max_size)
rescaled_image = F.resize(image, size)
if target is None:
return rescaled_image, None
ratios = tuple(
float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)
)
ratio_width, ratio_height = ratios
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32
)
target["boxes"] = scaled_boxes
if "input_boxes" in target:
boxes = target["input_boxes"]
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32
)
target["input_boxes"] = scaled_boxes
if "area" in target:
area = target["area"]
scaled_area = area * (ratio_width * ratio_height)
target["area"] = scaled_area
h, w = size
target["size"] = torch.tensor([h, w])
if "masks" in target:
target["masks"] = (
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0]
> 0.5
)
return rescaled_image, target
def pad(image, target, padding):
if len(padding) == 2:
# assumes that we only pad on the bottom right corners
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
else:
# left, top, right, bottom
padded_image = F.pad(image, (padding[0], padding[1], padding[2], padding[3]))
if target is None:
return padded_image, None
target = target.copy()
w, h = padded_image.size
# should we do something wrt the original size?
target["size"] = torch.tensor([h, w])
if "boxes" in target and len(padding) == 4:
boxes = target["boxes"]
boxes = boxes + torch.as_tensor(
[padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32
)
target["boxes"] = boxes
if "input_boxes" in target and len(padding) == 4:
boxes = target["input_boxes"]
boxes = boxes + torch.as_tensor(
[padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32
)
target["input_boxes"] = boxes
if "masks" in target:
if len(padding) == 2:
target["masks"] = torch.nn.functional.pad(
target["masks"], (0, padding[0], 0, padding[1])
)
else:
target["masks"] = torch.nn.functional.pad(
target["masks"], (padding[0], padding[2], padding[1], padding[3])
)
return padded_image, target
class RandomCrop:
def __init__(self, size):
self.size = size
def __call__(self, img, target):
region = T.RandomCrop.get_params(img, self.size)
return crop(img, target, region)
class RandomSizeCrop:
def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
self.min_size = min_size
self.max_size = max_size
self.respect_boxes = respect_boxes # if True we can't crop a box out
def __call__(self, img: PIL.Image.Image, target: dict):
init_boxes = len(target["boxes"])
init_boxes_tensor = target["boxes"].clone()
if self.respect_boxes and init_boxes > 0:
minW, minH, maxW, maxH = (
min(img.width, self.min_size),
min(img.width, self.min_size),
min(img.width, self.max_size),
min(img.height, self.max_size),
)
minX, minY = (
target["boxes"][:, 0].max().item() + 10.0,
target["boxes"][:, 1].max().item() + 10.0,
)
minX = min(img.width, minX)
minY = min(img.height, minY)
maxX, maxY = (
target["boxes"][:, 2].min().item() - 10,
target["boxes"][:, 3].min().item() - 10,
)
maxX = max(0.0, maxX)
maxY = max(0.0, maxY)
minW = max(minW, minX - maxX)
minH = max(minH, minY - maxY)
w = random.uniform(minW, max(minW, maxW))
h = random.uniform(minH, max(minH, maxH))
if minX > maxX:
# i = random.uniform(max(0, minX - w + 1), max(maxX, max(0, minX - w + 1)))
i = random.uniform(max(0, minX - w), max(maxX, max(0, minX - w)))
else:
i = random.uniform(
max(0, minX - w + 1), max(maxX - 1, max(0, minX - w + 1))
)
if minY > maxY:
# j = random.uniform(max(0, minY - h + 1), max(maxY, max(0, minY - h + 1)))
j = random.uniform(max(0, minY - h), max(maxY, max(0, minY - h)))
else:
j = random.uniform(
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
)
result_img, result_target = crop(img, target, [j, i, h, w])
assert (
len(result_target["boxes"]) == init_boxes
), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
return result_img, result_target
else:
w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, (h, w))
result_img, result_target = crop(img, target, region)
return result_img, result_target
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, img, target):
image_width, image_height = img.size
crop_height, crop_width = self.size
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
class RandomHorizontalFlip:
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return hflip(img, target)
return img, target
class RandomResize:
def __init__(self, sizes, max_size=None, square=False):
if isinstance(sizes, int):
sizes = (sizes,)
assert isinstance(sizes, Iterable)
self.sizes = list(sizes)
self.max_size = max_size
self.square = square
def __call__(self, img, target=None):
size = random.choice(self.sizes)
return resize(img, target, size, self.max_size, square=self.square)
class RandomPad:
def __init__(self, max_pad):
self.max_pad = max_pad
def __call__(self, img, target):
pad_x = random.randint(0, self.max_pad)
pad_y = random.randint(0, self.max_pad)
return pad(img, target, (pad_x, pad_y))
class PadToSize:
def __init__(self, size):
self.size = size
def __call__(self, img, target):
w, h = img.size
pad_x = self.size - w
pad_y = self.size - h
assert pad_x >= 0 and pad_y >= 0
pad_left = random.randint(0, pad_x)
pad_right = pad_x - pad_left
pad_top = random.randint(0, pad_y)
pad_bottom = pad_y - pad_top
return pad(img, target, (pad_left, pad_top, pad_right, pad_bottom))
class Identity:
def __call__(self, img, target):
return img, target
class RandomSelect:
"""
Randomly selects between transforms1 and transforms2,
with probability p for transforms1 and (1 - p) for transforms2
"""
def __init__(self, transforms1=None, transforms2=None, p=0.5):
self.transforms1 = transforms1 or Identity()
self.transforms2 = transforms2 or Identity()
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return self.transforms1(img, target)
return self.transforms2(img, target)
class ToTensor:
def __call__(self, img, target):
return F.to_tensor(img), target
class RandomErasing:
def __init__(self, *args, **kwargs):
self.eraser = T.RandomErasing(*args, **kwargs)
def __call__(self, img, target):
return self.eraser(img), target
class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target=None):
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image, None
target = target.copy()
h, w = image.shape[-2:]
if "boxes" in target:
boxes = target["boxes"]
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
target["boxes"] = boxes
if "input_boxes" in target:
boxes = target["input_boxes"]
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
target["input_boxes"] = boxes
return image, target
class RemoveDifficult:
def __init__(self, enabled=False):
self.remove_difficult = enabled
def __call__(self, image, target=None):
if target is None:
return image, None
target = target.copy()
keep = ~target["iscrowd"].to(torch.bool) | (not self.remove_difficult)
if "boxes" in target:
target["boxes"] = target["boxes"][keep]
target["labels"] = target["labels"][keep]
target["iscrowd"] = target["iscrowd"][keep]
return image, target
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
def get_random_resize_scales(size, min_size, rounded):
stride = 128 if rounded else 32
min_size = int(stride * math.ceil(min_size / stride))
scales = list(range(min_size, size + 1, stride))
return scales
def get_random_resize_max_size(size, ratio=5 / 3):
max_size = round(ratio * size)
return max_size

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,607 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import logging
import random
from collections import defaultdict
from typing import List, Optional, Union
import torch
from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object
class FilterDataPointQueries:
find_ids_to_filter: set = None
get_ids_to_filter: set = None
obj_ids_to_filter: set = None # stored as pairs (img_id, obj_id)
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
"""
Compute set of query ids to keep, for both find and get queries
"""
raise NotImplementedError
def _do_filter_query(self, query: Union[FindQuery], query_id: int):
assert self.find_ids_to_filter is not None
return query_id in self.find_ids_to_filter
class FilterQueryWithText(FilterDataPointQueries):
"""
Filter all datapoints which have query text in a specified list of exluded terms
"""
def __init__(
self, exclude_find_keys: List[str] = None, exclude_get_keys: List[str] = None
):
self.find_filter_keys = exclude_find_keys if exclude_find_keys else []
self.get_filter_keys = exclude_get_keys if exclude_get_keys else []
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
del_find_ids = []
del_get_ids = []
for i, f_q in enumerate(datapoint.find_queries):
if f_q.query_text in self.find_filter_keys:
del_find_ids.append(i)
self.find_ids_to_filter = set(del_find_ids)
class KeepMaxNumFindQueries(FilterDataPointQueries):
def __init__(
self, max_num_find_queries: int, retain_positive_queries: bool = False
):
self.max_num_find_queries = max_num_find_queries
self.retain_positive_queries = retain_positive_queries
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
self.obj_ids_to_filter = set()
num_find_queries = len(datapoint.find_queries)
if num_find_queries <= self.max_num_find_queries:
self.find_ids_to_filter = set() # keep all find queries
return
if not self.retain_positive_queries:
all_find_query_ids = list(range(num_find_queries))
num_queries_to_filter = max(0, num_find_queries - self.max_num_find_queries)
query_ids_to_filter = random.sample(
all_find_query_ids, k=num_queries_to_filter
)
else:
# keep up to max_num_find_queries postive find queries and fill
# the remaining slots (if any) with negative find queries
pos_find_ids, neg_find_ids = [], []
for i, f_q in enumerate(datapoint.find_queries):
# Negative finds return an empty list of object_ids_output
if len(f_q.object_ids_output) == 0:
neg_find_ids.append(i)
else:
pos_find_ids.append(i)
if len(pos_find_ids) >= self.max_num_find_queries:
# we have more positive find queries than `max_num_find_queries`,
# so we subsample postive find queries and remove all negative find queries
num_queries_to_filter = len(pos_find_ids) - self.max_num_find_queries
query_ids_to_filter = random.sample(
pos_find_ids, k=num_queries_to_filter
)
query_ids_to_filter.extend(neg_find_ids)
else:
# we have fewer positive find queries than `max_num_find_queries`
# so we need to fill the remaining with negative find queries
num_queries_to_filter = num_find_queries - self.max_num_find_queries
query_ids_to_filter = random.sample(
neg_find_ids, k=num_queries_to_filter
)
assert len(query_ids_to_filter) == num_find_queries - self.max_num_find_queries
self.find_ids_to_filter = set(query_ids_to_filter)
class KeepMaxNumFindQueriesVideo(FilterDataPointQueries):
def __init__(
self,
video_mosaic_max_num_find_queries_per_frame: int,
retain_positive_queries: bool = False,
):
self.video_mosaic_max_num_find_queries_per_frame = (
video_mosaic_max_num_find_queries_per_frame
)
self.retain_positive_queries = retain_positive_queries
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
self.obj_ids_to_filter = set()
num_find_queries = len(datapoint.find_queries)
findQueries_to_imageIds = defaultdict(list)
max_queries_per_frame = True
for i, f_q in enumerate(datapoint.find_queries):
findQueries_to_imageIds[f_q.image_id].append(i)
if (
len(findQueries_to_imageIds[f_q.image_id])
> self.video_mosaic_max_num_find_queries_per_frame
):
max_queries_per_frame = False
if max_queries_per_frame:
self.find_ids_to_filter = set()
return
num_frames = len(findQueries_to_imageIds)
findQueries_0 = findQueries_to_imageIds[0]
num_find_queries_0 = len(findQueries_0)
max_num_find_queries_per_frame = (
self.video_mosaic_max_num_find_queries_per_frame
)
if not self.retain_positive_queries:
find_query_ids_0 = list(range(num_find_queries_0))
num_queries_to_filter = max(
0, num_find_queries_0 - max_num_find_queries_per_frame
)
query_ids_to_filter_0 = random.sample(
find_query_ids_0, k=num_queries_to_filter
)
else:
# keep up to max_num_find_queries postive find queries and fill
# the remaining slots (if any) with negative find queries
pos_find_ids_0, neg_find_ids_0 = [], []
for i, f_q_id in enumerate(findQueries_0):
f_q = datapoint.find_queries[f_q_id]
# Negative finds return an empty list of object_ids_output
if len(f_q.object_ids_output) == 0:
neg_find_ids_0.append(i)
else:
pos_find_ids_0.append(i)
if len(pos_find_ids_0) >= max_num_find_queries_per_frame:
# we have more positive find queries than `max_num_find_queries`,
# so we subsample postive find queries and remove all negative find queries
num_queries_to_filter = (
len(pos_find_ids_0) - max_num_find_queries_per_frame
)
query_ids_to_filter_0 = random.sample(
pos_find_ids_0, k=num_queries_to_filter
)
query_ids_to_filter_0.extend(neg_find_ids_0)
else:
# we have fewer positive find queries than `max_num_find_queries`
# so we need to fill the remaining with negative find queries
num_queries_to_filter = (
num_find_queries_0 - max_num_find_queries_per_frame
)
query_ids_to_filter_0 = random.sample(
neg_find_ids_0, k=num_queries_to_filter
)
# get based on frame 0 all find queries from all the frames with the same indices as in frame 0
query_ids_to_filter = []
for i in range(num_frames):
findQueries_i = findQueries_to_imageIds[i]
query_ids_to_filter.extend(
[findQueries_i[j] for j in query_ids_to_filter_0]
)
assert (
len(query_ids_to_filter)
== num_find_queries
- self.video_mosaic_max_num_find_queries_per_frame * num_frames
)
self.find_ids_to_filter = set(query_ids_to_filter)
class KeepSemanticFindQueriesOnly(FilterDataPointQueries):
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
self.obj_ids_to_filter = set()
self.find_ids_to_filter = {
i for i, q in enumerate(datapoint.find_queries) if q.input_bbox is not None
} # filter (remove) geometric find queries (whose input_bbox is not None)
# Keep all get queries which don't depend on filtered finds
class KeepUnaryFindQueriesOnly(FilterDataPointQueries):
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
self.obj_ids_to_filter = set()
self.find_ids_to_filter = set()
# Keep all get queries which don't depend on filtered finds
class FilterZeroBoxQueries(FilterDataPointQueries):
"""
Filters all find queries which predict a box with zero area
"""
@staticmethod
def _is_zero_area_object(obj: Object):
# Check if height or width of bounding box is zero
bbox = obj.bbox # Assume in XYXY format
height = bbox[..., 3].item() - bbox[..., 1].item()
width = bbox[..., 2].item() - bbox[..., 0].item()
return height == 0 or width == 0
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
# Find objects with zero area
# Assume only one image per datapoint
image_objects = datapoint.images[0].objects
exclude_objects = {
obj_id
for obj_id, obj in enumerate(image_objects)
if self._is_zero_area_object(obj)
}
# If a query predicts an object with zero area, drop the whole find query
del_find_ids = []
for i, f_q in enumerate(datapoint.find_queries):
f_q_objects = set(f_q.object_ids_output)
if len(exclude_objects.intersection(f_q_objects)) > 0:
del_find_ids.append(i)
self.find_ids_to_filter = set(del_find_ids)
class FilterFindQueriesWithTooManyOut(FilterDataPointQueries):
"""
Filters all find queries which have more than a specified number of objects in the output
"""
def __init__(self, max_num_objects: int):
self.max_num_objects = max_num_objects
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
# If a query predicts more than max_num_objects, drop the whole find query
del_find_ids = []
for i, f_q in enumerate(datapoint.find_queries):
if len(f_q.object_ids_output) > self.max_num_objects:
del_find_ids.append(i)
self.find_ids_to_filter = set(del_find_ids)
class FilterEmptyTargets(FilterDataPointQueries):
"""
Filters all targets which have zero area
"""
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
for img_id in range(len(datapoint.images)):
for obj_id, obj in enumerate(datapoint.images[img_id].objects):
if obj.area < 1e-6:
self.obj_ids_to_filter.add((img_id, obj_id))
self.find_ids_to_filter = set()
class FilterNonExhaustiveFindQueries(FilterDataPointQueries):
"""
Filters all find queries which are non-exhaustive
"""
def __init__(self, exhaustivity_type: str):
"""
Args:
exhaustivity_type: Can be "pixel" or "instance":
-pixel: filter queries where the union of all segments covers every pixel belonging to target class
-instance: filter queries where there are non-separable or non annotated instances
Note that instance exhaustivity implies pixel exhaustivity
"""
assert exhaustivity_type in ["pixel", "instance"]
self.exhaustivity_type = exhaustivity_type
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
# If a query predicts more than max_num_objects, drop the whole find query
del_find_ids = []
for i, f_q in enumerate(datapoint.find_queries):
if self.exhaustivity_type == "instance":
if not f_q.is_exhaustive:
del_find_ids.append(i)
elif self.exhaustivity_type == "pixel":
if f_q.is_pixel_exhaustive is not None and not f_q.is_pixel_exhaustive:
del_find_ids.append(i)
else:
raise RuntimeError(
f"Unknown exhaustivity type {self.exhaustivity_type}"
)
self.find_ids_to_filter = set(del_find_ids)
class FilterInvalidGeometricQueries(FilterDataPointQueries):
"""
Filters geometric queries whose output got deleted (eg due to cropping)
"""
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
# If a query predicts more than max_num_objects, drop the whole find query
del_find_ids = []
for i, f_q in enumerate(datapoint.find_queries):
if f_q.input_bbox is not None and f_q.query_text == "geometric":
if len(f_q.object_ids_output) == 0:
del_find_ids.append(i)
self.find_ids_to_filter = set(del_find_ids)
class FlexibleFilterFindGetQueries:
def __init__(
self, query_filter: FilterDataPointQueries, enabled: bool = True
) -> None:
self.query_filter = query_filter
self.enabled = enabled
def __call__(self, datapoint, **kwargs):
if not self.enabled:
return datapoint
# Identify all queries to filter
self.query_filter.identify_queries_to_filter(datapoint=datapoint)
del_find_ids = []
del_get_ids = []
for i, f_q in enumerate(datapoint.find_queries):
if self.query_filter._do_filter_query(f_q, i):
datapoint.find_queries[i] = None
del_find_ids.append(i)
new_find_queries = []
new_get_queries = []
find_old_to_new_map = {}
get_old_to_new_map = {}
find_counter = 0
get_counter = 0
for i, f_q in enumerate(datapoint.find_queries):
if f_q is not None:
find_old_to_new_map[i] = find_counter
find_counter += 1
new_find_queries.append(f_q)
start_with_zero_check = False
for n_f_q in new_find_queries:
if n_f_q.query_processing_order == 0:
start_with_zero_check = True
break
if len(new_find_queries) == 0:
start_with_zero_check = True
assert (
start_with_zero_check
), "Invalid Find queries, they need to start at query_processing_order = 0"
datapoint.find_queries = new_find_queries
if len(datapoint.find_queries) == 0:
print("Warning: No find queries left in datapoint, this is not allowed")
print("Filtering function:", self.query_filter)
print("Datapoint:", datapoint)
raise ValueError
# The deletion may have removed intermediate steps, so we need to remap to make them contiguous again
all_stages = sorted(
list(set(q.query_processing_order for q in datapoint.find_queries))
)
stage_map = {qpo: i for i, qpo in enumerate(all_stages)}
for i in range(len(datapoint.find_queries)):
qpo = datapoint.find_queries[i].query_processing_order
datapoint.find_queries[i].query_processing_order = stage_map[qpo]
# Final step, clear up objects that are not used anymore
for img_id in range(len(datapoint.images)):
all_objects_ids = set(
i
for find in datapoint.find_queries
for i in find.object_ids_output
if find.image_id == img_id
)
unused_ids = (
set(range(len(datapoint.images[img_id].objects))) - all_objects_ids
)
for tgt_img_id, tgt_obj_id in self.query_filter.obj_ids_to_filter:
if tgt_img_id == img_id:
unused_ids.add(tgt_obj_id)
if len(unused_ids) > 0:
old_objects = datapoint.images[img_id].objects
object_old_to_new_map = {}
new_objects = []
for i, o in enumerate(old_objects):
if i not in unused_ids:
object_old_to_new_map[i] = len(new_objects)
new_objects.append(o)
datapoint.images[img_id].objects = new_objects
# Remap the outputs of the find queries
affected_find_queries_ids = set()
object_old_to_new_map_per_query = {}
for fid, find in enumerate(datapoint.find_queries):
if find.image_id == img_id:
old_object_ids_output = find.object_ids_output
object_old_to_new_map_per_query[fid] = {}
find.object_ids_output = []
for oid, old_obj_id in enumerate(old_object_ids_output):
if old_obj_id not in unused_ids:
new_obj_id = object_old_to_new_map[old_obj_id]
find.object_ids_output.append(new_obj_id)
object_old_to_new_map_per_query[fid][oid] = (
len(find.object_ids_output) - 1
)
affected_find_queries_ids.add(fid)
# finally remove unused images
all_imgs_to_keep = set()
for f_q in datapoint.find_queries:
all_imgs_to_keep.add(f_q.image_id)
old_img_id_to_new_img_id = {}
new_images = []
for img_id, img in enumerate(datapoint.images):
if img_id in all_imgs_to_keep:
old_img_id_to_new_img_id[img_id] = len(new_images)
new_images.append(img)
datapoint.images = new_images
for f_q in datapoint.find_queries:
f_q.image_id = old_img_id_to_new_img_id[f_q.image_id]
return datapoint
class AddPrefixSuffixToFindText:
"""
Add prefix or suffix strings to find query text on the fly.
If `condition_on_text` is True, the prefix or suffix strings are only added
to those find query text in `condition_text_list` (case-insensitive).
"""
def __init__(
self,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
condition_on_text: bool = False,
condition_text_list: Optional[List[str]] = None,
enabled: bool = True,
) -> None:
self.prefix = prefix
self.suffix = suffix
self.condition_on_text = condition_on_text
if self.condition_on_text:
assert condition_text_list is not None
self.condition_text_set = {s.lower().strip() for s in condition_text_list}
self.enabled = enabled
if self.enabled:
logging.info(
f"AddPrefixSuffixToFindText: prefix={prefix}, suffix={suffix}, "
f"condition_on_text={condition_on_text}, condition_text_list={condition_text_list}"
)
def __call__(self, datapoint, **kwargs):
if not self.enabled:
return datapoint
for find in datapoint.find_queries:
if find.query_text == "geometric":
# skip geometric find queries
continue
if (
self.condition_on_text
and find.query_text.lower().strip() not in self.condition_text_set
):
# if condition_on_text is True, skip those queries not in condition_text_set
continue
# add prefix and/or suffix strings to the find query text
if self.prefix is not None:
find.query_text = self.prefix + find.query_text
if self.suffix is not None:
find.query_text = find.query_text + self.suffix
return datapoint
class FilterCrowds(FilterDataPointQueries):
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
"""
Compute set of query ids to keep, for both find and get queries
"""
self.obj_ids_to_filter = set()
self.find_ids_to_filter = set()
# self.get_ids_to_filter = set()
for img_id, img in enumerate(datapoint.images):
for obj_id, obj in enumerate(img.objects):
if obj.is_crowd:
self.obj_ids_to_filter.add((img_id, obj_id))
class TextQueryToVisual:
"""
Transform a test query to a visual query (with some proba), using any of the output targets as the prompt
"""
def __init__(self, probability, keep_text_queries=False) -> None:
self.probability = probability
assert 0 <= probability <= 1
self.keep_text_queries = keep_text_queries
def __call__(self, datapoint: Datapoint, **kwargs):
for find in datapoint.find_queries:
if find.input_bbox is not None or find.input_points is not None:
# skip geometric find queries
continue
if len(find.object_ids_output) == 0:
# Can't create a visual query, skip
continue
if find.query_processing_order > 0:
# Second stage query, can't use
continue
if random.random() > self.probability:
continue
selected_vq_id = random.choice(find.object_ids_output)
img_id = find.image_id
find.input_bbox = datapoint.images[img_id].objects[selected_vq_id].bbox
find.input_bbox_label = torch.ones(1, dtype=torch.bool)
if not self.keep_text_queries:
find.query_text = "visual"
return datapoint
class RemoveInputBoxes:
"""
Remove input boxes from find queries
"""
def __init__(self) -> None:
pass
def __call__(self, datapoint: Datapoint, **kwargs):
for find in datapoint.find_queries:
if find.input_bbox is None:
continue
if find.query_text == "geometric":
print("Warning: removing input box from geometric find query")
find.input_bbox = None
return datapoint
class OverwriteTextQuery:
"""
With some probability, overwrite the text query with a custom text
"""
def __init__(self, target_text, probability=1.0) -> None:
self.probability = probability
self.target_text = target_text
assert 0 <= probability <= 1
def __call__(self, datapoint: Datapoint, **kwargs):
for find in datapoint.find_queries:
if random.random() > self.probability:
continue
find.query_text = self.target_text
return datapoint

View File

@@ -0,0 +1,345 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import cv2
import numpy as np
import torch
from PIL import Image as PILImage
from pycocotools import mask as mask_util
from sam3.train.data.sam3_image_dataset import Datapoint
from torchvision.ops import masks_to_boxes
def sample_points_from_rle(rle, n_points, mode, box=None, normalize=True):
"""
Sample random points from a mask provided in COCO RLE format. 'mode'
'mode' is in ["centered", "random_mask", "random_box"]
"centered": points are sampled farthest from the mask edges and each other
"random_mask": points are sampled uniformly from the mask
"random_box": points are sampled uniformly from the annotation's box
'box' must be provided if 'mode' is "random_box".
If 'normalize' is true, points are in [0,1], relative to mask h,w.
"""
mask = np.ascontiguousarray(mask_util.decode(rle))
points = sample_points_from_mask(mask, n_points, mode, box)
if normalize:
h, w = mask.shape
norm = np.array([w, h, 1.0])[None, :]
points = points / norm
return points
def sample_points_from_mask(mask, n_points, mode, box=None):
if mode == "centered":
points = center_positive_sample(mask, n_points)
elif mode == "random_mask":
points = uniform_positive_sample(mask, n_points)
elif mode == "random_box":
assert box is not None, "'random_box' mode requires a provided box."
points = uniform_sample_from_box(mask, box, n_points)
else:
raise ValueError(f"Unknown point sampling mode {mode}.")
return points
def uniform_positive_sample(mask, n_points):
"""
Samples positive points uniformly from the mask. Only integer pixel
values are sampled.
"""
# Sampling directly from the uncompressed RLE would be faster but is
# likely unnecessary.
mask_points = np.stack(np.nonzero(mask), axis=0).transpose(1, 0)
assert len(mask_points) > 0, "Can't sample positive points from an empty mask."
selected_idxs = np.random.randint(low=0, high=len(mask_points), size=n_points)
selected_points = mask_points[selected_idxs]
selected_points = selected_points[:, ::-1] # (y, x) -> (x, y)
labels = np.ones((len(selected_points), 1))
selected_points = np.concatenate([selected_points, labels], axis=1)
return selected_points
def center_positive_sample(mask, n_points):
"""
Samples points farthest from mask edges (by distance transform)
and subsequent points also farthest from each other. Each new point
sampled is treated as an edge for future points. Edges of the image are
treated as edges of the mask.
"""
# Pad mask by one pixel on each end to assure distance transform
# avoids edges
padded_mask = np.pad(mask, 1)
points = []
for _ in range(n_points):
assert np.max(mask) > 0, "Can't sample positive points from an empty mask."
dist = cv2.distanceTransform(padded_mask, cv2.DIST_L2, 0)
point = np.unravel_index(dist.argmax(), dist.shape)
# Mark selected point as background so next point avoids it
padded_mask[point[0], point[1]] = 0
points.append(point[::-1]) # (y, x) -> (x, y)
points = np.stack(points, axis=0)
points = points - 1 # Subtract left/top padding of 1
labels = np.ones((len(points), 1))
points = np.concatenate([points, labels], axis=1)
return points
def uniform_sample_from_box(mask, box, n_points):
"""
Sample points uniformly from the provided box. The points' labels
are determined by the provided mask. Does not guarantee a positive
point is sampled. The box is assumed unnormalized in XYXY format.
Points are sampled at integer values.
"""
# Since lower/right edges are exclusive, ceil can be applied to all edges
int_box = np.ceil(box)
x = np.random.randint(low=int_box[0], high=int_box[2], size=n_points)
y = np.random.randint(low=int_box[1], high=int_box[3], size=n_points)
labels = mask[y, x]
points = np.stack([x, y, labels], axis=1)
return points
def rescale_box_xyxy(box, factor, imsize=None):
"""
Rescale a box providing in unnormalized XYXY format, fixing the center.
If imsize is provided, clamp to the image.
"""
cx, cy = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2
w, h = box[2] - box[0], box[3] - box[1]
new_w, new_h = factor * w, factor * h
new_x0, new_y0 = cx - new_w / 2, cy - new_h / 2
new_x1, new_y1 = cx + new_w / 2, cy + new_h / 2
if imsize is not None:
new_x0 = max(min(new_x0, imsize[1]), 0)
new_x1 = max(min(new_x1, imsize[1]), 0)
new_y0 = max(min(new_y0, imsize[0]), 0)
new_y1 = max(min(new_y1, imsize[0]), 0)
return [new_x0, new_y0, new_x1, new_y1]
def noise_box(box, im_size, box_noise_std, box_noise_max, min_box_area):
if box_noise_std <= 0.0:
return box
noise = box_noise_std * torch.randn(size=(4,))
w, h = box[2] - box[0], box[3] - box[1]
scale_factor = torch.tensor([w, h, w, h])
noise = noise * scale_factor
if box_noise_max is not None:
noise = torch.clamp(noise, -box_noise_max, box_noise_max)
input_box = box + noise
# Clamp to maximum image size
img_clamp = torch.tensor([im_size[1], im_size[0], im_size[1], im_size[0]])
input_box = torch.maximum(input_box, torch.zeros_like(input_box))
input_box = torch.minimum(input_box, img_clamp)
if (input_box[2] - input_box[0]) * (input_box[3] - input_box[1]) <= min_box_area:
return box
return input_box
class RandomGeometricInputsAPI:
"""
For geometric queries, replaces the input box or points with a random
one sampled from the GT mask. Segments must be provided for objects
that are targets of geometric queries, and must be binary masks. Existing
point and box queries in the datapoint will be ignored and completely replaced.
Will sample points and boxes in XYXY format in absolute pixel space.
Geometry queries are currently determined by taking any query whose
query text is a set value.
Args:
num_points (int or (int, int)): how many points to sample. If a tuple,
sample a random number of points uniformly over the inclusive range.
box_chance (float): fraction of time a box is sampled. A box will replace
one sampled point.
box_noise_std (float): if greater than 0, add noise to the sampled boxes
with this std. Noise is relative to the length of the box side.
box_noise_max (int): if not none, truncate any box noise larger than this
in terms of absolute pixels.
resample_box_from_mask (bool): if True, any sampled box will be determined
by finding the extrema of the provided mask. If False, the bbox provided
in the target object will be used.
point_sample_mode (str): In ["centered", "random_mask", "random_box"],
controlling how points are sampled:
"centered": points are sampled farthest from the mask edges and each other
"random_mask": points are sampled uniformly from the mask
"random_box": points are sampled uniformly from the annotation's box
Note that "centered" may be too slow for on-line generation.
geometric_query_str (str): what string in query_text indicates a
geometry query.
minimum_box_area (float): sampled boxes with area this size or smaller after
noising will use the original box instead. It is the input's responsibility
to avoid original boxes that violate necessary area bounds.
concat_points (bool): if True, any sampled points will be added to existing
ones instead of replacing them.
"""
def __init__(
self,
num_points,
box_chance,
box_noise_std=0.0,
box_noise_max=None,
minimum_box_area=0.0,
resample_box_from_mask=False,
point_sample_mode="random_mask",
sample_box_scale_factor=1.0,
geometric_query_str="geometric",
concat_points=False,
):
self.num_points = num_points
if not isinstance(self.num_points, int):
# Convert from inclusive range to exclusive range expected by torch
self.num_points[1] += 1
self.num_points = tuple(self.num_points)
self.box_chance = box_chance
self.box_noise_std = box_noise_std
self.box_noise_max = box_noise_max
self.minimum_box_area = minimum_box_area
self.resample_box_from_mask = resample_box_from_mask
self.point_sample_mode = point_sample_mode
assert point_sample_mode in [
"centered",
"random_mask",
"random_box",
], "Unknown point sample mode."
self.geometric_query_str = geometric_query_str
self.concat_points = concat_points
self.sample_box_scale_factor = sample_box_scale_factor
def _sample_num_points_and_if_box(self):
if isinstance(self.num_points, tuple):
n_points = torch.randint(
low=self.num_points[0], high=self.num_points[1], size=(1,)
).item()
else:
n_points = self.num_points
if self.box_chance > 0.0:
use_box = torch.rand(size=(1,)).item() < self.box_chance
n_points -= int(use_box) # box stands in for one point
else:
use_box = False
return n_points, use_box
def _get_original_box(self, target_object):
if not self.resample_box_from_mask:
return target_object.bbox
mask = target_object.segment
return masks_to_boxes(mask[None, :, :])[0]
def _get_target_object(self, datapoint, query):
img = datapoint.images[query.image_id]
targets = query.object_ids_output
assert (
len(targets) == 1
), "Geometric queries only support a single target object."
target_idx = targets[0]
return img.objects[target_idx]
def __call__(self, datapoint, **kwargs):
for query in datapoint.find_queries:
if query.query_text != self.geometric_query_str:
continue
target_object = self._get_target_object(datapoint, query)
n_points, use_box = self._sample_num_points_and_if_box()
box = self._get_original_box(target_object)
mask = target_object.segment
if n_points > 0:
# FIXME: The conversion to numpy and back to reuse code
# is awkward, but this is all in the dataloader worker anyway
# on CPU and so I don't think it should matter.
if self.sample_box_scale_factor != 1.0:
sample_box = rescale_box_xyxy(
box.numpy(), self.sample_box_scale_factor, mask.shape
)
else:
sample_box = box.numpy()
input_points = sample_points_from_mask(
mask.numpy(),
n_points,
self.point_sample_mode,
sample_box,
)
input_points = torch.as_tensor(input_points)
input_points = input_points[None, :, :]
if self.concat_points and query.input_points is not None:
input_points = torch.cat([query.input_points, input_points], dim=1)
else:
input_points = query.input_points if self.concat_points else None
if use_box:
w, h = datapoint.images[query.image_id].size
input_box = noise_box(
box,
(h, w),
box_noise_std=self.box_noise_std,
box_noise_max=self.box_noise_max,
min_box_area=self.minimum_box_area,
)
input_box = input_box[None, :]
else:
input_box = query.input_bbox if self.concat_points else None
query.input_points = input_points
query.input_bbox = input_box
return datapoint
class RandomizeInputBbox:
"""
Simplified version of the geometric transform that only deals with input boxes
"""
def __init__(
self,
box_noise_std=0.0,
box_noise_max=None,
minimum_box_area=0.0,
):
self.box_noise_std = box_noise_std
self.box_noise_max = box_noise_max
self.minimum_box_area = minimum_box_area
def __call__(self, datapoint: Datapoint, **kwargs):
for query in datapoint.find_queries:
if query.input_bbox is None:
continue
img = datapoint.images[query.image_id].data
if isinstance(img, PILImage.Image):
w, h = img.size
else:
assert isinstance(img, torch.Tensor)
h, w = img.shape[-2:]
for box_id in range(query.input_bbox.shape[0]):
query.input_bbox[box_id, :] = noise_box(
query.input_bbox[box_id, :].view(4),
(h, w),
box_noise_std=self.box_noise_std,
box_noise_max=self.box_noise_max,
min_box_area=self.minimum_box_area,
).view(1, 4)
return datapoint

View File

@@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import numpy as np
import pycocotools.mask as mask_utils
import torch
import torchvision.transforms.functional as F
from PIL import Image as PILImage
from sam3.model.box_ops import masks_to_boxes
from sam3.train.data.sam3_image_dataset import Datapoint
class InstanceToSemantic(object):
"""Convert instance segmentation to semantic segmentation."""
def __init__(self, delete_instance=True, use_rle=False):
self.delete_instance = delete_instance
self.use_rle = use_rle
def __call__(self, datapoint: Datapoint, **kwargs):
for fquery in datapoint.find_queries:
h, w = datapoint.images[fquery.image_id].size
if self.use_rle:
all_segs = [
datapoint.images[fquery.image_id].objects[obj_id].segment
for obj_id in fquery.object_ids_output
]
if len(all_segs) > 0:
# we need to double check that all rles are the correct size
# Otherwise cocotools will fail silently to an empty [0,0] mask
for seg in all_segs:
assert seg["size"] == all_segs[0]["size"], (
"Instance segments have inconsistent sizes. "
f"Found sizes {seg['size']} and {all_segs[0]['size']}"
)
fquery.semantic_target = mask_utils.merge(all_segs)
else:
# There is no good way to create an empty RLE of the correct size
# We resort to converting an empty box to RLE
fquery.semantic_target = mask_utils.frPyObjects(
np.array([[0, 0, 0, 0]], dtype=np.float64), h, w
)[0]
else:
# `semantic_target` is uint8 and remains uint8 throughout the transforms
# (it contains binary 0 and 1 values just like `segment` for each object)
fquery.semantic_target = torch.zeros((h, w), dtype=torch.uint8)
for obj_id in fquery.object_ids_output:
segment = datapoint.images[fquery.image_id].objects[obj_id].segment
if segment is not None:
assert (
isinstance(segment, torch.Tensor)
and segment.dtype == torch.uint8
)
fquery.semantic_target |= segment
if self.delete_instance:
for img in datapoint.images:
for obj in img.objects:
del obj.segment
obj.segment = None
return datapoint
class RecomputeBoxesFromMasks:
"""Recompute bounding boxes from masks."""
def __call__(self, datapoint: Datapoint, **kwargs):
for img in datapoint.images:
for obj in img.objects:
# Note: if the mask is empty, the bounding box will be undefined
# The empty targets should be subsequently filtered
obj.bbox = masks_to_boxes(obj.segment)
obj.area = obj.segment.sum().item()
return datapoint
class DecodeRle:
"""This transform decodes RLEs into binary segments.
Implementing it as a transforms allows lazy loading. Some transforms (eg query filters)
may be deleting masks, so decoding them from the beginning is wasteful.
This transforms needs to be called before any kind of geometric manipulation
"""
def __call__(self, datapoint: Datapoint, **kwargs):
imgId2size = {}
warning_shown = False
for imgId, img in enumerate(datapoint.images):
if isinstance(img.data, PILImage.Image):
img_w, img_h = img.data.size
elif isinstance(img.data, torch.Tensor):
img_w, img_h = img.data.shape[-2:]
else:
raise RuntimeError(f"Unexpected image type {type(img.data)}")
imgId2size[imgId] = (img_h, img_w)
for obj in img.objects:
if obj.segment is not None and not isinstance(
obj.segment, torch.Tensor
):
if mask_utils.area(obj.segment) == 0:
print("Warning, empty mask found, approximating from box")
obj.segment = torch.zeros(img_h, img_w, dtype=torch.uint8)
x1, y1, x2, y2 = obj.bbox.int().tolist()
obj.segment[y1 : max(y2, y1 + 1), x1 : max(x1 + 1, x2)] = 1
else:
obj.segment = mask_utils.decode(obj.segment)
# segment is uint8 and remains uint8 throughout the transforms
obj.segment = torch.tensor(obj.segment).to(torch.uint8)
if list(obj.segment.shape) != [img_h, img_w]:
# Should not happen often, but adding for security
if not warning_shown:
print(
f"Warning expected instance segmentation size to be {[img_h, img_w]} but found {list(obj.segment.shape)}"
)
# Printing only once per datapoint to avoid spam
warning_shown = True
obj.segment = F.resize(
obj.segment[None], (img_h, img_w)
).squeeze(0)
assert list(obj.segment.shape) == [img_h, img_w]
warning_shown = False
for query in datapoint.find_queries:
if query.semantic_target is not None and not isinstance(
query.semantic_target, torch.Tensor
):
query.semantic_target = mask_utils.decode(query.semantic_target)
# segment is uint8 and remains uint8 throughout the transforms
query.semantic_target = torch.tensor(query.semantic_target).to(
torch.uint8
)
if tuple(query.semantic_target.shape) != imgId2size[query.image_id]:
if not warning_shown:
print(
f"Warning expected semantic segmentation size to be {imgId2size[query.image_id]} but found {tuple(query.semantic_target.shape)}"
)
# Printing only once per datapoint to avoid spam
warning_shown = True
query.semantic_target = F.resize(
query.semantic_target[None], imgId2size[query.image_id]
).squeeze(0)
assert tuple(query.semantic_target.shape) == imgId2size[query.image_id]
return datapoint