Initial commit

fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
facebook-github-bot
2025-11-18 23:07:42 -08:00
commit a13e358df4
504 changed files with 122758 additions and 0 deletions

4
sam3/sam/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
from .transformer import TwoWayTransformer

39
sam3/sam/common.py Normal file
View File

@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Type
import torch
import torch.nn as nn
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x

319
sam3/sam/mask_decoder.py Normal file
View File

@@ -0,0 +1,319 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import List, Optional, Tuple, Type
import torch
from torch import nn
from torch.nn import functional as F
from .common import LayerNorm2d
class MaskDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
use_high_res_features: bool = False,
iou_prediction_use_sigmoid=False,
dynamic_multimask_via_stability=False,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
pred_obj_scores: bool = False,
pred_obj_scores_mlp: bool = False,
use_multimask_token_for_obj_ptr: bool = False,
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a
transformer architecture.
Arguments:
transformer_dim (int): the channel dimension of the transformer
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict
when disambiguating masks
activation (nn.Module): the type of activation to use when
upscaling masks
iou_head_depth (int): the depth of the MLP used to predict
mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP
used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.pred_obj_scores = pred_obj_scores
if self.pred_obj_scores:
self.obj_score_token = nn.Embedding(1, transformer_dim)
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
),
activation(),
)
self.use_high_res_features = use_high_res_features
if use_high_res_features:
self.conv_s0 = nn.Conv2d(
transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
)
self.conv_s1 = nn.Conv2d(
transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
)
self.output_hypernetworks_mlps = nn.ModuleList(
[
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for i in range(self.num_mask_tokens)
]
)
self.iou_prediction_head = MLP(
transformer_dim,
iou_head_hidden_dim,
self.num_mask_tokens,
iou_head_depth,
sigmoid_output=iou_prediction_use_sigmoid,
)
if self.pred_obj_scores:
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
if pred_obj_scores_mlp:
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
# When outputting a single mask, optionally we can dynamically fall back to the best
# multimask output token if the single mask output token gives low stability scores.
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
repeat_image: bool,
high_res_features: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
torch.Tensor: batched SAM token for mask output
"""
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
repeat_image=repeat_image,
high_res_features=high_res_features,
)
# Select the correct mask or masks for output
if multimask_output:
masks = masks[:, 1:, :, :]
iou_pred = iou_pred[:, 1:]
elif self.dynamic_multimask_via_stability and not self.training:
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
else:
masks = masks[:, 0:1, :, :]
iou_pred = iou_pred[:, 0:1]
if multimask_output and self.use_multimask_token_for_obj_ptr:
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
else:
# Take the mask output token. Here we *always* use the token for single mask output.
# At test time, even if we track after 1-click (and using multimask_output=True),
# we still take the single mask token here. The rationale is that we always track
# after multiple clicks during training, so the past tokens seen during training
# are always the single mask token (and we'll let it be the object-memory token).
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
# Prepare output
return masks, iou_pred, sam_tokens_out, object_score_logits
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
repeat_image: bool,
high_res_features: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
s = 0
if self.pred_obj_scores:
output_tokens = torch.cat(
[
self.obj_score_token.weight,
self.iou_token.weight,
self.mask_tokens.weight,
],
dim=0,
)
s = 1
else:
output_tokens = torch.cat(
[self.iou_token.weight, self.mask_tokens.weight], dim=0
)
output_tokens = output_tokens.unsqueeze(0).expand(
sparse_prompt_embeddings.size(0), -1, -1
)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
if repeat_image:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
assert image_embeddings.shape[0] == tokens.shape[0]
src = image_embeddings
src = src + dense_prompt_embeddings
assert (
image_pe.size(0) == 1
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, s, :]
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
if not self.use_high_res_features:
upscaled_embedding = self.output_upscaling(src)
else:
dc1, ln1, act1, dc2, act2 = self.output_upscaling
feat_s0, feat_s1 = high_res_features
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
)
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
if self.pred_obj_scores:
assert s == 1
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
else:
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
return masks, iou_pred, mask_tokens_out, object_score_logits
def _get_stability_scores(self, mask_logits):
"""
Compute stability scores of the mask logits based on the IoU between upper and
lower thresholds.
"""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
return stability_scores
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
"""
When outputting a single mask, if the stability score from the current single-mask
output (based on output token 0) falls below a threshold, we instead select from
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
"""
# The best mask from multimask output tokens (1~3)
multimask_logits = all_mask_logits[:, 1:, :, :]
multimask_iou_scores = all_iou_scores[:, 1:]
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
batch_inds = torch.arange(
multimask_iou_scores.size(0), device=all_iou_scores.device
)
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
best_multimask_logits = best_multimask_logits.unsqueeze(1)
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
# The mask from singlemask output token 0 and its stability score
singlemask_logits = all_mask_logits[:, 0:1, :, :]
singlemask_iou_scores = all_iou_scores[:, 0:1]
stability_scores = self._get_stability_scores(singlemask_logits)
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
# Dynamically fall back to best multimask output upon low stability scores.
mask_logits_out = torch.where(
is_stable[..., None, None].expand_as(singlemask_logits),
singlemask_logits,
best_multimask_logits,
)
iou_scores_out = torch.where(
is_stable.expand_as(singlemask_iou_scores),
singlemask_iou_scores,
best_multimask_iou_scores,
)
return mask_logits_out, iou_scores_out
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x

243
sam3/sam/prompt_encoder.py Normal file
View File

@@ -0,0 +1,243 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Any, Optional, Tuple, Type
import numpy as np
import torch
from torch import nn
from .common import LayerNorm2d
class PromptEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
image_embedding_size: Tuple[int, int],
input_image_size: Tuple[int, int],
mask_in_chans: int,
activation: Type[nn.Module] = nn.GELU,
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
Arguments:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
input_image_size (int): The padded size of the image as input
to the image encoder, as (H, W).
mask_in_chans (int): The number of hidden channels used for
encoding input masks.
activation (nn.Module): The activation to use when encoding
input masks.
"""
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
point_embeddings = [
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)
self.mask_input_size = (
4 * image_embedding_size[0],
4 * image_embedding_size[1],
)
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
self.no_mask_embed = nn.Embedding(1, embed_dim)
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts,
applied to a dense set of points the shape of the image encoding.
Returns:
torch.Tensor: Positional encoding with shape
1x(embed_dim)x(embedding_h)x(embedding_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size
)
point_embedding = torch.where(
(labels == -1).unsqueeze(-1),
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 0).unsqueeze(-1),
point_embedding + self.point_embeddings[0].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 1).unsqueeze(-1),
point_embedding + self.point_embeddings[1].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 2).unsqueeze(-1),
point_embedding + self.point_embeddings[2].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 3).unsqueeze(-1),
point_embedding + self.point_embeddings[3].weight,
point_embedding,
)
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(
coords, self.input_image_size
)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
"""Embeds mask inputs."""
mask_embedding = self.mask_downscaling(masks)
return mask_embedding
def _get_batch_size(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> int:
"""
Gets the batch size of the output given the batch size of the input prompts.
"""
if points is not None:
return points[0].shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def _get_device(self) -> torch.device:
return self.point_embeddings[0].weight.device
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty(
(bs, 0, self.embed_dim), device=self._get_device()
)
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
def forward_with_coords(
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # B x N x C

161
sam3/sam/rope.py Normal file
View File

@@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Adapted from:
1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
2. https://github.com/naver-ai/rope-vit
3. https://github.com/lucidrains/rotary-embedding-torch
"""
from typing import Optional
import torch
from einops import rearrange, repeat
from torch import broadcast_tensors, nn
def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0, device=None):
t = torch.arange(end_x * end_y, dtype=torch.float32, device=device)
t_x = (t % end_x).float()
t_y = torch.div(t, end_x, rounding_mode="floor").float()
return t_x * scale + offset, t_y * scale + offset
def compute_axial_cis(
dim: int,
end_x: int,
end_y: int,
theta: float = 10000.0,
scale_pos: float = 1.0,
offset: int = 0,
device=None,
):
freqs_x = 1.0 / (
theta ** (torch.arange(0, dim, 4, device=device)[: (dim // 4)].float() / dim)
)
freqs_y = 1.0 / (
theta ** (torch.arange(0, dim, 4, device=device)[: (dim // 4)].float() / dim)
)
t_x, t_y = init_t_xy(end_x, end_y, scale_pos, offset, device=device)
freqs_x = torch.outer(t_x, freqs_x)
freqs_y = torch.outer(t_y, freqs_y)
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_enc(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
repeat_freqs_k: bool = False,
):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = (
torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
if xk.shape[-2] != 0
else None
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
if xk_ is None:
# no keys to rotate, due to dropout
return xq_out.type_as(xq).to(xq.device), xk
# repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
def complex_mult(xq_real, xq_imag, freqs_cis_real, freqs_cis_imag):
# Compute the real part of the product
real_part = xq_real * freqs_cis_real - xq_imag * freqs_cis_imag
# Compute the imaginary part of the product
imag_part = xq_real * freqs_cis_imag + xq_imag * freqs_cis_real
# Stack the real and imaginary parts along the last dimension
return torch.stack([real_part, imag_part], dim=-1)
def apply_rotary_enc_real(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis_real: torch.Tensor,
freqs_cis_imag: torch.Tensor,
repeat_freqs_k: bool = False,
):
assert xk is not None
assert xk.shape[-2] != 0
xq_real = xq.float().reshape(*xq.shape[:-1], -1, 2)[..., 0]
xq_imag = xq.float().reshape(*xq.shape[:-1], -1, 2)[..., 1]
xk_real = xk.float().reshape(*xk.shape[:-1], -1, 2)[..., 0]
xk_imag = xk.float().reshape(*xk.shape[:-1], -1, 2)[..., 1]
freqs_cis_real = reshape_for_broadcast(freqs_cis_real, xq_real)
freqs_cis_imag = reshape_for_broadcast(freqs_cis_imag, xq_imag)
xq_out = complex_mult(xq_real, xq_imag, freqs_cis_real, freqs_cis_imag).flatten(3)
if repeat_freqs_k:
r = xk_real.shape[-2] // xq_real.shape[-2]
freqs_cis_real = freqs_cis_real.repeat(*([1] * (freqs_cis_real.ndim - 2)), r, 1)
freqs_cis_imag = freqs_cis_imag.repeat(*([1] * (freqs_cis_imag.ndim - 2)), r, 1)
xk_out = complex_mult(xk_real, xk_imag, freqs_cis_real, freqs_cis_imag).flatten(3)
# xq_out = torch.view_as_real(torch.complex(xq_real, xq_imag) * torch.complex(freqs_cis_real, freqs_cis_imag)).flatten(3)
# xk_out = torch.view_as_real(torch.compelx(xk_real, xk_imag) * torch.complex(freqs_cis_real, freqs_cis_imag)).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
# rotary embedding helper functions
def broadcat(tensors, dim=-1):
broadcasted_tensors = broadcast_tensors(*tensors)
return torch.cat(broadcasted_tensors, dim=dim)
def rotate_half(x: torch.Tensor):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class VisionRotaryEmbeddingVE(nn.Module):
def __init__(
self,
dim: int,
seq_len: int,
pt_seq_len: Optional[int] = None,
theta: float = 10000.0,
offset: int = 1, # specific to VE
):
super().__init__()
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
scale = 1.0
if pt_seq_len is not None:
scale = pt_seq_len / seq_len
# offset of +1 following VE - even though for the
# attention op only differences matter
t = torch.arange(seq_len) * scale + offset
freqs = torch.einsum("..., f -> ... f", t, freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs[None, :, :], freqs[:, None, :]), dim=-1)
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)
def forward(self, t: torch.Tensor):
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin

358
sam3/sam/transformer.py Normal file
View File

@@ -0,0 +1,358 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import math
from functools import partial
from typing import Tuple, Type
import torch
import torch.nn.functional as F
from sam3.sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis
from torch import nn, Tensor
from .common import MLPBlock
class TwoWayTransformer(nn.Module):
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
for i in range(depth):
self.layers.append(
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
)
)
self.final_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
def forward(
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
class TwoWayAttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
class Attention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding
after projection to queries, keys, and values.
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
dropout: float = 0.0,
kv_in_dim: int = None,
use_fa3: bool = False,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
self.use_fa3 = use_fa3
assert (
self.internal_dim % num_heads == 0
), "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
self.dropout_p = dropout
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
def _recombine_heads(self, x: Tensor) -> Tensor:
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
dropout_p = self.dropout_p if self.training else 0.0
# Attention
# with torch.backends.cuda.sdp_kernel(
# enable_flash=USE_FLASH_ATTN,
# # if Flash attention kernel is off, then math kernel needs to be enabled
# enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
# enable_mem_efficient=OLD_GPU,
# ):
# Let's trust the dispatcher....
if self.use_fa3:
from sam3.perflib.fa3 import flash_attn_func
assert dropout_p == 0.0
out = flash_attn_func(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
).transpose(1, 2)
else:
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
out = self.out_proj(out)
return out
class RoPEAttention(Attention):
"""Attention with rotary position encoding."""
def __init__(
self,
*args,
rope_theta=10000.0,
# whether to repeat q rope to match k length
# this is needed for cross-attention to memories
rope_k_repeat=False,
feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
use_rope_real=False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.use_rope_real = use_rope_real
self.compute_cis = partial(
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
)
device = torch.device("cuda") if torch.cuda.is_available() else None
self.freqs_cis = self.compute_cis(
end_x=feat_sizes[0], end_y=feat_sizes[1], device=device
)
if self.use_rope_real:
self.freqs_cis_real = self.freqs_cis.real
self.freqs_cis_imag = self.freqs_cis.imag
self.rope_k_repeat = rope_k_repeat
def forward(
self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
) -> Tensor:
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Apply rotary position encoding
w = h = math.sqrt(q.shape[-2])
if self.freqs_cis.shape[0] != q.shape[-2]:
self.freqs_cis = self.compute_cis(end_x=w, end_y=h, device=q.device)
self.freqs_cis_real = self.freqs_cis.real
self.freqs_cis_imag = self.freqs_cis.imag
if q.shape[-2] != k.shape[-2]:
assert self.rope_k_repeat
num_k_rope = k.size(-2) - num_k_exclude_rope
if self.use_rope_real:
q, k[:, :, :num_k_rope] = apply_rotary_enc_real(
q,
k[:, :, :num_k_rope],
freqs_cis_real=self.freqs_cis_real,
freqs_cis_imag=self.freqs_cis_imag,
repeat_freqs_k=self.rope_k_repeat,
)
else:
q, k[:, :, :num_k_rope] = apply_rotary_enc(
q,
k[:, :, :num_k_rope],
self.freqs_cis,
repeat_freqs_k=self.rope_k_repeat,
)
dropout_p = self.dropout_p if self.training else 0.0
# Attention
# with torch.backends.cuda.sdp_kernel(
# enable_flash=USE_FLASH_ATTN,
# # if Flash attention kernel is off, then math kernel needs to be enabled
# enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
# enable_mem_efficient=OLD_GPU,
# ):
# Let's trust the dispatcher....
if self.use_fa3:
from sam3.perflib.fa3 import flash_attn_func
assert dropout_p == 0.0
out = flash_attn_func(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
).transpose(1, 2)
else:
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
out = self.out_proj(out)
return out