Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
244
sam3/agent/helpers/keypoints.py
Executable file
244
sam3/agent/helpers/keypoints.py
Executable file
@@ -0,0 +1,244 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Keypoints:
|
||||
"""
|
||||
Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property
|
||||
containing the x,y location and visibility flag of each keypoint. This tensor has shape
|
||||
(N, K, 3) where N is the number of instances and K is the number of keypoints per instance.
|
||||
|
||||
The visibility flag follows the COCO format and must be one of three integers:
|
||||
|
||||
* v=0: not labeled (in which case x=y=0)
|
||||
* v=1: labeled but not visible
|
||||
* v=2: labeled and visible
|
||||
"""
|
||||
|
||||
def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
|
||||
"""
|
||||
Arguments:
|
||||
keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint.
|
||||
The shape should be (N, K, 3) where N is the number of
|
||||
instances, and K is the number of keypoints per instance.
|
||||
"""
|
||||
device = (
|
||||
keypoints.device
|
||||
if isinstance(keypoints, torch.Tensor)
|
||||
else torch.device("cpu")
|
||||
)
|
||||
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
|
||||
assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape
|
||||
self.tensor = keypoints
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.tensor.size(0)
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any) -> "Keypoints":
|
||||
return type(self)(self.tensor.to(*args, **kwargs))
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.tensor.device
|
||||
|
||||
def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Convert keypoint annotations to a heatmap of one-hot labels for training,
|
||||
as described in :paper:`Mask R-CNN`.
|
||||
|
||||
Arguments:
|
||||
boxes: Nx4 tensor, the boxes to draw the keypoints to
|
||||
|
||||
Returns:
|
||||
heatmaps:
|
||||
A tensor of shape (N, K), each element is integer spatial label
|
||||
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
|
||||
valid:
|
||||
A tensor of shape (N, K) containing whether each keypoint is in the roi or not.
|
||||
"""
|
||||
return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size)
|
||||
|
||||
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints":
|
||||
"""
|
||||
Create a new `Keypoints` by indexing on this `Keypoints`.
|
||||
|
||||
The following usage are allowed:
|
||||
|
||||
1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance.
|
||||
2. `new_kpts = kpts[2:10]`: return a slice of key points.
|
||||
3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
|
||||
with `length = len(kpts)`. Nonzero elements in the vector will be selected.
|
||||
|
||||
Note that the returned Keypoints might share storage with this Keypoints,
|
||||
subject to Pytorch's indexing semantics.
|
||||
"""
|
||||
if isinstance(item, int):
|
||||
return Keypoints([self.tensor[item]])
|
||||
return Keypoints(self.tensor[item])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = self.__class__.__name__ + "("
|
||||
s += "num_instances={})".format(len(self.tensor))
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def cat(keypoints_list: List["Keypoints"]) -> "Keypoints":
|
||||
"""
|
||||
Concatenates a list of Keypoints into a single Keypoints
|
||||
|
||||
Arguments:
|
||||
keypoints_list (list[Keypoints])
|
||||
|
||||
Returns:
|
||||
Keypoints: the concatenated Keypoints
|
||||
"""
|
||||
assert isinstance(keypoints_list, (list, tuple))
|
||||
assert len(keypoints_list) > 0
|
||||
assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list)
|
||||
|
||||
cat_kpts = type(keypoints_list[0])(
|
||||
torch.cat([kpts.tensor for kpts in keypoints_list], dim=0)
|
||||
)
|
||||
return cat_kpts
|
||||
|
||||
|
||||
def _keypoints_to_heatmap(
|
||||
keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space.
|
||||
|
||||
Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the
|
||||
closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the
|
||||
continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"):
|
||||
d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
|
||||
|
||||
Arguments:
|
||||
keypoints: tensor of keypoint locations in of shape (N, K, 3).
|
||||
rois: Nx4 tensor of rois in xyxy format
|
||||
heatmap_size: integer side length of square heatmap.
|
||||
|
||||
Returns:
|
||||
heatmaps: A tensor of shape (N, K) containing an integer spatial label
|
||||
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
|
||||
valid: A tensor of shape (N, K) containing whether each keypoint is in
|
||||
the roi or not.
|
||||
"""
|
||||
|
||||
if rois.numel() == 0:
|
||||
return rois.new().long(), rois.new().long()
|
||||
offset_x = rois[:, 0]
|
||||
offset_y = rois[:, 1]
|
||||
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
|
||||
scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
|
||||
|
||||
offset_x = offset_x[:, None]
|
||||
offset_y = offset_y[:, None]
|
||||
scale_x = scale_x[:, None]
|
||||
scale_y = scale_y[:, None]
|
||||
|
||||
x = keypoints[..., 0]
|
||||
y = keypoints[..., 1]
|
||||
|
||||
x_boundary_inds = x == rois[:, 2][:, None]
|
||||
y_boundary_inds = y == rois[:, 3][:, None]
|
||||
|
||||
x = (x - offset_x) * scale_x
|
||||
x = x.floor().long()
|
||||
y = (y - offset_y) * scale_y
|
||||
y = y.floor().long()
|
||||
|
||||
x[x_boundary_inds] = heatmap_size - 1
|
||||
y[y_boundary_inds] = heatmap_size - 1
|
||||
|
||||
valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
|
||||
vis = keypoints[..., 2] > 0
|
||||
valid = (valid_loc & vis).long()
|
||||
|
||||
lin_ind = y * heatmap_size + x
|
||||
heatmaps = lin_ind * valid
|
||||
|
||||
return heatmaps, valid
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extract predicted keypoint locations from heatmaps.
|
||||
|
||||
Args:
|
||||
maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for
|
||||
each ROI and each keypoint.
|
||||
rois (Tensor): (#ROIs, 4). The box of each ROI.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to
|
||||
(x, y, logit, score) for each keypoint.
|
||||
|
||||
When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate,
|
||||
we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
|
||||
Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
|
||||
"""
|
||||
|
||||
offset_x = rois[:, 0]
|
||||
offset_y = rois[:, 1]
|
||||
|
||||
widths = (rois[:, 2] - rois[:, 0]).clamp(min=1)
|
||||
heights = (rois[:, 3] - rois[:, 1]).clamp(min=1)
|
||||
widths_ceil = widths.ceil()
|
||||
heights_ceil = heights.ceil()
|
||||
|
||||
num_rois, num_keypoints = maps.shape[:2]
|
||||
xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4)
|
||||
|
||||
width_corrections = widths / widths_ceil
|
||||
height_corrections = heights / heights_ceil
|
||||
|
||||
keypoints_idx = torch.arange(num_keypoints, device=maps.device)
|
||||
|
||||
for i in range(num_rois):
|
||||
outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
|
||||
roi_map = F.interpolate(
|
||||
maps[[i]], size=outsize, mode="bicubic", align_corners=False
|
||||
)
|
||||
|
||||
# Although semantically equivalent, `reshape` is used instead of `squeeze` due
|
||||
# to limitation during ONNX export of `squeeze` in scripting mode
|
||||
roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W
|
||||
|
||||
# softmax over the spatial region
|
||||
max_score, _ = roi_map.view(num_keypoints, -1).max(1)
|
||||
max_score = max_score.view(num_keypoints, 1, 1)
|
||||
tmp_full_resolution = (roi_map - max_score).exp_()
|
||||
tmp_pool_resolution = (maps[i] - max_score).exp_()
|
||||
# Produce scores over the region H x W, but normalize with POOL_H x POOL_W,
|
||||
# so that the scores of objects of different absolute sizes will be more comparable
|
||||
roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum(
|
||||
(1, 2), keepdim=True
|
||||
)
|
||||
|
||||
w = roi_map.shape[2]
|
||||
pos = roi_map.view(num_keypoints, -1).argmax(1)
|
||||
|
||||
x_int = pos % w
|
||||
y_int = (pos - x_int) // w
|
||||
|
||||
assert (
|
||||
roi_map_scores[keypoints_idx, y_int, x_int]
|
||||
== roi_map_scores.view(num_keypoints, -1).max(1)[0]
|
||||
).all()
|
||||
|
||||
x = (x_int.float() + 0.5) * width_corrections[i]
|
||||
y = (y_int.float() + 0.5) * height_corrections[i]
|
||||
|
||||
xy_preds[i, :, 0] = x + offset_x[i]
|
||||
xy_preds[i, :, 1] = y + offset_y[i]
|
||||
xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int]
|
||||
xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int]
|
||||
|
||||
return xy_preds
|
||||
Reference in New Issue
Block a user