Initial commit

fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
facebook-github-bot
2025-11-18 23:07:42 -08:00
commit a13e358df4
504 changed files with 122758 additions and 0 deletions

124
sam3/perflib/triton/nms.py Normal file
View File

@@ -0,0 +1,124 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# Adapted from https://github.com/stackav-oss/conch/blob/main/conch/kernels/vision/nms.py
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({"cxpr_block_size": 128}),
triton.Config({"cxpr_block_size": 256}),
triton.Config({"cxpr_block_size": 512}),
triton.Config({"cxpr_block_size": 1024}),
triton.Config({"cxpr_block_size": 2048}),
triton.Config({"cxpr_block_size": 4096}),
triton.Config({"cxpr_block_size": 8192}),
],
key=["num_boxes"],
)
@triton.jit
def _nms_suppression_kernel(
# Tensors
iou_mask_ptr: tl.tensor, # [N, N]
keep_mask_ptr: tl.tensor, # [N]
# Scalars
num_boxes: tl.int32,
# Strides
iou_mask_stride: tl.int32,
# Constexprs
cxpr_block_size: tl.constexpr,
) -> None:
"""NMS suppression kernel.
Args:
iou_mask_ptr: Pointer to precomputed IoU mask, shape: (N, N).
keep_mask_ptr: Pointer to keep mask tensor, shape: (N,).
num_boxes: Number of boxes.
iou_mask_stride: Stride for IoU mask tensor.
cxpr_block_size: Block size for processing.
"""
# Sequential NMS: for each box in sorted order, suppress later boxes
for current_box_idx in range(num_boxes - 1):
# Check if current box is still kept
is_kept = tl.load(keep_mask_ptr + current_box_idx)
if is_kept:
# IoU mask row offset for the current box
# Because the IoU mask is sorted by score, we will only consider boxes that come after the current box.
# This means we only need to read the upper triangular part of the IoU mask.
iou_row_offset = current_box_idx * iou_mask_stride
# Only process boxes that come after the current box
next_box_idx = current_box_idx + 1
remaining_boxes = num_boxes - next_box_idx
# Iterate blockwise through the columns
for block_idx in range(tl.cdiv(remaining_boxes, cxpr_block_size)):
# Masked load of indices for the target boxes in the current block
block_start = next_box_idx + block_idx * cxpr_block_size
target_box_offsets = block_start + tl.arange(0, cxpr_block_size)
target_box_mask = target_box_offsets < num_boxes
# Suppress boxes with lower scores that have high IoU
suppression_mask = tl.load(
iou_mask_ptr + iou_row_offset + target_box_offsets,
mask=target_box_mask,
other=False,
)
suppression_mask = tl.cast(suppression_mask, tl.int1)
# Conditionally store suppression result for high-IoU boxes
tl.store(
keep_mask_ptr + target_box_offsets, False, mask=suppression_mask
)
# Potential race condition: we need to ensure all threads complete the store before the next
# iteration otherwise we may load stale data for whether or not a box has been suppressed.
tl.debug_barrier()
def nms_triton(
ious: torch.Tensor,
scores: torch.Tensor,
iou_threshold: float,
) -> torch.Tensor:
"""Perform NMS given the iou matrix, the scores and the iou threshold
Args:
ious: Pairwise IoU tensor of shape (N, N).
scores: Scores tensor of shape (N,).
iou_threshold: IoU threshold for suppression.
Returns:
Tensor: Indices of kept boxes, sorted by decreasing score.
"""
assert scores.dim() == 1, "Scores must be 1D"
iou_mask = ious > iou_threshold
assert iou_mask.dim() == 2
assert iou_mask.shape[0] == iou_mask.shape[1] == scores.shape[0]
assert iou_mask.device == scores.device
assert iou_mask.dtype == torch.bool
num_boxes = scores.size(0)
keep_mask = torch.ones(len(scores), device=scores.device, dtype=torch.bool)
# Sort boxes by scores in descending order
_, sorted_indices = torch.sort(scores, dim=0, stable=True, descending=True)
iou_mask = iou_mask[sorted_indices][:, sorted_indices].contiguous()
# For the suppression stage, we need to process sequentially, but we'll still take
# advantage of parallelism by processing in blocks in one program.
stage2_grid = (1,)
_nms_suppression_kernel[stage2_grid](
# Tensors
iou_mask_ptr=iou_mask,
keep_mask_ptr=keep_mask,
# Scalars
num_boxes=num_boxes,
# Strides
iou_mask_stride=iou_mask.stride(0),
)
# Extract indices of kept boxes
return sorted_indices[keep_mask]