apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: itamaro Differential Revision: D90476315 fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
committed by
meta-codesync[bot]
parent
7b89b8fc3f
commit
11dec2936d
@@ -142,9 +142,9 @@ class COCO_FROM_JSON:
|
||||
self.prompts = {}
|
||||
for loc_dict in prompts:
|
||||
self.prompts[int(loc_dict["id"])] = loc_dict["name"]
|
||||
assert len(self.prompts) == len(
|
||||
self._sorted_cat_ids
|
||||
), "Number of prompts must match number of categories"
|
||||
assert len(self.prompts) == len(self._sorted_cat_ids), (
|
||||
"Number of prompts must match number of categories"
|
||||
)
|
||||
|
||||
def getDatapointIds(self):
|
||||
"""Return all datapoint indices for training."""
|
||||
|
||||
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_data
|
||||
from typing import Any, get_args, get_origin, List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.data_misc import (
|
||||
BatchedDatapoint,
|
||||
BatchedFindTarget,
|
||||
@@ -217,9 +216,9 @@ def collate_fn_api(
|
||||
text_batch.append(q.query_text)
|
||||
stages[stage_id].text_ids.append(text_batch.index(q.query_text))
|
||||
|
||||
assert (
|
||||
q.inference_metadata is not None
|
||||
), "inference_metadata must be provided when FindQueryLoaded is created."
|
||||
assert q.inference_metadata is not None, (
|
||||
"inference_metadata must be provided when FindQueryLoaded is created."
|
||||
)
|
||||
for f in fields(q.inference_metadata):
|
||||
getattr(find_metadatas[stage_id], f.name).append(
|
||||
getattr(q.inference_metadata, f.name)
|
||||
|
||||
@@ -19,10 +19,8 @@ import torch.utils.data
|
||||
import torchvision
|
||||
from decord import cpu, VideoReader
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from PIL.Image import DecompressionBombError
|
||||
|
||||
from sam3.model.box_ops import box_xywh_to_xyxy
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
@@ -234,9 +232,9 @@ class CustomCocoDetectionAPI(VisionDataset):
|
||||
if self.coco is not None:
|
||||
return
|
||||
|
||||
assert g_pathmgr.isfile(
|
||||
self.annFile
|
||||
), f"please provide valid annotation file. Missing: {self.annFile}"
|
||||
assert g_pathmgr.isfile(self.annFile), (
|
||||
f"please provide valid annotation file. Missing: {self.annFile}"
|
||||
)
|
||||
annFile = g_pathmgr.get_local_path(self.annFile)
|
||||
|
||||
if self.coco is not None:
|
||||
@@ -326,9 +324,9 @@ class CustomCocoDetectionAPI(VisionDataset):
|
||||
else:
|
||||
num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
|
||||
for stage, num_queries in stage2num_queries.items():
|
||||
assert (
|
||||
num_queries == num_queries_per_stage
|
||||
), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
|
||||
assert num_queries == num_queries_per_stage, (
|
||||
f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
|
||||
)
|
||||
|
||||
for query_id, query in enumerate(queries):
|
||||
h, w = id2imsize[query["image_id"]]
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
# pyre-unsafe
|
||||
|
||||
import copy
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
@@ -16,7 +15,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
# from decord import cpu, VideoReader
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
@@ -220,9 +218,9 @@ class VideoGroundingDataset(Sam3ImageDataset):
|
||||
for query in filtered_queries:
|
||||
ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
|
||||
ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
|
||||
assert (
|
||||
ptr_x_is_empty and ptr_y_is_empty
|
||||
), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
|
||||
assert ptr_x_is_empty and ptr_y_is_empty, (
|
||||
"Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
|
||||
)
|
||||
query["query_processing_order"] = stage_id_old2new[
|
||||
query["query_processing_order"]
|
||||
]
|
||||
|
||||
@@ -9,11 +9,8 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional as F
|
||||
import torchmetrics
|
||||
|
||||
from sam3.model import box_ops
|
||||
|
||||
from sam3.model.data_misc import interpolate
|
||||
|
||||
from sam3.train.loss.sigmoid_focal_loss import (
|
||||
triton_sigmoid_focal_loss,
|
||||
triton_sigmoid_focal_loss_reduce,
|
||||
@@ -327,7 +324,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
if num_det_queries is not None:
|
||||
logging.warning("note: it's not needed to set num_det_queries anymore")
|
||||
if self.use_separate_loss_for_det_and_trk:
|
||||
assert not self.weak_loss, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
assert not self.weak_loss, (
|
||||
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
)
|
||||
self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos
|
||||
self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg
|
||||
self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos
|
||||
@@ -342,7 +341,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
and det_non_exhaustive_loss_scale_neg == 1.0
|
||||
and trk_loss_scale_pos == 1.0
|
||||
and trk_loss_scale_neg == 1.0
|
||||
), "If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
|
||||
), (
|
||||
"If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
|
||||
)
|
||||
|
||||
def get_loss(self, outputs, targets, indices, num_boxes):
|
||||
assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape"
|
||||
@@ -443,7 +444,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
pass
|
||||
|
||||
if self.weak_loss:
|
||||
assert not self.use_separate_loss_for_det_and_trk, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
assert not self.use_separate_loss_for_det_and_trk, (
|
||||
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
)
|
||||
|
||||
# nullify the negative loss for the non-exhaustive classes
|
||||
assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0]
|
||||
@@ -497,9 +500,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
loss_bce = loss_bce.mean()
|
||||
else:
|
||||
assert isinstance(self.pad_n_queries, int)
|
||||
assert (
|
||||
loss_bce.size(1) < self.pad_n_queries
|
||||
), f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
|
||||
assert loss_bce.size(1) < self.pad_n_queries, (
|
||||
f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
|
||||
)
|
||||
loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0))
|
||||
|
||||
bce_f1 = torchmetrics.functional.f1_score(
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
# pyre-unsafe
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.model_misc import SAM3Output
|
||||
|
||||
from sam3.train.utils.distributed import get_world_size
|
||||
|
||||
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks
|
||||
|
||||
@@ -103,9 +103,9 @@ def dilation(mask, kernel_size):
|
||||
|
||||
assert mask.ndim == 3
|
||||
kernel_size = int(kernel_size)
|
||||
assert (
|
||||
kernel_size % 2 == 1
|
||||
), f"Dilation expects a odd kernel size, got {kernel_size}"
|
||||
assert kernel_size % 2 == 1, (
|
||||
f"Dilation expects a odd kernel size, got {kernel_size}"
|
||||
)
|
||||
|
||||
if mask.is_cuda:
|
||||
m = mask.unsqueeze(1).to(torch.float16)
|
||||
|
||||
@@ -8,7 +8,6 @@ Modules to compute the matching cost and solve the corresponding LSAP.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from torch import nn
|
||||
@@ -60,9 +59,9 @@ class HungarianMatcher(nn.Module):
|
||||
self.cost_bbox = cost_bbox
|
||||
self.cost_giou = cost_giou
|
||||
self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1)
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||
"all costs cant be 0"
|
||||
)
|
||||
self.focal_loss = focal_loss
|
||||
self.focal_alpha = focal_alpha
|
||||
self.focal_gamma = focal_gamma
|
||||
@@ -197,9 +196,9 @@ class BinaryHungarianMatcher(nn.Module):
|
||||
self.cost_bbox = cost_bbox
|
||||
self.cost_giou = cost_giou
|
||||
self.norm = nn.Sigmoid()
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||
"all costs cant be 0"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1):
|
||||
@@ -322,9 +321,9 @@ class BinaryFocalHungarianMatcher(nn.Module):
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.stable = stable
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||
"all costs cant be 0"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1):
|
||||
@@ -470,9 +469,9 @@ class BinaryHungarianMatcherV2(nn.Module):
|
||||
self.cost_bbox = cost_bbox
|
||||
self.cost_giou = cost_giou
|
||||
self.norm = nn.Sigmoid()
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||
"all costs cant be 0"
|
||||
)
|
||||
self.focal = focal
|
||||
if focal:
|
||||
self.alpha = alpha
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import hydra
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import DictConfig
|
||||
@@ -212,9 +211,9 @@ def unix_module_cls_pattern_to_parameter_names(
|
||||
"match any classes in the model"
|
||||
)
|
||||
matching_parameters = module_cls_to_param_names[module_cls]
|
||||
assert (
|
||||
len(matching_parameters) > 0
|
||||
), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
|
||||
assert len(matching_parameters) > 0, (
|
||||
f"module_cls_name {module_cls_name} does not contain any parameters in the model"
|
||||
)
|
||||
logging.info(
|
||||
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
|
||||
)
|
||||
@@ -240,9 +239,9 @@ def unix_param_pattern_to_parameter_names(
|
||||
allowed_parameter_names = []
|
||||
for param_name in filter_param_names:
|
||||
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
|
||||
assert (
|
||||
len(matching_parameters) >= 1
|
||||
), f"param_name {param_name} does not match any parameters in the model"
|
||||
assert len(matching_parameters) >= 1, (
|
||||
f"param_name {param_name} does not match any parameters in the model"
|
||||
)
|
||||
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
|
||||
allowed_parameter_names.append(matching_parameters)
|
||||
return set.union(*allowed_parameter_names)
|
||||
|
||||
@@ -12,13 +12,10 @@ from copy import deepcopy
|
||||
|
||||
import submitit
|
||||
import torch
|
||||
|
||||
from hydra import compose, initialize_config_module
|
||||
from hydra.utils import instantiate
|
||||
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -212,9 +209,9 @@ def main(args) -> None:
|
||||
},
|
||||
}
|
||||
if "include_nodes" in submitit_conf:
|
||||
assert (
|
||||
len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
|
||||
), "Not enough nodes"
|
||||
assert len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes, (
|
||||
"Not enough nodes"
|
||||
)
|
||||
job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
|
||||
submitit_conf["include_nodes"]
|
||||
)
|
||||
|
||||
@@ -15,28 +15,22 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from hydra.utils import instantiate
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from sam3.model.data_misc import BatchedDatapoint
|
||||
from sam3.model.model_misc import SAM3Output
|
||||
from sam3.model.utils.misc import copy_data_to_device
|
||||
|
||||
from sam3.train.optim.optimizer import construct_optimizer
|
||||
|
||||
from sam3.train.utils.checkpoint_utils import (
|
||||
assert_skipped_parameters_are_frozen,
|
||||
exclude_params_matching_unix_pattern,
|
||||
load_state_dict_into_model,
|
||||
with_check_parameter_frozen,
|
||||
)
|
||||
|
||||
from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank
|
||||
|
||||
from sam3.train.utils.logger import Logger, setup_logging
|
||||
from sam3.train.utils.train_utils import (
|
||||
AverageMeter,
|
||||
@@ -215,9 +209,9 @@ class Trainer:
|
||||
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
|
||||
log_env_variables()
|
||||
|
||||
assert (
|
||||
is_dist_avail_and_initialized()
|
||||
), "Torch distributed needs to be initialized before calling the trainer."
|
||||
assert is_dist_avail_and_initialized(), (
|
||||
"Torch distributed needs to be initialized before calling the trainer."
|
||||
)
|
||||
|
||||
self._setup_components() # Except Optimizer everything is setup here.
|
||||
self._move_to_device()
|
||||
@@ -227,9 +221,9 @@ class Trainer:
|
||||
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
|
||||
|
||||
if self.checkpoint_conf.resume_from is not None:
|
||||
assert os.path.exists(
|
||||
self.checkpoint_conf.resume_from
|
||||
), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
|
||||
assert os.path.exists(self.checkpoint_conf.resume_from), (
|
||||
f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
|
||||
)
|
||||
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
|
||||
if self.distributed_rank == 0 and not os.path.exists(dst):
|
||||
# Copy the "resume_from" checkpoint to the checkpoint folder
|
||||
@@ -477,9 +471,9 @@ class Trainer:
|
||||
return self.loss[key]
|
||||
|
||||
assert key != "all", "Loss must be specified for key='all'"
|
||||
assert (
|
||||
"default" in self.loss
|
||||
), f"Key {key} not found in losss, and no default provided"
|
||||
assert "default" in self.loss, (
|
||||
f"Key {key} not found in losss, and no default provided"
|
||||
)
|
||||
return self.loss["default"]
|
||||
|
||||
def _find_meter(self, phase: str, key: str):
|
||||
@@ -922,12 +916,12 @@ class Trainer:
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
if self.gradient_accumulation_steps > 1:
|
||||
assert isinstance(
|
||||
batch, list
|
||||
), f"Expected a list of batches, got {type(batch)}"
|
||||
assert (
|
||||
len(batch) == self.gradient_accumulation_steps
|
||||
), f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
|
||||
assert isinstance(batch, list), (
|
||||
f"Expected a list of batches, got {type(batch)}"
|
||||
)
|
||||
assert len(batch) == self.gradient_accumulation_steps, (
|
||||
f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
|
||||
)
|
||||
accum_steps = len(batch)
|
||||
else:
|
||||
accum_steps = 1
|
||||
@@ -1039,9 +1033,9 @@ class Trainer:
|
||||
def _check_val_key_match(self, val_keys, phase):
|
||||
if val_keys is not None:
|
||||
# Check if there are any duplicates
|
||||
assert len(val_keys) == len(
|
||||
set(val_keys)
|
||||
), f"Duplicate keys in val datasets, keys: {val_keys}"
|
||||
assert len(val_keys) == len(set(val_keys)), (
|
||||
f"Duplicate keys in val datasets, keys: {val_keys}"
|
||||
)
|
||||
|
||||
# Check that the keys match the meter keys
|
||||
if self.meters_conf is not None and phase in self.meters_conf:
|
||||
@@ -1055,9 +1049,9 @@ class Trainer:
|
||||
loss_keys = set(self.loss_conf.keys()) - set(["all"])
|
||||
if "default" not in loss_keys:
|
||||
for k in val_keys:
|
||||
assert (
|
||||
k in loss_keys
|
||||
), f"Error: key {k} is not defined in the losses, and no default is set"
|
||||
assert k in loss_keys, (
|
||||
f"Error: key {k} is not defined in the losses, and no default is set"
|
||||
)
|
||||
|
||||
def _setup_components(self):
|
||||
# Get the keys for all the val datasets, if any
|
||||
|
||||
@@ -14,7 +14,6 @@ import PIL
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
from sam3.model.box_ops import box_xyxy_to_cxcywh
|
||||
from sam3.model.data_misc import interpolate
|
||||
|
||||
@@ -277,9 +276,9 @@ class RandomSizeCrop:
|
||||
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
|
||||
)
|
||||
result_img, result_target = crop(img, target, [j, i, h, w])
|
||||
assert (
|
||||
len(result_target["boxes"]) == init_boxes
|
||||
), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
|
||||
assert len(result_target["boxes"]) == init_boxes, (
|
||||
f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
|
||||
)
|
||||
|
||||
return result_img, result_target
|
||||
else:
|
||||
|
||||
@@ -7,7 +7,6 @@ Transforms and data augmentation for both image + bbox.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import numbers
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
@@ -17,9 +16,7 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
import torchvision.transforms.v2.functional as Fv2
|
||||
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object
|
||||
|
||||
|
||||
@@ -381,9 +379,9 @@ class FlexibleFilterFindGetQueries:
|
||||
if len(new_find_queries) == 0:
|
||||
start_with_zero_check = True
|
||||
|
||||
assert (
|
||||
start_with_zero_check
|
||||
), "Invalid Find queries, they need to start at query_processing_order = 0"
|
||||
assert start_with_zero_check, (
|
||||
"Invalid Find queries, they need to start at query_processing_order = 0"
|
||||
)
|
||||
|
||||
datapoint.find_queries = new_find_queries
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image as PILImage
|
||||
from pycocotools import mask as mask_util
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||
from torchvision.ops import masks_to_boxes
|
||||
|
||||
@@ -250,9 +249,9 @@ class RandomGeometricInputsAPI:
|
||||
def _get_target_object(self, datapoint, query):
|
||||
img = datapoint.images[query.image_id]
|
||||
targets = query.object_ids_output
|
||||
assert (
|
||||
len(targets) == 1
|
||||
), "Geometric queries only support a single target object."
|
||||
assert len(targets) == 1, (
|
||||
"Geometric queries only support a single target object."
|
||||
)
|
||||
target_idx = targets[0]
|
||||
return img.objects[target_idx]
|
||||
|
||||
|
||||
@@ -5,12 +5,9 @@
|
||||
import numpy as np
|
||||
import pycocotools.mask as mask_utils
|
||||
import torch
|
||||
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from sam3.model.box_ops import masks_to_boxes
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||
|
||||
|
||||
|
||||
@@ -36,9 +36,9 @@ def unix_pattern_to_parameter_names(
|
||||
parameter_names = []
|
||||
for param_name in constraints:
|
||||
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
|
||||
assert (
|
||||
len(matching_parameters) > 0
|
||||
), f"param_names {param_name} don't match any param in the given names."
|
||||
assert len(matching_parameters) > 0, (
|
||||
f"param_names {param_name} don't match any param in the given names."
|
||||
)
|
||||
parameter_names.append(matching_parameters)
|
||||
return set.union(*parameter_names)
|
||||
|
||||
|
||||
@@ -10,10 +10,8 @@ import uuid
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from hydra.utils import instantiate
|
||||
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from numpy import ndarray
|
||||
|
||||
from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@@ -11,7 +11,6 @@ from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import hydra
|
||||
|
||||
import numpy as np
|
||||
import omegaconf
|
||||
import torch
|
||||
@@ -83,9 +82,9 @@ def get_machine_local_and_dist_rank():
|
||||
"""
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", None))
|
||||
distributed_rank = int(os.environ.get("RANK", None))
|
||||
assert (
|
||||
local_rank is not None and distributed_rank is not None
|
||||
), "Please the set the RANK and LOCAL_RANK environment variables."
|
||||
assert local_rank is not None and distributed_rank is not None, (
|
||||
"Please the set the RANK and LOCAL_RANK environment variables."
|
||||
)
|
||||
return local_rank, distributed_rank
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user