Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
795
sam3/model_builder.py
Normal file
795
sam3/model_builder.py
Normal file
@@ -0,0 +1,795 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from sam3.model.decoder import (
|
||||
TransformerDecoder,
|
||||
TransformerDecoderLayer,
|
||||
TransformerDecoderLayerv2,
|
||||
TransformerEncoderCrossAttention,
|
||||
)
|
||||
from sam3.model.encoder import TransformerEncoderFusion, TransformerEncoderLayer
|
||||
from sam3.model.geometry_encoders import SequenceGeometryEncoder
|
||||
from sam3.model.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
|
||||
from sam3.model.memory import (
|
||||
CXBlock,
|
||||
SimpleFuser,
|
||||
SimpleMaskDownSampler,
|
||||
SimpleMaskEncoder,
|
||||
)
|
||||
from sam3.model.model_misc import (
|
||||
DotProductScoring,
|
||||
MLP,
|
||||
MultiheadAttentionWrapper as MultiheadAttention,
|
||||
TransformerWrapper,
|
||||
)
|
||||
from sam3.model.necks import Sam3DualViTDetNeck
|
||||
from sam3.model.position_encoding import PositionEmbeddingSine
|
||||
|
||||
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
|
||||
|
||||
from sam3.model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU
|
||||
from sam3.model.sam3_tracking_predictor import Sam3TrackerPredictor
|
||||
from sam3.model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity
|
||||
from sam3.model.sam3_video_predictor import Sam3VideoPredictorMultiGPU
|
||||
from sam3.model.text_encoder_ve import VETextEncoder
|
||||
from sam3.model.tokenizer_ve import SimpleTokenizer
|
||||
from sam3.model.vitdet import ViT
|
||||
from sam3.model.vl_combiner import SAM3VLBackbone
|
||||
from sam3.sam.transformer import RoPEAttention
|
||||
|
||||
SAM3_MODEL_ID = "facebook/sam3"
|
||||
SAM3_CKPT_NAME = "sam3.pt"
|
||||
|
||||
|
||||
# Setup TensorFloat-32 for Ampere GPUs if available
|
||||
def _setup_tf32() -> None:
|
||||
"""Enable TensorFloat-32 for Ampere GPUs if available."""
|
||||
if torch.cuda.is_available():
|
||||
device_props = torch.cuda.get_device_properties(0)
|
||||
if device_props.major >= 8:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
_setup_tf32()
|
||||
|
||||
|
||||
def _create_position_encoding(precompute_resolution=None):
|
||||
"""Create position encoding for visual backbone."""
|
||||
return PositionEmbeddingSine(
|
||||
num_pos_feats=256,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
precompute_resolution=precompute_resolution,
|
||||
)
|
||||
|
||||
|
||||
def _create_vit_backbone(compile_mode=None):
|
||||
"""Create ViT backbone for visual feature extraction."""
|
||||
return ViT(
|
||||
img_size=1008,
|
||||
pretrain_img_size=336,
|
||||
patch_size=14,
|
||||
embed_dim=1024,
|
||||
depth=32,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.625,
|
||||
norm_layer="LayerNorm",
|
||||
drop_path_rate=0.1,
|
||||
qkv_bias=True,
|
||||
use_abs_pos=True,
|
||||
tile_abs_pos=True,
|
||||
global_att_blocks=(7, 15, 23, 31),
|
||||
rel_pos_blocks=(),
|
||||
use_rope=True,
|
||||
use_interp_rope=True,
|
||||
window_size=24,
|
||||
pretrain_use_cls_token=True,
|
||||
retain_cls_token=False,
|
||||
ln_pre=True,
|
||||
ln_post=False,
|
||||
return_interm_layers=False,
|
||||
bias_patch_embed=False,
|
||||
compile_mode=compile_mode,
|
||||
)
|
||||
|
||||
|
||||
def _create_vit_neck(position_encoding, vit_backbone, enable_inst_interactivity=False):
|
||||
"""Create ViT neck for feature pyramid."""
|
||||
return Sam3DualViTDetNeck(
|
||||
position_encoding=position_encoding,
|
||||
d_model=256,
|
||||
scale_factors=[4.0, 2.0, 1.0, 0.5],
|
||||
trunk=vit_backbone,
|
||||
add_sam2_neck=enable_inst_interactivity,
|
||||
)
|
||||
|
||||
|
||||
def _create_vl_backbone(vit_neck, text_encoder):
|
||||
"""Create visual-language backbone."""
|
||||
return SAM3VLBackbone(visual=vit_neck, text=text_encoder, scalp=1)
|
||||
|
||||
|
||||
def _create_transformer_encoder() -> TransformerEncoderFusion:
|
||||
"""Create transformer encoder with its layer."""
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
activation="relu",
|
||||
d_model=256,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
pos_enc_at_attn=True,
|
||||
pos_enc_at_cross_attn_keys=False,
|
||||
pos_enc_at_cross_attn_queries=False,
|
||||
pre_norm=True,
|
||||
self_attention=MultiheadAttention(
|
||||
num_heads=8,
|
||||
dropout=0.1,
|
||||
embed_dim=256,
|
||||
batch_first=True,
|
||||
),
|
||||
cross_attention=MultiheadAttention(
|
||||
num_heads=8,
|
||||
dropout=0.1,
|
||||
embed_dim=256,
|
||||
batch_first=True,
|
||||
),
|
||||
)
|
||||
|
||||
encoder = TransformerEncoderFusion(
|
||||
layer=encoder_layer,
|
||||
num_layers=6,
|
||||
d_model=256,
|
||||
num_feature_levels=1,
|
||||
frozen=False,
|
||||
use_act_checkpoint=True,
|
||||
add_pooled_text_to_img_feat=False,
|
||||
pool_text_with_mask=True,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def _create_transformer_decoder() -> TransformerDecoder:
|
||||
"""Create transformer decoder with its layer."""
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
activation="relu",
|
||||
d_model=256,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
cross_attention=MultiheadAttention(
|
||||
num_heads=8,
|
||||
dropout=0.1,
|
||||
embed_dim=256,
|
||||
),
|
||||
n_heads=8,
|
||||
use_text_cross_attention=True,
|
||||
)
|
||||
|
||||
decoder = TransformerDecoder(
|
||||
layer=decoder_layer,
|
||||
num_layers=6,
|
||||
num_queries=200,
|
||||
return_intermediate=True,
|
||||
box_refine=True,
|
||||
num_o2m_queries=0,
|
||||
dac=True,
|
||||
boxRPB="log",
|
||||
d_model=256,
|
||||
frozen=False,
|
||||
interaction_layer=None,
|
||||
dac_use_selfatt_ln=True,
|
||||
resolution=1008,
|
||||
stride=14,
|
||||
use_act_checkpoint=True,
|
||||
presence_token=True,
|
||||
)
|
||||
return decoder
|
||||
|
||||
|
||||
def _create_dot_product_scoring():
|
||||
"""Create dot product scoring module."""
|
||||
prompt_mlp = MLP(
|
||||
input_dim=256,
|
||||
hidden_dim=2048,
|
||||
output_dim=256,
|
||||
num_layers=2,
|
||||
dropout=0.1,
|
||||
residual=True,
|
||||
out_norm=nn.LayerNorm(256),
|
||||
)
|
||||
return DotProductScoring(d_model=256, d_proj=256, prompt_mlp=prompt_mlp)
|
||||
|
||||
|
||||
def _create_segmentation_head(compile_mode=None):
|
||||
"""Create segmentation head with pixel decoder."""
|
||||
pixel_decoder = PixelDecoder(
|
||||
num_upsampling_stages=3,
|
||||
interpolation_mode="nearest",
|
||||
hidden_dim=256,
|
||||
compile_mode=compile_mode,
|
||||
)
|
||||
|
||||
cross_attend_prompt = MultiheadAttention(
|
||||
num_heads=8,
|
||||
dropout=0,
|
||||
embed_dim=256,
|
||||
)
|
||||
|
||||
segmentation_head = UniversalSegmentationHead(
|
||||
hidden_dim=256,
|
||||
upsampling_stages=3,
|
||||
aux_masks=False,
|
||||
presence_head=False,
|
||||
dot_product_scorer=None,
|
||||
act_ckpt=True,
|
||||
cross_attend_prompt=cross_attend_prompt,
|
||||
pixel_decoder=pixel_decoder,
|
||||
)
|
||||
return segmentation_head
|
||||
|
||||
|
||||
def _create_geometry_encoder():
|
||||
"""Create geometry encoder with all its components."""
|
||||
# Create position encoding for geometry encoder
|
||||
geo_pos_enc = _create_position_encoding()
|
||||
# Create CX block for fuser
|
||||
cx_block = CXBlock(
|
||||
dim=256,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
layer_scale_init_value=1.0e-06,
|
||||
use_dwconv=True,
|
||||
)
|
||||
# Create geometry encoder layer
|
||||
geo_layer = TransformerEncoderLayer(
|
||||
activation="relu",
|
||||
d_model=256,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
pos_enc_at_attn=False,
|
||||
pre_norm=True,
|
||||
self_attention=MultiheadAttention(
|
||||
num_heads=8,
|
||||
dropout=0.1,
|
||||
embed_dim=256,
|
||||
batch_first=False,
|
||||
),
|
||||
pos_enc_at_cross_attn_queries=False,
|
||||
pos_enc_at_cross_attn_keys=True,
|
||||
cross_attention=MultiheadAttention(
|
||||
num_heads=8,
|
||||
dropout=0.1,
|
||||
embed_dim=256,
|
||||
batch_first=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Create geometry encoder
|
||||
input_geometry_encoder = SequenceGeometryEncoder(
|
||||
pos_enc=geo_pos_enc,
|
||||
encode_boxes_as_points=False,
|
||||
points_direct_project=True,
|
||||
points_pool=True,
|
||||
points_pos_enc=True,
|
||||
boxes_direct_project=True,
|
||||
boxes_pool=True,
|
||||
boxes_pos_enc=True,
|
||||
d_model=256,
|
||||
num_layers=3,
|
||||
layer=geo_layer,
|
||||
use_act_ckpt=True,
|
||||
add_cls=True,
|
||||
add_post_encode_proj=True,
|
||||
)
|
||||
return input_geometry_encoder
|
||||
|
||||
|
||||
def _create_sam3_model(
|
||||
backbone,
|
||||
transformer,
|
||||
input_geometry_encoder,
|
||||
segmentation_head,
|
||||
dot_prod_scoring,
|
||||
inst_interactive_predictor,
|
||||
eval_mode,
|
||||
):
|
||||
"""Create the SAM3 image model."""
|
||||
common_params = {
|
||||
"backbone": backbone,
|
||||
"transformer": transformer,
|
||||
"input_geometry_encoder": input_geometry_encoder,
|
||||
"segmentation_head": segmentation_head,
|
||||
"num_feature_levels": 1,
|
||||
"o2m_mask_predict": True,
|
||||
"dot_prod_scoring": dot_prod_scoring,
|
||||
"use_instance_query": False,
|
||||
"multimask_output": True,
|
||||
"inst_interactive_predictor": inst_interactive_predictor,
|
||||
}
|
||||
|
||||
matcher = None
|
||||
if not eval_mode:
|
||||
from sam3.train.matcher import BinaryHungarianMatcherV2
|
||||
|
||||
matcher = BinaryHungarianMatcherV2(
|
||||
focal=True,
|
||||
cost_class=2.0,
|
||||
cost_bbox=5.0,
|
||||
cost_giou=2.0,
|
||||
alpha=0.25,
|
||||
gamma=2,
|
||||
stable=False,
|
||||
)
|
||||
common_params["matcher"] = matcher
|
||||
model = Sam3Image(**common_params)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _create_tracker_maskmem_backbone():
|
||||
"""Create the SAM3 Tracker memory encoder."""
|
||||
# Position encoding for mask memory backbone
|
||||
position_encoding = PositionEmbeddingSine(
|
||||
num_pos_feats=64,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
precompute_resolution=1008,
|
||||
)
|
||||
|
||||
# Mask processing components
|
||||
mask_downsampler = SimpleMaskDownSampler(
|
||||
kernel_size=3, stride=2, padding=1, interpol_size=[1152, 1152]
|
||||
)
|
||||
|
||||
cx_block_layer = CXBlock(
|
||||
dim=256,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
layer_scale_init_value=1.0e-06,
|
||||
use_dwconv=True,
|
||||
)
|
||||
|
||||
fuser = SimpleFuser(layer=cx_block_layer, num_layers=2)
|
||||
|
||||
maskmem_backbone = SimpleMaskEncoder(
|
||||
out_dim=64,
|
||||
position_encoding=position_encoding,
|
||||
mask_downsampler=mask_downsampler,
|
||||
fuser=fuser,
|
||||
)
|
||||
|
||||
return maskmem_backbone
|
||||
|
||||
|
||||
def _create_tracker_transformer():
|
||||
"""Create the SAM3 Tracker transformer components."""
|
||||
# Self attention
|
||||
self_attention = RoPEAttention(
|
||||
embedding_dim=256,
|
||||
num_heads=1,
|
||||
downsample_rate=1,
|
||||
dropout=0.1,
|
||||
rope_theta=10000.0,
|
||||
feat_sizes=[72, 72],
|
||||
use_fa3=False,
|
||||
use_rope_real=False,
|
||||
)
|
||||
|
||||
# Cross attention
|
||||
cross_attention = RoPEAttention(
|
||||
embedding_dim=256,
|
||||
num_heads=1,
|
||||
downsample_rate=1,
|
||||
dropout=0.1,
|
||||
kv_in_dim=64,
|
||||
rope_theta=10000.0,
|
||||
feat_sizes=[72, 72],
|
||||
rope_k_repeat=True,
|
||||
use_fa3=False,
|
||||
use_rope_real=False,
|
||||
)
|
||||
|
||||
# Encoder layer
|
||||
encoder_layer = TransformerDecoderLayerv2(
|
||||
cross_attention_first=False,
|
||||
activation="relu",
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
pos_enc_at_attn=False,
|
||||
pre_norm=True,
|
||||
self_attention=self_attention,
|
||||
d_model=256,
|
||||
pos_enc_at_cross_attn_keys=True,
|
||||
pos_enc_at_cross_attn_queries=False,
|
||||
cross_attention=cross_attention,
|
||||
)
|
||||
|
||||
# Encoder
|
||||
encoder = TransformerEncoderCrossAttention(
|
||||
remove_cross_attention_layers=[],
|
||||
batch_first=True,
|
||||
d_model=256,
|
||||
frozen=False,
|
||||
pos_enc_at_input=True,
|
||||
layer=encoder_layer,
|
||||
num_layers=4,
|
||||
use_act_checkpoint=False,
|
||||
)
|
||||
|
||||
# Transformer wrapper
|
||||
transformer = TransformerWrapper(
|
||||
encoder=encoder,
|
||||
decoder=None,
|
||||
d_model=256,
|
||||
)
|
||||
|
||||
return transformer
|
||||
|
||||
|
||||
def build_tracker(
|
||||
apply_temporal_disambiguation: bool, with_backbone: bool = False, compile_mode=None
|
||||
) -> Sam3TrackerPredictor:
|
||||
"""
|
||||
Build the SAM3 Tracker module for video tracking.
|
||||
|
||||
Returns:
|
||||
Sam3TrackerPredictor: Wrapped SAM3 Tracker module
|
||||
"""
|
||||
|
||||
# Create model components
|
||||
maskmem_backbone = _create_tracker_maskmem_backbone()
|
||||
transformer = _create_tracker_transformer()
|
||||
backbone = None
|
||||
if with_backbone:
|
||||
vision_backbone = _create_vision_backbone(compile_mode=compile_mode)
|
||||
backbone = SAM3VLBackbone(scalp=1, visual=vision_backbone, text=None)
|
||||
# Create the Tracker module
|
||||
model = Sam3TrackerPredictor(
|
||||
image_size=1008,
|
||||
num_maskmem=7,
|
||||
backbone=backbone,
|
||||
backbone_stride=14,
|
||||
transformer=transformer,
|
||||
maskmem_backbone=maskmem_backbone,
|
||||
# SAM parameters
|
||||
multimask_output_in_sam=True,
|
||||
# Evaluation
|
||||
forward_backbone_per_frame_for_eval=True,
|
||||
trim_past_non_cond_mem_for_eval=False,
|
||||
# Multimask
|
||||
multimask_output_for_tracking=True,
|
||||
multimask_min_pt_num=0,
|
||||
multimask_max_pt_num=1,
|
||||
# Additional settings
|
||||
always_start_from_first_ann_frame=False,
|
||||
# Mask overlap
|
||||
non_overlap_masks_for_mem_enc=False,
|
||||
non_overlap_masks_for_output=False,
|
||||
max_cond_frames_in_attn=4,
|
||||
offload_output_to_cpu_for_eval=False,
|
||||
# SAM decoder settings
|
||||
sam_mask_decoder_extra_args={
|
||||
"dynamic_multimask_via_stability": True,
|
||||
"dynamic_multimask_stability_delta": 0.05,
|
||||
"dynamic_multimask_stability_thresh": 0.98,
|
||||
},
|
||||
clear_non_cond_mem_around_input=True,
|
||||
fill_hole_area=0,
|
||||
use_memory_selection=apply_temporal_disambiguation,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _create_text_encoder(bpe_path: str) -> VETextEncoder:
|
||||
"""Create SAM3 text encoder."""
|
||||
tokenizer = SimpleTokenizer(bpe_path=bpe_path)
|
||||
return VETextEncoder(
|
||||
tokenizer=tokenizer,
|
||||
d_model=256,
|
||||
width=1024,
|
||||
heads=16,
|
||||
layers=24,
|
||||
)
|
||||
|
||||
|
||||
def _create_vision_backbone(
|
||||
compile_mode=None, enable_inst_interactivity=True
|
||||
) -> Sam3DualViTDetNeck:
|
||||
"""Create SAM3 visual backbone with ViT and neck."""
|
||||
# Position encoding
|
||||
position_encoding = _create_position_encoding(precompute_resolution=1008)
|
||||
# ViT backbone
|
||||
vit_backbone: ViT = _create_vit_backbone(compile_mode=compile_mode)
|
||||
vit_neck: Sam3DualViTDetNeck = _create_vit_neck(
|
||||
position_encoding,
|
||||
vit_backbone,
|
||||
enable_inst_interactivity=enable_inst_interactivity,
|
||||
)
|
||||
# Visual neck
|
||||
return vit_neck
|
||||
|
||||
|
||||
def _create_sam3_transformer(has_presence_token: bool = True) -> TransformerWrapper:
|
||||
"""Create SAM3 transformer encoder and decoder."""
|
||||
encoder: TransformerEncoderFusion = _create_transformer_encoder()
|
||||
decoder: TransformerDecoder = _create_transformer_decoder()
|
||||
|
||||
return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
|
||||
|
||||
|
||||
def _load_checkpoint(model, checkpoint_path):
|
||||
"""Load model checkpoint from file."""
|
||||
with g_pathmgr.open(checkpoint_path, "rb") as f:
|
||||
ckpt = torch.load(f, map_location="cpu", weights_only=True)
|
||||
if "model" in ckpt and isinstance(ckpt["model"], dict):
|
||||
ckpt = ckpt["model"]
|
||||
sam3_image_ckpt = {
|
||||
k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k
|
||||
}
|
||||
if model.inst_interactive_predictor is not None:
|
||||
sam3_image_ckpt.update(
|
||||
{
|
||||
k.replace("tracker.", "inst_interactive_predictor.model."): v
|
||||
for k, v in ckpt.items()
|
||||
if "tracker" in k
|
||||
}
|
||||
)
|
||||
missing_keys, _ = model.load_state_dict(sam3_image_ckpt, strict=False)
|
||||
if len(missing_keys) > 0:
|
||||
print(
|
||||
f"loaded {checkpoint_path} and found "
|
||||
f"missing and/or unexpected keys:\n{missing_keys=}"
|
||||
)
|
||||
|
||||
|
||||
def _setup_device_and_mode(model, device, eval_mode):
|
||||
"""Setup model device and evaluation mode."""
|
||||
if device == "cuda":
|
||||
model = model.cuda()
|
||||
if eval_mode:
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def build_sam3_image_model(
|
||||
bpe_path=None,
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
eval_mode=True,
|
||||
checkpoint_path=None,
|
||||
load_from_HF=True,
|
||||
enable_segmentation=True,
|
||||
enable_inst_interactivity=False,
|
||||
compile=False,
|
||||
):
|
||||
"""
|
||||
Build SAM3 image model
|
||||
|
||||
Args:
|
||||
bpe_path: Path to the BPE tokenizer vocabulary
|
||||
device: Device to load the model on ('cuda' or 'cpu')
|
||||
eval_mode: Whether to set the model to evaluation mode
|
||||
checkpoint_path: Optional path to model checkpoint
|
||||
enable_segmentation: Whether to enable segmentation head
|
||||
enable_inst_interactivity: Whether to enable instance interactivity (SAM 1 task)
|
||||
compile_mode: To enable compilation, set to "default"
|
||||
|
||||
Returns:
|
||||
A SAM3 image model
|
||||
"""
|
||||
if bpe_path is None:
|
||||
bpe_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "assets", "bpe_simple_vocab_16e6.txt.gz"
|
||||
)
|
||||
# Create visual components
|
||||
compile_mode = "default" if compile else None
|
||||
vision_encoder = _create_vision_backbone(
|
||||
compile_mode=compile_mode, enable_inst_interactivity=enable_inst_interactivity
|
||||
)
|
||||
|
||||
# Create text components
|
||||
text_encoder = _create_text_encoder(bpe_path)
|
||||
|
||||
# Create visual-language backbone
|
||||
backbone = _create_vl_backbone(vision_encoder, text_encoder)
|
||||
|
||||
# Create transformer components
|
||||
transformer = _create_sam3_transformer()
|
||||
|
||||
# Create dot product scoring
|
||||
dot_prod_scoring = _create_dot_product_scoring()
|
||||
|
||||
# Create segmentation head if enabled
|
||||
segmentation_head = (
|
||||
_create_segmentation_head(compile_mode=compile_mode)
|
||||
if enable_segmentation
|
||||
else None
|
||||
)
|
||||
|
||||
# Create geometry encoder
|
||||
input_geometry_encoder = _create_geometry_encoder()
|
||||
if enable_inst_interactivity:
|
||||
sam3_pvs_base = build_tracker(apply_temporal_disambiguation=False)
|
||||
inst_predictor = SAM3InteractiveImagePredictor(sam3_pvs_base)
|
||||
else:
|
||||
inst_predictor = None
|
||||
# Create the SAM3 model
|
||||
model = _create_sam3_model(
|
||||
backbone,
|
||||
transformer,
|
||||
input_geometry_encoder,
|
||||
segmentation_head,
|
||||
dot_prod_scoring,
|
||||
inst_predictor,
|
||||
eval_mode,
|
||||
)
|
||||
if load_from_HF and checkpoint_path is None:
|
||||
checkpoint_path = hf_hub_download(
|
||||
repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME
|
||||
)
|
||||
# Load checkpoint if provided
|
||||
if checkpoint_path is not None:
|
||||
_load_checkpoint(model, checkpoint_path)
|
||||
|
||||
# Setup device and mode
|
||||
model = _setup_device_and_mode(model, device, eval_mode)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def build_sam3_video_model(
|
||||
checkpoint_path: Optional[str] = None,
|
||||
load_from_HF=True,
|
||||
bpe_path: Optional[str] = None,
|
||||
has_presence_token: bool = True,
|
||||
geo_encoder_use_img_cross_attn: bool = True,
|
||||
strict_state_dict_loading: bool = True,
|
||||
apply_temporal_disambiguation: bool = True,
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
compile=False,
|
||||
) -> Sam3VideoInferenceWithInstanceInteractivity:
|
||||
"""
|
||||
Build SAM3 dense tracking model.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Optional path to checkpoint file
|
||||
bpe_path: Path to the BPE tokenizer file
|
||||
|
||||
Returns:
|
||||
Sam3VideoInferenceWithInstanceInteractivity: The instantiated dense tracking model
|
||||
"""
|
||||
if bpe_path is None:
|
||||
bpe_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "assets", "bpe_simple_vocab_16e6.txt.gz"
|
||||
)
|
||||
|
||||
# Build Tracker module
|
||||
tracker = build_tracker(apply_temporal_disambiguation=apply_temporal_disambiguation)
|
||||
|
||||
# Build Detector components
|
||||
visual_neck = _create_vision_backbone()
|
||||
text_encoder = _create_text_encoder(bpe_path)
|
||||
backbone = SAM3VLBackbone(scalp=1, visual=visual_neck, text=text_encoder)
|
||||
transformer = _create_sam3_transformer(has_presence_token=has_presence_token)
|
||||
segmentation_head: UniversalSegmentationHead = _create_segmentation_head()
|
||||
input_geometry_encoder = _create_geometry_encoder()
|
||||
|
||||
# Create main dot product scoring
|
||||
main_dot_prod_mlp = MLP(
|
||||
input_dim=256,
|
||||
hidden_dim=2048,
|
||||
output_dim=256,
|
||||
num_layers=2,
|
||||
dropout=0.1,
|
||||
residual=True,
|
||||
out_norm=nn.LayerNorm(256),
|
||||
)
|
||||
main_dot_prod_scoring = DotProductScoring(
|
||||
d_model=256, d_proj=256, prompt_mlp=main_dot_prod_mlp
|
||||
)
|
||||
|
||||
# Build Detector module
|
||||
detector = Sam3ImageOnVideoMultiGPU(
|
||||
num_feature_levels=1,
|
||||
backbone=backbone,
|
||||
transformer=transformer,
|
||||
segmentation_head=segmentation_head,
|
||||
semantic_segmentation_head=None,
|
||||
input_geometry_encoder=input_geometry_encoder,
|
||||
use_early_fusion=True,
|
||||
use_dot_prod_scoring=True,
|
||||
dot_prod_scoring=main_dot_prod_scoring,
|
||||
supervise_joint_box_scores=has_presence_token,
|
||||
)
|
||||
|
||||
# Build the main SAM3 video model
|
||||
if apply_temporal_disambiguation:
|
||||
model = Sam3VideoInferenceWithInstanceInteractivity(
|
||||
detector=detector,
|
||||
tracker=tracker,
|
||||
score_threshold_detection=0.5,
|
||||
assoc_iou_thresh=0.1,
|
||||
det_nms_thresh=0.1,
|
||||
new_det_thresh=0.7,
|
||||
hotstart_delay=15,
|
||||
hotstart_unmatch_thresh=8,
|
||||
hotstart_dup_thresh=8,
|
||||
suppress_unmatched_only_within_hotstart=True,
|
||||
min_trk_keep_alive=-1,
|
||||
max_trk_keep_alive=30,
|
||||
init_trk_keep_alive=30,
|
||||
suppress_overlapping_based_on_recent_occlusion_threshold=0.7,
|
||||
suppress_det_close_to_boundary=False,
|
||||
fill_hole_area=16,
|
||||
recondition_every_nth_frame=16,
|
||||
masklet_confirmation_enable=False,
|
||||
decrease_trk_keep_alive_for_empty_masklets=False,
|
||||
image_size=1008,
|
||||
image_mean=(0.5, 0.5, 0.5),
|
||||
image_std=(0.5, 0.5, 0.5),
|
||||
compile_model=compile,
|
||||
)
|
||||
else:
|
||||
# a version without any heuristics for ablation studies
|
||||
model = Sam3VideoInferenceWithInstanceInteractivity(
|
||||
detector=detector,
|
||||
tracker=tracker,
|
||||
score_threshold_detection=0.5,
|
||||
assoc_iou_thresh=0.1,
|
||||
det_nms_thresh=0.1,
|
||||
new_det_thresh=0.7,
|
||||
hotstart_delay=0,
|
||||
hotstart_unmatch_thresh=0,
|
||||
hotstart_dup_thresh=0,
|
||||
suppress_unmatched_only_within_hotstart=True,
|
||||
min_trk_keep_alive=-1,
|
||||
max_trk_keep_alive=30,
|
||||
init_trk_keep_alive=30,
|
||||
suppress_overlapping_based_on_recent_occlusion_threshold=0.7,
|
||||
suppress_det_close_to_boundary=False,
|
||||
fill_hole_area=16,
|
||||
recondition_every_nth_frame=0,
|
||||
masklet_confirmation_enable=False,
|
||||
decrease_trk_keep_alive_for_empty_masklets=False,
|
||||
image_size=1008,
|
||||
image_mean=(0.5, 0.5, 0.5),
|
||||
image_std=(0.5, 0.5, 0.5),
|
||||
compile_model=compile,
|
||||
)
|
||||
|
||||
# Load checkpoint if provided
|
||||
if load_from_HF and checkpoint_path is None:
|
||||
checkpoint_path = hf_hub_download(
|
||||
repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME
|
||||
)
|
||||
if checkpoint_path is not None:
|
||||
with g_pathmgr.open(checkpoint_path, "rb") as f:
|
||||
ckpt = torch.load(f, map_location="cpu", weights_only=True)
|
||||
if "model" in ckpt and isinstance(ckpt["model"], dict):
|
||||
ckpt = ckpt["model"]
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(
|
||||
ckpt, strict=strict_state_dict_loading
|
||||
)
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {unexpected_keys}")
|
||||
|
||||
model.to(device=device)
|
||||
return model
|
||||
|
||||
|
||||
def build_sam3_video_predictor(*model_args, gpus_to_use=None, **model_kwargs):
|
||||
return Sam3VideoPredictorMultiGPU(
|
||||
*model_args, gpus_to_use=gpus_to_use, **model_kwargs
|
||||
)
|
||||
Reference in New Issue
Block a user