apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)

Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: itamaro

Differential Revision: D90476315

fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
Bowie Chen
2026-01-11 23:16:49 -08:00
committed by meta-codesync[bot]
parent 7b89b8fc3f
commit 11dec2936d
69 changed files with 445 additions and 522 deletions

View File

@@ -15,28 +15,22 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Mapping, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from sam3.model.data_misc import BatchedDatapoint
from sam3.model.model_misc import SAM3Output
from sam3.model.utils.misc import copy_data_to_device
from sam3.train.optim.optimizer import construct_optimizer
from sam3.train.utils.checkpoint_utils import (
assert_skipped_parameters_are_frozen,
exclude_params_matching_unix_pattern,
load_state_dict_into_model,
with_check_parameter_frozen,
)
from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank
from sam3.train.utils.logger import Logger, setup_logging
from sam3.train.utils.train_utils import (
AverageMeter,
@@ -215,9 +209,9 @@ class Trainer:
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
log_env_variables()
assert (
is_dist_avail_and_initialized()
), "Torch distributed needs to be initialized before calling the trainer."
assert is_dist_avail_and_initialized(), (
"Torch distributed needs to be initialized before calling the trainer."
)
self._setup_components() # Except Optimizer everything is setup here.
self._move_to_device()
@@ -227,9 +221,9 @@ class Trainer:
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
if self.checkpoint_conf.resume_from is not None:
assert os.path.exists(
self.checkpoint_conf.resume_from
), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
assert os.path.exists(self.checkpoint_conf.resume_from), (
f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
)
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
if self.distributed_rank == 0 and not os.path.exists(dst):
# Copy the "resume_from" checkpoint to the checkpoint folder
@@ -477,9 +471,9 @@ class Trainer:
return self.loss[key]
assert key != "all", "Loss must be specified for key='all'"
assert (
"default" in self.loss
), f"Key {key} not found in losss, and no default provided"
assert "default" in self.loss, (
f"Key {key} not found in losss, and no default provided"
)
return self.loss["default"]
def _find_meter(self, phase: str, key: str):
@@ -922,12 +916,12 @@ class Trainer:
self.optim.zero_grad(set_to_none=True)
if self.gradient_accumulation_steps > 1:
assert isinstance(
batch, list
), f"Expected a list of batches, got {type(batch)}"
assert (
len(batch) == self.gradient_accumulation_steps
), f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
assert isinstance(batch, list), (
f"Expected a list of batches, got {type(batch)}"
)
assert len(batch) == self.gradient_accumulation_steps, (
f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
)
accum_steps = len(batch)
else:
accum_steps = 1
@@ -1039,9 +1033,9 @@ class Trainer:
def _check_val_key_match(self, val_keys, phase):
if val_keys is not None:
# Check if there are any duplicates
assert len(val_keys) == len(
set(val_keys)
), f"Duplicate keys in val datasets, keys: {val_keys}"
assert len(val_keys) == len(set(val_keys)), (
f"Duplicate keys in val datasets, keys: {val_keys}"
)
# Check that the keys match the meter keys
if self.meters_conf is not None and phase in self.meters_conf:
@@ -1055,9 +1049,9 @@ class Trainer:
loss_keys = set(self.loss_conf.keys()) - set(["all"])
if "default" not in loss_keys:
for k in val_keys:
assert (
k in loss_keys
), f"Error: key {k} is not defined in the losses, and no default is set"
assert k in loss_keys, (
f"Error: key {k} is not defined in the losses, and no default is set"
)
def _setup_components(self):
# Get the keys for all the val datasets, if any