apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)

Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: itamaro

Differential Revision: D90476315

fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
Bowie Chen
2026-01-11 23:16:49 -08:00
committed by meta-codesync[bot]
parent 7b89b8fc3f
commit 11dec2936d
69 changed files with 445 additions and 522 deletions

View File

@@ -296,9 +296,9 @@ def agent_inference(
assert LATEST_SAM3_TEXT_PROMPT != ""
# Make sure that the last message is a image
assert (
messages[-1]["content"][1]["type"] == "image"
), "Second content element should be an image"
assert messages[-1]["content"][1]["type"] == "image", (
"Second content element should be an image"
)
messages.pop() # Remove the last user message
# Add simplified replacement message
simplified_message = {
@@ -318,7 +318,7 @@ def agent_inference(
# MLLM check the mask one by one
for i in range(num_masks):
print(f"🔍 Checking mask {i+1}/{num_masks}...")
print(f"🔍 Checking mask {i + 1}/{num_masks}...")
image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i)
image_w_zoomed_in_mask_i_path = os.path.join(
@@ -363,7 +363,7 @@ def agent_inference(
raise ValueError(
"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters."
)
print(f"Generated text for mask {i+1}: {checking_generated_text}")
print(f"Generated text for mask {i + 1}: {checking_generated_text}")
verdict = (
checking_generated_text.split("<verdict>")[-1]
.split("</verdict>")[0]
@@ -371,11 +371,11 @@ def agent_inference(
)
if "Accept" in verdict:
assert not "Reject" in verdict
print(f"Mask {i+1} accepted, keeping it in the outputs.")
print(f"Mask {i + 1} accepted, keeping it in the outputs.")
masks_to_keep.append(i)
elif "Reject" in verdict:
assert not "Accept" in verdict
print(f"Mask {i+1} rejected, removing it from the outputs.")
print(f"Mask {i + 1} rejected, removing it from the outputs.")
else:
raise ValueError(
f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'."
@@ -397,7 +397,7 @@ def agent_inference(
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png"
).replace(
".png",
f"_selected_masks_{'-'.join(map(str, [i+1 for i in masks_to_keep]))}.png".replace(
f"_selected_masks_{'-'.join(map(str, [i + 1 for i in masks_to_keep]))}.png".replace(
"/", "_"
),
)

View File

@@ -7,7 +7,6 @@ import os
import torch
from PIL import Image
from sam3.model.box_ops import box_xyxy_to_xywh
from sam3.train.masks_ops import rle_encode

View File

@@ -84,9 +84,9 @@ class BoxMode(IntEnum):
], "Relative mode not yet supported!"
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
assert (
arr.shape[-1] == 5
), "The last dimension of input shape must be 5 for XYWHA format"
assert arr.shape[-1] == 5, (
"The last dimension of input shape must be 5 for XYWHA format"
)
original_dtype = arr.dtype
arr = arr.double()
@@ -244,9 +244,9 @@ class Boxes:
if isinstance(item, int):
return Boxes(self.tensor[item].view(1, -1))
b = self.tensor[item]
assert (
b.dim() == 2
), "Indexing on Boxes with {} failed to return a matrix!".format(item)
assert b.dim() == 2, (
"Indexing on Boxes with {} failed to return a matrix!".format(item)
)
return Boxes(b)
def __len__(self) -> int:
@@ -425,7 +425,7 @@ def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
Tensor: iou, sized [N].
"""
assert len(boxes1) == len(boxes2), (
"boxlists should have the same" "number of entries, got {}, {}".format(
"boxlists should have the samenumber of entries, got {}, {}".format(
len(boxes1), len(boxes2)
)
)

View File

@@ -13,7 +13,6 @@ from torch import device
from .boxes import Boxes
from .memory import retry_if_cuda_oom
from .roi_align import ROIAlign
@@ -142,10 +141,10 @@ class BitMasks:
if isinstance(item, int):
return BitMasks(self.tensor[item].unsqueeze(0))
m = self.tensor[item]
assert (
m.dim() == 3
), "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
item, m.shape
assert m.dim() == 3, (
"Indexing on BitMasks with {} returns a tensor with shape {}!".format(
item, m.shape
)
)
return BitMasks(m)

View File

@@ -363,9 +363,9 @@ class RotatedBoxes(Boxes):
if isinstance(item, int):
return RotatedBoxes(self.tensor[item].view(1, -1))
b = self.tensor[item]
assert (
b.dim() == 2
), "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
assert b.dim() == 2, (
"Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
)
return RotatedBoxes(b)
def __len__(self) -> int:

View File

@@ -20,7 +20,6 @@ from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image
from .boxes import Boxes, BoxMode
from .color_map import random_color
from .keypoints import Keypoints
from .masks import BitMasks, PolygonMasks
@@ -222,9 +221,9 @@ class _PanopticPrediction:
empty_ids.append(id)
if len(empty_ids) == 0:
return np.zeros(self._seg.shape, dtype=np.uint8)
assert (
len(empty_ids) == 1
), ">1 ids corresponds to no labels. This is currently not supported"
assert len(empty_ids) == 1, (
">1 ids corresponds to no labels. This is currently not supported"
)
return (self._seg != empty_ids[0]).numpy().astype(np.bool)
def semantic_masks(self):

View File

@@ -41,7 +41,7 @@ def run_single_image_inference(
print(f"Output JSON {output_json_path} already exists. Skipping.")
return
print(f"{'-'*30} Starting SAM 3 Agent Session... {'-'*30} ")
print(f"{'-' * 30} Starting SAM 3 Agent Session... {'-' * 30} ")
agent_history, final_output_dict, rendered_final_output = agent_inference(
image_path,
text_prompt,
@@ -50,7 +50,7 @@ def run_single_image_inference(
output_dir=output_dir,
debug=debug,
)
print(f"{'-'*30} End of SAM 3 Agent Session... {'-'*30} ")
print(f"{'-' * 30} End of SAM 3 Agent Session... {'-' * 30} ")
final_output_dict["text_prompt"] = text_prompt
final_output_dict["image_path"] = image_path

View File

@@ -73,7 +73,9 @@ def visualize(
idx = int(zoom_in_index)
num_masks = len(input_json.get("pred_masks", []))
if idx < 0 or idx >= num_masks:
raise ValueError(f"zoom_in_index {idx} is out of range (0..{num_masks-1}).")
raise ValueError(
f"zoom_in_index {idx} is out of range (0..{num_masks - 1})."
)
# (1) Replicate zoom_in_and_visualize
object_data = {

View File

@@ -126,9 +126,9 @@ class COCOCustom(COCO):
# MODIFICATION: faster and cached subset check
if not hasattr(self, "img_id_set"):
self.img_id_set = set(self.getImgIds())
assert set(annsImgIds).issubset(
self.img_id_set
), "Results do not correspond to current coco set"
assert set(annsImgIds).issubset(self.img_id_set), (
"Results do not correspond to current coco set"
)
# END MODIFICATION
if "caption" in anns[0]:
imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
@@ -301,9 +301,9 @@ class CGF1Eval(COCOeval):
TP = (match_scores >= thresh).sum()
FP = len(dt) - TP
FN = len(gt) - TP
assert (
FP >= 0 and FN >= 0
), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
assert FP >= 0 and FN >= 0, (
f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
)
TPs.append(TP)
FPs.append(FP)
FNs.append(FN)
@@ -599,9 +599,9 @@ class CGF1Evaluator:
"""
assert len(self.coco_gts) > 0, "No ground truth provided for evaluation."
assert len(self.coco_gts) == len(
self.coco_evals
), "Mismatch in number of ground truths and evaluators."
assert len(self.coco_gts) == len(self.coco_evals), (
"Mismatch in number of ground truths and evaluators."
)
if self.verbose:
print(f"Loading predictions from {pred_file}")
@@ -668,17 +668,17 @@ class CGF1Evaluator:
if len(scorings) == 1:
return scorings[0]
assert (
scorings[0].ndim == 3
), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
assert (
scorings[0].shape[0] == 1
), f"Expecting a single category, got {scorings[0].shape[0]}"
assert scorings[0].ndim == 3, (
f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
)
assert scorings[0].shape[0] == 1, (
f"Expecting a single category, got {scorings[0].shape[0]}"
)
for scoring in scorings:
assert (
scoring.shape == scorings[0].shape
), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
assert scoring.shape == scorings[0].shape, (
f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
)
selected_imgs = []
for img_id in range(scorings[0].shape[-1]):

View File

@@ -18,19 +18,15 @@ import os
import pickle
from collections import defaultdict
from pathlib import Path
from typing import Any, List, Optional
import numpy as np
import pycocotools.mask as mask_utils
import torch
from iopath.common.file_io import g_pathmgr
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from sam3.train.masks_ops import rle_encode
from sam3.train.utils.distributed import (
all_gather,
gather_to_rank_0_via_filesys,
@@ -755,9 +751,9 @@ def loadRes(self, resFile):
anns = resFile
assert type(anns) == list, "results in not an array of objects"
annsImgIds = [ann["image_id"] for ann in anns]
assert set(annsImgIds) == (
set(annsImgIds) & set(self.getImgIds())
), "Results do not correspond to current coco set"
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), (
"Results do not correspond to current coco set"
)
if "caption" in anns[0]:
imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
[ann["image_id"] for ann in anns]

View File

@@ -83,9 +83,9 @@ class PredictionDumper:
self.merge_predictions = merge_predictions
self.pred_file_evaluators = pred_file_evaluators
if self.pred_file_evaluators is not None:
assert (
merge_predictions
), "merge_predictions must be True if pred_file_evaluators are provided"
assert merge_predictions, (
"merge_predictions must be True if pred_file_evaluators are provided"
)
assert self.dump_dir is not None, "dump_dir must be provided"
if is_main_process():

View File

@@ -13,11 +13,9 @@ from typing import Optional
import numpy as np
import pycocotools.mask as maskUtils
from pycocotools.cocoeval import COCOeval
from sam3.eval.coco_eval import CocoEvaluator
from sam3.train.masks_ops import compute_F_measure
from sam3.train.utils.distributed import is_main_process
from scipy.optimize import linear_sum_assignment
@@ -156,9 +154,9 @@ class DemoEval(COCOeval):
TP = (match_scores >= thresh).sum()
FP = len(dt) - TP
FN = len(gt) - TP
assert (
FP >= 0 and FN >= 0
), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
assert FP >= 0 and FN >= 0, (
f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
)
TPs.append(TP)
FPs.append(FP)
FNs.append(FN)
@@ -528,17 +526,17 @@ class DemoEvaluator(CocoEvaluator):
if len(scorings) == 1:
return scorings[0]
assert (
scorings[0].ndim == 3
), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
assert (
scorings[0].shape[0] == 1
), f"Expecting a single category, got {scorings[0].shape[0]}"
assert scorings[0].ndim == 3, (
f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
)
assert scorings[0].shape[0] == 1, (
f"Expecting a single category, got {scorings[0].shape[0]}"
)
for scoring in scorings:
assert (
scoring.shape == scorings[0].shape
), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
assert scoring.shape == scorings[0].shape, (
f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
)
selected_imgs = []
for img_id in range(scorings[0].shape[-1]):

View File

@@ -255,9 +255,10 @@ class Evaluator:
if show_progressbar and TQDM_IMPORTED:
seq_list_sorted = sorted(seq_list)
with Pool(config["NUM_PARALLEL_CORES"]) as pool, tqdm.tqdm(
total=len(seq_list)
) as pbar:
with (
Pool(config["NUM_PARALLEL_CORES"]) as pool,
tqdm.tqdm(total=len(seq_list)) as pbar,
):
_eval_sequence = partial(
eval_sequence,
dataset=dataset,

View File

@@ -83,9 +83,9 @@ class PostProcessImage(nn.Module):
ret_tensordict: Experimental argument. If true, return a tensordict.TensorDict instead of a list of dictionaries for easier manipulation.
"""
if ret_tensordict:
assert (
consistent is True
), "We don't support returning TensorDict if the outputs have different shapes" # NOTE: It's possible but we don't support it.
assert consistent is True, (
"We don't support returning TensorDict if the outputs have different shapes"
) # NOTE: It's possible but we don't support it.
assert self.detection_threshold <= 0.0, "TODO: implement?"
try:
from tensordict import TensorDict
@@ -118,7 +118,9 @@ class PostProcessImage(nn.Module):
if boxes is None:
assert out_masks is not None
assert not ret_tensordict, "We don't support returning TensorDict if the output does not contain boxes"
assert not ret_tensordict, (
"We don't support returning TensorDict if the output does not contain boxes"
)
B = len(out_masks)
boxes = [None] * B
scores = [None] * B
@@ -418,9 +420,9 @@ class PostProcessAPIVideo(PostProcessImage):
if video_id == -1:
video_id = unique_vid_id.item()
else:
assert (
video_id == unique_vid_id.item()
), "We can only postprocess one video per datapoint"
assert video_id == unique_vid_id.item(), (
"We can only postprocess one video per datapoint"
)
# keeping track of which objects appear in the current frame
obj_ids_per_frame = frame_outs["pred_object_ids"]
assert obj_ids_per_frame.size(-1) == frame_outs["pred_logits"].size(-2)

View File

@@ -95,9 +95,9 @@ class YTVIS(COCO):
anns = resFile
assert type(anns) == list, "results is not an array of objects"
annsImgIds = [ann["image_id"] for ann in anns]
assert set(annsImgIds) == (
set(annsImgIds) & set(self.getImgIds())
), "Results do not correspond to current coco set"
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), (
"Results do not correspond to current coco set"
)
if "bboxes" in anns[0] and not anns[0]["bboxes"] == []:
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
for id, ann in enumerate(anns):

View File

@@ -109,9 +109,7 @@ class YTVISevalMixin:
) # Num preds x Num GTS x Num frames
inter = inter.sum(-1)
union = union.sum(-1)
assert (
union > 0
).all(), (
assert (union > 0).all(), (
"There exists a tracklet with zero GTs across time. This is suspicious"
)
return inter / union
@@ -136,9 +134,9 @@ class YTVISevalMixin:
iou = inter / union
assert iou >= 0 and iou <= 1, "Encountered an error in IoU computation"
else:
assert np.isclose(inter, 0) and np.isclose(
union, 0
), "Encountered an error in IoU computation"
assert np.isclose(inter, 0) and np.isclose(union, 0), (
"Encountered an error in IoU computation"
)
iou = 1
return iou
@@ -206,16 +204,16 @@ class YTVISResultsWriter:
if len(prediction) == 0:
continue
for k in ["boxes", "scores", "labels"]:
assert (
k in prediction
), f"Expected predictions to have `{k}` key, available keys are {prediction.keys()}"
assert k in prediction, (
f"Expected predictions to have `{k}` key, available keys are {prediction.keys()}"
)
if self.save_per_frame_scores:
assert (
"per_frame_scores" in prediction
), f"Expected predictions to have `per_frame_scores` key, available keys are {prediction.keys()}"
assert xor(
"masks" in prediction, "masks_rle" in prediction
), f"Expected predictions to have either `masks` key or `masks_rle` key, available keys are {prediction.keys()}"
assert "per_frame_scores" in prediction, (
f"Expected predictions to have `per_frame_scores` key, available keys are {prediction.keys()}"
)
assert xor("masks" in prediction, "masks_rle" in prediction), (
f"Expected predictions to have either `masks` key or `masks_rle` key, available keys are {prediction.keys()}"
)
boxes = prediction["boxes"]
boxes = convert_to_xywh(boxes).tolist()
@@ -223,9 +221,9 @@ class YTVISResultsWriter:
labels = prediction["labels"].tolist()
if "masks" in prediction:
masks = prediction["masks"].squeeze(2)
assert (
masks.ndim == 4
), "Expected masks to be of shape(N_preds,T_frames,H,W)"
assert masks.ndim == 4, (
"Expected masks to be of shape(N_preds,T_frames,H,W)"
)
areas = [mask.flatten(1).sum(1).tolist() for mask in masks]
rles = [rle_encode(masklet) for masklet in masks]

View File

@@ -42,9 +42,9 @@ def get_logger(name, level=logging.INFO):
"""A command line logger."""
if "LOG_LEVEL" in os.environ:
level = os.environ["LOG_LEVEL"].upper()
assert (
level in LOG_LEVELS
), f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}"
assert level in LOG_LEVELS, (
f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}"
)
level = LOG_LEVELS[level]
logger = logging.getLogger(name)
logger.setLevel(level)

View File

@@ -7,7 +7,6 @@ Misc functions, including distributed helpers.
import collections
import re
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
@@ -29,9 +28,9 @@ def interpolate(
input, size, scale_factor, mode, align_corners
)
assert (
input.shape[0] != 0 or input.shape[1] != 0
), "At least one of the two first dimensions must be non zero"
assert input.shape[0] != 0 or input.shape[1] != 0, (
"At least one of the two first dimensions must be non zero"
)
if input.shape[1] == 0:
# Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim

View File

@@ -9,18 +9,13 @@ Inspired from Pytorch's version, adds the pre-norm variant
from typing import Any, Dict, List, Optional
import numpy as np
import torch
from sam3.sam.transformer import RoPEAttention
from torch import nn, Tensor
from torchvision.ops.roi_align import RoIAlign
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .model_misc import (
gen_sineembed_for_position,
get_activation_fn,
@@ -444,9 +439,9 @@ class TransformerDecoder(nn.Module):
- valid_ratios/spatial_shapes: bs, nlevel, 2
"""
if memory_mask is not None:
assert (
self.boxRPB == "none"
), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
assert self.boxRPB == "none", (
"inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
)
apply_dac = apply_dac if apply_dac is not None else self.dac
if apply_dac:
@@ -516,18 +511,18 @@ class TransformerDecoder(nn.Module):
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
if self.boxRPB != "none" and reference_boxes is not None:
assert (
spatial_shapes.shape[0] == 1
), "only single scale support implemented"
assert spatial_shapes.shape[0] == 1, (
"only single scale support implemented"
)
memory_mask = self._get_rpb_matrix(
reference_boxes,
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
)
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
if self.training:
assert (
self.use_act_checkpoint
), "Activation checkpointing not enabled in the decoder"
assert self.use_act_checkpoint, (
"Activation checkpointing not enabled in the decoder"
)
output, presence_out = activation_ckpt_wrapper(layer)(
tgt=output,
tgt_query_pos=query_pos,
@@ -676,9 +671,9 @@ class TransformerEncoderCrossAttention(nn.Module):
src_pos[0],
)
assert (
src.shape[1] == prompt.shape[1]
), "Batch size must be the same for src and prompt"
assert src.shape[1] == prompt.shape[1], (
"Batch size must be the same for src and prompt"
)
output = src

View File

@@ -322,9 +322,9 @@ class TransformerEncoder(nn.Module):
return reference_points
def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
assert (
len(srcs) == self.num_feature_levels
), "mismatch between expected and received # of feature levels"
assert len(srcs) == self.num_feature_levels, (
"mismatch between expected and received # of feature levels"
)
src_flatten = []
mask_flatten = []
@@ -406,9 +406,9 @@ class TransformerEncoder(nn.Module):
- spatial_shapes: Spatial dimensions of each feature level
- valid_ratios: Valid ratios for each feature level
"""
assert (
len(src) == self.num_feature_levels
), "must be equal to num_feature_levels"
assert len(src) == self.num_feature_levels, (
"must be equal to num_feature_levels"
)
if src_key_padding_masks is not None:
assert len(src_key_padding_masks) == self.num_feature_levels
if pos is not None:
@@ -538,9 +538,9 @@ class TransformerEncoderFusion(TransformerEncoder):
else None
)
else:
assert all(
x.dim == 4 for x in src
), "expected list of (bs, c, h, w) tensors"
assert all(x.dim == 4 for x in src), (
"expected list of (bs, c, h, w) tensors"
)
if self.add_pooled_text_to_img_feat:
# Fusion: Add mean pooled text to image features

View File

@@ -11,7 +11,6 @@ from typing_extensions import override
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .model_misc import get_clones
@@ -148,54 +147,42 @@ class Prompt:
)
# Dimension checks
assert (
box_embeddings is not None
and list(box_embeddings.shape[:2])
== [
box_seq_len,
bs,
]
), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
assert (
box_mask is not None
and list(box_mask.shape)
== [
bs,
box_seq_len,
]
), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
assert (
point_embeddings is not None
and list(point_embeddings.shape[:2])
== [
point_seq_len,
bs,
]
), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
assert (
point_mask is not None
and list(point_mask.shape)
== [
bs,
point_seq_len,
]
), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
assert (
box_labels is not None
and list(box_labels.shape)
== [
box_seq_len,
bs,
]
), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
assert (
point_labels is not None
and list(point_labels.shape)
== [
point_seq_len,
bs,
]
), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
assert box_embeddings is not None and list(box_embeddings.shape[:2]) == [
box_seq_len,
bs,
], (
f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
)
assert box_mask is not None and list(box_mask.shape) == [
bs,
box_seq_len,
], (
f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
)
assert point_embeddings is not None and list(point_embeddings.shape[:2]) == [
point_seq_len,
bs,
], (
f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
)
assert point_mask is not None and list(point_mask.shape) == [
bs,
point_seq_len,
], (
f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
)
assert box_labels is not None and list(box_labels.shape) == [
box_seq_len,
bs,
], (
f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
)
assert point_labels is not None and list(point_labels.shape) == [
point_seq_len,
bs,
], (
f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
)
assert (
# Allowed to be None, we leave it to the encoder to check for validity before encoding.
mask_embeddings is None
@@ -204,41 +191,41 @@ class Prompt:
mask_seq_len,
bs,
]
), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
assert (
mask_mask is None
or list(mask_mask.shape)
== [
bs,
mask_seq_len,
]
), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
), (
f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
)
assert mask_mask is None or list(mask_mask.shape) == [
bs,
mask_seq_len,
], (
f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
)
# Device checks
assert (
box_embeddings is not None and box_embeddings.device == device
), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
assert (
box_mask is not None and box_mask.device == device
), f"Expected box mask to be on device {device}, got {box_mask.device}"
assert (
box_labels is not None and box_labels.device == device
), f"Expected box labels to be on device {device}, got {box_labels.device}"
assert (
point_embeddings is not None and point_embeddings.device == device
), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
assert (
point_mask is not None and point_mask.device == device
), f"Expected point mask to be on device {device}, got {point_mask.device}"
assert (
point_labels is not None and point_labels.device == device
), f"Expected point labels to be on device {device}, got {point_labels.device}"
assert (
mask_embeddings is None or mask_embeddings.device == device
), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
assert (
mask_mask is None or mask_mask.device == device
), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
assert box_embeddings is not None and box_embeddings.device == device, (
f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
)
assert box_mask is not None and box_mask.device == device, (
f"Expected box mask to be on device {device}, got {box_mask.device}"
)
assert box_labels is not None and box_labels.device == device, (
f"Expected box labels to be on device {device}, got {box_labels.device}"
)
assert point_embeddings is not None and point_embeddings.device == device, (
f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
)
assert point_mask is not None and point_mask.device == device, (
f"Expected point mask to be on device {device}, got {point_mask.device}"
)
assert point_labels is not None and point_labels.device == device, (
f"Expected point labels to be on device {device}, got {point_labels.device}"
)
assert mask_embeddings is None or mask_embeddings.device == device, (
f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
)
assert mask_mask is None or mask_mask.device == device, (
f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
)
self.box_embeddings = box_embeddings
self.point_embeddings = point_embeddings
@@ -264,30 +251,30 @@ class Prompt:
if point_embeddings is not None:
point_seq_len = point_embeddings.shape[0]
if bs is not None:
assert (
bs == point_embeddings.shape[1]
), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
assert bs == point_embeddings.shape[1], (
f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
)
else:
bs = point_embeddings.shape[1]
if device is not None:
assert (
device == point_embeddings.device
), "Device mismatch between box and point embeddings"
assert device == point_embeddings.device, (
"Device mismatch between box and point embeddings"
)
else:
device = point_embeddings.device
if mask_embeddings is not None:
mask_seq_len = mask_embeddings.shape[0]
if bs is not None:
assert (
bs == mask_embeddings.shape[1]
), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
assert bs == mask_embeddings.shape[1], (
f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
)
else:
bs = mask_embeddings.shape[1]
if device is not None:
assert (
device == mask_embeddings.device
), "Device mismatch between box/point and mask embeddings."
assert device == mask_embeddings.device, (
"Device mismatch between box/point and mask embeddings."
)
else:
device = mask_embeddings.device
@@ -539,9 +526,9 @@ class SequenceGeometryEncoder(nn.Module):
if add_cls:
self.cls_embed = torch.nn.Embedding(1, self.d_model)
assert (
points_direct_project or points_pos_enc or points_pool
), "Error: need at least one way to encode points"
assert points_direct_project or points_pos_enc or points_pool, (
"Error: need at least one way to encode points"
)
assert (
encode_boxes_as_points
or boxes_direct_project
@@ -583,16 +570,16 @@ class SequenceGeometryEncoder(nn.Module):
self.encode = None
if num_layers > 0:
assert (
add_cls
), "It's currently highly recommended to add a CLS when using a transformer"
assert add_cls, (
"It's currently highly recommended to add a CLS when using a transformer"
)
self.encode = get_clones(layer, num_layers)
self.encode_norm = nn.LayerNorm(self.d_model)
if mask_encoder is not None:
assert isinstance(
mask_encoder, MaskEncoder
), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
assert isinstance(mask_encoder, MaskEncoder), (
f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
)
if add_mask_label:
self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
self.add_mask_label = add_mask_label
@@ -701,16 +688,15 @@ class SequenceGeometryEncoder(nn.Module):
img_feats: torch.Tensor = None,
):
n_masks, bs = masks.shape[:2]
assert (
n_masks == 1
), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
assert (
list(attn_mask.shape)
== [
bs,
n_masks,
]
), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
assert n_masks == 1, (
"We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
)
assert list(attn_mask.shape) == [
bs,
n_masks,
], (
f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
)
masks, pos = self.mask_encoder(
masks=masks.flatten(0, 1).float(),
pix_feat=img_feats,

View File

@@ -13,9 +13,7 @@ import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
from sam3.logger import get_logger
from tqdm import tqdm

View File

@@ -248,7 +248,9 @@ class UniversalSegmentationHead(SegmentationHead):
self.d_model = hidden_dim
if dot_product_scorer is not None:
assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
assert presence_head, (
"Specifying a dot product scorer without a presence head is likely a mistake"
)
self.presence_head = None
if presence_head:

View File

@@ -62,9 +62,9 @@ class SimpleMaskDownSampler(nn.Module):
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
self.interpol_size = interpol_size
if self.interpol_size is not None:
assert isinstance(
self.interpol_size, (list, tuple)
), f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
assert isinstance(self.interpol_size, (list, tuple)), (
f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
)
self.interpol_size = list(interpol_size)
assert len(self.interpol_size) == 2

View File

@@ -330,9 +330,9 @@ class SAM3Output(list):
self.output = output
else:
self.output = []
assert isinstance(
iter_mode, SAM3Output.IterMode
), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
assert isinstance(iter_mode, SAM3Output.IterMode), (
f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
)
self.iter_mode = iter_mode
# We create a weak reference to self to be used in the lambda functions.
@@ -411,9 +411,9 @@ class SAM3Output(list):
return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode)
def append(self, item: list):
assert isinstance(
item, list
), f"Only list items are supported. Got {type(item)}"
assert isinstance(item, list), (
f"Only list items are supported. Got {type(item)}"
)
self.output.append(item)
def __repr__(self):

View File

@@ -8,7 +8,6 @@ from copy import deepcopy
from typing import List, Optional, Tuple
import torch
import torch.nn as nn

View File

@@ -7,15 +7,12 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from PIL.Image import Image
from sam3.model.sam3_tracker_base import Sam3TrackerBase
from sam3.model.utils.sam1_utils import SAM2Transforms
@@ -97,9 +94,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
input_image = self._transforms(image)
input_image = input_image[None, ...].to(self.device)
assert (
len(input_image.shape) == 4 and input_image.shape[1] == 3
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
assert len(input_image.shape) == 4 and input_image.shape[1] == 3, (
f"input_image must be of size 1x3xHxW, got {input_image.shape}"
)
logging.info("Computing image embeddings for the provided image...")
backbone_out = self.model.forward_image(input_image)
(
@@ -136,17 +133,17 @@ class SAM3InteractiveImagePredictor(nn.Module):
assert isinstance(image_list, list)
self._orig_hw = []
for image in image_list:
assert isinstance(
image, np.ndarray
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
assert isinstance(image, np.ndarray), (
"Images are expected to be an np.ndarray in RGB format, and of shape HWC"
)
self._orig_hw.append(image.shape[:2])
# Transform the image to the form expected by the model
img_batch = self._transforms.forward_batch(image_list)
img_batch = img_batch.to(self.device)
batch_size = img_batch.shape[0]
assert (
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
assert len(img_batch.shape) == 4 and img_batch.shape[1] == 3, (
f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
)
logging.info("Computing image embeddings for the provided images...")
backbone_out = self.model.forward_image(img_batch)
(
@@ -302,9 +299,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
):
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
if point_coords is not None:
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
assert point_labels is not None, (
"point_labels must be supplied if point_coords is supplied."
)
point_coords = torch.as_tensor(
point_coords, dtype=torch.float, device=self.device
)
@@ -441,9 +438,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
assert (
self._features is not None
), "Features must exist if an image has been set."
assert self._features is not None, (
"Features must exist if an image has been set."
)
return self._features["image_embed"]
@property

View File

@@ -8,19 +8,14 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from sam3.model.model_misc import SAM3Output
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
from sam3.model.vl_combiner import SAM3VLBackbone
from sam3.perflib.nms import nms_masks
from sam3.train.data.collator import BatchedDatapoint
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .geometry_encoders import Prompt
from .model_misc import inverse_sigmoid
@@ -661,9 +656,9 @@ class Sam3Image(torch.nn.Module):
inference_state["original_heights"],
inference_state["original_widths"],
)
assert (
batch_size == len(orig_heights) == len(orig_widths)
), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
assert batch_size == len(orig_heights) == len(orig_widths), (
f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
)
feats = [
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
for feat, feat_size in zip(

View File

@@ -6,9 +6,7 @@ from typing import Dict, List
import numpy as np
import PIL
import torch
from sam3.model import box_ops
from sam3.model.data_misc import FindStage, interpolate
from torchvision.transforms import v2
@@ -83,9 +81,9 @@ class Sam3Processor:
if not isinstance(images, list):
raise ValueError("Images must be a list of PIL images or tensors")
assert len(images) > 0, "Images list must not be empty"
assert isinstance(
images[0], PIL.Image.Image
), "Images must be a list of PIL images"
assert isinstance(images[0], PIL.Image.Image), (
"Images must be a list of PIL images"
)
state["original_heights"] = [image.height for image in images]
state["original_widths"] = [image.width for image in images]

View File

@@ -6,11 +6,8 @@ import logging
import torch
import torch.nn.functional as F
from sam3.model.memory import SimpleMaskEncoder
from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames
from sam3.sam.mask_decoder import MaskDecoder, MLP
from sam3.sam.prompt_encoder import PromptEncoder
from sam3.sam.transformer import TwoWayTransformer

View File

@@ -6,7 +6,6 @@ import numpy as np
import torch
import torch.nn.functional as F
from numpy.typing import NDArray
from sam3.model.edt import edt_triton

View File

@@ -6,7 +6,6 @@ import logging
from collections import OrderedDict
import torch
from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase
from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores
from sam3.model.utils.sam2_utils import load_video_frames

View File

@@ -16,7 +16,6 @@ import numpy.typing as npt
import torch
import torch.distributed as dist
import torch.nn.functional as F
from sam3 import perflib
from sam3.logger import get_logger
from sam3.model.box_ops import fast_diag_box_iou
@@ -620,9 +619,9 @@ class Sam3VideoBase(nn.Module):
num_obj_dropped_due_to_limit,
trk_id_to_max_iou_high_conf_det,
]
assert (
len(update_plan) == NUM_BROADCAST_ITEMS
), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
assert len(update_plan) == NUM_BROADCAST_ITEMS, (
f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
)
self.broadcast_python_obj_cpu(update_plan, src=0)
elif self.rank > 0 and self.world_size > 1:
update_plan = [
@@ -842,9 +841,9 @@ class Sam3VideoBase(nn.Module):
binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
batch_size = tracker_low_res_masks_global.size(0)
if batch_size > 0:
assert (
len(obj_ids_global) == batch_size
), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
assert len(obj_ids_global) == batch_size, (
f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
)
NEVER_OCCLUDED = -1
ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic
last_occluded_prev = torch.cat(
@@ -1023,9 +1022,9 @@ class Sam3VideoBase(nn.Module):
reverse: bool = False,
):
# Suppress overlapping masks for objects that were most recently occluded
assert (
binary_low_res_masks.dtype == torch.bool
), f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
assert binary_low_res_masks.dtype == torch.bool, (
f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
)
to_suppress = torch.zeros(
binary_low_res_masks.size(0),
device=binary_low_res_masks.device,
@@ -1130,9 +1129,9 @@ class Sam3VideoBase(nn.Module):
num_frames_propagated += 1
# only 1 frames should be propagated
assert (
num_frames_propagated == 1 and out_frame_idx == frame_idx
), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
assert num_frames_propagated == 1 and out_frame_idx == frame_idx, (
f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
)
assert isinstance(out_obj_ids, list)
obj_ids_local.extend(out_obj_ids)
low_res_masks_list.append(out_low_res_masks.squeeze(1))
@@ -1189,9 +1188,9 @@ class Sam3VideoBase(nn.Module):
assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
assert (
trk_masks.size(0) == len(trk_obj_ids)
), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
assert trk_masks.size(0) == len(trk_obj_ids), (
f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
)
if trk_masks.size(0) == 0:
# all detections are new
new_det_fa_inds = np.arange(det_masks.size(0))
@@ -1655,9 +1654,9 @@ class Sam3VideoBase(nn.Module):
# a) first, expand "confirmation_data" to include new masklets added in this frame
status_prev = confirmation_data["status"]
consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
assert (
status_prev.shape == obj_ids_all_gpu_prev.shape
), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
assert status_prev.shape == obj_ids_all_gpu_prev.shape, (
f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
)
obj_id_to_updated_idx = {
obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)

View File

@@ -9,7 +9,6 @@ import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from sam3 import perflib
from sam3.logger import get_logger
from sam3.model.act_ckpt_utils import clone_output_wrapper
@@ -555,7 +554,9 @@ class Sam3VideoInference(Sam3VideoBase):
assert (
"cached_frame_outputs" in inference_state
and frame_idx in inference_state["cached_frame_outputs"]
), "No cached outputs found. Ensure normal propagation has run first to populate the cache."
), (
"No cached outputs found. Ensure normal propagation has run first to populate the cache."
)
cached_outputs = inference_state["cached_frame_outputs"][frame_idx]
obj_id_to_mask = cached_outputs.copy()
@@ -563,9 +564,9 @@ class Sam3VideoInference(Sam3VideoBase):
# Update with refined masks if provided
if refined_obj_id_to_mask is not None:
for obj_id, refined_mask in refined_obj_id_to_mask.items():
assert (
refined_mask is not None
), f"Refined mask data must be provided for obj_id {obj_id}"
assert refined_mask is not None, (
f"Refined mask data must be provided for obj_id {obj_id}"
)
obj_id_to_mask[obj_id] = refined_mask
return obj_id_to_mask
@@ -660,12 +661,12 @@ class Sam3VideoInference(Sam3VideoBase):
for i, thresh in enumerate(new_det_score_thresh_list):
self.new_det_thresh = thresh
for num_objects in num_objects_list:
logger.info(f"{i+1}/{num_rounds} warming up model compilation")
logger.info(f"{i + 1}/{num_rounds} warming up model compilation")
self.add_prompt(
inference_state, frame_idx=start_frame_idx, text_str="cat"
)
logger.info(
f"{i+1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
f"{i + 1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
)
inference_state = self.add_fake_objects_to_inference_state(
inference_state, num_objects, frame_idx=start_frame_idx
@@ -690,7 +691,7 @@ class Sam3VideoInference(Sam3VideoBase):
pass
self.reset_state(inference_state)
logger.info(
f"{i+1}/{num_rounds} warming up model compilation -- completed round {i+1} out of {num_rounds}"
f"{i + 1}/{num_rounds} warming up model compilation -- completed round {i + 1} out of {num_rounds}"
)
# Warm up Tracker memory encoder with varying input shapes
@@ -854,12 +855,12 @@ class Sam3VideoInference(Sam3VideoBase):
logger.debug("Running add_prompt on frame %d", frame_idx)
num_frames = inference_state["num_frames"]
assert (
text_str is not None or boxes_xywh is not None
), "at least one type of prompt (text, boxes) must be provided"
assert (
0 <= frame_idx < num_frames
), f"{frame_idx=} is out of range for a total of {num_frames} frames"
assert text_str is not None or boxes_xywh is not None, (
"at least one type of prompt (text, boxes) must be provided"
)
assert 0 <= frame_idx < num_frames, (
f"{frame_idx=} is out of range for a total of {num_frames} frames"
)
# since it's a semantic prompt, we start over
self.reset_state(inference_state)
@@ -1200,9 +1201,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
"propagation_partial",
"propagation_fetch",
]
assert (
action_type in instance_actions + propagation_actions
), f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
assert action_type in instance_actions + propagation_actions, (
f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
)
action = {
"type": action_type,
"frame_idx": frame_idx,
@@ -1370,12 +1371,12 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
):
if points is not None:
# Tracker instance prompts
assert (
text_str is None and boxes_xywh is None
), "When points are provided, text_str and boxes_xywh must be None."
assert (
obj_id is not None
), "When points are provided, obj_id must be provided."
assert text_str is None and boxes_xywh is None, (
"When points are provided, text_str and boxes_xywh must be None."
)
assert obj_id is not None, (
"When points are provided, obj_id must be provided."
)
return self.add_tracker_new_points(
inference_state,
frame_idx,
@@ -1491,9 +1492,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
tracker_states = self._get_tracker_inference_states_by_obj_ids(
inference_state, [obj_id]
)
assert (
len(tracker_states) == 1
), f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
assert len(tracker_states) == 1, (
f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
)
tracker_state = tracker_states[0]
# log

View File

@@ -16,7 +16,6 @@ from typing import List, Optional
import psutil
import torch
from sam3.logger import get_logger
logger = get_logger(__name__)
@@ -170,7 +169,7 @@ class Sam3VideoPredictor:
):
"""Remove an object from tracking."""
logger.debug(
f"remove object {obj_id} in session {session_id}: " f"{is_user_action=}"
f"remove object {obj_id} in session {session_id}: {is_user_action=}"
)
session = self._get_session(session_id)
inference_state = session["state"]

View File

@@ -318,9 +318,9 @@ class VETextEncoder(nn.Module):
# The text is already encoded, use as is.
text_attention_mask, text_memory_resized, tokenized = text
inputs_embeds = tokenized["inputs_embeds"]
assert (
input_boxes is None or len(input_boxes) == 0
), "Can't replace boxes in text if it's already encoded"
assert input_boxes is None or len(input_boxes) == 0, (
"Can't replace boxes in text if it's already encoded"
)
# Note that the input_embeds are returned in pytorch's convention (sequence first)
return (

View File

@@ -708,9 +708,9 @@ class ViT(nn.Module):
self.retain_cls_token = retain_cls_token
if self.retain_cls_token:
assert pretrain_use_cls_token
assert (
len(window_block_indexes) == 0
), "windowing not supported with cls token"
assert len(window_block_indexes) == 0, (
"windowing not supported with cls token"
)
assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"

View File

@@ -9,7 +9,6 @@ from typing import List, Optional
import torch
import torch.nn as nn
from torch.nn.attention import sdpa_kernel, SDPBackend
from .act_ckpt_utils import activation_ckpt_wrapper

View File

@@ -6,7 +6,6 @@ import os
from typing import Optional
import pkg_resources
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download

View File

@@ -36,9 +36,9 @@ def connected_components_cpu(input_tensor: torch.Tensor):
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
input_tensor = input_tensor.squeeze(1)
else:
assert (
input_tensor.dim() == 3
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
assert input_tensor.dim() == 3, (
"Input tensor must be (B, H, W) or (B, 1, H, W)."
)
batch_size = input_tensor.shape[0]
labels_list = []
@@ -67,9 +67,9 @@ def connected_components(input_tensor: torch.Tensor):
if input_tensor.dim() == 3:
input_tensor = input_tensor.unsqueeze(1)
assert (
input_tensor.dim() == 4 and input_tensor.shape[1] == 1
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
assert input_tensor.dim() == 4 and input_tensor.shape[1] == 1, (
"Input tensor must be (B, H, W) or (B, 1, H, W)."
)
if input_tensor.is_cuda:
if HAS_CC_TORCH:

View File

@@ -6,7 +6,6 @@ import logging
import numpy as np
import torch
from sam3.perflib.masks_ops import mask_iou

View File

@@ -407,16 +407,16 @@ def connected_components_triton(input_tensor: torch.Tensor):
- A BxHxW output tensor with dense labels. Background is 0.
- A BxHxW tensor with the size of the connected component for each pixel.
"""
assert (
input_tensor.is_cuda and input_tensor.is_contiguous()
), "Input tensor must be a contiguous CUDA tensor."
assert input_tensor.is_cuda and input_tensor.is_contiguous(), (
"Input tensor must be a contiguous CUDA tensor."
)
out_shape = input_tensor.shape
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
input_tensor = input_tensor.squeeze(1)
else:
assert (
input_tensor.dim() == 3
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
assert input_tensor.dim() == 3, (
"Input tensor must be (B, H, W) or (B, 1, H, W)."
)
B, H, W = input_tensor.shape
numel = B * H * W

View File

@@ -202,9 +202,9 @@ class MaskDecoder(nn.Module):
assert image_embeddings.shape[0] == tokens.shape[0]
src = image_embeddings
src = src + dense_prompt_embeddings
assert (
image_pe.size(0) == 1
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
assert image_pe.size(0) == 1, (
"image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
)
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape

View File

@@ -8,7 +8,6 @@ from typing import Tuple, Type
import torch
import torch.nn.functional as F
from sam3.sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis
from torch import nn, Tensor
@@ -205,9 +204,9 @@ class Attention(nn.Module):
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
self.use_fa3 = use_fa3
assert (
self.internal_dim % num_heads == 0
), "num_heads must divide embedding_dim."
assert self.internal_dim % num_heads == 0, (
"num_heads must divide embedding_dim."
)
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)

View File

@@ -142,9 +142,9 @@ class COCO_FROM_JSON:
self.prompts = {}
for loc_dict in prompts:
self.prompts[int(loc_dict["id"])] = loc_dict["name"]
assert len(self.prompts) == len(
self._sorted_cat_ids
), "Number of prompts must match number of categories"
assert len(self.prompts) == len(self._sorted_cat_ids), (
"Number of prompts must match number of categories"
)
def getDatapointIds(self):
"""Return all datapoint indices for training."""

View File

@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_data
from typing import Any, get_args, get_origin, List, Union
import torch
from sam3.model.data_misc import (
BatchedDatapoint,
BatchedFindTarget,
@@ -217,9 +216,9 @@ def collate_fn_api(
text_batch.append(q.query_text)
stages[stage_id].text_ids.append(text_batch.index(q.query_text))
assert (
q.inference_metadata is not None
), "inference_metadata must be provided when FindQueryLoaded is created."
assert q.inference_metadata is not None, (
"inference_metadata must be provided when FindQueryLoaded is created."
)
for f in fields(q.inference_metadata):
getattr(find_metadatas[stage_id], f.name).append(
getattr(q.inference_metadata, f.name)

View File

@@ -19,10 +19,8 @@ import torch.utils.data
import torchvision
from decord import cpu, VideoReader
from iopath.common.file_io import g_pathmgr
from PIL import Image as PILImage
from PIL.Image import DecompressionBombError
from sam3.model.box_ops import box_xywh_to_xyxy
from torchvision.datasets.vision import VisionDataset
@@ -234,9 +232,9 @@ class CustomCocoDetectionAPI(VisionDataset):
if self.coco is not None:
return
assert g_pathmgr.isfile(
self.annFile
), f"please provide valid annotation file. Missing: {self.annFile}"
assert g_pathmgr.isfile(self.annFile), (
f"please provide valid annotation file. Missing: {self.annFile}"
)
annFile = g_pathmgr.get_local_path(self.annFile)
if self.coco is not None:
@@ -326,9 +324,9 @@ class CustomCocoDetectionAPI(VisionDataset):
else:
num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
for stage, num_queries in stage2num_queries.items():
assert (
num_queries == num_queries_per_stage
), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
assert num_queries == num_queries_per_stage, (
f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
)
for query_id, query in enumerate(queries):
h, w = id2imsize[query["image_id"]]

View File

@@ -3,7 +3,6 @@
# pyre-unsafe
import copy
import io
import json
import logging
@@ -16,7 +15,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
import torchvision
# from decord import cpu, VideoReader
from iopath.common.file_io import PathManager
@@ -220,9 +218,9 @@ class VideoGroundingDataset(Sam3ImageDataset):
for query in filtered_queries:
ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
assert (
ptr_x_is_empty and ptr_y_is_empty
), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
assert ptr_x_is_empty and ptr_y_is_empty, (
"Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
)
query["query_processing_order"] = stage_id_old2new[
query["query_processing_order"]
]

View File

@@ -9,11 +9,8 @@ import torch
import torch.distributed
import torch.nn.functional as F
import torchmetrics
from sam3.model import box_ops
from sam3.model.data_misc import interpolate
from sam3.train.loss.sigmoid_focal_loss import (
triton_sigmoid_focal_loss,
triton_sigmoid_focal_loss_reduce,
@@ -327,7 +324,9 @@ class IABCEMdetr(LossWithWeights):
if num_det_queries is not None:
logging.warning("note: it's not needed to set num_det_queries anymore")
if self.use_separate_loss_for_det_and_trk:
assert not self.weak_loss, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
assert not self.weak_loss, (
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
)
self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos
self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg
self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos
@@ -342,7 +341,9 @@ class IABCEMdetr(LossWithWeights):
and det_non_exhaustive_loss_scale_neg == 1.0
and trk_loss_scale_pos == 1.0
and trk_loss_scale_neg == 1.0
), "If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
), (
"If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
)
def get_loss(self, outputs, targets, indices, num_boxes):
assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape"
@@ -443,7 +444,9 @@ class IABCEMdetr(LossWithWeights):
pass
if self.weak_loss:
assert not self.use_separate_loss_for_det_and_trk, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
assert not self.use_separate_loss_for_det_and_trk, (
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
)
# nullify the negative loss for the non-exhaustive classes
assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0]
@@ -497,9 +500,9 @@ class IABCEMdetr(LossWithWeights):
loss_bce = loss_bce.mean()
else:
assert isinstance(self.pad_n_queries, int)
assert (
loss_bce.size(1) < self.pad_n_queries
), f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
assert loss_bce.size(1) < self.pad_n_queries, (
f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
)
loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0))
bce_f1 = torchmetrics.functional.f1_score(

View File

@@ -3,9 +3,7 @@
# pyre-unsafe
import torch
from sam3.model.model_misc import SAM3Output
from sam3.train.utils.distributed import get_world_size
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks

View File

@@ -103,9 +103,9 @@ def dilation(mask, kernel_size):
assert mask.ndim == 3
kernel_size = int(kernel_size)
assert (
kernel_size % 2 == 1
), f"Dilation expects a odd kernel size, got {kernel_size}"
assert kernel_size % 2 == 1, (
f"Dilation expects a odd kernel size, got {kernel_size}"
)
if mask.is_cuda:
m = mask.unsqueeze(1).to(torch.float16)

View File

@@ -8,7 +8,6 @@ Modules to compute the matching cost and solve the corresponding LSAP.
import numpy as np
import torch
from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
from scipy.optimize import linear_sum_assignment
from torch import nn
@@ -60,9 +59,9 @@ class HungarianMatcher(nn.Module):
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1)
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
"all costs cant be 0"
)
self.focal_loss = focal_loss
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
@@ -197,9 +196,9 @@ class BinaryHungarianMatcher(nn.Module):
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.norm = nn.Sigmoid()
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
"all costs cant be 0"
)
@torch.no_grad()
def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1):
@@ -322,9 +321,9 @@ class BinaryFocalHungarianMatcher(nn.Module):
self.alpha = alpha
self.gamma = gamma
self.stable = stable
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
"all costs cant be 0"
)
@torch.no_grad()
def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1):
@@ -470,9 +469,9 @@ class BinaryHungarianMatcherV2(nn.Module):
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.norm = nn.Sigmoid()
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
"all costs cant be 0"
)
self.focal = focal
if focal:
self.alpha = alpha

View File

@@ -22,7 +22,6 @@ from typing import (
)
import hydra
import torch
import torch.nn as nn
from omegaconf import DictConfig
@@ -212,9 +211,9 @@ def unix_module_cls_pattern_to_parameter_names(
"match any classes in the model"
)
matching_parameters = module_cls_to_param_names[module_cls]
assert (
len(matching_parameters) > 0
), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
assert len(matching_parameters) > 0, (
f"module_cls_name {module_cls_name} does not contain any parameters in the model"
)
logging.info(
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
)
@@ -240,9 +239,9 @@ def unix_param_pattern_to_parameter_names(
allowed_parameter_names = []
for param_name in filter_param_names:
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
assert (
len(matching_parameters) >= 1
), f"param_name {param_name} does not match any parameters in the model"
assert len(matching_parameters) >= 1, (
f"param_name {param_name} does not match any parameters in the model"
)
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
allowed_parameter_names.append(matching_parameters)
return set.union(*allowed_parameter_names)

View File

@@ -12,13 +12,10 @@ from copy import deepcopy
import submitit
import torch
from hydra import compose, initialize_config_module
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from omegaconf import OmegaConf
from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers
from tqdm import tqdm
@@ -212,9 +209,9 @@ def main(args) -> None:
},
}
if "include_nodes" in submitit_conf:
assert (
len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
), "Not enough nodes"
assert len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes, (
"Not enough nodes"
)
job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
submitit_conf["include_nodes"]
)

View File

@@ -15,28 +15,22 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Mapping, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from sam3.model.data_misc import BatchedDatapoint
from sam3.model.model_misc import SAM3Output
from sam3.model.utils.misc import copy_data_to_device
from sam3.train.optim.optimizer import construct_optimizer
from sam3.train.utils.checkpoint_utils import (
assert_skipped_parameters_are_frozen,
exclude_params_matching_unix_pattern,
load_state_dict_into_model,
with_check_parameter_frozen,
)
from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank
from sam3.train.utils.logger import Logger, setup_logging
from sam3.train.utils.train_utils import (
AverageMeter,
@@ -215,9 +209,9 @@ class Trainer:
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
log_env_variables()
assert (
is_dist_avail_and_initialized()
), "Torch distributed needs to be initialized before calling the trainer."
assert is_dist_avail_and_initialized(), (
"Torch distributed needs to be initialized before calling the trainer."
)
self._setup_components() # Except Optimizer everything is setup here.
self._move_to_device()
@@ -227,9 +221,9 @@ class Trainer:
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
if self.checkpoint_conf.resume_from is not None:
assert os.path.exists(
self.checkpoint_conf.resume_from
), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
assert os.path.exists(self.checkpoint_conf.resume_from), (
f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
)
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
if self.distributed_rank == 0 and not os.path.exists(dst):
# Copy the "resume_from" checkpoint to the checkpoint folder
@@ -477,9 +471,9 @@ class Trainer:
return self.loss[key]
assert key != "all", "Loss must be specified for key='all'"
assert (
"default" in self.loss
), f"Key {key} not found in losss, and no default provided"
assert "default" in self.loss, (
f"Key {key} not found in losss, and no default provided"
)
return self.loss["default"]
def _find_meter(self, phase: str, key: str):
@@ -922,12 +916,12 @@ class Trainer:
self.optim.zero_grad(set_to_none=True)
if self.gradient_accumulation_steps > 1:
assert isinstance(
batch, list
), f"Expected a list of batches, got {type(batch)}"
assert (
len(batch) == self.gradient_accumulation_steps
), f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
assert isinstance(batch, list), (
f"Expected a list of batches, got {type(batch)}"
)
assert len(batch) == self.gradient_accumulation_steps, (
f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
)
accum_steps = len(batch)
else:
accum_steps = 1
@@ -1039,9 +1033,9 @@ class Trainer:
def _check_val_key_match(self, val_keys, phase):
if val_keys is not None:
# Check if there are any duplicates
assert len(val_keys) == len(
set(val_keys)
), f"Duplicate keys in val datasets, keys: {val_keys}"
assert len(val_keys) == len(set(val_keys)), (
f"Duplicate keys in val datasets, keys: {val_keys}"
)
# Check that the keys match the meter keys
if self.meters_conf is not None and phase in self.meters_conf:
@@ -1055,9 +1049,9 @@ class Trainer:
loss_keys = set(self.loss_conf.keys()) - set(["all"])
if "default" not in loss_keys:
for k in val_keys:
assert (
k in loss_keys
), f"Error: key {k} is not defined in the losses, and no default is set"
assert k in loss_keys, (
f"Error: key {k} is not defined in the losses, and no default is set"
)
def _setup_components(self):
# Get the keys for all the val datasets, if any

View File

@@ -14,7 +14,6 @@ import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from sam3.model.box_ops import box_xyxy_to_cxcywh
from sam3.model.data_misc import interpolate
@@ -277,9 +276,9 @@ class RandomSizeCrop:
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
)
result_img, result_target = crop(img, target, [j, i, h, w])
assert (
len(result_target["boxes"]) == init_boxes
), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
assert len(result_target["boxes"]) == init_boxes, (
f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
)
return result_img, result_target
else:

View File

@@ -7,7 +7,6 @@ Transforms and data augmentation for both image + bbox.
"""
import logging
import numbers
import random
from collections.abc import Sequence
@@ -17,9 +16,7 @@ import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
import torchvision.transforms.v2.functional as Fv2
from PIL import Image as PILImage
from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes
from sam3.train.data.sam3_image_dataset import Datapoint
from torchvision.transforms import InterpolationMode

View File

@@ -4,12 +4,10 @@
import logging
import random
from collections import defaultdict
from typing import List, Optional, Union
import torch
from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object
@@ -381,9 +379,9 @@ class FlexibleFilterFindGetQueries:
if len(new_find_queries) == 0:
start_with_zero_check = True
assert (
start_with_zero_check
), "Invalid Find queries, they need to start at query_processing_order = 0"
assert start_with_zero_check, (
"Invalid Find queries, they need to start at query_processing_order = 0"
)
datapoint.find_queries = new_find_queries

View File

@@ -7,7 +7,6 @@ import numpy as np
import torch
from PIL import Image as PILImage
from pycocotools import mask as mask_util
from sam3.train.data.sam3_image_dataset import Datapoint
from torchvision.ops import masks_to_boxes
@@ -250,9 +249,9 @@ class RandomGeometricInputsAPI:
def _get_target_object(self, datapoint, query):
img = datapoint.images[query.image_id]
targets = query.object_ids_output
assert (
len(targets) == 1
), "Geometric queries only support a single target object."
assert len(targets) == 1, (
"Geometric queries only support a single target object."
)
target_idx = targets[0]
return img.objects[target_idx]

View File

@@ -5,12 +5,9 @@
import numpy as np
import pycocotools.mask as mask_utils
import torch
import torchvision.transforms.functional as F
from PIL import Image as PILImage
from sam3.model.box_ops import masks_to_boxes
from sam3.train.data.sam3_image_dataset import Datapoint

View File

@@ -36,9 +36,9 @@ def unix_pattern_to_parameter_names(
parameter_names = []
for param_name in constraints:
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
assert (
len(matching_parameters) > 0
), f"param_names {param_name} don't match any param in the given names."
assert len(matching_parameters) > 0, (
f"param_names {param_name} don't match any param in the given names."
)
parameter_names.append(matching_parameters)
return set.union(*parameter_names)

View File

@@ -10,10 +10,8 @@ import uuid
from typing import Any, Dict, Optional, Union
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from numpy import ndarray
from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter

View File

@@ -11,7 +11,6 @@ from datetime import timedelta
from typing import Optional
import hydra
import numpy as np
import omegaconf
import torch
@@ -83,9 +82,9 @@ def get_machine_local_and_dist_rank():
"""
local_rank = int(os.environ.get("LOCAL_RANK", None))
distributed_rank = int(os.environ.get("RANK", None))
assert (
local_rank is not None and distributed_rank is not None
), "Please the set the RANK and LOCAL_RANK environment variables."
assert local_rank is not None and distributed_rank is not None, (
"Please the set the RANK and LOCAL_RANK environment variables."
)
return local_rank, distributed_rank

View File

@@ -158,22 +158,22 @@ def plot_mask(mask, color="r", ax=None):
def normalize_bbox(bbox_xywh, img_w, img_h):
# Assumes bbox_xywh is in XYWH format
if isinstance(bbox_xywh, list):
assert (
len(bbox_xywh) == 4
), "bbox_xywh list must have 4 elements. Batching not support except for torch tensors."
assert len(bbox_xywh) == 4, (
"bbox_xywh list must have 4 elements. Batching not support except for torch tensors."
)
normalized_bbox = bbox_xywh.copy()
normalized_bbox[0] /= img_w
normalized_bbox[1] /= img_h
normalized_bbox[2] /= img_w
normalized_bbox[3] /= img_h
else:
assert isinstance(
bbox_xywh, torch.Tensor
), "Only torch tensors are supported for batching."
assert isinstance(bbox_xywh, torch.Tensor), (
"Only torch tensors are supported for batching."
)
normalized_bbox = bbox_xywh.clone()
assert (
normalized_bbox.size(-1) == 4
), "bbox_xywh tensor must have last dimension of size 4."
assert normalized_bbox.size(-1) == 4, (
"bbox_xywh tensor must have last dimension of size 4."
)
normalized_bbox[..., 0] /= img_w
normalized_bbox[..., 1] /= img_h
normalized_bbox[..., 2] /= img_w
@@ -244,10 +244,10 @@ def visualize_formatted_frame_output(
num_outputs = len(outputs_list)
if titles is None:
titles = [f"Set {i+1}" for i in range(num_outputs)]
assert (
len(titles) == num_outputs
), "length of `titles` should match that of `outputs_list` if not None."
titles = [f"Set {i + 1}" for i in range(num_outputs)]
assert len(titles) == num_outputs, (
"length of `titles` should match that of `outputs_list` if not None."
)
_, axes = plt.subplots(1, num_outputs, figsize=figsize)
if num_outputs == 1:
@@ -703,9 +703,9 @@ def get_all_annotations_for_frame(
# Get the frame
video_df_current = video_df[video_df.id == video_id]
assert (
len(video_df_current) == 1
), f"Expected 1 video row, got {len(video_df_current)}"
assert len(video_df_current) == 1, (
f"Expected 1 video row, got {len(video_df_current)}"
)
video_row = video_df_current.iloc[0]
file_name = video_row.file_names[frame_idx]
file_path = os.path.join(
@@ -796,7 +796,7 @@ def visualize_prompt_overlay(
ax.text(
x_img + 5,
y_img - 5,
f"P{i+1}",
f"P{i + 1}",
color=color,
fontsize=10,
weight="bold",
@@ -828,7 +828,7 @@ def visualize_prompt_overlay(
ax.text(
x_img,
y_img - 5,
f"B{i+1}",
f"B{i + 1}",
color=color,
fontsize=10,
weight="bold",