Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
956
sam3/model/decoder.py
Normal file
956
sam3/model/decoder.py
Normal file
@@ -0,0 +1,956 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
"""
|
||||
Transformer decoder.
|
||||
Inspired from Pytorch's version, adds the pre-norm variant
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.sam.transformer import RoPEAttention
|
||||
|
||||
from torch import nn, Tensor
|
||||
from torchvision.ops.roi_align import RoIAlign
|
||||
|
||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||
|
||||
from .box_ops import box_cxcywh_to_xyxy
|
||||
|
||||
from .model_misc import (
|
||||
gen_sineembed_for_position,
|
||||
get_activation_fn,
|
||||
get_clones,
|
||||
inverse_sigmoid,
|
||||
MLP,
|
||||
)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation: str,
|
||||
d_model: int,
|
||||
dim_feedforward: int,
|
||||
dropout: float,
|
||||
cross_attention: nn.Module,
|
||||
n_heads: int,
|
||||
use_text_cross_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# cross attention
|
||||
self.cross_attn = cross_attention
|
||||
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
|
||||
# cross attention text
|
||||
self.use_text_cross_attention = use_text_cross_attention
|
||||
if use_text_cross_attention:
|
||||
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
||||
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.catext_norm = nn.LayerNorm(d_model)
|
||||
|
||||
# self attention
|
||||
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
||||
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
# ffn
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.activation = get_activation_fn(activation)
|
||||
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
|
||||
@staticmethod
|
||||
def with_pos_embed(tensor, pos):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_ffn(self, tgt):
|
||||
with torch.amp.autocast(device_type="cuda", enabled=False):
|
||||
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout4(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
# for tgt
|
||||
tgt: Optional[Tensor], # nq, bs, d_model
|
||||
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
|
||||
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
|
||||
memory_text: Optional[Tensor] = None, # num_token, bs, d_model
|
||||
text_attention_mask: Optional[Tensor] = None, # bs, num_token
|
||||
# for memory
|
||||
memory: Optional[Tensor] = None, # hw, bs, d_model
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_level_start_index: Optional[Tensor] = None, # num_levels
|
||||
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
|
||||
memory_pos: Optional[Tensor] = None, # pos for memory
|
||||
# sa
|
||||
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
|
||||
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
|
||||
# dac
|
||||
dac=False,
|
||||
dac_use_selfatt_ln=True,
|
||||
presence_token=None,
|
||||
# skip inside deformable attn
|
||||
identity=0.0,
|
||||
**kwargs, # additional kwargs for compatibility
|
||||
):
|
||||
"""
|
||||
Input:
|
||||
- tgt/tgt_query_pos: nq, bs, d_model
|
||||
-
|
||||
"""
|
||||
# self attention
|
||||
if self.self_attn is not None:
|
||||
if dac:
|
||||
# we only apply self attention to the first half of the queries
|
||||
assert tgt.shape[0] % 2 == 0
|
||||
num_o2o_queries = tgt.shape[0] // 2
|
||||
tgt_o2o = tgt[:num_o2o_queries]
|
||||
tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
|
||||
tgt_o2m = tgt[num_o2o_queries:]
|
||||
else:
|
||||
tgt_o2o = tgt
|
||||
tgt_query_pos_o2o = tgt_query_pos
|
||||
|
||||
if presence_token is not None:
|
||||
tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
|
||||
tgt_query_pos_o2o = torch.cat(
|
||||
[torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0
|
||||
)
|
||||
tgt_query_pos = torch.cat(
|
||||
[torch.zeros_like(presence_token), tgt_query_pos], dim=0
|
||||
)
|
||||
|
||||
q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
|
||||
tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0]
|
||||
tgt_o2o = tgt_o2o + self.dropout2(tgt2)
|
||||
if dac:
|
||||
if not dac_use_selfatt_ln:
|
||||
tgt_o2o = self.norm2(tgt_o2o)
|
||||
tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0) # Recombine
|
||||
if dac_use_selfatt_ln:
|
||||
tgt = self.norm2(tgt)
|
||||
else:
|
||||
tgt = tgt_o2o
|
||||
tgt = self.norm2(tgt)
|
||||
|
||||
if self.use_text_cross_attention:
|
||||
tgt2 = self.ca_text(
|
||||
self.with_pos_embed(tgt, tgt_query_pos),
|
||||
memory_text,
|
||||
memory_text,
|
||||
key_padding_mask=text_attention_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.catext_dropout(tgt2)
|
||||
tgt = self.catext_norm(tgt)
|
||||
|
||||
if presence_token is not None:
|
||||
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
|
||||
cross_attn_mask = torch.cat(
|
||||
[presence_token_mask, cross_attn_mask], dim=1
|
||||
) # (bs*nheads, 1+nq, hw)
|
||||
|
||||
# Cross attention to image
|
||||
tgt2 = self.cross_attn(
|
||||
query=self.with_pos_embed(tgt, tgt_query_pos),
|
||||
key=self.with_pos_embed(memory, memory_pos),
|
||||
value=memory,
|
||||
attn_mask=cross_attn_mask,
|
||||
key_padding_mask=(
|
||||
memory_key_padding_mask.transpose(0, 1)
|
||||
if memory_key_padding_mask is not None
|
||||
else None
|
||||
),
|
||||
)[0]
|
||||
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
# ffn
|
||||
tgt = self.forward_ffn(tgt)
|
||||
|
||||
presence_token_out = None
|
||||
if presence_token is not None:
|
||||
presence_token_out = tgt[:1]
|
||||
tgt = tgt[1:]
|
||||
|
||||
return tgt, presence_token_out
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
frozen: bool,
|
||||
interaction_layer,
|
||||
layer,
|
||||
num_layers: int,
|
||||
num_queries: int,
|
||||
return_intermediate: bool,
|
||||
box_refine: bool = False,
|
||||
num_o2m_queries: int = 0,
|
||||
dac: bool = False,
|
||||
boxRPB: str = "none",
|
||||
# Experimental: An object query for SAM 2 tasks
|
||||
instance_query: bool = False,
|
||||
# Defines the number of additional instance queries,
|
||||
# 1 or 4 are the most likely for single vs multi mask support
|
||||
num_instances: int = 1, # Irrelevant if instance_query is False
|
||||
dac_use_selfatt_ln: bool = True,
|
||||
use_act_checkpoint: bool = False,
|
||||
compile_mode=None,
|
||||
presence_token: bool = False,
|
||||
clamp_presence_logits: bool = True,
|
||||
clamp_presence_logit_max_val: float = 10.0,
|
||||
use_normed_output_consistently: bool = True,
|
||||
separate_box_head_instance: bool = False,
|
||||
separate_norm_instance: bool = False,
|
||||
resolution: Optional[int] = None,
|
||||
stride: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.layers = get_clones(layer, num_layers)
|
||||
self.fine_layers = (
|
||||
get_clones(interaction_layer, num_layers)
|
||||
if interaction_layer is not None
|
||||
else [None] * num_layers
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
self.num_queries = num_queries
|
||||
self.dac = dac
|
||||
if dac:
|
||||
self.num_o2m_queries = num_queries
|
||||
tot_num_queries = num_queries
|
||||
else:
|
||||
self.num_o2m_queries = num_o2m_queries
|
||||
tot_num_queries = num_queries + num_o2m_queries
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.return_intermediate = return_intermediate
|
||||
self.bbox_embed = MLP(d_model, d_model, 4, 3)
|
||||
self.query_embed = nn.Embedding(tot_num_queries, d_model)
|
||||
self.instance_query_embed = None
|
||||
self.instance_query_reference_points = None
|
||||
self.use_instance_query = instance_query
|
||||
self.num_instances = num_instances
|
||||
self.use_normed_output_consistently = use_normed_output_consistently
|
||||
|
||||
self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
|
||||
self.instance_bbox_embed = None
|
||||
if separate_box_head_instance:
|
||||
self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
|
||||
if instance_query:
|
||||
self.instance_query_embed = nn.Embedding(num_instances, d_model)
|
||||
self.box_refine = box_refine
|
||||
if box_refine:
|
||||
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
|
||||
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
|
||||
|
||||
self.reference_points = nn.Embedding(num_queries, 4)
|
||||
if instance_query:
|
||||
self.instance_reference_points = nn.Embedding(num_instances, 4)
|
||||
|
||||
assert boxRPB in ["none", "log", "linear", "both"]
|
||||
self.boxRPB = boxRPB
|
||||
if boxRPB != "none":
|
||||
try:
|
||||
nheads = self.layers[0].cross_attn_image.num_heads
|
||||
except AttributeError:
|
||||
nheads = self.layers[0].cross_attn.num_heads
|
||||
|
||||
n_input = 4 if boxRPB == "both" else 2
|
||||
self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
|
||||
self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
|
||||
self.compilable_cord_cache = None
|
||||
self.compilable_stored_size = None
|
||||
self.coord_cache = {}
|
||||
|
||||
if resolution is not None and stride is not None:
|
||||
feat_size = resolution // stride
|
||||
coords_h, coords_w = self._get_coords(
|
||||
feat_size, feat_size, device="cuda"
|
||||
)
|
||||
self.compilable_cord_cache = (coords_h, coords_w)
|
||||
self.compilable_stored_size = (feat_size, feat_size)
|
||||
|
||||
self.roi_pooler = (
|
||||
RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
|
||||
if interaction_layer is not None
|
||||
else None
|
||||
)
|
||||
if frozen:
|
||||
for p in self.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.presence_token = None
|
||||
self.clamp_presence_logits = clamp_presence_logits
|
||||
self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
|
||||
if presence_token:
|
||||
self.presence_token = nn.Embedding(1, d_model)
|
||||
self.presence_token_head = MLP(d_model, d_model, 1, 3)
|
||||
self.presence_token_out_norm = nn.LayerNorm(d_model)
|
||||
|
||||
self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
|
||||
self.dac_use_selfatt_ln = dac_use_selfatt_ln
|
||||
self.use_act_checkpoint = use_act_checkpoint
|
||||
|
||||
nn.init.normal_(self.query_embed.weight.data)
|
||||
if self.instance_query_embed is not None:
|
||||
nn.init.normal_(self.instance_query_embed.weight.data)
|
||||
|
||||
assert self.roi_pooler is None
|
||||
assert self.return_intermediate, "support return_intermediate only"
|
||||
assert self.box_refine, "support box refine only"
|
||||
|
||||
self.compile_mode = compile_mode
|
||||
self.compiled = False
|
||||
# We defer compilation till after the first forward, to first warm-up the boxRPB cache
|
||||
|
||||
# assign layer index to each layer so that some layers can decide what to do
|
||||
# based on which layer index they are (e.g. cross attention to memory bank only
|
||||
# in selected layers)
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
layer.layer_idx = layer_idx
|
||||
|
||||
@staticmethod
|
||||
def _get_coords(H, W, device):
|
||||
coords_h = torch.arange(0, H, device=device, dtype=torch.float32) / H
|
||||
coords_w = torch.arange(0, W, device=device, dtype=torch.float32) / W
|
||||
return coords_h, coords_w
|
||||
|
||||
def _get_rpb_matrix(self, reference_boxes, feat_size):
|
||||
H, W = feat_size
|
||||
boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1)
|
||||
bs, num_queries, _ = boxes_xyxy.shape
|
||||
if self.compilable_cord_cache is None:
|
||||
self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device)
|
||||
self.compilable_stored_size = (H, W)
|
||||
|
||||
if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
|
||||
H,
|
||||
W,
|
||||
):
|
||||
# good, hitting the cache, will be compilable
|
||||
coords_h, coords_w = self.compilable_cord_cache
|
||||
else:
|
||||
# cache miss, will create compilation issue
|
||||
# In case we're not compiling, we'll still rely on the dict-based cache
|
||||
if feat_size not in self.coord_cache:
|
||||
self.coord_cache[feat_size] = self._get_coords(
|
||||
H, W, reference_boxes.device
|
||||
)
|
||||
coords_h, coords_w = self.coord_cache[feat_size]
|
||||
|
||||
assert coords_h.shape == (H,)
|
||||
assert coords_w.shape == (W,)
|
||||
|
||||
deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
|
||||
deltas_y = deltas_y.view(bs, num_queries, -1, 2)
|
||||
deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
|
||||
deltas_x = deltas_x.view(bs, num_queries, -1, 2)
|
||||
|
||||
if self.boxRPB in ["log", "both"]:
|
||||
deltas_x_log = deltas_x * 8 # normalize to -8, 8
|
||||
deltas_x_log = (
|
||||
torch.sign(deltas_x_log)
|
||||
* torch.log2(torch.abs(deltas_x_log) + 1.0)
|
||||
/ np.log2(8)
|
||||
)
|
||||
|
||||
deltas_y_log = deltas_y * 8 # normalize to -8, 8
|
||||
deltas_y_log = (
|
||||
torch.sign(deltas_y_log)
|
||||
* torch.log2(torch.abs(deltas_y_log) + 1.0)
|
||||
/ np.log2(8)
|
||||
)
|
||||
if self.boxRPB == "log":
|
||||
deltas_x = deltas_x_log
|
||||
deltas_y = deltas_y_log
|
||||
else:
|
||||
deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
|
||||
deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)
|
||||
|
||||
if self.training:
|
||||
assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
|
||||
deltas_x = activation_ckpt_wrapper(self.boxRPB_embed_x)(
|
||||
x=deltas_x,
|
||||
act_ckpt_enable=self.training and self.use_act_checkpoint,
|
||||
) # bs, num_queries, W, n_heads
|
||||
deltas_y = activation_ckpt_wrapper(self.boxRPB_embed_y)(
|
||||
x=deltas_y,
|
||||
act_ckpt_enable=self.training and self.use_act_checkpoint,
|
||||
) # bs, num_queries, H, n_heads
|
||||
|
||||
if not torch.compiler.is_dynamo_compiling():
|
||||
assert deltas_x.shape[:3] == (bs, num_queries, W)
|
||||
assert deltas_y.shape[:3] == (bs, num_queries, H)
|
||||
|
||||
B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
|
||||
2
|
||||
) # bs, num_queries, H, W, n_heads
|
||||
if not torch.compiler.is_dynamo_compiling():
|
||||
assert B.shape[:4] == (bs, num_queries, H, W)
|
||||
B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
|
||||
B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
|
||||
B = B.contiguous() # memeff attn likes ordered strides
|
||||
if not torch.compiler.is_dynamo_compiling():
|
||||
assert B.shape[2:] == (num_queries, H * W)
|
||||
return B
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
reference_boxes: Optional[Tensor] = None, # num_queries, bs, 4
|
||||
# for memory
|
||||
level_start_index: Optional[Tensor] = None, # num_levels
|
||||
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
|
||||
valid_ratios: Optional[Tensor] = None,
|
||||
# for text
|
||||
memory_text: Optional[Tensor] = None,
|
||||
text_attention_mask: Optional[Tensor] = None,
|
||||
# if `apply_dac` is None, it will default to `self.dac`
|
||||
apply_dac: Optional[bool] = None,
|
||||
is_instance_prompt=False,
|
||||
decoder_extra_kwargs: Optional[Dict] = None,
|
||||
# ROI memory bank
|
||||
obj_roi_memory_feat=None,
|
||||
obj_roi_memory_mask=None,
|
||||
box_head_trk=None,
|
||||
):
|
||||
"""
|
||||
Input:
|
||||
- tgt: nq, bs, d_model
|
||||
- memory: \\sum{hw}, bs, d_model
|
||||
- pos: \\sum{hw}, bs, d_model
|
||||
- reference_boxes: nq, bs, 4 (after sigmoid)
|
||||
- valid_ratios/spatial_shapes: bs, nlevel, 2
|
||||
"""
|
||||
if memory_mask is not None:
|
||||
assert (
|
||||
self.boxRPB == "none"
|
||||
), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
|
||||
|
||||
apply_dac = apply_dac if apply_dac is not None else self.dac
|
||||
if apply_dac:
|
||||
assert (tgt.shape[0] == self.num_queries) or (
|
||||
self.use_instance_query
|
||||
and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
|
||||
)
|
||||
|
||||
tgt = tgt.repeat(2, 1, 1)
|
||||
# note that we don't tile tgt_mask, since DAC doesn't
|
||||
# use self-attention in o2m queries
|
||||
if reference_boxes is not None:
|
||||
assert (reference_boxes.shape[0] == self.num_queries) or (
|
||||
self.use_instance_query
|
||||
and (
|
||||
reference_boxes.shape[0]
|
||||
== self.instance_query_embed.num_embeddings
|
||||
)
|
||||
)
|
||||
reference_boxes = reference_boxes.repeat(2, 1, 1)
|
||||
|
||||
bs = tgt.shape[1]
|
||||
intermediate = []
|
||||
intermediate_presence_logits = []
|
||||
presence_feats = None
|
||||
|
||||
if self.box_refine:
|
||||
if reference_boxes is None:
|
||||
# In this case, we're in a one-stage model, so we generate the reference boxes
|
||||
reference_boxes = self.reference_points.weight.unsqueeze(1)
|
||||
reference_boxes = (
|
||||
reference_boxes.repeat(2, bs, 1)
|
||||
if apply_dac
|
||||
else reference_boxes.repeat(1, bs, 1)
|
||||
)
|
||||
reference_boxes = reference_boxes.sigmoid()
|
||||
intermediate_ref_boxes = [reference_boxes]
|
||||
else:
|
||||
reference_boxes = None
|
||||
intermediate_ref_boxes = None
|
||||
|
||||
output = tgt
|
||||
presence_out = None
|
||||
if self.presence_token is not None and is_instance_prompt is False:
|
||||
# expand to batch dim
|
||||
presence_out = self.presence_token.weight[None].expand(1, bs, -1)
|
||||
|
||||
box_head = self.bbox_embed
|
||||
if is_instance_prompt and self.instance_bbox_embed is not None:
|
||||
box_head = self.instance_bbox_embed
|
||||
|
||||
out_norm = self.norm
|
||||
if is_instance_prompt and self.instance_norm is not None:
|
||||
out_norm = self.instance_norm
|
||||
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
reference_points_input = (
|
||||
reference_boxes[:, :, None]
|
||||
* torch.cat([valid_ratios, valid_ratios], -1)[None, :]
|
||||
) # nq, bs, nlevel, 4
|
||||
|
||||
query_sine_embed = gen_sineembed_for_position(
|
||||
reference_points_input[:, :, 0, :], self.d_model
|
||||
) # nq, bs, d_model*2
|
||||
|
||||
# conditional query
|
||||
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
|
||||
|
||||
if self.boxRPB != "none" and reference_boxes is not None:
|
||||
assert (
|
||||
spatial_shapes.shape[0] == 1
|
||||
), "only single scale support implemented"
|
||||
memory_mask = self._get_rpb_matrix(
|
||||
reference_boxes,
|
||||
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
|
||||
)
|
||||
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
|
||||
if self.training:
|
||||
assert (
|
||||
self.use_act_checkpoint
|
||||
), "Activation checkpointing not enabled in the decoder"
|
||||
output, presence_out = activation_ckpt_wrapper(layer)(
|
||||
tgt=output,
|
||||
tgt_query_pos=query_pos,
|
||||
tgt_query_sine_embed=query_sine_embed,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
tgt_reference_points=reference_points_input,
|
||||
memory_text=memory_text,
|
||||
text_attention_mask=text_attention_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
memory_level_start_index=level_start_index,
|
||||
memory_spatial_shapes=spatial_shapes,
|
||||
memory_pos=pos,
|
||||
self_attn_mask=tgt_mask,
|
||||
cross_attn_mask=memory_mask,
|
||||
dac=apply_dac,
|
||||
dac_use_selfatt_ln=self.dac_use_selfatt_ln,
|
||||
presence_token=presence_out,
|
||||
**(decoder_extra_kwargs or {}),
|
||||
act_ckpt_enable=self.training and self.use_act_checkpoint,
|
||||
# ROI memory bank
|
||||
obj_roi_memory_feat=obj_roi_memory_feat,
|
||||
obj_roi_memory_mask=obj_roi_memory_mask,
|
||||
)
|
||||
|
||||
# iter update
|
||||
if self.box_refine:
|
||||
reference_before_sigmoid = inverse_sigmoid(reference_boxes)
|
||||
if box_head_trk is None:
|
||||
# delta_unsig = self.bbox_embed(output)
|
||||
if not self.use_normed_output_consistently:
|
||||
delta_unsig = box_head(output)
|
||||
else:
|
||||
delta_unsig = box_head(out_norm(output))
|
||||
else:
|
||||
# box_head_trk use a separate box head for tracking queries
|
||||
Q_det = decoder_extra_kwargs["Q_det"]
|
||||
assert output.size(0) >= Q_det
|
||||
delta_unsig_det = self.bbox_embed(output[:Q_det])
|
||||
delta_unsig_trk = box_head_trk(output[Q_det:])
|
||||
delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
|
||||
outputs_unsig = delta_unsig + reference_before_sigmoid
|
||||
new_reference_points = outputs_unsig.sigmoid()
|
||||
|
||||
reference_boxes = new_reference_points.detach()
|
||||
if layer_idx != self.num_layers - 1:
|
||||
intermediate_ref_boxes.append(new_reference_points)
|
||||
else:
|
||||
raise NotImplementedError("not implemented yet")
|
||||
|
||||
intermediate.append(out_norm(output))
|
||||
if self.presence_token is not None and is_instance_prompt is False:
|
||||
# norm, mlp head
|
||||
intermediate_layer_presence_logits = self.presence_token_head(
|
||||
self.presence_token_out_norm(presence_out)
|
||||
).squeeze(-1)
|
||||
|
||||
# clamp to mitigate numerical issues
|
||||
if self.clamp_presence_logits:
|
||||
intermediate_layer_presence_logits.clamp(
|
||||
min=-self.clamp_presence_logit_max_val,
|
||||
max=self.clamp_presence_logit_max_val,
|
||||
)
|
||||
|
||||
intermediate_presence_logits.append(intermediate_layer_presence_logits)
|
||||
presence_feats = presence_out.clone()
|
||||
|
||||
if not self.compiled and self.compile_mode is not None:
|
||||
self.forward = torch.compile(
|
||||
self.forward, mode=self.compile_mode, fullgraph=True
|
||||
)
|
||||
self.compiled = True
|
||||
|
||||
return (
|
||||
torch.stack(intermediate),
|
||||
torch.stack(intermediate_ref_boxes),
|
||||
(
|
||||
torch.stack(intermediate_presence_logits)
|
||||
if self.presence_token is not None and is_instance_prompt is False
|
||||
else None
|
||||
),
|
||||
presence_feats,
|
||||
)
|
||||
|
||||
|
||||
class TransformerEncoderCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
frozen: bool,
|
||||
pos_enc_at_input: bool,
|
||||
layer,
|
||||
num_layers: int,
|
||||
use_act_checkpoint: bool = False,
|
||||
batch_first: bool = False, # Do layers expect batch first input?
|
||||
# which layers to exclude cross attention? default: None, means all
|
||||
# layers use cross attention
|
||||
remove_cross_attention_layers: Optional[list] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.layers = get_clones(layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.pos_enc_at_input = pos_enc_at_input
|
||||
self.use_act_checkpoint = use_act_checkpoint
|
||||
|
||||
if frozen:
|
||||
for p in self.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.batch_first = batch_first
|
||||
|
||||
# remove cross attention layers if specified
|
||||
self.remove_cross_attention_layers = [False] * self.num_layers
|
||||
if remove_cross_attention_layers is not None:
|
||||
for i in remove_cross_attention_layers:
|
||||
self.remove_cross_attention_layers[i] = True
|
||||
assert len(self.remove_cross_attention_layers) == len(self.layers)
|
||||
|
||||
for i, remove_cross_attention in enumerate(self.remove_cross_attention_layers):
|
||||
if remove_cross_attention:
|
||||
self.layers[i].cross_attn_image = None
|
||||
self.layers[i].norm2 = None
|
||||
self.layers[i].dropout2 = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src, # self-attention inputs
|
||||
prompt, # cross-attention inputs
|
||||
src_mask: Optional[Tensor] = None, # att.mask for self-attention inputs
|
||||
prompt_mask: Optional[Tensor] = None, # att.mask for cross-attention inputs
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
prompt_key_padding_mask: Optional[Tensor] = None,
|
||||
src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
|
||||
prompt_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
|
||||
feat_sizes: Optional[list] = None,
|
||||
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
||||
):
|
||||
if isinstance(src, list):
|
||||
assert isinstance(src_key_padding_mask, list) and isinstance(src_pos, list)
|
||||
assert len(src) == len(src_key_padding_mask) == len(src_pos) == 1
|
||||
src, src_key_padding_mask, src_pos = (
|
||||
src[0],
|
||||
src_key_padding_mask[0],
|
||||
src_pos[0],
|
||||
)
|
||||
|
||||
assert (
|
||||
src.shape[1] == prompt.shape[1]
|
||||
), "Batch size must be the same for src and prompt"
|
||||
|
||||
output = src
|
||||
|
||||
if self.pos_enc_at_input and src_pos is not None:
|
||||
output = output + 0.1 * src_pos
|
||||
|
||||
if self.batch_first:
|
||||
# Convert to batch first
|
||||
output = output.transpose(0, 1)
|
||||
src_pos = src_pos.transpose(0, 1)
|
||||
prompt = prompt.transpose(0, 1)
|
||||
prompt_pos = prompt_pos.transpose(0, 1)
|
||||
|
||||
for layer in self.layers:
|
||||
kwds = {}
|
||||
if isinstance(layer.cross_attn_image, RoPEAttention):
|
||||
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
|
||||
|
||||
output = activation_ckpt_wrapper(layer)(
|
||||
tgt=output,
|
||||
memory=prompt,
|
||||
tgt_mask=src_mask,
|
||||
memory_mask=prompt_mask,
|
||||
tgt_key_padding_mask=src_key_padding_mask,
|
||||
memory_key_padding_mask=prompt_key_padding_mask,
|
||||
pos=prompt_pos,
|
||||
query_pos=src_pos,
|
||||
dac=False,
|
||||
attn_bias=None,
|
||||
act_ckpt_enable=self.training and self.use_act_checkpoint,
|
||||
**kwds,
|
||||
)
|
||||
normed_output = self.norm(output)
|
||||
|
||||
if self.batch_first:
|
||||
# Convert back to seq first
|
||||
normed_output = normed_output.transpose(0, 1)
|
||||
src_pos = src_pos.transpose(0, 1)
|
||||
|
||||
return {
|
||||
"memory": normed_output,
|
||||
"pos_embed": src_pos,
|
||||
"padding_mask": src_key_padding_mask,
|
||||
}
|
||||
|
||||
|
||||
class TransformerDecoderLayerv1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation: str,
|
||||
cross_attention: nn.Module,
|
||||
d_model: int,
|
||||
dim_feedforward: int,
|
||||
dropout: float,
|
||||
pos_enc_at_attn: bool,
|
||||
pos_enc_at_cross_attn_keys: bool,
|
||||
pos_enc_at_cross_attn_queries: bool,
|
||||
pre_norm: bool,
|
||||
self_attention: nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.dim_feedforward = dim_feedforward
|
||||
self.dropout_value = dropout
|
||||
self.self_attn = self_attention
|
||||
self.cross_attn_image = cross_attention
|
||||
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation_str = activation
|
||||
self.activation = get_activation_fn(activation)
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
self.pos_enc_at_attn = pos_enc_at_attn
|
||||
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
||||
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
|
||||
|
||||
# Self attention
|
||||
tgt2 = self.self_attn(
|
||||
q,
|
||||
k,
|
||||
value=tgt,
|
||||
attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
# Cross attention to image
|
||||
tgt2 = self.cross_attn_image(
|
||||
query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
|
||||
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
|
||||
# FFN
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
dac: bool = False,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
attn_bias: Optional[Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if dac:
|
||||
# we only apply self attention to the first half of the queries
|
||||
assert tgt.shape[0] % 2 == 0
|
||||
other_tgt = tgt[tgt.shape[0] // 2 :]
|
||||
tgt = tgt[: tgt.shape[0] // 2]
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
||||
tgt2 = self.self_attn(
|
||||
q,
|
||||
k,
|
||||
value=tgt2,
|
||||
attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
if dac:
|
||||
# Recombine
|
||||
tgt = torch.cat((tgt, other_tgt), dim=0)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.cross_attn_image(
|
||||
query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
||||
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
attn_bias=attn_bias,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
dac: bool = False,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
attn_bias: Optional[Tensor] = None,
|
||||
**kwds: Any,
|
||||
) -> torch.Tensor:
|
||||
fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
|
||||
return fwd_fn(
|
||||
tgt,
|
||||
memory,
|
||||
dac=dac,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
attn_bias=attn_bias,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
|
||||
class TransformerDecoderLayerv2(TransformerDecoderLayerv1):
|
||||
def __init__(self, cross_attention_first=False, *args: Any, **kwds: Any):
|
||||
super().__init__(*args, **kwds)
|
||||
self.cross_attention_first = cross_attention_first
|
||||
|
||||
def _forward_sa(self, tgt, query_pos):
|
||||
# Self-Attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
||||
tgt2 = self.self_attn(q, k, v=tgt2)
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
return tgt
|
||||
|
||||
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
|
||||
if self.cross_attn_image is None:
|
||||
return tgt
|
||||
|
||||
kwds = {}
|
||||
if num_k_exclude_rope > 0:
|
||||
assert isinstance(self.cross_attn_image, RoPEAttention)
|
||||
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
|
||||
|
||||
# Cross-Attention
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.cross_attn_image(
|
||||
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
||||
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
||||
v=memory,
|
||||
**kwds,
|
||||
)
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
dac: bool,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
attn_bias: Optional[Tensor] = None,
|
||||
num_k_exclude_rope: int = 0,
|
||||
):
|
||||
assert dac is False
|
||||
assert tgt_mask is None
|
||||
assert memory_mask is None
|
||||
assert tgt_key_padding_mask is None
|
||||
assert memory_key_padding_mask is None
|
||||
assert attn_bias is None
|
||||
|
||||
if self.cross_attention_first:
|
||||
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
||||
tgt = self._forward_sa(tgt, query_pos)
|
||||
else:
|
||||
tgt = self._forward_sa(tgt, query_pos)
|
||||
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
||||
|
||||
# MLP
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(self, *args: Any, **kwds: Any) -> torch.Tensor:
|
||||
if self.pre_norm:
|
||||
return self.forward_pre(*args, **kwds)
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user