Download config file from HF

Summary: This DIff downloads `config.json` from HF in addition to the checkpoint (`sam3.pt`). This is needed to correctly track downloads.

Reviewed By: alcinos

Differential Revision:
D87439565

Privacy Context Container: L1256182

fbshipit-source-id: 611ddde3e2e3fc24c4a70a0f44e43315e88a5763
This commit is contained in:
Haitham Khedr
2025-11-19 06:34:54 -08:00
committed by meta-codesync[bot]
parent a13e358df4
commit e0e2968a17

View File

@@ -5,10 +5,8 @@ from typing import Optional
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from iopath.common.file_io import g_pathmgr
from sam3.model.decoder import (
TransformerDecoder,
TransformerDecoderLayer,
@@ -32,9 +30,7 @@ from sam3.model.model_misc import (
)
from sam3.model.necks import Sam3DualViTDetNeck
from sam3.model.position_encoding import PositionEmbeddingSine
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
from sam3.model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU
from sam3.model.sam3_tracking_predictor import Sam3TrackerPredictor
from sam3.model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity
@@ -45,9 +41,6 @@ from sam3.model.vitdet import ViT
from sam3.model.vl_combiner import SAM3VLBackbone
from sam3.sam.transformer import RoPEAttention
SAM3_MODEL_ID = "facebook/sam3"
SAM3_CKPT_NAME = "sam3.pt"
# Setup TensorFloat-32 for Ampere GPUs if available
def _setup_tf32() -> None:
@@ -633,9 +626,7 @@ def build_sam3_image_model(
eval_mode,
)
if load_from_HF and checkpoint_path is None:
checkpoint_path = hf_hub_download(
repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME
)
checkpoint_path = download_ckpt_from_hf()
# Load checkpoint if provided
if checkpoint_path is not None:
_load_checkpoint(model, checkpoint_path)
@@ -646,6 +637,15 @@ def build_sam3_image_model(
return model
def download_ckpt_from_hf():
SAM3_MODEL_ID = "facebook/sam3"
SAM3_CKPT_NAME = "sam3.pt"
SAM3_CFG_NAME = "config.json"
_ = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CFG_NAME)
checkpoint_path = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME)
return checkpoint_path
def build_sam3_video_model(
checkpoint_path: Optional[str] = None,
load_from_HF=True,
@@ -768,9 +768,7 @@ def build_sam3_video_model(
# Load checkpoint if provided
if load_from_HF and checkpoint_path is None:
checkpoint_path = hf_hub_download(
repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME
)
checkpoint_path = download_ckpt_from_hf()
if checkpoint_path is not None:
with g_pathmgr.open(checkpoint_path, "rb") as f:
ckpt = torch.load(f, map_location="cpu", weights_only=True)