Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
8
sam3/perflib/__init__.py
Normal file
8
sam3/perflib/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import os
|
||||
|
||||
is_enabled = False
|
||||
if os.getenv("USE_PERFLIB", "1") == "1":
|
||||
# print("Enabled the use of perflib.\n", end="")
|
||||
is_enabled = True
|
||||
137
sam3/perflib/associate_det_trk.py
Normal file
137
sam3/perflib/associate_det_trk.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sam3.perflib.masks_ops import mask_iou
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
def associate_det_trk(
|
||||
det_masks,
|
||||
track_masks,
|
||||
iou_threshold=0.5,
|
||||
iou_threshold_trk=0.5,
|
||||
det_scores=None,
|
||||
new_det_thresh=0.0,
|
||||
):
|
||||
"""
|
||||
Optimized implementation of detection <-> track association that minimizes DtoH syncs.
|
||||
|
||||
Args:
|
||||
det_masks: (N, H, W) tensor of predicted masks
|
||||
track_masks: (M, H, W) tensor of track masks
|
||||
|
||||
Returns:
|
||||
new_det_indices: list of indices in det_masks considered 'new'
|
||||
unmatched_trk_indices: list of indices in track_masks considered 'unmatched'
|
||||
"""
|
||||
with torch.autograd.profiler.record_function("perflib: associate_det_trk"):
|
||||
assert isinstance(det_masks, torch.Tensor), "det_masks should be a tensor"
|
||||
assert isinstance(track_masks, torch.Tensor), "track_masks should be a tensor"
|
||||
if det_masks.size(0) == 0 or track_masks.size(0) == 0:
|
||||
return list(range(det_masks.size(0))), [], {}, {} # all detections are new
|
||||
|
||||
if list(det_masks.shape[-2:]) != list(track_masks.shape[-2:]):
|
||||
# resize to the smaller size to save GPU memory
|
||||
if torch.numel(det_masks[-2:]) < torch.numel(track_masks[-2:]):
|
||||
track_masks = (
|
||||
F.interpolate(
|
||||
track_masks.unsqueeze(1).float(),
|
||||
size=det_masks.shape[-2:],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(1)
|
||||
> 0
|
||||
)
|
||||
else:
|
||||
# resize detections to track size
|
||||
det_masks = (
|
||||
F.interpolate(
|
||||
det_masks.unsqueeze(1).float(),
|
||||
size=track_masks.shape[-2:],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(1)
|
||||
> 0
|
||||
)
|
||||
|
||||
det_masks = det_masks > 0
|
||||
track_masks = track_masks > 0
|
||||
|
||||
iou = mask_iou(det_masks, track_masks) # (N, M)
|
||||
igeit = iou >= iou_threshold
|
||||
igeit_any_dim_1 = igeit.any(dim=1)
|
||||
igeit_trk = iou >= iou_threshold_trk
|
||||
|
||||
iou_list = iou.cpu().numpy().tolist()
|
||||
igeit_list = igeit.cpu().numpy().tolist()
|
||||
igeit_any_dim_1_list = igeit_any_dim_1.cpu().numpy().tolist()
|
||||
igeit_trk_list = igeit_trk.cpu().numpy().tolist()
|
||||
|
||||
det_scores_list = (
|
||||
det_scores
|
||||
if det_scores is None
|
||||
else det_scores.cpu().float().numpy().tolist()
|
||||
)
|
||||
|
||||
# Hungarian matching for tracks (one-to-one: each track matches at most one detection)
|
||||
# For detections: allow many tracks to match to the same detection (many-to-one)
|
||||
|
||||
# If either is empty, return all detections as new
|
||||
if det_masks.size(0) == 0 or track_masks.size(0) == 0:
|
||||
return list(range(det_masks.size(0))), [], {}
|
||||
|
||||
# Hungarian matching: maximize IoU for tracks
|
||||
cost_matrix = 1 - iou.cpu().numpy() # Hungarian solves for minimum cost
|
||||
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
||||
|
||||
def branchy_hungarian_better_uses_the_cpu(
|
||||
cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks
|
||||
):
|
||||
matched_trk = set()
|
||||
matched_det = set()
|
||||
matched_det_scores = {} # track index -> [det_score, det_score * iou] det score of matched detection mask
|
||||
for d, t in zip(row_ind, col_ind):
|
||||
matched_det_scores[t] = [
|
||||
det_scores_list[d],
|
||||
det_scores_list[d] * iou_list[d][t],
|
||||
]
|
||||
if igeit_trk_list[d][t]:
|
||||
matched_trk.add(t)
|
||||
matched_det.add(d)
|
||||
|
||||
# Tracks not matched by Hungarian assignment above threshold are unmatched
|
||||
unmatched_trk_indices = [
|
||||
t for t in range(track_masks.size(0)) if t not in matched_trk
|
||||
]
|
||||
|
||||
# For detections: allow many tracks to match to the same detection (many-to-one)
|
||||
# So, a detection is 'new' if it does not match any track above threshold
|
||||
assert track_masks.size(0) == igeit.size(
|
||||
1
|
||||
) # Needed for loop optimizaiton below
|
||||
new_det_indices = []
|
||||
for d in range(det_masks.size(0)):
|
||||
if not igeit_any_dim_1_list[d]:
|
||||
if det_scores is not None and det_scores[d] >= new_det_thresh:
|
||||
new_det_indices.append(d)
|
||||
|
||||
# for each detection, which tracks it matched to (above threshold)
|
||||
det_to_matched_trk = defaultdict(list)
|
||||
for d in range(det_masks.size(0)):
|
||||
for t in range(track_masks.size(0)):
|
||||
if igeit_list[d][t]:
|
||||
det_to_matched_trk[d].append(t)
|
||||
|
||||
return (
|
||||
new_det_indices,
|
||||
unmatched_trk_indices,
|
||||
det_to_matched_trk,
|
||||
matched_det_scores,
|
||||
)
|
||||
|
||||
return (branchy_hungarian_better_uses_the_cpu)(
|
||||
cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks
|
||||
)
|
||||
99
sam3/perflib/compile.py
Normal file
99
sam3/perflib/compile.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def recursive_fn_factory(fn):
|
||||
def recursive_fn(b):
|
||||
if isinstance(b, dict):
|
||||
return {k: recursive_fn(b[k]) for k in b}
|
||||
if isinstance(b, list):
|
||||
return [recursive_fn(t) for t in b]
|
||||
if isinstance(b, tuple):
|
||||
return tuple(recursive_fn(t) for t in b)
|
||||
if isinstance(b, torch.Tensor):
|
||||
return fn(b)
|
||||
# Yes, writing out an explicit white list of
|
||||
# trivial types is tedious, but so are bugs that
|
||||
# come from not applying fn, when expected to have
|
||||
# applied it.
|
||||
if b is None:
|
||||
return b
|
||||
trivial_types = [bool, int]
|
||||
for t in trivial_types:
|
||||
if isinstance(b, t):
|
||||
return b
|
||||
raise TypeError(f"Unexpected type {type(b)}")
|
||||
|
||||
return recursive_fn
|
||||
|
||||
|
||||
recursive_contiguous = recursive_fn_factory(lambda x: x.contiguous())
|
||||
recursive_clone = recursive_fn_factory(torch.clone)
|
||||
|
||||
|
||||
def compile_wrapper(
|
||||
fn, *, mode="max-autotune", fullgraph=True, dynamic=False, name=None
|
||||
):
|
||||
compiled_fn = torch.compile(fn, mode=mode, fullgraph=fullgraph, dynamic=dynamic)
|
||||
|
||||
def compiled_fn_wrapper(*args, **kwargs):
|
||||
with torch.autograd.profiler.record_function(
|
||||
f"compiled {fn}" if name is None else name
|
||||
):
|
||||
cont_args = recursive_contiguous(args)
|
||||
cont_kwargs = recursive_contiguous(kwargs)
|
||||
result = compiled_fn(*cont_args, **cont_kwargs)
|
||||
cloned_result = recursive_clone(result)
|
||||
return cloned_result
|
||||
|
||||
return compiled_fn_wrapper
|
||||
|
||||
|
||||
def shape_logging_wrapper(fn, keep_kwargs, enable_logging=False):
|
||||
"""
|
||||
Wraps a function and prints the shapes of all tensor inputs.
|
||||
Only prints when a new combination of shapes is seen.
|
||||
Thread-safe.
|
||||
|
||||
Args:
|
||||
fn: Function to wrap
|
||||
enable_logging: Boolean flag to enable/disable logging
|
||||
"""
|
||||
seen_shapes = set()
|
||||
|
||||
def get_shape(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.shape
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
if len(obj) > 1:
|
||||
return tuple(get_shape(x) for x in obj)
|
||||
return get_shape(obj[0])
|
||||
elif isinstance(obj, dict):
|
||||
return tuple(sorted((k, get_shape(v)) for k, v in obj.items()))
|
||||
else:
|
||||
return type(obj).__name__
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
shapes = tuple(get_shape(arg) for arg in args) + tuple(
|
||||
(k, get_shape(v))
|
||||
for k, v in kwargs.items()
|
||||
if isinstance(v, (torch.Tensor, list))
|
||||
and (len(keep_kwargs) > 0 and k in keep_kwargs)
|
||||
)
|
||||
if shapes not in seen_shapes:
|
||||
seen_shapes.add(shapes)
|
||||
if enable_logging:
|
||||
print(f"[ShapeLogger] New input shapes for {fn.__qualname__}: {shapes}")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
# Allow toggling the flag at runtime
|
||||
wrapper.enable_logging = enable_logging
|
||||
|
||||
def set_logging(enabled=False):
|
||||
nonlocal enable_logging
|
||||
enable_logging = enabled
|
||||
wrapper.enable_logging = enable_logging
|
||||
|
||||
wrapper.set_logging = set_logging
|
||||
return wrapper
|
||||
84
sam3/perflib/connected_components.py
Normal file
84
sam3/perflib/connected_components.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from cc_torch import get_connected_components
|
||||
|
||||
HAS_CC_TORCH = True
|
||||
except ImportError:
|
||||
logging.debug(
|
||||
"cc_torch not found. Consider installing for better performance. Command line:"
|
||||
" pip install git+https://github.com/ronghanghu/cc_torch.git"
|
||||
)
|
||||
HAS_CC_TORCH = False
|
||||
|
||||
|
||||
def connected_components_cpu_single(values: torch.Tensor):
|
||||
assert values.dim() == 2
|
||||
from skimage.measure import label
|
||||
|
||||
labels, num = label(values.cpu().numpy(), return_num=True)
|
||||
labels = torch.from_numpy(labels)
|
||||
counts = torch.zeros_like(labels)
|
||||
for i in range(1, num + 1):
|
||||
cur_mask = labels == i
|
||||
cur_count = cur_mask.sum()
|
||||
counts[cur_mask] = cur_count
|
||||
return labels, counts
|
||||
|
||||
|
||||
def connected_components_cpu(input_tensor: torch.Tensor):
|
||||
out_shape = input_tensor.shape
|
||||
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
|
||||
input_tensor = input_tensor.squeeze(1)
|
||||
else:
|
||||
assert (
|
||||
input_tensor.dim() == 3
|
||||
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||
|
||||
batch_size = input_tensor.shape[0]
|
||||
labels_list = []
|
||||
counts_list = []
|
||||
for b in range(batch_size):
|
||||
labels, counts = connected_components_cpu_single(input_tensor[b])
|
||||
labels_list.append(labels)
|
||||
counts_list.append(counts)
|
||||
labels_tensor = torch.stack(labels_list, dim=0).to(input_tensor.device)
|
||||
counts_tensor = torch.stack(counts_list, dim=0).to(input_tensor.device)
|
||||
return labels_tensor.view(out_shape), counts_tensor.view(out_shape)
|
||||
|
||||
|
||||
def connected_components(input_tensor: torch.Tensor):
|
||||
"""
|
||||
Computes connected components labeling on a batch of 2D tensors, using the best available backend.
|
||||
|
||||
Args:
|
||||
input_tensor (torch.Tensor): A BxHxW integer tensor or Bx1xHxW. Non-zero values are considered foreground. Bool tensor also accepted
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Both tensors have the same shape as input_tensor.
|
||||
- A tensor with dense labels. Background is 0.
|
||||
- A tensor with the size of the connected component for each pixel.
|
||||
"""
|
||||
if input_tensor.dim() == 3:
|
||||
input_tensor = input_tensor.unsqueeze(1)
|
||||
|
||||
assert (
|
||||
input_tensor.dim() == 4 and input_tensor.shape[1] == 1
|
||||
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||
|
||||
if input_tensor.is_cuda:
|
||||
if HAS_CC_TORCH:
|
||||
return get_connected_components(input_tensor.to(torch.uint8))
|
||||
else:
|
||||
# triton fallback
|
||||
from sam3.perflib.triton.connected_components import (
|
||||
connected_components_triton,
|
||||
)
|
||||
|
||||
return connected_components_triton(input_tensor)
|
||||
|
||||
# CPU fallback
|
||||
return connected_components_cpu(input_tensor)
|
||||
27
sam3/perflib/fa3.py
Normal file
27
sam3/perflib/fa3.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
||||
def flash_attn_func_op(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
from flash_attn_interface import flash_attn_func as fa3
|
||||
|
||||
return fa3(q, k, v)
|
||||
|
||||
|
||||
def flash_attn_func(q, k, v):
|
||||
dtype = torch.float8_e4m3fn
|
||||
return flash_attn_func_op(q.to(dtype), k.to(dtype), v.to(dtype)).to(q.dtype)
|
||||
|
||||
|
||||
@flash_attn_func_op.register_fake
|
||||
def _(q, k, v, **kwargs):
|
||||
# two outputs:
|
||||
# 1. output: (batch, seq_len, num_heads, head_dim)
|
||||
# 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
|
||||
# output needs to be bfloat16, not float8!
|
||||
meta_q = torch.empty_like(q, dtype=torch.bfloat16).contiguous()
|
||||
return meta_q
|
||||
69
sam3/perflib/masks_ops.py
Normal file
69
sam3/perflib/masks_ops.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def masks_to_boxes(masks: torch.Tensor, obj_ids: list[int]):
|
||||
with torch.autograd.profiler.record_function("perflib: masks_to_boxes"):
|
||||
# Sanity check based on callsite for replacement
|
||||
assert masks.shape[0] == len(obj_ids)
|
||||
assert masks.dim() == 3
|
||||
|
||||
# Based on torchvision masks_to_boxes
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
|
||||
|
||||
N, H, W = masks.shape
|
||||
device = masks.device
|
||||
y = torch.arange(H, device=device).view(1, H)
|
||||
x = torch.arange(W, device=device).view(1, W)
|
||||
|
||||
masks_with_obj = masks != 0 # N, H, W
|
||||
masks_with_obj_x = masks_with_obj.amax(
|
||||
dim=1
|
||||
) # N, H (which columns have objects)
|
||||
masks_with_obj_y = masks_with_obj.amax(dim=2) # N, W (which rows have objects)
|
||||
masks_without_obj_x = ~masks_with_obj_x
|
||||
masks_without_obj_y = ~masks_with_obj_y
|
||||
|
||||
bounding_boxes_0 = torch.amin(
|
||||
(masks_without_obj_x * W) + (masks_with_obj_x * x), dim=1
|
||||
)
|
||||
bounding_boxes_1 = torch.amin(
|
||||
(masks_without_obj_y * H) + (masks_with_obj_y * y), dim=1
|
||||
)
|
||||
bounding_boxes_2 = torch.amax(masks_with_obj_x * x, dim=1)
|
||||
bounding_boxes_3 = torch.amax(masks_with_obj_y * y, dim=1)
|
||||
|
||||
bounding_boxes = torch.stack(
|
||||
[bounding_boxes_0, bounding_boxes_1, bounding_boxes_2, bounding_boxes_3],
|
||||
dim=1,
|
||||
).to(dtype=torch.float)
|
||||
assert bounding_boxes.shape == (N, 4)
|
||||
assert bounding_boxes.device == masks.device
|
||||
assert bounding_boxes.dtype == torch.float
|
||||
return bounding_boxes
|
||||
|
||||
|
||||
def mask_iou(pred_masks: torch.Tensor, gt_masks: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the IoU (Intersection over Union) between predicted masks and ground truth masks.
|
||||
Args:
|
||||
- pred_masks: (N, H, W) bool Tensor, containing binary predicted segmentation masks
|
||||
- gt_masks: (M, H, W) bool Tensor, containing binary ground truth segmentation masks
|
||||
Returns:
|
||||
- ious: (N, M) float Tensor, containing IoUs for each pair of predicted and ground truth masks
|
||||
"""
|
||||
assert pred_masks.dtype == gt_masks.dtype == torch.bool
|
||||
N, H, W = pred_masks.shape
|
||||
M, _, _ = gt_masks.shape
|
||||
|
||||
# Flatten masks: (N, 1, H*W) and (1, M, H*W)
|
||||
pred_flat = pred_masks.view(N, 1, H * W)
|
||||
gt_flat = gt_masks.view(1, M, H * W)
|
||||
|
||||
# Compute intersection and union: (N, M)
|
||||
intersection = (pred_flat & gt_flat).sum(dim=2).float()
|
||||
union = (pred_flat | gt_flat).sum(dim=2).float()
|
||||
ious = intersection / union.clamp(min=1)
|
||||
return ious # shape: (N, M)
|
||||
91
sam3/perflib/nms.py
Normal file
91
sam3/perflib/nms.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sam3.perflib.masks_ops import mask_iou
|
||||
|
||||
|
||||
try:
|
||||
from torch_generic_nms import generic_nms as generic_nms_cuda
|
||||
|
||||
GENERIC_NMS_AVAILABLE = True
|
||||
except ImportError:
|
||||
logging.debug(
|
||||
"Falling back to triton or CPU mask NMS implementation -- please install `torch_generic_nms` via\n\t"
|
||||
'pip uninstall -y torch_generic_nms; TORCH_CUDA_ARCH_LIST="8.0 9.0" pip install git+https://github.com/ronghanghu/torch_generic_nms'
|
||||
)
|
||||
GENERIC_NMS_AVAILABLE = False
|
||||
|
||||
|
||||
def nms_masks(
|
||||
pred_probs: torch.Tensor,
|
||||
pred_masks: torch.Tensor,
|
||||
prob_threshold: float,
|
||||
iou_threshold: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
- pred_probs: (num_det,) float Tensor, containing the score (probability) of each detection
|
||||
- pred_masks: (num_det, H_mask, W_mask) float Tensor, containing the binary segmentation mask of each detection
|
||||
- prob_threshold: float, score threshold to prefilter detections (NMS is performed on detections above threshold)
|
||||
- iou_threshold: float, mask IoU threshold for NMS
|
||||
|
||||
Returns:
|
||||
- keep: (num_det,) bool Tensor, indicating whether each detection is kept after score thresholding + NMS
|
||||
"""
|
||||
# prefilter the detections with prob_threshold ("valid" are those above prob_threshold)
|
||||
is_valid = pred_probs > prob_threshold # (num_det,)
|
||||
probs = pred_probs[is_valid] # (num_valid,)
|
||||
masks_binary = pred_masks[is_valid] > 0 # (num_valid, H_mask, W_mask)
|
||||
if probs.numel() == 0:
|
||||
return is_valid # no valid detection, return empty keep mask
|
||||
|
||||
ious = mask_iou(masks_binary, masks_binary) # (num_valid, num_valid)
|
||||
kept_inds = generic_nms(ious, probs, iou_threshold)
|
||||
|
||||
# valid_inds are the indices among `probs` of valid detections before NMS (or -1 for invalid)
|
||||
valid_inds = torch.where(is_valid, is_valid.cumsum(dim=0) - 1, -1) # (num_det,)
|
||||
keep = torch.isin(valid_inds, kept_inds) # (num_det,)
|
||||
return keep
|
||||
|
||||
|
||||
def generic_nms(
|
||||
ious: torch.Tensor, scores: torch.Tensor, iou_threshold=0.5
|
||||
) -> torch.Tensor:
|
||||
"""A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix."""
|
||||
|
||||
assert ious.dim() == 2 and ious.size(0) == ious.size(1)
|
||||
assert scores.dim() == 1 and scores.size(0) == ious.size(0)
|
||||
|
||||
if ious.is_cuda:
|
||||
if GENERIC_NMS_AVAILABLE:
|
||||
return generic_nms_cuda(ious, scores, iou_threshold, use_iou_matrix=True)
|
||||
else:
|
||||
from sam3.perflib.triton.nms import nms_triton
|
||||
|
||||
return nms_triton(ious, scores, iou_threshold)
|
||||
|
||||
return generic_nms_cpu(ious, scores, iou_threshold)
|
||||
|
||||
|
||||
def generic_nms_cpu(
|
||||
ious: torch.Tensor, scores: torch.Tensor, iou_threshold=0.5
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix. (CPU implementation
|
||||
based on https://github.com/jwyang/faster-rcnn.pytorch/blob/master/lib/model/nms/nms_cpu.py)
|
||||
"""
|
||||
ious_np = ious.float().detach().cpu().numpy()
|
||||
scores_np = scores.float().detach().cpu().numpy()
|
||||
order = scores_np.argsort()[::-1]
|
||||
kept_inds = []
|
||||
while order.size > 0:
|
||||
i = order.item(0)
|
||||
kept_inds.append(i)
|
||||
inds = np.where(ious_np[i, order[1:]] <= iou_threshold)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return torch.tensor(kept_inds, dtype=torch.int64, device=scores.device)
|
||||
BIN
sam3/perflib/tests/assets/masks.tiff
Normal file
BIN
sam3/perflib/tests/assets/masks.tiff
Normal file
Binary file not shown.
59
sam3/perflib/tests/tests.py
Normal file
59
sam3/perflib/tests/tests.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from sam3.perflib.masks_ops import masks_to_boxes
|
||||
|
||||
|
||||
class TestMasksToBoxes:
|
||||
def test_masks_box(self):
|
||||
def masks_box_check(masks, expected, atol=1e-4):
|
||||
out = masks_to_boxes(masks, [1 for _ in range(masks.shape[0])])
|
||||
assert out.dtype == torch.float
|
||||
print("out: ", out)
|
||||
print("expected: ", expected)
|
||||
torch.testing.assert_close(
|
||||
out, expected, rtol=0.0, check_dtype=True, atol=atol
|
||||
)
|
||||
|
||||
# Check for int type boxes.
|
||||
def _get_image():
|
||||
assets_directory = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "assets"
|
||||
)
|
||||
mask_path = os.path.join(assets_directory, "masks.tiff")
|
||||
image = Image.open(mask_path)
|
||||
return image
|
||||
|
||||
def _create_masks(image, masks):
|
||||
for index in range(image.n_frames):
|
||||
image.seek(index)
|
||||
frame = np.array(image)
|
||||
masks[index] = torch.tensor(frame)
|
||||
|
||||
return masks
|
||||
|
||||
expected = torch.tensor(
|
||||
[
|
||||
[127, 2, 165, 40],
|
||||
[2, 50, 44, 92],
|
||||
[56, 63, 98, 100],
|
||||
[139, 68, 175, 104],
|
||||
[160, 112, 198, 145],
|
||||
[49, 138, 99, 182],
|
||||
[108, 148, 152, 213],
|
||||
],
|
||||
dtype=torch.float,
|
||||
)
|
||||
|
||||
image = _get_image()
|
||||
for dtype in [torch.float16, torch.float32, torch.float64]:
|
||||
masks = torch.zeros(
|
||||
(image.n_frames, image.height, image.width), dtype=dtype
|
||||
)
|
||||
masks = _create_masks(image, masks)
|
||||
masks_box_check(masks, expected)
|
||||
468
sam3/perflib/triton/connected_components.py
Normal file
468
sam3/perflib/triton/connected_components.py
Normal file
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _any_combine(a, b):
|
||||
return a | b
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tl_any(a, dim=0):
|
||||
return tl.reduce(a, dim, _any_combine)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ## Phase 1: Initialization Kernel
|
||||
# ==============================================================================
|
||||
# Each foreground pixel (value > 0) gets a unique label equal to its
|
||||
# linear index. Background pixels (value == 0) get a sentinel label of -1.
|
||||
# Note that the indexing is done across batch boundaries for simplicity
|
||||
# (i.e., the first pixel of image 1 gets label H*W, etc.)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _init_labels_kernel(
|
||||
input_ptr, labels_ptr, numel: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < numel
|
||||
input_values = tl.load(input_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
indices = tl.where((input_values != 0), offsets, -1)
|
||||
tl.store(labels_ptr + offsets, indices, mask=mask)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ## Phase 2: Local merging
|
||||
# ==============================================================================
|
||||
# Each pixel tries to merge with its 8-connected neighbors (up, down, left, right)
|
||||
# if they have the same value. This is done using a disjoint-set union operation.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def find(labels_ptr, indices, mask):
|
||||
current_pids = indices
|
||||
|
||||
# 'is_done' tracks lanes that have finished their work.
|
||||
# A lane is initially "done" if it's not active (mask is False).
|
||||
is_done = ~mask
|
||||
|
||||
# Loop as long as there is at least one lane that is NOT done.
|
||||
while tl_any(~is_done):
|
||||
# The work_mask is for lanes that are still active and seeking their root.
|
||||
work_mask = ~is_done
|
||||
parents = tl.load(labels_ptr + current_pids, mask=work_mask, other=-1)
|
||||
# A lane is now done if its parent is itself (it's a root)
|
||||
# or if it hits a -1 sentinel (a safe exit condition).
|
||||
is_root = parents == current_pids
|
||||
is_sentinel = parents == -1
|
||||
is_done |= is_root | is_sentinel
|
||||
|
||||
# For lanes that are not yet done, update their pid to their parent to continue traversal.
|
||||
current_pids = tl.where(is_done, current_pids, parents)
|
||||
# We could add the following line to do path compression, but experimentally it's slower
|
||||
# tl.atomic_min(labels_ptr + indices, current_pids, mask=mask)
|
||||
return current_pids
|
||||
|
||||
|
||||
@triton.jit
|
||||
def union(labels_ptr, a, b, process_mask):
|
||||
# This function implements a disjoint-set union
|
||||
# As an invariant, we use the fact that the roots have the lower id. That helps parallelization
|
||||
# However, that is not sufficient by itself. Suppose two threads want to do union(0,2) and union(1,2) at the same time
|
||||
# Then if we do a naive atomic_min, 0 and 1 will compete to be the new parent of 2 and min(0, 1) will win.
|
||||
# However, 1 still needs to be merged with the new {0, 2} component.
|
||||
# To ensure that merge is also done, we need to detect whether the merge was successful, and if not retry until it is
|
||||
|
||||
current_a = a
|
||||
current_b = b
|
||||
|
||||
final_root = a
|
||||
# A mask to track which lanes have successfully completed their union.
|
||||
done_mask = ~process_mask # tl.zeros_like(a) == 1 # Init with all False
|
||||
|
||||
while tl_any(~done_mask):
|
||||
# Define the mask for lanes that still need work in this iteration
|
||||
work_mask = process_mask & ~done_mask
|
||||
|
||||
# Find the roots for the current a and b values in the active lanes
|
||||
root_a = find(labels_ptr, current_a, work_mask)
|
||||
tl.debug_barrier()
|
||||
root_b = find(labels_ptr, current_b, work_mask)
|
||||
|
||||
# 7. Merge logic
|
||||
# If roots are already the same, the sets are already merged. Mark as done.
|
||||
are_equal = root_a == root_b
|
||||
final_root = tl.where(are_equal & work_mask & ~done_mask, root_a, final_root)
|
||||
done_mask |= are_equal & work_mask
|
||||
|
||||
# Define masks for the two merge cases (a < b or b < a)
|
||||
a_is_smaller = root_a < root_b
|
||||
|
||||
# Case 1: root_a < root_b. Attempt to set parent[root_b] = root_a
|
||||
merge_mask_a_smaller = work_mask & a_is_smaller & ~are_equal
|
||||
ptr_b = labels_ptr + root_b
|
||||
old_val_b = tl.atomic_min(ptr_b, root_a, mask=merge_mask_a_smaller)
|
||||
|
||||
# A lane is done if its atomic op was successful (old value was what we expected)
|
||||
success_b = old_val_b == root_b
|
||||
final_root = tl.where(success_b & work_mask & ~done_mask, root_a, final_root)
|
||||
done_mask |= success_b & merge_mask_a_smaller
|
||||
|
||||
# *** Crucial Retry Logic ***
|
||||
# If the update failed (old_val_b != root_b), another thread interfered.
|
||||
# We update `current_b` to this new root (`old_val_b`) and will retry in the next loop iteration.
|
||||
current_b = tl.where(success_b | ~merge_mask_a_smaller, current_b, old_val_b)
|
||||
|
||||
# Case 2: root_b < root_a. Attempt to set parent[root_a] = root_b
|
||||
merge_mask_b_smaller = work_mask & ~a_is_smaller & ~are_equal
|
||||
ptr_a = labels_ptr + root_a
|
||||
old_val_a = tl.atomic_min(ptr_a, root_b, mask=merge_mask_b_smaller)
|
||||
|
||||
success_a = old_val_a == root_a
|
||||
final_root = tl.where(success_a & work_mask & ~done_mask, root_b, final_root)
|
||||
done_mask |= success_a & merge_mask_b_smaller
|
||||
|
||||
# *** Crucial Retry Logic ***
|
||||
# Similarly, update `current_a` if the atomic operation failed.
|
||||
current_a = tl.where(success_a | ~merge_mask_b_smaller, current_a, old_val_a)
|
||||
|
||||
return final_root
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _merge_helper(
|
||||
input_ptr,
|
||||
labels_ptr,
|
||||
base_offset,
|
||||
offsets_h,
|
||||
offsets_w,
|
||||
mask_2d,
|
||||
valid_current,
|
||||
current_values,
|
||||
current_labels,
|
||||
H,
|
||||
W,
|
||||
dx: tl.constexpr,
|
||||
dy: tl.constexpr,
|
||||
):
|
||||
# Helper functions to compute merge with a specific neighbor offset (dx, dy)
|
||||
|
||||
neighbor_h = offsets_h + dy
|
||||
neighbor_w = offsets_w + dx
|
||||
# Proper bounds checking: all four bounds must be satisfied
|
||||
mask_n = (
|
||||
mask_2d
|
||||
& (neighbor_h[:, None] >= 0)
|
||||
& (neighbor_h[:, None] < H)
|
||||
& (neighbor_w[None, :] >= 0)
|
||||
& (neighbor_w[None, :] < W)
|
||||
)
|
||||
|
||||
offsets_neighbor = neighbor_h[:, None] * W + neighbor_w[None, :]
|
||||
neighbor_values = tl.load(
|
||||
input_ptr + base_offset + offsets_neighbor, mask=mask_n, other=-1
|
||||
)
|
||||
|
||||
mask_n = tl.ravel(mask_n)
|
||||
neighbor_labels = tl.load(
|
||||
labels_ptr + tl.ravel(base_offset + offsets_neighbor), mask=mask_n, other=-1
|
||||
)
|
||||
|
||||
to_merge = (
|
||||
mask_n & (neighbor_labels != -1) & tl.ravel(current_values == neighbor_values)
|
||||
)
|
||||
valid_write = valid_current & to_merge
|
||||
|
||||
# returns new parents for the pixels that were merged (otherwise keeps current labels)
|
||||
return tl.where(
|
||||
valid_write,
|
||||
union(labels_ptr, current_labels, neighbor_labels, valid_write),
|
||||
current_labels,
|
||||
)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_H": 4, "BLOCK_SIZE_W": 16}, num_stages=1, num_warps=2
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_H": 4, "BLOCK_SIZE_W": 32}, num_stages=2, num_warps=4
|
||||
),
|
||||
],
|
||||
key=["H", "W"],
|
||||
restore_value=["labels_ptr"],
|
||||
)
|
||||
@triton.jit
|
||||
def _local_prop_kernel(
|
||||
labels_ptr,
|
||||
input_ptr,
|
||||
H: tl.constexpr,
|
||||
W: tl.constexpr,
|
||||
BLOCK_SIZE_H: tl.constexpr,
|
||||
BLOCK_SIZE_W: tl.constexpr,
|
||||
):
|
||||
# This is the meat of the Phase 2 to do local merging
|
||||
# It will be launched with a 2D grid:
|
||||
# - dim 0: batch index
|
||||
# - dim 1: block index over HxW image (2D tiling)
|
||||
pid_b = tl.program_id(0)
|
||||
pid_hw = tl.program_id(1)
|
||||
|
||||
# Calculate offsets for the core block
|
||||
offsets_h = (pid_hw // tl.cdiv(W, BLOCK_SIZE_W)) * BLOCK_SIZE_H + tl.arange(
|
||||
0, BLOCK_SIZE_H
|
||||
)
|
||||
offsets_w = (pid_hw % tl.cdiv(W, BLOCK_SIZE_W)) * BLOCK_SIZE_W + tl.arange(
|
||||
0, BLOCK_SIZE_W
|
||||
)
|
||||
|
||||
base_offset = pid_b * H * W
|
||||
offsets_2d = offsets_h[:, None] * W + offsets_w[None, :]
|
||||
mask_2d = (offsets_h[:, None] < H) & (offsets_w[None, :] < W)
|
||||
mask_1d = tl.ravel(mask_2d)
|
||||
|
||||
# Load the current labels for the block - these are parent pointers
|
||||
current_labels = tl.load(
|
||||
labels_ptr + tl.ravel(base_offset + offsets_2d), mask=mask_1d, other=-1
|
||||
)
|
||||
current_values = tl.load(
|
||||
input_ptr + base_offset + offsets_2d, mask=mask_2d, other=-1
|
||||
)
|
||||
valid_current = mask_1d & (current_labels != -1)
|
||||
|
||||
# Horizontal merge
|
||||
current_labels = _merge_helper(
|
||||
input_ptr,
|
||||
labels_ptr,
|
||||
base_offset,
|
||||
offsets_h,
|
||||
offsets_w,
|
||||
mask_2d,
|
||||
valid_current,
|
||||
current_values,
|
||||
current_labels,
|
||||
H,
|
||||
W,
|
||||
-1,
|
||||
0,
|
||||
)
|
||||
# Vertical merge
|
||||
current_labels = _merge_helper(
|
||||
input_ptr,
|
||||
labels_ptr,
|
||||
base_offset,
|
||||
offsets_h,
|
||||
offsets_w,
|
||||
mask_2d,
|
||||
valid_current,
|
||||
current_values,
|
||||
current_labels,
|
||||
H,
|
||||
W,
|
||||
0,
|
||||
-1,
|
||||
)
|
||||
# Diagonal merges
|
||||
current_labels = _merge_helper(
|
||||
input_ptr,
|
||||
labels_ptr,
|
||||
base_offset,
|
||||
offsets_h,
|
||||
offsets_w,
|
||||
mask_2d,
|
||||
valid_current,
|
||||
current_values,
|
||||
current_labels,
|
||||
H,
|
||||
W,
|
||||
-1,
|
||||
-1,
|
||||
)
|
||||
current_labels = _merge_helper(
|
||||
input_ptr,
|
||||
labels_ptr,
|
||||
base_offset,
|
||||
offsets_h,
|
||||
offsets_w,
|
||||
mask_2d,
|
||||
valid_current,
|
||||
current_values,
|
||||
current_labels,
|
||||
H,
|
||||
W,
|
||||
-1,
|
||||
1,
|
||||
)
|
||||
|
||||
# This actually does some path compression, in a lightweight but beneficial way
|
||||
tl.atomic_min(
|
||||
labels_ptr + tl.ravel(base_offset + offsets_2d), current_labels, mask=mask_1d
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ## Phase 3: Pointer Jumping Kernel
|
||||
# ==============================================================================
|
||||
# This kernel performs pointer jumping to ensure that all pixels point directly to their root labels.
|
||||
# This is done in a loop until convergence.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _pointer_jump_kernel(
|
||||
labels_in_ptr, labels_out_ptr, numel: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
"""
|
||||
Pointer jumping kernel with double buffering to avoid race conditions.
|
||||
Reads from labels_in_ptr and writes to labels_out_ptr.
|
||||
"""
|
||||
# This kernel is launched with a 1D grid, and does not care about batching explicitly.
|
||||
# By construction, the labels are global indices across the batch, and we never perform
|
||||
# cross-batch merges, so this is safe.
|
||||
|
||||
pid = tl.program_id(0)
|
||||
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < numel
|
||||
|
||||
# Load current labels from input buffer
|
||||
current_labels = tl.load(labels_in_ptr + offsets, mask=mask, other=-1)
|
||||
valid_mask = mask & (current_labels != -1)
|
||||
|
||||
# A mask to track which lanes have successfully completed their union.
|
||||
done_mask = ~valid_mask
|
||||
while tl_any(~(done_mask | ~valid_mask)):
|
||||
parent_labels = tl.load(
|
||||
labels_in_ptr + current_labels, mask=valid_mask, other=-1
|
||||
)
|
||||
|
||||
are_equal = current_labels == parent_labels
|
||||
done_mask |= are_equal & valid_mask
|
||||
|
||||
current_labels = tl.where(
|
||||
~done_mask, tl.minimum(current_labels, parent_labels), current_labels
|
||||
)
|
||||
|
||||
# Write to output buffer (safe because we're not reading from it)
|
||||
tl.store(labels_out_ptr + offsets, current_labels, mask=mask)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ## Phase 4: Kernels for Computing Component Sizes
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# Step 4.1: Count occurrences of each root label using atomic adds.
|
||||
@triton.jit
|
||||
def _count_labels_kernel(labels_ptr, sizes_ptr, numel, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < numel
|
||||
|
||||
# Load the final, converged labels
|
||||
labels = tl.load(labels_ptr + offsets, mask=mask, other=-1)
|
||||
valid_mask = mask & (labels != -1)
|
||||
|
||||
# Atomically increment the counter for each label. This builds a histogram.
|
||||
tl.atomic_add(sizes_ptr + labels, 1, mask=valid_mask)
|
||||
|
||||
|
||||
# Step 4.2: Broadcast the computed sizes back to the output tensor.
|
||||
@triton.jit
|
||||
def _broadcast_sizes_kernel(
|
||||
labels_ptr, sizes_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < numel
|
||||
|
||||
# Load the final labels
|
||||
labels = tl.load(labels_ptr + offsets, mask=mask, other=-1)
|
||||
valid_mask = mask & (labels != -1)
|
||||
|
||||
# Look up the size for each label from the histogram
|
||||
component_sizes = tl.load(sizes_ptr + labels, mask=valid_mask, other=0)
|
||||
|
||||
# Write the size to the final output tensor. Background pixels get size 0.
|
||||
tl.store(out_ptr + offsets, component_sizes, mask=mask)
|
||||
|
||||
|
||||
def connected_components_triton(input_tensor: torch.Tensor):
|
||||
"""
|
||||
Computes connected components labeling on a batch of 2D integer tensors using Triton.
|
||||
|
||||
Args:
|
||||
input_tensor (torch.Tensor): A BxHxW integer tensor or Bx1xHxW. Non-zero values are considered foreground. Bool tensor also accepted
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, int]: A tuple containing:
|
||||
- A BxHxW output tensor with dense labels. Background is 0.
|
||||
- A BxHxW tensor with the size of the connected component for each pixel.
|
||||
"""
|
||||
assert (
|
||||
input_tensor.is_cuda and input_tensor.is_contiguous()
|
||||
), "Input tensor must be a contiguous CUDA tensor."
|
||||
out_shape = input_tensor.shape
|
||||
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
|
||||
input_tensor = input_tensor.squeeze(1)
|
||||
else:
|
||||
assert (
|
||||
input_tensor.dim() == 3
|
||||
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||
|
||||
B, H, W = input_tensor.shape
|
||||
numel = B * H * W
|
||||
device = input_tensor.device
|
||||
|
||||
# --- Allocate Tensors ---
|
||||
labels = torch.empty_like(input_tensor, dtype=torch.int32)
|
||||
output = torch.empty_like(input_tensor, dtype=torch.int32)
|
||||
|
||||
# --- Phase 1 ---
|
||||
BLOCK_SIZE = 256
|
||||
grid_init = (triton.cdiv(numel, BLOCK_SIZE),)
|
||||
_init_labels_kernel[grid_init](
|
||||
input_tensor,
|
||||
labels,
|
||||
numel,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# --- Phase 2 ---
|
||||
grid_local_prop = lambda meta: (
|
||||
B,
|
||||
triton.cdiv(H, meta["BLOCK_SIZE_H"]) * triton.cdiv(W, meta["BLOCK_SIZE_W"]),
|
||||
)
|
||||
_local_prop_kernel[grid_local_prop](labels, input_tensor, H, W)
|
||||
|
||||
# --- Phase 3 ---
|
||||
BLOCK_SIZE = 256
|
||||
grid_jump = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
|
||||
_pointer_jump_kernel[grid_jump](labels, output, numel, BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
# --- Phase 4 ---
|
||||
# Allocate tensor to store the final output sizes
|
||||
component_sizes_out = torch.empty_like(input_tensor, dtype=torch.int32)
|
||||
|
||||
# Allocate a temporary 1D tensor to act as the histogram
|
||||
# Size is numel because labels can be up to numel-1
|
||||
sizes_histogram = torch.zeros(numel, dtype=torch.int32, device=device)
|
||||
|
||||
# 4.1: Count the occurrences of each label
|
||||
grid_count = (triton.cdiv(numel, BLOCK_SIZE),)
|
||||
_count_labels_kernel[grid_count](
|
||||
output, sizes_histogram, numel, BLOCK_SIZE=BLOCK_SIZE
|
||||
)
|
||||
|
||||
# 2.2: Broadcast the counts to the final output tensor
|
||||
grid_broadcast = (triton.cdiv(numel, BLOCK_SIZE),)
|
||||
_broadcast_sizes_kernel[grid_broadcast](
|
||||
output, sizes_histogram, component_sizes_out, numel, BLOCK_SIZE=BLOCK_SIZE
|
||||
)
|
||||
return output.view(out_shape) + 1, component_sizes_out.view(out_shape)
|
||||
124
sam3/perflib/triton/nms.py
Normal file
124
sam3/perflib/triton/nms.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
# Adapted from https://github.com/stackav-oss/conch/blob/main/conch/kernels/vision/nms.py
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"cxpr_block_size": 128}),
|
||||
triton.Config({"cxpr_block_size": 256}),
|
||||
triton.Config({"cxpr_block_size": 512}),
|
||||
triton.Config({"cxpr_block_size": 1024}),
|
||||
triton.Config({"cxpr_block_size": 2048}),
|
||||
triton.Config({"cxpr_block_size": 4096}),
|
||||
triton.Config({"cxpr_block_size": 8192}),
|
||||
],
|
||||
key=["num_boxes"],
|
||||
)
|
||||
@triton.jit
|
||||
def _nms_suppression_kernel(
|
||||
# Tensors
|
||||
iou_mask_ptr: tl.tensor, # [N, N]
|
||||
keep_mask_ptr: tl.tensor, # [N]
|
||||
# Scalars
|
||||
num_boxes: tl.int32,
|
||||
# Strides
|
||||
iou_mask_stride: tl.int32,
|
||||
# Constexprs
|
||||
cxpr_block_size: tl.constexpr,
|
||||
) -> None:
|
||||
"""NMS suppression kernel.
|
||||
|
||||
Args:
|
||||
iou_mask_ptr: Pointer to precomputed IoU mask, shape: (N, N).
|
||||
keep_mask_ptr: Pointer to keep mask tensor, shape: (N,).
|
||||
num_boxes: Number of boxes.
|
||||
iou_mask_stride: Stride for IoU mask tensor.
|
||||
cxpr_block_size: Block size for processing.
|
||||
"""
|
||||
# Sequential NMS: for each box in sorted order, suppress later boxes
|
||||
for current_box_idx in range(num_boxes - 1):
|
||||
# Check if current box is still kept
|
||||
is_kept = tl.load(keep_mask_ptr + current_box_idx)
|
||||
if is_kept:
|
||||
# IoU mask row offset for the current box
|
||||
# Because the IoU mask is sorted by score, we will only consider boxes that come after the current box.
|
||||
# This means we only need to read the upper triangular part of the IoU mask.
|
||||
iou_row_offset = current_box_idx * iou_mask_stride
|
||||
|
||||
# Only process boxes that come after the current box
|
||||
next_box_idx = current_box_idx + 1
|
||||
remaining_boxes = num_boxes - next_box_idx
|
||||
|
||||
# Iterate blockwise through the columns
|
||||
for block_idx in range(tl.cdiv(remaining_boxes, cxpr_block_size)):
|
||||
# Masked load of indices for the target boxes in the current block
|
||||
block_start = next_box_idx + block_idx * cxpr_block_size
|
||||
target_box_offsets = block_start + tl.arange(0, cxpr_block_size)
|
||||
target_box_mask = target_box_offsets < num_boxes
|
||||
|
||||
# Suppress boxes with lower scores that have high IoU
|
||||
suppression_mask = tl.load(
|
||||
iou_mask_ptr + iou_row_offset + target_box_offsets,
|
||||
mask=target_box_mask,
|
||||
other=False,
|
||||
)
|
||||
suppression_mask = tl.cast(suppression_mask, tl.int1)
|
||||
|
||||
# Conditionally store suppression result for high-IoU boxes
|
||||
tl.store(
|
||||
keep_mask_ptr + target_box_offsets, False, mask=suppression_mask
|
||||
)
|
||||
|
||||
# Potential race condition: we need to ensure all threads complete the store before the next
|
||||
# iteration otherwise we may load stale data for whether or not a box has been suppressed.
|
||||
tl.debug_barrier()
|
||||
|
||||
|
||||
def nms_triton(
|
||||
ious: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
iou_threshold: float,
|
||||
) -> torch.Tensor:
|
||||
"""Perform NMS given the iou matrix, the scores and the iou threshold
|
||||
|
||||
Args:
|
||||
ious: Pairwise IoU tensor of shape (N, N).
|
||||
scores: Scores tensor of shape (N,).
|
||||
iou_threshold: IoU threshold for suppression.
|
||||
|
||||
Returns:
|
||||
Tensor: Indices of kept boxes, sorted by decreasing score.
|
||||
"""
|
||||
assert scores.dim() == 1, "Scores must be 1D"
|
||||
iou_mask = ious > iou_threshold
|
||||
assert iou_mask.dim() == 2
|
||||
assert iou_mask.shape[0] == iou_mask.shape[1] == scores.shape[0]
|
||||
assert iou_mask.device == scores.device
|
||||
assert iou_mask.dtype == torch.bool
|
||||
|
||||
num_boxes = scores.size(0)
|
||||
keep_mask = torch.ones(len(scores), device=scores.device, dtype=torch.bool)
|
||||
|
||||
# Sort boxes by scores in descending order
|
||||
_, sorted_indices = torch.sort(scores, dim=0, stable=True, descending=True)
|
||||
iou_mask = iou_mask[sorted_indices][:, sorted_indices].contiguous()
|
||||
|
||||
# For the suppression stage, we need to process sequentially, but we'll still take
|
||||
# advantage of parallelism by processing in blocks in one program.
|
||||
stage2_grid = (1,)
|
||||
_nms_suppression_kernel[stage2_grid](
|
||||
# Tensors
|
||||
iou_mask_ptr=iou_mask,
|
||||
keep_mask_ptr=keep_mask,
|
||||
# Scalars
|
||||
num_boxes=num_boxes,
|
||||
# Strides
|
||||
iou_mask_stride=iou_mask.stride(0),
|
||||
)
|
||||
# Extract indices of kept boxes
|
||||
return sorted_indices[keep_mask]
|
||||
Reference in New Issue
Block a user