Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
1
sam3/agent/helpers/__init__.py
Executable file
1
sam3/agent/helpers/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
438
sam3/agent/helpers/boxes.py
Executable file
438
sam3/agent/helpers/boxes.py
Executable file
@@ -0,0 +1,438 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import math
|
||||
from enum import IntEnum, unique
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import device
|
||||
|
||||
_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
@unique
|
||||
class BoxMode(IntEnum):
|
||||
"""
|
||||
Enum of different ways to represent a box.
|
||||
"""
|
||||
|
||||
XYXY_ABS = 0
|
||||
"""
|
||||
(x0, y0, x1, y1) in absolute floating points coordinates.
|
||||
The coordinates in range [0, width or height].
|
||||
"""
|
||||
XYWH_ABS = 1
|
||||
"""
|
||||
(x0, y0, w, h) in absolute floating points coordinates.
|
||||
"""
|
||||
XYXY_REL = 2
|
||||
"""
|
||||
Not yet supported!
|
||||
(x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
|
||||
"""
|
||||
XYWH_REL = 3
|
||||
"""
|
||||
Not yet supported!
|
||||
(x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
|
||||
"""
|
||||
XYWHA_ABS = 4
|
||||
"""
|
||||
(xc, yc, w, h, a) in absolute floating points coordinates.
|
||||
(xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def convert(
|
||||
box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode"
|
||||
) -> _RawBoxType:
|
||||
"""
|
||||
Args:
|
||||
box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
|
||||
from_mode, to_mode (BoxMode)
|
||||
|
||||
Returns:
|
||||
The converted box of the same type.
|
||||
"""
|
||||
if from_mode == to_mode:
|
||||
return box
|
||||
|
||||
original_type = type(box)
|
||||
is_numpy = isinstance(box, np.ndarray)
|
||||
single_box = isinstance(box, (list, tuple))
|
||||
if single_box:
|
||||
assert len(box) == 4 or len(box) == 5, (
|
||||
"BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
|
||||
" where k == 4 or 5"
|
||||
)
|
||||
arr = torch.tensor(box)[None, :]
|
||||
else:
|
||||
# avoid modifying the input box
|
||||
if is_numpy:
|
||||
arr = torch.from_numpy(np.asarray(box)).clone()
|
||||
else:
|
||||
arr = box.clone()
|
||||
|
||||
assert to_mode not in [
|
||||
BoxMode.XYXY_REL,
|
||||
BoxMode.XYWH_REL,
|
||||
] and from_mode not in [
|
||||
BoxMode.XYXY_REL,
|
||||
BoxMode.XYWH_REL,
|
||||
], "Relative mode not yet supported!"
|
||||
|
||||
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
|
||||
assert (
|
||||
arr.shape[-1] == 5
|
||||
), "The last dimension of input shape must be 5 for XYWHA format"
|
||||
original_dtype = arr.dtype
|
||||
arr = arr.double()
|
||||
|
||||
w = arr[:, 2]
|
||||
h = arr[:, 3]
|
||||
a = arr[:, 4]
|
||||
c = torch.abs(torch.cos(a * math.pi / 180.0))
|
||||
s = torch.abs(torch.sin(a * math.pi / 180.0))
|
||||
# This basically computes the horizontal bounding rectangle of the rotated box
|
||||
new_w = c * w + s * h
|
||||
new_h = c * h + s * w
|
||||
|
||||
# convert center to top-left corner
|
||||
arr[:, 0] -= new_w / 2.0
|
||||
arr[:, 1] -= new_h / 2.0
|
||||
# bottom-right corner
|
||||
arr[:, 2] = arr[:, 0] + new_w
|
||||
arr[:, 3] = arr[:, 1] + new_h
|
||||
|
||||
arr = arr[:, :4].to(dtype=original_dtype)
|
||||
elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
|
||||
original_dtype = arr.dtype
|
||||
arr = arr.double()
|
||||
arr[:, 0] += arr[:, 2] / 2.0
|
||||
arr[:, 1] += arr[:, 3] / 2.0
|
||||
angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
|
||||
arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
|
||||
else:
|
||||
if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
|
||||
arr[:, 2] += arr[:, 0]
|
||||
arr[:, 3] += arr[:, 1]
|
||||
elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
|
||||
arr[:, 2] -= arr[:, 0]
|
||||
arr[:, 3] -= arr[:, 1]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Conversion from BoxMode {} to {} is not supported yet".format(
|
||||
from_mode, to_mode
|
||||
)
|
||||
)
|
||||
|
||||
if single_box:
|
||||
return original_type(arr.flatten().tolist())
|
||||
if is_numpy:
|
||||
return arr.numpy()
|
||||
else:
|
||||
return arr
|
||||
|
||||
|
||||
class Boxes:
|
||||
"""
|
||||
This structure stores a list of boxes as a Nx4 torch.Tensor.
|
||||
It supports some common methods about boxes
|
||||
(`area`, `clip`, `nonempty`, etc),
|
||||
and also behaves like a Tensor
|
||||
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
|
||||
|
||||
Attributes:
|
||||
tensor (torch.Tensor): float matrix of Nx4. Each row is (x1, y1, x2, y2).
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
|
||||
"""
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
tensor = torch.as_tensor(
|
||||
tensor, dtype=torch.float32, device=torch.device("cpu")
|
||||
)
|
||||
else:
|
||||
tensor = tensor.to(torch.float32)
|
||||
if tensor.numel() == 0:
|
||||
# Use reshape, so we don't end up creating a new tensor that does not depend on
|
||||
# the inputs (and consequently confuses jit)
|
||||
tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32)
|
||||
assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
|
||||
|
||||
self.tensor = tensor
|
||||
|
||||
def clone(self) -> "Boxes":
|
||||
"""
|
||||
Clone the Boxes.
|
||||
|
||||
Returns:
|
||||
Boxes
|
||||
"""
|
||||
return Boxes(self.tensor.clone())
|
||||
|
||||
def to(self, device: torch.device):
|
||||
# Boxes are assumed float32 and does not support to(dtype)
|
||||
return Boxes(self.tensor.to(device=device))
|
||||
|
||||
def area(self) -> torch.Tensor:
|
||||
"""
|
||||
Computes the area of all the boxes.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: a vector with areas of each box.
|
||||
"""
|
||||
box = self.tensor
|
||||
area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
||||
return area
|
||||
|
||||
def clip(self, box_size: Tuple[int, int]) -> None:
|
||||
"""
|
||||
Clip (in place) the boxes by limiting x coordinates to the range [0, width]
|
||||
and y coordinates to the range [0, height].
|
||||
|
||||
Args:
|
||||
box_size (height, width): The clipping box's size.
|
||||
"""
|
||||
assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
|
||||
h, w = box_size
|
||||
x1 = self.tensor[:, 0].clamp(min=0, max=w)
|
||||
y1 = self.tensor[:, 1].clamp(min=0, max=h)
|
||||
x2 = self.tensor[:, 2].clamp(min=0, max=w)
|
||||
y2 = self.tensor[:, 3].clamp(min=0, max=h)
|
||||
self.tensor = torch.stack((x1, y1, x2, y2), dim=-1)
|
||||
|
||||
def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
|
||||
"""
|
||||
Find boxes that are non-empty.
|
||||
A box is considered empty, if either of its side is no larger than threshold.
|
||||
|
||||
Returns:
|
||||
Tensor:
|
||||
a binary vector which represents whether each box is empty
|
||||
(False) or non-empty (True).
|
||||
"""
|
||||
box = self.tensor
|
||||
widths = box[:, 2] - box[:, 0]
|
||||
heights = box[:, 3] - box[:, 1]
|
||||
keep = (widths > threshold) & (heights > threshold)
|
||||
return keep
|
||||
|
||||
def __getitem__(self, item) -> "Boxes":
|
||||
"""
|
||||
Args:
|
||||
item: int, slice, or a BoolTensor
|
||||
|
||||
Returns:
|
||||
Boxes: Create a new :class:`Boxes` by indexing.
|
||||
|
||||
The following usage are allowed:
|
||||
|
||||
1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box.
|
||||
2. `new_boxes = boxes[2:10]`: return a slice of boxes.
|
||||
3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor
|
||||
with `length = len(boxes)`. Nonzero elements in the vector will be selected.
|
||||
|
||||
Note that the returned Boxes might share storage with this Boxes,
|
||||
subject to Pytorch's indexing semantics.
|
||||
"""
|
||||
if isinstance(item, int):
|
||||
return Boxes(self.tensor[item].view(1, -1))
|
||||
b = self.tensor[item]
|
||||
assert (
|
||||
b.dim() == 2
|
||||
), "Indexing on Boxes with {} failed to return a matrix!".format(item)
|
||||
return Boxes(b)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.tensor.shape[0]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Boxes(" + str(self.tensor) + ")"
|
||||
|
||||
def inside_box(
|
||||
self, box_size: Tuple[int, int], boundary_threshold: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
box_size (height, width): Size of the reference box.
|
||||
boundary_threshold (int): Boxes that extend beyond the reference box
|
||||
boundary by more than boundary_threshold are considered "outside".
|
||||
|
||||
Returns:
|
||||
a binary vector, indicating whether each box is inside the reference box.
|
||||
"""
|
||||
height, width = box_size
|
||||
inds_inside = (
|
||||
(self.tensor[..., 0] >= -boundary_threshold)
|
||||
& (self.tensor[..., 1] >= -boundary_threshold)
|
||||
& (self.tensor[..., 2] < width + boundary_threshold)
|
||||
& (self.tensor[..., 3] < height + boundary_threshold)
|
||||
)
|
||||
return inds_inside
|
||||
|
||||
def get_centers(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns:
|
||||
The box centers in a Nx2 array of (x, y).
|
||||
"""
|
||||
return (self.tensor[:, :2] + self.tensor[:, 2:]) / 2
|
||||
|
||||
def scale(self, scale_x: float, scale_y: float) -> None:
|
||||
"""
|
||||
Scale the box with horizontal and vertical scaling factors
|
||||
"""
|
||||
self.tensor[:, 0::2] *= scale_x
|
||||
self.tensor[:, 1::2] *= scale_y
|
||||
|
||||
@classmethod
|
||||
def cat(cls, boxes_list: List["Boxes"]) -> "Boxes":
|
||||
"""
|
||||
Concatenates a list of Boxes into a single Boxes
|
||||
|
||||
Arguments:
|
||||
boxes_list (list[Boxes])
|
||||
|
||||
Returns:
|
||||
Boxes: the concatenated Boxes
|
||||
"""
|
||||
assert isinstance(boxes_list, (list, tuple))
|
||||
if len(boxes_list) == 0:
|
||||
return cls(torch.empty(0))
|
||||
assert all([isinstance(box, Boxes) for box in boxes_list])
|
||||
|
||||
# use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
|
||||
cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
|
||||
return cat_boxes
|
||||
|
||||
@property
|
||||
def device(self) -> device:
|
||||
return self.tensor.device
|
||||
|
||||
# type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
|
||||
# https://github.com/pytorch/pytorch/issues/18627
|
||||
@torch.jit.unused
|
||||
def __iter__(self):
|
||||
"""
|
||||
Yield a box as a Tensor of shape (4,) at a time.
|
||||
"""
|
||||
yield from self.tensor
|
||||
|
||||
|
||||
def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
||||
"""
|
||||
Given two lists of boxes of size N and M,
|
||||
compute the intersection area between __all__ N x M pairs of boxes.
|
||||
The box order must be (xmin, ymin, xmax, ymax)
|
||||
|
||||
Args:
|
||||
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
|
||||
|
||||
Returns:
|
||||
Tensor: intersection, sized [N,M].
|
||||
"""
|
||||
boxes1, boxes2 = boxes1.tensor, boxes2.tensor
|
||||
width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max(
|
||||
boxes1[:, None, :2], boxes2[:, :2]
|
||||
) # [N,M,2]
|
||||
|
||||
width_height.clamp_(min=0) # [N,M,2]
|
||||
intersection = width_height.prod(dim=2) # [N,M]
|
||||
return intersection
|
||||
|
||||
|
||||
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
|
||||
# with slight modifications
|
||||
def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
||||
"""
|
||||
Given two lists of boxes of size N and M, compute the IoU
|
||||
(intersection over union) between **all** N x M pairs of boxes.
|
||||
The box order must be (xmin, ymin, xmax, ymax).
|
||||
|
||||
Args:
|
||||
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
|
||||
|
||||
Returns:
|
||||
Tensor: IoU, sized [N,M].
|
||||
"""
|
||||
area1 = boxes1.area() # [N]
|
||||
area2 = boxes2.area() # [M]
|
||||
inter = pairwise_intersection(boxes1, boxes2)
|
||||
|
||||
# handle empty boxes
|
||||
iou = torch.where(
|
||||
inter > 0,
|
||||
inter / (area1[:, None] + area2 - inter),
|
||||
torch.zeros(1, dtype=inter.dtype, device=inter.device),
|
||||
)
|
||||
return iou
|
||||
|
||||
|
||||
def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
||||
"""
|
||||
Similar to :func:`pariwise_iou` but compute the IoA (intersection over boxes2 area).
|
||||
|
||||
Args:
|
||||
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
|
||||
|
||||
Returns:
|
||||
Tensor: IoA, sized [N,M].
|
||||
"""
|
||||
area2 = boxes2.area() # [M]
|
||||
inter = pairwise_intersection(boxes1, boxes2)
|
||||
|
||||
# handle empty boxes
|
||||
ioa = torch.where(
|
||||
inter > 0, inter / area2, torch.zeros(1, dtype=inter.dtype, device=inter.device)
|
||||
)
|
||||
return ioa
|
||||
|
||||
|
||||
def pairwise_point_box_distance(points: torch.Tensor, boxes: Boxes):
|
||||
"""
|
||||
Pairwise distance between N points and M boxes. The distance between a
|
||||
point and a box is represented by the distance from the point to 4 edges
|
||||
of the box. Distances are all positive when the point is inside the box.
|
||||
|
||||
Args:
|
||||
points: Nx2 coordinates. Each row is (x, y)
|
||||
boxes: M boxes
|
||||
|
||||
Returns:
|
||||
Tensor: distances of size (N, M, 4). The 4 values are distances from
|
||||
the point to the left, top, right, bottom of the box.
|
||||
"""
|
||||
x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
|
||||
x0, y0, x1, y1 = boxes.tensor.unsqueeze(dim=0).unbind(dim=2) # (1, M)
|
||||
return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2)
|
||||
|
||||
|
||||
def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
||||
"""
|
||||
Compute pairwise intersection over union (IOU) of two sets of matched
|
||||
boxes that have the same number of boxes.
|
||||
Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix.
|
||||
|
||||
Args:
|
||||
boxes1 (Boxes): bounding boxes, sized [N,4].
|
||||
boxes2 (Boxes): same length as boxes1
|
||||
Returns:
|
||||
Tensor: iou, sized [N].
|
||||
"""
|
||||
assert len(boxes1) == len(boxes2), (
|
||||
"boxlists should have the same" "number of entries, got {}, {}".format(
|
||||
len(boxes1), len(boxes2)
|
||||
)
|
||||
)
|
||||
area1 = boxes1.area() # [N]
|
||||
area2 = boxes2.area() # [N]
|
||||
box1, box2 = boxes1.tensor, boxes2.tensor
|
||||
lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2]
|
||||
rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2]
|
||||
wh = (rb - lt).clamp(min=0) # [N,2]
|
||||
inter = wh[:, 0] * wh[:, 1] # [N]
|
||||
iou = inter / (area1 + area2 - inter) # [N]
|
||||
return iou
|
||||
150
sam3/agent/helpers/color_map.py
Normal file
150
sam3/agent/helpers/color_map.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
An awesome colormap for really neat visualizations.
|
||||
Copied from Detectron, and removed gray colors.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["colormap", "random_color", "random_colors"]
|
||||
|
||||
|
||||
# A list of 25 bright and sharp colors for segmentation masks,
|
||||
# generated from the edges of the sRGB color space for maximum intensity.
|
||||
_COLORS = (
|
||||
np.array(
|
||||
[
|
||||
# The original 8 sharp colors
|
||||
1.000,
|
||||
1.000,
|
||||
0.000, # 1. Yellow
|
||||
0.000,
|
||||
1.000,
|
||||
0.000, # 2. Lime
|
||||
0.000,
|
||||
1.000,
|
||||
1.000, # 3. Cyan
|
||||
1.000,
|
||||
0.000,
|
||||
1.000, # 4. Magenta
|
||||
1.000,
|
||||
0.000,
|
||||
0.000, # 5. Red
|
||||
1.000,
|
||||
0.498,
|
||||
0.000, # 6. Orange
|
||||
0.498,
|
||||
1.000,
|
||||
0.000, # 7. Chartreuse
|
||||
0.000,
|
||||
1.000,
|
||||
0.498, # 8. Spring Green
|
||||
1.000,
|
||||
0.000,
|
||||
0.498, # 9. Rose
|
||||
0.498,
|
||||
0.000,
|
||||
1.000, # 10. Violet
|
||||
0.753,
|
||||
1.000,
|
||||
0.000, # 11. Electric Lime
|
||||
1.000,
|
||||
0.753,
|
||||
0.000, # 12. Vivid Orange
|
||||
0.000,
|
||||
1.000,
|
||||
0.753, # 13. Turquoise
|
||||
0.753,
|
||||
0.000,
|
||||
1.000, # 14. Bright Violet
|
||||
1.000,
|
||||
0.000,
|
||||
0.753, # 15. Bright Pink
|
||||
1.000,
|
||||
0.251,
|
||||
0.000, # 16. Fiery Orange
|
||||
0.251,
|
||||
1.000,
|
||||
0.000, # 17. Bright Chartreuse
|
||||
0.000,
|
||||
1.000,
|
||||
0.251, # 18. Malachite Green
|
||||
0.251,
|
||||
0.000,
|
||||
1.000, # 19. Deep Violet
|
||||
1.000,
|
||||
0.000,
|
||||
0.251, # 20. Hot Pink
|
||||
]
|
||||
)
|
||||
.astype(np.float32)
|
||||
.reshape(-1, 3)
|
||||
)
|
||||
|
||||
|
||||
def colormap(rgb=False, maximum=255):
|
||||
"""
|
||||
Args:
|
||||
rgb (bool): whether to return RGB colors or BGR colors.
|
||||
maximum (int): either 255 or 1
|
||||
|
||||
Returns:
|
||||
ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
|
||||
"""
|
||||
assert maximum in [255, 1], maximum
|
||||
c = _COLORS * maximum
|
||||
if not rgb:
|
||||
c = c[:, ::-1]
|
||||
return c
|
||||
|
||||
|
||||
def random_color(rgb=False, maximum=255):
|
||||
"""
|
||||
Args:
|
||||
rgb (bool): whether to return RGB colors or BGR colors.
|
||||
maximum (int): either 255 or 1
|
||||
|
||||
Returns:
|
||||
ndarray: a vector of 3 numbers
|
||||
"""
|
||||
idx = np.random.randint(0, len(_COLORS))
|
||||
ret = _COLORS[idx] * maximum
|
||||
if not rgb:
|
||||
ret = ret[::-1]
|
||||
return ret
|
||||
|
||||
|
||||
def random_colors(N, rgb=False, maximum=255):
|
||||
"""
|
||||
Args:
|
||||
N (int): number of unique colors needed
|
||||
rgb (bool): whether to return RGB colors or BGR colors.
|
||||
maximum (int): either 255 or 1
|
||||
|
||||
Returns:
|
||||
ndarray: a list of random_color
|
||||
"""
|
||||
indices = random.sample(range(len(_COLORS)), N)
|
||||
ret = [_COLORS[i] * maximum for i in indices]
|
||||
if not rgb:
|
||||
ret = [x[::-1] for x in ret]
|
||||
return ret
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import cv2
|
||||
|
||||
size = 100
|
||||
H, W = 10, 10
|
||||
canvas = np.random.rand(H * size, W * size, 3).astype("float32")
|
||||
for h in range(H):
|
||||
for w in range(W):
|
||||
idx = h * W + w
|
||||
if idx >= len(_COLORS):
|
||||
break
|
||||
canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
|
||||
cv2.imshow("a", canvas)
|
||||
cv2.waitKey(0)
|
||||
244
sam3/agent/helpers/keypoints.py
Executable file
244
sam3/agent/helpers/keypoints.py
Executable file
@@ -0,0 +1,244 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Keypoints:
|
||||
"""
|
||||
Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property
|
||||
containing the x,y location and visibility flag of each keypoint. This tensor has shape
|
||||
(N, K, 3) where N is the number of instances and K is the number of keypoints per instance.
|
||||
|
||||
The visibility flag follows the COCO format and must be one of three integers:
|
||||
|
||||
* v=0: not labeled (in which case x=y=0)
|
||||
* v=1: labeled but not visible
|
||||
* v=2: labeled and visible
|
||||
"""
|
||||
|
||||
def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
|
||||
"""
|
||||
Arguments:
|
||||
keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint.
|
||||
The shape should be (N, K, 3) where N is the number of
|
||||
instances, and K is the number of keypoints per instance.
|
||||
"""
|
||||
device = (
|
||||
keypoints.device
|
||||
if isinstance(keypoints, torch.Tensor)
|
||||
else torch.device("cpu")
|
||||
)
|
||||
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
|
||||
assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape
|
||||
self.tensor = keypoints
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.tensor.size(0)
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any) -> "Keypoints":
|
||||
return type(self)(self.tensor.to(*args, **kwargs))
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.tensor.device
|
||||
|
||||
def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Convert keypoint annotations to a heatmap of one-hot labels for training,
|
||||
as described in :paper:`Mask R-CNN`.
|
||||
|
||||
Arguments:
|
||||
boxes: Nx4 tensor, the boxes to draw the keypoints to
|
||||
|
||||
Returns:
|
||||
heatmaps:
|
||||
A tensor of shape (N, K), each element is integer spatial label
|
||||
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
|
||||
valid:
|
||||
A tensor of shape (N, K) containing whether each keypoint is in the roi or not.
|
||||
"""
|
||||
return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size)
|
||||
|
||||
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints":
|
||||
"""
|
||||
Create a new `Keypoints` by indexing on this `Keypoints`.
|
||||
|
||||
The following usage are allowed:
|
||||
|
||||
1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance.
|
||||
2. `new_kpts = kpts[2:10]`: return a slice of key points.
|
||||
3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
|
||||
with `length = len(kpts)`. Nonzero elements in the vector will be selected.
|
||||
|
||||
Note that the returned Keypoints might share storage with this Keypoints,
|
||||
subject to Pytorch's indexing semantics.
|
||||
"""
|
||||
if isinstance(item, int):
|
||||
return Keypoints([self.tensor[item]])
|
||||
return Keypoints(self.tensor[item])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = self.__class__.__name__ + "("
|
||||
s += "num_instances={})".format(len(self.tensor))
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def cat(keypoints_list: List["Keypoints"]) -> "Keypoints":
|
||||
"""
|
||||
Concatenates a list of Keypoints into a single Keypoints
|
||||
|
||||
Arguments:
|
||||
keypoints_list (list[Keypoints])
|
||||
|
||||
Returns:
|
||||
Keypoints: the concatenated Keypoints
|
||||
"""
|
||||
assert isinstance(keypoints_list, (list, tuple))
|
||||
assert len(keypoints_list) > 0
|
||||
assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list)
|
||||
|
||||
cat_kpts = type(keypoints_list[0])(
|
||||
torch.cat([kpts.tensor for kpts in keypoints_list], dim=0)
|
||||
)
|
||||
return cat_kpts
|
||||
|
||||
|
||||
def _keypoints_to_heatmap(
|
||||
keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space.
|
||||
|
||||
Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the
|
||||
closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the
|
||||
continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"):
|
||||
d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
|
||||
|
||||
Arguments:
|
||||
keypoints: tensor of keypoint locations in of shape (N, K, 3).
|
||||
rois: Nx4 tensor of rois in xyxy format
|
||||
heatmap_size: integer side length of square heatmap.
|
||||
|
||||
Returns:
|
||||
heatmaps: A tensor of shape (N, K) containing an integer spatial label
|
||||
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
|
||||
valid: A tensor of shape (N, K) containing whether each keypoint is in
|
||||
the roi or not.
|
||||
"""
|
||||
|
||||
if rois.numel() == 0:
|
||||
return rois.new().long(), rois.new().long()
|
||||
offset_x = rois[:, 0]
|
||||
offset_y = rois[:, 1]
|
||||
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
|
||||
scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
|
||||
|
||||
offset_x = offset_x[:, None]
|
||||
offset_y = offset_y[:, None]
|
||||
scale_x = scale_x[:, None]
|
||||
scale_y = scale_y[:, None]
|
||||
|
||||
x = keypoints[..., 0]
|
||||
y = keypoints[..., 1]
|
||||
|
||||
x_boundary_inds = x == rois[:, 2][:, None]
|
||||
y_boundary_inds = y == rois[:, 3][:, None]
|
||||
|
||||
x = (x - offset_x) * scale_x
|
||||
x = x.floor().long()
|
||||
y = (y - offset_y) * scale_y
|
||||
y = y.floor().long()
|
||||
|
||||
x[x_boundary_inds] = heatmap_size - 1
|
||||
y[y_boundary_inds] = heatmap_size - 1
|
||||
|
||||
valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
|
||||
vis = keypoints[..., 2] > 0
|
||||
valid = (valid_loc & vis).long()
|
||||
|
||||
lin_ind = y * heatmap_size + x
|
||||
heatmaps = lin_ind * valid
|
||||
|
||||
return heatmaps, valid
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extract predicted keypoint locations from heatmaps.
|
||||
|
||||
Args:
|
||||
maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for
|
||||
each ROI and each keypoint.
|
||||
rois (Tensor): (#ROIs, 4). The box of each ROI.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to
|
||||
(x, y, logit, score) for each keypoint.
|
||||
|
||||
When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate,
|
||||
we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
|
||||
Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
|
||||
"""
|
||||
|
||||
offset_x = rois[:, 0]
|
||||
offset_y = rois[:, 1]
|
||||
|
||||
widths = (rois[:, 2] - rois[:, 0]).clamp(min=1)
|
||||
heights = (rois[:, 3] - rois[:, 1]).clamp(min=1)
|
||||
widths_ceil = widths.ceil()
|
||||
heights_ceil = heights.ceil()
|
||||
|
||||
num_rois, num_keypoints = maps.shape[:2]
|
||||
xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4)
|
||||
|
||||
width_corrections = widths / widths_ceil
|
||||
height_corrections = heights / heights_ceil
|
||||
|
||||
keypoints_idx = torch.arange(num_keypoints, device=maps.device)
|
||||
|
||||
for i in range(num_rois):
|
||||
outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
|
||||
roi_map = F.interpolate(
|
||||
maps[[i]], size=outsize, mode="bicubic", align_corners=False
|
||||
)
|
||||
|
||||
# Although semantically equivalent, `reshape` is used instead of `squeeze` due
|
||||
# to limitation during ONNX export of `squeeze` in scripting mode
|
||||
roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W
|
||||
|
||||
# softmax over the spatial region
|
||||
max_score, _ = roi_map.view(num_keypoints, -1).max(1)
|
||||
max_score = max_score.view(num_keypoints, 1, 1)
|
||||
tmp_full_resolution = (roi_map - max_score).exp_()
|
||||
tmp_pool_resolution = (maps[i] - max_score).exp_()
|
||||
# Produce scores over the region H x W, but normalize with POOL_H x POOL_W,
|
||||
# so that the scores of objects of different absolute sizes will be more comparable
|
||||
roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum(
|
||||
(1, 2), keepdim=True
|
||||
)
|
||||
|
||||
w = roi_map.shape[2]
|
||||
pos = roi_map.view(num_keypoints, -1).argmax(1)
|
||||
|
||||
x_int = pos % w
|
||||
y_int = (pos - x_int) // w
|
||||
|
||||
assert (
|
||||
roi_map_scores[keypoints_idx, y_int, x_int]
|
||||
== roi_map_scores.view(num_keypoints, -1).max(1)[0]
|
||||
).all()
|
||||
|
||||
x = (x_int.float() + 0.5) * width_corrections[i]
|
||||
y = (y_int.float() + 0.5) * height_corrections[i]
|
||||
|
||||
xy_preds[i, :, 0] = x + offset_x[i]
|
||||
xy_preds[i, :, 1] = y + offset_y[i]
|
||||
xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int]
|
||||
xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int]
|
||||
|
||||
return xy_preds
|
||||
128
sam3/agent/helpers/mask_overlap_removal.py
Normal file
128
sam3/agent/helpers/mask_overlap_removal.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
from pycocotools import mask as mask_utils
|
||||
except Exception:
|
||||
mask_utils = None
|
||||
|
||||
|
||||
def mask_intersection(
|
||||
masks1: torch.Tensor, masks2: torch.Tensor, block_size: int = 16
|
||||
) -> torch.Tensor:
|
||||
assert masks1.shape[1:] == masks2.shape[1:]
|
||||
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
|
||||
N, M = masks1.shape[0], masks2.shape[0]
|
||||
out = torch.zeros(N, M, device=masks1.device, dtype=torch.long)
|
||||
for i in range(0, N, block_size):
|
||||
for j in range(0, M, block_size):
|
||||
a = masks1[i : i + block_size]
|
||||
b = masks2[j : j + block_size]
|
||||
inter = (a[:, None] & b[None, :]).flatten(-2).sum(-1)
|
||||
out[i : i + block_size, j : j + block_size] = inter
|
||||
return out
|
||||
|
||||
|
||||
def mask_iom(masks1: torch.Tensor, masks2: torch.Tensor) -> torch.Tensor:
|
||||
assert masks1.shape[1:] == masks2.shape[1:]
|
||||
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
|
||||
inter = mask_intersection(masks1, masks2)
|
||||
area1 = masks1.flatten(-2).sum(-1) # (N,)
|
||||
area2 = masks2.flatten(-2).sum(-1) # (M,)
|
||||
min_area = torch.min(area1[:, None], area2[None, :]).clamp_min(1)
|
||||
return inter.float() / (min_area.float() + 1e-8)
|
||||
|
||||
|
||||
def _decode_single_mask(mask_repr, h: int, w: int) -> np.ndarray:
|
||||
if isinstance(mask_repr, (list, tuple, np.ndarray)):
|
||||
arr = np.array(mask_repr)
|
||||
if arr.ndim != 2:
|
||||
raise ValueError("Mask array must be 2D (H, W).")
|
||||
return (arr > 0).astype(np.uint8)
|
||||
|
||||
if mask_utils is None:
|
||||
raise ImportError(
|
||||
"pycocotools is required to decode RLE mask strings. pip install pycocotools"
|
||||
)
|
||||
|
||||
if not isinstance(mask_repr, (str, bytes)):
|
||||
raise ValueError("Unsupported mask representation type for RLE decode.")
|
||||
|
||||
rle = {
|
||||
"counts": mask_repr if isinstance(mask_repr, (str, bytes)) else str(mask_repr),
|
||||
"size": [h, w],
|
||||
}
|
||||
decoded = mask_utils.decode(rle)
|
||||
if decoded.ndim == 3:
|
||||
decoded = decoded[:, :, 0]
|
||||
return (decoded > 0).astype(np.uint8)
|
||||
|
||||
|
||||
def _decode_masks_to_torch_bool(pred_masks: List, h: int, w: int) -> torch.Tensor:
|
||||
bin_masks = [_decode_single_mask(m, h, w) for m in pred_masks]
|
||||
masks_np = np.stack(bin_masks, axis=0).astype(np.uint8) # (N, H, W)
|
||||
return torch.from_numpy(masks_np > 0)
|
||||
|
||||
|
||||
def remove_overlapping_masks(sample: Dict, iom_thresh: float = 0.3) -> Dict:
|
||||
"""
|
||||
Greedy keep: sort by score desc; keep a mask if IoM to all kept masks <= threshold.
|
||||
If pred_masks has length 0 or 1, returns sample unchanged (no extra keys).
|
||||
"""
|
||||
# Basic presence checks
|
||||
if "pred_masks" not in sample or not isinstance(sample["pred_masks"], list):
|
||||
return sample # nothing to do / preserve as-is
|
||||
|
||||
pred_masks = sample["pred_masks"]
|
||||
N = len(pred_masks)
|
||||
|
||||
# --- Early exit: 0 or 1 mask -> do NOT modify the JSON at all ---
|
||||
if N <= 1:
|
||||
return sample
|
||||
|
||||
# From here on we have at least 2 masks
|
||||
h = int(sample["orig_img_h"])
|
||||
w = int(sample["orig_img_w"])
|
||||
pred_scores = sample.get("pred_scores", [1.0] * N) # fallback if scores missing
|
||||
pred_boxes = sample.get("pred_boxes", None)
|
||||
|
||||
assert N == len(pred_scores), "pred_masks and pred_scores must have same length"
|
||||
if pred_boxes is not None:
|
||||
assert N == len(pred_boxes), "pred_masks and pred_boxes must have same length"
|
||||
|
||||
masks_bool = _decode_masks_to_torch_bool(pred_masks, h, w) # (N, H, W)
|
||||
|
||||
order = sorted(range(N), key=lambda i: float(pred_scores[i]), reverse=True)
|
||||
kept_idx: List[int] = []
|
||||
kept_masks: List[torch.Tensor] = []
|
||||
|
||||
for i in order:
|
||||
cand = masks_bool[i].unsqueeze(0) # (1, H, W)
|
||||
if len(kept_masks) == 0:
|
||||
kept_idx.append(i)
|
||||
kept_masks.append(masks_bool[i])
|
||||
continue
|
||||
|
||||
kept_stack = torch.stack(kept_masks, dim=0) # (K, H, W)
|
||||
iom_vals = mask_iom(cand, kept_stack).squeeze(0) # (K,)
|
||||
if torch.any(iom_vals > iom_thresh):
|
||||
continue # overlaps too much with a higher-scored kept mask
|
||||
kept_idx.append(i)
|
||||
kept_masks.append(masks_bool[i])
|
||||
|
||||
kept_idx_sorted = sorted(kept_idx)
|
||||
|
||||
# Build filtered JSON (this *does* modify fields; only for N>=2 case)
|
||||
out = dict(sample)
|
||||
out["pred_masks"] = [pred_masks[i] for i in kept_idx_sorted]
|
||||
out["pred_scores"] = [pred_scores[i] for i in kept_idx_sorted]
|
||||
if pred_boxes is not None:
|
||||
out["pred_boxes"] = [pred_boxes[i] for i in kept_idx_sorted]
|
||||
out["kept_indices"] = kept_idx_sorted
|
||||
out["removed_indices"] = [i for i in range(N) if i not in set(kept_idx_sorted)]
|
||||
out["iom_threshold"] = float(iom_thresh)
|
||||
return out
|
||||
560
sam3/agent/helpers/masks.py
Executable file
560
sam3/agent/helpers/masks.py
Executable file
@@ -0,0 +1,560 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Any, Iterator, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pycocotools.mask as mask_util
|
||||
import torch
|
||||
from torch import device
|
||||
|
||||
from .boxes import Boxes
|
||||
from .memory import retry_if_cuda_oom
|
||||
|
||||
from .roi_align import ROIAlign
|
||||
|
||||
|
||||
def polygon_area(x, y):
|
||||
# Using the shoelace formula
|
||||
# https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
|
||||
return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
|
||||
|
||||
|
||||
def polygons_to_bitmask(
|
||||
polygons: List[np.ndarray], height: int, width: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Args:
|
||||
polygons (list[ndarray]): each array has shape (Nx2,)
|
||||
height, width (int)
|
||||
|
||||
Returns:
|
||||
ndarray: a bool mask of shape (height, width)
|
||||
"""
|
||||
if len(polygons) == 0:
|
||||
# COCOAPI does not support empty polygons
|
||||
return np.zeros((height, width)).astype(bool)
|
||||
rles = mask_util.frPyObjects(polygons, height, width)
|
||||
rle = mask_util.merge(rles)
|
||||
return mask_util.decode(rle).astype(bool)
|
||||
|
||||
|
||||
def rasterize_polygons_within_box(
|
||||
polygons: List[np.ndarray], box: np.ndarray, mask_size: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Rasterize the polygons into a mask image and
|
||||
crop the mask content in the given box.
|
||||
The cropped mask is resized to (mask_size, mask_size).
|
||||
|
||||
This function is used when generating training targets for mask head in Mask R-CNN.
|
||||
Given original ground-truth masks for an image, new ground-truth mask
|
||||
training targets in the size of `mask_size x mask_size`
|
||||
must be provided for each predicted box. This function will be called to
|
||||
produce such targets.
|
||||
|
||||
Args:
|
||||
polygons (list[ndarray[float]]): a list of polygons, which represents an instance.
|
||||
box: 4-element numpy array
|
||||
mask_size (int):
|
||||
|
||||
Returns:
|
||||
Tensor: BoolTensor of shape (mask_size, mask_size)
|
||||
"""
|
||||
# 1. Shift the polygons w.r.t the boxes
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
|
||||
polygons = copy.deepcopy(polygons)
|
||||
for p in polygons:
|
||||
p[0::2] = p[0::2] - box[0]
|
||||
p[1::2] = p[1::2] - box[1]
|
||||
|
||||
# 2. Rescale the polygons to the new box size
|
||||
# max() to avoid division by small number
|
||||
ratio_h = mask_size / max(h, 0.1)
|
||||
ratio_w = mask_size / max(w, 0.1)
|
||||
|
||||
if ratio_h == ratio_w:
|
||||
for p in polygons:
|
||||
p *= ratio_h
|
||||
else:
|
||||
for p in polygons:
|
||||
p[0::2] *= ratio_w
|
||||
p[1::2] *= ratio_h
|
||||
|
||||
# 3. Rasterize the polygons with coco api
|
||||
mask = polygons_to_bitmask(polygons, mask_size, mask_size)
|
||||
mask = torch.from_numpy(mask)
|
||||
return mask
|
||||
|
||||
|
||||
class BitMasks:
|
||||
"""
|
||||
This class stores the segmentation masks for all objects in one image, in
|
||||
the form of bitmaps.
|
||||
|
||||
Attributes:
|
||||
tensor: bool Tensor of N,H,W, representing N instances in the image.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: Union[torch.Tensor, np.ndarray]):
|
||||
"""
|
||||
Args:
|
||||
tensor: bool Tensor of N,H,W, representing N instances in the image.
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.to(torch.bool)
|
||||
else:
|
||||
tensor = torch.as_tensor(
|
||||
tensor, dtype=torch.bool, device=torch.device("cpu")
|
||||
)
|
||||
assert tensor.dim() == 3, tensor.size()
|
||||
self.image_size = tensor.shape[1:]
|
||||
self.tensor = tensor
|
||||
|
||||
@torch.jit.unused
|
||||
def to(self, *args: Any, **kwargs: Any) -> "BitMasks":
|
||||
return BitMasks(self.tensor.to(*args, **kwargs))
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.tensor.device
|
||||
|
||||
@torch.jit.unused
|
||||
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks":
|
||||
"""
|
||||
Returns:
|
||||
BitMasks: Create a new :class:`BitMasks` by indexing.
|
||||
|
||||
The following usage are allowed:
|
||||
|
||||
1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask.
|
||||
2. `new_masks = masks[2:10]`: return a slice of masks.
|
||||
3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
|
||||
with `length = len(masks)`. Nonzero elements in the vector will be selected.
|
||||
|
||||
Note that the returned object might share storage with this object,
|
||||
subject to Pytorch's indexing semantics.
|
||||
"""
|
||||
if isinstance(item, int):
|
||||
return BitMasks(self.tensor[item].unsqueeze(0))
|
||||
m = self.tensor[item]
|
||||
assert (
|
||||
m.dim() == 3
|
||||
), "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
|
||||
item, m.shape
|
||||
)
|
||||
return BitMasks(m)
|
||||
|
||||
@torch.jit.unused
|
||||
def __iter__(self) -> torch.Tensor:
|
||||
yield from self.tensor
|
||||
|
||||
@torch.jit.unused
|
||||
def __repr__(self) -> str:
|
||||
s = self.__class__.__name__ + "("
|
||||
s += "num_instances={})".format(len(self.tensor))
|
||||
return s
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.tensor.shape[0]
|
||||
|
||||
def nonempty(self) -> torch.Tensor:
|
||||
"""
|
||||
Find masks that are non-empty.
|
||||
|
||||
Returns:
|
||||
Tensor: a BoolTensor which represents
|
||||
whether each mask is empty (False) or non-empty (True).
|
||||
"""
|
||||
return self.tensor.flatten(1).any(dim=1)
|
||||
|
||||
@staticmethod
|
||||
def from_polygon_masks(
|
||||
polygon_masks: Union["PolygonMasks", List[List[np.ndarray]]],
|
||||
height: int,
|
||||
width: int,
|
||||
) -> "BitMasks":
|
||||
"""
|
||||
Args:
|
||||
polygon_masks (list[list[ndarray]] or PolygonMasks)
|
||||
height, width (int)
|
||||
"""
|
||||
if isinstance(polygon_masks, PolygonMasks):
|
||||
polygon_masks = polygon_masks.polygons
|
||||
masks = [polygons_to_bitmask(p, height, width) for p in polygon_masks]
|
||||
if len(masks):
|
||||
return BitMasks(torch.stack([torch.from_numpy(x) for x in masks]))
|
||||
else:
|
||||
return BitMasks(torch.empty(0, height, width, dtype=torch.bool))
|
||||
|
||||
@staticmethod
|
||||
def from_roi_masks(roi_masks: "ROIMasks", height: int, width: int) -> "BitMasks":
|
||||
"""
|
||||
Args:
|
||||
roi_masks:
|
||||
height, width (int):
|
||||
"""
|
||||
return roi_masks.to_bitmasks(height, width)
|
||||
|
||||
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Crop each bitmask by the given box, and resize results to (mask_size, mask_size).
|
||||
This can be used to prepare training targets for Mask R-CNN.
|
||||
It has less reconstruction error compared to rasterization with polygons.
|
||||
However we observe no difference in accuracy,
|
||||
but BitMasks requires more memory to store all the masks.
|
||||
|
||||
Args:
|
||||
boxes (Tensor): Nx4 tensor storing the boxes for each mask
|
||||
mask_size (int): the size of the rasterized mask.
|
||||
|
||||
Returns:
|
||||
Tensor:
|
||||
A bool tensor of shape (N, mask_size, mask_size), where
|
||||
N is the number of predicted boxes for this image.
|
||||
"""
|
||||
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
|
||||
device = self.tensor.device
|
||||
|
||||
batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[
|
||||
:, None
|
||||
]
|
||||
rois = torch.cat([batch_inds, boxes], dim=1) # Nx5
|
||||
|
||||
bit_masks = self.tensor.to(dtype=torch.float32)
|
||||
rois = rois.to(device=device)
|
||||
output = (
|
||||
ROIAlign((mask_size, mask_size), 1.0, 0, aligned=True)
|
||||
.forward(bit_masks[:, None, :, :], rois)
|
||||
.squeeze(1)
|
||||
)
|
||||
output = output >= 0.5
|
||||
return output
|
||||
|
||||
def get_bounding_boxes(self) -> Boxes:
|
||||
"""
|
||||
Returns:
|
||||
Boxes: tight bounding boxes around bitmasks.
|
||||
If a mask is empty, it's bounding box will be all zero.
|
||||
"""
|
||||
boxes = torch.zeros(self.tensor.shape[0], 4, dtype=torch.float32)
|
||||
x_any = torch.any(self.tensor, dim=1)
|
||||
y_any = torch.any(self.tensor, dim=2)
|
||||
for idx in range(self.tensor.shape[0]):
|
||||
x = torch.where(x_any[idx, :])[0]
|
||||
y = torch.where(y_any[idx, :])[0]
|
||||
if len(x) > 0 and len(y) > 0:
|
||||
boxes[idx, :] = torch.as_tensor(
|
||||
[x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=torch.float32
|
||||
)
|
||||
return Boxes(boxes)
|
||||
|
||||
@staticmethod
|
||||
def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks":
|
||||
"""
|
||||
Concatenates a list of BitMasks into a single BitMasks
|
||||
|
||||
Arguments:
|
||||
bitmasks_list (list[BitMasks])
|
||||
|
||||
Returns:
|
||||
BitMasks: the concatenated BitMasks
|
||||
"""
|
||||
assert isinstance(bitmasks_list, (list, tuple))
|
||||
assert len(bitmasks_list) > 0
|
||||
assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list)
|
||||
|
||||
cat_bitmasks = type(bitmasks_list[0])(
|
||||
torch.cat([bm.tensor for bm in bitmasks_list], dim=0)
|
||||
)
|
||||
return cat_bitmasks
|
||||
|
||||
|
||||
class PolygonMasks:
|
||||
"""
|
||||
This class stores the segmentation masks for all objects in one image, in the form of polygons.
|
||||
|
||||
Attributes:
|
||||
polygons: list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon.
|
||||
"""
|
||||
|
||||
def __init__(self, polygons: List[List[Union[torch.Tensor, np.ndarray]]]):
|
||||
"""
|
||||
Arguments:
|
||||
polygons (list[list[np.ndarray]]): The first
|
||||
level of the list correspond to individual instances,
|
||||
the second level to all the polygons that compose the
|
||||
instance, and the third level to the polygon coordinates.
|
||||
The third level array should have the format of
|
||||
[x0, y0, x1, y1, ..., xn, yn] (n >= 3).
|
||||
"""
|
||||
if not isinstance(polygons, list):
|
||||
raise ValueError(
|
||||
"Cannot create PolygonMasks: Expect a list of list of polygons per image. "
|
||||
"Got '{}' instead.".format(type(polygons))
|
||||
)
|
||||
|
||||
def _make_array(t: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
||||
# Use float64 for higher precision, because why not?
|
||||
# Always put polygons on CPU (self.to is a no-op) since they
|
||||
# are supposed to be small tensors.
|
||||
# May need to change this assumption if GPU placement becomes useful
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = t.cpu().numpy()
|
||||
return np.asarray(t).astype("float64")
|
||||
|
||||
def process_polygons(
|
||||
polygons_per_instance: List[Union[torch.Tensor, np.ndarray]],
|
||||
) -> List[np.ndarray]:
|
||||
if not isinstance(polygons_per_instance, list):
|
||||
raise ValueError(
|
||||
"Cannot create polygons: Expect a list of polygons per instance. "
|
||||
"Got '{}' instead.".format(type(polygons_per_instance))
|
||||
)
|
||||
# transform each polygon to a numpy array
|
||||
polygons_per_instance = [_make_array(p) for p in polygons_per_instance]
|
||||
for polygon in polygons_per_instance:
|
||||
if len(polygon) % 2 != 0 or len(polygon) < 6:
|
||||
raise ValueError(
|
||||
f"Cannot create a polygon from {len(polygon)} coordinates."
|
||||
)
|
||||
return polygons_per_instance
|
||||
|
||||
self.polygons: List[List[np.ndarray]] = [
|
||||
process_polygons(polygons_per_instance)
|
||||
for polygons_per_instance in polygons
|
||||
]
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any) -> "PolygonMasks":
|
||||
return self
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
def get_bounding_boxes(self) -> Boxes:
|
||||
"""
|
||||
Returns:
|
||||
Boxes: tight bounding boxes around polygon masks.
|
||||
"""
|
||||
boxes = torch.zeros(len(self.polygons), 4, dtype=torch.float32)
|
||||
for idx, polygons_per_instance in enumerate(self.polygons):
|
||||
minxy = torch.as_tensor([float("inf"), float("inf")], dtype=torch.float32)
|
||||
maxxy = torch.zeros(2, dtype=torch.float32)
|
||||
for polygon in polygons_per_instance:
|
||||
coords = torch.from_numpy(polygon).view(-1, 2).to(dtype=torch.float32)
|
||||
minxy = torch.min(minxy, torch.min(coords, dim=0).values)
|
||||
maxxy = torch.max(maxxy, torch.max(coords, dim=0).values)
|
||||
boxes[idx, :2] = minxy
|
||||
boxes[idx, 2:] = maxxy
|
||||
return Boxes(boxes)
|
||||
|
||||
def nonempty(self) -> torch.Tensor:
|
||||
"""
|
||||
Find masks that are non-empty.
|
||||
|
||||
Returns:
|
||||
Tensor:
|
||||
a BoolTensor which represents whether each mask is empty (False) or not (True).
|
||||
"""
|
||||
keep = [1 if len(polygon) > 0 else 0 for polygon in self.polygons]
|
||||
return torch.from_numpy(np.asarray(keep, dtype=bool))
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice, List[int], torch.BoolTensor]
|
||||
) -> "PolygonMasks":
|
||||
"""
|
||||
Support indexing over the instances and return a `PolygonMasks` object.
|
||||
`item` can be:
|
||||
|
||||
1. An integer. It will return an object with only one instance.
|
||||
2. A slice. It will return an object with the selected instances.
|
||||
3. A list[int]. It will return an object with the selected instances,
|
||||
correpsonding to the indices in the list.
|
||||
4. A vector mask of type BoolTensor, whose length is num_instances.
|
||||
It will return an object with the instances whose mask is nonzero.
|
||||
"""
|
||||
if isinstance(item, int):
|
||||
selected_polygons = [self.polygons[item]]
|
||||
elif isinstance(item, slice):
|
||||
selected_polygons = self.polygons[item]
|
||||
elif isinstance(item, list):
|
||||
selected_polygons = [self.polygons[i] for i in item]
|
||||
elif isinstance(item, torch.Tensor):
|
||||
# Polygons is a list, so we have to move the indices back to CPU.
|
||||
if item.dtype == torch.bool:
|
||||
assert item.dim() == 1, item.shape
|
||||
item = item.nonzero().squeeze(1).cpu().numpy().tolist()
|
||||
elif item.dtype in [torch.int32, torch.int64]:
|
||||
item = item.cpu().numpy().tolist()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported tensor dtype={} for indexing!".format(item.dtype)
|
||||
)
|
||||
selected_polygons = [self.polygons[i] for i in item]
|
||||
return PolygonMasks(selected_polygons)
|
||||
|
||||
def __iter__(self) -> Iterator[List[np.ndarray]]:
|
||||
"""
|
||||
Yields:
|
||||
list[ndarray]: the polygons for one instance.
|
||||
Each Tensor is a float64 vector representing a polygon.
|
||||
"""
|
||||
return iter(self.polygons)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = self.__class__.__name__ + "("
|
||||
s += "num_instances={})".format(len(self.polygons))
|
||||
return s
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.polygons)
|
||||
|
||||
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Crop each mask by the given box, and resize results to (mask_size, mask_size).
|
||||
This can be used to prepare training targets for Mask R-CNN.
|
||||
|
||||
Args:
|
||||
boxes (Tensor): Nx4 tensor storing the boxes for each mask
|
||||
mask_size (int): the size of the rasterized mask.
|
||||
|
||||
Returns:
|
||||
Tensor: A bool tensor of shape (N, mask_size, mask_size), where
|
||||
N is the number of predicted boxes for this image.
|
||||
"""
|
||||
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
|
||||
|
||||
device = boxes.device
|
||||
# Put boxes on the CPU, as the polygon representation is not efficient GPU-wise
|
||||
# (several small tensors for representing a single instance mask)
|
||||
boxes = boxes.to(torch.device("cpu"))
|
||||
|
||||
results = [
|
||||
rasterize_polygons_within_box(poly, box.numpy(), mask_size)
|
||||
for poly, box in zip(self.polygons, boxes)
|
||||
]
|
||||
"""
|
||||
poly: list[list[float]], the polygons for one instance
|
||||
box: a tensor of shape (4,)
|
||||
"""
|
||||
if len(results) == 0:
|
||||
return torch.empty(0, mask_size, mask_size, dtype=torch.bool, device=device)
|
||||
return torch.stack(results, dim=0).to(device=device)
|
||||
|
||||
def area(self):
|
||||
"""
|
||||
Computes area of the mask.
|
||||
Only works with Polygons, using the shoelace formula:
|
||||
https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
|
||||
|
||||
Returns:
|
||||
Tensor: a vector, area for each instance
|
||||
"""
|
||||
|
||||
area = []
|
||||
for polygons_per_instance in self.polygons:
|
||||
area_per_instance = 0
|
||||
for p in polygons_per_instance:
|
||||
area_per_instance += polygon_area(p[0::2], p[1::2])
|
||||
area.append(area_per_instance)
|
||||
|
||||
return torch.tensor(area)
|
||||
|
||||
@staticmethod
|
||||
def cat(polymasks_list: List["PolygonMasks"]) -> "PolygonMasks":
|
||||
"""
|
||||
Concatenates a list of PolygonMasks into a single PolygonMasks
|
||||
|
||||
Arguments:
|
||||
polymasks_list (list[PolygonMasks])
|
||||
|
||||
Returns:
|
||||
PolygonMasks: the concatenated PolygonMasks
|
||||
"""
|
||||
assert isinstance(polymasks_list, (list, tuple))
|
||||
assert len(polymasks_list) > 0
|
||||
assert all(isinstance(polymask, PolygonMasks) for polymask in polymasks_list)
|
||||
|
||||
cat_polymasks = type(polymasks_list[0])(
|
||||
list(itertools.chain.from_iterable(pm.polygons for pm in polymasks_list))
|
||||
)
|
||||
return cat_polymasks
|
||||
|
||||
|
||||
class ROIMasks:
|
||||
"""
|
||||
Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given,
|
||||
full-image bitmask can be obtained by "pasting" the mask on the region defined
|
||||
by the corresponding ROI box.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
tensor: (N, M, M) mask tensor that defines the mask within each ROI.
|
||||
"""
|
||||
if tensor.dim() != 3:
|
||||
raise ValueError("ROIMasks must take a masks of 3 dimension.")
|
||||
self.tensor = tensor
|
||||
|
||||
def to(self, device: torch.device) -> "ROIMasks":
|
||||
return ROIMasks(self.tensor.to(device))
|
||||
|
||||
@property
|
||||
def device(self) -> device:
|
||||
return self.tensor.device
|
||||
|
||||
def __len__(self):
|
||||
return self.tensor.shape[0]
|
||||
|
||||
def __getitem__(self, item) -> "ROIMasks":
|
||||
"""
|
||||
Returns:
|
||||
ROIMasks: Create a new :class:`ROIMasks` by indexing.
|
||||
|
||||
The following usage are allowed:
|
||||
|
||||
1. `new_masks = masks[2:10]`: return a slice of masks.
|
||||
2. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
|
||||
with `length = len(masks)`. Nonzero elements in the vector will be selected.
|
||||
|
||||
Note that the returned object might share storage with this object,
|
||||
subject to Pytorch's indexing semantics.
|
||||
"""
|
||||
t = self.tensor[item]
|
||||
if t.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Indexing on ROIMasks with {item} returns a tensor with shape {t.shape}!"
|
||||
)
|
||||
return ROIMasks(t)
|
||||
|
||||
@torch.jit.unused
|
||||
def __repr__(self) -> str:
|
||||
s = self.__class__.__name__ + "("
|
||||
s += "num_instances={})".format(len(self.tensor))
|
||||
return s
|
||||
|
||||
@torch.jit.unused
|
||||
def to_bitmasks(self, boxes: torch.Tensor, height, width, threshold=0.5):
|
||||
"""
|
||||
Args: see documentation of :func:`paste_masks_in_image`.
|
||||
"""
|
||||
from detectron2.layers.mask_ops import (
|
||||
_paste_masks_tensor_shape,
|
||||
paste_masks_in_image,
|
||||
)
|
||||
|
||||
if torch.jit.is_tracing():
|
||||
if isinstance(height, torch.Tensor):
|
||||
paste_func = _paste_masks_tensor_shape
|
||||
else:
|
||||
paste_func = paste_masks_in_image
|
||||
else:
|
||||
paste_func = retry_if_cuda_oom(paste_masks_in_image)
|
||||
bitmasks = paste_func(
|
||||
self.tensor, boxes.tensor, (height, width), threshold=threshold
|
||||
)
|
||||
return BitMasks(bitmasks)
|
||||
87
sam3/agent/helpers/memory.py
Executable file
87
sam3/agent/helpers/memory.py
Executable file
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["retry_if_cuda_oom"]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _ignore_torch_cuda_oom():
|
||||
"""
|
||||
A context which ignores CUDA OOM exception from pytorch.
|
||||
"""
|
||||
try:
|
||||
yield
|
||||
except RuntimeError as e:
|
||||
# NOTE: the string may change?
|
||||
if "CUDA out of memory. " in str(e):
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def retry_if_cuda_oom(func):
|
||||
"""
|
||||
Makes a function retry itself after encountering
|
||||
pytorch's CUDA OOM error.
|
||||
It will first retry after calling `torch.cuda.empty_cache()`.
|
||||
|
||||
If that still fails, it will then retry by trying to convert inputs to CPUs.
|
||||
In this case, it expects the function to dispatch to CPU implementation.
|
||||
The return values may become CPU tensors as well and it's user's
|
||||
responsibility to convert it back to CUDA tensor if needed.
|
||||
|
||||
Args:
|
||||
func: a stateless callable that takes tensor-like objects as arguments
|
||||
|
||||
Returns:
|
||||
a callable which retries `func` if OOM is encountered.
|
||||
|
||||
Examples:
|
||||
::
|
||||
output = retry_if_cuda_oom(some_torch_function)(input1, input2)
|
||||
# output may be on CPU even if inputs are on GPU
|
||||
|
||||
Note:
|
||||
1. When converting inputs to CPU, it will only look at each argument and check
|
||||
if it has `.device` and `.to` for conversion. Nested structures of tensors
|
||||
are not supported.
|
||||
|
||||
2. Since the function might be called more than once, it has to be
|
||||
stateless.
|
||||
"""
|
||||
|
||||
def maybe_to_cpu(x):
|
||||
try:
|
||||
like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
|
||||
except AttributeError:
|
||||
like_gpu_tensor = False
|
||||
if like_gpu_tensor:
|
||||
return x.to(device="cpu")
|
||||
else:
|
||||
return x
|
||||
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
with _ignore_torch_cuda_oom():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Clear cache and retry
|
||||
torch.cuda.empty_cache()
|
||||
with _ignore_torch_cuda_oom():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Try on CPU. This slows down the code significantly, therefore print a notice.
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(
|
||||
"Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func))
|
||||
)
|
||||
new_args = (maybe_to_cpu(x) for x in args)
|
||||
new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
return wrapped
|
||||
122
sam3/agent/helpers/rle.py
Executable file
122
sam3/agent/helpers/rle.py
Executable file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""Some utilities for RLE encoding that doesn't require downloading the masks to the cpu"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pycocotools import mask as mask_util
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def rle_encode(orig_mask, return_areas=False):
|
||||
"""Encodes a collection of masks in RLE format
|
||||
|
||||
This function emulates the behavior of the COCO API's encode function, but
|
||||
is executed partially on the GPU for faster execution.
|
||||
|
||||
Args:
|
||||
mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool
|
||||
return_areas (bool): If True, add the areas of the masks as a part of
|
||||
the RLE output dict under the "area" key. Default is False.
|
||||
|
||||
Returns:
|
||||
str: The RLE encoded masks
|
||||
"""
|
||||
assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)"
|
||||
assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool"
|
||||
|
||||
if orig_mask.numel() == 0:
|
||||
return []
|
||||
|
||||
# First, transpose the spatial dimensions.
|
||||
# This is necessary because the COCO API uses Fortran order
|
||||
mask = orig_mask.transpose(1, 2)
|
||||
|
||||
# Flatten the mask
|
||||
flat_mask = mask.reshape(mask.shape[0], -1)
|
||||
if return_areas:
|
||||
mask_areas = flat_mask.sum(-1).tolist()
|
||||
# Find the indices where the mask changes
|
||||
differences = torch.ones(
|
||||
mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool
|
||||
)
|
||||
differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]
|
||||
differences[:, 0] = flat_mask[:, 0]
|
||||
_, change_indices = torch.where(differences)
|
||||
|
||||
try:
|
||||
boundaries = torch.cumsum(differences.sum(-1), 0).cpu()
|
||||
except RuntimeError as _:
|
||||
boundaries = torch.cumsum(differences.cpu().sum(-1), 0)
|
||||
|
||||
change_indices_clone = change_indices.clone()
|
||||
# First pass computes the RLEs on GPU, in a flatten format
|
||||
for i in range(mask.shape[0]):
|
||||
# Get the change indices for this batch item
|
||||
beg = 0 if i == 0 else boundaries[i - 1].item()
|
||||
end = boundaries[i].item()
|
||||
change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1]
|
||||
|
||||
# Now we can split the RLES of each batch item, and convert them to strings
|
||||
# No more gpu at this point
|
||||
change_indices = change_indices.tolist()
|
||||
|
||||
batch_rles = []
|
||||
# Process each mask in the batch separately
|
||||
for i in range(mask.shape[0]):
|
||||
beg = 0 if i == 0 else boundaries[i - 1].item()
|
||||
end = boundaries[i].item()
|
||||
run_lengths = change_indices[beg:end]
|
||||
|
||||
uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}
|
||||
h, w = uncompressed_rle["size"]
|
||||
rle = mask_util.frPyObjects(uncompressed_rle, h, w)
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
if return_areas:
|
||||
rle["area"] = mask_areas[i]
|
||||
batch_rles.append(rle)
|
||||
|
||||
return batch_rles
|
||||
|
||||
|
||||
def robust_rle_encode(masks):
|
||||
"""Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails"""
|
||||
|
||||
assert masks.ndim == 3, "Mask must be of shape (N, H, W)"
|
||||
assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool"
|
||||
|
||||
try:
|
||||
return rle_encode(masks)
|
||||
except RuntimeError as _:
|
||||
masks = masks.cpu().numpy()
|
||||
rles = [
|
||||
mask_util.encode(
|
||||
np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F")
|
||||
)[0]
|
||||
for mask in masks
|
||||
]
|
||||
for rle in rles:
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
return rles
|
||||
|
||||
|
||||
def ann_to_rle(segm, im_info):
|
||||
"""Convert annotation which can be polygons, uncompressed RLE to RLE.
|
||||
Args:
|
||||
ann (dict) : annotation object
|
||||
Returns:
|
||||
ann (rle)
|
||||
"""
|
||||
h, w = im_info["height"], im_info["width"]
|
||||
if isinstance(segm, list):
|
||||
# polygon -- a single object might consist of multiple parts
|
||||
# we merge all parts into one mask rle code
|
||||
rles = mask_util.frPyObjects(segm, h, w)
|
||||
rle = mask_util.merge(rles)
|
||||
elif isinstance(segm["counts"], list):
|
||||
# uncompressed RLE
|
||||
rle = mask_util.frPyObjects(segm, h, w)
|
||||
else:
|
||||
# rle
|
||||
rle = segm
|
||||
return rle
|
||||
75
sam3/agent/helpers/roi_align.py
Executable file
75
sam3/agent/helpers/roi_align.py
Executable file
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from torch import nn
|
||||
from torchvision.ops import roi_align
|
||||
|
||||
|
||||
# NOTE: torchvision's RoIAlign has a different default aligned=False
|
||||
class ROIAlign(nn.Module):
|
||||
def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True):
|
||||
"""
|
||||
Args:
|
||||
output_size (tuple): h, w
|
||||
spatial_scale (float): scale the input boxes by this number
|
||||
sampling_ratio (int): number of inputs samples to take for each output
|
||||
sample. 0 to take samples densely.
|
||||
aligned (bool): if False, use the legacy implementation in
|
||||
Detectron. If True, align the results more perfectly.
|
||||
|
||||
Note:
|
||||
The meaning of aligned=True:
|
||||
|
||||
Given a continuous coordinate c, its two neighboring pixel indices (in our
|
||||
pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example,
|
||||
c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled
|
||||
from the underlying signal at continuous coordinates 0.5 and 1.5). But the original
|
||||
roi_align (aligned=False) does not subtract the 0.5 when computing neighboring
|
||||
pixel indices and therefore it uses pixels with a slightly incorrect alignment
|
||||
(relative to our pixel model) when performing bilinear interpolation.
|
||||
|
||||
With `aligned=True`,
|
||||
we first appropriately scale the ROI and then shift it by -0.5
|
||||
prior to calling roi_align. This produces the correct neighbors; see
|
||||
detectron2/tests/test_roi_align.py for verification.
|
||||
|
||||
The difference does not make a difference to the model's performance if
|
||||
ROIAlign is used together with conv layers.
|
||||
"""
|
||||
super().__init__()
|
||||
self.output_size = output_size
|
||||
self.spatial_scale = spatial_scale
|
||||
self.sampling_ratio = sampling_ratio
|
||||
self.aligned = aligned
|
||||
|
||||
from torchvision import __version__
|
||||
|
||||
version = tuple(int(x) for x in __version__.split(".")[:2])
|
||||
# https://github.com/pytorch/vision/pull/2438
|
||||
assert version >= (0, 7), "Require torchvision >= 0.7"
|
||||
|
||||
def forward(self, input, rois):
|
||||
"""
|
||||
Args:
|
||||
input: NCHW images
|
||||
rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
|
||||
"""
|
||||
assert rois.dim() == 2 and rois.size(1) == 5
|
||||
if input.is_quantized:
|
||||
input = input.dequantize()
|
||||
return roi_align(
|
||||
input,
|
||||
rois.to(dtype=input.dtype),
|
||||
self.output_size,
|
||||
self.spatial_scale,
|
||||
self.sampling_ratio,
|
||||
self.aligned,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = self.__class__.__name__ + "("
|
||||
tmpstr += "output_size=" + str(self.output_size)
|
||||
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
|
||||
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
|
||||
tmpstr += ", aligned=" + str(self.aligned)
|
||||
tmpstr += ")"
|
||||
return tmpstr
|
||||
533
sam3/agent/helpers/rotated_boxes.py
Executable file
533
sam3/agent/helpers/rotated_boxes.py
Executable file
@@ -0,0 +1,533 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
# from detectron2.layers.rotated_boxes import pairwise_iou_rotated
|
||||
|
||||
from .boxes import Boxes
|
||||
|
||||
|
||||
def pairwise_iou_rotated(boxes1, boxes2):
|
||||
"""
|
||||
Return intersection-over-union (Jaccard index) of boxes.
|
||||
|
||||
Both sets of boxes are expected to be in
|
||||
(x_center, y_center, width, height, angle) format.
|
||||
|
||||
Arguments:
|
||||
boxes1 (Tensor[N, 5])
|
||||
boxes2 (Tensor[M, 5])
|
||||
|
||||
Returns:
|
||||
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
||||
IoU values for every element in boxes1 and boxes2
|
||||
"""
|
||||
return torch.ops.detectron2.box_iou_rotated(boxes1, boxes2)
|
||||
|
||||
|
||||
class RotatedBoxes(Boxes):
|
||||
"""
|
||||
This structure stores a list of rotated boxes as a Nx5 torch.Tensor.
|
||||
It supports some common methods about boxes
|
||||
(`area`, `clip`, `nonempty`, etc),
|
||||
and also behaves like a Tensor
|
||||
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
tensor (Tensor[float]): a Nx5 matrix. Each row is
|
||||
(x_center, y_center, width, height, angle),
|
||||
in which angle is represented in degrees.
|
||||
While there's no strict range restriction for it,
|
||||
the recommended principal range is between [-180, 180) degrees.
|
||||
|
||||
Assume we have a horizontal box B = (x_center, y_center, width, height),
|
||||
where width is along the x-axis and height is along the y-axis.
|
||||
The rotated box B_rot (x_center, y_center, width, height, angle)
|
||||
can be seen as:
|
||||
|
||||
1. When angle == 0:
|
||||
B_rot == B
|
||||
2. When angle > 0:
|
||||
B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CCW;
|
||||
3. When angle < 0:
|
||||
B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CW.
|
||||
|
||||
Mathematically, since the right-handed coordinate system for image space
|
||||
is (y, x), where y is top->down and x is left->right, the 4 vertices of the
|
||||
rotated rectangle :math:`(yr_i, xr_i)` (i = 1, 2, 3, 4) can be obtained from
|
||||
the vertices of the horizontal rectangle :math:`(y_i, x_i)` (i = 1, 2, 3, 4)
|
||||
in the following way (:math:`\\theta = angle*\\pi/180` is the angle in radians,
|
||||
:math:`(y_c, x_c)` is the center of the rectangle):
|
||||
|
||||
.. math::
|
||||
|
||||
yr_i = \\cos(\\theta) (y_i - y_c) - \\sin(\\theta) (x_i - x_c) + y_c,
|
||||
|
||||
xr_i = \\sin(\\theta) (y_i - y_c) + \\cos(\\theta) (x_i - x_c) + x_c,
|
||||
|
||||
which is the standard rigid-body rotation transformation.
|
||||
|
||||
Intuitively, the angle is
|
||||
(1) the rotation angle from y-axis in image space
|
||||
to the height vector (top->down in the box's local coordinate system)
|
||||
of the box in CCW, and
|
||||
(2) the rotation angle from x-axis in image space
|
||||
to the width vector (left->right in the box's local coordinate system)
|
||||
of the box in CCW.
|
||||
|
||||
More intuitively, consider the following horizontal box ABCD represented
|
||||
in (x1, y1, x2, y2): (3, 2, 7, 4),
|
||||
covering the [3, 7] x [2, 4] region of the continuous coordinate system
|
||||
which looks like this:
|
||||
|
||||
.. code:: none
|
||||
|
||||
O--------> x
|
||||
|
|
||||
| A---B
|
||||
| | |
|
||||
| D---C
|
||||
|
|
||||
v y
|
||||
|
||||
Note that each capital letter represents one 0-dimensional geometric point
|
||||
instead of a 'square pixel' here.
|
||||
|
||||
In the example above, using (x, y) to represent a point we have:
|
||||
|
||||
.. math::
|
||||
|
||||
O = (0, 0), A = (3, 2), B = (7, 2), C = (7, 4), D = (3, 4)
|
||||
|
||||
We name vector AB = vector DC as the width vector in box's local coordinate system, and
|
||||
vector AD = vector BC as the height vector in box's local coordinate system. Initially,
|
||||
when angle = 0 degree, they're aligned with the positive directions of x-axis and y-axis
|
||||
in the image space, respectively.
|
||||
|
||||
For better illustration, we denote the center of the box as E,
|
||||
|
||||
.. code:: none
|
||||
|
||||
O--------> x
|
||||
|
|
||||
| A---B
|
||||
| | E |
|
||||
| D---C
|
||||
|
|
||||
v y
|
||||
|
||||
where the center E = ((3+7)/2, (2+4)/2) = (5, 3).
|
||||
|
||||
Also,
|
||||
|
||||
.. math::
|
||||
|
||||
width = |AB| = |CD| = 7 - 3 = 4,
|
||||
height = |AD| = |BC| = 4 - 2 = 2.
|
||||
|
||||
Therefore, the corresponding representation for the same shape in rotated box in
|
||||
(x_center, y_center, width, height, angle) format is:
|
||||
|
||||
(5, 3, 4, 2, 0),
|
||||
|
||||
Now, let's consider (5, 3, 4, 2, 90), which is rotated by 90 degrees
|
||||
CCW (counter-clockwise) by definition. It looks like this:
|
||||
|
||||
.. code:: none
|
||||
|
||||
O--------> x
|
||||
| B-C
|
||||
| | |
|
||||
| |E|
|
||||
| | |
|
||||
| A-D
|
||||
v y
|
||||
|
||||
The center E is still located at the same point (5, 3), while the vertices
|
||||
ABCD are rotated by 90 degrees CCW with regard to E:
|
||||
A = (4, 5), B = (4, 1), C = (6, 1), D = (6, 5)
|
||||
|
||||
Here, 90 degrees can be seen as the CCW angle to rotate from y-axis to
|
||||
vector AD or vector BC (the top->down height vector in box's local coordinate system),
|
||||
or the CCW angle to rotate from x-axis to vector AB or vector DC (the left->right
|
||||
width vector in box's local coordinate system).
|
||||
|
||||
.. math::
|
||||
|
||||
width = |AB| = |CD| = 5 - 1 = 4,
|
||||
height = |AD| = |BC| = 6 - 4 = 2.
|
||||
|
||||
Next, how about (5, 3, 4, 2, -90), which is rotated by 90 degrees CW (clockwise)
|
||||
by definition? It looks like this:
|
||||
|
||||
.. code:: none
|
||||
|
||||
O--------> x
|
||||
| D-A
|
||||
| | |
|
||||
| |E|
|
||||
| | |
|
||||
| C-B
|
||||
v y
|
||||
|
||||
The center E is still located at the same point (5, 3), while the vertices
|
||||
ABCD are rotated by 90 degrees CW with regard to E:
|
||||
A = (6, 1), B = (6, 5), C = (4, 5), D = (4, 1)
|
||||
|
||||
.. math::
|
||||
|
||||
width = |AB| = |CD| = 5 - 1 = 4,
|
||||
height = |AD| = |BC| = 6 - 4 = 2.
|
||||
|
||||
This covers exactly the same region as (5, 3, 4, 2, 90) does, and their IoU
|
||||
will be 1. However, these two will generate different RoI Pooling results and
|
||||
should not be treated as an identical box.
|
||||
|
||||
On the other hand, it's easy to see that (X, Y, W, H, A) is identical to
|
||||
(X, Y, W, H, A+360N), for any integer N. For example (5, 3, 4, 2, 270) would be
|
||||
identical to (5, 3, 4, 2, -90), because rotating the shape 270 degrees CCW is
|
||||
equivalent to rotating the same shape 90 degrees CW.
|
||||
|
||||
We could rotate further to get (5, 3, 4, 2, 180), or (5, 3, 4, 2, -180):
|
||||
|
||||
.. code:: none
|
||||
|
||||
O--------> x
|
||||
|
|
||||
| C---D
|
||||
| | E |
|
||||
| B---A
|
||||
|
|
||||
v y
|
||||
|
||||
.. math::
|
||||
|
||||
A = (7, 4), B = (3, 4), C = (3, 2), D = (7, 2),
|
||||
|
||||
width = |AB| = |CD| = 7 - 3 = 4,
|
||||
height = |AD| = |BC| = 4 - 2 = 2.
|
||||
|
||||
Finally, this is a very inaccurate (heavily quantized) illustration of
|
||||
how (5, 3, 4, 2, 60) looks like in case anyone wonders:
|
||||
|
||||
.. code:: none
|
||||
|
||||
O--------> x
|
||||
| B\
|
||||
| / C
|
||||
| /E /
|
||||
| A /
|
||||
| `D
|
||||
v y
|
||||
|
||||
It's still a rectangle with center of (5, 3), width of 4 and height of 2,
|
||||
but its angle (and thus orientation) is somewhere between
|
||||
(5, 3, 4, 2, 0) and (5, 3, 4, 2, 90).
|
||||
"""
|
||||
device = (
|
||||
tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
|
||||
)
|
||||
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
|
||||
if tensor.numel() == 0:
|
||||
# Use reshape, so we don't end up creating a new tensor that does not depend on
|
||||
# the inputs (and consequently confuses jit)
|
||||
tensor = tensor.reshape((0, 5)).to(dtype=torch.float32, device=device)
|
||||
assert tensor.dim() == 2 and tensor.size(-1) == 5, tensor.size()
|
||||
|
||||
self.tensor = tensor
|
||||
|
||||
def clone(self) -> "RotatedBoxes":
|
||||
"""
|
||||
Clone the RotatedBoxes.
|
||||
|
||||
Returns:
|
||||
RotatedBoxes
|
||||
"""
|
||||
return RotatedBoxes(self.tensor.clone())
|
||||
|
||||
def to(self, device: torch.device, non_blocking: bool = False):
|
||||
# Boxes are assumed float32 and does not support to(dtype)
|
||||
return RotatedBoxes(self.tensor.to(device=device, non_blocking=non_blocking))
|
||||
|
||||
def area(self) -> torch.Tensor:
|
||||
"""
|
||||
Computes the area of all the boxes.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: a vector with areas of each box.
|
||||
"""
|
||||
box = self.tensor
|
||||
area = box[:, 2] * box[:, 3]
|
||||
return area
|
||||
|
||||
# Avoid in-place operations so that we can torchscript; NOTE: this creates a new tensor
|
||||
def normalize_angles(self) -> None:
|
||||
"""
|
||||
Restrict angles to the range of [-180, 180) degrees
|
||||
"""
|
||||
angle_tensor = (self.tensor[:, 4] + 180.0) % 360.0 - 180.0
|
||||
self.tensor = torch.cat((self.tensor[:, :4], angle_tensor[:, None]), dim=1)
|
||||
|
||||
def clip(
|
||||
self, box_size: Tuple[int, int], clip_angle_threshold: float = 1.0
|
||||
) -> None:
|
||||
"""
|
||||
Clip (in place) the boxes by limiting x coordinates to the range [0, width]
|
||||
and y coordinates to the range [0, height].
|
||||
|
||||
For RRPN:
|
||||
Only clip boxes that are almost horizontal with a tolerance of
|
||||
clip_angle_threshold to maintain backward compatibility.
|
||||
|
||||
Rotated boxes beyond this threshold are not clipped for two reasons:
|
||||
|
||||
1. There are potentially multiple ways to clip a rotated box to make it
|
||||
fit within the image.
|
||||
2. It's tricky to make the entire rectangular box fit within the image
|
||||
and still be able to not leave out pixels of interest.
|
||||
|
||||
Therefore we rely on ops like RoIAlignRotated to safely handle this.
|
||||
|
||||
Args:
|
||||
box_size (height, width): The clipping box's size.
|
||||
clip_angle_threshold:
|
||||
Iff. abs(normalized(angle)) <= clip_angle_threshold (in degrees),
|
||||
we do the clipping as horizontal boxes.
|
||||
"""
|
||||
h, w = box_size
|
||||
|
||||
# normalize angles to be within (-180, 180] degrees
|
||||
self.normalize_angles()
|
||||
|
||||
idx = torch.where(torch.abs(self.tensor[:, 4]) <= clip_angle_threshold)[0]
|
||||
|
||||
# convert to (x1, y1, x2, y2)
|
||||
x1 = self.tensor[idx, 0] - self.tensor[idx, 2] / 2.0
|
||||
y1 = self.tensor[idx, 1] - self.tensor[idx, 3] / 2.0
|
||||
x2 = self.tensor[idx, 0] + self.tensor[idx, 2] / 2.0
|
||||
y2 = self.tensor[idx, 1] + self.tensor[idx, 3] / 2.0
|
||||
|
||||
# clip
|
||||
x1.clamp_(min=0, max=w)
|
||||
y1.clamp_(min=0, max=h)
|
||||
x2.clamp_(min=0, max=w)
|
||||
y2.clamp_(min=0, max=h)
|
||||
|
||||
# convert back to (xc, yc, w, h)
|
||||
self.tensor[idx, 0] = (x1 + x2) / 2.0
|
||||
self.tensor[idx, 1] = (y1 + y2) / 2.0
|
||||
# make sure widths and heights do not increase due to numerical errors
|
||||
self.tensor[idx, 2] = torch.min(self.tensor[idx, 2], x2 - x1)
|
||||
self.tensor[idx, 3] = torch.min(self.tensor[idx, 3], y2 - y1)
|
||||
|
||||
def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
|
||||
"""
|
||||
Find boxes that are non-empty.
|
||||
A box is considered empty, if either of its side is no larger than threshold.
|
||||
|
||||
Returns:
|
||||
Tensor: a binary vector which represents
|
||||
whether each box is empty (False) or non-empty (True).
|
||||
"""
|
||||
box = self.tensor
|
||||
widths = box[:, 2]
|
||||
heights = box[:, 3]
|
||||
keep = (widths > threshold) & (heights > threshold)
|
||||
return keep
|
||||
|
||||
def __getitem__(self, item) -> "RotatedBoxes":
|
||||
"""
|
||||
Returns:
|
||||
RotatedBoxes: Create a new :class:`RotatedBoxes` by indexing.
|
||||
|
||||
The following usage are allowed:
|
||||
|
||||
1. `new_boxes = boxes[3]`: return a `RotatedBoxes` which contains only one box.
|
||||
2. `new_boxes = boxes[2:10]`: return a slice of boxes.
|
||||
3. `new_boxes = boxes[vector]`, where vector is a torch.ByteTensor
|
||||
with `length = len(boxes)`. Nonzero elements in the vector will be selected.
|
||||
|
||||
Note that the returned RotatedBoxes might share storage with this RotatedBoxes,
|
||||
subject to Pytorch's indexing semantics.
|
||||
"""
|
||||
if isinstance(item, int):
|
||||
return RotatedBoxes(self.tensor[item].view(1, -1))
|
||||
b = self.tensor[item]
|
||||
assert (
|
||||
b.dim() == 2
|
||||
), "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
|
||||
return RotatedBoxes(b)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.tensor.shape[0]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "RotatedBoxes(" + str(self.tensor) + ")"
|
||||
|
||||
def inside_box(
|
||||
self, box_size: Tuple[int, int], boundary_threshold: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
box_size (height, width): Size of the reference box covering
|
||||
[0, width] x [0, height]
|
||||
boundary_threshold (int): Boxes that extend beyond the reference box
|
||||
boundary by more than boundary_threshold are considered "outside".
|
||||
|
||||
For RRPN, it might not be necessary to call this function since it's common
|
||||
for rotated box to extend to outside of the image boundaries
|
||||
(the clip function only clips the near-horizontal boxes)
|
||||
|
||||
Returns:
|
||||
a binary vector, indicating whether each box is inside the reference box.
|
||||
"""
|
||||
height, width = box_size
|
||||
|
||||
cnt_x = self.tensor[..., 0]
|
||||
cnt_y = self.tensor[..., 1]
|
||||
half_w = self.tensor[..., 2] / 2.0
|
||||
half_h = self.tensor[..., 3] / 2.0
|
||||
a = self.tensor[..., 4]
|
||||
c = torch.abs(torch.cos(a * math.pi / 180.0))
|
||||
s = torch.abs(torch.sin(a * math.pi / 180.0))
|
||||
# This basically computes the horizontal bounding rectangle of the rotated box
|
||||
max_rect_dx = c * half_w + s * half_h
|
||||
max_rect_dy = c * half_h + s * half_w
|
||||
|
||||
inds_inside = (
|
||||
(cnt_x - max_rect_dx >= -boundary_threshold)
|
||||
& (cnt_y - max_rect_dy >= -boundary_threshold)
|
||||
& (cnt_x + max_rect_dx < width + boundary_threshold)
|
||||
& (cnt_y + max_rect_dy < height + boundary_threshold)
|
||||
)
|
||||
|
||||
return inds_inside
|
||||
|
||||
def get_centers(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns:
|
||||
The box centers in a Nx2 array of (x, y).
|
||||
"""
|
||||
return self.tensor[:, :2]
|
||||
|
||||
def scale(self, scale_x: float, scale_y: float) -> None:
|
||||
"""
|
||||
Scale the rotated box with horizontal and vertical scaling factors
|
||||
Note: when scale_factor_x != scale_factor_y,
|
||||
the rotated box does not preserve the rectangular shape when the angle
|
||||
is not a multiple of 90 degrees under resize transformation.
|
||||
Instead, the shape is a parallelogram (that has skew)
|
||||
Here we make an approximation by fitting a rotated rectangle to the parallelogram.
|
||||
"""
|
||||
self.tensor[:, 0] *= scale_x
|
||||
self.tensor[:, 1] *= scale_y
|
||||
theta = self.tensor[:, 4] * math.pi / 180.0
|
||||
c = torch.cos(theta)
|
||||
s = torch.sin(theta)
|
||||
|
||||
# In image space, y is top->down and x is left->right
|
||||
# Consider the local coordintate system for the rotated box,
|
||||
# where the box center is located at (0, 0), and the four vertices ABCD are
|
||||
# A(-w / 2, -h / 2), B(w / 2, -h / 2), C(w / 2, h / 2), D(-w / 2, h / 2)
|
||||
# the midpoint of the left edge AD of the rotated box E is:
|
||||
# E = (A+D)/2 = (-w / 2, 0)
|
||||
# the midpoint of the top edge AB of the rotated box F is:
|
||||
# F(0, -h / 2)
|
||||
# To get the old coordinates in the global system, apply the rotation transformation
|
||||
# (Note: the right-handed coordinate system for image space is yOx):
|
||||
# (old_x, old_y) = (s * y + c * x, c * y - s * x)
|
||||
# E(old) = (s * 0 + c * (-w/2), c * 0 - s * (-w/2)) = (-c * w / 2, s * w / 2)
|
||||
# F(old) = (s * (-h / 2) + c * 0, c * (-h / 2) - s * 0) = (-s * h / 2, -c * h / 2)
|
||||
# After applying the scaling factor (sfx, sfy):
|
||||
# E(new) = (-sfx * c * w / 2, sfy * s * w / 2)
|
||||
# F(new) = (-sfx * s * h / 2, -sfy * c * h / 2)
|
||||
# The new width after scaling tranformation becomes:
|
||||
|
||||
# w(new) = |E(new) - O| * 2
|
||||
# = sqrt[(sfx * c * w / 2)^2 + (sfy * s * w / 2)^2] * 2
|
||||
# = sqrt[(sfx * c)^2 + (sfy * s)^2] * w
|
||||
# i.e., scale_factor_w = sqrt[(sfx * c)^2 + (sfy * s)^2]
|
||||
#
|
||||
# For example,
|
||||
# when angle = 0 or 180, |c| = 1, s = 0, scale_factor_w == scale_factor_x;
|
||||
# when |angle| = 90, c = 0, |s| = 1, scale_factor_w == scale_factor_y
|
||||
self.tensor[:, 2] *= torch.sqrt((scale_x * c) ** 2 + (scale_y * s) ** 2)
|
||||
|
||||
# h(new) = |F(new) - O| * 2
|
||||
# = sqrt[(sfx * s * h / 2)^2 + (sfy * c * h / 2)^2] * 2
|
||||
# = sqrt[(sfx * s)^2 + (sfy * c)^2] * h
|
||||
# i.e., scale_factor_h = sqrt[(sfx * s)^2 + (sfy * c)^2]
|
||||
#
|
||||
# For example,
|
||||
# when angle = 0 or 180, |c| = 1, s = 0, scale_factor_h == scale_factor_y;
|
||||
# when |angle| = 90, c = 0, |s| = 1, scale_factor_h == scale_factor_x
|
||||
self.tensor[:, 3] *= torch.sqrt((scale_x * s) ** 2 + (scale_y * c) ** 2)
|
||||
|
||||
# The angle is the rotation angle from y-axis in image space to the height
|
||||
# vector (top->down in the box's local coordinate system) of the box in CCW.
|
||||
#
|
||||
# angle(new) = angle_yOx(O - F(new))
|
||||
# = angle_yOx( (sfx * s * h / 2, sfy * c * h / 2) )
|
||||
# = atan2(sfx * s * h / 2, sfy * c * h / 2)
|
||||
# = atan2(sfx * s, sfy * c)
|
||||
#
|
||||
# For example,
|
||||
# when sfx == sfy, angle(new) == atan2(s, c) == angle(old)
|
||||
self.tensor[:, 4] = torch.atan2(scale_x * s, scale_y * c) * 180 / math.pi
|
||||
|
||||
@classmethod
|
||||
def cat(cls, boxes_list: List["RotatedBoxes"]) -> "RotatedBoxes":
|
||||
"""
|
||||
Concatenates a list of RotatedBoxes into a single RotatedBoxes
|
||||
|
||||
Arguments:
|
||||
boxes_list (list[RotatedBoxes])
|
||||
|
||||
Returns:
|
||||
RotatedBoxes: the concatenated RotatedBoxes
|
||||
"""
|
||||
assert isinstance(boxes_list, (list, tuple))
|
||||
if len(boxes_list) == 0:
|
||||
return cls(torch.empty(0))
|
||||
assert all([isinstance(box, RotatedBoxes) for box in boxes_list])
|
||||
|
||||
# use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
|
||||
cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
|
||||
return cat_boxes
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.tensor.device
|
||||
|
||||
@torch.jit.unused
|
||||
def __iter__(self):
|
||||
"""
|
||||
Yield a box as a Tensor of shape (5,) at a time.
|
||||
"""
|
||||
yield from self.tensor
|
||||
|
||||
|
||||
def pairwise_iou(boxes1: RotatedBoxes, boxes2: RotatedBoxes) -> None:
|
||||
"""
|
||||
Given two lists of rotated boxes of size N and M,
|
||||
compute the IoU (intersection over union)
|
||||
between **all** N x M pairs of boxes.
|
||||
The box order must be (x_center, y_center, width, height, angle).
|
||||
|
||||
Args:
|
||||
boxes1, boxes2 (RotatedBoxes):
|
||||
two `RotatedBoxes`. Contains N & M rotated boxes, respectively.
|
||||
|
||||
Returns:
|
||||
Tensor: IoU, sized [N,M].
|
||||
"""
|
||||
|
||||
return pairwise_iou_rotated(boxes1.tensor, boxes2.tensor)
|
||||
406
sam3/agent/helpers/som_utils.py
Normal file
406
sam3/agent/helpers/som_utils.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import colorsys
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import matplotlib as mpl
|
||||
import matplotlib.colors as mplc
|
||||
import numpy as np
|
||||
import pycocotools.mask as mask_utils
|
||||
|
||||
|
||||
def rgb_to_hex(rgb_color):
|
||||
"""
|
||||
Convert a rgb color to hex color.
|
||||
|
||||
Args:
|
||||
rgb_color (tuple/list of ints): RGB color in tuple or list format.
|
||||
|
||||
Returns:
|
||||
str: Hex color.
|
||||
|
||||
Example:
|
||||
```
|
||||
>>> rgb_to_hex((255, 0, 244))
|
||||
'#ff00ff'
|
||||
```
|
||||
"""
|
||||
return "#" + "".join([hex(c)[2:].zfill(2) for c in rgb_color])
|
||||
|
||||
|
||||
# DEFAULT_COLOR_HEX_TO_NAME = {
|
||||
# rgb_to_hex((255, 0, 0)): "red",
|
||||
# rgb_to_hex((0, 255, 0)): "lime",
|
||||
# rgb_to_hex((0, 0, 255)): "blue",
|
||||
# rgb_to_hex((255, 255, 0)): "yellow",
|
||||
# rgb_to_hex((255, 0, 255)): "fuchsia",
|
||||
# rgb_to_hex((0, 255, 255)): "aqua",
|
||||
# rgb_to_hex((255, 165, 0)): "orange",
|
||||
# rgb_to_hex((128, 0, 128)): "purple",
|
||||
# rgb_to_hex((255, 215, 0)): "gold",
|
||||
# }
|
||||
|
||||
# Assuming rgb_to_hex is a function that converts an (R, G, B) tuple to a hex string.
|
||||
# For example: def rgb_to_hex(rgb): return '#%02x%02x%02x' % rgb
|
||||
|
||||
DEFAULT_COLOR_HEX_TO_NAME = {
|
||||
# The top 20 approved colors
|
||||
rgb_to_hex((255, 255, 0)): "yellow",
|
||||
rgb_to_hex((0, 255, 0)): "lime",
|
||||
rgb_to_hex((0, 255, 255)): "cyan",
|
||||
rgb_to_hex((255, 0, 255)): "magenta",
|
||||
rgb_to_hex((255, 0, 0)): "red",
|
||||
rgb_to_hex((255, 127, 0)): "orange",
|
||||
rgb_to_hex((127, 255, 0)): "chartreuse",
|
||||
rgb_to_hex((0, 255, 127)): "spring green",
|
||||
rgb_to_hex((255, 0, 127)): "rose",
|
||||
rgb_to_hex((127, 0, 255)): "violet",
|
||||
rgb_to_hex((192, 255, 0)): "electric lime",
|
||||
rgb_to_hex((255, 192, 0)): "vivid orange",
|
||||
rgb_to_hex((0, 255, 192)): "turquoise",
|
||||
rgb_to_hex((192, 0, 255)): "bright violet",
|
||||
rgb_to_hex((255, 0, 192)): "bright pink",
|
||||
rgb_to_hex((255, 64, 0)): "fiery orange",
|
||||
rgb_to_hex((64, 255, 0)): "bright chartreuse",
|
||||
rgb_to_hex((0, 255, 64)): "malachite",
|
||||
rgb_to_hex((64, 0, 255)): "deep violet",
|
||||
rgb_to_hex((255, 0, 64)): "hot pink",
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_COLOR_PALETTE = list(DEFAULT_COLOR_HEX_TO_NAME.keys())
|
||||
|
||||
|
||||
def _validate_color_hex(color_hex: str):
|
||||
color_hex = color_hex.lstrip("#")
|
||||
if not all(c in "0123456789abcdefABCDEF" for c in color_hex):
|
||||
raise ValueError("Invalid characters in color hash")
|
||||
if len(color_hex) not in (3, 6):
|
||||
raise ValueError("Invalid length of color hash")
|
||||
|
||||
|
||||
# copied from https://github.com/roboflow/supervision/blob/c8f557af0c61b5c03392bad2cc36c8835598b1e1/supervision/draw/color.py
|
||||
@dataclass
|
||||
class Color:
|
||||
"""
|
||||
Represents a color in RGB format.
|
||||
|
||||
Attributes:
|
||||
r (int): Red channel.
|
||||
g (int): Green channel.
|
||||
b (int): Blue channel.
|
||||
"""
|
||||
|
||||
r: int
|
||||
g: int
|
||||
b: int
|
||||
|
||||
@classmethod
|
||||
def from_hex(cls, color_hex: str):
|
||||
"""
|
||||
Create a Color instance from a hex string.
|
||||
|
||||
Args:
|
||||
color_hex (str): Hex string of the color.
|
||||
|
||||
Returns:
|
||||
Color: Instance representing the color.
|
||||
|
||||
Example:
|
||||
```
|
||||
>>> Color.from_hex('#ff00ff')
|
||||
Color(r=255, g=0, b=255)
|
||||
```
|
||||
"""
|
||||
_validate_color_hex(color_hex)
|
||||
color_hex = color_hex.lstrip("#")
|
||||
if len(color_hex) == 3:
|
||||
color_hex = "".join(c * 2 for c in color_hex)
|
||||
r, g, b = (int(color_hex[i : i + 2], 16) for i in range(0, 6, 2))
|
||||
return cls(r, g, b)
|
||||
|
||||
@classmethod
|
||||
def to_hex(cls, color):
|
||||
"""
|
||||
Convert a Color instance to a hex string.
|
||||
|
||||
Args:
|
||||
color (Color): Color instance of color.
|
||||
|
||||
Returns:
|
||||
Color: a hex string.
|
||||
"""
|
||||
return rgb_to_hex((color.r, color.g, color.b))
|
||||
|
||||
def as_rgb(self) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Returns the color as an RGB tuple.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int, int]: RGB tuple.
|
||||
|
||||
Example:
|
||||
```
|
||||
>>> color.as_rgb()
|
||||
(255, 0, 255)
|
||||
```
|
||||
"""
|
||||
return self.r, self.g, self.b
|
||||
|
||||
def as_bgr(self) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Returns the color as a BGR tuple.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int, int]: BGR tuple.
|
||||
|
||||
Example:
|
||||
```
|
||||
>>> color.as_bgr()
|
||||
(255, 0, 255)
|
||||
```
|
||||
"""
|
||||
return self.b, self.g, self.r
|
||||
|
||||
@classmethod
|
||||
def white(cls):
|
||||
return Color.from_hex(color_hex="#ffffff")
|
||||
|
||||
@classmethod
|
||||
def black(cls):
|
||||
return Color.from_hex(color_hex="#000000")
|
||||
|
||||
@classmethod
|
||||
def red(cls):
|
||||
return Color.from_hex(color_hex="#ff0000")
|
||||
|
||||
@classmethod
|
||||
def green(cls):
|
||||
return Color.from_hex(color_hex="#00ff00")
|
||||
|
||||
@classmethod
|
||||
def blue(cls):
|
||||
return Color.from_hex(color_hex="#0000ff")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColorPalette:
|
||||
colors: List[Color]
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
"""
|
||||
Returns a default color palette.
|
||||
|
||||
Returns:
|
||||
ColorPalette: A ColorPalette instance with default colors.
|
||||
|
||||
Example:
|
||||
```
|
||||
>>> ColorPalette.default()
|
||||
ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
|
||||
```
|
||||
"""
|
||||
return ColorPalette.from_hex(color_hex_list=DEFAULT_COLOR_PALETTE)
|
||||
|
||||
@classmethod
|
||||
def from_hex(cls, color_hex_list: List[str]):
|
||||
"""
|
||||
Create a ColorPalette instance from a list of hex strings.
|
||||
|
||||
Args:
|
||||
color_hex_list (List[str]): List of color hex strings.
|
||||
|
||||
Returns:
|
||||
ColorPalette: A ColorPalette instance.
|
||||
|
||||
Example:
|
||||
```
|
||||
>>> ColorPalette.from_hex(['#ff0000', '#00ff00', '#0000ff'])
|
||||
ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
|
||||
```
|
||||
"""
|
||||
colors = [Color.from_hex(color_hex) for color_hex in color_hex_list]
|
||||
return cls(colors)
|
||||
|
||||
def by_idx(self, idx: int) -> Color:
|
||||
"""
|
||||
Return the color at a given index in the palette.
|
||||
|
||||
Args:
|
||||
idx (int): Index of the color in the palette.
|
||||
|
||||
Returns:
|
||||
Color: Color at the given index.
|
||||
|
||||
Example:
|
||||
```
|
||||
>>> color_palette.by_idx(1)
|
||||
Color(r=0, g=255, b=0)
|
||||
```
|
||||
"""
|
||||
if idx < 0:
|
||||
raise ValueError("idx argument should not be negative")
|
||||
idx = idx % len(self.colors)
|
||||
return self.colors[idx]
|
||||
|
||||
def find_farthest_color(self, img_array):
|
||||
"""
|
||||
Return the color that is the farthest from the given color.
|
||||
|
||||
Args:
|
||||
img_array (np array): any *x3 np array, 3 is the RGB color channel.
|
||||
|
||||
Returns:
|
||||
Color: Farthest color.
|
||||
|
||||
"""
|
||||
# Reshape the image array for broadcasting
|
||||
img_array = img_array.reshape((-1, 3))
|
||||
|
||||
# Convert colors dictionary to a NumPy array
|
||||
color_values = np.array([[c.r, c.g, c.b] for c in self.colors])
|
||||
|
||||
# Calculate the Euclidean distance between the colors and each pixel in the image
|
||||
# Broadcasting happens here: img_array shape is (num_pixels, 3), color_values shape is (num_colors, 3)
|
||||
distances = np.sqrt(
|
||||
np.sum((img_array[:, np.newaxis, :] - color_values) ** 2, axis=2)
|
||||
)
|
||||
|
||||
# Average the distances for each color
|
||||
mean_distances = np.mean(distances, axis=0)
|
||||
|
||||
# return the farthest color
|
||||
farthest_idx = np.argmax(mean_distances)
|
||||
farthest_color = self.colors[farthest_idx]
|
||||
farthest_color_hex = Color.to_hex(farthest_color)
|
||||
if farthest_color_hex in DEFAULT_COLOR_HEX_TO_NAME:
|
||||
farthest_color_name = DEFAULT_COLOR_HEX_TO_NAME[farthest_color_hex]
|
||||
else:
|
||||
farthest_color_name = "unknown"
|
||||
|
||||
return farthest_color, farthest_color_name
|
||||
|
||||
|
||||
def draw_box(ax, box_coord, alpha=0.8, edge_color="g", line_style="-", linewidth=2.0):
|
||||
x0, y0, width, height = box_coord
|
||||
ax.add_patch(
|
||||
mpl.patches.Rectangle(
|
||||
(x0, y0),
|
||||
width,
|
||||
height,
|
||||
fill=False,
|
||||
edgecolor=edge_color,
|
||||
linewidth=linewidth,
|
||||
alpha=alpha,
|
||||
linestyle=line_style,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def draw_text(
|
||||
ax,
|
||||
text,
|
||||
position,
|
||||
font_size=None,
|
||||
color="g",
|
||||
horizontal_alignment="left",
|
||||
rotation=0,
|
||||
):
|
||||
if not font_size:
|
||||
font_size = mpl.rcParams["font.size"]
|
||||
|
||||
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
||||
color[np.argmax(color)] = max(0.8, np.max(color))
|
||||
|
||||
x, y = position
|
||||
ax.text(
|
||||
x,
|
||||
y,
|
||||
text,
|
||||
size=font_size,
|
||||
family="sans-serif",
|
||||
bbox={"facecolor": "none", "alpha": 0.5, "pad": 0.7, "edgecolor": "none"},
|
||||
verticalalignment="top",
|
||||
horizontalalignment=horizontal_alignment,
|
||||
color=color,
|
||||
rotation=rotation,
|
||||
)
|
||||
|
||||
|
||||
def draw_mask(
|
||||
ax, rle, color, show_holes=True, alpha=0.15, upsample_factor=1.0, rle_upsampled=None
|
||||
):
|
||||
if isinstance(rle, dict):
|
||||
mask = mask_utils.decode(rle)
|
||||
elif isinstance(rle, np.ndarray):
|
||||
mask = rle
|
||||
else:
|
||||
raise ValueError(f"Unsupported type for rle: {type(rle)}")
|
||||
|
||||
mask_upsampled = None
|
||||
if upsample_factor > 1.0 and show_holes:
|
||||
assert rle_upsampled is not None
|
||||
if isinstance(rle_upsampled, dict):
|
||||
mask_upsampled = mask_utils.decode(rle_upsampled)
|
||||
elif isinstance(rle_upsampled, np.ndarray):
|
||||
mask_upsampled = rle_upsampled
|
||||
else:
|
||||
raise ValueError(f"Unsupported type for rle: {type(rle)}")
|
||||
|
||||
if show_holes:
|
||||
if mask_upsampled is None:
|
||||
mask_upsampled = mask
|
||||
h, w = mask_upsampled.shape
|
||||
mask_img = np.zeros((h, w, 4))
|
||||
mask_img[:, :, :-1] = color[np.newaxis, np.newaxis, :]
|
||||
mask_img[:, :, -1] = mask_upsampled * alpha
|
||||
ax.imshow(mask_img)
|
||||
|
||||
*_, contours, _ = cv2.findContours(
|
||||
mask.astype(np.uint8).copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
upsampled_contours = [(cont + 0.5) * upsample_factor - 0.5 for cont in contours]
|
||||
facecolor = (0, 0, 0, 0) if show_holes else color
|
||||
if alpha > 0.8:
|
||||
edge_color = _change_color_brightness(color, brightness_factor=-0.7)
|
||||
else:
|
||||
edge_color = color
|
||||
for cont in upsampled_contours:
|
||||
polygon = mpl.patches.Polygon(
|
||||
[el[0] for el in cont],
|
||||
edgecolor=edge_color,
|
||||
linewidth=2.0,
|
||||
facecolor=facecolor,
|
||||
)
|
||||
ax.add_patch(polygon)
|
||||
|
||||
|
||||
def _change_color_brightness(color, brightness_factor):
|
||||
"""
|
||||
Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
|
||||
less or more saturation than the original color.
|
||||
|
||||
Args:
|
||||
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
||||
formats that are accepted.
|
||||
brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
|
||||
0 will correspond to no change, a factor in [-1.0, 0) range will result in
|
||||
a darker color and a factor in (0, 1.0] range will result in a lighter color.
|
||||
|
||||
Returns:
|
||||
modified_color (tuple[double]): a tuple containing the RGB values of the
|
||||
modified color. Each value in the tuple is in the [0.0, 1.0] range.
|
||||
"""
|
||||
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
|
||||
color = mplc.to_rgb(color)
|
||||
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
|
||||
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
|
||||
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
|
||||
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
|
||||
modified_color = colorsys.hls_to_rgb(
|
||||
polygon_color[0], modified_lightness, polygon_color[2]
|
||||
)
|
||||
return modified_color
|
||||
1662
sam3/agent/helpers/visualizer.py
Executable file
1662
sam3/agent/helpers/visualizer.py
Executable file
File diff suppressed because it is too large
Load Diff
195
sam3/agent/helpers/zoom_in.py
Normal file
195
sam3/agent/helpers/zoom_in.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import io
|
||||
import math
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pycocotools.mask as mask_utils
|
||||
from PIL import Image
|
||||
|
||||
from .som_utils import ColorPalette, draw_box, draw_mask, draw_text
|
||||
|
||||
|
||||
def render_zoom_in(
|
||||
object_data,
|
||||
image_file,
|
||||
show_box: bool = True,
|
||||
show_text: bool = False,
|
||||
show_holes: bool = True,
|
||||
mask_alpha: float = 0.15,
|
||||
):
|
||||
"""
|
||||
Render a two-panel visualization with a cropped original view (left/upper) and a zoomed-in
|
||||
mask overlay (right/lower), then return it as a PIL.Image along with the chosen mask color (hex).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
object_data : dict
|
||||
Dict containing "labels" and COCO RLE "segmentation".
|
||||
Expected:
|
||||
object_data["labels"][0]["noun_phrase"] : str
|
||||
object_data["segmentation"] : COCO RLE (with "size": [H, W])
|
||||
image_file : PIL.Image.Image
|
||||
Source image (PIL).
|
||||
show_box : bool
|
||||
Whether to draw the bbox on the cropped original panel.
|
||||
show_text : bool
|
||||
Whether to draw the noun phrase label near the bbox.
|
||||
show_holes : bool
|
||||
Whether to render mask holes (passed through to draw_mask).
|
||||
mask_alpha : float
|
||||
Alpha for the mask overlay.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pil_img : PIL.Image.Image
|
||||
The composed visualization image.
|
||||
color_hex : str
|
||||
Hex string of the chosen mask color.
|
||||
"""
|
||||
|
||||
# ---- local constants (avoid module-level globals) ----
|
||||
_AREA_LARGE = 0.25
|
||||
_AREA_MEDIUM = 0.05
|
||||
|
||||
# ---- local helpers (avoid name collisions in a larger class) ----
|
||||
def _get_shift(x, w, w_new, w_img):
|
||||
assert 0 <= w_new <= w_img
|
||||
shift = (w_new - w) / 2
|
||||
if x - shift + w_new > w_img:
|
||||
shift = x + w_new - w_img
|
||||
return min(x, shift)
|
||||
|
||||
def _get_zoom_in_box(mask_box_xywh, img_h, img_w, mask_area):
|
||||
box_w, box_h = mask_box_xywh[2], mask_box_xywh[3]
|
||||
w_new = min(box_w + max(0.2 * box_w, 16), img_w)
|
||||
h_new = min(box_h + max(0.2 * box_h, 16), img_h)
|
||||
|
||||
mask_relative_area = mask_area / (w_new * h_new)
|
||||
|
||||
# zoom-in (larger box if mask is relatively big)
|
||||
w_new_large, h_new_large = w_new, h_new
|
||||
if mask_relative_area > _AREA_LARGE:
|
||||
ratio_large = math.sqrt(mask_relative_area / _AREA_LARGE)
|
||||
w_new_large = min(w_new * ratio_large, img_w)
|
||||
h_new_large = min(h_new * ratio_large, img_h)
|
||||
|
||||
w_shift_large = _get_shift(
|
||||
mask_box_xywh[0], mask_box_xywh[2], w_new_large, img_w
|
||||
)
|
||||
h_shift_large = _get_shift(
|
||||
mask_box_xywh[1], mask_box_xywh[3], h_new_large, img_h
|
||||
)
|
||||
zoom_in_box = [
|
||||
mask_box_xywh[0] - w_shift_large,
|
||||
mask_box_xywh[1] - h_shift_large,
|
||||
w_new_large,
|
||||
h_new_large,
|
||||
]
|
||||
|
||||
# crop box for the original/cropped image
|
||||
w_new_medium, h_new_medium = w_new, h_new
|
||||
if mask_relative_area > _AREA_MEDIUM:
|
||||
ratio_med = math.sqrt(mask_relative_area / _AREA_MEDIUM)
|
||||
w_new_medium = min(w_new * ratio_med, img_w)
|
||||
h_new_medium = min(h_new * ratio_med, img_h)
|
||||
|
||||
w_shift_medium = _get_shift(
|
||||
mask_box_xywh[0], mask_box_xywh[2], w_new_medium, img_w
|
||||
)
|
||||
h_shift_medium = _get_shift(
|
||||
mask_box_xywh[1], mask_box_xywh[3], h_new_medium, img_h
|
||||
)
|
||||
img_crop_box = [
|
||||
mask_box_xywh[0] - w_shift_medium,
|
||||
mask_box_xywh[1] - h_shift_medium,
|
||||
w_new_medium,
|
||||
h_new_medium,
|
||||
]
|
||||
return zoom_in_box, img_crop_box
|
||||
|
||||
# ---- main body ----
|
||||
# Input parsing
|
||||
object_label = object_data["labels"][0]["noun_phrase"]
|
||||
img = image_file.convert("RGB")
|
||||
bbox_xywh = mask_utils.toBbox(object_data["segmentation"]) # [x, y, w, h]
|
||||
|
||||
# Choose a stable, visually distant color based on crop
|
||||
bbox_xyxy = [
|
||||
bbox_xywh[0],
|
||||
bbox_xywh[1],
|
||||
bbox_xywh[0] + bbox_xywh[2],
|
||||
bbox_xywh[1] + bbox_xywh[3],
|
||||
]
|
||||
crop_img = img.crop(bbox_xyxy)
|
||||
color_palette = ColorPalette.default()
|
||||
color_obj, _ = color_palette.find_farthest_color(np.array(crop_img))
|
||||
color = np.array([color_obj.r / 255, color_obj.g / 255, color_obj.b / 255])
|
||||
color_hex = f"#{color_obj.r:02x}{color_obj.g:02x}{color_obj.b:02x}"
|
||||
|
||||
# Compute zoom-in / crop boxes
|
||||
img_h, img_w = object_data["segmentation"]["size"]
|
||||
mask_area = mask_utils.area(object_data["segmentation"])
|
||||
zoom_in_box, img_crop_box = _get_zoom_in_box(bbox_xywh, img_h, img_w, mask_area)
|
||||
|
||||
# Layout choice
|
||||
w, h = img_crop_box[2], img_crop_box[3]
|
||||
if w < h:
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2)
|
||||
else:
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1)
|
||||
|
||||
# Panel 1: cropped original with optional box/text
|
||||
img_crop_box_xyxy = [
|
||||
img_crop_box[0],
|
||||
img_crop_box[1],
|
||||
img_crop_box[0] + img_crop_box[2],
|
||||
img_crop_box[1] + img_crop_box[3],
|
||||
]
|
||||
img1 = img.crop(img_crop_box_xyxy)
|
||||
bbox_xywh_rel = [
|
||||
bbox_xywh[0] - img_crop_box[0],
|
||||
bbox_xywh[1] - img_crop_box[1],
|
||||
bbox_xywh[2],
|
||||
bbox_xywh[3],
|
||||
]
|
||||
ax1.imshow(img1)
|
||||
ax1.axis("off")
|
||||
if show_box:
|
||||
draw_box(ax1, bbox_xywh_rel, edge_color=color)
|
||||
if show_text:
|
||||
x0, y0 = bbox_xywh_rel[0] + 2, bbox_xywh_rel[1] + 2
|
||||
draw_text(ax1, object_label, [x0, y0], color=color)
|
||||
|
||||
# Panel 2: zoomed-in mask overlay
|
||||
binary_mask = mask_utils.decode(object_data["segmentation"])
|
||||
alpha = Image.fromarray((binary_mask * 255).astype("uint8"))
|
||||
img_rgba = img.convert("RGBA")
|
||||
img_rgba.putalpha(alpha)
|
||||
zoom_in_box_xyxy = [
|
||||
zoom_in_box[0],
|
||||
zoom_in_box[1],
|
||||
zoom_in_box[0] + zoom_in_box[2],
|
||||
zoom_in_box[1] + zoom_in_box[3],
|
||||
]
|
||||
img_with_alpha_zoomin = img_rgba.crop(zoom_in_box_xyxy)
|
||||
alpha_zoomin = img_with_alpha_zoomin.split()[3]
|
||||
binary_mask_zoomin = np.array(alpha_zoomin).astype(bool)
|
||||
|
||||
ax2.imshow(img_with_alpha_zoomin.convert("RGB"))
|
||||
ax2.axis("off")
|
||||
draw_mask(
|
||||
ax2, binary_mask_zoomin, color=color, show_holes=show_holes, alpha=mask_alpha
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Buffer -> PIL.Image
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=100)
|
||||
plt.close(fig)
|
||||
buf.seek(0)
|
||||
pil_img = Image.open(buf)
|
||||
|
||||
return pil_img, color_hex
|
||||
Reference in New Issue
Block a user