Initial commit

fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
facebook-github-bot
2025-11-18 23:07:42 -08:00
commit a13e358df4
504 changed files with 122758 additions and 0 deletions

1
sam3/agent/helpers/__init__.py Executable file
View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

438
sam3/agent/helpers/boxes.py Executable file
View File

@@ -0,0 +1,438 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import math
from enum import IntEnum, unique
from typing import List, Tuple, Union
import numpy as np
import torch
from torch import device
_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
@unique
class BoxMode(IntEnum):
"""
Enum of different ways to represent a box.
"""
XYXY_ABS = 0
"""
(x0, y0, x1, y1) in absolute floating points coordinates.
The coordinates in range [0, width or height].
"""
XYWH_ABS = 1
"""
(x0, y0, w, h) in absolute floating points coordinates.
"""
XYXY_REL = 2
"""
Not yet supported!
(x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
"""
XYWH_REL = 3
"""
Not yet supported!
(x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
"""
XYWHA_ABS = 4
"""
(xc, yc, w, h, a) in absolute floating points coordinates.
(xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
"""
@staticmethod
def convert(
box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode"
) -> _RawBoxType:
"""
Args:
box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
from_mode, to_mode (BoxMode)
Returns:
The converted box of the same type.
"""
if from_mode == to_mode:
return box
original_type = type(box)
is_numpy = isinstance(box, np.ndarray)
single_box = isinstance(box, (list, tuple))
if single_box:
assert len(box) == 4 or len(box) == 5, (
"BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
" where k == 4 or 5"
)
arr = torch.tensor(box)[None, :]
else:
# avoid modifying the input box
if is_numpy:
arr = torch.from_numpy(np.asarray(box)).clone()
else:
arr = box.clone()
assert to_mode not in [
BoxMode.XYXY_REL,
BoxMode.XYWH_REL,
] and from_mode not in [
BoxMode.XYXY_REL,
BoxMode.XYWH_REL,
], "Relative mode not yet supported!"
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
assert (
arr.shape[-1] == 5
), "The last dimension of input shape must be 5 for XYWHA format"
original_dtype = arr.dtype
arr = arr.double()
w = arr[:, 2]
h = arr[:, 3]
a = arr[:, 4]
c = torch.abs(torch.cos(a * math.pi / 180.0))
s = torch.abs(torch.sin(a * math.pi / 180.0))
# This basically computes the horizontal bounding rectangle of the rotated box
new_w = c * w + s * h
new_h = c * h + s * w
# convert center to top-left corner
arr[:, 0] -= new_w / 2.0
arr[:, 1] -= new_h / 2.0
# bottom-right corner
arr[:, 2] = arr[:, 0] + new_w
arr[:, 3] = arr[:, 1] + new_h
arr = arr[:, :4].to(dtype=original_dtype)
elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
original_dtype = arr.dtype
arr = arr.double()
arr[:, 0] += arr[:, 2] / 2.0
arr[:, 1] += arr[:, 3] / 2.0
angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
else:
if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
arr[:, 2] += arr[:, 0]
arr[:, 3] += arr[:, 1]
elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
arr[:, 2] -= arr[:, 0]
arr[:, 3] -= arr[:, 1]
else:
raise NotImplementedError(
"Conversion from BoxMode {} to {} is not supported yet".format(
from_mode, to_mode
)
)
if single_box:
return original_type(arr.flatten().tolist())
if is_numpy:
return arr.numpy()
else:
return arr
class Boxes:
"""
This structure stores a list of boxes as a Nx4 torch.Tensor.
It supports some common methods about boxes
(`area`, `clip`, `nonempty`, etc),
and also behaves like a Tensor
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
Attributes:
tensor (torch.Tensor): float matrix of Nx4. Each row is (x1, y1, x2, y2).
"""
def __init__(self, tensor: torch.Tensor):
"""
Args:
tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
"""
if not isinstance(tensor, torch.Tensor):
tensor = torch.as_tensor(
tensor, dtype=torch.float32, device=torch.device("cpu")
)
else:
tensor = tensor.to(torch.float32)
if tensor.numel() == 0:
# Use reshape, so we don't end up creating a new tensor that does not depend on
# the inputs (and consequently confuses jit)
tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32)
assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
self.tensor = tensor
def clone(self) -> "Boxes":
"""
Clone the Boxes.
Returns:
Boxes
"""
return Boxes(self.tensor.clone())
def to(self, device: torch.device):
# Boxes are assumed float32 and does not support to(dtype)
return Boxes(self.tensor.to(device=device))
def area(self) -> torch.Tensor:
"""
Computes the area of all the boxes.
Returns:
torch.Tensor: a vector with areas of each box.
"""
box = self.tensor
area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
return area
def clip(self, box_size: Tuple[int, int]) -> None:
"""
Clip (in place) the boxes by limiting x coordinates to the range [0, width]
and y coordinates to the range [0, height].
Args:
box_size (height, width): The clipping box's size.
"""
assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
h, w = box_size
x1 = self.tensor[:, 0].clamp(min=0, max=w)
y1 = self.tensor[:, 1].clamp(min=0, max=h)
x2 = self.tensor[:, 2].clamp(min=0, max=w)
y2 = self.tensor[:, 3].clamp(min=0, max=h)
self.tensor = torch.stack((x1, y1, x2, y2), dim=-1)
def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
"""
Find boxes that are non-empty.
A box is considered empty, if either of its side is no larger than threshold.
Returns:
Tensor:
a binary vector which represents whether each box is empty
(False) or non-empty (True).
"""
box = self.tensor
widths = box[:, 2] - box[:, 0]
heights = box[:, 3] - box[:, 1]
keep = (widths > threshold) & (heights > threshold)
return keep
def __getitem__(self, item) -> "Boxes":
"""
Args:
item: int, slice, or a BoolTensor
Returns:
Boxes: Create a new :class:`Boxes` by indexing.
The following usage are allowed:
1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box.
2. `new_boxes = boxes[2:10]`: return a slice of boxes.
3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor
with `length = len(boxes)`. Nonzero elements in the vector will be selected.
Note that the returned Boxes might share storage with this Boxes,
subject to Pytorch's indexing semantics.
"""
if isinstance(item, int):
return Boxes(self.tensor[item].view(1, -1))
b = self.tensor[item]
assert (
b.dim() == 2
), "Indexing on Boxes with {} failed to return a matrix!".format(item)
return Boxes(b)
def __len__(self) -> int:
return self.tensor.shape[0]
def __repr__(self) -> str:
return "Boxes(" + str(self.tensor) + ")"
def inside_box(
self, box_size: Tuple[int, int], boundary_threshold: int = 0
) -> torch.Tensor:
"""
Args:
box_size (height, width): Size of the reference box.
boundary_threshold (int): Boxes that extend beyond the reference box
boundary by more than boundary_threshold are considered "outside".
Returns:
a binary vector, indicating whether each box is inside the reference box.
"""
height, width = box_size
inds_inside = (
(self.tensor[..., 0] >= -boundary_threshold)
& (self.tensor[..., 1] >= -boundary_threshold)
& (self.tensor[..., 2] < width + boundary_threshold)
& (self.tensor[..., 3] < height + boundary_threshold)
)
return inds_inside
def get_centers(self) -> torch.Tensor:
"""
Returns:
The box centers in a Nx2 array of (x, y).
"""
return (self.tensor[:, :2] + self.tensor[:, 2:]) / 2
def scale(self, scale_x: float, scale_y: float) -> None:
"""
Scale the box with horizontal and vertical scaling factors
"""
self.tensor[:, 0::2] *= scale_x
self.tensor[:, 1::2] *= scale_y
@classmethod
def cat(cls, boxes_list: List["Boxes"]) -> "Boxes":
"""
Concatenates a list of Boxes into a single Boxes
Arguments:
boxes_list (list[Boxes])
Returns:
Boxes: the concatenated Boxes
"""
assert isinstance(boxes_list, (list, tuple))
if len(boxes_list) == 0:
return cls(torch.empty(0))
assert all([isinstance(box, Boxes) for box in boxes_list])
# use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
return cat_boxes
@property
def device(self) -> device:
return self.tensor.device
# type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
# https://github.com/pytorch/pytorch/issues/18627
@torch.jit.unused
def __iter__(self):
"""
Yield a box as a Tensor of shape (4,) at a time.
"""
yield from self.tensor
def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
"""
Given two lists of boxes of size N and M,
compute the intersection area between __all__ N x M pairs of boxes.
The box order must be (xmin, ymin, xmax, ymax)
Args:
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
Returns:
Tensor: intersection, sized [N,M].
"""
boxes1, boxes2 = boxes1.tensor, boxes2.tensor
width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max(
boxes1[:, None, :2], boxes2[:, :2]
) # [N,M,2]
width_height.clamp_(min=0) # [N,M,2]
intersection = width_height.prod(dim=2) # [N,M]
return intersection
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
"""
Given two lists of boxes of size N and M, compute the IoU
(intersection over union) between **all** N x M pairs of boxes.
The box order must be (xmin, ymin, xmax, ymax).
Args:
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
Returns:
Tensor: IoU, sized [N,M].
"""
area1 = boxes1.area() # [N]
area2 = boxes2.area() # [M]
inter = pairwise_intersection(boxes1, boxes2)
# handle empty boxes
iou = torch.where(
inter > 0,
inter / (area1[:, None] + area2 - inter),
torch.zeros(1, dtype=inter.dtype, device=inter.device),
)
return iou
def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
"""
Similar to :func:`pariwise_iou` but compute the IoA (intersection over boxes2 area).
Args:
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
Returns:
Tensor: IoA, sized [N,M].
"""
area2 = boxes2.area() # [M]
inter = pairwise_intersection(boxes1, boxes2)
# handle empty boxes
ioa = torch.where(
inter > 0, inter / area2, torch.zeros(1, dtype=inter.dtype, device=inter.device)
)
return ioa
def pairwise_point_box_distance(points: torch.Tensor, boxes: Boxes):
"""
Pairwise distance between N points and M boxes. The distance between a
point and a box is represented by the distance from the point to 4 edges
of the box. Distances are all positive when the point is inside the box.
Args:
points: Nx2 coordinates. Each row is (x, y)
boxes: M boxes
Returns:
Tensor: distances of size (N, M, 4). The 4 values are distances from
the point to the left, top, right, bottom of the box.
"""
x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
x0, y0, x1, y1 = boxes.tensor.unsqueeze(dim=0).unbind(dim=2) # (1, M)
return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2)
def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
"""
Compute pairwise intersection over union (IOU) of two sets of matched
boxes that have the same number of boxes.
Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix.
Args:
boxes1 (Boxes): bounding boxes, sized [N,4].
boxes2 (Boxes): same length as boxes1
Returns:
Tensor: iou, sized [N].
"""
assert len(boxes1) == len(boxes2), (
"boxlists should have the same" "number of entries, got {}, {}".format(
len(boxes1), len(boxes2)
)
)
area1 = boxes1.area() # [N]
area2 = boxes2.area() # [N]
box1, box2 = boxes1.tensor, boxes2.tensor
lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2]
rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2]
wh = (rb - lt).clamp(min=0) # [N,2]
inter = wh[:, 0] * wh[:, 1] # [N]
iou = inter / (area1 + area2 - inter) # [N]
return iou

View File

@@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
An awesome colormap for really neat visualizations.
Copied from Detectron, and removed gray colors.
"""
import random
import numpy as np
__all__ = ["colormap", "random_color", "random_colors"]
# A list of 25 bright and sharp colors for segmentation masks,
# generated from the edges of the sRGB color space for maximum intensity.
_COLORS = (
np.array(
[
# The original 8 sharp colors
1.000,
1.000,
0.000, # 1. Yellow
0.000,
1.000,
0.000, # 2. Lime
0.000,
1.000,
1.000, # 3. Cyan
1.000,
0.000,
1.000, # 4. Magenta
1.000,
0.000,
0.000, # 5. Red
1.000,
0.498,
0.000, # 6. Orange
0.498,
1.000,
0.000, # 7. Chartreuse
0.000,
1.000,
0.498, # 8. Spring Green
1.000,
0.000,
0.498, # 9. Rose
0.498,
0.000,
1.000, # 10. Violet
0.753,
1.000,
0.000, # 11. Electric Lime
1.000,
0.753,
0.000, # 12. Vivid Orange
0.000,
1.000,
0.753, # 13. Turquoise
0.753,
0.000,
1.000, # 14. Bright Violet
1.000,
0.000,
0.753, # 15. Bright Pink
1.000,
0.251,
0.000, # 16. Fiery Orange
0.251,
1.000,
0.000, # 17. Bright Chartreuse
0.000,
1.000,
0.251, # 18. Malachite Green
0.251,
0.000,
1.000, # 19. Deep Violet
1.000,
0.000,
0.251, # 20. Hot Pink
]
)
.astype(np.float32)
.reshape(-1, 3)
)
def colormap(rgb=False, maximum=255):
"""
Args:
rgb (bool): whether to return RGB colors or BGR colors.
maximum (int): either 255 or 1
Returns:
ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
"""
assert maximum in [255, 1], maximum
c = _COLORS * maximum
if not rgb:
c = c[:, ::-1]
return c
def random_color(rgb=False, maximum=255):
"""
Args:
rgb (bool): whether to return RGB colors or BGR colors.
maximum (int): either 255 or 1
Returns:
ndarray: a vector of 3 numbers
"""
idx = np.random.randint(0, len(_COLORS))
ret = _COLORS[idx] * maximum
if not rgb:
ret = ret[::-1]
return ret
def random_colors(N, rgb=False, maximum=255):
"""
Args:
N (int): number of unique colors needed
rgb (bool): whether to return RGB colors or BGR colors.
maximum (int): either 255 or 1
Returns:
ndarray: a list of random_color
"""
indices = random.sample(range(len(_COLORS)), N)
ret = [_COLORS[i] * maximum for i in indices]
if not rgb:
ret = [x[::-1] for x in ret]
return ret
if __name__ == "__main__":
import cv2
size = 100
H, W = 10, 10
canvas = np.random.rand(H * size, W * size, 3).astype("float32")
for h in range(H):
for w in range(W):
idx = h * W + w
if idx >= len(_COLORS):
break
canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
cv2.imshow("a", canvas)
cv2.waitKey(0)

244
sam3/agent/helpers/keypoints.py Executable file
View File

@@ -0,0 +1,244 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Any, List, Tuple, Union
import numpy as np
import torch
from torch.nn import functional as F
class Keypoints:
"""
Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property
containing the x,y location and visibility flag of each keypoint. This tensor has shape
(N, K, 3) where N is the number of instances and K is the number of keypoints per instance.
The visibility flag follows the COCO format and must be one of three integers:
* v=0: not labeled (in which case x=y=0)
* v=1: labeled but not visible
* v=2: labeled and visible
"""
def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
"""
Arguments:
keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint.
The shape should be (N, K, 3) where N is the number of
instances, and K is the number of keypoints per instance.
"""
device = (
keypoints.device
if isinstance(keypoints, torch.Tensor)
else torch.device("cpu")
)
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape
self.tensor = keypoints
def __len__(self) -> int:
return self.tensor.size(0)
def to(self, *args: Any, **kwargs: Any) -> "Keypoints":
return type(self)(self.tensor.to(*args, **kwargs))
@property
def device(self) -> torch.device:
return self.tensor.device
def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor:
"""
Convert keypoint annotations to a heatmap of one-hot labels for training,
as described in :paper:`Mask R-CNN`.
Arguments:
boxes: Nx4 tensor, the boxes to draw the keypoints to
Returns:
heatmaps:
A tensor of shape (N, K), each element is integer spatial label
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
valid:
A tensor of shape (N, K) containing whether each keypoint is in the roi or not.
"""
return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size)
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints":
"""
Create a new `Keypoints` by indexing on this `Keypoints`.
The following usage are allowed:
1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance.
2. `new_kpts = kpts[2:10]`: return a slice of key points.
3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
with `length = len(kpts)`. Nonzero elements in the vector will be selected.
Note that the returned Keypoints might share storage with this Keypoints,
subject to Pytorch's indexing semantics.
"""
if isinstance(item, int):
return Keypoints([self.tensor[item]])
return Keypoints(self.tensor[item])
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_instances={})".format(len(self.tensor))
return s
@staticmethod
def cat(keypoints_list: List["Keypoints"]) -> "Keypoints":
"""
Concatenates a list of Keypoints into a single Keypoints
Arguments:
keypoints_list (list[Keypoints])
Returns:
Keypoints: the concatenated Keypoints
"""
assert isinstance(keypoints_list, (list, tuple))
assert len(keypoints_list) > 0
assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list)
cat_kpts = type(keypoints_list[0])(
torch.cat([kpts.tensor for kpts in keypoints_list], dim=0)
)
return cat_kpts
def _keypoints_to_heatmap(
keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space.
Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the
closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the
continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"):
d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
Arguments:
keypoints: tensor of keypoint locations in of shape (N, K, 3).
rois: Nx4 tensor of rois in xyxy format
heatmap_size: integer side length of square heatmap.
Returns:
heatmaps: A tensor of shape (N, K) containing an integer spatial label
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
valid: A tensor of shape (N, K) containing whether each keypoint is in
the roi or not.
"""
if rois.numel() == 0:
return rois.new().long(), rois.new().long()
offset_x = rois[:, 0]
offset_y = rois[:, 1]
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
offset_x = offset_x[:, None]
offset_y = offset_y[:, None]
scale_x = scale_x[:, None]
scale_y = scale_y[:, None]
x = keypoints[..., 0]
y = keypoints[..., 1]
x_boundary_inds = x == rois[:, 2][:, None]
y_boundary_inds = y == rois[:, 3][:, None]
x = (x - offset_x) * scale_x
x = x.floor().long()
y = (y - offset_y) * scale_y
y = y.floor().long()
x[x_boundary_inds] = heatmap_size - 1
y[y_boundary_inds] = heatmap_size - 1
valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
vis = keypoints[..., 2] > 0
valid = (valid_loc & vis).long()
lin_ind = y * heatmap_size + x
heatmaps = lin_ind * valid
return heatmaps, valid
@torch.jit.script_if_tracing
def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
"""
Extract predicted keypoint locations from heatmaps.
Args:
maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for
each ROI and each keypoint.
rois (Tensor): (#ROIs, 4). The box of each ROI.
Returns:
Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to
(x, y, logit, score) for each keypoint.
When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate,
we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
"""
offset_x = rois[:, 0]
offset_y = rois[:, 1]
widths = (rois[:, 2] - rois[:, 0]).clamp(min=1)
heights = (rois[:, 3] - rois[:, 1]).clamp(min=1)
widths_ceil = widths.ceil()
heights_ceil = heights.ceil()
num_rois, num_keypoints = maps.shape[:2]
xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4)
width_corrections = widths / widths_ceil
height_corrections = heights / heights_ceil
keypoints_idx = torch.arange(num_keypoints, device=maps.device)
for i in range(num_rois):
outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
roi_map = F.interpolate(
maps[[i]], size=outsize, mode="bicubic", align_corners=False
)
# Although semantically equivalent, `reshape` is used instead of `squeeze` due
# to limitation during ONNX export of `squeeze` in scripting mode
roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W
# softmax over the spatial region
max_score, _ = roi_map.view(num_keypoints, -1).max(1)
max_score = max_score.view(num_keypoints, 1, 1)
tmp_full_resolution = (roi_map - max_score).exp_()
tmp_pool_resolution = (maps[i] - max_score).exp_()
# Produce scores over the region H x W, but normalize with POOL_H x POOL_W,
# so that the scores of objects of different absolute sizes will be more comparable
roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum(
(1, 2), keepdim=True
)
w = roi_map.shape[2]
pos = roi_map.view(num_keypoints, -1).argmax(1)
x_int = pos % w
y_int = (pos - x_int) // w
assert (
roi_map_scores[keypoints_idx, y_int, x_int]
== roi_map_scores.view(num_keypoints, -1).max(1)[0]
).all()
x = (x_int.float() + 0.5) * width_corrections[i]
y = (y_int.float() + 0.5) * height_corrections[i]
xy_preds[i, :, 0] = x + offset_x[i]
xy_preds[i, :, 1] = y + offset_y[i]
xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int]
xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int]
return xy_preds

View File

@@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Dict, List
import numpy as np
import torch
try:
from pycocotools import mask as mask_utils
except Exception:
mask_utils = None
def mask_intersection(
masks1: torch.Tensor, masks2: torch.Tensor, block_size: int = 16
) -> torch.Tensor:
assert masks1.shape[1:] == masks2.shape[1:]
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
N, M = masks1.shape[0], masks2.shape[0]
out = torch.zeros(N, M, device=masks1.device, dtype=torch.long)
for i in range(0, N, block_size):
for j in range(0, M, block_size):
a = masks1[i : i + block_size]
b = masks2[j : j + block_size]
inter = (a[:, None] & b[None, :]).flatten(-2).sum(-1)
out[i : i + block_size, j : j + block_size] = inter
return out
def mask_iom(masks1: torch.Tensor, masks2: torch.Tensor) -> torch.Tensor:
assert masks1.shape[1:] == masks2.shape[1:]
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
inter = mask_intersection(masks1, masks2)
area1 = masks1.flatten(-2).sum(-1) # (N,)
area2 = masks2.flatten(-2).sum(-1) # (M,)
min_area = torch.min(area1[:, None], area2[None, :]).clamp_min(1)
return inter.float() / (min_area.float() + 1e-8)
def _decode_single_mask(mask_repr, h: int, w: int) -> np.ndarray:
if isinstance(mask_repr, (list, tuple, np.ndarray)):
arr = np.array(mask_repr)
if arr.ndim != 2:
raise ValueError("Mask array must be 2D (H, W).")
return (arr > 0).astype(np.uint8)
if mask_utils is None:
raise ImportError(
"pycocotools is required to decode RLE mask strings. pip install pycocotools"
)
if not isinstance(mask_repr, (str, bytes)):
raise ValueError("Unsupported mask representation type for RLE decode.")
rle = {
"counts": mask_repr if isinstance(mask_repr, (str, bytes)) else str(mask_repr),
"size": [h, w],
}
decoded = mask_utils.decode(rle)
if decoded.ndim == 3:
decoded = decoded[:, :, 0]
return (decoded > 0).astype(np.uint8)
def _decode_masks_to_torch_bool(pred_masks: List, h: int, w: int) -> torch.Tensor:
bin_masks = [_decode_single_mask(m, h, w) for m in pred_masks]
masks_np = np.stack(bin_masks, axis=0).astype(np.uint8) # (N, H, W)
return torch.from_numpy(masks_np > 0)
def remove_overlapping_masks(sample: Dict, iom_thresh: float = 0.3) -> Dict:
"""
Greedy keep: sort by score desc; keep a mask if IoM to all kept masks <= threshold.
If pred_masks has length 0 or 1, returns sample unchanged (no extra keys).
"""
# Basic presence checks
if "pred_masks" not in sample or not isinstance(sample["pred_masks"], list):
return sample # nothing to do / preserve as-is
pred_masks = sample["pred_masks"]
N = len(pred_masks)
# --- Early exit: 0 or 1 mask -> do NOT modify the JSON at all ---
if N <= 1:
return sample
# From here on we have at least 2 masks
h = int(sample["orig_img_h"])
w = int(sample["orig_img_w"])
pred_scores = sample.get("pred_scores", [1.0] * N) # fallback if scores missing
pred_boxes = sample.get("pred_boxes", None)
assert N == len(pred_scores), "pred_masks and pred_scores must have same length"
if pred_boxes is not None:
assert N == len(pred_boxes), "pred_masks and pred_boxes must have same length"
masks_bool = _decode_masks_to_torch_bool(pred_masks, h, w) # (N, H, W)
order = sorted(range(N), key=lambda i: float(pred_scores[i]), reverse=True)
kept_idx: List[int] = []
kept_masks: List[torch.Tensor] = []
for i in order:
cand = masks_bool[i].unsqueeze(0) # (1, H, W)
if len(kept_masks) == 0:
kept_idx.append(i)
kept_masks.append(masks_bool[i])
continue
kept_stack = torch.stack(kept_masks, dim=0) # (K, H, W)
iom_vals = mask_iom(cand, kept_stack).squeeze(0) # (K,)
if torch.any(iom_vals > iom_thresh):
continue # overlaps too much with a higher-scored kept mask
kept_idx.append(i)
kept_masks.append(masks_bool[i])
kept_idx_sorted = sorted(kept_idx)
# Build filtered JSON (this *does* modify fields; only for N>=2 case)
out = dict(sample)
out["pred_masks"] = [pred_masks[i] for i in kept_idx_sorted]
out["pred_scores"] = [pred_scores[i] for i in kept_idx_sorted]
if pred_boxes is not None:
out["pred_boxes"] = [pred_boxes[i] for i in kept_idx_sorted]
out["kept_indices"] = kept_idx_sorted
out["removed_indices"] = [i for i in range(N) if i not in set(kept_idx_sorted)]
out["iom_threshold"] = float(iom_thresh)
return out

560
sam3/agent/helpers/masks.py Executable file
View File

@@ -0,0 +1,560 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import copy
import itertools
from typing import Any, Iterator, List, Union
import numpy as np
import pycocotools.mask as mask_util
import torch
from torch import device
from .boxes import Boxes
from .memory import retry_if_cuda_oom
from .roi_align import ROIAlign
def polygon_area(x, y):
# Using the shoelace formula
# https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
def polygons_to_bitmask(
polygons: List[np.ndarray], height: int, width: int
) -> np.ndarray:
"""
Args:
polygons (list[ndarray]): each array has shape (Nx2,)
height, width (int)
Returns:
ndarray: a bool mask of shape (height, width)
"""
if len(polygons) == 0:
# COCOAPI does not support empty polygons
return np.zeros((height, width)).astype(bool)
rles = mask_util.frPyObjects(polygons, height, width)
rle = mask_util.merge(rles)
return mask_util.decode(rle).astype(bool)
def rasterize_polygons_within_box(
polygons: List[np.ndarray], box: np.ndarray, mask_size: int
) -> torch.Tensor:
"""
Rasterize the polygons into a mask image and
crop the mask content in the given box.
The cropped mask is resized to (mask_size, mask_size).
This function is used when generating training targets for mask head in Mask R-CNN.
Given original ground-truth masks for an image, new ground-truth mask
training targets in the size of `mask_size x mask_size`
must be provided for each predicted box. This function will be called to
produce such targets.
Args:
polygons (list[ndarray[float]]): a list of polygons, which represents an instance.
box: 4-element numpy array
mask_size (int):
Returns:
Tensor: BoolTensor of shape (mask_size, mask_size)
"""
# 1. Shift the polygons w.r.t the boxes
w, h = box[2] - box[0], box[3] - box[1]
polygons = copy.deepcopy(polygons)
for p in polygons:
p[0::2] = p[0::2] - box[0]
p[1::2] = p[1::2] - box[1]
# 2. Rescale the polygons to the new box size
# max() to avoid division by small number
ratio_h = mask_size / max(h, 0.1)
ratio_w = mask_size / max(w, 0.1)
if ratio_h == ratio_w:
for p in polygons:
p *= ratio_h
else:
for p in polygons:
p[0::2] *= ratio_w
p[1::2] *= ratio_h
# 3. Rasterize the polygons with coco api
mask = polygons_to_bitmask(polygons, mask_size, mask_size)
mask = torch.from_numpy(mask)
return mask
class BitMasks:
"""
This class stores the segmentation masks for all objects in one image, in
the form of bitmaps.
Attributes:
tensor: bool Tensor of N,H,W, representing N instances in the image.
"""
def __init__(self, tensor: Union[torch.Tensor, np.ndarray]):
"""
Args:
tensor: bool Tensor of N,H,W, representing N instances in the image.
"""
if isinstance(tensor, torch.Tensor):
tensor = tensor.to(torch.bool)
else:
tensor = torch.as_tensor(
tensor, dtype=torch.bool, device=torch.device("cpu")
)
assert tensor.dim() == 3, tensor.size()
self.image_size = tensor.shape[1:]
self.tensor = tensor
@torch.jit.unused
def to(self, *args: Any, **kwargs: Any) -> "BitMasks":
return BitMasks(self.tensor.to(*args, **kwargs))
@property
def device(self) -> torch.device:
return self.tensor.device
@torch.jit.unused
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks":
"""
Returns:
BitMasks: Create a new :class:`BitMasks` by indexing.
The following usage are allowed:
1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask.
2. `new_masks = masks[2:10]`: return a slice of masks.
3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
with `length = len(masks)`. Nonzero elements in the vector will be selected.
Note that the returned object might share storage with this object,
subject to Pytorch's indexing semantics.
"""
if isinstance(item, int):
return BitMasks(self.tensor[item].unsqueeze(0))
m = self.tensor[item]
assert (
m.dim() == 3
), "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
item, m.shape
)
return BitMasks(m)
@torch.jit.unused
def __iter__(self) -> torch.Tensor:
yield from self.tensor
@torch.jit.unused
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_instances={})".format(len(self.tensor))
return s
def __len__(self) -> int:
return self.tensor.shape[0]
def nonempty(self) -> torch.Tensor:
"""
Find masks that are non-empty.
Returns:
Tensor: a BoolTensor which represents
whether each mask is empty (False) or non-empty (True).
"""
return self.tensor.flatten(1).any(dim=1)
@staticmethod
def from_polygon_masks(
polygon_masks: Union["PolygonMasks", List[List[np.ndarray]]],
height: int,
width: int,
) -> "BitMasks":
"""
Args:
polygon_masks (list[list[ndarray]] or PolygonMasks)
height, width (int)
"""
if isinstance(polygon_masks, PolygonMasks):
polygon_masks = polygon_masks.polygons
masks = [polygons_to_bitmask(p, height, width) for p in polygon_masks]
if len(masks):
return BitMasks(torch.stack([torch.from_numpy(x) for x in masks]))
else:
return BitMasks(torch.empty(0, height, width, dtype=torch.bool))
@staticmethod
def from_roi_masks(roi_masks: "ROIMasks", height: int, width: int) -> "BitMasks":
"""
Args:
roi_masks:
height, width (int):
"""
return roi_masks.to_bitmasks(height, width)
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
"""
Crop each bitmask by the given box, and resize results to (mask_size, mask_size).
This can be used to prepare training targets for Mask R-CNN.
It has less reconstruction error compared to rasterization with polygons.
However we observe no difference in accuracy,
but BitMasks requires more memory to store all the masks.
Args:
boxes (Tensor): Nx4 tensor storing the boxes for each mask
mask_size (int): the size of the rasterized mask.
Returns:
Tensor:
A bool tensor of shape (N, mask_size, mask_size), where
N is the number of predicted boxes for this image.
"""
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
device = self.tensor.device
batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[
:, None
]
rois = torch.cat([batch_inds, boxes], dim=1) # Nx5
bit_masks = self.tensor.to(dtype=torch.float32)
rois = rois.to(device=device)
output = (
ROIAlign((mask_size, mask_size), 1.0, 0, aligned=True)
.forward(bit_masks[:, None, :, :], rois)
.squeeze(1)
)
output = output >= 0.5
return output
def get_bounding_boxes(self) -> Boxes:
"""
Returns:
Boxes: tight bounding boxes around bitmasks.
If a mask is empty, it's bounding box will be all zero.
"""
boxes = torch.zeros(self.tensor.shape[0], 4, dtype=torch.float32)
x_any = torch.any(self.tensor, dim=1)
y_any = torch.any(self.tensor, dim=2)
for idx in range(self.tensor.shape[0]):
x = torch.where(x_any[idx, :])[0]
y = torch.where(y_any[idx, :])[0]
if len(x) > 0 and len(y) > 0:
boxes[idx, :] = torch.as_tensor(
[x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=torch.float32
)
return Boxes(boxes)
@staticmethod
def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks":
"""
Concatenates a list of BitMasks into a single BitMasks
Arguments:
bitmasks_list (list[BitMasks])
Returns:
BitMasks: the concatenated BitMasks
"""
assert isinstance(bitmasks_list, (list, tuple))
assert len(bitmasks_list) > 0
assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list)
cat_bitmasks = type(bitmasks_list[0])(
torch.cat([bm.tensor for bm in bitmasks_list], dim=0)
)
return cat_bitmasks
class PolygonMasks:
"""
This class stores the segmentation masks for all objects in one image, in the form of polygons.
Attributes:
polygons: list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon.
"""
def __init__(self, polygons: List[List[Union[torch.Tensor, np.ndarray]]]):
"""
Arguments:
polygons (list[list[np.ndarray]]): The first
level of the list correspond to individual instances,
the second level to all the polygons that compose the
instance, and the third level to the polygon coordinates.
The third level array should have the format of
[x0, y0, x1, y1, ..., xn, yn] (n >= 3).
"""
if not isinstance(polygons, list):
raise ValueError(
"Cannot create PolygonMasks: Expect a list of list of polygons per image. "
"Got '{}' instead.".format(type(polygons))
)
def _make_array(t: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
# Use float64 for higher precision, because why not?
# Always put polygons on CPU (self.to is a no-op) since they
# are supposed to be small tensors.
# May need to change this assumption if GPU placement becomes useful
if isinstance(t, torch.Tensor):
t = t.cpu().numpy()
return np.asarray(t).astype("float64")
def process_polygons(
polygons_per_instance: List[Union[torch.Tensor, np.ndarray]],
) -> List[np.ndarray]:
if not isinstance(polygons_per_instance, list):
raise ValueError(
"Cannot create polygons: Expect a list of polygons per instance. "
"Got '{}' instead.".format(type(polygons_per_instance))
)
# transform each polygon to a numpy array
polygons_per_instance = [_make_array(p) for p in polygons_per_instance]
for polygon in polygons_per_instance:
if len(polygon) % 2 != 0 or len(polygon) < 6:
raise ValueError(
f"Cannot create a polygon from {len(polygon)} coordinates."
)
return polygons_per_instance
self.polygons: List[List[np.ndarray]] = [
process_polygons(polygons_per_instance)
for polygons_per_instance in polygons
]
def to(self, *args: Any, **kwargs: Any) -> "PolygonMasks":
return self
@property
def device(self) -> torch.device:
return torch.device("cpu")
def get_bounding_boxes(self) -> Boxes:
"""
Returns:
Boxes: tight bounding boxes around polygon masks.
"""
boxes = torch.zeros(len(self.polygons), 4, dtype=torch.float32)
for idx, polygons_per_instance in enumerate(self.polygons):
minxy = torch.as_tensor([float("inf"), float("inf")], dtype=torch.float32)
maxxy = torch.zeros(2, dtype=torch.float32)
for polygon in polygons_per_instance:
coords = torch.from_numpy(polygon).view(-1, 2).to(dtype=torch.float32)
minxy = torch.min(minxy, torch.min(coords, dim=0).values)
maxxy = torch.max(maxxy, torch.max(coords, dim=0).values)
boxes[idx, :2] = minxy
boxes[idx, 2:] = maxxy
return Boxes(boxes)
def nonempty(self) -> torch.Tensor:
"""
Find masks that are non-empty.
Returns:
Tensor:
a BoolTensor which represents whether each mask is empty (False) or not (True).
"""
keep = [1 if len(polygon) > 0 else 0 for polygon in self.polygons]
return torch.from_numpy(np.asarray(keep, dtype=bool))
def __getitem__(
self, item: Union[int, slice, List[int], torch.BoolTensor]
) -> "PolygonMasks":
"""
Support indexing over the instances and return a `PolygonMasks` object.
`item` can be:
1. An integer. It will return an object with only one instance.
2. A slice. It will return an object with the selected instances.
3. A list[int]. It will return an object with the selected instances,
correpsonding to the indices in the list.
4. A vector mask of type BoolTensor, whose length is num_instances.
It will return an object with the instances whose mask is nonzero.
"""
if isinstance(item, int):
selected_polygons = [self.polygons[item]]
elif isinstance(item, slice):
selected_polygons = self.polygons[item]
elif isinstance(item, list):
selected_polygons = [self.polygons[i] for i in item]
elif isinstance(item, torch.Tensor):
# Polygons is a list, so we have to move the indices back to CPU.
if item.dtype == torch.bool:
assert item.dim() == 1, item.shape
item = item.nonzero().squeeze(1).cpu().numpy().tolist()
elif item.dtype in [torch.int32, torch.int64]:
item = item.cpu().numpy().tolist()
else:
raise ValueError(
"Unsupported tensor dtype={} for indexing!".format(item.dtype)
)
selected_polygons = [self.polygons[i] for i in item]
return PolygonMasks(selected_polygons)
def __iter__(self) -> Iterator[List[np.ndarray]]:
"""
Yields:
list[ndarray]: the polygons for one instance.
Each Tensor is a float64 vector representing a polygon.
"""
return iter(self.polygons)
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_instances={})".format(len(self.polygons))
return s
def __len__(self) -> int:
return len(self.polygons)
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
"""
Crop each mask by the given box, and resize results to (mask_size, mask_size).
This can be used to prepare training targets for Mask R-CNN.
Args:
boxes (Tensor): Nx4 tensor storing the boxes for each mask
mask_size (int): the size of the rasterized mask.
Returns:
Tensor: A bool tensor of shape (N, mask_size, mask_size), where
N is the number of predicted boxes for this image.
"""
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
device = boxes.device
# Put boxes on the CPU, as the polygon representation is not efficient GPU-wise
# (several small tensors for representing a single instance mask)
boxes = boxes.to(torch.device("cpu"))
results = [
rasterize_polygons_within_box(poly, box.numpy(), mask_size)
for poly, box in zip(self.polygons, boxes)
]
"""
poly: list[list[float]], the polygons for one instance
box: a tensor of shape (4,)
"""
if len(results) == 0:
return torch.empty(0, mask_size, mask_size, dtype=torch.bool, device=device)
return torch.stack(results, dim=0).to(device=device)
def area(self):
"""
Computes area of the mask.
Only works with Polygons, using the shoelace formula:
https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
Returns:
Tensor: a vector, area for each instance
"""
area = []
for polygons_per_instance in self.polygons:
area_per_instance = 0
for p in polygons_per_instance:
area_per_instance += polygon_area(p[0::2], p[1::2])
area.append(area_per_instance)
return torch.tensor(area)
@staticmethod
def cat(polymasks_list: List["PolygonMasks"]) -> "PolygonMasks":
"""
Concatenates a list of PolygonMasks into a single PolygonMasks
Arguments:
polymasks_list (list[PolygonMasks])
Returns:
PolygonMasks: the concatenated PolygonMasks
"""
assert isinstance(polymasks_list, (list, tuple))
assert len(polymasks_list) > 0
assert all(isinstance(polymask, PolygonMasks) for polymask in polymasks_list)
cat_polymasks = type(polymasks_list[0])(
list(itertools.chain.from_iterable(pm.polygons for pm in polymasks_list))
)
return cat_polymasks
class ROIMasks:
"""
Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given,
full-image bitmask can be obtained by "pasting" the mask on the region defined
by the corresponding ROI box.
"""
def __init__(self, tensor: torch.Tensor):
"""
Args:
tensor: (N, M, M) mask tensor that defines the mask within each ROI.
"""
if tensor.dim() != 3:
raise ValueError("ROIMasks must take a masks of 3 dimension.")
self.tensor = tensor
def to(self, device: torch.device) -> "ROIMasks":
return ROIMasks(self.tensor.to(device))
@property
def device(self) -> device:
return self.tensor.device
def __len__(self):
return self.tensor.shape[0]
def __getitem__(self, item) -> "ROIMasks":
"""
Returns:
ROIMasks: Create a new :class:`ROIMasks` by indexing.
The following usage are allowed:
1. `new_masks = masks[2:10]`: return a slice of masks.
2. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
with `length = len(masks)`. Nonzero elements in the vector will be selected.
Note that the returned object might share storage with this object,
subject to Pytorch's indexing semantics.
"""
t = self.tensor[item]
if t.dim() != 3:
raise ValueError(
f"Indexing on ROIMasks with {item} returns a tensor with shape {t.shape}!"
)
return ROIMasks(t)
@torch.jit.unused
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_instances={})".format(len(self.tensor))
return s
@torch.jit.unused
def to_bitmasks(self, boxes: torch.Tensor, height, width, threshold=0.5):
"""
Args: see documentation of :func:`paste_masks_in_image`.
"""
from detectron2.layers.mask_ops import (
_paste_masks_tensor_shape,
paste_masks_in_image,
)
if torch.jit.is_tracing():
if isinstance(height, torch.Tensor):
paste_func = _paste_masks_tensor_shape
else:
paste_func = paste_masks_in_image
else:
paste_func = retry_if_cuda_oom(paste_masks_in_image)
bitmasks = paste_func(
self.tensor, boxes.tensor, (height, width), threshold=threshold
)
return BitMasks(bitmasks)

87
sam3/agent/helpers/memory.py Executable file
View File

@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import logging
from contextlib import contextmanager
from functools import wraps
import torch
__all__ = ["retry_if_cuda_oom"]
@contextmanager
def _ignore_torch_cuda_oom():
"""
A context which ignores CUDA OOM exception from pytorch.
"""
try:
yield
except RuntimeError as e:
# NOTE: the string may change?
if "CUDA out of memory. " in str(e):
pass
else:
raise
def retry_if_cuda_oom(func):
"""
Makes a function retry itself after encountering
pytorch's CUDA OOM error.
It will first retry after calling `torch.cuda.empty_cache()`.
If that still fails, it will then retry by trying to convert inputs to CPUs.
In this case, it expects the function to dispatch to CPU implementation.
The return values may become CPU tensors as well and it's user's
responsibility to convert it back to CUDA tensor if needed.
Args:
func: a stateless callable that takes tensor-like objects as arguments
Returns:
a callable which retries `func` if OOM is encountered.
Examples:
::
output = retry_if_cuda_oom(some_torch_function)(input1, input2)
# output may be on CPU even if inputs are on GPU
Note:
1. When converting inputs to CPU, it will only look at each argument and check
if it has `.device` and `.to` for conversion. Nested structures of tensors
are not supported.
2. Since the function might be called more than once, it has to be
stateless.
"""
def maybe_to_cpu(x):
try:
like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
except AttributeError:
like_gpu_tensor = False
if like_gpu_tensor:
return x.to(device="cpu")
else:
return x
@wraps(func)
def wrapped(*args, **kwargs):
with _ignore_torch_cuda_oom():
return func(*args, **kwargs)
# Clear cache and retry
torch.cuda.empty_cache()
with _ignore_torch_cuda_oom():
return func(*args, **kwargs)
# Try on CPU. This slows down the code significantly, therefore print a notice.
logger = logging.getLogger(__name__)
logger.info(
"Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func))
)
new_args = (maybe_to_cpu(x) for x in args)
new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
return func(*new_args, **new_kwargs)
return wrapped

122
sam3/agent/helpers/rle.py Executable file
View File

@@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Some utilities for RLE encoding that doesn't require downloading the masks to the cpu"""
import numpy as np
import torch
from pycocotools import mask as mask_util
@torch.no_grad()
def rle_encode(orig_mask, return_areas=False):
"""Encodes a collection of masks in RLE format
This function emulates the behavior of the COCO API's encode function, but
is executed partially on the GPU for faster execution.
Args:
mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool
return_areas (bool): If True, add the areas of the masks as a part of
the RLE output dict under the "area" key. Default is False.
Returns:
str: The RLE encoded masks
"""
assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)"
assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool"
if orig_mask.numel() == 0:
return []
# First, transpose the spatial dimensions.
# This is necessary because the COCO API uses Fortran order
mask = orig_mask.transpose(1, 2)
# Flatten the mask
flat_mask = mask.reshape(mask.shape[0], -1)
if return_areas:
mask_areas = flat_mask.sum(-1).tolist()
# Find the indices where the mask changes
differences = torch.ones(
mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool
)
differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]
differences[:, 0] = flat_mask[:, 0]
_, change_indices = torch.where(differences)
try:
boundaries = torch.cumsum(differences.sum(-1), 0).cpu()
except RuntimeError as _:
boundaries = torch.cumsum(differences.cpu().sum(-1), 0)
change_indices_clone = change_indices.clone()
# First pass computes the RLEs on GPU, in a flatten format
for i in range(mask.shape[0]):
# Get the change indices for this batch item
beg = 0 if i == 0 else boundaries[i - 1].item()
end = boundaries[i].item()
change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1]
# Now we can split the RLES of each batch item, and convert them to strings
# No more gpu at this point
change_indices = change_indices.tolist()
batch_rles = []
# Process each mask in the batch separately
for i in range(mask.shape[0]):
beg = 0 if i == 0 else boundaries[i - 1].item()
end = boundaries[i].item()
run_lengths = change_indices[beg:end]
uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}
h, w = uncompressed_rle["size"]
rle = mask_util.frPyObjects(uncompressed_rle, h, w)
rle["counts"] = rle["counts"].decode("utf-8")
if return_areas:
rle["area"] = mask_areas[i]
batch_rles.append(rle)
return batch_rles
def robust_rle_encode(masks):
"""Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails"""
assert masks.ndim == 3, "Mask must be of shape (N, H, W)"
assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool"
try:
return rle_encode(masks)
except RuntimeError as _:
masks = masks.cpu().numpy()
rles = [
mask_util.encode(
np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F")
)[0]
for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
return rles
def ann_to_rle(segm, im_info):
"""Convert annotation which can be polygons, uncompressed RLE to RLE.
Args:
ann (dict) : annotation object
Returns:
ann (rle)
"""
h, w = im_info["height"], im_info["width"]
if isinstance(segm, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = mask_util.frPyObjects(segm, h, w)
rle = mask_util.merge(rles)
elif isinstance(segm["counts"], list):
# uncompressed RLE
rle = mask_util.frPyObjects(segm, h, w)
else:
# rle
rle = segm
return rle

75
sam3/agent/helpers/roi_align.py Executable file
View File

@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from torch import nn
from torchvision.ops import roi_align
# NOTE: torchvision's RoIAlign has a different default aligned=False
class ROIAlign(nn.Module):
def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True):
"""
Args:
output_size (tuple): h, w
spatial_scale (float): scale the input boxes by this number
sampling_ratio (int): number of inputs samples to take for each output
sample. 0 to take samples densely.
aligned (bool): if False, use the legacy implementation in
Detectron. If True, align the results more perfectly.
Note:
The meaning of aligned=True:
Given a continuous coordinate c, its two neighboring pixel indices (in our
pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example,
c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled
from the underlying signal at continuous coordinates 0.5 and 1.5). But the original
roi_align (aligned=False) does not subtract the 0.5 when computing neighboring
pixel indices and therefore it uses pixels with a slightly incorrect alignment
(relative to our pixel model) when performing bilinear interpolation.
With `aligned=True`,
we first appropriately scale the ROI and then shift it by -0.5
prior to calling roi_align. This produces the correct neighbors; see
detectron2/tests/test_roi_align.py for verification.
The difference does not make a difference to the model's performance if
ROIAlign is used together with conv layers.
"""
super().__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
self.aligned = aligned
from torchvision import __version__
version = tuple(int(x) for x in __version__.split(".")[:2])
# https://github.com/pytorch/vision/pull/2438
assert version >= (0, 7), "Require torchvision >= 0.7"
def forward(self, input, rois):
"""
Args:
input: NCHW images
rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
"""
assert rois.dim() == 2 and rois.size(1) == 5
if input.is_quantized:
input = input.dequantize()
return roi_align(
input,
rois.to(dtype=input.dtype),
self.output_size,
self.spatial_scale,
self.sampling_ratio,
self.aligned,
)
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "output_size=" + str(self.output_size)
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
tmpstr += ", aligned=" + str(self.aligned)
tmpstr += ")"
return tmpstr

View File

@@ -0,0 +1,533 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from __future__ import absolute_import, division, print_function, unicode_literals
import math
from typing import List, Tuple
import torch
# from detectron2.layers.rotated_boxes import pairwise_iou_rotated
from .boxes import Boxes
def pairwise_iou_rotated(boxes1, boxes2):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in
(x_center, y_center, width, height, angle) format.
Arguments:
boxes1 (Tensor[N, 5])
boxes2 (Tensor[M, 5])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
return torch.ops.detectron2.box_iou_rotated(boxes1, boxes2)
class RotatedBoxes(Boxes):
"""
This structure stores a list of rotated boxes as a Nx5 torch.Tensor.
It supports some common methods about boxes
(`area`, `clip`, `nonempty`, etc),
and also behaves like a Tensor
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
"""
def __init__(self, tensor: torch.Tensor):
"""
Args:
tensor (Tensor[float]): a Nx5 matrix. Each row is
(x_center, y_center, width, height, angle),
in which angle is represented in degrees.
While there's no strict range restriction for it,
the recommended principal range is between [-180, 180) degrees.
Assume we have a horizontal box B = (x_center, y_center, width, height),
where width is along the x-axis and height is along the y-axis.
The rotated box B_rot (x_center, y_center, width, height, angle)
can be seen as:
1. When angle == 0:
B_rot == B
2. When angle > 0:
B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CCW;
3. When angle < 0:
B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CW.
Mathematically, since the right-handed coordinate system for image space
is (y, x), where y is top->down and x is left->right, the 4 vertices of the
rotated rectangle :math:`(yr_i, xr_i)` (i = 1, 2, 3, 4) can be obtained from
the vertices of the horizontal rectangle :math:`(y_i, x_i)` (i = 1, 2, 3, 4)
in the following way (:math:`\\theta = angle*\\pi/180` is the angle in radians,
:math:`(y_c, x_c)` is the center of the rectangle):
.. math::
yr_i = \\cos(\\theta) (y_i - y_c) - \\sin(\\theta) (x_i - x_c) + y_c,
xr_i = \\sin(\\theta) (y_i - y_c) + \\cos(\\theta) (x_i - x_c) + x_c,
which is the standard rigid-body rotation transformation.
Intuitively, the angle is
(1) the rotation angle from y-axis in image space
to the height vector (top->down in the box's local coordinate system)
of the box in CCW, and
(2) the rotation angle from x-axis in image space
to the width vector (left->right in the box's local coordinate system)
of the box in CCW.
More intuitively, consider the following horizontal box ABCD represented
in (x1, y1, x2, y2): (3, 2, 7, 4),
covering the [3, 7] x [2, 4] region of the continuous coordinate system
which looks like this:
.. code:: none
O--------> x
|
| A---B
| | |
| D---C
|
v y
Note that each capital letter represents one 0-dimensional geometric point
instead of a 'square pixel' here.
In the example above, using (x, y) to represent a point we have:
.. math::
O = (0, 0), A = (3, 2), B = (7, 2), C = (7, 4), D = (3, 4)
We name vector AB = vector DC as the width vector in box's local coordinate system, and
vector AD = vector BC as the height vector in box's local coordinate system. Initially,
when angle = 0 degree, they're aligned with the positive directions of x-axis and y-axis
in the image space, respectively.
For better illustration, we denote the center of the box as E,
.. code:: none
O--------> x
|
| A---B
| | E |
| D---C
|
v y
where the center E = ((3+7)/2, (2+4)/2) = (5, 3).
Also,
.. math::
width = |AB| = |CD| = 7 - 3 = 4,
height = |AD| = |BC| = 4 - 2 = 2.
Therefore, the corresponding representation for the same shape in rotated box in
(x_center, y_center, width, height, angle) format is:
(5, 3, 4, 2, 0),
Now, let's consider (5, 3, 4, 2, 90), which is rotated by 90 degrees
CCW (counter-clockwise) by definition. It looks like this:
.. code:: none
O--------> x
| B-C
| | |
| |E|
| | |
| A-D
v y
The center E is still located at the same point (5, 3), while the vertices
ABCD are rotated by 90 degrees CCW with regard to E:
A = (4, 5), B = (4, 1), C = (6, 1), D = (6, 5)
Here, 90 degrees can be seen as the CCW angle to rotate from y-axis to
vector AD or vector BC (the top->down height vector in box's local coordinate system),
or the CCW angle to rotate from x-axis to vector AB or vector DC (the left->right
width vector in box's local coordinate system).
.. math::
width = |AB| = |CD| = 5 - 1 = 4,
height = |AD| = |BC| = 6 - 4 = 2.
Next, how about (5, 3, 4, 2, -90), which is rotated by 90 degrees CW (clockwise)
by definition? It looks like this:
.. code:: none
O--------> x
| D-A
| | |
| |E|
| | |
| C-B
v y
The center E is still located at the same point (5, 3), while the vertices
ABCD are rotated by 90 degrees CW with regard to E:
A = (6, 1), B = (6, 5), C = (4, 5), D = (4, 1)
.. math::
width = |AB| = |CD| = 5 - 1 = 4,
height = |AD| = |BC| = 6 - 4 = 2.
This covers exactly the same region as (5, 3, 4, 2, 90) does, and their IoU
will be 1. However, these two will generate different RoI Pooling results and
should not be treated as an identical box.
On the other hand, it's easy to see that (X, Y, W, H, A) is identical to
(X, Y, W, H, A+360N), for any integer N. For example (5, 3, 4, 2, 270) would be
identical to (5, 3, 4, 2, -90), because rotating the shape 270 degrees CCW is
equivalent to rotating the same shape 90 degrees CW.
We could rotate further to get (5, 3, 4, 2, 180), or (5, 3, 4, 2, -180):
.. code:: none
O--------> x
|
| C---D
| | E |
| B---A
|
v y
.. math::
A = (7, 4), B = (3, 4), C = (3, 2), D = (7, 2),
width = |AB| = |CD| = 7 - 3 = 4,
height = |AD| = |BC| = 4 - 2 = 2.
Finally, this is a very inaccurate (heavily quantized) illustration of
how (5, 3, 4, 2, 60) looks like in case anyone wonders:
.. code:: none
O--------> x
| B\
| / C
| /E /
| A /
| `D
v y
It's still a rectangle with center of (5, 3), width of 4 and height of 2,
but its angle (and thus orientation) is somewhere between
(5, 3, 4, 2, 0) and (5, 3, 4, 2, 90).
"""
device = (
tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
)
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
if tensor.numel() == 0:
# Use reshape, so we don't end up creating a new tensor that does not depend on
# the inputs (and consequently confuses jit)
tensor = tensor.reshape((0, 5)).to(dtype=torch.float32, device=device)
assert tensor.dim() == 2 and tensor.size(-1) == 5, tensor.size()
self.tensor = tensor
def clone(self) -> "RotatedBoxes":
"""
Clone the RotatedBoxes.
Returns:
RotatedBoxes
"""
return RotatedBoxes(self.tensor.clone())
def to(self, device: torch.device, non_blocking: bool = False):
# Boxes are assumed float32 and does not support to(dtype)
return RotatedBoxes(self.tensor.to(device=device, non_blocking=non_blocking))
def area(self) -> torch.Tensor:
"""
Computes the area of all the boxes.
Returns:
torch.Tensor: a vector with areas of each box.
"""
box = self.tensor
area = box[:, 2] * box[:, 3]
return area
# Avoid in-place operations so that we can torchscript; NOTE: this creates a new tensor
def normalize_angles(self) -> None:
"""
Restrict angles to the range of [-180, 180) degrees
"""
angle_tensor = (self.tensor[:, 4] + 180.0) % 360.0 - 180.0
self.tensor = torch.cat((self.tensor[:, :4], angle_tensor[:, None]), dim=1)
def clip(
self, box_size: Tuple[int, int], clip_angle_threshold: float = 1.0
) -> None:
"""
Clip (in place) the boxes by limiting x coordinates to the range [0, width]
and y coordinates to the range [0, height].
For RRPN:
Only clip boxes that are almost horizontal with a tolerance of
clip_angle_threshold to maintain backward compatibility.
Rotated boxes beyond this threshold are not clipped for two reasons:
1. There are potentially multiple ways to clip a rotated box to make it
fit within the image.
2. It's tricky to make the entire rectangular box fit within the image
and still be able to not leave out pixels of interest.
Therefore we rely on ops like RoIAlignRotated to safely handle this.
Args:
box_size (height, width): The clipping box's size.
clip_angle_threshold:
Iff. abs(normalized(angle)) <= clip_angle_threshold (in degrees),
we do the clipping as horizontal boxes.
"""
h, w = box_size
# normalize angles to be within (-180, 180] degrees
self.normalize_angles()
idx = torch.where(torch.abs(self.tensor[:, 4]) <= clip_angle_threshold)[0]
# convert to (x1, y1, x2, y2)
x1 = self.tensor[idx, 0] - self.tensor[idx, 2] / 2.0
y1 = self.tensor[idx, 1] - self.tensor[idx, 3] / 2.0
x2 = self.tensor[idx, 0] + self.tensor[idx, 2] / 2.0
y2 = self.tensor[idx, 1] + self.tensor[idx, 3] / 2.0
# clip
x1.clamp_(min=0, max=w)
y1.clamp_(min=0, max=h)
x2.clamp_(min=0, max=w)
y2.clamp_(min=0, max=h)
# convert back to (xc, yc, w, h)
self.tensor[idx, 0] = (x1 + x2) / 2.0
self.tensor[idx, 1] = (y1 + y2) / 2.0
# make sure widths and heights do not increase due to numerical errors
self.tensor[idx, 2] = torch.min(self.tensor[idx, 2], x2 - x1)
self.tensor[idx, 3] = torch.min(self.tensor[idx, 3], y2 - y1)
def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
"""
Find boxes that are non-empty.
A box is considered empty, if either of its side is no larger than threshold.
Returns:
Tensor: a binary vector which represents
whether each box is empty (False) or non-empty (True).
"""
box = self.tensor
widths = box[:, 2]
heights = box[:, 3]
keep = (widths > threshold) & (heights > threshold)
return keep
def __getitem__(self, item) -> "RotatedBoxes":
"""
Returns:
RotatedBoxes: Create a new :class:`RotatedBoxes` by indexing.
The following usage are allowed:
1. `new_boxes = boxes[3]`: return a `RotatedBoxes` which contains only one box.
2. `new_boxes = boxes[2:10]`: return a slice of boxes.
3. `new_boxes = boxes[vector]`, where vector is a torch.ByteTensor
with `length = len(boxes)`. Nonzero elements in the vector will be selected.
Note that the returned RotatedBoxes might share storage with this RotatedBoxes,
subject to Pytorch's indexing semantics.
"""
if isinstance(item, int):
return RotatedBoxes(self.tensor[item].view(1, -1))
b = self.tensor[item]
assert (
b.dim() == 2
), "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
return RotatedBoxes(b)
def __len__(self) -> int:
return self.tensor.shape[0]
def __repr__(self) -> str:
return "RotatedBoxes(" + str(self.tensor) + ")"
def inside_box(
self, box_size: Tuple[int, int], boundary_threshold: int = 0
) -> torch.Tensor:
"""
Args:
box_size (height, width): Size of the reference box covering
[0, width] x [0, height]
boundary_threshold (int): Boxes that extend beyond the reference box
boundary by more than boundary_threshold are considered "outside".
For RRPN, it might not be necessary to call this function since it's common
for rotated box to extend to outside of the image boundaries
(the clip function only clips the near-horizontal boxes)
Returns:
a binary vector, indicating whether each box is inside the reference box.
"""
height, width = box_size
cnt_x = self.tensor[..., 0]
cnt_y = self.tensor[..., 1]
half_w = self.tensor[..., 2] / 2.0
half_h = self.tensor[..., 3] / 2.0
a = self.tensor[..., 4]
c = torch.abs(torch.cos(a * math.pi / 180.0))
s = torch.abs(torch.sin(a * math.pi / 180.0))
# This basically computes the horizontal bounding rectangle of the rotated box
max_rect_dx = c * half_w + s * half_h
max_rect_dy = c * half_h + s * half_w
inds_inside = (
(cnt_x - max_rect_dx >= -boundary_threshold)
& (cnt_y - max_rect_dy >= -boundary_threshold)
& (cnt_x + max_rect_dx < width + boundary_threshold)
& (cnt_y + max_rect_dy < height + boundary_threshold)
)
return inds_inside
def get_centers(self) -> torch.Tensor:
"""
Returns:
The box centers in a Nx2 array of (x, y).
"""
return self.tensor[:, :2]
def scale(self, scale_x: float, scale_y: float) -> None:
"""
Scale the rotated box with horizontal and vertical scaling factors
Note: when scale_factor_x != scale_factor_y,
the rotated box does not preserve the rectangular shape when the angle
is not a multiple of 90 degrees under resize transformation.
Instead, the shape is a parallelogram (that has skew)
Here we make an approximation by fitting a rotated rectangle to the parallelogram.
"""
self.tensor[:, 0] *= scale_x
self.tensor[:, 1] *= scale_y
theta = self.tensor[:, 4] * math.pi / 180.0
c = torch.cos(theta)
s = torch.sin(theta)
# In image space, y is top->down and x is left->right
# Consider the local coordintate system for the rotated box,
# where the box center is located at (0, 0), and the four vertices ABCD are
# A(-w / 2, -h / 2), B(w / 2, -h / 2), C(w / 2, h / 2), D(-w / 2, h / 2)
# the midpoint of the left edge AD of the rotated box E is:
# E = (A+D)/2 = (-w / 2, 0)
# the midpoint of the top edge AB of the rotated box F is:
# F(0, -h / 2)
# To get the old coordinates in the global system, apply the rotation transformation
# (Note: the right-handed coordinate system for image space is yOx):
# (old_x, old_y) = (s * y + c * x, c * y - s * x)
# E(old) = (s * 0 + c * (-w/2), c * 0 - s * (-w/2)) = (-c * w / 2, s * w / 2)
# F(old) = (s * (-h / 2) + c * 0, c * (-h / 2) - s * 0) = (-s * h / 2, -c * h / 2)
# After applying the scaling factor (sfx, sfy):
# E(new) = (-sfx * c * w / 2, sfy * s * w / 2)
# F(new) = (-sfx * s * h / 2, -sfy * c * h / 2)
# The new width after scaling tranformation becomes:
# w(new) = |E(new) - O| * 2
# = sqrt[(sfx * c * w / 2)^2 + (sfy * s * w / 2)^2] * 2
# = sqrt[(sfx * c)^2 + (sfy * s)^2] * w
# i.e., scale_factor_w = sqrt[(sfx * c)^2 + (sfy * s)^2]
#
# For example,
# when angle = 0 or 180, |c| = 1, s = 0, scale_factor_w == scale_factor_x;
# when |angle| = 90, c = 0, |s| = 1, scale_factor_w == scale_factor_y
self.tensor[:, 2] *= torch.sqrt((scale_x * c) ** 2 + (scale_y * s) ** 2)
# h(new) = |F(new) - O| * 2
# = sqrt[(sfx * s * h / 2)^2 + (sfy * c * h / 2)^2] * 2
# = sqrt[(sfx * s)^2 + (sfy * c)^2] * h
# i.e., scale_factor_h = sqrt[(sfx * s)^2 + (sfy * c)^2]
#
# For example,
# when angle = 0 or 180, |c| = 1, s = 0, scale_factor_h == scale_factor_y;
# when |angle| = 90, c = 0, |s| = 1, scale_factor_h == scale_factor_x
self.tensor[:, 3] *= torch.sqrt((scale_x * s) ** 2 + (scale_y * c) ** 2)
# The angle is the rotation angle from y-axis in image space to the height
# vector (top->down in the box's local coordinate system) of the box in CCW.
#
# angle(new) = angle_yOx(O - F(new))
# = angle_yOx( (sfx * s * h / 2, sfy * c * h / 2) )
# = atan2(sfx * s * h / 2, sfy * c * h / 2)
# = atan2(sfx * s, sfy * c)
#
# For example,
# when sfx == sfy, angle(new) == atan2(s, c) == angle(old)
self.tensor[:, 4] = torch.atan2(scale_x * s, scale_y * c) * 180 / math.pi
@classmethod
def cat(cls, boxes_list: List["RotatedBoxes"]) -> "RotatedBoxes":
"""
Concatenates a list of RotatedBoxes into a single RotatedBoxes
Arguments:
boxes_list (list[RotatedBoxes])
Returns:
RotatedBoxes: the concatenated RotatedBoxes
"""
assert isinstance(boxes_list, (list, tuple))
if len(boxes_list) == 0:
return cls(torch.empty(0))
assert all([isinstance(box, RotatedBoxes) for box in boxes_list])
# use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
return cat_boxes
@property
def device(self) -> torch.device:
return self.tensor.device
@torch.jit.unused
def __iter__(self):
"""
Yield a box as a Tensor of shape (5,) at a time.
"""
yield from self.tensor
def pairwise_iou(boxes1: RotatedBoxes, boxes2: RotatedBoxes) -> None:
"""
Given two lists of rotated boxes of size N and M,
compute the IoU (intersection over union)
between **all** N x M pairs of boxes.
The box order must be (x_center, y_center, width, height, angle).
Args:
boxes1, boxes2 (RotatedBoxes):
two `RotatedBoxes`. Contains N & M rotated boxes, respectively.
Returns:
Tensor: IoU, sized [N,M].
"""
return pairwise_iou_rotated(boxes1.tensor, boxes2.tensor)

View File

@@ -0,0 +1,406 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import colorsys
from dataclasses import dataclass
from typing import List, Tuple
import cv2
import matplotlib as mpl
import matplotlib.colors as mplc
import numpy as np
import pycocotools.mask as mask_utils
def rgb_to_hex(rgb_color):
"""
Convert a rgb color to hex color.
Args:
rgb_color (tuple/list of ints): RGB color in tuple or list format.
Returns:
str: Hex color.
Example:
```
>>> rgb_to_hex((255, 0, 244))
'#ff00ff'
```
"""
return "#" + "".join([hex(c)[2:].zfill(2) for c in rgb_color])
# DEFAULT_COLOR_HEX_TO_NAME = {
# rgb_to_hex((255, 0, 0)): "red",
# rgb_to_hex((0, 255, 0)): "lime",
# rgb_to_hex((0, 0, 255)): "blue",
# rgb_to_hex((255, 255, 0)): "yellow",
# rgb_to_hex((255, 0, 255)): "fuchsia",
# rgb_to_hex((0, 255, 255)): "aqua",
# rgb_to_hex((255, 165, 0)): "orange",
# rgb_to_hex((128, 0, 128)): "purple",
# rgb_to_hex((255, 215, 0)): "gold",
# }
# Assuming rgb_to_hex is a function that converts an (R, G, B) tuple to a hex string.
# For example: def rgb_to_hex(rgb): return '#%02x%02x%02x' % rgb
DEFAULT_COLOR_HEX_TO_NAME = {
# The top 20 approved colors
rgb_to_hex((255, 255, 0)): "yellow",
rgb_to_hex((0, 255, 0)): "lime",
rgb_to_hex((0, 255, 255)): "cyan",
rgb_to_hex((255, 0, 255)): "magenta",
rgb_to_hex((255, 0, 0)): "red",
rgb_to_hex((255, 127, 0)): "orange",
rgb_to_hex((127, 255, 0)): "chartreuse",
rgb_to_hex((0, 255, 127)): "spring green",
rgb_to_hex((255, 0, 127)): "rose",
rgb_to_hex((127, 0, 255)): "violet",
rgb_to_hex((192, 255, 0)): "electric lime",
rgb_to_hex((255, 192, 0)): "vivid orange",
rgb_to_hex((0, 255, 192)): "turquoise",
rgb_to_hex((192, 0, 255)): "bright violet",
rgb_to_hex((255, 0, 192)): "bright pink",
rgb_to_hex((255, 64, 0)): "fiery orange",
rgb_to_hex((64, 255, 0)): "bright chartreuse",
rgb_to_hex((0, 255, 64)): "malachite",
rgb_to_hex((64, 0, 255)): "deep violet",
rgb_to_hex((255, 0, 64)): "hot pink",
}
DEFAULT_COLOR_PALETTE = list(DEFAULT_COLOR_HEX_TO_NAME.keys())
def _validate_color_hex(color_hex: str):
color_hex = color_hex.lstrip("#")
if not all(c in "0123456789abcdefABCDEF" for c in color_hex):
raise ValueError("Invalid characters in color hash")
if len(color_hex) not in (3, 6):
raise ValueError("Invalid length of color hash")
# copied from https://github.com/roboflow/supervision/blob/c8f557af0c61b5c03392bad2cc36c8835598b1e1/supervision/draw/color.py
@dataclass
class Color:
"""
Represents a color in RGB format.
Attributes:
r (int): Red channel.
g (int): Green channel.
b (int): Blue channel.
"""
r: int
g: int
b: int
@classmethod
def from_hex(cls, color_hex: str):
"""
Create a Color instance from a hex string.
Args:
color_hex (str): Hex string of the color.
Returns:
Color: Instance representing the color.
Example:
```
>>> Color.from_hex('#ff00ff')
Color(r=255, g=0, b=255)
```
"""
_validate_color_hex(color_hex)
color_hex = color_hex.lstrip("#")
if len(color_hex) == 3:
color_hex = "".join(c * 2 for c in color_hex)
r, g, b = (int(color_hex[i : i + 2], 16) for i in range(0, 6, 2))
return cls(r, g, b)
@classmethod
def to_hex(cls, color):
"""
Convert a Color instance to a hex string.
Args:
color (Color): Color instance of color.
Returns:
Color: a hex string.
"""
return rgb_to_hex((color.r, color.g, color.b))
def as_rgb(self) -> Tuple[int, int, int]:
"""
Returns the color as an RGB tuple.
Returns:
Tuple[int, int, int]: RGB tuple.
Example:
```
>>> color.as_rgb()
(255, 0, 255)
```
"""
return self.r, self.g, self.b
def as_bgr(self) -> Tuple[int, int, int]:
"""
Returns the color as a BGR tuple.
Returns:
Tuple[int, int, int]: BGR tuple.
Example:
```
>>> color.as_bgr()
(255, 0, 255)
```
"""
return self.b, self.g, self.r
@classmethod
def white(cls):
return Color.from_hex(color_hex="#ffffff")
@classmethod
def black(cls):
return Color.from_hex(color_hex="#000000")
@classmethod
def red(cls):
return Color.from_hex(color_hex="#ff0000")
@classmethod
def green(cls):
return Color.from_hex(color_hex="#00ff00")
@classmethod
def blue(cls):
return Color.from_hex(color_hex="#0000ff")
@dataclass
class ColorPalette:
colors: List[Color]
@classmethod
def default(cls):
"""
Returns a default color palette.
Returns:
ColorPalette: A ColorPalette instance with default colors.
Example:
```
>>> ColorPalette.default()
ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
```
"""
return ColorPalette.from_hex(color_hex_list=DEFAULT_COLOR_PALETTE)
@classmethod
def from_hex(cls, color_hex_list: List[str]):
"""
Create a ColorPalette instance from a list of hex strings.
Args:
color_hex_list (List[str]): List of color hex strings.
Returns:
ColorPalette: A ColorPalette instance.
Example:
```
>>> ColorPalette.from_hex(['#ff0000', '#00ff00', '#0000ff'])
ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
```
"""
colors = [Color.from_hex(color_hex) for color_hex in color_hex_list]
return cls(colors)
def by_idx(self, idx: int) -> Color:
"""
Return the color at a given index in the palette.
Args:
idx (int): Index of the color in the palette.
Returns:
Color: Color at the given index.
Example:
```
>>> color_palette.by_idx(1)
Color(r=0, g=255, b=0)
```
"""
if idx < 0:
raise ValueError("idx argument should not be negative")
idx = idx % len(self.colors)
return self.colors[idx]
def find_farthest_color(self, img_array):
"""
Return the color that is the farthest from the given color.
Args:
img_array (np array): any *x3 np array, 3 is the RGB color channel.
Returns:
Color: Farthest color.
"""
# Reshape the image array for broadcasting
img_array = img_array.reshape((-1, 3))
# Convert colors dictionary to a NumPy array
color_values = np.array([[c.r, c.g, c.b] for c in self.colors])
# Calculate the Euclidean distance between the colors and each pixel in the image
# Broadcasting happens here: img_array shape is (num_pixels, 3), color_values shape is (num_colors, 3)
distances = np.sqrt(
np.sum((img_array[:, np.newaxis, :] - color_values) ** 2, axis=2)
)
# Average the distances for each color
mean_distances = np.mean(distances, axis=0)
# return the farthest color
farthest_idx = np.argmax(mean_distances)
farthest_color = self.colors[farthest_idx]
farthest_color_hex = Color.to_hex(farthest_color)
if farthest_color_hex in DEFAULT_COLOR_HEX_TO_NAME:
farthest_color_name = DEFAULT_COLOR_HEX_TO_NAME[farthest_color_hex]
else:
farthest_color_name = "unknown"
return farthest_color, farthest_color_name
def draw_box(ax, box_coord, alpha=0.8, edge_color="g", line_style="-", linewidth=2.0):
x0, y0, width, height = box_coord
ax.add_patch(
mpl.patches.Rectangle(
(x0, y0),
width,
height,
fill=False,
edgecolor=edge_color,
linewidth=linewidth,
alpha=alpha,
linestyle=line_style,
)
)
def draw_text(
ax,
text,
position,
font_size=None,
color="g",
horizontal_alignment="left",
rotation=0,
):
if not font_size:
font_size = mpl.rcParams["font.size"]
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
x, y = position
ax.text(
x,
y,
text,
size=font_size,
family="sans-serif",
bbox={"facecolor": "none", "alpha": 0.5, "pad": 0.7, "edgecolor": "none"},
verticalalignment="top",
horizontalalignment=horizontal_alignment,
color=color,
rotation=rotation,
)
def draw_mask(
ax, rle, color, show_holes=True, alpha=0.15, upsample_factor=1.0, rle_upsampled=None
):
if isinstance(rle, dict):
mask = mask_utils.decode(rle)
elif isinstance(rle, np.ndarray):
mask = rle
else:
raise ValueError(f"Unsupported type for rle: {type(rle)}")
mask_upsampled = None
if upsample_factor > 1.0 and show_holes:
assert rle_upsampled is not None
if isinstance(rle_upsampled, dict):
mask_upsampled = mask_utils.decode(rle_upsampled)
elif isinstance(rle_upsampled, np.ndarray):
mask_upsampled = rle_upsampled
else:
raise ValueError(f"Unsupported type for rle: {type(rle)}")
if show_holes:
if mask_upsampled is None:
mask_upsampled = mask
h, w = mask_upsampled.shape
mask_img = np.zeros((h, w, 4))
mask_img[:, :, :-1] = color[np.newaxis, np.newaxis, :]
mask_img[:, :, -1] = mask_upsampled * alpha
ax.imshow(mask_img)
*_, contours, _ = cv2.findContours(
mask.astype(np.uint8).copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
upsampled_contours = [(cont + 0.5) * upsample_factor - 0.5 for cont in contours]
facecolor = (0, 0, 0, 0) if show_holes else color
if alpha > 0.8:
edge_color = _change_color_brightness(color, brightness_factor=-0.7)
else:
edge_color = color
for cont in upsampled_contours:
polygon = mpl.patches.Polygon(
[el[0] for el in cont],
edgecolor=edge_color,
linewidth=2.0,
facecolor=facecolor,
)
ax.add_patch(polygon)
def _change_color_brightness(color, brightness_factor):
"""
Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
less or more saturation than the original color.
Args:
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
formats that are accepted.
brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
0 will correspond to no change, a factor in [-1.0, 0) range will result in
a darker color and a factor in (0, 1.0] range will result in a lighter color.
Returns:
modified_color (tuple[double]): a tuple containing the RGB values of the
modified color. Each value in the tuple is in the [0.0, 1.0] range.
"""
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
color = mplc.to_rgb(color)
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
modified_color = colorsys.hls_to_rgb(
polygon_color[0], modified_lightness, polygon_color[2]
)
return modified_color

1662
sam3/agent/helpers/visualizer.py Executable file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,195 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import io
import math
import matplotlib.pyplot as plt
import numpy as np
import pycocotools.mask as mask_utils
from PIL import Image
from .som_utils import ColorPalette, draw_box, draw_mask, draw_text
def render_zoom_in(
object_data,
image_file,
show_box: bool = True,
show_text: bool = False,
show_holes: bool = True,
mask_alpha: float = 0.15,
):
"""
Render a two-panel visualization with a cropped original view (left/upper) and a zoomed-in
mask overlay (right/lower), then return it as a PIL.Image along with the chosen mask color (hex).
Parameters
----------
object_data : dict
Dict containing "labels" and COCO RLE "segmentation".
Expected:
object_data["labels"][0]["noun_phrase"] : str
object_data["segmentation"] : COCO RLE (with "size": [H, W])
image_file : PIL.Image.Image
Source image (PIL).
show_box : bool
Whether to draw the bbox on the cropped original panel.
show_text : bool
Whether to draw the noun phrase label near the bbox.
show_holes : bool
Whether to render mask holes (passed through to draw_mask).
mask_alpha : float
Alpha for the mask overlay.
Returns
-------
pil_img : PIL.Image.Image
The composed visualization image.
color_hex : str
Hex string of the chosen mask color.
"""
# ---- local constants (avoid module-level globals) ----
_AREA_LARGE = 0.25
_AREA_MEDIUM = 0.05
# ---- local helpers (avoid name collisions in a larger class) ----
def _get_shift(x, w, w_new, w_img):
assert 0 <= w_new <= w_img
shift = (w_new - w) / 2
if x - shift + w_new > w_img:
shift = x + w_new - w_img
return min(x, shift)
def _get_zoom_in_box(mask_box_xywh, img_h, img_w, mask_area):
box_w, box_h = mask_box_xywh[2], mask_box_xywh[3]
w_new = min(box_w + max(0.2 * box_w, 16), img_w)
h_new = min(box_h + max(0.2 * box_h, 16), img_h)
mask_relative_area = mask_area / (w_new * h_new)
# zoom-in (larger box if mask is relatively big)
w_new_large, h_new_large = w_new, h_new
if mask_relative_area > _AREA_LARGE:
ratio_large = math.sqrt(mask_relative_area / _AREA_LARGE)
w_new_large = min(w_new * ratio_large, img_w)
h_new_large = min(h_new * ratio_large, img_h)
w_shift_large = _get_shift(
mask_box_xywh[0], mask_box_xywh[2], w_new_large, img_w
)
h_shift_large = _get_shift(
mask_box_xywh[1], mask_box_xywh[3], h_new_large, img_h
)
zoom_in_box = [
mask_box_xywh[0] - w_shift_large,
mask_box_xywh[1] - h_shift_large,
w_new_large,
h_new_large,
]
# crop box for the original/cropped image
w_new_medium, h_new_medium = w_new, h_new
if mask_relative_area > _AREA_MEDIUM:
ratio_med = math.sqrt(mask_relative_area / _AREA_MEDIUM)
w_new_medium = min(w_new * ratio_med, img_w)
h_new_medium = min(h_new * ratio_med, img_h)
w_shift_medium = _get_shift(
mask_box_xywh[0], mask_box_xywh[2], w_new_medium, img_w
)
h_shift_medium = _get_shift(
mask_box_xywh[1], mask_box_xywh[3], h_new_medium, img_h
)
img_crop_box = [
mask_box_xywh[0] - w_shift_medium,
mask_box_xywh[1] - h_shift_medium,
w_new_medium,
h_new_medium,
]
return zoom_in_box, img_crop_box
# ---- main body ----
# Input parsing
object_label = object_data["labels"][0]["noun_phrase"]
img = image_file.convert("RGB")
bbox_xywh = mask_utils.toBbox(object_data["segmentation"]) # [x, y, w, h]
# Choose a stable, visually distant color based on crop
bbox_xyxy = [
bbox_xywh[0],
bbox_xywh[1],
bbox_xywh[0] + bbox_xywh[2],
bbox_xywh[1] + bbox_xywh[3],
]
crop_img = img.crop(bbox_xyxy)
color_palette = ColorPalette.default()
color_obj, _ = color_palette.find_farthest_color(np.array(crop_img))
color = np.array([color_obj.r / 255, color_obj.g / 255, color_obj.b / 255])
color_hex = f"#{color_obj.r:02x}{color_obj.g:02x}{color_obj.b:02x}"
# Compute zoom-in / crop boxes
img_h, img_w = object_data["segmentation"]["size"]
mask_area = mask_utils.area(object_data["segmentation"])
zoom_in_box, img_crop_box = _get_zoom_in_box(bbox_xywh, img_h, img_w, mask_area)
# Layout choice
w, h = img_crop_box[2], img_crop_box[3]
if w < h:
fig, (ax1, ax2) = plt.subplots(1, 2)
else:
fig, (ax1, ax2) = plt.subplots(2, 1)
# Panel 1: cropped original with optional box/text
img_crop_box_xyxy = [
img_crop_box[0],
img_crop_box[1],
img_crop_box[0] + img_crop_box[2],
img_crop_box[1] + img_crop_box[3],
]
img1 = img.crop(img_crop_box_xyxy)
bbox_xywh_rel = [
bbox_xywh[0] - img_crop_box[0],
bbox_xywh[1] - img_crop_box[1],
bbox_xywh[2],
bbox_xywh[3],
]
ax1.imshow(img1)
ax1.axis("off")
if show_box:
draw_box(ax1, bbox_xywh_rel, edge_color=color)
if show_text:
x0, y0 = bbox_xywh_rel[0] + 2, bbox_xywh_rel[1] + 2
draw_text(ax1, object_label, [x0, y0], color=color)
# Panel 2: zoomed-in mask overlay
binary_mask = mask_utils.decode(object_data["segmentation"])
alpha = Image.fromarray((binary_mask * 255).astype("uint8"))
img_rgba = img.convert("RGBA")
img_rgba.putalpha(alpha)
zoom_in_box_xyxy = [
zoom_in_box[0],
zoom_in_box[1],
zoom_in_box[0] + zoom_in_box[2],
zoom_in_box[1] + zoom_in_box[3],
]
img_with_alpha_zoomin = img_rgba.crop(zoom_in_box_xyxy)
alpha_zoomin = img_with_alpha_zoomin.split()[3]
binary_mask_zoomin = np.array(alpha_zoomin).astype(bool)
ax2.imshow(img_with_alpha_zoomin.convert("RGB"))
ax2.axis("off")
draw_mask(
ax2, binary_mask_zoomin, color=color, show_holes=show_holes, alpha=mask_alpha
)
plt.tight_layout()
# Buffer -> PIL.Image
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=100)
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf)
return pil_img, color_hex