Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
176
sam3/model/vl_combiner.py
Normal file
176
sam3/model/vl_combiner.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""Provides utility to combine a vision backbone with a language backbone."""
|
||||
|
||||
from copy import copy
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||
from .necks import Sam3DualViTDetNeck
|
||||
|
||||
|
||||
class SAM3VLBackbone(nn.Module):
|
||||
"""This backbone combines a vision backbone and a language backbone without fusion.
|
||||
As such it is more of a convenience wrapper to handle the two backbones together.
|
||||
|
||||
It adds support for activation checkpointing and compilation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
visual: Sam3DualViTDetNeck,
|
||||
text,
|
||||
compile_visual: bool = False,
|
||||
act_ckpt_whole_vision_backbone: bool = False,
|
||||
act_ckpt_whole_language_backbone: bool = False,
|
||||
scalp=0,
|
||||
):
|
||||
"""Initialize the backbone combiner.
|
||||
|
||||
:param visual: The vision backbone to use
|
||||
:param text: The text encoder to use
|
||||
"""
|
||||
super().__init__()
|
||||
self.vision_backbone: Sam3DualViTDetNeck = (
|
||||
torch.compile(visual) if compile_visual else visual
|
||||
)
|
||||
self.language_backbone = text
|
||||
self.scalp = scalp
|
||||
# allow running activation checkpointing on the entire vision and language backbones
|
||||
self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
|
||||
self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
|
||||
|
||||
def forward(
|
||||
self,
|
||||
samples: torch.Tensor,
|
||||
captions: List[str],
|
||||
input_boxes: Optional[torch.Tensor] = None,
|
||||
additional_text: Optional[List[str]] = None,
|
||||
):
|
||||
"""Forward pass of the backbone combiner.
|
||||
|
||||
:param samples: The input images
|
||||
:param captions: The input captions
|
||||
:param input_boxes: If the text contains place-holders for boxes, this
|
||||
parameter contains the tensor containing their spatial features
|
||||
:param additional_text: This can be used to encode some additional text
|
||||
(different from the captions) in the same forward of the backbone
|
||||
:return: Output dictionary with the following keys:
|
||||
- vision_features: The output of the vision backbone
|
||||
- language_features: The output of the language backbone
|
||||
- language_mask: The attention mask of the language backbone
|
||||
- vision_pos_enc: The positional encoding of the vision backbone
|
||||
- (optional) additional_text_features: The output of the language
|
||||
backbone for the additional text
|
||||
- (optional) additional_text_mask: The attention mask of the
|
||||
language backbone for the additional text
|
||||
"""
|
||||
output = self.forward_image(samples)
|
||||
device = output["vision_features"].device
|
||||
output.update(self.forward_text(captions, input_boxes, additional_text, device))
|
||||
return output
|
||||
|
||||
def forward_image(self, samples: torch.Tensor):
|
||||
return activation_ckpt_wrapper(self._forward_image_no_act_ckpt)(
|
||||
samples=samples,
|
||||
act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training,
|
||||
)
|
||||
|
||||
def _forward_image_no_act_ckpt(self, samples):
|
||||
# Forward through backbone
|
||||
sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(
|
||||
samples
|
||||
)
|
||||
if self.scalp > 0:
|
||||
# Discard the lowest resolution features
|
||||
sam3_features, sam3_pos = (
|
||||
sam3_features[: -self.scalp],
|
||||
sam3_pos[: -self.scalp],
|
||||
)
|
||||
if sam2_features is not None and sam2_pos is not None:
|
||||
sam2_features, sam2_pos = (
|
||||
sam2_features[: -self.scalp],
|
||||
sam2_pos[: -self.scalp],
|
||||
)
|
||||
|
||||
sam2_output = None
|
||||
|
||||
if sam2_features is not None and sam2_pos is not None:
|
||||
sam2_src = sam2_features[-1]
|
||||
sam2_output = {
|
||||
"vision_features": sam2_src,
|
||||
"vision_pos_enc": sam2_pos,
|
||||
"backbone_fpn": sam2_features,
|
||||
}
|
||||
|
||||
sam3_src = sam3_features[-1]
|
||||
output = {
|
||||
"vision_features": sam3_src,
|
||||
"vision_pos_enc": sam3_pos,
|
||||
"backbone_fpn": sam3_features,
|
||||
"sam2_backbone_out": sam2_output,
|
||||
}
|
||||
|
||||
return output
|
||||
|
||||
def forward_text(
|
||||
self, captions, input_boxes=None, additional_text=None, device="cuda"
|
||||
):
|
||||
return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)(
|
||||
captions=captions,
|
||||
input_boxes=input_boxes,
|
||||
additional_text=additional_text,
|
||||
device=device,
|
||||
act_ckpt_enable=self.act_ckpt_whole_language_backbone and self.training,
|
||||
)
|
||||
|
||||
def _forward_text_no_ack_ckpt(
|
||||
self,
|
||||
captions,
|
||||
input_boxes=None,
|
||||
additional_text=None,
|
||||
device="cuda",
|
||||
):
|
||||
output = {}
|
||||
|
||||
# Forward through text_encoder
|
||||
text_to_encode = copy(captions)
|
||||
if additional_text is not None:
|
||||
# if there are additional_text, we piggy-back them into this forward.
|
||||
# They'll be used later for output alignment
|
||||
text_to_encode += additional_text
|
||||
|
||||
sdpa_context = sdpa_kernel(
|
||||
[
|
||||
SDPBackend.MATH,
|
||||
SDPBackend.EFFICIENT_ATTENTION,
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
]
|
||||
)
|
||||
|
||||
with sdpa_context:
|
||||
text_attention_mask, text_memory, text_embeds = self.language_backbone(
|
||||
text_to_encode, input_boxes, device=device
|
||||
)
|
||||
|
||||
if additional_text is not None:
|
||||
output["additional_text_features"] = text_memory[:, -len(additional_text) :]
|
||||
output["additional_text_mask"] = text_attention_mask[
|
||||
-len(additional_text) :
|
||||
]
|
||||
|
||||
text_memory = text_memory[:, : len(captions)]
|
||||
text_attention_mask = text_attention_mask[: len(captions)]
|
||||
text_embeds = text_embeds[:, : len(captions)]
|
||||
output["language_features"] = text_memory
|
||||
output["language_mask"] = text_attention_mask
|
||||
output["language_embeds"] = (
|
||||
text_embeds # Text embeddings before forward to the encoder
|
||||
)
|
||||
|
||||
return output
|
||||
Reference in New Issue
Block a user