diff --git a/sam3/agent/agent_core.py b/sam3/agent/agent_core.py index f893f00..8b4fc8a 100644 --- a/sam3/agent/agent_core.py +++ b/sam3/agent/agent_core.py @@ -296,9 +296,9 @@ def agent_inference( assert LATEST_SAM3_TEXT_PROMPT != "" # Make sure that the last message is a image - assert ( - messages[-1]["content"][1]["type"] == "image" - ), "Second content element should be an image" + assert messages[-1]["content"][1]["type"] == "image", ( + "Second content element should be an image" + ) messages.pop() # Remove the last user message # Add simplified replacement message simplified_message = { @@ -318,7 +318,7 @@ def agent_inference( # MLLM check the mask one by one for i in range(num_masks): - print(f"🔍 Checking mask {i+1}/{num_masks}...") + print(f"🔍 Checking mask {i + 1}/{num_masks}...") image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i) image_w_zoomed_in_mask_i_path = os.path.join( @@ -363,7 +363,7 @@ def agent_inference( raise ValueError( "Generated text is None, which is unexpected. Please check the Qwen server and the input parameters." ) - print(f"Generated text for mask {i+1}: {checking_generated_text}") + print(f"Generated text for mask {i + 1}: {checking_generated_text}") verdict = ( checking_generated_text.split("")[-1] .split("")[0] @@ -371,11 +371,11 @@ def agent_inference( ) if "Accept" in verdict: assert not "Reject" in verdict - print(f"Mask {i+1} accepted, keeping it in the outputs.") + print(f"Mask {i + 1} accepted, keeping it in the outputs.") masks_to_keep.append(i) elif "Reject" in verdict: assert not "Accept" in verdict - print(f"Mask {i+1} rejected, removing it from the outputs.") + print(f"Mask {i + 1} rejected, removing it from the outputs.") else: raise ValueError( f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'." @@ -397,7 +397,7 @@ def agent_inference( sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png" ).replace( ".png", - f"_selected_masks_{'-'.join(map(str, [i+1 for i in masks_to_keep]))}.png".replace( + f"_selected_masks_{'-'.join(map(str, [i + 1 for i in masks_to_keep]))}.png".replace( "/", "_" ), ) diff --git a/sam3/agent/client_sam3.py b/sam3/agent/client_sam3.py index b138e60..daeb849 100755 --- a/sam3/agent/client_sam3.py +++ b/sam3/agent/client_sam3.py @@ -7,7 +7,6 @@ import os import torch from PIL import Image - from sam3.model.box_ops import box_xyxy_to_xywh from sam3.train.masks_ops import rle_encode diff --git a/sam3/agent/helpers/boxes.py b/sam3/agent/helpers/boxes.py index 40df92e..df44769 100755 --- a/sam3/agent/helpers/boxes.py +++ b/sam3/agent/helpers/boxes.py @@ -84,9 +84,9 @@ class BoxMode(IntEnum): ], "Relative mode not yet supported!" if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS: - assert ( - arr.shape[-1] == 5 - ), "The last dimension of input shape must be 5 for XYWHA format" + assert arr.shape[-1] == 5, ( + "The last dimension of input shape must be 5 for XYWHA format" + ) original_dtype = arr.dtype arr = arr.double() @@ -244,9 +244,9 @@ class Boxes: if isinstance(item, int): return Boxes(self.tensor[item].view(1, -1)) b = self.tensor[item] - assert ( - b.dim() == 2 - ), "Indexing on Boxes with {} failed to return a matrix!".format(item) + assert b.dim() == 2, ( + "Indexing on Boxes with {} failed to return a matrix!".format(item) + ) return Boxes(b) def __len__(self) -> int: @@ -425,7 +425,7 @@ def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: Tensor: iou, sized [N]. """ assert len(boxes1) == len(boxes2), ( - "boxlists should have the same" "number of entries, got {}, {}".format( + "boxlists should have the samenumber of entries, got {}, {}".format( len(boxes1), len(boxes2) ) ) diff --git a/sam3/agent/helpers/masks.py b/sam3/agent/helpers/masks.py index a55b5e2..a303f6e 100755 --- a/sam3/agent/helpers/masks.py +++ b/sam3/agent/helpers/masks.py @@ -13,7 +13,6 @@ from torch import device from .boxes import Boxes from .memory import retry_if_cuda_oom - from .roi_align import ROIAlign @@ -142,10 +141,10 @@ class BitMasks: if isinstance(item, int): return BitMasks(self.tensor[item].unsqueeze(0)) m = self.tensor[item] - assert ( - m.dim() == 3 - ), "Indexing on BitMasks with {} returns a tensor with shape {}!".format( - item, m.shape + assert m.dim() == 3, ( + "Indexing on BitMasks with {} returns a tensor with shape {}!".format( + item, m.shape + ) ) return BitMasks(m) diff --git a/sam3/agent/helpers/rotated_boxes.py b/sam3/agent/helpers/rotated_boxes.py index cd39af8..0017335 100755 --- a/sam3/agent/helpers/rotated_boxes.py +++ b/sam3/agent/helpers/rotated_boxes.py @@ -363,9 +363,9 @@ class RotatedBoxes(Boxes): if isinstance(item, int): return RotatedBoxes(self.tensor[item].view(1, -1)) b = self.tensor[item] - assert ( - b.dim() == 2 - ), "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item) + assert b.dim() == 2, ( + "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item) + ) return RotatedBoxes(b) def __len__(self) -> int: diff --git a/sam3/agent/helpers/visualizer.py b/sam3/agent/helpers/visualizer.py index c6ce032..2050652 100755 --- a/sam3/agent/helpers/visualizer.py +++ b/sam3/agent/helpers/visualizer.py @@ -20,7 +20,6 @@ from matplotlib.backends.backend_agg import FigureCanvasAgg from PIL import Image from .boxes import Boxes, BoxMode - from .color_map import random_color from .keypoints import Keypoints from .masks import BitMasks, PolygonMasks @@ -222,9 +221,9 @@ class _PanopticPrediction: empty_ids.append(id) if len(empty_ids) == 0: return np.zeros(self._seg.shape, dtype=np.uint8) - assert ( - len(empty_ids) == 1 - ), ">1 ids corresponds to no labels. This is currently not supported" + assert len(empty_ids) == 1, ( + ">1 ids corresponds to no labels. This is currently not supported" + ) return (self._seg != empty_ids[0]).numpy().astype(np.bool) def semantic_masks(self): diff --git a/sam3/agent/inference.py b/sam3/agent/inference.py index 85167d1..01f1b63 100644 --- a/sam3/agent/inference.py +++ b/sam3/agent/inference.py @@ -41,7 +41,7 @@ def run_single_image_inference( print(f"Output JSON {output_json_path} already exists. Skipping.") return - print(f"{'-'*30} Starting SAM 3 Agent Session... {'-'*30} ") + print(f"{'-' * 30} Starting SAM 3 Agent Session... {'-' * 30} ") agent_history, final_output_dict, rendered_final_output = agent_inference( image_path, text_prompt, @@ -50,7 +50,7 @@ def run_single_image_inference( output_dir=output_dir, debug=debug, ) - print(f"{'-'*30} End of SAM 3 Agent Session... {'-'*30} ") + print(f"{'-' * 30} End of SAM 3 Agent Session... {'-' * 30} ") final_output_dict["text_prompt"] = text_prompt final_output_dict["image_path"] = image_path diff --git a/sam3/agent/viz.py b/sam3/agent/viz.py index 523e246..f5a7867 100644 --- a/sam3/agent/viz.py +++ b/sam3/agent/viz.py @@ -73,7 +73,9 @@ def visualize( idx = int(zoom_in_index) num_masks = len(input_json.get("pred_masks", [])) if idx < 0 or idx >= num_masks: - raise ValueError(f"zoom_in_index {idx} is out of range (0..{num_masks-1}).") + raise ValueError( + f"zoom_in_index {idx} is out of range (0..{num_masks - 1})." + ) # (1) Replicate zoom_in_and_visualize object_data = { diff --git a/sam3/eval/cgf1_eval.py b/sam3/eval/cgf1_eval.py index 71fe2ea..cbc91c1 100644 --- a/sam3/eval/cgf1_eval.py +++ b/sam3/eval/cgf1_eval.py @@ -126,9 +126,9 @@ class COCOCustom(COCO): # MODIFICATION: faster and cached subset check if not hasattr(self, "img_id_set"): self.img_id_set = set(self.getImgIds()) - assert set(annsImgIds).issubset( - self.img_id_set - ), "Results do not correspond to current coco set" + assert set(annsImgIds).issubset(self.img_id_set), ( + "Results do not correspond to current coco set" + ) # END MODIFICATION if "caption" in anns[0]: imgIds = set([img["id"] for img in res.dataset["images"]]) & set( @@ -301,9 +301,9 @@ class CGF1Eval(COCOeval): TP = (match_scores >= thresh).sum() FP = len(dt) - TP FN = len(gt) - TP - assert ( - FP >= 0 and FN >= 0 - ), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}" + assert FP >= 0 and FN >= 0, ( + f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}" + ) TPs.append(TP) FPs.append(FP) FNs.append(FN) @@ -599,9 +599,9 @@ class CGF1Evaluator: """ assert len(self.coco_gts) > 0, "No ground truth provided for evaluation." - assert len(self.coco_gts) == len( - self.coco_evals - ), "Mismatch in number of ground truths and evaluators." + assert len(self.coco_gts) == len(self.coco_evals), ( + "Mismatch in number of ground truths and evaluators." + ) if self.verbose: print(f"Loading predictions from {pred_file}") @@ -668,17 +668,17 @@ class CGF1Evaluator: if len(scorings) == 1: return scorings[0] - assert ( - scorings[0].ndim == 3 - ), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}" - assert ( - scorings[0].shape[0] == 1 - ), f"Expecting a single category, got {scorings[0].shape[0]}" + assert scorings[0].ndim == 3, ( + f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}" + ) + assert scorings[0].shape[0] == 1, ( + f"Expecting a single category, got {scorings[0].shape[0]}" + ) for scoring in scorings: - assert ( - scoring.shape == scorings[0].shape - ), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}" + assert scoring.shape == scorings[0].shape, ( + f"Shape mismatch: {scoring.shape}, {scorings[0].shape}" + ) selected_imgs = [] for img_id in range(scorings[0].shape[-1]): diff --git a/sam3/eval/coco_eval.py b/sam3/eval/coco_eval.py index fbb82a0..7eee615 100644 --- a/sam3/eval/coco_eval.py +++ b/sam3/eval/coco_eval.py @@ -18,19 +18,15 @@ import os import pickle from collections import defaultdict from pathlib import Path - from typing import Any, List, Optional import numpy as np - import pycocotools.mask as mask_utils import torch from iopath.common.file_io import g_pathmgr from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval - from sam3.train.masks_ops import rle_encode - from sam3.train.utils.distributed import ( all_gather, gather_to_rank_0_via_filesys, @@ -755,9 +751,9 @@ def loadRes(self, resFile): anns = resFile assert type(anns) == list, "results in not an array of objects" annsImgIds = [ann["image_id"] for ann in anns] - assert set(annsImgIds) == ( - set(annsImgIds) & set(self.getImgIds()) - ), "Results do not correspond to current coco set" + assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), ( + "Results do not correspond to current coco set" + ) if "caption" in anns[0]: imgIds = set([img["id"] for img in res.dataset["images"]]) & set( [ann["image_id"] for ann in anns] diff --git a/sam3/eval/coco_writer.py b/sam3/eval/coco_writer.py index 54f9307..f49fcf2 100644 --- a/sam3/eval/coco_writer.py +++ b/sam3/eval/coco_writer.py @@ -83,9 +83,9 @@ class PredictionDumper: self.merge_predictions = merge_predictions self.pred_file_evaluators = pred_file_evaluators if self.pred_file_evaluators is not None: - assert ( - merge_predictions - ), "merge_predictions must be True if pred_file_evaluators are provided" + assert merge_predictions, ( + "merge_predictions must be True if pred_file_evaluators are provided" + ) assert self.dump_dir is not None, "dump_dir must be provided" if is_main_process(): diff --git a/sam3/eval/demo_eval.py b/sam3/eval/demo_eval.py index a6076ad..6ac7063 100644 --- a/sam3/eval/demo_eval.py +++ b/sam3/eval/demo_eval.py @@ -13,11 +13,9 @@ from typing import Optional import numpy as np import pycocotools.mask as maskUtils from pycocotools.cocoeval import COCOeval - from sam3.eval.coco_eval import CocoEvaluator from sam3.train.masks_ops import compute_F_measure from sam3.train.utils.distributed import is_main_process - from scipy.optimize import linear_sum_assignment @@ -156,9 +154,9 @@ class DemoEval(COCOeval): TP = (match_scores >= thresh).sum() FP = len(dt) - TP FN = len(gt) - TP - assert ( - FP >= 0 and FN >= 0 - ), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}" + assert FP >= 0 and FN >= 0, ( + f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}" + ) TPs.append(TP) FPs.append(FP) FNs.append(FN) @@ -528,17 +526,17 @@ class DemoEvaluator(CocoEvaluator): if len(scorings) == 1: return scorings[0] - assert ( - scorings[0].ndim == 3 - ), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}" - assert ( - scorings[0].shape[0] == 1 - ), f"Expecting a single category, got {scorings[0].shape[0]}" + assert scorings[0].ndim == 3, ( + f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}" + ) + assert scorings[0].shape[0] == 1, ( + f"Expecting a single category, got {scorings[0].shape[0]}" + ) for scoring in scorings: - assert ( - scoring.shape == scorings[0].shape - ), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}" + assert scoring.shape == scorings[0].shape, ( + f"Shape mismatch: {scoring.shape}, {scorings[0].shape}" + ) selected_imgs = [] for img_id in range(scorings[0].shape[-1]): diff --git a/sam3/eval/hota_eval_toolkit/trackeval/eval.py b/sam3/eval/hota_eval_toolkit/trackeval/eval.py index ed32ff1..7dcdfd5 100644 --- a/sam3/eval/hota_eval_toolkit/trackeval/eval.py +++ b/sam3/eval/hota_eval_toolkit/trackeval/eval.py @@ -255,9 +255,10 @@ class Evaluator: if show_progressbar and TQDM_IMPORTED: seq_list_sorted = sorted(seq_list) - with Pool(config["NUM_PARALLEL_CORES"]) as pool, tqdm.tqdm( - total=len(seq_list) - ) as pbar: + with ( + Pool(config["NUM_PARALLEL_CORES"]) as pool, + tqdm.tqdm(total=len(seq_list)) as pbar, + ): _eval_sequence = partial( eval_sequence, dataset=dataset, diff --git a/sam3/eval/postprocessors.py b/sam3/eval/postprocessors.py index 7bd26a9..44eb103 100644 --- a/sam3/eval/postprocessors.py +++ b/sam3/eval/postprocessors.py @@ -83,9 +83,9 @@ class PostProcessImage(nn.Module): ret_tensordict: Experimental argument. If true, return a tensordict.TensorDict instead of a list of dictionaries for easier manipulation. """ if ret_tensordict: - assert ( - consistent is True - ), "We don't support returning TensorDict if the outputs have different shapes" # NOTE: It's possible but we don't support it. + assert consistent is True, ( + "We don't support returning TensorDict if the outputs have different shapes" + ) # NOTE: It's possible but we don't support it. assert self.detection_threshold <= 0.0, "TODO: implement?" try: from tensordict import TensorDict @@ -118,7 +118,9 @@ class PostProcessImage(nn.Module): if boxes is None: assert out_masks is not None - assert not ret_tensordict, "We don't support returning TensorDict if the output does not contain boxes" + assert not ret_tensordict, ( + "We don't support returning TensorDict if the output does not contain boxes" + ) B = len(out_masks) boxes = [None] * B scores = [None] * B @@ -418,9 +420,9 @@ class PostProcessAPIVideo(PostProcessImage): if video_id == -1: video_id = unique_vid_id.item() else: - assert ( - video_id == unique_vid_id.item() - ), "We can only postprocess one video per datapoint" + assert video_id == unique_vid_id.item(), ( + "We can only postprocess one video per datapoint" + ) # keeping track of which objects appear in the current frame obj_ids_per_frame = frame_outs["pred_object_ids"] assert obj_ids_per_frame.size(-1) == frame_outs["pred_logits"].size(-2) diff --git a/sam3/eval/ytvis_coco_wrapper.py b/sam3/eval/ytvis_coco_wrapper.py index 5412d69..25feda4 100644 --- a/sam3/eval/ytvis_coco_wrapper.py +++ b/sam3/eval/ytvis_coco_wrapper.py @@ -95,9 +95,9 @@ class YTVIS(COCO): anns = resFile assert type(anns) == list, "results is not an array of objects" annsImgIds = [ann["image_id"] for ann in anns] - assert set(annsImgIds) == ( - set(annsImgIds) & set(self.getImgIds()) - ), "Results do not correspond to current coco set" + assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), ( + "Results do not correspond to current coco set" + ) if "bboxes" in anns[0] and not anns[0]["bboxes"] == []: res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) for id, ann in enumerate(anns): diff --git a/sam3/eval/ytvis_eval.py b/sam3/eval/ytvis_eval.py index 2ff2309..0754b62 100644 --- a/sam3/eval/ytvis_eval.py +++ b/sam3/eval/ytvis_eval.py @@ -109,9 +109,7 @@ class YTVISevalMixin: ) # Num preds x Num GTS x Num frames inter = inter.sum(-1) union = union.sum(-1) - assert ( - union > 0 - ).all(), ( + assert (union > 0).all(), ( "There exists a tracklet with zero GTs across time. This is suspicious" ) return inter / union @@ -136,9 +134,9 @@ class YTVISevalMixin: iou = inter / union assert iou >= 0 and iou <= 1, "Encountered an error in IoU computation" else: - assert np.isclose(inter, 0) and np.isclose( - union, 0 - ), "Encountered an error in IoU computation" + assert np.isclose(inter, 0) and np.isclose(union, 0), ( + "Encountered an error in IoU computation" + ) iou = 1 return iou @@ -206,16 +204,16 @@ class YTVISResultsWriter: if len(prediction) == 0: continue for k in ["boxes", "scores", "labels"]: - assert ( - k in prediction - ), f"Expected predictions to have `{k}` key, available keys are {prediction.keys()}" + assert k in prediction, ( + f"Expected predictions to have `{k}` key, available keys are {prediction.keys()}" + ) if self.save_per_frame_scores: - assert ( - "per_frame_scores" in prediction - ), f"Expected predictions to have `per_frame_scores` key, available keys are {prediction.keys()}" - assert xor( - "masks" in prediction, "masks_rle" in prediction - ), f"Expected predictions to have either `masks` key or `masks_rle` key, available keys are {prediction.keys()}" + assert "per_frame_scores" in prediction, ( + f"Expected predictions to have `per_frame_scores` key, available keys are {prediction.keys()}" + ) + assert xor("masks" in prediction, "masks_rle" in prediction), ( + f"Expected predictions to have either `masks` key or `masks_rle` key, available keys are {prediction.keys()}" + ) boxes = prediction["boxes"] boxes = convert_to_xywh(boxes).tolist() @@ -223,9 +221,9 @@ class YTVISResultsWriter: labels = prediction["labels"].tolist() if "masks" in prediction: masks = prediction["masks"].squeeze(2) - assert ( - masks.ndim == 4 - ), "Expected masks to be of shape(N_preds,T_frames,H,W)" + assert masks.ndim == 4, ( + "Expected masks to be of shape(N_preds,T_frames,H,W)" + ) areas = [mask.flatten(1).sum(1).tolist() for mask in masks] rles = [rle_encode(masklet) for masklet in masks] diff --git a/sam3/logger.py b/sam3/logger.py index 35dcc0d..2ce9c09 100644 --- a/sam3/logger.py +++ b/sam3/logger.py @@ -42,9 +42,9 @@ def get_logger(name, level=logging.INFO): """A command line logger.""" if "LOG_LEVEL" in os.environ: level = os.environ["LOG_LEVEL"].upper() - assert ( - level in LOG_LEVELS - ), f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}" + assert level in LOG_LEVELS, ( + f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}" + ) level = LOG_LEVELS[level] logger = logging.getLogger(name) logger.setLevel(level) diff --git a/sam3/model/data_misc.py b/sam3/model/data_misc.py index 8f2efa9..298340d 100644 --- a/sam3/model/data_misc.py +++ b/sam3/model/data_misc.py @@ -7,7 +7,6 @@ Misc functions, including distributed helpers. import collections import re - from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union @@ -29,9 +28,9 @@ def interpolate( input, size, scale_factor, mode, align_corners ) - assert ( - input.shape[0] != 0 or input.shape[1] != 0 - ), "At least one of the two first dimensions must be non zero" + assert input.shape[0] != 0 or input.shape[1] != 0, ( + "At least one of the two first dimensions must be non zero" + ) if input.shape[1] == 0: # Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim diff --git a/sam3/model/decoder.py b/sam3/model/decoder.py index c074b9c..7a204be 100644 --- a/sam3/model/decoder.py +++ b/sam3/model/decoder.py @@ -9,18 +9,13 @@ Inspired from Pytorch's version, adds the pre-norm variant from typing import Any, Dict, List, Optional import numpy as np - import torch - from sam3.sam.transformer import RoPEAttention - from torch import nn, Tensor from torchvision.ops.roi_align import RoIAlign from .act_ckpt_utils import activation_ckpt_wrapper - from .box_ops import box_cxcywh_to_xyxy - from .model_misc import ( gen_sineembed_for_position, get_activation_fn, @@ -444,9 +439,9 @@ class TransformerDecoder(nn.Module): - valid_ratios/spatial_shapes: bs, nlevel, 2 """ if memory_mask is not None: - assert ( - self.boxRPB == "none" - ), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented" + assert self.boxRPB == "none", ( + "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented" + ) apply_dac = apply_dac if apply_dac is not None else self.dac if apply_dac: @@ -516,18 +511,18 @@ class TransformerDecoder(nn.Module): query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model if self.boxRPB != "none" and reference_boxes is not None: - assert ( - spatial_shapes.shape[0] == 1 - ), "only single scale support implemented" + assert spatial_shapes.shape[0] == 1, ( + "only single scale support implemented" + ) memory_mask = self._get_rpb_matrix( reference_boxes, (spatial_shapes[0, 0], spatial_shapes[0, 1]), ) memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W) if self.training: - assert ( - self.use_act_checkpoint - ), "Activation checkpointing not enabled in the decoder" + assert self.use_act_checkpoint, ( + "Activation checkpointing not enabled in the decoder" + ) output, presence_out = activation_ckpt_wrapper(layer)( tgt=output, tgt_query_pos=query_pos, @@ -676,9 +671,9 @@ class TransformerEncoderCrossAttention(nn.Module): src_pos[0], ) - assert ( - src.shape[1] == prompt.shape[1] - ), "Batch size must be the same for src and prompt" + assert src.shape[1] == prompt.shape[1], ( + "Batch size must be the same for src and prompt" + ) output = src diff --git a/sam3/model/encoder.py b/sam3/model/encoder.py index d825df4..3fc9406 100644 --- a/sam3/model/encoder.py +++ b/sam3/model/encoder.py @@ -322,9 +322,9 @@ class TransformerEncoder(nn.Module): return reference_points def _prepare_multilevel_features(self, srcs, masks, pos_embeds): - assert ( - len(srcs) == self.num_feature_levels - ), "mismatch between expected and received # of feature levels" + assert len(srcs) == self.num_feature_levels, ( + "mismatch between expected and received # of feature levels" + ) src_flatten = [] mask_flatten = [] @@ -406,9 +406,9 @@ class TransformerEncoder(nn.Module): - spatial_shapes: Spatial dimensions of each feature level - valid_ratios: Valid ratios for each feature level """ - assert ( - len(src) == self.num_feature_levels - ), "must be equal to num_feature_levels" + assert len(src) == self.num_feature_levels, ( + "must be equal to num_feature_levels" + ) if src_key_padding_masks is not None: assert len(src_key_padding_masks) == self.num_feature_levels if pos is not None: @@ -538,9 +538,9 @@ class TransformerEncoderFusion(TransformerEncoder): else None ) else: - assert all( - x.dim == 4 for x in src - ), "expected list of (bs, c, h, w) tensors" + assert all(x.dim == 4 for x in src), ( + "expected list of (bs, c, h, w) tensors" + ) if self.add_pooled_text_to_img_feat: # Fusion: Add mean pooled text to image features diff --git a/sam3/model/geometry_encoders.py b/sam3/model/geometry_encoders.py index acd9a15..d60ee54 100644 --- a/sam3/model/geometry_encoders.py +++ b/sam3/model/geometry_encoders.py @@ -11,7 +11,6 @@ from typing_extensions import override from .act_ckpt_utils import activation_ckpt_wrapper from .box_ops import box_cxcywh_to_xyxy - from .model_misc import get_clones @@ -148,54 +147,42 @@ class Prompt: ) # Dimension checks - assert ( - box_embeddings is not None - and list(box_embeddings.shape[:2]) - == [ - box_seq_len, - bs, - ] - ), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}" - assert ( - box_mask is not None - and list(box_mask.shape) - == [ - bs, - box_seq_len, - ] - ), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}" - assert ( - point_embeddings is not None - and list(point_embeddings.shape[:2]) - == [ - point_seq_len, - bs, - ] - ), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}" - assert ( - point_mask is not None - and list(point_mask.shape) - == [ - bs, - point_seq_len, - ] - ), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}" - assert ( - box_labels is not None - and list(box_labels.shape) - == [ - box_seq_len, - bs, - ] - ), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}" - assert ( - point_labels is not None - and list(point_labels.shape) - == [ - point_seq_len, - bs, - ] - ), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}" + assert box_embeddings is not None and list(box_embeddings.shape[:2]) == [ + box_seq_len, + bs, + ], ( + f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}" + ) + assert box_mask is not None and list(box_mask.shape) == [ + bs, + box_seq_len, + ], ( + f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}" + ) + assert point_embeddings is not None and list(point_embeddings.shape[:2]) == [ + point_seq_len, + bs, + ], ( + f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}" + ) + assert point_mask is not None and list(point_mask.shape) == [ + bs, + point_seq_len, + ], ( + f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}" + ) + assert box_labels is not None and list(box_labels.shape) == [ + box_seq_len, + bs, + ], ( + f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}" + ) + assert point_labels is not None and list(point_labels.shape) == [ + point_seq_len, + bs, + ], ( + f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}" + ) assert ( # Allowed to be None, we leave it to the encoder to check for validity before encoding. mask_embeddings is None @@ -204,41 +191,41 @@ class Prompt: mask_seq_len, bs, ] - ), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}" - assert ( - mask_mask is None - or list(mask_mask.shape) - == [ - bs, - mask_seq_len, - ] - ), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}" + ), ( + f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}" + ) + assert mask_mask is None or list(mask_mask.shape) == [ + bs, + mask_seq_len, + ], ( + f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}" + ) # Device checks - assert ( - box_embeddings is not None and box_embeddings.device == device - ), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}" - assert ( - box_mask is not None and box_mask.device == device - ), f"Expected box mask to be on device {device}, got {box_mask.device}" - assert ( - box_labels is not None and box_labels.device == device - ), f"Expected box labels to be on device {device}, got {box_labels.device}" - assert ( - point_embeddings is not None and point_embeddings.device == device - ), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}" - assert ( - point_mask is not None and point_mask.device == device - ), f"Expected point mask to be on device {device}, got {point_mask.device}" - assert ( - point_labels is not None and point_labels.device == device - ), f"Expected point labels to be on device {device}, got {point_labels.device}" - assert ( - mask_embeddings is None or mask_embeddings.device == device - ), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}" - assert ( - mask_mask is None or mask_mask.device == device - ), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}" + assert box_embeddings is not None and box_embeddings.device == device, ( + f"Expected box embeddings to be on device {device}, got {box_embeddings.device}" + ) + assert box_mask is not None and box_mask.device == device, ( + f"Expected box mask to be on device {device}, got {box_mask.device}" + ) + assert box_labels is not None and box_labels.device == device, ( + f"Expected box labels to be on device {device}, got {box_labels.device}" + ) + assert point_embeddings is not None and point_embeddings.device == device, ( + f"Expected point embeddings to be on device {device}, got {point_embeddings.device}" + ) + assert point_mask is not None and point_mask.device == device, ( + f"Expected point mask to be on device {device}, got {point_mask.device}" + ) + assert point_labels is not None and point_labels.device == device, ( + f"Expected point labels to be on device {device}, got {point_labels.device}" + ) + assert mask_embeddings is None or mask_embeddings.device == device, ( + f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}" + ) + assert mask_mask is None or mask_mask.device == device, ( + f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}" + ) self.box_embeddings = box_embeddings self.point_embeddings = point_embeddings @@ -264,30 +251,30 @@ class Prompt: if point_embeddings is not None: point_seq_len = point_embeddings.shape[0] if bs is not None: - assert ( - bs == point_embeddings.shape[1] - ), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}." + assert bs == point_embeddings.shape[1], ( + f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}." + ) else: bs = point_embeddings.shape[1] if device is not None: - assert ( - device == point_embeddings.device - ), "Device mismatch between box and point embeddings" + assert device == point_embeddings.device, ( + "Device mismatch between box and point embeddings" + ) else: device = point_embeddings.device if mask_embeddings is not None: mask_seq_len = mask_embeddings.shape[0] if bs is not None: - assert ( - bs == mask_embeddings.shape[1] - ), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}" + assert bs == mask_embeddings.shape[1], ( + f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}" + ) else: bs = mask_embeddings.shape[1] if device is not None: - assert ( - device == mask_embeddings.device - ), "Device mismatch between box/point and mask embeddings." + assert device == mask_embeddings.device, ( + "Device mismatch between box/point and mask embeddings." + ) else: device = mask_embeddings.device @@ -539,9 +526,9 @@ class SequenceGeometryEncoder(nn.Module): if add_cls: self.cls_embed = torch.nn.Embedding(1, self.d_model) - assert ( - points_direct_project or points_pos_enc or points_pool - ), "Error: need at least one way to encode points" + assert points_direct_project or points_pos_enc or points_pool, ( + "Error: need at least one way to encode points" + ) assert ( encode_boxes_as_points or boxes_direct_project @@ -583,16 +570,16 @@ class SequenceGeometryEncoder(nn.Module): self.encode = None if num_layers > 0: - assert ( - add_cls - ), "It's currently highly recommended to add a CLS when using a transformer" + assert add_cls, ( + "It's currently highly recommended to add a CLS when using a transformer" + ) self.encode = get_clones(layer, num_layers) self.encode_norm = nn.LayerNorm(self.d_model) if mask_encoder is not None: - assert isinstance( - mask_encoder, MaskEncoder - ), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}." + assert isinstance(mask_encoder, MaskEncoder), ( + f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}." + ) if add_mask_label: self.mask_label_embed = torch.nn.Embedding(2, self.d_model) self.add_mask_label = add_mask_label @@ -701,16 +688,15 @@ class SequenceGeometryEncoder(nn.Module): img_feats: torch.Tensor = None, ): n_masks, bs = masks.shape[:2] - assert ( - n_masks == 1 - ), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed." - assert ( - list(attn_mask.shape) - == [ - bs, - n_masks, - ] - ), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}." + assert n_masks == 1, ( + "We assume one mask per prompt for now. Code should still be functional if this assertion is removed." + ) + assert list(attn_mask.shape) == [ + bs, + n_masks, + ], ( + f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}." + ) masks, pos = self.mask_encoder( masks=masks.flatten(0, 1).float(), pix_feat=img_feats, diff --git a/sam3/model/io_utils.py b/sam3/model/io_utils.py index 082ba23..067f125 100644 --- a/sam3/model/io_utils.py +++ b/sam3/model/io_utils.py @@ -13,9 +13,7 @@ import numpy as np import torch import torch.nn.functional as F import torchvision.transforms.functional as TF - from PIL import Image - from sam3.logger import get_logger from tqdm import tqdm diff --git a/sam3/model/maskformer_segmentation.py b/sam3/model/maskformer_segmentation.py index 8790d7a..a2d5c68 100644 --- a/sam3/model/maskformer_segmentation.py +++ b/sam3/model/maskformer_segmentation.py @@ -248,7 +248,9 @@ class UniversalSegmentationHead(SegmentationHead): self.d_model = hidden_dim if dot_product_scorer is not None: - assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake" + assert presence_head, ( + "Specifying a dot product scorer without a presence head is likely a mistake" + ) self.presence_head = None if presence_head: diff --git a/sam3/model/memory.py b/sam3/model/memory.py index 397e1c8..196dbf9 100644 --- a/sam3/model/memory.py +++ b/sam3/model/memory.py @@ -62,9 +62,9 @@ class SimpleMaskDownSampler(nn.Module): self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) self.interpol_size = interpol_size if self.interpol_size is not None: - assert isinstance( - self.interpol_size, (list, tuple) - ), f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple." + assert isinstance(self.interpol_size, (list, tuple)), ( + f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple." + ) self.interpol_size = list(interpol_size) assert len(self.interpol_size) == 2 diff --git a/sam3/model/model_misc.py b/sam3/model/model_misc.py index 9fe38a1..d961461 100644 --- a/sam3/model/model_misc.py +++ b/sam3/model/model_misc.py @@ -330,9 +330,9 @@ class SAM3Output(list): self.output = output else: self.output = [] - assert isinstance( - iter_mode, SAM3Output.IterMode - ), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}" + assert isinstance(iter_mode, SAM3Output.IterMode), ( + f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}" + ) self.iter_mode = iter_mode # We create a weak reference to self to be used in the lambda functions. @@ -411,9 +411,9 @@ class SAM3Output(list): return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode) def append(self, item: list): - assert isinstance( - item, list - ), f"Only list items are supported. Got {type(item)}" + assert isinstance(item, list), ( + f"Only list items are supported. Got {type(item)}" + ) self.output.append(item) def __repr__(self): diff --git a/sam3/model/necks.py b/sam3/model/necks.py index 5a30825..c60f87f 100644 --- a/sam3/model/necks.py +++ b/sam3/model/necks.py @@ -8,7 +8,6 @@ from copy import deepcopy from typing import List, Optional, Tuple import torch - import torch.nn as nn diff --git a/sam3/model/sam1_task_predictor.py b/sam3/model/sam1_task_predictor.py index 1ca18bd..5cf0fde 100644 --- a/sam3/model/sam1_task_predictor.py +++ b/sam3/model/sam1_task_predictor.py @@ -7,15 +7,12 @@ # LICENSE file in the root directory of this source tree. import logging - from typing import List, Optional, Tuple, Union import numpy as np import torch - import torch.nn as nn from PIL.Image import Image - from sam3.model.sam3_tracker_base import Sam3TrackerBase from sam3.model.utils.sam1_utils import SAM2Transforms @@ -97,9 +94,9 @@ class SAM3InteractiveImagePredictor(nn.Module): input_image = self._transforms(image) input_image = input_image[None, ...].to(self.device) - assert ( - len(input_image.shape) == 4 and input_image.shape[1] == 3 - ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" + assert len(input_image.shape) == 4 and input_image.shape[1] == 3, ( + f"input_image must be of size 1x3xHxW, got {input_image.shape}" + ) logging.info("Computing image embeddings for the provided image...") backbone_out = self.model.forward_image(input_image) ( @@ -136,17 +133,17 @@ class SAM3InteractiveImagePredictor(nn.Module): assert isinstance(image_list, list) self._orig_hw = [] for image in image_list: - assert isinstance( - image, np.ndarray - ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + assert isinstance(image, np.ndarray), ( + "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + ) self._orig_hw.append(image.shape[:2]) # Transform the image to the form expected by the model img_batch = self._transforms.forward_batch(image_list) img_batch = img_batch.to(self.device) batch_size = img_batch.shape[0] - assert ( - len(img_batch.shape) == 4 and img_batch.shape[1] == 3 - ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + assert len(img_batch.shape) == 4 and img_batch.shape[1] == 3, ( + f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + ) logging.info("Computing image embeddings for the provided images...") backbone_out = self.model.forward_image(img_batch) ( @@ -302,9 +299,9 @@ class SAM3InteractiveImagePredictor(nn.Module): ): unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None if point_coords is not None: - assert ( - point_labels is not None - ), "point_labels must be supplied if point_coords is supplied." + assert point_labels is not None, ( + "point_labels must be supplied if point_coords is supplied." + ) point_coords = torch.as_tensor( point_coords, dtype=torch.float, device=self.device ) @@ -441,9 +438,9 @@ class SAM3InteractiveImagePredictor(nn.Module): raise RuntimeError( "An image must be set with .set_image(...) to generate an embedding." ) - assert ( - self._features is not None - ), "Features must exist if an image has been set." + assert self._features is not None, ( + "Features must exist if an image has been set." + ) return self._features["image_embed"] @property diff --git a/sam3/model/sam3_image.py b/sam3/model/sam3_image.py index 3e4f9de..679300d 100644 --- a/sam3/model/sam3_image.py +++ b/sam3/model/sam3_image.py @@ -8,19 +8,14 @@ from typing import Dict, List, Optional, Tuple import numpy as np import torch - from sam3.model.model_misc import SAM3Output - from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor from sam3.model.vl_combiner import SAM3VLBackbone from sam3.perflib.nms import nms_masks - from sam3.train.data.collator import BatchedDatapoint from .act_ckpt_utils import activation_ckpt_wrapper - from .box_ops import box_cxcywh_to_xyxy - from .geometry_encoders import Prompt from .model_misc import inverse_sigmoid @@ -661,9 +656,9 @@ class Sam3Image(torch.nn.Module): inference_state["original_heights"], inference_state["original_widths"], ) - assert ( - batch_size == len(orig_heights) == len(orig_widths) - ), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}" + assert batch_size == len(orig_heights) == len(orig_widths), ( + f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}" + ) feats = [ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) for feat, feat_size in zip( diff --git a/sam3/model/sam3_image_processor.py b/sam3/model/sam3_image_processor.py index 7a55e02..83bbb0d 100644 --- a/sam3/model/sam3_image_processor.py +++ b/sam3/model/sam3_image_processor.py @@ -6,9 +6,7 @@ from typing import Dict, List import numpy as np import PIL import torch - from sam3.model import box_ops - from sam3.model.data_misc import FindStage, interpolate from torchvision.transforms import v2 @@ -83,9 +81,9 @@ class Sam3Processor: if not isinstance(images, list): raise ValueError("Images must be a list of PIL images or tensors") assert len(images) > 0, "Images list must not be empty" - assert isinstance( - images[0], PIL.Image.Image - ), "Images must be a list of PIL images" + assert isinstance(images[0], PIL.Image.Image), ( + "Images must be a list of PIL images" + ) state["original_heights"] = [image.height for image in images] state["original_widths"] = [image.width for image in images] diff --git a/sam3/model/sam3_tracker_base.py b/sam3/model/sam3_tracker_base.py index a5c557d..c7f40b7 100644 --- a/sam3/model/sam3_tracker_base.py +++ b/sam3/model/sam3_tracker_base.py @@ -6,11 +6,8 @@ import logging import torch import torch.nn.functional as F - from sam3.model.memory import SimpleMaskEncoder - from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames - from sam3.sam.mask_decoder import MaskDecoder, MLP from sam3.sam.prompt_encoder import PromptEncoder from sam3.sam.transformer import TwoWayTransformer diff --git a/sam3/model/sam3_tracker_utils.py b/sam3/model/sam3_tracker_utils.py index e88c093..e971dac 100644 --- a/sam3/model/sam3_tracker_utils.py +++ b/sam3/model/sam3_tracker_utils.py @@ -6,7 +6,6 @@ import numpy as np import torch import torch.nn.functional as F from numpy.typing import NDArray - from sam3.model.edt import edt_triton diff --git a/sam3/model/sam3_tracking_predictor.py b/sam3/model/sam3_tracking_predictor.py index 0fb3ff7..43b068b 100644 --- a/sam3/model/sam3_tracking_predictor.py +++ b/sam3/model/sam3_tracking_predictor.py @@ -6,7 +6,6 @@ import logging from collections import OrderedDict import torch - from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores from sam3.model.utils.sam2_utils import load_video_frames diff --git a/sam3/model/sam3_video_base.py b/sam3/model/sam3_video_base.py index 78e1044..8780f1a 100644 --- a/sam3/model/sam3_video_base.py +++ b/sam3/model/sam3_video_base.py @@ -16,7 +16,6 @@ import numpy.typing as npt import torch import torch.distributed as dist import torch.nn.functional as F - from sam3 import perflib from sam3.logger import get_logger from sam3.model.box_ops import fast_diag_box_iou @@ -620,9 +619,9 @@ class Sam3VideoBase(nn.Module): num_obj_dropped_due_to_limit, trk_id_to_max_iou_high_conf_det, ] - assert ( - len(update_plan) == NUM_BROADCAST_ITEMS - ), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}" + assert len(update_plan) == NUM_BROADCAST_ITEMS, ( + f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}" + ) self.broadcast_python_obj_cpu(update_plan, src=0) elif self.rank > 0 and self.world_size > 1: update_plan = [ @@ -842,9 +841,9 @@ class Sam3VideoBase(nn.Module): binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0 batch_size = tracker_low_res_masks_global.size(0) if batch_size > 0: - assert ( - len(obj_ids_global) == batch_size - ), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}" + assert len(obj_ids_global) == batch_size, ( + f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}" + ) NEVER_OCCLUDED = -1 ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic last_occluded_prev = torch.cat( @@ -1023,9 +1022,9 @@ class Sam3VideoBase(nn.Module): reverse: bool = False, ): # Suppress overlapping masks for objects that were most recently occluded - assert ( - binary_low_res_masks.dtype == torch.bool - ), f"Expected boolean tensor, got {binary_low_res_masks.dtype}" + assert binary_low_res_masks.dtype == torch.bool, ( + f"Expected boolean tensor, got {binary_low_res_masks.dtype}" + ) to_suppress = torch.zeros( binary_low_res_masks.size(0), device=binary_low_res_masks.device, @@ -1130,9 +1129,9 @@ class Sam3VideoBase(nn.Module): num_frames_propagated += 1 # only 1 frames should be propagated - assert ( - num_frames_propagated == 1 and out_frame_idx == frame_idx - ), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}" + assert num_frames_propagated == 1 and out_frame_idx == frame_idx, ( + f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}" + ) assert isinstance(out_obj_ids, list) obj_ids_local.extend(out_obj_ids) low_res_masks_list.append(out_low_res_masks.squeeze(1)) @@ -1189,9 +1188,9 @@ class Sam3VideoBase(nn.Module): assert det_masks.is_floating_point(), "float tensor expected (do not binarize)" assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)" - assert ( - trk_masks.size(0) == len(trk_obj_ids) - ), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}" + assert trk_masks.size(0) == len(trk_obj_ids), ( + f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}" + ) if trk_masks.size(0) == 0: # all detections are new new_det_fa_inds = np.arange(det_masks.size(0)) @@ -1655,9 +1654,9 @@ class Sam3VideoBase(nn.Module): # a) first, expand "confirmation_data" to include new masklets added in this frame status_prev = confirmation_data["status"] consecutive_det_num_prev = confirmation_data["consecutive_det_num"] - assert ( - status_prev.shape == obj_ids_all_gpu_prev.shape - ), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}" + assert status_prev.shape == obj_ids_all_gpu_prev.shape, ( + f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}" + ) obj_id_to_updated_idx = { obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated) diff --git a/sam3/model/sam3_video_inference.py b/sam3/model/sam3_video_inference.py index 263b4d2..6f031be 100644 --- a/sam3/model/sam3_video_inference.py +++ b/sam3/model/sam3_video_inference.py @@ -9,7 +9,6 @@ import numpy as np import torch import torch.distributed as dist import torch.nn.functional as F - from sam3 import perflib from sam3.logger import get_logger from sam3.model.act_ckpt_utils import clone_output_wrapper @@ -555,7 +554,9 @@ class Sam3VideoInference(Sam3VideoBase): assert ( "cached_frame_outputs" in inference_state and frame_idx in inference_state["cached_frame_outputs"] - ), "No cached outputs found. Ensure normal propagation has run first to populate the cache." + ), ( + "No cached outputs found. Ensure normal propagation has run first to populate the cache." + ) cached_outputs = inference_state["cached_frame_outputs"][frame_idx] obj_id_to_mask = cached_outputs.copy() @@ -563,9 +564,9 @@ class Sam3VideoInference(Sam3VideoBase): # Update with refined masks if provided if refined_obj_id_to_mask is not None: for obj_id, refined_mask in refined_obj_id_to_mask.items(): - assert ( - refined_mask is not None - ), f"Refined mask data must be provided for obj_id {obj_id}" + assert refined_mask is not None, ( + f"Refined mask data must be provided for obj_id {obj_id}" + ) obj_id_to_mask[obj_id] = refined_mask return obj_id_to_mask @@ -660,12 +661,12 @@ class Sam3VideoInference(Sam3VideoBase): for i, thresh in enumerate(new_det_score_thresh_list): self.new_det_thresh = thresh for num_objects in num_objects_list: - logger.info(f"{i+1}/{num_rounds} warming up model compilation") + logger.info(f"{i + 1}/{num_rounds} warming up model compilation") self.add_prompt( inference_state, frame_idx=start_frame_idx, text_str="cat" ) logger.info( - f"{i+1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects" + f"{i + 1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects" ) inference_state = self.add_fake_objects_to_inference_state( inference_state, num_objects, frame_idx=start_frame_idx @@ -690,7 +691,7 @@ class Sam3VideoInference(Sam3VideoBase): pass self.reset_state(inference_state) logger.info( - f"{i+1}/{num_rounds} warming up model compilation -- completed round {i+1} out of {num_rounds}" + f"{i + 1}/{num_rounds} warming up model compilation -- completed round {i + 1} out of {num_rounds}" ) # Warm up Tracker memory encoder with varying input shapes @@ -854,12 +855,12 @@ class Sam3VideoInference(Sam3VideoBase): logger.debug("Running add_prompt on frame %d", frame_idx) num_frames = inference_state["num_frames"] - assert ( - text_str is not None or boxes_xywh is not None - ), "at least one type of prompt (text, boxes) must be provided" - assert ( - 0 <= frame_idx < num_frames - ), f"{frame_idx=} is out of range for a total of {num_frames} frames" + assert text_str is not None or boxes_xywh is not None, ( + "at least one type of prompt (text, boxes) must be provided" + ) + assert 0 <= frame_idx < num_frames, ( + f"{frame_idx=} is out of range for a total of {num_frames} frames" + ) # since it's a semantic prompt, we start over self.reset_state(inference_state) @@ -1200,9 +1201,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference): "propagation_partial", "propagation_fetch", ] - assert ( - action_type in instance_actions + propagation_actions - ), f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}" + assert action_type in instance_actions + propagation_actions, ( + f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}" + ) action = { "type": action_type, "frame_idx": frame_idx, @@ -1370,12 +1371,12 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference): ): if points is not None: # Tracker instance prompts - assert ( - text_str is None and boxes_xywh is None - ), "When points are provided, text_str and boxes_xywh must be None." - assert ( - obj_id is not None - ), "When points are provided, obj_id must be provided." + assert text_str is None and boxes_xywh is None, ( + "When points are provided, text_str and boxes_xywh must be None." + ) + assert obj_id is not None, ( + "When points are provided, obj_id must be provided." + ) return self.add_tracker_new_points( inference_state, frame_idx, @@ -1491,9 +1492,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference): tracker_states = self._get_tracker_inference_states_by_obj_ids( inference_state, [obj_id] ) - assert ( - len(tracker_states) == 1 - ), f"[rank={self.rank}] Multiple Tracker inference states found for the same object id." + assert len(tracker_states) == 1, ( + f"[rank={self.rank}] Multiple Tracker inference states found for the same object id." + ) tracker_state = tracker_states[0] # log diff --git a/sam3/model/sam3_video_predictor.py b/sam3/model/sam3_video_predictor.py index de2ec60..13b1448 100644 --- a/sam3/model/sam3_video_predictor.py +++ b/sam3/model/sam3_video_predictor.py @@ -16,7 +16,6 @@ from typing import List, Optional import psutil import torch - from sam3.logger import get_logger logger = get_logger(__name__) @@ -170,7 +169,7 @@ class Sam3VideoPredictor: ): """Remove an object from tracking.""" logger.debug( - f"remove object {obj_id} in session {session_id}: " f"{is_user_action=}" + f"remove object {obj_id} in session {session_id}: {is_user_action=}" ) session = self._get_session(session_id) inference_state = session["state"] diff --git a/sam3/model/text_encoder_ve.py b/sam3/model/text_encoder_ve.py index 53ddd5d..5788358 100644 --- a/sam3/model/text_encoder_ve.py +++ b/sam3/model/text_encoder_ve.py @@ -318,9 +318,9 @@ class VETextEncoder(nn.Module): # The text is already encoded, use as is. text_attention_mask, text_memory_resized, tokenized = text inputs_embeds = tokenized["inputs_embeds"] - assert ( - input_boxes is None or len(input_boxes) == 0 - ), "Can't replace boxes in text if it's already encoded" + assert input_boxes is None or len(input_boxes) == 0, ( + "Can't replace boxes in text if it's already encoded" + ) # Note that the input_embeds are returned in pytorch's convention (sequence first) return ( diff --git a/sam3/model/vitdet.py b/sam3/model/vitdet.py index 1b4d41b..bc4eeb0 100644 --- a/sam3/model/vitdet.py +++ b/sam3/model/vitdet.py @@ -708,9 +708,9 @@ class ViT(nn.Module): self.retain_cls_token = retain_cls_token if self.retain_cls_token: assert pretrain_use_cls_token - assert ( - len(window_block_indexes) == 0 - ), "windowing not supported with cls token" + assert len(window_block_indexes) == 0, ( + "windowing not supported with cls token" + ) assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token" diff --git a/sam3/model/vl_combiner.py b/sam3/model/vl_combiner.py index 5c400e0..faf5504 100644 --- a/sam3/model/vl_combiner.py +++ b/sam3/model/vl_combiner.py @@ -9,7 +9,6 @@ from typing import List, Optional import torch import torch.nn as nn - from torch.nn.attention import sdpa_kernel, SDPBackend from .act_ckpt_utils import activation_ckpt_wrapper diff --git a/sam3/model_builder.py b/sam3/model_builder.py index 4df2671..103b324 100644 --- a/sam3/model_builder.py +++ b/sam3/model_builder.py @@ -6,7 +6,6 @@ import os from typing import Optional import pkg_resources - import torch import torch.nn as nn from huggingface_hub import hf_hub_download diff --git a/sam3/perflib/connected_components.py b/sam3/perflib/connected_components.py index 0be67e5..fa0506d 100644 --- a/sam3/perflib/connected_components.py +++ b/sam3/perflib/connected_components.py @@ -36,9 +36,9 @@ def connected_components_cpu(input_tensor: torch.Tensor): if input_tensor.dim() == 4 and input_tensor.shape[1] == 1: input_tensor = input_tensor.squeeze(1) else: - assert ( - input_tensor.dim() == 3 - ), "Input tensor must be (B, H, W) or (B, 1, H, W)." + assert input_tensor.dim() == 3, ( + "Input tensor must be (B, H, W) or (B, 1, H, W)." + ) batch_size = input_tensor.shape[0] labels_list = [] @@ -67,9 +67,9 @@ def connected_components(input_tensor: torch.Tensor): if input_tensor.dim() == 3: input_tensor = input_tensor.unsqueeze(1) - assert ( - input_tensor.dim() == 4 and input_tensor.shape[1] == 1 - ), "Input tensor must be (B, H, W) or (B, 1, H, W)." + assert input_tensor.dim() == 4 and input_tensor.shape[1] == 1, ( + "Input tensor must be (B, H, W) or (B, 1, H, W)." + ) if input_tensor.is_cuda: if HAS_CC_TORCH: diff --git a/sam3/perflib/nms.py b/sam3/perflib/nms.py index 1b2f835..acd9162 100644 --- a/sam3/perflib/nms.py +++ b/sam3/perflib/nms.py @@ -6,7 +6,6 @@ import logging import numpy as np import torch - from sam3.perflib.masks_ops import mask_iou diff --git a/sam3/perflib/triton/connected_components.py b/sam3/perflib/triton/connected_components.py index cdb7d44..1d4376e 100644 --- a/sam3/perflib/triton/connected_components.py +++ b/sam3/perflib/triton/connected_components.py @@ -407,16 +407,16 @@ def connected_components_triton(input_tensor: torch.Tensor): - A BxHxW output tensor with dense labels. Background is 0. - A BxHxW tensor with the size of the connected component for each pixel. """ - assert ( - input_tensor.is_cuda and input_tensor.is_contiguous() - ), "Input tensor must be a contiguous CUDA tensor." + assert input_tensor.is_cuda and input_tensor.is_contiguous(), ( + "Input tensor must be a contiguous CUDA tensor." + ) out_shape = input_tensor.shape if input_tensor.dim() == 4 and input_tensor.shape[1] == 1: input_tensor = input_tensor.squeeze(1) else: - assert ( - input_tensor.dim() == 3 - ), "Input tensor must be (B, H, W) or (B, 1, H, W)." + assert input_tensor.dim() == 3, ( + "Input tensor must be (B, H, W) or (B, 1, H, W)." + ) B, H, W = input_tensor.shape numel = B * H * W diff --git a/sam3/sam/mask_decoder.py b/sam3/sam/mask_decoder.py index 3e1bbd2..944e57a 100644 --- a/sam3/sam/mask_decoder.py +++ b/sam3/sam/mask_decoder.py @@ -202,9 +202,9 @@ class MaskDecoder(nn.Module): assert image_embeddings.shape[0] == tokens.shape[0] src = image_embeddings src = src + dense_prompt_embeddings - assert ( - image_pe.size(0) == 1 - ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + assert image_pe.size(0) == 1, ( + "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + ) pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w = src.shape diff --git a/sam3/sam/transformer.py b/sam3/sam/transformer.py index 5c4bd34..1ff2380 100644 --- a/sam3/sam/transformer.py +++ b/sam3/sam/transformer.py @@ -8,7 +8,6 @@ from typing import Tuple, Type import torch import torch.nn.functional as F - from sam3.sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis from torch import nn, Tensor @@ -205,9 +204,9 @@ class Attention(nn.Module): self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads self.use_fa3 = use_fa3 - assert ( - self.internal_dim % num_heads == 0 - ), "num_heads must divide embedding_dim." + assert self.internal_dim % num_heads == 0, ( + "num_heads must divide embedding_dim." + ) self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) diff --git a/sam3/train/data/coco_json_loaders.py b/sam3/train/data/coco_json_loaders.py index 1618e19..d39b6f4 100644 --- a/sam3/train/data/coco_json_loaders.py +++ b/sam3/train/data/coco_json_loaders.py @@ -142,9 +142,9 @@ class COCO_FROM_JSON: self.prompts = {} for loc_dict in prompts: self.prompts[int(loc_dict["id"])] = loc_dict["name"] - assert len(self.prompts) == len( - self._sorted_cat_ids - ), "Number of prompts must match number of categories" + assert len(self.prompts) == len(self._sorted_cat_ids), ( + "Number of prompts must match number of categories" + ) def getDatapointIds(self): """Return all datapoint indices for training.""" diff --git a/sam3/train/data/collator.py b/sam3/train/data/collator.py index b32b7e3..4a0f2e8 100644 --- a/sam3/train/data/collator.py +++ b/sam3/train/data/collator.py @@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_data from typing import Any, get_args, get_origin, List, Union import torch - from sam3.model.data_misc import ( BatchedDatapoint, BatchedFindTarget, @@ -217,9 +216,9 @@ def collate_fn_api( text_batch.append(q.query_text) stages[stage_id].text_ids.append(text_batch.index(q.query_text)) - assert ( - q.inference_metadata is not None - ), "inference_metadata must be provided when FindQueryLoaded is created." + assert q.inference_metadata is not None, ( + "inference_metadata must be provided when FindQueryLoaded is created." + ) for f in fields(q.inference_metadata): getattr(find_metadatas[stage_id], f.name).append( getattr(q.inference_metadata, f.name) diff --git a/sam3/train/data/sam3_image_dataset.py b/sam3/train/data/sam3_image_dataset.py index 7941e37..c5a1c83 100644 --- a/sam3/train/data/sam3_image_dataset.py +++ b/sam3/train/data/sam3_image_dataset.py @@ -19,10 +19,8 @@ import torch.utils.data import torchvision from decord import cpu, VideoReader from iopath.common.file_io import g_pathmgr - from PIL import Image as PILImage from PIL.Image import DecompressionBombError - from sam3.model.box_ops import box_xywh_to_xyxy from torchvision.datasets.vision import VisionDataset @@ -234,9 +232,9 @@ class CustomCocoDetectionAPI(VisionDataset): if self.coco is not None: return - assert g_pathmgr.isfile( - self.annFile - ), f"please provide valid annotation file. Missing: {self.annFile}" + assert g_pathmgr.isfile(self.annFile), ( + f"please provide valid annotation file. Missing: {self.annFile}" + ) annFile = g_pathmgr.get_local_path(self.annFile) if self.coco is not None: @@ -326,9 +324,9 @@ class CustomCocoDetectionAPI(VisionDataset): else: num_queries_per_stage = stage2num_queries.most_common(1)[0][1] for stage, num_queries in stage2num_queries.items(): - assert ( - num_queries == num_queries_per_stage - ), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}" + assert num_queries == num_queries_per_stage, ( + f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}" + ) for query_id, query in enumerate(queries): h, w = id2imsize[query["image_id"]] diff --git a/sam3/train/data/sam3_video_dataset.py b/sam3/train/data/sam3_video_dataset.py index 91396f0..75eee62 100644 --- a/sam3/train/data/sam3_video_dataset.py +++ b/sam3/train/data/sam3_video_dataset.py @@ -3,7 +3,6 @@ # pyre-unsafe import copy - import io import json import logging @@ -16,7 +15,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch import torchvision - # from decord import cpu, VideoReader from iopath.common.file_io import PathManager @@ -220,9 +218,9 @@ class VideoGroundingDataset(Sam3ImageDataset): for query in filtered_queries: ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1] ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1] - assert ( - ptr_x_is_empty and ptr_y_is_empty - ), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers" + assert ptr_x_is_empty and ptr_y_is_empty, ( + "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers" + ) query["query_processing_order"] = stage_id_old2new[ query["query_processing_order"] ] diff --git a/sam3/train/loss/loss_fns.py b/sam3/train/loss/loss_fns.py index 2bb9039..8fa1774 100644 --- a/sam3/train/loss/loss_fns.py +++ b/sam3/train/loss/loss_fns.py @@ -9,11 +9,8 @@ import torch import torch.distributed import torch.nn.functional as F import torchmetrics - from sam3.model import box_ops - from sam3.model.data_misc import interpolate - from sam3.train.loss.sigmoid_focal_loss import ( triton_sigmoid_focal_loss, triton_sigmoid_focal_loss_reduce, @@ -327,7 +324,9 @@ class IABCEMdetr(LossWithWeights): if num_det_queries is not None: logging.warning("note: it's not needed to set num_det_queries anymore") if self.use_separate_loss_for_det_and_trk: - assert not self.weak_loss, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead" + assert not self.weak_loss, ( + "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead" + ) self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos @@ -342,7 +341,9 @@ class IABCEMdetr(LossWithWeights): and det_non_exhaustive_loss_scale_neg == 1.0 and trk_loss_scale_pos == 1.0 and trk_loss_scale_neg == 1.0 - ), "If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0" + ), ( + "If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0" + ) def get_loss(self, outputs, targets, indices, num_boxes): assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape" @@ -443,7 +444,9 @@ class IABCEMdetr(LossWithWeights): pass if self.weak_loss: - assert not self.use_separate_loss_for_det_and_trk, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead" + assert not self.use_separate_loss_for_det_and_trk, ( + "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead" + ) # nullify the negative loss for the non-exhaustive classes assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0] @@ -497,9 +500,9 @@ class IABCEMdetr(LossWithWeights): loss_bce = loss_bce.mean() else: assert isinstance(self.pad_n_queries, int) - assert ( - loss_bce.size(1) < self.pad_n_queries - ), f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions." + assert loss_bce.size(1) < self.pad_n_queries, ( + f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions." + ) loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0)) bce_f1 = torchmetrics.functional.f1_score( diff --git a/sam3/train/loss/sam3_loss.py b/sam3/train/loss/sam3_loss.py index 300100a..5aa5791 100644 --- a/sam3/train/loss/sam3_loss.py +++ b/sam3/train/loss/sam3_loss.py @@ -3,9 +3,7 @@ # pyre-unsafe import torch - from sam3.model.model_misc import SAM3Output - from sam3.train.utils.distributed import get_world_size from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks diff --git a/sam3/train/masks_ops.py b/sam3/train/masks_ops.py index a98e250..113a9c4 100644 --- a/sam3/train/masks_ops.py +++ b/sam3/train/masks_ops.py @@ -103,9 +103,9 @@ def dilation(mask, kernel_size): assert mask.ndim == 3 kernel_size = int(kernel_size) - assert ( - kernel_size % 2 == 1 - ), f"Dilation expects a odd kernel size, got {kernel_size}" + assert kernel_size % 2 == 1, ( + f"Dilation expects a odd kernel size, got {kernel_size}" + ) if mask.is_cuda: m = mask.unsqueeze(1).to(torch.float16) diff --git a/sam3/train/matcher.py b/sam3/train/matcher.py index 660e1d0..5adc6a4 100644 --- a/sam3/train/matcher.py +++ b/sam3/train/matcher.py @@ -8,7 +8,6 @@ Modules to compute the matching cost and solve the corresponding LSAP. import numpy as np import torch - from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou from scipy.optimize import linear_sum_assignment from torch import nn @@ -60,9 +59,9 @@ class HungarianMatcher(nn.Module): self.cost_bbox = cost_bbox self.cost_giou = cost_giou self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1) - assert ( - cost_class != 0 or cost_bbox != 0 or cost_giou != 0 - ), "all costs cant be 0" + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, ( + "all costs cant be 0" + ) self.focal_loss = focal_loss self.focal_alpha = focal_alpha self.focal_gamma = focal_gamma @@ -197,9 +196,9 @@ class BinaryHungarianMatcher(nn.Module): self.cost_bbox = cost_bbox self.cost_giou = cost_giou self.norm = nn.Sigmoid() - assert ( - cost_class != 0 or cost_bbox != 0 or cost_giou != 0 - ), "all costs cant be 0" + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, ( + "all costs cant be 0" + ) @torch.no_grad() def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1): @@ -322,9 +321,9 @@ class BinaryFocalHungarianMatcher(nn.Module): self.alpha = alpha self.gamma = gamma self.stable = stable - assert ( - cost_class != 0 or cost_bbox != 0 or cost_giou != 0 - ), "all costs cant be 0" + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, ( + "all costs cant be 0" + ) @torch.no_grad() def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1): @@ -470,9 +469,9 @@ class BinaryHungarianMatcherV2(nn.Module): self.cost_bbox = cost_bbox self.cost_giou = cost_giou self.norm = nn.Sigmoid() - assert ( - cost_class != 0 or cost_bbox != 0 or cost_giou != 0 - ), "all costs cant be 0" + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, ( + "all costs cant be 0" + ) self.focal = focal if focal: self.alpha = alpha diff --git a/sam3/train/optim/optimizer.py b/sam3/train/optim/optimizer.py index 7fa9b90..66f305e 100644 --- a/sam3/train/optim/optimizer.py +++ b/sam3/train/optim/optimizer.py @@ -22,7 +22,6 @@ from typing import ( ) import hydra - import torch import torch.nn as nn from omegaconf import DictConfig @@ -212,9 +211,9 @@ def unix_module_cls_pattern_to_parameter_names( "match any classes in the model" ) matching_parameters = module_cls_to_param_names[module_cls] - assert ( - len(matching_parameters) > 0 - ), f"module_cls_name {module_cls_name} does not contain any parameters in the model" + assert len(matching_parameters) > 0, ( + f"module_cls_name {module_cls_name} does not contain any parameters in the model" + ) logging.info( f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} " ) @@ -240,9 +239,9 @@ def unix_param_pattern_to_parameter_names( allowed_parameter_names = [] for param_name in filter_param_names: matching_parameters = set(fnmatch.filter(parameter_names, param_name)) - assert ( - len(matching_parameters) >= 1 - ), f"param_name {param_name} does not match any parameters in the model" + assert len(matching_parameters) >= 1, ( + f"param_name {param_name} does not match any parameters in the model" + ) logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}") allowed_parameter_names.append(matching_parameters) return set.union(*allowed_parameter_names) diff --git a/sam3/train/train.py b/sam3/train/train.py index 976fde3..df6cbe1 100644 --- a/sam3/train/train.py +++ b/sam3/train/train.py @@ -12,13 +12,10 @@ from copy import deepcopy import submitit import torch - from hydra import compose, initialize_config_module from hydra.utils import instantiate - from iopath.common.file_io import g_pathmgr from omegaconf import OmegaConf - from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers from tqdm import tqdm @@ -212,9 +209,9 @@ def main(args) -> None: }, } if "include_nodes" in submitit_conf: - assert ( - len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes - ), "Not enough nodes" + assert len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes, ( + "Not enough nodes" + ) job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join( submitit_conf["include_nodes"] ) diff --git a/sam3/train/trainer.py b/sam3/train/trainer.py index 4d25d92..cc6c37e 100644 --- a/sam3/train/trainer.py +++ b/sam3/train/trainer.py @@ -15,28 +15,22 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Mapping, Optional import numpy as np - import torch import torch.distributed as dist import torch.nn as nn from hydra.utils import instantiate from iopath.common.file_io import g_pathmgr - from sam3.model.data_misc import BatchedDatapoint from sam3.model.model_misc import SAM3Output from sam3.model.utils.misc import copy_data_to_device - from sam3.train.optim.optimizer import construct_optimizer - from sam3.train.utils.checkpoint_utils import ( assert_skipped_parameters_are_frozen, exclude_params_matching_unix_pattern, load_state_dict_into_model, with_check_parameter_frozen, ) - from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank - from sam3.train.utils.logger import Logger, setup_logging from sam3.train.utils.train_utils import ( AverageMeter, @@ -215,9 +209,9 @@ class Trainer: set_seeds(seed_value, self.max_epochs, self.distributed_rank) log_env_variables() - assert ( - is_dist_avail_and_initialized() - ), "Torch distributed needs to be initialized before calling the trainer." + assert is_dist_avail_and_initialized(), ( + "Torch distributed needs to be initialized before calling the trainer." + ) self._setup_components() # Except Optimizer everything is setup here. self._move_to_device() @@ -227,9 +221,9 @@ class Trainer: self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f") if self.checkpoint_conf.resume_from is not None: - assert os.path.exists( - self.checkpoint_conf.resume_from - ), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!" + assert os.path.exists(self.checkpoint_conf.resume_from), ( + f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!" + ) dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt") if self.distributed_rank == 0 and not os.path.exists(dst): # Copy the "resume_from" checkpoint to the checkpoint folder @@ -477,9 +471,9 @@ class Trainer: return self.loss[key] assert key != "all", "Loss must be specified for key='all'" - assert ( - "default" in self.loss - ), f"Key {key} not found in losss, and no default provided" + assert "default" in self.loss, ( + f"Key {key} not found in losss, and no default provided" + ) return self.loss["default"] def _find_meter(self, phase: str, key: str): @@ -922,12 +916,12 @@ class Trainer: self.optim.zero_grad(set_to_none=True) if self.gradient_accumulation_steps > 1: - assert isinstance( - batch, list - ), f"Expected a list of batches, got {type(batch)}" - assert ( - len(batch) == self.gradient_accumulation_steps - ), f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}" + assert isinstance(batch, list), ( + f"Expected a list of batches, got {type(batch)}" + ) + assert len(batch) == self.gradient_accumulation_steps, ( + f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}" + ) accum_steps = len(batch) else: accum_steps = 1 @@ -1039,9 +1033,9 @@ class Trainer: def _check_val_key_match(self, val_keys, phase): if val_keys is not None: # Check if there are any duplicates - assert len(val_keys) == len( - set(val_keys) - ), f"Duplicate keys in val datasets, keys: {val_keys}" + assert len(val_keys) == len(set(val_keys)), ( + f"Duplicate keys in val datasets, keys: {val_keys}" + ) # Check that the keys match the meter keys if self.meters_conf is not None and phase in self.meters_conf: @@ -1055,9 +1049,9 @@ class Trainer: loss_keys = set(self.loss_conf.keys()) - set(["all"]) if "default" not in loss_keys: for k in val_keys: - assert ( - k in loss_keys - ), f"Error: key {k} is not defined in the losses, and no default is set" + assert k in loss_keys, ( + f"Error: key {k} is not defined in the losses, and no default is set" + ) def _setup_components(self): # Get the keys for all the val datasets, if any diff --git a/sam3/train/transforms/basic.py b/sam3/train/transforms/basic.py index cdcce33..7211389 100644 --- a/sam3/train/transforms/basic.py +++ b/sam3/train/transforms/basic.py @@ -14,7 +14,6 @@ import PIL import torch import torchvision.transforms as T import torchvision.transforms.functional as F - from sam3.model.box_ops import box_xyxy_to_cxcywh from sam3.model.data_misc import interpolate @@ -277,9 +276,9 @@ class RandomSizeCrop: max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1)) ) result_img, result_target = crop(img, target, [j, i, h, w]) - assert ( - len(result_target["boxes"]) == init_boxes - ), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}" + assert len(result_target["boxes"]) == init_boxes, ( + f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}" + ) return result_img, result_target else: diff --git a/sam3/train/transforms/basic_for_api.py b/sam3/train/transforms/basic_for_api.py index c7cc494..27a3d4b 100644 --- a/sam3/train/transforms/basic_for_api.py +++ b/sam3/train/transforms/basic_for_api.py @@ -7,7 +7,6 @@ Transforms and data augmentation for both image + bbox. """ import logging - import numbers import random from collections.abc import Sequence @@ -17,9 +16,7 @@ import torch import torchvision.transforms as T import torchvision.transforms.functional as F import torchvision.transforms.v2.functional as Fv2 - from PIL import Image as PILImage - from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes from sam3.train.data.sam3_image_dataset import Datapoint from torchvision.transforms import InterpolationMode diff --git a/sam3/train/transforms/filter_query_transforms.py b/sam3/train/transforms/filter_query_transforms.py index 2075838..3ebd64e 100644 --- a/sam3/train/transforms/filter_query_transforms.py +++ b/sam3/train/transforms/filter_query_transforms.py @@ -4,12 +4,10 @@ import logging import random - from collections import defaultdict from typing import List, Optional, Union import torch - from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object @@ -381,9 +379,9 @@ class FlexibleFilterFindGetQueries: if len(new_find_queries) == 0: start_with_zero_check = True - assert ( - start_with_zero_check - ), "Invalid Find queries, they need to start at query_processing_order = 0" + assert start_with_zero_check, ( + "Invalid Find queries, they need to start at query_processing_order = 0" + ) datapoint.find_queries = new_find_queries diff --git a/sam3/train/transforms/point_sampling.py b/sam3/train/transforms/point_sampling.py index f1f8cad..521e2f5 100644 --- a/sam3/train/transforms/point_sampling.py +++ b/sam3/train/transforms/point_sampling.py @@ -7,7 +7,6 @@ import numpy as np import torch from PIL import Image as PILImage from pycocotools import mask as mask_util - from sam3.train.data.sam3_image_dataset import Datapoint from torchvision.ops import masks_to_boxes @@ -250,9 +249,9 @@ class RandomGeometricInputsAPI: def _get_target_object(self, datapoint, query): img = datapoint.images[query.image_id] targets = query.object_ids_output - assert ( - len(targets) == 1 - ), "Geometric queries only support a single target object." + assert len(targets) == 1, ( + "Geometric queries only support a single target object." + ) target_idx = targets[0] return img.objects[target_idx] diff --git a/sam3/train/transforms/segmentation.py b/sam3/train/transforms/segmentation.py index 466a109..4fd2316 100644 --- a/sam3/train/transforms/segmentation.py +++ b/sam3/train/transforms/segmentation.py @@ -5,12 +5,9 @@ import numpy as np import pycocotools.mask as mask_utils import torch - import torchvision.transforms.functional as F from PIL import Image as PILImage - from sam3.model.box_ops import masks_to_boxes - from sam3.train.data.sam3_image_dataset import Datapoint diff --git a/sam3/train/utils/checkpoint_utils.py b/sam3/train/utils/checkpoint_utils.py index 465e006..32ca776 100644 --- a/sam3/train/utils/checkpoint_utils.py +++ b/sam3/train/utils/checkpoint_utils.py @@ -36,9 +36,9 @@ def unix_pattern_to_parameter_names( parameter_names = [] for param_name in constraints: matching_parameters = set(fnmatch.filter(all_parameter_names, param_name)) - assert ( - len(matching_parameters) > 0 - ), f"param_names {param_name} don't match any param in the given names." + assert len(matching_parameters) > 0, ( + f"param_names {param_name} don't match any param in the given names." + ) parameter_names.append(matching_parameters) return set.union(*parameter_names) diff --git a/sam3/train/utils/logger.py b/sam3/train/utils/logger.py index 0c835a4..4d6c071 100644 --- a/sam3/train/utils/logger.py +++ b/sam3/train/utils/logger.py @@ -10,10 +10,8 @@ import uuid from typing import Any, Dict, Optional, Union from hydra.utils import instantiate - from iopath.common.file_io import g_pathmgr from numpy import ndarray - from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir from torch import Tensor from torch.utils.tensorboard import SummaryWriter diff --git a/sam3/train/utils/train_utils.py b/sam3/train/utils/train_utils.py index 10b6929..ca259a3 100644 --- a/sam3/train/utils/train_utils.py +++ b/sam3/train/utils/train_utils.py @@ -11,7 +11,6 @@ from datetime import timedelta from typing import Optional import hydra - import numpy as np import omegaconf import torch @@ -83,9 +82,9 @@ def get_machine_local_and_dist_rank(): """ local_rank = int(os.environ.get("LOCAL_RANK", None)) distributed_rank = int(os.environ.get("RANK", None)) - assert ( - local_rank is not None and distributed_rank is not None - ), "Please the set the RANK and LOCAL_RANK environment variables." + assert local_rank is not None and distributed_rank is not None, ( + "Please the set the RANK and LOCAL_RANK environment variables." + ) return local_rank, distributed_rank diff --git a/sam3/visualization_utils.py b/sam3/visualization_utils.py index 07cd0f7..c007c38 100644 --- a/sam3/visualization_utils.py +++ b/sam3/visualization_utils.py @@ -158,22 +158,22 @@ def plot_mask(mask, color="r", ax=None): def normalize_bbox(bbox_xywh, img_w, img_h): # Assumes bbox_xywh is in XYWH format if isinstance(bbox_xywh, list): - assert ( - len(bbox_xywh) == 4 - ), "bbox_xywh list must have 4 elements. Batching not support except for torch tensors." + assert len(bbox_xywh) == 4, ( + "bbox_xywh list must have 4 elements. Batching not support except for torch tensors." + ) normalized_bbox = bbox_xywh.copy() normalized_bbox[0] /= img_w normalized_bbox[1] /= img_h normalized_bbox[2] /= img_w normalized_bbox[3] /= img_h else: - assert isinstance( - bbox_xywh, torch.Tensor - ), "Only torch tensors are supported for batching." + assert isinstance(bbox_xywh, torch.Tensor), ( + "Only torch tensors are supported for batching." + ) normalized_bbox = bbox_xywh.clone() - assert ( - normalized_bbox.size(-1) == 4 - ), "bbox_xywh tensor must have last dimension of size 4." + assert normalized_bbox.size(-1) == 4, ( + "bbox_xywh tensor must have last dimension of size 4." + ) normalized_bbox[..., 0] /= img_w normalized_bbox[..., 1] /= img_h normalized_bbox[..., 2] /= img_w @@ -244,10 +244,10 @@ def visualize_formatted_frame_output( num_outputs = len(outputs_list) if titles is None: - titles = [f"Set {i+1}" for i in range(num_outputs)] - assert ( - len(titles) == num_outputs - ), "length of `titles` should match that of `outputs_list` if not None." + titles = [f"Set {i + 1}" for i in range(num_outputs)] + assert len(titles) == num_outputs, ( + "length of `titles` should match that of `outputs_list` if not None." + ) _, axes = plt.subplots(1, num_outputs, figsize=figsize) if num_outputs == 1: @@ -703,9 +703,9 @@ def get_all_annotations_for_frame( # Get the frame video_df_current = video_df[video_df.id == video_id] - assert ( - len(video_df_current) == 1 - ), f"Expected 1 video row, got {len(video_df_current)}" + assert len(video_df_current) == 1, ( + f"Expected 1 video row, got {len(video_df_current)}" + ) video_row = video_df_current.iloc[0] file_name = video_row.file_names[frame_idx] file_path = os.path.join( @@ -796,7 +796,7 @@ def visualize_prompt_overlay( ax.text( x_img + 5, y_img - 5, - f"P{i+1}", + f"P{i + 1}", color=color, fontsize=10, weight="bold", @@ -828,7 +828,7 @@ def visualize_prompt_overlay( ax.text( x_img, y_img - 5, - f"B{i+1}", + f"B{i + 1}", color=color, fontsize=10, weight="bold", diff --git a/scripts/eval/silver/download_videos.py b/scripts/eval/silver/download_videos.py index b8fda76..d0a1ba8 100644 --- a/scripts/eval/silver/download_videos.py +++ b/scripts/eval/silver/download_videos.py @@ -11,7 +11,6 @@ from concurrent.futures import as_completed, ThreadPoolExecutor from pathlib import Path import yt_dlp - from utils import ( annotation_files, config, @@ -244,9 +243,9 @@ def download_sav(): def main(): assert len(sys.argv) > 1, "You have to provide the name of the dataset" dataset_name = sys.argv[1] - assert ( - dataset_name in annotation_files - ), f"The dataset can be one of {list(annotation_files.keys())}" + assert dataset_name in annotation_files, ( + f"The dataset can be one of {list(annotation_files.keys())}" + ) if dataset_name == "yt1b": download_youtube() diff --git a/scripts/eval/silver/extract_frames.py b/scripts/eval/silver/extract_frames.py index 6be7f05..5c4285d 100644 --- a/scripts/eval/silver/extract_frames.py +++ b/scripts/eval/silver/extract_frames.py @@ -68,9 +68,9 @@ def process_image(args): def main(): assert len(sys.argv) > 1, "You have to provide the name of the dataset" dataset_name = sys.argv[1] - assert ( - dataset_name in annotation_files - ), f"The dataset can be one of {list(annotation_files.keys())}" + assert dataset_name in annotation_files, ( + f"The dataset can be one of {list(annotation_files.keys())}" + ) all_outputs = [] for file in annotation_files[dataset_name]: with open(os.path.join(config["path_annotations"], file), "r") as f: diff --git a/scripts/eval/silver/preprocess_silver_geode_bdd100k_food_rec.py b/scripts/eval/silver/preprocess_silver_geode_bdd100k_food_rec.py index 6b0c65f..2173d1a 100644 --- a/scripts/eval/silver/preprocess_silver_geode_bdd100k_food_rec.py +++ b/scripts/eval/silver/preprocess_silver_geode_bdd100k_food_rec.py @@ -51,7 +51,7 @@ def main(args, n_workers=20): paths = [ ( raw_folder_food_images - / f'{Path(each).stem.split("_")[-1]}{Path(each).suffix}', + / f"{Path(each).stem.split('_')[-1]}{Path(each).suffix}", processed_folder / each, ) for each in img_filenames diff --git a/scripts/eval/veval/saco_yt1b_downloader.py b/scripts/eval/veval/saco_yt1b_downloader.py index ee83102..658b5ce 100644 --- a/scripts/eval/veval/saco_yt1b_downloader.py +++ b/scripts/eval/veval/saco_yt1b_downloader.py @@ -3,7 +3,6 @@ # pyre-unsafe import argparse import logging - import multiprocessing as mp import os from functools import partial diff --git a/scripts/eval/veval/saco_yt1b_frame_prep_util.py b/scripts/eval/veval/saco_yt1b_frame_prep_util.py index 49c4f79..d4f14ac 100644 --- a/scripts/eval/veval/saco_yt1b_frame_prep_util.py +++ b/scripts/eval/veval/saco_yt1b_frame_prep_util.py @@ -58,9 +58,9 @@ class YtVideoPrep: df = self.yt1b_start_end_time_df[ self.yt1b_start_end_time_df.saco_yt1b_id == self.saco_yt1b_id ] - assert ( - len(df) == 1 - ), f"Expected exactly 1 row for saco_yt1b_id: {self.saco_yt1b_id}, found {len(df)}" + assert len(df) == 1, ( + f"Expected exactly 1 row for saco_yt1b_id: {self.saco_yt1b_id}, found {len(df)}" + ) id_and_frame_map_row = df.iloc[0] yt_video_id = ( @@ -82,9 +82,9 @@ class YtVideoPrep: def download_youtube_video(self): video_url = f"https://youtube.com/watch?v={self.yt_video_id}" - assert os.path.exists( - self.cookies_file - ), f"Cookies file '{self.cookies_file}' not found. Must have it to download videos." + assert os.path.exists(self.cookies_file), ( + f"Cookies file '{self.cookies_file}' not found. Must have it to download videos." + ) outtmpl = self.raw_video_path