diff --git a/sam3/model_builder.py b/sam3/model_builder.py index d4c67de..058bbec 100644 --- a/sam3/model_builder.py +++ b/sam3/model_builder.py @@ -5,10 +5,8 @@ from typing import Optional import torch import torch.nn as nn - from huggingface_hub import hf_hub_download from iopath.common.file_io import g_pathmgr - from sam3.model.decoder import ( TransformerDecoder, TransformerDecoderLayer, @@ -32,9 +30,7 @@ from sam3.model.model_misc import ( ) from sam3.model.necks import Sam3DualViTDetNeck from sam3.model.position_encoding import PositionEmbeddingSine - from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor - from sam3.model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU from sam3.model.sam3_tracking_predictor import Sam3TrackerPredictor from sam3.model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity @@ -45,9 +41,6 @@ from sam3.model.vitdet import ViT from sam3.model.vl_combiner import SAM3VLBackbone from sam3.sam.transformer import RoPEAttention -SAM3_MODEL_ID = "facebook/sam3" -SAM3_CKPT_NAME = "sam3.pt" - # Setup TensorFloat-32 for Ampere GPUs if available def _setup_tf32() -> None: @@ -633,9 +626,7 @@ def build_sam3_image_model( eval_mode, ) if load_from_HF and checkpoint_path is None: - checkpoint_path = hf_hub_download( - repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME - ) + checkpoint_path = download_ckpt_from_hf() # Load checkpoint if provided if checkpoint_path is not None: _load_checkpoint(model, checkpoint_path) @@ -646,6 +637,15 @@ def build_sam3_image_model( return model +def download_ckpt_from_hf(): + SAM3_MODEL_ID = "facebook/sam3" + SAM3_CKPT_NAME = "sam3.pt" + SAM3_CFG_NAME = "config.json" + _ = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CFG_NAME) + checkpoint_path = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME) + return checkpoint_path + + def build_sam3_video_model( checkpoint_path: Optional[str] = None, load_from_HF=True, @@ -768,9 +768,7 @@ def build_sam3_video_model( # Load checkpoint if provided if load_from_HF and checkpoint_path is None: - checkpoint_path = hf_hub_download( - repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME - ) + checkpoint_path = download_ckpt_from_hf() if checkpoint_path is not None: with g_pathmgr.open(checkpoint_path, "rb") as f: ckpt = torch.load(f, map_location="cpu", weights_only=True)