Download config file from HF
Summary: This DIff downloads `config.json` from HF in addition to the checkpoint (`sam3.pt`). This is needed to correctly track downloads. Reviewed By: alcinos Differential Revision: D87439565 Privacy Context Container: L1256182 fbshipit-source-id: 611ddde3e2e3fc24c4a70a0f44e43315e88a5763
This commit is contained in:
committed by
meta-codesync[bot]
parent
a13e358df4
commit
e0e2968a17
@@ -5,10 +5,8 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from sam3.model.decoder import (
|
||||
TransformerDecoder,
|
||||
TransformerDecoderLayer,
|
||||
@@ -32,9 +30,7 @@ from sam3.model.model_misc import (
|
||||
)
|
||||
from sam3.model.necks import Sam3DualViTDetNeck
|
||||
from sam3.model.position_encoding import PositionEmbeddingSine
|
||||
|
||||
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
|
||||
|
||||
from sam3.model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU
|
||||
from sam3.model.sam3_tracking_predictor import Sam3TrackerPredictor
|
||||
from sam3.model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity
|
||||
@@ -45,9 +41,6 @@ from sam3.model.vitdet import ViT
|
||||
from sam3.model.vl_combiner import SAM3VLBackbone
|
||||
from sam3.sam.transformer import RoPEAttention
|
||||
|
||||
SAM3_MODEL_ID = "facebook/sam3"
|
||||
SAM3_CKPT_NAME = "sam3.pt"
|
||||
|
||||
|
||||
# Setup TensorFloat-32 for Ampere GPUs if available
|
||||
def _setup_tf32() -> None:
|
||||
@@ -633,9 +626,7 @@ def build_sam3_image_model(
|
||||
eval_mode,
|
||||
)
|
||||
if load_from_HF and checkpoint_path is None:
|
||||
checkpoint_path = hf_hub_download(
|
||||
repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME
|
||||
)
|
||||
checkpoint_path = download_ckpt_from_hf()
|
||||
# Load checkpoint if provided
|
||||
if checkpoint_path is not None:
|
||||
_load_checkpoint(model, checkpoint_path)
|
||||
@@ -646,6 +637,15 @@ def build_sam3_image_model(
|
||||
return model
|
||||
|
||||
|
||||
def download_ckpt_from_hf():
|
||||
SAM3_MODEL_ID = "facebook/sam3"
|
||||
SAM3_CKPT_NAME = "sam3.pt"
|
||||
SAM3_CFG_NAME = "config.json"
|
||||
_ = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CFG_NAME)
|
||||
checkpoint_path = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME)
|
||||
return checkpoint_path
|
||||
|
||||
|
||||
def build_sam3_video_model(
|
||||
checkpoint_path: Optional[str] = None,
|
||||
load_from_HF=True,
|
||||
@@ -768,9 +768,7 @@ def build_sam3_video_model(
|
||||
|
||||
# Load checkpoint if provided
|
||||
if load_from_HF and checkpoint_path is None:
|
||||
checkpoint_path = hf_hub_download(
|
||||
repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME
|
||||
)
|
||||
checkpoint_path = download_ckpt_from_hf()
|
||||
if checkpoint_path is not None:
|
||||
with g_pathmgr.open(checkpoint_path, "rb") as f:
|
||||
ckpt = torch.load(f, map_location="cpu", weights_only=True)
|
||||
|
||||
Reference in New Issue
Block a user