Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
1
sam3/train/transforms/__init__.py
Normal file
1
sam3/train/transforms/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
455
sam3/train/transforms/basic.py
Normal file
455
sam3/train/transforms/basic.py
Normal file
@@ -0,0 +1,455 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Transforms and data augmentation for both image + bbox.
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
from typing import Iterable
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
from sam3.model.box_ops import box_xyxy_to_cxcywh
|
||||
from sam3.model.data_misc import interpolate
|
||||
|
||||
|
||||
def crop(image, target, region):
|
||||
cropped_image = F.crop(image, *region)
|
||||
|
||||
target = target.copy()
|
||||
i, j, h, w = region
|
||||
|
||||
# should we do something wrt the original size?
|
||||
target["size"] = torch.tensor([h, w])
|
||||
|
||||
fields = ["labels", "area", "iscrowd", "positive_map"]
|
||||
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
||||
cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
|
||||
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
||||
cropped_boxes = cropped_boxes.clamp(min=0)
|
||||
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
||||
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
||||
target["area"] = area
|
||||
fields.append("boxes")
|
||||
|
||||
if "input_boxes" in target:
|
||||
boxes = target["input_boxes"]
|
||||
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
||||
cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
|
||||
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
||||
cropped_boxes = cropped_boxes.clamp(min=0)
|
||||
target["input_boxes"] = cropped_boxes.reshape(-1, 4)
|
||||
|
||||
if "masks" in target:
|
||||
# FIXME should we update the area here if there are no boxes?
|
||||
target["masks"] = target["masks"][:, i : i + h, j : j + w]
|
||||
fields.append("masks")
|
||||
|
||||
# remove elements for which the boxes or masks that have zero area
|
||||
if "boxes" in target or "masks" in target:
|
||||
# favor boxes selection when defining which elements to keep
|
||||
# this is compatible with previous implementation
|
||||
if "boxes" in target:
|
||||
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
|
||||
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
||||
else:
|
||||
keep = target["masks"].flatten(1).any(1)
|
||||
|
||||
for field in fields:
|
||||
if field in target:
|
||||
target[field] = target[field][keep]
|
||||
|
||||
return cropped_image, target
|
||||
|
||||
|
||||
def hflip(image, target):
|
||||
flipped_image = F.hflip(image)
|
||||
|
||||
w, h = image.size
|
||||
|
||||
target = target.copy()
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
|
||||
[-1, 1, -1, 1]
|
||||
) + torch.as_tensor([w, 0, w, 0])
|
||||
target["boxes"] = boxes
|
||||
|
||||
if "input_boxes" in target:
|
||||
boxes = target["input_boxes"]
|
||||
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
|
||||
[-1, 1, -1, 1]
|
||||
) + torch.as_tensor([w, 0, w, 0])
|
||||
target["input_boxes"] = boxes
|
||||
|
||||
if "masks" in target:
|
||||
target["masks"] = target["masks"].flip(-1)
|
||||
|
||||
if "text_input" in target:
|
||||
text_input = (
|
||||
target["text_input"]
|
||||
.replace("left", "[TMP]")
|
||||
.replace("right", "left")
|
||||
.replace("[TMP]", "right")
|
||||
)
|
||||
target["text_input"] = text_input
|
||||
|
||||
return flipped_image, target
|
||||
|
||||
|
||||
def resize(image, target, size, max_size=None, square=False):
|
||||
# size can be min_size (scalar) or (w, h) tuple
|
||||
|
||||
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
||||
w, h = image_size
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((w, h)))
|
||||
max_original_size = float(max((w, h)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
size = int(round(max_size * min_original_size / max_original_size))
|
||||
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return (h, w)
|
||||
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
|
||||
return (oh, ow)
|
||||
|
||||
def get_size(image_size, size, max_size=None):
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size[::-1]
|
||||
else:
|
||||
return get_size_with_aspect_ratio(image_size, size, max_size)
|
||||
|
||||
if square:
|
||||
size = size, size
|
||||
else:
|
||||
size = get_size(image.size, size, max_size)
|
||||
rescaled_image = F.resize(image, size)
|
||||
|
||||
if target is None:
|
||||
return rescaled_image, None
|
||||
|
||||
ratios = tuple(
|
||||
float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)
|
||||
)
|
||||
ratio_width, ratio_height = ratios
|
||||
|
||||
target = target.copy()
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
scaled_boxes = boxes * torch.as_tensor(
|
||||
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32
|
||||
)
|
||||
target["boxes"] = scaled_boxes
|
||||
if "input_boxes" in target:
|
||||
boxes = target["input_boxes"]
|
||||
scaled_boxes = boxes * torch.as_tensor(
|
||||
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32
|
||||
)
|
||||
target["input_boxes"] = scaled_boxes
|
||||
|
||||
if "area" in target:
|
||||
area = target["area"]
|
||||
scaled_area = area * (ratio_width * ratio_height)
|
||||
target["area"] = scaled_area
|
||||
|
||||
h, w = size
|
||||
target["size"] = torch.tensor([h, w])
|
||||
|
||||
if "masks" in target:
|
||||
target["masks"] = (
|
||||
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0]
|
||||
> 0.5
|
||||
)
|
||||
|
||||
return rescaled_image, target
|
||||
|
||||
|
||||
def pad(image, target, padding):
|
||||
if len(padding) == 2:
|
||||
# assumes that we only pad on the bottom right corners
|
||||
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
|
||||
else:
|
||||
# left, top, right, bottom
|
||||
padded_image = F.pad(image, (padding[0], padding[1], padding[2], padding[3]))
|
||||
if target is None:
|
||||
return padded_image, None
|
||||
target = target.copy()
|
||||
|
||||
w, h = padded_image.size
|
||||
|
||||
# should we do something wrt the original size?
|
||||
target["size"] = torch.tensor([h, w])
|
||||
if "boxes" in target and len(padding) == 4:
|
||||
boxes = target["boxes"]
|
||||
boxes = boxes + torch.as_tensor(
|
||||
[padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32
|
||||
)
|
||||
target["boxes"] = boxes
|
||||
|
||||
if "input_boxes" in target and len(padding) == 4:
|
||||
boxes = target["input_boxes"]
|
||||
boxes = boxes + torch.as_tensor(
|
||||
[padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32
|
||||
)
|
||||
target["input_boxes"] = boxes
|
||||
|
||||
if "masks" in target:
|
||||
if len(padding) == 2:
|
||||
target["masks"] = torch.nn.functional.pad(
|
||||
target["masks"], (0, padding[0], 0, padding[1])
|
||||
)
|
||||
else:
|
||||
target["masks"] = torch.nn.functional.pad(
|
||||
target["masks"], (padding[0], padding[2], padding[1], padding[3])
|
||||
)
|
||||
return padded_image, target
|
||||
|
||||
|
||||
class RandomCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img, target):
|
||||
region = T.RandomCrop.get_params(img, self.size)
|
||||
return crop(img, target, region)
|
||||
|
||||
|
||||
class RandomSizeCrop:
|
||||
def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
|
||||
self.min_size = min_size
|
||||
self.max_size = max_size
|
||||
self.respect_boxes = respect_boxes # if True we can't crop a box out
|
||||
|
||||
def __call__(self, img: PIL.Image.Image, target: dict):
|
||||
init_boxes = len(target["boxes"])
|
||||
init_boxes_tensor = target["boxes"].clone()
|
||||
if self.respect_boxes and init_boxes > 0:
|
||||
minW, minH, maxW, maxH = (
|
||||
min(img.width, self.min_size),
|
||||
min(img.width, self.min_size),
|
||||
min(img.width, self.max_size),
|
||||
min(img.height, self.max_size),
|
||||
)
|
||||
minX, minY = (
|
||||
target["boxes"][:, 0].max().item() + 10.0,
|
||||
target["boxes"][:, 1].max().item() + 10.0,
|
||||
)
|
||||
minX = min(img.width, minX)
|
||||
minY = min(img.height, minY)
|
||||
maxX, maxY = (
|
||||
target["boxes"][:, 2].min().item() - 10,
|
||||
target["boxes"][:, 3].min().item() - 10,
|
||||
)
|
||||
maxX = max(0.0, maxX)
|
||||
maxY = max(0.0, maxY)
|
||||
minW = max(minW, minX - maxX)
|
||||
minH = max(minH, minY - maxY)
|
||||
w = random.uniform(minW, max(minW, maxW))
|
||||
h = random.uniform(minH, max(minH, maxH))
|
||||
if minX > maxX:
|
||||
# i = random.uniform(max(0, minX - w + 1), max(maxX, max(0, minX - w + 1)))
|
||||
i = random.uniform(max(0, minX - w), max(maxX, max(0, minX - w)))
|
||||
else:
|
||||
i = random.uniform(
|
||||
max(0, minX - w + 1), max(maxX - 1, max(0, minX - w + 1))
|
||||
)
|
||||
if minY > maxY:
|
||||
# j = random.uniform(max(0, minY - h + 1), max(maxY, max(0, minY - h + 1)))
|
||||
j = random.uniform(max(0, minY - h), max(maxY, max(0, minY - h)))
|
||||
else:
|
||||
j = random.uniform(
|
||||
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
|
||||
)
|
||||
result_img, result_target = crop(img, target, [j, i, h, w])
|
||||
assert (
|
||||
len(result_target["boxes"]) == init_boxes
|
||||
), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
|
||||
|
||||
return result_img, result_target
|
||||
else:
|
||||
w = random.randint(self.min_size, min(img.width, self.max_size))
|
||||
h = random.randint(self.min_size, min(img.height, self.max_size))
|
||||
region = T.RandomCrop.get_params(img, (h, w))
|
||||
result_img, result_target = crop(img, target, region)
|
||||
return result_img, result_target
|
||||
|
||||
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img, target):
|
||||
image_width, image_height = img.size
|
||||
crop_height, crop_width = self.size
|
||||
crop_top = int(round((image_height - crop_height) / 2.0))
|
||||
crop_left = int(round((image_width - crop_width) / 2.0))
|
||||
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
|
||||
|
||||
|
||||
class RandomHorizontalFlip:
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img, target):
|
||||
if random.random() < self.p:
|
||||
return hflip(img, target)
|
||||
return img, target
|
||||
|
||||
|
||||
class RandomResize:
|
||||
def __init__(self, sizes, max_size=None, square=False):
|
||||
if isinstance(sizes, int):
|
||||
sizes = (sizes,)
|
||||
assert isinstance(sizes, Iterable)
|
||||
self.sizes = list(sizes)
|
||||
self.max_size = max_size
|
||||
self.square = square
|
||||
|
||||
def __call__(self, img, target=None):
|
||||
size = random.choice(self.sizes)
|
||||
return resize(img, target, size, self.max_size, square=self.square)
|
||||
|
||||
|
||||
class RandomPad:
|
||||
def __init__(self, max_pad):
|
||||
self.max_pad = max_pad
|
||||
|
||||
def __call__(self, img, target):
|
||||
pad_x = random.randint(0, self.max_pad)
|
||||
pad_y = random.randint(0, self.max_pad)
|
||||
return pad(img, target, (pad_x, pad_y))
|
||||
|
||||
|
||||
class PadToSize:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img, target):
|
||||
w, h = img.size
|
||||
pad_x = self.size - w
|
||||
pad_y = self.size - h
|
||||
assert pad_x >= 0 and pad_y >= 0
|
||||
pad_left = random.randint(0, pad_x)
|
||||
pad_right = pad_x - pad_left
|
||||
pad_top = random.randint(0, pad_y)
|
||||
pad_bottom = pad_y - pad_top
|
||||
return pad(img, target, (pad_left, pad_top, pad_right, pad_bottom))
|
||||
|
||||
|
||||
class Identity:
|
||||
def __call__(self, img, target):
|
||||
return img, target
|
||||
|
||||
|
||||
class RandomSelect:
|
||||
"""
|
||||
Randomly selects between transforms1 and transforms2,
|
||||
with probability p for transforms1 and (1 - p) for transforms2
|
||||
"""
|
||||
|
||||
def __init__(self, transforms1=None, transforms2=None, p=0.5):
|
||||
self.transforms1 = transforms1 or Identity()
|
||||
self.transforms2 = transforms2 or Identity()
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img, target):
|
||||
if random.random() < self.p:
|
||||
return self.transforms1(img, target)
|
||||
return self.transforms2(img, target)
|
||||
|
||||
|
||||
class ToTensor:
|
||||
def __call__(self, img, target):
|
||||
return F.to_tensor(img), target
|
||||
|
||||
|
||||
class RandomErasing:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.eraser = T.RandomErasing(*args, **kwargs)
|
||||
|
||||
def __call__(self, img, target):
|
||||
return self.eraser(img), target
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, image, target=None):
|
||||
image = F.normalize(image, mean=self.mean, std=self.std)
|
||||
if target is None:
|
||||
return image, None
|
||||
target = target.copy()
|
||||
h, w = image.shape[-2:]
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
boxes = box_xyxy_to_cxcywh(boxes)
|
||||
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
||||
target["boxes"] = boxes
|
||||
if "input_boxes" in target:
|
||||
boxes = target["input_boxes"]
|
||||
boxes = box_xyxy_to_cxcywh(boxes)
|
||||
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
||||
target["input_boxes"] = boxes
|
||||
return image, target
|
||||
|
||||
|
||||
class RemoveDifficult:
|
||||
def __init__(self, enabled=False):
|
||||
self.remove_difficult = enabled
|
||||
|
||||
def __call__(self, image, target=None):
|
||||
if target is None:
|
||||
return image, None
|
||||
target = target.copy()
|
||||
keep = ~target["iscrowd"].to(torch.bool) | (not self.remove_difficult)
|
||||
if "boxes" in target:
|
||||
target["boxes"] = target["boxes"][keep]
|
||||
target["labels"] = target["labels"][keep]
|
||||
target["iscrowd"] = target["iscrowd"][keep]
|
||||
return image, target
|
||||
|
||||
|
||||
class Compose:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, image, target):
|
||||
for t in self.transforms:
|
||||
image, target = t(image, target)
|
||||
return image, target
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + "("
|
||||
for t in self.transforms:
|
||||
format_string += "\n"
|
||||
format_string += " {0}".format(t)
|
||||
format_string += "\n)"
|
||||
return format_string
|
||||
|
||||
|
||||
def get_random_resize_scales(size, min_size, rounded):
|
||||
stride = 128 if rounded else 32
|
||||
min_size = int(stride * math.ceil(min_size / stride))
|
||||
scales = list(range(min_size, size + 1, stride))
|
||||
return scales
|
||||
|
||||
|
||||
def get_random_resize_max_size(size, ratio=5 / 3):
|
||||
max_size = round(ratio * size)
|
||||
return max_size
|
||||
1396
sam3/train/transforms/basic_for_api.py
Normal file
1396
sam3/train/transforms/basic_for_api.py
Normal file
File diff suppressed because it is too large
Load Diff
607
sam3/train/transforms/filter_query_transforms.py
Normal file
607
sam3/train/transforms/filter_query_transforms.py
Normal file
@@ -0,0 +1,607 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object
|
||||
|
||||
|
||||
class FilterDataPointQueries:
|
||||
find_ids_to_filter: set = None
|
||||
get_ids_to_filter: set = None
|
||||
obj_ids_to_filter: set = None # stored as pairs (img_id, obj_id)
|
||||
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
"""
|
||||
Compute set of query ids to keep, for both find and get queries
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _do_filter_query(self, query: Union[FindQuery], query_id: int):
|
||||
assert self.find_ids_to_filter is not None
|
||||
|
||||
return query_id in self.find_ids_to_filter
|
||||
|
||||
|
||||
class FilterQueryWithText(FilterDataPointQueries):
|
||||
"""
|
||||
Filter all datapoints which have query text in a specified list of exluded terms
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, exclude_find_keys: List[str] = None, exclude_get_keys: List[str] = None
|
||||
):
|
||||
self.find_filter_keys = exclude_find_keys if exclude_find_keys else []
|
||||
self.get_filter_keys = exclude_get_keys if exclude_get_keys else []
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
del_find_ids = []
|
||||
del_get_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if f_q.query_text in self.find_filter_keys:
|
||||
del_find_ids.append(i)
|
||||
|
||||
self.find_ids_to_filter = set(del_find_ids)
|
||||
|
||||
|
||||
class KeepMaxNumFindQueries(FilterDataPointQueries):
|
||||
def __init__(
|
||||
self, max_num_find_queries: int, retain_positive_queries: bool = False
|
||||
):
|
||||
self.max_num_find_queries = max_num_find_queries
|
||||
self.retain_positive_queries = retain_positive_queries
|
||||
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
self.obj_ids_to_filter = set()
|
||||
num_find_queries = len(datapoint.find_queries)
|
||||
if num_find_queries <= self.max_num_find_queries:
|
||||
self.find_ids_to_filter = set() # keep all find queries
|
||||
return
|
||||
|
||||
if not self.retain_positive_queries:
|
||||
all_find_query_ids = list(range(num_find_queries))
|
||||
num_queries_to_filter = max(0, num_find_queries - self.max_num_find_queries)
|
||||
query_ids_to_filter = random.sample(
|
||||
all_find_query_ids, k=num_queries_to_filter
|
||||
)
|
||||
else:
|
||||
# keep up to max_num_find_queries postive find queries and fill
|
||||
# the remaining slots (if any) with negative find queries
|
||||
pos_find_ids, neg_find_ids = [], []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
# Negative finds return an empty list of object_ids_output
|
||||
if len(f_q.object_ids_output) == 0:
|
||||
neg_find_ids.append(i)
|
||||
else:
|
||||
pos_find_ids.append(i)
|
||||
|
||||
if len(pos_find_ids) >= self.max_num_find_queries:
|
||||
# we have more positive find queries than `max_num_find_queries`,
|
||||
# so we subsample postive find queries and remove all negative find queries
|
||||
num_queries_to_filter = len(pos_find_ids) - self.max_num_find_queries
|
||||
query_ids_to_filter = random.sample(
|
||||
pos_find_ids, k=num_queries_to_filter
|
||||
)
|
||||
query_ids_to_filter.extend(neg_find_ids)
|
||||
else:
|
||||
# we have fewer positive find queries than `max_num_find_queries`
|
||||
# so we need to fill the remaining with negative find queries
|
||||
num_queries_to_filter = num_find_queries - self.max_num_find_queries
|
||||
query_ids_to_filter = random.sample(
|
||||
neg_find_ids, k=num_queries_to_filter
|
||||
)
|
||||
|
||||
assert len(query_ids_to_filter) == num_find_queries - self.max_num_find_queries
|
||||
self.find_ids_to_filter = set(query_ids_to_filter)
|
||||
|
||||
|
||||
class KeepMaxNumFindQueriesVideo(FilterDataPointQueries):
|
||||
def __init__(
|
||||
self,
|
||||
video_mosaic_max_num_find_queries_per_frame: int,
|
||||
retain_positive_queries: bool = False,
|
||||
):
|
||||
self.video_mosaic_max_num_find_queries_per_frame = (
|
||||
video_mosaic_max_num_find_queries_per_frame
|
||||
)
|
||||
self.retain_positive_queries = retain_positive_queries
|
||||
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
self.obj_ids_to_filter = set()
|
||||
num_find_queries = len(datapoint.find_queries)
|
||||
|
||||
findQueries_to_imageIds = defaultdict(list)
|
||||
max_queries_per_frame = True
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
findQueries_to_imageIds[f_q.image_id].append(i)
|
||||
if (
|
||||
len(findQueries_to_imageIds[f_q.image_id])
|
||||
> self.video_mosaic_max_num_find_queries_per_frame
|
||||
):
|
||||
max_queries_per_frame = False
|
||||
|
||||
if max_queries_per_frame:
|
||||
self.find_ids_to_filter = set()
|
||||
return
|
||||
|
||||
num_frames = len(findQueries_to_imageIds)
|
||||
findQueries_0 = findQueries_to_imageIds[0]
|
||||
num_find_queries_0 = len(findQueries_0)
|
||||
max_num_find_queries_per_frame = (
|
||||
self.video_mosaic_max_num_find_queries_per_frame
|
||||
)
|
||||
if not self.retain_positive_queries:
|
||||
find_query_ids_0 = list(range(num_find_queries_0))
|
||||
num_queries_to_filter = max(
|
||||
0, num_find_queries_0 - max_num_find_queries_per_frame
|
||||
)
|
||||
query_ids_to_filter_0 = random.sample(
|
||||
find_query_ids_0, k=num_queries_to_filter
|
||||
)
|
||||
else:
|
||||
# keep up to max_num_find_queries postive find queries and fill
|
||||
# the remaining slots (if any) with negative find queries
|
||||
pos_find_ids_0, neg_find_ids_0 = [], []
|
||||
for i, f_q_id in enumerate(findQueries_0):
|
||||
f_q = datapoint.find_queries[f_q_id]
|
||||
# Negative finds return an empty list of object_ids_output
|
||||
if len(f_q.object_ids_output) == 0:
|
||||
neg_find_ids_0.append(i)
|
||||
else:
|
||||
pos_find_ids_0.append(i)
|
||||
|
||||
if len(pos_find_ids_0) >= max_num_find_queries_per_frame:
|
||||
# we have more positive find queries than `max_num_find_queries`,
|
||||
# so we subsample postive find queries and remove all negative find queries
|
||||
num_queries_to_filter = (
|
||||
len(pos_find_ids_0) - max_num_find_queries_per_frame
|
||||
)
|
||||
query_ids_to_filter_0 = random.sample(
|
||||
pos_find_ids_0, k=num_queries_to_filter
|
||||
)
|
||||
query_ids_to_filter_0.extend(neg_find_ids_0)
|
||||
else:
|
||||
# we have fewer positive find queries than `max_num_find_queries`
|
||||
# so we need to fill the remaining with negative find queries
|
||||
num_queries_to_filter = (
|
||||
num_find_queries_0 - max_num_find_queries_per_frame
|
||||
)
|
||||
query_ids_to_filter_0 = random.sample(
|
||||
neg_find_ids_0, k=num_queries_to_filter
|
||||
)
|
||||
|
||||
# get based on frame 0 all find queries from all the frames with the same indices as in frame 0
|
||||
query_ids_to_filter = []
|
||||
for i in range(num_frames):
|
||||
findQueries_i = findQueries_to_imageIds[i]
|
||||
query_ids_to_filter.extend(
|
||||
[findQueries_i[j] for j in query_ids_to_filter_0]
|
||||
)
|
||||
|
||||
assert (
|
||||
len(query_ids_to_filter)
|
||||
== num_find_queries
|
||||
- self.video_mosaic_max_num_find_queries_per_frame * num_frames
|
||||
)
|
||||
self.find_ids_to_filter = set(query_ids_to_filter)
|
||||
|
||||
|
||||
class KeepSemanticFindQueriesOnly(FilterDataPointQueries):
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
self.obj_ids_to_filter = set()
|
||||
self.find_ids_to_filter = {
|
||||
i for i, q in enumerate(datapoint.find_queries) if q.input_bbox is not None
|
||||
} # filter (remove) geometric find queries (whose input_bbox is not None)
|
||||
|
||||
# Keep all get queries which don't depend on filtered finds
|
||||
|
||||
|
||||
class KeepUnaryFindQueriesOnly(FilterDataPointQueries):
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
self.obj_ids_to_filter = set()
|
||||
self.find_ids_to_filter = set()
|
||||
|
||||
# Keep all get queries which don't depend on filtered finds
|
||||
|
||||
|
||||
class FilterZeroBoxQueries(FilterDataPointQueries):
|
||||
"""
|
||||
Filters all find queries which predict a box with zero area
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _is_zero_area_object(obj: Object):
|
||||
# Check if height or width of bounding box is zero
|
||||
bbox = obj.bbox # Assume in XYXY format
|
||||
height = bbox[..., 3].item() - bbox[..., 1].item()
|
||||
width = bbox[..., 2].item() - bbox[..., 0].item()
|
||||
|
||||
return height == 0 or width == 0
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
|
||||
# Find objects with zero area
|
||||
# Assume only one image per datapoint
|
||||
image_objects = datapoint.images[0].objects
|
||||
exclude_objects = {
|
||||
obj_id
|
||||
for obj_id, obj in enumerate(image_objects)
|
||||
if self._is_zero_area_object(obj)
|
||||
}
|
||||
|
||||
# If a query predicts an object with zero area, drop the whole find query
|
||||
del_find_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
f_q_objects = set(f_q.object_ids_output)
|
||||
if len(exclude_objects.intersection(f_q_objects)) > 0:
|
||||
del_find_ids.append(i)
|
||||
|
||||
self.find_ids_to_filter = set(del_find_ids)
|
||||
|
||||
|
||||
class FilterFindQueriesWithTooManyOut(FilterDataPointQueries):
|
||||
"""
|
||||
Filters all find queries which have more than a specified number of objects in the output
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_objects: int):
|
||||
self.max_num_objects = max_num_objects
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
|
||||
# If a query predicts more than max_num_objects, drop the whole find query
|
||||
del_find_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if len(f_q.object_ids_output) > self.max_num_objects:
|
||||
del_find_ids.append(i)
|
||||
|
||||
self.find_ids_to_filter = set(del_find_ids)
|
||||
|
||||
|
||||
class FilterEmptyTargets(FilterDataPointQueries):
|
||||
"""
|
||||
Filters all targets which have zero area
|
||||
"""
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
|
||||
for img_id in range(len(datapoint.images)):
|
||||
for obj_id, obj in enumerate(datapoint.images[img_id].objects):
|
||||
if obj.area < 1e-6:
|
||||
self.obj_ids_to_filter.add((img_id, obj_id))
|
||||
self.find_ids_to_filter = set()
|
||||
|
||||
|
||||
class FilterNonExhaustiveFindQueries(FilterDataPointQueries):
|
||||
"""
|
||||
Filters all find queries which are non-exhaustive
|
||||
"""
|
||||
|
||||
def __init__(self, exhaustivity_type: str):
|
||||
"""
|
||||
Args:
|
||||
exhaustivity_type: Can be "pixel" or "instance":
|
||||
-pixel: filter queries where the union of all segments covers every pixel belonging to target class
|
||||
-instance: filter queries where there are non-separable or non annotated instances
|
||||
Note that instance exhaustivity implies pixel exhaustivity
|
||||
"""
|
||||
assert exhaustivity_type in ["pixel", "instance"]
|
||||
self.exhaustivity_type = exhaustivity_type
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
|
||||
# If a query predicts more than max_num_objects, drop the whole find query
|
||||
del_find_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if self.exhaustivity_type == "instance":
|
||||
if not f_q.is_exhaustive:
|
||||
del_find_ids.append(i)
|
||||
elif self.exhaustivity_type == "pixel":
|
||||
if f_q.is_pixel_exhaustive is not None and not f_q.is_pixel_exhaustive:
|
||||
del_find_ids.append(i)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unknown exhaustivity type {self.exhaustivity_type}"
|
||||
)
|
||||
|
||||
self.find_ids_to_filter = set(del_find_ids)
|
||||
|
||||
|
||||
class FilterInvalidGeometricQueries(FilterDataPointQueries):
|
||||
"""
|
||||
Filters geometric queries whose output got deleted (eg due to cropping)
|
||||
"""
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
|
||||
# If a query predicts more than max_num_objects, drop the whole find query
|
||||
del_find_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if f_q.input_bbox is not None and f_q.query_text == "geometric":
|
||||
if len(f_q.object_ids_output) == 0:
|
||||
del_find_ids.append(i)
|
||||
self.find_ids_to_filter = set(del_find_ids)
|
||||
|
||||
|
||||
class FlexibleFilterFindGetQueries:
|
||||
def __init__(
|
||||
self, query_filter: FilterDataPointQueries, enabled: bool = True
|
||||
) -> None:
|
||||
self.query_filter = query_filter
|
||||
self.enabled = enabled
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
if not self.enabled:
|
||||
return datapoint
|
||||
|
||||
# Identify all queries to filter
|
||||
self.query_filter.identify_queries_to_filter(datapoint=datapoint)
|
||||
|
||||
del_find_ids = []
|
||||
del_get_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if self.query_filter._do_filter_query(f_q, i):
|
||||
datapoint.find_queries[i] = None
|
||||
del_find_ids.append(i)
|
||||
|
||||
new_find_queries = []
|
||||
new_get_queries = []
|
||||
|
||||
find_old_to_new_map = {}
|
||||
get_old_to_new_map = {}
|
||||
|
||||
find_counter = 0
|
||||
get_counter = 0
|
||||
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if f_q is not None:
|
||||
find_old_to_new_map[i] = find_counter
|
||||
find_counter += 1
|
||||
new_find_queries.append(f_q)
|
||||
|
||||
start_with_zero_check = False
|
||||
for n_f_q in new_find_queries:
|
||||
if n_f_q.query_processing_order == 0:
|
||||
start_with_zero_check = True
|
||||
break
|
||||
|
||||
if len(new_find_queries) == 0:
|
||||
start_with_zero_check = True
|
||||
|
||||
assert (
|
||||
start_with_zero_check
|
||||
), "Invalid Find queries, they need to start at query_processing_order = 0"
|
||||
|
||||
datapoint.find_queries = new_find_queries
|
||||
|
||||
if len(datapoint.find_queries) == 0:
|
||||
print("Warning: No find queries left in datapoint, this is not allowed")
|
||||
print("Filtering function:", self.query_filter)
|
||||
print("Datapoint:", datapoint)
|
||||
raise ValueError
|
||||
|
||||
# The deletion may have removed intermediate steps, so we need to remap to make them contiguous again
|
||||
all_stages = sorted(
|
||||
list(set(q.query_processing_order for q in datapoint.find_queries))
|
||||
)
|
||||
stage_map = {qpo: i for i, qpo in enumerate(all_stages)}
|
||||
for i in range(len(datapoint.find_queries)):
|
||||
qpo = datapoint.find_queries[i].query_processing_order
|
||||
datapoint.find_queries[i].query_processing_order = stage_map[qpo]
|
||||
|
||||
# Final step, clear up objects that are not used anymore
|
||||
for img_id in range(len(datapoint.images)):
|
||||
all_objects_ids = set(
|
||||
i
|
||||
for find in datapoint.find_queries
|
||||
for i in find.object_ids_output
|
||||
if find.image_id == img_id
|
||||
)
|
||||
unused_ids = (
|
||||
set(range(len(datapoint.images[img_id].objects))) - all_objects_ids
|
||||
)
|
||||
for tgt_img_id, tgt_obj_id in self.query_filter.obj_ids_to_filter:
|
||||
if tgt_img_id == img_id:
|
||||
unused_ids.add(tgt_obj_id)
|
||||
|
||||
if len(unused_ids) > 0:
|
||||
old_objects = datapoint.images[img_id].objects
|
||||
object_old_to_new_map = {}
|
||||
new_objects = []
|
||||
for i, o in enumerate(old_objects):
|
||||
if i not in unused_ids:
|
||||
object_old_to_new_map[i] = len(new_objects)
|
||||
new_objects.append(o)
|
||||
|
||||
datapoint.images[img_id].objects = new_objects
|
||||
|
||||
# Remap the outputs of the find queries
|
||||
affected_find_queries_ids = set()
|
||||
object_old_to_new_map_per_query = {}
|
||||
for fid, find in enumerate(datapoint.find_queries):
|
||||
if find.image_id == img_id:
|
||||
old_object_ids_output = find.object_ids_output
|
||||
object_old_to_new_map_per_query[fid] = {}
|
||||
find.object_ids_output = []
|
||||
for oid, old_obj_id in enumerate(old_object_ids_output):
|
||||
if old_obj_id not in unused_ids:
|
||||
new_obj_id = object_old_to_new_map[old_obj_id]
|
||||
find.object_ids_output.append(new_obj_id)
|
||||
object_old_to_new_map_per_query[fid][oid] = (
|
||||
len(find.object_ids_output) - 1
|
||||
)
|
||||
affected_find_queries_ids.add(fid)
|
||||
|
||||
# finally remove unused images
|
||||
all_imgs_to_keep = set()
|
||||
for f_q in datapoint.find_queries:
|
||||
all_imgs_to_keep.add(f_q.image_id)
|
||||
|
||||
old_img_id_to_new_img_id = {}
|
||||
new_images = []
|
||||
for img_id, img in enumerate(datapoint.images):
|
||||
if img_id in all_imgs_to_keep:
|
||||
old_img_id_to_new_img_id[img_id] = len(new_images)
|
||||
new_images.append(img)
|
||||
datapoint.images = new_images
|
||||
|
||||
for f_q in datapoint.find_queries:
|
||||
f_q.image_id = old_img_id_to_new_img_id[f_q.image_id]
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class AddPrefixSuffixToFindText:
|
||||
"""
|
||||
Add prefix or suffix strings to find query text on the fly.
|
||||
|
||||
If `condition_on_text` is True, the prefix or suffix strings are only added
|
||||
to those find query text in `condition_text_list` (case-insensitive).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
condition_on_text: bool = False,
|
||||
condition_text_list: Optional[List[str]] = None,
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.condition_on_text = condition_on_text
|
||||
if self.condition_on_text:
|
||||
assert condition_text_list is not None
|
||||
self.condition_text_set = {s.lower().strip() for s in condition_text_list}
|
||||
self.enabled = enabled
|
||||
if self.enabled:
|
||||
logging.info(
|
||||
f"AddPrefixSuffixToFindText: prefix={prefix}, suffix={suffix}, "
|
||||
f"condition_on_text={condition_on_text}, condition_text_list={condition_text_list}"
|
||||
)
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
if not self.enabled:
|
||||
return datapoint
|
||||
|
||||
for find in datapoint.find_queries:
|
||||
if find.query_text == "geometric":
|
||||
# skip geometric find queries
|
||||
continue
|
||||
if (
|
||||
self.condition_on_text
|
||||
and find.query_text.lower().strip() not in self.condition_text_set
|
||||
):
|
||||
# if condition_on_text is True, skip those queries not in condition_text_set
|
||||
continue
|
||||
|
||||
# add prefix and/or suffix strings to the find query text
|
||||
if self.prefix is not None:
|
||||
find.query_text = self.prefix + find.query_text
|
||||
if self.suffix is not None:
|
||||
find.query_text = find.query_text + self.suffix
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class FilterCrowds(FilterDataPointQueries):
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
"""
|
||||
Compute set of query ids to keep, for both find and get queries
|
||||
"""
|
||||
self.obj_ids_to_filter = set()
|
||||
self.find_ids_to_filter = set()
|
||||
# self.get_ids_to_filter = set()
|
||||
for img_id, img in enumerate(datapoint.images):
|
||||
for obj_id, obj in enumerate(img.objects):
|
||||
if obj.is_crowd:
|
||||
self.obj_ids_to_filter.add((img_id, obj_id))
|
||||
|
||||
|
||||
class TextQueryToVisual:
|
||||
"""
|
||||
Transform a test query to a visual query (with some proba), using any of the output targets as the prompt
|
||||
"""
|
||||
|
||||
def __init__(self, probability, keep_text_queries=False) -> None:
|
||||
self.probability = probability
|
||||
assert 0 <= probability <= 1
|
||||
self.keep_text_queries = keep_text_queries
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for find in datapoint.find_queries:
|
||||
if find.input_bbox is not None or find.input_points is not None:
|
||||
# skip geometric find queries
|
||||
continue
|
||||
|
||||
if len(find.object_ids_output) == 0:
|
||||
# Can't create a visual query, skip
|
||||
continue
|
||||
|
||||
if find.query_processing_order > 0:
|
||||
# Second stage query, can't use
|
||||
continue
|
||||
|
||||
if random.random() > self.probability:
|
||||
continue
|
||||
|
||||
selected_vq_id = random.choice(find.object_ids_output)
|
||||
img_id = find.image_id
|
||||
|
||||
find.input_bbox = datapoint.images[img_id].objects[selected_vq_id].bbox
|
||||
find.input_bbox_label = torch.ones(1, dtype=torch.bool)
|
||||
if not self.keep_text_queries:
|
||||
find.query_text = "visual"
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class RemoveInputBoxes:
|
||||
"""
|
||||
Remove input boxes from find queries
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for find in datapoint.find_queries:
|
||||
if find.input_bbox is None:
|
||||
continue
|
||||
|
||||
if find.query_text == "geometric":
|
||||
print("Warning: removing input box from geometric find query")
|
||||
|
||||
find.input_bbox = None
|
||||
return datapoint
|
||||
|
||||
|
||||
class OverwriteTextQuery:
|
||||
"""
|
||||
With some probability, overwrite the text query with a custom text
|
||||
"""
|
||||
|
||||
def __init__(self, target_text, probability=1.0) -> None:
|
||||
self.probability = probability
|
||||
self.target_text = target_text
|
||||
assert 0 <= probability <= 1
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for find in datapoint.find_queries:
|
||||
if random.random() > self.probability:
|
||||
continue
|
||||
|
||||
find.query_text = self.target_text
|
||||
|
||||
return datapoint
|
||||
345
sam3/train/transforms/point_sampling.py
Normal file
345
sam3/train/transforms/point_sampling.py
Normal file
@@ -0,0 +1,345 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image as PILImage
|
||||
from pycocotools import mask as mask_util
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||
from torchvision.ops import masks_to_boxes
|
||||
|
||||
|
||||
def sample_points_from_rle(rle, n_points, mode, box=None, normalize=True):
|
||||
"""
|
||||
Sample random points from a mask provided in COCO RLE format. 'mode'
|
||||
'mode' is in ["centered", "random_mask", "random_box"]
|
||||
"centered": points are sampled farthest from the mask edges and each other
|
||||
"random_mask": points are sampled uniformly from the mask
|
||||
"random_box": points are sampled uniformly from the annotation's box
|
||||
'box' must be provided if 'mode' is "random_box".
|
||||
If 'normalize' is true, points are in [0,1], relative to mask h,w.
|
||||
"""
|
||||
mask = np.ascontiguousarray(mask_util.decode(rle))
|
||||
points = sample_points_from_mask(mask, n_points, mode, box)
|
||||
|
||||
if normalize:
|
||||
h, w = mask.shape
|
||||
norm = np.array([w, h, 1.0])[None, :]
|
||||
points = points / norm
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def sample_points_from_mask(mask, n_points, mode, box=None):
|
||||
if mode == "centered":
|
||||
points = center_positive_sample(mask, n_points)
|
||||
elif mode == "random_mask":
|
||||
points = uniform_positive_sample(mask, n_points)
|
||||
elif mode == "random_box":
|
||||
assert box is not None, "'random_box' mode requires a provided box."
|
||||
points = uniform_sample_from_box(mask, box, n_points)
|
||||
else:
|
||||
raise ValueError(f"Unknown point sampling mode {mode}.")
|
||||
return points
|
||||
|
||||
|
||||
def uniform_positive_sample(mask, n_points):
|
||||
"""
|
||||
Samples positive points uniformly from the mask. Only integer pixel
|
||||
values are sampled.
|
||||
"""
|
||||
# Sampling directly from the uncompressed RLE would be faster but is
|
||||
# likely unnecessary.
|
||||
mask_points = np.stack(np.nonzero(mask), axis=0).transpose(1, 0)
|
||||
assert len(mask_points) > 0, "Can't sample positive points from an empty mask."
|
||||
selected_idxs = np.random.randint(low=0, high=len(mask_points), size=n_points)
|
||||
selected_points = mask_points[selected_idxs]
|
||||
|
||||
selected_points = selected_points[:, ::-1] # (y, x) -> (x, y)
|
||||
labels = np.ones((len(selected_points), 1))
|
||||
selected_points = np.concatenate([selected_points, labels], axis=1)
|
||||
|
||||
return selected_points
|
||||
|
||||
|
||||
def center_positive_sample(mask, n_points):
|
||||
"""
|
||||
Samples points farthest from mask edges (by distance transform)
|
||||
and subsequent points also farthest from each other. Each new point
|
||||
sampled is treated as an edge for future points. Edges of the image are
|
||||
treated as edges of the mask.
|
||||
"""
|
||||
|
||||
# Pad mask by one pixel on each end to assure distance transform
|
||||
# avoids edges
|
||||
padded_mask = np.pad(mask, 1)
|
||||
|
||||
points = []
|
||||
for _ in range(n_points):
|
||||
assert np.max(mask) > 0, "Can't sample positive points from an empty mask."
|
||||
dist = cv2.distanceTransform(padded_mask, cv2.DIST_L2, 0)
|
||||
point = np.unravel_index(dist.argmax(), dist.shape)
|
||||
# Mark selected point as background so next point avoids it
|
||||
padded_mask[point[0], point[1]] = 0
|
||||
points.append(point[::-1]) # (y, x) -> (x, y)
|
||||
|
||||
points = np.stack(points, axis=0)
|
||||
points = points - 1 # Subtract left/top padding of 1
|
||||
labels = np.ones((len(points), 1))
|
||||
points = np.concatenate([points, labels], axis=1)
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def uniform_sample_from_box(mask, box, n_points):
|
||||
"""
|
||||
Sample points uniformly from the provided box. The points' labels
|
||||
are determined by the provided mask. Does not guarantee a positive
|
||||
point is sampled. The box is assumed unnormalized in XYXY format.
|
||||
Points are sampled at integer values.
|
||||
"""
|
||||
|
||||
# Since lower/right edges are exclusive, ceil can be applied to all edges
|
||||
int_box = np.ceil(box)
|
||||
|
||||
x = np.random.randint(low=int_box[0], high=int_box[2], size=n_points)
|
||||
y = np.random.randint(low=int_box[1], high=int_box[3], size=n_points)
|
||||
labels = mask[y, x]
|
||||
points = np.stack([x, y, labels], axis=1)
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def rescale_box_xyxy(box, factor, imsize=None):
|
||||
"""
|
||||
Rescale a box providing in unnormalized XYXY format, fixing the center.
|
||||
If imsize is provided, clamp to the image.
|
||||
"""
|
||||
cx, cy = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
|
||||
new_w, new_h = factor * w, factor * h
|
||||
|
||||
new_x0, new_y0 = cx - new_w / 2, cy - new_h / 2
|
||||
new_x1, new_y1 = cx + new_w / 2, cy + new_h / 2
|
||||
|
||||
if imsize is not None:
|
||||
new_x0 = max(min(new_x0, imsize[1]), 0)
|
||||
new_x1 = max(min(new_x1, imsize[1]), 0)
|
||||
new_y0 = max(min(new_y0, imsize[0]), 0)
|
||||
new_y1 = max(min(new_y1, imsize[0]), 0)
|
||||
|
||||
return [new_x0, new_y0, new_x1, new_y1]
|
||||
|
||||
|
||||
def noise_box(box, im_size, box_noise_std, box_noise_max, min_box_area):
|
||||
if box_noise_std <= 0.0:
|
||||
return box
|
||||
noise = box_noise_std * torch.randn(size=(4,))
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
scale_factor = torch.tensor([w, h, w, h])
|
||||
noise = noise * scale_factor
|
||||
if box_noise_max is not None:
|
||||
noise = torch.clamp(noise, -box_noise_max, box_noise_max)
|
||||
input_box = box + noise
|
||||
# Clamp to maximum image size
|
||||
img_clamp = torch.tensor([im_size[1], im_size[0], im_size[1], im_size[0]])
|
||||
input_box = torch.maximum(input_box, torch.zeros_like(input_box))
|
||||
input_box = torch.minimum(input_box, img_clamp)
|
||||
if (input_box[2] - input_box[0]) * (input_box[3] - input_box[1]) <= min_box_area:
|
||||
return box
|
||||
|
||||
return input_box
|
||||
|
||||
|
||||
class RandomGeometricInputsAPI:
|
||||
"""
|
||||
For geometric queries, replaces the input box or points with a random
|
||||
one sampled from the GT mask. Segments must be provided for objects
|
||||
that are targets of geometric queries, and must be binary masks. Existing
|
||||
point and box queries in the datapoint will be ignored and completely replaced.
|
||||
Will sample points and boxes in XYXY format in absolute pixel space.
|
||||
|
||||
Geometry queries are currently determined by taking any query whose
|
||||
query text is a set value.
|
||||
|
||||
Args:
|
||||
num_points (int or (int, int)): how many points to sample. If a tuple,
|
||||
sample a random number of points uniformly over the inclusive range.
|
||||
box_chance (float): fraction of time a box is sampled. A box will replace
|
||||
one sampled point.
|
||||
box_noise_std (float): if greater than 0, add noise to the sampled boxes
|
||||
with this std. Noise is relative to the length of the box side.
|
||||
box_noise_max (int): if not none, truncate any box noise larger than this
|
||||
in terms of absolute pixels.
|
||||
resample_box_from_mask (bool): if True, any sampled box will be determined
|
||||
by finding the extrema of the provided mask. If False, the bbox provided
|
||||
in the target object will be used.
|
||||
point_sample_mode (str): In ["centered", "random_mask", "random_box"],
|
||||
controlling how points are sampled:
|
||||
"centered": points are sampled farthest from the mask edges and each other
|
||||
"random_mask": points are sampled uniformly from the mask
|
||||
"random_box": points are sampled uniformly from the annotation's box
|
||||
Note that "centered" may be too slow for on-line generation.
|
||||
geometric_query_str (str): what string in query_text indicates a
|
||||
geometry query.
|
||||
minimum_box_area (float): sampled boxes with area this size or smaller after
|
||||
noising will use the original box instead. It is the input's responsibility
|
||||
to avoid original boxes that violate necessary area bounds.
|
||||
concat_points (bool): if True, any sampled points will be added to existing
|
||||
ones instead of replacing them.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_points,
|
||||
box_chance,
|
||||
box_noise_std=0.0,
|
||||
box_noise_max=None,
|
||||
minimum_box_area=0.0,
|
||||
resample_box_from_mask=False,
|
||||
point_sample_mode="random_mask",
|
||||
sample_box_scale_factor=1.0,
|
||||
geometric_query_str="geometric",
|
||||
concat_points=False,
|
||||
):
|
||||
self.num_points = num_points
|
||||
if not isinstance(self.num_points, int):
|
||||
# Convert from inclusive range to exclusive range expected by torch
|
||||
self.num_points[1] += 1
|
||||
self.num_points = tuple(self.num_points)
|
||||
self.box_chance = box_chance
|
||||
self.box_noise_std = box_noise_std
|
||||
self.box_noise_max = box_noise_max
|
||||
self.minimum_box_area = minimum_box_area
|
||||
self.resample_box_from_mask = resample_box_from_mask
|
||||
self.point_sample_mode = point_sample_mode
|
||||
assert point_sample_mode in [
|
||||
"centered",
|
||||
"random_mask",
|
||||
"random_box",
|
||||
], "Unknown point sample mode."
|
||||
self.geometric_query_str = geometric_query_str
|
||||
self.concat_points = concat_points
|
||||
self.sample_box_scale_factor = sample_box_scale_factor
|
||||
|
||||
def _sample_num_points_and_if_box(self):
|
||||
if isinstance(self.num_points, tuple):
|
||||
n_points = torch.randint(
|
||||
low=self.num_points[0], high=self.num_points[1], size=(1,)
|
||||
).item()
|
||||
else:
|
||||
n_points = self.num_points
|
||||
if self.box_chance > 0.0:
|
||||
use_box = torch.rand(size=(1,)).item() < self.box_chance
|
||||
n_points -= int(use_box) # box stands in for one point
|
||||
else:
|
||||
use_box = False
|
||||
return n_points, use_box
|
||||
|
||||
def _get_original_box(self, target_object):
|
||||
if not self.resample_box_from_mask:
|
||||
return target_object.bbox
|
||||
mask = target_object.segment
|
||||
return masks_to_boxes(mask[None, :, :])[0]
|
||||
|
||||
def _get_target_object(self, datapoint, query):
|
||||
img = datapoint.images[query.image_id]
|
||||
targets = query.object_ids_output
|
||||
assert (
|
||||
len(targets) == 1
|
||||
), "Geometric queries only support a single target object."
|
||||
target_idx = targets[0]
|
||||
return img.objects[target_idx]
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
for query in datapoint.find_queries:
|
||||
if query.query_text != self.geometric_query_str:
|
||||
continue
|
||||
|
||||
target_object = self._get_target_object(datapoint, query)
|
||||
n_points, use_box = self._sample_num_points_and_if_box()
|
||||
box = self._get_original_box(target_object)
|
||||
|
||||
mask = target_object.segment
|
||||
if n_points > 0:
|
||||
# FIXME: The conversion to numpy and back to reuse code
|
||||
# is awkward, but this is all in the dataloader worker anyway
|
||||
# on CPU and so I don't think it should matter.
|
||||
if self.sample_box_scale_factor != 1.0:
|
||||
sample_box = rescale_box_xyxy(
|
||||
box.numpy(), self.sample_box_scale_factor, mask.shape
|
||||
)
|
||||
else:
|
||||
sample_box = box.numpy()
|
||||
input_points = sample_points_from_mask(
|
||||
mask.numpy(),
|
||||
n_points,
|
||||
self.point_sample_mode,
|
||||
sample_box,
|
||||
)
|
||||
input_points = torch.as_tensor(input_points)
|
||||
input_points = input_points[None, :, :]
|
||||
if self.concat_points and query.input_points is not None:
|
||||
input_points = torch.cat([query.input_points, input_points], dim=1)
|
||||
else:
|
||||
input_points = query.input_points if self.concat_points else None
|
||||
|
||||
if use_box:
|
||||
w, h = datapoint.images[query.image_id].size
|
||||
input_box = noise_box(
|
||||
box,
|
||||
(h, w),
|
||||
box_noise_std=self.box_noise_std,
|
||||
box_noise_max=self.box_noise_max,
|
||||
min_box_area=self.minimum_box_area,
|
||||
)
|
||||
input_box = input_box[None, :]
|
||||
else:
|
||||
input_box = query.input_bbox if self.concat_points else None
|
||||
|
||||
query.input_points = input_points
|
||||
query.input_bbox = input_box
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class RandomizeInputBbox:
|
||||
"""
|
||||
Simplified version of the geometric transform that only deals with input boxes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
box_noise_std=0.0,
|
||||
box_noise_max=None,
|
||||
minimum_box_area=0.0,
|
||||
):
|
||||
self.box_noise_std = box_noise_std
|
||||
self.box_noise_max = box_noise_max
|
||||
self.minimum_box_area = minimum_box_area
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for query in datapoint.find_queries:
|
||||
if query.input_bbox is None:
|
||||
continue
|
||||
|
||||
img = datapoint.images[query.image_id].data
|
||||
if isinstance(img, PILImage.Image):
|
||||
w, h = img.size
|
||||
else:
|
||||
assert isinstance(img, torch.Tensor)
|
||||
h, w = img.shape[-2:]
|
||||
|
||||
for box_id in range(query.input_bbox.shape[0]):
|
||||
query.input_bbox[box_id, :] = noise_box(
|
||||
query.input_bbox[box_id, :].view(4),
|
||||
(h, w),
|
||||
box_noise_std=self.box_noise_std,
|
||||
box_noise_max=self.box_noise_max,
|
||||
min_box_area=self.minimum_box_area,
|
||||
).view(1, 4)
|
||||
|
||||
return datapoint
|
||||
157
sam3/train/transforms/segmentation.py
Normal file
157
sam3/train/transforms/segmentation.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import numpy as np
|
||||
import pycocotools.mask as mask_utils
|
||||
import torch
|
||||
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from sam3.model.box_ops import masks_to_boxes
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||
|
||||
|
||||
class InstanceToSemantic(object):
|
||||
"""Convert instance segmentation to semantic segmentation."""
|
||||
|
||||
def __init__(self, delete_instance=True, use_rle=False):
|
||||
self.delete_instance = delete_instance
|
||||
self.use_rle = use_rle
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for fquery in datapoint.find_queries:
|
||||
h, w = datapoint.images[fquery.image_id].size
|
||||
|
||||
if self.use_rle:
|
||||
all_segs = [
|
||||
datapoint.images[fquery.image_id].objects[obj_id].segment
|
||||
for obj_id in fquery.object_ids_output
|
||||
]
|
||||
if len(all_segs) > 0:
|
||||
# we need to double check that all rles are the correct size
|
||||
# Otherwise cocotools will fail silently to an empty [0,0] mask
|
||||
for seg in all_segs:
|
||||
assert seg["size"] == all_segs[0]["size"], (
|
||||
"Instance segments have inconsistent sizes. "
|
||||
f"Found sizes {seg['size']} and {all_segs[0]['size']}"
|
||||
)
|
||||
fquery.semantic_target = mask_utils.merge(all_segs)
|
||||
else:
|
||||
# There is no good way to create an empty RLE of the correct size
|
||||
# We resort to converting an empty box to RLE
|
||||
fquery.semantic_target = mask_utils.frPyObjects(
|
||||
np.array([[0, 0, 0, 0]], dtype=np.float64), h, w
|
||||
)[0]
|
||||
|
||||
else:
|
||||
# `semantic_target` is uint8 and remains uint8 throughout the transforms
|
||||
# (it contains binary 0 and 1 values just like `segment` for each object)
|
||||
fquery.semantic_target = torch.zeros((h, w), dtype=torch.uint8)
|
||||
for obj_id in fquery.object_ids_output:
|
||||
segment = datapoint.images[fquery.image_id].objects[obj_id].segment
|
||||
if segment is not None:
|
||||
assert (
|
||||
isinstance(segment, torch.Tensor)
|
||||
and segment.dtype == torch.uint8
|
||||
)
|
||||
fquery.semantic_target |= segment
|
||||
|
||||
if self.delete_instance:
|
||||
for img in datapoint.images:
|
||||
for obj in img.objects:
|
||||
del obj.segment
|
||||
obj.segment = None
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class RecomputeBoxesFromMasks:
|
||||
"""Recompute bounding boxes from masks."""
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for img in datapoint.images:
|
||||
for obj in img.objects:
|
||||
# Note: if the mask is empty, the bounding box will be undefined
|
||||
# The empty targets should be subsequently filtered
|
||||
obj.bbox = masks_to_boxes(obj.segment)
|
||||
obj.area = obj.segment.sum().item()
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class DecodeRle:
|
||||
"""This transform decodes RLEs into binary segments.
|
||||
Implementing it as a transforms allows lazy loading. Some transforms (eg query filters)
|
||||
may be deleting masks, so decoding them from the beginning is wasteful.
|
||||
|
||||
This transforms needs to be called before any kind of geometric manipulation
|
||||
"""
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
imgId2size = {}
|
||||
warning_shown = False
|
||||
for imgId, img in enumerate(datapoint.images):
|
||||
if isinstance(img.data, PILImage.Image):
|
||||
img_w, img_h = img.data.size
|
||||
elif isinstance(img.data, torch.Tensor):
|
||||
img_w, img_h = img.data.shape[-2:]
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected image type {type(img.data)}")
|
||||
|
||||
imgId2size[imgId] = (img_h, img_w)
|
||||
|
||||
for obj in img.objects:
|
||||
if obj.segment is not None and not isinstance(
|
||||
obj.segment, torch.Tensor
|
||||
):
|
||||
if mask_utils.area(obj.segment) == 0:
|
||||
print("Warning, empty mask found, approximating from box")
|
||||
obj.segment = torch.zeros(img_h, img_w, dtype=torch.uint8)
|
||||
x1, y1, x2, y2 = obj.bbox.int().tolist()
|
||||
obj.segment[y1 : max(y2, y1 + 1), x1 : max(x1 + 1, x2)] = 1
|
||||
else:
|
||||
obj.segment = mask_utils.decode(obj.segment)
|
||||
# segment is uint8 and remains uint8 throughout the transforms
|
||||
obj.segment = torch.tensor(obj.segment).to(torch.uint8)
|
||||
|
||||
if list(obj.segment.shape) != [img_h, img_w]:
|
||||
# Should not happen often, but adding for security
|
||||
if not warning_shown:
|
||||
print(
|
||||
f"Warning expected instance segmentation size to be {[img_h, img_w]} but found {list(obj.segment.shape)}"
|
||||
)
|
||||
# Printing only once per datapoint to avoid spam
|
||||
warning_shown = True
|
||||
|
||||
obj.segment = F.resize(
|
||||
obj.segment[None], (img_h, img_w)
|
||||
).squeeze(0)
|
||||
|
||||
assert list(obj.segment.shape) == [img_h, img_w]
|
||||
|
||||
warning_shown = False
|
||||
for query in datapoint.find_queries:
|
||||
if query.semantic_target is not None and not isinstance(
|
||||
query.semantic_target, torch.Tensor
|
||||
):
|
||||
query.semantic_target = mask_utils.decode(query.semantic_target)
|
||||
# segment is uint8 and remains uint8 throughout the transforms
|
||||
query.semantic_target = torch.tensor(query.semantic_target).to(
|
||||
torch.uint8
|
||||
)
|
||||
if tuple(query.semantic_target.shape) != imgId2size[query.image_id]:
|
||||
if not warning_shown:
|
||||
print(
|
||||
f"Warning expected semantic segmentation size to be {imgId2size[query.image_id]} but found {tuple(query.semantic_target.shape)}"
|
||||
)
|
||||
# Printing only once per datapoint to avoid spam
|
||||
warning_shown = True
|
||||
|
||||
query.semantic_target = F.resize(
|
||||
query.semantic_target[None], imgId2size[query.image_id]
|
||||
).squeeze(0)
|
||||
|
||||
assert tuple(query.semantic_target.shape) == imgId2size[query.image_id]
|
||||
|
||||
return datapoint
|
||||
Reference in New Issue
Block a user