Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
113
sam3/train/loss/mask_sampling.py
Normal file
113
sam3/train/loss/mask_sampling.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
|
||||
def point_sample(input, point_coords, **kwargs):
|
||||
"""
|
||||
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
|
||||
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
|
||||
[0, 1] x [0, 1] square.
|
||||
|
||||
Args:
|
||||
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
|
||||
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
|
||||
[0, 1] x [0, 1] normalized point coordinates.
|
||||
|
||||
Returns:
|
||||
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
|
||||
features for points in `point_coords`. The features are obtained via bilinear
|
||||
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
|
||||
"""
|
||||
add_dim = False
|
||||
if point_coords.dim() == 3:
|
||||
add_dim = True
|
||||
point_coords = point_coords.unsqueeze(2)
|
||||
normalized_point_coords = 2.0 * point_coords - 1.0 # Normalize to [-1,1]
|
||||
output = F.grid_sample(input, normalized_point_coords, **kwargs)
|
||||
if add_dim:
|
||||
output = output.squeeze(3)
|
||||
return output
|
||||
|
||||
|
||||
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
|
||||
def get_uncertain_point_coords_with_randomness(
|
||||
logits: torch.Tensor,
|
||||
uncertainty_func: Callable,
|
||||
num_points: int,
|
||||
oversample_ratio: int,
|
||||
importance_sample_ratio: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
|
||||
are calculated for each point using 'uncertainty_func' function that takes point's logit
|
||||
prediction as input.
|
||||
See PointRend paper for details.
|
||||
|
||||
Args:
|
||||
logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
|
||||
class-specific or class-agnostic prediction.
|
||||
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
|
||||
contains logit predictions for P points and returns their uncertainties as a Tensor of
|
||||
shape (N, 1, P).
|
||||
num_points (int): The number of points P to sample.
|
||||
oversample_ratio (int): Oversampling parameter.
|
||||
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
|
||||
|
||||
Returns:
|
||||
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
|
||||
sampled points.
|
||||
"""
|
||||
assert oversample_ratio >= 1
|
||||
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
|
||||
num_boxes = logits.shape[0]
|
||||
num_sampled = int(num_points * oversample_ratio)
|
||||
point_coords = torch.rand(num_boxes, num_sampled, 2, device=logits.device)
|
||||
point_logits = point_sample(logits, point_coords, align_corners=False)
|
||||
# It is crucial to calculate uncertainty based on the sampled prediction value for the points.
|
||||
# Calculating uncertainties of the predictions first and sampling them for points leads
|
||||
# to incorrect results.
|
||||
# To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
|
||||
# two predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
|
||||
# However, if we calculate uncertainties for the predictions first,
|
||||
# both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
|
||||
point_uncertainties = uncertainty_func(point_logits)
|
||||
num_uncertain_points = int(importance_sample_ratio * num_points)
|
||||
num_random_points = num_points - num_uncertain_points
|
||||
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
||||
# Flatten the indices
|
||||
shift = num_sampled * torch.arange(
|
||||
num_boxes, dtype=torch.long, device=logits.device
|
||||
)
|
||||
idx += shift[:, None]
|
||||
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
|
||||
num_boxes, num_uncertain_points, 2
|
||||
)
|
||||
if num_random_points > 0:
|
||||
point_coords = torch.cat(
|
||||
[
|
||||
point_coords,
|
||||
torch.rand(num_boxes, num_random_points, 2, device=logits.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return point_coords
|
||||
|
||||
|
||||
# Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py
|
||||
def calculate_uncertainty(logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Estimates uncerainty as L1 distance between 0.0 and the logit prediction.
|
||||
Args:
|
||||
logits (Tensor): A tensor of shape (R, 1, ...) for class-agnostic
|
||||
predicted masks
|
||||
Returns:
|
||||
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
|
||||
the most uncertain locations having the highest uncertainty score.
|
||||
"""
|
||||
assert logits.shape[1] == 1
|
||||
return -(torch.abs(logits))
|
||||
Reference in New Issue
Block a user