Files
sam3_local/sam3/model/geometry_encoders.py
generatedunixname89002005307016 7b89b8fc3f Add missing Pyre mode headers] [batch:11/N] [shard:17/N]
Differential Revision: D90237984

fbshipit-source-id: 526fd760f303bf31be4f743bdcd77760496de0de
2026-01-07 05:16:41 -08:00

853 lines
33 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# pyre-unsafe
from typing import Tuple
import torch
import torch.nn as nn
import torchvision
from typing_extensions import override
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .model_misc import get_clones
def is_right_padded(mask):
"""Given a padding mask (following pytorch convention, 1s for padded values),
returns whether the padding is on the right or not."""
return (mask.long() == torch.sort(mask.long(), dim=-1)[0]).all()
def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
"""
Concatenates two right-padded sequences, such that the resulting sequence
is contiguous and also right-padded.
Following pytorch's convention, tensors are sequence first, and the mask are
batch first, with 1s for padded values.
:param seq1: A tensor of shape (seq1_length, batch_size, hidden_size).
:param mask1: A tensor of shape (batch_size, seq1_length).
:param seq2: A tensor of shape (seq2_length, batch_size, hidden_size).
:param mask2: A tensor of shape (batch_size, seq2_length).
:param return_index: If True, also returns the index of the ids of the element of seq2
in the concatenated sequence. This can be used to retrieve the elements of seq2
:return: A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
otherwise (concatenated_sequence, concatenated_mask, index).
"""
seq1_length, batch_size, hidden_size = seq1.shape
seq2_length, batch_size, hidden_size = seq2.shape
assert batch_size == seq1.size(1) == seq2.size(1) == mask1.size(0) == mask2.size(0)
assert hidden_size == seq1.size(2) == seq2.size(2)
assert seq1_length == mask1.size(1)
assert seq2_length == mask2.size(1)
torch._assert_async(is_right_padded(mask1))
torch._assert_async(is_right_padded(mask2))
actual_seq1_lengths = (~mask1).sum(dim=-1)
actual_seq2_lengths = (~mask2).sum(dim=-1)
final_lengths = actual_seq1_lengths + actual_seq2_lengths
max_length = seq1_length + seq2_length
concatenated_mask = (
torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1)
>= final_lengths[:, None]
)
# (max_len, batch_size, hidden_size)
concatenated_sequence = torch.zeros(
(max_length, batch_size, hidden_size), device=seq2.device, dtype=seq2.dtype
)
concatenated_sequence[:seq1_length, :, :] = seq1
# At this point, the element of seq1 are in the right place
# We just need to shift the elements of seq2
index = torch.arange(seq2_length, device=seq2.device)[:, None].repeat(1, batch_size)
index = index + actual_seq1_lengths[None]
concatenated_sequence = concatenated_sequence.scatter(
0, index[:, :, None].expand(-1, -1, hidden_size), seq2
)
if return_index:
return concatenated_sequence, concatenated_mask, index
return concatenated_sequence, concatenated_mask
class Prompt:
"""Utility class to manipulate geometric prompts.
We expect the sequences in pytorch convention, that is sequence first, batch second
The dimensions are expected as follows:
box_embeddings shape: N_boxes x B x C_box
box_mask shape: B x N_boxes. Can be None if nothing is masked out
point_embeddings shape: N_points x B x C_point
point_mask shape: B x N_points. Can be None if nothing is masked out
mask_embeddings shape: N_masks x B x 1 x H_mask x W_mask
mask_mask shape: B x N_masks. Can be None if nothing is masked out
We also store positive/negative labels. These tensors are also stored batch-first
If they are None, we'll assume positive labels everywhere
box_labels: long tensor of shape N_boxes x B
point_labels: long tensor of shape N_points x B
mask_labels: long tensor of shape N_masks x B
"""
def __init__(
self,
box_embeddings=None,
box_mask=None,
point_embeddings=None,
point_mask=None,
box_labels=None,
point_labels=None,
mask_embeddings=None,
mask_mask=None, # Attention mask for mask prompt
mask_labels=None,
):
# Check for null prompt
if (
box_embeddings is None
and point_embeddings is None
and mask_embeddings is None
):
self.box_embeddings = None
self.box_labels = None
self.box_mask = None
self.point_embeddings = None
self.point_labels = None
self.point_mask = None
self.mask_embeddings = None
self.mask_mask = None
# Masks are assumed positive only for now.
self.mask_labels = None
return
# Get sequence lengths and device
box_seq_len, point_seq_len, mask_seq_len, bs, device = (
self._init_seq_len_and_device(
box_embeddings, point_embeddings, mask_embeddings
)
)
# Initialize embeds, labels, attention masks.
box_embeddings, box_labels, box_mask = self._init_box(
box_embeddings, box_labels, box_mask, box_seq_len, bs, device
)
point_embeddings, point_labels, point_mask = self._init_point(
point_embeddings, point_labels, point_mask, point_seq_len, bs, device
)
mask_embeddings, mask_labels, mask_mask = self._init_mask(
mask_embeddings, mask_labels, mask_mask, mask_seq_len, bs, device
)
# Dimension checks
assert (
box_embeddings is not None
and list(box_embeddings.shape[:2])
== [
box_seq_len,
bs,
]
), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
assert (
box_mask is not None
and list(box_mask.shape)
== [
bs,
box_seq_len,
]
), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
assert (
point_embeddings is not None
and list(point_embeddings.shape[:2])
== [
point_seq_len,
bs,
]
), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
assert (
point_mask is not None
and list(point_mask.shape)
== [
bs,
point_seq_len,
]
), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
assert (
box_labels is not None
and list(box_labels.shape)
== [
box_seq_len,
bs,
]
), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
assert (
point_labels is not None
and list(point_labels.shape)
== [
point_seq_len,
bs,
]
), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
assert (
# Allowed to be None, we leave it to the encoder to check for validity before encoding.
mask_embeddings is None
or list(mask_embeddings.shape[:2])
== [
mask_seq_len,
bs,
]
), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
assert (
mask_mask is None
or list(mask_mask.shape)
== [
bs,
mask_seq_len,
]
), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
# Device checks
assert (
box_embeddings is not None and box_embeddings.device == device
), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
assert (
box_mask is not None and box_mask.device == device
), f"Expected box mask to be on device {device}, got {box_mask.device}"
assert (
box_labels is not None and box_labels.device == device
), f"Expected box labels to be on device {device}, got {box_labels.device}"
assert (
point_embeddings is not None and point_embeddings.device == device
), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
assert (
point_mask is not None and point_mask.device == device
), f"Expected point mask to be on device {device}, got {point_mask.device}"
assert (
point_labels is not None and point_labels.device == device
), f"Expected point labels to be on device {device}, got {point_labels.device}"
assert (
mask_embeddings is None or mask_embeddings.device == device
), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
assert (
mask_mask is None or mask_mask.device == device
), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
self.box_embeddings = box_embeddings
self.point_embeddings = point_embeddings
self.box_mask = box_mask
self.point_mask = point_mask
self.box_labels = box_labels
self.point_labels = point_labels
self.mask_embeddings = mask_embeddings
self.mask_labels = mask_labels
self.mask_mask = mask_mask
def _init_seq_len_and_device(
self, box_embeddings, point_embeddings, mask_embeddings
):
box_seq_len = point_seq_len = mask_seq_len = 0
bs = None
device = None
if box_embeddings is not None:
bs = box_embeddings.shape[1]
box_seq_len = box_embeddings.shape[0]
device = box_embeddings.device
if point_embeddings is not None:
point_seq_len = point_embeddings.shape[0]
if bs is not None:
assert (
bs == point_embeddings.shape[1]
), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
else:
bs = point_embeddings.shape[1]
if device is not None:
assert (
device == point_embeddings.device
), "Device mismatch between box and point embeddings"
else:
device = point_embeddings.device
if mask_embeddings is not None:
mask_seq_len = mask_embeddings.shape[0]
if bs is not None:
assert (
bs == mask_embeddings.shape[1]
), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
else:
bs = mask_embeddings.shape[1]
if device is not None:
assert (
device == mask_embeddings.device
), "Device mismatch between box/point and mask embeddings."
else:
device = mask_embeddings.device
return box_seq_len, point_seq_len, mask_seq_len, bs, device
def _init_box(self, box_embeddings, box_labels, box_mask, box_seq_len, bs, device):
if box_embeddings is None:
box_embeddings = torch.zeros(box_seq_len, bs, 4, device=device)
if box_labels is None:
box_labels = torch.ones(box_seq_len, bs, device=device, dtype=torch.long)
if box_mask is None:
box_mask = torch.zeros(bs, box_seq_len, device=device, dtype=torch.bool)
return box_embeddings, box_labels, box_mask
def _init_point(
self, point_embeddings, point_labels, point_mask, point_seq_len, bs, device
):
"""
Identical to _init_box. Except that C=2 for points (vs. 4 for boxes).
"""
if point_embeddings is None:
point_embeddings = torch.zeros(point_seq_len, bs, 2, device=device)
if point_labels is None:
point_labels = torch.ones(
point_seq_len, bs, device=device, dtype=torch.long
)
if point_mask is None:
point_mask = torch.zeros(bs, point_seq_len, device=device, dtype=torch.bool)
return point_embeddings, point_labels, point_mask
def _init_mask(
self, mask_embeddings, mask_labels, mask_mask, mask_seq_len, bs, device
):
# NOTE: Mask embeddings can be of arbitrary resolution, so we don't initialize it here.
# In case we append new mask, we check that its resolution matches exisiting ones (if any).
# In case mask_embeddings is None, we should never encode it.
if mask_labels is None:
mask_labels = torch.ones(mask_seq_len, bs, device=device, dtype=torch.long)
if mask_mask is None:
mask_mask = torch.zeros(bs, mask_seq_len, device=device, dtype=torch.bool)
return mask_embeddings, mask_labels, mask_mask
def append_boxes(self, boxes, labels, mask=None):
if self.box_embeddings is None:
self.box_embeddings = boxes
self.box_labels = labels
self.box_mask = mask
return
bs = self.box_embeddings.shape[1]
assert boxes.shape[1] == labels.shape[1] == bs
assert list(boxes.shape[:2]) == list(labels.shape[:2])
if mask is None:
mask = torch.zeros(
bs, boxes.shape[0], dtype=torch.bool, device=boxes.device
)
self.box_labels, _ = concat_padded_sequences(
self.box_labels.unsqueeze(-1), self.box_mask, labels.unsqueeze(-1), mask
)
self.box_labels = self.box_labels.squeeze(-1)
self.box_embeddings, self.box_mask = concat_padded_sequences(
self.box_embeddings, self.box_mask, boxes, mask
)
def append_points(self, points, labels, mask=None):
if self.point_embeddings is None:
self.point_embeddings = points
self.point_labels = labels
self.point_mask = mask
return
bs = self.point_embeddings.shape[1]
assert points.shape[1] == labels.shape[1] == bs
assert list(points.shape[:2]) == list(labels.shape[:2])
if mask is None:
mask = torch.zeros(
bs, points.shape[0], dtype=torch.bool, device=points.device
)
self.point_labels, _ = concat_padded_sequences(
self.point_labels.unsqueeze(-1), self.point_mask, labels.unsqueeze(-1), mask
)
self.point_labels = self.point_labels.squeeze(-1)
self.point_embeddings, self.point_mask = concat_padded_sequences(
self.point_embeddings, self.point_mask, points, mask
)
def append_masks(self, masks, labels=None, attn_mask=None):
if labels is not None:
assert list(masks.shape[:2]) == list(labels.shape[:2])
if self.mask_embeddings is None:
self.mask_embeddings = masks
mask_seq_len, bs = masks.shape[:2]
if labels is None:
self.mask_labels = torch.ones(
mask_seq_len, bs, device=masks.device, dtype=torch.long
)
else:
self.mask_labels = labels
if attn_mask is None:
self.mask_mask = torch.zeros(
bs, mask_seq_len, device=masks.device, dtype=torch.bool
)
else:
self.mask_mask = attn_mask
else:
raise NotImplementedError("Only one mask per prompt is supported.")
def clone(self):
return Prompt(
box_embeddings=(
None if self.box_embeddings is None else self.box_embeddings.clone()
),
box_mask=None if self.box_mask is None else self.box_mask.clone(),
point_embeddings=(
None if self.point_embeddings is None else self.point_embeddings.clone()
),
point_mask=None if self.point_mask is None else self.point_mask.clone(),
box_labels=None if self.box_labels is None else self.box_labels.clone(),
point_labels=(
None if self.point_labels is None else self.point_labels.clone()
),
)
class MaskEncoder(nn.Module):
"""
Base class for mask encoders.
"""
def __init__(
self,
mask_downsampler: nn.Module,
position_encoding: nn.Module,
):
super().__init__()
self.mask_downsampler = mask_downsampler
self.position_encoding = position_encoding
def forward(self, masks, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
masks = self.mask_downsampler(masks)
masks_pos = self.position_encoding(masks).to(masks.dtype)
return masks, masks_pos
class FusedMaskEncoder(MaskEncoder):
"""
Identical to memory.SimpleMaskEncoder but follows the interface of geometry_encoders.MaskEncoder.
We also remove the `skip_mask_sigmoid` option (to be handled outside the MaskEncoder).
Fuses backbone image features with mask features.
"""
def __init__(
self,
mask_downsampler: nn.Module,
position_encoding: nn.Module,
fuser: nn.Module,
in_dim: int = 256,
out_dim: int = 256,
):
super().__init__(mask_downsampler, position_encoding)
self.fuser = fuser
self.out_proj = nn.Identity()
if out_dim != in_dim:
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
@override
def forward(
self,
masks: torch.Tensor,
pix_feat: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
masks = self.mask_downsampler(masks)
## Fuse pix_feats and downsampled masks
# in case the visual features are on CPU, cast them to CUDA
pix_feat = pix_feat.to(masks.device)
x = self.pix_feat_proj(pix_feat)
x = x + masks
x = self.fuser(x)
x = self.out_proj(x)
pos = self.position_encoding(x).to(x.dtype)
return x, pos
class SequenceGeometryEncoder(nn.Module):
"""
This a fully fledged encoder for geometric prompts.
It assumes boxes are passed in the "normalized CxCyWH" format, and points in normalized xy
This allows flexibility in how to encode the features (eg do pooling)
Points and boxes can be encoded with any of the three possibilities:
- direct projection: we just compute a linear from coordinate space to d_model
- pooling: pool features from the backbone in the requested location.
For boxes, it's a roi align
For points it's a grid sample
- pos encoder: Take the position encoding of the point or box center
These three options are mutually compatible. If several are selected, we'll take a simple addition
As an alternative, we offer the possibility to encode points only.
In that case, the boxes are converted to two points for the top left and bottom right corners (with appropriate labels)
On top of these encodings, we offer the possibility to further encode the prompt sequence with a transformer.
"""
def __init__(
self,
encode_boxes_as_points: bool,
points_direct_project: bool,
points_pool: bool,
points_pos_enc: bool,
boxes_direct_project: bool,
boxes_pool: bool,
boxes_pos_enc: bool,
d_model: int,
pos_enc,
num_layers: int,
layer: nn.Module,
roi_size: int = 7, # for boxes pool
add_cls: bool = True,
add_post_encode_proj: bool = True,
mask_encoder: MaskEncoder = None,
add_mask_label: bool = False,
use_act_ckpt: bool = False,
):
super().__init__()
self.d_model = d_model
self.pos_enc = pos_enc
self.encode_boxes_as_points = encode_boxes_as_points
self.roi_size = roi_size
# There usually are two labels: positive and negatives.
# If we encode boxes as points, we have 3 types of points: regular, top left, bottom right
# These 3 types can be positives or negatives, hence 2*3 = 6 labels
num_labels = 6 if self.encode_boxes_as_points else 2
self.label_embed = torch.nn.Embedding(num_labels, self.d_model)
# This is a cls token, can be used for pooling if need be.
# It also ensures that the encoded sequences are always non-empty
self.cls_embed = None
if add_cls:
self.cls_embed = torch.nn.Embedding(1, self.d_model)
assert (
points_direct_project or points_pos_enc or points_pool
), "Error: need at least one way to encode points"
assert (
encode_boxes_as_points
or boxes_direct_project
or boxes_pos_enc
or boxes_pool
), "Error: need at least one way to encode boxes"
self.points_direct_project = None
if points_direct_project:
self.points_direct_project = nn.Linear(2, self.d_model)
self.points_pool_project = None
if points_pool:
self.points_pool_project = nn.Linear(self.d_model, self.d_model)
self.points_pos_enc_project = None
if points_pos_enc:
self.points_pos_enc_project = nn.Linear(self.d_model, self.d_model)
self.boxes_direct_project = None
self.boxes_pool_project = None
self.boxes_pos_enc_project = None
if not encode_boxes_as_points:
if boxes_direct_project:
self.boxes_direct_project = nn.Linear(4, self.d_model)
if boxes_pool:
self.boxes_pool_project = nn.Conv2d(
self.d_model, self.d_model, self.roi_size
)
if boxes_pos_enc:
self.boxes_pos_enc_project = nn.Linear(self.d_model + 2, self.d_model)
self.final_proj = None
if add_post_encode_proj:
self.final_proj = nn.Linear(self.d_model, self.d_model)
self.norm = nn.LayerNorm(self.d_model)
self.img_pre_norm = nn.Identity()
if self.points_pool_project is not None or self.boxes_pool_project is not None:
self.img_pre_norm = nn.LayerNorm(self.d_model)
self.encode = None
if num_layers > 0:
assert (
add_cls
), "It's currently highly recommended to add a CLS when using a transformer"
self.encode = get_clones(layer, num_layers)
self.encode_norm = nn.LayerNorm(self.d_model)
if mask_encoder is not None:
assert isinstance(
mask_encoder, MaskEncoder
), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
if add_mask_label:
self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
self.add_mask_label = add_mask_label
self.mask_encoder = mask_encoder
self.use_act_ckpt = use_act_ckpt
def _encode_points(self, points, points_mask, points_labels, img_feats):
points_embed = None
n_points, bs = points.shape[:2]
if self.points_direct_project is not None:
proj = self.points_direct_project(points)
assert points_embed is None
points_embed = proj
if self.points_pool_project is not None:
# points are [Num_points, bs, 2], normalized in [0, 1]
# the grid needs to be [Bs, H_out, W_out, 2] normalized in [-1,1]
# Will take H_out = num_points, w_out = 1
grid = points.transpose(0, 1).unsqueeze(2)
# re normalize to [-1, 1]
grid = (grid * 2) - 1
sampled = torch.nn.functional.grid_sample(
img_feats, grid, align_corners=False
)
assert list(sampled.shape) == [bs, self.d_model, n_points, 1]
sampled = sampled.squeeze(-1).permute(2, 0, 1)
proj = self.points_pool_project(sampled)
if points_embed is None:
points_embed = proj
else:
points_embed = points_embed + proj
if self.points_pos_enc_project is not None:
x, y = points.unbind(-1)
enc_x, enc_y = self.pos_enc._encode_xy(x.flatten(), y.flatten())
enc_x = enc_x.view(n_points, bs, enc_x.shape[-1])
enc_y = enc_y.view(n_points, bs, enc_y.shape[-1])
enc = torch.cat([enc_x, enc_y], -1)
proj = self.points_pos_enc_project(enc)
if points_embed is None:
points_embed = proj
else:
points_embed = points_embed + proj
type_embed = self.label_embed(points_labels.long())
return type_embed + points_embed, points_mask
def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats):
boxes_embed = None
n_boxes, bs = boxes.shape[:2]
if self.boxes_direct_project is not None:
proj = self.boxes_direct_project(boxes)
assert boxes_embed is None
boxes_embed = proj
if self.boxes_pool_project is not None:
H, W = img_feats.shape[-2:]
# boxes are [Num_boxes, bs, 4], normalized in [0, 1]
# We need to denormalize, and convert to [x, y, x, y]
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
scale = scale.view(1, 1, 4)
boxes_xyxy = boxes_xyxy * scale
sampled = torchvision.ops.roi_align(
img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size
)
assert list(sampled.shape) == [
bs * n_boxes,
self.d_model,
self.roi_size,
self.roi_size,
]
proj = self.boxes_pool_project(sampled)
proj = proj.view(bs, n_boxes, self.d_model).transpose(0, 1)
if boxes_embed is None:
boxes_embed = proj
else:
boxes_embed = boxes_embed + proj
if self.boxes_pos_enc_project is not None:
cx, cy, w, h = boxes.unbind(-1)
enc = self.pos_enc.encode_boxes(
cx.flatten(), cy.flatten(), w.flatten(), h.flatten()
)
enc = enc.view(boxes.shape[0], boxes.shape[1], enc.shape[-1])
proj = self.boxes_pos_enc_project(enc)
if boxes_embed is None:
boxes_embed = proj
else:
boxes_embed = boxes_embed + proj
type_embed = self.label_embed(boxes_labels.long())
return type_embed + boxes_embed, boxes_mask
def _encode_masks(
self,
masks: torch.Tensor,
attn_mask: torch.Tensor,
mask_labels: torch.Tensor,
img_feats: torch.Tensor = None,
):
n_masks, bs = masks.shape[:2]
assert (
n_masks == 1
), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
assert (
list(attn_mask.shape)
== [
bs,
n_masks,
]
), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
masks, pos = self.mask_encoder(
masks=masks.flatten(0, 1).float(),
pix_feat=img_feats,
)
H, W = masks.shape[-2:]
n_tokens_per_mask = H * W
# NOTE: We directly add pos enc here as we usually don't keep track of pos encoding for the concatenated prompt (text, other geometric prompts). Might need to do some refactoring for more flexibility.
masks = masks + pos
masks = masks.view(n_masks, bs, *masks.shape[1:]).flatten(
-2
) # n_masks x bs x C x H*W
masks = masks.permute(0, 3, 1, 2).flatten(0, 1) # n_masks * H*W x bs x C
attn_mask = attn_mask.repeat_interleave(n_tokens_per_mask, dim=1)
if self.add_mask_label:
masks = masks + self.mask_label_embed(mask_labels.long())
return masks, attn_mask
def forward(self, geo_prompt: Prompt, img_feats, img_sizes, img_pos_embeds=None):
points = geo_prompt.point_embeddings
points_mask = geo_prompt.point_mask
points_labels = geo_prompt.point_labels
boxes = geo_prompt.box_embeddings
boxes_mask = geo_prompt.box_mask
boxes_labels = geo_prompt.box_labels
masks = geo_prompt.mask_embeddings
masks_mask = geo_prompt.mask_mask
masks_labels = geo_prompt.mask_labels
seq_first_img_feats = img_feats[-1] # [H*W, B, C]
seq_first_img_pos_embeds = (
img_pos_embeds[-1]
if img_pos_embeds is not None
else torch.zeros_like(seq_first_img_feats)
)
if self.points_pool_project or self.boxes_pool_project:
assert len(img_feats) == len(img_sizes)
cur_img_feat = img_feats[-1]
cur_img_feat = self.img_pre_norm(cur_img_feat)
H, W = img_sizes[-1]
assert cur_img_feat.shape[0] == H * W
N, C = cur_img_feat.shape[-2:]
# Put back in NxCxHxW
cur_img_feat = cur_img_feat.permute(1, 2, 0)
cur_img_feat = cur_img_feat.view(N, C, H, W)
img_feats = cur_img_feat
if self.encode_boxes_as_points:
assert boxes is not None
assert geo_prompt.box_mask is not None
assert geo_prompt.box_labels is not None
assert boxes.shape[-1] == 4
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
top_left, bottom_right = boxes_xyxy.split(split_size=2, dim=-1)
labels_tl = geo_prompt.box_labels + 2
labels_br = geo_prompt.box_labels + 4
# Append to the existing points
points, _ = concat_padded_sequences(
points, points_mask, top_left, boxes_mask
)
points_labels, points_mask = concat_padded_sequences(
points_labels.unsqueeze(-1),
points_mask,
labels_tl.unsqueeze(-1),
boxes_mask,
)
points_labels = points_labels.squeeze(-1)
points, _ = concat_padded_sequences(
points, points_mask, bottom_right, boxes_mask
)
points_labels, points_mask = concat_padded_sequences(
points_labels.unsqueeze(-1),
points_mask,
labels_br.unsqueeze(-1),
boxes_mask,
)
points_labels = points_labels.squeeze(-1)
final_embeds, final_mask = self._encode_points(
points=points,
points_mask=points_mask,
points_labels=points_labels,
img_feats=img_feats,
)
if not self.encode_boxes_as_points:
boxes_embeds, boxes_mask = self._encode_boxes(
boxes=boxes,
boxes_mask=boxes_mask,
boxes_labels=boxes_labels,
img_feats=img_feats,
)
final_embeds, final_mask = concat_padded_sequences(
final_embeds, final_mask, boxes_embeds, boxes_mask
)
if masks is not None and self.mask_encoder is not None:
masks_embed, masks_mask = self._encode_masks(
masks=masks,
attn_mask=masks_mask,
mask_labels=masks_labels,
img_feats=img_feats,
)
if points.size(0) == boxes.size(0) == 0:
return masks_embed, masks_mask
bs = final_embeds.shape[1]
assert final_mask.shape[0] == bs
if self.cls_embed is not None:
cls = self.cls_embed.weight.view(1, 1, self.d_model).repeat(1, bs, 1)
cls_mask = torch.zeros(
bs, 1, dtype=final_mask.dtype, device=final_mask.device
)
final_embeds, final_mask = concat_padded_sequences(
final_embeds, final_mask, cls, cls_mask
)
if self.final_proj is not None:
final_embeds = self.norm(self.final_proj(final_embeds))
if self.encode is not None:
for lay in self.encode:
final_embeds = activation_ckpt_wrapper(lay)(
tgt=final_embeds,
memory=seq_first_img_feats,
tgt_key_padding_mask=final_mask,
pos=seq_first_img_pos_embeds,
act_ckpt_enable=self.training and self.use_act_ckpt,
)
final_embeds = self.encode_norm(final_embeds)
# Finally, concat mask embeddings if any
if masks is not None and self.mask_encoder is not None:
final_embeds, final_mask = concat_padded_sequences(
final_embeds, final_mask, masks_embed, masks_mask
)
return final_embeds, final_mask