Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
1
sam3/train/__init__.py
Normal file
1
sam3/train/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
279
sam3/train/configs/eval_base.yaml
Normal file
279
sam3/train/configs/eval_base.yaml
Normal file
@@ -0,0 +1,279 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# This config is the base configuration for all evaluations. Amongst other things, it defines:
|
||||
# - the model
|
||||
# - the image transforms
|
||||
# - the post processors
|
||||
# - cluster configuration (only relevant for slurm-based evals, ignored otherwise)
|
||||
#
|
||||
# Most of the parameters should be kept as-is. The main modifications you may want to make are:
|
||||
# - the cluster configuration, to adjust partitions/qos to your system
|
||||
# - the flag gather_pred_via_filesys if you ram is tight
|
||||
# - num_val_workers if your number of cores is small (should be roughly number of cores / number of gpus)
|
||||
# - the paths below
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
# If you leave the checkpoint path to null, the model will be downloaded from hugging-face. Otherwise provide a path
|
||||
checkpoint_path: null
|
||||
# the experiments will be subfolders of this
|
||||
base_experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
|
||||
# base path to the annotation folder for gold (refer to the readmes on how to download)
|
||||
base_annotation_path: <YOUR_GOLD_GT_DIR>
|
||||
|
||||
# base path to the annotation folder for silver (refer to the readmes on how to download)
|
||||
base_annotation_path_silver: <YOUR_SILVER_GT_DIR>
|
||||
|
||||
# path to the metaclip images, used for SA-Co gold (refer to the readme for instructions). Can be null if you don't intend on evaluating on this dataset.
|
||||
metaclip_img_path: <YOUR_METACLIP_IMG_DIR>
|
||||
|
||||
# path to the sa1b images, used for SA-Co gold (refer to the readme for instructions). Can be null if you don't intend on evaluating on this dataset.
|
||||
sa1b_img_path: <YOUR_SA1B_IMG_DIR>
|
||||
|
||||
# path to the SA-Co/silver images
|
||||
silver_img_path: <YOUR_SILVER_IMG_DIR>
|
||||
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
base_val_transform:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
######## transforms for validation (begin) ########
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
max_size:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_max_size
|
||||
size: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: False
|
||||
######## transforms for validation (end) ########
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
loss: null
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
input_box_embedding_dim: ${add:${scratch.d_model},2}
|
||||
|
||||
# Box processing
|
||||
original_box_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessImage
|
||||
max_dets_per_img: -1 # infinite detections
|
||||
use_original_ids: true
|
||||
use_original_sizes_box: true
|
||||
use_presence: ${scratch.use_presence_eval}
|
||||
|
||||
box_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessImage
|
||||
max_dets_per_img: -1 #infinite detections
|
||||
use_original_ids: false
|
||||
use_original_sizes_box: false
|
||||
use_presence: ${scratch.use_presence_eval}
|
||||
|
||||
box_postprocessor_thresholded:
|
||||
_target_: sam3.eval.postprocessors.PostProcessImage
|
||||
max_dets_per_img: -1 #infinite detections
|
||||
use_original_ids: false
|
||||
use_original_sizes_box: false
|
||||
detection_threshold: 0.3
|
||||
use_presence: ${scratch.use_presence_eval}
|
||||
|
||||
mask_postprocessor_thresholded:
|
||||
_target_: sam3.eval.postprocessors.PostProcessImage
|
||||
max_dets_per_img: -1 #infinite detections
|
||||
iou_type: "segm"
|
||||
use_original_ids: false
|
||||
use_original_sizes_box: false
|
||||
use_original_sizes_mask: true
|
||||
convert_mask_to_rle: True
|
||||
detection_threshold: 0.3
|
||||
use_presence: ${scratch.use_presence_eval}
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
max_ann_per_img: 200
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
# Training parameters
|
||||
train_batch_size: 1
|
||||
val_batch_size: 1
|
||||
num_train_workers: 0
|
||||
num_val_workers: 10 # change this depending on the number of cpu cores available
|
||||
max_data_epochs: 20
|
||||
target_epoch_size: 1500
|
||||
hybrid_repeats: 1
|
||||
context_length: 2
|
||||
|
||||
# All reduce - this controls how the predictions are sent back to node 0.
|
||||
# If you have a lot of ram, CPU gather is faster. Otherwise, we provide a fallback through filesystem (eg NFS)
|
||||
# Switch to true if you get cpu ooms during gather.
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
# Learning rate and scheduler parameters (unused for eval)
|
||||
lr_scale: 0.1
|
||||
lr_transformer: ${times:8e-4,${scratch.lr_scale}}
|
||||
lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
|
||||
lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
|
||||
lrd_vision_backbone: 0.9 # (lower for in-domain adn higher for ood)
|
||||
wd: 0.1
|
||||
scheduler_timescale: 20
|
||||
scheduler_warmup: 20
|
||||
scheduler_cooldown: 20
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val: null
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_image_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
device: cpus
|
||||
eval_mode: true
|
||||
enable_segmentation: true # Warning: Enable this if using segmentation.
|
||||
checkpoint_path: ${paths.checkpoint_path}
|
||||
|
||||
meters:
|
||||
val: null
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
optimizer:
|
||||
_target_: torch.optim.AdamW
|
||||
|
||||
gradient_clip:
|
||||
_target_: sam3.train.optim.optimizer.GradientClipper
|
||||
max_norm: 0.1
|
||||
norm_type: 2
|
||||
|
||||
param_group_modifiers:
|
||||
- _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
|
||||
_partial_: True
|
||||
layer_decay_value: ${scratch.lrd_vision_backbone}
|
||||
apply_to: 'backbone.vision_backbone.trunk'
|
||||
overrides:
|
||||
- pattern: '*pos_embed*'
|
||||
value: 1.0
|
||||
|
||||
options:
|
||||
lr:
|
||||
- scheduler: # transformer and class_embed
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_transformer}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
- scheduler:
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_vision_backbone}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
param_names:
|
||||
- 'backbone.vision_backbone.*'
|
||||
- scheduler:
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_language_backbone}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
param_names:
|
||||
- 'backbone.language_backbone.*'
|
||||
|
||||
weight_decay:
|
||||
- scheduler:
|
||||
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
||||
value: ${scratch.wd}
|
||||
- scheduler:
|
||||
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
||||
value: 0.0
|
||||
param_names:
|
||||
- '*bias*'
|
||||
module_cls_names: ['torch.nn.LayerNorm']
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 4
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
|
||||
submitit:
|
||||
account: null # Add your SLURM account if use_cluster == 1
|
||||
partition: null
|
||||
qos: null # Add your QoS if use_cluster == 1
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
@@ -0,0 +1,66 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_attributes/
|
||||
coco_gt: ${paths.base_annotation_path}/gold_attributes_merged_a_release_test.json
|
||||
coco_gts:
|
||||
- ${paths.base_annotation_path}/gold_attributes_merged_a_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_attributes_merged_b_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_attributes_merged_c_release_test.json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.metaclip_img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: gold_attributes
|
||||
|
||||
meters:
|
||||
val:
|
||||
gold_attributes: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_attributes
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,66 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_crowded/
|
||||
coco_gt: ${paths.base_annotation_path}/gold_crowded_merged_a_release_test.json
|
||||
coco_gts:
|
||||
- ${paths.base_annotation_path}/gold_crowded_merged_a_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_crowded_merged_b_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_crowded_merged_c_release_test.json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.metaclip_img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: gold_crowded
|
||||
|
||||
meters:
|
||||
val:
|
||||
gold_crowded: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_crowded
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,66 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_fg_food/
|
||||
coco_gt: ${paths.base_annotation_path}/gold_fg_food_merged_a_release_test.json
|
||||
coco_gts:
|
||||
- ${paths.base_annotation_path}/gold_fg_food_merged_a_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_fg_food_merged_b_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_fg_food_merged_c_release_test.json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.metaclip_img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: gold_fg_food
|
||||
|
||||
meters:
|
||||
val:
|
||||
gold_fg_food: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_fg_food
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,66 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_fg_sports_equipment/
|
||||
coco_gt: ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_a_release_test.json
|
||||
coco_gts:
|
||||
- ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_a_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_b_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_c_release_test.json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.metaclip_img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: gold_fg_sports_equipment
|
||||
|
||||
meters:
|
||||
val:
|
||||
gold_fg_sports_equipment: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_fg_sports_equipment
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,66 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_metaclip_nps/
|
||||
coco_gt: ${paths.base_annotation_path}/gold_metaclip_merged_a_release_test.json
|
||||
coco_gts:
|
||||
- ${paths.base_annotation_path}/gold_metaclip_merged_a_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_metaclip_merged_b_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_metaclip_merged_c_release_test.json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.metaclip_img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: gold_metaclip_nps
|
||||
|
||||
meters:
|
||||
val:
|
||||
gold_metaclip_nps: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_metaclip_nps
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,66 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_sa1b_nps/
|
||||
coco_gt: ${paths.base_annotation_path}/gold_sa1b_merged_a_release_test.json
|
||||
coco_gts:
|
||||
- ${paths.base_annotation_path}/gold_sa1b_merged_a_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_sa1b_merged_b_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_sa1b_merged_c_release_test.json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.sa1b_img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: gold_sa1b_nps
|
||||
|
||||
meters:
|
||||
val:
|
||||
gold_sa1b_nps: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_sa1b_nps
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,66 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/gold_wiki_common/
|
||||
coco_gt: ${paths.base_annotation_path}/gold_wiki_common_merged_a_release_test.json
|
||||
coco_gts:
|
||||
- ${paths.base_annotation_path}/gold_wiki_common_merged_a_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_wiki_common_merged_b_release_test.json
|
||||
- ${paths.base_annotation_path}/gold_wiki_common_merged_c_release_test.json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.metaclip_img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: gold_wiki_common
|
||||
|
||||
meters:
|
||||
val:
|
||||
gold_wiki_common: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/gold_wiki_common
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gts}
|
||||
iou_type: "segm"
|
||||
255
sam3/train/configs/odinw13/odinw_text_and_visual.yaml
Normal file
255
sam3/train/configs/odinw13/odinw_text_and_visual.yaml
Normal file
@@ -0,0 +1,255 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS}
|
||||
|
||||
paths:
|
||||
odinw_data_root: <YOUR_DATA_DIR>
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
|
||||
supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}}
|
||||
# Validation transforms pipeline
|
||||
val_transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution}
|
||||
max_size:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_max_size
|
||||
size: ${scratch.resolution}
|
||||
square: true
|
||||
consistent_transform: False
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.TextQueryToVisual
|
||||
keep_text_queries: true # Note: set this to false if you only want visual
|
||||
probability: 1.0 # always
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
enable_segmentation: True
|
||||
# Box processing
|
||||
use_presence_eval: True
|
||||
original_box_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessImage
|
||||
max_dets_per_img: -1 # infinite detections
|
||||
use_original_ids: true
|
||||
use_original_sizes_box: true
|
||||
use_presence: ${scratch.use_presence_eval}
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
# Normalization parameters
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
# Training parameters
|
||||
val_batch_size: 2
|
||||
num_val_workers: 0
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
max_epochs: 1
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
|
||||
prompts: ${odinw35_prompts.${supercategory_tuple.name}}
|
||||
include_negatives: true
|
||||
category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories!
|
||||
_partial_: true
|
||||
img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder}
|
||||
ann_file:
|
||||
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
|
||||
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
|
||||
transforms: ${val_transforms}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: 1
|
||||
dict_key: odinw35
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_image_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
device: cpus
|
||||
eval_mode: true # Set to false if training
|
||||
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
|
||||
|
||||
meters:
|
||||
val:
|
||||
odinw35:
|
||||
detection:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "bbox"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${supercategory_tuple.name}
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.original_box_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 100
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
|
||||
gt_path:
|
||||
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
|
||||
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
|
||||
tide: False
|
||||
iou_type: "bbox"
|
||||
positive_split: true
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name}
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 1
|
||||
gpus_per_node: 2
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
|
||||
job_array:
|
||||
num_tasks: 13
|
||||
task_index: 0
|
||||
|
||||
# ============================================================================
|
||||
# ODinW13 Supercategories
|
||||
# ============================================================================
|
||||
|
||||
all_odinw_supercategories:
|
||||
- name: AerialMaritimeDrone_large
|
||||
val:
|
||||
img_folder: AerialMaritimeDrone/large/test/
|
||||
json: AerialMaritimeDrone/large/test/annotations_without_background.json
|
||||
- name: Aquarium
|
||||
val:
|
||||
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/
|
||||
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json
|
||||
- name: CottontailRabbits
|
||||
val:
|
||||
img_folder: CottontailRabbits/test/
|
||||
json: CottontailRabbits/test/annotations_without_background.json
|
||||
- name: EgoHands_generic
|
||||
val:
|
||||
img_folder: EgoHands/generic/test/
|
||||
json: EgoHands/generic/test/annotations_without_background.json
|
||||
- name: NorthAmericaMushrooms
|
||||
val:
|
||||
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/
|
||||
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json
|
||||
- name: Packages
|
||||
val:
|
||||
img_folder: Packages/Raw/test/
|
||||
json: Packages/Raw/test/annotations_without_background.json
|
||||
- name: PascalVOC
|
||||
val:
|
||||
img_folder: PascalVOC/valid/
|
||||
json: PascalVOC/valid/annotations_without_background.json
|
||||
- name: Raccoon
|
||||
val:
|
||||
img_folder: Raccoon/Raccoon.v2-raw.coco/test/
|
||||
json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json
|
||||
- name: ShellfishOpenImages
|
||||
val:
|
||||
img_folder: ShellfishOpenImages/raw/test/
|
||||
json: ShellfishOpenImages/raw/test/annotations_without_background.json
|
||||
- name: VehiclesOpenImages
|
||||
val:
|
||||
img_folder: VehiclesOpenImages/416x416/test/
|
||||
json: VehiclesOpenImages/416x416/test/annotations_without_background.json
|
||||
- name: pistols
|
||||
val:
|
||||
img_folder: pistols/export/
|
||||
json: pistols/export/test_annotations_without_background.json
|
||||
- name: pothole
|
||||
val:
|
||||
img_folder: pothole/test/
|
||||
json: pothole/test/annotations_without_background.json
|
||||
- name: thermalDogsAndPeople
|
||||
val:
|
||||
img_folder: thermalDogsAndPeople/test/
|
||||
json: thermalDogsAndPeople/test/annotations_without_background.json
|
||||
|
||||
|
||||
odinw35_prompts:
|
||||
AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"},
|
||||
{"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock",
|
||||
"supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"},
|
||||
{"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]'
|
||||
Aquarium: null
|
||||
CottontailRabbits: null
|
||||
EgoHands_generic: null
|
||||
NorthAmericaMushrooms: '[{''id'': 1, ''name'':
|
||||
''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]'
|
||||
Packages: null
|
||||
PascalVOC: null
|
||||
Raccoon: null
|
||||
ShellfishOpenImages: null
|
||||
VehiclesOpenImages: null
|
||||
pistols: null
|
||||
pothole: null
|
||||
thermalDogsAndPeople: null
|
||||
253
sam3/train/configs/odinw13/odinw_text_only.yaml
Normal file
253
sam3/train/configs/odinw13/odinw_text_only.yaml
Normal file
@@ -0,0 +1,253 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS}
|
||||
|
||||
paths:
|
||||
odinw_data_root: <YOUR_DATA_DIR>
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
|
||||
|
||||
supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}}
|
||||
# Validation transforms pipeline
|
||||
val_transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution}
|
||||
max_size:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_max_size
|
||||
size: ${scratch.resolution}
|
||||
square: true
|
||||
consistent_transform: False
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
enable_segmentation: True
|
||||
# Box processing
|
||||
use_presence_eval: True
|
||||
original_box_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessImage
|
||||
max_dets_per_img: -1 # infinite detections
|
||||
use_original_ids: true
|
||||
use_original_sizes_box: true
|
||||
use_presence: ${scratch.use_presence_eval}
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
# Normalization parameters
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
# Training parameters
|
||||
val_batch_size: 2
|
||||
num_val_workers: 0
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
max_epochs: 1
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
|
||||
prompts: ${odinw35_prompts.${supercategory_tuple.name}}
|
||||
include_negatives: true
|
||||
category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories!
|
||||
_partial_: true
|
||||
img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder}
|
||||
ann_file:
|
||||
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
|
||||
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
|
||||
transforms: ${val_transforms}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: 1
|
||||
dict_key: odinw35
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_image_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
device: cpus
|
||||
eval_mode: true # Set to false if training
|
||||
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
|
||||
|
||||
meters:
|
||||
val:
|
||||
odinw35:
|
||||
detection:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "bbox"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/odinw/${supercategory_tuple.name}
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.original_box_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 100
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
|
||||
gt_path:
|
||||
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
|
||||
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
|
||||
tide: False
|
||||
iou_type: "bbox"
|
||||
positive_split: False
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name}
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 1
|
||||
gpus_per_node: 2
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
|
||||
job_array:
|
||||
num_tasks: 13
|
||||
task_index: 0
|
||||
|
||||
# ============================================================================
|
||||
# ODinW13 Supercategories
|
||||
# ============================================================================
|
||||
|
||||
all_odinw_supercategories:
|
||||
- name: AerialMaritimeDrone_large
|
||||
val:
|
||||
img_folder: AerialMaritimeDrone/large/test/
|
||||
json: AerialMaritimeDrone/large/test/annotations_without_background.json
|
||||
- name: Aquarium
|
||||
val:
|
||||
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/
|
||||
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json
|
||||
- name: CottontailRabbits
|
||||
val:
|
||||
img_folder: CottontailRabbits/test/
|
||||
json: CottontailRabbits/test/annotations_without_background.json
|
||||
- name: EgoHands_generic
|
||||
val:
|
||||
img_folder: EgoHands/generic/test/
|
||||
json: EgoHands/generic/test/annotations_without_background.json
|
||||
- name: NorthAmericaMushrooms
|
||||
val:
|
||||
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/
|
||||
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json
|
||||
- name: Packages
|
||||
val:
|
||||
img_folder: Packages/Raw/test/
|
||||
json: Packages/Raw/test/annotations_without_background.json
|
||||
- name: PascalVOC
|
||||
val:
|
||||
img_folder: PascalVOC/valid/
|
||||
json: PascalVOC/valid/annotations_without_background.json
|
||||
- name: Raccoon
|
||||
val:
|
||||
img_folder: Raccoon/Raccoon.v2-raw.coco/test/
|
||||
json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json
|
||||
- name: ShellfishOpenImages
|
||||
val:
|
||||
img_folder: ShellfishOpenImages/raw/test/
|
||||
json: ShellfishOpenImages/raw/test/annotations_without_background.json
|
||||
- name: VehiclesOpenImages
|
||||
val:
|
||||
img_folder: VehiclesOpenImages/416x416/test/
|
||||
json: VehiclesOpenImages/416x416/test/annotations_without_background.json
|
||||
- name: pistols
|
||||
val:
|
||||
img_folder: pistols/export/
|
||||
json: pistols/export/test_annotations_without_background.json
|
||||
- name: pothole
|
||||
val:
|
||||
img_folder: pothole/test/
|
||||
json: pothole/test/annotations_without_background.json
|
||||
- name: thermalDogsAndPeople
|
||||
val:
|
||||
img_folder: thermalDogsAndPeople/test/
|
||||
json: thermalDogsAndPeople/test/annotations_without_background.json
|
||||
|
||||
|
||||
odinw35_prompts:
|
||||
AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"},
|
||||
{"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock",
|
||||
"supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"},
|
||||
{"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]'
|
||||
Aquarium: null
|
||||
CottontailRabbits: null
|
||||
EgoHands_generic: null
|
||||
NorthAmericaMushrooms: '[{''id'': 1, ''name'':
|
||||
''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]'
|
||||
Packages: null
|
||||
PascalVOC: null
|
||||
Raccoon: null
|
||||
ShellfishOpenImages: null
|
||||
VehiclesOpenImages: null
|
||||
pistols: null
|
||||
pothole: null
|
||||
thermalDogsAndPeople: null
|
||||
253
sam3/train/configs/odinw13/odinw_text_only_positive.yaml
Normal file
253
sam3/train/configs/odinw13/odinw_text_only_positive.yaml
Normal file
@@ -0,0 +1,253 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS}
|
||||
|
||||
paths:
|
||||
odinw_data_root: <YOUR_DATA_DIR>
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
|
||||
|
||||
supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}}
|
||||
# Validation transforms pipeline
|
||||
val_transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution}
|
||||
max_size:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_max_size
|
||||
size: ${scratch.resolution}
|
||||
square: true
|
||||
consistent_transform: False
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
enable_segmentation: True
|
||||
# Box processing
|
||||
use_presence_eval: True
|
||||
original_box_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessImage
|
||||
max_dets_per_img: -1 # infinite detections
|
||||
use_original_ids: true
|
||||
use_original_sizes_box: true
|
||||
use_presence: ${scratch.use_presence_eval}
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
# Normalization parameters
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
# Training parameters
|
||||
val_batch_size: 2
|
||||
num_val_workers: 0
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
max_epochs: 1
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
|
||||
prompts: ${odinw35_prompts.${supercategory_tuple.name}}
|
||||
include_negatives: true
|
||||
category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories!
|
||||
_partial_: true
|
||||
img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder}
|
||||
ann_file:
|
||||
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
|
||||
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
|
||||
transforms: ${val_transforms}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: 1
|
||||
dict_key: odinw35
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_image_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
device: cpus
|
||||
eval_mode: true # Set to false if training
|
||||
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
|
||||
|
||||
meters:
|
||||
val:
|
||||
odinw35:
|
||||
detection:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "bbox"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${supercategory_tuple.name}
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.original_box_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 100
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
|
||||
gt_path:
|
||||
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
|
||||
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
|
||||
tide: False
|
||||
iou_type: "bbox"
|
||||
positive_split: true
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name}
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 1
|
||||
gpus_per_node: 2
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
|
||||
job_array:
|
||||
num_tasks: 13
|
||||
task_index: 0
|
||||
|
||||
# ============================================================================
|
||||
# ODinW13 Supercategories
|
||||
# ============================================================================
|
||||
|
||||
all_odinw_supercategories:
|
||||
- name: AerialMaritimeDrone_large
|
||||
val:
|
||||
img_folder: AerialMaritimeDrone/large/test/
|
||||
json: AerialMaritimeDrone/large/test/annotations_without_background.json
|
||||
- name: Aquarium
|
||||
val:
|
||||
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/
|
||||
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json
|
||||
- name: CottontailRabbits
|
||||
val:
|
||||
img_folder: CottontailRabbits/test/
|
||||
json: CottontailRabbits/test/annotations_without_background.json
|
||||
- name: EgoHands_generic
|
||||
val:
|
||||
img_folder: EgoHands/generic/test/
|
||||
json: EgoHands/generic/test/annotations_without_background.json
|
||||
- name: NorthAmericaMushrooms
|
||||
val:
|
||||
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/
|
||||
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json
|
||||
- name: Packages
|
||||
val:
|
||||
img_folder: Packages/Raw/test/
|
||||
json: Packages/Raw/test/annotations_without_background.json
|
||||
- name: PascalVOC
|
||||
val:
|
||||
img_folder: PascalVOC/valid/
|
||||
json: PascalVOC/valid/annotations_without_background.json
|
||||
- name: Raccoon
|
||||
val:
|
||||
img_folder: Raccoon/Raccoon.v2-raw.coco/test/
|
||||
json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json
|
||||
- name: ShellfishOpenImages
|
||||
val:
|
||||
img_folder: ShellfishOpenImages/raw/test/
|
||||
json: ShellfishOpenImages/raw/test/annotations_without_background.json
|
||||
- name: VehiclesOpenImages
|
||||
val:
|
||||
img_folder: VehiclesOpenImages/416x416/test/
|
||||
json: VehiclesOpenImages/416x416/test/annotations_without_background.json
|
||||
- name: pistols
|
||||
val:
|
||||
img_folder: pistols/export/
|
||||
json: pistols/export/test_annotations_without_background.json
|
||||
- name: pothole
|
||||
val:
|
||||
img_folder: pothole/test/
|
||||
json: pothole/test/annotations_without_background.json
|
||||
- name: thermalDogsAndPeople
|
||||
val:
|
||||
img_folder: thermalDogsAndPeople/test/
|
||||
json: thermalDogsAndPeople/test/annotations_without_background.json
|
||||
|
||||
|
||||
odinw35_prompts:
|
||||
AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"},
|
||||
{"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock",
|
||||
"supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"},
|
||||
{"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]'
|
||||
Aquarium: null
|
||||
CottontailRabbits: null
|
||||
EgoHands_generic: null
|
||||
NorthAmericaMushrooms: '[{''id'': 1, ''name'':
|
||||
''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]'
|
||||
Packages: null
|
||||
PascalVOC: null
|
||||
Raccoon: null
|
||||
ShellfishOpenImages: null
|
||||
VehiclesOpenImages: null
|
||||
pistols: null
|
||||
pothole: null
|
||||
thermalDogsAndPeople: null
|
||||
591
sam3/train/configs/odinw13/odinw_text_only_train.yaml
Normal file
591
sam3/train/configs/odinw13/odinw_text_only_train.yaml
Normal file
@@ -0,0 +1,591 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS}
|
||||
|
||||
paths:
|
||||
odinw_data_root: <YOUR_DATA_DIR>
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
|
||||
|
||||
odinw_train:
|
||||
train_file: fewshot_train_shot10_seed300
|
||||
num_images: null
|
||||
supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}}
|
||||
# Training transforms pipeline
|
||||
train_transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterCrowds
|
||||
- _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox
|
||||
box_noise_std: 0.1
|
||||
box_noise_max: 20
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_scales
|
||||
size: ${scratch.resolution}
|
||||
min_size: 480
|
||||
rounded: false
|
||||
max_size:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_max_size
|
||||
size: ${scratch.resolution}
|
||||
square: true
|
||||
consistent_transform: ${scratch.consistent_transform}
|
||||
- _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI
|
||||
size: ${scratch.resolution}
|
||||
consistent_transform: ${scratch.consistent_transform}
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.train_norm_mean}
|
||||
std: ${scratch.train_norm_std}
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut
|
||||
max_num_objects: ${scratch.max_ann_per_img}
|
||||
|
||||
# Validation transforms pipeline
|
||||
val_transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution}
|
||||
max_size:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_max_size
|
||||
size: ${scratch.resolution}
|
||||
square: true
|
||||
consistent_transform: False
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# loss config (no mask loss)
|
||||
loss:
|
||||
_target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
|
||||
matcher: ${scratch.matcher}
|
||||
o2m_weight: 2.0
|
||||
o2m_matcher:
|
||||
_target_: sam3.train.matcher.BinaryOneToManyMatcher
|
||||
alpha: 0.3
|
||||
threshold: 0.4
|
||||
topk: 4
|
||||
use_o2m_matcher_on_o2m_aux: ${scratch.use_o2m_matcher_on_o2m_aux}
|
||||
loss_fns_find:
|
||||
- _target_: sam3.train.loss.loss_fns.Boxes
|
||||
weight_dict:
|
||||
loss_bbox: 5.0
|
||||
loss_giou: 2.0
|
||||
- _target_: sam3.train.loss.loss_fns.IABCEMdetr
|
||||
weak_loss: False
|
||||
weight_dict:
|
||||
loss_ce: ${scratch.loss_ce_weight} # Change
|
||||
presence_loss: ${scratch.presence_weight} # Change
|
||||
pos_weight: ${scratch.iabce_pos_weight}
|
||||
alpha: ${scratch.iabce_alpha}
|
||||
gamma: 2
|
||||
use_presence: True # Change
|
||||
pos_focal: ${scratch.iabce_pos_focal}
|
||||
pad_n_queries: ${scratch.num_queries}
|
||||
pad_scale_pos: ${scratch.instance_query_loss_pad_scale_pos}
|
||||
|
||||
loss_fn_semantic_seg: null
|
||||
scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
enable_segmentation: False
|
||||
use_act_checkpoint_geo_encoder: True
|
||||
input_geometry_encoder:
|
||||
_target_: sam3.model.geometry_encoders.SequenceGeometryEncoder
|
||||
pos_enc: ${scratch.pos_embed}
|
||||
encode_boxes_as_points: False
|
||||
points_direct_project: True
|
||||
points_pool: True
|
||||
points_pos_enc: True
|
||||
boxes_direct_project: True
|
||||
boxes_pool: True
|
||||
boxes_pos_enc: True
|
||||
d_model: ${scratch.d_model}
|
||||
num_layers: 3
|
||||
use_act_ckpt: ${scratch.use_act_checkpoint_geo_encoder}
|
||||
layer:
|
||||
_target_: sam3.model.encoder.TransformerEncoderLayer
|
||||
activation: "relu"
|
||||
d_model: ${scratch.d_model}
|
||||
dim_feedforward: 2048
|
||||
dropout: ${scratch.encoder_dropout}
|
||||
pos_enc_at_attn: false
|
||||
pre_norm: True
|
||||
pos_enc_at_cross_attn_queries: false
|
||||
pos_enc_at_cross_attn_keys: true
|
||||
self_attention:
|
||||
_target_: sam3.model.attention.MultiheadAttention
|
||||
attn_type: Vanilla
|
||||
num_heads: 8
|
||||
dropout: ${scratch.encoder_dropout}
|
||||
embed_dim: ${scratch.d_model}
|
||||
batch_first: False
|
||||
cross_attention:
|
||||
_target_: sam3.model.attention.MultiheadAttention
|
||||
attn_type: Vanilla
|
||||
num_heads: 8
|
||||
dropout: ${scratch.encoder_dropout}
|
||||
embed_dim: ${scratch.d_model}
|
||||
batch_first: False
|
||||
add_cls: true
|
||||
add_post_encode_proj: True
|
||||
|
||||
boxRPB: "log"
|
||||
dac: True
|
||||
use_early_fusion: true
|
||||
o2m_mask: false
|
||||
num_feature_levels: 1 # > 1 not implemented
|
||||
encoder_dropout: 0.1
|
||||
decoder_dropout: 0.1
|
||||
|
||||
tokenizer_ve:
|
||||
_target_: sam3.model.tokenizer_ve.SimpleTokenizer
|
||||
bpe_path: ${paths.bpe_path}
|
||||
|
||||
|
||||
freeze_text_tower: False
|
||||
freeze_image_tower: NoFreeze
|
||||
vis_backbone_dp: 0.0
|
||||
# Activation checkpointing (Save memory)
|
||||
use_act_checkpoint_vision_backbone: True
|
||||
use_act_checkpoint_text_backbone: True
|
||||
use_act_checkpoint_encoder: True
|
||||
use_act_checkpoint_decoder: True
|
||||
|
||||
loss: null
|
||||
# Loss parameters
|
||||
num_queries: 200
|
||||
presence_weight: 20.0
|
||||
loss_ce_weight: 20.0
|
||||
iabce_pos_weight: 5.0
|
||||
iabce_pos_focal: false
|
||||
iabce_alpha: 0.25
|
||||
instance_query_loss_pad_scale_pos: 1.0
|
||||
use_o2m_matcher_on_o2m_aux: false
|
||||
|
||||
# Model parameters
|
||||
use_instance_query: true
|
||||
d_model: 256
|
||||
pos_embed:
|
||||
_target_: sam3.model.position_encoding.PositionEmbeddingSine
|
||||
num_pos_feats: ${scratch.d_model}
|
||||
normalize: true
|
||||
scale: null
|
||||
temperature: 10000
|
||||
|
||||
# Box processing
|
||||
use_presence_eval: True
|
||||
original_box_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessImage
|
||||
max_dets_per_img: -1 # infinite detections
|
||||
use_original_ids: true
|
||||
use_original_sizes_box: true
|
||||
use_presence: ${scratch.use_presence_eval}
|
||||
|
||||
|
||||
# Matcher configuration
|
||||
matcher:
|
||||
_target_: sam3.train.matcher.BinaryHungarianMatcherV2
|
||||
focal: true
|
||||
cost_class: 2.0
|
||||
cost_bbox: 5.0
|
||||
cost_giou: 2.0
|
||||
alpha: 0.25
|
||||
gamma: 2
|
||||
stable: False
|
||||
scale_by_find_batch_size: True
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
consistent_transform: False
|
||||
max_ann_per_img: 200
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
# Training parameters
|
||||
train_batch_size: 1
|
||||
val_batch_size: 1
|
||||
num_train_workers: 0
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 40
|
||||
target_epoch_size: 1500
|
||||
hybrid_repeats: 1
|
||||
context_length: 2
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
# Learning rate and scheduler parameters
|
||||
lr_scale: 0.1
|
||||
lr_transformer: ${times:8e-4,${scratch.lr_scale}}
|
||||
lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
|
||||
lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
|
||||
lrd_vision_backbone: 0.9
|
||||
wd: 0.1
|
||||
scheduler_timescale: 20
|
||||
scheduler_warmup: 20
|
||||
scheduler_cooldown: 20
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
# _target_: sam3.train.trainer.Trainer
|
||||
# skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: train
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all: ${odinw_train.loss}
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
limit_ids: ${odinw_train.num_images}
|
||||
transforms: ${odinw_train.train_transforms}
|
||||
load_segmentation: ${scratch.enable_segmentation}
|
||||
max_ann_per_img: 500000
|
||||
multiplier: 1
|
||||
max_train_queries: 50000
|
||||
max_val_queries: 50000
|
||||
training: true
|
||||
use_caching: False
|
||||
img_folder: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.train.img_folder}
|
||||
ann_file:
|
||||
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
|
||||
input_json_path: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.train.json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
|
||||
prompts: ${odinw35_prompts.${odinw_train.supercategory_tuple.name}} #${odinw_train.supercategory_tuple.name)
|
||||
_partial_: true
|
||||
shuffle: True
|
||||
batch_size: ${scratch.train_batch_size}
|
||||
num_workers: ${scratch.num_train_workers}
|
||||
pin_memory: False
|
||||
drop_last: True
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: all
|
||||
with_seg_masks: ${scratch.enable_segmentation}
|
||||
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
load_segmentation: ${scratch.enable_segmentation}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
|
||||
prompts: ${odinw35_prompts.${odinw_train.supercategory_tuple.name}}
|
||||
include_negatives: true
|
||||
category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories!
|
||||
_partial_: true
|
||||
img_folder: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.val.img_folder}
|
||||
ann_file:
|
||||
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
|
||||
input_json_path: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.val.json}
|
||||
transforms: ${odinw_train.val_transforms}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: 1
|
||||
dict_key: odinw35
|
||||
with_seg_masks: ${scratch.enable_segmentation}
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_image_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
device: cpus
|
||||
eval_mode: false # Set to false if training
|
||||
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
|
||||
|
||||
meters:
|
||||
val:
|
||||
odinw35:
|
||||
detection:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "bbox"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/odinw/${odinw_train.supercategory_tuple.name}
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.original_box_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 100
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
|
||||
gt_path:
|
||||
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
|
||||
input_json_path: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.val.json}
|
||||
tide: False
|
||||
iou_type: "bbox"
|
||||
positive_split: False
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
optimizer:
|
||||
_target_: torch.optim.AdamW
|
||||
|
||||
gradient_clip:
|
||||
_target_: sam3.train.optim.optimizer.GradientClipper
|
||||
max_norm: 0.1
|
||||
norm_type: 2
|
||||
|
||||
param_group_modifiers:
|
||||
- _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
|
||||
_partial_: True
|
||||
layer_decay_value: ${scratch.lrd_vision_backbone}
|
||||
apply_to: 'backbone.vision_backbone.trunk'
|
||||
overrides:
|
||||
- pattern: '*pos_embed*'
|
||||
value: 1.0
|
||||
|
||||
options:
|
||||
lr:
|
||||
- scheduler: # transformer and class_embed
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_transformer}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
- scheduler:
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_vision_backbone}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
param_names:
|
||||
- 'backbone.vision_backbone.*'
|
||||
- scheduler:
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_language_backbone}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
param_names:
|
||||
- 'backbone.language_backbone.*'
|
||||
|
||||
weight_decay:
|
||||
- scheduler:
|
||||
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
||||
value: ${scratch.wd}
|
||||
- scheduler:
|
||||
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
||||
value: 0.0
|
||||
param_names:
|
||||
- '*bias*'
|
||||
module_cls_names: ['torch.nn.LayerNorm']
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/${odinw_train.supercategory_tuple.name}
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 1
|
||||
gpus_per_node: 2
|
||||
experiment_log_dir: null #${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
|
||||
# task_index: 2
|
||||
# Uncomment for job array configuration
|
||||
job_array:
|
||||
num_tasks: 13
|
||||
task_index: 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ODinW13 Supercategories
|
||||
# ============================================================================
|
||||
|
||||
all_odinw_supercategories:
|
||||
- name: AerialMaritimeDrone_large
|
||||
val:
|
||||
img_folder: AerialMaritimeDrone/large/test/
|
||||
json: AerialMaritimeDrone/large/test/annotations_without_background.json
|
||||
train:
|
||||
img_folder: AerialMaritimeDrone/large/train/
|
||||
json: AerialMaritimeDrone/large/train/${odinw_train.train_file}.json
|
||||
- name: Aquarium
|
||||
val:
|
||||
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/
|
||||
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json
|
||||
train:
|
||||
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/train/
|
||||
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/train/${odinw_train.train_file}.json
|
||||
- name: CottontailRabbits
|
||||
val:
|
||||
img_folder: CottontailRabbits/test/
|
||||
json: CottontailRabbits/test/annotations_without_background.json
|
||||
train:
|
||||
img_folder: CottontailRabbits/train/
|
||||
json: CottontailRabbits/train/${odinw_train.train_file}.json
|
||||
- name: EgoHands_generic
|
||||
val:
|
||||
img_folder: EgoHands/generic/test/
|
||||
json: EgoHands/generic/test/annotations_without_background.json
|
||||
train:
|
||||
img_folder: EgoHands/generic/train/
|
||||
json: EgoHands/generic/train/${odinw_train.train_file}.json
|
||||
- name: NorthAmericaMushrooms
|
||||
val:
|
||||
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/
|
||||
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json
|
||||
train:
|
||||
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/train/
|
||||
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/train/${odinw_train.train_file}.json
|
||||
- name: Packages
|
||||
val:
|
||||
img_folder: Packages/Raw/test/
|
||||
json: Packages/Raw/test/annotations_without_background.json
|
||||
train:
|
||||
img_folder: Packages/Raw/train/
|
||||
json: Packages/Raw/train/${odinw_train.train_file}.json
|
||||
- name: PascalVOC
|
||||
val:
|
||||
img_folder: PascalVOC/valid/
|
||||
json: PascalVOC/valid/annotations_without_background.json
|
||||
train:
|
||||
img_folder: PascalVOC/train/
|
||||
json: PascalVOC/train/${odinw_train.train_file}.json
|
||||
- name: Raccoon
|
||||
val:
|
||||
img_folder: Raccoon/Raccoon.v2-raw.coco/test/
|
||||
json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json
|
||||
train:
|
||||
img_folder: Raccoon/Raccoon.v2-raw.coco/train/
|
||||
json: Raccoon/Raccoon.v2-raw.coco/train/${odinw_train.train_file}.json
|
||||
- name: ShellfishOpenImages
|
||||
val:
|
||||
img_folder: ShellfishOpenImages/raw/test/
|
||||
json: ShellfishOpenImages/raw/test/annotations_without_background.json
|
||||
train:
|
||||
img_folder: ShellfishOpenImages/raw/train/
|
||||
json: ShellfishOpenImages/raw/train/${odinw_train.train_file}.json
|
||||
- name: VehiclesOpenImages
|
||||
val:
|
||||
img_folder: VehiclesOpenImages/416x416/test/
|
||||
json: VehiclesOpenImages/416x416/test/annotations_without_background.json
|
||||
train:
|
||||
img_folder: VehiclesOpenImages/416x416/train/
|
||||
json: VehiclesOpenImages/416x416/train/${odinw_train.train_file}.json
|
||||
- name: pistols
|
||||
val:
|
||||
img_folder: pistols/export/
|
||||
json: pistols/export/test_annotations_without_background.json
|
||||
train:
|
||||
img_folder: pistols/export/
|
||||
json: pistols/export/${odinw_train.train_file}.json
|
||||
- name: pothole
|
||||
val:
|
||||
img_folder: pothole/test/
|
||||
json: pothole/test/annotations_without_background.json
|
||||
train:
|
||||
img_folder: pothole/train/
|
||||
json: pothole/train/${odinw_train.train_file}.json
|
||||
- name: thermalDogsAndPeople
|
||||
val:
|
||||
img_folder: thermalDogsAndPeople/test/
|
||||
json: thermalDogsAndPeople/test/annotations_without_background.json
|
||||
train:
|
||||
img_folder: thermalDogsAndPeople/train/
|
||||
json: thermalDogsAndPeople/train/${odinw_train.train_file}.json
|
||||
|
||||
|
||||
odinw35_prompts:
|
||||
AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"},
|
||||
{"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock",
|
||||
"supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"},
|
||||
{"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]'
|
||||
Aquarium: null
|
||||
CottontailRabbits: null
|
||||
EgoHands_generic: null
|
||||
NorthAmericaMushrooms: '[{''id'': 1, ''name'':
|
||||
''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]'
|
||||
Packages: null
|
||||
PascalVOC: null
|
||||
Raccoon: null
|
||||
ShellfishOpenImages: null
|
||||
VehiclesOpenImages: null
|
||||
pistols: null
|
||||
pothole: null
|
||||
thermalDogsAndPeople: null
|
||||
256
sam3/train/configs/odinw13/odinw_visual_only.yaml
Normal file
256
sam3/train/configs/odinw13/odinw_visual_only.yaml
Normal file
@@ -0,0 +1,256 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS}
|
||||
|
||||
paths:
|
||||
odinw_data_root: <YOUR_DATA_DIR>
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
|
||||
|
||||
supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}}
|
||||
# Validation transforms pipeline
|
||||
val_transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution}
|
||||
max_size:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_max_size
|
||||
size: ${scratch.resolution}
|
||||
square: true
|
||||
consistent_transform: False
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.TextQueryToVisual
|
||||
keep_text_queries: false # Note: set this to false if you only want visual
|
||||
probability: 1.0 # always
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
enable_segmentation: True
|
||||
# Box processing
|
||||
use_presence_eval: True
|
||||
original_box_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessImage
|
||||
max_dets_per_img: -1 # infinite detections
|
||||
use_original_ids: true
|
||||
use_original_sizes_box: true
|
||||
use_presence: ${scratch.use_presence_eval}
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
# Normalization parameters
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
# Training parameters
|
||||
val_batch_size: 2
|
||||
num_val_workers: 0
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
max_epochs: 1
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
|
||||
prompts: ${odinw35_prompts.${supercategory_tuple.name}}
|
||||
include_negatives: true
|
||||
category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories!
|
||||
_partial_: true
|
||||
img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder}
|
||||
ann_file:
|
||||
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
|
||||
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
|
||||
transforms: ${val_transforms}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: 1
|
||||
dict_key: odinw35
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_image_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
device: cpus
|
||||
eval_mode: true # Set to false if training
|
||||
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
|
||||
|
||||
meters:
|
||||
val:
|
||||
odinw35:
|
||||
detection:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "bbox"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${supercategory_tuple.name}
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.original_box_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 100
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
|
||||
gt_path:
|
||||
_target_: sam3.eval.coco_reindex.reindex_coco_to_temp
|
||||
input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json}
|
||||
tide: False
|
||||
iou_type: "bbox"
|
||||
positive_split: true
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name}
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 1
|
||||
gpus_per_node: 2
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
|
||||
job_array:
|
||||
num_tasks: 13
|
||||
task_index: 0
|
||||
|
||||
# ============================================================================
|
||||
# ODinW13 Supercategories
|
||||
# ============================================================================
|
||||
|
||||
all_odinw_supercategories:
|
||||
- name: AerialMaritimeDrone_large
|
||||
val:
|
||||
img_folder: AerialMaritimeDrone/large/test/
|
||||
json: AerialMaritimeDrone/large/test/annotations_without_background.json
|
||||
- name: Aquarium
|
||||
val:
|
||||
img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/
|
||||
json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json
|
||||
- name: CottontailRabbits
|
||||
val:
|
||||
img_folder: CottontailRabbits/test/
|
||||
json: CottontailRabbits/test/annotations_without_background.json
|
||||
- name: EgoHands_generic
|
||||
val:
|
||||
img_folder: EgoHands/generic/test/
|
||||
json: EgoHands/generic/test/annotations_without_background.json
|
||||
- name: NorthAmericaMushrooms
|
||||
val:
|
||||
img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/
|
||||
json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json
|
||||
- name: Packages
|
||||
val:
|
||||
img_folder: Packages/Raw/test/
|
||||
json: Packages/Raw/test/annotations_without_background.json
|
||||
- name: PascalVOC
|
||||
val:
|
||||
img_folder: PascalVOC/valid/
|
||||
json: PascalVOC/valid/annotations_without_background.json
|
||||
- name: Raccoon
|
||||
val:
|
||||
img_folder: Raccoon/Raccoon.v2-raw.coco/test/
|
||||
json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json
|
||||
- name: ShellfishOpenImages
|
||||
val:
|
||||
img_folder: ShellfishOpenImages/raw/test/
|
||||
json: ShellfishOpenImages/raw/test/annotations_without_background.json
|
||||
- name: VehiclesOpenImages
|
||||
val:
|
||||
img_folder: VehiclesOpenImages/416x416/test/
|
||||
json: VehiclesOpenImages/416x416/test/annotations_without_background.json
|
||||
- name: pistols
|
||||
val:
|
||||
img_folder: pistols/export/
|
||||
json: pistols/export/test_annotations_without_background.json
|
||||
- name: pothole
|
||||
val:
|
||||
img_folder: pothole/test/
|
||||
json: pothole/test/annotations_without_background.json
|
||||
- name: thermalDogsAndPeople
|
||||
val:
|
||||
img_folder: thermalDogsAndPeople/test/
|
||||
json: thermalDogsAndPeople/test/annotations_without_background.json
|
||||
|
||||
|
||||
odinw35_prompts:
|
||||
AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"},
|
||||
{"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock",
|
||||
"supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"},
|
||||
{"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]'
|
||||
Aquarium: null
|
||||
CottontailRabbits: null
|
||||
EgoHands_generic: null
|
||||
NorthAmericaMushrooms: '[{''id'': 1, ''name'':
|
||||
''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]'
|
||||
Packages: null
|
||||
PascalVOC: null
|
||||
Raccoon: null
|
||||
ShellfishOpenImages: null
|
||||
VehiclesOpenImages: null
|
||||
pistols: null
|
||||
pothole: null
|
||||
thermalDogsAndPeople: null
|
||||
539
sam3/train/configs/roboflow_v100/roboflow_v100_eval.yaml
Normal file
539
sam3/train/configs/roboflow_v100/roboflow_v100_eval.yaml
Normal file
@@ -0,0 +1,539 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
roboflow_vl_100_root: <YOUR_DATASET_DIR>
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
|
||||
# Roboflow dataset configuration
|
||||
roboflow_train:
|
||||
num_images: 100 # Note: This is the number of images used for training. If null, all images are used.
|
||||
supercategory: ${all_roboflow_supercategories.${string:${submitit.job_array.task_index}}}
|
||||
|
||||
# Training transforms pipeline
|
||||
train_transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterCrowds
|
||||
- _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox
|
||||
box_noise_std: 0.1
|
||||
box_noise_max: 20
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_scales
|
||||
size: ${scratch.resolution}
|
||||
min_size: 480
|
||||
rounded: false
|
||||
max_size:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_max_size
|
||||
size: ${scratch.resolution}
|
||||
square: true
|
||||
consistent_transform: ${scratch.consistent_transform}
|
||||
- _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI
|
||||
size: ${scratch.resolution}
|
||||
consistent_transform: ${scratch.consistent_transform}
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.train_norm_mean}
|
||||
std: ${scratch.train_norm_std}
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut
|
||||
max_num_objects: ${scratch.max_ann_per_img}
|
||||
|
||||
# Validation transforms pipeline
|
||||
val_transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution}
|
||||
max_size:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_max_size
|
||||
size: ${scratch.resolution}
|
||||
square: true
|
||||
consistent_transform: False
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.train_norm_mean}
|
||||
std: ${scratch.train_norm_std}
|
||||
|
||||
# loss config (no mask loss)
|
||||
loss:
|
||||
_target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
|
||||
matcher: ${scratch.matcher}
|
||||
o2m_weight: 2.0
|
||||
o2m_matcher:
|
||||
_target_: sam3.train.matcher.BinaryOneToManyMatcher
|
||||
alpha: 0.3
|
||||
threshold: 0.4
|
||||
topk: 4
|
||||
use_o2m_matcher_on_o2m_aux: false # Another option is true
|
||||
loss_fns_find:
|
||||
- _target_: sam3.train.loss.loss_fns.Boxes
|
||||
weight_dict:
|
||||
loss_bbox: 5.0
|
||||
loss_giou: 2.0
|
||||
- _target_: sam3.train.loss.loss_fns.IABCEMdetr
|
||||
weak_loss: False
|
||||
weight_dict:
|
||||
loss_ce: 20.0 # Another option is 100.0
|
||||
presence_loss: 20.0
|
||||
pos_weight: 10.0 # Another option is 5.0
|
||||
alpha: 0.25
|
||||
gamma: 2
|
||||
use_presence: True # Change
|
||||
pos_focal: false
|
||||
pad_n_queries: 200
|
||||
pad_scale_pos: 1.0
|
||||
|
||||
loss_fn_semantic_seg: null
|
||||
scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
|
||||
|
||||
|
||||
# NOTE: Loss to be used for training in case of segmentation
|
||||
# loss:
|
||||
# _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
|
||||
# matcher: ${scratch.matcher}
|
||||
# o2m_weight: 2.0
|
||||
# o2m_matcher:
|
||||
# _target_: sam3.train.matcher.BinaryOneToManyMatcher
|
||||
# alpha: 0.3
|
||||
# threshold: 0.4
|
||||
# topk: 4
|
||||
# use_o2m_matcher_on_o2m_aux: false
|
||||
# loss_fns_find:
|
||||
# - _target_: sam3.train.loss.loss_fns.Boxes
|
||||
# weight_dict:
|
||||
# loss_bbox: 5.0
|
||||
# loss_giou: 2.0
|
||||
# - _target_: sam3.train.loss.loss_fns.IABCEMdetr
|
||||
# weak_loss: False
|
||||
# weight_dict:
|
||||
# loss_ce: 20.0 # Another option is 100.0
|
||||
# presence_loss: 20.0
|
||||
# pos_weight: 10.0 # Another option is 5.0
|
||||
# alpha: 0.25
|
||||
# gamma: 2
|
||||
# use_presence: True # Change
|
||||
# pos_focal: false
|
||||
# pad_n_queries: 200
|
||||
# pad_scale_pos: 1.0
|
||||
# - _target_: sam3.train.loss.loss_fns.Masks
|
||||
# focal_alpha: 0.25
|
||||
# focal_gamma: 2.0
|
||||
# weight_dict:
|
||||
# loss_mask: 200.0
|
||||
# loss_dice: 10.0
|
||||
# compute_aux: false
|
||||
# loss_fn_semantic_seg:
|
||||
# _target_: sam3.losses.loss_fns.SemanticSegCriterion
|
||||
# presence_head: True
|
||||
# presence_loss: False # Change
|
||||
# focal: True
|
||||
# focal_alpha: 0.6
|
||||
# focal_gamma: 2.0
|
||||
# downsample: False
|
||||
# weight_dict:
|
||||
# loss_semantic_seg: 20.0
|
||||
# loss_semantic_presence: 1.0
|
||||
# loss_semantic_dice: 30.0
|
||||
# scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
enable_segmentation: False # NOTE: This is the number of queries used for segmentation
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
pos_embed:
|
||||
_target_: sam3.model.position_encoding.PositionEmbeddingSine
|
||||
num_pos_feats: ${scratch.d_model}
|
||||
normalize: true
|
||||
scale: null
|
||||
temperature: 10000
|
||||
|
||||
# Box processing
|
||||
use_presence_eval: True
|
||||
original_box_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessImage
|
||||
max_dets_per_img: -1 # infinite detections
|
||||
use_original_ids: true
|
||||
use_original_sizes_box: true
|
||||
use_presence: ${scratch.use_presence_eval}
|
||||
|
||||
# Matcher configuration
|
||||
matcher:
|
||||
_target_: sam3.train.matcher.BinaryHungarianMatcherV2
|
||||
focal: true # with `focal: true` it is equivalent to BinaryFocalHungarianMatcher
|
||||
cost_class: 2.0
|
||||
cost_bbox: 5.0
|
||||
cost_giou: 2.0
|
||||
alpha: 0.25
|
||||
gamma: 2
|
||||
stable: False
|
||||
scale_by_find_batch_size: True
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
consistent_transform: False
|
||||
max_ann_per_img: 200
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
# Training parameters
|
||||
num_train_workers: 10
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
target_epoch_size: 1500
|
||||
hybrid_repeats: 1
|
||||
context_length: 2
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
# Learning rate and scheduler parameters
|
||||
lr_scale: 0.1
|
||||
lr_transformer: ${times:8e-4,${scratch.lr_scale}}
|
||||
lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
|
||||
lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
|
||||
lrd_vision_backbone: 0.9
|
||||
wd: 0.1
|
||||
scheduler_timescale: 20
|
||||
scheduler_warmup: 20
|
||||
scheduler_cooldown: 20
|
||||
|
||||
val_batch_size: 1
|
||||
collate_fn_val:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: roboflow100
|
||||
with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
train_batch_size: 1
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: all
|
||||
with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: 20
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
gradient_accumulation_steps: ${scratch.gradient_accumulation_steps}
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all: ${roboflow_train.loss}
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
limit_ids: ${roboflow_train.num_images}
|
||||
transforms: ${roboflow_train.train_transforms}
|
||||
load_segmentation: ${scratch.enable_segmentation}
|
||||
max_ann_per_img: 500000
|
||||
multiplier: 1
|
||||
max_train_queries: 50000
|
||||
max_val_queries: 50000
|
||||
training: true
|
||||
use_caching: False
|
||||
img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/
|
||||
ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/_annotations.coco.json
|
||||
|
||||
shuffle: True
|
||||
batch_size: ${scratch.train_batch_size}
|
||||
num_workers: ${scratch.num_train_workers}
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
collate_fn: ${scratch.collate_fn}
|
||||
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
load_segmentation: ${scratch.enable_segmentation}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
|
||||
include_negatives: true
|
||||
category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
|
||||
_partial_: true
|
||||
img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/
|
||||
ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
|
||||
transforms: ${roboflow_train.val_transforms}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn: ${scratch.collate_fn_val}
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_image_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
device: cpus
|
||||
eval_mode: true
|
||||
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
|
||||
|
||||
meters:
|
||||
val:
|
||||
roboflow100:
|
||||
detection:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "bbox"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${roboflow_train.supercategory}
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.original_box_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 100
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
|
||||
gt_path: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
|
||||
tide: False
|
||||
iou_type: "bbox"
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
optimizer:
|
||||
_target_: torch.optim.AdamW
|
||||
|
||||
gradient_clip:
|
||||
_target_: sam3.train.optim.optimizer.GradientClipper
|
||||
max_norm: 0.1
|
||||
norm_type: 2
|
||||
|
||||
param_group_modifiers:
|
||||
- _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
|
||||
_partial_: True
|
||||
layer_decay_value: ${scratch.lrd_vision_backbone}
|
||||
apply_to: 'backbone.vision_backbone.trunk'
|
||||
overrides:
|
||||
- pattern: '*pos_embed*'
|
||||
value: 1.0
|
||||
|
||||
options:
|
||||
lr:
|
||||
- scheduler: # transformer and class_embed
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_transformer}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
- scheduler:
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_vision_backbone}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
param_names:
|
||||
- 'backbone.vision_backbone.*'
|
||||
- scheduler:
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_language_backbone}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
param_names:
|
||||
- 'backbone.language_backbone.*'
|
||||
|
||||
weight_decay:
|
||||
- scheduler:
|
||||
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
||||
value: ${scratch.wd}
|
||||
- scheduler:
|
||||
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
||||
value: 0.0
|
||||
param_names:
|
||||
- '*bias*'
|
||||
module_cls_names: ['torch.nn.LayerNorm']
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/${roboflow_train.supercategory}
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 1
|
||||
gpus_per_node: 2
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
# Uncomment for job array configuration
|
||||
job_array:
|
||||
num_tasks: 100
|
||||
task_index: 0
|
||||
|
||||
# ============================================================================
|
||||
# Available Roboflow Supercategories (for reference)
|
||||
# ============================================================================
|
||||
|
||||
all_roboflow_supercategories:
|
||||
- -grccs
|
||||
- zebrasatasturias
|
||||
- cod-mw-warzone
|
||||
- canalstenosis
|
||||
- label-printing-defect-version-2
|
||||
- new-defects-in-wood
|
||||
- orionproducts
|
||||
- aquarium-combined
|
||||
- varroa-mites-detection--test-set
|
||||
- clashroyalechardetector
|
||||
- stomata-cells
|
||||
- halo-infinite-angel-videogame
|
||||
- pig-detection
|
||||
- urine-analysis1
|
||||
- aerial-sheep
|
||||
- orgharvest
|
||||
- actions
|
||||
- mahjong
|
||||
- liver-disease
|
||||
- needle-base-tip-min-max
|
||||
- wheel-defect-detection
|
||||
- aircraft-turnaround-dataset
|
||||
- xray
|
||||
- wildfire-smoke
|
||||
- spinefrxnormalvindr
|
||||
- ufba-425
|
||||
- speech-bubbles-detection
|
||||
- train
|
||||
- pill
|
||||
- truck-movement
|
||||
- car-logo-detection
|
||||
- inbreast
|
||||
- sea-cucumbers-new-tiles
|
||||
- uavdet-small
|
||||
- penguin-finder-seg
|
||||
- aerial-airport
|
||||
- bibdetection
|
||||
- taco-trash-annotations-in-context
|
||||
- bees
|
||||
- recode-waste
|
||||
- screwdetectclassification
|
||||
- wine-labels
|
||||
- aerial-cows
|
||||
- into-the-vale
|
||||
- gwhd2021
|
||||
- lacrosse-object-detection
|
||||
- defect-detection
|
||||
- dataconvert
|
||||
- x-ray-id
|
||||
- ball
|
||||
- tube
|
||||
- 2024-frc
|
||||
- crystal-clean-brain-tumors-mri-dataset
|
||||
- grapes-5
|
||||
- human-detection-in-floods
|
||||
- buoy-onboarding
|
||||
- apoce-aerial-photographs-for-object-detection-of-construction-equipment
|
||||
- l10ul502
|
||||
- floating-waste
|
||||
- deeppcb
|
||||
- ism-band-packet-detection
|
||||
- weeds4
|
||||
- invoice-processing
|
||||
- thermal-cheetah
|
||||
- tomatoes-2
|
||||
- marine-sharks
|
||||
- peixos-fish
|
||||
- sssod
|
||||
- aerial-pool
|
||||
- countingpills
|
||||
- asphaltdistressdetection
|
||||
- roboflow-trained-dataset
|
||||
- everdaynew
|
||||
- underwater-objects
|
||||
- soda-bottles
|
||||
- dentalai
|
||||
- jellyfish
|
||||
- deepfruits
|
||||
- activity-diagrams
|
||||
- circuit-voltages
|
||||
- all-elements
|
||||
- macro-segmentation
|
||||
- exploratorium-daphnia
|
||||
- signatures
|
||||
- conveyor-t-shirts
|
||||
- fruitjes
|
||||
- grass-weeds
|
||||
- infraredimageofpowerequipment
|
||||
- 13-lkc01
|
||||
- wb-prova
|
||||
- flir-camera-objects
|
||||
- paper-parts
|
||||
- football-player-detection
|
||||
- trail-camera
|
||||
- smd-components
|
||||
- water-meter
|
||||
- nih-xray
|
||||
- the-dreidel-project
|
||||
- electric-pylon-detection-in-rsi
|
||||
- cable-damage
|
||||
@@ -0,0 +1,539 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
roboflow_vl_100_root: <YOUR_DATASET_DIR>
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
|
||||
# Roboflow dataset configuration
|
||||
roboflow_train:
|
||||
num_images: 100 # Note: This is the number of images used for training. If null, all images are used.
|
||||
supercategory: ${all_roboflow_supercategories.${string:${submitit.job_array.task_index}}}
|
||||
|
||||
# Training transforms pipeline
|
||||
train_transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterCrowds
|
||||
- _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox
|
||||
box_noise_std: 0.1
|
||||
box_noise_max: 20
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_scales
|
||||
size: ${scratch.resolution}
|
||||
min_size: 480
|
||||
rounded: false
|
||||
max_size:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_max_size
|
||||
size: ${scratch.resolution}
|
||||
square: true
|
||||
consistent_transform: ${scratch.consistent_transform}
|
||||
- _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI
|
||||
size: ${scratch.resolution}
|
||||
consistent_transform: ${scratch.consistent_transform}
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.train_norm_mean}
|
||||
std: ${scratch.train_norm_std}
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
|
||||
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
|
||||
query_filter:
|
||||
_target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut
|
||||
max_num_objects: ${scratch.max_ann_per_img}
|
||||
|
||||
# Validation transforms pipeline
|
||||
val_transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution}
|
||||
max_size:
|
||||
_target_: sam3.train.transforms.basic.get_random_resize_max_size
|
||||
size: ${scratch.resolution}
|
||||
square: true
|
||||
consistent_transform: False
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.train_norm_mean}
|
||||
std: ${scratch.train_norm_std}
|
||||
|
||||
# loss config (no mask loss)
|
||||
loss:
|
||||
_target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
|
||||
matcher: ${scratch.matcher}
|
||||
o2m_weight: 2.0
|
||||
o2m_matcher:
|
||||
_target_: sam3.train.matcher.BinaryOneToManyMatcher
|
||||
alpha: 0.3
|
||||
threshold: 0.4
|
||||
topk: 4
|
||||
use_o2m_matcher_on_o2m_aux: false # Another option is true
|
||||
loss_fns_find:
|
||||
- _target_: sam3.train.loss.loss_fns.Boxes
|
||||
weight_dict:
|
||||
loss_bbox: 5.0
|
||||
loss_giou: 2.0
|
||||
- _target_: sam3.train.loss.loss_fns.IABCEMdetr
|
||||
weak_loss: False
|
||||
weight_dict:
|
||||
loss_ce: 20.0 # Another option is 100.0
|
||||
presence_loss: 20.0
|
||||
pos_weight: 10.0 # Another option is 5.0
|
||||
alpha: 0.25
|
||||
gamma: 2
|
||||
use_presence: True # Change
|
||||
pos_focal: false
|
||||
pad_n_queries: 200
|
||||
pad_scale_pos: 1.0
|
||||
|
||||
loss_fn_semantic_seg: null
|
||||
scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
|
||||
|
||||
|
||||
# NOTE: Loss to be used for training in case of segmentation
|
||||
# loss:
|
||||
# _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
|
||||
# matcher: ${scratch.matcher}
|
||||
# o2m_weight: 2.0
|
||||
# o2m_matcher:
|
||||
# _target_: sam3.train.matcher.BinaryOneToManyMatcher
|
||||
# alpha: 0.3
|
||||
# threshold: 0.4
|
||||
# topk: 4
|
||||
# use_o2m_matcher_on_o2m_aux: false
|
||||
# loss_fns_find:
|
||||
# - _target_: sam3.train.loss.loss_fns.Boxes
|
||||
# weight_dict:
|
||||
# loss_bbox: 5.0
|
||||
# loss_giou: 2.0
|
||||
# - _target_: sam3.train.loss.loss_fns.IABCEMdetr
|
||||
# weak_loss: False
|
||||
# weight_dict:
|
||||
# loss_ce: 20.0 # Another option is 100.0
|
||||
# presence_loss: 20.0
|
||||
# pos_weight: 10.0 # Another option is 5.0
|
||||
# alpha: 0.25
|
||||
# gamma: 2
|
||||
# use_presence: True # Change
|
||||
# pos_focal: false
|
||||
# pad_n_queries: 200
|
||||
# pad_scale_pos: 1.0
|
||||
# - _target_: sam3.train.loss.loss_fns.Masks
|
||||
# focal_alpha: 0.25
|
||||
# focal_gamma: 2.0
|
||||
# weight_dict:
|
||||
# loss_mask: 200.0
|
||||
# loss_dice: 10.0
|
||||
# compute_aux: false
|
||||
# loss_fn_semantic_seg:
|
||||
# _target_: sam3.losses.loss_fns.SemanticSegCriterion
|
||||
# presence_head: True
|
||||
# presence_loss: False # Change
|
||||
# focal: True
|
||||
# focal_alpha: 0.6
|
||||
# focal_gamma: 2.0
|
||||
# downsample: False
|
||||
# weight_dict:
|
||||
# loss_semantic_seg: 20.0
|
||||
# loss_semantic_presence: 1.0
|
||||
# loss_semantic_dice: 30.0
|
||||
# scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
enable_segmentation: False # NOTE: This is the number of queries used for segmentation
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
pos_embed:
|
||||
_target_: sam3.model.position_encoding.PositionEmbeddingSine
|
||||
num_pos_feats: ${scratch.d_model}
|
||||
normalize: true
|
||||
scale: null
|
||||
temperature: 10000
|
||||
|
||||
# Box processing
|
||||
use_presence_eval: True
|
||||
original_box_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessImage
|
||||
max_dets_per_img: -1 # infinite detections
|
||||
use_original_ids: true
|
||||
use_original_sizes_box: true
|
||||
use_presence: ${scratch.use_presence_eval}
|
||||
|
||||
# Matcher configuration
|
||||
matcher:
|
||||
_target_: sam3.train.matcher.BinaryHungarianMatcherV2
|
||||
focal: true # with `focal: true` it is equivalent to BinaryFocalHungarianMatcher
|
||||
cost_class: 2.0
|
||||
cost_bbox: 5.0
|
||||
cost_giou: 2.0
|
||||
alpha: 0.25
|
||||
gamma: 2
|
||||
stable: False
|
||||
scale_by_find_batch_size: True
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
consistent_transform: False
|
||||
max_ann_per_img: 200
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
# Training parameters
|
||||
num_train_workers: 10
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
target_epoch_size: 1500
|
||||
hybrid_repeats: 1
|
||||
context_length: 2
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
# Learning rate and scheduler parameters
|
||||
lr_scale: 0.1
|
||||
lr_transformer: ${times:8e-4,${scratch.lr_scale}}
|
||||
lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
|
||||
lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
|
||||
lrd_vision_backbone: 0.9
|
||||
wd: 0.1
|
||||
scheduler_timescale: 20
|
||||
scheduler_warmup: 20
|
||||
scheduler_cooldown: 20
|
||||
|
||||
val_batch_size: 1
|
||||
collate_fn_val:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: roboflow100
|
||||
with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
train_batch_size: 1
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: all
|
||||
with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: 20
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: train
|
||||
gradient_accumulation_steps: ${scratch.gradient_accumulation_steps}
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all: ${roboflow_train.loss}
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
limit_ids: ${roboflow_train.num_images}
|
||||
transforms: ${roboflow_train.train_transforms}
|
||||
load_segmentation: ${scratch.enable_segmentation}
|
||||
max_ann_per_img: 500000
|
||||
multiplier: 1
|
||||
max_train_queries: 50000
|
||||
max_val_queries: 50000
|
||||
training: true
|
||||
use_caching: False
|
||||
img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/
|
||||
ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/_annotations.coco.json
|
||||
|
||||
shuffle: True
|
||||
batch_size: ${scratch.train_batch_size}
|
||||
num_workers: ${scratch.num_train_workers}
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
collate_fn: ${scratch.collate_fn}
|
||||
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
load_segmentation: ${scratch.enable_segmentation}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
|
||||
include_negatives: true
|
||||
category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
|
||||
_partial_: true
|
||||
img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/
|
||||
ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
|
||||
transforms: ${roboflow_train.val_transforms}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn: ${scratch.collate_fn_val}
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_image_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
device: cpus
|
||||
eval_mode: false
|
||||
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
|
||||
|
||||
meters:
|
||||
val:
|
||||
roboflow100:
|
||||
detection:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "bbox"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${roboflow_train.supercategory}
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.original_box_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 100
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
|
||||
gt_path: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
|
||||
tide: False
|
||||
iou_type: "bbox"
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
optimizer:
|
||||
_target_: torch.optim.AdamW
|
||||
|
||||
gradient_clip:
|
||||
_target_: sam3.train.optim.optimizer.GradientClipper
|
||||
max_norm: 0.1
|
||||
norm_type: 2
|
||||
|
||||
param_group_modifiers:
|
||||
- _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
|
||||
_partial_: True
|
||||
layer_decay_value: ${scratch.lrd_vision_backbone}
|
||||
apply_to: 'backbone.vision_backbone.trunk'
|
||||
overrides:
|
||||
- pattern: '*pos_embed*'
|
||||
value: 1.0
|
||||
|
||||
options:
|
||||
lr:
|
||||
- scheduler: # transformer and class_embed
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_transformer}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
- scheduler:
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_vision_backbone}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
param_names:
|
||||
- 'backbone.vision_backbone.*'
|
||||
- scheduler:
|
||||
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
|
||||
base_lr: ${scratch.lr_language_backbone}
|
||||
timescale: ${scratch.scheduler_timescale}
|
||||
warmup_steps: ${scratch.scheduler_warmup}
|
||||
cooldown_steps: ${scratch.scheduler_cooldown}
|
||||
param_names:
|
||||
- 'backbone.language_backbone.*'
|
||||
|
||||
weight_decay:
|
||||
- scheduler:
|
||||
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
||||
value: ${scratch.wd}
|
||||
- scheduler:
|
||||
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
||||
value: 0.0
|
||||
param_names:
|
||||
- '*bias*'
|
||||
module_cls_names: ['torch.nn.LayerNorm']
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/${roboflow_train.supercategory}
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 1
|
||||
gpus_per_node: 2
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
# Uncomment for job array configuration
|
||||
job_array:
|
||||
num_tasks: 100
|
||||
task_index: 0
|
||||
|
||||
# ============================================================================
|
||||
# Available Roboflow Supercategories (for reference)
|
||||
# ============================================================================
|
||||
|
||||
all_roboflow_supercategories:
|
||||
- -grccs
|
||||
- zebrasatasturias
|
||||
- cod-mw-warzone
|
||||
- canalstenosis
|
||||
- label-printing-defect-version-2
|
||||
- new-defects-in-wood
|
||||
- orionproducts
|
||||
- aquarium-combined
|
||||
- varroa-mites-detection--test-set
|
||||
- clashroyalechardetector
|
||||
- stomata-cells
|
||||
- halo-infinite-angel-videogame
|
||||
- pig-detection
|
||||
- urine-analysis1
|
||||
- aerial-sheep
|
||||
- orgharvest
|
||||
- actions
|
||||
- mahjong
|
||||
- liver-disease
|
||||
- needle-base-tip-min-max
|
||||
- wheel-defect-detection
|
||||
- aircraft-turnaround-dataset
|
||||
- xray
|
||||
- wildfire-smoke
|
||||
- spinefrxnormalvindr
|
||||
- ufba-425
|
||||
- speech-bubbles-detection
|
||||
- train
|
||||
- pill
|
||||
- truck-movement
|
||||
- car-logo-detection
|
||||
- inbreast
|
||||
- sea-cucumbers-new-tiles
|
||||
- uavdet-small
|
||||
- penguin-finder-seg
|
||||
- aerial-airport
|
||||
- bibdetection
|
||||
- taco-trash-annotations-in-context
|
||||
- bees
|
||||
- recode-waste
|
||||
- screwdetectclassification
|
||||
- wine-labels
|
||||
- aerial-cows
|
||||
- into-the-vale
|
||||
- gwhd2021
|
||||
- lacrosse-object-detection
|
||||
- defect-detection
|
||||
- dataconvert
|
||||
- x-ray-id
|
||||
- ball
|
||||
- tube
|
||||
- 2024-frc
|
||||
- crystal-clean-brain-tumors-mri-dataset
|
||||
- grapes-5
|
||||
- human-detection-in-floods
|
||||
- buoy-onboarding
|
||||
- apoce-aerial-photographs-for-object-detection-of-construction-equipment
|
||||
- l10ul502
|
||||
- floating-waste
|
||||
- deeppcb
|
||||
- ism-band-packet-detection
|
||||
- weeds4
|
||||
- invoice-processing
|
||||
- thermal-cheetah
|
||||
- tomatoes-2
|
||||
- marine-sharks
|
||||
- peixos-fish
|
||||
- sssod
|
||||
- aerial-pool
|
||||
- countingpills
|
||||
- asphaltdistressdetection
|
||||
- roboflow-trained-dataset
|
||||
- everdaynew
|
||||
- underwater-objects
|
||||
- soda-bottles
|
||||
- dentalai
|
||||
- jellyfish
|
||||
- deepfruits
|
||||
- activity-diagrams
|
||||
- circuit-voltages
|
||||
- all-elements
|
||||
- macro-segmentation
|
||||
- exploratorium-daphnia
|
||||
- signatures
|
||||
- conveyor-t-shirts
|
||||
- fruitjes
|
||||
- grass-weeds
|
||||
- infraredimageofpowerequipment
|
||||
- 13-lkc01
|
||||
- wb-prova
|
||||
- flir-camera-objects
|
||||
- paper-parts
|
||||
- football-player-detection
|
||||
- trail-camera
|
||||
- smd-components
|
||||
- water-meter
|
||||
- nih-xray
|
||||
- the-dreidel-project
|
||||
- electric-pylon-detection-in-rsi
|
||||
- cable-damage
|
||||
174
sam3/train/configs/saco_video_evals/saco_veval_sav_test.yaml
Normal file
174
sam3/train/configs/saco_video_evals/saco_veval_sav_test.yaml
Normal file
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_sav_test
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_sav_test.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: True
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_sav_test
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_sav_test.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: False
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
174
sam3/train/configs/saco_video_evals/saco_veval_sav_val.yaml
Normal file
174
sam3/train/configs/saco_video_evals/saco_veval_sav_val.yaml
Normal file
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_sav_val
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_sav_val.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: True
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_sav_val
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_sav_val.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: False
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_smartglasses_test
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_smartglasses_test.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: True
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_smartglasses_test
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_smartglasses_test.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: False
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_smartglasses_val
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_smartglasses_val.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: True
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_smartglasses_val
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_smartglasses_val.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: False
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
174
sam3/train/configs/saco_video_evals/saco_veval_yt1b_test.yaml
Normal file
174
sam3/train/configs/saco_video_evals/saco_veval_yt1b_test.yaml
Normal file
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_yt1b_test
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_yt1b_test.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: True
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_yt1b_test
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_yt1b_test.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: False
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
174
sam3/train/configs/saco_video_evals/saco_veval_yt1b_val.yaml
Normal file
174
sam3/train/configs/saco_video_evals/saco_veval_yt1b_val.yaml
Normal file
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_yt1b_val
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_yt1b_val.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: True
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
@@ -0,0 +1,174 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (Chage this to your own paths)
|
||||
# ============================================================================
|
||||
paths:
|
||||
|
||||
dump_file_name: saco_veval_yt1b_val
|
||||
experiment_log_dir: <YOUR EXPERIMENET LOG_DIR>
|
||||
ytvis_json: <YOUR_GT_PATH>/saco_veval_yt1b_val.json
|
||||
ytvis_dir : <YOUR_VIDEO_JPG_DIR>
|
||||
bpe_path: <BPE_PATH> # This should be under assets/bpe_simple_vocab_16e6.txt.gz
|
||||
num_videos: null
|
||||
|
||||
# ============================================================================
|
||||
# Different helper parameters and functions
|
||||
# ============================================================================
|
||||
scratch:
|
||||
vid_mask_postprocessor:
|
||||
_target_: sam3.eval.postprocessors.PostProcessNullOp
|
||||
|
||||
use_presence_eval: True
|
||||
|
||||
video_transforms_val:
|
||||
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
|
||||
transforms:
|
||||
- _target_: sam3.train.transforms.segmentation.DecodeRle
|
||||
# resize the image to 1024x1024 resolution
|
||||
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
|
||||
sizes: ${scratch.resolution} # originally `resolution: 1024`
|
||||
square: true
|
||||
consistent_transform: true
|
||||
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
|
||||
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
|
||||
mean: ${scratch.val_norm_mean}
|
||||
std: ${scratch.val_norm_std}
|
||||
|
||||
# Model parameters
|
||||
d_model: 256
|
||||
|
||||
# Image processing parameters
|
||||
resolution: 1008
|
||||
|
||||
# Normalization parameters
|
||||
train_norm_mean: [0.5, 0.5, 0.5]
|
||||
train_norm_std: [0.5, 0.5, 0.5]
|
||||
val_norm_mean: [0.5, 0.5, 0.5]
|
||||
val_norm_std: [0.5, 0.5, 0.5]
|
||||
|
||||
val_batch_size: 1
|
||||
num_val_workers: 0
|
||||
max_data_epochs: 20
|
||||
hybrid_repeats: 1
|
||||
gather_pred_via_filesys: false
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
_target_: sam3.train.trainer.Trainer
|
||||
skip_saving_ckpts: true
|
||||
empty_gpu_mem_cache_after_eval: True
|
||||
skip_first_val: True
|
||||
max_epochs: ${scratch.max_data_epochs}
|
||||
accelerator: cuda
|
||||
seed_value: 123
|
||||
val_epoch_freq: 10
|
||||
mode: val
|
||||
|
||||
distributed:
|
||||
backend: nccl
|
||||
find_unused_parameters: True
|
||||
gradient_as_bucket_view: True
|
||||
|
||||
loss:
|
||||
all:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
default:
|
||||
_target_: sam3.train.loss.sam3_loss.DummyLoss
|
||||
|
||||
data:
|
||||
train: null
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset
|
||||
limit_ids: ${paths.num_videos}
|
||||
img_folder: ${paths.ytvis_dir}
|
||||
ann_file: ${paths.ytvis_json}
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
|
||||
transforms: ${scratch.video_transforms_val}
|
||||
max_ann_per_img: 100000 # filtered in transforms
|
||||
max_val_queries: 100000
|
||||
multiplier: 1
|
||||
load_segmentation: true
|
||||
training: false
|
||||
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: True
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: ytvis_val
|
||||
with_seg_masks: true
|
||||
|
||||
|
||||
model:
|
||||
_target_: sam3.model_builder.build_sam3_video_model
|
||||
bpe_path: ${paths.bpe_path}
|
||||
has_presence_token: True
|
||||
geo_encoder_use_img_cross_attn: True
|
||||
apply_temporal_disambiguation: False
|
||||
|
||||
meters:
|
||||
val:
|
||||
ytvis_val:
|
||||
pred_file: # key
|
||||
_target_: sam3.eval.ytvis_eval.YTVISResultsWriter
|
||||
dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json
|
||||
postprocessor: ${scratch.vid_mask_postprocessor}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
|
||||
optim:
|
||||
amp:
|
||||
enabled: True
|
||||
amp_dtype: bfloat16
|
||||
|
||||
|
||||
checkpoint:
|
||||
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
||||
save_freq: 0 # 0 only last checkpoint is saved.
|
||||
|
||||
|
||||
logging:
|
||||
tensorboard_writer:
|
||||
_target_: sam3.train.utils.logger.make_tensorboard_logger
|
||||
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
||||
flush_secs: 120
|
||||
should_log: True
|
||||
wandb_writer: null
|
||||
log_dir: ${launcher.experiment_log_dir}/logs/
|
||||
log_freq: 10
|
||||
|
||||
# ============================================================================
|
||||
# Launcher and Submitit Configuration
|
||||
# ============================================================================
|
||||
|
||||
launcher:
|
||||
num_nodes: 8
|
||||
gpus_per_node: 8
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
multiprocessing_context: forkserver
|
||||
|
||||
submitit:
|
||||
account: null
|
||||
partition: null
|
||||
qos: null
|
||||
timeout_hour: 72
|
||||
use_cluster: True
|
||||
cpus_per_task: 10
|
||||
port_range: [10000, 65000]
|
||||
constraint: null
|
||||
@@ -0,0 +1,64 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_bdd100k/
|
||||
coco_gt: ${paths.base_annotation_path_silver}/silver_bdd100k_merged_test.json
|
||||
img_path: ${paths.silver_img_path}/bdd100k/
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: silver_bdd100k
|
||||
|
||||
meters:
|
||||
val:
|
||||
silver_bdd100k: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_bdd100k
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,64 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_droid/
|
||||
coco_gt: ${paths.base_annotation_path_silver}/silver_droid_merged_test.json
|
||||
img_path: ${paths.silver_img_path}/droid/
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: silver_droid
|
||||
|
||||
meters:
|
||||
val:
|
||||
silver_droid: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_droid
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,64 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_ego4d/
|
||||
coco_gt: ${paths.base_annotation_path_silver}/silver_ego4d_merged_test.json
|
||||
img_path: ${paths.silver_img_path}/ego4d/
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: silver_ego4d
|
||||
|
||||
meters:
|
||||
val:
|
||||
silver_ego4d: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_ego4d
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,64 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_fathomnet/
|
||||
coco_gt: ${paths.base_annotation_path_silver}/silver_fathomnet_test.json
|
||||
img_path: ${paths.silver_img_path}/fathomnet/
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: silver_fathomnet
|
||||
|
||||
meters:
|
||||
val:
|
||||
silver_fathomnet: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_fathomnet
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,64 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_food_rec/
|
||||
coco_gt: ${paths.base_annotation_path_silver}/silver_food_rec_merged_test.json
|
||||
img_path: ${paths.silver_img_path}/food_rec/
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: silver_food_rec
|
||||
|
||||
meters:
|
||||
val:
|
||||
silver_food_rec: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_food_rec
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,64 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_geode/
|
||||
coco_gt: ${paths.base_annotation_path_silver}/silver_geode_merged_test.json
|
||||
img_path: ${paths.silver_img_path}/geode/
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: silver_geode
|
||||
|
||||
meters:
|
||||
val:
|
||||
silver_geode: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_geode
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,64 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_inaturalist/
|
||||
coco_gt: ${paths.base_annotation_path_silver}/silver_inaturalist_merged_test.json
|
||||
img_path: ${paths.silver_img_path}/inaturalist/
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: silver_inaturalist
|
||||
|
||||
meters:
|
||||
val:
|
||||
silver_inaturalist: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_inaturalist
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,64 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_nga_art/
|
||||
coco_gt: ${paths.base_annotation_path_silver}/silver_nga_art_merged_test.json
|
||||
img_path: ${paths.silver_img_path}/nga/
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: silver_nga_art
|
||||
|
||||
meters:
|
||||
val:
|
||||
silver_nga_art: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_nga_art
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,64 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_sav/
|
||||
coco_gt: ${paths.base_annotation_path_silver}/silver_sav_merged_test.json
|
||||
img_path: ${paths.silver_img_path}/sav/
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: silver_sav
|
||||
|
||||
meters:
|
||||
val:
|
||||
silver_sav: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_sav
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "segm"
|
||||
@@ -0,0 +1,64 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /configs/eval_base.yaml
|
||||
- _self_
|
||||
|
||||
# ============================================================================
|
||||
# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct
|
||||
# ============================================================================
|
||||
paths:
|
||||
experiment_log_dir: ${paths.base_experiment_log_dir}/silver_yt1b/
|
||||
coco_gt: ${paths.base_annotation_path_silver}/silver_yt1b_merged_test.json
|
||||
img_path: ${paths.silver_img_path}/yt1b/
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trainer Configuration
|
||||
# ============================================================================
|
||||
|
||||
trainer:
|
||||
data:
|
||||
val:
|
||||
_target_: sam3.train.data.torch_dataset.TorchDataset
|
||||
dataset:
|
||||
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
|
||||
coco_json_loader:
|
||||
_target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP
|
||||
_partial_: true
|
||||
img_folder: ${paths.img_path}
|
||||
ann_file: ${paths.coco_gt}
|
||||
transforms: ${scratch.base_val_transform}
|
||||
max_ann_per_img: 100000
|
||||
multiplier: 1
|
||||
training: false
|
||||
|
||||
shuffle: False
|
||||
batch_size: ${scratch.val_batch_size}
|
||||
num_workers: ${scratch.num_val_workers}
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
collate_fn:
|
||||
_target_: sam3.train.data.collator.collate_fn_api
|
||||
_partial_: true
|
||||
repeats: ${scratch.hybrid_repeats}
|
||||
dict_key: silver_yt1b
|
||||
|
||||
meters:
|
||||
val:
|
||||
silver_yt1b: # this key matches the "dict_key" in the dataloader's collate function
|
||||
cgf1:
|
||||
_target_: sam3.eval.coco_writer.PredictionDumper
|
||||
iou_type: "segm"
|
||||
dump_dir: ${launcher.experiment_log_dir}/dumps/silver_yt1b
|
||||
merge_predictions: True
|
||||
postprocessor: ${scratch.mask_postprocessor_thresholded}
|
||||
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
|
||||
maxdets: 1000000 # no limit
|
||||
pred_file_evaluators:
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "bbox"
|
||||
- _target_: sam3.eval.cgf1_eval.CGF1Evaluator
|
||||
gt_path: ${paths.coco_gt}
|
||||
iou_type: "segm"
|
||||
1
sam3/train/data/__init__.py
Normal file
1
sam3/train/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
465
sam3/train/data/coco_json_loaders.py
Normal file
465
sam3/train/data/coco_json_loaders.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from pycocotools import mask as mask_util
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Utility Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def convert_boxlist_to_normalized_tensor(box_list, image_width, image_height):
|
||||
"""
|
||||
Converts a list of bounding boxes to a normalized PyTorch tensor.
|
||||
|
||||
Args:
|
||||
box_list (list of list or tuples): Each box is [x_min, y_min, x_max, y_max].
|
||||
image_width (int or float): Width of the image.
|
||||
image_height (int or float): Height of the image.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Normalized tensor of shape (N, 4), values in [0, 1].
|
||||
"""
|
||||
boxes = torch.tensor(box_list, dtype=torch.float32)
|
||||
boxes[:, [0, 2]] /= image_width # x_min, x_max
|
||||
boxes[:, [1, 3]] /= image_height # y_min, y_max
|
||||
boxes = boxes.clamp(0, 1)
|
||||
return boxes
|
||||
|
||||
|
||||
def load_coco_and_group_by_image(json_path: str) -> Tuple[List[Dict], Dict[int, str]]:
|
||||
"""
|
||||
Load COCO JSON file and group annotations by image.
|
||||
|
||||
Args:
|
||||
json_path (str): Path to COCO JSON file.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of dicts with 'image' and 'annotations' keys
|
||||
- Dict mapping category IDs to category names
|
||||
"""
|
||||
with open(json_path, "r") as f:
|
||||
coco = json.load(f)
|
||||
|
||||
images = {img["id"]: img for img in coco["images"]}
|
||||
|
||||
anns_by_image = defaultdict(list)
|
||||
for ann in coco["annotations"]:
|
||||
anns_by_image[ann["image_id"]].append(ann)
|
||||
|
||||
sorted_image_ids = sorted(images.keys())
|
||||
|
||||
grouped = []
|
||||
for image_id in sorted_image_ids:
|
||||
image_info = images[image_id]
|
||||
grouped.append(
|
||||
{"image": image_info, "annotations": anns_by_image.get(image_id, [])}
|
||||
)
|
||||
|
||||
cat_id_to_name = {cat["id"]: cat["name"] for cat in coco["categories"]}
|
||||
|
||||
return grouped, cat_id_to_name
|
||||
|
||||
|
||||
def ann_to_rle(segm, im_info: Dict) -> Dict:
|
||||
"""
|
||||
Convert annotation which can be polygons or uncompressed RLE to RLE.
|
||||
|
||||
Args:
|
||||
segm: Segmentation data (polygon list or RLE dict)
|
||||
im_info (dict): Image info containing 'height' and 'width'
|
||||
|
||||
Returns:
|
||||
RLE encoded segmentation
|
||||
"""
|
||||
h, w = im_info["height"], im_info["width"]
|
||||
|
||||
if isinstance(segm, list):
|
||||
# Polygon - merge all parts into one mask RLE code
|
||||
rles = mask_util.frPyObjects(segm, h, w)
|
||||
rle = mask_util.merge(rles)
|
||||
elif isinstance(segm["counts"], list):
|
||||
# Uncompressed RLE
|
||||
rle = mask_util.frPyObjects(segm, h, w)
|
||||
else:
|
||||
# Already RLE
|
||||
rle = segm
|
||||
|
||||
return rle
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# COCO Training API
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class COCO_FROM_JSON:
|
||||
"""
|
||||
COCO training API for loading box-only annotations from JSON.
|
||||
Groups all annotations per image and creates queries per category.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
annotation_file,
|
||||
prompts=None,
|
||||
include_negatives=True,
|
||||
category_chunk_size=None,
|
||||
):
|
||||
"""
|
||||
Initialize the COCO training API.
|
||||
|
||||
Args:
|
||||
annotation_file (str): Path to COCO JSON annotation file
|
||||
prompts: Optional custom prompts for categories
|
||||
include_negatives (bool): Whether to include negative examples (categories with no instances)
|
||||
"""
|
||||
self._raw_data, self._cat_idx_to_text = load_coco_and_group_by_image(
|
||||
annotation_file
|
||||
)
|
||||
self._sorted_cat_ids = sorted(list(self._cat_idx_to_text.keys()))
|
||||
self.prompts = None
|
||||
self.include_negatives = include_negatives
|
||||
self.category_chunk_size = (
|
||||
category_chunk_size
|
||||
if category_chunk_size is not None
|
||||
else len(self._sorted_cat_ids)
|
||||
)
|
||||
self.category_chunks = [
|
||||
self._sorted_cat_ids[i : i + self.category_chunk_size]
|
||||
for i in range(0, len(self._sorted_cat_ids), self.category_chunk_size)
|
||||
]
|
||||
if prompts is not None:
|
||||
prompts = eval(prompts)
|
||||
self.prompts = {}
|
||||
for loc_dict in prompts:
|
||||
self.prompts[int(loc_dict["id"])] = loc_dict["name"]
|
||||
assert len(self.prompts) == len(
|
||||
self._sorted_cat_ids
|
||||
), "Number of prompts must match number of categories"
|
||||
|
||||
def getDatapointIds(self):
|
||||
"""Return all datapoint indices for training."""
|
||||
return list(range(len(self._raw_data) * len(self.category_chunks)))
|
||||
|
||||
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
|
||||
"""
|
||||
Load queries and annotations for a specific datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
Tuple of (queries, annotations) lists
|
||||
"""
|
||||
img_idx = idx // len(self.category_chunks)
|
||||
chunk_idx = idx % len(self.category_chunks)
|
||||
cat_chunk = self.category_chunks[chunk_idx]
|
||||
|
||||
queries = []
|
||||
annotations = []
|
||||
|
||||
query_template = {
|
||||
"id": None,
|
||||
"original_cat_id": None,
|
||||
"object_ids_output": None,
|
||||
"query_text": None,
|
||||
"query_processing_order": 0,
|
||||
"ptr_x_query_id": None,
|
||||
"ptr_y_query_id": None,
|
||||
"image_id": 0, # Single image per datapoint
|
||||
"input_box": None,
|
||||
"input_box_label": None,
|
||||
"input_points": None,
|
||||
"is_exhaustive": True,
|
||||
}
|
||||
|
||||
annot_template = {
|
||||
"image_id": 0,
|
||||
"bbox": None, # Normalized bbox in xywh
|
||||
"area": None, # Unnormalized area
|
||||
"segmentation": None, # RLE encoded
|
||||
"object_id": None,
|
||||
"is_crowd": None,
|
||||
"id": None,
|
||||
}
|
||||
|
||||
raw_annotations = self._raw_data[img_idx]["annotations"]
|
||||
image_info = self._raw_data[img_idx]["image"]
|
||||
width, height = image_info["width"], image_info["height"]
|
||||
|
||||
# Group annotations by category
|
||||
cat_id_to_anns = defaultdict(list)
|
||||
for ann in raw_annotations:
|
||||
cat_id_to_anns[ann["category_id"]].append(ann)
|
||||
|
||||
annotations_by_cat_sorted = [
|
||||
(cat_id, cat_id_to_anns[cat_id]) for cat_id in cat_chunk
|
||||
]
|
||||
|
||||
for cat_id, anns in annotations_by_cat_sorted:
|
||||
if len(anns) == 0 and not self.include_negatives:
|
||||
continue
|
||||
|
||||
cur_ann_ids = []
|
||||
|
||||
# Create annotations for this category
|
||||
for ann in anns:
|
||||
annotation = annot_template.copy()
|
||||
annotation["id"] = len(annotations)
|
||||
annotation["object_id"] = annotation["id"]
|
||||
annotation["is_crowd"] = ann["iscrowd"]
|
||||
|
||||
normalized_boxes = convert_boxlist_to_normalized_tensor(
|
||||
[ann["bbox"]], width, height
|
||||
)
|
||||
bbox = normalized_boxes[0]
|
||||
|
||||
annotation["area"] = (bbox[2] * bbox[3]).item()
|
||||
annotation["bbox"] = bbox
|
||||
|
||||
if (
|
||||
"segmentation" in ann
|
||||
and ann["segmentation"] is not None
|
||||
and ann["segmentation"] != []
|
||||
):
|
||||
annotation["segmentation"] = ann_to_rle(
|
||||
ann["segmentation"], im_info=image_info
|
||||
)
|
||||
|
||||
annotations.append(annotation)
|
||||
cur_ann_ids.append(annotation["id"])
|
||||
|
||||
# Create query for this category
|
||||
query = query_template.copy()
|
||||
query["id"] = len(queries)
|
||||
query["original_cat_id"] = cat_id
|
||||
query["query_text"] = (
|
||||
self._cat_idx_to_text[cat_id]
|
||||
if self.prompts is None
|
||||
else self.prompts[cat_id]
|
||||
)
|
||||
query["object_ids_output"] = cur_ann_ids
|
||||
queries.append(query)
|
||||
|
||||
return queries, annotations
|
||||
|
||||
def loadImagesFromDatapoint(self, idx):
|
||||
"""
|
||||
Load image information for a specific datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
List containing image info dict
|
||||
"""
|
||||
img_idx = idx // len(self.category_chunks)
|
||||
img_data = self._raw_data[img_idx]["image"]
|
||||
images = [
|
||||
{
|
||||
"id": 0,
|
||||
"file_name": img_data["file_name"],
|
||||
"original_img_id": img_data["id"],
|
||||
"coco_img_id": img_data["id"],
|
||||
}
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SAM3 Evaluation APIs
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SAM3_EVAL_API_FROM_JSON_NP:
|
||||
"""
|
||||
SAM3 evaluation API for loading noun phrase queries from JSON.
|
||||
"""
|
||||
|
||||
def __init__(self, annotation_file):
|
||||
"""
|
||||
Initialize the SAM3 evaluation API.
|
||||
|
||||
Args:
|
||||
annotation_file (str): Path to SAM3 JSON annotation file
|
||||
"""
|
||||
with open(annotation_file, "r") as f:
|
||||
data = json.load(f)
|
||||
self._image_data = data["images"]
|
||||
|
||||
def getDatapointIds(self):
|
||||
"""Return all datapoint indices."""
|
||||
return list(range(len(self._image_data)))
|
||||
|
||||
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
|
||||
"""
|
||||
Load queries and annotations for a specific datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
Tuple of (queries, annotations) lists
|
||||
"""
|
||||
cur_img_data = self._image_data[idx]
|
||||
queries = []
|
||||
annotations = []
|
||||
|
||||
query_template = {
|
||||
"id": None,
|
||||
"original_cat_id": None,
|
||||
"object_ids_output": None,
|
||||
"query_text": None,
|
||||
"query_processing_order": 0,
|
||||
"ptr_x_query_id": None,
|
||||
"ptr_y_query_id": None,
|
||||
"image_id": 0,
|
||||
"input_box": None,
|
||||
"input_box_label": None,
|
||||
"input_points": None,
|
||||
"is_exhaustive": True,
|
||||
}
|
||||
|
||||
# Create query
|
||||
query = query_template.copy()
|
||||
query["id"] = len(queries)
|
||||
query["original_cat_id"] = int(cur_img_data["queried_category"])
|
||||
query["query_text"] = cur_img_data["text_input"]
|
||||
query["object_ids_output"] = []
|
||||
queries.append(query)
|
||||
|
||||
return queries, annotations
|
||||
|
||||
def loadImagesFromDatapoint(self, idx):
|
||||
"""
|
||||
Load image information for a specific datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
List containing image info dict
|
||||
"""
|
||||
img_data = self._image_data[idx]
|
||||
images = [
|
||||
{
|
||||
"id": 0,
|
||||
"file_name": img_data["file_name"],
|
||||
"original_img_id": img_data["id"],
|
||||
"coco_img_id": img_data["id"],
|
||||
}
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
class SAM3_VEVAL_API_FROM_JSON_NP:
|
||||
"""
|
||||
SAM3 video evaluation API for loading noun phrase queries from JSON.
|
||||
"""
|
||||
|
||||
def __init__(self, annotation_file):
|
||||
"""
|
||||
Initialize the SAM3 video evaluation API.
|
||||
|
||||
Args:
|
||||
annotation_file (str): Path to SAM3 video JSON annotation file
|
||||
"""
|
||||
with open(annotation_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert "video_np_pairs" in data, "Incorrect data format"
|
||||
|
||||
self._video_data = data["videos"]
|
||||
self._video_id_to_np_ids = defaultdict(list)
|
||||
self._cat_id_to_np = {}
|
||||
|
||||
for cat_dict in data["categories"]:
|
||||
self._cat_id_to_np[cat_dict["id"]] = cat_dict["name"]
|
||||
|
||||
for video_np_dict in data["video_np_pairs"]:
|
||||
self._video_id_to_np_ids[video_np_dict["video_id"]].append(
|
||||
video_np_dict["category_id"]
|
||||
)
|
||||
assert (
|
||||
self._cat_id_to_np[video_np_dict["category_id"]]
|
||||
== video_np_dict["noun_phrase"]
|
||||
), "Category name does not match text input"
|
||||
|
||||
def getDatapointIds(self):
|
||||
"""Return all datapoint indices."""
|
||||
return list(range(len(self._video_data)))
|
||||
|
||||
def loadQueriesAndAnnotationsFromDatapoint(self, idx):
|
||||
"""
|
||||
Load queries and annotations for a specific video datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
Tuple of (queries, annotations) lists
|
||||
"""
|
||||
cur_vid_data = self._video_data[idx]
|
||||
queries = []
|
||||
annotations = []
|
||||
|
||||
query_template = {
|
||||
"id": None,
|
||||
"original_cat_id": None,
|
||||
"object_ids_output": None,
|
||||
"query_text": None,
|
||||
"query_processing_order": 0,
|
||||
"ptr_x_query_id": None,
|
||||
"ptr_y_query_id": None,
|
||||
"image_id": 0,
|
||||
"input_box": None,
|
||||
"input_box_label": None,
|
||||
"input_points": None,
|
||||
"is_exhaustive": True,
|
||||
}
|
||||
|
||||
all_np_ids = self._video_id_to_np_ids[cur_vid_data["id"]]
|
||||
|
||||
for np_id in all_np_ids:
|
||||
text_input = self._cat_id_to_np[np_id]
|
||||
|
||||
for i, image_path in enumerate(cur_vid_data["file_names"]):
|
||||
query = query_template.copy()
|
||||
query["id"] = len(queries)
|
||||
query["original_cat_id"] = np_id
|
||||
query["query_text"] = text_input
|
||||
query["image_id"] = i
|
||||
query["query_processing_order"] = i
|
||||
query["object_ids_output"] = []
|
||||
queries.append(query)
|
||||
|
||||
return queries, annotations
|
||||
|
||||
def loadImagesFromDatapoint(self, idx):
|
||||
"""
|
||||
Load image information for a specific video datapoint.
|
||||
|
||||
Args:
|
||||
idx (int): Datapoint index
|
||||
|
||||
Returns:
|
||||
List containing image info dicts for all frames
|
||||
"""
|
||||
video_data = self._video_data[idx]
|
||||
images = [
|
||||
{
|
||||
"id": i,
|
||||
"file_name": file_name,
|
||||
"original_img_id": video_data["id"],
|
||||
"coco_img_id": video_data["id"],
|
||||
}
|
||||
for i, file_name in enumerate(video_data["file_names"])
|
||||
]
|
||||
return images
|
||||
360
sam3/train/data/collator.py
Normal file
360
sam3/train/data/collator.py
Normal file
@@ -0,0 +1,360 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
|
||||
from typing import Any, get_args, get_origin, List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.data_misc import (
|
||||
BatchedDatapoint,
|
||||
BatchedFindTarget,
|
||||
BatchedInferenceMetadata,
|
||||
FindStage,
|
||||
)
|
||||
|
||||
from .sam3_image_dataset import Datapoint
|
||||
|
||||
|
||||
MyTensor = Union[torch.Tensor, List[Any]]
|
||||
|
||||
|
||||
def convert_my_tensors(obj):
|
||||
def is_optional_field(field) -> bool:
|
||||
return get_origin(field) is Union and type(None) in get_args(field)
|
||||
|
||||
for field in fields(obj):
|
||||
if is_dataclass(getattr(obj, field.name)):
|
||||
convert_my_tensors(getattr(obj, field.name))
|
||||
continue
|
||||
|
||||
field_type = field.type
|
||||
if is_optional_field(field.type):
|
||||
field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type
|
||||
|
||||
if field_type != MyTensor or getattr(obj, field.name) is None:
|
||||
continue
|
||||
|
||||
elif len(getattr(obj, field.name)) and isinstance(
|
||||
getattr(obj, field.name)[0], torch.Tensor
|
||||
):
|
||||
stack_dim = 0
|
||||
if field.name in [
|
||||
"input_boxes",
|
||||
"input_boxes_label",
|
||||
]:
|
||||
stack_dim = 1
|
||||
setattr(
|
||||
obj,
|
||||
field.name,
|
||||
torch.stack(getattr(obj, field.name), dim=stack_dim).to(
|
||||
getattr(obj, field.name + "__type")
|
||||
),
|
||||
)
|
||||
else:
|
||||
setattr(
|
||||
obj,
|
||||
field.name,
|
||||
torch.as_tensor(
|
||||
getattr(obj, field.name), dtype=getattr(obj, field.name + "__type")
|
||||
),
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
def packed_to_padded_naive(boxes_packed, num_boxes, fill_value=0):
|
||||
"""
|
||||
Convert a packed tensor of bounding boxes to a padded tensor of bounding
|
||||
boxes. Naive implementation using a loop.
|
||||
|
||||
Inputs:
|
||||
- boxes_packed: Tensor of shape (N_1 + ... + N_B, 4)
|
||||
- num_boxes: Tensor of shape (B,) where num_boxes[i] = N_i
|
||||
|
||||
Returns:
|
||||
- boxes_padded: Tensor of shape (B, N_max, 4) where N_max = max_i N_i
|
||||
"""
|
||||
B = num_boxes.shape[0]
|
||||
Ns = num_boxes.tolist()
|
||||
|
||||
boxes_padded = boxes_packed.new_zeros(B, max(Ns), *boxes_packed.shape[1:])
|
||||
if fill_value != 0:
|
||||
boxes_padded[...] = fill_value
|
||||
prev_idx = 0
|
||||
for i in range(B):
|
||||
next_idx = prev_idx + Ns[i]
|
||||
boxes_padded[i, : Ns[i]] = boxes_packed[prev_idx:next_idx]
|
||||
prev_idx = next_idx
|
||||
return boxes_padded
|
||||
|
||||
|
||||
def pad_tensor_list_to_longest(
|
||||
tensors: List[torch.Tensor], dim=0, pad_val=0
|
||||
) -> List[torch.Tensor]:
|
||||
# Edits the list in-place
|
||||
if not tensors:
|
||||
return tensors
|
||||
pad_len = max(t.shape[dim] for t in tensors)
|
||||
for i in range(len(tensors)):
|
||||
n_dims = len(tensors[i].shape)
|
||||
n_right_dims = (n_dims - 1) - (n_dims + dim) % n_dims
|
||||
n_pad = pad_len - tensors[i].shape[dim]
|
||||
pad_tuple = tuple([0] * 2 * n_right_dims + [0, n_pad])
|
||||
tensors[i] = torch.nn.functional.pad(tensors[i], pad_tuple, value=pad_val)
|
||||
return tensors
|
||||
|
||||
|
||||
def collate_fn_api_with_chunking(
|
||||
batch,
|
||||
num_chunks,
|
||||
dict_key,
|
||||
with_seg_masks=False,
|
||||
input_points_embedding_dim=257,
|
||||
repeats: int = 0,
|
||||
load_image_in_fp16: bool = False,
|
||||
):
|
||||
assert num_chunks >= 1, "num_chunks must be >= 1"
|
||||
|
||||
# split the batch into num_chunks chunks
|
||||
batch_chunks = [batch[i::num_chunks] for i in range(num_chunks)]
|
||||
|
||||
# collate each chunk
|
||||
collated_chunks = [
|
||||
collate_fn_api(
|
||||
chunk,
|
||||
dict_key,
|
||||
with_seg_masks,
|
||||
input_points_embedding_dim,
|
||||
repeats,
|
||||
# ptr_behaviour,
|
||||
load_image_in_fp16,
|
||||
)
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
return collated_chunks
|
||||
|
||||
|
||||
def collate_fn_api(
|
||||
batch: List[Datapoint],
|
||||
dict_key,
|
||||
with_seg_masks=False,
|
||||
input_points_embedding_dim=257,
|
||||
repeats: int = 0,
|
||||
load_image_in_fp16: bool = False,
|
||||
):
|
||||
# img_batch = torch.stack(sum([[img.data for img in v.images] for v in batch], []))
|
||||
img_batch = []
|
||||
text_batch = []
|
||||
raw_images = None
|
||||
|
||||
num_stages = (
|
||||
max(q.query_processing_order for data in batch for q in data.find_queries) + 1
|
||||
)
|
||||
|
||||
stages = [
|
||||
FindStage(
|
||||
img_ids=[],
|
||||
text_ids=[],
|
||||
input_boxes=[],
|
||||
input_boxes_label=[],
|
||||
input_boxes_mask=[],
|
||||
input_points=[],
|
||||
input_points_mask=[],
|
||||
object_ids=[],
|
||||
)
|
||||
for _ in range(num_stages)
|
||||
]
|
||||
find_targets = [
|
||||
BatchedFindTarget(
|
||||
num_boxes=[],
|
||||
boxes=[],
|
||||
boxes_padded=[],
|
||||
is_exhaustive=[],
|
||||
segments=[],
|
||||
semantic_segments=[],
|
||||
is_valid_segment=[],
|
||||
repeated_boxes=[],
|
||||
object_ids=[],
|
||||
object_ids_padded=[],
|
||||
)
|
||||
for _ in range(num_stages)
|
||||
]
|
||||
find_metadatas = [
|
||||
BatchedInferenceMetadata(
|
||||
coco_image_id=[],
|
||||
original_size=[],
|
||||
object_id=[],
|
||||
frame_index=[],
|
||||
original_image_id=[],
|
||||
original_category_id=[],
|
||||
is_conditioning_only=[],
|
||||
)
|
||||
for _ in range(num_stages)
|
||||
]
|
||||
|
||||
offset_img_id = 0
|
||||
offset_query_id = [0 for _ in range(num_stages)]
|
||||
for i, data in enumerate(batch):
|
||||
img_batch.extend([img.data for img in data.images])
|
||||
|
||||
if data.raw_images is not None:
|
||||
if raw_images is None:
|
||||
raw_images = []
|
||||
raw_images.extend(data.raw_images)
|
||||
|
||||
# Conversion of query_ids indexing in a datapoint to query_ids indexing in a stage
|
||||
datapoint_query_id_2_stage_query_id = []
|
||||
for q in data.find_queries:
|
||||
stage_id = q.query_processing_order
|
||||
datapoint_query_id_2_stage_query_id.append(offset_query_id[stage_id])
|
||||
offset_query_id[stage_id] += 1
|
||||
|
||||
for j, q in enumerate(data.find_queries):
|
||||
stage_id = q.query_processing_order
|
||||
stages[stage_id].img_ids.append(q.image_id + offset_img_id)
|
||||
if q.query_text not in text_batch:
|
||||
text_batch.append(q.query_text)
|
||||
stages[stage_id].text_ids.append(text_batch.index(q.query_text))
|
||||
|
||||
assert (
|
||||
q.inference_metadata is not None
|
||||
), "inference_metadata must be provided when FindQueryLoaded is created."
|
||||
for f in fields(q.inference_metadata):
|
||||
getattr(find_metadatas[stage_id], f.name).append(
|
||||
getattr(q.inference_metadata, f.name)
|
||||
)
|
||||
|
||||
if q.input_bbox is not None:
|
||||
assert q.input_bbox.numel() % 4 == 0
|
||||
assert q.input_bbox_label is not None
|
||||
nb_boxes = q.input_bbox.numel() // 4
|
||||
assert len(q.input_bbox_label) == nb_boxes
|
||||
stages[stage_id].input_boxes.append(q.input_bbox.view(nb_boxes, 4))
|
||||
stages[stage_id].input_boxes_label.append(
|
||||
q.input_bbox_label.view(nb_boxes)
|
||||
)
|
||||
stages[stage_id].input_boxes_mask.append(
|
||||
torch.zeros(nb_boxes, dtype=torch.bool)
|
||||
)
|
||||
else:
|
||||
stages[stage_id].input_boxes.append(torch.zeros(0, 4))
|
||||
stages[stage_id].input_boxes_label.append(
|
||||
torch.zeros(0, dtype=torch.bool)
|
||||
)
|
||||
stages[stage_id].input_boxes_mask.append(
|
||||
torch.ones(0, dtype=torch.bool)
|
||||
)
|
||||
|
||||
if q.input_points is not None:
|
||||
stages[stage_id].input_points.append(
|
||||
q.input_points.squeeze(0) # Strip a trivial batch index
|
||||
)
|
||||
# All masks will be padded up to the longest length
|
||||
# with 1s before final conversion to batchd tensors
|
||||
stages[stage_id].input_points_mask.append(
|
||||
torch.zeros(q.input_points.shape[1])
|
||||
)
|
||||
else:
|
||||
stages[stage_id].input_points.append(
|
||||
torch.empty(0, input_points_embedding_dim)
|
||||
)
|
||||
stages[stage_id].input_points_mask.append(torch.empty(0))
|
||||
|
||||
current_out_boxes = []
|
||||
current_out_object_ids = []
|
||||
# Set the object ids referred to by this query
|
||||
stages[stage_id].object_ids.append(q.object_ids_output)
|
||||
for object_id in q.object_ids_output:
|
||||
current_out_boxes.append(
|
||||
data.images[q.image_id].objects[object_id].bbox
|
||||
)
|
||||
current_out_object_ids.append(object_id)
|
||||
find_targets[stage_id].boxes.extend(current_out_boxes)
|
||||
find_targets[stage_id].object_ids.extend(current_out_object_ids)
|
||||
if repeats > 0:
|
||||
for _ in range(repeats):
|
||||
find_targets[stage_id].repeated_boxes.extend(current_out_boxes)
|
||||
find_targets[stage_id].num_boxes.append(len(current_out_boxes))
|
||||
find_targets[stage_id].is_exhaustive.append(q.is_exhaustive)
|
||||
|
||||
if with_seg_masks:
|
||||
current_seg_mask = []
|
||||
current_is_valid_segment = []
|
||||
for object_id in q.object_ids_output:
|
||||
seg_mask = data.images[q.image_id].objects[object_id].segment
|
||||
if seg_mask is not None:
|
||||
current_seg_mask.append(seg_mask)
|
||||
current_is_valid_segment.append(1)
|
||||
else:
|
||||
dummy_mask = torch.zeros(
|
||||
data.images[q.image_id].data.shape[-2:], dtype=torch.bool
|
||||
)
|
||||
current_seg_mask.append(dummy_mask)
|
||||
current_is_valid_segment.append(0)
|
||||
find_targets[stage_id].segments.extend(current_seg_mask)
|
||||
find_targets[stage_id].is_valid_segment.extend(current_is_valid_segment)
|
||||
else:
|
||||
# We are not loading segmentation masks
|
||||
find_targets[stage_id].segments = None
|
||||
find_targets[stage_id].is_valid_segment = None
|
||||
|
||||
if q.semantic_target is not None:
|
||||
find_targets[stage_id].semantic_segments.append(q.semantic_target)
|
||||
|
||||
offset_img_id += len(data.images)
|
||||
|
||||
# Pad input points to equal sequence lengths
|
||||
for i in range(len(stages)):
|
||||
stages[i].input_points = pad_tensor_list_to_longest(
|
||||
stages[i].input_points, dim=0, pad_val=0
|
||||
)
|
||||
# Masked-out regions indicated by 1s.
|
||||
stages[i].input_points_mask = pad_tensor_list_to_longest(
|
||||
stages[i].input_points_mask, dim=0, pad_val=1
|
||||
)
|
||||
|
||||
# Pad input boxes to equal sequence lengths
|
||||
for i in range(len(stages)):
|
||||
stages[i].input_boxes = pad_tensor_list_to_longest(
|
||||
stages[i].input_boxes, dim=0, pad_val=0
|
||||
)
|
||||
stages[i].input_boxes_label = pad_tensor_list_to_longest(
|
||||
stages[i].input_boxes_label, dim=0, pad_val=0
|
||||
)
|
||||
# Masked-out regions indicated by 1s.
|
||||
stages[i].input_boxes_mask = pad_tensor_list_to_longest(
|
||||
stages[i].input_boxes_mask, dim=0, pad_val=1
|
||||
)
|
||||
|
||||
# Convert to tensors
|
||||
for i in range(len(stages)):
|
||||
stages[i] = convert_my_tensors(stages[i])
|
||||
find_targets[i] = convert_my_tensors(find_targets[i])
|
||||
find_metadatas[i] = convert_my_tensors(find_metadatas[i])
|
||||
# get padded representation for the boxes
|
||||
find_targets[i].boxes_padded = packed_to_padded_naive(
|
||||
find_targets[i].boxes.view(-1, 4), find_targets[i].num_boxes
|
||||
)
|
||||
find_targets[i].object_ids_padded = packed_to_padded_naive(
|
||||
find_targets[i].object_ids, find_targets[i].num_boxes, fill_value=-1
|
||||
)
|
||||
|
||||
# Finalize the image batch
|
||||
# check sizes
|
||||
for img in img_batch[1:]:
|
||||
assert img.shape == img_batch[0].shape, "All images must have the same size"
|
||||
image_batch = torch.stack(img_batch)
|
||||
if load_image_in_fp16:
|
||||
# Optionally, cast the image tensors to fp16, which helps save GPU memory on
|
||||
# long videos with thousands of frames (where image tensors could be several GBs)
|
||||
image_batch = image_batch.half()
|
||||
|
||||
return {
|
||||
dict_key: BatchedDatapoint(
|
||||
img_batch=image_batch,
|
||||
find_text_batch=text_batch,
|
||||
find_inputs=stages,
|
||||
find_targets=find_targets,
|
||||
find_metadatas=find_metadatas,
|
||||
raw_images=raw_images,
|
||||
)
|
||||
}
|
||||
528
sam3/train/data/sam3_image_dataset.py
Normal file
528
sam3/train/data/sam3_image_dataset.py
Normal file
@@ -0,0 +1,528 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""Dataset class for modulated detection"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision
|
||||
from decord import cpu, VideoReader
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from PIL.Image import DecompressionBombError
|
||||
|
||||
from sam3.model.box_ops import box_xywh_to_xyxy
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
from .coco_json_loaders import COCO_FROM_JSON
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceMetadata:
|
||||
"""Metadata required for postprocessing"""
|
||||
|
||||
# Coco id that corresponds to the "image" for evaluation by the coco evaluator
|
||||
# This is used for our own "class agnostic" evaluation
|
||||
coco_image_id: int
|
||||
|
||||
# id in the original dataset, such that we can use the original evaluator
|
||||
original_image_id: int
|
||||
|
||||
# Original category id (if we want to use the original evaluator)
|
||||
original_category_id: int
|
||||
|
||||
# Size of the raw image (height, width)
|
||||
original_size: Tuple[int, int]
|
||||
|
||||
# Id of the object in the media
|
||||
object_id: int
|
||||
|
||||
# Index of the frame in the media (0 if single image)
|
||||
frame_index: int
|
||||
|
||||
# Whether it is for conditioning only, e.g., 0-th frame in TA is for conditioning
|
||||
# as we assume GT available in frame 0.
|
||||
is_conditioning_only: Optional[bool] = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FindQuery:
|
||||
query_text: str
|
||||
|
||||
image_id: int
|
||||
|
||||
# In case of a find query, the list of object ids that have to be predicted
|
||||
object_ids_output: List[int]
|
||||
|
||||
# This is "instance exhaustivity".
|
||||
# true iff all instances are separable and annotated
|
||||
# See below the slightly different "pixel exhaustivity"
|
||||
is_exhaustive: bool
|
||||
|
||||
# The order in which the queries are processed (only meaningful for video)
|
||||
query_processing_order: int = 0
|
||||
|
||||
# Input geometry, initially in denormalized XYXY format. Then
|
||||
# 1. converted to normalized CxCyWH by the Normalize transform
|
||||
input_bbox: Optional[torch.Tensor] = None
|
||||
input_bbox_label: Optional[torch.Tensor] = None
|
||||
|
||||
# Only for the PVS task
|
||||
input_points: Optional[torch.Tensor] = None
|
||||
|
||||
semantic_target: Optional[torch.Tensor] = None
|
||||
|
||||
# pixel exhaustivity: true iff the union of all segments (including crowds)
|
||||
# covers every pixel belonging to the target class
|
||||
# Note that instance_exhaustive implies pixel_exhaustive
|
||||
is_pixel_exhaustive: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FindQueryLoaded(FindQuery):
|
||||
# Must have default value since FindQuery has entries with default values
|
||||
inference_metadata: Optional[InferenceMetadata] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Object:
|
||||
# Initially in denormalized XYXY format, gets converted to normalized CxCyWH by the Normalize transform
|
||||
bbox: torch.Tensor
|
||||
area: float
|
||||
|
||||
# Id of the object in the media
|
||||
object_id: Optional[int] = -1
|
||||
|
||||
# Index of the frame in the media (0 if single image)
|
||||
frame_index: Optional[int] = -1
|
||||
|
||||
segment: Optional[Union[torch.Tensor, dict]] = None # RLE dict or binary mask
|
||||
|
||||
is_crowd: bool = False
|
||||
|
||||
source: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Image:
|
||||
data: Union[torch.Tensor, PILImage.Image]
|
||||
objects: List[Object]
|
||||
size: Tuple[int, int] # (height, width)
|
||||
|
||||
# For blurring augmentation
|
||||
blurring_mask: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Datapoint:
|
||||
"""Refers to an image/video and all its annotations"""
|
||||
|
||||
find_queries: List[FindQueryLoaded]
|
||||
images: List[Image]
|
||||
raw_images: Optional[List[PILImage.Image]] = None
|
||||
|
||||
|
||||
class CustomCocoDetectionAPI(VisionDataset):
|
||||
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory where images are downloaded to.
|
||||
annFile (string): Path to json annotation file.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.ToTensor``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
||||
and returns a transformed version.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
annFile: str,
|
||||
load_segmentation: bool,
|
||||
fix_fname: bool = False,
|
||||
training: bool = True,
|
||||
blurring_masks_path: Optional[str] = None,
|
||||
use_caching: bool = True,
|
||||
zstd_dict_path=None,
|
||||
filter_query=None,
|
||||
coco_json_loader: Callable = COCO_FROM_JSON,
|
||||
limit_ids: int = None,
|
||||
) -> None:
|
||||
super().__init__(root)
|
||||
|
||||
self.annFile = annFile
|
||||
self.use_caching = use_caching
|
||||
self.zstd_dict_path = zstd_dict_path
|
||||
|
||||
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
|
||||
self.load_segmentation = load_segmentation
|
||||
self.fix_fname = fix_fname
|
||||
self.filter_query = filter_query
|
||||
|
||||
self.coco = None
|
||||
self.coco_json_loader = coco_json_loader
|
||||
self.limit_ids = limit_ids
|
||||
self.set_sharded_annotation_file(0)
|
||||
self.training = training
|
||||
self.blurring_masks_path = blurring_masks_path
|
||||
|
||||
def _load_images(
|
||||
self, datapoint_id: int, img_ids_to_load: Optional[Set[int]] = None
|
||||
) -> Tuple[List[Tuple[int, PILImage.Image]], List[Dict[str, Any]]]:
|
||||
all_images = []
|
||||
all_img_metadata = []
|
||||
for current_meta in self.coco.loadImagesFromDatapoint(datapoint_id):
|
||||
img_id = current_meta["id"]
|
||||
if img_ids_to_load is not None and img_id not in img_ids_to_load:
|
||||
continue
|
||||
if self.fix_fname:
|
||||
current_meta["file_name"] = current_meta["file_name"].split("/")[-1]
|
||||
path = current_meta["file_name"]
|
||||
if self.blurring_masks_path is not None:
|
||||
mask_fname = os.path.basename(path).replace(".jpg", "-mask.json")
|
||||
mask_path = os.path.join(self.blurring_masks_path, mask_fname)
|
||||
if os.path.exists(mask_path):
|
||||
with open(mask_path, "r") as fopen:
|
||||
current_meta["blurring_mask"] = json.load(fopen)
|
||||
|
||||
all_img_metadata.append(current_meta)
|
||||
path = os.path.join(self.root, path)
|
||||
try:
|
||||
if ".mp4" in path and path[-4:] == ".mp4":
|
||||
# Going to load a video frame
|
||||
video_path, frame = path.split("@")
|
||||
video = VideoReader(video_path, ctx=cpu(0))
|
||||
# Convert to PIL image
|
||||
all_images.append(
|
||||
(
|
||||
img_id,
|
||||
torchvision.transforms.ToPILImage()(
|
||||
video[int(frame)].asnumpy()
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
with g_pathmgr.open(path, "rb") as fopen:
|
||||
all_images.append((img_id, PILImage.open(fopen).convert("RGB")))
|
||||
except FileNotFoundError as e:
|
||||
print(f"File not found: {path} from dataset: {self.annFile}")
|
||||
raise e
|
||||
|
||||
return all_images, all_img_metadata
|
||||
|
||||
def set_curr_epoch(self, epoch: int):
|
||||
self.curr_epoch = epoch
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
pass
|
||||
|
||||
def set_sharded_annotation_file(self, data_epoch: int):
|
||||
if self.coco is not None:
|
||||
return
|
||||
|
||||
assert g_pathmgr.isfile(
|
||||
self.annFile
|
||||
), f"please provide valid annotation file. Missing: {self.annFile}"
|
||||
annFile = g_pathmgr.get_local_path(self.annFile)
|
||||
|
||||
if self.coco is not None:
|
||||
del self.coco
|
||||
|
||||
self.coco = self.coco_json_loader(annFile)
|
||||
# Use a torch tensor here to optimize memory usage when using several dataloaders
|
||||
ids_list = list(sorted(self.coco.getDatapointIds()))
|
||||
if self.limit_ids is not None:
|
||||
local_random = random.Random(len(ids_list))
|
||||
local_random.shuffle(ids_list)
|
||||
ids_list = ids_list[: self.limit_ids]
|
||||
self.ids = torch.as_tensor(ids_list, dtype=torch.long)
|
||||
|
||||
def __getitem__(self, index: int) -> Datapoint:
|
||||
return self._load_datapoint(index)
|
||||
|
||||
def _load_datapoint(self, index: int) -> Datapoint:
|
||||
"""A separate method for easy overriding in subclasses."""
|
||||
id = self.ids[index].item()
|
||||
pil_images, img_metadata = self._load_images(id)
|
||||
queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id)
|
||||
return self.load_queries(pil_images, annotations, queries, img_metadata)
|
||||
|
||||
def load_queries(self, pil_images, annotations, queries, img_metadata):
|
||||
"""Transform the raw image and queries into a Datapoint sample."""
|
||||
images: List[Image] = []
|
||||
id2index_img = {}
|
||||
id2index_obj = {}
|
||||
id2index_find_query = {}
|
||||
id2imsize = {}
|
||||
assert len(pil_images) == len(img_metadata)
|
||||
for i in range(len(pil_images)):
|
||||
w, h = pil_images[i][1].size
|
||||
blurring_mask = None
|
||||
if "blurring_mask" in img_metadata[i]:
|
||||
blurring_mask = img_metadata[i]["blurring_mask"]
|
||||
images.append(
|
||||
Image(
|
||||
data=pil_images[i][1],
|
||||
objects=[],
|
||||
size=(h, w),
|
||||
blurring_mask=blurring_mask,
|
||||
)
|
||||
)
|
||||
id2index_img[pil_images[i][0]] = i
|
||||
id2imsize[pil_images[i][0]] = (h, w)
|
||||
|
||||
for annotation in annotations:
|
||||
image_id = id2index_img[annotation["image_id"]]
|
||||
bbox = box_xywh_to_xyxy(torch.as_tensor(annotation["bbox"])).view(1, 4)
|
||||
h, w = id2imsize[annotation["image_id"]]
|
||||
bbox[:, 0::2].mul_(w).clamp_(min=0, max=w)
|
||||
bbox[:, 1::2].mul_(h).clamp_(min=0, max=h)
|
||||
segment = None
|
||||
if self.load_segmentation and "segmentation" in annotation:
|
||||
# We're not decoding the RLE here, a transform will do it lazily later
|
||||
segment = annotation["segmentation"]
|
||||
images[image_id].objects.append(
|
||||
Object(
|
||||
bbox=bbox[0],
|
||||
area=annotation["area"],
|
||||
object_id=(
|
||||
annotation["object_id"] if "object_id" in annotation else -1
|
||||
),
|
||||
frame_index=(
|
||||
annotation["frame_index"] if "frame_index" in annotation else -1
|
||||
),
|
||||
segment=segment,
|
||||
is_crowd=(
|
||||
annotation["is_crowd"] if "is_crowd" in annotation else None
|
||||
),
|
||||
source=annotation["source"] if "source" in annotation else "",
|
||||
)
|
||||
)
|
||||
id2index_obj[annotation["id"]] = len(images[image_id].objects) - 1
|
||||
|
||||
find_queries = []
|
||||
stage2num_queries = Counter()
|
||||
for i, query in enumerate(queries):
|
||||
stage2num_queries[query["query_processing_order"]] += 1
|
||||
id2index_find_query[query["id"]] = i
|
||||
|
||||
# Sanity check: all the stages should have the same number of queries
|
||||
if len(stage2num_queries) == 0:
|
||||
num_queries_per_stage = 0
|
||||
else:
|
||||
num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
|
||||
for stage, num_queries in stage2num_queries.items():
|
||||
assert (
|
||||
num_queries == num_queries_per_stage
|
||||
), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
|
||||
|
||||
for query_id, query in enumerate(queries):
|
||||
h, w = id2imsize[query["image_id"]]
|
||||
if (
|
||||
"input_box" in query
|
||||
and query["input_box"] is not None
|
||||
and len(query["input_box"]) > 0
|
||||
):
|
||||
bbox = box_xywh_to_xyxy(torch.as_tensor(query["input_box"])).view(-1, 4)
|
||||
bbox[:, 0::2].mul_(w).clamp_(min=0, max=w)
|
||||
bbox[:, 1::2].mul_(h).clamp_(min=0, max=h)
|
||||
if "input_box_label" in query and query["input_box_label"] is not None:
|
||||
bbox_label = torch.as_tensor(
|
||||
query["input_box_label"], dtype=torch.long
|
||||
).view(-1)
|
||||
assert len(bbox_label) == len(bbox)
|
||||
else:
|
||||
# assume the boxes are positives
|
||||
bbox_label = torch.ones(len(bbox), dtype=torch.long)
|
||||
else:
|
||||
bbox = None
|
||||
bbox_label = None
|
||||
|
||||
if "input_points" in query and query["input_points"] is not None:
|
||||
points = torch.as_tensor(query["input_points"]).view(1, -1, 3)
|
||||
points[:, :, 0:1].mul_(w).clamp_(min=0, max=w)
|
||||
points[:, :, 1:2].mul_(h).clamp_(min=0, max=h)
|
||||
else:
|
||||
points = None
|
||||
|
||||
try:
|
||||
original_image_id = int(
|
||||
img_metadata[id2index_img[query["image_id"]]]["original_img_id"]
|
||||
)
|
||||
except ValueError:
|
||||
original_image_id = -1
|
||||
|
||||
try:
|
||||
img_metadata_query = img_metadata[id2index_img[query["image_id"]]]
|
||||
coco_image_id = (
|
||||
int(img_metadata_query["coco_img_id"])
|
||||
if "coco_img_id" in img_metadata_query
|
||||
else query["id"]
|
||||
)
|
||||
except KeyError:
|
||||
coco_image_id = -1
|
||||
|
||||
try:
|
||||
original_category_id = int(query["original_cat_id"])
|
||||
except (ValueError, KeyError):
|
||||
original_category_id = -1
|
||||
|
||||
# For evaluation, we associate the ids of the object to be tracked to the query
|
||||
if query["object_ids_output"]:
|
||||
obj_id = query["object_ids_output"][0]
|
||||
obj_idx = id2index_obj[obj_id]
|
||||
image_idx = id2index_img[query["image_id"]]
|
||||
object_id = images[image_idx].objects[obj_idx].object_id
|
||||
frame_index = images[image_idx].objects[obj_idx].frame_index
|
||||
else:
|
||||
object_id = -1
|
||||
frame_index = -1
|
||||
|
||||
find_queries.append(
|
||||
FindQueryLoaded(
|
||||
# id=query["id"],
|
||||
# query_type=qtype,
|
||||
query_text=(
|
||||
query["query_text"] if query["query_text"] is not None else ""
|
||||
),
|
||||
image_id=id2index_img[query["image_id"]],
|
||||
input_bbox=bbox,
|
||||
input_bbox_label=bbox_label,
|
||||
input_points=points,
|
||||
object_ids_output=[
|
||||
id2index_obj[obj_id] for obj_id in query["object_ids_output"]
|
||||
],
|
||||
is_exhaustive=query["is_exhaustive"],
|
||||
is_pixel_exhaustive=(
|
||||
query["is_pixel_exhaustive"]
|
||||
if "is_pixel_exhaustive" in query
|
||||
else (
|
||||
query["is_exhaustive"] if query["is_exhaustive"] else None
|
||||
)
|
||||
),
|
||||
query_processing_order=query["query_processing_order"],
|
||||
inference_metadata=InferenceMetadata(
|
||||
coco_image_id=-1 if self.training else coco_image_id,
|
||||
original_image_id=(-1 if self.training else original_image_id),
|
||||
frame_index=frame_index,
|
||||
original_category_id=original_category_id,
|
||||
original_size=(h, w),
|
||||
object_id=object_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return Datapoint(
|
||||
find_queries=find_queries,
|
||||
images=images,
|
||||
raw_images=[p[1] for p in pil_images],
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.ids)
|
||||
|
||||
|
||||
class Sam3ImageDataset(CustomCocoDetectionAPI):
|
||||
def __init__(
|
||||
self,
|
||||
img_folder,
|
||||
ann_file,
|
||||
transforms,
|
||||
max_ann_per_img: int,
|
||||
multiplier: int,
|
||||
training: bool,
|
||||
load_segmentation: bool = False,
|
||||
max_train_queries: int = 81,
|
||||
max_val_queries: int = 300,
|
||||
fix_fname: bool = False,
|
||||
is_sharded_annotation_dir: bool = False,
|
||||
blurring_masks_path: Optional[str] = None,
|
||||
use_caching: bool = True,
|
||||
zstd_dict_path=None,
|
||||
filter_query=None,
|
||||
coco_json_loader: Callable = COCO_FROM_JSON,
|
||||
limit_ids: int = None,
|
||||
):
|
||||
super(Sam3ImageDataset, self).__init__(
|
||||
img_folder,
|
||||
ann_file,
|
||||
fix_fname=fix_fname,
|
||||
load_segmentation=load_segmentation,
|
||||
training=training,
|
||||
blurring_masks_path=blurring_masks_path,
|
||||
use_caching=use_caching,
|
||||
zstd_dict_path=zstd_dict_path,
|
||||
filter_query=filter_query,
|
||||
coco_json_loader=coco_json_loader,
|
||||
limit_ids=limit_ids,
|
||||
)
|
||||
|
||||
self._transforms = transforms
|
||||
self.training = training
|
||||
self.max_ann_per_img = max_ann_per_img
|
||||
self.max_train_queries = max_train_queries
|
||||
self.max_val_queries = max_val_queries
|
||||
|
||||
self.repeat_factors = torch.ones(len(self.ids), dtype=torch.float32)
|
||||
|
||||
self.repeat_factors *= multiplier
|
||||
print(f"Raw dataset length = {len(self.ids)}")
|
||||
|
||||
self._MAX_RETRIES = 100
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.__orig_getitem__(idx)
|
||||
|
||||
def __orig_getitem__(self, idx):
|
||||
for _ in range(self._MAX_RETRIES):
|
||||
try:
|
||||
datapoint = super(Sam3ImageDataset, self).__getitem__(idx)
|
||||
|
||||
# This can be done better by filtering the offending find queries
|
||||
# However, this requires care:
|
||||
# - Delete any find/get query that may depend on the deleted one
|
||||
# - Re-compute the indexes in the pointers to account for the deleted finds
|
||||
for q in datapoint.find_queries:
|
||||
if len(q.object_ids_output) > self.max_ann_per_img:
|
||||
raise DecompressionBombError(
|
||||
f"Too many outputs ({len(q.object_ids_output)})"
|
||||
)
|
||||
|
||||
max_queries = (
|
||||
self.max_train_queries if self.training else self.max_val_queries
|
||||
)
|
||||
|
||||
if len(datapoint.find_queries) > max_queries:
|
||||
raise DecompressionBombError(
|
||||
f"Too many find queries ({len(datapoint.find_queries)})"
|
||||
)
|
||||
|
||||
if len(datapoint.find_queries) == 0:
|
||||
raise DecompressionBombError("No find queries")
|
||||
for transform in self._transforms:
|
||||
datapoint = transform(datapoint, epoch=self.curr_epoch)
|
||||
|
||||
break
|
||||
except (DecompressionBombError, OSError, ValueError) as error:
|
||||
sys.stderr.write(f"ERROR: got loading error on datapoint {idx}\n")
|
||||
sys.stderr.write(f"Exception: {error}\n")
|
||||
sys.stderr.write(traceback.format_exc())
|
||||
idx = (idx + 1) % len(self)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Failed {self._MAX_RETRIES} times trying to load an image."
|
||||
)
|
||||
|
||||
return datapoint
|
||||
327
sam3/train/data/sam3_video_dataset.py
Normal file
327
sam3/train/data/sam3_video_dataset.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import copy
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
# from decord import cpu, VideoReader
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from .sam3_image_dataset import Datapoint, Sam3ImageDataset
|
||||
|
||||
|
||||
SEED = 42
|
||||
|
||||
|
||||
class VideoGroundingDataset(Sam3ImageDataset):
|
||||
def __init__(
|
||||
self,
|
||||
num_stages_sample: int = 4,
|
||||
stage_stride_min: int = 1,
|
||||
stage_stride_max: int = 5,
|
||||
random_reverse_time_axis: bool = True,
|
||||
is_tiling_single_image: bool = False,
|
||||
# By default, we remove find those queries with geometric inputs (input_box or input_points)
|
||||
# when creating synthetic videos from frames (since they are not *video-level* text prompts).
|
||||
# If we need them later, we can sample them on-the-fly via transforms or inside the model.
|
||||
tile_img_keep_find_queries_with_geo_inputs: bool = False,
|
||||
tile_img_keep_get_queries: bool = False,
|
||||
# the maximum number of find queries (for each frame) to keep in a video; if the datapoint
|
||||
# contains more queries per frame than this limit, we subsample them to avoid OOM errors
|
||||
max_query_num: int = -1, # the default -1 means no limit
|
||||
# whether to override the "is_exhaustive" flag of the loaded find queries to True
|
||||
# (by default, our video datasets are ingested with is_exhaustive=False, since the YTVIS format
|
||||
# annotations doesn't involve an "is_exhaustive" flag; this means that those unmatched (negative)
|
||||
# detection queries or tracking queries do not receive a classification loss given that we have
|
||||
# weak_loss=True in IABCEMdetr -- this could lead to false positives for both image detection
|
||||
# and video association.)
|
||||
override_query_is_exhaustive_to_true: bool = False,
|
||||
# the maximum number of masklets in a video; if the datapoint contains more masklets
|
||||
# than this limit, we skip the datapoint to avoid OOM errors (this is useful for
|
||||
# training with large videos that contain many objects)
|
||||
max_masklet_num_in_video: int = 300, # 300 masklets is usually OK to avoid OOM
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loading video grounding data
|
||||
|
||||
Video frame sampling parameters (for training only):
|
||||
- num_stages_sample: number of frames to sample from the video during training
|
||||
- stage_stride_min: minimum stride between sampled frames during training
|
||||
- stage_stride_max: maximum stride between sampled frames during training (if it's
|
||||
greater than stage_stride_min, the actual stride is sampled uniformly between min
|
||||
and max; during inference, we always use all frames in the video with stride=1)
|
||||
- random_reverse_time_axis: whether to randomly invert the video's temporal axis
|
||||
(i.e. playing it backwards) during training
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
assert num_stages_sample >= 1
|
||||
assert stage_stride_min >= 1
|
||||
assert stage_stride_max >= stage_stride_min
|
||||
self.num_stages_sample = num_stages_sample
|
||||
self.stage_stride_min = stage_stride_min
|
||||
self.stage_stride_max = stage_stride_max
|
||||
self.random_reverse_time_axis = random_reverse_time_axis
|
||||
self.is_tiling_single_image = is_tiling_single_image
|
||||
self.tile_img_keep_find_queries_with_geo_inputs = (
|
||||
tile_img_keep_find_queries_with_geo_inputs
|
||||
)
|
||||
self.tile_img_keep_get_queries = tile_img_keep_get_queries
|
||||
self.max_query_num = max_query_num
|
||||
self.override_query_is_exhaustive_to_true = override_query_is_exhaustive_to_true
|
||||
self.max_masklet_num_in_video = max_masklet_num_in_video
|
||||
self.rng = random.Random()
|
||||
self.set_curr_epoch(0)
|
||||
|
||||
def set_curr_epoch(self, epoch: int):
|
||||
super().set_curr_epoch(epoch)
|
||||
self.rng.seed(SEED + epoch)
|
||||
|
||||
def _load_datapoint(self, index: int) -> Datapoint:
|
||||
id = self.ids[index].item()
|
||||
queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id)
|
||||
|
||||
# we subsample the video frames during training
|
||||
if self.training and not self.is_tiling_single_image:
|
||||
# pick a random stride for sampling query stages (`randint` includes both ends)
|
||||
stage_stride = self.rng.randint(
|
||||
self.stage_stride_min, self.stage_stride_max
|
||||
)
|
||||
stage_ids_to_keep = self._sample_stage_ids(
|
||||
queries, self.num_stages_sample, stage_stride
|
||||
)
|
||||
# filter the queries and annotations to keep only the selected stages
|
||||
# (also remap the stage ids so that they are contiguous and start from 0)
|
||||
reverse_time_axis = (
|
||||
self.rng.random() < 0.5 if self.random_reverse_time_axis else False
|
||||
)
|
||||
queries, annotations, kept_img_ids = self._filter_query_and_anns(
|
||||
queries,
|
||||
annotations,
|
||||
stage_ids_to_keep,
|
||||
remap_stage_id=True,
|
||||
reverse_time_axis=reverse_time_axis,
|
||||
)
|
||||
pil_images, img_metadata = self._load_images(id, kept_img_ids)
|
||||
if reverse_time_axis:
|
||||
# reverse the temporal ordering of the images and their metadata
|
||||
# so that the image order matches the query order
|
||||
pil_images = pil_images[::-1]
|
||||
img_metadata = img_metadata[::-1]
|
||||
else:
|
||||
pil_images, img_metadata = self._load_images(id)
|
||||
|
||||
# check that all the images have the same image size (they are expected
|
||||
# to have the same image size since they are frames from the same video)
|
||||
assert all(p.size == pil_images[0][1].size for _, p in pil_images)
|
||||
|
||||
queries.sort(key=lambda q: q["query_processing_order"])
|
||||
if self.override_query_is_exhaustive_to_true:
|
||||
for query in queries:
|
||||
query["is_exhaustive"] = True
|
||||
datapoint = self.load_queries(pil_images, annotations, queries, img_metadata)
|
||||
|
||||
# skip datapoints with too many masklets to avoid OOM errors
|
||||
num_masklets_in_video = len(datapoint.images[0].objects)
|
||||
if num_masklets_in_video > self.max_masklet_num_in_video > 0:
|
||||
logging.warning(
|
||||
f"Datapoint {id} has ({num_masklets_in_video=}), exceeding "
|
||||
f"the maximum allowed ({self.max_masklet_num_in_video}). "
|
||||
"Skipping this datapoint."
|
||||
)
|
||||
next_index = (index + 1) % len(self)
|
||||
return self._load_datapoint(next_index) # move to the next datapoint
|
||||
|
||||
if self.is_tiling_single_image:
|
||||
datapoint = self._tile_single_image_data(datapoint, self.num_stages_sample)
|
||||
if self.max_query_num > 0:
|
||||
datapoint = self._subsample_queries(datapoint, self.max_query_num)
|
||||
|
||||
# ensure that all find queries have the same processing order as their image id
|
||||
for query in datapoint.find_queries:
|
||||
assert query.image_id == query.query_processing_order, (
|
||||
f"find query has inconsistent image_id and "
|
||||
f"query_processing_order: {query.image_id=} vs "
|
||||
f"{query.query_processing_order=}"
|
||||
)
|
||||
return datapoint
|
||||
|
||||
def _sample_stage_ids(self, queries, num_stages_sample, stage_stride):
|
||||
"""Sample a subset of stage ids from all queries."""
|
||||
# Later we can perhaps turn it into a Sampler class to be more flexible.
|
||||
all_stage_ids = sorted(set(q["query_processing_order"] for q in queries))
|
||||
num_stages_total = len(all_stage_ids)
|
||||
if num_stages_total < num_stages_sample:
|
||||
raise ValueError("Not enough stages to sample")
|
||||
|
||||
# the difference in index between the first and the last sampled stage ids
|
||||
b_e_gap = (num_stages_sample - 1) * stage_stride
|
||||
if b_e_gap > num_stages_total - 1:
|
||||
# In this case, it's not possible to sample with the provide stride,
|
||||
# so we use the maximum possible stride.
|
||||
prev_stage_stride = stage_stride
|
||||
stage_stride = math.floor((num_stages_total - 1) / (num_stages_sample - 1))
|
||||
logging.info(
|
||||
f"lowering stride from {prev_stage_stride} to {stage_stride} to "
|
||||
f"sample {num_stages_sample} stages (from {num_stages_total} total)"
|
||||
)
|
||||
b_e_gap = (num_stages_sample - 1) * stage_stride
|
||||
|
||||
# randomly select a starting stage id (`randint` includes both ends)
|
||||
b_max = len(all_stage_ids) - 1 - b_e_gap
|
||||
b = self.rng.randint(0, b_max)
|
||||
e = b + b_e_gap
|
||||
stage_ids_to_keep = all_stage_ids[b : e + 1 : stage_stride]
|
||||
return stage_ids_to_keep
|
||||
|
||||
def _filter_query_and_anns(
|
||||
self, queries, annotations, stage_ids_to_keep, remap_stage_id, reverse_time_axis
|
||||
):
|
||||
"""Filter queries and annotations to only keep those in `stage_ids_to_keep`."""
|
||||
stage_ids_to_keep = set(stage_ids_to_keep)
|
||||
kept_img_ids = set()
|
||||
kept_stage_ids = set()
|
||||
|
||||
# Filter queries -- keep those queries with stage_id in `stage_ids_to_keep`
|
||||
filtered_queries = []
|
||||
for query in queries:
|
||||
input_box = query.get("input_box", None)
|
||||
input_points = query.get("input_points", None)
|
||||
has_geo_input = input_box is not None or input_points is not None
|
||||
if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs:
|
||||
continue
|
||||
stage_id = query["query_processing_order"]
|
||||
if stage_id in stage_ids_to_keep:
|
||||
kept_img_ids.add(query["image_id"])
|
||||
kept_stage_ids.add(stage_id)
|
||||
filtered_queries.append(query)
|
||||
# Check that all frames in `stage_ids_to_keep` are present after filtering
|
||||
all_frame_present = kept_stage_ids == stage_ids_to_keep
|
||||
assert all_frame_present, f"{kept_stage_ids=} vs {stage_ids_to_keep=}"
|
||||
if remap_stage_id:
|
||||
# Remap those kept stage ids to be contiguous and starting from 0
|
||||
old_stage_ids = sorted(kept_stage_ids, reverse=reverse_time_axis)
|
||||
stage_id_old2new = {old: new for new, old in enumerate(old_stage_ids)}
|
||||
for query in filtered_queries:
|
||||
ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
|
||||
ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
|
||||
assert (
|
||||
ptr_x_is_empty and ptr_y_is_empty
|
||||
), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
|
||||
query["query_processing_order"] = stage_id_old2new[
|
||||
query["query_processing_order"]
|
||||
]
|
||||
|
||||
# Filter annotations -- keep those annotations with image_id in `kept_img_ids`
|
||||
filtered_annotations = [
|
||||
ann for ann in annotations if ann["image_id"] in kept_img_ids
|
||||
]
|
||||
|
||||
return filtered_queries, filtered_annotations, kept_img_ids
|
||||
|
||||
def _tile_single_image_data(self, datapoint: Datapoint, num_stages_sample: int):
|
||||
"""
|
||||
Tile a single image and its queries to simulate video frames. The output is a
|
||||
datapoint with *identical video frames* (i.e. the same static image) and needs
|
||||
further transforms (e.g. affine) to get video frames with different content.
|
||||
"""
|
||||
# tile `images: List[Image]`
|
||||
assert len(datapoint.images) == 1, "Expected only one single image"
|
||||
tiled_images = [
|
||||
copy.deepcopy(datapoint.images[0]) for _ in range(num_stages_sample)
|
||||
]
|
||||
for stage_id, img in enumerate(tiled_images):
|
||||
for obj in img.objects:
|
||||
obj.frame_index = stage_id
|
||||
|
||||
# tile `raw_images: Optional[List[PILImage.Image]] = None`
|
||||
tiled_raw_images = None
|
||||
if datapoint.raw_images is not None:
|
||||
assert len(datapoint.raw_images) == 1, "Expected only one single image"
|
||||
tiled_raw_images = [
|
||||
datapoint.raw_images[0].copy() for _ in range(num_stages_sample)
|
||||
]
|
||||
|
||||
# tile `find_queries: List[FindQueryLoaded]`
|
||||
tiled_find_queries_per_stage = [[] for _ in range(num_stages_sample)]
|
||||
for query in datapoint.find_queries:
|
||||
assert query.image_id == 0
|
||||
assert query.query_processing_order == 0
|
||||
# check and make sure that a query doesn't contain pointers or references
|
||||
# to other queries (that cannot be tiled)
|
||||
assert query.ptr_x is None and query.ptr_y is None
|
||||
assert query.ptr_mem is None
|
||||
# assert query.wkdata_qid is None
|
||||
# assert query.other_positive_qids is None
|
||||
# assert query.negative_qids is None
|
||||
has_geo_input = (
|
||||
query.input_bbox is not None or query.input_points is not None
|
||||
)
|
||||
if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs:
|
||||
continue
|
||||
for stage_id in range(num_stages_sample):
|
||||
# copy the query and update the image_id
|
||||
new_query = copy.deepcopy(query)
|
||||
new_query.image_id = stage_id
|
||||
new_query.query_processing_order = stage_id
|
||||
if new_query.inference_metadata is not None:
|
||||
new_query.inference_metadata.frame_index = stage_id
|
||||
tiled_find_queries_per_stage[stage_id].append(new_query)
|
||||
|
||||
tiled_find_queries = sum(tiled_find_queries_per_stage, [])
|
||||
|
||||
# tile `get_queries: List[GetQuery]` -- we skip them for now (since they involve
|
||||
# a pointer to a find query that is complicated to tile, and there is not an
|
||||
# imminent use case for them in the video grounding task in the near future)
|
||||
if self.tile_img_keep_get_queries:
|
||||
raise NotImplementedError("Tiling get queries is not implemented yet")
|
||||
else:
|
||||
tiled_get_queries = []
|
||||
|
||||
return Datapoint(
|
||||
images=tiled_images,
|
||||
raw_images=tiled_raw_images,
|
||||
find_queries=tiled_find_queries,
|
||||
get_queries=tiled_get_queries,
|
||||
)
|
||||
|
||||
def _subsample_queries(self, datapoint: Datapoint, max_query_num: int):
|
||||
"""Subsample to keep at most `max_query_num` queries per frame in a datapoint."""
|
||||
# aggregate the find queries per stage
|
||||
num_frames = max(q.query_processing_order for q in datapoint.find_queries) + 1
|
||||
find_queries_per_stage = [[] for _ in range(num_frames)]
|
||||
for query in datapoint.find_queries:
|
||||
find_queries_per_stage[query.query_processing_order].append(query)
|
||||
|
||||
# verify that all the stages have the same number of queries
|
||||
num_queries_per_stage = len(find_queries_per_stage[0])
|
||||
for queries in find_queries_per_stage:
|
||||
assert len(queries) == num_queries_per_stage
|
||||
if max_query_num <= 0 or num_queries_per_stage <= max_query_num:
|
||||
return datapoint
|
||||
|
||||
# subsample the queries to keep only `max_query_num` queries
|
||||
sampled_inds = self.rng.sample(range(num_queries_per_stage), max_query_num)
|
||||
sampled_find_queries_per_stage = [
|
||||
[queries[idx] for idx in sampled_inds] for queries in find_queries_per_stage
|
||||
]
|
||||
sampled_find_queries = sum(sampled_find_queries_per_stage, [])
|
||||
return Datapoint(
|
||||
images=datapoint.images,
|
||||
raw_images=datapoint.raw_images,
|
||||
find_queries=sampled_find_queries,
|
||||
get_queries=datapoint.get_queries,
|
||||
)
|
||||
52
sam3/train/data/torch_dataset.py
Normal file
52
sam3/train/data/torch_dataset.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset
|
||||
|
||||
|
||||
class TorchDataset:
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
shuffle: bool,
|
||||
pin_memory: bool,
|
||||
drop_last: bool,
|
||||
collate_fn: Optional[Callable] = None,
|
||||
worker_init_fn: Optional[Callable] = None,
|
||||
enable_distributed_sampler=True,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.shuffle = shuffle
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
self.collate_fn = collate_fn
|
||||
self.worker_init_fn = worker_init_fn
|
||||
assert not isinstance(self.dataset, IterableDataset), "Not supported yet"
|
||||
if enable_distributed_sampler:
|
||||
self.sampler = DistributedSampler(self.dataset, shuffle=self.shuffle)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
def get_loader(self, epoch) -> Iterable:
|
||||
if self.sampler:
|
||||
self.sampler.set_epoch(epoch)
|
||||
if hasattr(self.dataset, "epoch"):
|
||||
self.dataset.epoch = epoch
|
||||
if hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
return DataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
drop_last=self.drop_last,
|
||||
sampler=self.sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
worker_init_fn=self.worker_init_fn,
|
||||
)
|
||||
1
sam3/train/loss/__init__.py
Normal file
1
sam3/train/loss/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
1319
sam3/train/loss/loss_fns.py
Normal file
1319
sam3/train/loss/loss_fns.py
Normal file
File diff suppressed because it is too large
Load Diff
113
sam3/train/loss/mask_sampling.py
Normal file
113
sam3/train/loss/mask_sampling.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
|
||||
def point_sample(input, point_coords, **kwargs):
|
||||
"""
|
||||
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
|
||||
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
|
||||
[0, 1] x [0, 1] square.
|
||||
|
||||
Args:
|
||||
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
|
||||
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
|
||||
[0, 1] x [0, 1] normalized point coordinates.
|
||||
|
||||
Returns:
|
||||
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
|
||||
features for points in `point_coords`. The features are obtained via bilinear
|
||||
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
|
||||
"""
|
||||
add_dim = False
|
||||
if point_coords.dim() == 3:
|
||||
add_dim = True
|
||||
point_coords = point_coords.unsqueeze(2)
|
||||
normalized_point_coords = 2.0 * point_coords - 1.0 # Normalize to [-1,1]
|
||||
output = F.grid_sample(input, normalized_point_coords, **kwargs)
|
||||
if add_dim:
|
||||
output = output.squeeze(3)
|
||||
return output
|
||||
|
||||
|
||||
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
|
||||
def get_uncertain_point_coords_with_randomness(
|
||||
logits: torch.Tensor,
|
||||
uncertainty_func: Callable,
|
||||
num_points: int,
|
||||
oversample_ratio: int,
|
||||
importance_sample_ratio: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
|
||||
are calculated for each point using 'uncertainty_func' function that takes point's logit
|
||||
prediction as input.
|
||||
See PointRend paper for details.
|
||||
|
||||
Args:
|
||||
logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
|
||||
class-specific or class-agnostic prediction.
|
||||
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
|
||||
contains logit predictions for P points and returns their uncertainties as a Tensor of
|
||||
shape (N, 1, P).
|
||||
num_points (int): The number of points P to sample.
|
||||
oversample_ratio (int): Oversampling parameter.
|
||||
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
|
||||
|
||||
Returns:
|
||||
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
|
||||
sampled points.
|
||||
"""
|
||||
assert oversample_ratio >= 1
|
||||
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
|
||||
num_boxes = logits.shape[0]
|
||||
num_sampled = int(num_points * oversample_ratio)
|
||||
point_coords = torch.rand(num_boxes, num_sampled, 2, device=logits.device)
|
||||
point_logits = point_sample(logits, point_coords, align_corners=False)
|
||||
# It is crucial to calculate uncertainty based on the sampled prediction value for the points.
|
||||
# Calculating uncertainties of the predictions first and sampling them for points leads
|
||||
# to incorrect results.
|
||||
# To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
|
||||
# two predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
|
||||
# However, if we calculate uncertainties for the predictions first,
|
||||
# both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
|
||||
point_uncertainties = uncertainty_func(point_logits)
|
||||
num_uncertain_points = int(importance_sample_ratio * num_points)
|
||||
num_random_points = num_points - num_uncertain_points
|
||||
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
||||
# Flatten the indices
|
||||
shift = num_sampled * torch.arange(
|
||||
num_boxes, dtype=torch.long, device=logits.device
|
||||
)
|
||||
idx += shift[:, None]
|
||||
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
|
||||
num_boxes, num_uncertain_points, 2
|
||||
)
|
||||
if num_random_points > 0:
|
||||
point_coords = torch.cat(
|
||||
[
|
||||
point_coords,
|
||||
torch.rand(num_boxes, num_random_points, 2, device=logits.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return point_coords
|
||||
|
||||
|
||||
# Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py
|
||||
def calculate_uncertainty(logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Estimates uncerainty as L1 distance between 0.0 and the logit prediction.
|
||||
Args:
|
||||
logits (Tensor): A tensor of shape (R, 1, ...) for class-agnostic
|
||||
predicted masks
|
||||
Returns:
|
||||
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
|
||||
the most uncertain locations having the highest uncertainty score.
|
||||
"""
|
||||
assert logits.shape[1] == 1
|
||||
return -(torch.abs(logits))
|
||||
203
sam3/train/loss/sam3_loss.py
Normal file
203
sam3/train/loss/sam3_loss.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.model_misc import SAM3Output
|
||||
|
||||
from sam3.train.utils.distributed import get_world_size
|
||||
|
||||
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks
|
||||
|
||||
|
||||
class DummyLoss(torch.nn.Module):
|
||||
"""A dummy loss that always returns 0 (as a placeholder for eval)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
core_loss_key: str = CORE_LOSS_KEY,
|
||||
device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.core_loss_key = core_loss_key
|
||||
self.device = torch.device(device)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return {self.core_loss_key: torch.tensor(0.0, device=self.device)}
|
||||
|
||||
def accumulate(self, out_dict):
|
||||
"""
|
||||
Called by iterative losses.
|
||||
"""
|
||||
if self.core_loss_key not in out_dict:
|
||||
out_dict[self.core_loss_key] = torch.tensor(0.0, device=self.device)
|
||||
return out_dict
|
||||
|
||||
|
||||
class Sam3LossWrapper(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
loss_fns_find,
|
||||
normalization="global",
|
||||
matcher=None,
|
||||
o2m_matcher=None,
|
||||
o2m_weight=1.0,
|
||||
use_o2m_matcher_on_o2m_aux=True,
|
||||
loss_fn_semantic_seg=None,
|
||||
normalize_by_valid_object_num=False,
|
||||
normalize_by_stage_num=False,
|
||||
scale_by_find_batch_size=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_fns_find = loss_fns_find
|
||||
assert normalization in ["global", "local", "none"]
|
||||
self.normalization = normalization
|
||||
self.normalize_by_valid_object_num = normalize_by_valid_object_num
|
||||
self.normalize_by_stage_num = normalize_by_stage_num
|
||||
self.matcher = matcher
|
||||
self.o2m_matcher = o2m_matcher
|
||||
self.o2m_weight = o2m_weight
|
||||
# whether to use the o2m matcher on the o2m queries in auxiliary outputs
|
||||
self.use_o2m_matcher_on_o2m_aux = use_o2m_matcher_on_o2m_aux
|
||||
self.loss_fn_semantic_seg = loss_fn_semantic_seg
|
||||
self.scale_by_find_batch_size = scale_by_find_batch_size
|
||||
|
||||
def _get_num_boxes(self, targets):
|
||||
# the average number of target boxes for loss normalization
|
||||
if self.normalize_by_valid_object_num:
|
||||
# valid boxes are those with non-zero height and width
|
||||
# (while padded invisible boxes are )
|
||||
boxes_hw = targets["boxes"].view(-1, 4) # cx, cy, w, h
|
||||
num_boxes = (boxes_hw[:, 2:] > 0).all(dim=-1).sum().float()
|
||||
else:
|
||||
num_boxes = targets["num_boxes"].sum().float()
|
||||
if self.normalization == "global":
|
||||
torch.distributed.all_reduce(num_boxes)
|
||||
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1)
|
||||
elif self.normalization == "local":
|
||||
num_boxes = torch.clamp(num_boxes, min=1)
|
||||
elif self.normalization == "none":
|
||||
num_boxes = 1
|
||||
return num_boxes
|
||||
|
||||
def compute_loss(self, nested_out, targets):
|
||||
num_boxes = self._get_num_boxes(targets)
|
||||
o2m_out_is_valid = nested_out.get("o2m_out_is_valid", None)
|
||||
o2m_target_is_valid_padded = nested_out.get("o2m_target_is_valid_padded", None)
|
||||
|
||||
# Get a list of outputs, including auxiliary and first stage outputs
|
||||
output_list = [(nested_out, "", False)] # (out, suffix, is_aux)
|
||||
if "aux_outputs" in nested_out:
|
||||
output_list.extend(
|
||||
(aux_out, f"_aux_{i}", True)
|
||||
for i, aux_out in enumerate(nested_out["aux_outputs"])
|
||||
)
|
||||
if "first_stage" in nested_out:
|
||||
output_list.append((nested_out["first_stage"], "_fs", True))
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
total_core_loss = 0.0
|
||||
for out, suffix, is_aux in output_list:
|
||||
# o2o matcher indices need to be computed by the model (as the video model requires
|
||||
# a specific way of matching free and locked indices beyond just calling the matcher)
|
||||
indices = out["indices"]
|
||||
has_o2m_out = "pred_logits_o2m" in out
|
||||
if has_o2m_out:
|
||||
o2m_out = {
|
||||
k[: -len("_o2m")]: v for k, v in out.items() if k.endswith("_o2m")
|
||||
}
|
||||
# o2m targets are the same as the o2o targets (assuming repeat=1)
|
||||
o2m_targets = targets
|
||||
if self.use_o2m_matcher_on_o2m_aux or not is_aux:
|
||||
o2m_indices = self.o2m_matcher(
|
||||
o2m_out,
|
||||
o2m_targets,
|
||||
out_is_valid=o2m_out_is_valid,
|
||||
target_is_valid_padded=o2m_target_is_valid_padded,
|
||||
)
|
||||
else:
|
||||
o2m_indices = self.matcher(
|
||||
o2m_out,
|
||||
o2m_targets,
|
||||
out_is_valid=o2m_out_is_valid,
|
||||
target_is_valid_padded=o2m_target_is_valid_padded,
|
||||
)
|
||||
|
||||
for loss_fn in self.loss_fns_find:
|
||||
l_dict = loss_fn(
|
||||
outputs=out,
|
||||
targets=targets,
|
||||
indices=indices,
|
||||
num_boxes=num_boxes,
|
||||
is_aux=is_aux,
|
||||
)
|
||||
total_core_loss += l_dict.pop(CORE_LOSS_KEY)
|
||||
losses.update({f"{k}{suffix}": v for k, v in l_dict.items()})
|
||||
|
||||
compute_o2m_loss = has_o2m_out
|
||||
# a special handling to allow turning off mask loss in o2m
|
||||
# (to be compatible with the original implementation)
|
||||
if isinstance(loss_fn, Masks):
|
||||
compute_o2m_loss = compute_o2m_loss and "pred_masks" in o2m_out
|
||||
if isinstance(loss_fn, Det2TrkAssoc):
|
||||
compute_o2m_loss = False # Det2TrkAssoc does not support o2m
|
||||
if compute_o2m_loss:
|
||||
l_dict = loss_fn(
|
||||
outputs=o2m_out,
|
||||
targets=o2m_targets,
|
||||
indices=o2m_indices,
|
||||
num_boxes=num_boxes,
|
||||
is_aux=is_aux,
|
||||
)
|
||||
for k in l_dict:
|
||||
l_dict[k] *= self.o2m_weight
|
||||
total_core_loss += l_dict.pop(CORE_LOSS_KEY)
|
||||
losses.update({f"{k}{suffix}_o2m": v for k, v in l_dict.items()})
|
||||
|
||||
losses[CORE_LOSS_KEY] = total_core_loss
|
||||
return losses
|
||||
|
||||
def forward(self, find_stages: SAM3Output, find_targets):
|
||||
if find_stages.loss_stages is not None:
|
||||
find_targets = [find_targets[i] for i in find_stages.loss_stages]
|
||||
with SAM3Output.iteration_mode(
|
||||
find_stages, iter_mode=SAM3Output.IterMode.ALL_STEPS_PER_STAGE
|
||||
) as find_stages:
|
||||
assert len(find_stages) == len(find_targets)
|
||||
total_losses = {}
|
||||
for stage_outputs, stage_targets in zip(find_stages, find_targets):
|
||||
stage_targets = [stage_targets] * len(stage_outputs)
|
||||
# If there are multiple steps within a stage, compute the loss for all of them (e.g. interactivity)
|
||||
for outputs, targets in zip(stage_outputs, stage_targets):
|
||||
cur_losses = self.compute_loss(outputs, targets)
|
||||
|
||||
if self.loss_fn_semantic_seg is not None:
|
||||
cur_losses_semantic = self.loss_fn_semantic_seg(
|
||||
outputs, targets
|
||||
)
|
||||
cur_losses[CORE_LOSS_KEY] += cur_losses_semantic.pop(
|
||||
CORE_LOSS_KEY
|
||||
)
|
||||
# make sure the semantic losses don't overlap with the find losses
|
||||
assert set(cur_losses).isdisjoint(set(cur_losses_semantic))
|
||||
cur_losses.update(cur_losses_semantic)
|
||||
|
||||
# Optionally, normalize the loss by the number of find stages (training video frames) so that
|
||||
# image batches and video batches have similar loss scales. (Otherwise video batches would
|
||||
# have a much higher loss scale due to summing the losses over all the find stages.)
|
||||
if self.normalize_by_stage_num:
|
||||
cur_losses[CORE_LOSS_KEY] /= len(find_stages)
|
||||
|
||||
if self.scale_by_find_batch_size:
|
||||
bs = targets["num_boxes"].shape[0]
|
||||
# sqrt scaling based on the "effective" batch size
|
||||
cur_losses[CORE_LOSS_KEY] *= bs**0.5
|
||||
|
||||
for k, v in cur_losses.items():
|
||||
if k not in total_losses:
|
||||
total_losses[k] = v
|
||||
else:
|
||||
total_losses[k] += v
|
||||
|
||||
return total_losses
|
||||
321
sam3/train/loss/sigmoid_focal_loss.py
Normal file
321
sam3/train/loss/sigmoid_focal_loss.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""Triton kernel for faster and memory efficient sigmoid focal loss"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch._inductor.runtime.triton_helpers import libdevice
|
||||
|
||||
"""
|
||||
|
||||
The sigmoid focal loss is defined as:
|
||||
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
Where alpha and gamma are scalar parameters, inputs are the logits, targets the float targets.
|
||||
|
||||
We implement two versions of the sigmoid focal loss: with and without sum reduction.
|
||||
The latter is implemented with built-in reduction to avoid materializing wrt the output of the loss.
|
||||
This can help save a bit of peak memory.
|
||||
|
||||
The reduction version is implemented using somewhat of a hack. Pytorch's generated kernels usually do the point-wise operation in a first kernel, and implement the reduction another kernel launched on a grid of size 1, where the reduction happens as a for loop in the triton kernel.
|
||||
Since we want to fuse those two kernels, that is not a good idea: we'd have to launch the overall kernel on a grid of size 1, which is obviously inefficient.
|
||||
On the other hand, typical CUDA algorithms for reduction (eg reduction tree) are hard to implement in triton due to the lack of thread sync primitives.
|
||||
We settle for a version that abuses triton's atomic_add: we can have all threads simply add to the same location.
|
||||
In practice, this is not good, since it creates a massive bottleneck on the semaphore for that single memory location. So instead, we create M reduction locations. Each thread will simply write to thread_id%M. The python code can finally sum over the M reductions.
|
||||
M = 32 works fine in benchmarking tests. The forward is a tiny bit slower compared to the non-reduced kernel, but the backward breaks even due to one less memory allocation.
|
||||
"""
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _inner_focal_loss_fwd(inputs, targets, alpha, gamma):
|
||||
inv_targets = 1 - targets
|
||||
# Sigmoid
|
||||
sig = tl.sigmoid(inputs)
|
||||
|
||||
# Binary cross entropy with logits
|
||||
# In practice, we want the following:
|
||||
# bce_loss = -targets * tl.log(sig) - (1 - targets) * tl.log(1 - sig)
|
||||
# However, the above is not numerically stable.
|
||||
# We're also not directly taking the sum here, so the usual log-sum-exp trick doesn't apply
|
||||
# The bce can be reformulated, after algebraic manipulation, to
|
||||
# bce_loss = log(1 + exp(-x)) + x * (1-y)
|
||||
# This is still not stable, because for large (-x) the exponential will blow up.
|
||||
# We'll use the following alternate formulation:
|
||||
# bce_loss = max(x, 0) - x * y + log(1 + exp(-abs(x)))
|
||||
# Let's show that it's equivalent:
|
||||
# Case x>=0: abs(x) = x , max(x, 0) = x
|
||||
# so we get x - x * y + log(1 + exp(-x)) which is equivalent
|
||||
# Case x<0: abs(x) = -x, max(x, 0) = 0
|
||||
# we have log(1 + exp(-abs(x))) = log(1 + exp(x)) = log(exp(x)(1 + exp(-x))) = x+log(1 + exp(-x))
|
||||
# plugging it in, we get
|
||||
# 0 - x * y + x + log(1 + exp(-x)), which is also equivalent
|
||||
# Note that this is stable because now the exponent are guaranteed to be below 0.
|
||||
max_val = tl.clamp(inputs, min=0, max=1e9)
|
||||
bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
|
||||
|
||||
# Modulating factor
|
||||
p_t = sig * targets + (1 - sig) * inv_targets
|
||||
mod_factor = libdevice.pow(1 - p_t, gamma)
|
||||
|
||||
# Alpha factor
|
||||
alpha_t = alpha * targets + (1 - alpha) * inv_targets
|
||||
|
||||
# Final loss calculation
|
||||
return alpha_t * mod_factor * bce_loss
|
||||
|
||||
|
||||
# Non-reduced version
|
||||
@triton.jit
|
||||
def sigmoid_focal_loss_fwd_kernel(
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
loss_ptr,
|
||||
alpha: float,
|
||||
gamma: float,
|
||||
n_elements: int,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offset = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < n_elements
|
||||
|
||||
# Load data
|
||||
inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
|
||||
targets = tl.load(targets_ptr + offset, mask=mask)
|
||||
|
||||
final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma)
|
||||
|
||||
# Store result
|
||||
tl.store(loss_ptr + offset, final_loss, mask=mask)
|
||||
|
||||
|
||||
# version with reduction
|
||||
@triton.jit
|
||||
def sigmoid_focal_loss_fwd_kernel_reduce(
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
loss_ptr,
|
||||
alpha: float,
|
||||
gamma: float,
|
||||
n_elements: int,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
REDUCE_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
reduce_loc = pid % REDUCE_SIZE
|
||||
offset = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < n_elements
|
||||
# Load data
|
||||
inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
|
||||
targets = tl.load(targets_ptr + offset, mask=mask)
|
||||
|
||||
final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) * mask
|
||||
|
||||
fl = tl.sum(final_loss)
|
||||
|
||||
# Store result
|
||||
tl.atomic_add(loss_ptr + reduce_loc, fl)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _inner_focal_loss_bwd(inputs, targets, alpha, gamma):
|
||||
inv_targets = 1 - targets
|
||||
|
||||
# Recompute forward
|
||||
max_val = tl.clamp(inputs, min=0, max=1e9)
|
||||
bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
|
||||
|
||||
# Sigmoid
|
||||
sig = tl.sigmoid(inputs)
|
||||
inv_sig = 1 - sig
|
||||
|
||||
# Modulating factor
|
||||
p_t = sig * targets + inv_sig * inv_targets
|
||||
tmp = libdevice.pow(1 - p_t, gamma - 1)
|
||||
mod_factor = tmp * (1 - p_t)
|
||||
|
||||
# Alpha factor
|
||||
alpha_t = alpha * targets + (1 - alpha) * inv_targets
|
||||
|
||||
# Now computing the derivatives
|
||||
d_pt = (2 * targets - 1) * sig * inv_sig
|
||||
d_mod_factor = -gamma * d_pt * tmp
|
||||
|
||||
d_bce_loss = sig - targets
|
||||
|
||||
return alpha_t * (d_bce_loss * mod_factor + d_mod_factor * bce_loss)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sigmoid_focal_loss_bwd_kernel(
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
grad_inputs_ptr,
|
||||
grad_out_ptr,
|
||||
alpha: float,
|
||||
gamma: float,
|
||||
n_elements: int,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offset = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < n_elements
|
||||
input_ptrs = inputs_ptr + offset
|
||||
target_ptrs = targets_ptr + offset
|
||||
grad_input_ptrs = grad_inputs_ptr + offset
|
||||
grad_out_ptrs = grad_out_ptr + offset
|
||||
# Load data
|
||||
inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
|
||||
targets = tl.load(target_ptrs, mask=mask)
|
||||
grad_out = tl.load(grad_out_ptrs, mask=mask)
|
||||
d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
|
||||
tl.store(grad_input_ptrs, d_loss, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sigmoid_focal_loss_bwd_kernel_reduce(
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
grad_inputs_ptr,
|
||||
grad_out_ptr,
|
||||
alpha: float,
|
||||
gamma: float,
|
||||
n_elements: int,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# The only difference is that the gradient is now a single scalar
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offset = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < n_elements
|
||||
input_ptrs = inputs_ptr + offset
|
||||
target_ptrs = targets_ptr + offset
|
||||
grad_input_ptrs = grad_inputs_ptr + offset
|
||||
# Load data
|
||||
inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
|
||||
targets = tl.load(target_ptrs, mask=mask)
|
||||
grad_out = tl.load(grad_out_ptr)
|
||||
d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
|
||||
tl.store(grad_input_ptrs, d_loss, mask=mask)
|
||||
|
||||
|
||||
class SigmoidFocalLoss(torch.autograd.Function):
|
||||
BLOCK_SIZE = 256
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
|
||||
n_elements = inputs.numel()
|
||||
assert targets.numel() == n_elements
|
||||
input_shape = inputs.shape
|
||||
inputs = inputs.view(-1).contiguous()
|
||||
targets = targets.view(-1).contiguous()
|
||||
loss = torch.empty(inputs.shape, dtype=torch.float32, device=inputs.device)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
sigmoid_focal_loss_fwd_kernel[grid](
|
||||
inputs, targets, loss, alpha, gamma, n_elements, SigmoidFocalLoss.BLOCK_SIZE
|
||||
)
|
||||
ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
|
||||
ctx.alpha = alpha
|
||||
ctx.gamma = gamma
|
||||
return loss.view(input_shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
inputs, targets = ctx.saved_tensors
|
||||
alpha = ctx.alpha
|
||||
gamma = ctx.gamma
|
||||
n_elements = inputs.numel()
|
||||
input_shape = inputs.shape
|
||||
grad_inputs = torch.empty(
|
||||
inputs.shape, dtype=grad_output.dtype, device=grad_output.device
|
||||
)
|
||||
inputs_ptr = inputs.view(-1).contiguous()
|
||||
targets_ptr = targets.view(-1).contiguous()
|
||||
grad_output_ptr = grad_output.view(-1).contiguous()
|
||||
grad_inputs_ptr = grad_inputs
|
||||
assert grad_output.numel() == n_elements
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
sigmoid_focal_loss_bwd_kernel[grid](
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
grad_inputs_ptr,
|
||||
grad_output_ptr,
|
||||
alpha,
|
||||
gamma,
|
||||
n_elements,
|
||||
SigmoidFocalLoss.BLOCK_SIZE,
|
||||
)
|
||||
return grad_inputs.view(input_shape), None, None, None
|
||||
|
||||
|
||||
triton_sigmoid_focal_loss = SigmoidFocalLoss.apply
|
||||
|
||||
|
||||
class SigmoidFocalLossReduced(torch.autograd.Function):
|
||||
BLOCK_SIZE = 256
|
||||
REDUCE_SIZE = 32
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
|
||||
n_elements = inputs.numel()
|
||||
input_shape = inputs.shape
|
||||
inputs = inputs.view(-1).contiguous()
|
||||
targets = targets.view(-1).contiguous()
|
||||
loss = torch.zeros(
|
||||
SigmoidFocalLossReduced.REDUCE_SIZE,
|
||||
device=inputs.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
sigmoid_focal_loss_fwd_kernel_reduce[grid](
|
||||
inputs,
|
||||
targets,
|
||||
loss,
|
||||
alpha,
|
||||
gamma,
|
||||
n_elements,
|
||||
SigmoidFocalLossReduced.BLOCK_SIZE,
|
||||
SigmoidFocalLossReduced.REDUCE_SIZE,
|
||||
)
|
||||
ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
|
||||
ctx.alpha = alpha
|
||||
ctx.gamma = gamma
|
||||
return loss.sum()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
inputs, targets = ctx.saved_tensors
|
||||
alpha = ctx.alpha
|
||||
gamma = ctx.gamma
|
||||
n_elements = inputs.numel()
|
||||
input_shape = inputs.shape
|
||||
grad_inputs = torch.empty(
|
||||
inputs.shape, dtype=grad_output.dtype, device=grad_output.device
|
||||
)
|
||||
inputs_ptr = inputs.view(-1).contiguous()
|
||||
targets_ptr = targets.reshape(-1).contiguous()
|
||||
assert grad_output.numel() == 1
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
sigmoid_focal_loss_bwd_kernel_reduce[grid](
|
||||
inputs_ptr,
|
||||
targets_ptr,
|
||||
grad_inputs,
|
||||
grad_output,
|
||||
alpha,
|
||||
gamma,
|
||||
n_elements,
|
||||
SigmoidFocalLossReduced.BLOCK_SIZE,
|
||||
)
|
||||
return grad_inputs.view(input_shape), None, None, None
|
||||
|
||||
|
||||
triton_sigmoid_focal_loss_reduce = SigmoidFocalLossReduced.apply
|
||||
272
sam3/train/masks_ops.py
Normal file
272
sam3/train/masks_ops.py
Normal file
@@ -0,0 +1,272 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""Utilities for masks manipulation"""
|
||||
|
||||
import numpy as np
|
||||
import pycocotools.mask as maskUtils
|
||||
import torch
|
||||
from pycocotools import mask as mask_util
|
||||
|
||||
|
||||
def instance_masks_to_semantic_masks(
|
||||
instance_masks: torch.Tensor, num_instances: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""This function converts instance masks to semantic masks.
|
||||
It accepts a collapsed batch of instances masks (ie all instance masks are concatenated in a single tensor) and
|
||||
the number of instances in each image of the batch.
|
||||
It returns a mask with the same spatial dimensions as the input instance masks, where for each batch element the
|
||||
semantic mask is the union of all the instance masks in the batch element.
|
||||
|
||||
If for a given batch element there are no instances (ie num_instances[i]==0), the corresponding semantic mask will be a tensor of zeros.
|
||||
|
||||
Args:
|
||||
instance_masks (torch.Tensor): A tensor of shape (N, H, W) where N is the number of instances in the batch.
|
||||
num_instances (torch.Tensor): A tensor of shape (B,) where B is the batch size. It contains the number of instances
|
||||
in each image of the batch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor of shape (B, H, W) where B is the batch size and H, W are the spatial dimensions of the
|
||||
input instance masks.
|
||||
"""
|
||||
|
||||
masks_per_query = torch.split(instance_masks, num_instances.tolist())
|
||||
|
||||
return torch.stack([torch.any(masks, dim=0) for masks in masks_per_query], dim=0)
|
||||
|
||||
|
||||
def mask_intersection(masks1, masks2, block_size=16):
|
||||
"""Compute the intersection of two sets of masks, without blowing the memory"""
|
||||
|
||||
assert masks1.shape[1:] == masks2.shape[1:]
|
||||
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
|
||||
|
||||
result = torch.zeros(
|
||||
masks1.shape[0], masks2.shape[0], device=masks1.device, dtype=torch.long
|
||||
)
|
||||
for i in range(0, masks1.shape[0], block_size):
|
||||
for j in range(0, masks2.shape[0], block_size):
|
||||
intersection = (
|
||||
(masks1[i : i + block_size, None] * masks2[None, j : j + block_size])
|
||||
.flatten(-2)
|
||||
.sum(-1)
|
||||
)
|
||||
result[i : i + block_size, j : j + block_size] = intersection
|
||||
return result
|
||||
|
||||
|
||||
def mask_iom(masks1, masks2):
|
||||
"""
|
||||
Similar to IoU, except the denominator is the area of the smallest mask
|
||||
"""
|
||||
assert masks1.shape[1:] == masks2.shape[1:]
|
||||
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
|
||||
|
||||
# intersection = (masks1[:, None] * masks2[None]).flatten(-2).sum(-1)
|
||||
intersection = mask_intersection(masks1, masks2)
|
||||
area1 = masks1.flatten(-2).sum(-1)
|
||||
area2 = masks2.flatten(-2).sum(-1)
|
||||
min_area = torch.min(area1[:, None], area2[None, :])
|
||||
return intersection / (min_area + 1e-8)
|
||||
|
||||
|
||||
def compute_boundary(seg):
|
||||
"""
|
||||
Adapted from https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/metrics/j_and_f.py#L148
|
||||
Return a 1pix wide boundary of the given mask
|
||||
"""
|
||||
assert seg.ndim >= 2
|
||||
e = torch.zeros_like(seg)
|
||||
s = torch.zeros_like(seg)
|
||||
se = torch.zeros_like(seg)
|
||||
|
||||
e[..., :, :-1] = seg[..., :, 1:]
|
||||
s[..., :-1, :] = seg[..., 1:, :]
|
||||
se[..., :-1, :-1] = seg[..., 1:, 1:]
|
||||
|
||||
b = seg ^ e | seg ^ s | seg ^ se
|
||||
b[..., -1, :] = seg[..., -1, :] ^ e[..., -1, :]
|
||||
b[..., :, -1] = seg[..., :, -1] ^ s[..., :, -1]
|
||||
b[..., -1, -1] = 0
|
||||
return b
|
||||
|
||||
|
||||
def dilation(mask, kernel_size):
|
||||
"""
|
||||
Implements the dilation operation. If the input is on cpu, we call the cv2 version.
|
||||
Otherwise, we implement it using a convolution
|
||||
|
||||
The kernel is assumed to be a square kernel
|
||||
|
||||
"""
|
||||
|
||||
assert mask.ndim == 3
|
||||
kernel_size = int(kernel_size)
|
||||
assert (
|
||||
kernel_size % 2 == 1
|
||||
), f"Dilation expects a odd kernel size, got {kernel_size}"
|
||||
|
||||
if mask.is_cuda:
|
||||
m = mask.unsqueeze(1).to(torch.float16)
|
||||
k = torch.ones(1, 1, kernel_size, 1, dtype=m.dtype, device=m.device)
|
||||
|
||||
result = torch.nn.functional.conv2d(m, k, padding="same")
|
||||
result = torch.nn.functional.conv2d(result, k.transpose(-1, -2), padding="same")
|
||||
return result.view_as(mask) > 0
|
||||
|
||||
all_masks = mask.view(-1, mask.size(-2), mask.size(-1)).numpy().astype(np.uint8)
|
||||
kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
|
||||
|
||||
import cv2
|
||||
|
||||
processed = [torch.from_numpy(cv2.dilate(m, kernel)) for m in all_masks]
|
||||
return torch.stack(processed).view_as(mask).to(mask)
|
||||
|
||||
|
||||
def compute_F_measure(
|
||||
gt_boundary_rle, gt_dilated_boundary_rle, dt_boundary_rle, dt_dilated_boundary_rle
|
||||
):
|
||||
"""Adapted from https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/metrics/j_and_f.py#L207
|
||||
|
||||
Assumes the boundary and dilated boundaries have already been computed and converted to RLE
|
||||
"""
|
||||
gt_match = maskUtils.merge([gt_boundary_rle, dt_dilated_boundary_rle], True)
|
||||
dt_match = maskUtils.merge([dt_boundary_rle, gt_dilated_boundary_rle], True)
|
||||
|
||||
n_dt = maskUtils.area(dt_boundary_rle)
|
||||
n_gt = maskUtils.area(gt_boundary_rle)
|
||||
# % Compute precision and recall
|
||||
if n_dt == 0 and n_gt > 0:
|
||||
precision = 1
|
||||
recall = 0
|
||||
elif n_dt > 0 and n_gt == 0:
|
||||
precision = 0
|
||||
recall = 1
|
||||
elif n_dt == 0 and n_gt == 0:
|
||||
precision = 1
|
||||
recall = 1
|
||||
else:
|
||||
precision = maskUtils.area(dt_match) / float(n_dt)
|
||||
recall = maskUtils.area(gt_match) / float(n_gt)
|
||||
|
||||
# Compute F measure
|
||||
if precision + recall == 0:
|
||||
f_val = 0
|
||||
else:
|
||||
f_val = 2 * precision * recall / (precision + recall)
|
||||
|
||||
return f_val
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def rle_encode(orig_mask, return_areas=False):
|
||||
"""Encodes a collection of masks in RLE format
|
||||
|
||||
This function emulates the behavior of the COCO API's encode function, but
|
||||
is executed partially on the GPU for faster execution.
|
||||
|
||||
Args:
|
||||
mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool
|
||||
return_areas (bool): If True, add the areas of the masks as a part of
|
||||
the RLE output dict under the "area" key. Default is False.
|
||||
|
||||
Returns:
|
||||
str: The RLE encoded masks
|
||||
"""
|
||||
assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)"
|
||||
assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool"
|
||||
|
||||
if orig_mask.numel() == 0:
|
||||
return []
|
||||
|
||||
# First, transpose the spatial dimensions.
|
||||
# This is necessary because the COCO API uses Fortran order
|
||||
mask = orig_mask.transpose(1, 2)
|
||||
|
||||
# Flatten the mask
|
||||
flat_mask = mask.reshape(mask.shape[0], -1)
|
||||
if return_areas:
|
||||
mask_areas = flat_mask.sum(-1).tolist()
|
||||
# Find the indices where the mask changes
|
||||
differences = torch.ones(
|
||||
mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool
|
||||
)
|
||||
differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]
|
||||
differences[:, 0] = flat_mask[:, 0]
|
||||
_, change_indices = torch.where(differences)
|
||||
|
||||
try:
|
||||
boundaries = torch.cumsum(differences.sum(-1), 0).cpu()
|
||||
except RuntimeError as _:
|
||||
boundaries = torch.cumsum(differences.cpu().sum(-1), 0)
|
||||
|
||||
change_indices_clone = change_indices.clone()
|
||||
# First pass computes the RLEs on GPU, in a flatten format
|
||||
for i in range(mask.shape[0]):
|
||||
# Get the change indices for this batch item
|
||||
beg = 0 if i == 0 else boundaries[i - 1].item()
|
||||
end = boundaries[i].item()
|
||||
change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1]
|
||||
|
||||
# Now we can split the RLES of each batch item, and convert them to strings
|
||||
# No more gpu at this point
|
||||
change_indices = change_indices.tolist()
|
||||
|
||||
batch_rles = []
|
||||
# Process each mask in the batch separately
|
||||
for i in range(mask.shape[0]):
|
||||
beg = 0 if i == 0 else boundaries[i - 1].item()
|
||||
end = boundaries[i].item()
|
||||
run_lengths = change_indices[beg:end]
|
||||
|
||||
uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}
|
||||
h, w = uncompressed_rle["size"]
|
||||
rle = mask_util.frPyObjects(uncompressed_rle, h, w)
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
if return_areas:
|
||||
rle["area"] = mask_areas[i]
|
||||
batch_rles.append(rle)
|
||||
|
||||
return batch_rles
|
||||
|
||||
|
||||
def robust_rle_encode(masks):
|
||||
"""Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails"""
|
||||
|
||||
assert masks.ndim == 3, "Mask must be of shape (N, H, W)"
|
||||
assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool"
|
||||
|
||||
try:
|
||||
return rle_encode(masks)
|
||||
except RuntimeError as _:
|
||||
masks = masks.cpu().numpy()
|
||||
rles = [
|
||||
mask_util.encode(
|
||||
np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F")
|
||||
)[0]
|
||||
for mask in masks
|
||||
]
|
||||
for rle in rles:
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
return rles
|
||||
|
||||
|
||||
def ann_to_rle(segm, im_info):
|
||||
"""Convert annotation which can be polygons, uncompressed RLE to RLE.
|
||||
Args:
|
||||
ann (dict) : annotation object
|
||||
Returns:
|
||||
ann (rle)
|
||||
"""
|
||||
h, w = im_info["height"], im_info["width"]
|
||||
if isinstance(segm, list):
|
||||
# polygon -- a single object might consist of multiple parts
|
||||
# we merge all parts into one mask rle code
|
||||
rles = mask_util.frPyObjects(segm, h, w)
|
||||
rle = mask_util.merge(rles)
|
||||
elif isinstance(segm["counts"], list):
|
||||
# uncompressed RLE
|
||||
rle = mask_util.frPyObjects(segm, h, w)
|
||||
else:
|
||||
# rle
|
||||
rle = segm
|
||||
return rle
|
||||
806
sam3/train/matcher.py
Normal file
806
sam3/train/matcher.py
Normal file
@@ -0,0 +1,806 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Modules to compute the matching cost and solve the corresponding LSAP.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from torch import nn
|
||||
|
||||
|
||||
def _do_matching(cost, repeats=1, return_tgt_indices=False, do_filtering=False):
|
||||
if repeats > 1:
|
||||
cost = np.tile(cost, (1, repeats))
|
||||
|
||||
i, j = linear_sum_assignment(cost)
|
||||
if do_filtering:
|
||||
# filter out invalid entries (i.e. those with cost > 1e8)
|
||||
valid_thresh = 1e8
|
||||
valid_ijs = [(ii, jj) for ii, jj in zip(i, j) if cost[ii, jj] < valid_thresh]
|
||||
i, j = zip(*valid_ijs) if len(valid_ijs) > 0 else ([], [])
|
||||
i, j = np.array(i, dtype=np.int64), np.array(j, dtype=np.int64)
|
||||
if return_tgt_indices:
|
||||
return i, j
|
||||
order = np.argsort(j)
|
||||
return i[order]
|
||||
|
||||
|
||||
class HungarianMatcher(nn.Module):
|
||||
"""This class computes an assignment between the targets and the predictions of the network
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
||||
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
||||
while the others are un-matched (and thus treated as non-objects).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_class: float = 1,
|
||||
cost_bbox: float = 1,
|
||||
cost_giou: float = 1,
|
||||
focal_loss: bool = False,
|
||||
focal_alpha: float = 0.25,
|
||||
focal_gamma: float = 2,
|
||||
):
|
||||
"""Creates the matcher
|
||||
|
||||
Params:
|
||||
cost_class: This is the relative weight of the classification error in the matching cost
|
||||
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
||||
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
|
||||
"""
|
||||
super().__init__()
|
||||
self.cost_class = cost_class
|
||||
self.cost_bbox = cost_bbox
|
||||
self.cost_giou = cost_giou
|
||||
self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1)
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
self.focal_loss = focal_loss
|
||||
self.focal_alpha = focal_alpha
|
||||
self.focal_gamma = focal_gamma
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, batched_targets):
|
||||
"""Performs the matching
|
||||
|
||||
Params:
|
||||
outputs: This is a dict that contains at least these entries:
|
||||
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
||||
|
||||
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
||||
objects in the target) containing the class labels
|
||||
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
||||
|
||||
Returns:
|
||||
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds:
|
||||
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
bs, num_queries = outputs["pred_logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = self.norm(
|
||||
outputs["pred_logits"].flatten(0, 1)
|
||||
) # [batch_size * num_queries, num_classes]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
tgt_bbox = batched_targets["boxes"]
|
||||
|
||||
if "positive_map" in batched_targets:
|
||||
# In this case we have a multi-hot target
|
||||
positive_map = batched_targets["positive_map"]
|
||||
assert len(tgt_bbox) == len(positive_map)
|
||||
|
||||
if self.focal_loss:
|
||||
positive_map = positive_map > 1e-4
|
||||
alpha = self.focal_alpha
|
||||
gamma = self.focal_gamma
|
||||
neg_cost_class = (
|
||||
(1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
||||
)
|
||||
pos_cost_class = (
|
||||
alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
||||
)
|
||||
cost_class = (
|
||||
(pos_cost_class - neg_cost_class).unsqueeze(1)
|
||||
* positive_map.unsqueeze(0)
|
||||
).sum(-1)
|
||||
else:
|
||||
# Compute the soft-cross entropy between the predicted token alignment and the GT one for each box
|
||||
cost_class = -(out_prob.unsqueeze(1) * positive_map.unsqueeze(0)).sum(
|
||||
-1
|
||||
)
|
||||
else:
|
||||
# In this case we are doing a "standard" cross entropy
|
||||
tgt_ids = batched_targets["labels"]
|
||||
assert len(tgt_bbox) == len(tgt_ids)
|
||||
|
||||
if self.focal_loss:
|
||||
alpha = self.focal_alpha
|
||||
gamma = self.focal_gamma
|
||||
neg_cost_class = (
|
||||
(1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
||||
)
|
||||
pos_cost_class = (
|
||||
alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
||||
)
|
||||
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
|
||||
else:
|
||||
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
||||
# but approximate it in 1 - proba[target class].
|
||||
# The 1 is a constant that doesn't change the matching, it can be omitted.
|
||||
cost_class = -out_prob[:, tgt_ids]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
|
||||
assert cost_class.shape == cost_bbox.shape
|
||||
|
||||
# Compute the giou cost betwen boxes
|
||||
cost_giou = -generalized_box_iou(
|
||||
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
|
||||
)
|
||||
|
||||
# Final cost matrix
|
||||
C = (
|
||||
self.cost_bbox * cost_bbox
|
||||
+ self.cost_class * cost_class
|
||||
+ self.cost_giou * cost_giou
|
||||
)
|
||||
C = C.view(bs, num_queries, -1).cpu().numpy()
|
||||
|
||||
sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1]
|
||||
costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))]
|
||||
indices = [_do_matching(c) for c in costs]
|
||||
batch_idx = torch.as_tensor(
|
||||
sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long
|
||||
)
|
||||
src_idx = torch.from_numpy(np.concatenate(indices)).long()
|
||||
return batch_idx, src_idx
|
||||
|
||||
|
||||
class BinaryHungarianMatcher(nn.Module):
|
||||
"""This class computes an assignment between the targets and the predictions of the network
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
||||
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
||||
while the others are un-matched (and thus treated as non-objects).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_class: float = 1,
|
||||
cost_bbox: float = 1,
|
||||
cost_giou: float = 1,
|
||||
):
|
||||
"""Creates the matcher
|
||||
|
||||
Params:
|
||||
cost_class: This is the relative weight of the classification error in the matching cost
|
||||
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
||||
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
|
||||
"""
|
||||
super().__init__()
|
||||
self.cost_class = cost_class
|
||||
self.cost_bbox = cost_bbox
|
||||
self.cost_giou = cost_giou
|
||||
self.norm = nn.Sigmoid()
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1):
|
||||
"""Performs the matching
|
||||
|
||||
Params:
|
||||
outputs: This is a dict that contains at least these entries:
|
||||
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
||||
|
||||
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
||||
objects in the target) containing the class labels
|
||||
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
||||
|
||||
Returns:
|
||||
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds:
|
||||
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
if repeat_batch != 1:
|
||||
raise NotImplementedError("please use BinaryHungarianMatcherV2 instead")
|
||||
|
||||
bs, num_queries = outputs["pred_logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = self.norm(outputs["pred_logits"].flatten(0, 1)).squeeze(
|
||||
-1
|
||||
) # [batch_size * num_queries]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
tgt_bbox = batched_targets["boxes"]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
|
||||
|
||||
cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox)
|
||||
|
||||
assert cost_class.shape == cost_bbox.shape
|
||||
|
||||
# Compute the giou cost betwen boxes
|
||||
cost_giou = -generalized_box_iou(
|
||||
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
|
||||
)
|
||||
|
||||
# Final cost matrix
|
||||
C = (
|
||||
self.cost_bbox * cost_bbox
|
||||
+ self.cost_class * cost_class
|
||||
+ self.cost_giou * cost_giou
|
||||
)
|
||||
C = C.view(bs, num_queries, -1).cpu().numpy()
|
||||
|
||||
sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1]
|
||||
costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))]
|
||||
return_tgt_indices = False
|
||||
for c in costs:
|
||||
n_targ = c.shape[1]
|
||||
if repeats > 1:
|
||||
n_targ *= repeats
|
||||
if c.shape[0] < n_targ:
|
||||
return_tgt_indices = True
|
||||
break
|
||||
if return_tgt_indices:
|
||||
indices, tgt_indices = zip(
|
||||
*(
|
||||
_do_matching(
|
||||
c, repeats=repeats, return_tgt_indices=return_tgt_indices
|
||||
)
|
||||
for c in costs
|
||||
)
|
||||
)
|
||||
tgt_indices = list(tgt_indices)
|
||||
for i in range(1, len(tgt_indices)):
|
||||
tgt_indices[i] += sizes[i - 1].item()
|
||||
tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long()
|
||||
else:
|
||||
indices = [_do_matching(c, repeats=repeats) for c in costs]
|
||||
tgt_idx = None
|
||||
|
||||
batch_idx = torch.as_tensor(
|
||||
sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long
|
||||
)
|
||||
src_idx = torch.from_numpy(np.concatenate(indices)).long()
|
||||
return batch_idx, src_idx, tgt_idx
|
||||
|
||||
|
||||
class BinaryFocalHungarianMatcher(nn.Module):
|
||||
"""This class computes an assignment between the targets and the predictions of the network
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
||||
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
||||
while the others are un-matched (and thus treated as non-objects).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_class: float = 1,
|
||||
cost_bbox: float = 1,
|
||||
cost_giou: float = 1,
|
||||
alpha: float = 0.25,
|
||||
gamma: float = 2.0,
|
||||
stable: bool = False,
|
||||
):
|
||||
"""Creates the matcher
|
||||
|
||||
Params:
|
||||
cost_class: This is the relative weight of the classification error in the matching cost
|
||||
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
||||
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
|
||||
"""
|
||||
super().__init__()
|
||||
self.cost_class = cost_class
|
||||
self.cost_bbox = cost_bbox
|
||||
self.cost_giou = cost_giou
|
||||
self.norm = nn.Sigmoid()
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.stable = stable
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1):
|
||||
"""Performs the matching
|
||||
|
||||
Params:
|
||||
outputs: This is a dict that contains at least these entries:
|
||||
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
||||
|
||||
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
||||
objects in the target) containing the class labels
|
||||
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
||||
|
||||
Returns:
|
||||
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds:
|
||||
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
if repeat_batch != 1:
|
||||
raise NotImplementedError("please use BinaryHungarianMatcherV2 instead")
|
||||
|
||||
bs, num_queries = outputs["pred_logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_score = outputs["pred_logits"].flatten(0, 1).squeeze(-1)
|
||||
out_prob = self.norm(out_score) # [batch_size * num_queries]
|
||||
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
||||
|
||||
# Also concat the target labels and boxes
|
||||
tgt_bbox = batched_targets["boxes"]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
|
||||
|
||||
# Compute the giou cost betwen boxes
|
||||
cost_giou = -generalized_box_iou(
|
||||
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
|
||||
)
|
||||
|
||||
# cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox)
|
||||
if self.stable:
|
||||
rescaled_giou = (-cost_giou + 1) / 2
|
||||
out_prob = out_prob.unsqueeze(-1).expand_as(cost_bbox) * rescaled_giou
|
||||
cost_class = -self.alpha * (1 - out_prob) ** self.gamma * torch.log(
|
||||
out_prob
|
||||
) + (1 - self.alpha) * out_prob**self.gamma * torch.log(1 - out_prob)
|
||||
else:
|
||||
# directly computing log sigmoid (more numerically stable)
|
||||
log_out_prob = torch.nn.functional.logsigmoid(out_score)
|
||||
log_one_minus_out_prob = torch.nn.functional.logsigmoid(-out_score)
|
||||
cost_class = (
|
||||
-self.alpha * (1 - out_prob) ** self.gamma * log_out_prob
|
||||
+ (1 - self.alpha) * out_prob**self.gamma * log_one_minus_out_prob
|
||||
)
|
||||
if not self.stable:
|
||||
cost_class = cost_class.unsqueeze(-1).expand_as(cost_bbox)
|
||||
|
||||
assert cost_class.shape == cost_bbox.shape
|
||||
|
||||
# Final cost matrix
|
||||
C = (
|
||||
self.cost_bbox * cost_bbox
|
||||
+ self.cost_class * cost_class
|
||||
+ self.cost_giou * cost_giou
|
||||
)
|
||||
C = C.view(bs, num_queries, -1).cpu().numpy()
|
||||
|
||||
sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1]
|
||||
costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))]
|
||||
return_tgt_indices = False
|
||||
for c in costs:
|
||||
n_targ = c.shape[1]
|
||||
if repeats > 1:
|
||||
n_targ *= repeats
|
||||
if c.shape[0] < n_targ:
|
||||
return_tgt_indices = True
|
||||
break
|
||||
if return_tgt_indices:
|
||||
indices, tgt_indices = zip(
|
||||
*(
|
||||
_do_matching(
|
||||
c, repeats=repeats, return_tgt_indices=return_tgt_indices
|
||||
)
|
||||
for c in costs
|
||||
)
|
||||
)
|
||||
tgt_indices = list(tgt_indices)
|
||||
for i in range(1, len(tgt_indices)):
|
||||
tgt_indices[i] += sizes[i - 1].item()
|
||||
tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long()
|
||||
else:
|
||||
indices = [_do_matching(c, repeats=repeats) for c in costs]
|
||||
tgt_idx = None
|
||||
|
||||
batch_idx = torch.as_tensor(
|
||||
sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long
|
||||
)
|
||||
src_idx = torch.from_numpy(np.concatenate(indices)).long()
|
||||
return batch_idx, src_idx, tgt_idx
|
||||
|
||||
|
||||
class BinaryHungarianMatcherV2(nn.Module):
|
||||
"""
|
||||
This class computes an assignment between the targets and the predictions
|
||||
of the network
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of
|
||||
this, in general, there are more predictions than targets. In this case, we
|
||||
do a 1-to-1 matching of the best predictions, while the others are
|
||||
un-matched (and thus treated as non-objects).
|
||||
|
||||
This is a more efficient implementation of BinaryHungarianMatcher.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_class: float = 1,
|
||||
cost_bbox: float = 1,
|
||||
cost_giou: float = 1,
|
||||
focal: bool = False,
|
||||
alpha: float = 0.25,
|
||||
gamma: float = 2.0,
|
||||
stable: bool = False,
|
||||
remove_samples_with_0_gt: bool = True,
|
||||
):
|
||||
"""
|
||||
Creates the matcher
|
||||
|
||||
Params:
|
||||
- cost_class: Relative weight of the classification error in the
|
||||
matching cost
|
||||
- cost_bbox: Relative weight of the L1 error of the bounding box
|
||||
coordinates in the matching cost
|
||||
- cost_giou: This is the relative weight of the giou loss of the
|
||||
bounding box in the matching cost
|
||||
"""
|
||||
super().__init__()
|
||||
self.cost_class = cost_class
|
||||
self.cost_bbox = cost_bbox
|
||||
self.cost_giou = cost_giou
|
||||
self.norm = nn.Sigmoid()
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
self.focal = focal
|
||||
if focal:
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.stable = stable
|
||||
self.remove_samples_with_0_gt = remove_samples_with_0_gt
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
outputs,
|
||||
batched_targets,
|
||||
repeats=1,
|
||||
repeat_batch=1,
|
||||
out_is_valid=None,
|
||||
target_is_valid_padded=None,
|
||||
):
|
||||
"""
|
||||
Performs the matching. The inputs and outputs are the same as
|
||||
BinaryHungarianMatcher.forward, except for the optional cached_padded
|
||||
flag and the optional "_boxes_padded" entry of batched_targets.
|
||||
|
||||
Inputs:
|
||||
- outputs: A dict with the following keys:
|
||||
- "pred_logits": Tensor of shape (batch_size, num_queries, 1) with
|
||||
classification logits
|
||||
- "pred_boxes": Tensor of shape (batch_size, num_queries, 4) with
|
||||
predicted box coordinates in cxcywh format.
|
||||
- batched_targets: A dict of targets. There may be a variable number of
|
||||
targets per batch entry; suppose that there are T_b targets for batch
|
||||
entry 0 <= b < batch_size. It should have the following keys:
|
||||
- "boxes": Tensor of shape (sum_b T_b, 4) giving ground-truth boxes
|
||||
in cxcywh format for all batch entries packed into a single tensor
|
||||
- "num_boxes": int64 Tensor of shape (batch_size,) giving the number
|
||||
of ground-truth boxes per batch entry: num_boxes[b] = T_b
|
||||
- "_boxes_padded": Tensor of shape (batch_size, max_b T_b, 4) giving
|
||||
a padded version of ground-truth boxes. If this is not present then
|
||||
it will be computed from batched_targets["boxes"] instead, but
|
||||
caching it here can improve performance for repeated calls with the
|
||||
same targets.
|
||||
- out_is_valid: If not None, it should be a boolean tensor of shape
|
||||
(batch_size, num_queries) indicating which predictions are valid.
|
||||
Invalid predictions are ignored during matching and won't appear in
|
||||
the output indices.
|
||||
- target_is_valid_padded: If not None, it should be a boolean tensor of
|
||||
shape (batch_size, max_num_gt_boxes) in padded format indicating
|
||||
which GT boxes are valid. Invalid GT boxes are ignored during matching
|
||||
and won't appear in the output indices.
|
||||
|
||||
Returns:
|
||||
A list of size batch_size, containing tuples of (idx_i, idx_j):
|
||||
- idx_i is the indices of the selected predictions (in order)
|
||||
- idx_j is the indices of the corresponding selected targets
|
||||
(in order)
|
||||
For each batch element, it holds:
|
||||
len(index_i) = len(index_j)
|
||||
= min(num_queries, num_target_boxes)
|
||||
"""
|
||||
_, num_queries = outputs["pred_logits"].shape[:2]
|
||||
|
||||
out_score = outputs["pred_logits"].squeeze(-1) # (B, Q)
|
||||
out_bbox = outputs["pred_boxes"] # (B, Q, 4))
|
||||
|
||||
device = out_score.device
|
||||
|
||||
num_boxes = batched_targets["num_boxes"].cpu()
|
||||
# Get a padded version of target boxes (as precomputed in the collator).
|
||||
# It should work for both repeat==1 (o2o) and repeat>1 (o2m) matching.
|
||||
tgt_bbox = batched_targets["boxes_padded"]
|
||||
if self.remove_samples_with_0_gt:
|
||||
# keep only samples w/ at least 1 GT box in targets (num_boxes and tgt_bbox)
|
||||
batch_keep = num_boxes > 0
|
||||
num_boxes = num_boxes[batch_keep]
|
||||
tgt_bbox = tgt_bbox[batch_keep]
|
||||
if target_is_valid_padded is not None:
|
||||
target_is_valid_padded = target_is_valid_padded[batch_keep]
|
||||
# Repeat the targets (for the case of batched aux outputs in the matcher)
|
||||
if repeat_batch > 1:
|
||||
# In this case, out_prob and out_bbox will be a concatenation of
|
||||
# both final and auxiliary outputs, so we also repeat the targets
|
||||
num_boxes = num_boxes.repeat(repeat_batch)
|
||||
tgt_bbox = tgt_bbox.repeat(repeat_batch, 1, 1)
|
||||
if target_is_valid_padded is not None:
|
||||
target_is_valid_padded = target_is_valid_padded.repeat(repeat_batch, 1)
|
||||
|
||||
# keep only samples w/ at least 1 GT box in outputs
|
||||
if self.remove_samples_with_0_gt:
|
||||
if repeat_batch > 1:
|
||||
batch_keep = batch_keep.repeat(repeat_batch)
|
||||
out_score = out_score[batch_keep]
|
||||
out_bbox = out_bbox[batch_keep]
|
||||
if out_is_valid is not None:
|
||||
out_is_valid = out_is_valid[batch_keep]
|
||||
assert out_bbox.shape[0] == tgt_bbox.shape[0]
|
||||
assert out_bbox.shape[0] == num_boxes.shape[0]
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
|
||||
|
||||
# Compute the giou cost betwen boxes
|
||||
cost_giou = -generalized_box_iou(
|
||||
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
|
||||
)
|
||||
|
||||
out_prob = self.norm(out_score)
|
||||
if not self.focal:
|
||||
cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox)
|
||||
else:
|
||||
if self.stable:
|
||||
rescaled_giou = (-cost_giou + 1) / 2
|
||||
out_prob = out_prob.unsqueeze(-1).expand_as(cost_bbox) * rescaled_giou
|
||||
cost_class = -self.alpha * (1 - out_prob) ** self.gamma * torch.log(
|
||||
out_prob
|
||||
) + (1 - self.alpha) * out_prob**self.gamma * torch.log(1 - out_prob)
|
||||
else:
|
||||
# directly computing log sigmoid (more numerically stable)
|
||||
log_out_prob = torch.nn.functional.logsigmoid(out_score)
|
||||
log_one_minus_out_prob = torch.nn.functional.logsigmoid(-out_score)
|
||||
cost_class = (
|
||||
-self.alpha * (1 - out_prob) ** self.gamma * log_out_prob
|
||||
+ (1 - self.alpha) * out_prob**self.gamma * log_one_minus_out_prob
|
||||
)
|
||||
if not self.stable:
|
||||
cost_class = cost_class.unsqueeze(-1).expand_as(cost_bbox)
|
||||
|
||||
assert cost_class.shape == cost_bbox.shape
|
||||
|
||||
# Final cost matrix
|
||||
C = (
|
||||
self.cost_bbox * cost_bbox
|
||||
+ self.cost_class * cost_class
|
||||
+ self.cost_giou * cost_giou
|
||||
)
|
||||
# assign a very high cost (1e9) to invalid outputs and targets, so that we can
|
||||
# filter them out (in `_do_matching`) from bipartite matching results
|
||||
do_filtering = out_is_valid is not None or target_is_valid_padded is not None
|
||||
if out_is_valid is not None:
|
||||
C = torch.where(out_is_valid[:, :, None], C, 1e9)
|
||||
if target_is_valid_padded is not None:
|
||||
C = torch.where(target_is_valid_padded[:, None, :], C, 1e9)
|
||||
C = C.cpu().numpy()
|
||||
costs = [C[i, :, :s] for i, s in enumerate(num_boxes.tolist())]
|
||||
return_tgt_indices = (
|
||||
do_filtering or torch.any(num_queries < num_boxes * max(repeats, 1)).item()
|
||||
)
|
||||
if len(costs) == 0:
|
||||
# We have size 0 in the batch dimension, so we return empty matching indices
|
||||
# (note that this can happen due to `remove_samples_with_0_gt=True` even if
|
||||
# the original input batch size is not 0, when all queries have empty GTs).
|
||||
indices = []
|
||||
tgt_idx = torch.zeros(0).long().to(device) if return_tgt_indices else None
|
||||
elif return_tgt_indices:
|
||||
indices, tgt_indices = zip(
|
||||
*(
|
||||
_do_matching(
|
||||
c,
|
||||
repeats=repeats,
|
||||
return_tgt_indices=return_tgt_indices,
|
||||
do_filtering=do_filtering,
|
||||
)
|
||||
for c in costs
|
||||
)
|
||||
)
|
||||
tgt_indices = list(tgt_indices)
|
||||
sizes = torch.cumsum(num_boxes, -1)[:-1]
|
||||
for i in range(1, len(tgt_indices)):
|
||||
tgt_indices[i] += sizes[i - 1].item()
|
||||
tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long().to(device)
|
||||
else:
|
||||
indices = [
|
||||
_do_matching(c, repeats=repeats, do_filtering=do_filtering)
|
||||
for c in costs
|
||||
]
|
||||
tgt_idx = None
|
||||
|
||||
if self.remove_samples_with_0_gt:
|
||||
kept_inds = batch_keep.nonzero().squeeze(1)
|
||||
batch_idx = torch.as_tensor(
|
||||
sum([[kept_inds[i]] * len(src) for i, src in enumerate(indices)], []),
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
batch_idx = torch.as_tensor(
|
||||
sum([[i] * len(src) for i, src in enumerate(indices)], []),
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# indices could be an empty list (since we remove samples w/ 0 GT boxes)
|
||||
if len(indices) > 0:
|
||||
src_idx = torch.from_numpy(np.concatenate(indices)).long().to(device)
|
||||
else:
|
||||
src_idx = torch.empty(0, dtype=torch.long, device=device)
|
||||
return batch_idx, src_idx, tgt_idx
|
||||
|
||||
|
||||
class BinaryOneToManyMatcher(nn.Module):
|
||||
"""
|
||||
This class computes a greedy assignment between the targets and the predictions of the network.
|
||||
In this formulation, several predictions can be assigned to each target, but each prediction can be assigned to
|
||||
at most one target.
|
||||
|
||||
See DAC-Detr for details
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float = 0.3,
|
||||
threshold: float = 0.4,
|
||||
topk: int = 6,
|
||||
):
|
||||
"""
|
||||
Creates the matcher
|
||||
|
||||
Params:
|
||||
alpha: relative balancing between classification and localization
|
||||
threshold: threshold used to select positive predictions
|
||||
topk: number of top scoring predictions to consider
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm = nn.Sigmoid()
|
||||
self.alpha = alpha
|
||||
self.threshold = threshold
|
||||
self.topk = topk
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
outputs,
|
||||
batched_targets,
|
||||
repeats=1,
|
||||
repeat_batch=1,
|
||||
out_is_valid=None,
|
||||
target_is_valid_padded=None,
|
||||
):
|
||||
"""
|
||||
Performs the matching. The inputs and outputs are the same as
|
||||
BinaryHungarianMatcher.forward
|
||||
|
||||
Inputs:
|
||||
- outputs: A dict with the following keys:
|
||||
- "pred_logits": Tensor of shape (batch_size, num_queries, 1) with
|
||||
classification logits
|
||||
- "pred_boxes": Tensor of shape (batch_size, num_queries, 4) with
|
||||
predicted box coordinates in cxcywh format.
|
||||
- batched_targets: A dict of targets. There may be a variable number of
|
||||
targets per batch entry; suppose that there are T_b targets for batch
|
||||
entry 0 <= b < batch_size. It should have the following keys:
|
||||
- "num_boxes": int64 Tensor of shape (batch_size,) giving the number
|
||||
of ground-truth boxes per batch entry: num_boxes[b] = T_b
|
||||
- "_boxes_padded": Tensor of shape (batch_size, max_b T_b, 4) giving
|
||||
a padded version of ground-truth boxes. If this is not present then
|
||||
it will be computed from batched_targets["boxes"] instead, but
|
||||
caching it here can improve performance for repeated calls with the
|
||||
same targets.
|
||||
- out_is_valid: If not None, it should be a boolean tensor of shape
|
||||
(batch_size, num_queries) indicating which predictions are valid.
|
||||
Invalid predictions are ignored during matching and won't appear in
|
||||
the output indices.
|
||||
- target_is_valid_padded: If not None, it should be a boolean tensor of
|
||||
shape (batch_size, max_num_gt_boxes) in padded format indicating
|
||||
which GT boxes are valid. Invalid GT boxes are ignored during matching
|
||||
and won't appear in the output indices.
|
||||
Returns:
|
||||
A list of size batch_size, containing tuples of (idx_i, idx_j):
|
||||
- idx_i is the indices of the selected predictions (in order)
|
||||
- idx_j is the indices of the corresponding selected targets
|
||||
(in order)
|
||||
For each batch element, it holds:
|
||||
len(index_i) = len(index_j)
|
||||
= min(num_queries, num_target_boxes)
|
||||
"""
|
||||
assert repeats <= 1 and repeat_batch <= 1
|
||||
bs, num_queries = outputs["pred_logits"].shape[:2]
|
||||
|
||||
out_prob = self.norm(outputs["pred_logits"]).squeeze(-1) # (B, Q)
|
||||
out_bbox = outputs["pred_boxes"] # (B, Q, 4))
|
||||
|
||||
num_boxes = batched_targets["num_boxes"]
|
||||
|
||||
# Get a padded version of target boxes (as precomputed in the collator).
|
||||
tgt_bbox = batched_targets["boxes_padded"]
|
||||
assert len(tgt_bbox) == bs
|
||||
num_targets = tgt_bbox.shape[1]
|
||||
if num_targets == 0:
|
||||
return (
|
||||
torch.empty(0, dtype=torch.long, device=out_prob.device),
|
||||
torch.empty(0, dtype=torch.long, device=out_prob.device),
|
||||
torch.empty(0, dtype=torch.long, device=out_prob.device),
|
||||
)
|
||||
|
||||
iou, _ = box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
|
||||
|
||||
assert iou.shape == (bs, num_queries, num_targets)
|
||||
|
||||
# Final cost matrix (higher is better in `C`; this is unlike the case
|
||||
# of BinaryHungarianMatcherV2 above where lower is better in its `C`)
|
||||
C = self.alpha * out_prob.unsqueeze(-1) + (1 - self.alpha) * iou
|
||||
if out_is_valid is not None:
|
||||
C = torch.where(out_is_valid[:, :, None], C, -1e9)
|
||||
if target_is_valid_padded is not None:
|
||||
C = torch.where(target_is_valid_padded[:, None, :], C, -1e9)
|
||||
|
||||
# Selecting topk predictions
|
||||
matches = C > torch.quantile(
|
||||
C, 1 - self.topk / num_queries, dim=1, keepdim=True
|
||||
)
|
||||
|
||||
# Selecting predictions above threshold
|
||||
matches = matches & (C > self.threshold)
|
||||
if out_is_valid is not None:
|
||||
matches = matches & out_is_valid[:, :, None]
|
||||
if target_is_valid_padded is not None:
|
||||
matches = matches & target_is_valid_padded[:, None, :]
|
||||
|
||||
# Removing padding
|
||||
matches = matches & (
|
||||
torch.arange(0, num_targets, device=num_boxes.device)[None]
|
||||
< num_boxes[:, None]
|
||||
).unsqueeze(1)
|
||||
|
||||
batch_idx, src_idx, tgt_idx = torch.nonzero(matches, as_tuple=True)
|
||||
|
||||
cum_num_boxes = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=num_boxes.dtype, device=num_boxes.device),
|
||||
num_boxes.cumsum(-1)[:-1],
|
||||
]
|
||||
)
|
||||
tgt_idx += cum_num_boxes[batch_idx]
|
||||
|
||||
return batch_idx, src_idx, tgt_idx
|
||||
306
sam3/train/nms_helper.py
Normal file
306
sam3/train/nms_helper.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
import warnings
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Check if Numba is available
|
||||
HAS_NUMBA = False
|
||||
try:
|
||||
import numba as nb
|
||||
|
||||
HAS_NUMBA = True
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
"Numba not found. Using slower pure Python implementations.", UserWarning
|
||||
)
|
||||
|
||||
|
||||
# -------------------- Helper Functions --------------------
|
||||
def is_zero_box(bbox: list) -> bool:
|
||||
"""Check if bounding box is invalid"""
|
||||
if bbox is None:
|
||||
return True
|
||||
return all(x <= 0 for x in bbox[:4]) or len(bbox) < 4
|
||||
|
||||
|
||||
def convert_bbox_format(bbox: list) -> List[float]:
|
||||
"""Convert bbox from (x,y,w,h) to (x1,y1,x2,y2)"""
|
||||
x, y, w, h = bbox
|
||||
return [x, y, x + w, y + h]
|
||||
|
||||
|
||||
# -------------------- Track-level NMS --------------------
|
||||
def process_track_level_nms(video_groups: Dict, nms_threshold: float) -> Dict:
|
||||
"""Apply track-level NMS to all videos"""
|
||||
for video_id, tracks in video_groups.items():
|
||||
track_detections = []
|
||||
|
||||
# Process tracks
|
||||
for track_idx, track in enumerate(tracks):
|
||||
if not track["bboxes"]:
|
||||
continue
|
||||
|
||||
converted_bboxes = []
|
||||
valid_frames = []
|
||||
for bbox in track["bboxes"]:
|
||||
if bbox and not is_zero_box(bbox):
|
||||
converted_bboxes.append(convert_bbox_format(bbox))
|
||||
valid_frames.append(True)
|
||||
else:
|
||||
converted_bboxes.append([np.nan] * 4)
|
||||
valid_frames.append(False)
|
||||
|
||||
if any(valid_frames):
|
||||
track_detections.append(
|
||||
{
|
||||
"track_idx": track_idx,
|
||||
"bboxes": np.array(converted_bboxes, dtype=np.float32),
|
||||
"score": track["score"],
|
||||
}
|
||||
)
|
||||
|
||||
# Apply NMS
|
||||
if track_detections:
|
||||
scores = np.array([d["score"] for d in track_detections], dtype=np.float32)
|
||||
keep = apply_track_nms(track_detections, scores, nms_threshold)
|
||||
|
||||
# Suppress non-kept tracks
|
||||
for idx, track in enumerate(track_detections):
|
||||
if idx not in keep:
|
||||
tracks[track["track_idx"]]["bboxes"] = [None] * len(track["bboxes"])
|
||||
|
||||
return video_groups
|
||||
|
||||
|
||||
# -------------------- Frame-level NMS --------------------
|
||||
def process_frame_level_nms(video_groups: Dict, nms_threshold: float) -> Dict:
|
||||
"""Apply frame-level NMS to all videos"""
|
||||
for video_id, tracks in video_groups.items():
|
||||
if not tracks:
|
||||
continue
|
||||
|
||||
num_frames = len(tracks[0]["bboxes"])
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
frame_detections = []
|
||||
|
||||
# Collect valid detections
|
||||
for track_idx, track in enumerate(tracks):
|
||||
bbox = track["bboxes"][frame_idx]
|
||||
if bbox and not is_zero_box(bbox):
|
||||
frame_detections.append(
|
||||
{
|
||||
"track_idx": track_idx,
|
||||
"bbox": np.array(
|
||||
convert_bbox_format(bbox), dtype=np.float32
|
||||
),
|
||||
"score": track["score"],
|
||||
}
|
||||
)
|
||||
|
||||
# Apply NMS
|
||||
if frame_detections:
|
||||
bboxes = np.stack([d["bbox"] for d in frame_detections])
|
||||
scores = np.array(
|
||||
[d["score"] for d in frame_detections], dtype=np.float32
|
||||
)
|
||||
keep = apply_frame_nms(bboxes, scores, nms_threshold)
|
||||
|
||||
# Suppress non-kept detections
|
||||
for i, d in enumerate(frame_detections):
|
||||
if i not in keep:
|
||||
tracks[d["track_idx"]]["bboxes"][frame_idx] = None
|
||||
|
||||
return video_groups
|
||||
|
||||
|
||||
# Track-level NMS helpers ------------------------------------------------------
|
||||
def compute_track_iou_matrix(
|
||||
bboxes_stacked: np.ndarray, valid_masks: np.ndarray, areas: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""IoU matrix computation for track-level NMS with fallback to pure Python"""
|
||||
num_tracks = bboxes_stacked.shape[0]
|
||||
iou_matrix = np.zeros((num_tracks, num_tracks), dtype=np.float32)
|
||||
if HAS_NUMBA:
|
||||
iou_matrix = _compute_track_iou_matrix_numba(bboxes_stacked, valid_masks, areas)
|
||||
else:
|
||||
# Pure Python implementation
|
||||
for i in range(num_tracks):
|
||||
for j in range(i + 1, num_tracks):
|
||||
valid_ij = valid_masks[i] & valid_masks[j]
|
||||
if not valid_ij.any():
|
||||
continue
|
||||
bboxes_i = bboxes_stacked[i, valid_ij]
|
||||
bboxes_j = bboxes_stacked[j, valid_ij]
|
||||
area_i = areas[i, valid_ij]
|
||||
area_j = areas[j, valid_ij]
|
||||
inter_total = 0.0
|
||||
union_total = 0.0
|
||||
for k in range(bboxes_i.shape[0]):
|
||||
x1 = max(bboxes_i[k, 0], bboxes_j[k, 0])
|
||||
y1 = max(bboxes_i[k, 1], bboxes_j[k, 1])
|
||||
x2 = min(bboxes_i[k, 2], bboxes_j[k, 2])
|
||||
y2 = min(bboxes_i[k, 3], bboxes_j[k, 3])
|
||||
inter = max(0, x2 - x1) * max(0, y2 - y1)
|
||||
union = area_i[k] + area_j[k] - inter
|
||||
inter_total += inter
|
||||
union_total += union
|
||||
if union_total > 0:
|
||||
iou_matrix[i, j] = inter_total / union_total
|
||||
iou_matrix[j, i] = iou_matrix[i, j]
|
||||
return iou_matrix
|
||||
|
||||
|
||||
if HAS_NUMBA:
|
||||
|
||||
@nb.jit(nopython=True, parallel=True)
|
||||
def _compute_track_iou_matrix_numba(bboxes_stacked, valid_masks, areas):
|
||||
"""Numba-optimized IoU matrix computation for track-level NMS"""
|
||||
num_tracks = bboxes_stacked.shape[0]
|
||||
iou_matrix = np.zeros((num_tracks, num_tracks), dtype=np.float32)
|
||||
for i in nb.prange(num_tracks):
|
||||
for j in range(i + 1, num_tracks):
|
||||
valid_ij = valid_masks[i] & valid_masks[j]
|
||||
if not valid_ij.any():
|
||||
continue
|
||||
bboxes_i = bboxes_stacked[i, valid_ij]
|
||||
bboxes_j = bboxes_stacked[j, valid_ij]
|
||||
area_i = areas[i, valid_ij]
|
||||
area_j = areas[j, valid_ij]
|
||||
inter_total = 0.0
|
||||
union_total = 0.0
|
||||
for k in range(bboxes_i.shape[0]):
|
||||
x1 = max(bboxes_i[k, 0], bboxes_j[k, 0])
|
||||
y1 = max(bboxes_i[k, 1], bboxes_j[k, 1])
|
||||
x2 = min(bboxes_i[k, 2], bboxes_j[k, 2])
|
||||
y2 = min(bboxes_i[k, 3], bboxes_j[k, 3])
|
||||
inter = max(0, x2 - x1) * max(0, y2 - y1)
|
||||
union = area_i[k] + area_j[k] - inter
|
||||
inter_total += inter
|
||||
union_total += union
|
||||
if union_total > 0:
|
||||
iou_matrix[i, j] = inter_total / union_total
|
||||
iou_matrix[j, i] = iou_matrix[i, j]
|
||||
return iou_matrix
|
||||
|
||||
|
||||
def apply_track_nms(
|
||||
track_detections: List[dict], scores: np.ndarray, nms_threshold: float
|
||||
) -> List[int]:
|
||||
"""Vectorized track-level NMS implementation"""
|
||||
if not track_detections:
|
||||
return []
|
||||
bboxes_stacked = np.stack([d["bboxes"] for d in track_detections], axis=0)
|
||||
valid_masks = ~np.isnan(bboxes_stacked).any(axis=2)
|
||||
areas = (bboxes_stacked[:, :, 2] - bboxes_stacked[:, :, 0]) * (
|
||||
bboxes_stacked[:, :, 3] - bboxes_stacked[:, :, 1]
|
||||
)
|
||||
areas[~valid_masks] = 0
|
||||
iou_matrix = compute_track_iou_matrix(bboxes_stacked, valid_masks, areas)
|
||||
keep = []
|
||||
order = np.argsort(-scores)
|
||||
suppress = np.zeros(len(track_detections), dtype=bool)
|
||||
for i in range(len(order)):
|
||||
if not suppress[order[i]]:
|
||||
keep.append(order[i])
|
||||
suppress[order[i:]] = suppress[order[i:]] | (
|
||||
iou_matrix[order[i], order[i:]] >= nms_threshold
|
||||
)
|
||||
return keep
|
||||
|
||||
|
||||
# Frame-level NMS helpers ------------------------------------------------------
|
||||
def compute_frame_ious(bbox: np.ndarray, bboxes: np.ndarray) -> np.ndarray:
|
||||
"""IoU computation for frame-level NMS with fallback to pure Python"""
|
||||
if HAS_NUMBA:
|
||||
return _compute_frame_ious_numba(bbox, bboxes)
|
||||
else:
|
||||
# Pure Python implementation
|
||||
ious = np.zeros(len(bboxes), dtype=np.float32)
|
||||
for i in range(len(bboxes)):
|
||||
x1 = max(bbox[0], bboxes[i, 0])
|
||||
y1 = max(bbox[1], bboxes[i, 1])
|
||||
x2 = min(bbox[2], bboxes[i, 2])
|
||||
y2 = min(bbox[3], bboxes[i, 3])
|
||||
|
||||
inter = max(0, x2 - x1) * max(0, y2 - y1)
|
||||
area1 = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||
area2 = (bboxes[i, 2] - bboxes[i, 0]) * (bboxes[i, 3] - bboxes[i, 1])
|
||||
union = area1 + area2 - inter
|
||||
|
||||
ious[i] = inter / union if union > 0 else 0.0
|
||||
return ious
|
||||
|
||||
|
||||
if HAS_NUMBA:
|
||||
|
||||
@nb.jit(nopython=True, parallel=True)
|
||||
def _compute_frame_ious_numba(bbox, bboxes):
|
||||
"""Numba-optimized IoU computation"""
|
||||
ious = np.zeros(len(bboxes), dtype=np.float32)
|
||||
for i in nb.prange(len(bboxes)):
|
||||
x1 = max(bbox[0], bboxes[i, 0])
|
||||
y1 = max(bbox[1], bboxes[i, 1])
|
||||
x2 = min(bbox[2], bboxes[i, 2])
|
||||
y2 = min(bbox[3], bboxes[i, 3])
|
||||
|
||||
inter = max(0, x2 - x1) * max(0, y2 - y1)
|
||||
area1 = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||
area2 = (bboxes[i, 2] - bboxes[i, 0]) * (bboxes[i, 3] - bboxes[i, 1])
|
||||
union = area1 + area2 - inter
|
||||
|
||||
ious[i] = inter / union if union > 0 else 0.0
|
||||
return ious
|
||||
|
||||
|
||||
def apply_frame_nms(
|
||||
bboxes: np.ndarray, scores: np.ndarray, nms_threshold: float
|
||||
) -> List[int]:
|
||||
"""Frame-level NMS implementation with fallback to pure Python"""
|
||||
if HAS_NUMBA:
|
||||
return _apply_frame_nms_numba(bboxes, scores, nms_threshold)
|
||||
else:
|
||||
# Pure Python implementation
|
||||
order = np.argsort(-scores)
|
||||
keep = []
|
||||
suppress = np.zeros(len(bboxes), dtype=bool)
|
||||
|
||||
for i in range(len(order)):
|
||||
if not suppress[order[i]]:
|
||||
keep.append(order[i])
|
||||
current_bbox = bboxes[order[i]]
|
||||
|
||||
remaining_bboxes = bboxes[order[i + 1 :]]
|
||||
if len(remaining_bboxes) > 0: # Check if there are any remaining boxes
|
||||
ious = compute_frame_ious(current_bbox, remaining_bboxes)
|
||||
suppress[order[i + 1 :]] = suppress[order[i + 1 :]] | (
|
||||
ious >= nms_threshold
|
||||
)
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
if HAS_NUMBA:
|
||||
|
||||
@nb.jit(nopython=True)
|
||||
def _apply_frame_nms_numba(bboxes, scores, nms_threshold):
|
||||
"""Numba-optimized NMS implementation"""
|
||||
order = np.argsort(-scores)
|
||||
keep = []
|
||||
suppress = np.zeros(len(bboxes), dtype=nb.boolean)
|
||||
|
||||
for i in range(len(order)):
|
||||
if not suppress[order[i]]:
|
||||
keep.append(order[i])
|
||||
current_bbox = bboxes[order[i]]
|
||||
|
||||
if i + 1 < len(order): # Check bounds
|
||||
ious = _compute_frame_ious_numba(
|
||||
current_bbox, bboxes[order[i + 1 :]]
|
||||
)
|
||||
suppress[order[i + 1 :]] = suppress[order[i + 1 :]] | (
|
||||
ious >= nms_threshold
|
||||
)
|
||||
|
||||
return keep
|
||||
1
sam3/train/optim/__init__.py
Normal file
1
sam3/train/optim/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
498
sam3/train/optim/optimizer.py
Normal file
498
sam3/train/optim/optimizer.py
Normal file
@@ -0,0 +1,498 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import fnmatch
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import hydra
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import DictConfig
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, optimizer, schedulers=None) -> None:
|
||||
self.optimizer = optimizer
|
||||
self.schedulers = schedulers
|
||||
self._validate_optimizer_schedulers()
|
||||
self.step_schedulers(0.0, 0)
|
||||
|
||||
def _validate_optimizer_schedulers(self):
|
||||
if self.schedulers is None:
|
||||
return
|
||||
for _, set_of_schedulers in enumerate(self.schedulers):
|
||||
for option, _ in set_of_schedulers.items():
|
||||
assert option in self.optimizer.defaults, (
|
||||
"Optimizer option "
|
||||
f"{option} not found in {self.optimizer}. Valid options are "
|
||||
f"{self.optimizer.defaults.keys()}"
|
||||
)
|
||||
|
||||
def step_schedulers(self, where: float, step: int) -> None:
|
||||
if self.schedulers is None:
|
||||
return
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
for option, scheduler in self.schedulers[i].items():
|
||||
if "step" in inspect.signature(scheduler.__call__).parameters:
|
||||
new_value = scheduler(step=step, where=where)
|
||||
elif (
|
||||
hasattr(scheduler, "scheduler")
|
||||
and "step"
|
||||
in inspect.signature(scheduler.scheduler.__call__).parameters
|
||||
):
|
||||
# To handle ValueScaler wrappers
|
||||
new_value = scheduler(step=step, where=where)
|
||||
else:
|
||||
new_value = scheduler(where)
|
||||
param_group[option] = new_value
|
||||
|
||||
def step(self, where, step, closure=None):
|
||||
self.step_schedulers(where, step)
|
||||
return self.optimizer.step(closure)
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
return self.optimizer.zero_grad(*args, **kwargs)
|
||||
|
||||
|
||||
def set_default_parameters(
|
||||
scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str]
|
||||
) -> None:
|
||||
"""Set up the "default" scheduler with the right parameters.
|
||||
|
||||
Args:
|
||||
scheduler_cgfs: A list of scheduler configs, where each scheduler also
|
||||
specifies which parameters it applies to, based on the names of parameters
|
||||
or the class of the modules. At most one scheduler is allowed to skip this
|
||||
specification, which is used as a "default" specification for any remaining
|
||||
parameters.
|
||||
all_parameter_names: Names of all the parameters to consider.
|
||||
"""
|
||||
constraints = [
|
||||
scheduler_cfg.parameter_names
|
||||
for scheduler_cfg in scheduler_cfgs
|
||||
if scheduler_cfg.parameter_names is not None
|
||||
]
|
||||
if len(constraints) == 0:
|
||||
default_params = set(all_parameter_names)
|
||||
else:
|
||||
default_params = all_parameter_names - set.union(*constraints)
|
||||
default_count = 0
|
||||
for scheduler_cfg in scheduler_cfgs:
|
||||
if scheduler_cfg.parameter_names is None:
|
||||
scheduler_cfg.parameter_names = default_params
|
||||
default_count += 1
|
||||
assert default_count <= 1, "Only one scheduler per option can be default"
|
||||
if default_count == 0:
|
||||
# No default scheduler specified, add a default, but without any scheduler
|
||||
# for that option
|
||||
scheduler_cfgs.append({"parameter_names": default_params})
|
||||
|
||||
|
||||
def name_constraints_to_parameters(
|
||||
param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]
|
||||
) -> List[torch.nn.Parameter]:
|
||||
"""Return parameters which match the intersection of parameter constraints.
|
||||
|
||||
Note that this returns the parameters themselves, not their names.
|
||||
|
||||
Args:
|
||||
param_constraints: A list, with each element being a set of allowed parameters.
|
||||
named_parameters: Mapping from a parameter name to the parameter itself.
|
||||
|
||||
Returns:
|
||||
A list containing the parameters which overlap with _each_ constraint set from
|
||||
param_constraints.
|
||||
"""
|
||||
matching_names = set.intersection(*param_constraints)
|
||||
return [value for name, value in named_parameters.items() if name in matching_names]
|
||||
|
||||
|
||||
def map_scheduler_cfgs_to_param_groups(
|
||||
all_scheduler_cfgs: Iterable[List[Dict]],
|
||||
named_parameters: Dict[str, Tensor],
|
||||
) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]:
|
||||
"""Produce parameter groups corresponding to all the scheduler configs.
|
||||
|
||||
Takes all the scheduler configs, each of which applies to a specific optimizer
|
||||
option (like "lr" or "weight_decay") and has a set of parameter names which it
|
||||
applies to, and produces a final set of param groups where each param group
|
||||
covers all the options which apply to a particular set of parameters.
|
||||
|
||||
Args:
|
||||
all_scheduler_cfgs: All the scheduler configs covering every option.
|
||||
named_parameters: Mapping from a parameter name to the parameter itself.
|
||||
Returns:
|
||||
Tuple of lists of schedulers and param_groups, where schedulers[i]
|
||||
applies to param_groups[i].
|
||||
"""
|
||||
|
||||
scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs)
|
||||
schedulers = []
|
||||
param_groups = []
|
||||
for scheduler_cfgs in scheduler_cfgs_per_param_group:
|
||||
param_constraints = [
|
||||
scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs
|
||||
]
|
||||
matching_parameters = name_constraints_to_parameters(
|
||||
param_constraints, named_parameters
|
||||
)
|
||||
if len(matching_parameters) == 0: # If no overlap of parameters, skip
|
||||
continue
|
||||
schedulers_for_group = {
|
||||
scheduler_cfg["option"]: scheduler_cfg["scheduler"]
|
||||
for scheduler_cfg in scheduler_cfgs
|
||||
if "option" in scheduler_cfg
|
||||
}
|
||||
schedulers.append(schedulers_for_group)
|
||||
param_groups.append({"params": matching_parameters})
|
||||
return schedulers, param_groups
|
||||
|
||||
|
||||
def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
|
||||
"""Check that the param groups are non-overlapping and cover all the parameters.
|
||||
|
||||
Args:
|
||||
param_groups: List of all param groups
|
||||
model: Model to validate against. The check ensures that all the model
|
||||
parameters are part of param_groups
|
||||
"""
|
||||
for pg in param_groups:
|
||||
# no param should be repeated within a group
|
||||
assert len(pg["params"]) == len(set(pg["params"]))
|
||||
parameters = [set(param_group["params"]) for param_group in param_groups]
|
||||
model_parameters = {parameter for _, parameter in model.named_parameters()}
|
||||
for p1, p2 in itertools.permutations(parameters, 2):
|
||||
assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint"
|
||||
assert set.union(*parameters) == model_parameters, (
|
||||
"Scheduler generated param_groups must include all parameters of the model."
|
||||
f" Found {len(set.union(*parameters))} params whereas model has"
|
||||
f" {len(model_parameters)} params"
|
||||
)
|
||||
|
||||
|
||||
def unix_module_cls_pattern_to_parameter_names(
|
||||
filter_module_cls_names: List[str],
|
||||
module_cls_to_param_names: Dict[Type, str],
|
||||
) -> Union[None, Set[str]]:
|
||||
"""Returns param names which pass the filters specified in filter_module_cls_names.
|
||||
|
||||
Args:
|
||||
filter_module_cls_names: A list of filter strings containing class names, like
|
||||
["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"]
|
||||
module_cls_to_param_names: Mapping from module classes to the parameter names
|
||||
they contain. See `get_module_cls_to_param_names`.
|
||||
"""
|
||||
if filter_module_cls_names is None:
|
||||
return set()
|
||||
allowed_parameter_names = []
|
||||
for module_cls_name in filter_module_cls_names:
|
||||
module_cls = hydra.utils.get_class(module_cls_name)
|
||||
if module_cls not in module_cls_to_param_names:
|
||||
raise AssertionError(
|
||||
f"module_cls_name {module_cls_name} does not "
|
||||
"match any classes in the model"
|
||||
)
|
||||
matching_parameters = module_cls_to_param_names[module_cls]
|
||||
assert (
|
||||
len(matching_parameters) > 0
|
||||
), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
|
||||
logging.info(
|
||||
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
|
||||
)
|
||||
allowed_parameter_names.append(matching_parameters)
|
||||
return set.union(*allowed_parameter_names)
|
||||
|
||||
|
||||
def unix_param_pattern_to_parameter_names(
|
||||
filter_param_names: Optional[List[str]],
|
||||
parameter_names: Dict[str, torch.Tensor],
|
||||
) -> Union[None, Set[str]]:
|
||||
"""Returns param names which pass the filters specified in filter_param_names.
|
||||
|
||||
Args:
|
||||
filter_param_names: A list of unix-style filter strings with optional
|
||||
wildcards, like ["block.2.*", "block.2.linear.weight"]
|
||||
module_cls_to_param_names: Mapping from module classes to the parameter names
|
||||
they contain. See `get_module_cls_to_param_names`.
|
||||
"""
|
||||
|
||||
if filter_param_names is None:
|
||||
return set()
|
||||
allowed_parameter_names = []
|
||||
for param_name in filter_param_names:
|
||||
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
|
||||
assert (
|
||||
len(matching_parameters) >= 1
|
||||
), f"param_name {param_name} does not match any parameters in the model"
|
||||
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
|
||||
allowed_parameter_names.append(matching_parameters)
|
||||
return set.union(*allowed_parameter_names)
|
||||
|
||||
|
||||
def _unix_pattern_to_parameter_names(
|
||||
scheduler_cfg: DictConfig,
|
||||
parameter_names: Set[str],
|
||||
module_cls_to_param_names: Dict[Type, str],
|
||||
) -> Union[None, Set[str]]:
|
||||
"""Returns param names which pass the filters specified in scheduler_cfg.
|
||||
|
||||
Args:
|
||||
scheduler_cfg: The config for the scheduler
|
||||
parameter_names: The set of all parameter names which will be filtered
|
||||
"""
|
||||
if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
|
||||
return None
|
||||
return unix_param_pattern_to_parameter_names(
|
||||
scheduler_cfg.get("param_names"), parameter_names
|
||||
).union(
|
||||
unix_module_cls_pattern_to_parameter_names(
|
||||
scheduler_cfg.get("module_cls_names"), module_cls_to_param_names
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_module_cls_to_param_names(
|
||||
model: nn.Module, param_allowlist: Set[str] = None
|
||||
) -> Dict[Type, str]:
|
||||
"""Produce a mapping from all the modules classes to the names of parames they own.
|
||||
|
||||
Only counts a parameter as part of the immediate parent module, i.e. recursive
|
||||
parents do not count.
|
||||
|
||||
Args:
|
||||
model: Model to iterate over
|
||||
param_allowlist: If specified, only these param names will be processed
|
||||
"""
|
||||
|
||||
module_cls_to_params = {}
|
||||
for module_name, module in model.named_modules():
|
||||
module_cls = type(module)
|
||||
module_cls_to_params.setdefault(module_cls, set())
|
||||
for param_name, _ in module.named_parameters(recurse=False):
|
||||
full_param_name = get_full_parameter_name(module_name, param_name)
|
||||
if param_allowlist is None or full_param_name in param_allowlist:
|
||||
module_cls_to_params[module_cls].add(full_param_name)
|
||||
return module_cls_to_params
|
||||
|
||||
|
||||
def construct_optimizer(
|
||||
model: torch.nn.Module,
|
||||
optimizer_conf: Any,
|
||||
options_conf: Mapping[str, List] = None,
|
||||
param_group_modifiers_conf: List[Callable] = None,
|
||||
param_allowlist: Optional[Set[str]] = None,
|
||||
validate_param_groups=True,
|
||||
) -> Optimizer:
|
||||
"""
|
||||
Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
|
||||
with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
|
||||
Batchnorm and/or no-update 1-D parameters support, based on the config.
|
||||
|
||||
Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
|
||||
(LARS): https://arxiv.org/abs/1708.03888
|
||||
|
||||
Args:
|
||||
model: model to perform stochastic gradient descent
|
||||
optimization or ADAM optimization.
|
||||
optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or
|
||||
ADAM, still missing the params argument which this function provides to
|
||||
produce the final optimizer
|
||||
param_group_modifiers_conf: Optional user specified functions which can modify
|
||||
the final scheduler configs before the optimizer's param groups are built
|
||||
param_allowlist: The parameters to optimize. Parameters which are not part of
|
||||
this allowlist will be skipped.
|
||||
validate_param_groups: If enabled, valides that the produced param_groups don't
|
||||
overlap and cover all the model parameters.
|
||||
"""
|
||||
if param_allowlist is None:
|
||||
param_allowlist = {name for name, _ in model.named_parameters()}
|
||||
|
||||
named_parameters = {
|
||||
name: param
|
||||
for name, param in model.named_parameters()
|
||||
if name in param_allowlist
|
||||
}
|
||||
|
||||
if not options_conf:
|
||||
optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
|
||||
return Optimizer(optimizer)
|
||||
|
||||
all_parameter_names = {
|
||||
name for name, _ in model.named_parameters() if name in param_allowlist
|
||||
}
|
||||
module_cls_to_all_param_names = get_module_cls_to_param_names(
|
||||
model, param_allowlist
|
||||
)
|
||||
|
||||
scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
|
||||
all_scheduler_cfgs = []
|
||||
for option, scheduler_cfgs in scheduler_cfgs_per_option.items():
|
||||
for config in scheduler_cfgs:
|
||||
config.option = option
|
||||
config.parameter_names = _unix_pattern_to_parameter_names(
|
||||
config, all_parameter_names, module_cls_to_all_param_names
|
||||
)
|
||||
set_default_parameters(scheduler_cfgs, all_parameter_names)
|
||||
all_scheduler_cfgs.append(scheduler_cfgs)
|
||||
|
||||
if param_group_modifiers_conf:
|
||||
for custom_param_modifier in param_group_modifiers_conf:
|
||||
custom_param_modifier = hydra.utils.instantiate(custom_param_modifier)
|
||||
all_scheduler_cfgs = custom_param_modifier(
|
||||
scheduler_cfgs=all_scheduler_cfgs, model=model
|
||||
)
|
||||
schedulers, param_groups = map_scheduler_cfgs_to_param_groups(
|
||||
all_scheduler_cfgs, named_parameters
|
||||
)
|
||||
if validate_param_groups:
|
||||
validate_param_group_params(param_groups, model)
|
||||
optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
|
||||
return Optimizer(optimizer, schedulers)
|
||||
|
||||
|
||||
def get_full_parameter_name(module_name, param_name):
|
||||
if module_name == "":
|
||||
return param_name
|
||||
return f"{module_name}.{param_name}"
|
||||
|
||||
|
||||
class GradientClipper:
|
||||
"""
|
||||
Gradient clipping utils that works for DDP
|
||||
"""
|
||||
|
||||
def __init__(self, max_norm: float = 1.0, norm_type: int = 2):
|
||||
assert isinstance(max_norm, (int, float)) or max_norm is None
|
||||
self.max_norm = max_norm if max_norm is None else float(max_norm)
|
||||
self.norm_type = norm_type
|
||||
|
||||
def __call__(self, model: nn.Module):
|
||||
if self.max_norm is None:
|
||||
return # no-op
|
||||
|
||||
nn.utils.clip_grad_norm_(
|
||||
model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type
|
||||
)
|
||||
|
||||
|
||||
class ValueScaler:
|
||||
def __init__(self, scheduler, mult_val: float):
|
||||
self.scheduler = scheduler
|
||||
self.mult_val = mult_val
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
val = self.scheduler(*args, **kwargs)
|
||||
return val * self.mult_val
|
||||
|
||||
|
||||
def rgetattr(obj, rattrs: str = None):
|
||||
"""
|
||||
Like getattr(), but supports dotted notation for nested objects.
|
||||
rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2
|
||||
"""
|
||||
if rattrs is None:
|
||||
return obj
|
||||
attrs = rattrs.split(".")
|
||||
for attr in attrs:
|
||||
obj = getattr(obj, attr)
|
||||
return obj
|
||||
|
||||
|
||||
def layer_decay_param_modifier(
|
||||
scheduler_cfgs: List[List[Dict]],
|
||||
model,
|
||||
layer_decay_value: float,
|
||||
layer_decay_min: Optional[float] = None,
|
||||
apply_to: Optional[str] = None,
|
||||
overrides: List[Dict] = (),
|
||||
) -> List[List[Dict]]:
|
||||
"""
|
||||
Args
|
||||
- scheduler_cfgs: a list of omegaconf.ListConfigs.
|
||||
Each element in the list is a omegaconfg.DictConfig with the following structure
|
||||
{
|
||||
"scheduler": <some fvcore scheduler>
|
||||
"option": <value> possible options are "lr", "weight_decay" etc.
|
||||
"parameter_names": Set of str indicating param names that this scheduler applies to
|
||||
}
|
||||
- model: a model that implements a method `get_layer_id` that maps layer_name to an integer and
|
||||
and a method get_num_layers.
|
||||
Alternatively, use apply_to argument to select a specific component of the model.
|
||||
- layer_decay_value: float
|
||||
- layer_decay_min: min val for layer decay
|
||||
- apply_to: optional arg to select which component of the model to apply the the layer decay modifier to
|
||||
- overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value".
|
||||
Returns
|
||||
- scheduler_configs: same structure as the input, elements can be modified
|
||||
"""
|
||||
model = rgetattr(model, apply_to)
|
||||
num_layers = model.get_num_layers() + 1
|
||||
layer_decays = [
|
||||
layer_decay_value ** (num_layers - i) for i in range(num_layers + 1)
|
||||
]
|
||||
if layer_decay_min is not None:
|
||||
layer_decays = [max(val, layer_decay_min) for val in layer_decays]
|
||||
final_scheduler_cfgs = []
|
||||
# scheduler_cfgs is a list of lists
|
||||
for scheduler_cfg_group in scheduler_cfgs:
|
||||
curr_cfg_group = []
|
||||
# scheduler_cfg_group is a list of dictionaries
|
||||
for scheduler_cfg in scheduler_cfg_group:
|
||||
if scheduler_cfg["option"] != "lr":
|
||||
curr_cfg_group.append(scheduler_cfg)
|
||||
continue
|
||||
# Need sorted so that the list of parameter names is deterministic and consistent
|
||||
# across re-runs of this job. Else it was causing issues with loading the optimizer
|
||||
# state during a job restart
|
||||
parameter_names = sorted(scheduler_cfg["parameter_names"])
|
||||
|
||||
# Only want one cfg group per layer
|
||||
layer_cfg_groups = {}
|
||||
for param_name in parameter_names:
|
||||
layer_id = num_layers
|
||||
this_scale = layer_decays[layer_id]
|
||||
if param_name.startswith(apply_to):
|
||||
layer_id = model.get_layer_id(param_name)
|
||||
this_scale = layer_decays[layer_id]
|
||||
# Overrides
|
||||
for override in overrides:
|
||||
if fnmatch.fnmatchcase(param_name, override["pattern"]):
|
||||
this_scale = float(override["value"])
|
||||
layer_id = override["pattern"]
|
||||
break
|
||||
|
||||
if layer_id not in layer_cfg_groups:
|
||||
curr_param = {
|
||||
"option": scheduler_cfg["option"],
|
||||
"scheduler": ValueScaler(
|
||||
scheduler_cfg["scheduler"], this_scale
|
||||
),
|
||||
"parameter_names": {param_name},
|
||||
}
|
||||
else:
|
||||
curr_param = layer_cfg_groups[layer_id]
|
||||
curr_param["parameter_names"].add(param_name)
|
||||
layer_cfg_groups[layer_id] = curr_param
|
||||
|
||||
for layer_cfg in layer_cfg_groups.values():
|
||||
curr_cfg_group.append(layer_cfg)
|
||||
|
||||
final_scheduler_cfgs.append(curr_cfg_group)
|
||||
return final_scheduler_cfgs
|
||||
41
sam3/train/optim/schedulers.py
Normal file
41
sam3/train/optim/schedulers.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import math
|
||||
|
||||
|
||||
class InverseSquareRootParamScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
base_lr: float,
|
||||
warmup_steps: int,
|
||||
cooldown_steps: int,
|
||||
timescale: int,
|
||||
):
|
||||
self.base_lr = base_lr
|
||||
self.warmup_steps = warmup_steps
|
||||
self.cooldown_steps = cooldown_steps
|
||||
self.timescale = timescale
|
||||
|
||||
def __call__(self, step: int, where: float):
|
||||
lr = self.base_lr
|
||||
|
||||
if where > 0:
|
||||
total_steps = step / where
|
||||
progress = (step - self.warmup_steps) / float(
|
||||
total_steps - self.warmup_steps
|
||||
)
|
||||
progress = max(min(progress, 1), 0)
|
||||
else:
|
||||
progress = 0
|
||||
total_steps = 1
|
||||
|
||||
shift = self.timescale - self.warmup_steps
|
||||
if self.warmup_steps < step:
|
||||
lr = lr / math.sqrt((step + shift) / self.timescale)
|
||||
|
||||
if self.warmup_steps:
|
||||
lr = lr * min(1.0, step / self.warmup_steps)
|
||||
if self.cooldown_steps:
|
||||
lr = lr * min(1.0, (total_steps - step) / self.cooldown_steps)
|
||||
|
||||
return lr
|
||||
339
sam3/train/train.py
Normal file
339
sam3/train/train.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import ArgumentParser
|
||||
from copy import deepcopy
|
||||
|
||||
import submitit
|
||||
import torch
|
||||
|
||||
from hydra import compose, initialize_config_module
|
||||
from hydra.utils import instantiate
|
||||
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
|
||||
|
||||
class SlurmEvent:
|
||||
QUEUED = "QUEUED"
|
||||
START = "START"
|
||||
FINISH = "FINISH"
|
||||
JOB_ERROR = "JOB_ERROR"
|
||||
SLURM_SIGNAL = "SLURM_SIGNAL"
|
||||
|
||||
|
||||
def handle_custom_resolving(cfg):
|
||||
# We'll resolve the config here, so we can catch mistakes early.
|
||||
# However, we need to pass the un-resolved config to the launcher
|
||||
# (because DVC resolving needs to be done on the node it will run on)
|
||||
# First, do a copy without triggering resolving
|
||||
cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
|
||||
cfg_resolved = OmegaConf.create(cfg_resolved)
|
||||
return cfg_resolved
|
||||
|
||||
|
||||
def single_proc_run(local_rank, main_port, cfg, world_size):
|
||||
"""Single GPU process"""
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(main_port)
|
||||
os.environ["RANK"] = str(local_rank)
|
||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
try:
|
||||
register_omegaconf_resolvers()
|
||||
except Exception as e:
|
||||
logging.info(e)
|
||||
|
||||
trainer = instantiate(cfg.trainer, _recursive_=False)
|
||||
trainer.run()
|
||||
|
||||
|
||||
def single_node_runner(cfg, main_port: int):
|
||||
assert cfg.launcher.num_nodes == 1
|
||||
# assert cfg.launcher.gpus_per_node == 1
|
||||
num_proc = cfg.launcher.gpus_per_node
|
||||
torch.multiprocessing.set_start_method(
|
||||
"spawn"
|
||||
) # CUDA runtime does not support `fork`
|
||||
if num_proc == 1:
|
||||
# directly call single_proc so we can easily set breakpoints
|
||||
# mp.spawn does not let us set breakpoints
|
||||
single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc)
|
||||
else:
|
||||
mp_runner = torch.multiprocessing.start_processes
|
||||
args = (main_port, cfg, num_proc)
|
||||
# Note: using "fork" below, "spawn" causes time and error regressions. Using
|
||||
# spawn changes the default multiprocessing context to spawn, which doesn't
|
||||
# interact well with the dataloaders (likely due to the use of OpenCV).
|
||||
mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn")
|
||||
|
||||
|
||||
def format_exception(e: Exception, limit=20):
|
||||
traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit))
|
||||
return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}"
|
||||
|
||||
|
||||
class SubmititRunner(submitit.helpers.Checkpointable):
|
||||
"""A callable which is passed to submitit to launch the jobs."""
|
||||
|
||||
def __init__(self, port, cfg):
|
||||
self.cfg = cfg
|
||||
self.port = port
|
||||
self.has_setup = False
|
||||
|
||||
def run_trainer(self):
|
||||
job_env = submitit.JobEnvironment()
|
||||
# Need to add this again so the hydra.job.set_env PYTHONPATH
|
||||
# is also set when launching jobs.
|
||||
add_pythonpath_to_sys_path()
|
||||
os.environ["MASTER_ADDR"] = job_env.hostnames[0]
|
||||
os.environ["MASTER_PORT"] = str(self.port)
|
||||
os.environ["RANK"] = str(job_env.global_rank)
|
||||
os.environ["LOCAL_RANK"] = str(job_env.local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
|
||||
|
||||
register_omegaconf_resolvers()
|
||||
cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False)
|
||||
cfg_resolved = OmegaConf.create(cfg_resolved)
|
||||
|
||||
trainer = instantiate(cfg_resolved.trainer, _recursive_=False)
|
||||
trainer.run()
|
||||
|
||||
def __call__(self):
|
||||
job_env = submitit.JobEnvironment()
|
||||
self.setup_job_info(job_env.job_id, job_env.global_rank)
|
||||
try:
|
||||
self.run_trainer()
|
||||
except Exception as e:
|
||||
# Log the exception. Then raise it again (as what SubmititRunner currently does).
|
||||
message = format_exception(e)
|
||||
logging.error(message)
|
||||
raise e
|
||||
|
||||
def setup_job_info(self, job_id, rank):
|
||||
"""Set up slurm job info"""
|
||||
self.job_info = {
|
||||
"job_id": job_id,
|
||||
"rank": rank,
|
||||
"cluster": self.cfg.get("cluster", None),
|
||||
"experiment_log_dir": self.cfg.launcher.experiment_log_dir,
|
||||
}
|
||||
|
||||
self.has_setup = True
|
||||
|
||||
|
||||
def add_pythonpath_to_sys_path():
|
||||
if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]:
|
||||
return
|
||||
sys.path = os.environ["PYTHONPATH"].split(":") + sys.path
|
||||
|
||||
|
||||
def main(args) -> None:
|
||||
cfg = compose(config_name=args.config)
|
||||
if cfg.launcher.experiment_log_dir is None:
|
||||
cfg.launcher.experiment_log_dir = os.path.join(
|
||||
os.getcwd(), "sam3_logs", args.config
|
||||
)
|
||||
print("###################### Train App Config ####################")
|
||||
print(OmegaConf.to_yaml(cfg))
|
||||
print("############################################################")
|
||||
|
||||
add_pythonpath_to_sys_path()
|
||||
makedir(cfg.launcher.experiment_log_dir)
|
||||
with g_pathmgr.open(
|
||||
os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w"
|
||||
) as f:
|
||||
f.write(OmegaConf.to_yaml(cfg))
|
||||
|
||||
cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
|
||||
cfg_resolved = OmegaConf.create(cfg_resolved)
|
||||
|
||||
with g_pathmgr.open(
|
||||
os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w"
|
||||
) as f:
|
||||
f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True))
|
||||
|
||||
submitit_conf = cfg.get("submitit", None)
|
||||
assert submitit_conf is not None, "Missing submitit config"
|
||||
|
||||
experiment_log_dir = cfg.launcher.experiment_log_dir
|
||||
print(f"Experiment Log Dir:\n{experiment_log_dir}")
|
||||
submitit_dir = os.path.join(experiment_log_dir, "submitit_logs")
|
||||
|
||||
# Prioritize cmd line args
|
||||
cfg.launcher.gpus_per_node = (
|
||||
args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node
|
||||
)
|
||||
cfg.launcher.num_nodes = (
|
||||
args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes
|
||||
)
|
||||
submitit_conf.use_cluster = (
|
||||
args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster
|
||||
)
|
||||
if submitit_conf.use_cluster:
|
||||
executor = submitit.AutoExecutor(folder=submitit_dir)
|
||||
submitit_conf.partition = (
|
||||
args.partition
|
||||
if args.partition is not None
|
||||
else submitit_conf.get("partition", None)
|
||||
)
|
||||
submitit_conf.account = (
|
||||
args.account
|
||||
if args.account is not None
|
||||
else submitit_conf.get("account", None)
|
||||
)
|
||||
submitit_conf.qos = (
|
||||
args.qos if args.qos is not None else submitit_conf.get("qos", None)
|
||||
)
|
||||
job_kwargs = {
|
||||
"timeout_min": 60 * submitit_conf.timeout_hour,
|
||||
"name": (
|
||||
submitit_conf.name if hasattr(submitit_conf, "name") else args.config
|
||||
),
|
||||
"slurm_partition": submitit_conf.partition,
|
||||
"gpus_per_node": cfg.launcher.gpus_per_node,
|
||||
"tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU
|
||||
"cpus_per_task": submitit_conf.cpus_per_task,
|
||||
"nodes": cfg.launcher.num_nodes,
|
||||
"slurm_additional_parameters": {
|
||||
"exclude": " ".join(submitit_conf.get("exclude_nodes", [])),
|
||||
},
|
||||
}
|
||||
if "include_nodes" in submitit_conf:
|
||||
assert (
|
||||
len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
|
||||
), "Not enough nodes"
|
||||
job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
|
||||
submitit_conf["include_nodes"]
|
||||
)
|
||||
if submitit_conf.account is not None:
|
||||
job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account
|
||||
if submitit_conf.qos is not None:
|
||||
job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos
|
||||
|
||||
if submitit_conf.get("mem_gb", None) is not None:
|
||||
job_kwargs["mem_gb"] = submitit_conf.mem_gb
|
||||
elif submitit_conf.get("mem", None) is not None:
|
||||
job_kwargs["slurm_mem"] = submitit_conf.mem
|
||||
|
||||
if submitit_conf.get("constraints", None) is not None:
|
||||
job_kwargs["slurm_constraint"] = submitit_conf.constraints
|
||||
|
||||
if submitit_conf.get("comment", None) is not None:
|
||||
job_kwargs["slurm_comment"] = submitit_conf.comment
|
||||
|
||||
# Supports only cpu-bind option within srun_args. New options can be added here
|
||||
if submitit_conf.get("srun_args", None) is not None:
|
||||
job_kwargs["slurm_srun_args"] = []
|
||||
if submitit_conf.srun_args.get("cpu_bind", None) is not None:
|
||||
job_kwargs["slurm_srun_args"].extend(
|
||||
["--cpu-bind", submitit_conf.srun_args.cpu_bind]
|
||||
)
|
||||
|
||||
print("###################### SLURM Config ####################")
|
||||
print(job_kwargs)
|
||||
print("##########################################")
|
||||
executor.update_parameters(**job_kwargs)
|
||||
|
||||
if (
|
||||
"job_array" in submitit_conf
|
||||
and submitit_conf.job_array.get("num_tasks", -1) > 0
|
||||
):
|
||||
num_tasks = submitit_conf.job_array.num_tasks
|
||||
job_array_config_dir = os.path.join(
|
||||
cfg.launcher.experiment_log_dir, "job_array_configs"
|
||||
)
|
||||
makedir(job_array_config_dir)
|
||||
|
||||
job_indices = range(num_tasks)
|
||||
ports = random.sample(
|
||||
range(submitit_conf.port_range[0], submitit_conf.port_range[1] + 1),
|
||||
k=len(job_indices),
|
||||
)
|
||||
|
||||
jobs_runners_configs = []
|
||||
with executor.batch():
|
||||
task_index = 0
|
||||
for indices, main_port in tqdm(zip(job_indices, ports)):
|
||||
curr_cfg = deepcopy(cfg)
|
||||
curr_cfg.submitit.job_array["task_index"] = task_index
|
||||
curr_cfg_resolved = handle_custom_resolving(cfg)
|
||||
runner = SubmititRunner(main_port, curr_cfg)
|
||||
job = executor.submit(runner)
|
||||
jobs_runners_configs.append(
|
||||
(job, runner, curr_cfg, curr_cfg_resolved)
|
||||
)
|
||||
task_index += 1
|
||||
|
||||
for job, runner, job_cfg, job_cfg_resolved in jobs_runners_configs:
|
||||
print("Submitit Job ID:", job.job_id)
|
||||
|
||||
# Save job specific config
|
||||
job_array_config_file = os.path.join(
|
||||
job_array_config_dir, "{}.config.yaml".format(job.job_id)
|
||||
)
|
||||
with g_pathmgr.open(job_array_config_file, "w") as f:
|
||||
f.write(OmegaConf.to_yaml(job_cfg))
|
||||
|
||||
job_array_config_resolved_file = os.path.join(
|
||||
job_array_config_dir, "{}.config_resolved.yaml".format(job.job_id)
|
||||
)
|
||||
with g_pathmgr.open(job_array_config_resolved_file, "w") as f:
|
||||
f.write(OmegaConf.to_yaml(job_cfg_resolved, resolve=True))
|
||||
|
||||
runner.setup_job_info(job.job_id, rank=0)
|
||||
# runner.log_event(event_type=SlurmEvent.QUEUED)
|
||||
else:
|
||||
main_port = random.randint(
|
||||
submitit_conf.port_range[0], submitit_conf.port_range[1]
|
||||
)
|
||||
runner = SubmititRunner(main_port, cfg)
|
||||
job = executor.submit(runner)
|
||||
print(f"Submitit Job ID: {job.job_id}")
|
||||
runner.setup_job_info(job.job_id, rank=0)
|
||||
|
||||
else:
|
||||
cfg.launcher.num_nodes = 1
|
||||
main_port = random.randint(
|
||||
submitit_conf.port_range[0], submitit_conf.port_range[1]
|
||||
)
|
||||
single_node_runner(cfg, main_port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
initialize_config_module("sam3.train", version_base="1.2")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
required=True,
|
||||
type=str,
|
||||
help="path to config file (e.g. configs/roboflow_v100_full_ft_100_images.yaml)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-cluster",
|
||||
type=int,
|
||||
default=None,
|
||||
help="whether to launch on a cluster, 0: run locally, 1: run on a cluster",
|
||||
)
|
||||
parser.add_argument("--partition", type=str, default=None, help="SLURM partition")
|
||||
parser.add_argument("--account", type=str, default=None, help="SLURM account")
|
||||
parser.add_argument("--qos", type=str, default=None, help="SLURM qos")
|
||||
parser.add_argument(
|
||||
"--num-gpus", type=int, default=None, help="number of GPUS per node"
|
||||
)
|
||||
parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes")
|
||||
args = parser.parse_args()
|
||||
args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None
|
||||
register_omegaconf_resolvers()
|
||||
main(args)
|
||||
1193
sam3/train/trainer.py
Normal file
1193
sam3/train/trainer.py
Normal file
File diff suppressed because it is too large
Load Diff
1
sam3/train/transforms/__init__.py
Normal file
1
sam3/train/transforms/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
455
sam3/train/transforms/basic.py
Normal file
455
sam3/train/transforms/basic.py
Normal file
@@ -0,0 +1,455 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Transforms and data augmentation for both image + bbox.
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
from typing import Iterable
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
from sam3.model.box_ops import box_xyxy_to_cxcywh
|
||||
from sam3.model.data_misc import interpolate
|
||||
|
||||
|
||||
def crop(image, target, region):
|
||||
cropped_image = F.crop(image, *region)
|
||||
|
||||
target = target.copy()
|
||||
i, j, h, w = region
|
||||
|
||||
# should we do something wrt the original size?
|
||||
target["size"] = torch.tensor([h, w])
|
||||
|
||||
fields = ["labels", "area", "iscrowd", "positive_map"]
|
||||
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
||||
cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
|
||||
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
||||
cropped_boxes = cropped_boxes.clamp(min=0)
|
||||
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
||||
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
||||
target["area"] = area
|
||||
fields.append("boxes")
|
||||
|
||||
if "input_boxes" in target:
|
||||
boxes = target["input_boxes"]
|
||||
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
||||
cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
|
||||
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
||||
cropped_boxes = cropped_boxes.clamp(min=0)
|
||||
target["input_boxes"] = cropped_boxes.reshape(-1, 4)
|
||||
|
||||
if "masks" in target:
|
||||
# FIXME should we update the area here if there are no boxes?
|
||||
target["masks"] = target["masks"][:, i : i + h, j : j + w]
|
||||
fields.append("masks")
|
||||
|
||||
# remove elements for which the boxes or masks that have zero area
|
||||
if "boxes" in target or "masks" in target:
|
||||
# favor boxes selection when defining which elements to keep
|
||||
# this is compatible with previous implementation
|
||||
if "boxes" in target:
|
||||
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
|
||||
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
||||
else:
|
||||
keep = target["masks"].flatten(1).any(1)
|
||||
|
||||
for field in fields:
|
||||
if field in target:
|
||||
target[field] = target[field][keep]
|
||||
|
||||
return cropped_image, target
|
||||
|
||||
|
||||
def hflip(image, target):
|
||||
flipped_image = F.hflip(image)
|
||||
|
||||
w, h = image.size
|
||||
|
||||
target = target.copy()
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
|
||||
[-1, 1, -1, 1]
|
||||
) + torch.as_tensor([w, 0, w, 0])
|
||||
target["boxes"] = boxes
|
||||
|
||||
if "input_boxes" in target:
|
||||
boxes = target["input_boxes"]
|
||||
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
|
||||
[-1, 1, -1, 1]
|
||||
) + torch.as_tensor([w, 0, w, 0])
|
||||
target["input_boxes"] = boxes
|
||||
|
||||
if "masks" in target:
|
||||
target["masks"] = target["masks"].flip(-1)
|
||||
|
||||
if "text_input" in target:
|
||||
text_input = (
|
||||
target["text_input"]
|
||||
.replace("left", "[TMP]")
|
||||
.replace("right", "left")
|
||||
.replace("[TMP]", "right")
|
||||
)
|
||||
target["text_input"] = text_input
|
||||
|
||||
return flipped_image, target
|
||||
|
||||
|
||||
def resize(image, target, size, max_size=None, square=False):
|
||||
# size can be min_size (scalar) or (w, h) tuple
|
||||
|
||||
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
||||
w, h = image_size
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((w, h)))
|
||||
max_original_size = float(max((w, h)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
size = int(round(max_size * min_original_size / max_original_size))
|
||||
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return (h, w)
|
||||
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
|
||||
return (oh, ow)
|
||||
|
||||
def get_size(image_size, size, max_size=None):
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size[::-1]
|
||||
else:
|
||||
return get_size_with_aspect_ratio(image_size, size, max_size)
|
||||
|
||||
if square:
|
||||
size = size, size
|
||||
else:
|
||||
size = get_size(image.size, size, max_size)
|
||||
rescaled_image = F.resize(image, size)
|
||||
|
||||
if target is None:
|
||||
return rescaled_image, None
|
||||
|
||||
ratios = tuple(
|
||||
float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)
|
||||
)
|
||||
ratio_width, ratio_height = ratios
|
||||
|
||||
target = target.copy()
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
scaled_boxes = boxes * torch.as_tensor(
|
||||
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32
|
||||
)
|
||||
target["boxes"] = scaled_boxes
|
||||
if "input_boxes" in target:
|
||||
boxes = target["input_boxes"]
|
||||
scaled_boxes = boxes * torch.as_tensor(
|
||||
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32
|
||||
)
|
||||
target["input_boxes"] = scaled_boxes
|
||||
|
||||
if "area" in target:
|
||||
area = target["area"]
|
||||
scaled_area = area * (ratio_width * ratio_height)
|
||||
target["area"] = scaled_area
|
||||
|
||||
h, w = size
|
||||
target["size"] = torch.tensor([h, w])
|
||||
|
||||
if "masks" in target:
|
||||
target["masks"] = (
|
||||
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0]
|
||||
> 0.5
|
||||
)
|
||||
|
||||
return rescaled_image, target
|
||||
|
||||
|
||||
def pad(image, target, padding):
|
||||
if len(padding) == 2:
|
||||
# assumes that we only pad on the bottom right corners
|
||||
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
|
||||
else:
|
||||
# left, top, right, bottom
|
||||
padded_image = F.pad(image, (padding[0], padding[1], padding[2], padding[3]))
|
||||
if target is None:
|
||||
return padded_image, None
|
||||
target = target.copy()
|
||||
|
||||
w, h = padded_image.size
|
||||
|
||||
# should we do something wrt the original size?
|
||||
target["size"] = torch.tensor([h, w])
|
||||
if "boxes" in target and len(padding) == 4:
|
||||
boxes = target["boxes"]
|
||||
boxes = boxes + torch.as_tensor(
|
||||
[padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32
|
||||
)
|
||||
target["boxes"] = boxes
|
||||
|
||||
if "input_boxes" in target and len(padding) == 4:
|
||||
boxes = target["input_boxes"]
|
||||
boxes = boxes + torch.as_tensor(
|
||||
[padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32
|
||||
)
|
||||
target["input_boxes"] = boxes
|
||||
|
||||
if "masks" in target:
|
||||
if len(padding) == 2:
|
||||
target["masks"] = torch.nn.functional.pad(
|
||||
target["masks"], (0, padding[0], 0, padding[1])
|
||||
)
|
||||
else:
|
||||
target["masks"] = torch.nn.functional.pad(
|
||||
target["masks"], (padding[0], padding[2], padding[1], padding[3])
|
||||
)
|
||||
return padded_image, target
|
||||
|
||||
|
||||
class RandomCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img, target):
|
||||
region = T.RandomCrop.get_params(img, self.size)
|
||||
return crop(img, target, region)
|
||||
|
||||
|
||||
class RandomSizeCrop:
|
||||
def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
|
||||
self.min_size = min_size
|
||||
self.max_size = max_size
|
||||
self.respect_boxes = respect_boxes # if True we can't crop a box out
|
||||
|
||||
def __call__(self, img: PIL.Image.Image, target: dict):
|
||||
init_boxes = len(target["boxes"])
|
||||
init_boxes_tensor = target["boxes"].clone()
|
||||
if self.respect_boxes and init_boxes > 0:
|
||||
minW, minH, maxW, maxH = (
|
||||
min(img.width, self.min_size),
|
||||
min(img.width, self.min_size),
|
||||
min(img.width, self.max_size),
|
||||
min(img.height, self.max_size),
|
||||
)
|
||||
minX, minY = (
|
||||
target["boxes"][:, 0].max().item() + 10.0,
|
||||
target["boxes"][:, 1].max().item() + 10.0,
|
||||
)
|
||||
minX = min(img.width, minX)
|
||||
minY = min(img.height, minY)
|
||||
maxX, maxY = (
|
||||
target["boxes"][:, 2].min().item() - 10,
|
||||
target["boxes"][:, 3].min().item() - 10,
|
||||
)
|
||||
maxX = max(0.0, maxX)
|
||||
maxY = max(0.0, maxY)
|
||||
minW = max(minW, minX - maxX)
|
||||
minH = max(minH, minY - maxY)
|
||||
w = random.uniform(minW, max(minW, maxW))
|
||||
h = random.uniform(minH, max(minH, maxH))
|
||||
if minX > maxX:
|
||||
# i = random.uniform(max(0, minX - w + 1), max(maxX, max(0, minX - w + 1)))
|
||||
i = random.uniform(max(0, minX - w), max(maxX, max(0, minX - w)))
|
||||
else:
|
||||
i = random.uniform(
|
||||
max(0, minX - w + 1), max(maxX - 1, max(0, minX - w + 1))
|
||||
)
|
||||
if minY > maxY:
|
||||
# j = random.uniform(max(0, minY - h + 1), max(maxY, max(0, minY - h + 1)))
|
||||
j = random.uniform(max(0, minY - h), max(maxY, max(0, minY - h)))
|
||||
else:
|
||||
j = random.uniform(
|
||||
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
|
||||
)
|
||||
result_img, result_target = crop(img, target, [j, i, h, w])
|
||||
assert (
|
||||
len(result_target["boxes"]) == init_boxes
|
||||
), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
|
||||
|
||||
return result_img, result_target
|
||||
else:
|
||||
w = random.randint(self.min_size, min(img.width, self.max_size))
|
||||
h = random.randint(self.min_size, min(img.height, self.max_size))
|
||||
region = T.RandomCrop.get_params(img, (h, w))
|
||||
result_img, result_target = crop(img, target, region)
|
||||
return result_img, result_target
|
||||
|
||||
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img, target):
|
||||
image_width, image_height = img.size
|
||||
crop_height, crop_width = self.size
|
||||
crop_top = int(round((image_height - crop_height) / 2.0))
|
||||
crop_left = int(round((image_width - crop_width) / 2.0))
|
||||
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
|
||||
|
||||
|
||||
class RandomHorizontalFlip:
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img, target):
|
||||
if random.random() < self.p:
|
||||
return hflip(img, target)
|
||||
return img, target
|
||||
|
||||
|
||||
class RandomResize:
|
||||
def __init__(self, sizes, max_size=None, square=False):
|
||||
if isinstance(sizes, int):
|
||||
sizes = (sizes,)
|
||||
assert isinstance(sizes, Iterable)
|
||||
self.sizes = list(sizes)
|
||||
self.max_size = max_size
|
||||
self.square = square
|
||||
|
||||
def __call__(self, img, target=None):
|
||||
size = random.choice(self.sizes)
|
||||
return resize(img, target, size, self.max_size, square=self.square)
|
||||
|
||||
|
||||
class RandomPad:
|
||||
def __init__(self, max_pad):
|
||||
self.max_pad = max_pad
|
||||
|
||||
def __call__(self, img, target):
|
||||
pad_x = random.randint(0, self.max_pad)
|
||||
pad_y = random.randint(0, self.max_pad)
|
||||
return pad(img, target, (pad_x, pad_y))
|
||||
|
||||
|
||||
class PadToSize:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img, target):
|
||||
w, h = img.size
|
||||
pad_x = self.size - w
|
||||
pad_y = self.size - h
|
||||
assert pad_x >= 0 and pad_y >= 0
|
||||
pad_left = random.randint(0, pad_x)
|
||||
pad_right = pad_x - pad_left
|
||||
pad_top = random.randint(0, pad_y)
|
||||
pad_bottom = pad_y - pad_top
|
||||
return pad(img, target, (pad_left, pad_top, pad_right, pad_bottom))
|
||||
|
||||
|
||||
class Identity:
|
||||
def __call__(self, img, target):
|
||||
return img, target
|
||||
|
||||
|
||||
class RandomSelect:
|
||||
"""
|
||||
Randomly selects between transforms1 and transforms2,
|
||||
with probability p for transforms1 and (1 - p) for transforms2
|
||||
"""
|
||||
|
||||
def __init__(self, transforms1=None, transforms2=None, p=0.5):
|
||||
self.transforms1 = transforms1 or Identity()
|
||||
self.transforms2 = transforms2 or Identity()
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img, target):
|
||||
if random.random() < self.p:
|
||||
return self.transforms1(img, target)
|
||||
return self.transforms2(img, target)
|
||||
|
||||
|
||||
class ToTensor:
|
||||
def __call__(self, img, target):
|
||||
return F.to_tensor(img), target
|
||||
|
||||
|
||||
class RandomErasing:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.eraser = T.RandomErasing(*args, **kwargs)
|
||||
|
||||
def __call__(self, img, target):
|
||||
return self.eraser(img), target
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, image, target=None):
|
||||
image = F.normalize(image, mean=self.mean, std=self.std)
|
||||
if target is None:
|
||||
return image, None
|
||||
target = target.copy()
|
||||
h, w = image.shape[-2:]
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
boxes = box_xyxy_to_cxcywh(boxes)
|
||||
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
||||
target["boxes"] = boxes
|
||||
if "input_boxes" in target:
|
||||
boxes = target["input_boxes"]
|
||||
boxes = box_xyxy_to_cxcywh(boxes)
|
||||
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
||||
target["input_boxes"] = boxes
|
||||
return image, target
|
||||
|
||||
|
||||
class RemoveDifficult:
|
||||
def __init__(self, enabled=False):
|
||||
self.remove_difficult = enabled
|
||||
|
||||
def __call__(self, image, target=None):
|
||||
if target is None:
|
||||
return image, None
|
||||
target = target.copy()
|
||||
keep = ~target["iscrowd"].to(torch.bool) | (not self.remove_difficult)
|
||||
if "boxes" in target:
|
||||
target["boxes"] = target["boxes"][keep]
|
||||
target["labels"] = target["labels"][keep]
|
||||
target["iscrowd"] = target["iscrowd"][keep]
|
||||
return image, target
|
||||
|
||||
|
||||
class Compose:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, image, target):
|
||||
for t in self.transforms:
|
||||
image, target = t(image, target)
|
||||
return image, target
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + "("
|
||||
for t in self.transforms:
|
||||
format_string += "\n"
|
||||
format_string += " {0}".format(t)
|
||||
format_string += "\n)"
|
||||
return format_string
|
||||
|
||||
|
||||
def get_random_resize_scales(size, min_size, rounded):
|
||||
stride = 128 if rounded else 32
|
||||
min_size = int(stride * math.ceil(min_size / stride))
|
||||
scales = list(range(min_size, size + 1, stride))
|
||||
return scales
|
||||
|
||||
|
||||
def get_random_resize_max_size(size, ratio=5 / 3):
|
||||
max_size = round(ratio * size)
|
||||
return max_size
|
||||
1396
sam3/train/transforms/basic_for_api.py
Normal file
1396
sam3/train/transforms/basic_for_api.py
Normal file
File diff suppressed because it is too large
Load Diff
607
sam3/train/transforms/filter_query_transforms.py
Normal file
607
sam3/train/transforms/filter_query_transforms.py
Normal file
@@ -0,0 +1,607 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object
|
||||
|
||||
|
||||
class FilterDataPointQueries:
|
||||
find_ids_to_filter: set = None
|
||||
get_ids_to_filter: set = None
|
||||
obj_ids_to_filter: set = None # stored as pairs (img_id, obj_id)
|
||||
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
"""
|
||||
Compute set of query ids to keep, for both find and get queries
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _do_filter_query(self, query: Union[FindQuery], query_id: int):
|
||||
assert self.find_ids_to_filter is not None
|
||||
|
||||
return query_id in self.find_ids_to_filter
|
||||
|
||||
|
||||
class FilterQueryWithText(FilterDataPointQueries):
|
||||
"""
|
||||
Filter all datapoints which have query text in a specified list of exluded terms
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, exclude_find_keys: List[str] = None, exclude_get_keys: List[str] = None
|
||||
):
|
||||
self.find_filter_keys = exclude_find_keys if exclude_find_keys else []
|
||||
self.get_filter_keys = exclude_get_keys if exclude_get_keys else []
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
del_find_ids = []
|
||||
del_get_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if f_q.query_text in self.find_filter_keys:
|
||||
del_find_ids.append(i)
|
||||
|
||||
self.find_ids_to_filter = set(del_find_ids)
|
||||
|
||||
|
||||
class KeepMaxNumFindQueries(FilterDataPointQueries):
|
||||
def __init__(
|
||||
self, max_num_find_queries: int, retain_positive_queries: bool = False
|
||||
):
|
||||
self.max_num_find_queries = max_num_find_queries
|
||||
self.retain_positive_queries = retain_positive_queries
|
||||
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
self.obj_ids_to_filter = set()
|
||||
num_find_queries = len(datapoint.find_queries)
|
||||
if num_find_queries <= self.max_num_find_queries:
|
||||
self.find_ids_to_filter = set() # keep all find queries
|
||||
return
|
||||
|
||||
if not self.retain_positive_queries:
|
||||
all_find_query_ids = list(range(num_find_queries))
|
||||
num_queries_to_filter = max(0, num_find_queries - self.max_num_find_queries)
|
||||
query_ids_to_filter = random.sample(
|
||||
all_find_query_ids, k=num_queries_to_filter
|
||||
)
|
||||
else:
|
||||
# keep up to max_num_find_queries postive find queries and fill
|
||||
# the remaining slots (if any) with negative find queries
|
||||
pos_find_ids, neg_find_ids = [], []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
# Negative finds return an empty list of object_ids_output
|
||||
if len(f_q.object_ids_output) == 0:
|
||||
neg_find_ids.append(i)
|
||||
else:
|
||||
pos_find_ids.append(i)
|
||||
|
||||
if len(pos_find_ids) >= self.max_num_find_queries:
|
||||
# we have more positive find queries than `max_num_find_queries`,
|
||||
# so we subsample postive find queries and remove all negative find queries
|
||||
num_queries_to_filter = len(pos_find_ids) - self.max_num_find_queries
|
||||
query_ids_to_filter = random.sample(
|
||||
pos_find_ids, k=num_queries_to_filter
|
||||
)
|
||||
query_ids_to_filter.extend(neg_find_ids)
|
||||
else:
|
||||
# we have fewer positive find queries than `max_num_find_queries`
|
||||
# so we need to fill the remaining with negative find queries
|
||||
num_queries_to_filter = num_find_queries - self.max_num_find_queries
|
||||
query_ids_to_filter = random.sample(
|
||||
neg_find_ids, k=num_queries_to_filter
|
||||
)
|
||||
|
||||
assert len(query_ids_to_filter) == num_find_queries - self.max_num_find_queries
|
||||
self.find_ids_to_filter = set(query_ids_to_filter)
|
||||
|
||||
|
||||
class KeepMaxNumFindQueriesVideo(FilterDataPointQueries):
|
||||
def __init__(
|
||||
self,
|
||||
video_mosaic_max_num_find_queries_per_frame: int,
|
||||
retain_positive_queries: bool = False,
|
||||
):
|
||||
self.video_mosaic_max_num_find_queries_per_frame = (
|
||||
video_mosaic_max_num_find_queries_per_frame
|
||||
)
|
||||
self.retain_positive_queries = retain_positive_queries
|
||||
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
self.obj_ids_to_filter = set()
|
||||
num_find_queries = len(datapoint.find_queries)
|
||||
|
||||
findQueries_to_imageIds = defaultdict(list)
|
||||
max_queries_per_frame = True
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
findQueries_to_imageIds[f_q.image_id].append(i)
|
||||
if (
|
||||
len(findQueries_to_imageIds[f_q.image_id])
|
||||
> self.video_mosaic_max_num_find_queries_per_frame
|
||||
):
|
||||
max_queries_per_frame = False
|
||||
|
||||
if max_queries_per_frame:
|
||||
self.find_ids_to_filter = set()
|
||||
return
|
||||
|
||||
num_frames = len(findQueries_to_imageIds)
|
||||
findQueries_0 = findQueries_to_imageIds[0]
|
||||
num_find_queries_0 = len(findQueries_0)
|
||||
max_num_find_queries_per_frame = (
|
||||
self.video_mosaic_max_num_find_queries_per_frame
|
||||
)
|
||||
if not self.retain_positive_queries:
|
||||
find_query_ids_0 = list(range(num_find_queries_0))
|
||||
num_queries_to_filter = max(
|
||||
0, num_find_queries_0 - max_num_find_queries_per_frame
|
||||
)
|
||||
query_ids_to_filter_0 = random.sample(
|
||||
find_query_ids_0, k=num_queries_to_filter
|
||||
)
|
||||
else:
|
||||
# keep up to max_num_find_queries postive find queries and fill
|
||||
# the remaining slots (if any) with negative find queries
|
||||
pos_find_ids_0, neg_find_ids_0 = [], []
|
||||
for i, f_q_id in enumerate(findQueries_0):
|
||||
f_q = datapoint.find_queries[f_q_id]
|
||||
# Negative finds return an empty list of object_ids_output
|
||||
if len(f_q.object_ids_output) == 0:
|
||||
neg_find_ids_0.append(i)
|
||||
else:
|
||||
pos_find_ids_0.append(i)
|
||||
|
||||
if len(pos_find_ids_0) >= max_num_find_queries_per_frame:
|
||||
# we have more positive find queries than `max_num_find_queries`,
|
||||
# so we subsample postive find queries and remove all negative find queries
|
||||
num_queries_to_filter = (
|
||||
len(pos_find_ids_0) - max_num_find_queries_per_frame
|
||||
)
|
||||
query_ids_to_filter_0 = random.sample(
|
||||
pos_find_ids_0, k=num_queries_to_filter
|
||||
)
|
||||
query_ids_to_filter_0.extend(neg_find_ids_0)
|
||||
else:
|
||||
# we have fewer positive find queries than `max_num_find_queries`
|
||||
# so we need to fill the remaining with negative find queries
|
||||
num_queries_to_filter = (
|
||||
num_find_queries_0 - max_num_find_queries_per_frame
|
||||
)
|
||||
query_ids_to_filter_0 = random.sample(
|
||||
neg_find_ids_0, k=num_queries_to_filter
|
||||
)
|
||||
|
||||
# get based on frame 0 all find queries from all the frames with the same indices as in frame 0
|
||||
query_ids_to_filter = []
|
||||
for i in range(num_frames):
|
||||
findQueries_i = findQueries_to_imageIds[i]
|
||||
query_ids_to_filter.extend(
|
||||
[findQueries_i[j] for j in query_ids_to_filter_0]
|
||||
)
|
||||
|
||||
assert (
|
||||
len(query_ids_to_filter)
|
||||
== num_find_queries
|
||||
- self.video_mosaic_max_num_find_queries_per_frame * num_frames
|
||||
)
|
||||
self.find_ids_to_filter = set(query_ids_to_filter)
|
||||
|
||||
|
||||
class KeepSemanticFindQueriesOnly(FilterDataPointQueries):
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
self.obj_ids_to_filter = set()
|
||||
self.find_ids_to_filter = {
|
||||
i for i, q in enumerate(datapoint.find_queries) if q.input_bbox is not None
|
||||
} # filter (remove) geometric find queries (whose input_bbox is not None)
|
||||
|
||||
# Keep all get queries which don't depend on filtered finds
|
||||
|
||||
|
||||
class KeepUnaryFindQueriesOnly(FilterDataPointQueries):
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
self.obj_ids_to_filter = set()
|
||||
self.find_ids_to_filter = set()
|
||||
|
||||
# Keep all get queries which don't depend on filtered finds
|
||||
|
||||
|
||||
class FilterZeroBoxQueries(FilterDataPointQueries):
|
||||
"""
|
||||
Filters all find queries which predict a box with zero area
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _is_zero_area_object(obj: Object):
|
||||
# Check if height or width of bounding box is zero
|
||||
bbox = obj.bbox # Assume in XYXY format
|
||||
height = bbox[..., 3].item() - bbox[..., 1].item()
|
||||
width = bbox[..., 2].item() - bbox[..., 0].item()
|
||||
|
||||
return height == 0 or width == 0
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
|
||||
# Find objects with zero area
|
||||
# Assume only one image per datapoint
|
||||
image_objects = datapoint.images[0].objects
|
||||
exclude_objects = {
|
||||
obj_id
|
||||
for obj_id, obj in enumerate(image_objects)
|
||||
if self._is_zero_area_object(obj)
|
||||
}
|
||||
|
||||
# If a query predicts an object with zero area, drop the whole find query
|
||||
del_find_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
f_q_objects = set(f_q.object_ids_output)
|
||||
if len(exclude_objects.intersection(f_q_objects)) > 0:
|
||||
del_find_ids.append(i)
|
||||
|
||||
self.find_ids_to_filter = set(del_find_ids)
|
||||
|
||||
|
||||
class FilterFindQueriesWithTooManyOut(FilterDataPointQueries):
|
||||
"""
|
||||
Filters all find queries which have more than a specified number of objects in the output
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_objects: int):
|
||||
self.max_num_objects = max_num_objects
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
|
||||
# If a query predicts more than max_num_objects, drop the whole find query
|
||||
del_find_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if len(f_q.object_ids_output) > self.max_num_objects:
|
||||
del_find_ids.append(i)
|
||||
|
||||
self.find_ids_to_filter = set(del_find_ids)
|
||||
|
||||
|
||||
class FilterEmptyTargets(FilterDataPointQueries):
|
||||
"""
|
||||
Filters all targets which have zero area
|
||||
"""
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
|
||||
for img_id in range(len(datapoint.images)):
|
||||
for obj_id, obj in enumerate(datapoint.images[img_id].objects):
|
||||
if obj.area < 1e-6:
|
||||
self.obj_ids_to_filter.add((img_id, obj_id))
|
||||
self.find_ids_to_filter = set()
|
||||
|
||||
|
||||
class FilterNonExhaustiveFindQueries(FilterDataPointQueries):
|
||||
"""
|
||||
Filters all find queries which are non-exhaustive
|
||||
"""
|
||||
|
||||
def __init__(self, exhaustivity_type: str):
|
||||
"""
|
||||
Args:
|
||||
exhaustivity_type: Can be "pixel" or "instance":
|
||||
-pixel: filter queries where the union of all segments covers every pixel belonging to target class
|
||||
-instance: filter queries where there are non-separable or non annotated instances
|
||||
Note that instance exhaustivity implies pixel exhaustivity
|
||||
"""
|
||||
assert exhaustivity_type in ["pixel", "instance"]
|
||||
self.exhaustivity_type = exhaustivity_type
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
|
||||
# If a query predicts more than max_num_objects, drop the whole find query
|
||||
del_find_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if self.exhaustivity_type == "instance":
|
||||
if not f_q.is_exhaustive:
|
||||
del_find_ids.append(i)
|
||||
elif self.exhaustivity_type == "pixel":
|
||||
if f_q.is_pixel_exhaustive is not None and not f_q.is_pixel_exhaustive:
|
||||
del_find_ids.append(i)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unknown exhaustivity type {self.exhaustivity_type}"
|
||||
)
|
||||
|
||||
self.find_ids_to_filter = set(del_find_ids)
|
||||
|
||||
|
||||
class FilterInvalidGeometricQueries(FilterDataPointQueries):
|
||||
"""
|
||||
Filters geometric queries whose output got deleted (eg due to cropping)
|
||||
"""
|
||||
|
||||
def identify_queries_to_filter(self, datapoint):
|
||||
self.obj_ids_to_filter = set()
|
||||
|
||||
# If a query predicts more than max_num_objects, drop the whole find query
|
||||
del_find_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if f_q.input_bbox is not None and f_q.query_text == "geometric":
|
||||
if len(f_q.object_ids_output) == 0:
|
||||
del_find_ids.append(i)
|
||||
self.find_ids_to_filter = set(del_find_ids)
|
||||
|
||||
|
||||
class FlexibleFilterFindGetQueries:
|
||||
def __init__(
|
||||
self, query_filter: FilterDataPointQueries, enabled: bool = True
|
||||
) -> None:
|
||||
self.query_filter = query_filter
|
||||
self.enabled = enabled
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
if not self.enabled:
|
||||
return datapoint
|
||||
|
||||
# Identify all queries to filter
|
||||
self.query_filter.identify_queries_to_filter(datapoint=datapoint)
|
||||
|
||||
del_find_ids = []
|
||||
del_get_ids = []
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if self.query_filter._do_filter_query(f_q, i):
|
||||
datapoint.find_queries[i] = None
|
||||
del_find_ids.append(i)
|
||||
|
||||
new_find_queries = []
|
||||
new_get_queries = []
|
||||
|
||||
find_old_to_new_map = {}
|
||||
get_old_to_new_map = {}
|
||||
|
||||
find_counter = 0
|
||||
get_counter = 0
|
||||
|
||||
for i, f_q in enumerate(datapoint.find_queries):
|
||||
if f_q is not None:
|
||||
find_old_to_new_map[i] = find_counter
|
||||
find_counter += 1
|
||||
new_find_queries.append(f_q)
|
||||
|
||||
start_with_zero_check = False
|
||||
for n_f_q in new_find_queries:
|
||||
if n_f_q.query_processing_order == 0:
|
||||
start_with_zero_check = True
|
||||
break
|
||||
|
||||
if len(new_find_queries) == 0:
|
||||
start_with_zero_check = True
|
||||
|
||||
assert (
|
||||
start_with_zero_check
|
||||
), "Invalid Find queries, they need to start at query_processing_order = 0"
|
||||
|
||||
datapoint.find_queries = new_find_queries
|
||||
|
||||
if len(datapoint.find_queries) == 0:
|
||||
print("Warning: No find queries left in datapoint, this is not allowed")
|
||||
print("Filtering function:", self.query_filter)
|
||||
print("Datapoint:", datapoint)
|
||||
raise ValueError
|
||||
|
||||
# The deletion may have removed intermediate steps, so we need to remap to make them contiguous again
|
||||
all_stages = sorted(
|
||||
list(set(q.query_processing_order for q in datapoint.find_queries))
|
||||
)
|
||||
stage_map = {qpo: i for i, qpo in enumerate(all_stages)}
|
||||
for i in range(len(datapoint.find_queries)):
|
||||
qpo = datapoint.find_queries[i].query_processing_order
|
||||
datapoint.find_queries[i].query_processing_order = stage_map[qpo]
|
||||
|
||||
# Final step, clear up objects that are not used anymore
|
||||
for img_id in range(len(datapoint.images)):
|
||||
all_objects_ids = set(
|
||||
i
|
||||
for find in datapoint.find_queries
|
||||
for i in find.object_ids_output
|
||||
if find.image_id == img_id
|
||||
)
|
||||
unused_ids = (
|
||||
set(range(len(datapoint.images[img_id].objects))) - all_objects_ids
|
||||
)
|
||||
for tgt_img_id, tgt_obj_id in self.query_filter.obj_ids_to_filter:
|
||||
if tgt_img_id == img_id:
|
||||
unused_ids.add(tgt_obj_id)
|
||||
|
||||
if len(unused_ids) > 0:
|
||||
old_objects = datapoint.images[img_id].objects
|
||||
object_old_to_new_map = {}
|
||||
new_objects = []
|
||||
for i, o in enumerate(old_objects):
|
||||
if i not in unused_ids:
|
||||
object_old_to_new_map[i] = len(new_objects)
|
||||
new_objects.append(o)
|
||||
|
||||
datapoint.images[img_id].objects = new_objects
|
||||
|
||||
# Remap the outputs of the find queries
|
||||
affected_find_queries_ids = set()
|
||||
object_old_to_new_map_per_query = {}
|
||||
for fid, find in enumerate(datapoint.find_queries):
|
||||
if find.image_id == img_id:
|
||||
old_object_ids_output = find.object_ids_output
|
||||
object_old_to_new_map_per_query[fid] = {}
|
||||
find.object_ids_output = []
|
||||
for oid, old_obj_id in enumerate(old_object_ids_output):
|
||||
if old_obj_id not in unused_ids:
|
||||
new_obj_id = object_old_to_new_map[old_obj_id]
|
||||
find.object_ids_output.append(new_obj_id)
|
||||
object_old_to_new_map_per_query[fid][oid] = (
|
||||
len(find.object_ids_output) - 1
|
||||
)
|
||||
affected_find_queries_ids.add(fid)
|
||||
|
||||
# finally remove unused images
|
||||
all_imgs_to_keep = set()
|
||||
for f_q in datapoint.find_queries:
|
||||
all_imgs_to_keep.add(f_q.image_id)
|
||||
|
||||
old_img_id_to_new_img_id = {}
|
||||
new_images = []
|
||||
for img_id, img in enumerate(datapoint.images):
|
||||
if img_id in all_imgs_to_keep:
|
||||
old_img_id_to_new_img_id[img_id] = len(new_images)
|
||||
new_images.append(img)
|
||||
datapoint.images = new_images
|
||||
|
||||
for f_q in datapoint.find_queries:
|
||||
f_q.image_id = old_img_id_to_new_img_id[f_q.image_id]
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class AddPrefixSuffixToFindText:
|
||||
"""
|
||||
Add prefix or suffix strings to find query text on the fly.
|
||||
|
||||
If `condition_on_text` is True, the prefix or suffix strings are only added
|
||||
to those find query text in `condition_text_list` (case-insensitive).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
condition_on_text: bool = False,
|
||||
condition_text_list: Optional[List[str]] = None,
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.condition_on_text = condition_on_text
|
||||
if self.condition_on_text:
|
||||
assert condition_text_list is not None
|
||||
self.condition_text_set = {s.lower().strip() for s in condition_text_list}
|
||||
self.enabled = enabled
|
||||
if self.enabled:
|
||||
logging.info(
|
||||
f"AddPrefixSuffixToFindText: prefix={prefix}, suffix={suffix}, "
|
||||
f"condition_on_text={condition_on_text}, condition_text_list={condition_text_list}"
|
||||
)
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
if not self.enabled:
|
||||
return datapoint
|
||||
|
||||
for find in datapoint.find_queries:
|
||||
if find.query_text == "geometric":
|
||||
# skip geometric find queries
|
||||
continue
|
||||
if (
|
||||
self.condition_on_text
|
||||
and find.query_text.lower().strip() not in self.condition_text_set
|
||||
):
|
||||
# if condition_on_text is True, skip those queries not in condition_text_set
|
||||
continue
|
||||
|
||||
# add prefix and/or suffix strings to the find query text
|
||||
if self.prefix is not None:
|
||||
find.query_text = self.prefix + find.query_text
|
||||
if self.suffix is not None:
|
||||
find.query_text = find.query_text + self.suffix
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class FilterCrowds(FilterDataPointQueries):
|
||||
def identify_queries_to_filter(self, datapoint: Datapoint) -> None:
|
||||
"""
|
||||
Compute set of query ids to keep, for both find and get queries
|
||||
"""
|
||||
self.obj_ids_to_filter = set()
|
||||
self.find_ids_to_filter = set()
|
||||
# self.get_ids_to_filter = set()
|
||||
for img_id, img in enumerate(datapoint.images):
|
||||
for obj_id, obj in enumerate(img.objects):
|
||||
if obj.is_crowd:
|
||||
self.obj_ids_to_filter.add((img_id, obj_id))
|
||||
|
||||
|
||||
class TextQueryToVisual:
|
||||
"""
|
||||
Transform a test query to a visual query (with some proba), using any of the output targets as the prompt
|
||||
"""
|
||||
|
||||
def __init__(self, probability, keep_text_queries=False) -> None:
|
||||
self.probability = probability
|
||||
assert 0 <= probability <= 1
|
||||
self.keep_text_queries = keep_text_queries
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for find in datapoint.find_queries:
|
||||
if find.input_bbox is not None or find.input_points is not None:
|
||||
# skip geometric find queries
|
||||
continue
|
||||
|
||||
if len(find.object_ids_output) == 0:
|
||||
# Can't create a visual query, skip
|
||||
continue
|
||||
|
||||
if find.query_processing_order > 0:
|
||||
# Second stage query, can't use
|
||||
continue
|
||||
|
||||
if random.random() > self.probability:
|
||||
continue
|
||||
|
||||
selected_vq_id = random.choice(find.object_ids_output)
|
||||
img_id = find.image_id
|
||||
|
||||
find.input_bbox = datapoint.images[img_id].objects[selected_vq_id].bbox
|
||||
find.input_bbox_label = torch.ones(1, dtype=torch.bool)
|
||||
if not self.keep_text_queries:
|
||||
find.query_text = "visual"
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class RemoveInputBoxes:
|
||||
"""
|
||||
Remove input boxes from find queries
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for find in datapoint.find_queries:
|
||||
if find.input_bbox is None:
|
||||
continue
|
||||
|
||||
if find.query_text == "geometric":
|
||||
print("Warning: removing input box from geometric find query")
|
||||
|
||||
find.input_bbox = None
|
||||
return datapoint
|
||||
|
||||
|
||||
class OverwriteTextQuery:
|
||||
"""
|
||||
With some probability, overwrite the text query with a custom text
|
||||
"""
|
||||
|
||||
def __init__(self, target_text, probability=1.0) -> None:
|
||||
self.probability = probability
|
||||
self.target_text = target_text
|
||||
assert 0 <= probability <= 1
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for find in datapoint.find_queries:
|
||||
if random.random() > self.probability:
|
||||
continue
|
||||
|
||||
find.query_text = self.target_text
|
||||
|
||||
return datapoint
|
||||
345
sam3/train/transforms/point_sampling.py
Normal file
345
sam3/train/transforms/point_sampling.py
Normal file
@@ -0,0 +1,345 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image as PILImage
|
||||
from pycocotools import mask as mask_util
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||
from torchvision.ops import masks_to_boxes
|
||||
|
||||
|
||||
def sample_points_from_rle(rle, n_points, mode, box=None, normalize=True):
|
||||
"""
|
||||
Sample random points from a mask provided in COCO RLE format. 'mode'
|
||||
'mode' is in ["centered", "random_mask", "random_box"]
|
||||
"centered": points are sampled farthest from the mask edges and each other
|
||||
"random_mask": points are sampled uniformly from the mask
|
||||
"random_box": points are sampled uniformly from the annotation's box
|
||||
'box' must be provided if 'mode' is "random_box".
|
||||
If 'normalize' is true, points are in [0,1], relative to mask h,w.
|
||||
"""
|
||||
mask = np.ascontiguousarray(mask_util.decode(rle))
|
||||
points = sample_points_from_mask(mask, n_points, mode, box)
|
||||
|
||||
if normalize:
|
||||
h, w = mask.shape
|
||||
norm = np.array([w, h, 1.0])[None, :]
|
||||
points = points / norm
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def sample_points_from_mask(mask, n_points, mode, box=None):
|
||||
if mode == "centered":
|
||||
points = center_positive_sample(mask, n_points)
|
||||
elif mode == "random_mask":
|
||||
points = uniform_positive_sample(mask, n_points)
|
||||
elif mode == "random_box":
|
||||
assert box is not None, "'random_box' mode requires a provided box."
|
||||
points = uniform_sample_from_box(mask, box, n_points)
|
||||
else:
|
||||
raise ValueError(f"Unknown point sampling mode {mode}.")
|
||||
return points
|
||||
|
||||
|
||||
def uniform_positive_sample(mask, n_points):
|
||||
"""
|
||||
Samples positive points uniformly from the mask. Only integer pixel
|
||||
values are sampled.
|
||||
"""
|
||||
# Sampling directly from the uncompressed RLE would be faster but is
|
||||
# likely unnecessary.
|
||||
mask_points = np.stack(np.nonzero(mask), axis=0).transpose(1, 0)
|
||||
assert len(mask_points) > 0, "Can't sample positive points from an empty mask."
|
||||
selected_idxs = np.random.randint(low=0, high=len(mask_points), size=n_points)
|
||||
selected_points = mask_points[selected_idxs]
|
||||
|
||||
selected_points = selected_points[:, ::-1] # (y, x) -> (x, y)
|
||||
labels = np.ones((len(selected_points), 1))
|
||||
selected_points = np.concatenate([selected_points, labels], axis=1)
|
||||
|
||||
return selected_points
|
||||
|
||||
|
||||
def center_positive_sample(mask, n_points):
|
||||
"""
|
||||
Samples points farthest from mask edges (by distance transform)
|
||||
and subsequent points also farthest from each other. Each new point
|
||||
sampled is treated as an edge for future points. Edges of the image are
|
||||
treated as edges of the mask.
|
||||
"""
|
||||
|
||||
# Pad mask by one pixel on each end to assure distance transform
|
||||
# avoids edges
|
||||
padded_mask = np.pad(mask, 1)
|
||||
|
||||
points = []
|
||||
for _ in range(n_points):
|
||||
assert np.max(mask) > 0, "Can't sample positive points from an empty mask."
|
||||
dist = cv2.distanceTransform(padded_mask, cv2.DIST_L2, 0)
|
||||
point = np.unravel_index(dist.argmax(), dist.shape)
|
||||
# Mark selected point as background so next point avoids it
|
||||
padded_mask[point[0], point[1]] = 0
|
||||
points.append(point[::-1]) # (y, x) -> (x, y)
|
||||
|
||||
points = np.stack(points, axis=0)
|
||||
points = points - 1 # Subtract left/top padding of 1
|
||||
labels = np.ones((len(points), 1))
|
||||
points = np.concatenate([points, labels], axis=1)
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def uniform_sample_from_box(mask, box, n_points):
|
||||
"""
|
||||
Sample points uniformly from the provided box. The points' labels
|
||||
are determined by the provided mask. Does not guarantee a positive
|
||||
point is sampled. The box is assumed unnormalized in XYXY format.
|
||||
Points are sampled at integer values.
|
||||
"""
|
||||
|
||||
# Since lower/right edges are exclusive, ceil can be applied to all edges
|
||||
int_box = np.ceil(box)
|
||||
|
||||
x = np.random.randint(low=int_box[0], high=int_box[2], size=n_points)
|
||||
y = np.random.randint(low=int_box[1], high=int_box[3], size=n_points)
|
||||
labels = mask[y, x]
|
||||
points = np.stack([x, y, labels], axis=1)
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def rescale_box_xyxy(box, factor, imsize=None):
|
||||
"""
|
||||
Rescale a box providing in unnormalized XYXY format, fixing the center.
|
||||
If imsize is provided, clamp to the image.
|
||||
"""
|
||||
cx, cy = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
|
||||
new_w, new_h = factor * w, factor * h
|
||||
|
||||
new_x0, new_y0 = cx - new_w / 2, cy - new_h / 2
|
||||
new_x1, new_y1 = cx + new_w / 2, cy + new_h / 2
|
||||
|
||||
if imsize is not None:
|
||||
new_x0 = max(min(new_x0, imsize[1]), 0)
|
||||
new_x1 = max(min(new_x1, imsize[1]), 0)
|
||||
new_y0 = max(min(new_y0, imsize[0]), 0)
|
||||
new_y1 = max(min(new_y1, imsize[0]), 0)
|
||||
|
||||
return [new_x0, new_y0, new_x1, new_y1]
|
||||
|
||||
|
||||
def noise_box(box, im_size, box_noise_std, box_noise_max, min_box_area):
|
||||
if box_noise_std <= 0.0:
|
||||
return box
|
||||
noise = box_noise_std * torch.randn(size=(4,))
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
scale_factor = torch.tensor([w, h, w, h])
|
||||
noise = noise * scale_factor
|
||||
if box_noise_max is not None:
|
||||
noise = torch.clamp(noise, -box_noise_max, box_noise_max)
|
||||
input_box = box + noise
|
||||
# Clamp to maximum image size
|
||||
img_clamp = torch.tensor([im_size[1], im_size[0], im_size[1], im_size[0]])
|
||||
input_box = torch.maximum(input_box, torch.zeros_like(input_box))
|
||||
input_box = torch.minimum(input_box, img_clamp)
|
||||
if (input_box[2] - input_box[0]) * (input_box[3] - input_box[1]) <= min_box_area:
|
||||
return box
|
||||
|
||||
return input_box
|
||||
|
||||
|
||||
class RandomGeometricInputsAPI:
|
||||
"""
|
||||
For geometric queries, replaces the input box or points with a random
|
||||
one sampled from the GT mask. Segments must be provided for objects
|
||||
that are targets of geometric queries, and must be binary masks. Existing
|
||||
point and box queries in the datapoint will be ignored and completely replaced.
|
||||
Will sample points and boxes in XYXY format in absolute pixel space.
|
||||
|
||||
Geometry queries are currently determined by taking any query whose
|
||||
query text is a set value.
|
||||
|
||||
Args:
|
||||
num_points (int or (int, int)): how many points to sample. If a tuple,
|
||||
sample a random number of points uniformly over the inclusive range.
|
||||
box_chance (float): fraction of time a box is sampled. A box will replace
|
||||
one sampled point.
|
||||
box_noise_std (float): if greater than 0, add noise to the sampled boxes
|
||||
with this std. Noise is relative to the length of the box side.
|
||||
box_noise_max (int): if not none, truncate any box noise larger than this
|
||||
in terms of absolute pixels.
|
||||
resample_box_from_mask (bool): if True, any sampled box will be determined
|
||||
by finding the extrema of the provided mask. If False, the bbox provided
|
||||
in the target object will be used.
|
||||
point_sample_mode (str): In ["centered", "random_mask", "random_box"],
|
||||
controlling how points are sampled:
|
||||
"centered": points are sampled farthest from the mask edges and each other
|
||||
"random_mask": points are sampled uniformly from the mask
|
||||
"random_box": points are sampled uniformly from the annotation's box
|
||||
Note that "centered" may be too slow for on-line generation.
|
||||
geometric_query_str (str): what string in query_text indicates a
|
||||
geometry query.
|
||||
minimum_box_area (float): sampled boxes with area this size or smaller after
|
||||
noising will use the original box instead. It is the input's responsibility
|
||||
to avoid original boxes that violate necessary area bounds.
|
||||
concat_points (bool): if True, any sampled points will be added to existing
|
||||
ones instead of replacing them.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_points,
|
||||
box_chance,
|
||||
box_noise_std=0.0,
|
||||
box_noise_max=None,
|
||||
minimum_box_area=0.0,
|
||||
resample_box_from_mask=False,
|
||||
point_sample_mode="random_mask",
|
||||
sample_box_scale_factor=1.0,
|
||||
geometric_query_str="geometric",
|
||||
concat_points=False,
|
||||
):
|
||||
self.num_points = num_points
|
||||
if not isinstance(self.num_points, int):
|
||||
# Convert from inclusive range to exclusive range expected by torch
|
||||
self.num_points[1] += 1
|
||||
self.num_points = tuple(self.num_points)
|
||||
self.box_chance = box_chance
|
||||
self.box_noise_std = box_noise_std
|
||||
self.box_noise_max = box_noise_max
|
||||
self.minimum_box_area = minimum_box_area
|
||||
self.resample_box_from_mask = resample_box_from_mask
|
||||
self.point_sample_mode = point_sample_mode
|
||||
assert point_sample_mode in [
|
||||
"centered",
|
||||
"random_mask",
|
||||
"random_box",
|
||||
], "Unknown point sample mode."
|
||||
self.geometric_query_str = geometric_query_str
|
||||
self.concat_points = concat_points
|
||||
self.sample_box_scale_factor = sample_box_scale_factor
|
||||
|
||||
def _sample_num_points_and_if_box(self):
|
||||
if isinstance(self.num_points, tuple):
|
||||
n_points = torch.randint(
|
||||
low=self.num_points[0], high=self.num_points[1], size=(1,)
|
||||
).item()
|
||||
else:
|
||||
n_points = self.num_points
|
||||
if self.box_chance > 0.0:
|
||||
use_box = torch.rand(size=(1,)).item() < self.box_chance
|
||||
n_points -= int(use_box) # box stands in for one point
|
||||
else:
|
||||
use_box = False
|
||||
return n_points, use_box
|
||||
|
||||
def _get_original_box(self, target_object):
|
||||
if not self.resample_box_from_mask:
|
||||
return target_object.bbox
|
||||
mask = target_object.segment
|
||||
return masks_to_boxes(mask[None, :, :])[0]
|
||||
|
||||
def _get_target_object(self, datapoint, query):
|
||||
img = datapoint.images[query.image_id]
|
||||
targets = query.object_ids_output
|
||||
assert (
|
||||
len(targets) == 1
|
||||
), "Geometric queries only support a single target object."
|
||||
target_idx = targets[0]
|
||||
return img.objects[target_idx]
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
for query in datapoint.find_queries:
|
||||
if query.query_text != self.geometric_query_str:
|
||||
continue
|
||||
|
||||
target_object = self._get_target_object(datapoint, query)
|
||||
n_points, use_box = self._sample_num_points_and_if_box()
|
||||
box = self._get_original_box(target_object)
|
||||
|
||||
mask = target_object.segment
|
||||
if n_points > 0:
|
||||
# FIXME: The conversion to numpy and back to reuse code
|
||||
# is awkward, but this is all in the dataloader worker anyway
|
||||
# on CPU and so I don't think it should matter.
|
||||
if self.sample_box_scale_factor != 1.0:
|
||||
sample_box = rescale_box_xyxy(
|
||||
box.numpy(), self.sample_box_scale_factor, mask.shape
|
||||
)
|
||||
else:
|
||||
sample_box = box.numpy()
|
||||
input_points = sample_points_from_mask(
|
||||
mask.numpy(),
|
||||
n_points,
|
||||
self.point_sample_mode,
|
||||
sample_box,
|
||||
)
|
||||
input_points = torch.as_tensor(input_points)
|
||||
input_points = input_points[None, :, :]
|
||||
if self.concat_points and query.input_points is not None:
|
||||
input_points = torch.cat([query.input_points, input_points], dim=1)
|
||||
else:
|
||||
input_points = query.input_points if self.concat_points else None
|
||||
|
||||
if use_box:
|
||||
w, h = datapoint.images[query.image_id].size
|
||||
input_box = noise_box(
|
||||
box,
|
||||
(h, w),
|
||||
box_noise_std=self.box_noise_std,
|
||||
box_noise_max=self.box_noise_max,
|
||||
min_box_area=self.minimum_box_area,
|
||||
)
|
||||
input_box = input_box[None, :]
|
||||
else:
|
||||
input_box = query.input_bbox if self.concat_points else None
|
||||
|
||||
query.input_points = input_points
|
||||
query.input_bbox = input_box
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class RandomizeInputBbox:
|
||||
"""
|
||||
Simplified version of the geometric transform that only deals with input boxes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
box_noise_std=0.0,
|
||||
box_noise_max=None,
|
||||
minimum_box_area=0.0,
|
||||
):
|
||||
self.box_noise_std = box_noise_std
|
||||
self.box_noise_max = box_noise_max
|
||||
self.minimum_box_area = minimum_box_area
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for query in datapoint.find_queries:
|
||||
if query.input_bbox is None:
|
||||
continue
|
||||
|
||||
img = datapoint.images[query.image_id].data
|
||||
if isinstance(img, PILImage.Image):
|
||||
w, h = img.size
|
||||
else:
|
||||
assert isinstance(img, torch.Tensor)
|
||||
h, w = img.shape[-2:]
|
||||
|
||||
for box_id in range(query.input_bbox.shape[0]):
|
||||
query.input_bbox[box_id, :] = noise_box(
|
||||
query.input_bbox[box_id, :].view(4),
|
||||
(h, w),
|
||||
box_noise_std=self.box_noise_std,
|
||||
box_noise_max=self.box_noise_max,
|
||||
min_box_area=self.minimum_box_area,
|
||||
).view(1, 4)
|
||||
|
||||
return datapoint
|
||||
157
sam3/train/transforms/segmentation.py
Normal file
157
sam3/train/transforms/segmentation.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import numpy as np
|
||||
import pycocotools.mask as mask_utils
|
||||
import torch
|
||||
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from sam3.model.box_ops import masks_to_boxes
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||
|
||||
|
||||
class InstanceToSemantic(object):
|
||||
"""Convert instance segmentation to semantic segmentation."""
|
||||
|
||||
def __init__(self, delete_instance=True, use_rle=False):
|
||||
self.delete_instance = delete_instance
|
||||
self.use_rle = use_rle
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for fquery in datapoint.find_queries:
|
||||
h, w = datapoint.images[fquery.image_id].size
|
||||
|
||||
if self.use_rle:
|
||||
all_segs = [
|
||||
datapoint.images[fquery.image_id].objects[obj_id].segment
|
||||
for obj_id in fquery.object_ids_output
|
||||
]
|
||||
if len(all_segs) > 0:
|
||||
# we need to double check that all rles are the correct size
|
||||
# Otherwise cocotools will fail silently to an empty [0,0] mask
|
||||
for seg in all_segs:
|
||||
assert seg["size"] == all_segs[0]["size"], (
|
||||
"Instance segments have inconsistent sizes. "
|
||||
f"Found sizes {seg['size']} and {all_segs[0]['size']}"
|
||||
)
|
||||
fquery.semantic_target = mask_utils.merge(all_segs)
|
||||
else:
|
||||
# There is no good way to create an empty RLE of the correct size
|
||||
# We resort to converting an empty box to RLE
|
||||
fquery.semantic_target = mask_utils.frPyObjects(
|
||||
np.array([[0, 0, 0, 0]], dtype=np.float64), h, w
|
||||
)[0]
|
||||
|
||||
else:
|
||||
# `semantic_target` is uint8 and remains uint8 throughout the transforms
|
||||
# (it contains binary 0 and 1 values just like `segment` for each object)
|
||||
fquery.semantic_target = torch.zeros((h, w), dtype=torch.uint8)
|
||||
for obj_id in fquery.object_ids_output:
|
||||
segment = datapoint.images[fquery.image_id].objects[obj_id].segment
|
||||
if segment is not None:
|
||||
assert (
|
||||
isinstance(segment, torch.Tensor)
|
||||
and segment.dtype == torch.uint8
|
||||
)
|
||||
fquery.semantic_target |= segment
|
||||
|
||||
if self.delete_instance:
|
||||
for img in datapoint.images:
|
||||
for obj in img.objects:
|
||||
del obj.segment
|
||||
obj.segment = None
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class RecomputeBoxesFromMasks:
|
||||
"""Recompute bounding boxes from masks."""
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
for img in datapoint.images:
|
||||
for obj in img.objects:
|
||||
# Note: if the mask is empty, the bounding box will be undefined
|
||||
# The empty targets should be subsequently filtered
|
||||
obj.bbox = masks_to_boxes(obj.segment)
|
||||
obj.area = obj.segment.sum().item()
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class DecodeRle:
|
||||
"""This transform decodes RLEs into binary segments.
|
||||
Implementing it as a transforms allows lazy loading. Some transforms (eg query filters)
|
||||
may be deleting masks, so decoding them from the beginning is wasteful.
|
||||
|
||||
This transforms needs to be called before any kind of geometric manipulation
|
||||
"""
|
||||
|
||||
def __call__(self, datapoint: Datapoint, **kwargs):
|
||||
imgId2size = {}
|
||||
warning_shown = False
|
||||
for imgId, img in enumerate(datapoint.images):
|
||||
if isinstance(img.data, PILImage.Image):
|
||||
img_w, img_h = img.data.size
|
||||
elif isinstance(img.data, torch.Tensor):
|
||||
img_w, img_h = img.data.shape[-2:]
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected image type {type(img.data)}")
|
||||
|
||||
imgId2size[imgId] = (img_h, img_w)
|
||||
|
||||
for obj in img.objects:
|
||||
if obj.segment is not None and not isinstance(
|
||||
obj.segment, torch.Tensor
|
||||
):
|
||||
if mask_utils.area(obj.segment) == 0:
|
||||
print("Warning, empty mask found, approximating from box")
|
||||
obj.segment = torch.zeros(img_h, img_w, dtype=torch.uint8)
|
||||
x1, y1, x2, y2 = obj.bbox.int().tolist()
|
||||
obj.segment[y1 : max(y2, y1 + 1), x1 : max(x1 + 1, x2)] = 1
|
||||
else:
|
||||
obj.segment = mask_utils.decode(obj.segment)
|
||||
# segment is uint8 and remains uint8 throughout the transforms
|
||||
obj.segment = torch.tensor(obj.segment).to(torch.uint8)
|
||||
|
||||
if list(obj.segment.shape) != [img_h, img_w]:
|
||||
# Should not happen often, but adding for security
|
||||
if not warning_shown:
|
||||
print(
|
||||
f"Warning expected instance segmentation size to be {[img_h, img_w]} but found {list(obj.segment.shape)}"
|
||||
)
|
||||
# Printing only once per datapoint to avoid spam
|
||||
warning_shown = True
|
||||
|
||||
obj.segment = F.resize(
|
||||
obj.segment[None], (img_h, img_w)
|
||||
).squeeze(0)
|
||||
|
||||
assert list(obj.segment.shape) == [img_h, img_w]
|
||||
|
||||
warning_shown = False
|
||||
for query in datapoint.find_queries:
|
||||
if query.semantic_target is not None and not isinstance(
|
||||
query.semantic_target, torch.Tensor
|
||||
):
|
||||
query.semantic_target = mask_utils.decode(query.semantic_target)
|
||||
# segment is uint8 and remains uint8 throughout the transforms
|
||||
query.semantic_target = torch.tensor(query.semantic_target).to(
|
||||
torch.uint8
|
||||
)
|
||||
if tuple(query.semantic_target.shape) != imgId2size[query.image_id]:
|
||||
if not warning_shown:
|
||||
print(
|
||||
f"Warning expected semantic segmentation size to be {imgId2size[query.image_id]} but found {tuple(query.semantic_target.shape)}"
|
||||
)
|
||||
# Printing only once per datapoint to avoid spam
|
||||
warning_shown = True
|
||||
|
||||
query.semantic_target = F.resize(
|
||||
query.semantic_target[None], imgId2size[query.image_id]
|
||||
).squeeze(0)
|
||||
|
||||
assert tuple(query.semantic_target.shape) == imgId2size[query.image_id]
|
||||
|
||||
return datapoint
|
||||
1
sam3/train/utils/__init__.py
Normal file
1
sam3/train/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
358
sam3/train/utils/checkpoint_utils.py
Normal file
358
sam3/train/utils/checkpoint_utils.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
|
||||
import contextlib
|
||||
import fnmatch
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from torch.jit._script import RecursiveScriptModule
|
||||
|
||||
|
||||
def unix_pattern_to_parameter_names(
|
||||
constraints: List[str], all_parameter_names: Sequence[str]
|
||||
) -> Union[None, Set[str]]:
|
||||
"""
|
||||
Go through the list of parameter names and select those that match
|
||||
any of the provided constraints
|
||||
"""
|
||||
parameter_names = []
|
||||
for param_name in constraints:
|
||||
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
|
||||
assert (
|
||||
len(matching_parameters) > 0
|
||||
), f"param_names {param_name} don't match any param in the given names."
|
||||
parameter_names.append(matching_parameters)
|
||||
return set.union(*parameter_names)
|
||||
|
||||
|
||||
def filter_params_matching_unix_pattern(
|
||||
patterns: List[str], state_dict: Dict[str, torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Remove from the state dictionary the parameters matching the provided unix patterns
|
||||
|
||||
Args:
|
||||
patterns: the list of unix patterns to exclude
|
||||
state_dict: the dictionary to filter
|
||||
|
||||
Returns:
|
||||
A new state dictionary
|
||||
"""
|
||||
if len(patterns) == 0:
|
||||
return {}
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
included_keys = unix_pattern_to_parameter_names(patterns, all_keys)
|
||||
return {k: state_dict[k] for k in included_keys}
|
||||
|
||||
|
||||
def exclude_params_matching_unix_pattern(
|
||||
patterns: List[str], state_dict: Dict[str, torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Remove from the state dictionary the parameters matching the provided unix patterns
|
||||
|
||||
Args:
|
||||
patterns: the list of unix patterns to exclude
|
||||
state_dict: the dictionary to filter
|
||||
|
||||
Returns:
|
||||
A new state dictionary
|
||||
"""
|
||||
if len(patterns) == 0:
|
||||
return state_dict
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
excluded_keys = unix_pattern_to_parameter_names(patterns, all_keys)
|
||||
return {k: v for k, v in state_dict.items() if k not in excluded_keys}
|
||||
|
||||
|
||||
def _get_state_dict_summary(state_dict: Dict[str, torch.Tensor]):
|
||||
keys = []
|
||||
trace = []
|
||||
for k, v in state_dict.items():
|
||||
keys.append(k)
|
||||
trace.append(v.sum().item())
|
||||
trace = np.array(trace)[np.argsort(keys)]
|
||||
return trace
|
||||
|
||||
|
||||
def assert_skipped_parameters_are_frozen(model: nn.Module, patterns: List[str]):
|
||||
"""
|
||||
Verifies that all the parameters matching the provided patterns
|
||||
are frozen - this acts as a safeguard when ignoring parameter
|
||||
when saving checkpoints - if the parameters are in fact trainable
|
||||
"""
|
||||
if not patterns:
|
||||
return
|
||||
|
||||
frozen_state_dict = filter_params_matching_unix_pattern(
|
||||
patterns=patterns, state_dict=model.state_dict()
|
||||
)
|
||||
non_frozen_keys = {
|
||||
n
|
||||
for n, p in model.named_parameters()
|
||||
if n in frozen_state_dict and p.requires_grad
|
||||
}
|
||||
if non_frozen_keys:
|
||||
raise ValueError(
|
||||
f"Parameters excluded with `skip_saving_parameters` should be frozen: {non_frozen_keys}"
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def with_check_parameter_frozen(
|
||||
model: nn.Module, patterns: List[str], disabled: bool = True
|
||||
):
|
||||
"""
|
||||
Context manager that inspects a model surrounding a piece of code
|
||||
and verifies if the model has been updated by this piece of code
|
||||
|
||||
The function will raise an exception if the model has been updated
|
||||
on at least one of the parameter that matches one of the pattern
|
||||
|
||||
Args:
|
||||
model: the model that might have been updated
|
||||
patterns: for the parameters we want to observe
|
||||
allowed:
|
||||
"""
|
||||
if not patterns or disabled:
|
||||
yield
|
||||
return
|
||||
|
||||
frozen_state_dict = filter_params_matching_unix_pattern(
|
||||
patterns=patterns, state_dict=model.state_dict()
|
||||
)
|
||||
summary_before = _get_state_dict_summary(frozen_state_dict)
|
||||
|
||||
yield
|
||||
|
||||
frozen_state_dict = filter_params_matching_unix_pattern(
|
||||
patterns=patterns, state_dict=model.state_dict()
|
||||
)
|
||||
summary_after = _get_state_dict_summary(frozen_state_dict)
|
||||
|
||||
if not np.allclose(summary_before, summary_after, atol=1e-6):
|
||||
raise ValueError(
|
||||
f"""
|
||||
The `model_weight_initializer` has initialized parameters frozen with `skip_saving_parameters`.
|
||||
You can resolve this error by either initializing those parameters from within the model definition
|
||||
or using the flag `trainer.checkpoint.initialize_after_preemption` to True.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class CkptExcludeKernel:
|
||||
"""
|
||||
Removes the keys from the given model state_dict that match the key_pattern.
|
||||
|
||||
Args:
|
||||
key_pattern: Patterns used to select the keys in the state_dict
|
||||
that are eligible for this kernel.
|
||||
"""
|
||||
|
||||
def __init__(self, key_pattern: List[str]):
|
||||
self.key_pattern = key_pattern
|
||||
|
||||
def __call__(self, state_dict: Dict):
|
||||
"""
|
||||
Args:
|
||||
state_dict: A dictionary representing the given checkpoint's state dict.
|
||||
"""
|
||||
if len(self.key_pattern) == 0:
|
||||
return state_dict
|
||||
exclude_keys = unix_pattern_to_parameter_names(
|
||||
self.key_pattern, state_dict.keys()
|
||||
)
|
||||
return {k: v for k, v in state_dict.items() if k not in exclude_keys}
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
path_list: List[str],
|
||||
pick_recursive_keys: Optional[List[str]] = None,
|
||||
map_location: str = "cpu",
|
||||
) -> Any:
|
||||
"""
|
||||
Loads a checkpoint from the specified path.
|
||||
|
||||
Args:
|
||||
path_list: A list of paths which contain the checkpoint. Each element
|
||||
is tried (in order) until a file that exists is found. That file is then
|
||||
used to read the checkpoint.
|
||||
pick_recursive_keys: Picks sub dicts from the loaded checkpoint if not None.
|
||||
For pick_recursive_keys = ["a", "b"], will return checkpoint_dict["a"]["b"]
|
||||
map_location (str): a function, torch.device, string or a dict specifying how to
|
||||
remap storage locations
|
||||
|
||||
Returns: Model with the matchin pre-trained weights loaded.
|
||||
"""
|
||||
path_exists = False
|
||||
for path in path_list:
|
||||
if g_pathmgr.exists(path):
|
||||
path_exists = True
|
||||
break
|
||||
|
||||
if not path_exists:
|
||||
raise ValueError(f"No path exists in {path_list}")
|
||||
|
||||
with g_pathmgr.open(path, "rb") as f:
|
||||
checkpoint = torch.load(f, map_location=map_location)
|
||||
|
||||
logging.info(f"Loaded checkpoint from {path}")
|
||||
if pick_recursive_keys is not None:
|
||||
for key in pick_recursive_keys:
|
||||
checkpoint = checkpoint[key]
|
||||
return checkpoint
|
||||
|
||||
|
||||
def get_state_dict(checkpoint, ckpt_state_dict_keys):
|
||||
if isinstance(checkpoint, RecursiveScriptModule):
|
||||
# This is a torchscript JIT model
|
||||
return checkpoint.state_dict()
|
||||
pre_train_dict = checkpoint
|
||||
for i, key in enumerate(ckpt_state_dict_keys):
|
||||
if (isinstance(pre_train_dict, Mapping) and key not in pre_train_dict) or (
|
||||
isinstance(pre_train_dict, Sequence) and key >= len(pre_train_dict)
|
||||
):
|
||||
key_str = (
|
||||
'["' + '"]["'.join(list(map(ckpt_state_dict_keys[:i], str))) + '"]'
|
||||
)
|
||||
raise KeyError(
|
||||
f"'{key}' not found in checkpoint{key_str} "
|
||||
f"with keys: {pre_train_dict.keys()}"
|
||||
)
|
||||
pre_train_dict = pre_train_dict[key]
|
||||
return pre_train_dict
|
||||
|
||||
|
||||
def load_checkpoint_and_apply_kernels(
|
||||
checkpoint_path: str,
|
||||
checkpoint_kernels: List[Callable] = None,
|
||||
ckpt_state_dict_keys: Tuple[str] = ("state_dict",),
|
||||
map_location: str = "cpu",
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Performs checkpoint loading with a variety of pre-processing kernel applied in
|
||||
sequence.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): Path to the checkpoint.
|
||||
checkpoint_kernels List(Callable): A list of checkpoint processing kernels
|
||||
to apply in the specified order. Supported kernels include `CkptIncludeKernel`,
|
||||
`CkptExcludeKernel`, etc. These kernels are applied in the
|
||||
given order.
|
||||
ckpt_state_dict_keys (str): Keys containing the model state dict.
|
||||
map_location (str): a function, torch.device, string or a dict specifying how to
|
||||
remap storage locations
|
||||
|
||||
Returns: Model with the matchin pre-trained weights loaded.
|
||||
"""
|
||||
assert g_pathmgr.exists(checkpoint_path), "Checkpoint '{}' not found".format(
|
||||
checkpoint_path
|
||||
)
|
||||
|
||||
# Load the checkpoint on CPU to avoid GPU mem spike.
|
||||
with g_pathmgr.open(checkpoint_path, "rb") as f:
|
||||
checkpoint = torch.load(f, map_location=map_location)
|
||||
|
||||
pre_train_dict = get_state_dict(checkpoint, ckpt_state_dict_keys)
|
||||
|
||||
# Not logging into info etc since it's a huge log
|
||||
logging.debug(
|
||||
"Loaded Checkpoint State Dict pre-kernel application: %s"
|
||||
% str(", ".join(list(pre_train_dict.keys())))
|
||||
)
|
||||
# Apply kernels
|
||||
if checkpoint_kernels is not None:
|
||||
for f in checkpoint_kernels:
|
||||
pre_train_dict = f(state_dict=pre_train_dict)
|
||||
|
||||
logging.debug(
|
||||
"Loaded Checkpoint State Dict Post-kernel application %s"
|
||||
% str(", ".join(list(pre_train_dict.keys())))
|
||||
)
|
||||
|
||||
return pre_train_dict
|
||||
|
||||
|
||||
def check_load_state_dict_errors(
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
strict: bool,
|
||||
ignore_missing_keys: List[str] = None,
|
||||
ignore_unexpected_keys: List[str] = None,
|
||||
):
|
||||
if ignore_missing_keys is not None and len(ignore_missing_keys) > 0:
|
||||
ignored_keys = unix_pattern_to_parameter_names(
|
||||
ignore_missing_keys, missing_keys
|
||||
)
|
||||
missing_keys = [key for key in missing_keys if key not in ignored_keys]
|
||||
|
||||
if ignore_unexpected_keys is not None and len(ignore_unexpected_keys) > 0:
|
||||
ignored_unexpected_keys = unix_pattern_to_parameter_names(
|
||||
ignore_unexpected_keys, unexpected_keys
|
||||
)
|
||||
unexpected_keys = [
|
||||
key for key in unexpected_keys if key not in ignored_unexpected_keys
|
||||
]
|
||||
|
||||
err = "State key mismatch."
|
||||
if unexpected_keys:
|
||||
err += f" Unexpected keys: {unexpected_keys}."
|
||||
if missing_keys:
|
||||
err += f" Missing keys: {missing_keys}."
|
||||
|
||||
if unexpected_keys or missing_keys:
|
||||
logging.warning(err)
|
||||
if unexpected_keys or strict:
|
||||
raise KeyError(err)
|
||||
|
||||
|
||||
def load_state_dict_into_model(
|
||||
state_dict: Dict,
|
||||
model: nn.Module,
|
||||
strict: bool = True,
|
||||
ignore_missing_keys: List[str] = None,
|
||||
ignore_unexpected_keys: List[str] = None,
|
||||
checkpoint_kernels: List[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Loads a state dict into the given model.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary containing the model's
|
||||
state dict, or a subset if strict is False
|
||||
model: Model to load the checkpoint weights into
|
||||
strict: raise if the state_dict has missing state keys
|
||||
ignore_missing_keys: unix pattern of keys to ignore
|
||||
"""
|
||||
# Apply kernels
|
||||
if checkpoint_kernels is not None:
|
||||
for f in checkpoint_kernels:
|
||||
state_dict = f(state_dict=state_dict)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
check_load_state_dict_errors(
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
strict=strict,
|
||||
ignore_missing_keys=ignore_missing_keys,
|
||||
ignore_unexpected_keys=ignore_unexpected_keys,
|
||||
)
|
||||
return model
|
||||
585
sam3/train/utils/distributed.py
Normal file
585
sam3/train/utils/distributed.py
Normal file
@@ -0,0 +1,585 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.autograd as autograd
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# Default to GPU 0
|
||||
_cuda_device_index: int = 0
|
||||
|
||||
# Setting _cuda_device_index to -1 internally implies that we should use CPU
|
||||
_CPU_DEVICE_INDEX = -1
|
||||
_PRIMARY_RANK = 0
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_global_gloo_group():
|
||||
"""
|
||||
Return a process group based on gloo backend, containing all the ranks
|
||||
The result is cached.
|
||||
"""
|
||||
|
||||
if dist.get_backend() == "nccl":
|
||||
# Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes
|
||||
# being much slower than others causing a timeout (which can happen in relation
|
||||
# or LVIS class mAP evaluation).
|
||||
timeout = 43200
|
||||
return dist.new_group(
|
||||
backend="gloo",
|
||||
timeout=datetime.timedelta(seconds=timeout),
|
||||
)
|
||||
|
||||
return dist.group.WORLD
|
||||
|
||||
|
||||
def is_main_process():
|
||||
"""Return true if the current process is the main one"""
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
|
||||
`all_gather` above, but using filesystem instead of collective ops.
|
||||
|
||||
If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
|
||||
(and other ranks will have an empty list).
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
print("gathering via files")
|
||||
cpu_group = _get_global_gloo_group()
|
||||
|
||||
# if unspecified, we will save to the current python file dir
|
||||
if filesys_save_dir is not None:
|
||||
save_dir = filesys_save_dir
|
||||
elif "EXP_DIR" in os.environ:
|
||||
save_dir = os.environ["EXP_DIR"]
|
||||
else:
|
||||
# try the same directory where the code is stored
|
||||
save_dir = filesys_save_dir or os.path.dirname(__file__)
|
||||
save_dir = os.path.join(save_dir, "all_gather_via_filesys")
|
||||
if is_main_process():
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# use a timestamp and salt to distinguish different all_gather
|
||||
timestamp = int(time.time()) if is_main_process() else 0
|
||||
salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
|
||||
# broadcast the timestamp and salt across ranks
|
||||
# (all-reduce will do the broadcasting since only rank 0 is non-zero)
|
||||
timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
|
||||
dist.all_reduce(timestamp_and_salt, group=cpu_group)
|
||||
timestamp, salt = timestamp_and_salt.tolist()
|
||||
|
||||
# save the data to a file on the disk
|
||||
rank_save = get_rank()
|
||||
save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
|
||||
save_data_path = os.path.join(save_dir, save_data_filename)
|
||||
assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
|
||||
torch.save(data, save_data_path)
|
||||
dist.barrier(group=cpu_group)
|
||||
|
||||
# read the data from the files
|
||||
data_list = []
|
||||
if rank_save == 0 or not gather_to_rank_0_only:
|
||||
for rank_load in range(world_size):
|
||||
load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
|
||||
load_data_path = os.path.join(save_dir, load_data_filename)
|
||||
assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
|
||||
data_list.append(torch.load(load_data_path, weights_only=False))
|
||||
dist.barrier(group=cpu_group)
|
||||
|
||||
# delete the saved file
|
||||
os.remove(save_data_path)
|
||||
return data_list
|
||||
|
||||
|
||||
def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1":
|
||||
return all_gather_via_filesys(
|
||||
data, filesys_save_dir, gather_to_rank_0_only=True
|
||||
)
|
||||
|
||||
if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys:
|
||||
return all_gather_via_filesys(data, filesys_save_dir)
|
||||
|
||||
cpu_group = None
|
||||
if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu:
|
||||
cpu_group = _get_global_gloo_group()
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(data, buffer)
|
||||
data_view = buffer.getbuffer()
|
||||
device = "cuda" if cpu_group is None else "cpu"
|
||||
tensor = torch.ByteTensor(data_view).to(device)
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
|
||||
size_list = [
|
||||
torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
|
||||
]
|
||||
if cpu_group is None:
|
||||
dist.all_gather(size_list, local_size)
|
||||
else:
|
||||
print("gathering on cpu")
|
||||
dist.all_gather(size_list, local_size, group=cpu_group)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
assert isinstance(local_size.item(), int)
|
||||
local_size = int(local_size.item())
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(
|
||||
size=(max_size - local_size,), dtype=torch.uint8, device=device
|
||||
)
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
if cpu_group is None:
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
else:
|
||||
dist.all_gather(tensor_list, tensor, group=cpu_group)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
|
||||
buffer = io.BytesIO(tensor.cpu().numpy())
|
||||
obj = torch.load(buffer, weights_only=False)
|
||||
data_list.append(obj)
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
|
||||
"""
|
||||
For some backends, such as NCCL, communication only works if the
|
||||
tensor is on the GPU. This helper function converts to the correct
|
||||
device and returns the tensor + original device.
|
||||
"""
|
||||
orig_device = "cpu" if not tensor.is_cuda else "gpu"
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
|
||||
and not tensor.is_cuda
|
||||
):
|
||||
tensor = tensor.cuda()
|
||||
return (tensor, orig_device)
|
||||
|
||||
|
||||
def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
|
||||
"""
|
||||
For some backends, such as NCCL, communication only works if the
|
||||
tensor is on the GPU. This converts the tensor back to original device.
|
||||
"""
|
||||
if tensor.is_cuda and orig_device == "cpu":
|
||||
tensor = tensor.cpu()
|
||||
return tensor
|
||||
|
||||
|
||||
def is_distributed_training_run() -> bool:
|
||||
return (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
and (torch.distributed.get_world_size() > 1)
|
||||
)
|
||||
|
||||
|
||||
def is_primary() -> bool:
|
||||
"""
|
||||
Returns True if this is rank 0 of a distributed training job OR if it is
|
||||
a single trainer job. Otherwise False.
|
||||
"""
|
||||
return get_rank() == _PRIMARY_RANK
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing mean reduction
|
||||
of tensor over all processes.
|
||||
"""
|
||||
return all_reduce_op(
|
||||
tensor,
|
||||
torch.distributed.ReduceOp.SUM,
|
||||
lambda t: t / torch.distributed.get_world_size(),
|
||||
)
|
||||
|
||||
|
||||
def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing sum
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM)
|
||||
|
||||
|
||||
def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing min
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN)
|
||||
|
||||
|
||||
def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing min
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX)
|
||||
|
||||
|
||||
def all_reduce_op(
|
||||
tensor: torch.Tensor,
|
||||
op: torch.distributed.ReduceOp,
|
||||
after_op_func: Callable[[torch.Tensor], torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
if is_distributed_training_run():
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
torch.distributed.all_reduce(tensor, op)
|
||||
if after_op_func is not None:
|
||||
tensor = after_op_func(tensor)
|
||||
tensor = convert_to_normal_tensor(tensor, orig_device)
|
||||
return tensor
|
||||
|
||||
|
||||
def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_gather for performing
|
||||
'gather' of 'tensor' over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
if tensor.ndim == 0:
|
||||
# 0 dim tensors cannot be gathered. so unsqueeze
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
if is_distributed_training_run():
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
gathered_tensors = [
|
||||
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(gathered_tensors, tensor)
|
||||
gathered_tensors = [
|
||||
convert_to_normal_tensor(_tensor, orig_device)
|
||||
for _tensor in gathered_tensors
|
||||
]
|
||||
else:
|
||||
gathered_tensors = [tensor]
|
||||
|
||||
return gathered_tensors
|
||||
|
||||
|
||||
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
|
||||
gathered_tensors = gather_tensors_from_all(tensor)
|
||||
gathered_tensor = torch.cat(gathered_tensors, 0)
|
||||
return gathered_tensor
|
||||
|
||||
|
||||
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source
|
||||
to all processes in both distributed / non-distributed scenarios.
|
||||
"""
|
||||
if is_distributed_training_run():
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
torch.distributed.broadcast(tensor, src)
|
||||
tensor = convert_to_normal_tensor(tensor, orig_device)
|
||||
return tensor
|
||||
|
||||
|
||||
def barrier() -> None:
|
||||
"""
|
||||
Wrapper over torch.distributed.barrier, returns without waiting
|
||||
if the distributed process group is not initialized instead of throwing error.
|
||||
"""
|
||||
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
|
||||
return
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""
|
||||
Simple wrapper for correctly getting worldsize in both distributed
|
||||
/ non-distributed settings
|
||||
"""
|
||||
return (
|
||||
torch.distributed.get_world_size()
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
else 1
|
||||
)
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""
|
||||
Simple wrapper for correctly getting rank in both distributed
|
||||
/ non-distributed settings
|
||||
"""
|
||||
return (
|
||||
torch.distributed.get_rank()
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
else 0
|
||||
)
|
||||
|
||||
|
||||
def get_primary_rank() -> int:
|
||||
return _PRIMARY_RANK
|
||||
|
||||
|
||||
def set_cuda_device_index(idx: int) -> None:
|
||||
global _cuda_device_index
|
||||
_cuda_device_index = idx
|
||||
torch.cuda.set_device(_cuda_device_index)
|
||||
|
||||
|
||||
def set_cpu_device() -> None:
|
||||
global _cuda_device_index
|
||||
_cuda_device_index = _CPU_DEVICE_INDEX
|
||||
|
||||
|
||||
def get_cuda_device_index() -> int:
|
||||
return _cuda_device_index
|
||||
|
||||
|
||||
def init_distributed_data_parallel_model(
|
||||
model: torch.nn.Module,
|
||||
broadcast_buffers: bool = False,
|
||||
find_unused_parameters: bool = True,
|
||||
bucket_cap_mb: int = 25,
|
||||
) -> torch.nn.parallel.DistributedDataParallel:
|
||||
global _cuda_device_index
|
||||
|
||||
if _cuda_device_index == _CPU_DEVICE_INDEX:
|
||||
# CPU-only model, don't specify device
|
||||
return torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
bucket_cap_mb=bucket_cap_mb,
|
||||
)
|
||||
else:
|
||||
# GPU model
|
||||
return torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[_cuda_device_index],
|
||||
output_device=_cuda_device_index,
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
bucket_cap_mb=bucket_cap_mb,
|
||||
)
|
||||
|
||||
|
||||
def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
|
||||
"""Broadcast an object from a source to all workers.
|
||||
|
||||
Args:
|
||||
obj: Object to broadcast, must be serializable
|
||||
src: Source rank for broadcast (default is primary)
|
||||
use_disk: If enabled, removes redundant CPU memory copies by writing to
|
||||
disk
|
||||
"""
|
||||
# Either broadcast from primary to the fleet (default),
|
||||
# or use the src setting as the original rank
|
||||
if get_rank() == src:
|
||||
# Emit data
|
||||
buffer = io.BytesIO()
|
||||
torch.save(obj, buffer)
|
||||
data_view = buffer.getbuffer()
|
||||
length_tensor = torch.LongTensor([len(data_view)])
|
||||
length_tensor = broadcast(length_tensor, src=src)
|
||||
data_tensor = torch.ByteTensor(data_view)
|
||||
data_tensor = broadcast(data_tensor, src=src)
|
||||
else:
|
||||
# Fetch from the source
|
||||
length_tensor = torch.LongTensor([0])
|
||||
length_tensor = broadcast(length_tensor, src=src)
|
||||
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
|
||||
data_tensor = broadcast(data_tensor, src=src)
|
||||
if use_disk:
|
||||
with tempfile.TemporaryFile("r+b") as f:
|
||||
f.write(data_tensor.numpy())
|
||||
# remove reference to the data tensor and hope that Python garbage
|
||||
# collects it
|
||||
del data_tensor
|
||||
f.seek(0)
|
||||
obj = torch.load(f, weights_only=False)
|
||||
else:
|
||||
buffer = io.BytesIO(data_tensor.numpy())
|
||||
obj = torch.load(buffer, weights_only=False)
|
||||
return obj
|
||||
|
||||
|
||||
def all_gather_tensor(tensor: torch.Tensor, world_size=None):
|
||||
if world_size is None:
|
||||
world_size = get_world_size()
|
||||
# make contiguous because NCCL won't gather the tensor otherwise
|
||||
assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!"
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_all, tensor, async_op=False) # performance opt
|
||||
tensor_all = [
|
||||
convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all
|
||||
]
|
||||
return tensor_all
|
||||
|
||||
|
||||
def all_gather_batch(tensors: List[torch.Tensor]):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
"""
|
||||
# Queue the gathered tensors
|
||||
world_size = get_world_size()
|
||||
# There is no need for reduction in the single-proc case
|
||||
if world_size == 1:
|
||||
return tensors
|
||||
tensor_list = []
|
||||
output_tensor = []
|
||||
for tensor in tensors:
|
||||
tensor_all = all_gather_tensor(tensor, world_size)
|
||||
tensor_list.append(tensor_all)
|
||||
|
||||
for tensor_all in tensor_list:
|
||||
output_tensor.append(torch.cat(tensor_all, dim=0))
|
||||
return output_tensor
|
||||
|
||||
|
||||
class GatherLayer(autograd.Function):
|
||||
"""
|
||||
Gather tensors from all workers with support for backward propagation:
|
||||
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output, x)
|
||||
return tuple(output)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
all_gradients = torch.stack(grads)
|
||||
dist.all_reduce(all_gradients)
|
||||
return all_gradients[dist.get_rank()]
|
||||
|
||||
|
||||
def all_gather_batch_with_grad(tensors):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
Graph remains connected for backward grad computation.
|
||||
"""
|
||||
# Queue the gathered tensors
|
||||
world_size = get_world_size()
|
||||
# There is no need for reduction in the single-proc case
|
||||
if world_size == 1:
|
||||
return tensors
|
||||
tensor_list = []
|
||||
output_tensor = []
|
||||
|
||||
for tensor in tensors:
|
||||
tensor_all = GatherLayer.apply(tensor)
|
||||
tensor_list.append(tensor_all)
|
||||
|
||||
for tensor_all in tensor_list:
|
||||
output_tensor.append(torch.cat(tensor_all, dim=0))
|
||||
return output_tensor
|
||||
|
||||
|
||||
def unwrap_ddp_if_wrapped(model):
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
|
||||
def create_new_process_group(group_size):
|
||||
"""
|
||||
Creates process groups of a gives `group_size` and returns
|
||||
process group that current GPU participates in.
|
||||
|
||||
`group_size` must divide the total number of GPUs (world_size).
|
||||
|
||||
Modified from
|
||||
https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60
|
||||
|
||||
Args:
|
||||
group_size (int): number of GPU's to collaborate for sync bn
|
||||
"""
|
||||
|
||||
assert group_size > 0
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if world_size <= 8:
|
||||
if group_size > world_size:
|
||||
logging.warning(
|
||||
f"Requested group size [{group_size}] > world size [{world_size}]. "
|
||||
"Assuming local debug run and capping it to world size."
|
||||
)
|
||||
group_size = world_size
|
||||
assert world_size >= group_size
|
||||
assert world_size % group_size == 0
|
||||
|
||||
group = None
|
||||
for group_num in range(world_size // group_size):
|
||||
group_ids = range(group_num * group_size, (group_num + 1) * group_size)
|
||||
cur_group = torch.distributed.new_group(ranks=group_ids)
|
||||
if torch.distributed.get_rank() // group_size == group_num:
|
||||
group = cur_group
|
||||
# can not drop out and return here, every process must go through creation of all subgroups
|
||||
|
||||
assert group is not None
|
||||
return group
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def gather_to_rank_0_via_filesys(data, filesys_save_dir=None):
|
||||
"""
|
||||
Gather any picklable data to rank 0 via filesystem, using all_gather_via_filesys.
|
||||
"""
|
||||
return all_gather_via_filesys(data, filesys_save_dir, gather_to_rank_0_only=True)
|
||||
241
sam3/train/utils/logger.py
Normal file
241
sam3/train/utils/logger.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import atexit
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from hydra.utils import instantiate
|
||||
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from numpy import ndarray
|
||||
|
||||
from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
Scalar = Union[Tensor, ndarray, int, float]
|
||||
|
||||
|
||||
def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any):
|
||||
makedir(log_dir)
|
||||
summary_writer_method = SummaryWriter
|
||||
return TensorBoardLogger(
|
||||
path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs
|
||||
)
|
||||
|
||||
|
||||
class TensorBoardWriterWrapper:
|
||||
"""
|
||||
A wrapper around a SummaryWriter object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
*args: Any,
|
||||
filename_suffix: str = None,
|
||||
summary_writer_method: Any = SummaryWriter,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TensorBoard logger.
|
||||
On construction, the logger creates a new events file that logs
|
||||
will be written to. If the environment variable `RANK` is defined,
|
||||
logger will only log if RANK = 0.
|
||||
|
||||
NOTE: If using the logger with distributed training:
|
||||
- This logger can call collective operations
|
||||
- Logs will be written on rank 0 only
|
||||
- Logger must be constructed synchronously *after* initializing distributed process group.
|
||||
|
||||
Args:
|
||||
path (str): path to write logs to
|
||||
*args, **kwargs: Extra arguments to pass to SummaryWriter
|
||||
"""
|
||||
self._writer: Optional[SummaryWriter] = None
|
||||
_, self._rank = get_machine_local_and_dist_rank()
|
||||
self._path: str = path
|
||||
if self._rank == 0:
|
||||
logging.info(
|
||||
f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}"
|
||||
)
|
||||
self._writer = summary_writer_method(
|
||||
log_dir=path,
|
||||
*args,
|
||||
filename_suffix=filename_suffix or str(uuid.uuid4()),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logging.debug(
|
||||
f"Not logging meters on this host because env RANK: {self._rank} != 0"
|
||||
)
|
||||
atexit.register(self.close)
|
||||
|
||||
@property
|
||||
def writer(self) -> Optional[SummaryWriter]:
|
||||
return self._writer
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self._path
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Writes pending logs to disk."""
|
||||
|
||||
if not self._writer:
|
||||
return
|
||||
|
||||
self._writer.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close writer, flushing pending logs to disk.
|
||||
Logs cannot be written after `close` is called.
|
||||
"""
|
||||
|
||||
if not self._writer:
|
||||
return
|
||||
|
||||
self._writer.close()
|
||||
self._writer = None
|
||||
|
||||
|
||||
class TensorBoardLogger(TensorBoardWriterWrapper):
|
||||
"""
|
||||
A simple logger for TensorBoard.
|
||||
"""
|
||||
|
||||
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
|
||||
"""Add multiple scalar values to TensorBoard.
|
||||
|
||||
Args:
|
||||
payload (dict): dictionary of tag name and scalar value
|
||||
step (int, Optional): step value to record
|
||||
"""
|
||||
if not self._writer:
|
||||
return
|
||||
for k, v in payload.items():
|
||||
self.log(k, v, step)
|
||||
|
||||
def log(self, name: str, data: Scalar, step: int) -> None:
|
||||
"""Add scalar data to TensorBoard.
|
||||
|
||||
Args:
|
||||
name (string): tag name used to group scalars
|
||||
data (float/int/Tensor): scalar data to log
|
||||
step (int, optional): step value to record
|
||||
"""
|
||||
if not self._writer:
|
||||
return
|
||||
self._writer.add_scalar(name, data, global_step=step, new_style=True)
|
||||
|
||||
def log_hparams(
|
||||
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
|
||||
) -> None:
|
||||
"""Add hyperparameter data to TensorBoard.
|
||||
|
||||
Args:
|
||||
hparams (dict): dictionary of hyperparameter names and corresponding values
|
||||
meters (dict): dictionary of name of meter and corersponding values
|
||||
"""
|
||||
if not self._writer:
|
||||
return
|
||||
self._writer.add_hparams(hparams, meters)
|
||||
|
||||
|
||||
class Logger:
|
||||
"""
|
||||
A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger.
|
||||
"""
|
||||
|
||||
def __init__(self, logging_conf):
|
||||
# allow turning off TensorBoard with "should_log: false" in config
|
||||
tb_config = logging_conf.tensorboard_writer
|
||||
tb_should_log = tb_config and tb_config.pop("should_log", True)
|
||||
self.tb_logger = instantiate(tb_config) if tb_should_log else None
|
||||
|
||||
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log_dict(payload, step)
|
||||
|
||||
def log(self, name: str, data: Scalar, step: int) -> None:
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log(name, data, step)
|
||||
|
||||
def log_hparams(
|
||||
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
|
||||
) -> None:
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log_hparams(hparams, meters)
|
||||
|
||||
|
||||
# cache the opened file object, so that different calls to `setup_logger`
|
||||
# with the same file name can safely write to the same file.
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _cached_log_stream(filename):
|
||||
# we tune the buffering value so that the logs are updated
|
||||
# frequently.
|
||||
log_buffer_kb = 10 * 1024 # 10KB
|
||||
io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb)
|
||||
atexit.register(io.close)
|
||||
return io
|
||||
|
||||
|
||||
def setup_logging(
|
||||
name,
|
||||
output_dir=None,
|
||||
rank=0,
|
||||
log_level_primary="INFO",
|
||||
log_level_secondary="ERROR",
|
||||
):
|
||||
"""
|
||||
Setup various logging streams: stdout and file handlers.
|
||||
For file handlers, we only setup for the master gpu.
|
||||
"""
|
||||
# get the filename if we want to log to the file as well
|
||||
log_filename = None
|
||||
if output_dir:
|
||||
makedir(output_dir)
|
||||
if rank == 0:
|
||||
log_filename = f"{output_dir}/log.txt"
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(log_level_primary)
|
||||
|
||||
# create formatter
|
||||
FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s"
|
||||
formatter = logging.Formatter(FORMAT)
|
||||
|
||||
# Cleanup any existing handlers
|
||||
for h in logger.handlers:
|
||||
logger.removeHandler(h)
|
||||
logger.root.handlers = []
|
||||
|
||||
# setup the console handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
if rank == 0:
|
||||
console_handler.setLevel(log_level_primary)
|
||||
else:
|
||||
console_handler.setLevel(log_level_secondary)
|
||||
|
||||
# we log to file as well if user wants
|
||||
if log_filename and rank == 0:
|
||||
file_handler = logging.StreamHandler(_cached_log_stream(log_filename))
|
||||
file_handler.setLevel(log_level_primary)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
logging.root = logger
|
||||
|
||||
|
||||
def shutdown_logging():
|
||||
"""
|
||||
After training is done, we ensure to shut down all the logger streams.
|
||||
"""
|
||||
logging.info("Shutting down loggers...")
|
||||
handlers = logging.root.handlers
|
||||
for handler in handlers:
|
||||
handler.close()
|
||||
285
sam3/train/utils/train_utils.py
Normal file
285
sam3/train/utils/train_utils.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import hydra
|
||||
|
||||
import numpy as np
|
||||
import omegaconf
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def multiply_all(*args):
|
||||
return np.prod(np.array(args)).item()
|
||||
|
||||
|
||||
def collect_dict_keys(config):
|
||||
"""This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined"""
|
||||
val_keys = []
|
||||
# If the this config points to the collate function, then it has a key
|
||||
if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]):
|
||||
val_keys.append(config["dict_key"])
|
||||
else:
|
||||
# Recursively proceed
|
||||
for v in config.values():
|
||||
if isinstance(v, type(config)):
|
||||
val_keys.extend(collect_dict_keys(v))
|
||||
elif isinstance(v, omegaconf.listconfig.ListConfig):
|
||||
for item in v:
|
||||
if isinstance(item, type(config)):
|
||||
val_keys.extend(collect_dict_keys(item))
|
||||
return val_keys
|
||||
|
||||
|
||||
class Phase:
|
||||
TRAIN = "train"
|
||||
VAL = "val"
|
||||
|
||||
|
||||
def register_omegaconf_resolvers():
|
||||
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
|
||||
OmegaConf.register_new_resolver("get_class", hydra.utils.get_class)
|
||||
OmegaConf.register_new_resolver("add", lambda x, y: x + y)
|
||||
OmegaConf.register_new_resolver("times", multiply_all)
|
||||
OmegaConf.register_new_resolver("divide", lambda x, y: x / y)
|
||||
OmegaConf.register_new_resolver("pow", lambda x, y: x**y)
|
||||
OmegaConf.register_new_resolver("subtract", lambda x, y: x - y)
|
||||
OmegaConf.register_new_resolver("range", lambda x: list(range(x)))
|
||||
OmegaConf.register_new_resolver("int", lambda x: int(x))
|
||||
OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x)))
|
||||
OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x))
|
||||
OmegaConf.register_new_resolver("string", lambda x: str(x))
|
||||
|
||||
|
||||
def setup_distributed_backend(backend, timeout_mins):
|
||||
"""
|
||||
Initialize torch.distributed and set the CUDA device.
|
||||
Expects environment variables to be set as per
|
||||
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
|
||||
along with the environ variable "LOCAL_RANK" which is used to set the CUDA device.
|
||||
"""
|
||||
# enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins
|
||||
# of waiting
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
||||
logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins")
|
||||
dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins))
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def get_machine_local_and_dist_rank():
|
||||
"""
|
||||
Get the distributed and local rank of the current gpu.
|
||||
"""
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", None))
|
||||
distributed_rank = int(os.environ.get("RANK", None))
|
||||
assert (
|
||||
local_rank is not None and distributed_rank is not None
|
||||
), "Please the set the RANK and LOCAL_RANK environment variables."
|
||||
return local_rank, distributed_rank
|
||||
|
||||
|
||||
def print_cfg(cfg):
|
||||
"""
|
||||
Supports printing both Hydra DictConfig and also the AttrDict config
|
||||
"""
|
||||
logging.info("Training with config:")
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
|
||||
def set_seeds(seed_value, max_epochs, dist_rank):
|
||||
"""
|
||||
Set the python random, numpy and torch seed for each gpu. Also set the CUDA
|
||||
seeds if the CUDA is available. This ensures deterministic nature of the training.
|
||||
"""
|
||||
# Since in the pytorch sampler, we increment the seed by 1 for every epoch.
|
||||
seed_value = (seed_value + dist_rank) * max_epochs
|
||||
logging.info(f"MACHINE SEED: {seed_value}")
|
||||
random.seed(seed_value)
|
||||
np.random.seed(seed_value)
|
||||
torch.manual_seed(seed_value)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed_value)
|
||||
|
||||
|
||||
def makedir(dir_path):
|
||||
"""
|
||||
Create the directory if it does not exist.
|
||||
"""
|
||||
is_success = False
|
||||
try:
|
||||
if not g_pathmgr.exists(dir_path):
|
||||
g_pathmgr.mkdirs(dir_path)
|
||||
is_success = True
|
||||
except BaseException:
|
||||
logging.info(f"Error creating directory: {dir_path}")
|
||||
return is_success
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_amp_type(amp_type: Optional[str] = None):
|
||||
if amp_type is None:
|
||||
return None
|
||||
assert amp_type in ["bfloat16", "float16"], "Invalid Amp type."
|
||||
if amp_type == "bfloat16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float16
|
||||
|
||||
|
||||
def log_env_variables():
|
||||
env_keys = sorted(list(os.environ.keys()))
|
||||
st = ""
|
||||
for k in env_keys:
|
||||
v = os.environ[k]
|
||||
st += f"{k}={v}\n"
|
||||
logging.info("Logging ENV_VARIABLES")
|
||||
logging.info(st)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, device, fmt=":f"):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.device = device
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
self._allow_updates = True
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class MemMeter:
|
||||
"""Computes and stores the current, avg, and max of peak Mem usage per iteration"""
|
||||
|
||||
def __init__(self, name, device, fmt=":f"):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.device = device
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0 # Per iteration max usage
|
||||
self.avg = 0 # Avg per iteration max usage
|
||||
self.peak = 0 # Peak usage for lifetime of program
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
self._allow_updates = True
|
||||
|
||||
def update(self, n=1, reset_peak_usage=True):
|
||||
self.val = torch.cuda.max_memory_allocated() // 1e9
|
||||
self.sum += self.val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
self.peak = max(self.peak, self.val)
|
||||
if reset_peak_usage:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = (
|
||||
"{name}: {val"
|
||||
+ self.fmt
|
||||
+ "} ({avg"
|
||||
+ self.fmt
|
||||
+ "}/{peak"
|
||||
+ self.fmt
|
||||
+ "})"
|
||||
)
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
def human_readable_time(time_seconds):
|
||||
time = int(time_seconds)
|
||||
minutes, seconds = divmod(time, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
days, hours = divmod(hours, 24)
|
||||
return f"{days:02}d {hours:02}h {minutes:02}m"
|
||||
|
||||
|
||||
class DurationMeter:
|
||||
def __init__(self, name, device, fmt=":f"):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.fmt = fmt
|
||||
self.val = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
|
||||
def update(self, val):
|
||||
self.val = val
|
||||
|
||||
def add(self, val):
|
||||
self.val += val
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}: {human_readable_time(self.val)}"
|
||||
|
||||
|
||||
class ProgressMeter:
|
||||
def __init__(self, num_batches, meters, real_meters, prefix=""):
|
||||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
||||
self.meters = meters
|
||||
self.real_meters = real_meters
|
||||
self.prefix = prefix
|
||||
|
||||
def display(self, batch, enable_print=False):
|
||||
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
||||
entries += [str(meter) for meter in self.meters]
|
||||
entries += [
|
||||
" | ".join(
|
||||
[
|
||||
f"{os.path.join(name, subname)}: {val:.4f}"
|
||||
for subname, val in meter.compute().items()
|
||||
]
|
||||
)
|
||||
for name, meter in self.real_meters.items()
|
||||
]
|
||||
logging.info(" | ".join(entries))
|
||||
if enable_print:
|
||||
print(" | ".join(entries))
|
||||
|
||||
def _get_batch_fmtstr(self, num_batches):
|
||||
num_digits = len(str(num_batches // 1))
|
||||
fmt = "{:" + str(num_digits) + "d}"
|
||||
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|
||||
|
||||
|
||||
def get_resume_checkpoint(checkpoint_save_dir):
|
||||
if not g_pathmgr.isdir(checkpoint_save_dir):
|
||||
return None
|
||||
ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt")
|
||||
if not g_pathmgr.isfile(ckpt_file):
|
||||
return None
|
||||
|
||||
return ckpt_file
|
||||
Reference in New Issue
Block a user