apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)

Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: itamaro

Differential Revision: D90476315

fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
Bowie Chen
2026-01-11 23:16:49 -08:00
committed by meta-codesync[bot]
parent 7b89b8fc3f
commit 11dec2936d
69 changed files with 445 additions and 522 deletions

View File

@@ -142,9 +142,9 @@ class COCO_FROM_JSON:
self.prompts = {}
for loc_dict in prompts:
self.prompts[int(loc_dict["id"])] = loc_dict["name"]
assert len(self.prompts) == len(
self._sorted_cat_ids
), "Number of prompts must match number of categories"
assert len(self.prompts) == len(self._sorted_cat_ids), (
"Number of prompts must match number of categories"
)
def getDatapointIds(self):
"""Return all datapoint indices for training."""

View File

@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_data
from typing import Any, get_args, get_origin, List, Union
import torch
from sam3.model.data_misc import (
BatchedDatapoint,
BatchedFindTarget,
@@ -217,9 +216,9 @@ def collate_fn_api(
text_batch.append(q.query_text)
stages[stage_id].text_ids.append(text_batch.index(q.query_text))
assert (
q.inference_metadata is not None
), "inference_metadata must be provided when FindQueryLoaded is created."
assert q.inference_metadata is not None, (
"inference_metadata must be provided when FindQueryLoaded is created."
)
for f in fields(q.inference_metadata):
getattr(find_metadatas[stage_id], f.name).append(
getattr(q.inference_metadata, f.name)

View File

@@ -19,10 +19,8 @@ import torch.utils.data
import torchvision
from decord import cpu, VideoReader
from iopath.common.file_io import g_pathmgr
from PIL import Image as PILImage
from PIL.Image import DecompressionBombError
from sam3.model.box_ops import box_xywh_to_xyxy
from torchvision.datasets.vision import VisionDataset
@@ -234,9 +232,9 @@ class CustomCocoDetectionAPI(VisionDataset):
if self.coco is not None:
return
assert g_pathmgr.isfile(
self.annFile
), f"please provide valid annotation file. Missing: {self.annFile}"
assert g_pathmgr.isfile(self.annFile), (
f"please provide valid annotation file. Missing: {self.annFile}"
)
annFile = g_pathmgr.get_local_path(self.annFile)
if self.coco is not None:
@@ -326,9 +324,9 @@ class CustomCocoDetectionAPI(VisionDataset):
else:
num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
for stage, num_queries in stage2num_queries.items():
assert (
num_queries == num_queries_per_stage
), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
assert num_queries == num_queries_per_stage, (
f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
)
for query_id, query in enumerate(queries):
h, w = id2imsize[query["image_id"]]

View File

@@ -3,7 +3,6 @@
# pyre-unsafe
import copy
import io
import json
import logging
@@ -16,7 +15,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
import torchvision
# from decord import cpu, VideoReader
from iopath.common.file_io import PathManager
@@ -220,9 +218,9 @@ class VideoGroundingDataset(Sam3ImageDataset):
for query in filtered_queries:
ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
assert (
ptr_x_is_empty and ptr_y_is_empty
), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
assert ptr_x_is_empty and ptr_y_is_empty, (
"Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
)
query["query_processing_order"] = stage_id_old2new[
query["query_processing_order"]
]

View File

@@ -9,11 +9,8 @@ import torch
import torch.distributed
import torch.nn.functional as F
import torchmetrics
from sam3.model import box_ops
from sam3.model.data_misc import interpolate
from sam3.train.loss.sigmoid_focal_loss import (
triton_sigmoid_focal_loss,
triton_sigmoid_focal_loss_reduce,
@@ -327,7 +324,9 @@ class IABCEMdetr(LossWithWeights):
if num_det_queries is not None:
logging.warning("note: it's not needed to set num_det_queries anymore")
if self.use_separate_loss_for_det_and_trk:
assert not self.weak_loss, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
assert not self.weak_loss, (
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
)
self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos
self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg
self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos
@@ -342,7 +341,9 @@ class IABCEMdetr(LossWithWeights):
and det_non_exhaustive_loss_scale_neg == 1.0
and trk_loss_scale_pos == 1.0
and trk_loss_scale_neg == 1.0
), "If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
), (
"If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
)
def get_loss(self, outputs, targets, indices, num_boxes):
assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape"
@@ -443,7 +444,9 @@ class IABCEMdetr(LossWithWeights):
pass
if self.weak_loss:
assert not self.use_separate_loss_for_det_and_trk, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
assert not self.use_separate_loss_for_det_and_trk, (
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
)
# nullify the negative loss for the non-exhaustive classes
assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0]
@@ -497,9 +500,9 @@ class IABCEMdetr(LossWithWeights):
loss_bce = loss_bce.mean()
else:
assert isinstance(self.pad_n_queries, int)
assert (
loss_bce.size(1) < self.pad_n_queries
), f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
assert loss_bce.size(1) < self.pad_n_queries, (
f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
)
loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0))
bce_f1 = torchmetrics.functional.f1_score(

View File

@@ -3,9 +3,7 @@
# pyre-unsafe
import torch
from sam3.model.model_misc import SAM3Output
from sam3.train.utils.distributed import get_world_size
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks

View File

@@ -103,9 +103,9 @@ def dilation(mask, kernel_size):
assert mask.ndim == 3
kernel_size = int(kernel_size)
assert (
kernel_size % 2 == 1
), f"Dilation expects a odd kernel size, got {kernel_size}"
assert kernel_size % 2 == 1, (
f"Dilation expects a odd kernel size, got {kernel_size}"
)
if mask.is_cuda:
m = mask.unsqueeze(1).to(torch.float16)

View File

@@ -8,7 +8,6 @@ Modules to compute the matching cost and solve the corresponding LSAP.
import numpy as np
import torch
from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
from scipy.optimize import linear_sum_assignment
from torch import nn
@@ -60,9 +59,9 @@ class HungarianMatcher(nn.Module):
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1)
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
"all costs cant be 0"
)
self.focal_loss = focal_loss
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
@@ -197,9 +196,9 @@ class BinaryHungarianMatcher(nn.Module):
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.norm = nn.Sigmoid()
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
"all costs cant be 0"
)
@torch.no_grad()
def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1):
@@ -322,9 +321,9 @@ class BinaryFocalHungarianMatcher(nn.Module):
self.alpha = alpha
self.gamma = gamma
self.stable = stable
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
"all costs cant be 0"
)
@torch.no_grad()
def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1):
@@ -470,9 +469,9 @@ class BinaryHungarianMatcherV2(nn.Module):
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.norm = nn.Sigmoid()
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
"all costs cant be 0"
)
self.focal = focal
if focal:
self.alpha = alpha

View File

@@ -22,7 +22,6 @@ from typing import (
)
import hydra
import torch
import torch.nn as nn
from omegaconf import DictConfig
@@ -212,9 +211,9 @@ def unix_module_cls_pattern_to_parameter_names(
"match any classes in the model"
)
matching_parameters = module_cls_to_param_names[module_cls]
assert (
len(matching_parameters) > 0
), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
assert len(matching_parameters) > 0, (
f"module_cls_name {module_cls_name} does not contain any parameters in the model"
)
logging.info(
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
)
@@ -240,9 +239,9 @@ def unix_param_pattern_to_parameter_names(
allowed_parameter_names = []
for param_name in filter_param_names:
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
assert (
len(matching_parameters) >= 1
), f"param_name {param_name} does not match any parameters in the model"
assert len(matching_parameters) >= 1, (
f"param_name {param_name} does not match any parameters in the model"
)
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
allowed_parameter_names.append(matching_parameters)
return set.union(*allowed_parameter_names)

View File

@@ -12,13 +12,10 @@ from copy import deepcopy
import submitit
import torch
from hydra import compose, initialize_config_module
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from omegaconf import OmegaConf
from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers
from tqdm import tqdm
@@ -212,9 +209,9 @@ def main(args) -> None:
},
}
if "include_nodes" in submitit_conf:
assert (
len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
), "Not enough nodes"
assert len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes, (
"Not enough nodes"
)
job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
submitit_conf["include_nodes"]
)

View File

@@ -15,28 +15,22 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Mapping, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from sam3.model.data_misc import BatchedDatapoint
from sam3.model.model_misc import SAM3Output
from sam3.model.utils.misc import copy_data_to_device
from sam3.train.optim.optimizer import construct_optimizer
from sam3.train.utils.checkpoint_utils import (
assert_skipped_parameters_are_frozen,
exclude_params_matching_unix_pattern,
load_state_dict_into_model,
with_check_parameter_frozen,
)
from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank
from sam3.train.utils.logger import Logger, setup_logging
from sam3.train.utils.train_utils import (
AverageMeter,
@@ -215,9 +209,9 @@ class Trainer:
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
log_env_variables()
assert (
is_dist_avail_and_initialized()
), "Torch distributed needs to be initialized before calling the trainer."
assert is_dist_avail_and_initialized(), (
"Torch distributed needs to be initialized before calling the trainer."
)
self._setup_components() # Except Optimizer everything is setup here.
self._move_to_device()
@@ -227,9 +221,9 @@ class Trainer:
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
if self.checkpoint_conf.resume_from is not None:
assert os.path.exists(
self.checkpoint_conf.resume_from
), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
assert os.path.exists(self.checkpoint_conf.resume_from), (
f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
)
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
if self.distributed_rank == 0 and not os.path.exists(dst):
# Copy the "resume_from" checkpoint to the checkpoint folder
@@ -477,9 +471,9 @@ class Trainer:
return self.loss[key]
assert key != "all", "Loss must be specified for key='all'"
assert (
"default" in self.loss
), f"Key {key} not found in losss, and no default provided"
assert "default" in self.loss, (
f"Key {key} not found in losss, and no default provided"
)
return self.loss["default"]
def _find_meter(self, phase: str, key: str):
@@ -922,12 +916,12 @@ class Trainer:
self.optim.zero_grad(set_to_none=True)
if self.gradient_accumulation_steps > 1:
assert isinstance(
batch, list
), f"Expected a list of batches, got {type(batch)}"
assert (
len(batch) == self.gradient_accumulation_steps
), f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
assert isinstance(batch, list), (
f"Expected a list of batches, got {type(batch)}"
)
assert len(batch) == self.gradient_accumulation_steps, (
f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
)
accum_steps = len(batch)
else:
accum_steps = 1
@@ -1039,9 +1033,9 @@ class Trainer:
def _check_val_key_match(self, val_keys, phase):
if val_keys is not None:
# Check if there are any duplicates
assert len(val_keys) == len(
set(val_keys)
), f"Duplicate keys in val datasets, keys: {val_keys}"
assert len(val_keys) == len(set(val_keys)), (
f"Duplicate keys in val datasets, keys: {val_keys}"
)
# Check that the keys match the meter keys
if self.meters_conf is not None and phase in self.meters_conf:
@@ -1055,9 +1049,9 @@ class Trainer:
loss_keys = set(self.loss_conf.keys()) - set(["all"])
if "default" not in loss_keys:
for k in val_keys:
assert (
k in loss_keys
), f"Error: key {k} is not defined in the losses, and no default is set"
assert k in loss_keys, (
f"Error: key {k} is not defined in the losses, and no default is set"
)
def _setup_components(self):
# Get the keys for all the val datasets, if any

View File

@@ -14,7 +14,6 @@ import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from sam3.model.box_ops import box_xyxy_to_cxcywh
from sam3.model.data_misc import interpolate
@@ -277,9 +276,9 @@ class RandomSizeCrop:
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
)
result_img, result_target = crop(img, target, [j, i, h, w])
assert (
len(result_target["boxes"]) == init_boxes
), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
assert len(result_target["boxes"]) == init_boxes, (
f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
)
return result_img, result_target
else:

View File

@@ -7,7 +7,6 @@ Transforms and data augmentation for both image + bbox.
"""
import logging
import numbers
import random
from collections.abc import Sequence
@@ -17,9 +16,7 @@ import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
import torchvision.transforms.v2.functional as Fv2
from PIL import Image as PILImage
from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes
from sam3.train.data.sam3_image_dataset import Datapoint
from torchvision.transforms import InterpolationMode

View File

@@ -4,12 +4,10 @@
import logging
import random
from collections import defaultdict
from typing import List, Optional, Union
import torch
from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object
@@ -381,9 +379,9 @@ class FlexibleFilterFindGetQueries:
if len(new_find_queries) == 0:
start_with_zero_check = True
assert (
start_with_zero_check
), "Invalid Find queries, they need to start at query_processing_order = 0"
assert start_with_zero_check, (
"Invalid Find queries, they need to start at query_processing_order = 0"
)
datapoint.find_queries = new_find_queries

View File

@@ -7,7 +7,6 @@ import numpy as np
import torch
from PIL import Image as PILImage
from pycocotools import mask as mask_util
from sam3.train.data.sam3_image_dataset import Datapoint
from torchvision.ops import masks_to_boxes
@@ -250,9 +249,9 @@ class RandomGeometricInputsAPI:
def _get_target_object(self, datapoint, query):
img = datapoint.images[query.image_id]
targets = query.object_ids_output
assert (
len(targets) == 1
), "Geometric queries only support a single target object."
assert len(targets) == 1, (
"Geometric queries only support a single target object."
)
target_idx = targets[0]
return img.objects[target_idx]

View File

@@ -5,12 +5,9 @@
import numpy as np
import pycocotools.mask as mask_utils
import torch
import torchvision.transforms.functional as F
from PIL import Image as PILImage
from sam3.model.box_ops import masks_to_boxes
from sam3.train.data.sam3_image_dataset import Datapoint

View File

@@ -36,9 +36,9 @@ def unix_pattern_to_parameter_names(
parameter_names = []
for param_name in constraints:
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
assert (
len(matching_parameters) > 0
), f"param_names {param_name} don't match any param in the given names."
assert len(matching_parameters) > 0, (
f"param_names {param_name} don't match any param in the given names."
)
parameter_names.append(matching_parameters)
return set.union(*parameter_names)

View File

@@ -10,10 +10,8 @@ import uuid
from typing import Any, Dict, Optional, Union
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from numpy import ndarray
from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter

View File

@@ -11,7 +11,6 @@ from datetime import timedelta
from typing import Optional
import hydra
import numpy as np
import omegaconf
import torch
@@ -83,9 +82,9 @@ def get_machine_local_and_dist_rank():
"""
local_rank = int(os.environ.get("LOCAL_RANK", None))
distributed_rank = int(os.environ.get("RANK", None))
assert (
local_rank is not None and distributed_rank is not None
), "Please the set the RANK and LOCAL_RANK environment variables."
assert local_rank is not None and distributed_rank is not None, (
"Please the set the RANK and LOCAL_RANK environment variables."
)
return local_rank, distributed_rank