Initial commit

fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
facebook-github-bot
2025-11-18 23:07:42 -08:00
commit a13e358df4
504 changed files with 122758 additions and 0 deletions

1
sam3/train/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

View File

@@ -0,0 +1,279 @@
# @package _global_
defaults:
- _self_
# This config is the base configuration for all evaluations. Amongst other things, it defines:
# - the model
# - the image transforms
# - the post processors
# - cluster configuration (only relevant for slurm-based evals, ignored otherwise)
#
# Most of the parameters should be kept as-is. The main modifications you may want to make are:
# - the cluster configuration, to adjust partitions/qos to your system
# - the flag gather_pred_via_filesys if you ram is tight
# - num_val_workers if your number of cores is small (should be roughly number of cores / number of gpus)
# - the paths below
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
# If you leave the checkpoint path to null, the model will be downloaded from hugging-face. Otherwise provide a path
checkpoint_path: null
# the experiments will be subfolders of this
base_experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
# base path to the annotation folder for gold (refer to the readmes on how to download)
base_annotation_path: <YOUR_GOLD_GT_DIR>
# base path to the annotation folder for silver (refer to the readmes on how to download)
base_annotation_path_silver: <YOUR_SILVER_GT_DIR>
# path to the metaclip images, used for SA-Co gold (refer to the readme for instructions). Can be null if you don't intend on evaluating on this dataset.
metaclip_img_path: <YOUR_METACLIP_IMG_DIR>
# path to the sa1b images, used for SA-Co gold (refer to the readme for instructions). Can be null if you don't intend on evaluating on this dataset.
sa1b_img_path: <YOUR_SA1B_IMG_DIR>
# path to the SA-Co/silver images
silver_img_path: <YOUR_SILVER_IMG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
use_presence_eval: True
base_val_transform:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
######## transforms for validation (begin) ########
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: False
######## transforms for validation (end) ########
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
loss: null
# Model parameters
d_model: 256
input_box_embedding_dim: ${add:${scratch.d_model},2}
# Box processing
original_box_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 # infinite detections
use_original_ids: true
use_original_sizes_box: true
use_presence: ${scratch.use_presence_eval}
box_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 #infinite detections
use_original_ids: false
use_original_sizes_box: false
use_presence: ${scratch.use_presence_eval}
box_postprocessor_thresholded:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 #infinite detections
use_original_ids: false
use_original_sizes_box: false
detection_threshold: 0.3
use_presence: ${scratch.use_presence_eval}
mask_postprocessor_thresholded:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 #infinite detections
iou_type: "segm"
use_original_ids: false
use_original_sizes_box: false
use_original_sizes_mask: true
convert_mask_to_rle: True
detection_threshold: 0.3
use_presence: ${scratch.use_presence_eval}
# Image processing parameters
resolution: 1008
max_ann_per_img: 200
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
# Training parameters
train_batch_size: 1
val_batch_size: 1
num_train_workers: 0
num_val_workers: 10 # change this depending on the number of cpu cores available
max_data_epochs: 20
target_epoch_size: 1500
hybrid_repeats: 1
context_length: 2
# All reduce - this controls how the predictions are sent back to node 0.
# If you have a lot of ram, CPU gather is faster. Otherwise, we provide a fallback through filesystem (eg NFS)
# Switch to true if you get cpu ooms during gather.
gather_pred_via_filesys: false
# Learning rate and scheduler parameters (unused for eval)
lr_scale: 0.1
lr_transformer: ${times:8e-4,${scratch.lr_scale}}
lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
lrd_vision_backbone: 0.9 # (lower for in-domain adn higher for ood)
wd: 0.1
scheduler_timescale: 20
scheduler_warmup: 20
scheduler_cooldown: 20
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val: null
model:
_target_: sam3.model_builder.build_sam3_image_model
bpe_path: ${paths.bpe_path}
device: cpus
eval_mode: true
enable_segmentation: true # Warning: Enable this if using segmentation.
checkpoint_path: ${paths.checkpoint_path}
meters:
val: null
optim:
amp:
enabled: True
amp_dtype: bfloat16
optimizer:
_target_: torch.optim.AdamW
gradient_clip:
_target_: sam3.train.optim.optimizer.GradientClipper
max_norm: 0.1
norm_type: 2
param_group_modifiers:
- _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
_partial_: True
layer_decay_value: ${scratch.lrd_vision_backbone}
apply_to: 'backbone.vision_backbone.trunk'
overrides:
- pattern: '*pos_embed*'
value: 1.0
options:
lr:
- scheduler: # transformer and class_embed
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_transformer}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
- scheduler:
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_vision_backbone}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
param_names:
- 'backbone.vision_backbone.*'
- scheduler:
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_language_backbone}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
param_names:
- 'backbone.language_backbone.*'
weight_decay:
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: ${scratch.wd}
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: 0.0
param_names:
- '*bias*'
module_cls_names: ['torch.nn.LayerNorm']
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 4
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null # Add your SLURM account if use_cluster == 1
partition: null
qos: null # Add your QoS if use_cluster == 1
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,66 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_attributes/
coco_gt: ${paths.base_annotation_path}/gold_attributes_merged_a_release_test.json
coco_gts:
- ${paths.base_annotation_path}/gold_attributes_merged_a_release_test.json
- ${paths.base_annotation_path}/gold_attributes_merged_b_release_test.json
- ${paths.base_annotation_path}/gold_attributes_merged_c_release_test.json
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.metaclip_img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: gold_attributes
meters:
val:
gold_attributes: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_attributes
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "segm"

View File

@@ -0,0 +1,66 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_crowded/
coco_gt: ${paths.base_annotation_path}/gold_crowded_merged_a_release_test.json
coco_gts:
- ${paths.base_annotation_path}/gold_crowded_merged_a_release_test.json
- ${paths.base_annotation_path}/gold_crowded_merged_b_release_test.json
- ${paths.base_annotation_path}/gold_crowded_merged_c_release_test.json
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.metaclip_img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: gold_crowded
meters:
val:
gold_crowded: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_crowded
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "segm"

View File

@@ -0,0 +1,66 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_fg_food/
coco_gt: ${paths.base_annotation_path}/gold_fg_food_merged_a_release_test.json
coco_gts:
- ${paths.base_annotation_path}/gold_fg_food_merged_a_release_test.json
- ${paths.base_annotation_path}/gold_fg_food_merged_b_release_test.json
- ${paths.base_annotation_path}/gold_fg_food_merged_c_release_test.json
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.metaclip_img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: gold_fg_food
meters:
val:
gold_fg_food: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_fg_food
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "segm"

View File

@@ -0,0 +1,66 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_fg_sports_equipment/
coco_gt: ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_a_release_test.json
coco_gts:
- ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_a_release_test.json
- ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_b_release_test.json
- ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_c_release_test.json
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.metaclip_img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: gold_fg_sports_equipment
meters:
val:
gold_fg_sports_equipment: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_fg_sports_equipment
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "segm"

View File

@@ -0,0 +1,66 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_metaclip_nps/
coco_gt: ${paths.base_annotation_path}/gold_metaclip_merged_a_release_test.json
coco_gts:
- ${paths.base_annotation_path}/gold_metaclip_merged_a_release_test.json
- ${paths.base_annotation_path}/gold_metaclip_merged_b_release_test.json
- ${paths.base_annotation_path}/gold_metaclip_merged_c_release_test.json
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.metaclip_img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: gold_metaclip_nps
meters:
val:
gold_metaclip_nps: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_metaclip_nps
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "segm"

View File

@@ -0,0 +1,66 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_sa1b_nps/
coco_gt: ${paths.base_annotation_path}/gold_sa1b_merged_a_release_test.json
coco_gts:
- ${paths.base_annotation_path}/gold_sa1b_merged_a_release_test.json
- ${paths.base_annotation_path}/gold_sa1b_merged_b_release_test.json
- ${paths.base_annotation_path}/gold_sa1b_merged_c_release_test.json
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.sa1b_img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: gold_sa1b_nps
meters:
val:
gold_sa1b_nps: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_sa1b_nps
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "segm"

View File

@@ -0,0 +1,66 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_wiki_common/
coco_gt: ${paths.base_annotation_path}/gold_wiki_common_merged_a_release_test.json
coco_gts:
- ${paths.base_annotation_path}/gold_wiki_common_merged_a_release_test.json
- ${paths.base_annotation_path}/gold_wiki_common_merged_b_release_test.json
- ${paths.base_annotation_path}/gold_wiki_common_merged_c_release_test.json
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.metaclip_img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: gold_wiki_common
meters:
val:
gold_wiki_common: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_wiki_common
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gts}
iou_type: "segm"

View File

@@ -0,0 +1,255 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS}
paths:
odinw_data_root: <YOUR_DATA_DIR>
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}}
# Validation transforms pipeline
val_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution}
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: False
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
- _target_: sam3.train.transforms.filter_query_transforms.TextQueryToVisual
keep_text_queries: true # Note: set this to false if you only want visual
probability: 1.0 # always
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
enable_segmentation: True
# Box processing
use_presence_eval: True
original_box_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 # infinite detections
use_original_ids: true
use_original_sizes_box: true
use_presence: ${scratch.use_presence_eval}
# Image processing parameters
resolution: 1008
# Normalization parameters
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
# Training parameters
val_batch_size: 2
num_val_workers: 0
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
max_epochs: 1
accelerator: cuda
seed_value: 123
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
prompts: ${odinw35_prompts.${supercategory_tuple.name}}
include_negatives: true
category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories!
_partial_: true
img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder}
ann_file:
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
transforms: ${val_transforms}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: 1
dict_key: odinw35
model:
_target_: sam3.model_builder.build_sam3_image_model
bpe_path: ${paths.bpe_path}
device: cpus
eval_mode: true # Set to false if training
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
meters:
val:
odinw35:
detection:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "bbox"
dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${supercategory_tuple.name}
merge_predictions: True
postprocessor: ${scratch.original_box_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 100
pred_file_evaluators:
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
gt_path:
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
tide: False
iou_type: "bbox"
positive_split: true
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name}
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 1
gpus_per_node: 2
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null
job_array:
num_tasks: 13
task_index: 0
# ============================================================================
# ODinW13 Supercategories
# ============================================================================
all_odinw_supercategories:
- name: AerialMaritimeDrone_large
val:
img_folder: AerialMaritimeDrone/large/test/
json: AerialMaritimeDrone/large/test/annotations_without_background.json
- name: Aquarium
val:
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json
- name: CottontailRabbits
val:
img_folder: CottontailRabbits/test/
json: CottontailRabbits/test/annotations_without_background.json
- name: EgoHands_generic
val:
img_folder: EgoHands/generic/test/
json: EgoHands/generic/test/annotations_without_background.json
- name: NorthAmericaMushrooms
val:
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json
- name: Packages
val:
img_folder: Packages/Raw/test/
json: Packages/Raw/test/annotations_without_background.json
- name: PascalVOC
val:
img_folder: PascalVOC/valid/
json: PascalVOC/valid/annotations_without_background.json
- name: Raccoon
val:
img_folder: Raccoon/Raccoon.v2-raw.coco/test/
json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json
- name: ShellfishOpenImages
val:
img_folder: ShellfishOpenImages/raw/test/
json: ShellfishOpenImages/raw/test/annotations_without_background.json
- name: VehiclesOpenImages
val:
img_folder: VehiclesOpenImages/416x416/test/
json: VehiclesOpenImages/416x416/test/annotations_without_background.json
- name: pistols
val:
img_folder: pistols/export/
json: pistols/export/test_annotations_without_background.json
- name: pothole
val:
img_folder: pothole/test/
json: pothole/test/annotations_without_background.json
- name: thermalDogsAndPeople
val:
img_folder: thermalDogsAndPeople/test/
json: thermalDogsAndPeople/test/annotations_without_background.json
odinw35_prompts:
AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"},
{"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock",
"supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"},
{"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]'
Aquarium: null
CottontailRabbits: null
EgoHands_generic: null
NorthAmericaMushrooms: '[{''id'': 1, ''name'':
''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]'
Packages: null
PascalVOC: null
Raccoon: null
ShellfishOpenImages: null
VehiclesOpenImages: null
pistols: null
pothole: null
thermalDogsAndPeople: null

View File

@@ -0,0 +1,253 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS}
paths:
odinw_data_root: <YOUR_DATA_DIR>
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}}
# Validation transforms pipeline
val_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution}
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: False
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
enable_segmentation: True
# Box processing
use_presence_eval: True
original_box_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 # infinite detections
use_original_ids: true
use_original_sizes_box: true
use_presence: ${scratch.use_presence_eval}
# Image processing parameters
resolution: 1008
# Normalization parameters
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
# Training parameters
val_batch_size: 2
num_val_workers: 0
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
max_epochs: 1
accelerator: cuda
seed_value: 123
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
prompts: ${odinw35_prompts.${supercategory_tuple.name}}
include_negatives: true
category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories!
_partial_: true
img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder}
ann_file:
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
transforms: ${val_transforms}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: 1
dict_key: odinw35
model:
_target_: sam3.model_builder.build_sam3_image_model
bpe_path: ${paths.bpe_path}
device: cpus
eval_mode: true # Set to false if training
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
meters:
val:
odinw35:
detection:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "bbox"
dump_dir: ${launcher.experiment_log_dir}/dumps/odinw/${supercategory_tuple.name}
merge_predictions: True
postprocessor: ${scratch.original_box_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 100
pred_file_evaluators:
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
gt_path:
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
tide: False
iou_type: "bbox"
positive_split: False
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name}
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 1
gpus_per_node: 2
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null
job_array:
num_tasks: 13
task_index: 0
# ============================================================================
# ODinW13 Supercategories
# ============================================================================
all_odinw_supercategories:
- name: AerialMaritimeDrone_large
val:
img_folder: AerialMaritimeDrone/large/test/
json: AerialMaritimeDrone/large/test/annotations_without_background.json
- name: Aquarium
val:
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json
- name: CottontailRabbits
val:
img_folder: CottontailRabbits/test/
json: CottontailRabbits/test/annotations_without_background.json
- name: EgoHands_generic
val:
img_folder: EgoHands/generic/test/
json: EgoHands/generic/test/annotations_without_background.json
- name: NorthAmericaMushrooms
val:
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json
- name: Packages
val:
img_folder: Packages/Raw/test/
json: Packages/Raw/test/annotations_without_background.json
- name: PascalVOC
val:
img_folder: PascalVOC/valid/
json: PascalVOC/valid/annotations_without_background.json
- name: Raccoon
val:
img_folder: Raccoon/Raccoon.v2-raw.coco/test/
json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json
- name: ShellfishOpenImages
val:
img_folder: ShellfishOpenImages/raw/test/
json: ShellfishOpenImages/raw/test/annotations_without_background.json
- name: VehiclesOpenImages
val:
img_folder: VehiclesOpenImages/416x416/test/
json: VehiclesOpenImages/416x416/test/annotations_without_background.json
- name: pistols
val:
img_folder: pistols/export/
json: pistols/export/test_annotations_without_background.json
- name: pothole
val:
img_folder: pothole/test/
json: pothole/test/annotations_without_background.json
- name: thermalDogsAndPeople
val:
img_folder: thermalDogsAndPeople/test/
json: thermalDogsAndPeople/test/annotations_without_background.json
odinw35_prompts:
AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"},
{"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock",
"supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"},
{"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]'
Aquarium: null
CottontailRabbits: null
EgoHands_generic: null
NorthAmericaMushrooms: '[{''id'': 1, ''name'':
''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]'
Packages: null
PascalVOC: null
Raccoon: null
ShellfishOpenImages: null
VehiclesOpenImages: null
pistols: null
pothole: null
thermalDogsAndPeople: null

View File

@@ -0,0 +1,253 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS}
paths:
odinw_data_root: <YOUR_DATA_DIR>
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}}
# Validation transforms pipeline
val_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution}
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: False
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
enable_segmentation: True
# Box processing
use_presence_eval: True
original_box_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 # infinite detections
use_original_ids: true
use_original_sizes_box: true
use_presence: ${scratch.use_presence_eval}
# Image processing parameters
resolution: 1008
# Normalization parameters
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
# Training parameters
val_batch_size: 2
num_val_workers: 0
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
max_epochs: 1
accelerator: cuda
seed_value: 123
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
prompts: ${odinw35_prompts.${supercategory_tuple.name}}
include_negatives: true
category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories!
_partial_: true
img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder}
ann_file:
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
transforms: ${val_transforms}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: 1
dict_key: odinw35
model:
_target_: sam3.model_builder.build_sam3_image_model
bpe_path: ${paths.bpe_path}
device: cpus
eval_mode: true # Set to false if training
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
meters:
val:
odinw35:
detection:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "bbox"
dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${supercategory_tuple.name}
merge_predictions: True
postprocessor: ${scratch.original_box_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 100
pred_file_evaluators:
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
gt_path:
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
tide: False
iou_type: "bbox"
positive_split: true
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name}
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 1
gpus_per_node: 2
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null
job_array:
num_tasks: 13
task_index: 0
# ============================================================================
# ODinW13 Supercategories
# ============================================================================
all_odinw_supercategories:
- name: AerialMaritimeDrone_large
val:
img_folder: AerialMaritimeDrone/large/test/
json: AerialMaritimeDrone/large/test/annotations_without_background.json
- name: Aquarium
val:
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json
- name: CottontailRabbits
val:
img_folder: CottontailRabbits/test/
json: CottontailRabbits/test/annotations_without_background.json
- name: EgoHands_generic
val:
img_folder: EgoHands/generic/test/
json: EgoHands/generic/test/annotations_without_background.json
- name: NorthAmericaMushrooms
val:
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json
- name: Packages
val:
img_folder: Packages/Raw/test/
json: Packages/Raw/test/annotations_without_background.json
- name: PascalVOC
val:
img_folder: PascalVOC/valid/
json: PascalVOC/valid/annotations_without_background.json
- name: Raccoon
val:
img_folder: Raccoon/Raccoon.v2-raw.coco/test/
json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json
- name: ShellfishOpenImages
val:
img_folder: ShellfishOpenImages/raw/test/
json: ShellfishOpenImages/raw/test/annotations_without_background.json
- name: VehiclesOpenImages
val:
img_folder: VehiclesOpenImages/416x416/test/
json: VehiclesOpenImages/416x416/test/annotations_without_background.json
- name: pistols
val:
img_folder: pistols/export/
json: pistols/export/test_annotations_without_background.json
- name: pothole
val:
img_folder: pothole/test/
json: pothole/test/annotations_without_background.json
- name: thermalDogsAndPeople
val:
img_folder: thermalDogsAndPeople/test/
json: thermalDogsAndPeople/test/annotations_without_background.json
odinw35_prompts:
AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"},
{"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock",
"supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"},
{"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]'
Aquarium: null
CottontailRabbits: null
EgoHands_generic: null
NorthAmericaMushrooms: '[{''id'': 1, ''name'':
''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]'
Packages: null
PascalVOC: null
Raccoon: null
ShellfishOpenImages: null
VehiclesOpenImages: null
pistols: null
pothole: null
thermalDogsAndPeople: null

View File

@@ -0,0 +1,591 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS}
paths:
odinw_data_root: <YOUR_DATA_DIR>
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
odinw_train:
train_file: fewshot_train_shot10_seed300
num_images: null
supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}}
# Training transforms pipeline
train_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterCrowds
- _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox
box_noise_std: 0.1
box_noise_max: 20
- _target_: sam3.train.transforms.segmentation.DecodeRle
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes:
_target_: sam3.train.transforms.basic.get_random_resize_scales
size: ${scratch.resolution}
min_size: 480
rounded: false
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: ${scratch.consistent_transform}
- _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI
size: ${scratch.resolution}
consistent_transform: ${scratch.consistent_transform}
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.train_norm_mean}
std: ${scratch.train_norm_std}
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut
max_num_objects: ${scratch.max_ann_per_img}
# Validation transforms pipeline
val_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution}
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: False
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# loss config (no mask loss)
loss:
_target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
matcher: ${scratch.matcher}
o2m_weight: 2.0
o2m_matcher:
_target_: sam3.train.matcher.BinaryOneToManyMatcher
alpha: 0.3
threshold: 0.4
topk: 4
use_o2m_matcher_on_o2m_aux: ${scratch.use_o2m_matcher_on_o2m_aux}
loss_fns_find:
- _target_: sam3.train.loss.loss_fns.Boxes
weight_dict:
loss_bbox: 5.0
loss_giou: 2.0
- _target_: sam3.train.loss.loss_fns.IABCEMdetr
weak_loss: False
weight_dict:
loss_ce: ${scratch.loss_ce_weight} # Change
presence_loss: ${scratch.presence_weight} # Change
pos_weight: ${scratch.iabce_pos_weight}
alpha: ${scratch.iabce_alpha}
gamma: 2
use_presence: True # Change
pos_focal: ${scratch.iabce_pos_focal}
pad_n_queries: ${scratch.num_queries}
pad_scale_pos: ${scratch.instance_query_loss_pad_scale_pos}
loss_fn_semantic_seg: null
scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
enable_segmentation: False
use_act_checkpoint_geo_encoder: True
input_geometry_encoder:
_target_: sam3.model.geometry_encoders.SequenceGeometryEncoder
pos_enc: ${scratch.pos_embed}
encode_boxes_as_points: False
points_direct_project: True
points_pool: True
points_pos_enc: True
boxes_direct_project: True
boxes_pool: True
boxes_pos_enc: True
d_model: ${scratch.d_model}
num_layers: 3
use_act_ckpt: ${scratch.use_act_checkpoint_geo_encoder}
layer:
_target_: sam3.model.encoder.TransformerEncoderLayer
activation: "relu"
d_model: ${scratch.d_model}
dim_feedforward: 2048
dropout: ${scratch.encoder_dropout}
pos_enc_at_attn: false
pre_norm: True
pos_enc_at_cross_attn_queries: false
pos_enc_at_cross_attn_keys: true
self_attention:
_target_: sam3.model.attention.MultiheadAttention
attn_type: Vanilla
num_heads: 8
dropout: ${scratch.encoder_dropout}
embed_dim: ${scratch.d_model}
batch_first: False
cross_attention:
_target_: sam3.model.attention.MultiheadAttention
attn_type: Vanilla
num_heads: 8
dropout: ${scratch.encoder_dropout}
embed_dim: ${scratch.d_model}
batch_first: False
add_cls: true
add_post_encode_proj: True
boxRPB: "log"
dac: True
use_early_fusion: true
o2m_mask: false
num_feature_levels: 1 # > 1 not implemented
encoder_dropout: 0.1
decoder_dropout: 0.1
tokenizer_ve:
_target_: sam3.model.tokenizer_ve.SimpleTokenizer
bpe_path: ${paths.bpe_path}
freeze_text_tower: False
freeze_image_tower: NoFreeze
vis_backbone_dp: 0.0
# Activation checkpointing (Save memory)
use_act_checkpoint_vision_backbone: True
use_act_checkpoint_text_backbone: True
use_act_checkpoint_encoder: True
use_act_checkpoint_decoder: True
loss: null
# Loss parameters
num_queries: 200
presence_weight: 20.0
loss_ce_weight: 20.0
iabce_pos_weight: 5.0
iabce_pos_focal: false
iabce_alpha: 0.25
instance_query_loss_pad_scale_pos: 1.0
use_o2m_matcher_on_o2m_aux: false
# Model parameters
use_instance_query: true
d_model: 256
pos_embed:
_target_: sam3.model.position_encoding.PositionEmbeddingSine
num_pos_feats: ${scratch.d_model}
normalize: true
scale: null
temperature: 10000
# Box processing
use_presence_eval: True
original_box_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 # infinite detections
use_original_ids: true
use_original_sizes_box: true
use_presence: ${scratch.use_presence_eval}
# Matcher configuration
matcher:
_target_: sam3.train.matcher.BinaryHungarianMatcherV2
focal: true
cost_class: 2.0
cost_bbox: 5.0
cost_giou: 2.0
alpha: 0.25
gamma: 2
stable: False
scale_by_find_batch_size: True
# Image processing parameters
resolution: 1008
consistent_transform: False
max_ann_per_img: 200
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
# Training parameters
train_batch_size: 1
val_batch_size: 1
num_train_workers: 0
num_val_workers: 0
max_data_epochs: 40
target_epoch_size: 1500
hybrid_repeats: 1
context_length: 2
gather_pred_via_filesys: false
# Learning rate and scheduler parameters
lr_scale: 0.1
lr_transformer: ${times:8e-4,${scratch.lr_scale}}
lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
lrd_vision_backbone: 0.9
wd: 0.1
scheduler_timescale: 20
scheduler_warmup: 20
scheduler_cooldown: 20
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
# _target_: sam3.train.trainer.Trainer
# skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: train
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all: ${odinw_train.loss}
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
limit_ids: ${odinw_train.num_images}
transforms: ${odinw_train.train_transforms}
load_segmentation: ${scratch.enable_segmentation}
max_ann_per_img: 500000
multiplier: 1
max_train_queries: 50000
max_val_queries: 50000
training: true
use_caching: False
img_folder: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.train.img_folder}
ann_file:
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
input_json_path: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.train.json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
prompts: ${odinw35_prompts.${odinw_train.supercategory_tuple.name}} #${odinw_train.supercategory_tuple.name)
_partial_: true
shuffle: True
batch_size: ${scratch.train_batch_size}
num_workers: ${scratch.num_train_workers}
pin_memory: False
drop_last: True
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: all
with_seg_masks: ${scratch.enable_segmentation}
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
load_segmentation: ${scratch.enable_segmentation}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
prompts: ${odinw35_prompts.${odinw_train.supercategory_tuple.name}}
include_negatives: true
category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories!
_partial_: true
img_folder: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.val.img_folder}
ann_file:
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
input_json_path: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.val.json}
transforms: ${odinw_train.val_transforms}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: 1
dict_key: odinw35
with_seg_masks: ${scratch.enable_segmentation}
model:
_target_: sam3.model_builder.build_sam3_image_model
bpe_path: ${paths.bpe_path}
device: cpus
eval_mode: false # Set to false if training
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
meters:
val:
odinw35:
detection:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "bbox"
dump_dir: ${launcher.experiment_log_dir}/dumps/odinw/${odinw_train.supercategory_tuple.name}
merge_predictions: True
postprocessor: ${scratch.original_box_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 100
pred_file_evaluators:
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
gt_path:
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
input_json_path: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.val.json}
tide: False
iou_type: "bbox"
positive_split: False
optim:
amp:
enabled: True
amp_dtype: bfloat16
optimizer:
_target_: torch.optim.AdamW
gradient_clip:
_target_: sam3.train.optim.optimizer.GradientClipper
max_norm: 0.1
norm_type: 2
param_group_modifiers:
- _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
_partial_: True
layer_decay_value: ${scratch.lrd_vision_backbone}
apply_to: 'backbone.vision_backbone.trunk'
overrides:
- pattern: '*pos_embed*'
value: 1.0
options:
lr:
- scheduler: # transformer and class_embed
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_transformer}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
- scheduler:
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_vision_backbone}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
param_names:
- 'backbone.vision_backbone.*'
- scheduler:
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_language_backbone}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
param_names:
- 'backbone.language_backbone.*'
weight_decay:
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: ${scratch.wd}
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: 0.0
param_names:
- '*bias*'
module_cls_names: ['torch.nn.LayerNorm']
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/${odinw_train.supercategory_tuple.name}
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 1
gpus_per_node: 2
experiment_log_dir: null #${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null
# task_index: 2
# Uncomment for job array configuration
job_array:
num_tasks: 13
task_index: 0
# ============================================================================
# ODinW13 Supercategories
# ============================================================================
all_odinw_supercategories:
- name: AerialMaritimeDrone_large
val:
img_folder: AerialMaritimeDrone/large/test/
json: AerialMaritimeDrone/large/test/annotations_without_background.json
train:
img_folder: AerialMaritimeDrone/large/train/
json: AerialMaritimeDrone/large/train/${odinw_train.train_file}.json
- name: Aquarium
val:
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json
train:
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/train/
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/train/${odinw_train.train_file}.json
- name: CottontailRabbits
val:
img_folder: CottontailRabbits/test/
json: CottontailRabbits/test/annotations_without_background.json
train:
img_folder: CottontailRabbits/train/
json: CottontailRabbits/train/${odinw_train.train_file}.json
- name: EgoHands_generic
val:
img_folder: EgoHands/generic/test/
json: EgoHands/generic/test/annotations_without_background.json
train:
img_folder: EgoHands/generic/train/
json: EgoHands/generic/train/${odinw_train.train_file}.json
- name: NorthAmericaMushrooms
val:
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json
train:
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/train/
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/train/${odinw_train.train_file}.json
- name: Packages
val:
img_folder: Packages/Raw/test/
json: Packages/Raw/test/annotations_without_background.json
train:
img_folder: Packages/Raw/train/
json: Packages/Raw/train/${odinw_train.train_file}.json
- name: PascalVOC
val:
img_folder: PascalVOC/valid/
json: PascalVOC/valid/annotations_without_background.json
train:
img_folder: PascalVOC/train/
json: PascalVOC/train/${odinw_train.train_file}.json
- name: Raccoon
val:
img_folder: Raccoon/Raccoon.v2-raw.coco/test/
json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json
train:
img_folder: Raccoon/Raccoon.v2-raw.coco/train/
json: Raccoon/Raccoon.v2-raw.coco/train/${odinw_train.train_file}.json
- name: ShellfishOpenImages
val:
img_folder: ShellfishOpenImages/raw/test/
json: ShellfishOpenImages/raw/test/annotations_without_background.json
train:
img_folder: ShellfishOpenImages/raw/train/
json: ShellfishOpenImages/raw/train/${odinw_train.train_file}.json
- name: VehiclesOpenImages
val:
img_folder: VehiclesOpenImages/416x416/test/
json: VehiclesOpenImages/416x416/test/annotations_without_background.json
train:
img_folder: VehiclesOpenImages/416x416/train/
json: VehiclesOpenImages/416x416/train/${odinw_train.train_file}.json
- name: pistols
val:
img_folder: pistols/export/
json: pistols/export/test_annotations_without_background.json
train:
img_folder: pistols/export/
json: pistols/export/${odinw_train.train_file}.json
- name: pothole
val:
img_folder: pothole/test/
json: pothole/test/annotations_without_background.json
train:
img_folder: pothole/train/
json: pothole/train/${odinw_train.train_file}.json
- name: thermalDogsAndPeople
val:
img_folder: thermalDogsAndPeople/test/
json: thermalDogsAndPeople/test/annotations_without_background.json
train:
img_folder: thermalDogsAndPeople/train/
json: thermalDogsAndPeople/train/${odinw_train.train_file}.json
odinw35_prompts:
AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"},
{"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock",
"supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"},
{"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]'
Aquarium: null
CottontailRabbits: null
EgoHands_generic: null
NorthAmericaMushrooms: '[{''id'': 1, ''name'':
''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]'
Packages: null
PascalVOC: null
Raccoon: null
ShellfishOpenImages: null
VehiclesOpenImages: null
pistols: null
pothole: null
thermalDogsAndPeople: null

View File

@@ -0,0 +1,256 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS}
paths:
odinw_data_root: <YOUR_DATA_DIR>
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}}
# Validation transforms pipeline
val_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution}
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: False
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
- _target_: sam3.train.transforms.filter_query_transforms.TextQueryToVisual
keep_text_queries: false # Note: set this to false if you only want visual
probability: 1.0 # always
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
enable_segmentation: True
# Box processing
use_presence_eval: True
original_box_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 # infinite detections
use_original_ids: true
use_original_sizes_box: true
use_presence: ${scratch.use_presence_eval}
# Image processing parameters
resolution: 1008
# Normalization parameters
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
# Training parameters
val_batch_size: 2
num_val_workers: 0
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
max_epochs: 1
accelerator: cuda
seed_value: 123
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
prompts: ${odinw35_prompts.${supercategory_tuple.name}}
include_negatives: true
category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories!
_partial_: true
img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder}
ann_file:
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
transforms: ${val_transforms}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: 1
dict_key: odinw35
model:
_target_: sam3.model_builder.build_sam3_image_model
bpe_path: ${paths.bpe_path}
device: cpus
eval_mode: true # Set to false if training
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
meters:
val:
odinw35:
detection:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "bbox"
dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${supercategory_tuple.name}
merge_predictions: True
postprocessor: ${scratch.original_box_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 100
pred_file_evaluators:
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
gt_path:
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
tide: False
iou_type: "bbox"
positive_split: true
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name}
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 1
gpus_per_node: 2
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null
job_array:
num_tasks: 13
task_index: 0
# ============================================================================
# ODinW13 Supercategories
# ============================================================================
all_odinw_supercategories:
- name: AerialMaritimeDrone_large
val:
img_folder: AerialMaritimeDrone/large/test/
json: AerialMaritimeDrone/large/test/annotations_without_background.json
- name: Aquarium
val:
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json
- name: CottontailRabbits
val:
img_folder: CottontailRabbits/test/
json: CottontailRabbits/test/annotations_without_background.json
- name: EgoHands_generic
val:
img_folder: EgoHands/generic/test/
json: EgoHands/generic/test/annotations_without_background.json
- name: NorthAmericaMushrooms
val:
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json
- name: Packages
val:
img_folder: Packages/Raw/test/
json: Packages/Raw/test/annotations_without_background.json
- name: PascalVOC
val:
img_folder: PascalVOC/valid/
json: PascalVOC/valid/annotations_without_background.json
- name: Raccoon
val:
img_folder: Raccoon/Raccoon.v2-raw.coco/test/
json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json
- name: ShellfishOpenImages
val:
img_folder: ShellfishOpenImages/raw/test/
json: ShellfishOpenImages/raw/test/annotations_without_background.json
- name: VehiclesOpenImages
val:
img_folder: VehiclesOpenImages/416x416/test/
json: VehiclesOpenImages/416x416/test/annotations_without_background.json
- name: pistols
val:
img_folder: pistols/export/
json: pistols/export/test_annotations_without_background.json
- name: pothole
val:
img_folder: pothole/test/
json: pothole/test/annotations_without_background.json
- name: thermalDogsAndPeople
val:
img_folder: thermalDogsAndPeople/test/
json: thermalDogsAndPeople/test/annotations_without_background.json
odinw35_prompts:
AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"},
{"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock",
"supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"},
{"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]'
Aquarium: null
CottontailRabbits: null
EgoHands_generic: null
NorthAmericaMushrooms: '[{''id'': 1, ''name'':
''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]'
Packages: null
PascalVOC: null
Raccoon: null
ShellfishOpenImages: null
VehiclesOpenImages: null
pistols: null
pothole: null
thermalDogsAndPeople: null

View File

@@ -0,0 +1,539 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
roboflow_vl_100_root: <YOUR_DATASET_DIR>
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
# Roboflow dataset configuration
roboflow_train:
num_images: 100 # Note: This is the number of images used for training. If null, all images are used.
supercategory: ${all_roboflow_supercategories.${string:${submitit.job_array.task_index}}}
# Training transforms pipeline
train_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterCrowds
- _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox
box_noise_std: 0.1
box_noise_max: 20
- _target_: sam3.train.transforms.segmentation.DecodeRle
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes:
_target_: sam3.train.transforms.basic.get_random_resize_scales
size: ${scratch.resolution}
min_size: 480
rounded: false
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: ${scratch.consistent_transform}
- _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI
size: ${scratch.resolution}
consistent_transform: ${scratch.consistent_transform}
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.train_norm_mean}
std: ${scratch.train_norm_std}
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut
max_num_objects: ${scratch.max_ann_per_img}
# Validation transforms pipeline
val_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution}
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: False
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.train_norm_mean}
std: ${scratch.train_norm_std}
# loss config (no mask loss)
loss:
_target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
matcher: ${scratch.matcher}
o2m_weight: 2.0
o2m_matcher:
_target_: sam3.train.matcher.BinaryOneToManyMatcher
alpha: 0.3
threshold: 0.4
topk: 4
use_o2m_matcher_on_o2m_aux: false # Another option is true
loss_fns_find:
- _target_: sam3.train.loss.loss_fns.Boxes
weight_dict:
loss_bbox: 5.0
loss_giou: 2.0
- _target_: sam3.train.loss.loss_fns.IABCEMdetr
weak_loss: False
weight_dict:
loss_ce: 20.0 # Another option is 100.0
presence_loss: 20.0
pos_weight: 10.0 # Another option is 5.0
alpha: 0.25
gamma: 2
use_presence: True # Change
pos_focal: false
pad_n_queries: 200
pad_scale_pos: 1.0
loss_fn_semantic_seg: null
scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
# NOTE: Loss to be used for training in case of segmentation
# loss:
# _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
# matcher: ${scratch.matcher}
# o2m_weight: 2.0
# o2m_matcher:
# _target_: sam3.train.matcher.BinaryOneToManyMatcher
# alpha: 0.3
# threshold: 0.4
# topk: 4
# use_o2m_matcher_on_o2m_aux: false
# loss_fns_find:
# - _target_: sam3.train.loss.loss_fns.Boxes
# weight_dict:
# loss_bbox: 5.0
# loss_giou: 2.0
# - _target_: sam3.train.loss.loss_fns.IABCEMdetr
# weak_loss: False
# weight_dict:
# loss_ce: 20.0 # Another option is 100.0
# presence_loss: 20.0
# pos_weight: 10.0 # Another option is 5.0
# alpha: 0.25
# gamma: 2
# use_presence: True # Change
# pos_focal: false
# pad_n_queries: 200
# pad_scale_pos: 1.0
# - _target_: sam3.train.loss.loss_fns.Masks
# focal_alpha: 0.25
# focal_gamma: 2.0
# weight_dict:
# loss_mask: 200.0
# loss_dice: 10.0
# compute_aux: false
# loss_fn_semantic_seg:
# _target_: sam3.losses.loss_fns.SemanticSegCriterion
# presence_head: True
# presence_loss: False # Change
# focal: True
# focal_alpha: 0.6
# focal_gamma: 2.0
# downsample: False
# weight_dict:
# loss_semantic_seg: 20.0
# loss_semantic_presence: 1.0
# loss_semantic_dice: 30.0
# scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
enable_segmentation: False # NOTE: This is the number of queries used for segmentation
# Model parameters
d_model: 256
pos_embed:
_target_: sam3.model.position_encoding.PositionEmbeddingSine
num_pos_feats: ${scratch.d_model}
normalize: true
scale: null
temperature: 10000
# Box processing
use_presence_eval: True
original_box_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 # infinite detections
use_original_ids: true
use_original_sizes_box: true
use_presence: ${scratch.use_presence_eval}
# Matcher configuration
matcher:
_target_: sam3.train.matcher.BinaryHungarianMatcherV2
focal: true # with `focal: true` it is equivalent to BinaryFocalHungarianMatcher
cost_class: 2.0
cost_bbox: 5.0
cost_giou: 2.0
alpha: 0.25
gamma: 2
stable: False
scale_by_find_batch_size: True
# Image processing parameters
resolution: 1008
consistent_transform: False
max_ann_per_img: 200
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
# Training parameters
num_train_workers: 10
num_val_workers: 0
max_data_epochs: 20
target_epoch_size: 1500
hybrid_repeats: 1
context_length: 2
gather_pred_via_filesys: false
# Learning rate and scheduler parameters
lr_scale: 0.1
lr_transformer: ${times:8e-4,${scratch.lr_scale}}
lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
lrd_vision_backbone: 0.9
wd: 0.1
scheduler_timescale: 20
scheduler_warmup: 20
scheduler_cooldown: 20
val_batch_size: 1
collate_fn_val:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: roboflow100
with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
gradient_accumulation_steps: 1
train_batch_size: 1
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: all
with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: 20
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
gradient_accumulation_steps: ${scratch.gradient_accumulation_steps}
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all: ${roboflow_train.loss}
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
limit_ids: ${roboflow_train.num_images}
transforms: ${roboflow_train.train_transforms}
load_segmentation: ${scratch.enable_segmentation}
max_ann_per_img: 500000
multiplier: 1
max_train_queries: 50000
max_val_queries: 50000
training: true
use_caching: False
img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/
ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/_annotations.coco.json
shuffle: True
batch_size: ${scratch.train_batch_size}
num_workers: ${scratch.num_train_workers}
pin_memory: True
drop_last: True
collate_fn: ${scratch.collate_fn}
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
load_segmentation: ${scratch.enable_segmentation}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
include_negatives: true
category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
_partial_: true
img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/
ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
transforms: ${roboflow_train.val_transforms}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn: ${scratch.collate_fn_val}
model:
_target_: sam3.model_builder.build_sam3_image_model
bpe_path: ${paths.bpe_path}
device: cpus
eval_mode: true
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
meters:
val:
roboflow100:
detection:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "bbox"
dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${roboflow_train.supercategory}
merge_predictions: True
postprocessor: ${scratch.original_box_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 100
pred_file_evaluators:
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
gt_path: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
tide: False
iou_type: "bbox"
optim:
amp:
enabled: True
amp_dtype: bfloat16
optimizer:
_target_: torch.optim.AdamW
gradient_clip:
_target_: sam3.train.optim.optimizer.GradientClipper
max_norm: 0.1
norm_type: 2
param_group_modifiers:
- _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
_partial_: True
layer_decay_value: ${scratch.lrd_vision_backbone}
apply_to: 'backbone.vision_backbone.trunk'
overrides:
- pattern: '*pos_embed*'
value: 1.0
options:
lr:
- scheduler: # transformer and class_embed
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_transformer}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
- scheduler:
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_vision_backbone}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
param_names:
- 'backbone.vision_backbone.*'
- scheduler:
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_language_backbone}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
param_names:
- 'backbone.language_backbone.*'
weight_decay:
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: ${scratch.wd}
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: 0.0
param_names:
- '*bias*'
module_cls_names: ['torch.nn.LayerNorm']
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/${roboflow_train.supercategory}
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 1
gpus_per_node: 2
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null
# Uncomment for job array configuration
job_array:
num_tasks: 100
task_index: 0
# ============================================================================
# Available Roboflow Supercategories (for reference)
# ============================================================================
all_roboflow_supercategories:
- -grccs
- zebrasatasturias
- cod-mw-warzone
- canalstenosis
- label-printing-defect-version-2
- new-defects-in-wood
- orionproducts
- aquarium-combined
- varroa-mites-detection--test-set
- clashroyalechardetector
- stomata-cells
- halo-infinite-angel-videogame
- pig-detection
- urine-analysis1
- aerial-sheep
- orgharvest
- actions
- mahjong
- liver-disease
- needle-base-tip-min-max
- wheel-defect-detection
- aircraft-turnaround-dataset
- xray
- wildfire-smoke
- spinefrxnormalvindr
- ufba-425
- speech-bubbles-detection
- train
- pill
- truck-movement
- car-logo-detection
- inbreast
- sea-cucumbers-new-tiles
- uavdet-small
- penguin-finder-seg
- aerial-airport
- bibdetection
- taco-trash-annotations-in-context
- bees
- recode-waste
- screwdetectclassification
- wine-labels
- aerial-cows
- into-the-vale
- gwhd2021
- lacrosse-object-detection
- defect-detection
- dataconvert
- x-ray-id
- ball
- tube
- 2024-frc
- crystal-clean-brain-tumors-mri-dataset
- grapes-5
- human-detection-in-floods
- buoy-onboarding
- apoce-aerial-photographs-for-object-detection-of-construction-equipment
- l10ul502
- floating-waste
- deeppcb
- ism-band-packet-detection
- weeds4
- invoice-processing
- thermal-cheetah
- tomatoes-2
- marine-sharks
- peixos-fish
- sssod
- aerial-pool
- countingpills
- asphaltdistressdetection
- roboflow-trained-dataset
- everdaynew
- underwater-objects
- soda-bottles
- dentalai
- jellyfish
- deepfruits
- activity-diagrams
- circuit-voltages
- all-elements
- macro-segmentation
- exploratorium-daphnia
- signatures
- conveyor-t-shirts
- fruitjes
- grass-weeds
- infraredimageofpowerequipment
- 13-lkc01
- wb-prova
- flir-camera-objects
- paper-parts
- football-player-detection
- trail-camera
- smd-components
- water-meter
- nih-xray
- the-dreidel-project
- electric-pylon-detection-in-rsi
- cable-damage

View File

@@ -0,0 +1,539 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
roboflow_vl_100_root: <YOUR_DATASET_DIR>
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
# Roboflow dataset configuration
roboflow_train:
num_images: 100 # Note: This is the number of images used for training. If null, all images are used.
supercategory: ${all_roboflow_supercategories.${string:${submitit.job_array.task_index}}}
# Training transforms pipeline
train_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterCrowds
- _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox
box_noise_std: 0.1
box_noise_max: 20
- _target_: sam3.train.transforms.segmentation.DecodeRle
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes:
_target_: sam3.train.transforms.basic.get_random_resize_scales
size: ${scratch.resolution}
min_size: 480
rounded: false
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: ${scratch.consistent_transform}
- _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI
size: ${scratch.resolution}
consistent_transform: ${scratch.consistent_transform}
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.train_norm_mean}
std: ${scratch.train_norm_std}
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut
max_num_objects: ${scratch.max_ann_per_img}
# Validation transforms pipeline
val_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution}
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: False
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.train_norm_mean}
std: ${scratch.train_norm_std}
# loss config (no mask loss)
loss:
_target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
matcher: ${scratch.matcher}
o2m_weight: 2.0
o2m_matcher:
_target_: sam3.train.matcher.BinaryOneToManyMatcher
alpha: 0.3
threshold: 0.4
topk: 4
use_o2m_matcher_on_o2m_aux: false # Another option is true
loss_fns_find:
- _target_: sam3.train.loss.loss_fns.Boxes
weight_dict:
loss_bbox: 5.0
loss_giou: 2.0
- _target_: sam3.train.loss.loss_fns.IABCEMdetr
weak_loss: False
weight_dict:
loss_ce: 20.0 # Another option is 100.0
presence_loss: 20.0
pos_weight: 10.0 # Another option is 5.0
alpha: 0.25
gamma: 2
use_presence: True # Change
pos_focal: false
pad_n_queries: 200
pad_scale_pos: 1.0
loss_fn_semantic_seg: null
scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
# NOTE: Loss to be used for training in case of segmentation
# loss:
# _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
# matcher: ${scratch.matcher}
# o2m_weight: 2.0
# o2m_matcher:
# _target_: sam3.train.matcher.BinaryOneToManyMatcher
# alpha: 0.3
# threshold: 0.4
# topk: 4
# use_o2m_matcher_on_o2m_aux: false
# loss_fns_find:
# - _target_: sam3.train.loss.loss_fns.Boxes
# weight_dict:
# loss_bbox: 5.0
# loss_giou: 2.0
# - _target_: sam3.train.loss.loss_fns.IABCEMdetr
# weak_loss: False
# weight_dict:
# loss_ce: 20.0 # Another option is 100.0
# presence_loss: 20.0
# pos_weight: 10.0 # Another option is 5.0
# alpha: 0.25
# gamma: 2
# use_presence: True # Change
# pos_focal: false
# pad_n_queries: 200
# pad_scale_pos: 1.0
# - _target_: sam3.train.loss.loss_fns.Masks
# focal_alpha: 0.25
# focal_gamma: 2.0
# weight_dict:
# loss_mask: 200.0
# loss_dice: 10.0
# compute_aux: false
# loss_fn_semantic_seg:
# _target_: sam3.losses.loss_fns.SemanticSegCriterion
# presence_head: True
# presence_loss: False # Change
# focal: True
# focal_alpha: 0.6
# focal_gamma: 2.0
# downsample: False
# weight_dict:
# loss_semantic_seg: 20.0
# loss_semantic_presence: 1.0
# loss_semantic_dice: 30.0
# scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
enable_segmentation: False # NOTE: This is the number of queries used for segmentation
# Model parameters
d_model: 256
pos_embed:
_target_: sam3.model.position_encoding.PositionEmbeddingSine
num_pos_feats: ${scratch.d_model}
normalize: true
scale: null
temperature: 10000
# Box processing
use_presence_eval: True
original_box_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 # infinite detections
use_original_ids: true
use_original_sizes_box: true
use_presence: ${scratch.use_presence_eval}
# Matcher configuration
matcher:
_target_: sam3.train.matcher.BinaryHungarianMatcherV2
focal: true # with `focal: true` it is equivalent to BinaryFocalHungarianMatcher
cost_class: 2.0
cost_bbox: 5.0
cost_giou: 2.0
alpha: 0.25
gamma: 2
stable: False
scale_by_find_batch_size: True
# Image processing parameters
resolution: 1008
consistent_transform: False
max_ann_per_img: 200
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
# Training parameters
num_train_workers: 10
num_val_workers: 0
max_data_epochs: 20
target_epoch_size: 1500
hybrid_repeats: 1
context_length: 2
gather_pred_via_filesys: false
# Learning rate and scheduler parameters
lr_scale: 0.1
lr_transformer: ${times:8e-4,${scratch.lr_scale}}
lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
lrd_vision_backbone: 0.9
wd: 0.1
scheduler_timescale: 20
scheduler_warmup: 20
scheduler_cooldown: 20
val_batch_size: 1
collate_fn_val:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: roboflow100
with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
gradient_accumulation_steps: 1
train_batch_size: 1
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: all
with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: 20
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: train
gradient_accumulation_steps: ${scratch.gradient_accumulation_steps}
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all: ${roboflow_train.loss}
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
limit_ids: ${roboflow_train.num_images}
transforms: ${roboflow_train.train_transforms}
load_segmentation: ${scratch.enable_segmentation}
max_ann_per_img: 500000
multiplier: 1
max_train_queries: 50000
max_val_queries: 50000
training: true
use_caching: False
img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/
ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/_annotations.coco.json
shuffle: True
batch_size: ${scratch.train_batch_size}
num_workers: ${scratch.num_train_workers}
pin_memory: True
drop_last: True
collate_fn: ${scratch.collate_fn}
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
load_segmentation: ${scratch.enable_segmentation}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
include_negatives: true
category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
_partial_: true
img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/
ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
transforms: ${roboflow_train.val_transforms}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn: ${scratch.collate_fn_val}
model:
_target_: sam3.model_builder.build_sam3_image_model
bpe_path: ${paths.bpe_path}
device: cpus
eval_mode: false
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
meters:
val:
roboflow100:
detection:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "bbox"
dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${roboflow_train.supercategory}
merge_predictions: True
postprocessor: ${scratch.original_box_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 100
pred_file_evaluators:
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
gt_path: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
tide: False
iou_type: "bbox"
optim:
amp:
enabled: True
amp_dtype: bfloat16
optimizer:
_target_: torch.optim.AdamW
gradient_clip:
_target_: sam3.train.optim.optimizer.GradientClipper
max_norm: 0.1
norm_type: 2
param_group_modifiers:
- _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
_partial_: True
layer_decay_value: ${scratch.lrd_vision_backbone}
apply_to: 'backbone.vision_backbone.trunk'
overrides:
- pattern: '*pos_embed*'
value: 1.0
options:
lr:
- scheduler: # transformer and class_embed
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_transformer}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
- scheduler:
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_vision_backbone}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
param_names:
- 'backbone.vision_backbone.*'
- scheduler:
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_language_backbone}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
param_names:
- 'backbone.language_backbone.*'
weight_decay:
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: ${scratch.wd}
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: 0.0
param_names:
- '*bias*'
module_cls_names: ['torch.nn.LayerNorm']
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/${roboflow_train.supercategory}
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 1
gpus_per_node: 2
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null
# Uncomment for job array configuration
job_array:
num_tasks: 100
task_index: 0
# ============================================================================
# Available Roboflow Supercategories (for reference)
# ============================================================================
all_roboflow_supercategories:
- -grccs
- zebrasatasturias
- cod-mw-warzone
- canalstenosis
- label-printing-defect-version-2
- new-defects-in-wood
- orionproducts
- aquarium-combined
- varroa-mites-detection--test-set
- clashroyalechardetector
- stomata-cells
- halo-infinite-angel-videogame
- pig-detection
- urine-analysis1
- aerial-sheep
- orgharvest
- actions
- mahjong
- liver-disease
- needle-base-tip-min-max
- wheel-defect-detection
- aircraft-turnaround-dataset
- xray
- wildfire-smoke
- spinefrxnormalvindr
- ufba-425
- speech-bubbles-detection
- train
- pill
- truck-movement
- car-logo-detection
- inbreast
- sea-cucumbers-new-tiles
- uavdet-small
- penguin-finder-seg
- aerial-airport
- bibdetection
- taco-trash-annotations-in-context
- bees
- recode-waste
- screwdetectclassification
- wine-labels
- aerial-cows
- into-the-vale
- gwhd2021
- lacrosse-object-detection
- defect-detection
- dataconvert
- x-ray-id
- ball
- tube
- 2024-frc
- crystal-clean-brain-tumors-mri-dataset
- grapes-5
- human-detection-in-floods
- buoy-onboarding
- apoce-aerial-photographs-for-object-detection-of-construction-equipment
- l10ul502
- floating-waste
- deeppcb
- ism-band-packet-detection
- weeds4
- invoice-processing
- thermal-cheetah
- tomatoes-2
- marine-sharks
- peixos-fish
- sssod
- aerial-pool
- countingpills
- asphaltdistressdetection
- roboflow-trained-dataset
- everdaynew
- underwater-objects
- soda-bottles
- dentalai
- jellyfish
- deepfruits
- activity-diagrams
- circuit-voltages
- all-elements
- macro-segmentation
- exploratorium-daphnia
- signatures
- conveyor-t-shirts
- fruitjes
- grass-weeds
- infraredimageofpowerequipment
- 13-lkc01
- wb-prova
- flir-camera-objects
- paper-parts
- football-player-detection
- trail-camera
- smd-components
- water-meter
- nih-xray
- the-dreidel-project
- electric-pylon-detection-in-rsi
- cable-damage

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_sav_test
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_sav_test.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: True
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_sav_test
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_sav_test.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: False
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_sav_val
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_sav_val.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: True
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_sav_val
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_sav_val.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: False
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_smartglasses_test
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_smartglasses_test.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: True
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_smartglasses_test
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_smartglasses_test.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: False
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_smartglasses_val
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_smartglasses_val.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: True
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_smartglasses_val
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_smartglasses_val.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: False
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_yt1b_test
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_yt1b_test.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: True
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_yt1b_test
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_yt1b_test.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: False
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_yt1b_val
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_yt1b_val.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: True
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,174 @@
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
dump_file_name: saco_veval_yt1b_val
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
ytvis_json: <YOUR_GT_PATH>/saco_veval_yt1b_val.json
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
num_videos: null
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
vid_mask_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessNullOp
use_presence_eval: True
video_transforms_val:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.segmentation.DecodeRle
# resize the image to 1024x1024 resolution
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution} # originally `resolution: 1024`
square: true
consistent_transform: true
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.val_norm_mean}
std: ${scratch.val_norm_std}
# Model parameters
d_model: 256
# Image processing parameters
resolution: 1008
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
val_batch_size: 1
num_val_workers: 0
max_data_epochs: 20
hybrid_repeats: 1
gather_pred_via_filesys: false
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: true
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: ${scratch.max_data_epochs}
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: val
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all:
_target_: sam3.train.loss.sam3_loss.DummyLoss
default:
_target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train: null
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
limit_ids: ${paths.num_videos}
img_folder: ${paths.ytvis_dir}
ann_file: ${paths.ytvis_json}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
_partial_: true
transforms: ${scratch.video_transforms_val}
max_ann_per_img: 100000 # filtered in transforms
max_val_queries: 100000
multiplier: 1
load_segmentation: true
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: ytvis_val
with_seg_masks: true
model:
_target_: sam3.model_builder.build_sam3_video_model
bpe_path: ${paths.bpe_path}
has_presence_token: True
geo_encoder_use_img_cross_attn: True
apply_temporal_disambiguation: False
meters:
val:
ytvis_val:
pred_file: # key
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
postprocessor: ${scratch.vid_mask_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
optim:
amp:
enabled: True
amp_dtype: bfloat16
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 8
gpus_per_node: 8
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: True
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_bdd100k/
coco_gt: ${paths.base_annotation_path_silver}/silver_bdd100k_merged_test.json
img_path: ${paths.silver_img_path}/bdd100k/
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: silver_bdd100k
meters:
val:
silver_bdd100k: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_bdd100k
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "segm"

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_droid/
coco_gt: ${paths.base_annotation_path_silver}/silver_droid_merged_test.json
img_path: ${paths.silver_img_path}/droid/
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: silver_droid
meters:
val:
silver_droid: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_droid
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "segm"

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_ego4d/
coco_gt: ${paths.base_annotation_path_silver}/silver_ego4d_merged_test.json
img_path: ${paths.silver_img_path}/ego4d/
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: silver_ego4d
meters:
val:
silver_ego4d: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_ego4d
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "segm"

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_fathomnet/
coco_gt: ${paths.base_annotation_path_silver}/silver_fathomnet_test.json
img_path: ${paths.silver_img_path}/fathomnet/
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: silver_fathomnet
meters:
val:
silver_fathomnet: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_fathomnet
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "segm"

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_food_rec/
coco_gt: ${paths.base_annotation_path_silver}/silver_food_rec_merged_test.json
img_path: ${paths.silver_img_path}/food_rec/
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: silver_food_rec
meters:
val:
silver_food_rec: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_food_rec
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "segm"

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_geode/
coco_gt: ${paths.base_annotation_path_silver}/silver_geode_merged_test.json
img_path: ${paths.silver_img_path}/geode/
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: silver_geode
meters:
val:
silver_geode: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_geode
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "segm"

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_inaturalist/
coco_gt: ${paths.base_annotation_path_silver}/silver_inaturalist_merged_test.json
img_path: ${paths.silver_img_path}/inaturalist/
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: silver_inaturalist
meters:
val:
silver_inaturalist: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_inaturalist
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "segm"

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_nga_art/
coco_gt: ${paths.base_annotation_path_silver}/silver_nga_art_merged_test.json
img_path: ${paths.silver_img_path}/nga/
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: silver_nga_art
meters:
val:
silver_nga_art: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_nga_art
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "segm"

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_sav/
coco_gt: ${paths.base_annotation_path_silver}/silver_sav_merged_test.json
img_path: ${paths.silver_img_path}/sav/
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: silver_sav
meters:
val:
silver_sav: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_sav
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "segm"

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- /configs/eval_base.yaml
- _self_
# ============================================================================
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
# ============================================================================
paths:
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_yt1b/
coco_gt: ${paths.base_annotation_path_silver}/silver_yt1b_merged_test.json
img_path: ${paths.silver_img_path}/yt1b/
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
data:
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
_partial_: true
img_folder: ${paths.img_path}
ann_file: ${paths.coco_gt}
transforms: ${scratch.base_val_transform}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: False
drop_last: False
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: silver_yt1b
meters:
val:
silver_yt1b: # this key matches the "dict_key" in the dataloader's collate function
cgf1:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "segm"
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_yt1b
merge_predictions: True
postprocessor: ${scratch.mask_postprocessor_thresholded}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 1000000 # no limit
pred_file_evaluators:
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "bbox"
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
gt_path: ${paths.coco_gt}
iou_type: "segm"

View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

View File

@@ -0,0 +1,465 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import json
from collections import defaultdict
from typing import Dict, List, Tuple
import torch
from pycocotools import mask as mask_util
# ============================================================================
# Utility Functions
# ============================================================================
def convert_boxlist_to_normalized_tensor(box_list, image_width, image_height):
"""
Converts a list of bounding boxes to a normalized PyTorch tensor.
Args:
box_list (list of list or tuples): Each box is [x_min, y_min, x_max, y_max].
image_width (int or float): Width of the image.
image_height (int or float): Height of the image.
Returns:
torch.Tensor: Normalized tensor of shape (N, 4), values in [0, 1].
"""
boxes = torch.tensor(box_list, dtype=torch.float32)
boxes[:, [0, 2]] /= image_width # x_min, x_max
boxes[:, [1, 3]] /= image_height # y_min, y_max
boxes = boxes.clamp(0, 1)
return boxes
def load_coco_and_group_by_image(json_path: str) -> Tuple[List[Dict], Dict[int, str]]:
"""
Load COCO JSON file and group annotations by image.
Args:
json_path (str): Path to COCO JSON file.
Returns:
Tuple containing:
- List of dicts with 'image' and 'annotations' keys
- Dict mapping category IDs to category names
"""
with open(json_path, "r") as f:
coco = json.load(f)
images = {img["id"]: img for img in coco["images"]}
anns_by_image = defaultdict(list)
for ann in coco["annotations"]:
anns_by_image[ann["image_id"]].append(ann)
sorted_image_ids = sorted(images.keys())
grouped = []
for image_id in sorted_image_ids:
image_info = images[image_id]
grouped.append(
{"image": image_info, "annotations": anns_by_image.get(image_id, [])}
)
cat_id_to_name = {cat["id"]: cat["name"] for cat in coco["categories"]}
return grouped, cat_id_to_name
def ann_to_rle(segm, im_info: Dict) -> Dict:
"""
Convert annotation which can be polygons or uncompressed RLE to RLE.
Args:
segm: Segmentation data (polygon list or RLE dict)
im_info (dict): Image info containing 'height' and 'width'
Returns:
RLE encoded segmentation
"""
h, w = im_info["height"], im_info["width"]
if isinstance(segm, list):
# Polygon - merge all parts into one mask RLE code
rles = mask_util.frPyObjects(segm, h, w)
rle = mask_util.merge(rles)
elif isinstance(segm["counts"], list):
# Uncompressed RLE
rle = mask_util.frPyObjects(segm, h, w)
else:
# Already RLE
rle = segm
return rle
# ============================================================================
# COCO Training API
# ============================================================================
class COCO_FROM_JSON:
"""
COCO training API for loading box-only annotations from JSON.
Groups all annotations per image and creates queries per category.
"""
def __init__(
self,
annotation_file,
prompts=None,
include_negatives=True,
category_chunk_size=None,
):
"""
Initialize the COCO training API.
Args:
annotation_file (str): Path to COCO JSON annotation file
prompts: Optional custom prompts for categories
include_negatives (bool): Whether to include negative examples (categories with no instances)
"""
self._raw_data, self._cat_idx_to_text = load_coco_and_group_by_image(
annotation_file
)
self._sorted_cat_ids = sorted(list(self._cat_idx_to_text.keys()))
self.prompts = None
self.include_negatives = include_negatives
self.category_chunk_size = (
category_chunk_size
if category_chunk_size is not None
else len(self._sorted_cat_ids)
)
self.category_chunks = [
self._sorted_cat_ids[i : i + self.category_chunk_size]
for i in range(0, len(self._sorted_cat_ids), self.category_chunk_size)
]
if prompts is not None:
prompts = eval(prompts)
self.prompts = {}
for loc_dict in prompts:
self.prompts[int(loc_dict["id"])] = loc_dict["name"]
assert len(self.prompts) == len(
self._sorted_cat_ids
), "Number of prompts must match number of categories"
def getDatapointIds(self):
"""Return all datapoint indices for training."""
return list(range(len(self._raw_data) * len(self.category_chunks)))
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
"""
Load queries and annotations for a specific datapoint.
Args:
idx (int): Datapoint index
Returns:
Tuple of (queries, annotations) lists
"""
img_idx = idx // len(self.category_chunks)
chunk_idx = idx % len(self.category_chunks)
cat_chunk = self.category_chunks[chunk_idx]
queries = []
annotations = []
query_template = {
"id": None,
"original_cat_id": None,
"object_ids_output": None,
"query_text": None,
"query_processing_order": 0,
"ptr_x_query_id": None,
"ptr_y_query_id": None,
"image_id": 0, # Single image per datapoint
"input_box": None,
"input_box_label": None,
"input_points": None,
"is_exhaustive": True,
}
annot_template = {
"image_id": 0,
"bbox": None, # Normalized bbox in xywh
"area": None, # Unnormalized area
"segmentation": None, # RLE encoded
"object_id": None,
"is_crowd": None,
"id": None,
}
raw_annotations = self._raw_data[img_idx]["annotations"]
image_info = self._raw_data[img_idx]["image"]
width, height = image_info["width"], image_info["height"]
# Group annotations by category
cat_id_to_anns = defaultdict(list)
for ann in raw_annotations:
cat_id_to_anns[ann["category_id"]].append(ann)
annotations_by_cat_sorted = [
(cat_id, cat_id_to_anns[cat_id]) for cat_id in cat_chunk
]
for cat_id, anns in annotations_by_cat_sorted:
if len(anns) == 0 and not self.include_negatives:
continue
cur_ann_ids = []
# Create annotations for this category
for ann in anns:
annotation = annot_template.copy()
annotation["id"] = len(annotations)
annotation["object_id"] = annotation["id"]
annotation["is_crowd"] = ann["iscrowd"]
normalized_boxes = convert_boxlist_to_normalized_tensor(
[ann["bbox"]], width, height
)
bbox = normalized_boxes[0]
annotation["area"] = (bbox[2] * bbox[3]).item()
annotation["bbox"] = bbox
if (
"segmentation" in ann
and ann["segmentation"] is not None
and ann["segmentation"] != []
):
annotation["segmentation"] = ann_to_rle(
ann["segmentation"], im_info=image_info
)
annotations.append(annotation)
cur_ann_ids.append(annotation["id"])
# Create query for this category
query = query_template.copy()
query["id"] = len(queries)
query["original_cat_id"] = cat_id
query["query_text"] = (
self._cat_idx_to_text[cat_id]
if self.prompts is None
else self.prompts[cat_id]
)
query["object_ids_output"] = cur_ann_ids
queries.append(query)
return queries, annotations
def loadImagesFromDatapoint(self, idx):
"""
Load image information for a specific datapoint.
Args:
idx (int): Datapoint index
Returns:
List containing image info dict
"""
img_idx = idx // len(self.category_chunks)
img_data = self._raw_data[img_idx]["image"]
images = [
{
"id": 0,
"file_name": img_data["file_name"],
"original_img_id": img_data["id"],
"coco_img_id": img_data["id"],
}
]
return images
# ============================================================================
# SAM3 Evaluation APIs
# ============================================================================
class SAM3_EVAL_API_FROM_JSON_NP:
"""
SAM3 evaluation API for loading noun phrase queries from JSON.
"""
def __init__(self, annotation_file):
"""
Initialize the SAM3 evaluation API.
Args:
annotation_file (str): Path to SAM3 JSON annotation file
"""
with open(annotation_file, "r") as f:
data = json.load(f)
self._image_data = data["images"]
def getDatapointIds(self):
"""Return all datapoint indices."""
return list(range(len(self._image_data)))
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
"""
Load queries and annotations for a specific datapoint.
Args:
idx (int): Datapoint index
Returns:
Tuple of (queries, annotations) lists
"""
cur_img_data = self._image_data[idx]
queries = []
annotations = []
query_template = {
"id": None,
"original_cat_id": None,
"object_ids_output": None,
"query_text": None,
"query_processing_order": 0,
"ptr_x_query_id": None,
"ptr_y_query_id": None,
"image_id": 0,
"input_box": None,
"input_box_label": None,
"input_points": None,
"is_exhaustive": True,
}
# Create query
query = query_template.copy()
query["id"] = len(queries)
query["original_cat_id"] = int(cur_img_data["queried_category"])
query["query_text"] = cur_img_data["text_input"]
query["object_ids_output"] = []
queries.append(query)
return queries, annotations
def loadImagesFromDatapoint(self, idx):
"""
Load image information for a specific datapoint.
Args:
idx (int): Datapoint index
Returns:
List containing image info dict
"""
img_data = self._image_data[idx]
images = [
{
"id": 0,
"file_name": img_data["file_name"],
"original_img_id": img_data["id"],
"coco_img_id": img_data["id"],
}
]
return images
class SAM3_VEVAL_API_FROM_JSON_NP:
"""
SAM3 video evaluation API for loading noun phrase queries from JSON.
"""
def __init__(self, annotation_file):
"""
Initialize the SAM3 video evaluation API.
Args:
annotation_file (str): Path to SAM3 video JSON annotation file
"""
with open(annotation_file, "r") as f:
data = json.load(f)
assert "video_np_pairs" in data, "Incorrect data format"
self._video_data = data["videos"]
self._video_id_to_np_ids = defaultdict(list)
self._cat_id_to_np = {}
for cat_dict in data["categories"]:
self._cat_id_to_np[cat_dict["id"]] = cat_dict["name"]
for video_np_dict in data["video_np_pairs"]:
self._video_id_to_np_ids[video_np_dict["video_id"]].append(
video_np_dict["category_id"]
)
assert (
self._cat_id_to_np[video_np_dict["category_id"]]
== video_np_dict["noun_phrase"]
), "Category name does not match text input"
def getDatapointIds(self):
"""Return all datapoint indices."""
return list(range(len(self._video_data)))
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
"""
Load queries and annotations for a specific video datapoint.
Args:
idx (int): Datapoint index
Returns:
Tuple of (queries, annotations) lists
"""
cur_vid_data = self._video_data[idx]
queries = []
annotations = []
query_template = {
"id": None,
"original_cat_id": None,
"object_ids_output": None,
"query_text": None,
"query_processing_order": 0,
"ptr_x_query_id": None,
"ptr_y_query_id": None,
"image_id": 0,
"input_box": None,
"input_box_label": None,
"input_points": None,
"is_exhaustive": True,
}
all_np_ids = self._video_id_to_np_ids[cur_vid_data["id"]]
for np_id in all_np_ids:
text_input = self._cat_id_to_np[np_id]
for i, image_path in enumerate(cur_vid_data["file_names"]):
query = query_template.copy()
query["id"] = len(queries)
query["original_cat_id"] = np_id
query["query_text"] = text_input
query["image_id"] = i
query["query_processing_order"] = i
query["object_ids_output"] = []
queries.append(query)
return queries, annotations
def loadImagesFromDatapoint(self, idx):
"""
Load image information for a specific video datapoint.
Args:
idx (int): Datapoint index
Returns:
List containing image info dicts for all frames
"""
video_data = self._video_data[idx]
images = [
{
"id": i,
"file_name": file_name,
"original_img_id": video_data["id"],
"coco_img_id": video_data["id"],
}
for i, file_name in enumerate(video_data["file_names"])
]
return images

360
sam3/train/data/collator.py Normal file
View File

@@ -0,0 +1,360 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
from typing import Any, get_args, get_origin, List, Union
import torch
from sam3.model.data_misc import (
BatchedDatapoint,
BatchedFindTarget,
BatchedInferenceMetadata,
FindStage,
)
from .sam3_image_dataset import Datapoint
MyTensor = Union[torch.Tensor, List[Any]]
def convert_my_tensors(obj):
def is_optional_field(field) -> bool:
return get_origin(field) is Union and type(None) in get_args(field)
for field in fields(obj):
if is_dataclass(getattr(obj, field.name)):
convert_my_tensors(getattr(obj, field.name))
continue
field_type = field.type
if is_optional_field(field.type):
field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type
if field_type != MyTensor or getattr(obj, field.name) is None:
continue
elif len(getattr(obj, field.name)) and isinstance(
getattr(obj, field.name)[0], torch.Tensor
):
stack_dim = 0
if field.name in [
"input_boxes",
"input_boxes_label",
]:
stack_dim = 1
setattr(
obj,
field.name,
torch.stack(getattr(obj, field.name), dim=stack_dim).to(
getattr(obj, field.name + "__type")
),
)
else:
setattr(
obj,
field.name,
torch.as_tensor(
getattr(obj, field.name), dtype=getattr(obj, field.name + "__type")
),
)
return obj
def packed_to_padded_naive(boxes_packed, num_boxes, fill_value=0):
"""
Convert a packed tensor of bounding boxes to a padded tensor of bounding
boxes. Naive implementation using a loop.
Inputs:
- boxes_packed: Tensor of shape (N_1 + ... + N_B, 4)
- num_boxes: Tensor of shape (B,) where num_boxes[i] = N_i
Returns:
- boxes_padded: Tensor of shape (B, N_max, 4) where N_max = max_i N_i
"""
B = num_boxes.shape[0]
Ns = num_boxes.tolist()
boxes_padded = boxes_packed.new_zeros(B, max(Ns), *boxes_packed.shape[1:])
if fill_value != 0:
boxes_padded[...] = fill_value
prev_idx = 0
for i in range(B):
next_idx = prev_idx + Ns[i]
boxes_padded[i, : Ns[i]] = boxes_packed[prev_idx:next_idx]
prev_idx = next_idx
return boxes_padded
def pad_tensor_list_to_longest(
tensors: List[torch.Tensor], dim=0, pad_val=0
) -> List[torch.Tensor]:
# Edits the list in-place
if not tensors:
return tensors
pad_len = max(t.shape[dim] for t in tensors)
for i in range(len(tensors)):
n_dims = len(tensors[i].shape)
n_right_dims = (n_dims - 1) - (n_dims + dim) % n_dims
n_pad = pad_len - tensors[i].shape[dim]
pad_tuple = tuple([0] * 2 * n_right_dims + [0, n_pad])
tensors[i] = torch.nn.functional.pad(tensors[i], pad_tuple, value=pad_val)
return tensors
def collate_fn_api_with_chunking(
batch,
num_chunks,
dict_key,
with_seg_masks=False,
input_points_embedding_dim=257,
repeats: int = 0,
load_image_in_fp16: bool = False,
):
assert num_chunks >= 1, "num_chunks must be >= 1"
# split the batch into num_chunks chunks
batch_chunks = [batch[i::num_chunks] for i in range(num_chunks)]
# collate each chunk
collated_chunks = [
collate_fn_api(
chunk,
dict_key,
with_seg_masks,
input_points_embedding_dim,
repeats,
# ptr_behaviour,
load_image_in_fp16,
)
for chunk in batch_chunks
]
return collated_chunks
def collate_fn_api(
batch: List[Datapoint],
dict_key,
with_seg_masks=False,
input_points_embedding_dim=257,
repeats: int = 0,
load_image_in_fp16: bool = False,
):
# img_batch = torch.stack(sum([[img.data for img in v.images] for v in batch], []))
img_batch = []
text_batch = []
raw_images = None
num_stages = (
max(q.query_processing_order for data in batch for q in data.find_queries) + 1
)
stages = [
FindStage(
img_ids=[],
text_ids=[],
input_boxes=[],
input_boxes_label=[],
input_boxes_mask=[],
input_points=[],
input_points_mask=[],
object_ids=[],
)
for _ in range(num_stages)
]
find_targets = [
BatchedFindTarget(
num_boxes=[],
boxes=[],
boxes_padded=[],
is_exhaustive=[],
segments=[],
semantic_segments=[],
is_valid_segment=[],
repeated_boxes=[],
object_ids=[],
object_ids_padded=[],
)
for _ in range(num_stages)
]
find_metadatas = [
BatchedInferenceMetadata(
coco_image_id=[],
original_size=[],
object_id=[],
frame_index=[],
original_image_id=[],
original_category_id=[],
is_conditioning_only=[],
)
for _ in range(num_stages)
]
offset_img_id = 0
offset_query_id = [0 for _ in range(num_stages)]
for i, data in enumerate(batch):
img_batch.extend([img.data for img in data.images])
if data.raw_images is not None:
if raw_images is None:
raw_images = []
raw_images.extend(data.raw_images)
# Conversion of query_ids indexing in a datapoint to query_ids indexing in a stage
datapoint_query_id_2_stage_query_id = []
for q in data.find_queries:
stage_id = q.query_processing_order
datapoint_query_id_2_stage_query_id.append(offset_query_id[stage_id])
offset_query_id[stage_id] += 1
for j, q in enumerate(data.find_queries):
stage_id = q.query_processing_order
stages[stage_id].img_ids.append(q.image_id + offset_img_id)
if q.query_text not in text_batch:
text_batch.append(q.query_text)
stages[stage_id].text_ids.append(text_batch.index(q.query_text))
assert (
q.inference_metadata is not None
), "inference_metadata must be provided when FindQueryLoaded is created."
for f in fields(q.inference_metadata):
getattr(find_metadatas[stage_id], f.name).append(
getattr(q.inference_metadata, f.name)
)
if q.input_bbox is not None:
assert q.input_bbox.numel() % 4 == 0
assert q.input_bbox_label is not None
nb_boxes = q.input_bbox.numel() // 4
assert len(q.input_bbox_label) == nb_boxes
stages[stage_id].input_boxes.append(q.input_bbox.view(nb_boxes, 4))
stages[stage_id].input_boxes_label.append(
q.input_bbox_label.view(nb_boxes)
)
stages[stage_id].input_boxes_mask.append(
torch.zeros(nb_boxes, dtype=torch.bool)
)
else:
stages[stage_id].input_boxes.append(torch.zeros(0, 4))
stages[stage_id].input_boxes_label.append(
torch.zeros(0, dtype=torch.bool)
)
stages[stage_id].input_boxes_mask.append(
torch.ones(0, dtype=torch.bool)
)
if q.input_points is not None:
stages[stage_id].input_points.append(
q.input_points.squeeze(0) # Strip a trivial batch index
)
# All masks will be padded up to the longest length
# with 1s before final conversion to batchd tensors
stages[stage_id].input_points_mask.append(
torch.zeros(q.input_points.shape[1])
)
else:
stages[stage_id].input_points.append(
torch.empty(0, input_points_embedding_dim)
)
stages[stage_id].input_points_mask.append(torch.empty(0))
current_out_boxes = []
current_out_object_ids = []
# Set the object ids referred to by this query
stages[stage_id].object_ids.append(q.object_ids_output)
for object_id in q.object_ids_output:
current_out_boxes.append(
data.images[q.image_id].objects[object_id].bbox
)
current_out_object_ids.append(object_id)
find_targets[stage_id].boxes.extend(current_out_boxes)
find_targets[stage_id].object_ids.extend(current_out_object_ids)
if repeats > 0:
for _ in range(repeats):
find_targets[stage_id].repeated_boxes.extend(current_out_boxes)
find_targets[stage_id].num_boxes.append(len(current_out_boxes))
find_targets[stage_id].is_exhaustive.append(q.is_exhaustive)
if with_seg_masks:
current_seg_mask = []
current_is_valid_segment = []
for object_id in q.object_ids_output:
seg_mask = data.images[q.image_id].objects[object_id].segment
if seg_mask is not None:
current_seg_mask.append(seg_mask)
current_is_valid_segment.append(1)
else:
dummy_mask = torch.zeros(
data.images[q.image_id].data.shape[-2:], dtype=torch.bool
)
current_seg_mask.append(dummy_mask)
current_is_valid_segment.append(0)
find_targets[stage_id].segments.extend(current_seg_mask)
find_targets[stage_id].is_valid_segment.extend(current_is_valid_segment)
else:
# We are not loading segmentation masks
find_targets[stage_id].segments = None
find_targets[stage_id].is_valid_segment = None
if q.semantic_target is not None:
find_targets[stage_id].semantic_segments.append(q.semantic_target)
offset_img_id += len(data.images)
# Pad input points to equal sequence lengths
for i in range(len(stages)):
stages[i].input_points = pad_tensor_list_to_longest(
stages[i].input_points, dim=0, pad_val=0
)
# Masked-out regions indicated by 1s.
stages[i].input_points_mask = pad_tensor_list_to_longest(
stages[i].input_points_mask, dim=0, pad_val=1
)
# Pad input boxes to equal sequence lengths
for i in range(len(stages)):
stages[i].input_boxes = pad_tensor_list_to_longest(
stages[i].input_boxes, dim=0, pad_val=0
)
stages[i].input_boxes_label = pad_tensor_list_to_longest(
stages[i].input_boxes_label, dim=0, pad_val=0
)
# Masked-out regions indicated by 1s.
stages[i].input_boxes_mask = pad_tensor_list_to_longest(
stages[i].input_boxes_mask, dim=0, pad_val=1
)
# Convert to tensors
for i in range(len(stages)):
stages[i] = convert_my_tensors(stages[i])
find_targets[i] = convert_my_tensors(find_targets[i])
find_metadatas[i] = convert_my_tensors(find_metadatas[i])
# get padded representation for the boxes
find_targets[i].boxes_padded = packed_to_padded_naive(
find_targets[i].boxes.view(-1, 4), find_targets[i].num_boxes
)
find_targets[i].object_ids_padded = packed_to_padded_naive(
find_targets[i].object_ids, find_targets[i].num_boxes, fill_value=-1
)
# Finalize the image batch
# check sizes
for img in img_batch[1:]:
assert img.shape == img_batch[0].shape, "All images must have the same size"
image_batch = torch.stack(img_batch)
if load_image_in_fp16:
# Optionally, cast the image tensors to fp16, which helps save GPU memory on
# long videos with thousands of frames (where image tensors could be several GBs)
image_batch = image_batch.half()
return {
dict_key: BatchedDatapoint(
img_batch=image_batch,
find_text_batch=text_batch,
find_inputs=stages,
find_targets=find_targets,
find_metadatas=find_metadatas,
raw_images=raw_images,
)
}

View File

@@ -0,0 +1,528 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Dataset class for modulated detection"""
import json
import os
import random
import sys
import traceback
from collections import Counter
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.data
import torchvision
from decord import cpu, VideoReader
from iopath.common.file_io import g_pathmgr
from PIL import Image as PILImage
from PIL.Image import DecompressionBombError
from sam3.model.box_ops import box_xywh_to_xyxy
from torchvision.datasets.vision import VisionDataset
from .coco_json_loaders import COCO_FROM_JSON
@dataclass
class InferenceMetadata:
"""Metadata required for postprocessing"""
# Coco id that corresponds to the "image" for evaluation by the coco evaluator
# This is used for our own "class agnostic" evaluation
coco_image_id: int
# id in the original dataset, such that we can use the original evaluator
original_image_id: int
# Original category id (if we want to use the original evaluator)
original_category_id: int
# Size of the raw image (height, width)
original_size: Tuple[int, int]
# Id of the object in the media
object_id: int
# Index of the frame in the media (0 if single image)
frame_index: int
# Whether it is for conditioning only, e.g., 0-th frame in TA is for conditioning
# as we assume GT available in frame 0.
is_conditioning_only: Optional[bool] = False
@dataclass
class FindQuery:
query_text: str
image_id: int
# In case of a find query, the list of object ids that have to be predicted
object_ids_output: List[int]
# This is "instance exhaustivity".
# true iff all instances are separable and annotated
# See below the slightly different "pixel exhaustivity"
is_exhaustive: bool
# The order in which the queries are processed (only meaningful for video)
query_processing_order: int = 0
# Input geometry, initially in denormalized XYXY format. Then
# 1. converted to normalized CxCyWH by the Normalize transform
input_bbox: Optional[torch.Tensor] = None
input_bbox_label: Optional[torch.Tensor] = None
# Only for the PVS task
input_points: Optional[torch.Tensor] = None
semantic_target: Optional[torch.Tensor] = None
# pixel exhaustivity: true iff the union of all segments (including crowds)
# covers every pixel belonging to the target class
# Note that instance_exhaustive implies pixel_exhaustive
is_pixel_exhaustive: Optional[bool] = None
@dataclass
class FindQueryLoaded(FindQuery):
# Must have default value since FindQuery has entries with default values
inference_metadata: Optional[InferenceMetadata] = None
@dataclass
class Object:
# Initially in denormalized XYXY format, gets converted to normalized CxCyWH by the Normalize transform
bbox: torch.Tensor
area: float
# Id of the object in the media
object_id: Optional[int] = -1
# Index of the frame in the media (0 if single image)
frame_index: Optional[int] = -1
segment: Optional[Union[torch.Tensor, dict]] = None # RLE dict or binary mask
is_crowd: bool = False
source: Optional[str] = None
@dataclass
class Image:
data: Union[torch.Tensor, PILImage.Image]
objects: List[Object]
size: Tuple[int, int] # (height, width)
# For blurring augmentation
blurring_mask: Optional[Dict[str, Any]] = None
@dataclass
class Datapoint:
"""Refers to an image/video and all its annotations"""
find_queries: List[FindQueryLoaded]
images: List[Image]
raw_images: Optional[List[PILImage.Image]] = None
class CustomCocoDetectionAPI(VisionDataset):
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
def __init__(
self,
root: str,
annFile: str,
load_segmentation: bool,
fix_fname: bool = False,
training: bool = True,
blurring_masks_path: Optional[str] = None,
use_caching: bool = True,
zstd_dict_path=None,
filter_query=None,
coco_json_loader: Callable = COCO_FROM_JSON,
limit_ids: int = None,
) -> None:
super().__init__(root)
self.annFile = annFile
self.use_caching = use_caching
self.zstd_dict_path = zstd_dict_path
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
self.load_segmentation = load_segmentation
self.fix_fname = fix_fname
self.filter_query = filter_query
self.coco = None
self.coco_json_loader = coco_json_loader
self.limit_ids = limit_ids
self.set_sharded_annotation_file(0)
self.training = training
self.blurring_masks_path = blurring_masks_path
def _load_images(
self, datapoint_id: int, img_ids_to_load: Optional[Set[int]] = None
) -> Tuple[List[Tuple[int, PILImage.Image]], List[Dict[str, Any]]]:
all_images = []
all_img_metadata = []
for current_meta in self.coco.loadImagesFromDatapoint(datapoint_id):
img_id = current_meta["id"]
if img_ids_to_load is not None and img_id not in img_ids_to_load:
continue
if self.fix_fname:
current_meta["file_name"] = current_meta["file_name"].split("/")[-1]
path = current_meta["file_name"]
if self.blurring_masks_path is not None:
mask_fname = os.path.basename(path).replace(".jpg", "-mask.json")
mask_path = os.path.join(self.blurring_masks_path, mask_fname)
if os.path.exists(mask_path):
with open(mask_path, "r") as fopen:
current_meta["blurring_mask"] = json.load(fopen)
all_img_metadata.append(current_meta)
path = os.path.join(self.root, path)
try:
if ".mp4" in path and path[-4:] == ".mp4":
# Going to load a video frame
video_path, frame = path.split("@")
video = VideoReader(video_path, ctx=cpu(0))
# Convert to PIL image
all_images.append(
(
img_id,
torchvision.transforms.ToPILImage()(
video[int(frame)].asnumpy()
),
)
)
else:
with g_pathmgr.open(path, "rb") as fopen:
all_images.append((img_id, PILImage.open(fopen).convert("RGB")))
except FileNotFoundError as e:
print(f"File not found: {path} from dataset: {self.annFile}")
raise e
return all_images, all_img_metadata
def set_curr_epoch(self, epoch: int):
self.curr_epoch = epoch
def set_epoch(self, epoch: int):
pass
def set_sharded_annotation_file(self, data_epoch: int):
if self.coco is not None:
return
assert g_pathmgr.isfile(
self.annFile
), f"please provide valid annotation file. Missing: {self.annFile}"
annFile = g_pathmgr.get_local_path(self.annFile)
if self.coco is not None:
del self.coco
self.coco = self.coco_json_loader(annFile)
# Use a torch tensor here to optimize memory usage when using several dataloaders
ids_list = list(sorted(self.coco.getDatapointIds()))
if self.limit_ids is not None:
local_random = random.Random(len(ids_list))
local_random.shuffle(ids_list)
ids_list = ids_list[: self.limit_ids]
self.ids = torch.as_tensor(ids_list, dtype=torch.long)
def __getitem__(self, index: int) -> Datapoint:
return self._load_datapoint(index)
def _load_datapoint(self, index: int) -> Datapoint:
"""A separate method for easy overriding in subclasses."""
id = self.ids[index].item()
pil_images, img_metadata = self._load_images(id)
queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id)
return self.load_queries(pil_images, annotations, queries, img_metadata)
def load_queries(self, pil_images, annotations, queries, img_metadata):
"""Transform the raw image and queries into a Datapoint sample."""
images: List[Image] = []
id2index_img = {}
id2index_obj = {}
id2index_find_query = {}
id2imsize = {}
assert len(pil_images) == len(img_metadata)
for i in range(len(pil_images)):
w, h = pil_images[i][1].size
blurring_mask = None
if "blurring_mask" in img_metadata[i]:
blurring_mask = img_metadata[i]["blurring_mask"]
images.append(
Image(
data=pil_images[i][1],
objects=[],
size=(h, w),
blurring_mask=blurring_mask,
)
)
id2index_img[pil_images[i][0]] = i
id2imsize[pil_images[i][0]] = (h, w)
for annotation in annotations:
image_id = id2index_img[annotation["image_id"]]
bbox = box_xywh_to_xyxy(torch.as_tensor(annotation["bbox"])).view(1, 4)
h, w = id2imsize[annotation["image_id"]]
bbox[:, 0::2].mul_(w).clamp_(min=0, max=w)
bbox[:, 1::2].mul_(h).clamp_(min=0, max=h)
segment = None
if self.load_segmentation and "segmentation" in annotation:
# We're not decoding the RLE here, a transform will do it lazily later
segment = annotation["segmentation"]
images[image_id].objects.append(
Object(
bbox=bbox[0],
area=annotation["area"],
object_id=(
annotation["object_id"] if "object_id" in annotation else -1
),
frame_index=(
annotation["frame_index"] if "frame_index" in annotation else -1
),
segment=segment,
is_crowd=(
annotation["is_crowd"] if "is_crowd" in annotation else None
),
source=annotation["source"] if "source" in annotation else "",
)
)
id2index_obj[annotation["id"]] = len(images[image_id].objects) - 1
find_queries = []
stage2num_queries = Counter()
for i, query in enumerate(queries):
stage2num_queries[query["query_processing_order"]] += 1
id2index_find_query[query["id"]] = i
# Sanity check: all the stages should have the same number of queries
if len(stage2num_queries) == 0:
num_queries_per_stage = 0
else:
num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
for stage, num_queries in stage2num_queries.items():
assert (
num_queries == num_queries_per_stage
), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
for query_id, query in enumerate(queries):
h, w = id2imsize[query["image_id"]]
if (
"input_box" in query
and query["input_box"] is not None
and len(query["input_box"]) > 0
):
bbox = box_xywh_to_xyxy(torch.as_tensor(query["input_box"])).view(-1, 4)
bbox[:, 0::2].mul_(w).clamp_(min=0, max=w)
bbox[:, 1::2].mul_(h).clamp_(min=0, max=h)
if "input_box_label" in query and query["input_box_label"] is not None:
bbox_label = torch.as_tensor(
query["input_box_label"], dtype=torch.long
).view(-1)
assert len(bbox_label) == len(bbox)
else:
# assume the boxes are positives
bbox_label = torch.ones(len(bbox), dtype=torch.long)
else:
bbox = None
bbox_label = None
if "input_points" in query and query["input_points"] is not None:
points = torch.as_tensor(query["input_points"]).view(1, -1, 3)
points[:, :, 0:1].mul_(w).clamp_(min=0, max=w)
points[:, :, 1:2].mul_(h).clamp_(min=0, max=h)
else:
points = None
try:
original_image_id = int(
img_metadata[id2index_img[query["image_id"]]]["original_img_id"]
)
except ValueError:
original_image_id = -1
try:
img_metadata_query = img_metadata[id2index_img[query["image_id"]]]
coco_image_id = (
int(img_metadata_query["coco_img_id"])
if "coco_img_id" in img_metadata_query
else query["id"]
)
except KeyError:
coco_image_id = -1
try:
original_category_id = int(query["original_cat_id"])
except (ValueError, KeyError):
original_category_id = -1
# For evaluation, we associate the ids of the object to be tracked to the query
if query["object_ids_output"]:
obj_id = query["object_ids_output"][0]
obj_idx = id2index_obj[obj_id]
image_idx = id2index_img[query["image_id"]]
object_id = images[image_idx].objects[obj_idx].object_id
frame_index = images[image_idx].objects[obj_idx].frame_index
else:
object_id = -1
frame_index = -1
find_queries.append(
FindQueryLoaded(
# id=query["id"],
# query_type=qtype,
query_text=(
query["query_text"] if query["query_text"] is not None else ""
),
image_id=id2index_img[query["image_id"]],
input_bbox=bbox,
input_bbox_label=bbox_label,
input_points=points,
object_ids_output=[
id2index_obj[obj_id] for obj_id in query["object_ids_output"]
],
is_exhaustive=query["is_exhaustive"],
is_pixel_exhaustive=(
query["is_pixel_exhaustive"]
if "is_pixel_exhaustive" in query
else (
query["is_exhaustive"] if query["is_exhaustive"] else None
)
),
query_processing_order=query["query_processing_order"],
inference_metadata=InferenceMetadata(
coco_image_id=-1 if self.training else coco_image_id,
original_image_id=(-1 if self.training else original_image_id),
frame_index=frame_index,
original_category_id=original_category_id,
original_size=(h, w),
object_id=object_id,
),
)
)
return Datapoint(
find_queries=find_queries,
images=images,
raw_images=[p[1] for p in pil_images],
)
def __len__(self) -> int:
return len(self.ids)
class Sam3ImageDataset(CustomCocoDetectionAPI):
def __init__(
self,
img_folder,
ann_file,
transforms,
max_ann_per_img: int,
multiplier: int,
training: bool,
load_segmentation: bool = False,
max_train_queries: int = 81,
max_val_queries: int = 300,
fix_fname: bool = False,
is_sharded_annotation_dir: bool = False,
blurring_masks_path: Optional[str] = None,
use_caching: bool = True,
zstd_dict_path=None,
filter_query=None,
coco_json_loader: Callable = COCO_FROM_JSON,
limit_ids: int = None,
):
super(Sam3ImageDataset, self).__init__(
img_folder,
ann_file,
fix_fname=fix_fname,
load_segmentation=load_segmentation,
training=training,
blurring_masks_path=blurring_masks_path,
use_caching=use_caching,
zstd_dict_path=zstd_dict_path,
filter_query=filter_query,
coco_json_loader=coco_json_loader,
limit_ids=limit_ids,
)
self._transforms = transforms
self.training = training
self.max_ann_per_img = max_ann_per_img
self.max_train_queries = max_train_queries
self.max_val_queries = max_val_queries
self.repeat_factors = torch.ones(len(self.ids), dtype=torch.float32)
self.repeat_factors *= multiplier
print(f"Raw dataset length = {len(self.ids)}")
self._MAX_RETRIES = 100
def __getitem__(self, idx):
return self.__orig_getitem__(idx)
def __orig_getitem__(self, idx):
for _ in range(self._MAX_RETRIES):
try:
datapoint = super(Sam3ImageDataset, self).__getitem__(idx)
# This can be done better by filtering the offending find queries
# However, this requires care:
# - Delete any find/get query that may depend on the deleted one
# - Re-compute the indexes in the pointers to account for the deleted finds
for q in datapoint.find_queries:
if len(q.object_ids_output) > self.max_ann_per_img:
raise DecompressionBombError(
f"Too many outputs ({len(q.object_ids_output)})"
)
max_queries = (
self.max_train_queries if self.training else self.max_val_queries
)
if len(datapoint.find_queries) > max_queries:
raise DecompressionBombError(
f"Too many find queries ({len(datapoint.find_queries)})"
)
if len(datapoint.find_queries) == 0:
raise DecompressionBombError("No find queries")
for transform in self._transforms:
datapoint = transform(datapoint, epoch=self.curr_epoch)
break
except (DecompressionBombError, OSError, ValueError) as error:
sys.stderr.write(f"ERROR: got loading error on datapoint {idx}\n")
sys.stderr.write(f"Exception: {error}\n")
sys.stderr.write(traceback.format_exc())
idx = (idx + 1) % len(self)
else:
raise RuntimeError(
f"Failed {self._MAX_RETRIES} times trying to load an image."
)
return datapoint

View File

@@ -0,0 +1,327 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import copy
import io
import json
import logging
import math
import os
import pickle
import random
import sys
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
import torchvision
# from decord import cpu, VideoReader
from iopath.common.file_io import PathManager
from PIL import Image as PILImage
from .sam3_image_dataset import Datapoint, Sam3ImageDataset
SEED = 42
class VideoGroundingDataset(Sam3ImageDataset):
def __init__(
self,
num_stages_sample: int = 4,
stage_stride_min: int = 1,
stage_stride_max: int = 5,
random_reverse_time_axis: bool = True,
is_tiling_single_image: bool = False,
# By default, we remove find those queries with geometric inputs (input_box or input_points)
# when creating synthetic videos from frames (since they are not *video-level* text prompts).
# If we need them later, we can sample them on-the-fly via transforms or inside the model.
tile_img_keep_find_queries_with_geo_inputs: bool = False,
tile_img_keep_get_queries: bool = False,
# the maximum number of find queries (for each frame) to keep in a video; if the datapoint
# contains more queries per frame than this limit, we subsample them to avoid OOM errors
max_query_num: int = -1, # the default -1 means no limit
# whether to override the "is_exhaustive" flag of the loaded find queries to True
# (by default, our video datasets are ingested with is_exhaustive=False, since the YTVIS format
# annotations doesn't involve an "is_exhaustive" flag; this means that those unmatched (negative)
# detection queries or tracking queries do not receive a classification loss given that we have
# weak_loss=True in IABCEMdetr -- this could lead to false positives for both image detection
# and video association.)
override_query_is_exhaustive_to_true: bool = False,
# the maximum number of masklets in a video; if the datapoint contains more masklets
# than this limit, we skip the datapoint to avoid OOM errors (this is useful for
# training with large videos that contain many objects)
max_masklet_num_in_video: int = 300, # 300 masklets is usually OK to avoid OOM
**kwargs,
):
"""
Loading video grounding data
Video frame sampling parameters (for training only):
- num_stages_sample: number of frames to sample from the video during training
- stage_stride_min: minimum stride between sampled frames during training
- stage_stride_max: maximum stride between sampled frames during training (if it's
greater than stage_stride_min, the actual stride is sampled uniformly between min
and max; during inference, we always use all frames in the video with stride=1)
- random_reverse_time_axis: whether to randomly invert the video's temporal axis
(i.e. playing it backwards) during training
"""
super().__init__(**kwargs)
assert num_stages_sample >= 1
assert stage_stride_min >= 1
assert stage_stride_max >= stage_stride_min
self.num_stages_sample = num_stages_sample
self.stage_stride_min = stage_stride_min
self.stage_stride_max = stage_stride_max
self.random_reverse_time_axis = random_reverse_time_axis
self.is_tiling_single_image = is_tiling_single_image
self.tile_img_keep_find_queries_with_geo_inputs = (
tile_img_keep_find_queries_with_geo_inputs
)
self.tile_img_keep_get_queries = tile_img_keep_get_queries
self.max_query_num = max_query_num
self.override_query_is_exhaustive_to_true = override_query_is_exhaustive_to_true
self.max_masklet_num_in_video = max_masklet_num_in_video
self.rng = random.Random()
self.set_curr_epoch(0)
def set_curr_epoch(self, epoch: int):
super().set_curr_epoch(epoch)
self.rng.seed(SEED + epoch)
def _load_datapoint(self, index: int) -> Datapoint:
id = self.ids[index].item()
queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id)
# we subsample the video frames during training
if self.training and not self.is_tiling_single_image:
# pick a random stride for sampling query stages (`randint` includes both ends)
stage_stride = self.rng.randint(
self.stage_stride_min, self.stage_stride_max
)
stage_ids_to_keep = self._sample_stage_ids(
queries, self.num_stages_sample, stage_stride
)
# filter the queries and annotations to keep only the selected stages
# (also remap the stage ids so that they are contiguous and start from 0)
reverse_time_axis = (
self.rng.random() < 0.5 if self.random_reverse_time_axis else False
)
queries, annotations, kept_img_ids = self._filter_query_and_anns(
queries,
annotations,
stage_ids_to_keep,
remap_stage_id=True,
reverse_time_axis=reverse_time_axis,
)
pil_images, img_metadata = self._load_images(id, kept_img_ids)
if reverse_time_axis:
# reverse the temporal ordering of the images and their metadata
# so that the image order matches the query order
pil_images = pil_images[::-1]
img_metadata = img_metadata[::-1]
else:
pil_images, img_metadata = self._load_images(id)
# check that all the images have the same image size (they are expected
# to have the same image size since they are frames from the same video)
assert all(p.size == pil_images[0][1].size for _, p in pil_images)
queries.sort(key=lambda q: q["query_processing_order"])
if self.override_query_is_exhaustive_to_true:
for query in queries:
query["is_exhaustive"] = True
datapoint = self.load_queries(pil_images, annotations, queries, img_metadata)
# skip datapoints with too many masklets to avoid OOM errors
num_masklets_in_video = len(datapoint.images[0].objects)
if num_masklets_in_video > self.max_masklet_num_in_video > 0:
logging.warning(
f"Datapoint {id} has ({num_masklets_in_video=}), exceeding "
f"the maximum allowed ({self.max_masklet_num_in_video}). "
"Skipping this datapoint."
)
next_index = (index + 1) % len(self)
return self._load_datapoint(next_index) # move to the next datapoint
if self.is_tiling_single_image:
datapoint = self._tile_single_image_data(datapoint, self.num_stages_sample)
if self.max_query_num > 0:
datapoint = self._subsample_queries(datapoint, self.max_query_num)
# ensure that all find queries have the same processing order as their image id
for query in datapoint.find_queries:
assert query.image_id == query.query_processing_order, (
f"find query has inconsistent image_id and "
f"query_processing_order: {query.image_id=} vs "
f"{query.query_processing_order=}"
)
return datapoint
def _sample_stage_ids(self, queries, num_stages_sample, stage_stride):
"""Sample a subset of stage ids from all queries."""
# Later we can perhaps turn it into a Sampler class to be more flexible.
all_stage_ids = sorted(set(q["query_processing_order"] for q in queries))
num_stages_total = len(all_stage_ids)
if num_stages_total < num_stages_sample:
raise ValueError("Not enough stages to sample")
# the difference in index between the first and the last sampled stage ids
b_e_gap = (num_stages_sample - 1) * stage_stride
if b_e_gap > num_stages_total - 1:
# In this case, it's not possible to sample with the provide stride,
# so we use the maximum possible stride.
prev_stage_stride = stage_stride
stage_stride = math.floor((num_stages_total - 1) / (num_stages_sample - 1))
logging.info(
f"lowering stride from {prev_stage_stride} to {stage_stride} to "
f"sample {num_stages_sample} stages (from {num_stages_total} total)"
)
b_e_gap = (num_stages_sample - 1) * stage_stride
# randomly select a starting stage id (`randint` includes both ends)
b_max = len(all_stage_ids) - 1 - b_e_gap
b = self.rng.randint(0, b_max)
e = b + b_e_gap
stage_ids_to_keep = all_stage_ids[b : e + 1 : stage_stride]
return stage_ids_to_keep
def _filter_query_and_anns(
self, queries, annotations, stage_ids_to_keep, remap_stage_id, reverse_time_axis
):
"""Filter queries and annotations to only keep those in `stage_ids_to_keep`."""
stage_ids_to_keep = set(stage_ids_to_keep)
kept_img_ids = set()
kept_stage_ids = set()
# Filter queries -- keep those queries with stage_id in `stage_ids_to_keep`
filtered_queries = []
for query in queries:
input_box = query.get("input_box", None)
input_points = query.get("input_points", None)
has_geo_input = input_box is not None or input_points is not None
if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs:
continue
stage_id = query["query_processing_order"]
if stage_id in stage_ids_to_keep:
kept_img_ids.add(query["image_id"])
kept_stage_ids.add(stage_id)
filtered_queries.append(query)
# Check that all frames in `stage_ids_to_keep` are present after filtering
all_frame_present = kept_stage_ids == stage_ids_to_keep
assert all_frame_present, f"{kept_stage_ids=} vs {stage_ids_to_keep=}"
if remap_stage_id:
# Remap those kept stage ids to be contiguous and starting from 0
old_stage_ids = sorted(kept_stage_ids, reverse=reverse_time_axis)
stage_id_old2new = {old: new for new, old in enumerate(old_stage_ids)}
for query in filtered_queries:
ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
assert (
ptr_x_is_empty and ptr_y_is_empty
), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
query["query_processing_order"] = stage_id_old2new[
query["query_processing_order"]
]
# Filter annotations -- keep those annotations with image_id in `kept_img_ids`
filtered_annotations = [
ann for ann in annotations if ann["image_id"] in kept_img_ids
]
return filtered_queries, filtered_annotations, kept_img_ids
def _tile_single_image_data(self, datapoint: Datapoint, num_stages_sample: int):
"""
Tile a single image and its queries to simulate video frames. The output is a
datapoint with *identical video frames* (i.e. the same static image) and needs
further transforms (e.g. affine) to get video frames with different content.
"""
# tile `images: List[Image]`
assert len(datapoint.images) == 1, "Expected only one single image"
tiled_images = [
copy.deepcopy(datapoint.images[0]) for _ in range(num_stages_sample)
]
for stage_id, img in enumerate(tiled_images):
for obj in img.objects:
obj.frame_index = stage_id
# tile `raw_images: Optional[List[PILImage.Image]] = None`
tiled_raw_images = None
if datapoint.raw_images is not None:
assert len(datapoint.raw_images) == 1, "Expected only one single image"
tiled_raw_images = [
datapoint.raw_images[0].copy() for _ in range(num_stages_sample)
]
# tile `find_queries: List[FindQueryLoaded]`
tiled_find_queries_per_stage = [[] for _ in range(num_stages_sample)]
for query in datapoint.find_queries:
assert query.image_id == 0
assert query.query_processing_order == 0
# check and make sure that a query doesn't contain pointers or references
# to other queries (that cannot be tiled)
assert query.ptr_x is None and query.ptr_y is None
assert query.ptr_mem is None
# assert query.wkdata_qid is None
# assert query.other_positive_qids is None
# assert query.negative_qids is None
has_geo_input = (
query.input_bbox is not None or query.input_points is not None
)
if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs:
continue
for stage_id in range(num_stages_sample):
# copy the query and update the image_id
new_query = copy.deepcopy(query)
new_query.image_id = stage_id
new_query.query_processing_order = stage_id
if new_query.inference_metadata is not None:
new_query.inference_metadata.frame_index = stage_id
tiled_find_queries_per_stage[stage_id].append(new_query)
tiled_find_queries = sum(tiled_find_queries_per_stage, [])
# tile `get_queries: List[GetQuery]` -- we skip them for now (since they involve
# a pointer to a find query that is complicated to tile, and there is not an
# imminent use case for them in the video grounding task in the near future)
if self.tile_img_keep_get_queries:
raise NotImplementedError("Tiling get queries is not implemented yet")
else:
tiled_get_queries = []
return Datapoint(
images=tiled_images,
raw_images=tiled_raw_images,
find_queries=tiled_find_queries,
get_queries=tiled_get_queries,
)
def _subsample_queries(self, datapoint: Datapoint, max_query_num: int):
"""Subsample to keep at most `max_query_num` queries per frame in a datapoint."""
# aggregate the find queries per stage
num_frames = max(q.query_processing_order for q in datapoint.find_queries) + 1
find_queries_per_stage = [[] for _ in range(num_frames)]
for query in datapoint.find_queries:
find_queries_per_stage[query.query_processing_order].append(query)
# verify that all the stages have the same number of queries
num_queries_per_stage = len(find_queries_per_stage[0])
for queries in find_queries_per_stage:
assert len(queries) == num_queries_per_stage
if max_query_num <= 0 or num_queries_per_stage <= max_query_num:
return datapoint
# subsample the queries to keep only `max_query_num` queries
sampled_inds = self.rng.sample(range(num_queries_per_stage), max_query_num)
sampled_find_queries_per_stage = [
[queries[idx] for idx in sampled_inds] for queries in find_queries_per_stage
]
sampled_find_queries = sum(sampled_find_queries_per_stage, [])
return Datapoint(
images=datapoint.images,
raw_images=datapoint.raw_images,
find_queries=sampled_find_queries,
get_queries=datapoint.get_queries,
)

View File

@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Callable, Iterable, Optional
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset
class TorchDataset:
def __init__(
self,
dataset: Dataset,
batch_size: int,
num_workers: int,
shuffle: bool,
pin_memory: bool,
drop_last: bool,
collate_fn: Optional[Callable] = None,
worker_init_fn: Optional[Callable] = None,
enable_distributed_sampler=True,
) -> None:
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
self.collate_fn = collate_fn
self.worker_init_fn = worker_init_fn
assert not isinstance(self.dataset, IterableDataset), "Not supported yet"
if enable_distributed_sampler:
self.sampler = DistributedSampler(self.dataset, shuffle=self.shuffle)
else:
self.sampler = None
def get_loader(self, epoch) -> Iterable:
if self.sampler:
self.sampler.set_epoch(epoch)
if hasattr(self.dataset, "epoch"):
self.dataset.epoch = epoch
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
return DataLoader(
self.dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
drop_last=self.drop_last,
sampler=self.sampler,
collate_fn=self.collate_fn,
worker_init_fn=self.worker_init_fn,
)

View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

1319
sam3/train/loss/loss_fns.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Callable
import torch
from torch.nn import functional as F
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
def point_sample(input, point_coords, **kwargs):
"""
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
[0, 1] x [0, 1] square.
Args:
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
[0, 1] x [0, 1] normalized point coordinates.
Returns:
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
features for points in `point_coords`. The features are obtained via bilinear
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
"""
add_dim = False
if point_coords.dim() == 3:
add_dim = True
point_coords = point_coords.unsqueeze(2)
normalized_point_coords = 2.0 * point_coords - 1.0 # Normalize to [-1,1]
output = F.grid_sample(input, normalized_point_coords, **kwargs)
if add_dim:
output = output.squeeze(3)
return output
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
def get_uncertain_point_coords_with_randomness(
logits: torch.Tensor,
uncertainty_func: Callable,
num_points: int,
oversample_ratio: int,
importance_sample_ratio: float,
) -> torch.Tensor:
"""
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
are calculated for each point using 'uncertainty_func' function that takes point's logit
prediction as input.
See PointRend paper for details.
Args:
logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
class-specific or class-agnostic prediction.
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
contains logit predictions for P points and returns their uncertainties as a Tensor of
shape (N, 1, P).
num_points (int): The number of points P to sample.
oversample_ratio (int): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
sampled points.
"""
assert oversample_ratio >= 1
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
num_boxes = logits.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(num_boxes, num_sampled, 2, device=logits.device)
point_logits = point_sample(logits, point_coords, align_corners=False)
# It is crucial to calculate uncertainty based on the sampled prediction value for the points.
# Calculating uncertainties of the predictions first and sampling them for points leads
# to incorrect results.
# To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
# two predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
# However, if we calculate uncertainties for the predictions first,
# both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
point_uncertainties = uncertainty_func(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
# Flatten the indices
shift = num_sampled * torch.arange(
num_boxes, dtype=torch.long, device=logits.device
)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
num_boxes, num_uncertain_points, 2
)
if num_random_points > 0:
point_coords = torch.cat(
[
point_coords,
torch.rand(num_boxes, num_random_points, 2, device=logits.device),
],
dim=1,
)
return point_coords
# Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py
def calculate_uncertainty(logits: torch.Tensor) -> torch.Tensor:
"""
Estimates uncerainty as L1 distance between 0.0 and the logit prediction.
Args:
logits (Tensor): A tensor of shape (R, 1, ...) for class-agnostic
predicted masks
Returns:
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
the most uncertain locations having the highest uncertainty score.
"""
assert logits.shape[1] == 1
return -(torch.abs(logits))

View File

@@ -0,0 +1,203 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import torch
from sam3.model.model_misc import SAM3Output
from sam3.train.utils.distributed import get_world_size
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks
class DummyLoss(torch.nn.Module):
"""A dummy loss that always returns 0 (as a placeholder for eval)"""
def __init__(
self,
core_loss_key: str = CORE_LOSS_KEY,
device: str = "cuda",
**kwargs,
):
super().__init__()
self.core_loss_key = core_loss_key
self.device = torch.device(device)
def forward(self, *args, **kwargs):
return {self.core_loss_key: torch.tensor(0.0, device=self.device)}
def accumulate(self, out_dict):
"""
Called by iterative losses.
"""
if self.core_loss_key not in out_dict:
out_dict[self.core_loss_key] = torch.tensor(0.0, device=self.device)
return out_dict
class Sam3LossWrapper(torch.nn.Module):
def __init__(
self,
loss_fns_find,
normalization="global",
matcher=None,
o2m_matcher=None,
o2m_weight=1.0,
use_o2m_matcher_on_o2m_aux=True,
loss_fn_semantic_seg=None,
normalize_by_valid_object_num=False,
normalize_by_stage_num=False,
scale_by_find_batch_size=False,
):
super().__init__()
self.loss_fns_find = loss_fns_find
assert normalization in ["global", "local", "none"]
self.normalization = normalization
self.normalize_by_valid_object_num = normalize_by_valid_object_num
self.normalize_by_stage_num = normalize_by_stage_num
self.matcher = matcher
self.o2m_matcher = o2m_matcher
self.o2m_weight = o2m_weight
# whether to use the o2m matcher on the o2m queries in auxiliary outputs
self.use_o2m_matcher_on_o2m_aux = use_o2m_matcher_on_o2m_aux
self.loss_fn_semantic_seg = loss_fn_semantic_seg
self.scale_by_find_batch_size = scale_by_find_batch_size
def _get_num_boxes(self, targets):
# the average number of target boxes for loss normalization
if self.normalize_by_valid_object_num:
# valid boxes are those with non-zero height and width
# (while padded invisible boxes are )
boxes_hw = targets["boxes"].view(-1, 4) # cx, cy, w, h
num_boxes = (boxes_hw[:, 2:] > 0).all(dim=-1).sum().float()
else:
num_boxes = targets["num_boxes"].sum().float()
if self.normalization == "global":
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1)
elif self.normalization == "local":
num_boxes = torch.clamp(num_boxes, min=1)
elif self.normalization == "none":
num_boxes = 1
return num_boxes
def compute_loss(self, nested_out, targets):
num_boxes = self._get_num_boxes(targets)
o2m_out_is_valid = nested_out.get("o2m_out_is_valid", None)
o2m_target_is_valid_padded = nested_out.get("o2m_target_is_valid_padded", None)
# Get a list of outputs, including auxiliary and first stage outputs
output_list = [(nested_out, "", False)] # (out, suffix, is_aux)
if "aux_outputs" in nested_out:
output_list.extend(
(aux_out, f"_aux_{i}", True)
for i, aux_out in enumerate(nested_out["aux_outputs"])
)
if "first_stage" in nested_out:
output_list.append((nested_out["first_stage"], "_fs", True))
# Compute all the requested losses
losses = {}
total_core_loss = 0.0
for out, suffix, is_aux in output_list:
# o2o matcher indices need to be computed by the model (as the video model requires
# a specific way of matching free and locked indices beyond just calling the matcher)
indices = out["indices"]
has_o2m_out = "pred_logits_o2m" in out
if has_o2m_out:
o2m_out = {
k[: -len("_o2m")]: v for k, v in out.items() if k.endswith("_o2m")
}
# o2m targets are the same as the o2o targets (assuming repeat=1)
o2m_targets = targets
if self.use_o2m_matcher_on_o2m_aux or not is_aux:
o2m_indices = self.o2m_matcher(
o2m_out,
o2m_targets,
out_is_valid=o2m_out_is_valid,
target_is_valid_padded=o2m_target_is_valid_padded,
)
else:
o2m_indices = self.matcher(
o2m_out,
o2m_targets,
out_is_valid=o2m_out_is_valid,
target_is_valid_padded=o2m_target_is_valid_padded,
)
for loss_fn in self.loss_fns_find:
l_dict = loss_fn(
outputs=out,
targets=targets,
indices=indices,
num_boxes=num_boxes,
is_aux=is_aux,
)
total_core_loss += l_dict.pop(CORE_LOSS_KEY)
losses.update({f"{k}{suffix}": v for k, v in l_dict.items()})
compute_o2m_loss = has_o2m_out
# a special handling to allow turning off mask loss in o2m
# (to be compatible with the original implementation)
if isinstance(loss_fn, Masks):
compute_o2m_loss = compute_o2m_loss and "pred_masks" in o2m_out
if isinstance(loss_fn, Det2TrkAssoc):
compute_o2m_loss = False # Det2TrkAssoc does not support o2m
if compute_o2m_loss:
l_dict = loss_fn(
outputs=o2m_out,
targets=o2m_targets,
indices=o2m_indices,
num_boxes=num_boxes,
is_aux=is_aux,
)
for k in l_dict:
l_dict[k] *= self.o2m_weight
total_core_loss += l_dict.pop(CORE_LOSS_KEY)
losses.update({f"{k}{suffix}_o2m": v for k, v in l_dict.items()})
losses[CORE_LOSS_KEY] = total_core_loss
return losses
def forward(self, find_stages: SAM3Output, find_targets):
if find_stages.loss_stages is not None:
find_targets = [find_targets[i] for i in find_stages.loss_stages]
with SAM3Output.iteration_mode(
find_stages, iter_mode=SAM3Output.IterMode.ALL_STEPS_PER_STAGE
) as find_stages:
assert len(find_stages) == len(find_targets)
total_losses = {}
for stage_outputs, stage_targets in zip(find_stages, find_targets):
stage_targets = [stage_targets] * len(stage_outputs)
# If there are multiple steps within a stage, compute the loss for all of them (e.g. interactivity)
for outputs, targets in zip(stage_outputs, stage_targets):
cur_losses = self.compute_loss(outputs, targets)
if self.loss_fn_semantic_seg is not None:
cur_losses_semantic = self.loss_fn_semantic_seg(
outputs, targets
)
cur_losses[CORE_LOSS_KEY] += cur_losses_semantic.pop(
CORE_LOSS_KEY
)
# make sure the semantic losses don't overlap with the find losses
assert set(cur_losses).isdisjoint(set(cur_losses_semantic))
cur_losses.update(cur_losses_semantic)
# Optionally, normalize the loss by the number of find stages (training video frames) so that
# image batches and video batches have similar loss scales. (Otherwise video batches would
# have a much higher loss scale due to summing the losses over all the find stages.)
if self.normalize_by_stage_num:
cur_losses[CORE_LOSS_KEY] /= len(find_stages)
if self.scale_by_find_batch_size:
bs = targets["num_boxes"].shape[0]
# sqrt scaling based on the "effective" batch size
cur_losses[CORE_LOSS_KEY] *= bs**0.5
for k, v in cur_losses.items():
if k not in total_losses:
total_losses[k] = v
else:
total_losses[k] += v
return total_losses

View File

@@ -0,0 +1,321 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Triton kernel for faster and memory efficient sigmoid focal loss"""
import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_helpers import libdevice
"""
The sigmoid focal loss is defined as:
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * ce_loss * ((1 - p_t) ** gamma)
Where alpha and gamma are scalar parameters, inputs are the logits, targets the float targets.
We implement two versions of the sigmoid focal loss: with and without sum reduction.
The latter is implemented with built-in reduction to avoid materializing wrt the output of the loss.
This can help save a bit of peak memory.
The reduction version is implemented using somewhat of a hack. Pytorch's generated kernels usually do the point-wise operation in a first kernel, and implement the reduction another kernel launched on a grid of size 1, where the reduction happens as a for loop in the triton kernel.
Since we want to fuse those two kernels, that is not a good idea: we'd have to launch the overall kernel on a grid of size 1, which is obviously inefficient.
On the other hand, typical CUDA algorithms for reduction (eg reduction tree) are hard to implement in triton due to the lack of thread sync primitives.
We settle for a version that abuses triton's atomic_add: we can have all threads simply add to the same location.
In practice, this is not good, since it creates a massive bottleneck on the semaphore for that single memory location. So instead, we create M reduction locations. Each thread will simply write to thread_id%M. The python code can finally sum over the M reductions.
M = 32 works fine in benchmarking tests. The forward is a tiny bit slower compared to the non-reduced kernel, but the backward breaks even due to one less memory allocation.
"""
@triton.jit
def _inner_focal_loss_fwd(inputs, targets, alpha, gamma):
inv_targets = 1 - targets
# Sigmoid
sig = tl.sigmoid(inputs)
# Binary cross entropy with logits
# In practice, we want the following:
# bce_loss = -targets * tl.log(sig) - (1 - targets) * tl.log(1 - sig)
# However, the above is not numerically stable.
# We're also not directly taking the sum here, so the usual log-sum-exp trick doesn't apply
# The bce can be reformulated, after algebraic manipulation, to
# bce_loss = log(1 + exp(-x)) + x * (1-y)
# This is still not stable, because for large (-x) the exponential will blow up.
# We'll use the following alternate formulation:
# bce_loss = max(x, 0) - x * y + log(1 + exp(-abs(x)))
# Let's show that it's equivalent:
# Case x>=0: abs(x) = x , max(x, 0) = x
# so we get x - x * y + log(1 + exp(-x)) which is equivalent
# Case x<0: abs(x) = -x, max(x, 0) = 0
# we have log(1 + exp(-abs(x))) = log(1 + exp(x)) = log(exp(x)(1 + exp(-x))) = x+log(1 + exp(-x))
# plugging it in, we get
# 0 - x * y + x + log(1 + exp(-x)), which is also equivalent
# Note that this is stable because now the exponent are guaranteed to be below 0.
max_val = tl.clamp(inputs, min=0, max=1e9)
bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
# Modulating factor
p_t = sig * targets + (1 - sig) * inv_targets
mod_factor = libdevice.pow(1 - p_t, gamma)
# Alpha factor
alpha_t = alpha * targets + (1 - alpha) * inv_targets
# Final loss calculation
return alpha_t * mod_factor * bce_loss
# Non-reduced version
@triton.jit
def sigmoid_focal_loss_fwd_kernel(
inputs_ptr,
targets_ptr,
loss_ptr,
alpha: float,
gamma: float,
n_elements: int,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offset = block_start + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements
# Load data
inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
targets = tl.load(targets_ptr + offset, mask=mask)
final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma)
# Store result
tl.store(loss_ptr + offset, final_loss, mask=mask)
# version with reduction
@triton.jit
def sigmoid_focal_loss_fwd_kernel_reduce(
inputs_ptr,
targets_ptr,
loss_ptr,
alpha: float,
gamma: float,
n_elements: int,
BLOCK_SIZE: tl.constexpr,
REDUCE_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
reduce_loc = pid % REDUCE_SIZE
offset = block_start + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements
# Load data
inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
targets = tl.load(targets_ptr + offset, mask=mask)
final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) * mask
fl = tl.sum(final_loss)
# Store result
tl.atomic_add(loss_ptr + reduce_loc, fl)
@triton.jit
def _inner_focal_loss_bwd(inputs, targets, alpha, gamma):
inv_targets = 1 - targets
# Recompute forward
max_val = tl.clamp(inputs, min=0, max=1e9)
bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
# Sigmoid
sig = tl.sigmoid(inputs)
inv_sig = 1 - sig
# Modulating factor
p_t = sig * targets + inv_sig * inv_targets
tmp = libdevice.pow(1 - p_t, gamma - 1)
mod_factor = tmp * (1 - p_t)
# Alpha factor
alpha_t = alpha * targets + (1 - alpha) * inv_targets
# Now computing the derivatives
d_pt = (2 * targets - 1) * sig * inv_sig
d_mod_factor = -gamma * d_pt * tmp
d_bce_loss = sig - targets
return alpha_t * (d_bce_loss * mod_factor + d_mod_factor * bce_loss)
@triton.jit
def sigmoid_focal_loss_bwd_kernel(
inputs_ptr,
targets_ptr,
grad_inputs_ptr,
grad_out_ptr,
alpha: float,
gamma: float,
n_elements: int,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offset = block_start + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements
input_ptrs = inputs_ptr + offset
target_ptrs = targets_ptr + offset
grad_input_ptrs = grad_inputs_ptr + offset
grad_out_ptrs = grad_out_ptr + offset
# Load data
inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
targets = tl.load(target_ptrs, mask=mask)
grad_out = tl.load(grad_out_ptrs, mask=mask)
d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
tl.store(grad_input_ptrs, d_loss, mask=mask)
@triton.jit
def sigmoid_focal_loss_bwd_kernel_reduce(
inputs_ptr,
targets_ptr,
grad_inputs_ptr,
grad_out_ptr,
alpha: float,
gamma: float,
n_elements: int,
BLOCK_SIZE: tl.constexpr,
):
# The only difference is that the gradient is now a single scalar
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offset = block_start + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements
input_ptrs = inputs_ptr + offset
target_ptrs = targets_ptr + offset
grad_input_ptrs = grad_inputs_ptr + offset
# Load data
inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
targets = tl.load(target_ptrs, mask=mask)
grad_out = tl.load(grad_out_ptr)
d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
tl.store(grad_input_ptrs, d_loss, mask=mask)
class SigmoidFocalLoss(torch.autograd.Function):
BLOCK_SIZE = 256
@staticmethod
def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
n_elements = inputs.numel()
assert targets.numel() == n_elements
input_shape = inputs.shape
inputs = inputs.view(-1).contiguous()
targets = targets.view(-1).contiguous()
loss = torch.empty(inputs.shape, dtype=torch.float32, device=inputs.device)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
sigmoid_focal_loss_fwd_kernel[grid](
inputs, targets, loss, alpha, gamma, n_elements, SigmoidFocalLoss.BLOCK_SIZE
)
ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
ctx.alpha = alpha
ctx.gamma = gamma
return loss.view(input_shape)
@staticmethod
def backward(ctx, grad_output):
inputs, targets = ctx.saved_tensors
alpha = ctx.alpha
gamma = ctx.gamma
n_elements = inputs.numel()
input_shape = inputs.shape
grad_inputs = torch.empty(
inputs.shape, dtype=grad_output.dtype, device=grad_output.device
)
inputs_ptr = inputs.view(-1).contiguous()
targets_ptr = targets.view(-1).contiguous()
grad_output_ptr = grad_output.view(-1).contiguous()
grad_inputs_ptr = grad_inputs
assert grad_output.numel() == n_elements
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
sigmoid_focal_loss_bwd_kernel[grid](
inputs_ptr,
targets_ptr,
grad_inputs_ptr,
grad_output_ptr,
alpha,
gamma,
n_elements,
SigmoidFocalLoss.BLOCK_SIZE,
)
return grad_inputs.view(input_shape), None, None, None
triton_sigmoid_focal_loss = SigmoidFocalLoss.apply
class SigmoidFocalLossReduced(torch.autograd.Function):
BLOCK_SIZE = 256
REDUCE_SIZE = 32
@staticmethod
def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
n_elements = inputs.numel()
input_shape = inputs.shape
inputs = inputs.view(-1).contiguous()
targets = targets.view(-1).contiguous()
loss = torch.zeros(
SigmoidFocalLossReduced.REDUCE_SIZE,
device=inputs.device,
dtype=torch.float32,
)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
sigmoid_focal_loss_fwd_kernel_reduce[grid](
inputs,
targets,
loss,
alpha,
gamma,
n_elements,
SigmoidFocalLossReduced.BLOCK_SIZE,
SigmoidFocalLossReduced.REDUCE_SIZE,
)
ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
ctx.alpha = alpha
ctx.gamma = gamma
return loss.sum()
@staticmethod
def backward(ctx, grad_output):
inputs, targets = ctx.saved_tensors
alpha = ctx.alpha
gamma = ctx.gamma
n_elements = inputs.numel()
input_shape = inputs.shape
grad_inputs = torch.empty(
inputs.shape, dtype=grad_output.dtype, device=grad_output.device
)
inputs_ptr = inputs.view(-1).contiguous()
targets_ptr = targets.reshape(-1).contiguous()
assert grad_output.numel() == 1
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
sigmoid_focal_loss_bwd_kernel_reduce[grid](
inputs_ptr,
targets_ptr,
grad_inputs,
grad_output,
alpha,
gamma,
n_elements,
SigmoidFocalLossReduced.BLOCK_SIZE,
)
return grad_inputs.view(input_shape), None, None, None
triton_sigmoid_focal_loss_reduce = SigmoidFocalLossReduced.apply

272
sam3/train/masks_ops.py Normal file
View File

@@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Utilities for masks manipulation"""
import numpy as np
import pycocotools.mask as maskUtils
import torch
from pycocotools import mask as mask_util
def instance_masks_to_semantic_masks(
instance_masks: torch.Tensor, num_instances: torch.Tensor
) -> torch.Tensor:
"""This function converts instance masks to semantic masks.
It accepts a collapsed batch of instances masks (ie all instance masks are concatenated in a single tensor) and
the number of instances in each image of the batch.
It returns a mask with the same spatial dimensions as the input instance masks, where for each batch element the
semantic mask is the union of all the instance masks in the batch element.
If for a given batch element there are no instances (ie num_instances[i]==0), the corresponding semantic mask will be a tensor of zeros.
Args:
instance_masks (torch.Tensor): A tensor of shape (N, H, W) where N is the number of instances in the batch.
num_instances (torch.Tensor): A tensor of shape (B,) where B is the batch size. It contains the number of instances
in each image of the batch.
Returns:
torch.Tensor: A tensor of shape (B, H, W) where B is the batch size and H, W are the spatial dimensions of the
input instance masks.
"""
masks_per_query = torch.split(instance_masks, num_instances.tolist())
return torch.stack([torch.any(masks, dim=0) for masks in masks_per_query], dim=0)
def mask_intersection(masks1, masks2, block_size=16):
"""Compute the intersection of two sets of masks, without blowing the memory"""
assert masks1.shape[1:] == masks2.shape[1:]
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
result = torch.zeros(
masks1.shape[0], masks2.shape[0], device=masks1.device, dtype=torch.long
)
for i in range(0, masks1.shape[0], block_size):
for j in range(0, masks2.shape[0], block_size):
intersection = (
(masks1[i : i + block_size, None] * masks2[None, j : j + block_size])
.flatten(-2)
.sum(-1)
)
result[i : i + block_size, j : j + block_size] = intersection
return result
def mask_iom(masks1, masks2):
"""
Similar to IoU, except the denominator is the area of the smallest mask
"""
assert masks1.shape[1:] == masks2.shape[1:]
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
# intersection = (masks1[:, None] * masks2[None]).flatten(-2).sum(-1)
intersection = mask_intersection(masks1, masks2)
area1 = masks1.flatten(-2).sum(-1)
area2 = masks2.flatten(-2).sum(-1)
min_area = torch.min(area1[:, None], area2[None, :])
return intersection / (min_area + 1e-8)
def compute_boundary(seg):
"""
Adapted from https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/metrics/j_and_f.py#L148
Return a 1pix wide boundary of the given mask
"""
assert seg.ndim >= 2
e = torch.zeros_like(seg)
s = torch.zeros_like(seg)
se = torch.zeros_like(seg)
e[..., :, :-1] = seg[..., :, 1:]
s[..., :-1, :] = seg[..., 1:, :]
se[..., :-1, :-1] = seg[..., 1:, 1:]
b = seg ^ e | seg ^ s | seg ^ se
b[..., -1, :] = seg[..., -1, :] ^ e[..., -1, :]
b[..., :, -1] = seg[..., :, -1] ^ s[..., :, -1]
b[..., -1, -1] = 0
return b
def dilation(mask, kernel_size):
"""
Implements the dilation operation. If the input is on cpu, we call the cv2 version.
Otherwise, we implement it using a convolution
The kernel is assumed to be a square kernel
"""
assert mask.ndim == 3
kernel_size = int(kernel_size)
assert (
kernel_size % 2 == 1
), f"Dilation expects a odd kernel size, got {kernel_size}"
if mask.is_cuda:
m = mask.unsqueeze(1).to(torch.float16)
k = torch.ones(1, 1, kernel_size, 1, dtype=m.dtype, device=m.device)
result = torch.nn.functional.conv2d(m, k, padding="same")
result = torch.nn.functional.conv2d(result, k.transpose(-1, -2), padding="same")
return result.view_as(mask) > 0
all_masks = mask.view(-1, mask.size(-2), mask.size(-1)).numpy().astype(np.uint8)
kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
import cv2
processed = [torch.from_numpy(cv2.dilate(m, kernel)) for m in all_masks]
return torch.stack(processed).view_as(mask).to(mask)
def compute_F_measure(
gt_boundary_rle, gt_dilated_boundary_rle, dt_boundary_rle, dt_dilated_boundary_rle
):
"""Adapted from https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/metrics/j_and_f.py#L207
Assumes the boundary and dilated boundaries have already been computed and converted to RLE
"""
gt_match = maskUtils.merge([gt_boundary_rle, dt_dilated_boundary_rle], True)
dt_match = maskUtils.merge([dt_boundary_rle, gt_dilated_boundary_rle], True)
n_dt = maskUtils.area(dt_boundary_rle)
n_gt = maskUtils.area(gt_boundary_rle)
# % Compute precision and recall
if n_dt == 0 and n_gt > 0:
precision = 1
recall = 0
elif n_dt > 0 and n_gt == 0:
precision = 0
recall = 1
elif n_dt == 0 and n_gt == 0:
precision = 1
recall = 1
else:
precision = maskUtils.area(dt_match) / float(n_dt)
recall = maskUtils.area(gt_match) / float(n_gt)
# Compute F measure
if precision + recall == 0:
f_val = 0
else:
f_val = 2 * precision * recall / (precision + recall)
return f_val
@torch.no_grad()
def rle_encode(orig_mask, return_areas=False):
"""Encodes a collection of masks in RLE format
This function emulates the behavior of the COCO API's encode function, but
is executed partially on the GPU for faster execution.
Args:
mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool
return_areas (bool): If True, add the areas of the masks as a part of
the RLE output dict under the "area" key. Default is False.
Returns:
str: The RLE encoded masks
"""
assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)"
assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool"
if orig_mask.numel() == 0:
return []
# First, transpose the spatial dimensions.
# This is necessary because the COCO API uses Fortran order
mask = orig_mask.transpose(1, 2)
# Flatten the mask
flat_mask = mask.reshape(mask.shape[0], -1)
if return_areas:
mask_areas = flat_mask.sum(-1).tolist()
# Find the indices where the mask changes
differences = torch.ones(
mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool
)
differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]
differences[:, 0] = flat_mask[:, 0]
_, change_indices = torch.where(differences)
try:
boundaries = torch.cumsum(differences.sum(-1), 0).cpu()
except RuntimeError as _:
boundaries = torch.cumsum(differences.cpu().sum(-1), 0)
change_indices_clone = change_indices.clone()
# First pass computes the RLEs on GPU, in a flatten format
for i in range(mask.shape[0]):
# Get the change indices for this batch item
beg = 0 if i == 0 else boundaries[i - 1].item()
end = boundaries[i].item()
change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1]
# Now we can split the RLES of each batch item, and convert them to strings
# No more gpu at this point
change_indices = change_indices.tolist()
batch_rles = []
# Process each mask in the batch separately
for i in range(mask.shape[0]):
beg = 0 if i == 0 else boundaries[i - 1].item()
end = boundaries[i].item()
run_lengths = change_indices[beg:end]
uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}
h, w = uncompressed_rle["size"]
rle = mask_util.frPyObjects(uncompressed_rle, h, w)
rle["counts"] = rle["counts"].decode("utf-8")
if return_areas:
rle["area"] = mask_areas[i]
batch_rles.append(rle)
return batch_rles
def robust_rle_encode(masks):
"""Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails"""
assert masks.ndim == 3, "Mask must be of shape (N, H, W)"
assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool"
try:
return rle_encode(masks)
except RuntimeError as _:
masks = masks.cpu().numpy()
rles = [
mask_util.encode(
np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F")
)[0]
for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
return rles
def ann_to_rle(segm, im_info):
"""Convert annotation which can be polygons, uncompressed RLE to RLE.
Args:
ann (dict) : annotation object
Returns:
ann (rle)
"""
h, w = im_info["height"], im_info["width"]
if isinstance(segm, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = mask_util.frPyObjects(segm, h, w)
rle = mask_util.merge(rles)
elif isinstance(segm["counts"], list):
# uncompressed RLE
rle = mask_util.frPyObjects(segm, h, w)
else:
# rle
rle = segm
return rle

806
sam3/train/matcher.py Normal file
View File

@@ -0,0 +1,806 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Modules to compute the matching cost and solve the corresponding LSAP.
"""
import numpy as np
import torch
from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
from scipy.optimize import linear_sum_assignment
from torch import nn
def _do_matching(cost, repeats=1, return_tgt_indices=False, do_filtering=False):
if repeats > 1:
cost = np.tile(cost, (1, repeats))
i, j = linear_sum_assignment(cost)
if do_filtering:
# filter out invalid entries (i.e. those with cost > 1e8)
valid_thresh = 1e8
valid_ijs = [(ii, jj) for ii, jj in zip(i, j) if cost[ii, jj] < valid_thresh]
i, j = zip(*valid_ijs) if len(valid_ijs) > 0 else ([], [])
i, j = np.array(i, dtype=np.int64), np.array(j, dtype=np.int64)
if return_tgt_indices:
return i, j
order = np.argsort(j)
return i[order]
class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(
self,
cost_class: float = 1,
cost_bbox: float = 1,
cost_giou: float = 1,
focal_loss: bool = False,
focal_alpha: float = 0.25,
focal_gamma: float = 2,
):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1)
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
self.focal_loss = focal_loss
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
@torch.no_grad()
def forward(self, outputs, batched_targets):
"""Performs the matching
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, num_queries = outputs["pred_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = self.norm(
outputs["pred_logits"].flatten(0, 1)
) # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_bbox = batched_targets["boxes"]
if "positive_map" in batched_targets:
# In this case we have a multi-hot target
positive_map = batched_targets["positive_map"]
assert len(tgt_bbox) == len(positive_map)
if self.focal_loss:
positive_map = positive_map > 1e-4
alpha = self.focal_alpha
gamma = self.focal_gamma
neg_cost_class = (
(1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
)
pos_cost_class = (
alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
)
cost_class = (
(pos_cost_class - neg_cost_class).unsqueeze(1)
* positive_map.unsqueeze(0)
).sum(-1)
else:
# Compute the soft-cross entropy between the predicted token alignment and the GT one for each box
cost_class = -(out_prob.unsqueeze(1) * positive_map.unsqueeze(0)).sum(
-1
)
else:
# In this case we are doing a "standard" cross entropy
tgt_ids = batched_targets["labels"]
assert len(tgt_bbox) == len(tgt_ids)
if self.focal_loss:
alpha = self.focal_alpha
gamma = self.focal_gamma
neg_cost_class = (
(1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
)
pos_cost_class = (
alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
)
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
else:
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be omitted.
cost_class = -out_prob[:, tgt_ids]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
assert cost_class.shape == cost_bbox.shape
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
)
# Final cost matrix
C = (
self.cost_bbox * cost_bbox
+ self.cost_class * cost_class
+ self.cost_giou * cost_giou
)
C = C.view(bs, num_queries, -1).cpu().numpy()
sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1]
costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))]
indices = [_do_matching(c) for c in costs]
batch_idx = torch.as_tensor(
sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long
)
src_idx = torch.from_numpy(np.concatenate(indices)).long()
return batch_idx, src_idx
class BinaryHungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(
self,
cost_class: float = 1,
cost_bbox: float = 1,
cost_giou: float = 1,
):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.norm = nn.Sigmoid()
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
@torch.no_grad()
def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1):
"""Performs the matching
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
if repeat_batch != 1:
raise NotImplementedError("please use BinaryHungarianMatcherV2 instead")
bs, num_queries = outputs["pred_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = self.norm(outputs["pred_logits"].flatten(0, 1)).squeeze(
-1
) # [batch_size * num_queries]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_bbox = batched_targets["boxes"]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox)
assert cost_class.shape == cost_bbox.shape
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
)
# Final cost matrix
C = (
self.cost_bbox * cost_bbox
+ self.cost_class * cost_class
+ self.cost_giou * cost_giou
)
C = C.view(bs, num_queries, -1).cpu().numpy()
sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1]
costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))]
return_tgt_indices = False
for c in costs:
n_targ = c.shape[1]
if repeats > 1:
n_targ *= repeats
if c.shape[0] < n_targ:
return_tgt_indices = True
break
if return_tgt_indices:
indices, tgt_indices = zip(
*(
_do_matching(
c, repeats=repeats, return_tgt_indices=return_tgt_indices
)
for c in costs
)
)
tgt_indices = list(tgt_indices)
for i in range(1, len(tgt_indices)):
tgt_indices[i] += sizes[i - 1].item()
tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long()
else:
indices = [_do_matching(c, repeats=repeats) for c in costs]
tgt_idx = None
batch_idx = torch.as_tensor(
sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long
)
src_idx = torch.from_numpy(np.concatenate(indices)).long()
return batch_idx, src_idx, tgt_idx
class BinaryFocalHungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(
self,
cost_class: float = 1,
cost_bbox: float = 1,
cost_giou: float = 1,
alpha: float = 0.25,
gamma: float = 2.0,
stable: bool = False,
):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.norm = nn.Sigmoid()
self.alpha = alpha
self.gamma = gamma
self.stable = stable
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
@torch.no_grad()
def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1):
"""Performs the matching
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
if repeat_batch != 1:
raise NotImplementedError("please use BinaryHungarianMatcherV2 instead")
bs, num_queries = outputs["pred_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_score = outputs["pred_logits"].flatten(0, 1).squeeze(-1)
out_prob = self.norm(out_score) # [batch_size * num_queries]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_bbox = batched_targets["boxes"]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
)
# cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox)
if self.stable:
rescaled_giou = (-cost_giou + 1) / 2
out_prob = out_prob.unsqueeze(-1).expand_as(cost_bbox) * rescaled_giou
cost_class = -self.alpha * (1 - out_prob) ** self.gamma * torch.log(
out_prob
) + (1 - self.alpha) * out_prob**self.gamma * torch.log(1 - out_prob)
else:
# directly computing log sigmoid (more numerically stable)
log_out_prob = torch.nn.functional.logsigmoid(out_score)
log_one_minus_out_prob = torch.nn.functional.logsigmoid(-out_score)
cost_class = (
-self.alpha * (1 - out_prob) ** self.gamma * log_out_prob
+ (1 - self.alpha) * out_prob**self.gamma * log_one_minus_out_prob
)
if not self.stable:
cost_class = cost_class.unsqueeze(-1).expand_as(cost_bbox)
assert cost_class.shape == cost_bbox.shape
# Final cost matrix
C = (
self.cost_bbox * cost_bbox
+ self.cost_class * cost_class
+ self.cost_giou * cost_giou
)
C = C.view(bs, num_queries, -1).cpu().numpy()
sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1]
costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))]
return_tgt_indices = False
for c in costs:
n_targ = c.shape[1]
if repeats > 1:
n_targ *= repeats
if c.shape[0] < n_targ:
return_tgt_indices = True
break
if return_tgt_indices:
indices, tgt_indices = zip(
*(
_do_matching(
c, repeats=repeats, return_tgt_indices=return_tgt_indices
)
for c in costs
)
)
tgt_indices = list(tgt_indices)
for i in range(1, len(tgt_indices)):
tgt_indices[i] += sizes[i - 1].item()
tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long()
else:
indices = [_do_matching(c, repeats=repeats) for c in costs]
tgt_idx = None
batch_idx = torch.as_tensor(
sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long
)
src_idx = torch.from_numpy(np.concatenate(indices)).long()
return batch_idx, src_idx, tgt_idx
class BinaryHungarianMatcherV2(nn.Module):
"""
This class computes an assignment between the targets and the predictions
of the network
For efficiency reasons, the targets don't include the no_object. Because of
this, in general, there are more predictions than targets. In this case, we
do a 1-to-1 matching of the best predictions, while the others are
un-matched (and thus treated as non-objects).
This is a more efficient implementation of BinaryHungarianMatcher.
"""
def __init__(
self,
cost_class: float = 1,
cost_bbox: float = 1,
cost_giou: float = 1,
focal: bool = False,
alpha: float = 0.25,
gamma: float = 2.0,
stable: bool = False,
remove_samples_with_0_gt: bool = True,
):
"""
Creates the matcher
Params:
- cost_class: Relative weight of the classification error in the
matching cost
- cost_bbox: Relative weight of the L1 error of the bounding box
coordinates in the matching cost
- cost_giou: This is the relative weight of the giou loss of the
bounding box in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.norm = nn.Sigmoid()
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
self.focal = focal
if focal:
self.alpha = alpha
self.gamma = gamma
self.stable = stable
self.remove_samples_with_0_gt = remove_samples_with_0_gt
@torch.no_grad()
def forward(
self,
outputs,
batched_targets,
repeats=1,
repeat_batch=1,
out_is_valid=None,
target_is_valid_padded=None,
):
"""
Performs the matching. The inputs and outputs are the same as
BinaryHungarianMatcher.forward, except for the optional cached_padded
flag and the optional "_boxes_padded" entry of batched_targets.
Inputs:
- outputs: A dict with the following keys:
- "pred_logits": Tensor of shape (batch_size, num_queries, 1) with
classification logits
- "pred_boxes": Tensor of shape (batch_size, num_queries, 4) with
predicted box coordinates in cxcywh format.
- batched_targets: A dict of targets. There may be a variable number of
targets per batch entry; suppose that there are T_b targets for batch
entry 0 <= b < batch_size. It should have the following keys:
- "boxes": Tensor of shape (sum_b T_b, 4) giving ground-truth boxes
in cxcywh format for all batch entries packed into a single tensor
- "num_boxes": int64 Tensor of shape (batch_size,) giving the number
of ground-truth boxes per batch entry: num_boxes[b] = T_b
- "_boxes_padded": Tensor of shape (batch_size, max_b T_b, 4) giving
a padded version of ground-truth boxes. If this is not present then
it will be computed from batched_targets["boxes"] instead, but
caching it here can improve performance for repeated calls with the
same targets.
- out_is_valid: If not None, it should be a boolean tensor of shape
(batch_size, num_queries) indicating which predictions are valid.
Invalid predictions are ignored during matching and won't appear in
the output indices.
- target_is_valid_padded: If not None, it should be a boolean tensor of
shape (batch_size, max_num_gt_boxes) in padded format indicating
which GT boxes are valid. Invalid GT boxes are ignored during matching
and won't appear in the output indices.
Returns:
A list of size batch_size, containing tuples of (idx_i, idx_j):
- idx_i is the indices of the selected predictions (in order)
- idx_j is the indices of the corresponding selected targets
(in order)
For each batch element, it holds:
len(index_i) = len(index_j)
= min(num_queries, num_target_boxes)
"""
_, num_queries = outputs["pred_logits"].shape[:2]
out_score = outputs["pred_logits"].squeeze(-1) # (B, Q)
out_bbox = outputs["pred_boxes"] # (B, Q, 4))
device = out_score.device
num_boxes = batched_targets["num_boxes"].cpu()
# Get a padded version of target boxes (as precomputed in the collator).
# It should work for both repeat==1 (o2o) and repeat>1 (o2m) matching.
tgt_bbox = batched_targets["boxes_padded"]
if self.remove_samples_with_0_gt:
# keep only samples w/ at least 1 GT box in targets (num_boxes and tgt_bbox)
batch_keep = num_boxes > 0
num_boxes = num_boxes[batch_keep]
tgt_bbox = tgt_bbox[batch_keep]
if target_is_valid_padded is not None:
target_is_valid_padded = target_is_valid_padded[batch_keep]
# Repeat the targets (for the case of batched aux outputs in the matcher)
if repeat_batch > 1:
# In this case, out_prob and out_bbox will be a concatenation of
# both final and auxiliary outputs, so we also repeat the targets
num_boxes = num_boxes.repeat(repeat_batch)
tgt_bbox = tgt_bbox.repeat(repeat_batch, 1, 1)
if target_is_valid_padded is not None:
target_is_valid_padded = target_is_valid_padded.repeat(repeat_batch, 1)
# keep only samples w/ at least 1 GT box in outputs
if self.remove_samples_with_0_gt:
if repeat_batch > 1:
batch_keep = batch_keep.repeat(repeat_batch)
out_score = out_score[batch_keep]
out_bbox = out_bbox[batch_keep]
if out_is_valid is not None:
out_is_valid = out_is_valid[batch_keep]
assert out_bbox.shape[0] == tgt_bbox.shape[0]
assert out_bbox.shape[0] == num_boxes.shape[0]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
)
out_prob = self.norm(out_score)
if not self.focal:
cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox)
else:
if self.stable:
rescaled_giou = (-cost_giou + 1) / 2
out_prob = out_prob.unsqueeze(-1).expand_as(cost_bbox) * rescaled_giou
cost_class = -self.alpha * (1 - out_prob) ** self.gamma * torch.log(
out_prob
) + (1 - self.alpha) * out_prob**self.gamma * torch.log(1 - out_prob)
else:
# directly computing log sigmoid (more numerically stable)
log_out_prob = torch.nn.functional.logsigmoid(out_score)
log_one_minus_out_prob = torch.nn.functional.logsigmoid(-out_score)
cost_class = (
-self.alpha * (1 - out_prob) ** self.gamma * log_out_prob
+ (1 - self.alpha) * out_prob**self.gamma * log_one_minus_out_prob
)
if not self.stable:
cost_class = cost_class.unsqueeze(-1).expand_as(cost_bbox)
assert cost_class.shape == cost_bbox.shape
# Final cost matrix
C = (
self.cost_bbox * cost_bbox
+ self.cost_class * cost_class
+ self.cost_giou * cost_giou
)
# assign a very high cost (1e9) to invalid outputs and targets, so that we can
# filter them out (in `_do_matching`) from bipartite matching results
do_filtering = out_is_valid is not None or target_is_valid_padded is not None
if out_is_valid is not None:
C = torch.where(out_is_valid[:, :, None], C, 1e9)
if target_is_valid_padded is not None:
C = torch.where(target_is_valid_padded[:, None, :], C, 1e9)
C = C.cpu().numpy()
costs = [C[i, :, :s] for i, s in enumerate(num_boxes.tolist())]
return_tgt_indices = (
do_filtering or torch.any(num_queries < num_boxes * max(repeats, 1)).item()
)
if len(costs) == 0:
# We have size 0 in the batch dimension, so we return empty matching indices
# (note that this can happen due to `remove_samples_with_0_gt=True` even if
# the original input batch size is not 0, when all queries have empty GTs).
indices = []
tgt_idx = torch.zeros(0).long().to(device) if return_tgt_indices else None
elif return_tgt_indices:
indices, tgt_indices = zip(
*(
_do_matching(
c,
repeats=repeats,
return_tgt_indices=return_tgt_indices,
do_filtering=do_filtering,
)
for c in costs
)
)
tgt_indices = list(tgt_indices)
sizes = torch.cumsum(num_boxes, -1)[:-1]
for i in range(1, len(tgt_indices)):
tgt_indices[i] += sizes[i - 1].item()
tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long().to(device)
else:
indices = [
_do_matching(c, repeats=repeats, do_filtering=do_filtering)
for c in costs
]
tgt_idx = None
if self.remove_samples_with_0_gt:
kept_inds = batch_keep.nonzero().squeeze(1)
batch_idx = torch.as_tensor(
sum([[kept_inds[i]] * len(src) for i, src in enumerate(indices)], []),
dtype=torch.long,
device=device,
)
else:
batch_idx = torch.as_tensor(
sum([[i] * len(src) for i, src in enumerate(indices)], []),
dtype=torch.long,
device=device,
)
# indices could be an empty list (since we remove samples w/ 0 GT boxes)
if len(indices) > 0:
src_idx = torch.from_numpy(np.concatenate(indices)).long().to(device)
else:
src_idx = torch.empty(0, dtype=torch.long, device=device)
return batch_idx, src_idx, tgt_idx
class BinaryOneToManyMatcher(nn.Module):
"""
This class computes a greedy assignment between the targets and the predictions of the network.
In this formulation, several predictions can be assigned to each target, but each prediction can be assigned to
at most one target.
See DAC-Detr for details
"""
def __init__(
self,
alpha: float = 0.3,
threshold: float = 0.4,
topk: int = 6,
):
"""
Creates the matcher
Params:
alpha: relative balancing between classification and localization
threshold: threshold used to select positive predictions
topk: number of top scoring predictions to consider
"""
super().__init__()
self.norm = nn.Sigmoid()
self.alpha = alpha
self.threshold = threshold
self.topk = topk
@torch.no_grad()
def forward(
self,
outputs,
batched_targets,
repeats=1,
repeat_batch=1,
out_is_valid=None,
target_is_valid_padded=None,
):
"""
Performs the matching. The inputs and outputs are the same as
BinaryHungarianMatcher.forward
Inputs:
- outputs: A dict with the following keys:
- "pred_logits": Tensor of shape (batch_size, num_queries, 1) with
classification logits
- "pred_boxes": Tensor of shape (batch_size, num_queries, 4) with
predicted box coordinates in cxcywh format.
- batched_targets: A dict of targets. There may be a variable number of
targets per batch entry; suppose that there are T_b targets for batch
entry 0 <= b < batch_size. It should have the following keys:
- "num_boxes": int64 Tensor of shape (batch_size,) giving the number
of ground-truth boxes per batch entry: num_boxes[b] = T_b
- "_boxes_padded": Tensor of shape (batch_size, max_b T_b, 4) giving
a padded version of ground-truth boxes. If this is not present then
it will be computed from batched_targets["boxes"] instead, but
caching it here can improve performance for repeated calls with the
same targets.
- out_is_valid: If not None, it should be a boolean tensor of shape
(batch_size, num_queries) indicating which predictions are valid.
Invalid predictions are ignored during matching and won't appear in
the output indices.
- target_is_valid_padded: If not None, it should be a boolean tensor of
shape (batch_size, max_num_gt_boxes) in padded format indicating
which GT boxes are valid. Invalid GT boxes are ignored during matching
and won't appear in the output indices.
Returns:
A list of size batch_size, containing tuples of (idx_i, idx_j):
- idx_i is the indices of the selected predictions (in order)
- idx_j is the indices of the corresponding selected targets
(in order)
For each batch element, it holds:
len(index_i) = len(index_j)
= min(num_queries, num_target_boxes)
"""
assert repeats <= 1 and repeat_batch <= 1
bs, num_queries = outputs["pred_logits"].shape[:2]
out_prob = self.norm(outputs["pred_logits"]).squeeze(-1) # (B, Q)
out_bbox = outputs["pred_boxes"] # (B, Q, 4))
num_boxes = batched_targets["num_boxes"]
# Get a padded version of target boxes (as precomputed in the collator).
tgt_bbox = batched_targets["boxes_padded"]
assert len(tgt_bbox) == bs
num_targets = tgt_bbox.shape[1]
if num_targets == 0:
return (
torch.empty(0, dtype=torch.long, device=out_prob.device),
torch.empty(0, dtype=torch.long, device=out_prob.device),
torch.empty(0, dtype=torch.long, device=out_prob.device),
)
iou, _ = box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
assert iou.shape == (bs, num_queries, num_targets)
# Final cost matrix (higher is better in `C`; this is unlike the case
# of BinaryHungarianMatcherV2 above where lower is better in its `C`)
C = self.alpha * out_prob.unsqueeze(-1) + (1 - self.alpha) * iou
if out_is_valid is not None:
C = torch.where(out_is_valid[:, :, None], C, -1e9)
if target_is_valid_padded is not None:
C = torch.where(target_is_valid_padded[:, None, :], C, -1e9)
# Selecting topk predictions
matches = C > torch.quantile(
C, 1 - self.topk / num_queries, dim=1, keepdim=True
)
# Selecting predictions above threshold
matches = matches & (C > self.threshold)
if out_is_valid is not None:
matches = matches & out_is_valid[:, :, None]
if target_is_valid_padded is not None:
matches = matches & target_is_valid_padded[:, None, :]
# Removing padding
matches = matches & (
torch.arange(0, num_targets, device=num_boxes.device)[None]
< num_boxes[:, None]
).unsqueeze(1)
batch_idx, src_idx, tgt_idx = torch.nonzero(matches, as_tuple=True)
cum_num_boxes = torch.cat(
[
torch.zeros(1, dtype=num_boxes.dtype, device=num_boxes.device),
num_boxes.cumsum(-1)[:-1],
]
)
tgt_idx += cum_num_boxes[batch_idx]
return batch_idx, src_idx, tgt_idx

306
sam3/train/nms_helper.py Normal file
View File

@@ -0,0 +1,306 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import warnings
from typing import Dict, List
import numpy as np
# Check if Numba is available
HAS_NUMBA = False
try:
import numba as nb
HAS_NUMBA = True
except ImportError:
warnings.warn(
"Numba not found. Using slower pure Python implementations.", UserWarning
)
# -------------------- Helper Functions --------------------
def is_zero_box(bbox: list) -> bool:
"""Check if bounding box is invalid"""
if bbox is None:
return True
return all(x <= 0 for x in bbox[:4]) or len(bbox) < 4
def convert_bbox_format(bbox: list) -> List[float]:
"""Convert bbox from (x,y,w,h) to (x1,y1,x2,y2)"""
x, y, w, h = bbox
return [x, y, x + w, y + h]
# -------------------- Track-level NMS --------------------
def process_track_level_nms(video_groups: Dict, nms_threshold: float) -> Dict:
"""Apply track-level NMS to all videos"""
for video_id, tracks in video_groups.items():
track_detections = []
# Process tracks
for track_idx, track in enumerate(tracks):
if not track["bboxes"]:
continue
converted_bboxes = []
valid_frames = []
for bbox in track["bboxes"]:
if bbox and not is_zero_box(bbox):
converted_bboxes.append(convert_bbox_format(bbox))
valid_frames.append(True)
else:
converted_bboxes.append([np.nan] * 4)
valid_frames.append(False)
if any(valid_frames):
track_detections.append(
{
"track_idx": track_idx,
"bboxes": np.array(converted_bboxes, dtype=np.float32),
"score": track["score"],
}
)
# Apply NMS
if track_detections:
scores = np.array([d["score"] for d in track_detections], dtype=np.float32)
keep = apply_track_nms(track_detections, scores, nms_threshold)
# Suppress non-kept tracks
for idx, track in enumerate(track_detections):
if idx not in keep:
tracks[track["track_idx"]]["bboxes"] = [None] * len(track["bboxes"])
return video_groups
# -------------------- Frame-level NMS --------------------
def process_frame_level_nms(video_groups: Dict, nms_threshold: float) -> Dict:
"""Apply frame-level NMS to all videos"""
for video_id, tracks in video_groups.items():
if not tracks:
continue
num_frames = len(tracks[0]["bboxes"])
for frame_idx in range(num_frames):
frame_detections = []
# Collect valid detections
for track_idx, track in enumerate(tracks):
bbox = track["bboxes"][frame_idx]
if bbox and not is_zero_box(bbox):
frame_detections.append(
{
"track_idx": track_idx,
"bbox": np.array(
convert_bbox_format(bbox), dtype=np.float32
),
"score": track["score"],
}
)
# Apply NMS
if frame_detections:
bboxes = np.stack([d["bbox"] for d in frame_detections])
scores = np.array(
[d["score"] for d in frame_detections], dtype=np.float32
)
keep = apply_frame_nms(bboxes, scores, nms_threshold)
# Suppress non-kept detections
for i, d in enumerate(frame_detections):
if i not in keep:
tracks[d["track_idx"]]["bboxes"][frame_idx] = None
return video_groups
# Track-level NMS helpers ------------------------------------------------------
def compute_track_iou_matrix(
bboxes_stacked: np.ndarray, valid_masks: np.ndarray, areas: np.ndarray
) -> np.ndarray:
"""IoU matrix computation for track-level NMS with fallback to pure Python"""
num_tracks = bboxes_stacked.shape[0]
iou_matrix = np.zeros((num_tracks, num_tracks), dtype=np.float32)
if HAS_NUMBA:
iou_matrix = _compute_track_iou_matrix_numba(bboxes_stacked, valid_masks, areas)
else:
# Pure Python implementation
for i in range(num_tracks):
for j in range(i + 1, num_tracks):
valid_ij = valid_masks[i] & valid_masks[j]
if not valid_ij.any():
continue
bboxes_i = bboxes_stacked[i, valid_ij]
bboxes_j = bboxes_stacked[j, valid_ij]
area_i = areas[i, valid_ij]
area_j = areas[j, valid_ij]
inter_total = 0.0
union_total = 0.0
for k in range(bboxes_i.shape[0]):
x1 = max(bboxes_i[k, 0], bboxes_j[k, 0])
y1 = max(bboxes_i[k, 1], bboxes_j[k, 1])
x2 = min(bboxes_i[k, 2], bboxes_j[k, 2])
y2 = min(bboxes_i[k, 3], bboxes_j[k, 3])
inter = max(0, x2 - x1) * max(0, y2 - y1)
union = area_i[k] + area_j[k] - inter
inter_total += inter
union_total += union
if union_total > 0:
iou_matrix[i, j] = inter_total / union_total
iou_matrix[j, i] = iou_matrix[i, j]
return iou_matrix
if HAS_NUMBA:
@nb.jit(nopython=True, parallel=True)
def _compute_track_iou_matrix_numba(bboxes_stacked, valid_masks, areas):
"""Numba-optimized IoU matrix computation for track-level NMS"""
num_tracks = bboxes_stacked.shape[0]
iou_matrix = np.zeros((num_tracks, num_tracks), dtype=np.float32)
for i in nb.prange(num_tracks):
for j in range(i + 1, num_tracks):
valid_ij = valid_masks[i] & valid_masks[j]
if not valid_ij.any():
continue
bboxes_i = bboxes_stacked[i, valid_ij]
bboxes_j = bboxes_stacked[j, valid_ij]
area_i = areas[i, valid_ij]
area_j = areas[j, valid_ij]
inter_total = 0.0
union_total = 0.0
for k in range(bboxes_i.shape[0]):
x1 = max(bboxes_i[k, 0], bboxes_j[k, 0])
y1 = max(bboxes_i[k, 1], bboxes_j[k, 1])
x2 = min(bboxes_i[k, 2], bboxes_j[k, 2])
y2 = min(bboxes_i[k, 3], bboxes_j[k, 3])
inter = max(0, x2 - x1) * max(0, y2 - y1)
union = area_i[k] + area_j[k] - inter
inter_total += inter
union_total += union
if union_total > 0:
iou_matrix[i, j] = inter_total / union_total
iou_matrix[j, i] = iou_matrix[i, j]
return iou_matrix
def apply_track_nms(
track_detections: List[dict], scores: np.ndarray, nms_threshold: float
) -> List[int]:
"""Vectorized track-level NMS implementation"""
if not track_detections:
return []
bboxes_stacked = np.stack([d["bboxes"] for d in track_detections], axis=0)
valid_masks = ~np.isnan(bboxes_stacked).any(axis=2)
areas = (bboxes_stacked[:, :, 2] - bboxes_stacked[:, :, 0]) * (
bboxes_stacked[:, :, 3] - bboxes_stacked[:, :, 1]
)
areas[~valid_masks] = 0
iou_matrix = compute_track_iou_matrix(bboxes_stacked, valid_masks, areas)
keep = []
order = np.argsort(-scores)
suppress = np.zeros(len(track_detections), dtype=bool)
for i in range(len(order)):
if not suppress[order[i]]:
keep.append(order[i])
suppress[order[i:]] = suppress[order[i:]] | (
iou_matrix[order[i], order[i:]] >= nms_threshold
)
return keep
# Frame-level NMS helpers ------------------------------------------------------
def compute_frame_ious(bbox: np.ndarray, bboxes: np.ndarray) -> np.ndarray:
"""IoU computation for frame-level NMS with fallback to pure Python"""
if HAS_NUMBA:
return _compute_frame_ious_numba(bbox, bboxes)
else:
# Pure Python implementation
ious = np.zeros(len(bboxes), dtype=np.float32)
for i in range(len(bboxes)):
x1 = max(bbox[0], bboxes[i, 0])
y1 = max(bbox[1], bboxes[i, 1])
x2 = min(bbox[2], bboxes[i, 2])
y2 = min(bbox[3], bboxes[i, 3])
inter = max(0, x2 - x1) * max(0, y2 - y1)
area1 = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
area2 = (bboxes[i, 2] - bboxes[i, 0]) * (bboxes[i, 3] - bboxes[i, 1])
union = area1 + area2 - inter
ious[i] = inter / union if union > 0 else 0.0
return ious
if HAS_NUMBA:
@nb.jit(nopython=True, parallel=True)
def _compute_frame_ious_numba(bbox, bboxes):
"""Numba-optimized IoU computation"""
ious = np.zeros(len(bboxes), dtype=np.float32)
for i in nb.prange(len(bboxes)):
x1 = max(bbox[0], bboxes[i, 0])
y1 = max(bbox[1], bboxes[i, 1])
x2 = min(bbox[2], bboxes[i, 2])
y2 = min(bbox[3], bboxes[i, 3])
inter = max(0, x2 - x1) * max(0, y2 - y1)
area1 = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
area2 = (bboxes[i, 2] - bboxes[i, 0]) * (bboxes[i, 3] - bboxes[i, 1])
union = area1 + area2 - inter
ious[i] = inter / union if union > 0 else 0.0
return ious
def apply_frame_nms(
bboxes: np.ndarray, scores: np.ndarray, nms_threshold: float
) -> List[int]:
"""Frame-level NMS implementation with fallback to pure Python"""
if HAS_NUMBA:
return _apply_frame_nms_numba(bboxes, scores, nms_threshold)
else:
# Pure Python implementation
order = np.argsort(-scores)
keep = []
suppress = np.zeros(len(bboxes), dtype=bool)
for i in range(len(order)):
if not suppress[order[i]]:
keep.append(order[i])
current_bbox = bboxes[order[i]]
remaining_bboxes = bboxes[order[i + 1 :]]
if len(remaining_bboxes) > 0: # Check if there are any remaining boxes
ious = compute_frame_ious(current_bbox, remaining_bboxes)
suppress[order[i + 1 :]] = suppress[order[i + 1 :]] | (
ious >= nms_threshold
)
return keep
if HAS_NUMBA:
@nb.jit(nopython=True)
def _apply_frame_nms_numba(bboxes, scores, nms_threshold):
"""Numba-optimized NMS implementation"""
order = np.argsort(-scores)
keep = []
suppress = np.zeros(len(bboxes), dtype=nb.boolean)
for i in range(len(order)):
if not suppress[order[i]]:
keep.append(order[i])
current_bbox = bboxes[order[i]]
if i + 1 < len(order): # Check bounds
ious = _compute_frame_ious_numba(
current_bbox, bboxes[order[i + 1 :]]
)
suppress[order[i + 1 :]] = suppress[order[i + 1 :]] | (
ious >= nms_threshold
)
return keep

View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

View File

@@ -0,0 +1,498 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import fnmatch
import inspect
import itertools
import logging
import types
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
Type,
Union,
)
import hydra
import torch
import torch.nn as nn
from omegaconf import DictConfig
from torch import Tensor
class Optimizer:
def __init__(self, optimizer, schedulers=None) -> None:
self.optimizer = optimizer
self.schedulers = schedulers
self._validate_optimizer_schedulers()
self.step_schedulers(0.0, 0)
def _validate_optimizer_schedulers(self):
if self.schedulers is None:
return
for _, set_of_schedulers in enumerate(self.schedulers):
for option, _ in set_of_schedulers.items():
assert option in self.optimizer.defaults, (
"Optimizer option "
f"{option} not found in {self.optimizer}. Valid options are "
f"{self.optimizer.defaults.keys()}"
)
def step_schedulers(self, where: float, step: int) -> None:
if self.schedulers is None:
return
for i, param_group in enumerate(self.optimizer.param_groups):
for option, scheduler in self.schedulers[i].items():
if "step" in inspect.signature(scheduler.__call__).parameters:
new_value = scheduler(step=step, where=where)
elif (
hasattr(scheduler, "scheduler")
and "step"
in inspect.signature(scheduler.scheduler.__call__).parameters
):
# To handle ValueScaler wrappers
new_value = scheduler(step=step, where=where)
else:
new_value = scheduler(where)
param_group[option] = new_value
def step(self, where, step, closure=None):
self.step_schedulers(where, step)
return self.optimizer.step(closure)
def zero_grad(self, *args, **kwargs):
return self.optimizer.zero_grad(*args, **kwargs)
def set_default_parameters(
scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str]
) -> None:
"""Set up the "default" scheduler with the right parameters.
Args:
scheduler_cgfs: A list of scheduler configs, where each scheduler also
specifies which parameters it applies to, based on the names of parameters
or the class of the modules. At most one scheduler is allowed to skip this
specification, which is used as a "default" specification for any remaining
parameters.
all_parameter_names: Names of all the parameters to consider.
"""
constraints = [
scheduler_cfg.parameter_names
for scheduler_cfg in scheduler_cfgs
if scheduler_cfg.parameter_names is not None
]
if len(constraints) == 0:
default_params = set(all_parameter_names)
else:
default_params = all_parameter_names - set.union(*constraints)
default_count = 0
for scheduler_cfg in scheduler_cfgs:
if scheduler_cfg.parameter_names is None:
scheduler_cfg.parameter_names = default_params
default_count += 1
assert default_count <= 1, "Only one scheduler per option can be default"
if default_count == 0:
# No default scheduler specified, add a default, but without any scheduler
# for that option
scheduler_cfgs.append({"parameter_names": default_params})
def name_constraints_to_parameters(
param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]
) -> List[torch.nn.Parameter]:
"""Return parameters which match the intersection of parameter constraints.
Note that this returns the parameters themselves, not their names.
Args:
param_constraints: A list, with each element being a set of allowed parameters.
named_parameters: Mapping from a parameter name to the parameter itself.
Returns:
A list containing the parameters which overlap with _each_ constraint set from
param_constraints.
"""
matching_names = set.intersection(*param_constraints)
return [value for name, value in named_parameters.items() if name in matching_names]
def map_scheduler_cfgs_to_param_groups(
all_scheduler_cfgs: Iterable[List[Dict]],
named_parameters: Dict[str, Tensor],
) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]:
"""Produce parameter groups corresponding to all the scheduler configs.
Takes all the scheduler configs, each of which applies to a specific optimizer
option (like "lr" or "weight_decay") and has a set of parameter names which it
applies to, and produces a final set of param groups where each param group
covers all the options which apply to a particular set of parameters.
Args:
all_scheduler_cfgs: All the scheduler configs covering every option.
named_parameters: Mapping from a parameter name to the parameter itself.
Returns:
Tuple of lists of schedulers and param_groups, where schedulers[i]
applies to param_groups[i].
"""
scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs)
schedulers = []
param_groups = []
for scheduler_cfgs in scheduler_cfgs_per_param_group:
param_constraints = [
scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs
]
matching_parameters = name_constraints_to_parameters(
param_constraints, named_parameters
)
if len(matching_parameters) == 0: # If no overlap of parameters, skip
continue
schedulers_for_group = {
scheduler_cfg["option"]: scheduler_cfg["scheduler"]
for scheduler_cfg in scheduler_cfgs
if "option" in scheduler_cfg
}
schedulers.append(schedulers_for_group)
param_groups.append({"params": matching_parameters})
return schedulers, param_groups
def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
"""Check that the param groups are non-overlapping and cover all the parameters.
Args:
param_groups: List of all param groups
model: Model to validate against. The check ensures that all the model
parameters are part of param_groups
"""
for pg in param_groups:
# no param should be repeated within a group
assert len(pg["params"]) == len(set(pg["params"]))
parameters = [set(param_group["params"]) for param_group in param_groups]
model_parameters = {parameter for _, parameter in model.named_parameters()}
for p1, p2 in itertools.permutations(parameters, 2):
assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint"
assert set.union(*parameters) == model_parameters, (
"Scheduler generated param_groups must include all parameters of the model."
f" Found {len(set.union(*parameters))} params whereas model has"
f" {len(model_parameters)} params"
)
def unix_module_cls_pattern_to_parameter_names(
filter_module_cls_names: List[str],
module_cls_to_param_names: Dict[Type, str],
) -> Union[None, Set[str]]:
"""Returns param names which pass the filters specified in filter_module_cls_names.
Args:
filter_module_cls_names: A list of filter strings containing class names, like
["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"]
module_cls_to_param_names: Mapping from module classes to the parameter names
they contain. See `get_module_cls_to_param_names`.
"""
if filter_module_cls_names is None:
return set()
allowed_parameter_names = []
for module_cls_name in filter_module_cls_names:
module_cls = hydra.utils.get_class(module_cls_name)
if module_cls not in module_cls_to_param_names:
raise AssertionError(
f"module_cls_name {module_cls_name} does not "
"match any classes in the model"
)
matching_parameters = module_cls_to_param_names[module_cls]
assert (
len(matching_parameters) > 0
), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
logging.info(
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
)
allowed_parameter_names.append(matching_parameters)
return set.union(*allowed_parameter_names)
def unix_param_pattern_to_parameter_names(
filter_param_names: Optional[List[str]],
parameter_names: Dict[str, torch.Tensor],
) -> Union[None, Set[str]]:
"""Returns param names which pass the filters specified in filter_param_names.
Args:
filter_param_names: A list of unix-style filter strings with optional
wildcards, like ["block.2.*", "block.2.linear.weight"]
module_cls_to_param_names: Mapping from module classes to the parameter names
they contain. See `get_module_cls_to_param_names`.
"""
if filter_param_names is None:
return set()
allowed_parameter_names = []
for param_name in filter_param_names:
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
assert (
len(matching_parameters) >= 1
), f"param_name {param_name} does not match any parameters in the model"
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
allowed_parameter_names.append(matching_parameters)
return set.union(*allowed_parameter_names)
def _unix_pattern_to_parameter_names(
scheduler_cfg: DictConfig,
parameter_names: Set[str],
module_cls_to_param_names: Dict[Type, str],
) -> Union[None, Set[str]]:
"""Returns param names which pass the filters specified in scheduler_cfg.
Args:
scheduler_cfg: The config for the scheduler
parameter_names: The set of all parameter names which will be filtered
"""
if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
return None
return unix_param_pattern_to_parameter_names(
scheduler_cfg.get("param_names"), parameter_names
).union(
unix_module_cls_pattern_to_parameter_names(
scheduler_cfg.get("module_cls_names"), module_cls_to_param_names
)
)
def get_module_cls_to_param_names(
model: nn.Module, param_allowlist: Set[str] = None
) -> Dict[Type, str]:
"""Produce a mapping from all the modules classes to the names of parames they own.
Only counts a parameter as part of the immediate parent module, i.e. recursive
parents do not count.
Args:
model: Model to iterate over
param_allowlist: If specified, only these param names will be processed
"""
module_cls_to_params = {}
for module_name, module in model.named_modules():
module_cls = type(module)
module_cls_to_params.setdefault(module_cls, set())
for param_name, _ in module.named_parameters(recurse=False):
full_param_name = get_full_parameter_name(module_name, param_name)
if param_allowlist is None or full_param_name in param_allowlist:
module_cls_to_params[module_cls].add(full_param_name)
return module_cls_to_params
def construct_optimizer(
model: torch.nn.Module,
optimizer_conf: Any,
options_conf: Mapping[str, List] = None,
param_group_modifiers_conf: List[Callable] = None,
param_allowlist: Optional[Set[str]] = None,
validate_param_groups=True,
) -> Optimizer:
"""
Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
Batchnorm and/or no-update 1-D parameters support, based on the config.
Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
(LARS): https://arxiv.org/abs/1708.03888
Args:
model: model to perform stochastic gradient descent
optimization or ADAM optimization.
optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or
ADAM, still missing the params argument which this function provides to
produce the final optimizer
param_group_modifiers_conf: Optional user specified functions which can modify
the final scheduler configs before the optimizer's param groups are built
param_allowlist: The parameters to optimize. Parameters which are not part of
this allowlist will be skipped.
validate_param_groups: If enabled, valides that the produced param_groups don't
overlap and cover all the model parameters.
"""
if param_allowlist is None:
param_allowlist = {name for name, _ in model.named_parameters()}
named_parameters = {
name: param
for name, param in model.named_parameters()
if name in param_allowlist
}
if not options_conf:
optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
return Optimizer(optimizer)
all_parameter_names = {
name for name, _ in model.named_parameters() if name in param_allowlist
}
module_cls_to_all_param_names = get_module_cls_to_param_names(
model, param_allowlist
)
scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
all_scheduler_cfgs = []
for option, scheduler_cfgs in scheduler_cfgs_per_option.items():
for config in scheduler_cfgs:
config.option = option
config.parameter_names = _unix_pattern_to_parameter_names(
config, all_parameter_names, module_cls_to_all_param_names
)
set_default_parameters(scheduler_cfgs, all_parameter_names)
all_scheduler_cfgs.append(scheduler_cfgs)
if param_group_modifiers_conf:
for custom_param_modifier in param_group_modifiers_conf:
custom_param_modifier = hydra.utils.instantiate(custom_param_modifier)
all_scheduler_cfgs = custom_param_modifier(
scheduler_cfgs=all_scheduler_cfgs, model=model
)
schedulers, param_groups = map_scheduler_cfgs_to_param_groups(
all_scheduler_cfgs, named_parameters
)
if validate_param_groups:
validate_param_group_params(param_groups, model)
optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
return Optimizer(optimizer, schedulers)
def get_full_parameter_name(module_name, param_name):
if module_name == "":
return param_name
return f"{module_name}.{param_name}"
class GradientClipper:
"""
Gradient clipping utils that works for DDP
"""
def __init__(self, max_norm: float = 1.0, norm_type: int = 2):
assert isinstance(max_norm, (int, float)) or max_norm is None
self.max_norm = max_norm if max_norm is None else float(max_norm)
self.norm_type = norm_type
def __call__(self, model: nn.Module):
if self.max_norm is None:
return # no-op
nn.utils.clip_grad_norm_(
model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type
)
class ValueScaler:
def __init__(self, scheduler, mult_val: float):
self.scheduler = scheduler
self.mult_val = mult_val
def __call__(self, *args, **kwargs):
val = self.scheduler(*args, **kwargs)
return val * self.mult_val
def rgetattr(obj, rattrs: str = None):
"""
Like getattr(), but supports dotted notation for nested objects.
rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2
"""
if rattrs is None:
return obj
attrs = rattrs.split(".")
for attr in attrs:
obj = getattr(obj, attr)
return obj
def layer_decay_param_modifier(
scheduler_cfgs: List[List[Dict]],
model,
layer_decay_value: float,
layer_decay_min: Optional[float] = None,
apply_to: Optional[str] = None,
overrides: List[Dict] = (),
) -> List[List[Dict]]:
"""
Args
- scheduler_cfgs: a list of omegaconf.ListConfigs.
Each element in the list is a omegaconfg.DictConfig with the following structure
{
"scheduler": <some fvcore scheduler>
"option": <value> possible options are "lr", "weight_decay" etc.
"parameter_names": Set of str indicating param names that this scheduler applies to
}
- model: a model that implements a method `get_layer_id` that maps layer_name to an integer and
and a method get_num_layers.
Alternatively, use apply_to argument to select a specific component of the model.
- layer_decay_value: float
- layer_decay_min: min val for layer decay
- apply_to: optional arg to select which component of the model to apply the the layer decay modifier to
- overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value".
Returns
- scheduler_configs: same structure as the input, elements can be modified
"""
model = rgetattr(model, apply_to)
num_layers = model.get_num_layers() + 1
layer_decays = [
layer_decay_value ** (num_layers - i) for i in range(num_layers + 1)
]
if layer_decay_min is not None:
layer_decays = [max(val, layer_decay_min) for val in layer_decays]
final_scheduler_cfgs = []
# scheduler_cfgs is a list of lists
for scheduler_cfg_group in scheduler_cfgs:
curr_cfg_group = []
# scheduler_cfg_group is a list of dictionaries
for scheduler_cfg in scheduler_cfg_group:
if scheduler_cfg["option"] != "lr":
curr_cfg_group.append(scheduler_cfg)
continue
# Need sorted so that the list of parameter names is deterministic and consistent
# across re-runs of this job. Else it was causing issues with loading the optimizer
# state during a job restart
parameter_names = sorted(scheduler_cfg["parameter_names"])
# Only want one cfg group per layer
layer_cfg_groups = {}
for param_name in parameter_names:
layer_id = num_layers
this_scale = layer_decays[layer_id]
if param_name.startswith(apply_to):
layer_id = model.get_layer_id(param_name)
this_scale = layer_decays[layer_id]
# Overrides
for override in overrides:
if fnmatch.fnmatchcase(param_name, override["pattern"]):
this_scale = float(override["value"])
layer_id = override["pattern"]
break
if layer_id not in layer_cfg_groups:
curr_param = {
"option": scheduler_cfg["option"],
"scheduler": ValueScaler(
scheduler_cfg["scheduler"], this_scale
),
"parameter_names": {param_name},
}
else:
curr_param = layer_cfg_groups[layer_id]
curr_param["parameter_names"].add(param_name)
layer_cfg_groups[layer_id] = curr_param
for layer_cfg in layer_cfg_groups.values():
curr_cfg_group.append(layer_cfg)
final_scheduler_cfgs.append(curr_cfg_group)
return final_scheduler_cfgs

View File

@@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import math
class InverseSquareRootParamScheduler:
def __init__(
self,
base_lr: float,
warmup_steps: int,
cooldown_steps: int,
timescale: int,
):
self.base_lr = base_lr
self.warmup_steps = warmup_steps
self.cooldown_steps = cooldown_steps
self.timescale = timescale
def __call__(self, step: int, where: float):
lr = self.base_lr
if where > 0:
total_steps = step / where
progress = (step - self.warmup_steps) / float(
total_steps - self.warmup_steps
)
progress = max(min(progress, 1), 0)
else:
progress = 0
total_steps = 1
shift = self.timescale - self.warmup_steps
if self.warmup_steps < step:
lr = lr / math.sqrt((step + shift) / self.timescale)
if self.warmup_steps:
lr = lr * min(1.0, step / self.warmup_steps)
if self.cooldown_steps:
lr = lr * min(1.0, (total_steps - step) / self.cooldown_steps)
return lr

339
sam3/train/train.py Normal file
View File

@@ -0,0 +1,339 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import logging
import os
import random
import sys
import traceback
from argparse import ArgumentParser
from copy import deepcopy
import submitit
import torch
from hydra import compose, initialize_config_module
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from omegaconf import OmegaConf
from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers
from tqdm import tqdm
os.environ["HYDRA_FULL_ERROR"] = "1"
class SlurmEvent:
QUEUED = "QUEUED"
START = "START"
FINISH = "FINISH"
JOB_ERROR = "JOB_ERROR"
SLURM_SIGNAL = "SLURM_SIGNAL"
def handle_custom_resolving(cfg):
# We'll resolve the config here, so we can catch mistakes early.
# However, we need to pass the un-resolved config to the launcher
# (because DVC resolving needs to be done on the node it will run on)
# First, do a copy without triggering resolving
cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
cfg_resolved = OmegaConf.create(cfg_resolved)
return cfg_resolved
def single_proc_run(local_rank, main_port, cfg, world_size):
"""Single GPU process"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(main_port)
os.environ["RANK"] = str(local_rank)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
try:
register_omegaconf_resolvers()
except Exception as e:
logging.info(e)
trainer = instantiate(cfg.trainer, _recursive_=False)
trainer.run()
def single_node_runner(cfg, main_port: int):
assert cfg.launcher.num_nodes == 1
# assert cfg.launcher.gpus_per_node == 1
num_proc = cfg.launcher.gpus_per_node
torch.multiprocessing.set_start_method(
"spawn"
) # CUDA runtime does not support `fork`
if num_proc == 1:
# directly call single_proc so we can easily set breakpoints
# mp.spawn does not let us set breakpoints
single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc)
else:
mp_runner = torch.multiprocessing.start_processes
args = (main_port, cfg, num_proc)
# Note: using "fork" below, "spawn" causes time and error regressions. Using
# spawn changes the default multiprocessing context to spawn, which doesn't
# interact well with the dataloaders (likely due to the use of OpenCV).
mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn")
def format_exception(e: Exception, limit=20):
traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit))
return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}"
class SubmititRunner(submitit.helpers.Checkpointable):
"""A callable which is passed to submitit to launch the jobs."""
def __init__(self, port, cfg):
self.cfg = cfg
self.port = port
self.has_setup = False
def run_trainer(self):
job_env = submitit.JobEnvironment()
# Need to add this again so the hydra.job.set_env PYTHONPATH
# is also set when launching jobs.
add_pythonpath_to_sys_path()
os.environ["MASTER_ADDR"] = job_env.hostnames[0]
os.environ["MASTER_PORT"] = str(self.port)
os.environ["RANK"] = str(job_env.global_rank)
os.environ["LOCAL_RANK"] = str(job_env.local_rank)
os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
register_omegaconf_resolvers()
cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False)
cfg_resolved = OmegaConf.create(cfg_resolved)
trainer = instantiate(cfg_resolved.trainer, _recursive_=False)
trainer.run()
def __call__(self):
job_env = submitit.JobEnvironment()
self.setup_job_info(job_env.job_id, job_env.global_rank)
try:
self.run_trainer()
except Exception as e:
# Log the exception. Then raise it again (as what SubmititRunner currently does).
message = format_exception(e)
logging.error(message)
raise e
def setup_job_info(self, job_id, rank):
"""Set up slurm job info"""
self.job_info = {
"job_id": job_id,
"rank": rank,
"cluster": self.cfg.get("cluster", None),
"experiment_log_dir": self.cfg.launcher.experiment_log_dir,
}
self.has_setup = True
def add_pythonpath_to_sys_path():
if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]:
return
sys.path = os.environ["PYTHONPATH"].split(":") + sys.path
def main(args) -> None:
cfg = compose(config_name=args.config)
if cfg.launcher.experiment_log_dir is None:
cfg.launcher.experiment_log_dir = os.path.join(
os.getcwd(), "sam3_logs", args.config
)
print("###################### Train App Config ####################")
print(OmegaConf.to_yaml(cfg))
print("############################################################")
add_pythonpath_to_sys_path()
makedir(cfg.launcher.experiment_log_dir)
with g_pathmgr.open(
os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w"
) as f:
f.write(OmegaConf.to_yaml(cfg))
cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
cfg_resolved = OmegaConf.create(cfg_resolved)
with g_pathmgr.open(
os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w"
) as f:
f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True))
submitit_conf = cfg.get("submitit", None)
assert submitit_conf is not None, "Missing submitit config"
experiment_log_dir = cfg.launcher.experiment_log_dir
print(f"Experiment Log Dir:\n{experiment_log_dir}")
submitit_dir = os.path.join(experiment_log_dir, "submitit_logs")
# Prioritize cmd line args
cfg.launcher.gpus_per_node = (
args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node
)
cfg.launcher.num_nodes = (
args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes
)
submitit_conf.use_cluster = (
args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster
)
if submitit_conf.use_cluster:
executor = submitit.AutoExecutor(folder=submitit_dir)
submitit_conf.partition = (
args.partition
if args.partition is not None
else submitit_conf.get("partition", None)
)
submitit_conf.account = (
args.account
if args.account is not None
else submitit_conf.get("account", None)
)
submitit_conf.qos = (
args.qos if args.qos is not None else submitit_conf.get("qos", None)
)
job_kwargs = {
"timeout_min": 60 * submitit_conf.timeout_hour,
"name": (
submitit_conf.name if hasattr(submitit_conf, "name") else args.config
),
"slurm_partition": submitit_conf.partition,
"gpus_per_node": cfg.launcher.gpus_per_node,
"tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU
"cpus_per_task": submitit_conf.cpus_per_task,
"nodes": cfg.launcher.num_nodes,
"slurm_additional_parameters": {
"exclude": " ".join(submitit_conf.get("exclude_nodes", [])),
},
}
if "include_nodes" in submitit_conf:
assert (
len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
), "Not enough nodes"
job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
submitit_conf["include_nodes"]
)
if submitit_conf.account is not None:
job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account
if submitit_conf.qos is not None:
job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos
if submitit_conf.get("mem_gb", None) is not None:
job_kwargs["mem_gb"] = submitit_conf.mem_gb
elif submitit_conf.get("mem", None) is not None:
job_kwargs["slurm_mem"] = submitit_conf.mem
if submitit_conf.get("constraints", None) is not None:
job_kwargs["slurm_constraint"] = submitit_conf.constraints
if submitit_conf.get("comment", None) is not None:
job_kwargs["slurm_comment"] = submitit_conf.comment
# Supports only cpu-bind option within srun_args. New options can be added here
if submitit_conf.get("srun_args", None) is not None:
job_kwargs["slurm_srun_args"] = []
if submitit_conf.srun_args.get("cpu_bind", None) is not None:
job_kwargs["slurm_srun_args"].extend(
["--cpu-bind", submitit_conf.srun_args.cpu_bind]
)
print("###################### SLURM Config ####################")
print(job_kwargs)
print("##########################################")
executor.update_parameters(**job_kwargs)
if (
"job_array" in submitit_conf
and submitit_conf.job_array.get("num_tasks", -1) > 0
):
num_tasks = submitit_conf.job_array.num_tasks
job_array_config_dir = os.path.join(
cfg.launcher.experiment_log_dir, "job_array_configs"
)
makedir(job_array_config_dir)
job_indices = range(num_tasks)
ports = random.sample(
range(submitit_conf.port_range[0], submitit_conf.port_range[1] + 1),
k=len(job_indices),
)
jobs_runners_configs = []
with executor.batch():
task_index = 0
for indices, main_port in tqdm(zip(job_indices, ports)):
curr_cfg = deepcopy(cfg)
curr_cfg.submitit.job_array["task_index"] = task_index
curr_cfg_resolved = handle_custom_resolving(cfg)
runner = SubmititRunner(main_port, curr_cfg)
job = executor.submit(runner)
jobs_runners_configs.append(
(job, runner, curr_cfg, curr_cfg_resolved)
)
task_index += 1
for job, runner, job_cfg, job_cfg_resolved in jobs_runners_configs:
print("Submitit Job ID:", job.job_id)
# Save job specific config
job_array_config_file = os.path.join(
job_array_config_dir, "{}.config.yaml".format(job.job_id)
)
with g_pathmgr.open(job_array_config_file, "w") as f:
f.write(OmegaConf.to_yaml(job_cfg))
job_array_config_resolved_file = os.path.join(
job_array_config_dir, "{}.config_resolved.yaml".format(job.job_id)
)
with g_pathmgr.open(job_array_config_resolved_file, "w") as f:
f.write(OmegaConf.to_yaml(job_cfg_resolved, resolve=True))
runner.setup_job_info(job.job_id, rank=0)
# runner.log_event(event_type=SlurmEvent.QUEUED)
else:
main_port = random.randint(
submitit_conf.port_range[0], submitit_conf.port_range[1]
)
runner = SubmititRunner(main_port, cfg)
job = executor.submit(runner)
print(f"Submitit Job ID: {job.job_id}")
runner.setup_job_info(job.job_id, rank=0)
else:
cfg.launcher.num_nodes = 1
main_port = random.randint(
submitit_conf.port_range[0], submitit_conf.port_range[1]
)
single_node_runner(cfg, main_port)
if __name__ == "__main__":
initialize_config_module("sam3.train", version_base="1.2")
parser = ArgumentParser()
parser.add_argument(
"-c",
"--config",
required=True,
type=str,
help="path to config file (e.g. configs/roboflow_v100_full_ft_100_images.yaml)",
)
parser.add_argument(
"--use-cluster",
type=int,
default=None,
help="whether to launch on a cluster, 0: run locally, 1: run on a cluster",
)
parser.add_argument("--partition", type=str, default=None, help="SLURM partition")
parser.add_argument("--account", type=str, default=None, help="SLURM account")
parser.add_argument("--qos", type=str, default=None, help="SLURM qos")
parser.add_argument(
"--num-gpus", type=int, default=None, help="number of GPUS per node"
)
parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes")
args = parser.parse_args()
args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None
register_omegaconf_resolvers()
main(args)

1193
sam3/train/trainer.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

View File

@@ -0,0 +1,455 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Transforms and data augmentation for both image + bbox.
"""
import math
import random
from typing import Iterable
import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from sam3.model.box_ops import box_xyxy_to_cxcywh
from sam3.model.data_misc import interpolate
def crop(image, target, region):
cropped_image = F.crop(image, *region)
target = target.copy()
i, j, h, w = region
# should we do something wrt the original size?
target["size"] = torch.tensor([h, w])
fields = ["labels", "area", "iscrowd", "positive_map"]
if "boxes" in target:
boxes = target["boxes"]
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
target["boxes"] = cropped_boxes.reshape(-1, 4)
target["area"] = area
fields.append("boxes")
if "input_boxes" in target:
boxes = target["input_boxes"]
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
target["input_boxes"] = cropped_boxes.reshape(-1, 4)
if "masks" in target:
# FIXME should we update the area here if there are no boxes?
target["masks"] = target["masks"][:, i : i + h, j : j + w]
fields.append("masks")
# remove elements for which the boxes or masks that have zero area
if "boxes" in target or "masks" in target:
# favor boxes selection when defining which elements to keep
# this is compatible with previous implementation
if "boxes" in target:
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
else:
keep = target["masks"].flatten(1).any(1)
for field in fields:
if field in target:
target[field] = target[field][keep]
return cropped_image, target
def hflip(image, target):
flipped_image = F.hflip(image)
w, h = image.size
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
[-1, 1, -1, 1]
) + torch.as_tensor([w, 0, w, 0])
target["boxes"] = boxes
if "input_boxes" in target:
boxes = target["input_boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
[-1, 1, -1, 1]
) + torch.as_tensor([w, 0, w, 0])
target["input_boxes"] = boxes
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
if "text_input" in target:
text_input = (
target["text_input"]
.replace("left", "[TMP]")
.replace("right", "left")
.replace("[TMP]", "right")
)
target["text_input"] = text_input
return flipped_image, target
def resize(image, target, size, max_size=None, square=False):
# size can be min_size (scalar) or (w, h) tuple
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
if square:
size = size, size
else:
size = get_size(image.size, size, max_size)
rescaled_image = F.resize(image, size)
if target is None:
return rescaled_image, None
ratios = tuple(
float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)
)
ratio_width, ratio_height = ratios
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32
)
target["boxes"] = scaled_boxes
if "input_boxes" in target:
boxes = target["input_boxes"]
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32
)
target["input_boxes"] = scaled_boxes
if "area" in target:
area = target["area"]
scaled_area = area * (ratio_width * ratio_height)
target["area"] = scaled_area
h, w = size
target["size"] = torch.tensor([h, w])
if "masks" in target:
target["masks"] = (
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0]
> 0.5
)
return rescaled_image, target
def pad(image, target, padding):
if len(padding) == 2:
# assumes that we only pad on the bottom right corners
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
else:
# left, top, right, bottom
padded_image = F.pad(image, (padding[0], padding[1], padding[2], padding[3]))
if target is None:
return padded_image, None
target = target.copy()
w, h = padded_image.size
# should we do something wrt the original size?
target["size"] = torch.tensor([h, w])
if "boxes" in target and len(padding) == 4:
boxes = target["boxes"]
boxes = boxes + torch.as_tensor(
[padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32
)
target["boxes"] = boxes
if "input_boxes" in target and len(padding) == 4:
boxes = target["input_boxes"]
boxes = boxes + torch.as_tensor(
[padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32
)
target["input_boxes"] = boxes
if "masks" in target:
if len(padding) == 2:
target["masks"] = torch.nn.functional.pad(
target["masks"], (0, padding[0], 0, padding[1])
)
else:
target["masks"] = torch.nn.functional.pad(
target["masks"], (padding[0], padding[2], padding[1], padding[3])
)
return padded_image, target
class RandomCrop:
def __init__(self, size):
self.size = size
def __call__(self, img, target):
region = T.RandomCrop.get_params(img, self.size)
return crop(img, target, region)
class RandomSizeCrop:
def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
self.min_size = min_size
self.max_size = max_size
self.respect_boxes = respect_boxes # if True we can't crop a box out
def __call__(self, img: PIL.Image.Image, target: dict):
init_boxes = len(target["boxes"])
init_boxes_tensor = target["boxes"].clone()
if self.respect_boxes and init_boxes > 0:
minW, minH, maxW, maxH = (
min(img.width, self.min_size),
min(img.width, self.min_size),
min(img.width, self.max_size),
min(img.height, self.max_size),
)
minX, minY = (
target["boxes"][:, 0].max().item() + 10.0,
target["boxes"][:, 1].max().item() + 10.0,
)
minX = min(img.width, minX)
minY = min(img.height, minY)
maxX, maxY = (
target["boxes"][:, 2].min().item() - 10,
target["boxes"][:, 3].min().item() - 10,
)
maxX = max(0.0, maxX)
maxY = max(0.0, maxY)
minW = max(minW, minX - maxX)
minH = max(minH, minY - maxY)
w = random.uniform(minW, max(minW, maxW))
h = random.uniform(minH, max(minH, maxH))
if minX > maxX:
# i = random.uniform(max(0, minX - w + 1), max(maxX, max(0, minX - w + 1)))
i = random.uniform(max(0, minX - w), max(maxX, max(0, minX - w)))
else:
i = random.uniform(
max(0, minX - w + 1), max(maxX - 1, max(0, minX - w + 1))
)
if minY > maxY:
# j = random.uniform(max(0, minY - h + 1), max(maxY, max(0, minY - h + 1)))
j = random.uniform(max(0, minY - h), max(maxY, max(0, minY - h)))
else:
j = random.uniform(
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
)
result_img, result_target = crop(img, target, [j, i, h, w])
assert (
len(result_target["boxes"]) == init_boxes
), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
return result_img, result_target
else:
w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, (h, w))
result_img, result_target = crop(img, target, region)
return result_img, result_target
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, img, target):
image_width, image_height = img.size
crop_height, crop_width = self.size
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
class RandomHorizontalFlip:
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return hflip(img, target)
return img, target
class RandomResize:
def __init__(self, sizes, max_size=None, square=False):
if isinstance(sizes, int):
sizes = (sizes,)
assert isinstance(sizes, Iterable)
self.sizes = list(sizes)
self.max_size = max_size
self.square = square
def __call__(self, img, target=None):
size = random.choice(self.sizes)
return resize(img, target, size, self.max_size, square=self.square)
class RandomPad:
def __init__(self, max_pad):
self.max_pad = max_pad
def __call__(self, img, target):
pad_x = random.randint(0, self.max_pad)
pad_y = random.randint(0, self.max_pad)
return pad(img, target, (pad_x, pad_y))
class PadToSize:
def __init__(self, size):
self.size = size
def __call__(self, img, target):
w, h = img.size
pad_x = self.size - w
pad_y = self.size - h
assert pad_x >= 0 and pad_y >= 0
pad_left = random.randint(0, pad_x)
pad_right = pad_x - pad_left
pad_top = random.randint(0, pad_y)
pad_bottom = pad_y - pad_top
return pad(img, target, (pad_left, pad_top, pad_right, pad_bottom))
class Identity:
def __call__(self, img, target):
return img, target
class RandomSelect:
"""
Randomly selects between transforms1 and transforms2,
with probability p for transforms1 and (1 - p) for transforms2
"""
def __init__(self, transforms1=None, transforms2=None, p=0.5):
self.transforms1 = transforms1 or Identity()
self.transforms2 = transforms2 or Identity()
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return self.transforms1(img, target)
return self.transforms2(img, target)
class ToTensor:
def __call__(self, img, target):
return F.to_tensor(img), target
class RandomErasing:
def __init__(self, *args, **kwargs):
self.eraser = T.RandomErasing(*args, **kwargs)
def __call__(self, img, target):
return self.eraser(img), target
class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target=None):
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image, None
target = target.copy()
h, w = image.shape[-2:]
if "boxes" in target:
boxes = target["boxes"]
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
target["boxes"] = boxes
if "input_boxes" in target:
boxes = target["input_boxes"]
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
target["input_boxes"] = boxes
return image, target
class RemoveDifficult:
def __init__(self, enabled=False):
self.remove_difficult = enabled
def __call__(self, image, target=None):
if target is None:
return image, None
target = target.copy()
keep = ~target["iscrowd"].to(torch.bool) | (not self.remove_difficult)
if "boxes" in target:
target["boxes"] = target["boxes"][keep]
target["labels"] = target["labels"][keep]
target["iscrowd"] = target["iscrowd"][keep]
return image, target
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
def get_random_resize_scales(size, min_size, rounded):
stride = 128 if rounded else 32
min_size = int(stride * math.ceil(min_size / stride))
scales = list(range(min_size, size + 1, stride))
return scales
def get_random_resize_max_size(size, ratio=5 / 3):
max_size = round(ratio * size)
return max_size

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,607 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import logging
import random
from collections import defaultdict
from typing import List, Optional, Union
import torch
from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object
class FilterDataPointQueries:
find_ids_to_filter: set = None
get_ids_to_filter: set = None
obj_ids_to_filter: set = None # stored as pairs (img_id, obj_id)
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
"""
Compute set of query ids to keep, for both find and get queries
"""
raise NotImplementedError
def _do_filter_query(self, query: Union[FindQuery], query_id: int):
assert self.find_ids_to_filter is not None
return query_id in self.find_ids_to_filter
class FilterQueryWithText(FilterDataPointQueries):
"""
Filter all datapoints which have query text in a specified list of exluded terms
"""
def __init__(
self, exclude_find_keys: List[str] = None, exclude_get_keys: List[str] = None
):
self.find_filter_keys = exclude_find_keys if exclude_find_keys else []
self.get_filter_keys = exclude_get_keys if exclude_get_keys else []
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
del_find_ids = []
del_get_ids = []
for i, f_q in enumerate(datapoint.find_queries):
if f_q.query_text in self.find_filter_keys:
del_find_ids.append(i)
self.find_ids_to_filter = set(del_find_ids)
class KeepMaxNumFindQueries(FilterDataPointQueries):
def __init__(
self, max_num_find_queries: int, retain_positive_queries: bool = False
):
self.max_num_find_queries = max_num_find_queries
self.retain_positive_queries = retain_positive_queries
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
self.obj_ids_to_filter = set()
num_find_queries = len(datapoint.find_queries)
if num_find_queries <= self.max_num_find_queries:
self.find_ids_to_filter = set() # keep all find queries
return
if not self.retain_positive_queries:
all_find_query_ids = list(range(num_find_queries))
num_queries_to_filter = max(0, num_find_queries - self.max_num_find_queries)
query_ids_to_filter = random.sample(
all_find_query_ids, k=num_queries_to_filter
)
else:
# keep up to max_num_find_queries postive find queries and fill
# the remaining slots (if any) with negative find queries
pos_find_ids, neg_find_ids = [], []
for i, f_q in enumerate(datapoint.find_queries):
# Negative finds return an empty list of object_ids_output
if len(f_q.object_ids_output) == 0:
neg_find_ids.append(i)
else:
pos_find_ids.append(i)
if len(pos_find_ids) >= self.max_num_find_queries:
# we have more positive find queries than `max_num_find_queries`,
# so we subsample postive find queries and remove all negative find queries
num_queries_to_filter = len(pos_find_ids) - self.max_num_find_queries
query_ids_to_filter = random.sample(
pos_find_ids, k=num_queries_to_filter
)
query_ids_to_filter.extend(neg_find_ids)
else:
# we have fewer positive find queries than `max_num_find_queries`
# so we need to fill the remaining with negative find queries
num_queries_to_filter = num_find_queries - self.max_num_find_queries
query_ids_to_filter = random.sample(
neg_find_ids, k=num_queries_to_filter
)
assert len(query_ids_to_filter) == num_find_queries - self.max_num_find_queries
self.find_ids_to_filter = set(query_ids_to_filter)
class KeepMaxNumFindQueriesVideo(FilterDataPointQueries):
def __init__(
self,
video_mosaic_max_num_find_queries_per_frame: int,
retain_positive_queries: bool = False,
):
self.video_mosaic_max_num_find_queries_per_frame = (
video_mosaic_max_num_find_queries_per_frame
)
self.retain_positive_queries = retain_positive_queries
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
self.obj_ids_to_filter = set()
num_find_queries = len(datapoint.find_queries)
findQueries_to_imageIds = defaultdict(list)
max_queries_per_frame = True
for i, f_q in enumerate(datapoint.find_queries):
findQueries_to_imageIds[f_q.image_id].append(i)
if (
len(findQueries_to_imageIds[f_q.image_id])
> self.video_mosaic_max_num_find_queries_per_frame
):
max_queries_per_frame = False
if max_queries_per_frame:
self.find_ids_to_filter = set()
return
num_frames = len(findQueries_to_imageIds)
findQueries_0 = findQueries_to_imageIds[0]
num_find_queries_0 = len(findQueries_0)
max_num_find_queries_per_frame = (
self.video_mosaic_max_num_find_queries_per_frame
)
if not self.retain_positive_queries:
find_query_ids_0 = list(range(num_find_queries_0))
num_queries_to_filter = max(
0, num_find_queries_0 - max_num_find_queries_per_frame
)
query_ids_to_filter_0 = random.sample(
find_query_ids_0, k=num_queries_to_filter
)
else:
# keep up to max_num_find_queries postive find queries and fill
# the remaining slots (if any) with negative find queries
pos_find_ids_0, neg_find_ids_0 = [], []
for i, f_q_id in enumerate(findQueries_0):
f_q = datapoint.find_queries[f_q_id]
# Negative finds return an empty list of object_ids_output
if len(f_q.object_ids_output) == 0:
neg_find_ids_0.append(i)
else:
pos_find_ids_0.append(i)
if len(pos_find_ids_0) >= max_num_find_queries_per_frame:
# we have more positive find queries than `max_num_find_queries`,
# so we subsample postive find queries and remove all negative find queries
num_queries_to_filter = (
len(pos_find_ids_0) - max_num_find_queries_per_frame
)
query_ids_to_filter_0 = random.sample(
pos_find_ids_0, k=num_queries_to_filter
)
query_ids_to_filter_0.extend(neg_find_ids_0)
else:
# we have fewer positive find queries than `max_num_find_queries`
# so we need to fill the remaining with negative find queries
num_queries_to_filter = (
num_find_queries_0 - max_num_find_queries_per_frame
)
query_ids_to_filter_0 = random.sample(
neg_find_ids_0, k=num_queries_to_filter
)
# get based on frame 0 all find queries from all the frames with the same indices as in frame 0
query_ids_to_filter = []
for i in range(num_frames):
findQueries_i = findQueries_to_imageIds[i]
query_ids_to_filter.extend(
[findQueries_i[j] for j in query_ids_to_filter_0]
)
assert (
len(query_ids_to_filter)
== num_find_queries
- self.video_mosaic_max_num_find_queries_per_frame * num_frames
)
self.find_ids_to_filter = set(query_ids_to_filter)
class KeepSemanticFindQueriesOnly(FilterDataPointQueries):
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
self.obj_ids_to_filter = set()
self.find_ids_to_filter = {
i for i, q in enumerate(datapoint.find_queries) if q.input_bbox is not None
} # filter (remove) geometric find queries (whose input_bbox is not None)
# Keep all get queries which don't depend on filtered finds
class KeepUnaryFindQueriesOnly(FilterDataPointQueries):
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
self.obj_ids_to_filter = set()
self.find_ids_to_filter = set()
# Keep all get queries which don't depend on filtered finds
class FilterZeroBoxQueries(FilterDataPointQueries):
"""
Filters all find queries which predict a box with zero area
"""
@staticmethod
def _is_zero_area_object(obj: Object):
# Check if height or width of bounding box is zero
bbox = obj.bbox # Assume in XYXY format
height = bbox[..., 3].item() - bbox[..., 1].item()
width = bbox[..., 2].item() - bbox[..., 0].item()
return height == 0 or width == 0
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
# Find objects with zero area
# Assume only one image per datapoint
image_objects = datapoint.images[0].objects
exclude_objects = {
obj_id
for obj_id, obj in enumerate(image_objects)
if self._is_zero_area_object(obj)
}
# If a query predicts an object with zero area, drop the whole find query
del_find_ids = []
for i, f_q in enumerate(datapoint.find_queries):
f_q_objects = set(f_q.object_ids_output)
if len(exclude_objects.intersection(f_q_objects)) > 0:
del_find_ids.append(i)
self.find_ids_to_filter = set(del_find_ids)
class FilterFindQueriesWithTooManyOut(FilterDataPointQueries):
"""
Filters all find queries which have more than a specified number of objects in the output
"""
def __init__(self, max_num_objects: int):
self.max_num_objects = max_num_objects
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
# If a query predicts more than max_num_objects, drop the whole find query
del_find_ids = []
for i, f_q in enumerate(datapoint.find_queries):
if len(f_q.object_ids_output) > self.max_num_objects:
del_find_ids.append(i)
self.find_ids_to_filter = set(del_find_ids)
class FilterEmptyTargets(FilterDataPointQueries):
"""
Filters all targets which have zero area
"""
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
for img_id in range(len(datapoint.images)):
for obj_id, obj in enumerate(datapoint.images[img_id].objects):
if obj.area < 1e-6:
self.obj_ids_to_filter.add((img_id, obj_id))
self.find_ids_to_filter = set()
class FilterNonExhaustiveFindQueries(FilterDataPointQueries):
"""
Filters all find queries which are non-exhaustive
"""
def __init__(self, exhaustivity_type: str):
"""
Args:
exhaustivity_type: Can be "pixel" or "instance":
-pixel: filter queries where the union of all segments covers every pixel belonging to target class
-instance: filter queries where there are non-separable or non annotated instances
Note that instance exhaustivity implies pixel exhaustivity
"""
assert exhaustivity_type in ["pixel", "instance"]
self.exhaustivity_type = exhaustivity_type
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
# If a query predicts more than max_num_objects, drop the whole find query
del_find_ids = []
for i, f_q in enumerate(datapoint.find_queries):
if self.exhaustivity_type == "instance":
if not f_q.is_exhaustive:
del_find_ids.append(i)
elif self.exhaustivity_type == "pixel":
if f_q.is_pixel_exhaustive is not None and not f_q.is_pixel_exhaustive:
del_find_ids.append(i)
else:
raise RuntimeError(
f"Unknown exhaustivity type {self.exhaustivity_type}"
)
self.find_ids_to_filter = set(del_find_ids)
class FilterInvalidGeometricQueries(FilterDataPointQueries):
"""
Filters geometric queries whose output got deleted (eg due to cropping)
"""
def identify_queries_to_filter(self, datapoint):
self.obj_ids_to_filter = set()
# If a query predicts more than max_num_objects, drop the whole find query
del_find_ids = []
for i, f_q in enumerate(datapoint.find_queries):
if f_q.input_bbox is not None and f_q.query_text == "geometric":
if len(f_q.object_ids_output) == 0:
del_find_ids.append(i)
self.find_ids_to_filter = set(del_find_ids)
class FlexibleFilterFindGetQueries:
def __init__(
self, query_filter: FilterDataPointQueries, enabled: bool = True
) -> None:
self.query_filter = query_filter
self.enabled = enabled
def __call__(self, datapoint, **kwargs):
if not self.enabled:
return datapoint
# Identify all queries to filter
self.query_filter.identify_queries_to_filter(datapoint=datapoint)
del_find_ids = []
del_get_ids = []
for i, f_q in enumerate(datapoint.find_queries):
if self.query_filter._do_filter_query(f_q, i):
datapoint.find_queries[i] = None
del_find_ids.append(i)
new_find_queries = []
new_get_queries = []
find_old_to_new_map = {}
get_old_to_new_map = {}
find_counter = 0
get_counter = 0
for i, f_q in enumerate(datapoint.find_queries):
if f_q is not None:
find_old_to_new_map[i] = find_counter
find_counter += 1
new_find_queries.append(f_q)
start_with_zero_check = False
for n_f_q in new_find_queries:
if n_f_q.query_processing_order == 0:
start_with_zero_check = True
break
if len(new_find_queries) == 0:
start_with_zero_check = True
assert (
start_with_zero_check
), "Invalid Find queries, they need to start at query_processing_order = 0"
datapoint.find_queries = new_find_queries
if len(datapoint.find_queries) == 0:
print("Warning: No find queries left in datapoint, this is not allowed")
print("Filtering function:", self.query_filter)
print("Datapoint:", datapoint)
raise ValueError
# The deletion may have removed intermediate steps, so we need to remap to make them contiguous again
all_stages = sorted(
list(set(q.query_processing_order for q in datapoint.find_queries))
)
stage_map = {qpo: i for i, qpo in enumerate(all_stages)}
for i in range(len(datapoint.find_queries)):
qpo = datapoint.find_queries[i].query_processing_order
datapoint.find_queries[i].query_processing_order = stage_map[qpo]
# Final step, clear up objects that are not used anymore
for img_id in range(len(datapoint.images)):
all_objects_ids = set(
i
for find in datapoint.find_queries
for i in find.object_ids_output
if find.image_id == img_id
)
unused_ids = (
set(range(len(datapoint.images[img_id].objects))) - all_objects_ids
)
for tgt_img_id, tgt_obj_id in self.query_filter.obj_ids_to_filter:
if tgt_img_id == img_id:
unused_ids.add(tgt_obj_id)
if len(unused_ids) > 0:
old_objects = datapoint.images[img_id].objects
object_old_to_new_map = {}
new_objects = []
for i, o in enumerate(old_objects):
if i not in unused_ids:
object_old_to_new_map[i] = len(new_objects)
new_objects.append(o)
datapoint.images[img_id].objects = new_objects
# Remap the outputs of the find queries
affected_find_queries_ids = set()
object_old_to_new_map_per_query = {}
for fid, find in enumerate(datapoint.find_queries):
if find.image_id == img_id:
old_object_ids_output = find.object_ids_output
object_old_to_new_map_per_query[fid] = {}
find.object_ids_output = []
for oid, old_obj_id in enumerate(old_object_ids_output):
if old_obj_id not in unused_ids:
new_obj_id = object_old_to_new_map[old_obj_id]
find.object_ids_output.append(new_obj_id)
object_old_to_new_map_per_query[fid][oid] = (
len(find.object_ids_output) - 1
)
affected_find_queries_ids.add(fid)
# finally remove unused images
all_imgs_to_keep = set()
for f_q in datapoint.find_queries:
all_imgs_to_keep.add(f_q.image_id)
old_img_id_to_new_img_id = {}
new_images = []
for img_id, img in enumerate(datapoint.images):
if img_id in all_imgs_to_keep:
old_img_id_to_new_img_id[img_id] = len(new_images)
new_images.append(img)
datapoint.images = new_images
for f_q in datapoint.find_queries:
f_q.image_id = old_img_id_to_new_img_id[f_q.image_id]
return datapoint
class AddPrefixSuffixToFindText:
"""
Add prefix or suffix strings to find query text on the fly.
If `condition_on_text` is True, the prefix or suffix strings are only added
to those find query text in `condition_text_list` (case-insensitive).
"""
def __init__(
self,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
condition_on_text: bool = False,
condition_text_list: Optional[List[str]] = None,
enabled: bool = True,
) -> None:
self.prefix = prefix
self.suffix = suffix
self.condition_on_text = condition_on_text
if self.condition_on_text:
assert condition_text_list is not None
self.condition_text_set = {s.lower().strip() for s in condition_text_list}
self.enabled = enabled
if self.enabled:
logging.info(
f"AddPrefixSuffixToFindText: prefix={prefix}, suffix={suffix}, "
f"condition_on_text={condition_on_text}, condition_text_list={condition_text_list}"
)
def __call__(self, datapoint, **kwargs):
if not self.enabled:
return datapoint
for find in datapoint.find_queries:
if find.query_text == "geometric":
# skip geometric find queries
continue
if (
self.condition_on_text
and find.query_text.lower().strip() not in self.condition_text_set
):
# if condition_on_text is True, skip those queries not in condition_text_set
continue
# add prefix and/or suffix strings to the find query text
if self.prefix is not None:
find.query_text = self.prefix + find.query_text
if self.suffix is not None:
find.query_text = find.query_text + self.suffix
return datapoint
class FilterCrowds(FilterDataPointQueries):
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
"""
Compute set of query ids to keep, for both find and get queries
"""
self.obj_ids_to_filter = set()
self.find_ids_to_filter = set()
# self.get_ids_to_filter = set()
for img_id, img in enumerate(datapoint.images):
for obj_id, obj in enumerate(img.objects):
if obj.is_crowd:
self.obj_ids_to_filter.add((img_id, obj_id))
class TextQueryToVisual:
"""
Transform a test query to a visual query (with some proba), using any of the output targets as the prompt
"""
def __init__(self, probability, keep_text_queries=False) -> None:
self.probability = probability
assert 0 <= probability <= 1
self.keep_text_queries = keep_text_queries
def __call__(self, datapoint: Datapoint, **kwargs):
for find in datapoint.find_queries:
if find.input_bbox is not None or find.input_points is not None:
# skip geometric find queries
continue
if len(find.object_ids_output) == 0:
# Can't create a visual query, skip
continue
if find.query_processing_order > 0:
# Second stage query, can't use
continue
if random.random() > self.probability:
continue
selected_vq_id = random.choice(find.object_ids_output)
img_id = find.image_id
find.input_bbox = datapoint.images[img_id].objects[selected_vq_id].bbox
find.input_bbox_label = torch.ones(1, dtype=torch.bool)
if not self.keep_text_queries:
find.query_text = "visual"
return datapoint
class RemoveInputBoxes:
"""
Remove input boxes from find queries
"""
def __init__(self) -> None:
pass
def __call__(self, datapoint: Datapoint, **kwargs):
for find in datapoint.find_queries:
if find.input_bbox is None:
continue
if find.query_text == "geometric":
print("Warning: removing input box from geometric find query")
find.input_bbox = None
return datapoint
class OverwriteTextQuery:
"""
With some probability, overwrite the text query with a custom text
"""
def __init__(self, target_text, probability=1.0) -> None:
self.probability = probability
self.target_text = target_text
assert 0 <= probability <= 1
def __call__(self, datapoint: Datapoint, **kwargs):
for find in datapoint.find_queries:
if random.random() > self.probability:
continue
find.query_text = self.target_text
return datapoint

View File

@@ -0,0 +1,345 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import cv2
import numpy as np
import torch
from PIL import Image as PILImage
from pycocotools import mask as mask_util
from sam3.train.data.sam3_image_dataset import Datapoint
from torchvision.ops import masks_to_boxes
def sample_points_from_rle(rle, n_points, mode, box=None, normalize=True):
"""
Sample random points from a mask provided in COCO RLE format. 'mode'
'mode' is in ["centered", "random_mask", "random_box"]
"centered": points are sampled farthest from the mask edges and each other
"random_mask": points are sampled uniformly from the mask
"random_box": points are sampled uniformly from the annotation's box
'box' must be provided if 'mode' is "random_box".
If 'normalize' is true, points are in [0,1], relative to mask h,w.
"""
mask = np.ascontiguousarray(mask_util.decode(rle))
points = sample_points_from_mask(mask, n_points, mode, box)
if normalize:
h, w = mask.shape
norm = np.array([w, h, 1.0])[None, :]
points = points / norm
return points
def sample_points_from_mask(mask, n_points, mode, box=None):
if mode == "centered":
points = center_positive_sample(mask, n_points)
elif mode == "random_mask":
points = uniform_positive_sample(mask, n_points)
elif mode == "random_box":
assert box is not None, "'random_box' mode requires a provided box."
points = uniform_sample_from_box(mask, box, n_points)
else:
raise ValueError(f"Unknown point sampling mode {mode}.")
return points
def uniform_positive_sample(mask, n_points):
"""
Samples positive points uniformly from the mask. Only integer pixel
values are sampled.
"""
# Sampling directly from the uncompressed RLE would be faster but is
# likely unnecessary.
mask_points = np.stack(np.nonzero(mask), axis=0).transpose(1, 0)
assert len(mask_points) > 0, "Can't sample positive points from an empty mask."
selected_idxs = np.random.randint(low=0, high=len(mask_points), size=n_points)
selected_points = mask_points[selected_idxs]
selected_points = selected_points[:, ::-1] # (y, x) -> (x, y)
labels = np.ones((len(selected_points), 1))
selected_points = np.concatenate([selected_points, labels], axis=1)
return selected_points
def center_positive_sample(mask, n_points):
"""
Samples points farthest from mask edges (by distance transform)
and subsequent points also farthest from each other. Each new point
sampled is treated as an edge for future points. Edges of the image are
treated as edges of the mask.
"""
# Pad mask by one pixel on each end to assure distance transform
# avoids edges
padded_mask = np.pad(mask, 1)
points = []
for _ in range(n_points):
assert np.max(mask) > 0, "Can't sample positive points from an empty mask."
dist = cv2.distanceTransform(padded_mask, cv2.DIST_L2, 0)
point = np.unravel_index(dist.argmax(), dist.shape)
# Mark selected point as background so next point avoids it
padded_mask[point[0], point[1]] = 0
points.append(point[::-1]) # (y, x) -> (x, y)
points = np.stack(points, axis=0)
points = points - 1 # Subtract left/top padding of 1
labels = np.ones((len(points), 1))
points = np.concatenate([points, labels], axis=1)
return points
def uniform_sample_from_box(mask, box, n_points):
"""
Sample points uniformly from the provided box. The points' labels
are determined by the provided mask. Does not guarantee a positive
point is sampled. The box is assumed unnormalized in XYXY format.
Points are sampled at integer values.
"""
# Since lower/right edges are exclusive, ceil can be applied to all edges
int_box = np.ceil(box)
x = np.random.randint(low=int_box[0], high=int_box[2], size=n_points)
y = np.random.randint(low=int_box[1], high=int_box[3], size=n_points)
labels = mask[y, x]
points = np.stack([x, y, labels], axis=1)
return points
def rescale_box_xyxy(box, factor, imsize=None):
"""
Rescale a box providing in unnormalized XYXY format, fixing the center.
If imsize is provided, clamp to the image.
"""
cx, cy = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2
w, h = box[2] - box[0], box[3] - box[1]
new_w, new_h = factor * w, factor * h
new_x0, new_y0 = cx - new_w / 2, cy - new_h / 2
new_x1, new_y1 = cx + new_w / 2, cy + new_h / 2
if imsize is not None:
new_x0 = max(min(new_x0, imsize[1]), 0)
new_x1 = max(min(new_x1, imsize[1]), 0)
new_y0 = max(min(new_y0, imsize[0]), 0)
new_y1 = max(min(new_y1, imsize[0]), 0)
return [new_x0, new_y0, new_x1, new_y1]
def noise_box(box, im_size, box_noise_std, box_noise_max, min_box_area):
if box_noise_std <= 0.0:
return box
noise = box_noise_std * torch.randn(size=(4,))
w, h = box[2] - box[0], box[3] - box[1]
scale_factor = torch.tensor([w, h, w, h])
noise = noise * scale_factor
if box_noise_max is not None:
noise = torch.clamp(noise, -box_noise_max, box_noise_max)
input_box = box + noise
# Clamp to maximum image size
img_clamp = torch.tensor([im_size[1], im_size[0], im_size[1], im_size[0]])
input_box = torch.maximum(input_box, torch.zeros_like(input_box))
input_box = torch.minimum(input_box, img_clamp)
if (input_box[2] - input_box[0]) * (input_box[3] - input_box[1]) <= min_box_area:
return box
return input_box
class RandomGeometricInputsAPI:
"""
For geometric queries, replaces the input box or points with a random
one sampled from the GT mask. Segments must be provided for objects
that are targets of geometric queries, and must be binary masks. Existing
point and box queries in the datapoint will be ignored and completely replaced.
Will sample points and boxes in XYXY format in absolute pixel space.
Geometry queries are currently determined by taking any query whose
query text is a set value.
Args:
num_points (int or (int, int)): how many points to sample. If a tuple,
sample a random number of points uniformly over the inclusive range.
box_chance (float): fraction of time a box is sampled. A box will replace
one sampled point.
box_noise_std (float): if greater than 0, add noise to the sampled boxes
with this std. Noise is relative to the length of the box side.
box_noise_max (int): if not none, truncate any box noise larger than this
in terms of absolute pixels.
resample_box_from_mask (bool): if True, any sampled box will be determined
by finding the extrema of the provided mask. If False, the bbox provided
in the target object will be used.
point_sample_mode (str): In ["centered", "random_mask", "random_box"],
controlling how points are sampled:
"centered": points are sampled farthest from the mask edges and each other
"random_mask": points are sampled uniformly from the mask
"random_box": points are sampled uniformly from the annotation's box
Note that "centered" may be too slow for on-line generation.
geometric_query_str (str): what string in query_text indicates a
geometry query.
minimum_box_area (float): sampled boxes with area this size or smaller after
noising will use the original box instead. It is the input's responsibility
to avoid original boxes that violate necessary area bounds.
concat_points (bool): if True, any sampled points will be added to existing
ones instead of replacing them.
"""
def __init__(
self,
num_points,
box_chance,
box_noise_std=0.0,
box_noise_max=None,
minimum_box_area=0.0,
resample_box_from_mask=False,
point_sample_mode="random_mask",
sample_box_scale_factor=1.0,
geometric_query_str="geometric",
concat_points=False,
):
self.num_points = num_points
if not isinstance(self.num_points, int):
# Convert from inclusive range to exclusive range expected by torch
self.num_points[1] += 1
self.num_points = tuple(self.num_points)
self.box_chance = box_chance
self.box_noise_std = box_noise_std
self.box_noise_max = box_noise_max
self.minimum_box_area = minimum_box_area
self.resample_box_from_mask = resample_box_from_mask
self.point_sample_mode = point_sample_mode
assert point_sample_mode in [
"centered",
"random_mask",
"random_box",
], "Unknown point sample mode."
self.geometric_query_str = geometric_query_str
self.concat_points = concat_points
self.sample_box_scale_factor = sample_box_scale_factor
def _sample_num_points_and_if_box(self):
if isinstance(self.num_points, tuple):
n_points = torch.randint(
low=self.num_points[0], high=self.num_points[1], size=(1,)
).item()
else:
n_points = self.num_points
if self.box_chance > 0.0:
use_box = torch.rand(size=(1,)).item() < self.box_chance
n_points -= int(use_box) # box stands in for one point
else:
use_box = False
return n_points, use_box
def _get_original_box(self, target_object):
if not self.resample_box_from_mask:
return target_object.bbox
mask = target_object.segment
return masks_to_boxes(mask[None, :, :])[0]
def _get_target_object(self, datapoint, query):
img = datapoint.images[query.image_id]
targets = query.object_ids_output
assert (
len(targets) == 1
), "Geometric queries only support a single target object."
target_idx = targets[0]
return img.objects[target_idx]
def __call__(self, datapoint, **kwargs):
for query in datapoint.find_queries:
if query.query_text != self.geometric_query_str:
continue
target_object = self._get_target_object(datapoint, query)
n_points, use_box = self._sample_num_points_and_if_box()
box = self._get_original_box(target_object)
mask = target_object.segment
if n_points > 0:
# FIXME: The conversion to numpy and back to reuse code
# is awkward, but this is all in the dataloader worker anyway
# on CPU and so I don't think it should matter.
if self.sample_box_scale_factor != 1.0:
sample_box = rescale_box_xyxy(
box.numpy(), self.sample_box_scale_factor, mask.shape
)
else:
sample_box = box.numpy()
input_points = sample_points_from_mask(
mask.numpy(),
n_points,
self.point_sample_mode,
sample_box,
)
input_points = torch.as_tensor(input_points)
input_points = input_points[None, :, :]
if self.concat_points and query.input_points is not None:
input_points = torch.cat([query.input_points, input_points], dim=1)
else:
input_points = query.input_points if self.concat_points else None
if use_box:
w, h = datapoint.images[query.image_id].size
input_box = noise_box(
box,
(h, w),
box_noise_std=self.box_noise_std,
box_noise_max=self.box_noise_max,
min_box_area=self.minimum_box_area,
)
input_box = input_box[None, :]
else:
input_box = query.input_bbox if self.concat_points else None
query.input_points = input_points
query.input_bbox = input_box
return datapoint
class RandomizeInputBbox:
"""
Simplified version of the geometric transform that only deals with input boxes
"""
def __init__(
self,
box_noise_std=0.0,
box_noise_max=None,
minimum_box_area=0.0,
):
self.box_noise_std = box_noise_std
self.box_noise_max = box_noise_max
self.minimum_box_area = minimum_box_area
def __call__(self, datapoint: Datapoint, **kwargs):
for query in datapoint.find_queries:
if query.input_bbox is None:
continue
img = datapoint.images[query.image_id].data
if isinstance(img, PILImage.Image):
w, h = img.size
else:
assert isinstance(img, torch.Tensor)
h, w = img.shape[-2:]
for box_id in range(query.input_bbox.shape[0]):
query.input_bbox[box_id, :] = noise_box(
query.input_bbox[box_id, :].view(4),
(h, w),
box_noise_std=self.box_noise_std,
box_noise_max=self.box_noise_max,
min_box_area=self.minimum_box_area,
).view(1, 4)
return datapoint

View File

@@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import numpy as np
import pycocotools.mask as mask_utils
import torch
import torchvision.transforms.functional as F
from PIL import Image as PILImage
from sam3.model.box_ops import masks_to_boxes
from sam3.train.data.sam3_image_dataset import Datapoint
class InstanceToSemantic(object):
"""Convert instance segmentation to semantic segmentation."""
def __init__(self, delete_instance=True, use_rle=False):
self.delete_instance = delete_instance
self.use_rle = use_rle
def __call__(self, datapoint: Datapoint, **kwargs):
for fquery in datapoint.find_queries:
h, w = datapoint.images[fquery.image_id].size
if self.use_rle:
all_segs = [
datapoint.images[fquery.image_id].objects[obj_id].segment
for obj_id in fquery.object_ids_output
]
if len(all_segs) > 0:
# we need to double check that all rles are the correct size
# Otherwise cocotools will fail silently to an empty [0,0] mask
for seg in all_segs:
assert seg["size"] == all_segs[0]["size"], (
"Instance segments have inconsistent sizes. "
f"Found sizes {seg['size']} and {all_segs[0]['size']}"
)
fquery.semantic_target = mask_utils.merge(all_segs)
else:
# There is no good way to create an empty RLE of the correct size
# We resort to converting an empty box to RLE
fquery.semantic_target = mask_utils.frPyObjects(
np.array([[0, 0, 0, 0]], dtype=np.float64), h, w
)[0]
else:
# `semantic_target` is uint8 and remains uint8 throughout the transforms
# (it contains binary 0 and 1 values just like `segment` for each object)
fquery.semantic_target = torch.zeros((h, w), dtype=torch.uint8)
for obj_id in fquery.object_ids_output:
segment = datapoint.images[fquery.image_id].objects[obj_id].segment
if segment is not None:
assert (
isinstance(segment, torch.Tensor)
and segment.dtype == torch.uint8
)
fquery.semantic_target |= segment
if self.delete_instance:
for img in datapoint.images:
for obj in img.objects:
del obj.segment
obj.segment = None
return datapoint
class RecomputeBoxesFromMasks:
"""Recompute bounding boxes from masks."""
def __call__(self, datapoint: Datapoint, **kwargs):
for img in datapoint.images:
for obj in img.objects:
# Note: if the mask is empty, the bounding box will be undefined
# The empty targets should be subsequently filtered
obj.bbox = masks_to_boxes(obj.segment)
obj.area = obj.segment.sum().item()
return datapoint
class DecodeRle:
"""This transform decodes RLEs into binary segments.
Implementing it as a transforms allows lazy loading. Some transforms (eg query filters)
may be deleting masks, so decoding them from the beginning is wasteful.
This transforms needs to be called before any kind of geometric manipulation
"""
def __call__(self, datapoint: Datapoint, **kwargs):
imgId2size = {}
warning_shown = False
for imgId, img in enumerate(datapoint.images):
if isinstance(img.data, PILImage.Image):
img_w, img_h = img.data.size
elif isinstance(img.data, torch.Tensor):
img_w, img_h = img.data.shape[-2:]
else:
raise RuntimeError(f"Unexpected image type {type(img.data)}")
imgId2size[imgId] = (img_h, img_w)
for obj in img.objects:
if obj.segment is not None and not isinstance(
obj.segment, torch.Tensor
):
if mask_utils.area(obj.segment) == 0:
print("Warning, empty mask found, approximating from box")
obj.segment = torch.zeros(img_h, img_w, dtype=torch.uint8)
x1, y1, x2, y2 = obj.bbox.int().tolist()
obj.segment[y1 : max(y2, y1 + 1), x1 : max(x1 + 1, x2)] = 1
else:
obj.segment = mask_utils.decode(obj.segment)
# segment is uint8 and remains uint8 throughout the transforms
obj.segment = torch.tensor(obj.segment).to(torch.uint8)
if list(obj.segment.shape) != [img_h, img_w]:
# Should not happen often, but adding for security
if not warning_shown:
print(
f"Warning expected instance segmentation size to be {[img_h, img_w]} but found {list(obj.segment.shape)}"
)
# Printing only once per datapoint to avoid spam
warning_shown = True
obj.segment = F.resize(
obj.segment[None], (img_h, img_w)
).squeeze(0)
assert list(obj.segment.shape) == [img_h, img_w]
warning_shown = False
for query in datapoint.find_queries:
if query.semantic_target is not None and not isinstance(
query.semantic_target, torch.Tensor
):
query.semantic_target = mask_utils.decode(query.semantic_target)
# segment is uint8 and remains uint8 throughout the transforms
query.semantic_target = torch.tensor(query.semantic_target).to(
torch.uint8
)
if tuple(query.semantic_target.shape) != imgId2size[query.image_id]:
if not warning_shown:
print(
f"Warning expected semantic segmentation size to be {imgId2size[query.image_id]} but found {tuple(query.semantic_target.shape)}"
)
# Printing only once per datapoint to avoid spam
warning_shown = True
query.semantic_target = F.resize(
query.semantic_target[None], imgId2size[query.image_id]
).squeeze(0)
assert tuple(query.semantic_target.shape) == imgId2size[query.image_id]
return datapoint

View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

View File

@@ -0,0 +1,358 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import contextlib
import fnmatch
import logging
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import numpy as np
import torch
import torch.nn as nn
from iopath.common.file_io import g_pathmgr
from torch.jit._script import RecursiveScriptModule
def unix_pattern_to_parameter_names(
constraints: List[str], all_parameter_names: Sequence[str]
) -> Union[None, Set[str]]:
"""
Go through the list of parameter names and select those that match
any of the provided constraints
"""
parameter_names = []
for param_name in constraints:
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
assert (
len(matching_parameters) > 0
), f"param_names {param_name} don't match any param in the given names."
parameter_names.append(matching_parameters)
return set.union(*parameter_names)
def filter_params_matching_unix_pattern(
patterns: List[str], state_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Remove from the state dictionary the parameters matching the provided unix patterns
Args:
patterns: the list of unix patterns to exclude
state_dict: the dictionary to filter
Returns:
A new state dictionary
"""
if len(patterns) == 0:
return {}
all_keys = list(state_dict.keys())
included_keys = unix_pattern_to_parameter_names(patterns, all_keys)
return {k: state_dict[k] for k in included_keys}
def exclude_params_matching_unix_pattern(
patterns: List[str], state_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Remove from the state dictionary the parameters matching the provided unix patterns
Args:
patterns: the list of unix patterns to exclude
state_dict: the dictionary to filter
Returns:
A new state dictionary
"""
if len(patterns) == 0:
return state_dict
all_keys = list(state_dict.keys())
excluded_keys = unix_pattern_to_parameter_names(patterns, all_keys)
return {k: v for k, v in state_dict.items() if k not in excluded_keys}
def _get_state_dict_summary(state_dict: Dict[str, torch.Tensor]):
keys = []
trace = []
for k, v in state_dict.items():
keys.append(k)
trace.append(v.sum().item())
trace = np.array(trace)[np.argsort(keys)]
return trace
def assert_skipped_parameters_are_frozen(model: nn.Module, patterns: List[str]):
"""
Verifies that all the parameters matching the provided patterns
are frozen - this acts as a safeguard when ignoring parameter
when saving checkpoints - if the parameters are in fact trainable
"""
if not patterns:
return
frozen_state_dict = filter_params_matching_unix_pattern(
patterns=patterns, state_dict=model.state_dict()
)
non_frozen_keys = {
n
for n, p in model.named_parameters()
if n in frozen_state_dict and p.requires_grad
}
if non_frozen_keys:
raise ValueError(
f"Parameters excluded with `skip_saving_parameters` should be frozen: {non_frozen_keys}"
)
@contextlib.contextmanager
def with_check_parameter_frozen(
model: nn.Module, patterns: List[str], disabled: bool = True
):
"""
Context manager that inspects a model surrounding a piece of code
and verifies if the model has been updated by this piece of code
The function will raise an exception if the model has been updated
on at least one of the parameter that matches one of the pattern
Args:
model: the model that might have been updated
patterns: for the parameters we want to observe
allowed:
"""
if not patterns or disabled:
yield
return
frozen_state_dict = filter_params_matching_unix_pattern(
patterns=patterns, state_dict=model.state_dict()
)
summary_before = _get_state_dict_summary(frozen_state_dict)
yield
frozen_state_dict = filter_params_matching_unix_pattern(
patterns=patterns, state_dict=model.state_dict()
)
summary_after = _get_state_dict_summary(frozen_state_dict)
if not np.allclose(summary_before, summary_after, atol=1e-6):
raise ValueError(
f"""
The `model_weight_initializer` has initialized parameters frozen with `skip_saving_parameters`.
You can resolve this error by either initializing those parameters from within the model definition
or using the flag `trainer.checkpoint.initialize_after_preemption` to True.
"""
)
class CkptExcludeKernel:
"""
Removes the keys from the given model state_dict that match the key_pattern.
Args:
key_pattern: Patterns used to select the keys in the state_dict
that are eligible for this kernel.
"""
def __init__(self, key_pattern: List[str]):
self.key_pattern = key_pattern
def __call__(self, state_dict: Dict):
"""
Args:
state_dict: A dictionary representing the given checkpoint's state dict.
"""
if len(self.key_pattern) == 0:
return state_dict
exclude_keys = unix_pattern_to_parameter_names(
self.key_pattern, state_dict.keys()
)
return {k: v for k, v in state_dict.items() if k not in exclude_keys}
def load_checkpoint(
path_list: List[str],
pick_recursive_keys: Optional[List[str]] = None,
map_location: str = "cpu",
) -> Any:
"""
Loads a checkpoint from the specified path.
Args:
path_list: A list of paths which contain the checkpoint. Each element
is tried (in order) until a file that exists is found. That file is then
used to read the checkpoint.
pick_recursive_keys: Picks sub dicts from the loaded checkpoint if not None.
For pick_recursive_keys = ["a", "b"], will return checkpoint_dict["a"]["b"]
map_location (str): a function, torch.device, string or a dict specifying how to
remap storage locations
Returns: Model with the matchin pre-trained weights loaded.
"""
path_exists = False
for path in path_list:
if g_pathmgr.exists(path):
path_exists = True
break
if not path_exists:
raise ValueError(f"No path exists in {path_list}")
with g_pathmgr.open(path, "rb") as f:
checkpoint = torch.load(f, map_location=map_location)
logging.info(f"Loaded checkpoint from {path}")
if pick_recursive_keys is not None:
for key in pick_recursive_keys:
checkpoint = checkpoint[key]
return checkpoint
def get_state_dict(checkpoint, ckpt_state_dict_keys):
if isinstance(checkpoint, RecursiveScriptModule):
# This is a torchscript JIT model
return checkpoint.state_dict()
pre_train_dict = checkpoint
for i, key in enumerate(ckpt_state_dict_keys):
if (isinstance(pre_train_dict, Mapping) and key not in pre_train_dict) or (
isinstance(pre_train_dict, Sequence) and key >= len(pre_train_dict)
):
key_str = (
'["' + '"]["'.join(list(map(ckpt_state_dict_keys[:i], str))) + '"]'
)
raise KeyError(
f"'{key}' not found in checkpoint{key_str} "
f"with keys: {pre_train_dict.keys()}"
)
pre_train_dict = pre_train_dict[key]
return pre_train_dict
def load_checkpoint_and_apply_kernels(
checkpoint_path: str,
checkpoint_kernels: List[Callable] = None,
ckpt_state_dict_keys: Tuple[str] = ("state_dict",),
map_location: str = "cpu",
) -> nn.Module:
"""
Performs checkpoint loading with a variety of pre-processing kernel applied in
sequence.
Args:
checkpoint_path (str): Path to the checkpoint.
checkpoint_kernels List(Callable): A list of checkpoint processing kernels
to apply in the specified order. Supported kernels include `CkptIncludeKernel`,
`CkptExcludeKernel`, etc. These kernels are applied in the
given order.
ckpt_state_dict_keys (str): Keys containing the model state dict.
map_location (str): a function, torch.device, string or a dict specifying how to
remap storage locations
Returns: Model with the matchin pre-trained weights loaded.
"""
assert g_pathmgr.exists(checkpoint_path), "Checkpoint '{}' not found".format(
checkpoint_path
)
# Load the checkpoint on CPU to avoid GPU mem spike.
with g_pathmgr.open(checkpoint_path, "rb") as f:
checkpoint = torch.load(f, map_location=map_location)
pre_train_dict = get_state_dict(checkpoint, ckpt_state_dict_keys)
# Not logging into info etc since it's a huge log
logging.debug(
"Loaded Checkpoint State Dict pre-kernel application: %s"
% str(", ".join(list(pre_train_dict.keys())))
)
# Apply kernels
if checkpoint_kernels is not None:
for f in checkpoint_kernels:
pre_train_dict = f(state_dict=pre_train_dict)
logging.debug(
"Loaded Checkpoint State Dict Post-kernel application %s"
% str(", ".join(list(pre_train_dict.keys())))
)
return pre_train_dict
def check_load_state_dict_errors(
missing_keys,
unexpected_keys,
strict: bool,
ignore_missing_keys: List[str] = None,
ignore_unexpected_keys: List[str] = None,
):
if ignore_missing_keys is not None and len(ignore_missing_keys) > 0:
ignored_keys = unix_pattern_to_parameter_names(
ignore_missing_keys, missing_keys
)
missing_keys = [key for key in missing_keys if key not in ignored_keys]
if ignore_unexpected_keys is not None and len(ignore_unexpected_keys) > 0:
ignored_unexpected_keys = unix_pattern_to_parameter_names(
ignore_unexpected_keys, unexpected_keys
)
unexpected_keys = [
key for key in unexpected_keys if key not in ignored_unexpected_keys
]
err = "State key mismatch."
if unexpected_keys:
err += f" Unexpected keys: {unexpected_keys}."
if missing_keys:
err += f" Missing keys: {missing_keys}."
if unexpected_keys or missing_keys:
logging.warning(err)
if unexpected_keys or strict:
raise KeyError(err)
def load_state_dict_into_model(
state_dict: Dict,
model: nn.Module,
strict: bool = True,
ignore_missing_keys: List[str] = None,
ignore_unexpected_keys: List[str] = None,
checkpoint_kernels: List[Callable] = None,
):
"""
Loads a state dict into the given model.
Args:
state_dict: A dictionary containing the model's
state dict, or a subset if strict is False
model: Model to load the checkpoint weights into
strict: raise if the state_dict has missing state keys
ignore_missing_keys: unix pattern of keys to ignore
"""
# Apply kernels
if checkpoint_kernels is not None:
for f in checkpoint_kernels:
state_dict = f(state_dict=state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
check_load_state_dict_errors(
missing_keys,
unexpected_keys,
strict=strict,
ignore_missing_keys=ignore_missing_keys,
ignore_unexpected_keys=ignore_unexpected_keys,
)
return model

View File

@@ -0,0 +1,585 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import functools
import io
import logging
import os
import random
import tempfile
import time
from typing import Any, Callable, List, Tuple
import torch
import torch.autograd as autograd
import torch.distributed as dist
# Default to GPU 0
_cuda_device_index: int = 0
# Setting _cuda_device_index to -1 internally implies that we should use CPU
_CPU_DEVICE_INDEX = -1
_PRIMARY_RANK = 0
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
# Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes
# being much slower than others causing a timeout (which can happen in relation
# or LVIS class mAP evaluation).
timeout = 43200
return dist.new_group(
backend="gloo",
timeout=datetime.timedelta(seconds=timeout),
)
return dist.group.WORLD
def is_main_process():
"""Return true if the current process is the main one"""
return get_rank() == 0
def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
`all_gather` above, but using filesystem instead of collective ops.
If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
(and other ranks will have an empty list).
"""
world_size = get_world_size()
if world_size == 1:
return [data]
print("gathering via files")
cpu_group = _get_global_gloo_group()
# if unspecified, we will save to the current python file dir
if filesys_save_dir is not None:
save_dir = filesys_save_dir
elif "EXP_DIR" in os.environ:
save_dir = os.environ["EXP_DIR"]
else:
# try the same directory where the code is stored
save_dir = filesys_save_dir or os.path.dirname(__file__)
save_dir = os.path.join(save_dir, "all_gather_via_filesys")
if is_main_process():
os.makedirs(save_dir, exist_ok=True)
# use a timestamp and salt to distinguish different all_gather
timestamp = int(time.time()) if is_main_process() else 0
salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
# broadcast the timestamp and salt across ranks
# (all-reduce will do the broadcasting since only rank 0 is non-zero)
timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
dist.all_reduce(timestamp_and_salt, group=cpu_group)
timestamp, salt = timestamp_and_salt.tolist()
# save the data to a file on the disk
rank_save = get_rank()
save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
save_data_path = os.path.join(save_dir, save_data_filename)
assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
torch.save(data, save_data_path)
dist.barrier(group=cpu_group)
# read the data from the files
data_list = []
if rank_save == 0 or not gather_to_rank_0_only:
for rank_load in range(world_size):
load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
load_data_path = os.path.join(save_dir, load_data_filename)
assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
data_list.append(torch.load(load_data_path, weights_only=False))
dist.barrier(group=cpu_group)
# delete the saved file
os.remove(save_data_path)
return data_list
def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1":
return all_gather_via_filesys(
data, filesys_save_dir, gather_to_rank_0_only=True
)
if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys:
return all_gather_via_filesys(data, filesys_save_dir)
cpu_group = None
if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu:
cpu_group = _get_global_gloo_group()
buffer = io.BytesIO()
torch.save(data, buffer)
data_view = buffer.getbuffer()
device = "cuda" if cpu_group is None else "cpu"
tensor = torch.ByteTensor(data_view).to(device)
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
size_list = [
torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
]
if cpu_group is None:
dist.all_gather(size_list, local_size)
else:
print("gathering on cpu")
dist.all_gather(size_list, local_size, group=cpu_group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
assert isinstance(local_size.item(), int)
local_size = int(local_size.item())
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
if local_size != max_size:
padding = torch.empty(
size=(max_size - local_size,), dtype=torch.uint8, device=device
)
tensor = torch.cat((tensor, padding), dim=0)
if cpu_group is None:
dist.all_gather(tensor_list, tensor)
else:
dist.all_gather(tensor_list, tensor, group=cpu_group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
buffer = io.BytesIO(tensor.cpu().numpy())
obj = torch.load(buffer, weights_only=False)
data_list.append(obj)
return data_list
def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
"""
For some backends, such as NCCL, communication only works if the
tensor is on the GPU. This helper function converts to the correct
device and returns the tensor + original device.
"""
orig_device = "cpu" if not tensor.is_cuda else "gpu"
if (
torch.distributed.is_available()
and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
and not tensor.is_cuda
):
tensor = tensor.cuda()
return (tensor, orig_device)
def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
"""
For some backends, such as NCCL, communication only works if the
tensor is on the GPU. This converts the tensor back to original device.
"""
if tensor.is_cuda and orig_device == "cpu":
tensor = tensor.cpu()
return tensor
def is_distributed_training_run() -> bool:
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and (torch.distributed.get_world_size() > 1)
)
def is_primary() -> bool:
"""
Returns True if this is rank 0 of a distributed training job OR if it is
a single trainer job. Otherwise False.
"""
return get_rank() == _PRIMARY_RANK
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
"""
Wrapper over torch.distributed.all_reduce for performing mean reduction
of tensor over all processes.
"""
return all_reduce_op(
tensor,
torch.distributed.ReduceOp.SUM,
lambda t: t / torch.distributed.get_world_size(),
)
def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
"""
Wrapper over torch.distributed.all_reduce for performing sum
reduction of tensor over all processes in both distributed /
non-distributed scenarios.
"""
return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM)
def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor:
"""
Wrapper over torch.distributed.all_reduce for performing min
reduction of tensor over all processes in both distributed /
non-distributed scenarios.
"""
return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN)
def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor:
"""
Wrapper over torch.distributed.all_reduce for performing min
reduction of tensor over all processes in both distributed /
non-distributed scenarios.
"""
return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX)
def all_reduce_op(
tensor: torch.Tensor,
op: torch.distributed.ReduceOp,
after_op_func: Callable[[torch.Tensor], torch.Tensor] = None,
) -> torch.Tensor:
"""
Wrapper over torch.distributed.all_reduce for performing
reduction of tensor over all processes in both distributed /
non-distributed scenarios.
"""
if is_distributed_training_run():
tensor, orig_device = convert_to_distributed_tensor(tensor)
torch.distributed.all_reduce(tensor, op)
if after_op_func is not None:
tensor = after_op_func(tensor)
tensor = convert_to_normal_tensor(tensor, orig_device)
return tensor
def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]:
"""
Wrapper over torch.distributed.all_gather for performing
'gather' of 'tensor' over all processes in both distributed /
non-distributed scenarios.
"""
if tensor.ndim == 0:
# 0 dim tensors cannot be gathered. so unsqueeze
tensor = tensor.unsqueeze(0)
if is_distributed_training_run():
tensor, orig_device = convert_to_distributed_tensor(tensor)
gathered_tensors = [
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(gathered_tensors, tensor)
gathered_tensors = [
convert_to_normal_tensor(_tensor, orig_device)
for _tensor in gathered_tensors
]
else:
gathered_tensors = [tensor]
return gathered_tensors
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
gathered_tensors = gather_tensors_from_all(tensor)
gathered_tensor = torch.cat(gathered_tensors, 0)
return gathered_tensor
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""
Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source
to all processes in both distributed / non-distributed scenarios.
"""
if is_distributed_training_run():
tensor, orig_device = convert_to_distributed_tensor(tensor)
torch.distributed.broadcast(tensor, src)
tensor = convert_to_normal_tensor(tensor, orig_device)
return tensor
def barrier() -> None:
"""
Wrapper over torch.distributed.barrier, returns without waiting
if the distributed process group is not initialized instead of throwing error.
"""
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
return
torch.distributed.barrier()
def get_world_size() -> int:
"""
Simple wrapper for correctly getting worldsize in both distributed
/ non-distributed settings
"""
return (
torch.distributed.get_world_size()
if torch.distributed.is_available() and torch.distributed.is_initialized()
else 1
)
def get_rank() -> int:
"""
Simple wrapper for correctly getting rank in both distributed
/ non-distributed settings
"""
return (
torch.distributed.get_rank()
if torch.distributed.is_available() and torch.distributed.is_initialized()
else 0
)
def get_primary_rank() -> int:
return _PRIMARY_RANK
def set_cuda_device_index(idx: int) -> None:
global _cuda_device_index
_cuda_device_index = idx
torch.cuda.set_device(_cuda_device_index)
def set_cpu_device() -> None:
global _cuda_device_index
_cuda_device_index = _CPU_DEVICE_INDEX
def get_cuda_device_index() -> int:
return _cuda_device_index
def init_distributed_data_parallel_model(
model: torch.nn.Module,
broadcast_buffers: bool = False,
find_unused_parameters: bool = True,
bucket_cap_mb: int = 25,
) -> torch.nn.parallel.DistributedDataParallel:
global _cuda_device_index
if _cuda_device_index == _CPU_DEVICE_INDEX:
# CPU-only model, don't specify device
return torch.nn.parallel.DistributedDataParallel(
model,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
bucket_cap_mb=bucket_cap_mb,
)
else:
# GPU model
return torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[_cuda_device_index],
output_device=_cuda_device_index,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
bucket_cap_mb=bucket_cap_mb,
)
def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
"""Broadcast an object from a source to all workers.
Args:
obj: Object to broadcast, must be serializable
src: Source rank for broadcast (default is primary)
use_disk: If enabled, removes redundant CPU memory copies by writing to
disk
"""
# Either broadcast from primary to the fleet (default),
# or use the src setting as the original rank
if get_rank() == src:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
data_view = buffer.getbuffer()
length_tensor = torch.LongTensor([len(data_view)])
length_tensor = broadcast(length_tensor, src=src)
data_tensor = torch.ByteTensor(data_view)
data_tensor = broadcast(data_tensor, src=src)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0])
length_tensor = broadcast(length_tensor, src=src)
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
data_tensor = broadcast(data_tensor, src=src)
if use_disk:
with tempfile.TemporaryFile("r+b") as f:
f.write(data_tensor.numpy())
# remove reference to the data tensor and hope that Python garbage
# collects it
del data_tensor
f.seek(0)
obj = torch.load(f, weights_only=False)
else:
buffer = io.BytesIO(data_tensor.numpy())
obj = torch.load(buffer, weights_only=False)
return obj
def all_gather_tensor(tensor: torch.Tensor, world_size=None):
if world_size is None:
world_size = get_world_size()
# make contiguous because NCCL won't gather the tensor otherwise
assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!"
tensor, orig_device = convert_to_distributed_tensor(tensor)
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_all, tensor, async_op=False) # performance opt
tensor_all = [
convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all
]
return tensor_all
def all_gather_batch(tensors: List[torch.Tensor]):
"""
Performs all_gather operation on the provided tensors.
"""
# Queue the gathered tensors
world_size = get_world_size()
# There is no need for reduction in the single-proc case
if world_size == 1:
return tensors
tensor_list = []
output_tensor = []
for tensor in tensors:
tensor_all = all_gather_tensor(tensor, world_size)
tensor_list.append(tensor_all)
for tensor_all in tensor_list:
output_tensor.append(torch.cat(tensor_all, dim=0))
return output_tensor
class GatherLayer(autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""
@staticmethod
def forward(ctx, x):
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
dist.all_reduce(all_gradients)
return all_gradients[dist.get_rank()]
def all_gather_batch_with_grad(tensors):
"""
Performs all_gather operation on the provided tensors.
Graph remains connected for backward grad computation.
"""
# Queue the gathered tensors
world_size = get_world_size()
# There is no need for reduction in the single-proc case
if world_size == 1:
return tensors
tensor_list = []
output_tensor = []
for tensor in tensors:
tensor_all = GatherLayer.apply(tensor)
tensor_list.append(tensor_all)
for tensor_all in tensor_list:
output_tensor.append(torch.cat(tensor_all, dim=0))
return output_tensor
def unwrap_ddp_if_wrapped(model):
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
return model.module
return model
def create_new_process_group(group_size):
"""
Creates process groups of a gives `group_size` and returns
process group that current GPU participates in.
`group_size` must divide the total number of GPUs (world_size).
Modified from
https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60
Args:
group_size (int): number of GPU's to collaborate for sync bn
"""
assert group_size > 0
world_size = torch.distributed.get_world_size()
if world_size <= 8:
if group_size > world_size:
logging.warning(
f"Requested group size [{group_size}] > world size [{world_size}]. "
"Assuming local debug run and capping it to world size."
)
group_size = world_size
assert world_size >= group_size
assert world_size % group_size == 0
group = None
for group_num in range(world_size // group_size):
group_ids = range(group_num * group_size, (group_num + 1) * group_size)
cur_group = torch.distributed.new_group(ranks=group_ids)
if torch.distributed.get_rank() // group_size == group_num:
group = cur_group
# can not drop out and return here, every process must go through creation of all subgroups
assert group is not None
return group
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def gather_to_rank_0_via_filesys(data, filesys_save_dir=None):
"""
Gather any picklable data to rank 0 via filesystem, using all_gather_via_filesys.
"""
return all_gather_via_filesys(data, filesys_save_dir, gather_to_rank_0_only=True)

241
sam3/train/utils/logger.py Normal file
View File

@@ -0,0 +1,241 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import atexit
import functools
import logging
import sys
import uuid
from typing import Any, Dict, Optional, Union
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from numpy import ndarray
from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
Scalar = Union[Tensor, ndarray, int, float]
def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any):
makedir(log_dir)
summary_writer_method = SummaryWriter
return TensorBoardLogger(
path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs
)
class TensorBoardWriterWrapper:
"""
A wrapper around a SummaryWriter object.
"""
def __init__(
self,
path: str,
*args: Any,
filename_suffix: str = None,
summary_writer_method: Any = SummaryWriter,
**kwargs: Any,
) -> None:
"""Create a new TensorBoard logger.
On construction, the logger creates a new events file that logs
will be written to. If the environment variable `RANK` is defined,
logger will only log if RANK = 0.
NOTE: If using the logger with distributed training:
- This logger can call collective operations
- Logs will be written on rank 0 only
- Logger must be constructed synchronously *after* initializing distributed process group.
Args:
path (str): path to write logs to
*args, **kwargs: Extra arguments to pass to SummaryWriter
"""
self._writer: Optional[SummaryWriter] = None
_, self._rank = get_machine_local_and_dist_rank()
self._path: str = path
if self._rank == 0:
logging.info(
f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}"
)
self._writer = summary_writer_method(
log_dir=path,
*args,
filename_suffix=filename_suffix or str(uuid.uuid4()),
**kwargs,
)
else:
logging.debug(
f"Not logging meters on this host because env RANK: {self._rank} != 0"
)
atexit.register(self.close)
@property
def writer(self) -> Optional[SummaryWriter]:
return self._writer
@property
def path(self) -> str:
return self._path
def flush(self) -> None:
"""Writes pending logs to disk."""
if not self._writer:
return
self._writer.flush()
def close(self) -> None:
"""Close writer, flushing pending logs to disk.
Logs cannot be written after `close` is called.
"""
if not self._writer:
return
self._writer.close()
self._writer = None
class TensorBoardLogger(TensorBoardWriterWrapper):
"""
A simple logger for TensorBoard.
"""
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
"""Add multiple scalar values to TensorBoard.
Args:
payload (dict): dictionary of tag name and scalar value
step (int, Optional): step value to record
"""
if not self._writer:
return
for k, v in payload.items():
self.log(k, v, step)
def log(self, name: str, data: Scalar, step: int) -> None:
"""Add scalar data to TensorBoard.
Args:
name (string): tag name used to group scalars
data (float/int/Tensor): scalar data to log
step (int, optional): step value to record
"""
if not self._writer:
return
self._writer.add_scalar(name, data, global_step=step, new_style=True)
def log_hparams(
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
) -> None:
"""Add hyperparameter data to TensorBoard.
Args:
hparams (dict): dictionary of hyperparameter names and corresponding values
meters (dict): dictionary of name of meter and corersponding values
"""
if not self._writer:
return
self._writer.add_hparams(hparams, meters)
class Logger:
"""
A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger.
"""
def __init__(self, logging_conf):
# allow turning off TensorBoard with "should_log: false" in config
tb_config = logging_conf.tensorboard_writer
tb_should_log = tb_config and tb_config.pop("should_log", True)
self.tb_logger = instantiate(tb_config) if tb_should_log else None
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
if self.tb_logger:
self.tb_logger.log_dict(payload, step)
def log(self, name: str, data: Scalar, step: int) -> None:
if self.tb_logger:
self.tb_logger.log(name, data, step)
def log_hparams(
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
) -> None:
if self.tb_logger:
self.tb_logger.log_hparams(hparams, meters)
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
# we tune the buffering value so that the logs are updated
# frequently.
log_buffer_kb = 10 * 1024 # 10KB
io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb)
atexit.register(io.close)
return io
def setup_logging(
name,
output_dir=None,
rank=0,
log_level_primary="INFO",
log_level_secondary="ERROR",
):
"""
Setup various logging streams: stdout and file handlers.
For file handlers, we only setup for the master gpu.
"""
# get the filename if we want to log to the file as well
log_filename = None
if output_dir:
makedir(output_dir)
if rank == 0:
log_filename = f"{output_dir}/log.txt"
logger = logging.getLogger(name)
logger.setLevel(log_level_primary)
# create formatter
FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s"
formatter = logging.Formatter(FORMAT)
# Cleanup any existing handlers
for h in logger.handlers:
logger.removeHandler(h)
logger.root.handlers = []
# setup the console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
if rank == 0:
console_handler.setLevel(log_level_primary)
else:
console_handler.setLevel(log_level_secondary)
# we log to file as well if user wants
if log_filename and rank == 0:
file_handler = logging.StreamHandler(_cached_log_stream(log_filename))
file_handler.setLevel(log_level_primary)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logging.root = logger
def shutdown_logging():
"""
After training is done, we ensure to shut down all the logger streams.
"""
logging.info("Shutting down loggers...")
handlers = logging.root.handlers
for handler in handlers:
handler.close()

View File

@@ -0,0 +1,285 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import logging
import math
import os
import random
import re
from datetime import timedelta
from typing import Optional
import hydra
import numpy as np
import omegaconf
import torch
import torch.distributed as dist
from iopath.common.file_io import g_pathmgr
from omegaconf import OmegaConf
def multiply_all(*args):
return np.prod(np.array(args)).item()
def collect_dict_keys(config):
"""This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined"""
val_keys = []
# If the this config points to the collate function, then it has a key
if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]):
val_keys.append(config["dict_key"])
else:
# Recursively proceed
for v in config.values():
if isinstance(v, type(config)):
val_keys.extend(collect_dict_keys(v))
elif isinstance(v, omegaconf.listconfig.ListConfig):
for item in v:
if isinstance(item, type(config)):
val_keys.extend(collect_dict_keys(item))
return val_keys
class Phase:
TRAIN = "train"
VAL = "val"
def register_omegaconf_resolvers():
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
OmegaConf.register_new_resolver("get_class", hydra.utils.get_class)
OmegaConf.register_new_resolver("add", lambda x, y: x + y)
OmegaConf.register_new_resolver("times", multiply_all)
OmegaConf.register_new_resolver("divide", lambda x, y: x / y)
OmegaConf.register_new_resolver("pow", lambda x, y: x**y)
OmegaConf.register_new_resolver("subtract", lambda x, y: x - y)
OmegaConf.register_new_resolver("range", lambda x: list(range(x)))
OmegaConf.register_new_resolver("int", lambda x: int(x))
OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x)))
OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x))
OmegaConf.register_new_resolver("string", lambda x: str(x))
def setup_distributed_backend(backend, timeout_mins):
"""
Initialize torch.distributed and set the CUDA device.
Expects environment variables to be set as per
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
along with the environ variable "LOCAL_RANK" which is used to set the CUDA device.
"""
# enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins
# of waiting
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins")
dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins))
return dist.get_rank()
def get_machine_local_and_dist_rank():
"""
Get the distributed and local rank of the current gpu.
"""
local_rank = int(os.environ.get("LOCAL_RANK", None))
distributed_rank = int(os.environ.get("RANK", None))
assert (
local_rank is not None and distributed_rank is not None
), "Please the set the RANK and LOCAL_RANK environment variables."
return local_rank, distributed_rank
def print_cfg(cfg):
"""
Supports printing both Hydra DictConfig and also the AttrDict config
"""
logging.info("Training with config:")
logging.info(OmegaConf.to_yaml(cfg))
def set_seeds(seed_value, max_epochs, dist_rank):
"""
Set the python random, numpy and torch seed for each gpu. Also set the CUDA
seeds if the CUDA is available. This ensures deterministic nature of the training.
"""
# Since in the pytorch sampler, we increment the seed by 1 for every epoch.
seed_value = (seed_value + dist_rank) * max_epochs
logging.info(f"MACHINE SEED: {seed_value}")
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed_value)
def makedir(dir_path):
"""
Create the directory if it does not exist.
"""
is_success = False
try:
if not g_pathmgr.exists(dir_path):
g_pathmgr.mkdirs(dir_path)
is_success = True
except BaseException:
logging.info(f"Error creating directory: {dir_path}")
return is_success
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_amp_type(amp_type: Optional[str] = None):
if amp_type is None:
return None
assert amp_type in ["bfloat16", "float16"], "Invalid Amp type."
if amp_type == "bfloat16":
return torch.bfloat16
else:
return torch.float16
def log_env_variables():
env_keys = sorted(list(os.environ.keys()))
st = ""
for k in env_keys:
v = os.environ[k]
st += f"{k}={v}\n"
logging.info("Logging ENV_VARIABLES")
logging.info(st)
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, device, fmt=":f"):
self.name = name
self.fmt = fmt
self.device = device
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self._allow_updates = True
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
class MemMeter:
"""Computes and stores the current, avg, and max of peak Mem usage per iteration"""
def __init__(self, name, device, fmt=":f"):
self.name = name
self.fmt = fmt
self.device = device
self.reset()
def reset(self):
self.val = 0 # Per iteration max usage
self.avg = 0 # Avg per iteration max usage
self.peak = 0 # Peak usage for lifetime of program
self.sum = 0
self.count = 0
self._allow_updates = True
def update(self, n=1, reset_peak_usage=True):
self.val = torch.cuda.max_memory_allocated() // 1e9
self.sum += self.val * n
self.count += n
self.avg = self.sum / self.count
self.peak = max(self.peak, self.val)
if reset_peak_usage:
torch.cuda.reset_peak_memory_stats()
def __str__(self):
fmtstr = (
"{name}: {val"
+ self.fmt
+ "} ({avg"
+ self.fmt
+ "}/{peak"
+ self.fmt
+ "})"
)
return fmtstr.format(**self.__dict__)
def human_readable_time(time_seconds):
time = int(time_seconds)
minutes, seconds = divmod(time, 60)
hours, minutes = divmod(minutes, 60)
days, hours = divmod(hours, 24)
return f"{days:02}d {hours:02}h {minutes:02}m"
class DurationMeter:
def __init__(self, name, device, fmt=":f"):
self.name = name
self.device = device
self.fmt = fmt
self.val = 0
def reset(self):
self.val = 0
def update(self, val):
self.val = val
def add(self, val):
self.val += val
def __str__(self):
return f"{self.name}: {human_readable_time(self.val)}"
class ProgressMeter:
def __init__(self, num_batches, meters, real_meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.real_meters = real_meters
self.prefix = prefix
def display(self, batch, enable_print=False):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
entries += [
" | ".join(
[
f"{os.path.join(name, subname)}: {val:.4f}"
for subname, val in meter.compute().items()
]
)
for name, meter in self.real_meters.items()
]
logging.info(" | ".join(entries))
if enable_print:
print(" | ".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
def get_resume_checkpoint(checkpoint_save_dir):
if not g_pathmgr.isdir(checkpoint_save_dir):
return None
ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt")
if not g_pathmgr.isfile(ckpt_file):
return None
return ckpt_file