Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
560
sam3/agent/helpers/masks.py
Executable file
560
sam3/agent/helpers/masks.py
Executable file
@@ -0,0 +1,560 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Any, Iterator, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pycocotools.mask as mask_util
|
||||
import torch
|
||||
from torch import device
|
||||
|
||||
from .boxes import Boxes
|
||||
from .memory import retry_if_cuda_oom
|
||||
|
||||
from .roi_align import ROIAlign
|
||||
|
||||
|
||||
def polygon_area(x, y):
|
||||
# Using the shoelace formula
|
||||
# https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
|
||||
return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
|
||||
|
||||
|
||||
def polygons_to_bitmask(
|
||||
polygons: List[np.ndarray], height: int, width: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Args:
|
||||
polygons (list[ndarray]): each array has shape (Nx2,)
|
||||
height, width (int)
|
||||
|
||||
Returns:
|
||||
ndarray: a bool mask of shape (height, width)
|
||||
"""
|
||||
if len(polygons) == 0:
|
||||
# COCOAPI does not support empty polygons
|
||||
return np.zeros((height, width)).astype(bool)
|
||||
rles = mask_util.frPyObjects(polygons, height, width)
|
||||
rle = mask_util.merge(rles)
|
||||
return mask_util.decode(rle).astype(bool)
|
||||
|
||||
|
||||
def rasterize_polygons_within_box(
|
||||
polygons: List[np.ndarray], box: np.ndarray, mask_size: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Rasterize the polygons into a mask image and
|
||||
crop the mask content in the given box.
|
||||
The cropped mask is resized to (mask_size, mask_size).
|
||||
|
||||
This function is used when generating training targets for mask head in Mask R-CNN.
|
||||
Given original ground-truth masks for an image, new ground-truth mask
|
||||
training targets in the size of `mask_size x mask_size`
|
||||
must be provided for each predicted box. This function will be called to
|
||||
produce such targets.
|
||||
|
||||
Args:
|
||||
polygons (list[ndarray[float]]): a list of polygons, which represents an instance.
|
||||
box: 4-element numpy array
|
||||
mask_size (int):
|
||||
|
||||
Returns:
|
||||
Tensor: BoolTensor of shape (mask_size, mask_size)
|
||||
"""
|
||||
# 1. Shift the polygons w.r.t the boxes
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
|
||||
polygons = copy.deepcopy(polygons)
|
||||
for p in polygons:
|
||||
p[0::2] = p[0::2] - box[0]
|
||||
p[1::2] = p[1::2] - box[1]
|
||||
|
||||
# 2. Rescale the polygons to the new box size
|
||||
# max() to avoid division by small number
|
||||
ratio_h = mask_size / max(h, 0.1)
|
||||
ratio_w = mask_size / max(w, 0.1)
|
||||
|
||||
if ratio_h == ratio_w:
|
||||
for p in polygons:
|
||||
p *= ratio_h
|
||||
else:
|
||||
for p in polygons:
|
||||
p[0::2] *= ratio_w
|
||||
p[1::2] *= ratio_h
|
||||
|
||||
# 3. Rasterize the polygons with coco api
|
||||
mask = polygons_to_bitmask(polygons, mask_size, mask_size)
|
||||
mask = torch.from_numpy(mask)
|
||||
return mask
|
||||
|
||||
|
||||
class BitMasks:
|
||||
"""
|
||||
This class stores the segmentation masks for all objects in one image, in
|
||||
the form of bitmaps.
|
||||
|
||||
Attributes:
|
||||
tensor: bool Tensor of N,H,W, representing N instances in the image.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: Union[torch.Tensor, np.ndarray]):
|
||||
"""
|
||||
Args:
|
||||
tensor: bool Tensor of N,H,W, representing N instances in the image.
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.to(torch.bool)
|
||||
else:
|
||||
tensor = torch.as_tensor(
|
||||
tensor, dtype=torch.bool, device=torch.device("cpu")
|
||||
)
|
||||
assert tensor.dim() == 3, tensor.size()
|
||||
self.image_size = tensor.shape[1:]
|
||||
self.tensor = tensor
|
||||
|
||||
@torch.jit.unused
|
||||
def to(self, *args: Any, **kwargs: Any) -> "BitMasks":
|
||||
return BitMasks(self.tensor.to(*args, **kwargs))
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.tensor.device
|
||||
|
||||
@torch.jit.unused
|
||||
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks":
|
||||
"""
|
||||
Returns:
|
||||
BitMasks: Create a new :class:`BitMasks` by indexing.
|
||||
|
||||
The following usage are allowed:
|
||||
|
||||
1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask.
|
||||
2. `new_masks = masks[2:10]`: return a slice of masks.
|
||||
3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
|
||||
with `length = len(masks)`. Nonzero elements in the vector will be selected.
|
||||
|
||||
Note that the returned object might share storage with this object,
|
||||
subject to Pytorch's indexing semantics.
|
||||
"""
|
||||
if isinstance(item, int):
|
||||
return BitMasks(self.tensor[item].unsqueeze(0))
|
||||
m = self.tensor[item]
|
||||
assert (
|
||||
m.dim() == 3
|
||||
), "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
|
||||
item, m.shape
|
||||
)
|
||||
return BitMasks(m)
|
||||
|
||||
@torch.jit.unused
|
||||
def __iter__(self) -> torch.Tensor:
|
||||
yield from self.tensor
|
||||
|
||||
@torch.jit.unused
|
||||
def __repr__(self) -> str:
|
||||
s = self.__class__.__name__ + "("
|
||||
s += "num_instances={})".format(len(self.tensor))
|
||||
return s
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.tensor.shape[0]
|
||||
|
||||
def nonempty(self) -> torch.Tensor:
|
||||
"""
|
||||
Find masks that are non-empty.
|
||||
|
||||
Returns:
|
||||
Tensor: a BoolTensor which represents
|
||||
whether each mask is empty (False) or non-empty (True).
|
||||
"""
|
||||
return self.tensor.flatten(1).any(dim=1)
|
||||
|
||||
@staticmethod
|
||||
def from_polygon_masks(
|
||||
polygon_masks: Union["PolygonMasks", List[List[np.ndarray]]],
|
||||
height: int,
|
||||
width: int,
|
||||
) -> "BitMasks":
|
||||
"""
|
||||
Args:
|
||||
polygon_masks (list[list[ndarray]] or PolygonMasks)
|
||||
height, width (int)
|
||||
"""
|
||||
if isinstance(polygon_masks, PolygonMasks):
|
||||
polygon_masks = polygon_masks.polygons
|
||||
masks = [polygons_to_bitmask(p, height, width) for p in polygon_masks]
|
||||
if len(masks):
|
||||
return BitMasks(torch.stack([torch.from_numpy(x) for x in masks]))
|
||||
else:
|
||||
return BitMasks(torch.empty(0, height, width, dtype=torch.bool))
|
||||
|
||||
@staticmethod
|
||||
def from_roi_masks(roi_masks: "ROIMasks", height: int, width: int) -> "BitMasks":
|
||||
"""
|
||||
Args:
|
||||
roi_masks:
|
||||
height, width (int):
|
||||
"""
|
||||
return roi_masks.to_bitmasks(height, width)
|
||||
|
||||
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Crop each bitmask by the given box, and resize results to (mask_size, mask_size).
|
||||
This can be used to prepare training targets for Mask R-CNN.
|
||||
It has less reconstruction error compared to rasterization with polygons.
|
||||
However we observe no difference in accuracy,
|
||||
but BitMasks requires more memory to store all the masks.
|
||||
|
||||
Args:
|
||||
boxes (Tensor): Nx4 tensor storing the boxes for each mask
|
||||
mask_size (int): the size of the rasterized mask.
|
||||
|
||||
Returns:
|
||||
Tensor:
|
||||
A bool tensor of shape (N, mask_size, mask_size), where
|
||||
N is the number of predicted boxes for this image.
|
||||
"""
|
||||
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
|
||||
device = self.tensor.device
|
||||
|
||||
batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[
|
||||
:, None
|
||||
]
|
||||
rois = torch.cat([batch_inds, boxes], dim=1) # Nx5
|
||||
|
||||
bit_masks = self.tensor.to(dtype=torch.float32)
|
||||
rois = rois.to(device=device)
|
||||
output = (
|
||||
ROIAlign((mask_size, mask_size), 1.0, 0, aligned=True)
|
||||
.forward(bit_masks[:, None, :, :], rois)
|
||||
.squeeze(1)
|
||||
)
|
||||
output = output >= 0.5
|
||||
return output
|
||||
|
||||
def get_bounding_boxes(self) -> Boxes:
|
||||
"""
|
||||
Returns:
|
||||
Boxes: tight bounding boxes around bitmasks.
|
||||
If a mask is empty, it's bounding box will be all zero.
|
||||
"""
|
||||
boxes = torch.zeros(self.tensor.shape[0], 4, dtype=torch.float32)
|
||||
x_any = torch.any(self.tensor, dim=1)
|
||||
y_any = torch.any(self.tensor, dim=2)
|
||||
for idx in range(self.tensor.shape[0]):
|
||||
x = torch.where(x_any[idx, :])[0]
|
||||
y = torch.where(y_any[idx, :])[0]
|
||||
if len(x) > 0 and len(y) > 0:
|
||||
boxes[idx, :] = torch.as_tensor(
|
||||
[x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=torch.float32
|
||||
)
|
||||
return Boxes(boxes)
|
||||
|
||||
@staticmethod
|
||||
def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks":
|
||||
"""
|
||||
Concatenates a list of BitMasks into a single BitMasks
|
||||
|
||||
Arguments:
|
||||
bitmasks_list (list[BitMasks])
|
||||
|
||||
Returns:
|
||||
BitMasks: the concatenated BitMasks
|
||||
"""
|
||||
assert isinstance(bitmasks_list, (list, tuple))
|
||||
assert len(bitmasks_list) > 0
|
||||
assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list)
|
||||
|
||||
cat_bitmasks = type(bitmasks_list[0])(
|
||||
torch.cat([bm.tensor for bm in bitmasks_list], dim=0)
|
||||
)
|
||||
return cat_bitmasks
|
||||
|
||||
|
||||
class PolygonMasks:
|
||||
"""
|
||||
This class stores the segmentation masks for all objects in one image, in the form of polygons.
|
||||
|
||||
Attributes:
|
||||
polygons: list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon.
|
||||
"""
|
||||
|
||||
def __init__(self, polygons: List[List[Union[torch.Tensor, np.ndarray]]]):
|
||||
"""
|
||||
Arguments:
|
||||
polygons (list[list[np.ndarray]]): The first
|
||||
level of the list correspond to individual instances,
|
||||
the second level to all the polygons that compose the
|
||||
instance, and the third level to the polygon coordinates.
|
||||
The third level array should have the format of
|
||||
[x0, y0, x1, y1, ..., xn, yn] (n >= 3).
|
||||
"""
|
||||
if not isinstance(polygons, list):
|
||||
raise ValueError(
|
||||
"Cannot create PolygonMasks: Expect a list of list of polygons per image. "
|
||||
"Got '{}' instead.".format(type(polygons))
|
||||
)
|
||||
|
||||
def _make_array(t: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
||||
# Use float64 for higher precision, because why not?
|
||||
# Always put polygons on CPU (self.to is a no-op) since they
|
||||
# are supposed to be small tensors.
|
||||
# May need to change this assumption if GPU placement becomes useful
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = t.cpu().numpy()
|
||||
return np.asarray(t).astype("float64")
|
||||
|
||||
def process_polygons(
|
||||
polygons_per_instance: List[Union[torch.Tensor, np.ndarray]],
|
||||
) -> List[np.ndarray]:
|
||||
if not isinstance(polygons_per_instance, list):
|
||||
raise ValueError(
|
||||
"Cannot create polygons: Expect a list of polygons per instance. "
|
||||
"Got '{}' instead.".format(type(polygons_per_instance))
|
||||
)
|
||||
# transform each polygon to a numpy array
|
||||
polygons_per_instance = [_make_array(p) for p in polygons_per_instance]
|
||||
for polygon in polygons_per_instance:
|
||||
if len(polygon) % 2 != 0 or len(polygon) < 6:
|
||||
raise ValueError(
|
||||
f"Cannot create a polygon from {len(polygon)} coordinates."
|
||||
)
|
||||
return polygons_per_instance
|
||||
|
||||
self.polygons: List[List[np.ndarray]] = [
|
||||
process_polygons(polygons_per_instance)
|
||||
for polygons_per_instance in polygons
|
||||
]
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any) -> "PolygonMasks":
|
||||
return self
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
def get_bounding_boxes(self) -> Boxes:
|
||||
"""
|
||||
Returns:
|
||||
Boxes: tight bounding boxes around polygon masks.
|
||||
"""
|
||||
boxes = torch.zeros(len(self.polygons), 4, dtype=torch.float32)
|
||||
for idx, polygons_per_instance in enumerate(self.polygons):
|
||||
minxy = torch.as_tensor([float("inf"), float("inf")], dtype=torch.float32)
|
||||
maxxy = torch.zeros(2, dtype=torch.float32)
|
||||
for polygon in polygons_per_instance:
|
||||
coords = torch.from_numpy(polygon).view(-1, 2).to(dtype=torch.float32)
|
||||
minxy = torch.min(minxy, torch.min(coords, dim=0).values)
|
||||
maxxy = torch.max(maxxy, torch.max(coords, dim=0).values)
|
||||
boxes[idx, :2] = minxy
|
||||
boxes[idx, 2:] = maxxy
|
||||
return Boxes(boxes)
|
||||
|
||||
def nonempty(self) -> torch.Tensor:
|
||||
"""
|
||||
Find masks that are non-empty.
|
||||
|
||||
Returns:
|
||||
Tensor:
|
||||
a BoolTensor which represents whether each mask is empty (False) or not (True).
|
||||
"""
|
||||
keep = [1 if len(polygon) > 0 else 0 for polygon in self.polygons]
|
||||
return torch.from_numpy(np.asarray(keep, dtype=bool))
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice, List[int], torch.BoolTensor]
|
||||
) -> "PolygonMasks":
|
||||
"""
|
||||
Support indexing over the instances and return a `PolygonMasks` object.
|
||||
`item` can be:
|
||||
|
||||
1. An integer. It will return an object with only one instance.
|
||||
2. A slice. It will return an object with the selected instances.
|
||||
3. A list[int]. It will return an object with the selected instances,
|
||||
correpsonding to the indices in the list.
|
||||
4. A vector mask of type BoolTensor, whose length is num_instances.
|
||||
It will return an object with the instances whose mask is nonzero.
|
||||
"""
|
||||
if isinstance(item, int):
|
||||
selected_polygons = [self.polygons[item]]
|
||||
elif isinstance(item, slice):
|
||||
selected_polygons = self.polygons[item]
|
||||
elif isinstance(item, list):
|
||||
selected_polygons = [self.polygons[i] for i in item]
|
||||
elif isinstance(item, torch.Tensor):
|
||||
# Polygons is a list, so we have to move the indices back to CPU.
|
||||
if item.dtype == torch.bool:
|
||||
assert item.dim() == 1, item.shape
|
||||
item = item.nonzero().squeeze(1).cpu().numpy().tolist()
|
||||
elif item.dtype in [torch.int32, torch.int64]:
|
||||
item = item.cpu().numpy().tolist()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported tensor dtype={} for indexing!".format(item.dtype)
|
||||
)
|
||||
selected_polygons = [self.polygons[i] for i in item]
|
||||
return PolygonMasks(selected_polygons)
|
||||
|
||||
def __iter__(self) -> Iterator[List[np.ndarray]]:
|
||||
"""
|
||||
Yields:
|
||||
list[ndarray]: the polygons for one instance.
|
||||
Each Tensor is a float64 vector representing a polygon.
|
||||
"""
|
||||
return iter(self.polygons)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = self.__class__.__name__ + "("
|
||||
s += "num_instances={})".format(len(self.polygons))
|
||||
return s
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.polygons)
|
||||
|
||||
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Crop each mask by the given box, and resize results to (mask_size, mask_size).
|
||||
This can be used to prepare training targets for Mask R-CNN.
|
||||
|
||||
Args:
|
||||
boxes (Tensor): Nx4 tensor storing the boxes for each mask
|
||||
mask_size (int): the size of the rasterized mask.
|
||||
|
||||
Returns:
|
||||
Tensor: A bool tensor of shape (N, mask_size, mask_size), where
|
||||
N is the number of predicted boxes for this image.
|
||||
"""
|
||||
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
|
||||
|
||||
device = boxes.device
|
||||
# Put boxes on the CPU, as the polygon representation is not efficient GPU-wise
|
||||
# (several small tensors for representing a single instance mask)
|
||||
boxes = boxes.to(torch.device("cpu"))
|
||||
|
||||
results = [
|
||||
rasterize_polygons_within_box(poly, box.numpy(), mask_size)
|
||||
for poly, box in zip(self.polygons, boxes)
|
||||
]
|
||||
"""
|
||||
poly: list[list[float]], the polygons for one instance
|
||||
box: a tensor of shape (4,)
|
||||
"""
|
||||
if len(results) == 0:
|
||||
return torch.empty(0, mask_size, mask_size, dtype=torch.bool, device=device)
|
||||
return torch.stack(results, dim=0).to(device=device)
|
||||
|
||||
def area(self):
|
||||
"""
|
||||
Computes area of the mask.
|
||||
Only works with Polygons, using the shoelace formula:
|
||||
https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
|
||||
|
||||
Returns:
|
||||
Tensor: a vector, area for each instance
|
||||
"""
|
||||
|
||||
area = []
|
||||
for polygons_per_instance in self.polygons:
|
||||
area_per_instance = 0
|
||||
for p in polygons_per_instance:
|
||||
area_per_instance += polygon_area(p[0::2], p[1::2])
|
||||
area.append(area_per_instance)
|
||||
|
||||
return torch.tensor(area)
|
||||
|
||||
@staticmethod
|
||||
def cat(polymasks_list: List["PolygonMasks"]) -> "PolygonMasks":
|
||||
"""
|
||||
Concatenates a list of PolygonMasks into a single PolygonMasks
|
||||
|
||||
Arguments:
|
||||
polymasks_list (list[PolygonMasks])
|
||||
|
||||
Returns:
|
||||
PolygonMasks: the concatenated PolygonMasks
|
||||
"""
|
||||
assert isinstance(polymasks_list, (list, tuple))
|
||||
assert len(polymasks_list) > 0
|
||||
assert all(isinstance(polymask, PolygonMasks) for polymask in polymasks_list)
|
||||
|
||||
cat_polymasks = type(polymasks_list[0])(
|
||||
list(itertools.chain.from_iterable(pm.polygons for pm in polymasks_list))
|
||||
)
|
||||
return cat_polymasks
|
||||
|
||||
|
||||
class ROIMasks:
|
||||
"""
|
||||
Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given,
|
||||
full-image bitmask can be obtained by "pasting" the mask on the region defined
|
||||
by the corresponding ROI box.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
tensor: (N, M, M) mask tensor that defines the mask within each ROI.
|
||||
"""
|
||||
if tensor.dim() != 3:
|
||||
raise ValueError("ROIMasks must take a masks of 3 dimension.")
|
||||
self.tensor = tensor
|
||||
|
||||
def to(self, device: torch.device) -> "ROIMasks":
|
||||
return ROIMasks(self.tensor.to(device))
|
||||
|
||||
@property
|
||||
def device(self) -> device:
|
||||
return self.tensor.device
|
||||
|
||||
def __len__(self):
|
||||
return self.tensor.shape[0]
|
||||
|
||||
def __getitem__(self, item) -> "ROIMasks":
|
||||
"""
|
||||
Returns:
|
||||
ROIMasks: Create a new :class:`ROIMasks` by indexing.
|
||||
|
||||
The following usage are allowed:
|
||||
|
||||
1. `new_masks = masks[2:10]`: return a slice of masks.
|
||||
2. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
|
||||
with `length = len(masks)`. Nonzero elements in the vector will be selected.
|
||||
|
||||
Note that the returned object might share storage with this object,
|
||||
subject to Pytorch's indexing semantics.
|
||||
"""
|
||||
t = self.tensor[item]
|
||||
if t.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Indexing on ROIMasks with {item} returns a tensor with shape {t.shape}!"
|
||||
)
|
||||
return ROIMasks(t)
|
||||
|
||||
@torch.jit.unused
|
||||
def __repr__(self) -> str:
|
||||
s = self.__class__.__name__ + "("
|
||||
s += "num_instances={})".format(len(self.tensor))
|
||||
return s
|
||||
|
||||
@torch.jit.unused
|
||||
def to_bitmasks(self, boxes: torch.Tensor, height, width, threshold=0.5):
|
||||
"""
|
||||
Args: see documentation of :func:`paste_masks_in_image`.
|
||||
"""
|
||||
from detectron2.layers.mask_ops import (
|
||||
_paste_masks_tensor_shape,
|
||||
paste_masks_in_image,
|
||||
)
|
||||
|
||||
if torch.jit.is_tracing():
|
||||
if isinstance(height, torch.Tensor):
|
||||
paste_func = _paste_masks_tensor_shape
|
||||
else:
|
||||
paste_func = paste_masks_in_image
|
||||
else:
|
||||
paste_func = retry_if_cuda_oom(paste_masks_in_image)
|
||||
bitmasks = paste_func(
|
||||
self.tensor, boxes.tensor, (height, width), threshold=threshold
|
||||
)
|
||||
return BitMasks(bitmasks)
|
||||
Reference in New Issue
Block a user