apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: itamaro Differential Revision: D90476315 fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
committed by
meta-codesync[bot]
parent
7b89b8fc3f
commit
11dec2936d
@@ -15,28 +15,22 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from hydra.utils import instantiate
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from sam3.model.data_misc import BatchedDatapoint
|
||||
from sam3.model.model_misc import SAM3Output
|
||||
from sam3.model.utils.misc import copy_data_to_device
|
||||
|
||||
from sam3.train.optim.optimizer import construct_optimizer
|
||||
|
||||
from sam3.train.utils.checkpoint_utils import (
|
||||
assert_skipped_parameters_are_frozen,
|
||||
exclude_params_matching_unix_pattern,
|
||||
load_state_dict_into_model,
|
||||
with_check_parameter_frozen,
|
||||
)
|
||||
|
||||
from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank
|
||||
|
||||
from sam3.train.utils.logger import Logger, setup_logging
|
||||
from sam3.train.utils.train_utils import (
|
||||
AverageMeter,
|
||||
@@ -215,9 +209,9 @@ class Trainer:
|
||||
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
|
||||
log_env_variables()
|
||||
|
||||
assert (
|
||||
is_dist_avail_and_initialized()
|
||||
), "Torch distributed needs to be initialized before calling the trainer."
|
||||
assert is_dist_avail_and_initialized(), (
|
||||
"Torch distributed needs to be initialized before calling the trainer."
|
||||
)
|
||||
|
||||
self._setup_components() # Except Optimizer everything is setup here.
|
||||
self._move_to_device()
|
||||
@@ -227,9 +221,9 @@ class Trainer:
|
||||
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
|
||||
|
||||
if self.checkpoint_conf.resume_from is not None:
|
||||
assert os.path.exists(
|
||||
self.checkpoint_conf.resume_from
|
||||
), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
|
||||
assert os.path.exists(self.checkpoint_conf.resume_from), (
|
||||
f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
|
||||
)
|
||||
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
|
||||
if self.distributed_rank == 0 and not os.path.exists(dst):
|
||||
# Copy the "resume_from" checkpoint to the checkpoint folder
|
||||
@@ -477,9 +471,9 @@ class Trainer:
|
||||
return self.loss[key]
|
||||
|
||||
assert key != "all", "Loss must be specified for key='all'"
|
||||
assert (
|
||||
"default" in self.loss
|
||||
), f"Key {key} not found in losss, and no default provided"
|
||||
assert "default" in self.loss, (
|
||||
f"Key {key} not found in losss, and no default provided"
|
||||
)
|
||||
return self.loss["default"]
|
||||
|
||||
def _find_meter(self, phase: str, key: str):
|
||||
@@ -922,12 +916,12 @@ class Trainer:
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
if self.gradient_accumulation_steps > 1:
|
||||
assert isinstance(
|
||||
batch, list
|
||||
), f"Expected a list of batches, got {type(batch)}"
|
||||
assert (
|
||||
len(batch) == self.gradient_accumulation_steps
|
||||
), f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
|
||||
assert isinstance(batch, list), (
|
||||
f"Expected a list of batches, got {type(batch)}"
|
||||
)
|
||||
assert len(batch) == self.gradient_accumulation_steps, (
|
||||
f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
|
||||
)
|
||||
accum_steps = len(batch)
|
||||
else:
|
||||
accum_steps = 1
|
||||
@@ -1039,9 +1033,9 @@ class Trainer:
|
||||
def _check_val_key_match(self, val_keys, phase):
|
||||
if val_keys is not None:
|
||||
# Check if there are any duplicates
|
||||
assert len(val_keys) == len(
|
||||
set(val_keys)
|
||||
), f"Duplicate keys in val datasets, keys: {val_keys}"
|
||||
assert len(val_keys) == len(set(val_keys)), (
|
||||
f"Duplicate keys in val datasets, keys: {val_keys}"
|
||||
)
|
||||
|
||||
# Check that the keys match the meter keys
|
||||
if self.meters_conf is not None and phase in self.meters_conf:
|
||||
@@ -1055,9 +1049,9 @@ class Trainer:
|
||||
loss_keys = set(self.loss_conf.keys()) - set(["all"])
|
||||
if "default" not in loss_keys:
|
||||
for k in val_keys:
|
||||
assert (
|
||||
k in loss_keys
|
||||
), f"Error: key {k} is not defined in the losses, and no default is set"
|
||||
assert k in loss_keys, (
|
||||
f"Error: key {k} is not defined in the losses, and no default is set"
|
||||
)
|
||||
|
||||
def _setup_components(self):
|
||||
# Get the keys for all the val datasets, if any
|
||||
|
||||
Reference in New Issue
Block a user