Initial commit

fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
facebook-github-bot
2025-11-18 23:07:42 -08:00
commit a13e358df4
504 changed files with 122758 additions and 0 deletions

1
sam3/model/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

View File

@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import inspect
from functools import wraps
from typing import Callable, TypeVar, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torch.utils._pytree import tree_map_only
# Type variables for better type hinting
T = TypeVar("T")
Module = TypeVar("Module", bound=nn.Module)
def activation_ckpt_wrapper(module: Union[nn.Module, Callable]) -> Callable:
"""
Wraps a given module to enable or disable activation checkpointing.
Activation checkpointing (gradient checkpointing) trades compute for memory by
recomputing intermediate activations during the backward pass instead of storing
them in memory during the forward pass.
When activation checkpointing is enabled, the wrapper expects only keyword arguments,
and it maps these to positional arguments based on the module's signature.
Args:
module: The module or function to wrap with activation checkpointing
Returns:
A wrapped callable that supports activation checkpointing
Usage:
The returned wrapper function can be called with the same arguments as the
original module, with an additional `act_ckpt_enable` keyword argument to control
activation checkpointing and optional `use_reentrant` parameter.
Example:
```python
wrapped_module = activation_ckpt_wrapper(my_module)
output = wrapped_module(x=input_tensor, y=another_tensor, act_ckpt_enable=True)
```
"""
@wraps(module)
def act_ckpt_wrapper(
*args, act_ckpt_enable: bool = True, use_reentrant: bool = False, **kwargs
):
if act_ckpt_enable:
if len(args) > 0:
raise ValueError(
"This wrapper expects keyword arguments only when `act_ckpt_enable=True`"
)
# Get the signature of the target function/module
callable_fn = module.forward if isinstance(module, nn.Module) else module
sig = inspect.signature(callable_fn)
# Create a mapping of parameter names to their default values
param_defaults = {
name: param.default for name, param in sig.parameters.items()
}
args = []
for p_name in param_defaults.keys():
if p_name in kwargs:
args.append(kwargs.pop(p_name))
elif param_defaults[p_name] is not inspect.Parameter.empty:
# Set arg to default value if it's not in kwargs. Useful for primitive types or args that default to None
args.append(param_defaults[p_name])
elif (
sig.parameters[p_name].kind is not inspect.Parameter.VAR_KEYWORD
): # Skip **kwargs parameter
raise ValueError(f"Missing positional argument: {p_name}")
# Scan remaining kwargs for torch.Tensor
remaining_keys = list(kwargs.keys())
for key in remaining_keys:
if isinstance(kwargs[key], torch.Tensor):
# Remove the tensor from kwargs, assuming it's not required by the module.
# If it is required, the module's signature should be modified to accept it as a positional or keyword argument.
kwargs[key] = "_REMOVED_BY_ACT_CKPT_WRAPPER_"
ret = checkpoint.checkpoint(
module, *args, use_reentrant=use_reentrant, **kwargs
)
else:
ret = module(*args, **kwargs)
return ret
return act_ckpt_wrapper
def clone_output_wrapper(f: Callable[..., T]) -> Callable[..., T]:
"""
Clone the CUDA output tensors of a function to avoid in-place operations.
This wrapper is useful when working with torch.compile to prevent errors
related to in-place operations on tensors.
Args:
f: The function whose CUDA tensor outputs should be cloned
Returns:
A wrapped function that clones any CUDA tensor outputs
"""
@wraps(f)
def wrapped(*args, **kwargs):
outputs = f(*args, **kwargs)
return tree_map_only(
torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs
)
return wrapped

217
sam3/model/box_ops.py Normal file
View File

@@ -0,0 +1,217 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
from typing import Tuple
import torch
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_cxcywh_to_xywh(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (w), (h)]
return torch.stack(b, dim=-1)
def box_xywh_to_xyxy(x):
x, y, w, h = x.unbind(-1)
b = [(x), (y), (x + w), (y + h)]
return torch.stack(b, dim=-1)
def box_xywh_to_cxcywh(x):
x, y, w, h = x.unbind(-1)
b = [(x + 0.5 * w), (y + 0.5 * h), (w), (h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_xywh(x):
x, y, X, Y = x.unbind(-1)
b = [(x), (y), (X - x), (Y - y)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
def box_area(boxes):
"""
Batched version of box area. Boxes should be in [x0, y0, x1, y1] format.
Inputs:
- boxes: Tensor of shape (..., 4)
Returns:
- areas: Tensor of shape (...,)
"""
x0, y0, x1, y1 = boxes.unbind(-1)
return (x1 - x0) * (y1 - y0)
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float, device=masks.device)
x = torch.arange(0, w, dtype=torch.float, device=masks.device)
y, x = torch.meshgrid(y, x)
x_mask = masks * x.unsqueeze(0)
x_max = x_mask.flatten(1).max(-1)[0] + 1
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = masks * y.unsqueeze(0)
y_max = y_mask.flatten(1).max(-1)[0] + 1
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
boxes = torch.stack([x_min, y_min, x_max, y_max], 1)
# Invalidate boxes corresponding to empty masks.
boxes = boxes * masks.flatten(-2).any(-1)
return boxes
def box_iou(boxes1, boxes2):
"""
Batched version of box_iou. Boxes should be in [x0, y0, x1, y1] format.
Inputs:
- boxes1: Tensor of shape (..., N, 4)
- boxes2: Tensor of shape (..., M, 4)
Returns:
- iou, union: Tensors of shape (..., N, M)
"""
area1 = box_area(boxes1)
area2 = box_area(boxes2)
# boxes1: (..., N, 4) -> (..., N, 1, 2)
# boxes2: (..., M, 4) -> (..., 1, M, 2)
lt = torch.max(boxes1[..., :, None, :2], boxes2[..., None, :, :2])
rb = torch.min(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:])
wh = (rb - lt).clamp(min=0) # (..., N, M, 2)
inter = wh[..., 0] * wh[..., 1] # (..., N, M)
union = area1[..., None] + area2[..., None, :] - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Batched version of Generalized IoU from https://giou.stanford.edu/
Boxes should be in [x0, y0, x1, y1] format
Inputs:
- boxes1: Tensor of shape (..., N, 4)
- boxes2: Tensor of shape (..., M, 4)
Returns:
- giou: Tensor of shape (..., N, M)
"""
iou, union = box_iou(boxes1, boxes2)
# boxes1: (..., N, 4) -> (..., N, 1, 2)
# boxes2: (..., M, 4) -> (..., 1, M, 2)
lt = torch.min(boxes1[..., :, None, :2], boxes2[..., None, :, :2])
rb = torch.max(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:])
wh = (rb - lt).clamp(min=0) # (..., N, M, 2)
area = wh[..., 0] * wh[..., 1] # (..., N, M)
return iou - (area - union) / area
@torch.jit.script
def fast_diag_generalized_box_iou(boxes1, boxes2):
assert len(boxes1) == len(boxes2)
box1_xy = boxes1[:, 2:]
box1_XY = boxes1[:, :2]
box2_xy = boxes2[:, 2:]
box2_XY = boxes2[:, :2]
# assert (box1_xy >= box1_XY).all()
# assert (box2_xy >= box2_XY).all()
area1 = (box1_xy - box1_XY).prod(-1)
area2 = (box2_xy - box2_XY).prod(-1)
lt = torch.max(box1_XY, box2_XY) # [N,2]
lt2 = torch.min(box1_XY, box2_XY)
rb = torch.min(box1_xy, box2_xy) # [N,2]
rb2 = torch.max(box1_xy, box2_xy)
inter = (rb - lt).clamp(min=0).prod(-1)
tot_area = (rb2 - lt2).clamp(min=0).prod(-1)
union = area1 + area2 - inter
iou = inter / union
return iou - (tot_area - union) / tot_area
@torch.jit.script
def fast_diag_box_iou(boxes1, boxes2):
assert len(boxes1) == len(boxes2)
box1_xy = boxes1[:, 2:]
box1_XY = boxes1[:, :2]
box2_xy = boxes2[:, 2:]
box2_XY = boxes2[:, :2]
# assert (box1_xy >= box1_XY).all()
# assert (box2_xy >= box2_XY).all()
area1 = (box1_xy - box1_XY).prod(-1)
area2 = (box2_xy - box2_XY).prod(-1)
lt = torch.max(box1_XY, box2_XY) # [N,2]
rb = torch.min(box1_xy, box2_xy) # [N,2]
inter = (rb - lt).clamp(min=0).prod(-1)
union = area1 + area2 - inter
iou = inter / union
return iou
def box_xywh_inter_union(
boxes1: torch.Tensor, boxes2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# Asuumes boxes in xywh format
assert boxes1.size(-1) == 4 and boxes2.size(-1) == 4
boxes1 = box_xywh_to_xyxy(boxes1)
boxes2 = box_xywh_to_xyxy(boxes2)
box1_tl_xy = boxes1[..., :2]
box1_br_xy = boxes1[..., 2:]
box2_tl_xy = boxes2[..., :2]
box2_br_xy = boxes2[..., 2:]
area1 = (box1_br_xy - box1_tl_xy).prod(-1)
area2 = (box2_br_xy - box2_tl_xy).prod(-1)
assert (area1 >= 0).all() and (area2 >= 0).all()
tl = torch.max(box1_tl_xy, box2_tl_xy)
br = torch.min(box1_br_xy, box2_br_xy)
inter = (br - tl).clamp(min=0).prod(-1)
union = area1 + area2 - inter
return inter, union

209
sam3/model/data_misc.py Normal file
View File

@@ -0,0 +1,209 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
"""
import collections
import re
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
import torch
MyTensor = Union[torch.Tensor, List[Any]]
def interpolate(
input, size=None, scale_factor=None, mode="nearest", align_corners=None
):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty channel sizes.
"""
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
assert (
input.shape[0] != 0 or input.shape[1] != 0
), "At least one of the two first dimensions must be non zero"
if input.shape[1] == 0:
# Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim
return torch.nn.functional.interpolate(
input.transpose(0, 1), size, scale_factor, mode, align_corners
).transpose(0, 1)
# empty batch dimension is now supported in pytorch
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
@dataclass
class BatchedPointer:
stage_ids: MyTensor
stage_ids__type = torch.long
query_ids: MyTensor
query_ids__type = torch.long
object_ids: MyTensor
object_ids__type = torch.long
ptr_mask: MyTensor
ptr_mask__type = torch.bool
ptr_types: MyTensor
ptr_types__type = torch.long
@dataclass
class FindStage:
img_ids: MyTensor
img_ids__type = torch.long
text_ids: MyTensor
text_ids__type = torch.long
input_boxes: MyTensor
input_boxes__type = torch.float
input_boxes_mask: MyTensor
input_boxes_mask__type = torch.bool
input_boxes_label: MyTensor
input_boxes_label__type = torch.long
input_points: MyTensor
input_points__type = torch.float
input_points_mask: MyTensor
input_points_mask__type = torch.bool
# We track the object ids referred to by this query.
# This is beneficial for tracking in videos without the need for pointers.
object_ids: Optional[List[List]] = None # List of objects per query
@dataclass
class BatchedFindTarget:
# The number of boxes in each find query
num_boxes: MyTensor
num_boxes__type = torch.long
# Target boxes in normalized CxCywh format
boxes: MyTensor
boxes__type = torch.float
# Target boxes in normalized CxCywh format but in padded representation
# as used in BinaryHungarianMatcherV2 (unlike the packed ones in `boxes`)
boxes_padded: MyTensor
boxes_padded__type = torch.float
# For hybrid matching, we repeat the boxes
repeated_boxes: MyTensor
repeated_boxes__type = torch.float
# Target Segmentation masks
segments: Optional[MyTensor]
segments__type = torch.bool
# Target Semantic Segmentation masks
semantic_segments: Optional[MyTensor]
semantic_segments__type = torch.bool
is_valid_segment: Optional[MyTensor]
is_valid_segment__type = torch.bool
# Whether annotations are exhaustive for each query
is_exhaustive: MyTensor
is_exhaustive__type = torch.bool
# The object id for each ground-truth box, in both packed and padded representations
object_ids: MyTensor
object_ids__type = torch.long
object_ids_padded: MyTensor
object_ids_padded__type = torch.long
@dataclass
class BatchedInferenceMetadata:
"""All metadata required to post-process a find stage"""
# Coco id that corresponds to the "image" for evaluation by the coco evaluator
coco_image_id: MyTensor
coco_image_id__type = torch.long
# id in the original dataset, such that we can use the original evaluator
original_image_id: MyTensor
original_image_id__type = torch.long
# Original category id (if we want to use the original evaluator)
original_category_id: MyTensor
original_category_id__type = torch.int
# Size of the raw image (height, width)
original_size: MyTensor
original_size__type = torch.long
# id of the object in the media (track_id for a video)
object_id: MyTensor
object_id__type = torch.long
# index of the frame in the media (0 in the case of a single-frame media)
frame_index: MyTensor
frame_index__type = torch.long
# Adding for relations inference
# get_text_input: List[Optional[str]]
# Adding for TA conditional inference
is_conditioning_only: List[Optional[bool]]
@dataclass
class BatchedDatapoint:
img_batch: torch.Tensor
find_text_batch: List[str]
find_inputs: List[FindStage]
find_targets: List[BatchedFindTarget]
find_metadatas: List[BatchedInferenceMetadata]
raw_images: Optional[List[Any]] = None
def convert_my_tensors(obj):
def is_optional_field(field) -> bool:
return get_origin(field) is Union and type(None) in get_args(field)
for field in fields(obj):
if is_dataclass(getattr(obj, field.name)):
convert_my_tensors(getattr(obj, field.name))
continue
field_type = field.type
if is_optional_field(field.type):
field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type
if field_type != MyTensor or getattr(obj, field.name) is None:
continue
elif len(getattr(obj, field.name)) and isinstance(
getattr(obj, field.name)[0], torch.Tensor
):
stack_dim = 0
if field.name in [
"input_boxes",
"input_boxes_label",
]:
stack_dim = 1
setattr(
obj,
field.name,
torch.stack(getattr(obj, field.name), dim=stack_dim).to(
getattr(obj, field.name + "__type")
),
)
else:
setattr(
obj,
field.name,
torch.as_tensor(
getattr(obj, field.name), dtype=getattr(obj, field.name + "__type")
),
)
return obj

956
sam3/model/decoder.py Normal file
View File

@@ -0,0 +1,956 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Transformer decoder.
Inspired from Pytorch's version, adds the pre-norm variant
"""
from typing import Any, Dict, List, Optional
import numpy as np
import torch
from sam3.sam.transformer import RoPEAttention
from torch import nn, Tensor
from torchvision.ops.roi_align import RoIAlign
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .model_misc import (
gen_sineembed_for_position,
get_activation_fn,
get_clones,
inverse_sigmoid,
MLP,
)
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
activation: str,
d_model: int,
dim_feedforward: int,
dropout: float,
cross_attention: nn.Module,
n_heads: int,
use_text_cross_attention: bool = False,
):
super().__init__()
# cross attention
self.cross_attn = cross_attention
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm1 = nn.LayerNorm(d_model)
# cross attention text
self.use_text_cross_attention = use_text_cross_attention
if use_text_cross_attention:
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.catext_norm = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = get_activation_fn(activation)
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm3 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
with torch.amp.autocast(device_type="cuda", enabled=False):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward(
self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
memory_text: Optional[Tensor] = None, # num_token, bs, d_model
text_attention_mask: Optional[Tensor] = None, # bs, num_token
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
# dac
dac=False,
dac_use_selfatt_ln=True,
presence_token=None,
# skip inside deformable attn
identity=0.0,
**kwargs, # additional kwargs for compatibility
):
"""
Input:
- tgt/tgt_query_pos: nq, bs, d_model
-
"""
# self attention
if self.self_attn is not None:
if dac:
# we only apply self attention to the first half of the queries
assert tgt.shape[0] % 2 == 0
num_o2o_queries = tgt.shape[0] // 2
tgt_o2o = tgt[:num_o2o_queries]
tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
tgt_o2m = tgt[num_o2o_queries:]
else:
tgt_o2o = tgt
tgt_query_pos_o2o = tgt_query_pos
if presence_token is not None:
tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
tgt_query_pos_o2o = torch.cat(
[torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0
)
tgt_query_pos = torch.cat(
[torch.zeros_like(presence_token), tgt_query_pos], dim=0
)
q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0]
tgt_o2o = tgt_o2o + self.dropout2(tgt2)
if dac:
if not dac_use_selfatt_ln:
tgt_o2o = self.norm2(tgt_o2o)
tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0) # Recombine
if dac_use_selfatt_ln:
tgt = self.norm2(tgt)
else:
tgt = tgt_o2o
tgt = self.norm2(tgt)
if self.use_text_cross_attention:
tgt2 = self.ca_text(
self.with_pos_embed(tgt, tgt_query_pos),
memory_text,
memory_text,
key_padding_mask=text_attention_mask,
)[0]
tgt = tgt + self.catext_dropout(tgt2)
tgt = self.catext_norm(tgt)
if presence_token is not None:
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
cross_attn_mask = torch.cat(
[presence_token_mask, cross_attn_mask], dim=1
) # (bs*nheads, 1+nq, hw)
# Cross attention to image
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=cross_attn_mask,
key_padding_mask=(
memory_key_padding_mask.transpose(0, 1)
if memory_key_padding_mask is not None
else None
),
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt)
presence_token_out = None
if presence_token is not None:
presence_token_out = tgt[:1]
tgt = tgt[1:]
return tgt, presence_token_out
class TransformerDecoder(nn.Module):
def __init__(
self,
d_model: int,
frozen: bool,
interaction_layer,
layer,
num_layers: int,
num_queries: int,
return_intermediate: bool,
box_refine: bool = False,
num_o2m_queries: int = 0,
dac: bool = False,
boxRPB: str = "none",
# Experimental: An object query for SAM 2 tasks
instance_query: bool = False,
# Defines the number of additional instance queries,
# 1 or 4 are the most likely for single vs multi mask support
num_instances: int = 1, # Irrelevant if instance_query is False
dac_use_selfatt_ln: bool = True,
use_act_checkpoint: bool = False,
compile_mode=None,
presence_token: bool = False,
clamp_presence_logits: bool = True,
clamp_presence_logit_max_val: float = 10.0,
use_normed_output_consistently: bool = True,
separate_box_head_instance: bool = False,
separate_norm_instance: bool = False,
resolution: Optional[int] = None,
stride: Optional[int] = None,
):
super().__init__()
self.d_model = d_model
self.layers = get_clones(layer, num_layers)
self.fine_layers = (
get_clones(interaction_layer, num_layers)
if interaction_layer is not None
else [None] * num_layers
)
self.num_layers = num_layers
self.num_queries = num_queries
self.dac = dac
if dac:
self.num_o2m_queries = num_queries
tot_num_queries = num_queries
else:
self.num_o2m_queries = num_o2m_queries
tot_num_queries = num_queries + num_o2m_queries
self.norm = nn.LayerNorm(d_model)
self.return_intermediate = return_intermediate
self.bbox_embed = MLP(d_model, d_model, 4, 3)
self.query_embed = nn.Embedding(tot_num_queries, d_model)
self.instance_query_embed = None
self.instance_query_reference_points = None
self.use_instance_query = instance_query
self.num_instances = num_instances
self.use_normed_output_consistently = use_normed_output_consistently
self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
self.instance_bbox_embed = None
if separate_box_head_instance:
self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
if instance_query:
self.instance_query_embed = nn.Embedding(num_instances, d_model)
self.box_refine = box_refine
if box_refine:
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
self.reference_points = nn.Embedding(num_queries, 4)
if instance_query:
self.instance_reference_points = nn.Embedding(num_instances, 4)
assert boxRPB in ["none", "log", "linear", "both"]
self.boxRPB = boxRPB
if boxRPB != "none":
try:
nheads = self.layers[0].cross_attn_image.num_heads
except AttributeError:
nheads = self.layers[0].cross_attn.num_heads
n_input = 4 if boxRPB == "both" else 2
self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
self.compilable_cord_cache = None
self.compilable_stored_size = None
self.coord_cache = {}
if resolution is not None and stride is not None:
feat_size = resolution // stride
coords_h, coords_w = self._get_coords(
feat_size, feat_size, device="cuda"
)
self.compilable_cord_cache = (coords_h, coords_w)
self.compilable_stored_size = (feat_size, feat_size)
self.roi_pooler = (
RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
if interaction_layer is not None
else None
)
if frozen:
for p in self.parameters():
p.requires_grad_(False)
self.presence_token = None
self.clamp_presence_logits = clamp_presence_logits
self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
if presence_token:
self.presence_token = nn.Embedding(1, d_model)
self.presence_token_head = MLP(d_model, d_model, 1, 3)
self.presence_token_out_norm = nn.LayerNorm(d_model)
self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
self.dac_use_selfatt_ln = dac_use_selfatt_ln
self.use_act_checkpoint = use_act_checkpoint
nn.init.normal_(self.query_embed.weight.data)
if self.instance_query_embed is not None:
nn.init.normal_(self.instance_query_embed.weight.data)
assert self.roi_pooler is None
assert self.return_intermediate, "support return_intermediate only"
assert self.box_refine, "support box refine only"
self.compile_mode = compile_mode
self.compiled = False
# We defer compilation till after the first forward, to first warm-up the boxRPB cache
# assign layer index to each layer so that some layers can decide what to do
# based on which layer index they are (e.g. cross attention to memory bank only
# in selected layers)
for layer_idx, layer in enumerate(self.layers):
layer.layer_idx = layer_idx
@staticmethod
def _get_coords(H, W, device):
coords_h = torch.arange(0, H, device=device, dtype=torch.float32) / H
coords_w = torch.arange(0, W, device=device, dtype=torch.float32) / W
return coords_h, coords_w
def _get_rpb_matrix(self, reference_boxes, feat_size):
H, W = feat_size
boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1)
bs, num_queries, _ = boxes_xyxy.shape
if self.compilable_cord_cache is None:
self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device)
self.compilable_stored_size = (H, W)
if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
H,
W,
):
# good, hitting the cache, will be compilable
coords_h, coords_w = self.compilable_cord_cache
else:
# cache miss, will create compilation issue
# In case we're not compiling, we'll still rely on the dict-based cache
if feat_size not in self.coord_cache:
self.coord_cache[feat_size] = self._get_coords(
H, W, reference_boxes.device
)
coords_h, coords_w = self.coord_cache[feat_size]
assert coords_h.shape == (H,)
assert coords_w.shape == (W,)
deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
deltas_y = deltas_y.view(bs, num_queries, -1, 2)
deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
deltas_x = deltas_x.view(bs, num_queries, -1, 2)
if self.boxRPB in ["log", "both"]:
deltas_x_log = deltas_x * 8 # normalize to -8, 8
deltas_x_log = (
torch.sign(deltas_x_log)
* torch.log2(torch.abs(deltas_x_log) + 1.0)
/ np.log2(8)
)
deltas_y_log = deltas_y * 8 # normalize to -8, 8
deltas_y_log = (
torch.sign(deltas_y_log)
* torch.log2(torch.abs(deltas_y_log) + 1.0)
/ np.log2(8)
)
if self.boxRPB == "log":
deltas_x = deltas_x_log
deltas_y = deltas_y_log
else:
deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)
if self.training:
assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
deltas_x = activation_ckpt_wrapper(self.boxRPB_embed_x)(
x=deltas_x,
act_ckpt_enable=self.training and self.use_act_checkpoint,
) # bs, num_queries, W, n_heads
deltas_y = activation_ckpt_wrapper(self.boxRPB_embed_y)(
x=deltas_y,
act_ckpt_enable=self.training and self.use_act_checkpoint,
) # bs, num_queries, H, n_heads
if not torch.compiler.is_dynamo_compiling():
assert deltas_x.shape[:3] == (bs, num_queries, W)
assert deltas_y.shape[:3] == (bs, num_queries, H)
B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
2
) # bs, num_queries, H, W, n_heads
if not torch.compiler.is_dynamo_compiling():
assert B.shape[:4] == (bs, num_queries, H, W)
B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
B = B.contiguous() # memeff attn likes ordered strides
if not torch.compiler.is_dynamo_compiling():
assert B.shape[2:] == (num_queries, H * W)
return B
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
reference_boxes: Optional[Tensor] = None, # num_queries, bs, 4
# for memory
level_start_index: Optional[Tensor] = None, # num_levels
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
valid_ratios: Optional[Tensor] = None,
# for text
memory_text: Optional[Tensor] = None,
text_attention_mask: Optional[Tensor] = None,
# if `apply_dac` is None, it will default to `self.dac`
apply_dac: Optional[bool] = None,
is_instance_prompt=False,
decoder_extra_kwargs: Optional[Dict] = None,
# ROI memory bank
obj_roi_memory_feat=None,
obj_roi_memory_mask=None,
box_head_trk=None,
):
"""
Input:
- tgt: nq, bs, d_model
- memory: \\sum{hw}, bs, d_model
- pos: \\sum{hw}, bs, d_model
- reference_boxes: nq, bs, 4 (after sigmoid)
- valid_ratios/spatial_shapes: bs, nlevel, 2
"""
if memory_mask is not None:
assert (
self.boxRPB == "none"
), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
apply_dac = apply_dac if apply_dac is not None else self.dac
if apply_dac:
assert (tgt.shape[0] == self.num_queries) or (
self.use_instance_query
and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
)
tgt = tgt.repeat(2, 1, 1)
# note that we don't tile tgt_mask, since DAC doesn't
# use self-attention in o2m queries
if reference_boxes is not None:
assert (reference_boxes.shape[0] == self.num_queries) or (
self.use_instance_query
and (
reference_boxes.shape[0]
== self.instance_query_embed.num_embeddings
)
)
reference_boxes = reference_boxes.repeat(2, 1, 1)
bs = tgt.shape[1]
intermediate = []
intermediate_presence_logits = []
presence_feats = None
if self.box_refine:
if reference_boxes is None:
# In this case, we're in a one-stage model, so we generate the reference boxes
reference_boxes = self.reference_points.weight.unsqueeze(1)
reference_boxes = (
reference_boxes.repeat(2, bs, 1)
if apply_dac
else reference_boxes.repeat(1, bs, 1)
)
reference_boxes = reference_boxes.sigmoid()
intermediate_ref_boxes = [reference_boxes]
else:
reference_boxes = None
intermediate_ref_boxes = None
output = tgt
presence_out = None
if self.presence_token is not None and is_instance_prompt is False:
# expand to batch dim
presence_out = self.presence_token.weight[None].expand(1, bs, -1)
box_head = self.bbox_embed
if is_instance_prompt and self.instance_bbox_embed is not None:
box_head = self.instance_bbox_embed
out_norm = self.norm
if is_instance_prompt and self.instance_norm is not None:
out_norm = self.instance_norm
for layer_idx, layer in enumerate(self.layers):
reference_points_input = (
reference_boxes[:, :, None]
* torch.cat([valid_ratios, valid_ratios], -1)[None, :]
) # nq, bs, nlevel, 4
query_sine_embed = gen_sineembed_for_position(
reference_points_input[:, :, 0, :], self.d_model
) # nq, bs, d_model*2
# conditional query
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
if self.boxRPB != "none" and reference_boxes is not None:
assert (
spatial_shapes.shape[0] == 1
), "only single scale support implemented"
memory_mask = self._get_rpb_matrix(
reference_boxes,
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
)
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
if self.training:
assert (
self.use_act_checkpoint
), "Activation checkpointing not enabled in the decoder"
output, presence_out = activation_ckpt_wrapper(layer)(
tgt=output,
tgt_query_pos=query_pos,
tgt_query_sine_embed=query_sine_embed,
tgt_key_padding_mask=tgt_key_padding_mask,
tgt_reference_points=reference_points_input,
memory_text=memory_text,
text_attention_mask=text_attention_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
memory_level_start_index=level_start_index,
memory_spatial_shapes=spatial_shapes,
memory_pos=pos,
self_attn_mask=tgt_mask,
cross_attn_mask=memory_mask,
dac=apply_dac,
dac_use_selfatt_ln=self.dac_use_selfatt_ln,
presence_token=presence_out,
**(decoder_extra_kwargs or {}),
act_ckpt_enable=self.training and self.use_act_checkpoint,
# ROI memory bank
obj_roi_memory_feat=obj_roi_memory_feat,
obj_roi_memory_mask=obj_roi_memory_mask,
)
# iter update
if self.box_refine:
reference_before_sigmoid = inverse_sigmoid(reference_boxes)
if box_head_trk is None:
# delta_unsig = self.bbox_embed(output)
if not self.use_normed_output_consistently:
delta_unsig = box_head(output)
else:
delta_unsig = box_head(out_norm(output))
else:
# box_head_trk use a separate box head for tracking queries
Q_det = decoder_extra_kwargs["Q_det"]
assert output.size(0) >= Q_det
delta_unsig_det = self.bbox_embed(output[:Q_det])
delta_unsig_trk = box_head_trk(output[Q_det:])
delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid()
reference_boxes = new_reference_points.detach()
if layer_idx != self.num_layers - 1:
intermediate_ref_boxes.append(new_reference_points)
else:
raise NotImplementedError("not implemented yet")
intermediate.append(out_norm(output))
if self.presence_token is not None and is_instance_prompt is False:
# norm, mlp head
intermediate_layer_presence_logits = self.presence_token_head(
self.presence_token_out_norm(presence_out)
).squeeze(-1)
# clamp to mitigate numerical issues
if self.clamp_presence_logits:
intermediate_layer_presence_logits.clamp(
min=-self.clamp_presence_logit_max_val,
max=self.clamp_presence_logit_max_val,
)
intermediate_presence_logits.append(intermediate_layer_presence_logits)
presence_feats = presence_out.clone()
if not self.compiled and self.compile_mode is not None:
self.forward = torch.compile(
self.forward, mode=self.compile_mode, fullgraph=True
)
self.compiled = True
return (
torch.stack(intermediate),
torch.stack(intermediate_ref_boxes),
(
torch.stack(intermediate_presence_logits)
if self.presence_token is not None and is_instance_prompt is False
else None
),
presence_feats,
)
class TransformerEncoderCrossAttention(nn.Module):
def __init__(
self,
d_model: int,
frozen: bool,
pos_enc_at_input: bool,
layer,
num_layers: int,
use_act_checkpoint: bool = False,
batch_first: bool = False, # Do layers expect batch first input?
# which layers to exclude cross attention? default: None, means all
# layers use cross attention
remove_cross_attention_layers: Optional[list] = None,
):
super().__init__()
self.d_model = d_model
self.layers = get_clones(layer, num_layers)
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model)
self.pos_enc_at_input = pos_enc_at_input
self.use_act_checkpoint = use_act_checkpoint
if frozen:
for p in self.parameters():
p.requires_grad_(False)
self.batch_first = batch_first
# remove cross attention layers if specified
self.remove_cross_attention_layers = [False] * self.num_layers
if remove_cross_attention_layers is not None:
for i in remove_cross_attention_layers:
self.remove_cross_attention_layers[i] = True
assert len(self.remove_cross_attention_layers) == len(self.layers)
for i, remove_cross_attention in enumerate(self.remove_cross_attention_layers):
if remove_cross_attention:
self.layers[i].cross_attn_image = None
self.layers[i].norm2 = None
self.layers[i].dropout2 = None
def forward(
self,
src, # self-attention inputs
prompt, # cross-attention inputs
src_mask: Optional[Tensor] = None, # att.mask for self-attention inputs
prompt_mask: Optional[Tensor] = None, # att.mask for cross-attention inputs
src_key_padding_mask: Optional[Tensor] = None,
prompt_key_padding_mask: Optional[Tensor] = None,
src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
prompt_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
feat_sizes: Optional[list] = None,
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
):
if isinstance(src, list):
assert isinstance(src_key_padding_mask, list) and isinstance(src_pos, list)
assert len(src) == len(src_key_padding_mask) == len(src_pos) == 1
src, src_key_padding_mask, src_pos = (
src[0],
src_key_padding_mask[0],
src_pos[0],
)
assert (
src.shape[1] == prompt.shape[1]
), "Batch size must be the same for src and prompt"
output = src
if self.pos_enc_at_input and src_pos is not None:
output = output + 0.1 * src_pos
if self.batch_first:
# Convert to batch first
output = output.transpose(0, 1)
src_pos = src_pos.transpose(0, 1)
prompt = prompt.transpose(0, 1)
prompt_pos = prompt_pos.transpose(0, 1)
for layer in self.layers:
kwds = {}
if isinstance(layer.cross_attn_image, RoPEAttention):
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
output = activation_ckpt_wrapper(layer)(
tgt=output,
memory=prompt,
tgt_mask=src_mask,
memory_mask=prompt_mask,
tgt_key_padding_mask=src_key_padding_mask,
memory_key_padding_mask=prompt_key_padding_mask,
pos=prompt_pos,
query_pos=src_pos,
dac=False,
attn_bias=None,
act_ckpt_enable=self.training and self.use_act_checkpoint,
**kwds,
)
normed_output = self.norm(output)
if self.batch_first:
# Convert back to seq first
normed_output = normed_output.transpose(0, 1)
src_pos = src_pos.transpose(0, 1)
return {
"memory": normed_output,
"pos_embed": src_pos,
"padding_mask": src_key_padding_mask,
}
class TransformerDecoderLayerv1(nn.Module):
def __init__(
self,
activation: str,
cross_attention: nn.Module,
d_model: int,
dim_feedforward: int,
dropout: float,
pos_enc_at_attn: bool,
pos_enc_at_cross_attn_keys: bool,
pos_enc_at_cross_attn_queries: bool,
pre_norm: bool,
self_attention: nn.Module,
):
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
self.dropout_value = dropout
self.self_attn = self_attention
self.cross_attn_image = cross_attention
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation_str = activation
self.activation = get_activation_fn(activation)
self.pre_norm = pre_norm
self.pos_enc_at_attn = pos_enc_at_attn
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
def forward_post(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
**kwargs,
):
q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
# Self attention
tgt2 = self.self_attn(
q,
k,
value=tgt,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# Cross attention to image
tgt2 = self.cross_attn_image(
query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# FFN
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
dac: bool = False,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
attn_bias: Optional[Tensor] = None,
**kwargs,
):
if dac:
# we only apply self attention to the first half of the queries
assert tgt.shape[0] % 2 == 0
other_tgt = tgt[tgt.shape[0] // 2 :]
tgt = tgt[: tgt.shape[0] // 2]
tgt2 = self.norm1(tgt)
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
tgt2 = self.self_attn(
q,
k,
value=tgt2,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
)[0]
tgt = tgt + self.dropout1(tgt2)
if dac:
# Recombine
tgt = torch.cat((tgt, other_tgt), dim=0)
tgt2 = self.norm2(tgt)
tgt2 = self.cross_attn_image(
query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
attn_bias=attn_bias,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory,
dac: bool = False,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
attn_bias: Optional[Tensor] = None,
**kwds: Any,
) -> torch.Tensor:
fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
return fwd_fn(
tgt,
memory,
dac=dac,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
attn_bias=attn_bias,
**kwds,
)
class TransformerDecoderLayerv2(TransformerDecoderLayerv1):
def __init__(self, cross_attention_first=False, *args: Any, **kwds: Any):
super().__init__(*args, **kwds)
self.cross_attention_first = cross_attention_first
def _forward_sa(self, tgt, query_pos):
# Self-Attention
tgt2 = self.norm1(tgt)
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
tgt2 = self.self_attn(q, k, v=tgt2)
tgt = tgt + self.dropout1(tgt2)
return tgt
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
if self.cross_attn_image is None:
return tgt
kwds = {}
if num_k_exclude_rope > 0:
assert isinstance(self.cross_attn_image, RoPEAttention)
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
# Cross-Attention
tgt2 = self.norm2(tgt)
tgt2 = self.cross_attn_image(
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
v=memory,
**kwds,
)
tgt = tgt + self.dropout2(tgt2)
return tgt
def forward_pre(
self,
tgt,
memory,
dac: bool,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
attn_bias: Optional[Tensor] = None,
num_k_exclude_rope: int = 0,
):
assert dac is False
assert tgt_mask is None
assert memory_mask is None
assert tgt_key_padding_mask is None
assert memory_key_padding_mask is None
assert attn_bias is None
if self.cross_attention_first:
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
tgt = self._forward_sa(tgt, query_pos)
else:
tgt = self._forward_sa(tgt, query_pos)
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
# MLP
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, *args: Any, **kwds: Any) -> torch.Tensor:
if self.pre_norm:
return self.forward_pre(*args, **kwds)
raise NotImplementedError

173
sam3/model/edt.py Normal file
View File

@@ -0,0 +1,173 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Triton kernel for euclidean distance transform (EDT)"""
import torch
import triton
import triton.language as tl
"""
Disclaimer: This implementation is not meant to be extremely efficient. A CUDA kernel would likely be more efficient.
Even in Triton, there may be more suitable algorithms.
The goal of this kernel is to mimic cv2.distanceTransform(input, cv2.DIST_L2, 0).
Recall that the euclidean distance transform (EDT) calculates the L2 distance to the closest zero pixel for each pixel of the source image.
For images of size NxN, the naive algorithm would be to compute pairwise distances between every pair of points, leading to a O(N^4) algorithm, which is obviously impractical.
One can do better using the following approach:
- First, compute the distance to the closest point in the same row. We can write it as Row_EDT[i,j] = min_k (sqrt((k-j)^2) if input[i,k]==0 else +infinity). With a naive implementation, this step has a O(N^3) complexity
- Then, because of triangular inequality, we notice that the EDT for a given location [i,j] is the min of the row EDTs in the same column. EDT[i,j] = min_k Row_EDT[k, j]. This is also O(N^3)
Overall, this algorithm is quite amenable to parallelization, and has a complexity O(N^3). Can we do better?
It turns out that we can leverage the structure of the L2 distance (nice and convex) to find the minimum in a more efficient way.
We follow the algorithm from "Distance Transforms of Sampled Functions" (https://cs.brown.edu/people/pfelzens/papers/dt-final.pdf), which is also what's implemented in opencv
For a single dimension EDT, we can compute the EDT of an arbitrary function F, that we discretize over the grid. Note that for the binary EDT that we're interested in, we can set F(i,j) = 0 if input[i,j]==0 else +infinity
For now, we'll compute the EDT squared, and will take the sqrt only at the very end.
The basic idea is that each point at location i spawns a parabola around itself, with a bias equal to F(i). So specifically, we're looking at the parabola (x - i)^2 + F(i)
When we're looking for the row EDT at location j, we're effectively looking for min_i (x-i)^2 + F(i). In other word we want to find the lowest parabola at location j.
To do this efficiently, we need to maintain the lower envelope of the union of parabolas. This can be constructed on the fly using a sort of stack approach:
- every time we want to add a new parabola, we check if it may be covering the current right-most parabola. If so, then that parabola was useless, so we can pop it from the stack
- repeat until we can't find any more parabola to pop. Then push the new one.
This algorithm runs in O(N) for a single row, so overall O(N^2) when applied to all rows
Similarly as before, we notice that we can decompose the algorithm for rows and columns, leading to an overall run-time of O(N^2)
This algorithm is less suited for to GPUs, since the one-dimensional EDT computation is quite sequential in nature. However, we can parallelize over batch and row dimensions.
In Triton, things are particularly bad at the moment, since there is no support for reading/writing to the local memory at a specific index (a local gather is coming soon, see https://github.com/triton-lang/triton/issues/974, but no mention of writing, ie scatter)
One could emulate these operations with masking, but in initial tests, it proved to be worst than naively reading and writing to the global memory. My guess is that the cache is compensating somewhat for the repeated single-point accesses.
The timing obtained on a H100 for a random batch of masks of dimension 256 x 1024 x 1024 are as follows:
- OpenCV: 1780ms (including round-trip to cpu, but discounting the fact that it introduces a synchronization point)
- triton, O(N^3) algo: 627ms
- triton, O(N^2) algo: 322ms
Overall, despite being quite naive, this implementation is roughly 5.5x faster than the openCV cpu implem
"""
@triton.jit
def edt_kernel(inputs_ptr, outputs_ptr, v, z, height, width, horizontal: tl.constexpr):
# This is a somewhat verbatim implementation of the efficient 1D EDT algorithm described above
# It can be applied horizontally or vertically depending if we're doing the first or second stage.
# It's parallelized across batch+row (or batch+col if horizontal=False)
# TODO: perhaps the implementation can be revisited if/when local gather/scatter become available in triton
batch_id = tl.program_id(axis=0)
if horizontal:
row_id = tl.program_id(axis=1)
block_start = (batch_id * height * width) + row_id * width
length = width
stride = 1
else:
col_id = tl.program_id(axis=1)
block_start = (batch_id * height * width) + col_id
length = height
stride = width
# This will be the index of the right most parabola in the envelope ("the top of the stack")
k = 0
for q in range(1, length):
# Read the function value at the current location. Note that we're doing a singular read, not very efficient
cur_input = tl.load(inputs_ptr + block_start + (q * stride))
# location of the parabola on top of the stack
r = tl.load(v + block_start + (k * stride))
# associated boundary
z_k = tl.load(z + block_start + (k * stride))
# value of the function at the parabola location
previous_input = tl.load(inputs_ptr + block_start + (r * stride))
# intersection between the two parabolas
s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2
# we'll pop as many parabolas as required
while s <= z_k and k - 1 >= 0:
k = k - 1
r = tl.load(v + block_start + (k * stride))
z_k = tl.load(z + block_start + (k * stride))
previous_input = tl.load(inputs_ptr + block_start + (r * stride))
s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2
# Store the new one
k = k + 1
tl.store(v + block_start + (k * stride), q)
tl.store(z + block_start + (k * stride), s)
if k + 1 < length:
tl.store(z + block_start + ((k + 1) * stride), 1e9)
# Last step, we read the envelope to find the min in every location
k = 0
for q in range(length):
while (
k + 1 < length
and tl.load(
z + block_start + ((k + 1) * stride), mask=(k + 1) < length, other=q
)
< q
):
k += 1
r = tl.load(v + block_start + (k * stride))
d = q - r
old_value = tl.load(inputs_ptr + block_start + (r * stride))
tl.store(outputs_ptr + block_start + (q * stride), old_value + d * d)
def edt_triton(data: torch.Tensor):
"""
Computes the Euclidean Distance Transform (EDT) of a batch of binary images.
Args:
data: A tensor of shape (B, H, W) representing a batch of binary images.
Returns:
A tensor of the same shape as data containing the EDT.
It should be equivalent to a batched version of cv2.distanceTransform(input, cv2.DIST_L2, 0)
"""
assert data.dim() == 3
assert data.is_cuda
B, H, W = data.shape
data = data.contiguous()
# Allocate the "function" tensor. Implicitly the function is 0 if data[i,j]==0 else +infinity
output = torch.where(data, 1e18, 0.0)
assert output.is_contiguous()
# Scratch tensors for the parabola stacks
parabola_loc = torch.zeros(B, H, W, dtype=torch.uint32, device=data.device)
parabola_inter = torch.empty(B, H, W, dtype=torch.float, device=data.device)
parabola_inter[:, :, 0] = -1e18
parabola_inter[:, :, 1] = 1e18
# Grid size (number of blocks)
grid = (B, H)
# Launch initialization kernel
edt_kernel[grid](
output.clone(),
output,
parabola_loc,
parabola_inter,
H,
W,
horizontal=True,
)
# reset the parabola stacks
parabola_loc.zero_()
parabola_inter[:, :, 0] = -1e18
parabola_inter[:, :, 1] = 1e18
grid = (B, W)
edt_kernel[grid](
output.clone(),
output,
parabola_loc,
parabola_inter,
H,
W,
horizontal=False,
)
# don't forget to take sqrt at the end
return output.sqrt()

594
sam3/model/encoder.py Normal file
View File

@@ -0,0 +1,594 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# Based on https://github.com/IDEA-Research/GroundingDINO
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn, Tensor
from .act_ckpt_utils import activation_ckpt_wrapper
from .model_misc import get_activation_fn, get_clones, get_valid_ratio
class TransformerEncoderLayer(nn.Module):
"""
Transformer encoder layer that performs self-attention followed by cross-attention.
This layer was previously called TransformerDecoderLayer but was renamed to better
reflect its role in the architecture. It processes input sequences through self-attention
and then cross-attention with another input (typically image features).
The layer supports both pre-norm and post-norm configurations, as well as
positional encoding at different stages of the attention mechanism.
"""
def __init__(
self,
activation: str,
cross_attention: nn.Module,
d_model: int,
dim_feedforward: int,
dropout: float,
pos_enc_at_attn: bool,
pos_enc_at_cross_attn_keys: bool,
pos_enc_at_cross_attn_queries: bool,
pre_norm: bool,
self_attention: nn.Module,
):
"""
Initialize a transformer encoder layer.
Args:
activation: Activation function to use in the feedforward network
cross_attention: Cross-attention module for attending to image features
d_model: Model dimension/hidden size
dim_feedforward: Dimension of the feedforward network
dropout: Dropout probability
pos_enc_at_attn: Whether to add positional encodings at self-attention
pos_enc_at_cross_attn_keys: Whether to add positional encodings to keys in cross-attention
pos_enc_at_cross_attn_queries: Whether to add positional encodings to queries in cross-attention
pre_norm: Whether to use pre-norm (True) or post-norm (False) architecture
self_attention: Self-attention module
"""
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
self.dropout_value = dropout
self.self_attn = self_attention
self.cross_attn_image = cross_attention
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation_str = activation
self.activation = get_activation_fn(activation)
self.pre_norm = pre_norm
self.pos_enc_at_attn = pos_enc_at_attn
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
self.layer_idx = None
def forward_post(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
**kwargs,
) -> Tensor:
"""
Forward pass for post-norm architecture.
In post-norm architecture, normalization is applied after attention and feedforward operations.
Args:
tgt: Input tensor to be processed
memory: Memory tensor for cross-attention
tgt_mask: Mask for self-attention
memory_mask: Mask for cross-attention
tgt_key_padding_mask: Key padding mask for self-attention
memory_key_padding_mask: Key padding mask for cross-attention
pos: Positional encoding for memory
query_pos: Positional encoding for query
**kwargs: Additional keyword arguments
Returns:
Processed tensor
"""
q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
# Self attention
tgt2 = self.self_attn(
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# Cross attention to image
tgt2 = self.cross_attn_image(
query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# FFN
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt: Tensor,
memory: Tensor,
dac: bool = False,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
# attn_bias: Optional[Tensor] = None,
# **kwargs,
) -> Tensor:
"""
Forward pass for pre-norm architecture.
In pre-norm architecture, normalization is applied before attention and feedforward operations.
Args:
tgt: Input tensor to be processed
memory: Memory tensor for cross-attention
dac: Whether to use Divide-and-Conquer attention
tgt_mask: Mask for self-attention
memory_mask: Mask for cross-attention
tgt_key_padding_mask: Key padding mask for self-attention
memory_key_padding_mask: Key padding mask for cross-attention
pos: Positional encoding for memory
query_pos: Positional encoding for query
attn_bias: Optional attention bias tensor
**kwargs: Additional keyword arguments
Returns:
Processed tensor
"""
if dac:
# we only apply self attention to the first half of the queries
assert tgt.shape[0] % 2 == 0
other_tgt = tgt[tgt.shape[0] // 2 :]
tgt = tgt[: tgt.shape[0] // 2]
tgt2 = self.norm1(tgt)
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
tgt2 = self.self_attn(
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
if dac:
# Recombine
tgt = torch.cat((tgt, other_tgt), dim=0)
tgt2 = self.norm2(tgt)
tgt2 = self.cross_attn_image(
query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
# attn_bias=attn_bias,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt: Tensor,
memory: Tensor,
dac: bool = False,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
# attn_bias: Optional[Tensor] = None,
# **kwds: Any,
) -> torch.Tensor:
"""
Forward pass for the transformer encoder layer.
Args:
tgt: Input tensor to be processed
memory: Memory tensor (e.g., image features) for cross-attention
dac: Whether to use Divide-and-Conquer attention (only apply self-attention to first half)
tgt_mask: Mask for self-attention
memory_mask: Mask for cross-attention
tgt_key_padding_mask: Key padding mask for self-attention
memory_key_padding_mask: Key padding mask for cross-attention
pos: Positional encoding for memory
query_pos: Positional encoding for query
attn_bias: Optional attention bias tensor
**kwds: Additional keyword arguments
Returns:
Processed tensor after self-attention, cross-attention, and feedforward network
"""
fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
return fwd_fn(
tgt,
memory,
dac=dac,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
# attn_bias=attn_bias,
# **kwds,
)
class TransformerEncoder(nn.Module):
"""
Transformer encoder that processes multi-level features.
This encoder takes multi-level features (e.g., from a backbone network) and processes
them through a stack of transformer encoder layers. It supports features from multiple
levels (e.g., different resolutions) and can apply activation checkpointing for memory
efficiency during training.
Args:
layer: The encoder layer to be stacked multiple times
num_layers: Number of encoder layers to stack
d_model: Model dimension/hidden size
num_feature_levels: Number of feature levels to process
frozen: Whether to freeze the parameters of this module
use_act_checkpoint: Whether to use activation checkpointing during training
"""
def __init__(
self,
layer: nn.Module,
num_layers: int,
d_model: int,
num_feature_levels: int,
frozen: bool = False,
use_act_checkpoint: bool = False,
):
super().__init__()
self.layers = get_clones(layer, num_layers)
self.num_layers = num_layers
self.num_feature_levels = num_feature_levels
self.level_embed = None
if num_feature_levels > 1:
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
if frozen:
for p in self.parameters():
p.requires_grad_(False)
self.use_act_checkpoint = use_act_checkpoint
# assign layer index to each layer so that some layers can decide what to do
# based on which layer index they are (e.g. cross attention to memory bank only
# in selected layers)
for layer_idx, layer in enumerate(self.layers):
layer.layer_idx = layer_idx
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
with torch.no_grad():
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, H_ - 0.5, H_, dtype=torch.float32, device=device
),
torch.linspace(
0.5, W_ - 0.5, W_, dtype=torch.float32, device=device
),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
assert (
len(srcs) == self.num_feature_levels
), "mismatch between expected and received # of feature levels"
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
has_mask = masks is not None and masks[0] is not None
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
if has_mask:
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
if self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
if has_mask:
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
mask_flatten = torch.cat(mask_flatten, 1) if has_mask else None # bs, \sum{hxw}
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
spatial_shapes = torch.tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device
)
level_start_index = torch.cat(
(
spatial_shapes.new_zeros((1,)),
spatial_shapes.prod(1).cumsum(0)[:-1],
)
)
if has_mask:
valid_ratios = torch.stack([get_valid_ratio(m) for m in masks], 1)
else:
valid_ratios = torch.ones(
(src_flatten.shape[0], self.num_feature_levels, 2),
device=src_flatten.device,
)
return (
src_flatten,
mask_flatten,
lvl_pos_embed_flatten,
level_start_index,
valid_ratios,
spatial_shapes,
)
def forward(
self,
src: List[Tensor],
src_key_padding_masks: Optional[List[Tensor]] = None,
pos: Optional[List[Tensor]] = None,
prompt: Optional[Tensor] = None,
prompt_key_padding_mask: Optional[Tensor] = None,
encoder_extra_kwargs: Optional[Dict] = None,
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor, Tensor]:
"""
Process multi-level features through the transformer encoder.
Args:
src: List of multi-level features, each with shape (batch_size, channels, height, width)
src_key_padding_masks: List of padding masks for each feature level, each with shape (batch_size, height, width)
pos: List of positional embeddings for each feature level, each with shape (batch_size, channels, height, width)
prompt: Optional text/prompt features to attend to, with shape (seq_len, batch_size, d_model)
prompt_key_padding_mask: Optional padding mask for prompt, with shape (batch_size, seq_len)
encoder_extra_kwargs: Optional additional arguments to pass to each encoder layer
Returns:
A tuple containing:
- output: Processed features with shape (seq_len, batch_size, d_model)
- key_padding_masks_flatten: Flattened padding masks
- lvl_pos_embed_flatten: Flattened positional embeddings
- level_start_index: Starting indices for each feature level
- spatial_shapes: Spatial dimensions of each feature level
- valid_ratios: Valid ratios for each feature level
"""
assert (
len(src) == self.num_feature_levels
), "must be equal to num_feature_levels"
if src_key_padding_masks is not None:
assert len(src_key_padding_masks) == self.num_feature_levels
if pos is not None:
assert len(pos) == self.num_feature_levels
# Flatten multilevel feats and add level pos embeds
(
src_flatten,
key_padding_masks_flatten,
lvl_pos_embed_flatten,
level_start_index,
valid_ratios,
spatial_shapes,
) = self._prepare_multilevel_features(src, src_key_padding_masks, pos)
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=src_flatten.device
)
output = src_flatten
for layer in self.layers:
layer_kwargs = {}
assert isinstance(layer, TransformerEncoderLayer)
layer_kwargs["memory"] = prompt
layer_kwargs["memory_key_padding_mask"] = prompt_key_padding_mask
layer_kwargs["query_pos"] = lvl_pos_embed_flatten
layer_kwargs["tgt"] = output
layer_kwargs["tgt_key_padding_mask"] = key_padding_masks_flatten
if self.training:
assert self.use_act_checkpoint, "activation ckpt not enabled in encoder"
if encoder_extra_kwargs is not None:
layer_kwargs.update(encoder_extra_kwargs)
output = activation_ckpt_wrapper(layer)(
**layer_kwargs,
act_ckpt_enable=self.training and self.use_act_checkpoint,
)
# return as seq first
return (
output.transpose(0, 1),
(
key_padding_masks_flatten.transpose(0, 1)
if key_padding_masks_flatten is not None
else None
),
lvl_pos_embed_flatten.transpose(0, 1),
level_start_index,
spatial_shapes,
valid_ratios,
)
class TransformerEncoderFusion(TransformerEncoder):
"""
Transformer encoder that fuses text and image features.
This encoder extends TransformerEncoder to handle both text and image features,
with the ability to add pooled text features to image features for better
cross-modal fusion. It supports torch.compile for performance optimization.
Args:
layer: The encoder layer to be stacked multiple times
num_layers: Number of encoder layers to stack
d_model: Model dimension/hidden size
num_feature_levels: Number of feature levels to process
add_pooled_text_to_img_feat: Whether to add pooled text features to image features
pool_text_with_mask: Whether to use the mask when pooling text features
compile_mode: Mode for torch.compile, or None to disable compilation
**kwargs: Additional arguments to pass to the parent class
"""
def __init__(
self,
layer: nn.Module,
num_layers: int,
d_model: int,
num_feature_levels: int,
add_pooled_text_to_img_feat: bool = True,
pool_text_with_mask: bool = False,
compile_mode: Optional[str] = None,
**kwargs,
):
super().__init__(
layer,
num_layers,
d_model,
num_feature_levels,
**kwargs,
)
self.add_pooled_text_to_img_feat = add_pooled_text_to_img_feat
if self.add_pooled_text_to_img_feat:
self.text_pooling_proj = nn.Linear(d_model, d_model)
self.pool_text_with_mask = pool_text_with_mask
if compile_mode is not None:
self.forward = torch.compile(
self.forward, mode=compile_mode, fullgraph=True
)
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
# Not needed here
return None
def forward(
self,
src: List[Tensor],
prompt: Tensor,
src_key_padding_mask: Optional[List[Tensor]] = None,
src_pos: Optional[List[Tensor]] = None,
prompt_key_padding_mask: Optional[Tensor] = None,
prompt_pos: Optional[Tensor] = None,
feat_sizes: Optional[List[int]] = None,
encoder_extra_kwargs: Optional[Dict] = None,
):
# Restore spatial shapes of vision
bs = src[0].shape[1] # seq first
if feat_sizes is not None:
assert len(feat_sizes) == len(src)
if src_key_padding_mask is None:
src_key_padding_mask = [None] * len(src)
for i, (h, w) in enumerate(feat_sizes):
src[i] = src[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
src_pos[i] = src_pos[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
src_key_padding_mask[i] = (
src_key_padding_mask[i].reshape(h, w, bs).permute(2, 0, 1)
if src_key_padding_mask[i] is not None
else None
)
else:
assert all(
x.dim == 4 for x in src
), "expected list of (bs, c, h, w) tensors"
if self.add_pooled_text_to_img_feat:
# Fusion: Add mean pooled text to image features
pooled_text = pool_text_feat(
prompt, prompt_key_padding_mask, self.pool_text_with_mask
)
pooled_text = self.text_pooling_proj(pooled_text)[
..., None, None
] # prompt is seq first
src = [x.add_(pooled_text) for x in src]
(
out,
key_padding_masks_flatten,
lvl_pos_embed_flatten,
level_start_index,
spatial_shapes,
valid_ratios,
) = super().forward(
src,
src_key_padding_masks=src_key_padding_mask,
pos=src_pos,
prompt=prompt.transpose(0, 1),
prompt_key_padding_mask=prompt_key_padding_mask,
encoder_extra_kwargs=encoder_extra_kwargs,
)
return {
"memory": out,
"padding_mask": key_padding_masks_flatten,
"pos_embed": lvl_pos_embed_flatten,
"memory_text": prompt,
"level_start_index": level_start_index,
"spatial_shapes": spatial_shapes,
"valid_ratios": valid_ratios,
}
def pool_text_feat(prompt, prompt_mask, pool_with_mask):
# prompt has shape (seq, bs, dim)
if not pool_with_mask:
return prompt.mean(dim=0)
# prompt_mask has shape (bs, seq), where False is valid and True is padding
assert prompt_mask.dim() == 2
# is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
is_valid = (~prompt_mask).float().permute(1, 0)[..., None]
# num_valid has shape (bs, 1)
num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
# mean pool over all the valid tokens
pooled_text = (prompt * is_valid).sum(dim=0) / num_valid
return pooled_text

View File

@@ -0,0 +1,850 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Tuple
import torch
import torch.nn as nn
import torchvision
from typing_extensions import override
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .model_misc import get_clones
def is_right_padded(mask):
"""Given a padding mask (following pytorch convention, 1s for padded values),
returns whether the padding is on the right or not."""
return (mask.long() == torch.sort(mask.long(), dim=-1)[0]).all()
def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
"""
Concatenates two right-padded sequences, such that the resulting sequence
is contiguous and also right-padded.
Following pytorch's convention, tensors are sequence first, and the mask are
batch first, with 1s for padded values.
:param seq1: A tensor of shape (seq1_length, batch_size, hidden_size).
:param mask1: A tensor of shape (batch_size, seq1_length).
:param seq2: A tensor of shape (seq2_length, batch_size, hidden_size).
:param mask2: A tensor of shape (batch_size, seq2_length).
:param return_index: If True, also returns the index of the ids of the element of seq2
in the concatenated sequence. This can be used to retrieve the elements of seq2
:return: A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
otherwise (concatenated_sequence, concatenated_mask, index).
"""
seq1_length, batch_size, hidden_size = seq1.shape
seq2_length, batch_size, hidden_size = seq2.shape
assert batch_size == seq1.size(1) == seq2.size(1) == mask1.size(0) == mask2.size(0)
assert hidden_size == seq1.size(2) == seq2.size(2)
assert seq1_length == mask1.size(1)
assert seq2_length == mask2.size(1)
torch._assert_async(is_right_padded(mask1))
torch._assert_async(is_right_padded(mask2))
actual_seq1_lengths = (~mask1).sum(dim=-1)
actual_seq2_lengths = (~mask2).sum(dim=-1)
final_lengths = actual_seq1_lengths + actual_seq2_lengths
max_length = seq1_length + seq2_length
concatenated_mask = (
torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1)
>= final_lengths[:, None]
)
# (max_len, batch_size, hidden_size)
concatenated_sequence = torch.zeros(
(max_length, batch_size, hidden_size), device=seq2.device, dtype=seq2.dtype
)
concatenated_sequence[:seq1_length, :, :] = seq1
# At this point, the element of seq1 are in the right place
# We just need to shift the elements of seq2
index = torch.arange(seq2_length, device=seq2.device)[:, None].repeat(1, batch_size)
index = index + actual_seq1_lengths[None]
concatenated_sequence = concatenated_sequence.scatter(
0, index[:, :, None].expand(-1, -1, hidden_size), seq2
)
if return_index:
return concatenated_sequence, concatenated_mask, index
return concatenated_sequence, concatenated_mask
class Prompt:
"""Utility class to manipulate geometric prompts.
We expect the sequences in pytorch convention, that is sequence first, batch second
The dimensions are expected as follows:
box_embeddings shape: N_boxes x B x C_box
box_mask shape: B x N_boxes. Can be None if nothing is masked out
point_embeddings shape: N_points x B x C_point
point_mask shape: B x N_points. Can be None if nothing is masked out
mask_embeddings shape: N_masks x B x 1 x H_mask x W_mask
mask_mask shape: B x N_masks. Can be None if nothing is masked out
We also store positive/negative labels. These tensors are also stored batch-first
If they are None, we'll assume positive labels everywhere
box_labels: long tensor of shape N_boxes x B
point_labels: long tensor of shape N_points x B
mask_labels: long tensor of shape N_masks x B
"""
def __init__(
self,
box_embeddings=None,
box_mask=None,
point_embeddings=None,
point_mask=None,
box_labels=None,
point_labels=None,
mask_embeddings=None,
mask_mask=None, # Attention mask for mask prompt
mask_labels=None,
):
# Check for null prompt
if (
box_embeddings is None
and point_embeddings is None
and mask_embeddings is None
):
self.box_embeddings = None
self.box_labels = None
self.box_mask = None
self.point_embeddings = None
self.point_labels = None
self.point_mask = None
self.mask_embeddings = None
self.mask_mask = None
# Masks are assumed positive only for now.
self.mask_labels = None
return
# Get sequence lengths and device
box_seq_len, point_seq_len, mask_seq_len, bs, device = (
self._init_seq_len_and_device(
box_embeddings, point_embeddings, mask_embeddings
)
)
# Initialize embeds, labels, attention masks.
box_embeddings, box_labels, box_mask = self._init_box(
box_embeddings, box_labels, box_mask, box_seq_len, bs, device
)
point_embeddings, point_labels, point_mask = self._init_point(
point_embeddings, point_labels, point_mask, point_seq_len, bs, device
)
mask_embeddings, mask_labels, mask_mask = self._init_mask(
mask_embeddings, mask_labels, mask_mask, mask_seq_len, bs, device
)
# Dimension checks
assert (
box_embeddings is not None
and list(box_embeddings.shape[:2])
== [
box_seq_len,
bs,
]
), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
assert (
box_mask is not None
and list(box_mask.shape)
== [
bs,
box_seq_len,
]
), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
assert (
point_embeddings is not None
and list(point_embeddings.shape[:2])
== [
point_seq_len,
bs,
]
), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
assert (
point_mask is not None
and list(point_mask.shape)
== [
bs,
point_seq_len,
]
), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
assert (
box_labels is not None
and list(box_labels.shape)
== [
box_seq_len,
bs,
]
), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
assert (
point_labels is not None
and list(point_labels.shape)
== [
point_seq_len,
bs,
]
), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
assert (
# Allowed to be None, we leave it to the encoder to check for validity before encoding.
mask_embeddings is None
or list(mask_embeddings.shape[:2])
== [
mask_seq_len,
bs,
]
), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
assert (
mask_mask is None
or list(mask_mask.shape)
== [
bs,
mask_seq_len,
]
), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
# Device checks
assert (
box_embeddings is not None and box_embeddings.device == device
), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
assert (
box_mask is not None and box_mask.device == device
), f"Expected box mask to be on device {device}, got {box_mask.device}"
assert (
box_labels is not None and box_labels.device == device
), f"Expected box labels to be on device {device}, got {box_labels.device}"
assert (
point_embeddings is not None and point_embeddings.device == device
), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
assert (
point_mask is not None and point_mask.device == device
), f"Expected point mask to be on device {device}, got {point_mask.device}"
assert (
point_labels is not None and point_labels.device == device
), f"Expected point labels to be on device {device}, got {point_labels.device}"
assert (
mask_embeddings is None or mask_embeddings.device == device
), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
assert (
mask_mask is None or mask_mask.device == device
), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
self.box_embeddings = box_embeddings
self.point_embeddings = point_embeddings
self.box_mask = box_mask
self.point_mask = point_mask
self.box_labels = box_labels
self.point_labels = point_labels
self.mask_embeddings = mask_embeddings
self.mask_labels = mask_labels
self.mask_mask = mask_mask
def _init_seq_len_and_device(
self, box_embeddings, point_embeddings, mask_embeddings
):
box_seq_len = point_seq_len = mask_seq_len = 0
bs = None
device = None
if box_embeddings is not None:
bs = box_embeddings.shape[1]
box_seq_len = box_embeddings.shape[0]
device = box_embeddings.device
if point_embeddings is not None:
point_seq_len = point_embeddings.shape[0]
if bs is not None:
assert (
bs == point_embeddings.shape[1]
), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
else:
bs = point_embeddings.shape[1]
if device is not None:
assert (
device == point_embeddings.device
), "Device mismatch between box and point embeddings"
else:
device = point_embeddings.device
if mask_embeddings is not None:
mask_seq_len = mask_embeddings.shape[0]
if bs is not None:
assert (
bs == mask_embeddings.shape[1]
), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
else:
bs = mask_embeddings.shape[1]
if device is not None:
assert (
device == mask_embeddings.device
), "Device mismatch between box/point and mask embeddings."
else:
device = mask_embeddings.device
return box_seq_len, point_seq_len, mask_seq_len, bs, device
def _init_box(self, box_embeddings, box_labels, box_mask, box_seq_len, bs, device):
if box_embeddings is None:
box_embeddings = torch.zeros(box_seq_len, bs, 4, device=device)
if box_labels is None:
box_labels = torch.ones(box_seq_len, bs, device=device, dtype=torch.long)
if box_mask is None:
box_mask = torch.zeros(bs, box_seq_len, device=device, dtype=torch.bool)
return box_embeddings, box_labels, box_mask
def _init_point(
self, point_embeddings, point_labels, point_mask, point_seq_len, bs, device
):
"""
Identical to _init_box. Except that C=2 for points (vs. 4 for boxes).
"""
if point_embeddings is None:
point_embeddings = torch.zeros(point_seq_len, bs, 2, device=device)
if point_labels is None:
point_labels = torch.ones(
point_seq_len, bs, device=device, dtype=torch.long
)
if point_mask is None:
point_mask = torch.zeros(bs, point_seq_len, device=device, dtype=torch.bool)
return point_embeddings, point_labels, point_mask
def _init_mask(
self, mask_embeddings, mask_labels, mask_mask, mask_seq_len, bs, device
):
# NOTE: Mask embeddings can be of arbitrary resolution, so we don't initialize it here.
# In case we append new mask, we check that its resolution matches exisiting ones (if any).
# In case mask_embeddings is None, we should never encode it.
if mask_labels is None:
mask_labels = torch.ones(mask_seq_len, bs, device=device, dtype=torch.long)
if mask_mask is None:
mask_mask = torch.zeros(bs, mask_seq_len, device=device, dtype=torch.bool)
return mask_embeddings, mask_labels, mask_mask
def append_boxes(self, boxes, labels, mask=None):
if self.box_embeddings is None:
self.box_embeddings = boxes
self.box_labels = labels
self.box_mask = mask
return
bs = self.box_embeddings.shape[1]
assert boxes.shape[1] == labels.shape[1] == bs
assert list(boxes.shape[:2]) == list(labels.shape[:2])
if mask is None:
mask = torch.zeros(
bs, boxes.shape[0], dtype=torch.bool, device=boxes.device
)
self.box_labels, _ = concat_padded_sequences(
self.box_labels.unsqueeze(-1), self.box_mask, labels.unsqueeze(-1), mask
)
self.box_labels = self.box_labels.squeeze(-1)
self.box_embeddings, self.box_mask = concat_padded_sequences(
self.box_embeddings, self.box_mask, boxes, mask
)
def append_points(self, points, labels, mask=None):
if self.point_embeddings is None:
self.point_embeddings = points
self.point_labels = labels
self.point_mask = mask
return
bs = self.point_embeddings.shape[1]
assert points.shape[1] == labels.shape[1] == bs
assert list(points.shape[:2]) == list(labels.shape[:2])
if mask is None:
mask = torch.zeros(
bs, points.shape[0], dtype=torch.bool, device=points.device
)
self.point_labels, _ = concat_padded_sequences(
self.point_labels.unsqueeze(-1), self.point_mask, labels.unsqueeze(-1), mask
)
self.point_labels = self.point_labels.squeeze(-1)
self.point_embeddings, self.point_mask = concat_padded_sequences(
self.point_embeddings, self.point_mask, points, mask
)
def append_masks(self, masks, labels=None, attn_mask=None):
if labels is not None:
assert list(masks.shape[:2]) == list(labels.shape[:2])
if self.mask_embeddings is None:
self.mask_embeddings = masks
mask_seq_len, bs = masks.shape[:2]
if labels is None:
self.mask_labels = torch.ones(
mask_seq_len, bs, device=masks.device, dtype=torch.long
)
else:
self.mask_labels = labels
if attn_mask is None:
self.mask_mask = torch.zeros(
bs, mask_seq_len, device=masks.device, dtype=torch.bool
)
else:
self.mask_mask = attn_mask
else:
raise NotImplementedError("Only one mask per prompt is supported.")
def clone(self):
return Prompt(
box_embeddings=(
None if self.box_embeddings is None else self.box_embeddings.clone()
),
box_mask=None if self.box_mask is None else self.box_mask.clone(),
point_embeddings=(
None if self.point_embeddings is None else self.point_embeddings.clone()
),
point_mask=None if self.point_mask is None else self.point_mask.clone(),
box_labels=None if self.box_labels is None else self.box_labels.clone(),
point_labels=(
None if self.point_labels is None else self.point_labels.clone()
),
)
class MaskEncoder(nn.Module):
"""
Base class for mask encoders.
"""
def __init__(
self,
mask_downsampler: nn.Module,
position_encoding: nn.Module,
):
super().__init__()
self.mask_downsampler = mask_downsampler
self.position_encoding = position_encoding
def forward(self, masks, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
masks = self.mask_downsampler(masks)
masks_pos = self.position_encoding(masks).to(masks.dtype)
return masks, masks_pos
class FusedMaskEncoder(MaskEncoder):
"""
Identical to memory.SimpleMaskEncoder but follows the interface of geometry_encoders.MaskEncoder.
We also remove the `skip_mask_sigmoid` option (to be handled outside the MaskEncoder).
Fuses backbone image features with mask features.
"""
def __init__(
self,
mask_downsampler: nn.Module,
position_encoding: nn.Module,
fuser: nn.Module,
in_dim: int = 256,
out_dim: int = 256,
):
super().__init__(mask_downsampler, position_encoding)
self.fuser = fuser
self.out_proj = nn.Identity()
if out_dim != in_dim:
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
@override
def forward(
self,
masks: torch.Tensor,
pix_feat: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
masks = self.mask_downsampler(masks)
## Fuse pix_feats and downsampled masks
# in case the visual features are on CPU, cast them to CUDA
pix_feat = pix_feat.to(masks.device)
x = self.pix_feat_proj(pix_feat)
x = x + masks
x = self.fuser(x)
x = self.out_proj(x)
pos = self.position_encoding(x).to(x.dtype)
return x, pos
class SequenceGeometryEncoder(nn.Module):
"""
This a fully fledged encoder for geometric prompts.
It assumes boxes are passed in the "normalized CxCyWH" format, and points in normalized xy
This allows flexibility in how to encode the features (eg do pooling)
Points and boxes can be encoded with any of the three possibilities:
- direct projection: we just compute a linear from coordinate space to d_model
- pooling: pool features from the backbone in the requested location.
For boxes, it's a roi align
For points it's a grid sample
- pos encoder: Take the position encoding of the point or box center
These three options are mutually compatible. If several are selected, we'll take a simple addition
As an alternative, we offer the possibility to encode points only.
In that case, the boxes are converted to two points for the top left and bottom right corners (with appropriate labels)
On top of these encodings, we offer the possibility to further encode the prompt sequence with a transformer.
"""
def __init__(
self,
encode_boxes_as_points: bool,
points_direct_project: bool,
points_pool: bool,
points_pos_enc: bool,
boxes_direct_project: bool,
boxes_pool: bool,
boxes_pos_enc: bool,
d_model: int,
pos_enc,
num_layers: int,
layer: nn.Module,
roi_size: int = 7, # for boxes pool
add_cls: bool = True,
add_post_encode_proj: bool = True,
mask_encoder: MaskEncoder = None,
add_mask_label: bool = False,
use_act_ckpt: bool = False,
):
super().__init__()
self.d_model = d_model
self.pos_enc = pos_enc
self.encode_boxes_as_points = encode_boxes_as_points
self.roi_size = roi_size
# There usually are two labels: positive and negatives.
# If we encode boxes as points, we have 3 types of points: regular, top left, bottom right
# These 3 types can be positives or negatives, hence 2*3 = 6 labels
num_labels = 6 if self.encode_boxes_as_points else 2
self.label_embed = torch.nn.Embedding(num_labels, self.d_model)
# This is a cls token, can be used for pooling if need be.
# It also ensures that the encoded sequences are always non-empty
self.cls_embed = None
if add_cls:
self.cls_embed = torch.nn.Embedding(1, self.d_model)
assert (
points_direct_project or points_pos_enc or points_pool
), "Error: need at least one way to encode points"
assert (
encode_boxes_as_points
or boxes_direct_project
or boxes_pos_enc
or boxes_pool
), "Error: need at least one way to encode boxes"
self.points_direct_project = None
if points_direct_project:
self.points_direct_project = nn.Linear(2, self.d_model)
self.points_pool_project = None
if points_pool:
self.points_pool_project = nn.Linear(self.d_model, self.d_model)
self.points_pos_enc_project = None
if points_pos_enc:
self.points_pos_enc_project = nn.Linear(self.d_model, self.d_model)
self.boxes_direct_project = None
self.boxes_pool_project = None
self.boxes_pos_enc_project = None
if not encode_boxes_as_points:
if boxes_direct_project:
self.boxes_direct_project = nn.Linear(4, self.d_model)
if boxes_pool:
self.boxes_pool_project = nn.Conv2d(
self.d_model, self.d_model, self.roi_size
)
if boxes_pos_enc:
self.boxes_pos_enc_project = nn.Linear(self.d_model + 2, self.d_model)
self.final_proj = None
if add_post_encode_proj:
self.final_proj = nn.Linear(self.d_model, self.d_model)
self.norm = nn.LayerNorm(self.d_model)
self.img_pre_norm = nn.Identity()
if self.points_pool_project is not None or self.boxes_pool_project is not None:
self.img_pre_norm = nn.LayerNorm(self.d_model)
self.encode = None
if num_layers > 0:
assert (
add_cls
), "It's currently highly recommended to add a CLS when using a transformer"
self.encode = get_clones(layer, num_layers)
self.encode_norm = nn.LayerNorm(self.d_model)
if mask_encoder is not None:
assert isinstance(
mask_encoder, MaskEncoder
), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
if add_mask_label:
self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
self.add_mask_label = add_mask_label
self.mask_encoder = mask_encoder
self.use_act_ckpt = use_act_ckpt
def _encode_points(self, points, points_mask, points_labels, img_feats):
points_embed = None
n_points, bs = points.shape[:2]
if self.points_direct_project is not None:
proj = self.points_direct_project(points)
assert points_embed is None
points_embed = proj
if self.points_pool_project is not None:
# points are [Num_points, bs, 2], normalized in [0, 1]
# the grid needs to be [Bs, H_out, W_out, 2] normalized in [-1,1]
# Will take H_out = num_points, w_out = 1
grid = points.transpose(0, 1).unsqueeze(2)
# re normalize to [-1, 1]
grid = (grid * 2) - 1
sampled = torch.nn.functional.grid_sample(
img_feats, grid, align_corners=False
)
assert list(sampled.shape) == [bs, self.d_model, n_points, 1]
sampled = sampled.squeeze(-1).permute(2, 0, 1)
proj = self.points_pool_project(sampled)
if points_embed is None:
points_embed = proj
else:
points_embed = points_embed + proj
if self.points_pos_enc_project is not None:
x, y = points.unbind(-1)
enc_x, enc_y = self.pos_enc._encode_xy(x.flatten(), y.flatten())
enc_x = enc_x.view(n_points, bs, enc_x.shape[-1])
enc_y = enc_y.view(n_points, bs, enc_y.shape[-1])
enc = torch.cat([enc_x, enc_y], -1)
proj = self.points_pos_enc_project(enc)
if points_embed is None:
points_embed = proj
else:
points_embed = points_embed + proj
type_embed = self.label_embed(points_labels.long())
return type_embed + points_embed, points_mask
def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats):
boxes_embed = None
n_boxes, bs = boxes.shape[:2]
if self.boxes_direct_project is not None:
proj = self.boxes_direct_project(boxes)
assert boxes_embed is None
boxes_embed = proj
if self.boxes_pool_project is not None:
H, W = img_feats.shape[-2:]
# boxes are [Num_boxes, bs, 4], normalized in [0, 1]
# We need to denormalize, and convert to [x, y, x, y]
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
scale = scale.view(1, 1, 4)
boxes_xyxy = boxes_xyxy * scale
sampled = torchvision.ops.roi_align(
img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size
)
assert list(sampled.shape) == [
bs * n_boxes,
self.d_model,
self.roi_size,
self.roi_size,
]
proj = self.boxes_pool_project(sampled)
proj = proj.view(bs, n_boxes, self.d_model).transpose(0, 1)
if boxes_embed is None:
boxes_embed = proj
else:
boxes_embed = boxes_embed + proj
if self.boxes_pos_enc_project is not None:
cx, cy, w, h = boxes.unbind(-1)
enc = self.pos_enc.encode_boxes(
cx.flatten(), cy.flatten(), w.flatten(), h.flatten()
)
enc = enc.view(boxes.shape[0], boxes.shape[1], enc.shape[-1])
proj = self.boxes_pos_enc_project(enc)
if boxes_embed is None:
boxes_embed = proj
else:
boxes_embed = boxes_embed + proj
type_embed = self.label_embed(boxes_labels.long())
return type_embed + boxes_embed, boxes_mask
def _encode_masks(
self,
masks: torch.Tensor,
attn_mask: torch.Tensor,
mask_labels: torch.Tensor,
img_feats: torch.Tensor = None,
):
n_masks, bs = masks.shape[:2]
assert (
n_masks == 1
), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
assert (
list(attn_mask.shape)
== [
bs,
n_masks,
]
), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
masks, pos = self.mask_encoder(
masks=masks.flatten(0, 1).float(),
pix_feat=img_feats,
)
H, W = masks.shape[-2:]
n_tokens_per_mask = H * W
# NOTE: We directly add pos enc here as we usually don't keep track of pos encoding for the concatenated prompt (text, other geometric prompts). Might need to do some refactoring for more flexibility.
masks = masks + pos
masks = masks.view(n_masks, bs, *masks.shape[1:]).flatten(
-2
) # n_masks x bs x C x H*W
masks = masks.permute(0, 3, 1, 2).flatten(0, 1) # n_masks * H*W x bs x C
attn_mask = attn_mask.repeat_interleave(n_tokens_per_mask, dim=1)
if self.add_mask_label:
masks = masks + self.mask_label_embed(mask_labels.long())
return masks, attn_mask
def forward(self, geo_prompt: Prompt, img_feats, img_sizes, img_pos_embeds=None):
points = geo_prompt.point_embeddings
points_mask = geo_prompt.point_mask
points_labels = geo_prompt.point_labels
boxes = geo_prompt.box_embeddings
boxes_mask = geo_prompt.box_mask
boxes_labels = geo_prompt.box_labels
masks = geo_prompt.mask_embeddings
masks_mask = geo_prompt.mask_mask
masks_labels = geo_prompt.mask_labels
seq_first_img_feats = img_feats[-1] # [H*W, B, C]
seq_first_img_pos_embeds = (
img_pos_embeds[-1]
if img_pos_embeds is not None
else torch.zeros_like(seq_first_img_feats)
)
if self.points_pool_project or self.boxes_pool_project:
assert len(img_feats) == len(img_sizes)
cur_img_feat = img_feats[-1]
cur_img_feat = self.img_pre_norm(cur_img_feat)
H, W = img_sizes[-1]
assert cur_img_feat.shape[0] == H * W
N, C = cur_img_feat.shape[-2:]
# Put back in NxCxHxW
cur_img_feat = cur_img_feat.permute(1, 2, 0)
cur_img_feat = cur_img_feat.view(N, C, H, W)
img_feats = cur_img_feat
if self.encode_boxes_as_points:
assert boxes is not None
assert geo_prompt.box_mask is not None
assert geo_prompt.box_labels is not None
assert boxes.shape[-1] == 4
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
top_left, bottom_right = boxes_xyxy.split(split_size=2, dim=-1)
labels_tl = geo_prompt.box_labels + 2
labels_br = geo_prompt.box_labels + 4
# Append to the existing points
points, _ = concat_padded_sequences(
points, points_mask, top_left, boxes_mask
)
points_labels, points_mask = concat_padded_sequences(
points_labels.unsqueeze(-1),
points_mask,
labels_tl.unsqueeze(-1),
boxes_mask,
)
points_labels = points_labels.squeeze(-1)
points, _ = concat_padded_sequences(
points, points_mask, bottom_right, boxes_mask
)
points_labels, points_mask = concat_padded_sequences(
points_labels.unsqueeze(-1),
points_mask,
labels_br.unsqueeze(-1),
boxes_mask,
)
points_labels = points_labels.squeeze(-1)
final_embeds, final_mask = self._encode_points(
points=points,
points_mask=points_mask,
points_labels=points_labels,
img_feats=img_feats,
)
if not self.encode_boxes_as_points:
boxes_embeds, boxes_mask = self._encode_boxes(
boxes=boxes,
boxes_mask=boxes_mask,
boxes_labels=boxes_labels,
img_feats=img_feats,
)
final_embeds, final_mask = concat_padded_sequences(
final_embeds, final_mask, boxes_embeds, boxes_mask
)
if masks is not None and self.mask_encoder is not None:
masks_embed, masks_mask = self._encode_masks(
masks=masks,
attn_mask=masks_mask,
mask_labels=masks_labels,
img_feats=img_feats,
)
if points.size(0) == boxes.size(0) == 0:
return masks_embed, masks_mask
bs = final_embeds.shape[1]
assert final_mask.shape[0] == bs
if self.cls_embed is not None:
cls = self.cls_embed.weight.view(1, 1, self.d_model).repeat(1, bs, 1)
cls_mask = torch.zeros(
bs, 1, dtype=final_mask.dtype, device=final_mask.device
)
final_embeds, final_mask = concat_padded_sequences(
final_embeds, final_mask, cls, cls_mask
)
if self.final_proj is not None:
final_embeds = self.norm(self.final_proj(final_embeds))
if self.encode is not None:
for lay in self.encode:
final_embeds = activation_ckpt_wrapper(lay)(
tgt=final_embeds,
memory=seq_first_img_feats,
tgt_key_padding_mask=final_mask,
pos=seq_first_img_pos_embeds,
act_ckpt_enable=self.training and self.use_act_ckpt,
)
final_embeds = self.encode_norm(final_embeds)
# Finally, concat mask embeddings if any
if masks is not None and self.mask_encoder is not None:
final_embeds, final_mask = concat_padded_sequences(
final_embeds, final_mask, masks_embed, masks_mask
)
return final_embeds, final_mask

709
sam3/model/io_utils.py Normal file
View File

@@ -0,0 +1,709 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import contextlib
import os
import queue
import re
import time
from threading import Condition, get_ident, Lock, Thread
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
from sam3.logger import get_logger
from tqdm import tqdm
logger = get_logger(__name__)
IS_MAIN_PROCESS = os.getenv("IS_MAIN_PROCESS", "1") == "1"
RANK = int(os.getenv("RANK", "0"))
IMAGE_EXTS = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"]
VIDEO_EXTS = [".mp4", ".mov", ".avi", ".mkv", ".webm"]
def load_resource_as_video_frames(
resource_path,
image_size,
offload_video_to_cpu,
img_mean=(0.5, 0.5, 0.5),
img_std=(0.5, 0.5, 0.5),
async_loading_frames=False,
video_loader_type="cv2",
):
"""
Load video frames from either a video or an image (as a single-frame video).
Alternatively, if input is a list of PIL images, convert its format
"""
if isinstance(resource_path, list):
img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
assert all(isinstance(img_pil, Image.Image) for img_pil in resource_path)
assert len(resource_path) is not None
orig_height, orig_width = resource_path[0].size
orig_height, orig_width = (
orig_width,
orig_height,
) # For some reason, this method returns these swapped
images = []
for img_pil in resource_path:
img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
assert img_np.dtype == np.uint8, "np.uint8 is expected for JPEG images"
img_np = img_np / 255.0
img = torch.from_numpy(img_np).permute(2, 0, 1)
# float16 precision should be sufficient for image tensor storage
img = img.to(dtype=torch.float16)
# normalize by mean and std
img -= img_mean
img /= img_std
images.append(img)
images = torch.stack(images)
if not offload_video_to_cpu:
images = images.cuda()
return images, orig_height, orig_width
is_image = (
isinstance(resource_path, str)
and os.path.splitext(resource_path)[-1].lower() in IMAGE_EXTS
)
if is_image:
return load_image_as_single_frame_video(
image_path=resource_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
)
else:
return load_video_frames(
video_path=resource_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
async_loading_frames=async_loading_frames,
video_loader_type=video_loader_type,
)
def load_image_as_single_frame_video(
image_path,
image_size,
offload_video_to_cpu,
img_mean=(0.5, 0.5, 0.5),
img_std=(0.5, 0.5, 0.5),
):
"""Load an image as a single-frame video."""
images, image_height, image_width = _load_img_as_tensor(image_path, image_size)
images = images.unsqueeze(0).half()
img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
# normalize by mean and std
images -= img_mean
images /= img_std
return images, image_height, image_width
def load_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.5, 0.5, 0.5),
img_std=(0.5, 0.5, 0.5),
async_loading_frames=False,
video_loader_type="cv2",
):
"""
Load the video frames from video_path. The frames are resized to image_size as in
the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo.
"""
assert isinstance(video_path, str)
if video_path.startswith("<load-dummy-video"):
# Check for pattern <load-dummy-video-N> where N is an integer
match = re.match(r"<load-dummy-video-(\d+)>", video_path)
num_frames = int(match.group(1)) if match else 60
return load_dummy_video(image_size, offload_video_to_cpu, num_frames=num_frames)
elif os.path.isdir(video_path):
return load_video_frames_from_image_folder(
image_folder=video_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
async_loading_frames=async_loading_frames,
)
elif os.path.splitext(video_path)[-1].lower() in VIDEO_EXTS:
return load_video_frames_from_video_file(
video_path=video_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
async_loading_frames=async_loading_frames,
video_loader_type=video_loader_type,
)
else:
raise NotImplementedError("Only video files and image folders are supported")
def load_video_frames_from_image_folder(
image_folder,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
async_loading_frames,
):
"""
Load the video frames from a directory of image files ("<frame_index>.<img_ext>" format)
"""
frame_names = [
p
for p in os.listdir(image_folder)
if os.path.splitext(p)[-1].lower() in IMAGE_EXTS
]
try:
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
except ValueError:
# fallback to lexicographic sort if the format is not "<frame_index>.<img_ext>"
logger.warning(
f'frame names are not in "<frame_index>.<img_ext>" format: {frame_names[:5]=}, '
f"falling back to lexicographic sort."
)
frame_names.sort()
num_frames = len(frame_names)
if num_frames == 0:
raise RuntimeError(f"no images found in {image_folder}")
img_paths = [os.path.join(image_folder, frame_name) for frame_name in frame_names]
img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
if async_loading_frames:
lazy_images = AsyncImageFrameLoader(
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
)
return lazy_images, lazy_images.video_height, lazy_images.video_width
# float16 precision should be sufficient for image tensor storage
images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float16)
video_height, video_width = None, None
for n, img_path in enumerate(
tqdm(img_paths, desc=f"frame loading (image folder) [rank={RANK}]")
):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
def load_video_frames_from_video_file(
video_path,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
async_loading_frames,
gpu_acceleration=False,
gpu_device=None,
video_loader_type="cv2",
):
"""Load the video frames from a video file."""
if video_loader_type == "cv2":
return load_video_frames_from_video_file_using_cv2(
video_path=video_path,
image_size=image_size,
img_mean=img_mean,
img_std=img_std,
offload_video_to_cpu=offload_video_to_cpu,
)
elif video_loader_type == "torchcodec":
logger.info("Using torchcodec to load video file")
lazy_images = AsyncVideoFileLoaderWithTorchCodec(
video_path=video_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
gpu_acceleration=gpu_acceleration,
gpu_device=gpu_device,
)
# The `AsyncVideoFileLoaderWithTorchCodec` class always loads the videos asynchronously,
# so we just wait for its loading thread to finish if async_loading_frames=False.
if not async_loading_frames:
async_thread = lazy_images.thread
if async_thread is not None:
async_thread.join()
return lazy_images, lazy_images.video_height, lazy_images.video_width
else:
raise RuntimeError("video_loader_type must be either 'cv2' or 'torchcodec'")
def load_video_frames_from_video_file_using_cv2(
video_path: str,
image_size: int,
img_mean: tuple = (0.5, 0.5, 0.5),
img_std: tuple = (0.5, 0.5, 0.5),
offload_video_to_cpu: bool = False,
) -> torch.Tensor:
"""
Load video from path, convert to normalized tensor with specified preprocessing
Args:
video_path: Path to video file
image_size: Target size for square frames (height and width)
img_mean: Normalization mean (RGB)
img_std: Normalization standard deviation (RGB)
Returns:
torch.Tensor: Preprocessed video tensor in shape (T, C, H, W) with float16 dtype
"""
import cv2 # delay OpenCV import to avoid unnecessary dependency
# Initialize video capture
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Could not open video: {video_path}")
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
num_frames = num_frames if num_frames > 0 else None
frames = []
pbar = tqdm(desc=f"frame loading (OpenCV) [rank={RANK}]", total=num_frames)
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB and resize
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_resized = cv2.resize(
frame_rgb, (image_size, image_size), interpolation=cv2.INTER_CUBIC
)
frames.append(frame_resized)
pbar.update(1)
cap.release()
pbar.close()
# Convert to tensor
frames_np = np.stack(frames, axis=0).astype(np.float32) # (T, H, W, C)
video_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2) # (T, C, H, W)
img_mean = torch.tensor(img_mean, dtype=torch.float16).view(1, 3, 1, 1)
img_std = torch.tensor(img_std, dtype=torch.float16).view(1, 3, 1, 1)
if not offload_video_to_cpu:
video_tensor = video_tensor.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
# normalize by mean and std
video_tensor -= img_mean
video_tensor /= img_std
return video_tensor, original_height, original_width
def load_dummy_video(image_size, offload_video_to_cpu, num_frames=60):
"""
Load a dummy video with random frames for testing and compilation warmup purposes.
"""
video_height, video_width = 480, 640 # dummy original video sizes
images = torch.randn(num_frames, 3, image_size, image_size, dtype=torch.float16)
if not offload_video_to_cpu:
images = images.cuda()
return images, video_height, video_width
def _load_img_as_tensor(img_path, image_size):
"""Load and resize an image and convert it into a PyTorch tensor."""
img = Image.open(img_path).convert("RGB")
orig_width, orig_height = img.width, img.height
img = TF.resize(img, size=(image_size, image_size))
img = TF.to_tensor(img)
return img, orig_height, orig_width
class AsyncImageFrameLoader:
"""
A list of video frames to be load asynchronously without blocking session start.
"""
def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
self.img_paths = img_paths
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
self.img_mean = img_mean
self.img_std = img_std
# items in `self._images` will be loaded asynchronously
self.images = [None] * len(img_paths)
# catch and raise any exceptions in the async loading thread
self.exception = None
# video_height and video_width be filled when loading the first image
self.video_height = None
self.video_width = None
# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
self.__getitem__(0)
# load the rest of frames asynchronously without blocking the session start
def _load_frames():
try:
for n in tqdm(
range(len(self.images)),
desc=f"frame loading (image folder) [rank={RANK}]",
):
self.__getitem__(n)
except Exception as e:
self.exception = e
self.thread = Thread(target=_load_frames, daemon=True)
self.thread.start()
def __getitem__(self, index):
if self.exception is not None:
raise RuntimeError("Failure in frame loading thread") from self.exception
img = self.images[index]
if img is not None:
return img
img, video_height, video_width = _load_img_as_tensor(
self.img_paths[index], self.image_size
)
self.video_height = video_height
self.video_width = video_width
# float16 precision should be sufficient for image tensor storage
img = img.to(dtype=torch.float16)
# normalize by mean and std
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.cuda()
self.images[index] = img
return img
def __len__(self):
return len(self.images)
class TorchCodecDecoder:
"""
A wrapper to support GPU device and num_threads in TorchCodec decoder,
which are not supported by `torchcodec.decoders.SimpleVideoDecoder` yet.
"""
def __init__(self, source, dimension_order="NCHW", device="cpu", num_threads=1):
from torchcodec import _core as core
self._source = source # hold a reference to the source to prevent it from GC
if isinstance(source, str):
self._decoder = core.create_from_file(source, "exact")
elif isinstance(source, bytes):
self._decoder = core.create_from_bytes(source, "exact")
else:
raise TypeError(f"Unknown source type: {type(source)}.")
assert dimension_order in ("NCHW", "NHWC")
device_string = str(device)
core.scan_all_streams_to_update_metadata(self._decoder)
core.add_video_stream(
self._decoder,
dimension_order=dimension_order,
device=device_string,
num_threads=(1 if "cuda" in device_string else num_threads),
)
video_metadata = core.get_container_metadata(self._decoder)
best_stream_index = video_metadata.best_video_stream_index
assert best_stream_index is not None
self.metadata = video_metadata.streams[best_stream_index]
assert self.metadata.num_frames_from_content is not None
self._num_frames = self.metadata.num_frames_from_content
def __len__(self) -> int:
return self._num_frames
def __getitem__(self, key: int):
from torchcodec import _core as core
if key < 0:
key += self._num_frames
if key >= self._num_frames or key < 0:
raise IndexError(
f"Index {key} is out of bounds; length is {self._num_frames}"
)
frame_data, *_ = core.get_frame_at_index(
self._decoder,
frame_index=key,
)
return frame_data
class FIFOLock:
"""A lock that ensures FIFO ordering of lock acquisitions."""
def __init__(self):
self._lock = Lock()
self._waiters = queue.Queue()
self._condition = Condition()
def acquire(self):
ident = get_ident()
with self._condition:
self._waiters.put(ident)
while self._waiters.queue[0] != ident or not self._lock.acquire(
blocking=False
):
self._condition.wait()
# got the lock and it's our turn
def release(self):
with self._condition:
self._lock.release()
self._waiters.get()
self._condition.notify_all()
def __enter__(self):
self.acquire()
def __exit__(self, t, v, tb):
self.release()
class AsyncVideoFileLoaderWithTorchCodec:
"""
Loading frames from video files asynchronously without blocking session start.
Unlike `AsyncVideoFileLoader`, this class uses PyTorch's offical TorchCodec library
for video decoding, which is more efficient and supports more video formats.
"""
def __init__(
self,
video_path,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
gpu_acceleration=True,
gpu_device=None,
use_rand_seek_in_loading=False,
):
# Check and possibly infer the output device (and also get its GPU id when applicable)
assert gpu_device is None or gpu_device.type == "cuda"
gpu_id = (
gpu_device.index
if gpu_device is not None and gpu_device.index is not None
else torch.cuda.current_device()
)
if offload_video_to_cpu:
out_device = torch.device("cpu")
else:
out_device = torch.device("cuda") if gpu_device is None else gpu_device
self.out_device = out_device
self.gpu_acceleration = gpu_acceleration
self.gpu_id = gpu_id
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
if not isinstance(img_mean, torch.Tensor):
img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
self.img_mean = img_mean
if not isinstance(img_std, torch.Tensor):
img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
self.img_std = img_std
if gpu_acceleration:
self.img_mean = self.img_mean.to(f"cuda:{self.gpu_id}")
self.img_std = self.img_std.to(f"cuda:{self.gpu_id}")
decoder_option = {"device": f"cuda:{self.gpu_id}"}
else:
self.img_mean = self.img_mean.cpu()
self.img_std = self.img_std.cpu()
decoder_option = {"num_threads": 1} # use a single thread to save memory
self.rank = int(os.environ.get("RANK", "0"))
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
self.async_reader = TorchCodecDecoder(video_path, **decoder_option)
# `num_frames_from_content` is the true number of frames in the video content
# from the scan operation (rather than from the metadata, which could be wrong)
self.num_frames = self.async_reader.metadata.num_frames_from_content
self.video_height = self.async_reader.metadata.height
self.video_width = self.async_reader.metadata.width
# items in `self._images` will be loaded asynchronously
self.images_loaded = [False] * self.num_frames
self.images = torch.zeros(
self.num_frames,
3,
self.image_size,
self.image_size,
dtype=torch.float16,
device=self.out_device,
)
# catch and raise any exceptions in the async loading thread
self.exception = None
self.use_rand_seek_in_loading = use_rand_seek_in_loading
self.rand_seek_idx_queue = queue.Queue()
# use a lock to avoid race condition between concurrent access to torchcodec
# libs (which are not thread-safe); the lock is replaced with a nullcontext
# when the video is fully loaded
self.torchcodec_access_lock = FIFOLock()
self._start_video_loading()
def _load_one_frame(self, idx):
frame_resized = self._transform_frame(self.async_reader[idx])
return frame_resized
@torch.inference_mode()
def _start_video_loading(self):
desc = f"frame loading (TorchCodec w/ {'GPU' if self.gpu_acceleration else 'CPU'}) [rank={RANK}]"
pbar = tqdm(desc=desc, total=self.num_frames)
self.num_loaded_frames = 0
# load the first frame synchronously to cache it before the session is opened
idx = self.num_loaded_frames
self.images[idx] = self._load_one_frame(idx)
self.images_loaded[idx] = True
self.num_loaded_frames += 1
pbar.update(n=1)
self.all_frames_loaded = self.num_loaded_frames == self.num_frames
# load the frames asynchronously without blocking the session start
def _load_frames():
finished = self.all_frames_loaded
chunk_size = 16
while not finished:
# asynchronously load `chunk_size` frames each time we acquire the lock
with self.torchcodec_access_lock, torch.inference_mode():
for _ in range(chunk_size):
try:
idx = self.num_loaded_frames
self.images[idx] = self._load_one_frame(idx)
self.images_loaded[idx] = True
self.num_loaded_frames += 1
pbar.update(n=1)
if self.num_loaded_frames >= self.num_frames:
finished = True
break
except Exception as e:
self.exception = e
raise
# also read the frame that is being randomly seeked to
while True:
try:
idx = self.rand_seek_idx_queue.get_nowait()
if not self.images_loaded[idx]:
self.images[idx] = self._load_one_frame(idx)
self.images_loaded[idx] = True
except queue.Empty:
break
except Exception as e:
self.exception = e
raise
# finished -- check whether we have loaded the total number of frames
if self.num_loaded_frames != self.num_frames:
raise RuntimeError(
f"There are {self.num_frames} frames in the video, but only "
f"{self.num_loaded_frames} frames can be loaded successfully."
)
else:
self.all_frames_loaded = True
pbar.close()
with self.torchcodec_access_lock:
import gc
# all frames have been loaded, so we can release the readers and free their memory
# also remove pbar and thread (which shouldn't be a part of session saving)
reader = self.async_reader
if reader is not None:
reader._source = None
self.async_reader = None
self.pbar = None
self.thread = None
self.rand_seek_idx_queue = None
gc.collect()
# remove the lock (replace it with nullcontext) when the video is fully loaded
self.torchcodec_access_lock = contextlib.nullcontext()
self.thread = Thread(target=_load_frames, daemon=True)
self.thread.start()
def _transform_frame(self, frame):
frame = frame.clone() # make a copy to avoid modifying the original frame bytes
frame = frame.float() # convert to float32 before interpolation
frame_resized = F.interpolate(
frame[None, :],
size=(self.image_size, self.image_size),
mode="bicubic",
align_corners=False,
)[0]
# float16 precision should be sufficient for image tensor storage
frame_resized = frame_resized.half() # uint8 -> float16
frame_resized /= 255
frame_resized -= self.img_mean
frame_resized /= self.img_std
if self.offload_video_to_cpu:
frame_resized = frame_resized.cpu()
elif frame_resized.device != self.out_device:
frame_resized = frame_resized.to(device=self.out_device, non_blocking=True)
return frame_resized
def __getitem__(self, index):
if self.exception is not None:
raise RuntimeError("Failure in frame loading thread") from self.exception
max_tries = 1200
for _ in range(max_tries):
# use a lock to avoid race condition between concurrent access to torchcodec
# libs (which are not thread-safe); the lock is replaced with a nullcontext
# when the video is fully loaded
with self.torchcodec_access_lock:
if self.images_loaded[index]:
return self.images[index]
if self.use_rand_seek_in_loading:
# async loading hasn't reached this frame yet, so we load this frame individually
# (it will be loaded by in _load_frames thread and added to self.images[index])
self.rand_seek_idx_queue.put(index)
time.sleep(0.1)
raise RuntimeError(f"Failed to load frame {index} after {max_tries} tries")
def __len__(self):
return len(self.images)
def __getstate__(self):
"""
Remove a few attributes during pickling, so that this async video loader can be
saved and loaded as a part of the model session.
"""
# wait for async video loading to finish before pickling
async_thread = self.thread
if async_thread is not None:
async_thread.join()
# release a few objects that cannot be pickled
reader = self.async_reader
if reader is not None:
reader._source = None
self.async_reader = None
self.pbar = None
self.thread = None
self.rand_seek_idx_queue = None
self.torchcodec_access_lock = contextlib.nullcontext()
return self.__dict__.copy()

View File

@@ -0,0 +1,323 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import math
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from .model_misc import MLP
class LinearPresenceHead(nn.Sequential):
def __init__(self, d_model):
# a hack to make `LinearPresenceHead` compatible with old checkpoints
super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1))
def forward(self, hs, prompt, prompt_mask):
return super().forward(hs)
class MaskPredictor(nn.Module):
def __init__(self, hidden_dim, mask_dim):
super().__init__()
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
def forward(self, obj_queries, pixel_embed):
if len(obj_queries.shape) == 3:
if pixel_embed.ndim == 3:
# batch size was omitted
mask_preds = torch.einsum(
"bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed
)
else:
mask_preds = torch.einsum(
"bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed
)
else:
# Assumed to have aux masks
if pixel_embed.ndim == 3:
# batch size was omitted
mask_preds = torch.einsum(
"lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed
)
else:
mask_preds = torch.einsum(
"lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed
)
return mask_preds
class SegmentationHead(nn.Module):
def __init__(
self,
hidden_dim,
upsampling_stages,
use_encoder_inputs=False,
aux_masks=False,
no_dec=False,
pixel_decoder=None,
act_ckpt=False,
shared_conv=False,
compile_mode_pixel_decoder=None,
):
super().__init__()
self.use_encoder_inputs = use_encoder_inputs
self.aux_masks = aux_masks
if pixel_decoder is not None:
self.pixel_decoder = pixel_decoder
else:
self.pixel_decoder = PixelDecoder(
hidden_dim,
upsampling_stages,
shared_conv=shared_conv,
compile_mode=compile_mode_pixel_decoder,
)
self.no_dec = no_dec
if no_dec:
self.mask_predictor = nn.Conv2d(
hidden_dim, 1, kernel_size=3, stride=1, padding=1
)
else:
self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim)
self.act_ckpt = act_ckpt
# used to update the output dictionary
self.instance_keys = ["pred_masks"]
@property
def device(self):
self._device = getattr(self, "_device", None) or next(self.parameters()).device
return self._device
def to(self, *args, **kwargs):
# clear cached _device in case the model is moved to a different device
self._device = None
return super().to(*args, **kwargs)
def _embed_pixels(
self,
backbone_feats: List[torch.Tensor],
image_ids,
encoder_hidden_states,
) -> torch.Tensor:
feature_device = backbone_feats[0].device # features could be on CPU
model_device = self.device
image_ids_ = image_ids.to(feature_device)
if self.use_encoder_inputs:
if backbone_feats[0].shape[0] > 1:
# For bs > 1, we construct the per query backbone features
backbone_visual_feats = []
for feat in backbone_feats:
# Copy the img features per query (pixel decoder won't share img feats)
backbone_visual_feats.append(feat[image_ids_, ...].to(model_device))
else:
# Bs=1, we rely on broadcasting for query-based processing
backbone_visual_feats = [bb_feat.clone() for bb_feat in backbone_feats]
# Extract visual embeddings
encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0)
spatial_dim = math.prod(backbone_feats[-1].shape[-2:])
encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape(
-1, *backbone_feats[-1].shape[1:]
)
backbone_visual_feats[-1] = encoder_visual_embed
if self.act_ckpt:
pixel_embed = checkpoint.checkpoint(
self.pixel_decoder, backbone_visual_feats, use_reentrant=False
)
else:
pixel_embed = self.pixel_decoder(backbone_visual_feats)
else:
backbone_feats = [x.to(model_device) for x in backbone_feats]
pixel_embed = self.pixel_decoder(backbone_feats)
if pixel_embed.shape[0] == 1:
# For batch_size=1 training, we can avoid the indexing to save memory
pixel_embed = pixel_embed.squeeze(0)
else:
pixel_embed = pixel_embed[image_ids, ...]
return pixel_embed
def forward(
self,
backbone_feats: List[torch.Tensor],
obj_queries: torch.Tensor,
image_ids,
encoder_hidden_states: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
if self.use_encoder_inputs:
assert encoder_hidden_states is not None
pixel_embed = self._embed_pixels(
backbone_feats=backbone_feats,
image_ids=image_ids,
encoder_hidden_states=encoder_hidden_states,
)
if self.no_dec:
mask_pred = self.mask_predictor(pixel_embed)
elif self.aux_masks:
mask_pred = self.mask_predictor(obj_queries, pixel_embed)
else:
mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed)
return {"pred_masks": mask_pred}
class PixelDecoder(nn.Module):
def __init__(
self,
hidden_dim,
num_upsampling_stages,
interpolation_mode="nearest",
shared_conv=False,
compile_mode=None,
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_upsampling_stages = num_upsampling_stages
self.interpolation_mode = interpolation_mode
conv_layers = []
norms = []
num_convs = 1 if shared_conv else num_upsampling_stages
for _ in range(num_convs):
conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1))
norms.append(nn.GroupNorm(8, self.hidden_dim))
self.conv_layers = nn.ModuleList(conv_layers)
self.norms = nn.ModuleList(norms)
self.shared_conv = shared_conv
self.out_dim = self.conv_layers[-1].out_channels
if compile_mode is not None:
self.forward = torch.compile(
self.forward, mode=compile_mode, dynamic=True, fullgraph=True
)
# Needed to make checkpointing happy. But we don't know if the module is checkpointed, so we disable it by default.
torch._dynamo.config.optimize_ddp = False
def forward(self, backbone_feats: List[torch.Tensor]):
# Assumes backbone features are already projected (C == hidden dim)
prev_fpn = backbone_feats[-1]
fpn_feats = backbone_feats[:-1]
for layer_idx, bb_feat in enumerate(fpn_feats[::-1]):
curr_fpn = bb_feat
prev_fpn = curr_fpn + F.interpolate(
prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode
)
if self.shared_conv:
# only one conv layer
layer_idx = 0
prev_fpn = self.conv_layers[layer_idx](prev_fpn)
prev_fpn = F.relu(self.norms[layer_idx](prev_fpn))
return prev_fpn
class UniversalSegmentationHead(SegmentationHead):
"""This module handles semantic+instance segmentation"""
def __init__(
self,
hidden_dim,
upsampling_stages,
pixel_decoder,
aux_masks=False,
no_dec=False,
act_ckpt=False,
presence_head: bool = False,
dot_product_scorer=None,
cross_attend_prompt=None,
):
super().__init__(
hidden_dim=hidden_dim,
upsampling_stages=upsampling_stages,
use_encoder_inputs=True,
aux_masks=aux_masks,
no_dec=no_dec,
pixel_decoder=pixel_decoder,
act_ckpt=act_ckpt,
)
self.d_model = hidden_dim
if dot_product_scorer is not None:
assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
self.presence_head = None
if presence_head:
self.presence_head = (
dot_product_scorer
if dot_product_scorer is not None
else LinearPresenceHead(self.d_model)
)
self.cross_attend_prompt = cross_attend_prompt
if self.cross_attend_prompt is not None:
self.cross_attn_norm = nn.LayerNorm(self.d_model)
self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1)
self.instance_seg_head = nn.Conv2d(
self.pixel_decoder.out_dim, self.d_model, kernel_size=1
)
def forward(
self,
backbone_feats: List[torch.Tensor],
obj_queries: torch.Tensor,
image_ids,
encoder_hidden_states: Optional[torch.Tensor] = None,
prompt: Optional[torch.Tensor] = None,
prompt_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, Optional[torch.Tensor]]:
assert encoder_hidden_states is not None
bs = encoder_hidden_states.shape[1]
if self.cross_attend_prompt is not None:
tgt2 = self.cross_attn_norm(encoder_hidden_states)
tgt2 = self.cross_attend_prompt(
query=tgt2,
key=prompt,
value=prompt,
key_padding_mask=prompt_mask,
)[0]
encoder_hidden_states = tgt2 + encoder_hidden_states
presence_logit = None
if self.presence_head is not None:
pooled_enc = encoder_hidden_states.mean(0)
presence_logit = (
self.presence_head(
pooled_enc.view(1, bs, 1, self.d_model),
prompt=prompt,
prompt_mask=prompt_mask,
)
.squeeze(0)
.squeeze(1)
)
pixel_embed = self._embed_pixels(
backbone_feats=backbone_feats,
image_ids=image_ids,
encoder_hidden_states=encoder_hidden_states,
)
instance_embeds = self.instance_seg_head(pixel_embed)
if self.no_dec:
mask_pred = self.mask_predictor(instance_embeds)
elif self.aux_masks:
mask_pred = self.mask_predictor(obj_queries, instance_embeds)
else:
mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds)
return {
"pred_masks": mask_pred,
"semantic_seg": self.semantic_seg_head(pixel_embed),
"presence_logit": presence_logit,
}

201
sam3/model/memory.py Normal file
View File

@@ -0,0 +1,201 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import math
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from timm.layers import DropPath
except ModuleNotFoundError:
# compatibility for older timm versions
from timm.models.layers import DropPath
from .model_misc import get_clones, LayerNorm2d
class SimpleMaskDownSampler(nn.Module):
"""
Progressively downsample a mask by total_stride, each time by stride.
Note that LayerNorm is applied per *token*, like in ViT.
With each downsample (by a factor stride**2), channel capacity increases by the same factor.
In the end, we linearly project to embed_dim channels.
"""
def __init__(
self,
embed_dim=256,
kernel_size=4,
stride=4,
padding=0,
total_stride=16,
activation=nn.GELU,
# Option to interpolate the input mask first before downsampling using convs. In that case, the total_stride is assumed to be after interpolation.
# If set to input resolution or None, we don't interpolate. We default to None to be safe (for older configs or if not explicitly set)
interpol_size=None,
):
super().__init__()
num_layers = int(math.log2(total_stride) // math.log2(stride))
assert stride**num_layers == total_stride
self.encoder = nn.Sequential()
mask_in_chans, mask_out_chans = 1, 1
for _ in range(num_layers):
mask_out_chans = mask_in_chans * (stride**2)
self.encoder.append(
nn.Conv2d(
mask_in_chans,
mask_out_chans,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
)
self.encoder.append(LayerNorm2d(mask_out_chans))
self.encoder.append(activation())
mask_in_chans = mask_out_chans
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
self.interpol_size = interpol_size
if self.interpol_size is not None:
assert isinstance(
self.interpol_size, (list, tuple)
), f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
self.interpol_size = list(interpol_size)
assert len(self.interpol_size) == 2
def forward(self, x: torch.Tensor):
if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]):
x = F.interpolate(
x.float(),
size=self.interpol_size,
align_corners=False,
mode="bilinear",
antialias=True,
)
return self.encoder(x)
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
class CXBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(
self,
dim,
kernel_size=7,
padding=3,
drop_path=0.0,
layer_scale_init_value=1e-6,
use_dwconv=True,
):
super().__init__()
self.dwconv = nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
padding=padding,
groups=dim if use_dwconv else 1,
) # depthwise conv
self.norm = LayerNorm2d(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, 4 * dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
if layer_scale_init_value > 0
else None
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x = self.norm(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class SimpleFuser(nn.Module):
def __init__(self, layer, num_layers, dim=None, input_projection=False):
super().__init__()
self.proj = nn.Identity()
self.layers = get_clones(layer, num_layers)
if input_projection:
assert dim is not None
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
def forward(self, x):
# normally x: (N, C, H, W)
x = self.proj(x)
for layer in self.layers:
x = layer(x)
return x
class SimpleMaskEncoder(nn.Module):
def __init__(
self,
out_dim,
mask_downsampler,
fuser,
position_encoding,
in_dim=256, # in_dim of pix_feats
):
super().__init__()
self.mask_downsampler = mask_downsampler
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.fuser = fuser
self.position_encoding = position_encoding
self.out_proj = nn.Identity()
if out_dim != in_dim:
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
def forward(
self,
pix_feat: torch.Tensor,
masks: torch.Tensor,
skip_mask_sigmoid: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
## Process masks
# sigmoid, so that less domain shift from gt masks which are bool
if not skip_mask_sigmoid:
masks = F.sigmoid(masks)
masks = self.mask_downsampler(masks)
## Fuse pix_feats and downsampled masks
# in case the visual features are on CPU, cast them to CUDA
pix_feat = pix_feat.to(masks.device)
x = self.pix_feat_proj(pix_feat)
x = x + masks
x = self.fuser(x)
x = self.out_proj(x)
pos = self.position_encoding(x).to(x.dtype)
return {"vision_features": x, "vision_pos_enc": [pos]}

428
sam3/model/model_misc.py Normal file
View File

@@ -0,0 +1,428 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Various utility models"""
import copy
import math
import weakref
from collections.abc import Iterator
from contextlib import AbstractContextManager
from enum import auto, Enum
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from typing_extensions import override
def inverse_sigmoid(x, eps=1e-3):
"""
The inverse function for sigmoid activation function.
Note: It might face numberical issues with fp16 small eps.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
class MultiheadAttentionWrapper(nn.MultiheadAttention):
def forward(self, *args, **kwargs):
kwargs["need_weights"] = False
return super().forward(*args, **kwargs)
class DotProductScoring(torch.nn.Module):
def __init__(
self,
d_model,
d_proj,
prompt_mlp=None,
clamp_logits=True,
clamp_max_val=12.0,
):
super().__init__()
self.d_proj = d_proj
assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None
self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt
self.prompt_proj = torch.nn.Linear(d_model, d_proj)
self.hs_proj = torch.nn.Linear(d_model, d_proj)
self.scale = float(1.0 / np.sqrt(d_proj))
self.clamp_logits = clamp_logits
if self.clamp_logits:
self.clamp_max_val = clamp_max_val
def mean_pool_text(self, prompt, prompt_mask):
# is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
is_valid = (~prompt_mask).float().permute(1, 0)[..., None]
# num_valid has shape (bs, 1)
num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
# mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim)
pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid
return pooled_prompt
def forward(self, hs, prompt, prompt_mask):
# hs has shape (num_layer, bs, num_query, d_model)
# prompt has shape (seq, bs, d_model)
# prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding
assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2
# apply MLP on prompt if specified
if self.prompt_mlp is not None:
prompt = self.prompt_mlp(prompt)
# first, get the mean-pooled version of the prompt
pooled_prompt = self.mean_pool_text(prompt, prompt_mask)
# then, project pooled_prompt and hs to d_proj dimensions
proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj)
proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj)
# finally, get dot-product scores of shape (num_layer, bs, num_query, 1)
scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1))
scores *= self.scale
# clamp scores to a max value to avoid numerical issues in loss or matcher
if self.clamp_logits:
scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val)
return scores
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class TransformerWrapper(nn.Module):
def __init__(
self,
encoder,
decoder,
d_model: int,
two_stage_type="none", # ["none"] only for now
pos_enc_at_input_dec=True,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.num_queries = decoder.num_queries if decoder is not None else None
self.pos_enc_at_input_dec = pos_enc_at_input_dec
# for two stage
assert two_stage_type in ["none"], "unknown param {} of two_stage_type".format(
two_stage_type
)
self.two_stage_type = two_stage_type
self._reset_parameters()
self.d_model = d_model
def _reset_parameters(self):
for n, p in self.named_parameters():
if p.dim() > 1:
if (
"box_embed" not in n
and "query_embed" not in n
and "reference_points" not in n
):
nn.init.xavier_uniform_(p)
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
dropout: float = 0.0,
residual: bool = False,
out_norm: Optional[nn.Module] = None,
):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
# whether to add the output as a residual connection to the input
if residual and input_dim != output_dim:
raise ValueError("residual is only supported if input_dim == output_dim")
self.residual = residual
# whether to apply a normalization layer to the output
assert isinstance(out_norm, nn.Module) or out_norm is None
self.out_norm = out_norm or nn.Identity()
def forward(self, x):
orig_x = x
for i, layer in enumerate(self.layers):
x = self.drop(F.relu(layer(x))) if i < self.num_layers - 1 else layer(x)
if self.residual:
x = x + orig_x
x = self.out_norm(x)
return x
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def get_clones_seq(module, N):
return nn.Sequential(*[copy.deepcopy(module) for i in range(N)])
def get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def get_activation_module(activation):
"""Return an activation function given a string"""
if activation == "relu":
return nn.ReLU
if activation == "gelu":
return nn.GELU
if activation == "glu":
return nn.GLU
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def get_valid_ratio(mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def gen_sineembed_for_position(pos_tensor, num_feats=256):
assert num_feats % 2 == 0
num_feats = num_feats // 2
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(num_feats, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
).flatten(2)
pos_y = torch.stack(
(pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack(
(pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack(
(pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
return pos
class SAM3Output(list):
"""
A class representing the output of a SAM3 model.
It provides an iterable interface that supports different iteration modes, including iterating over all steps per stage,
last step per stage, and flattened output.
Attributes:
output: The output of the SAM3 model, represented as a list of lists.
iter_mode: The current iteration mode.
Example:
>>> output = [[1, 2], [3, 4], [5, 6]]
>>> sam3_output = SAM3Output(output)
>>> for step in sam3_output:
... print(step)
[1, 2]
[3, 4]
[5, 6]
>>> with SAM3Output.iteration_mode(SAM3Output.IterMode.LAST_STEP_PER_STAGE) as sam3_last_step_out:
... for step in sam3_last_step_out:
... print(step)
[2]
[4]
[6]
>>> with SAM3Output.iteration_mode(SAM3Output.IterMode.FLATTENED) as sam3_flattened_out:
... for step in sam3_flattened_out:
... print(step)
1
2
3
4
5
6
"""
class IterMode(Enum):
# Defines the type of iterator over ouptuts.
ALL_STEPS_PER_STAGE = auto()
LAST_STEP_PER_STAGE = auto()
FLATTENED = auto() # Returns each interactivity step as if it is a separate stage (this is used in SAM3Image model)
def __init__(
self,
output: List[List[Dict]] = None,
iter_mode: IterMode = IterMode.ALL_STEPS_PER_STAGE,
loss_stages: Optional[List[int]] = None,
):
if output is not None:
assert (
isinstance(output, list)
and len(output) > 0
and isinstance(output[0], list)
), "Expected output to be a list of lists"
self.output = output
else:
self.output = []
assert isinstance(
iter_mode, SAM3Output.IterMode
), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
self.iter_mode = iter_mode
# We create a weak reference to self to be used in the lambda functions.
# This is to avoid cyclic references and let SAM3Output be garabge collected.
self_ref = weakref.ref(self)
self._mode2iter = {
SAM3Output.IterMode.ALL_STEPS_PER_STAGE: lambda: iter(self_ref().output),
SAM3Output.IterMode.LAST_STEP_PER_STAGE: lambda: (
inner_list[-1] for inner_list in self_ref().output
),
SAM3Output.IterMode.FLATTENED: lambda: (
element for inner_list in self_ref().output for element in inner_list
),
}
self.loss_stages = loss_stages
@override
def __iter__(self) -> Iterator:
return self._mode2iter[self.iter_mode]()
def __getitem__(self, index):
"""
Returns the item at the specified index.
Args:
index (int): The index of the item to return.
Returns:
list or element: The item at the specified index.
"""
assert isinstance(index, int), f"index should be an integer. Got {type(index)}"
if self.iter_mode == SAM3Output.IterMode.ALL_STEPS_PER_STAGE:
return self.output[index]
elif self.iter_mode == SAM3Output.IterMode.LAST_STEP_PER_STAGE:
return self.output[index][-1]
elif self.iter_mode == SAM3Output.IterMode.FLATTENED:
if index == -1:
return self.self.output[-1][-1]
else:
flattened_output = sum(self.output, [])
return flattened_output[index]
class _IterationMode(AbstractContextManager):
"""
A context manager that temporarily changes the iteration mode of a SAM3Output object.
This class is used internally by the SAM3Output.iteration_mode method.
"""
def __init__(
self, model_output: "SAM3Output", iter_mode: "SAM3Output.IterMode"
):
self._model_output = model_output
self._orig_iter_mode = model_output.iter_mode
self._new_iter_mode = iter_mode
@override
def __enter__(self) -> "SAM3Output":
self._model_output.iter_mode = self._new_iter_mode
return self._model_output
@override
def __exit__(self, exc_type, exc_value, traceback):
self._model_output.iter_mode = self._orig_iter_mode
return super().__exit__(exc_type, exc_value, traceback)
@staticmethod
def iteration_mode(
model_output: "SAM3Output", iter_mode: IterMode
) -> _IterationMode:
"""
Returns a context manager that allows you to temporarily change the iteration mode of the SAM3Output object.
Args:
model_output: The SAM3Output object.
iter_mode: The new iteration mode.
Returns:
SAM3Output._IterationMode: A context manager that changes the iteration mode of the SAM3Output object.
"""
return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode)
def append(self, item: list):
assert isinstance(
item, list
), f"Only list items are supported. Got {type(item)}"
self.output.append(item)
def __repr__(self):
return self.output.__repr__()
def __len__(self):
if self.iter_mode in [
SAM3Output.IterMode.ALL_STEPS_PER_STAGE,
SAM3Output.IterMode.LAST_STEP_PER_STAGE,
]:
return len(self.output)
elif self.iter_mode == SAM3Output.IterMode.FLATTENED:
flattened_output = sum(self.output, [])
return len(flattened_output)

125
sam3/model/necks.py Normal file
View File

@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Necks are the interface between a vision backbone and the rest of the detection model"""
from copy import deepcopy
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
class Sam3DualViTDetNeck(nn.Module):
def __init__(
self,
trunk: nn.Module,
position_encoding: nn.Module,
d_model: int,
scale_factors=(4.0, 2.0, 1.0, 0.5),
add_sam2_neck: bool = False,
):
"""
SimpleFPN neck a la ViTDet
(From detectron2, very lightly adapted)
It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights
:param trunk: the backbone
:param position_encoding: the positional encoding to use
:param d_model: the dimension of the model
"""
super().__init__()
self.trunk = trunk
self.position_encoding = position_encoding
self.convs = nn.ModuleList()
self.scale_factors = scale_factors
use_bias = True
dim: int = self.trunk.channel_list[-1]
for _, scale in enumerate(scale_factors):
current = nn.Sequential()
if scale == 4.0:
current.add_module(
"dconv_2x2_0",
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
)
current.add_module(
"gelu",
nn.GELU(),
)
current.add_module(
"dconv_2x2_1",
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
)
out_dim = dim // 4
elif scale == 2.0:
current.add_module(
"dconv_2x2",
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
)
out_dim = dim // 2
elif scale == 1.0:
out_dim = dim
elif scale == 0.5:
current.add_module(
"maxpool_2x2",
nn.MaxPool2d(kernel_size=2, stride=2),
)
out_dim = dim
else:
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
current.add_module(
"conv_1x1",
nn.Conv2d(
in_channels=out_dim,
out_channels=d_model,
kernel_size=1,
bias=use_bias,
),
)
current.add_module(
"conv_3x3",
nn.Conv2d(
in_channels=d_model,
out_channels=d_model,
kernel_size=3,
padding=1,
bias=use_bias,
),
)
self.convs.append(current)
self.sam2_convs = None
if add_sam2_neck:
# Assumes sam2 neck is just a clone of the original neck
self.sam2_convs = deepcopy(self.convs)
def forward(
self, tensor_list: List[torch.Tensor]
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
]:
xs = self.trunk(tensor_list)
sam3_out, sam3_pos = [], []
sam2_out, sam2_pos = None, None
if self.sam2_convs is not None:
sam2_out, sam2_pos = [], []
x = xs[-1] # simpleFPN
for i in range(len(self.convs)):
sam3_x_out = self.convs[i](x)
sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype)
sam3_out.append(sam3_x_out)
sam3_pos.append(sam3_pos_out)
if self.sam2_convs is not None:
sam2_x_out = self.sam2_convs[i](x)
sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
sam2_out.append(sam2_x_out)
sam2_pos.append(sam2_pos_out)
return sam3_out, sam3_pos, sam2_out, sam2_pos

View File

@@ -0,0 +1,124 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import math
from typing import Optional
import torch
from torch import nn
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self,
num_pos_feats,
temperature: int = 10000,
normalize: bool = True,
scale: Optional[float] = None,
precompute_resolution: Optional[int] = None,
):
super().__init__()
assert num_pos_feats % 2 == 0, "Expecting even model width"
self.num_pos_feats = num_pos_feats // 2
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
self.cache = {}
# Precompute positional encodings under `precompute_resolution` to fill the cache
# and avoid symbolic shape tracing errors in torch.compile in PyTorch 2.4 nightly.
if precompute_resolution is not None:
# We precompute pos enc for stride 4, 8, 16 and 32 to fill `self.cache`.
precompute_sizes = [
(precompute_resolution // 4, precompute_resolution // 4),
(precompute_resolution // 8, precompute_resolution // 8),
(precompute_resolution // 16, precompute_resolution // 16),
(precompute_resolution // 32, precompute_resolution // 32),
]
for size in precompute_sizes:
tensors = torch.zeros((1, 1) + size, device="cuda")
self.forward(tensors)
# further clone and detach it in the cache (just to be safe)
self.cache[size] = self.cache[size].clone().detach()
def _encode_xy(self, x, y):
# The positions are expected to be normalized
assert len(x) == len(y) and x.ndim == y.ndim == 1
x_embed = x * self.scale
y_embed = y * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, None] / dim_t
pos_y = y_embed[:, None] / dim_t
pos_x = torch.stack(
(pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
).flatten(1)
pos_y = torch.stack(
(pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
).flatten(1)
return pos_x, pos_y
@torch.no_grad()
def encode_boxes(self, x, y, w, h):
pos_x, pos_y = self._encode_xy(x, y)
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
return pos
encode = encode_boxes # Backwards compatibility
@torch.no_grad()
def encode_points(self, x, y, labels):
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
assert bx == by and nx == ny and bx == bl and nx == nl
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
return pos
@torch.no_grad()
def forward(self, x):
cache_key = None
cache_key = (x.shape[-2], x.shape[-1])
if cache_key in self.cache:
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
y_embed = (
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
.view(1, -1, 1)
.repeat(x.shape[0], 1, x.shape[-1])
)
x_embed = (
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
.view(1, 1, -1)
.repeat(x.shape[0], x.shape[-2], 1)
)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
if cache_key is not None:
self.cache[cache_key] = pos[0]
return pos

View File

@@ -0,0 +1,458 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from PIL.Image import Image
from sam3.model.sam3_tracker_base import Sam3TrackerBase
from sam3.model.utils.sam1_utils import SAM2Transforms
# Adapted from https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py
class SAM3InteractiveImagePredictor(nn.Module):
def __init__(
self,
sam_model: Sam3TrackerBase,
mask_threshold=0.0,
max_hole_area=256.0,
max_sprinkle_area=0.0,
**kwargs,
) -> None:
"""
Uses SAM-3 to calculate the image embedding for an image, and then
allow repeated, efficient mask prediction given prompts.
Arguments:
sam_model : The model to use for mask prediction.
mask_threshold (float): The threshold to use when converting mask logits
to binary masks. Masks are thresholded at 0 by default.
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
"""
super().__init__()
self.model = sam_model
self._transforms = SAM2Transforms(
resolution=self.model.image_size,
mask_threshold=mask_threshold,
max_hole_area=max_hole_area,
max_sprinkle_area=max_sprinkle_area,
)
# Predictor state
self._is_image_set = False
self._features = None
self._orig_hw = None
# Whether the predictor is set for single image or a batch of images
self._is_batch = False
# Predictor config
self.mask_threshold = mask_threshold
# Spatial dim for backbone feature maps
self._bb_feat_sizes = [
(288, 288),
(144, 144),
(72, 72),
]
@torch.no_grad()
def set_image(
self,
image: Union[np.ndarray, Image],
) -> None:
"""
Calculates the image embeddings for the provided image, allowing
masks to be predicted with the 'predict' method.
Arguments:
image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
with pixel values in [0, 255].
image_format (str): The color format of the image, in ['RGB', 'BGR'].
"""
self.reset_predictor()
# Transform the image to the form expected by the model
if isinstance(image, np.ndarray):
logging.info("For numpy array image, we assume (HxWxC) format")
self._orig_hw = [image.shape[:2]]
elif isinstance(image, Image):
w, h = image.size
self._orig_hw = [(h, w)]
else:
raise NotImplementedError("Image format not supported")
input_image = self._transforms(image)
input_image = input_image[None, ...].to(self.device)
assert (
len(input_image.shape) == 4 and input_image.shape[1] == 3
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
logging.info("Computing image embeddings for the provided image...")
backbone_out = self.model.forward_image(input_image)
(
_,
vision_feats,
_,
_,
) = self.model._prepare_backbone_features(backbone_out)
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
feats = [
feat.permute(1, 2, 0).view(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
self._is_image_set = True
logging.info("Image embeddings computed.")
@torch.no_grad()
def set_image_batch(
self,
image_list: List[Union[np.ndarray]],
) -> None:
"""
Calculates the image embeddings for the provided image batch, allowing
masks to be predicted with the 'predict_batch' method.
Arguments:
image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
with pixel values in [0, 255].
"""
self.reset_predictor()
assert isinstance(image_list, list)
self._orig_hw = []
for image in image_list:
assert isinstance(
image, np.ndarray
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
self._orig_hw.append(image.shape[:2])
# Transform the image to the form expected by the model
img_batch = self._transforms.forward_batch(image_list)
img_batch = img_batch.to(self.device)
batch_size = img_batch.shape[0]
assert (
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
logging.info("Computing image embeddings for the provided images...")
backbone_out = self.model.forward_image(img_batch)
(
_,
vision_feats,
_,
_,
) = self.model._prepare_backbone_features(backbone_out)
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
feats = [
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
self._is_image_set = True
self._is_batch = True
logging.info("Image embeddings computed.")
def predict_batch(
self,
point_coords_batch: List[np.ndarray] = None,
point_labels_batch: List[np.ndarray] = None,
box_batch: List[np.ndarray] = None,
mask_input_batch: List[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
normalize_coords=True,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
"""
assert self._is_batch, "This function should only be used when in batched mode"
if not self._is_image_set:
raise RuntimeError(
"An image must be set with .set_image_batch(...) before mask prediction."
)
num_images = len(self._features["image_embed"])
all_masks = []
all_ious = []
all_low_res_masks = []
for img_idx in range(num_images):
# Transform input prompts
point_coords = (
point_coords_batch[img_idx] if point_coords_batch is not None else None
)
point_labels = (
point_labels_batch[img_idx] if point_labels_batch is not None else None
)
box = box_batch[img_idx] if box_batch is not None else None
mask_input = (
mask_input_batch[img_idx] if mask_input_batch is not None else None
)
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
point_coords,
point_labels,
box,
mask_input,
normalize_coords,
img_idx=img_idx,
)
masks, iou_predictions, low_res_masks = self._predict(
unnorm_coords,
labels,
unnorm_box,
mask_input,
multimask_output,
return_logits=return_logits,
img_idx=img_idx,
)
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
iou_predictions_np = (
iou_predictions.squeeze(0).float().detach().cpu().numpy()
)
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
all_masks.append(masks_np)
all_ious.append(iou_predictions_np)
all_low_res_masks.append(low_res_masks_np)
return all_masks, all_ious, all_low_res_masks
def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
normalize_coords=True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Predict masks for the given input prompts, using the currently set image.
Arguments:
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (np.ndarray or None): A length N array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
box (np.ndarray or None): A length 4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form 1xHxW, where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
Returns:
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
if not self._is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)
# Transform input prompts
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
point_coords, point_labels, box, mask_input, normalize_coords
)
masks, iou_predictions, low_res_masks = self._predict(
unnorm_coords,
labels,
unnorm_box,
mask_input,
multimask_output,
return_logits=return_logits,
)
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
return masks_np, iou_predictions_np, low_res_masks_np
def _prep_prompts(
self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
):
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
if point_coords is not None:
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
point_coords = torch.as_tensor(
point_coords, dtype=torch.float, device=self.device
)
unnorm_coords = self._transforms.transform_coords(
point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
)
labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
if len(unnorm_coords.shape) == 2:
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
if box is not None:
box = torch.as_tensor(box, dtype=torch.float, device=self.device)
unnorm_box = self._transforms.transform_boxes(
box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
) # Bx2x2
if mask_logits is not None:
mask_input = torch.as_tensor(
mask_logits, dtype=torch.float, device=self.device
)
if len(mask_input.shape) == 3:
mask_input = mask_input[None, :, :, :]
return mask_input, unnorm_coords, labels, unnorm_box
@torch.no_grad()
def _predict(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
img_idx: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the currently set image.
Input prompts are batched torch tensors and are expected to already be
transformed to the input frame using SAM2Transforms.
Arguments:
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (torch.Tensor or None): A BxN array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form Bx1xHxW, where
for SAM, H=W=256. Masks returned by a previous iteration of the
predict method do not need further transformation.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
Returns:
(torch.Tensor): The output masks in BxCxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(torch.Tensor): An array of shape BxC containing the model's
predictions for the quality of each mask.
(torch.Tensor): An array of shape BxCxHxW, where C is the number
of masks and H=W=256. These low res logits can be passed to
a subsequent iteration as mask input.
"""
if not self._is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)
if point_coords is not None:
concat_points = (point_coords, point_labels)
else:
concat_points = None
# Embed prompts
if boxes is not None:
box_coords = boxes.reshape(-1, 2, 2)
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
box_labels = box_labels.repeat(boxes.size(0), 1)
# we merge "boxes" and "points" into a single "concat_points" input (where
# boxes are added at the beginning) to sam_prompt_encoder
if concat_points is not None:
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
concat_points = (concat_coords, concat_labels)
else:
concat_points = (box_coords, box_labels)
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
points=concat_points,
boxes=None,
masks=mask_input,
)
# Predict masks
batched_mode = (
concat_points is not None and concat_points[0].shape[0] > 1
) # multi object prediction
high_res_features = [
feat_level[img_idx].unsqueeze(0)
for feat_level in self._features["high_res_feats"]
]
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=batched_mode,
high_res_features=high_res_features,
)
# Upscale the masks to the original image resolution
masks = self._transforms.postprocess_masks(
low_res_masks, self._orig_hw[img_idx]
)
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
if not return_logits:
masks = masks > self.mask_threshold
return masks, iou_predictions, low_res_masks
def get_image_embedding(self) -> torch.Tensor:
"""
Returns the image embeddings for the currently set image, with
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
the embedding spatial dimension of SAM (typically C=256, H=W=64).
"""
if not self._is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
assert (
self._features is not None
), "Features must exist if an image has been set."
return self._features["image_embed"]
@property
def device(self) -> torch.device:
return self.model.device
def reset_predictor(self) -> None:
"""
Resets the image embeddings and other state variables.
"""
self._is_image_set = False
self._features = None
self._orig_hw = None
self._is_batch = False

883
sam3/model/sam3_image.py Normal file
View File

@@ -0,0 +1,883 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import os
from copy import deepcopy
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from sam3.model.model_misc import SAM3Output
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
from sam3.model.vl_combiner import SAM3VLBackbone
from sam3.perflib.nms import nms_masks
from sam3.train.data.collator import BatchedDatapoint
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .geometry_encoders import Prompt
from .model_misc import inverse_sigmoid
def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True):
out[out_name] = out_value[-1] if auxiliary else out_value
if auxiliary and update_aux:
if "aux_outputs" not in out:
out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)]
assert len(out["aux_outputs"]) == len(out_value) - 1
for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]):
aux_output[out_name] = aux_value
class Sam3Image(torch.nn.Module):
TEXT_ID_FOR_TEXT = 0
TEXT_ID_FOR_VISUAL = 1
TEXT_ID_FOR_GEOMETRIC = 2
def __init__(
self,
backbone: SAM3VLBackbone,
transformer,
input_geometry_encoder,
segmentation_head=None,
num_feature_levels=1,
o2m_mask_predict=True,
dot_prod_scoring=None,
use_instance_query: bool = True,
multimask_output: bool = True,
use_act_checkpoint_seg_head: bool = True,
interactivity_in_encoder: bool = True,
matcher=None,
use_dot_prod_scoring=True,
supervise_joint_box_scores: bool = False, # only relevant if using presence token/score
detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score
separate_scorer_for_instance: bool = False,
num_interactive_steps_val: int = 0,
inst_interactive_predictor: SAM3InteractiveImagePredictor = None,
**kwargs,
):
super().__init__()
self.backbone = backbone
self.geometry_encoder = input_geometry_encoder
self.transformer = transformer
self.hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.segmentation_head = segmentation_head
self.o2m_mask_predict = o2m_mask_predict
self.dot_prod_scoring = dot_prod_scoring
self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head
self.interactivity_in_encoder = interactivity_in_encoder
self.matcher = matcher
self.num_interactive_steps_val = num_interactive_steps_val
self.use_dot_prod_scoring = use_dot_prod_scoring
if self.use_dot_prod_scoring:
assert dot_prod_scoring is not None
self.dot_prod_scoring = dot_prod_scoring
self.instance_dot_prod_scoring = None
if separate_scorer_for_instance:
self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring)
else:
self.class_embed = torch.nn.Linear(self.hidden_dim, 1)
self.instance_class_embed = None
if separate_scorer_for_instance:
self.instance_class_embed = deepcopy(self.class_embed)
self.supervise_joint_box_scores = supervise_joint_box_scores
self.detach_presence_in_joint_score = detach_presence_in_joint_score
# verify the number of queries for O2O and O2M
num_o2o_static = self.transformer.decoder.num_queries
num_o2m_static = self.transformer.decoder.num_o2m_queries
assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0)
self.dac = self.transformer.decoder.dac
self.use_instance_query = use_instance_query
self.multimask_output = multimask_output
self.inst_interactive_predictor = inst_interactive_predictor
@property
def device(self):
self._device = getattr(self, "_device", None) or next(self.parameters()).device
return self._device
def to(self, *args, **kwargs):
# clear cached _device in case the model is moved to a different device
self._device = None
return super().to(*args, **kwargs)
def _get_img_feats(self, backbone_out, img_ids):
"""Retrieve correct image features from backbone output."""
if "backbone_fpn" in backbone_out:
if "id_mapping" in backbone_out and backbone_out["id_mapping"] is not None:
img_ids = backbone_out["id_mapping"][img_ids]
# If this assert fails, it likely means we're requesting different img_ids (perhaps a different frame?)
# We currently don't expect this to happen. We could technically trigger a recompute here,
# but likely at the cost of a cpu<->gpu sync point, which would deteriorate perf
torch._assert_async((img_ids >= 0).all())
vis_feats = backbone_out["backbone_fpn"][-self.num_feature_levels :]
vis_pos_enc = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
vis_feat_sizes = [x.shape[-2:] for x in vis_pos_enc] # (H, W) shapes
# index and flatten visual features NxCxHxW => HWxNxC (batch-first => seq-first)
img_feats = [x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_feats]
img_pos_embeds = [
x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_pos_enc
]
return backbone_out, img_feats, img_pos_embeds, vis_feat_sizes
# Image features not available in backbone output, so we compute them on the fly
# This case likely occurs for video. In that case, we want to forward only the current frame
img_batch = backbone_out["img_batch_all_stages"]
if img_ids.numel() > 1:
# Only forward backbone on unique image ids to avoid repetitive computation
unique_ids, _ = torch.unique(img_ids, return_inverse=True)
else:
unique_ids, _ = img_ids, slice(None)
# Compute the image features on those unique image ids
# note: we allow using a list (or other indexable types) of tensors as img_batch
# (e.g. for async frame loading in demo). In this case we index img_batch.tensors directly
if isinstance(img_batch, torch.Tensor):
image = img_batch[unique_ids]
elif unique_ids.numel() == 1:
image = img_batch[unique_ids.item()].unsqueeze(0)
else:
image = torch.stack([img_batch[i] for i in unique_ids.tolist()])
# `img_batch` might be fp16 and offloaded to CPU
image = image.to(dtype=torch.float32, device=self.device)
# Next time we call this function, we want to remember which indices we computed
id_mapping = torch.full(
(len(img_batch),), -1, dtype=torch.long, device=self.device
)
id_mapping[unique_ids] = torch.arange(len(unique_ids), device=self.device)
backbone_out = {
**backbone_out,
**self.backbone.forward_image(image),
"id_mapping": id_mapping,
}
assert "backbone_fpn" in backbone_out
return self._get_img_feats(backbone_out, img_ids=img_ids)
def _encode_prompt(
self,
backbone_out,
find_input,
geometric_prompt,
visual_prompt_embed=None,
visual_prompt_mask=None,
encode_text=True,
prev_mask_pred=None,
):
# index text features (note that regardless of early or late fusion, the batch size of
# `txt_feats` is always the number of *prompts* in the encoder)
txt_ids = find_input.text_ids
txt_feats = backbone_out["language_features"][:, txt_ids]
txt_masks = backbone_out["language_mask"][txt_ids]
feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids)
backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple
if prev_mask_pred is not None:
img_feats = [img_feats[-1] + prev_mask_pred]
# Encode geometry
geo_feats, geo_masks = self.geometry_encoder(
geo_prompt=geometric_prompt,
img_feats=img_feats,
img_sizes=vis_feat_sizes,
img_pos_embeds=img_pos_embeds,
)
if visual_prompt_embed is None:
visual_prompt_embed = torch.zeros(
(0, *geo_feats.shape[1:]), device=geo_feats.device
)
visual_prompt_mask = torch.zeros(
(*geo_masks.shape[:-1], 0),
device=geo_masks.device,
dtype=geo_masks.dtype,
)
if encode_text:
prompt = torch.cat([txt_feats, geo_feats, visual_prompt_embed], dim=0)
prompt_mask = torch.cat([txt_masks, geo_masks, visual_prompt_mask], dim=1)
else:
prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0)
prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1)
return prompt, prompt_mask, backbone_out
def _run_encoder(
self,
backbone_out,
find_input,
prompt,
prompt_mask,
encoder_extra_kwargs: Optional[Dict] = None,
):
feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids)
backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple
# Run the encoder
prompt_pos_embed = torch.zeros_like(prompt)
# make a copy of the image feature lists since the encoder may modify these lists in-place
memory = self.transformer.encoder(
src=img_feats.copy(),
src_key_padding_mask=None,
src_pos=img_pos_embeds.copy(),
prompt=prompt,
prompt_pos=prompt_pos_embed,
prompt_key_padding_mask=prompt_mask,
feat_sizes=vis_feat_sizes,
encoder_extra_kwargs=encoder_extra_kwargs,
)
encoder_out = {
# encoded image features
"encoder_hidden_states": memory["memory"],
"pos_embed": memory["pos_embed"],
"padding_mask": memory["padding_mask"],
"level_start_index": memory["level_start_index"],
"spatial_shapes": memory["spatial_shapes"],
"valid_ratios": memory["valid_ratios"],
"vis_feat_sizes": vis_feat_sizes,
# encoded text features (or other prompts)
"prompt_before_enc": prompt,
"prompt_after_enc": memory.get("memory_text", prompt),
"prompt_mask": prompt_mask,
}
return backbone_out, encoder_out, feat_tuple
def _run_decoder(
self,
pos_embed,
memory,
src_mask,
out,
prompt,
prompt_mask,
encoder_out,
):
bs = memory.shape[1]
query_embed = self.transformer.decoder.query_embed.weight
tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
apply_dac = self.transformer.decoder.dac and self.training
hs, reference_boxes, dec_presence_out, dec_presence_feats = (
self.transformer.decoder(
tgt=tgt,
memory=memory,
memory_key_padding_mask=src_mask,
pos=pos_embed,
reference_boxes=None,
level_start_index=encoder_out["level_start_index"],
spatial_shapes=encoder_out["spatial_shapes"],
valid_ratios=encoder_out["valid_ratios"],
tgt_mask=None,
memory_text=prompt,
text_attention_mask=prompt_mask,
apply_dac=apply_dac,
)
)
hs = hs.transpose(1, 2) # seq-first to batch-first
reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first
if dec_presence_out is not None:
# seq-first to batch-first
dec_presence_out = dec_presence_out.transpose(1, 2)
out["presence_feats"] = dec_presence_feats
self._update_scores_and_boxes(
out,
hs,
reference_boxes,
prompt,
prompt_mask,
dec_presence_out=dec_presence_out,
)
return out, hs
def _update_scores_and_boxes(
self,
out,
hs,
reference_boxes,
prompt,
prompt_mask,
dec_presence_out=None,
is_instance_prompt=False,
):
apply_dac = self.transformer.decoder.dac and self.training
num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2)
num_o2m = hs.size(2) - num_o2o
assert num_o2m == (num_o2o if apply_dac else 0)
out["queries"] = hs[-1][:, :num_o2o] # remove o2m queries if there are any
# score prediction
if self.use_dot_prod_scoring:
dot_prod_scoring_head = self.dot_prod_scoring
if is_instance_prompt and self.instance_dot_prod_scoring is not None:
dot_prod_scoring_head = self.instance_dot_prod_scoring
outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask)
else:
class_embed_head = self.class_embed
if is_instance_prompt and self.instance_class_embed is not None:
class_embed_head = self.instance_class_embed
outputs_class = class_embed_head(hs)
# box prediction
box_head = self.transformer.decoder.bbox_embed
if (
is_instance_prompt
and self.transformer.decoder.instance_bbox_embed is not None
):
box_head = self.transformer.decoder.instance_bbox_embed
anchor_box_offsets = box_head(hs)
reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid()
outputs_boxes_xyxy = box_cxcywh_to_xyxy(outputs_coord)
if dec_presence_out is not None:
_update_out(
out, "presence_logit_dec", dec_presence_out, update_aux=self.training
)
if self.supervise_joint_box_scores:
assert dec_presence_out is not None
prob_dec_presence_out = dec_presence_out.clone().sigmoid()
if self.detach_presence_in_joint_score:
prob_dec_presence_out = prob_dec_presence_out.detach()
outputs_class = inverse_sigmoid(
outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)
).clamp(min=-10.0, max=10.0)
_update_out(
out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=self.training
)
_update_out(
out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=self.training
)
_update_out(
out,
"pred_boxes_xyxy",
outputs_boxes_xyxy[:, :, :num_o2o],
update_aux=self.training,
)
if num_o2m > 0 and self.training:
_update_out(
out,
"pred_logits_o2m",
outputs_class[:, :, num_o2o:],
update_aux=self.training,
)
_update_out(
out,
"pred_boxes_o2m",
outputs_coord[:, :, num_o2o:],
update_aux=self.training,
)
_update_out(
out,
"pred_boxes_xyxy_o2m",
outputs_boxes_xyxy[:, :, num_o2o:],
update_aux=self.training,
)
def _run_segmentation_heads(
self,
out,
backbone_out,
img_ids,
vis_feat_sizes,
encoder_hidden_states,
prompt,
prompt_mask,
hs,
):
apply_dac = self.transformer.decoder.dac and self.training
if self.segmentation_head is not None:
num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2)
num_o2m = hs.size(2) - num_o2o
obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o]
seg_head_outputs = activation_ckpt_wrapper(self.segmentation_head)(
backbone_feats=backbone_out["backbone_fpn"],
obj_queries=obj_queries,
image_ids=img_ids,
encoder_hidden_states=encoder_hidden_states,
act_ckpt_enable=self.training and self.use_act_checkpoint_seg_head,
prompt=prompt,
prompt_mask=prompt_mask,
)
aux_masks = False # self.aux_loss and self.segmentation_head.aux_masks
for k, v in seg_head_outputs.items():
if k in self.segmentation_head.instance_keys:
_update_out(out, k, v[:, :num_o2o], auxiliary=aux_masks)
if (
self.o2m_mask_predict and num_o2m > 0
): # handle o2m mask prediction
_update_out(
out, f"{k}_o2m", v[:, num_o2o:], auxiliary=aux_masks
)
else:
out[k] = v
else:
backbone_out.pop("backbone_fpn", None)
def _get_best_mask(self, out):
prev_mask_idx = out["pred_logits"].argmax(dim=1).squeeze(1)
batch_idx = torch.arange(
out["pred_logits"].shape[0], device=prev_mask_idx.device
)
prev_mask_pred = out["pred_masks"][batch_idx, prev_mask_idx][:, None]
# Downsample mask to match image resolution.
prev_mask_pred = self.geometry_encoder.mask_encoder.mask_downsampler(
prev_mask_pred
)
prev_mask_pred = prev_mask_pred.flatten(-2).permute(2, 0, 1)
return prev_mask_pred
def forward_grounding(
self,
backbone_out,
find_input,
find_target,
geometric_prompt: Prompt,
):
with torch.profiler.record_function("SAM3Image._encode_prompt"):
prompt, prompt_mask, backbone_out = self._encode_prompt(
backbone_out, find_input, geometric_prompt
)
# Run the encoder
with torch.profiler.record_function("SAM3Image._run_encoder"):
backbone_out, encoder_out, _ = self._run_encoder(
backbone_out, find_input, prompt, prompt_mask
)
out = {
"encoder_hidden_states": encoder_out["encoder_hidden_states"],
"prev_encoder_out": {
"encoder_out": encoder_out,
"backbone_out": backbone_out,
},
}
# Run the decoder
with torch.profiler.record_function("SAM3Image._run_decoder"):
out, hs = self._run_decoder(
memory=out["encoder_hidden_states"],
pos_embed=encoder_out["pos_embed"],
src_mask=encoder_out["padding_mask"],
out=out,
prompt=prompt,
prompt_mask=prompt_mask,
encoder_out=encoder_out,
)
# Run segmentation heads
with torch.profiler.record_function("SAM3Image._run_segmentation_heads"):
self._run_segmentation_heads(
out=out,
backbone_out=backbone_out,
img_ids=find_input.img_ids,
vis_feat_sizes=encoder_out["vis_feat_sizes"],
encoder_hidden_states=out["encoder_hidden_states"],
prompt=prompt,
prompt_mask=prompt_mask,
hs=hs,
)
if self.training or self.num_interactive_steps_val > 0:
self._compute_matching(out, self.back_convert(find_target))
return out
def _postprocess_out(self, out: Dict, multimask_output: bool = False):
# For multimask output, during eval we return the single best mask with the dict keys expected by the evaluators, but also return the multimasks output with new keys.
num_mask_boxes = out["pred_boxes"].size(1)
if not self.training and multimask_output and num_mask_boxes > 1:
out["multi_pred_logits"] = out["pred_logits"]
if "pred_masks" in out:
out["multi_pred_masks"] = out["pred_masks"]
out["multi_pred_boxes"] = out["pred_boxes"]
out["multi_pred_boxes_xyxy"] = out["pred_boxes_xyxy"]
best_mask_idx = out["pred_logits"].argmax(1).squeeze(1)
batch_idx = torch.arange(len(best_mask_idx), device=best_mask_idx.device)
out["pred_logits"] = out["pred_logits"][batch_idx, best_mask_idx].unsqueeze(
1
)
if "pred_masks" in out:
out["pred_masks"] = out["pred_masks"][
batch_idx, best_mask_idx
].unsqueeze(1)
out["pred_boxes"] = out["pred_boxes"][batch_idx, best_mask_idx].unsqueeze(1)
out["pred_boxes_xyxy"] = out["pred_boxes_xyxy"][
batch_idx, best_mask_idx
].unsqueeze(1)
return out
def _get_dummy_prompt(self, num_prompts=1):
device = self.device
geometric_prompt = Prompt(
box_embeddings=torch.zeros(0, num_prompts, 4, device=device),
box_mask=torch.zeros(num_prompts, 0, device=device, dtype=torch.bool),
)
return geometric_prompt
def forward(self, input: BatchedDatapoint):
device = self.device
backbone_out = {"img_batch_all_stages": input.img_batch}
backbone_out.update(self.backbone.forward_image(input.img_batch))
num_frames = len(input.find_inputs)
assert num_frames == 1
text_outputs = self.backbone.forward_text(input.find_text_batch, device=device)
backbone_out.update(text_outputs)
previous_stages_out = SAM3Output(
iter_mode=SAM3Output.IterMode.LAST_STEP_PER_STAGE
)
find_input = input.find_inputs[0]
find_target = input.find_targets[0]
if find_input.input_points is not None and find_input.input_points.numel() > 0:
print("Warning: Point prompts are ignored in PCS.")
num_interactive_steps = 0 if self.training else self.num_interactive_steps_val
geometric_prompt = Prompt(
box_embeddings=find_input.input_boxes,
box_mask=find_input.input_boxes_mask,
box_labels=find_input.input_boxes_label,
)
# Init vars that are shared across the loop.
stage_outs = []
for cur_step in range(num_interactive_steps + 1):
if cur_step > 0:
# We sample interactive geometric prompts (boxes, points)
geometric_prompt, _ = self.interactive_prompt_sampler.sample(
geo_prompt=geometric_prompt,
find_target=find_target,
previous_out=stage_outs[-1],
)
out = self.forward_grounding(
backbone_out=backbone_out,
find_input=find_input,
find_target=find_target,
geometric_prompt=geometric_prompt.clone(),
)
stage_outs.append(out)
previous_stages_out.append(stage_outs)
return previous_stages_out
def _compute_matching(self, out, targets):
out["indices"] = self.matcher(out, targets)
for aux_out in out.get("aux_outputs", []):
aux_out["indices"] = self.matcher(aux_out, targets)
def back_convert(self, targets):
batched_targets = {
"boxes": targets.boxes.view(-1, 4),
"boxes_xyxy": box_cxcywh_to_xyxy(targets.boxes.view(-1, 4)),
"boxes_padded": targets.boxes_padded,
"positive_map": targets.boxes.new_ones(len(targets.boxes), 1),
"num_boxes": targets.num_boxes,
"masks": targets.segments,
"semantic_masks": targets.semantic_segments,
"is_valid_mask": targets.is_valid_segment,
"is_exhaustive": targets.is_exhaustive,
"object_ids_packed": targets.object_ids,
"object_ids_padded": targets.object_ids_padded,
}
return batched_targets
def predict_inst(
self,
inference_state,
**kwargs,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
orig_h, orig_w = (
inference_state["original_height"],
inference_state["original_width"],
)
backbone_out = inference_state["backbone_out"]["sam2_backbone_out"]
(
_,
vision_feats,
_,
_,
) = self.inst_interactive_predictor.model._prepare_backbone_features(
backbone_out
)
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
vision_feats[-1] = (
vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed
)
feats = [
feat.permute(1, 2, 0).view(1, -1, *feat_size)
for feat, feat_size in zip(
vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1]
)
][::-1]
self.inst_interactive_predictor._features = {
"image_embed": feats[-1],
"high_res_feats": feats[:-1],
}
self.inst_interactive_predictor._is_image_set = True
self.inst_interactive_predictor._orig_hw = [(orig_h, orig_w)]
res = self.inst_interactive_predictor.predict(**kwargs)
self.inst_interactive_predictor._features = None
self.inst_interactive_predictor._is_image_set = False
return res
def predict_inst_batch(
self,
inference_state,
*args,
**kwargs,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
backbone_out = inference_state["backbone_out"]["sam2_backbone_out"]
(
_,
vision_feats,
_,
_,
) = self.inst_interactive_predictor.model._prepare_backbone_features(
backbone_out
)
# Add no_mem_embed, which is added to the lowest res feat. map during training on videos
vision_feats[-1] = (
vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed
)
batch_size = vision_feats[-1].shape[1]
orig_heights, orig_widths = (
inference_state["original_heights"],
inference_state["original_widths"],
)
assert (
batch_size == len(orig_heights) == len(orig_widths)
), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
feats = [
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
for feat, feat_size in zip(
vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1]
)
][::-1]
self.inst_interactive_predictor._features = {
"image_embed": feats[-1],
"high_res_feats": feats[:-1],
}
self.inst_interactive_predictor._is_image_set = True
self.inst_interactive_predictor._is_batch = True
self.inst_interactive_predictor._orig_hw = [
(orig_h, orig_w) for orig_h, orig_w in zip(orig_heights, orig_widths)
]
res = self.inst_interactive_predictor.predict_batch(*args, **kwargs)
self.inst_interactive_predictor._features = None
self.inst_interactive_predictor._is_image_set = False
self.inst_interactive_predictor._is_batch = False
return res
class Sam3ImageOnVideoMultiGPU(Sam3Image):
def __init__(
self, *args, async_all_gather=True, gather_backbone_out=None, **kwargs
):
super().__init__(*args, **kwargs)
self.rank = int(os.getenv("RANK", "0"))
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
self.async_all_gather = async_all_gather
# if gather_backbone is not set, default to gathering only for `SAM3VLBackbone`
if gather_backbone_out is None:
gather_backbone_out = isinstance(self.backbone, SAM3VLBackbone)
self.gather_backbone_out = gather_backbone_out
def forward_video_grounding_multigpu(
self,
backbone_out,
find_inputs,
geometric_prompt: Prompt,
frame_idx,
num_frames,
# `multigpu_buffer` is a dict to cache detector's outputs in a chunk between different calls
multigpu_buffer,
track_in_reverse=False,
# whether to also return the SAM2 backbone features
return_sam2_backbone_feats=False,
# whether to perform NMS and suppress the scores of those detections removed by NMS
run_nms=False,
nms_prob_thresh=None,
nms_iou_thresh=None,
**kwargs,
):
"""
Compute the detector's detection outputs in a distributed manner, where all GPUs process
a chunk of frames (equal to the number of GPUs) at once and store them in cache.
"""
# Step 1: fetch the detector outputs in the current chunk from buffer
frame_idx_curr_b = frame_idx - frame_idx % self.world_size
frame_idx_curr_e = min(frame_idx_curr_b + self.world_size, num_frames)
# in case the current frame's detection results are not in the buffer yet, build the current chunk
# (this should only happen on the first chunk, since we are also building the next chunk below)
if frame_idx not in multigpu_buffer:
with torch.profiler.record_function("build_multigpu_buffer_next_chunk1"):
self._build_multigpu_buffer_next_chunk(
backbone_out=backbone_out,
find_inputs=find_inputs,
geometric_prompt=geometric_prompt,
frame_idx_begin=frame_idx_curr_b,
frame_idx_end=frame_idx_curr_e,
num_frames=num_frames,
multigpu_buffer=multigpu_buffer,
run_nms=run_nms,
nms_prob_thresh=nms_prob_thresh,
nms_iou_thresh=nms_iou_thresh,
)
# read out the current frame's results from `multigpu_buffer`
out = {}
for k, (v, handle) in multigpu_buffer[frame_idx].items():
if k.startswith("sam2_backbone_") and not return_sam2_backbone_feats:
continue
if handle is not None:
handle.wait() # wait for async all-gather to finish
out[k] = v
# Step 2: remove detection outputs of the previous chunk from cache to save GPU memory
if not track_in_reverse and frame_idx_curr_b - self.world_size >= 0:
frame_idx_prev_e = frame_idx_curr_b
frame_idx_prev_b = frame_idx_curr_b - self.world_size
elif track_in_reverse and frame_idx_curr_e < num_frames:
frame_idx_prev_b = frame_idx_curr_e
frame_idx_prev_e = min(frame_idx_prev_b + self.world_size, num_frames)
else:
frame_idx_prev_b = frame_idx_prev_e = None
if frame_idx_prev_b is not None:
for frame_idx_rm in range(frame_idx_prev_b, frame_idx_prev_e):
multigpu_buffer.pop(frame_idx_rm, None)
# Step 3: compute and cache detection outputs of the next chunk ahead of time
# (so that we can overlap computation with all-gather transfer)
if not track_in_reverse and frame_idx_curr_e < num_frames:
frame_idx_next_b = frame_idx_curr_e
frame_idx_next_e = min(frame_idx_next_b + self.world_size, num_frames)
elif track_in_reverse and frame_idx_curr_b - self.world_size >= 0:
frame_idx_next_e = frame_idx_curr_b
frame_idx_next_b = frame_idx_curr_b - self.world_size
else:
frame_idx_next_b = frame_idx_next_e = None
if frame_idx_next_b is not None and frame_idx_next_b not in multigpu_buffer:
with torch.profiler.record_function("build_multigpu_buffer_next_chunk2"):
self._build_multigpu_buffer_next_chunk(
backbone_out=backbone_out,
find_inputs=find_inputs,
geometric_prompt=geometric_prompt,
frame_idx_begin=frame_idx_next_b,
frame_idx_end=frame_idx_next_e,
num_frames=num_frames,
multigpu_buffer=multigpu_buffer,
run_nms=run_nms,
nms_prob_thresh=nms_prob_thresh,
nms_iou_thresh=nms_iou_thresh,
)
return out, backbone_out
def _build_multigpu_buffer_next_chunk(
self,
backbone_out,
find_inputs,
geometric_prompt: Prompt,
frame_idx_begin,
frame_idx_end,
num_frames,
multigpu_buffer,
run_nms=False,
nms_prob_thresh=None,
nms_iou_thresh=None,
):
"""Compute detection outputs on a chunk of frames and store their results in multigpu_buffer."""
# each GPU computes detections on one frame in the chunk (in a round-robin manner)
frame_idx_local_gpu = min(frame_idx_begin + self.rank, frame_idx_end - 1)
# `forward_grounding` (from base class `Sam3ImageOnVideo`) runs the detector on a single frame
with torch.profiler.record_function("forward_grounding"):
out_local = self.forward_grounding(
backbone_out=backbone_out,
find_input=find_inputs[frame_idx_local_gpu],
find_target=None,
geometric_prompt=geometric_prompt,
)
if run_nms:
with torch.profiler.record_function("nms_masks"):
# run NMS as a post-processing step on top of the detection outputs
assert nms_prob_thresh is not None and nms_iou_thresh is not None
pred_probs = out_local["pred_logits"].squeeze(-1).sigmoid()
pred_masks = out_local["pred_masks"]
# loop over text prompts (not an overhead for demo where there's only 1 prompt)
for prompt_idx in range(pred_probs.size(0)):
keep = nms_masks(
pred_probs=pred_probs[prompt_idx],
pred_masks=pred_masks[prompt_idx],
prob_threshold=nms_prob_thresh,
iou_threshold=nms_iou_thresh,
)
# set a very low threshold for those detections removed by NMS
out_local["pred_logits"][prompt_idx, :, 0] -= 1e4 * (~keep).float()
if self.gather_backbone_out:
# gather the SAM 2 backbone features across GPUs
feats = out_local["prev_encoder_out"]["backbone_out"]["sam2_backbone_out"]
assert len(feats["backbone_fpn"]) == 3 # SAM2 backbone always have 3 levels
# cast the SAM2 backbone features to bfloat16 for all-gather (this is usually
# a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP)
backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]]
fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0])
fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1])
fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2])
# vision_pos_enc is the same on all frames, so no need to all-gather them
vision_pos_enc = feats["vision_pos_enc"]
# trim the detector output to only include the necessary keys
out_local = {
"pred_logits": out_local["pred_logits"],
"pred_boxes": out_local["pred_boxes"],
"pred_boxes_xyxy": out_local["pred_boxes_xyxy"],
"pred_masks": out_local["pred_masks"],
}
# gather the results: after this step, each GPU will receive detector outputs on
# all frames in the chunk and store them in `multigpu_buffer`
out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()}
for rank in range(self.world_size):
frame_idx_to_save = frame_idx_begin + rank
if frame_idx_to_save >= num_frames:
continue
frame_buffer = {
k: (v[rank], handle) for k, (v, handle) in out_gathered.items()
}
if self.gather_backbone_out:
# also add gathered SAM 2 backbone features to frame_buffer
frame_buffer["tracker_backbone_fpn_0"] = (fpn0[rank], fpn_handle0)
frame_buffer["tracker_backbone_fpn_1"] = (fpn1[rank], fpn_handle1)
frame_buffer["tracker_backbone_fpn_2"] = (fpn2[rank], fpn_handle2)
frame_buffer["tracker_backbone_pos_enc"] = (vision_pos_enc, None)
multigpu_buffer[frame_idx_to_save] = frame_buffer
def _gather_tensor(self, x):
if self.world_size == 1:
return [x], None
async_op = self.async_all_gather
# here `.contiguous()` is required -- otherwise NCCL all_gather
# sometimes gives wrong results
x = x.contiguous() # ensure contiguous memory for NCCL
output_list = [torch.empty_like(x) for _ in range(self.world_size)]
handle = torch.distributed.all_gather(output_list, x, async_op=async_op)
return output_list, handle

View File

@@ -0,0 +1,222 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Dict, List
import numpy as np
import PIL
import torch
from sam3.model import box_ops
from sam3.model.data_misc import FindStage, interpolate
from torchvision.transforms import v2
class Sam3Processor:
""" """
def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5):
self.model = model
self.resolution = resolution
self.device = device
self.transform = v2.Compose(
[
v2.ToDtype(torch.uint8, scale=True),
v2.Resize(size=(resolution, resolution)),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
self.confidence_threshold = confidence_threshold
self.find_stage = FindStage(
img_ids=torch.tensor([0], device=device, dtype=torch.long),
text_ids=torch.tensor([0], device=device, dtype=torch.long),
input_boxes=None,
input_boxes_mask=None,
input_boxes_label=None,
input_points=None,
input_points_mask=None,
)
@torch.inference_mode()
def set_image(self, image, state=None):
"""Sets the image on which we want to do predictions."""
if state is None:
state = {}
if isinstance(image, PIL.Image.Image):
width, height = image.size
elif isinstance(image, (torch.Tensor, np.ndarray)):
height, width = image.shape[-2:]
else:
raise ValueError("Image must be a PIL image or a tensor")
image = v2.functional.to_image(image).to(self.device)
image = self.transform(image).unsqueeze(0)
state["original_height"] = height
state["original_width"] = width
state["backbone_out"] = self.model.backbone.forward_image(image)
inst_interactivity_en = self.model.inst_interactive_predictor is not None
if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]:
sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"]
sam2_backbone_out["backbone_fpn"][0] = (
self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0(
sam2_backbone_out["backbone_fpn"][0]
)
)
sam2_backbone_out["backbone_fpn"][1] = (
self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1(
sam2_backbone_out["backbone_fpn"][1]
)
)
return state
@torch.inference_mode()
def set_image_batch(self, images: List[np.ndarray], state=None):
"""Sets the image batch on which we want to do predictions."""
if state is None:
state = {}
if not isinstance(images, list):
raise ValueError("Images must be a list of PIL images or tensors")
assert len(images) > 0, "Images list must not be empty"
assert isinstance(
images[0], PIL.Image.Image
), "Images must be a list of PIL images"
state["original_heights"] = [image.height for image in images]
state["original_widths"] = [image.width for image in images]
images = [
self.transform(v2.functional.to_image(image).to(self.device))
for image in images
]
images = torch.stack(images, dim=0)
state["backbone_out"] = self.model.backbone.forward_image(images)
inst_interactivity_en = self.model.inst_interactive_predictor is not None
if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]:
sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"]
sam2_backbone_out["backbone_fpn"][0] = (
self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0(
sam2_backbone_out["backbone_fpn"][0]
)
)
sam2_backbone_out["backbone_fpn"][1] = (
self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1(
sam2_backbone_out["backbone_fpn"][1]
)
)
return state
@torch.inference_mode()
def set_text_prompt(self, prompt: str, state: Dict):
"""Sets the text prompt and run the inference"""
if "backbone_out" not in state:
raise ValueError("You must call set_image before set_text_prompt")
text_outputs = self.model.backbone.forward_text([prompt], device=self.device)
# will erase the previous text prompt if any
state["backbone_out"].update(text_outputs)
if "geometric_prompt" not in state:
state["geometric_prompt"] = self.model._get_dummy_prompt()
return self._forward_grounding(state)
@torch.inference_mode()
def add_geometric_prompt(self, box: List, label: bool, state: Dict):
"""Adds a box prompt and run the inference.
The image needs to be set, but not necessarily the text prompt.
The box is assumed to be in [center_x, center_y, width, height] format and normalized in [0, 1] range.
The label is True for a positive box, False for a negative box.
"""
if "backbone_out" not in state:
raise ValueError("You must call set_image before set_text_prompt")
if "language_features" not in state["backbone_out"]:
# Looks like we don't have a text prompt yet. This is allowed, but we need to set the text prompt to "visual" for the model to rely only on the geometric prompt
dummy_text_outputs = self.model.backbone.forward_text(
["visual"], device=self.device
)
state["backbone_out"].update(dummy_text_outputs)
if "geometric_prompt" not in state:
state["geometric_prompt"] = self.model._get_dummy_prompt()
# adding a batch and sequence dimension
boxes = torch.tensor(box, device=self.device, dtype=torch.float32).view(1, 1, 4)
labels = torch.tensor([label], device=self.device, dtype=torch.bool).view(1, 1)
state["geometric_prompt"].append_boxes(boxes, labels)
return self._forward_grounding(state)
def reset_all_prompts(self, state: Dict):
"""Removes all the prompts and results"""
if "backbone_out" in state:
backbone_keys_to_del = [
"language_features",
"language_mask",
"language_embeds",
]
for key in backbone_keys_to_del:
if key in state["backbone_out"]:
del state["backbone_out"][key]
keys_to_del = ["geometric_prompt", "boxes", "masks", "masks_logits", "scores"]
for key in keys_to_del:
if key in state:
del state[key]
@torch.inference_mode()
def set_confidence_threshold(self, threshold: float, state=None):
"""Sets the confidence threshold for the masks"""
self.confidence_threshold = threshold
if state is not None and "boxes" in state:
# we need to filter the boxes again
# In principle we could do this more efficiently since we would only need
# to rerun the heads. But this is simpler and not too inefficient
return self._forward_grounding(state)
return state
@torch.inference_mode()
def _forward_grounding(self, state: Dict):
outputs = self.model.forward_grounding(
backbone_out=state["backbone_out"],
find_input=self.find_stage,
geometric_prompt=state["geometric_prompt"],
find_target=None,
)
out_bbox = outputs["pred_boxes"]
out_logits = outputs["pred_logits"]
out_masks = outputs["pred_masks"]
out_probs = out_logits.sigmoid()
presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1)
out_probs = (out_probs * presence_score).squeeze(-1)
keep = out_probs > self.confidence_threshold
out_probs = out_probs[keep]
out_masks = out_masks[keep]
out_bbox = out_bbox[keep]
# convert to [x0, y0, x1, y1] format
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
img_h = state["original_height"]
img_w = state["original_width"]
scale_fct = torch.tensor([img_w, img_h, img_w, img_h]).to(self.device)
boxes = boxes * scale_fct[None, :]
out_masks = interpolate(
out_masks.unsqueeze(1),
(img_h, img_w),
mode="bilinear",
align_corners=False,
).sigmoid()
state["masks_logits"] = out_masks
state["masks"] = out_masks > 0.5
state["boxes"] = boxes
state["scores"] = out_probs
return state

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,427 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import numpy as np
import torch
import torch.nn.functional as F
from numpy.typing import NDArray
from sam3.model.edt import edt_triton
def sample_box_points(
masks: torch.Tensor,
noise: float = 0.1, # SAM default
noise_bound: int = 20, # SAM default
top_left_label: int = 2,
bottom_right_label: int = 3,
) -> tuple[NDArray, NDArray]:
"""
Sample a noised version of the top left and bottom right corners of a given `bbox`
Inputs:
- masks: [B, 1, H, W] tensor
- noise: noise as a fraction of box width and height, dtype=float
- noise_bound: maximum amount of noise (in pure pixels), dtype=int
Returns:
- box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
- box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
"""
device = masks.device
box_coords = mask_to_box(masks)
B, _, H, W = masks.shape
box_labels = torch.tensor(
[top_left_label, bottom_right_label], dtype=torch.int, device=device
).repeat(B)
if noise > 0.0:
if not isinstance(noise_bound, torch.Tensor):
noise_bound = torch.tensor(noise_bound, device=device)
bbox_w = box_coords[..., 2] - box_coords[..., 0]
bbox_h = box_coords[..., 3] - box_coords[..., 1]
max_dx = torch.min(bbox_w * noise, noise_bound)
max_dy = torch.min(bbox_h * noise, noise_bound)
box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
box_coords = box_coords + box_noise
img_bounds = (
torch.tensor([W, H, W, H], device=device) - 1
) # uncentered pixel coords
box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
box_labels = box_labels.reshape(-1, 2)
return box_coords, box_labels
def mask_to_box(masks: torch.Tensor):
"""
compute bounding box given an input mask
Inputs:
- masks: [B, 1, H, W] tensor
Returns:
- box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
"""
B, _, h, w = masks.shape
device = masks.device
mask_area = masks.sum(dim=(-1, -2))
xs = torch.arange(w, device=device, dtype=torch.int32)
ys = torch.arange(h, device=device, dtype=torch.int32)
grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
bbox_coords = torch.where(
mask_area[..., None] > 0, bbox_coords, torch.zeros_like(bbox_coords)
)
return bbox_coords
def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
"""
Sample `num_pt` random points (along with their labels) independently from the error regions.
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- num_pt: int, number of points to sample independently for each of the B error maps
Outputs:
- points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
negative clicks
"""
if pred_masks is None: # if pred_masks is not provided, treat it as empty
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
assert num_pt >= 0
B, _, H_im, W_im = gt_masks.shape
device = gt_masks.device
# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = ~gt_masks & pred_masks
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = gt_masks & ~pred_masks
# whether the prediction completely match the ground-truth on each mask
all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
all_correct = all_correct[..., None, None]
# channel 0 is FP map, while channel 1 is FN map
pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
# sample a negative new click from FP region or a positive new click
# from FN region, depend on where the maximum falls,
# and in case the predictions are all correct (no FP or FN), we just
# sample a negative click from the background region
pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
pts_noise[..., 1] *= fn_masks
pts_idx = pts_noise.flatten(2).argmax(dim=2)
labels = (pts_idx % 2).to(torch.int32)
pts_idx = pts_idx // 2
pts_x = pts_idx % W_im
pts_y = pts_idx // W_im
points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
return points, labels
def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
"""
Sample 1 random point (along with its label) from the center of each error region,
that is, the point with the largest distance to the boundary of each error region.
This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- padding: if True, pad with boundary of 1 px for distance transform
Outputs:
- points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
"""
if pred_masks is None:
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
B, _, H, W = gt_masks.shape
# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = (~gt_masks & pred_masks).squeeze(1)
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = (gt_masks & ~pred_masks).squeeze(1)
if padding:
padded_fp_masks = torch.zeros(
B, H + 2, W + 2, dtype=fp_masks.dtype, device=fp_masks.device
)
padded_fp_masks[:, 1 : H + 1, 1 : W + 1] = fp_masks
padded_fn_masks = torch.zeros(
B, H + 2, W + 2, dtype=fp_masks.dtype, device=fp_masks.device
)
padded_fn_masks[:, 1 : H + 1, 1 : W + 1] = fn_masks
else:
padded_fp_masks = fp_masks
padded_fn_masks = fn_masks
fn_mask_dt = edt_triton(padded_fn_masks)
fp_mask_dt = edt_triton(padded_fp_masks)
if padding:
fn_mask_dt = fn_mask_dt[:, 1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[:, 1:-1, 1:-1]
fn_max, fn_argmax = fn_mask_dt.reshape(B, -1).max(dim=-1)
fp_max, fp_argmax = fp_mask_dt.reshape(B, -1).max(dim=-1)
is_positive = fn_max > fp_max
chosen = torch.where(is_positive, fn_argmax, fp_argmax)
points_x = chosen % W
points_y = chosen // W
labels = is_positive.long()
points = torch.stack([points_x, points_y], -1)
return points.unsqueeze(1), labels.unsqueeze(1)
def sample_one_point_from_error_center_slow(gt_masks, pred_masks, padding=True):
"""
Sample 1 random point (along with its label) from the center of each error region,
that is, the point with the largest distance to the boundary of each error region.
This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- padding: if True, pad with boundary of 1 px for distance transform
Outputs:
- points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
"""
import cv2 # delay OpenCV import to avoid unnecessary dependency
if pred_masks is None:
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
B, _, _, W_im = gt_masks.shape
device = gt_masks.device
# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = ~gt_masks & pred_masks
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = gt_masks & ~pred_masks
fp_masks = fp_masks.cpu().numpy()
fn_masks = fn_masks.cpu().numpy()
points = torch.zeros(B, 1, 2, dtype=torch.float)
labels = torch.ones(B, 1, dtype=torch.int32)
for b in range(B):
fn_mask = fn_masks[b, 0]
fp_mask = fp_masks[b, 0]
if padding:
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
# compute the distance of each point in FN/FP region to its boundary
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
if padding:
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
# take the point in FN/FP region with the largest distance to its boundary
fn_mask_dt_flat = fn_mask_dt.reshape(-1)
fp_mask_dt_flat = fp_mask_dt.reshape(-1)
fn_argmax = np.argmax(fn_mask_dt_flat)
fp_argmax = np.argmax(fp_mask_dt_flat)
is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
pt_idx = fn_argmax if is_positive else fp_argmax
points[b, 0, 0] = pt_idx % W_im # x
points[b, 0, 1] = pt_idx // W_im # y
labels[b, 0] = int(is_positive)
points = points.to(device)
labels = labels.to(device)
return points, labels
def get_next_point(gt_masks, pred_masks, method):
if method == "uniform":
return sample_random_points_from_errors(gt_masks, pred_masks)
elif method == "center":
return sample_one_point_from_error_center(gt_masks, pred_masks)
else:
raise ValueError(f"unknown sampling method {method}")
def select_closest_cond_frames(
frame_idx, cond_frame_outputs, max_cond_frame_num, keep_first_cond_frame=False
):
"""
Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
that are temporally closest to the current frame at `frame_idx`. Here, we take
- a) the closest conditioning frame before `frame_idx` (if any);
- b) the closest conditioning frame after `frame_idx` (if any);
- c) any other temporally closest conditioning frames until reaching a total
of `max_cond_frame_num` conditioning frames.
Outputs:
- selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
- unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
"""
if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
selected_outputs = cond_frame_outputs
unselected_outputs = {}
else:
assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
selected_outputs = {}
if keep_first_cond_frame:
idx_first = min(
(t for t in cond_frame_outputs if t < frame_idx), default=None
)
if idx_first is None:
# Maybe we are tracking in reverse
idx_first = max(
(t for t in cond_frame_outputs if t > frame_idx), default=None
)
if idx_first is not None:
selected_outputs[idx_first] = cond_frame_outputs[idx_first]
# the closest conditioning frame before `frame_idx` (if any)
idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
if idx_before is not None:
selected_outputs[idx_before] = cond_frame_outputs[idx_before]
# the closest conditioning frame after `frame_idx` (if any)
idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
if idx_after is not None:
selected_outputs[idx_after] = cond_frame_outputs[idx_after]
# add other temporally closest conditioning frames until reaching a total
# of `max_cond_frame_num` conditioning frames.
num_remain = max_cond_frame_num - len(selected_outputs)
inds_remain = sorted(
(t for t in cond_frame_outputs if t not in selected_outputs),
key=lambda x: abs(x - frame_idx),
)[:num_remain]
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
unselected_outputs = {
t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
}
return selected_outputs, unselected_outputs
def get_1d_sine_pe(pos_inds, dim, temperature=10000):
"""
Get 1D sine positional embedding as in the original Transformer paper.
"""
pe_dim = dim // 2
dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
pos_embed = pos_inds.unsqueeze(-1) / dim_t
pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
return pos_embed
def get_best_gt_match_from_multimasks(pred_multimasks, gt_masks, pred_scores=None):
"""
Get the mask with the best match to GT masks (based on IoU) from pred_multimasks.
Optionally, use `pred_scores` to break ties in case all IoUs are zeros.
"""
assert pred_multimasks.ndim == 4 and gt_masks.ndim == 4
if pred_multimasks.size(1) == 1:
return pred_multimasks # only a single mask channel, nothing to select
pred_multimasks_binary = pred_multimasks > 0
area_i = torch.sum(pred_multimasks_binary & gt_masks, dim=(2, 3)).float()
area_u = torch.sum(pred_multimasks_binary | gt_masks, dim=(2, 3)).float()
ious = area_i / torch.clamp(area_u, min=1.0)
# In case all IoUs are zeros (e.g. because the GT mask is empty), use pred_scores
# to break ties and select the best mask
if pred_scores is not None:
has_nonzero_ious = torch.any(ious > 0).expand_as(ious)
scores = torch.where(has_nonzero_ious, ious, pred_scores)
else:
scores = ious
# Finally, take the best mask prediction (with the highest score)
best_scores_inds = torch.argmax(scores, dim=-1)
batch_inds = torch.arange(scores.size(0), device=scores.device)
best_pred_mask = pred_multimasks[batch_inds, best_scores_inds].unsqueeze(1)
return best_pred_mask
def fill_holes_in_mask_scores(mask, max_area, fill_holes=True, remove_sprinkles=True):
"""
A post processor to fill small holes in mask scores with area under `max_area`.
Holes are those small connected components in either background or foreground.
Note that it relies on the "cc_torch" package to find connected components fast. You can
install it via the following command (`TORCH_CUDA_ARCH_LIST=8.0` is for A100 GPUs):
```
pip uninstall -y cc_torch; TORCH_CUDA_ARCH_LIST=8.0 9.0 pip install git+https://github.com/ronghanghu/cc_torch
```
Otherwise, it will fallback to a slightly slower triton implementation, or skimage if the tensor is on cpu
"""
if max_area <= 0:
return mask # nothing to fill in this case
if fill_holes:
# We remove small connected components in background by changing them to foreground
# with a small positive mask score (0.1).
mask_bg = mask <= 0
bg_area_thresh = max_area
_, areas_bg = _get_connected_components_with_padding(mask_bg)
small_components_bg = mask_bg & (areas_bg <= bg_area_thresh)
mask = torch.where(small_components_bg, 0.1, mask)
if remove_sprinkles:
# We remove small connected components in foreground by changing them to background
# with a small negative mask score (-0.1). Here we only remove connected components
# whose areas are under both `max_area` and half of the entire mask's area. This
# removes sprinkles while avoids filtering out tiny objects that we want to track.
mask_fg = mask > 0
fg_area_thresh = torch.sum(mask_fg, dim=(2, 3), keepdim=True, dtype=torch.int32)
fg_area_thresh.floor_divide_(2).clamp_(max=max_area)
_, areas_fg = _get_connected_components_with_padding(mask_fg)
small_components_fg = mask_fg & (areas_fg <= fg_area_thresh)
mask = torch.where(small_components_fg, -0.1, mask)
return mask
def _get_connected_components_with_padding(mask):
"""Get connected components from masks (possibly padding them to an even size)."""
from sam3.perflib.connected_components import connected_components
mask = mask.to(torch.uint8)
_, _, H, W = mask.shape
# make sure both height and width are even (to be compatible with cc_torch)
pad_h = H % 2
pad_w = W % 2
if pad_h == 0 and pad_w == 0:
labels, counts = connected_components(mask)
else:
# pad the mask to make its height and width even
# padding format is (padding_left,padding_right,padding_top,padding_bottom)
mask_pad = F.pad(mask, (0, pad_w, 0, pad_h), mode="constant", value=0)
labels, counts = connected_components(mask_pad)
labels = labels[:, :, :H, :W]
counts = counts[:, :, :H, :W]
return labels, counts

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,521 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import datetime
import gc
import multiprocessing as mp
import os
import queue
import socket
import sys
import time
import uuid
from contextlib import closing
from typing import List, Optional
import psutil
import torch
from sam3.logger import get_logger
logger = get_logger(__name__)
class Sam3VideoPredictor:
# a global dictionary that holds all inference states for this model (key is session_id)
_ALL_INFERENCE_STATES = {}
def __init__(
self,
checkpoint_path=None,
bpe_path=None,
has_presence_token=True,
geo_encoder_use_img_cross_attn=True,
strict_state_dict_loading=True,
async_loading_frames=False,
video_loader_type="cv2",
apply_temporal_disambiguation: bool = True,
):
self.async_loading_frames = async_loading_frames
self.video_loader_type = video_loader_type
from sam3.model_builder import build_sam3_video_model
self.model = (
build_sam3_video_model(
checkpoint_path=checkpoint_path,
bpe_path=bpe_path,
has_presence_token=has_presence_token,
geo_encoder_use_img_cross_attn=geo_encoder_use_img_cross_attn,
strict_state_dict_loading=strict_state_dict_loading,
apply_temporal_disambiguation=apply_temporal_disambiguation,
)
.cuda()
.eval()
)
@torch.inference_mode()
def handle_request(self, request):
"""Dispatch a request based on its type."""
request_type = request["type"]
if request_type == "start_session":
return self.start_session(
resource_path=request["resource_path"],
session_id=request.get("session_id", None),
)
elif request_type == "add_prompt":
return self.add_prompt(
session_id=request["session_id"],
frame_idx=request["frame_index"],
text=request.get("text", None),
points=request.get("points", None),
point_labels=request.get("point_labels", None),
bounding_boxes=request.get("bounding_boxes", None),
bounding_box_labels=request.get("bounding_box_labels", None),
obj_id=request.get("obj_id", None),
)
elif request_type == "remove_object":
return self.remove_object(
session_id=request["session_id"],
obj_id=request["obj_id"],
is_user_action=request.get("is_user_action", True),
)
elif request_type == "reset_session":
return self.reset_session(session_id=request["session_id"])
elif request_type == "close_session":
return self.close_session(session_id=request["session_id"])
else:
raise RuntimeError(f"invalid request type: {request_type}")
@torch.inference_mode()
def handle_stream_request(self, request):
"""Dispatch a stream request based on its type."""
request_type = request["type"]
if request_type == "propagate_in_video":
yield from self.propagate_in_video(
session_id=request["session_id"],
propagation_direction=request.get("propagation_direction", "both"),
start_frame_idx=request.get("start_frame_index", None),
max_frame_num_to_track=request.get("max_frame_num_to_track", None),
)
else:
raise RuntimeError(f"invalid request type: {request_type}")
def start_session(self, resource_path, session_id=None):
"""
Start a new inference session on an image or a video. Here `resource_path`
can be either a path to an image file (for image inference) or an MP4 file
or directory with JPEG video frames (for video inference).
If `session_id` is defined, it will be used as identifier for the
session. If it is not defined, the start_session function will create
a session id and return it.
"""
# get an initial inference_state from the model
inference_state = self.model.init_state(
resource_path=resource_path,
async_loading_frames=self.async_loading_frames,
video_loader_type=self.video_loader_type,
)
if not session_id:
session_id = str(uuid.uuid4())
self._ALL_INFERENCE_STATES[session_id] = {
"state": inference_state,
"session_id": session_id,
"start_time": time.time(),
}
logger.debug(
f"started new session {session_id}; {self._get_session_stats()}; "
f"{self._get_torch_and_gpu_properties()}"
)
return {"session_id": session_id}
def add_prompt(
self,
session_id: str,
frame_idx: int,
text: Optional[str] = None,
points: Optional[List[List[float]]] = None,
point_labels: Optional[List[int]] = None,
bounding_boxes: Optional[List[List[float]]] = None,
bounding_box_labels: Optional[List[int]] = None,
obj_id: Optional[int] = None,
):
"""Add text, box and/or point prompt on a specific video frame."""
logger.debug(
f"add prompt on frame {frame_idx} in session {session_id}: "
f"{text=}, {points=}, {point_labels=}, "
f"{bounding_boxes=}, {bounding_box_labels=}"
)
session = self._get_session(session_id)
inference_state = session["state"]
frame_idx, outputs = self.model.add_prompt(
inference_state=inference_state,
frame_idx=frame_idx,
text_str=text,
points=points,
point_labels=point_labels,
boxes_xywh=bounding_boxes,
box_labels=bounding_box_labels,
obj_id=obj_id,
)
return {"frame_index": frame_idx, "outputs": outputs}
def remove_object(
self,
session_id: str,
obj_id: int,
is_user_action: bool = True,
):
"""Remove an object from tracking."""
logger.debug(
f"remove object {obj_id} in session {session_id}: " f"{is_user_action=}"
)
session = self._get_session(session_id)
inference_state = session["state"]
self.model.remove_object(
inference_state=inference_state,
obj_id=obj_id,
is_user_action=is_user_action,
)
return {"is_success": True}
def propagate_in_video(
self,
session_id,
propagation_direction,
start_frame_idx,
max_frame_num_to_track,
):
"""Propagate the added prompts to get grounding results on all video frames."""
logger.debug(
f"propagate in video in session {session_id}: "
f"{propagation_direction=}, {start_frame_idx=}, {max_frame_num_to_track=}"
)
try:
session = self._get_session(session_id)
inference_state = session["state"]
if propagation_direction not in ["both", "forward", "backward"]:
raise ValueError(
f"invalid propagation direction: {propagation_direction}"
)
# First doing the forward propagation
if propagation_direction in ["both", "forward"]:
for frame_idx, outputs in self.model.propagate_in_video(
inference_state=inference_state,
start_frame_idx=start_frame_idx,
max_frame_num_to_track=max_frame_num_to_track,
reverse=False,
):
yield {"frame_index": frame_idx, "outputs": outputs}
# Then doing the backward propagation (reverse in time)
if propagation_direction in ["both", "backward"]:
for frame_idx, outputs in self.model.propagate_in_video(
inference_state=inference_state,
start_frame_idx=start_frame_idx,
max_frame_num_to_track=max_frame_num_to_track,
reverse=True,
):
yield {"frame_index": frame_idx, "outputs": outputs}
finally:
# Log upon completion (so that e.g. we can see if two propagations happen in parallel).
# Using `finally` here to log even when the tracking is aborted with GeneratorExit.
logger.debug(
f"propagation ended in session {session_id}; {self._get_session_stats()}"
)
def reset_session(self, session_id):
"""Reset the session to its initial state (as when it's initial opened)."""
logger.debug(f"reset session {session_id}")
session = self._get_session(session_id)
inference_state = session["state"]
self.model.reset_state(inference_state)
return {"is_success": True}
def close_session(self, session_id):
"""
Close a session. This method is idempotent and can be called multiple
times on the same "session_id".
"""
session = self._ALL_INFERENCE_STATES.pop(session_id, None)
if session is None:
logger.warning(
f"cannot close session {session_id} as it does not exist (it might have expired); "
f"{self._get_session_stats()}"
)
else:
del session
gc.collect()
logger.info(f"removed session {session_id}; {self._get_session_stats()}")
return {"is_success": True}
def _get_session(self, session_id):
session = self._ALL_INFERENCE_STATES.get(session_id, None)
if session is None:
raise RuntimeError(
f"Cannot find session {session_id}; it might have expired"
)
return session
def _get_session_stats(self):
"""Get a statistics string for live sessions and their GPU usage."""
# print both the session ids and their video frame numbers
live_session_strs = [
f"'{session_id}' ({session['state']['num_frames']} frames)"
for session_id, session in self._ALL_INFERENCE_STATES.items()
]
session_stats_str = (
f"live sessions: [{', '.join(live_session_strs)}], GPU memory: "
f"{torch.cuda.memory_allocated() // 1024**2} MiB used and "
f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved"
f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used "
f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)"
)
return session_stats_str
def _get_torch_and_gpu_properties(self):
"""Get a string for PyTorch and GPU properties (for logging and debugging)."""
torch_and_gpu_str = (
f"torch: {torch.__version__} with CUDA arch {torch.cuda.get_arch_list()}, "
f"GPU device: {torch.cuda.get_device_properties(torch.cuda.current_device())}"
)
return torch_and_gpu_str
def shutdown(self):
"""Shutdown the predictor and clear all sessions."""
self._ALL_INFERENCE_STATES.clear()
class Sam3VideoPredictorMultiGPU(Sam3VideoPredictor):
def __init__(self, *model_args, gpus_to_use=None, **model_kwargs):
if gpus_to_use is None:
# if not specified, use only the current GPU by default
gpus_to_use = [torch.cuda.current_device()]
IS_MAIN_PROCESS = os.getenv("IS_MAIN_PROCESS", "1") == "1"
if IS_MAIN_PROCESS:
gpus_to_use = sorted(set(gpus_to_use))
logger.info(f"using the following GPU IDs: {gpus_to_use}")
assert len(gpus_to_use) > 0 and all(isinstance(i, int) for i in gpus_to_use)
assert all(0 <= i < torch.cuda.device_count() for i in gpus_to_use)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = f"{self._find_free_port()}"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = f"{len(gpus_to_use)}"
self.gpus_to_use = gpus_to_use
self.rank = int(os.environ["RANK"])
self.world_size = int(os.environ["WORLD_SIZE"])
self.rank_str = f"rank={self.rank} with world_size={self.world_size}"
self.device = torch.device(f"cuda:{self.gpus_to_use[self.rank]}")
torch.cuda.set_device(self.device)
self.has_shutdown = False
if self.rank == 0:
logger.info("\n\n\n\t*** START loading model on all ranks ***\n\n")
logger.info(f"loading model on {self.rank_str} -- this could take a while ...")
super().__init__(*model_args, **model_kwargs)
logger.info(f"loading model on {self.rank_str} -- DONE locally")
if self.world_size > 1 and self.rank == 0:
# start the worker processes *after* the model is loaded in the main process
# so that the main process can run torch.compile and fill the cache first
self._start_worker_processes(*model_args, **model_kwargs)
for rank in range(1, self.world_size):
self.command_queues[rank].put(("start_nccl_process_group", None))
self._start_nccl_process_group()
if self.rank == 0:
logger.info("\n\n\n\t*** DONE loading model on all ranks ***\n\n")
@torch.inference_mode()
def handle_request(self, request):
"""Dispatch a request based on its type."""
if self.has_shutdown:
raise RuntimeError(
"cannot handle request after the predictor has shutdown; please create a new predictor"
)
# when starting a session, we need to create a session id before dispatching
# the request to the workers
if request["type"] == "start_session" and request.get("session_id") is None:
request["session_id"] = str(uuid.uuid4())
# dispatch the request to all worker processes
if self.world_size > 1 and self.rank == 0:
for rank in range(1, self.world_size):
self.command_queues[rank].put((request, False))
response = super().handle_request(request)
if self.world_size > 1:
torch.distributed.barrier() # wait for all ranks to finish
return response
@torch.inference_mode()
def handle_stream_request(self, request):
"""Dispatch a stream request based on its type."""
if self.has_shutdown:
raise RuntimeError(
"cannot handle request after the predictor has shutdown; please create a new predictor"
)
# dispatch the request to all worker processes
if self.world_size > 1 and self.rank == 0:
for rank in range(1, self.world_size):
self.command_queues[rank].put((request, True))
yield from super().handle_stream_request(request)
if self.world_size > 1:
torch.distributed.barrier() # wait for all ranks to finish
def _start_worker_processes(self, *model_args, **model_kwargs):
"""Start worker processes for handling model inference."""
world_size = self.world_size
logger.info(f"spawning {world_size - 1} worker processes")
# Use "spawn" (instead of "fork") for different PyTorch or CUDA context
mp_ctx = mp.get_context("spawn")
self.command_queues = {rank: mp_ctx.Queue() for rank in range(1, world_size)}
self.result_queues = {rank: mp_ctx.Queue() for rank in range(1, world_size)}
parent_pid = os.getpid()
for rank in range(1, world_size):
# set the environment variables for each worker process
os.environ["IS_MAIN_PROCESS"] = "0" # mark this as a worker process
os.environ["RANK"] = f"{rank}"
worker_process = mp_ctx.Process(
target=Sam3VideoPredictorMultiGPU._worker_process_command_loop,
args=(
rank,
world_size,
self.command_queues[rank],
self.result_queues[rank],
model_args,
model_kwargs,
self.gpus_to_use,
parent_pid,
),
daemon=True,
)
worker_process.start()
# revert the environment variables for the main process
os.environ["IS_MAIN_PROCESS"] = "1"
os.environ["RANK"] = "0"
# wait for all the worker processes to load the model and collect their PIDs
self.worker_pids = {}
for rank in range(1, self.world_size):
# a large timeout to cover potentially long model loading time due to compilation
_, worker_pid = self.result_queues[rank].get(timeout=7200)
self.worker_pids[rank] = worker_pid
logger.info(f"spawned {world_size - 1} worker processes")
def _start_nccl_process_group(self):
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if world_size == 1:
return
logger.debug(f"starting NCCL process group on {rank=} with {world_size=}")
assert not torch.distributed.is_initialized()
# use the "env://" init method with environment variables set in start_worker_processes
# a short 3-min timeout to quickly detect any synchronization failures
timeout_sec = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180"))
timeout = datetime.timedelta(seconds=timeout_sec)
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
timeout=timeout,
device_id=self.device,
)
# warm-up the NCCL process group by running a dummy all-reduce
tensor = torch.ones(1024, 1024).cuda()
torch.distributed.all_reduce(tensor)
logger.debug(f"started NCCL process group on {rank=} with {world_size=}")
def _find_free_port(self) -> int:
"""
Find a free port (a random free port from 1024 to 65535 will be selected)
https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number)
"""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
@staticmethod
def _worker_process_command_loop(
rank,
world_size,
command_queue,
result_queue,
model_args,
model_kwargs,
gpus_to_use,
parent_pid,
):
"""
The command loop for each worker process. It listens to commands from the main process
and executes them using the model.
"""
logger.info(f"starting worker process {rank=} with {world_size=}")
# verify that the environment variables are set correctly
assert int(os.environ["IS_MAIN_PROCESS"]) == 0
assert int(os.environ["RANK"]) == rank
assert int(os.environ["WORLD_SIZE"]) == world_size
# load the model in this worker process
predictor = Sam3VideoPredictorMultiGPU(
*model_args, gpus_to_use=gpus_to_use, **model_kwargs
)
logger.info(f"started worker {rank=} with {world_size=}")
# return the worker process id to the main process for bookkeeping
worker_pid = os.getpid()
result_queue.put(("load_model", worker_pid))
# wait for the command to start the NCCL process group
request_type, _ = command_queue.get(timeout=7200)
assert request_type == "start_nccl_process_group"
predictor._start_nccl_process_group()
# keep listening to commands from the main process
while True:
try:
request, is_stream_request = command_queue.get(timeout=5.0)
if request == "shutdown":
logger.info(f"worker {rank=} shutting down")
torch.distributed.destroy_process_group()
result_queue.put(("shutdown", True)) # acknowledge the shutdown
sys.exit(0)
logger.debug(f"worker {rank=} received request {request['type']=}")
if is_stream_request:
for _ in predictor.handle_stream_request(request):
pass # handle stream requests in a generator fashion
else:
predictor.handle_request(request)
except queue.Empty:
# Usually Python's multiprocessing module will shutdown all the daemon worker
# processes when the main process exits gracefully. However, the user may kill
# the main process using SIGKILL and thereby leaving no chance for the main process
# to clean up its daemon child processes. So here we manually check whether the
# parent process still exists (every 5 sec as in `command_queue.get` timeout).
if not psutil.pid_exists(parent_pid):
logger.info(
f"stopping worker {rank=} as its parent process has exited"
)
sys.exit(1)
except Exception as e:
logger.error(f"worker {rank=} exception: {e}", exc_info=True)
def shutdown(self):
"""Shutdown all worker processes."""
if self.rank == 0 and self.world_size > 1:
logger.info(f"shutting down {self.world_size - 1} worker processes")
for rank in range(1, self.world_size):
self.command_queues[rank].put(("shutdown", False))
torch.distributed.destroy_process_group()
for rank in range(1, self.world_size):
self.result_queues[rank].get() # wait for the worker to acknowledge
logger.info(f"shut down {self.world_size - 1} worker processes")
self.has_shutdown = True
super().shutdown()

View File

@@ -0,0 +1,328 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from collections import OrderedDict
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .model_misc import LayerScale
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: Optional[float] = None,
act_layer: Callable[[], nn.Module] = nn.GELU,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
):
super().__init__()
# Attention
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
# LayerNorm, LayerScale
self.ln_1 = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
self.ls_1 = (
LayerScale(d_model, ls_init_value)
if ls_init_value is not None
else nn.Identity()
)
self.ls_2 = (
LayerScale(d_model, ls_init_value)
if ls_init_value is not None
else nn.Identity()
)
# MLP
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model)),
]
)
)
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
if attn_mask is not None:
# Leave boolean masks as is
if not attn_mask.dtype == torch.bool:
attn_mask = attn_mask.to(q_x.dtype)
return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0]
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
k_x = (
self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
)
v_x = (
self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
)
x = q_x + self.ls_1(
self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
)
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: Optional[float] = None,
act_layer: Callable[[], nn.Module] = nn.GELU,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
compile_mode: Optional[str] = None,
use_act_checkpoint: bool = False,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = use_act_checkpoint
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
for _ in range(layers)
]
)
if compile_mode is not None:
self.forward = torch.compile(
self.forward, mode=compile_mode, fullgraph=True
)
if self.grad_checkpointing:
torch._dynamo.config.optimize_ddp = False
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for _, r in enumerate(self.resblocks):
if (
self.grad_checkpointing
and not torch.jit.is_scripting()
and self.training
):
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
else:
x = r(
x,
attn_mask=attn_mask,
)
return x
def text_global_pool(
x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
) -> Tuple[torch.Tensor, torch.Tensor]:
if pool_type == "first":
pooled, tokens = x[:, 0], x[:, 1:]
elif pool_type == "last":
pooled, tokens = x[:, -1], x[:, :-1]
elif pool_type == "argmax":
# take features from the eot embedding (eot_token is the highest number in each sequence)
assert text is not None
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
else:
pooled = tokens = x
return pooled, tokens
class TextTransformer(nn.Module):
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
mlp_ratio: float = 4.0,
ls_init_value: Optional[float] = None,
output_dim: int = 512,
no_causal_mask: bool = False,
pool_type: str = "none", # no pooling
proj_bias: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
output_tokens: bool = False,
use_ln_post: bool = True,
compile_mode: Optional[str] = None,
use_act_checkpoint: bool = False,
):
super().__init__()
assert pool_type in ("first", "last", "argmax", "none")
self.output_tokens = output_tokens
self.num_pos = self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.heads = heads
self.pool_type = pool_type
self.token_embedding = nn.Embedding(self.vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
compile_mode=compile_mode,
use_act_checkpoint=use_act_checkpoint,
)
self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
if no_causal_mask:
self.attn_mask = None
else:
self.register_buffer(
"attn_mask", self.build_causal_mask(), persistent=False
)
if proj_bias:
self.text_projection = nn.Linear(width, output_dim)
else:
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
def build_causal_mask(self) -> torch.Tensor:
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(
self, text: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_len = text.shape[1]
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
attn_mask = self.attn_mask
if attn_mask is not None:
attn_mask = attn_mask[:seq_len, :seq_len]
x = x + self.positional_embedding[:seq_len]
x = self.transformer(x, attn_mask=attn_mask)
x = self.ln_final(x)
pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
pooled = self.text_projection(pooled)
else:
pooled = pooled @ self.text_projection
if self.output_tokens:
return pooled, tokens
return pooled
class VETextEncoder(nn.Module):
def __init__(
self,
d_model: int,
tokenizer: Callable,
width: int = 1024,
heads: int = 16,
layers: int = 24,
context_length: int = 32,
vocab_size: int = 49408,
use_ln_post: bool = True,
compile_mode: Optional[str] = None,
use_act_checkpoint: bool = True,
):
super().__init__()
self.context_length = context_length
self.use_ln_post = use_ln_post
self.tokenizer = tokenizer
self.encoder = TextTransformer(
context_length=self.context_length,
vocab_size=vocab_size,
width=width,
heads=heads,
layers=layers,
# we want the tokens, not just the pooled output
output_tokens=True,
use_ln_post=use_ln_post,
compile_mode=compile_mode,
use_act_checkpoint=use_act_checkpoint,
)
self.resizer = nn.Linear(self.encoder.width, d_model)
def forward(
self,
text: Union[List[str], Tuple[torch.Tensor, torch.Tensor, dict]],
input_boxes: Optional[List] = None,
device: torch.device = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if isinstance(text[0], str):
# no use case for this
assert input_boxes is None or len(input_boxes) == 0, "not supported"
# Encode the text
tokenized = self.tokenizer(text, context_length=self.context_length).to(
device
) # [b, seq_len]
text_attention_mask = (tokenized != 0).bool()
# manually embed the tokens
inputs_embeds = self.encoder.token_embedding(
tokenized
) # [b, seq_len, d=1024]
_, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024]
assert text_memory.shape[1] == inputs_embeds.shape[1]
# Invert attention mask because its the opposite in pytorch transformer
text_attention_mask = text_attention_mask.ne(1)
# Transpose memory because pytorch's attention expects sequence first
text_memory = text_memory.transpose(0, 1)
# Resize the encoder hidden states to be of the same d_model as the decoder
text_memory_resized = self.resizer(text_memory)
else:
# The text is already encoded, use as is.
text_attention_mask, text_memory_resized, tokenized = text
inputs_embeds = tokenized["inputs_embeds"]
assert (
input_boxes is None or len(input_boxes) == 0
), "Can't replace boxes in text if it's already encoded"
# Note that the input_embeds are returned in pytorch's convention (sequence first)
return (
text_attention_mask,
text_memory_resized,
inputs_embeds.transpose(0, 1),
)

253
sam3/model/tokenizer_ve.py Normal file
View File

@@ -0,0 +1,253 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Text Tokenizer.
Copied and lightly adapted from VE repo, which in turn copied
from open_clip and openAI CLIP.
"""
import gzip
import html
import io
import os
import string
from functools import lru_cache
from typing import List, Optional, Union
import ftfy
import regex as re
import torch
from iopath.common.file_io import g_pathmgr
# https://stackoverflow.com/q/62691279
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEFAULT_CONTEXT_LENGTH = 77
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def _clean_canonicalize(x):
# basic, remove whitespace, remove punctuation, lower case
return canonicalize_text(basic_clean(x))
def _clean_lower(x):
# basic, remove whitespace, lower case
return whitespace_clean(basic_clean(x)).lower()
def _clean_whitespace(x):
# basic, remove whitespace
return whitespace_clean(basic_clean(x))
def get_clean_fn(type: str):
if type == "canonicalize":
return _clean_canonicalize
elif type == "lower":
return _clean_lower
elif type == "whitespace":
return _clean_whitespace
else:
assert False, f"Invalid clean function ({type})."
def canonicalize_text(text, *, keep_punctuation_exact_string=None):
"""Returns canonicalized `text` (lowercase and punctuation removed).
From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
Args:
text: string to be canonicalized.
keep_punctuation_exact_string: If provided, then this exact string kept.
For example providing '{}' will keep any occurrences of '{}' (but will
still remove '{' and '}' that appear separately).
"""
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans("", "", string.punctuation))
for part in text.split(keep_punctuation_exact_string)
)
else:
text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower()
text = re.sub(r"\s+", " ", text)
return text.strip()
class SimpleTokenizer(object):
def __init__(
self,
bpe_path: Union[str, os.PathLike],
additional_special_tokens: Optional[List[str]] = None,
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = "lower",
):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with g_pathmgr.open(bpe_path, "rb") as fh:
bpe_bytes = io.BytesIO(fh.read())
merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
# merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
merges = merges[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
special_tokens = ["<start_of_text>", "<end_of_text>"]
if additional_special_tokens:
special_tokens += additional_special_tokens
vocab.extend(special_tokens)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {t: t for t in special_tokens}
special = "|".join(special_tokens)
self.pat = re.compile(
special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE,
)
self.vocab_size = len(self.encoder)
self.all_special_ids = [self.encoder[t] for t in special_tokens]
self.sot_token_id = self.all_special_ids[0]
self.eot_token_id = self.all_special_ids[1]
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = self.clean_fn(text)
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = (
bytearray([self.byte_decoder[c] for c in text])
.decode("utf-8", errors="replace")
.replace("</w>", " ")
)
return text
def __call__(
self, texts: Union[str, List[str]], context_length: Optional[int] = None
) -> torch.LongTensor:
"""Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
context_length = context_length or self.context_length
assert context_length, "Please set a valid context length"
all_tokens = [
[self.sot_token_id] + self.encode(text) + [self.eot_token_id]
for text in texts
]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = self.eot_token_id
result[i, : len(tokens)] = torch.tensor(tokens)
return result

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

77
sam3/model/utils/misc.py Normal file
View File

@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from collections import defaultdict
from dataclasses import fields, is_dataclass
from typing import Any, Mapping, Protocol, runtime_checkable
import torch
def _is_named_tuple(x) -> bool:
return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")
@runtime_checkable
class _CopyableData(Protocol):
def to(self, device: torch.device, *args: Any, **kwargs: Any):
"""Copy data to the specified device"""
...
def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any):
"""Function that recursively copies data to a torch.device.
Args:
data: The data to copy to device
device: The device to which the data should be copied
args: positional arguments that will be passed to the `to` call
kwargs: keyword arguments that will be passed to the `to` call
Returns:
The data on the correct device
"""
if _is_named_tuple(data):
return type(data)(
**copy_data_to_device(data._asdict(), device, *args, **kwargs)
)
elif isinstance(data, (list, tuple)):
return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data)
elif isinstance(data, defaultdict):
return type(data)(
data.default_factory,
{
k: copy_data_to_device(v, device, *args, **kwargs)
for k, v in data.items()
},
)
elif isinstance(data, Mapping):
return type(data)(
{
k: copy_data_to_device(v, device, *args, **kwargs)
for k, v in data.items()
}
)
elif is_dataclass(data) and not isinstance(data, type):
new_data_class = type(data)(
**{
field.name: copy_data_to_device(
getattr(data, field.name), device, *args, **kwargs
)
for field in fields(data)
if field.init
}
)
for field in fields(data):
if not field.init:
setattr(
new_data_class,
field.name,
copy_data_to_device(
getattr(data, field.name), device, *args, **kwargs
),
)
return new_data_class
elif isinstance(data, _CopyableData):
return data.to(device, *args, **kwargs)
return data

View File

@@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Normalize, Resize, ToTensor
# Adapted from https://github.com/facebookresearch/sam2/blob/main/sam2/utils/transforms.py
class SAM2Transforms(nn.Module):
def __init__(
self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
):
"""
Transforms for SAM2.
"""
super().__init__()
self.resolution = resolution
self.mask_threshold = mask_threshold
self.max_hole_area = max_hole_area
self.max_sprinkle_area = max_sprinkle_area
self.mean = [0.5, 0.5, 0.5]
self.std = [0.5, 0.5, 0.5]
self.to_tensor = ToTensor()
self.transforms = torch.jit.script(
nn.Sequential(
Resize((self.resolution, self.resolution)),
Normalize(self.mean, self.std),
)
)
def __call__(self, x):
x = self.to_tensor(x)
return self.transforms(x)
def forward_batch(self, img_list):
img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
img_batch = torch.stack(img_batch, dim=0)
return img_batch
def transform_coords(
self, coords: torch.Tensor, normalize=False, orig_hw=None
) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
Returns
Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
"""
if normalize:
assert orig_hw is not None
h, w = orig_hw
coords = coords.clone()
coords[..., 0] = coords[..., 0] / w
coords[..., 1] = coords[..., 1] / h
coords = coords * self.resolution # unnormalize coords
return coords
def transform_boxes(
self, boxes: torch.Tensor, normalize=False, orig_hw=None
) -> torch.Tensor:
"""
Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
"""
boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
return boxes
def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
"""
Perform PostProcessing on output masks.
"""
masks = masks.float()
input_masks = masks
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
try:
from sam3.perflib.connected_components import connected_components
if self.max_hole_area > 0:
# Holes are those connected components in background with area <= self.fill_hole_area
# (background regions are those with mask scores <= self.mask_threshold)
labels, areas = connected_components(
(mask_flat <= self.mask_threshold).to(torch.uint8)
)
is_hole = (labels > 0) & (areas <= self.max_hole_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with a small positive mask score (10.0) to change them to foreground.
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
if self.max_sprinkle_area > 0:
labels, areas = connected_components(
(mask_flat > self.mask_threshold).to(torch.uint8)
)
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with negative mask score (-10.0) to change them to background.
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
except Exception as e:
# Skip the post-processing step if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"still use SAM 3 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/sam3/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
masks = input_masks
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
return masks

View File

@@ -0,0 +1,233 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from threading import Thread
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
def _load_img_as_tensor(img_path, image_size):
img_pil = Image.open(img_path)
img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
img_np = img_np / 255.0
else:
raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
img = torch.from_numpy(img_np).permute(2, 0, 1)
video_width, video_height = img_pil.size # the original video size
return img, video_height, video_width
class AsyncVideoFrameLoader:
"""
A list of video frames to be load asynchronously without blocking session start.
"""
def __init__(
self,
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.img_paths = img_paths
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
self.img_mean = img_mean
self.img_std = img_std
# items in `self.images` will be loaded asynchronously
self.images = [None] * len(img_paths)
# catch and raise any exceptions in the async loading thread
self.exception = None
# video_height and video_width be filled when loading the first image
self.video_height = None
self.video_width = None
self.compute_device = compute_device
# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
self.__getitem__(0)
# load the rest of frames asynchronously without blocking the session start
def _load_frames():
try:
for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
self.__getitem__(n)
except Exception as e:
self.exception = e
self.thread = Thread(target=_load_frames, daemon=True)
self.thread.start()
def __getitem__(self, index):
if self.exception is not None:
raise RuntimeError("Failure in frame loading thread") from self.exception
img = self.images[index]
if img is not None:
return img
img, video_height, video_width = _load_img_as_tensor(
self.img_paths[index], self.image_size
)
self.video_height = video_height
self.video_width = video_width
# normalize by mean and std
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.to(self.compute_device, non_blocking=True)
self.images[index] = img
return img
def __len__(self):
return len(self.images)
def load_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
"""
Load the video frames from video_path. The frames are resized to image_size as in
the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo.
"""
is_bytes = isinstance(video_path, bytes)
is_str = isinstance(video_path, str)
is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"]
if is_bytes or is_mp4_path:
return load_video_frames_from_video_file(
video_path=video_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
compute_device=compute_device,
)
elif is_str and os.path.isdir(video_path):
return load_video_frames_from_jpg_images(
video_path=video_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
else:
raise NotImplementedError(
"Only MP4 video and JPEG folder are supported at this moment"
)
def load_video_frames_from_jpg_images(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
The frames are resized to image_size x image_size and are loaded to GPU if
`offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
You can load a frame asynchronously by setting `async_loading_frames` to `True`.
"""
if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path
else:
raise NotImplementedError(
"Only JPEG frames are supported at this moment. For video files, you may use "
"ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
"```\n"
"ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
"```\n"
"where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
"ffmpeg to start the JPEG file from 00000.jpg."
)
frame_names = [
p
for p in os.listdir(jpg_folder)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
num_frames = len(frame_names)
if num_frames == 0:
raise RuntimeError(f"no images found in {jpg_folder}")
img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
)
return lazy_images, lazy_images.video_height, lazy_images.video_width
images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
def load_video_frames_from_video_file(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
compute_device=torch.device("cuda"),
):
"""Load the video frames from a video file."""
import decord
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
# Get the original video height and width
decord.bridge.set_bridge("torch")
video_height, video_width, _ = decord.VideoReader(video_path).next().shape
# Iterate over all frames in the video
images = []
for frame in decord.VideoReader(video_path, width=image_size, height=image_size):
images.append(frame.permute(2, 0, 1))
images = torch.stack(images, dim=0).float() / 255.0
if not offload_video_to_cpu:
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width

879
sam3/model/vitdet.py Normal file
View File

@@ -0,0 +1,879 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
ViTDet backbone adapted from Detectron2.
This module implements Vision Transformer (ViT) backbone for object detection.
Rope embedding code adopted from:
1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
2. https://github.com/naver-ai/rope-vit
3. https://github.com/lucidrains/rotary-embedding-torch
"""
import math
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
try:
from timm.layers import DropPath, Mlp, trunc_normal_
except ModuleNotFoundError:
# compatibility for older timm versions
from timm.models.layers import DropPath, Mlp, trunc_normal_
from torch import Tensor
from .model_misc import LayerScale
def init_t_xy(
end_x: int, end_y: int, scale: float = 1.0, offset: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
t = torch.arange(end_x * end_y, dtype=torch.float32)
t_x = (t % end_x).float()
t_y = torch.div(t, end_x, rounding_mode="floor").float()
return t_x * scale + offset, t_y * scale + offset
def compute_axial_cis(
dim: int,
end_x: int,
end_y: int,
theta: float = 10000.0,
scale_pos: float = 1.0,
offset: int = 0,
) -> torch.Tensor:
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
t_x, t_y = init_t_xy(end_x, end_y, scale_pos, offset)
freqs_x = torch.outer(t_x, freqs_x)
freqs_y = torch.outer(t_y, freqs_y)
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_enc(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
repeat_freqs_k: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = (
torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
if xk.shape[-2] != 0
else None
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
if xk_ is None:
# no keys to rotate, due to dropout
return xq_out.type_as(xq).to(xq.device), xk
# repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
def window_partition(x: Tensor, window_size: int) -> Tuple[Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.reshape(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :]
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: Tensor) -> Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
align_corners=False,
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def get_abs_pos(
abs_pos: Tensor,
has_cls_token: bool,
hw: Tuple[int, int],
retain_cls_token: bool = False,
tiling: bool = False,
) -> Tensor:
"""
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
dimension for the original embeddings.
Args:
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
hw (Tuple): size of input image tokens.
retain_cls_token: whether to retain the cls_token
tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win)
Returns:
Absolute positional embeddings after processing with shape (1, H, W, C),
if retain_cls_token is False, otherwise (1, 1+H*W, C)
"""
if retain_cls_token:
assert has_cls_token
h, w = hw
if has_cls_token:
cls_pos = abs_pos[:, :1]
abs_pos = abs_pos[:, 1:]
xy_num = abs_pos.shape[1]
size = int(math.sqrt(xy_num))
assert size * size == xy_num
if size != h or size != w:
new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2)
if tiling:
new_abs_pos = new_abs_pos.tile(
[1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])]
)[:, :, :h, :w]
else:
new_abs_pos = F.interpolate(
new_abs_pos,
size=(h, w),
mode="bicubic",
align_corners=False,
)
if not retain_cls_token:
return new_abs_pos.permute(0, 2, 3, 1)
else:
# add cls_token back, flatten spatial dims
assert has_cls_token
return torch.cat(
[cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)],
dim=1,
)
else:
if not retain_cls_token:
return abs_pos.reshape(1, h, w, -1)
else:
assert has_cls_token
return torch.cat([cls_pos, abs_pos], dim=1)
def concat_rel_pos(
q: Tensor,
k: Tensor,
q_hw: Tuple[int, int],
k_hw: Tuple[int, int],
rel_pos_h: Tensor,
rel_pos_w: Tensor,
rescale: bool = False,
relative_coords: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now
effectively including rel pos biases.
Args:
q (Tensor): q tensor with shape (B, L_q, C).
k (Tensor): k tensor with shape (B, L_k, C).
q_hw, k_hw: These are spatial size of q & k tensors.
rel_pos_h, rel_pos_w: These are relative pos embeddings/params of height, width.
rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will
scale by the wrong factor due to the concat.
Returns:
q, k: But, padded so that qk^T accounts for rel pos biases
"""
q_h, q_w = q_hw
k_h, k_w = k_hw
assert (q_h == q_w) and (k_h == k_w), "only square inputs supported"
if relative_coords is not None:
Rh = rel_pos_h[relative_coords]
Rw = rel_pos_w[relative_coords]
else:
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
old_scale = dim**0.5
new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa
# attn will be divided by new_scale, but we want to divide q by old_scale
scale_ratio = new_scale / old_scale
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w)
eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device)
eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device)
eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h])
eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w])
q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1)
k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(
B, k_h * k_w, -1
)
return q, k
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
bias: bool = True,
):
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
def forward(self, x: Tensor) -> Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings and 2d-rope."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
cls_token: bool = False,
use_rope: bool = False,
rope_theta: float = 10000.0,
rope_pt_size: Optional[Tuple[int, int]] = None,
rope_interp: bool = False,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution for calculating the relative positional
parameter size or rope size.
attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer".
cls_token: whether a cls_token is present.
use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
rope_theta: control frequencies of rope
rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
rope_interp: whether to interpolate (or extrapolate) rope to match input size
"""
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.cls_token = cls_token
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
# rel_pos embeddings and rope
self.use_rel_pos = use_rel_pos
self.input_size = input_size
self.use_rope = use_rope
self.rope_theta = rope_theta
self.rope_pt_size = rope_pt_size
self.rope_interp = rope_interp
# init rel_pos embeddings and rope
self._setup_rel_pos(rel_pos_zero_init)
self._setup_rope_freqs()
def _setup_rel_pos(self, rel_pos_zero_init: bool = True) -> None:
if not self.use_rel_pos:
self.rel_pos_h = None
self.rel_pos_w = None
return
assert self.input_size is not None
assert self.cls_token is False, "not supported"
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(
torch.zeros(2 * self.input_size[0] - 1, self.head_dim)
)
self.rel_pos_w = nn.Parameter(
torch.zeros(2 * self.input_size[1] - 1, self.head_dim)
)
if not rel_pos_zero_init:
trunc_normal_(self.rel_pos_h, std=0.02)
trunc_normal_(self.rel_pos_w, std=0.02)
# Precompute the relative coords
H, W = self.input_size
q_coords = torch.arange(H)[:, None]
k_coords = torch.arange(W)[None, :]
relative_coords = (q_coords - k_coords) + (H - 1)
self.register_buffer("relative_coords", relative_coords.long())
def _setup_rope_freqs(self) -> None:
if not self.use_rope:
self.freqs_cis = None
return
assert self.input_size is not None
# determine rope input size
if self.rope_pt_size is None:
self.rope_pt_size = self.input_size
# initialize 2d rope freqs
self.compute_cis = partial(
compute_axial_cis,
dim=self.head_dim,
theta=self.rope_theta,
)
# interpolate rope
scale_pos = 1.0
if self.rope_interp:
scale_pos = self.rope_pt_size[0] / self.input_size[0]
# get scaled freqs_cis
freqs_cis = self.compute_cis(
end_x=self.input_size[0],
end_y=self.input_size[1],
scale_pos=scale_pos,
)
if self.cls_token:
t = torch.zeros(
self.head_dim // 2,
dtype=torch.float32,
device=freqs_cis.device,
)
cls_freqs_cis = torch.polar(torch.ones_like(t), t)[None, :]
freqs_cis = torch.cat([cls_freqs_cis, freqs_cis], dim=0)
self.register_buffer("freqs_cis", freqs_cis)
def _apply_rope(self, q, k) -> Tuple[Tensor, Tensor]:
if not self.use_rope:
return q, k
assert self.freqs_cis is not None
return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis)
def forward(self, x: Tensor) -> Tensor:
s = 1 if self.cls_token else 0 # used to exclude cls_token
if x.ndim == 4:
B, H, W, _ = x.shape
assert s == 0 # no cls_token
L = H * W
ndim = 4
else:
assert x.ndim == 3
B, L, _ = x.shape
ndim = 3
H = W = math.sqrt(L - s)
# qkv with shape (3, B, nHead, L, C)
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, -1)
# q, k, v with shape (B, nHead, L, C)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
# handle rope and rel pos embeddings
q, k = self._apply_rope(q, k)
if self.use_rel_pos:
q, k = concat_rel_pos(
q.flatten(0, 1),
k.flatten(0, 1),
(H, W),
x.shape[1:3],
self.rel_pos_h,
self.rel_pos_w,
rescale=True,
relative_coords=self.relative_coords,
)
# sdpa expects [B, nheads, H*W, C] so we transpose back
q = q.reshape(B, self.num_heads, H * W, -1)
k = k.reshape(B, self.num_heads, H * W, -1)
x = F.scaled_dot_product_attention(q, k, v)
if ndim == 4:
x = (
x.view(B, self.num_heads, H, W, -1)
.permute(0, 2, 3, 1, 4)
.reshape(B, H, W, -1)
)
else:
x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
x = self.proj(x)
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_path: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
act_layer: Callable[..., nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
use_rope: bool = False,
rope_pt_size: Optional[Tuple[int, int]] = None,
rope_tiled: bool = False,
rope_interp: bool = False,
use_ve_rope: bool = False,
cls_token: bool = False,
dropout: float = 0.0,
init_values: Optional[float] = None,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then not
use window attention.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
dropout (float): Dropout rate.
cls_token: whether a cls_token is present.
use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
rope_interp: whether to interpolate (or extrapolate) rope to match target input size,
expected to specify source size as rope_pt_size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rope=use_rope,
rope_pt_size=rope_pt_size,
rope_interp=rope_interp,
cls_token=cls_token,
)
self.ls1 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=(dropout, 0.0),
)
self.ls2 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.dropout = nn.Dropout(dropout)
self.window_size = window_size
def forward(self, x: Tensor) -> Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.ls1(self.attn(x))
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + self.dropout(self.drop_path(x))
x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x)))))
return x
class ViT(nn.Module):
"""
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
"Exploring Plain Vision Transformer Backbones for Object Detection",
https://arxiv.org/abs/2203.16527
"""
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_path_rate: float = 0.0,
norm_layer: Union[Callable[..., nn.Module], str] = "LayerNorm",
act_layer: Callable[..., nn.Module] = nn.GELU,
use_abs_pos: bool = True,
tile_abs_pos: bool = True,
rel_pos_blocks: Union[Tuple[int, ...], bool] = (2, 5, 8, 11),
rel_pos_zero_init: bool = True,
window_size: int = 14,
global_att_blocks: Tuple[int, ...] = (2, 5, 8, 11),
use_rope: bool = False,
rope_pt_size: Optional[int] = None,
use_interp_rope: bool = False,
pretrain_img_size: int = 224,
pretrain_use_cls_token: bool = True,
retain_cls_token: bool = True,
dropout: float = 0.0,
return_interm_layers: bool = False,
init_values: Optional[float] = None, # for layerscale
ln_pre: bool = False,
ln_post: bool = False,
bias_patch_embed: bool = True,
compile_mode: Optional[str] = None,
use_act_checkpoint: bool = True,
):
"""
Args:
img_size (int): Input image size. Only relevant for rel pos or rope.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation.
rel_pos_blocks (list): Blocks which have rel pos embeddings.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention).
use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together).
rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling.
use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size,
expected to specify source size as rope_pt_size.
use_act_checkpoint (bool): If True, use activation checkpointing.
pretrain_img_size (int): input image size for pretraining models.
pretrain_use_cls_token (bool): If True, pretraining models use class token.
retain_cls_token: whether cls_token should be retained.
dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp.
return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks).
init_values: layer scale init, None for no layer scale.
ln_pre (bool): If True, apply layer norm before transformer blocks.
ln_post (bool): If True, apply layer norm after transformer blocks.
bias_patch_embed (bool): bias in conv for patch embed?
compile_mode (str): mode to compile the forward
"""
super().__init__()
self.pretrain_use_cls_token = pretrain_use_cls_token
window_block_indexes = [i for i in range(depth) if i not in global_att_blocks]
self.full_attn_ids = list(global_att_blocks)
self.rel_pos_blocks = [False] * depth
if isinstance(rel_pos_blocks, bool) and rel_pos_blocks:
self.rel_pos_blocks = [True] * depth
else:
for i in rel_pos_blocks:
self.rel_pos_blocks[i] = True
self.retain_cls_token = retain_cls_token
if self.retain_cls_token:
assert pretrain_use_cls_token
assert (
len(window_block_indexes) == 0
), "windowing not supported with cls token"
assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
scale = embed_dim**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(1, 1, embed_dim))
if isinstance(norm_layer, str):
norm_layer = partial(getattr(nn, norm_layer), eps=1e-5)
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
bias=bias_patch_embed,
)
# Handle absolute positional embedding
self.tile_abs_pos = tile_abs_pos
self.use_abs_pos = use_abs_pos
if self.tile_abs_pos:
assert self.use_abs_pos
if self.use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
num_patches = (pretrain_img_size // patch_size) * (
pretrain_img_size // patch_size
)
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
else:
self.pos_embed = None
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList()
cur_stage = 1
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=self.rel_pos_blocks[i],
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i in window_block_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
use_rope=use_rope,
rope_pt_size=(
(window_size, window_size)
if rope_pt_size is None
else (rope_pt_size, rope_pt_size)
),
rope_interp=use_interp_rope,
cls_token=self.retain_cls_token,
dropout=dropout,
init_values=init_values,
)
if i not in window_block_indexes:
cur_stage += 1
self.use_act_checkpoint = use_act_checkpoint
self.blocks.append(block)
self.return_interm_layers = return_interm_layers
self.channel_list = (
[embed_dim] * len(self.full_attn_ids)
if return_interm_layers
else [embed_dim]
)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
self.ln_pre = norm_layer(embed_dim) if ln_pre else nn.Identity()
self.ln_post = norm_layer(embed_dim) if ln_post else nn.Identity()
self.apply(self._init_weights)
if compile_mode is not None:
self.forward = torch.compile(
self.forward, mode=compile_mode, fullgraph=True
)
if self.use_act_checkpoint and self.training:
torch._dynamo.config.optimize_ddp = False
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.patch_embed(x)
h, w = x.shape[1], x.shape[2]
s = 0
if self.retain_cls_token:
# If cls_token is retained, we don't
# maintain spatial shape
x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
s = 1
if self.pos_embed is not None:
x = x + get_abs_pos(
self.pos_embed,
self.pretrain_use_cls_token,
(h, w),
self.retain_cls_token,
tiling=self.tile_abs_pos,
)
x = self.ln_pre(x)
outputs = []
for i, blk in enumerate(self.blocks):
if self.use_act_checkpoint and self.training:
x = checkpoint.checkpoint(blk, x, use_reentrant=False)
else:
x = blk(x)
if (i == self.full_attn_ids[-1]) or (
self.return_interm_layers and i in self.full_attn_ids
):
if i == self.full_attn_ids[-1]:
x = self.ln_post(x)
feats = x[:, s:]
if feats.ndim == 4:
feats = feats.permute(0, 3, 1, 2)
else:
assert feats.ndim == 3
h = w = math.sqrt(feats.shape[1])
feats = feats.reshape(
feats.shape[0], h, w, feats.shape[-1]
).permute(0, 3, 1, 2)
outputs.append(feats)
return outputs
def get_layer_id(self, layer_name: str) -> int:
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
num_layers = self.get_num_layers()
if layer_name.find("rel_pos") != -1:
return num_layers + 1
elif layer_name.find("ln_pre") != -1:
return 0
elif layer_name.find("pos_embed") != -1 or layer_name.find("cls_token") != -1:
return 0
elif layer_name.find("patch_embed") != -1:
return 0
elif layer_name.find("blocks") != -1:
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
else:
return num_layers + 1
def get_num_layers(self) -> int:
return len(self.blocks)

176
sam3/model/vl_combiner.py Normal file
View File

@@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Provides utility to combine a vision backbone with a language backbone."""
from copy import copy
from typing import List, Optional
import torch
import torch.nn as nn
from torch.nn.attention import sdpa_kernel, SDPBackend
from .act_ckpt_utils import activation_ckpt_wrapper
from .necks import Sam3DualViTDetNeck
class SAM3VLBackbone(nn.Module):
"""This backbone combines a vision backbone and a language backbone without fusion.
As such it is more of a convenience wrapper to handle the two backbones together.
It adds support for activation checkpointing and compilation.
"""
def __init__(
self,
visual: Sam3DualViTDetNeck,
text,
compile_visual: bool = False,
act_ckpt_whole_vision_backbone: bool = False,
act_ckpt_whole_language_backbone: bool = False,
scalp=0,
):
"""Initialize the backbone combiner.
:param visual: The vision backbone to use
:param text: The text encoder to use
"""
super().__init__()
self.vision_backbone: Sam3DualViTDetNeck = (
torch.compile(visual) if compile_visual else visual
)
self.language_backbone = text
self.scalp = scalp
# allow running activation checkpointing on the entire vision and language backbones
self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
def forward(
self,
samples: torch.Tensor,
captions: List[str],
input_boxes: Optional[torch.Tensor] = None,
additional_text: Optional[List[str]] = None,
):
"""Forward pass of the backbone combiner.
:param samples: The input images
:param captions: The input captions
:param input_boxes: If the text contains place-holders for boxes, this
parameter contains the tensor containing their spatial features
:param additional_text: This can be used to encode some additional text
(different from the captions) in the same forward of the backbone
:return: Output dictionary with the following keys:
- vision_features: The output of the vision backbone
- language_features: The output of the language backbone
- language_mask: The attention mask of the language backbone
- vision_pos_enc: The positional encoding of the vision backbone
- (optional) additional_text_features: The output of the language
backbone for the additional text
- (optional) additional_text_mask: The attention mask of the
language backbone for the additional text
"""
output = self.forward_image(samples)
device = output["vision_features"].device
output.update(self.forward_text(captions, input_boxes, additional_text, device))
return output
def forward_image(self, samples: torch.Tensor):
return activation_ckpt_wrapper(self._forward_image_no_act_ckpt)(
samples=samples,
act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training,
)
def _forward_image_no_act_ckpt(self, samples):
# Forward through backbone
sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(
samples
)
if self.scalp > 0:
# Discard the lowest resolution features
sam3_features, sam3_pos = (
sam3_features[: -self.scalp],
sam3_pos[: -self.scalp],
)
if sam2_features is not None and sam2_pos is not None:
sam2_features, sam2_pos = (
sam2_features[: -self.scalp],
sam2_pos[: -self.scalp],
)
sam2_output = None
if sam2_features is not None and sam2_pos is not None:
sam2_src = sam2_features[-1]
sam2_output = {
"vision_features": sam2_src,
"vision_pos_enc": sam2_pos,
"backbone_fpn": sam2_features,
}
sam3_src = sam3_features[-1]
output = {
"vision_features": sam3_src,
"vision_pos_enc": sam3_pos,
"backbone_fpn": sam3_features,
"sam2_backbone_out": sam2_output,
}
return output
def forward_text(
self, captions, input_boxes=None, additional_text=None, device="cuda"
):
return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)(
captions=captions,
input_boxes=input_boxes,
additional_text=additional_text,
device=device,
act_ckpt_enable=self.act_ckpt_whole_language_backbone and self.training,
)
def _forward_text_no_ack_ckpt(
self,
captions,
input_boxes=None,
additional_text=None,
device="cuda",
):
output = {}
# Forward through text_encoder
text_to_encode = copy(captions)
if additional_text is not None:
# if there are additional_text, we piggy-back them into this forward.
# They'll be used later for output alignment
text_to_encode += additional_text
sdpa_context = sdpa_kernel(
[
SDPBackend.MATH,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.FLASH_ATTENTION,
]
)
with sdpa_context:
text_attention_mask, text_memory, text_embeds = self.language_backbone(
text_to_encode, input_boxes, device=device
)
if additional_text is not None:
output["additional_text_features"] = text_memory[:, -len(additional_text) :]
output["additional_text_mask"] = text_attention_mask[
-len(additional_text) :
]
text_memory = text_memory[:, : len(captions)]
text_attention_mask = text_attention_mask[: len(captions)]
text_embeds = text_embeds[:, : len(captions)]
output["language_features"] = text_memory
output["language_mask"] = text_attention_mask
output["language_embeds"] = (
text_embeds # Text embeddings before forward to the encoder
)
return output