Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
1
sam3/train/loss/__init__.py
Normal file
1
sam3/train/loss/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
1319
sam3/train/loss/loss_fns.py
Normal file
1319
sam3/train/loss/loss_fns.py
Normal file
File diff suppressed because it is too large
Load Diff
113
sam3/train/loss/mask_sampling.py
Normal file
113
sam3/train/loss/mask_sampling.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
|
||||
def point_sample(input, point_coords, **kwargs):
|
||||
"""
|
||||
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
|
||||
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
|
||||
[0, 1] x [0, 1] square.
|
||||
|
||||
Args:
|
||||
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
|
||||
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
|
||||
[0, 1] x [0, 1] normalized point coordinates.
|
||||
|
||||
Returns:
|
||||
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
|
||||
features for points in `point_coords`. The features are obtained via bilinear
|
||||
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
|
||||
"""
|
||||
add_dim = False
|
||||
if point_coords.dim() == 3:
|
||||
add_dim = True
|
||||
point_coords = point_coords.unsqueeze(2)
|
||||
normalized_point_coords = 2.0 * point_coords - 1.0 # Normalize to [-1,1]
|
||||
output = F.grid_sample(input, normalized_point_coords, **kwargs)
|
||||
if add_dim:
|
||||
output = output.squeeze(3)
|
||||
return output
|
||||
|
||||
|
||||
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
|
||||
def get_uncertain_point_coords_with_randomness(
|
||||
logits: torch.Tensor,
|
||||
uncertainty_func: Callable,
|
||||
num_points: int,
|
||||
oversample_ratio: int,
|
||||
importance_sample_ratio: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
|
||||
are calculated for each point using 'uncertainty_func' function that takes point's logit
|
||||
prediction as input.
|
||||
See PointRend paper for details.
|
||||
|
||||
Args:
|
||||
logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
|
||||
class-specific or class-agnostic prediction.
|
||||
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
|
||||
contains logit predictions for P points and returns their uncertainties as a Tensor of
|
||||
shape (N, 1, P).
|
||||
num_points (int): The number of points P to sample.
|
||||
oversample_ratio (int): Oversampling parameter.
|
||||
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
|
||||
|
||||
Returns:
|
||||
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
|
||||
sampled points.
|
||||
"""
|
||||
assert oversample_ratio >= 1
|
||||
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
|
||||
num_boxes = logits.shape[0]
|
||||
num_sampled = int(num_points * oversample_ratio)
|
||||
point_coords = torch.rand(num_boxes, num_sampled, 2, device=logits.device)
|
||||
point_logits = point_sample(logits, point_coords, align_corners=False)
|
||||
# It is crucial to calculate uncertainty based on the sampled prediction value for the points.
|
||||
# Calculating uncertainties of the predictions first and sampling them for points leads
|
||||
# to incorrect results.
|
||||
# To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
|
||||
# two predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
|
||||
# However, if we calculate uncertainties for the predictions first,
|
||||
# both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
|
||||
point_uncertainties = uncertainty_func(point_logits)
|
||||
num_uncertain_points = int(importance_sample_ratio * num_points)
|
||||
num_random_points = num_points - num_uncertain_points
|
||||
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
||||
# Flatten the indices
|
||||
shift = num_sampled * torch.arange(
|
||||
num_boxes, dtype=torch.long, device=logits.device
|
||||
)
|
||||
idx += shift[:, None]
|
||||
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
|
||||
num_boxes, num_uncertain_points, 2
|
||||
)
|
||||
if num_random_points > 0:
|
||||
point_coords = torch.cat(
|
||||
[
|
||||
point_coords,
|
||||
torch.rand(num_boxes, num_random_points, 2, device=logits.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return point_coords
|
||||
|
||||
|
||||
# Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py
|
||||
def calculate_uncertainty(logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Estimates uncerainty as L1 distance between 0.0 and the logit prediction.
|
||||
Args:
|
||||
logits (Tensor): A tensor of shape (R, 1, ...) for class-agnostic
|
||||
predicted masks
|
||||
Returns:
|
||||
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
|
||||
the most uncertain locations having the highest uncertainty score.
|
||||
"""
|
||||
assert logits.shape[1] == 1
|
||||
return -(torch.abs(logits))
|
||||
203
sam3/train/loss/sam3_loss.py
Normal file
203
sam3/train/loss/sam3_loss.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.model_misc import SAM3Output
|
||||
|
||||
from sam3.train.utils.distributed import get_world_size
|
||||
|
||||
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks
|
||||
|
||||
|
||||
class DummyLoss(torch.nn.Module):
|
||||
"""A dummy loss that always returns 0 (as a placeholder for eval)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
core_loss_key: str = CORE_LOSS_KEY,
|
||||
device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.core_loss_key = core_loss_key
|
||||
self.device = torch.device(device)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return {self.core_loss_key: torch.tensor(0.0, device=self.device)}
|
||||
|
||||
def accumulate(self, out_dict):
|
||||
"""
|
||||
Called by iterative losses.
|
||||
"""
|
||||
if self.core_loss_key not in out_dict:
|
||||
out_dict[self.core_loss_key] = torch.tensor(0.0, device=self.device)
|
||||
return out_dict
|
||||
|
||||
|
||||
class Sam3LossWrapper(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
loss_fns_find,
|
||||
normalization="global",
|
||||
matcher=None,
|
||||
o2m_matcher=None,
|
||||
o2m_weight=1.0,
|
||||
use_o2m_matcher_on_o2m_aux=True,
|
||||
loss_fn_semantic_seg=None,
|
||||
normalize_by_valid_object_num=False,
|
||||
normalize_by_stage_num=False,
|
||||
scale_by_find_batch_size=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_fns_find = loss_fns_find
|
||||
assert normalization in ["global", "local", "none"]
|
||||
self.normalization = normalization
|
||||
self.normalize_by_valid_object_num = normalize_by_valid_object_num
|
||||
self.normalize_by_stage_num = normalize_by_stage_num
|
||||
self.matcher = matcher
|
||||
self.o2m_matcher = o2m_matcher
|
||||
self.o2m_weight = o2m_weight
|
||||
# whether to use the o2m matcher on the o2m queries in auxiliary outputs
|
||||
self.use_o2m_matcher_on_o2m_aux = use_o2m_matcher_on_o2m_aux
|
||||
self.loss_fn_semantic_seg = loss_fn_semantic_seg
|
||||
self.scale_by_find_batch_size = scale_by_find_batch_size
|
||||
|
||||
def _get_num_boxes(self, targets):
|
||||
# the average number of target boxes for loss normalization
|
||||
if self.normalize_by_valid_object_num:
|
||||
# valid boxes are those with non-zero height and width
|
||||
# (while padded invisible boxes are )
|
||||
boxes_hw = targets["boxes"].view(-1, 4) # cx, cy, w, h
|
||||
num_boxes = (boxes_hw[:, 2:] > 0).all(dim=-1).sum().float()
|
||||
else:
|
||||
num_boxes = targets["num_boxes"].sum().float()
|
||||
if self.normalization == "global":
|
||||
torch.distributed.all_reduce(num_boxes)
|
||||
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1)
|
||||
elif self.normalization == "local":
|
||||
num_boxes = torch.clamp(num_boxes, min=1)
|
||||
elif self.normalization == "none":
|
||||
num_boxes = 1
|
||||
return num_boxes
|
||||
|
||||
def compute_loss(self, nested_out, targets):
|
||||
num_boxes = self._get_num_boxes(targets)
|
||||
o2m_out_is_valid = nested_out.get("o2m_out_is_valid", None)
|
||||
o2m_target_is_valid_padded = nested_out.get("o2m_target_is_valid_padded", None)
|
||||
|
||||
# Get a list of outputs, including auxiliary and first stage outputs
|
||||
output_list = [(nested_out, "", False)] # (out, suffix, is_aux)
|
||||
if "aux_outputs" in nested_out:
|
||||
output_list.extend(
|
||||
(aux_out, f"_aux_{i}", True)
|
||||
for i, aux_out in enumerate(nested_out["aux_outputs"])
|
||||
)
|
||||
if "first_stage" in nested_out:
|
||||
output_list.append((nested_out["first_stage"], "_fs", True))
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
total_core_loss = 0.0
|
||||
for out, suffix, is_aux in output_list:
|
||||
# o2o matcher indices need to be computed by the model (as the video model requires
|
||||
# a specific way of matching free and locked indices beyond just calling the matcher)
|
||||
indices = out["indices"]
|
||||
has_o2m_out = "pred_logits_o2m" in out
|
||||
if has_o2m_out:
|
||||
o2m_out = {
|
||||
k[: -len("_o2m")]: v for k, v in out.items() if k.endswith("_o2m")
|
||||
}
|
||||
# o2m targets are the same as the o2o targets (assuming repeat=1)
|
||||
o2m_targets = targets
|
||||
if self.use_o2m_matcher_on_o2m_aux or not is_aux:
|
||||
o2m_indices = self.o2m_matcher(
|
||||
o2m_out,
|
||||
o2m_targets,
|
||||
out_is_valid=o2m_out_is_valid,
|
||||
target_is_valid_padded=o2m_target_is_valid_padded,
|
||||
)
|
||||
else:
|
||||
o2m_indices = self.matcher(
|
||||
o2m_out,
|
||||
o2m_targets,
|
||||
out_is_valid=o2m_out_is_valid,
|
||||
target_is_valid_padded=o2m_target_is_valid_padded,
|
||||
)
|
||||
|
||||
for loss_fn in self.loss_fns_find:
|
||||
l_dict = loss_fn(
|
||||
outputs=out,
|
||||
targets=targets,
|
||||
indices=indices,
|
||||
num_boxes=num_boxes,
|
||||
is_aux=is_aux,
|
||||
)
|
||||
total_core_loss += l_dict.pop(CORE_LOSS_KEY)
|
||||
losses.update({f"{k}{suffix}": v for k, v in l_dict.items()})
|
||||
|
||||
compute_o2m_loss = has_o2m_out
|
||||
# a special handling to allow turning off mask loss in o2m
|
||||
# (to be compatible with the original implementation)
|
||||
if isinstance(loss_fn, Masks):
|
||||
compute_o2m_loss = compute_o2m_loss and "pred_masks" in o2m_out
|
||||
if isinstance(loss_fn, Det2TrkAssoc):
|
||||
compute_o2m_loss = False # Det2TrkAssoc does not support o2m
|
||||
if compute_o2m_loss:
|
||||
l_dict = loss_fn(
|
||||
outputs=o2m_out,
|
||||
targets=o2m_targets,
|
||||
indices=o2m_indices,
|
||||
num_boxes=num_boxes,
|
||||
is_aux=is_aux,
|
||||
)
|
||||
for k in l_dict:
|
||||
l_dict[k] *= self.o2m_weight
|
||||
total_core_loss += l_dict.pop(CORE_LOSS_KEY)
|
||||
losses.update({f"{k}{suffix}_o2m": v for k, v in l_dict.items()})
|
||||
|
||||
losses[CORE_LOSS_KEY] = total_core_loss
|
||||
return losses
|
||||
|
||||
def forward(self, find_stages: SAM3Output, find_targets):
|
||||
if find_stages.loss_stages is not None:
|
||||
find_targets = [find_targets[i] for i in find_stages.loss_stages]
|
||||
with SAM3Output.iteration_mode(
|
||||
find_stages, iter_mode=SAM3Output.IterMode.ALL_STEPS_PER_STAGE
|
||||
) as find_stages:
|
||||
assert len(find_stages) == len(find_targets)
|
||||
total_losses = {}
|
||||
for stage_outputs, stage_targets in zip(find_stages, find_targets):
|
||||
stage_targets = [stage_targets] * len(stage_outputs)
|
||||
# If there are multiple steps within a stage, compute the loss for all of them (e.g. interactivity)
|
||||
for outputs, targets in zip(stage_outputs, stage_targets):
|
||||
cur_losses = self.compute_loss(outputs, targets)
|
||||
|
||||
if self.loss_fn_semantic_seg is not None:
|
||||
cur_losses_semantic = self.loss_fn_semantic_seg(
|
||||
outputs, targets
|
||||
)
|
||||
cur_losses[CORE_LOSS_KEY] += cur_losses_semantic.pop(
|
||||
CORE_LOSS_KEY
|
||||
)
|
||||
# make sure the semantic losses don't overlap with the find losses
|
||||
assert set(cur_losses).isdisjoint(set(cur_losses_semantic))
|
||||
cur_losses.update(cur_losses_semantic)
|
||||
|
||||
# Optionally, normalize the loss by the number of find stages (training video frames) so that
|
||||
# image batches and video batches have similar loss scales. (Otherwise video batches would
|
||||
# have a much higher loss scale due to summing the losses over all the find stages.)
|
||||
if self.normalize_by_stage_num:
|
||||
cur_losses[CORE_LOSS_KEY] /= len(find_stages)
|
||||
|
||||
if self.scale_by_find_batch_size:
|
||||
bs = targets["num_boxes"].shape[0]
|
||||
# sqrt scaling based on the "effective" batch size
|
||||
cur_losses[CORE_LOSS_KEY] *= bs**0.5
|
||||
|
||||
for k, v in cur_losses.items():
|
||||
if k not in total_losses:
|
||||
total_losses[k] = v
|
||||
else:
|
||||
total_losses[k] += v
|
||||
|
||||
return total_losses
|
||||
321
sam3/train/loss/sigmoid_focal_loss.py
Normal file
321
sam3/train/loss/sigmoid_focal_loss.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""Triton kernel for faster and memory efficient sigmoid focal loss"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch._inductor.runtime.triton_helpers import libdevice
|
||||
|
||||
"""
|
||||
|
||||
The sigmoid focal loss is defined as:
|
||||
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
Where alpha and gamma are scalar parameters, inputs are the logits, targets the float targets.
|
||||
|
||||
We implement two versions of the sigmoid focal loss: with and without sum reduction.
|
||||
The latter is implemented with built-in reduction to avoid materializing wrt the output of the loss.
|
||||
This can help save a bit of peak memory.
|
||||
|
||||
The reduction version is implemented using somewhat of a hack. Pytorch's generated kernels usually do the point-wise operation in a first kernel, and implement the reduction another kernel launched on a grid of size 1, where the reduction happens as a for loop in the triton kernel.
|
||||
Since we want to fuse those two kernels, that is not a good idea: we'd have to launch the overall kernel on a grid of size 1, which is obviously inefficient.
|
||||
On the other hand, typical CUDA algorithms for reduction (eg reduction tree) are hard to implement in triton due to the lack of thread sync primitives.
|
||||
We settle for a version that abuses triton's atomic_add: we can have all threads simply add to the same location.
|
||||
In practice, this is not good, since it creates a massive bottleneck on the semaphore for that single memory location. So instead, we create M reduction locations. Each thread will simply write to thread_id%M. The python code can finally sum over the M reductions.
|
||||
M = 32 works fine in benchmarking tests. The forward is a tiny bit slower compared to the non-reduced kernel, but the backward breaks even due to one less memory allocation.
|
||||
"""
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _inner_focal_loss_fwd(inputs, targets, alpha, gamma):
|
||||
inv_targets = 1 - targets
|
||||
# Sigmoid
|
||||
sig = tl.sigmoid(inputs)
|
||||
|
||||
# Binary cross entropy with logits
|
||||
# In practice, we want the following:
|
||||
# bce_loss = -targets * tl.log(sig) - (1 - targets) * tl.log(1 - sig)
|
||||
# However, the above is not numerically stable.
|
||||
# We're also not directly taking the sum here, so the usual log-sum-exp trick doesn't apply
|
||||
# The bce can be reformulated, after algebraic manipulation, to
|
||||
# bce_loss = log(1 + exp(-x)) + x * (1-y)
|
||||
# This is still not stable, because for large (-x) the exponential will blow up.
|
||||
# We'll use the following alternate formulation:
|
||||
# bce_loss = max(x, 0) - x * y + log(1 + exp(-abs(x)))
|
||||
# Let's show that it's equivalent:
|
||||
# Case x>=0: abs(x) = x , max(x, 0) = x
|
||||
# so we get x - x * y + log(1 + exp(-x)) which is equivalent
|
||||
# Case x<0: abs(x) = -x, max(x, 0) = 0
|
||||
# we have log(1 + exp(-abs(x))) = log(1 + exp(x)) = log(exp(x)(1 + exp(-x))) = x+log(1 + exp(-x))
|
||||
# plugging it in, we get
|
||||
# 0 - x * y + x + log(1 + exp(-x)), which is also equivalent
|
||||
# Note that this is stable because now the exponent are guaranteed to be below 0.
|
||||
max_val = tl.clamp(inputs, min=0, max=1e9)
|
||||
bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
|
||||
|
||||
# Modulating factor
|
||||
p_t = sig * targets + (1 - sig) * inv_targets
|
||||
mod_factor = libdevice.pow(1 - p_t, gamma)
|
||||
|
||||
# Alpha factor
|
||||
alpha_t = alpha * targets + (1 - alpha) * inv_targets
|
||||
|
||||
# Final loss calculation
|
||||
return alpha_t * mod_factor * bce_loss
|
||||
|
||||
|
||||
# Non-reduced version
|
||||
@triton.jit
|
||||
def sigmoid_focal_loss_fwd_kernel(
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
loss_ptr,
|
||||
alpha: float,
|
||||
gamma: float,
|
||||
n_elements: int,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offset = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < n_elements
|
||||
|
||||
# Load data
|
||||
inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
|
||||
targets = tl.load(targets_ptr + offset, mask=mask)
|
||||
|
||||
final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma)
|
||||
|
||||
# Store result
|
||||
tl.store(loss_ptr + offset, final_loss, mask=mask)
|
||||
|
||||
|
||||
# version with reduction
|
||||
@triton.jit
|
||||
def sigmoid_focal_loss_fwd_kernel_reduce(
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
loss_ptr,
|
||||
alpha: float,
|
||||
gamma: float,
|
||||
n_elements: int,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
REDUCE_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
reduce_loc = pid % REDUCE_SIZE
|
||||
offset = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < n_elements
|
||||
# Load data
|
||||
inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
|
||||
targets = tl.load(targets_ptr + offset, mask=mask)
|
||||
|
||||
final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) * mask
|
||||
|
||||
fl = tl.sum(final_loss)
|
||||
|
||||
# Store result
|
||||
tl.atomic_add(loss_ptr + reduce_loc, fl)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _inner_focal_loss_bwd(inputs, targets, alpha, gamma):
|
||||
inv_targets = 1 - targets
|
||||
|
||||
# Recompute forward
|
||||
max_val = tl.clamp(inputs, min=0, max=1e9)
|
||||
bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
|
||||
|
||||
# Sigmoid
|
||||
sig = tl.sigmoid(inputs)
|
||||
inv_sig = 1 - sig
|
||||
|
||||
# Modulating factor
|
||||
p_t = sig * targets + inv_sig * inv_targets
|
||||
tmp = libdevice.pow(1 - p_t, gamma - 1)
|
||||
mod_factor = tmp * (1 - p_t)
|
||||
|
||||
# Alpha factor
|
||||
alpha_t = alpha * targets + (1 - alpha) * inv_targets
|
||||
|
||||
# Now computing the derivatives
|
||||
d_pt = (2 * targets - 1) * sig * inv_sig
|
||||
d_mod_factor = -gamma * d_pt * tmp
|
||||
|
||||
d_bce_loss = sig - targets
|
||||
|
||||
return alpha_t * (d_bce_loss * mod_factor + d_mod_factor * bce_loss)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sigmoid_focal_loss_bwd_kernel(
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
grad_inputs_ptr,
|
||||
grad_out_ptr,
|
||||
alpha: float,
|
||||
gamma: float,
|
||||
n_elements: int,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offset = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < n_elements
|
||||
input_ptrs = inputs_ptr + offset
|
||||
target_ptrs = targets_ptr + offset
|
||||
grad_input_ptrs = grad_inputs_ptr + offset
|
||||
grad_out_ptrs = grad_out_ptr + offset
|
||||
# Load data
|
||||
inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
|
||||
targets = tl.load(target_ptrs, mask=mask)
|
||||
grad_out = tl.load(grad_out_ptrs, mask=mask)
|
||||
d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
|
||||
tl.store(grad_input_ptrs, d_loss, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sigmoid_focal_loss_bwd_kernel_reduce(
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
grad_inputs_ptr,
|
||||
grad_out_ptr,
|
||||
alpha: float,
|
||||
gamma: float,
|
||||
n_elements: int,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# The only difference is that the gradient is now a single scalar
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offset = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < n_elements
|
||||
input_ptrs = inputs_ptr + offset
|
||||
target_ptrs = targets_ptr + offset
|
||||
grad_input_ptrs = grad_inputs_ptr + offset
|
||||
# Load data
|
||||
inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
|
||||
targets = tl.load(target_ptrs, mask=mask)
|
||||
grad_out = tl.load(grad_out_ptr)
|
||||
d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
|
||||
tl.store(grad_input_ptrs, d_loss, mask=mask)
|
||||
|
||||
|
||||
class SigmoidFocalLoss(torch.autograd.Function):
|
||||
BLOCK_SIZE = 256
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
|
||||
n_elements = inputs.numel()
|
||||
assert targets.numel() == n_elements
|
||||
input_shape = inputs.shape
|
||||
inputs = inputs.view(-1).contiguous()
|
||||
targets = targets.view(-1).contiguous()
|
||||
loss = torch.empty(inputs.shape, dtype=torch.float32, device=inputs.device)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
sigmoid_focal_loss_fwd_kernel[grid](
|
||||
inputs, targets, loss, alpha, gamma, n_elements, SigmoidFocalLoss.BLOCK_SIZE
|
||||
)
|
||||
ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
|
||||
ctx.alpha = alpha
|
||||
ctx.gamma = gamma
|
||||
return loss.view(input_shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
inputs, targets = ctx.saved_tensors
|
||||
alpha = ctx.alpha
|
||||
gamma = ctx.gamma
|
||||
n_elements = inputs.numel()
|
||||
input_shape = inputs.shape
|
||||
grad_inputs = torch.empty(
|
||||
inputs.shape, dtype=grad_output.dtype, device=grad_output.device
|
||||
)
|
||||
inputs_ptr = inputs.view(-1).contiguous()
|
||||
targets_ptr = targets.view(-1).contiguous()
|
||||
grad_output_ptr = grad_output.view(-1).contiguous()
|
||||
grad_inputs_ptr = grad_inputs
|
||||
assert grad_output.numel() == n_elements
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
sigmoid_focal_loss_bwd_kernel[grid](
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
grad_inputs_ptr,
|
||||
grad_output_ptr,
|
||||
alpha,
|
||||
gamma,
|
||||
n_elements,
|
||||
SigmoidFocalLoss.BLOCK_SIZE,
|
||||
)
|
||||
return grad_inputs.view(input_shape), None, None, None
|
||||
|
||||
|
||||
triton_sigmoid_focal_loss = SigmoidFocalLoss.apply
|
||||
|
||||
|
||||
class SigmoidFocalLossReduced(torch.autograd.Function):
|
||||
BLOCK_SIZE = 256
|
||||
REDUCE_SIZE = 32
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
|
||||
n_elements = inputs.numel()
|
||||
input_shape = inputs.shape
|
||||
inputs = inputs.view(-1).contiguous()
|
||||
targets = targets.view(-1).contiguous()
|
||||
loss = torch.zeros(
|
||||
SigmoidFocalLossReduced.REDUCE_SIZE,
|
||||
device=inputs.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
sigmoid_focal_loss_fwd_kernel_reduce[grid](
|
||||
inputs,
|
||||
targets,
|
||||
loss,
|
||||
alpha,
|
||||
gamma,
|
||||
n_elements,
|
||||
SigmoidFocalLossReduced.BLOCK_SIZE,
|
||||
SigmoidFocalLossReduced.REDUCE_SIZE,
|
||||
)
|
||||
ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
|
||||
ctx.alpha = alpha
|
||||
ctx.gamma = gamma
|
||||
return loss.sum()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
inputs, targets = ctx.saved_tensors
|
||||
alpha = ctx.alpha
|
||||
gamma = ctx.gamma
|
||||
n_elements = inputs.numel()
|
||||
input_shape = inputs.shape
|
||||
grad_inputs = torch.empty(
|
||||
inputs.shape, dtype=grad_output.dtype, device=grad_output.device
|
||||
)
|
||||
inputs_ptr = inputs.view(-1).contiguous()
|
||||
targets_ptr = targets.reshape(-1).contiguous()
|
||||
assert grad_output.numel() == 1
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
sigmoid_focal_loss_bwd_kernel_reduce[grid](
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
grad_inputs,
|
||||
grad_output,
|
||||
alpha,
|
||||
gamma,
|
||||
n_elements,
|
||||
SigmoidFocalLossReduced.BLOCK_SIZE,
|
||||
)
|
||||
return grad_inputs.view(input_shape), None, None, None
|
||||
|
||||
|
||||
triton_sigmoid_focal_loss_reduce = SigmoidFocalLossReduced.apply
|
||||
Reference in New Issue
Block a user