Initial commit

fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
facebook-github-bot
2025-11-18 23:07:42 -08:00
commit a13e358df4
504 changed files with 122758 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

1319
sam3/train/loss/loss_fns.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Callable
import torch
from torch.nn import functional as F
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
def point_sample(input, point_coords, **kwargs):
"""
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
[0, 1] x [0, 1] square.
Args:
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
[0, 1] x [0, 1] normalized point coordinates.
Returns:
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
features for points in `point_coords`. The features are obtained via bilinear
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
"""
add_dim = False
if point_coords.dim() == 3:
add_dim = True
point_coords = point_coords.unsqueeze(2)
normalized_point_coords = 2.0 * point_coords - 1.0 # Normalize to [-1,1]
output = F.grid_sample(input, normalized_point_coords, **kwargs)
if add_dim:
output = output.squeeze(3)
return output
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
def get_uncertain_point_coords_with_randomness(
logits: torch.Tensor,
uncertainty_func: Callable,
num_points: int,
oversample_ratio: int,
importance_sample_ratio: float,
) -> torch.Tensor:
"""
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
are calculated for each point using 'uncertainty_func' function that takes point's logit
prediction as input.
See PointRend paper for details.
Args:
logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
class-specific or class-agnostic prediction.
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
contains logit predictions for P points and returns their uncertainties as a Tensor of
shape (N, 1, P).
num_points (int): The number of points P to sample.
oversample_ratio (int): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
sampled points.
"""
assert oversample_ratio >= 1
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
num_boxes = logits.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(num_boxes, num_sampled, 2, device=logits.device)
point_logits = point_sample(logits, point_coords, align_corners=False)
# It is crucial to calculate uncertainty based on the sampled prediction value for the points.
# Calculating uncertainties of the predictions first and sampling them for points leads
# to incorrect results.
# To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
# two predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
# However, if we calculate uncertainties for the predictions first,
# both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
point_uncertainties = uncertainty_func(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
# Flatten the indices
shift = num_sampled * torch.arange(
num_boxes, dtype=torch.long, device=logits.device
)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
num_boxes, num_uncertain_points, 2
)
if num_random_points > 0:
point_coords = torch.cat(
[
point_coords,
torch.rand(num_boxes, num_random_points, 2, device=logits.device),
],
dim=1,
)
return point_coords
# Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py
def calculate_uncertainty(logits: torch.Tensor) -> torch.Tensor:
"""
Estimates uncerainty as L1 distance between 0.0 and the logit prediction.
Args:
logits (Tensor): A tensor of shape (R, 1, ...) for class-agnostic
predicted masks
Returns:
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
the most uncertain locations having the highest uncertainty score.
"""
assert logits.shape[1] == 1
return -(torch.abs(logits))

View File

@@ -0,0 +1,203 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import torch
from sam3.model.model_misc import SAM3Output
from sam3.train.utils.distributed import get_world_size
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks
class DummyLoss(torch.nn.Module):
"""A dummy loss that always returns 0 (as a placeholder for eval)"""
def __init__(
self,
core_loss_key: str = CORE_LOSS_KEY,
device: str = "cuda",
**kwargs,
):
super().__init__()
self.core_loss_key = core_loss_key
self.device = torch.device(device)
def forward(self, *args, **kwargs):
return {self.core_loss_key: torch.tensor(0.0, device=self.device)}
def accumulate(self, out_dict):
"""
Called by iterative losses.
"""
if self.core_loss_key not in out_dict:
out_dict[self.core_loss_key] = torch.tensor(0.0, device=self.device)
return out_dict
class Sam3LossWrapper(torch.nn.Module):
def __init__(
self,
loss_fns_find,
normalization="global",
matcher=None,
o2m_matcher=None,
o2m_weight=1.0,
use_o2m_matcher_on_o2m_aux=True,
loss_fn_semantic_seg=None,
normalize_by_valid_object_num=False,
normalize_by_stage_num=False,
scale_by_find_batch_size=False,
):
super().__init__()
self.loss_fns_find = loss_fns_find
assert normalization in ["global", "local", "none"]
self.normalization = normalization
self.normalize_by_valid_object_num = normalize_by_valid_object_num
self.normalize_by_stage_num = normalize_by_stage_num
self.matcher = matcher
self.o2m_matcher = o2m_matcher
self.o2m_weight = o2m_weight
# whether to use the o2m matcher on the o2m queries in auxiliary outputs
self.use_o2m_matcher_on_o2m_aux = use_o2m_matcher_on_o2m_aux
self.loss_fn_semantic_seg = loss_fn_semantic_seg
self.scale_by_find_batch_size = scale_by_find_batch_size
def _get_num_boxes(self, targets):
# the average number of target boxes for loss normalization
if self.normalize_by_valid_object_num:
# valid boxes are those with non-zero height and width
# (while padded invisible boxes are )
boxes_hw = targets["boxes"].view(-1, 4) # cx, cy, w, h
num_boxes = (boxes_hw[:, 2:] > 0).all(dim=-1).sum().float()
else:
num_boxes = targets["num_boxes"].sum().float()
if self.normalization == "global":
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1)
elif self.normalization == "local":
num_boxes = torch.clamp(num_boxes, min=1)
elif self.normalization == "none":
num_boxes = 1
return num_boxes
def compute_loss(self, nested_out, targets):
num_boxes = self._get_num_boxes(targets)
o2m_out_is_valid = nested_out.get("o2m_out_is_valid", None)
o2m_target_is_valid_padded = nested_out.get("o2m_target_is_valid_padded", None)
# Get a list of outputs, including auxiliary and first stage outputs
output_list = [(nested_out, "", False)] # (out, suffix, is_aux)
if "aux_outputs" in nested_out:
output_list.extend(
(aux_out, f"_aux_{i}", True)
for i, aux_out in enumerate(nested_out["aux_outputs"])
)
if "first_stage" in nested_out:
output_list.append((nested_out["first_stage"], "_fs", True))
# Compute all the requested losses
losses = {}
total_core_loss = 0.0
for out, suffix, is_aux in output_list:
# o2o matcher indices need to be computed by the model (as the video model requires
# a specific way of matching free and locked indices beyond just calling the matcher)
indices = out["indices"]
has_o2m_out = "pred_logits_o2m" in out
if has_o2m_out:
o2m_out = {
k[: -len("_o2m")]: v for k, v in out.items() if k.endswith("_o2m")
}
# o2m targets are the same as the o2o targets (assuming repeat=1)
o2m_targets = targets
if self.use_o2m_matcher_on_o2m_aux or not is_aux:
o2m_indices = self.o2m_matcher(
o2m_out,
o2m_targets,
out_is_valid=o2m_out_is_valid,
target_is_valid_padded=o2m_target_is_valid_padded,
)
else:
o2m_indices = self.matcher(
o2m_out,
o2m_targets,
out_is_valid=o2m_out_is_valid,
target_is_valid_padded=o2m_target_is_valid_padded,
)
for loss_fn in self.loss_fns_find:
l_dict = loss_fn(
outputs=out,
targets=targets,
indices=indices,
num_boxes=num_boxes,
is_aux=is_aux,
)
total_core_loss += l_dict.pop(CORE_LOSS_KEY)
losses.update({f"{k}{suffix}": v for k, v in l_dict.items()})
compute_o2m_loss = has_o2m_out
# a special handling to allow turning off mask loss in o2m
# (to be compatible with the original implementation)
if isinstance(loss_fn, Masks):
compute_o2m_loss = compute_o2m_loss and "pred_masks" in o2m_out
if isinstance(loss_fn, Det2TrkAssoc):
compute_o2m_loss = False # Det2TrkAssoc does not support o2m
if compute_o2m_loss:
l_dict = loss_fn(
outputs=o2m_out,
targets=o2m_targets,
indices=o2m_indices,
num_boxes=num_boxes,
is_aux=is_aux,
)
for k in l_dict:
l_dict[k] *= self.o2m_weight
total_core_loss += l_dict.pop(CORE_LOSS_KEY)
losses.update({f"{k}{suffix}_o2m": v for k, v in l_dict.items()})
losses[CORE_LOSS_KEY] = total_core_loss
return losses
def forward(self, find_stages: SAM3Output, find_targets):
if find_stages.loss_stages is not None:
find_targets = [find_targets[i] for i in find_stages.loss_stages]
with SAM3Output.iteration_mode(
find_stages, iter_mode=SAM3Output.IterMode.ALL_STEPS_PER_STAGE
) as find_stages:
assert len(find_stages) == len(find_targets)
total_losses = {}
for stage_outputs, stage_targets in zip(find_stages, find_targets):
stage_targets = [stage_targets] * len(stage_outputs)
# If there are multiple steps within a stage, compute the loss for all of them (e.g. interactivity)
for outputs, targets in zip(stage_outputs, stage_targets):
cur_losses = self.compute_loss(outputs, targets)
if self.loss_fn_semantic_seg is not None:
cur_losses_semantic = self.loss_fn_semantic_seg(
outputs, targets
)
cur_losses[CORE_LOSS_KEY] += cur_losses_semantic.pop(
CORE_LOSS_KEY
)
# make sure the semantic losses don't overlap with the find losses
assert set(cur_losses).isdisjoint(set(cur_losses_semantic))
cur_losses.update(cur_losses_semantic)
# Optionally, normalize the loss by the number of find stages (training video frames) so that
# image batches and video batches have similar loss scales. (Otherwise video batches would
# have a much higher loss scale due to summing the losses over all the find stages.)
if self.normalize_by_stage_num:
cur_losses[CORE_LOSS_KEY] /= len(find_stages)
if self.scale_by_find_batch_size:
bs = targets["num_boxes"].shape[0]
# sqrt scaling based on the "effective" batch size
cur_losses[CORE_LOSS_KEY] *= bs**0.5
for k, v in cur_losses.items():
if k not in total_losses:
total_losses[k] = v
else:
total_losses[k] += v
return total_losses

View File

@@ -0,0 +1,321 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Triton kernel for faster and memory efficient sigmoid focal loss"""
import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_helpers import libdevice
"""
The sigmoid focal loss is defined as:
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * ce_loss * ((1 - p_t) ** gamma)
Where alpha and gamma are scalar parameters, inputs are the logits, targets the float targets.
We implement two versions of the sigmoid focal loss: with and without sum reduction.
The latter is implemented with built-in reduction to avoid materializing wrt the output of the loss.
This can help save a bit of peak memory.
The reduction version is implemented using somewhat of a hack. Pytorch's generated kernels usually do the point-wise operation in a first kernel, and implement the reduction another kernel launched on a grid of size 1, where the reduction happens as a for loop in the triton kernel.
Since we want to fuse those two kernels, that is not a good idea: we'd have to launch the overall kernel on a grid of size 1, which is obviously inefficient.
On the other hand, typical CUDA algorithms for reduction (eg reduction tree) are hard to implement in triton due to the lack of thread sync primitives.
We settle for a version that abuses triton's atomic_add: we can have all threads simply add to the same location.
In practice, this is not good, since it creates a massive bottleneck on the semaphore for that single memory location. So instead, we create M reduction locations. Each thread will simply write to thread_id%M. The python code can finally sum over the M reductions.
M = 32 works fine in benchmarking tests. The forward is a tiny bit slower compared to the non-reduced kernel, but the backward breaks even due to one less memory allocation.
"""
@triton.jit
def _inner_focal_loss_fwd(inputs, targets, alpha, gamma):
inv_targets = 1 - targets
# Sigmoid
sig = tl.sigmoid(inputs)
# Binary cross entropy with logits
# In practice, we want the following:
# bce_loss = -targets * tl.log(sig) - (1 - targets) * tl.log(1 - sig)
# However, the above is not numerically stable.
# We're also not directly taking the sum here, so the usual log-sum-exp trick doesn't apply
# The bce can be reformulated, after algebraic manipulation, to
# bce_loss = log(1 + exp(-x)) + x * (1-y)
# This is still not stable, because for large (-x) the exponential will blow up.
# We'll use the following alternate formulation:
# bce_loss = max(x, 0) - x * y + log(1 + exp(-abs(x)))
# Let's show that it's equivalent:
# Case x>=0: abs(x) = x , max(x, 0) = x
# so we get x - x * y + log(1 + exp(-x)) which is equivalent
# Case x<0: abs(x) = -x, max(x, 0) = 0
# we have log(1 + exp(-abs(x))) = log(1 + exp(x)) = log(exp(x)(1 + exp(-x))) = x+log(1 + exp(-x))
# plugging it in, we get
# 0 - x * y + x + log(1 + exp(-x)), which is also equivalent
# Note that this is stable because now the exponent are guaranteed to be below 0.
max_val = tl.clamp(inputs, min=0, max=1e9)
bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
# Modulating factor
p_t = sig * targets + (1 - sig) * inv_targets
mod_factor = libdevice.pow(1 - p_t, gamma)
# Alpha factor
alpha_t = alpha * targets + (1 - alpha) * inv_targets
# Final loss calculation
return alpha_t * mod_factor * bce_loss
# Non-reduced version
@triton.jit
def sigmoid_focal_loss_fwd_kernel(
inputs_ptr,
targets_ptr,
loss_ptr,
alpha: float,
gamma: float,
n_elements: int,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offset = block_start + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements
# Load data
inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
targets = tl.load(targets_ptr + offset, mask=mask)
final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma)
# Store result
tl.store(loss_ptr + offset, final_loss, mask=mask)
# version with reduction
@triton.jit
def sigmoid_focal_loss_fwd_kernel_reduce(
inputs_ptr,
targets_ptr,
loss_ptr,
alpha: float,
gamma: float,
n_elements: int,
BLOCK_SIZE: tl.constexpr,
REDUCE_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
reduce_loc = pid % REDUCE_SIZE
offset = block_start + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements
# Load data
inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
targets = tl.load(targets_ptr + offset, mask=mask)
final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) * mask
fl = tl.sum(final_loss)
# Store result
tl.atomic_add(loss_ptr + reduce_loc, fl)
@triton.jit
def _inner_focal_loss_bwd(inputs, targets, alpha, gamma):
inv_targets = 1 - targets
# Recompute forward
max_val = tl.clamp(inputs, min=0, max=1e9)
bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
# Sigmoid
sig = tl.sigmoid(inputs)
inv_sig = 1 - sig
# Modulating factor
p_t = sig * targets + inv_sig * inv_targets
tmp = libdevice.pow(1 - p_t, gamma - 1)
mod_factor = tmp * (1 - p_t)
# Alpha factor
alpha_t = alpha * targets + (1 - alpha) * inv_targets
# Now computing the derivatives
d_pt = (2 * targets - 1) * sig * inv_sig
d_mod_factor = -gamma * d_pt * tmp
d_bce_loss = sig - targets
return alpha_t * (d_bce_loss * mod_factor + d_mod_factor * bce_loss)
@triton.jit
def sigmoid_focal_loss_bwd_kernel(
inputs_ptr,
targets_ptr,
grad_inputs_ptr,
grad_out_ptr,
alpha: float,
gamma: float,
n_elements: int,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offset = block_start + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements
input_ptrs = inputs_ptr + offset
target_ptrs = targets_ptr + offset
grad_input_ptrs = grad_inputs_ptr + offset
grad_out_ptrs = grad_out_ptr + offset
# Load data
inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
targets = tl.load(target_ptrs, mask=mask)
grad_out = tl.load(grad_out_ptrs, mask=mask)
d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
tl.store(grad_input_ptrs, d_loss, mask=mask)
@triton.jit
def sigmoid_focal_loss_bwd_kernel_reduce(
inputs_ptr,
targets_ptr,
grad_inputs_ptr,
grad_out_ptr,
alpha: float,
gamma: float,
n_elements: int,
BLOCK_SIZE: tl.constexpr,
):
# The only difference is that the gradient is now a single scalar
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offset = block_start + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements
input_ptrs = inputs_ptr + offset
target_ptrs = targets_ptr + offset
grad_input_ptrs = grad_inputs_ptr + offset
# Load data
inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
targets = tl.load(target_ptrs, mask=mask)
grad_out = tl.load(grad_out_ptr)
d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
tl.store(grad_input_ptrs, d_loss, mask=mask)
class SigmoidFocalLoss(torch.autograd.Function):
BLOCK_SIZE = 256
@staticmethod
def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
n_elements = inputs.numel()
assert targets.numel() == n_elements
input_shape = inputs.shape
inputs = inputs.view(-1).contiguous()
targets = targets.view(-1).contiguous()
loss = torch.empty(inputs.shape, dtype=torch.float32, device=inputs.device)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
sigmoid_focal_loss_fwd_kernel[grid](
inputs, targets, loss, alpha, gamma, n_elements, SigmoidFocalLoss.BLOCK_SIZE
)
ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
ctx.alpha = alpha
ctx.gamma = gamma
return loss.view(input_shape)
@staticmethod
def backward(ctx, grad_output):
inputs, targets = ctx.saved_tensors
alpha = ctx.alpha
gamma = ctx.gamma
n_elements = inputs.numel()
input_shape = inputs.shape
grad_inputs = torch.empty(
inputs.shape, dtype=grad_output.dtype, device=grad_output.device
)
inputs_ptr = inputs.view(-1).contiguous()
targets_ptr = targets.view(-1).contiguous()
grad_output_ptr = grad_output.view(-1).contiguous()
grad_inputs_ptr = grad_inputs
assert grad_output.numel() == n_elements
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
sigmoid_focal_loss_bwd_kernel[grid](
inputs_ptr,
targets_ptr,
grad_inputs_ptr,
grad_output_ptr,
alpha,
gamma,
n_elements,
SigmoidFocalLoss.BLOCK_SIZE,
)
return grad_inputs.view(input_shape), None, None, None
triton_sigmoid_focal_loss = SigmoidFocalLoss.apply
class SigmoidFocalLossReduced(torch.autograd.Function):
BLOCK_SIZE = 256
REDUCE_SIZE = 32
@staticmethod
def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
n_elements = inputs.numel()
input_shape = inputs.shape
inputs = inputs.view(-1).contiguous()
targets = targets.view(-1).contiguous()
loss = torch.zeros(
SigmoidFocalLossReduced.REDUCE_SIZE,
device=inputs.device,
dtype=torch.float32,
)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
sigmoid_focal_loss_fwd_kernel_reduce[grid](
inputs,
targets,
loss,
alpha,
gamma,
n_elements,
SigmoidFocalLossReduced.BLOCK_SIZE,
SigmoidFocalLossReduced.REDUCE_SIZE,
)
ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
ctx.alpha = alpha
ctx.gamma = gamma
return loss.sum()
@staticmethod
def backward(ctx, grad_output):
inputs, targets = ctx.saved_tensors
alpha = ctx.alpha
gamma = ctx.gamma
n_elements = inputs.numel()
input_shape = inputs.shape
grad_inputs = torch.empty(
inputs.shape, dtype=grad_output.dtype, device=grad_output.device
)
inputs_ptr = inputs.view(-1).contiguous()
targets_ptr = targets.reshape(-1).contiguous()
assert grad_output.numel() == 1
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
sigmoid_focal_loss_bwd_kernel_reduce[grid](
inputs_ptr,
targets_ptr,
grad_inputs,
grad_output,
alpha,
gamma,
n_elements,
SigmoidFocalLossReduced.BLOCK_SIZE,
)
return grad_inputs.view(input_shape), None, None, None
triton_sigmoid_focal_loss_reduce = SigmoidFocalLossReduced.apply