apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)

Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: itamaro

Differential Revision: D90476315

fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
Bowie Chen
2026-01-11 23:16:49 -08:00
committed by meta-codesync[bot]
parent 7b89b8fc3f
commit 11dec2936d
69 changed files with 445 additions and 522 deletions

View File

@@ -296,9 +296,9 @@ def agent_inference(
assert LATEST_SAM3_TEXT_PROMPT != "" assert LATEST_SAM3_TEXT_PROMPT != ""
# Make sure that the last message is a image # Make sure that the last message is a image
assert ( assert messages[-1]["content"][1]["type"] == "image", (
messages[-1]["content"][1]["type"] == "image" "Second content element should be an image"
), "Second content element should be an image" )
messages.pop() # Remove the last user message messages.pop() # Remove the last user message
# Add simplified replacement message # Add simplified replacement message
simplified_message = { simplified_message = {
@@ -318,7 +318,7 @@ def agent_inference(
# MLLM check the mask one by one # MLLM check the mask one by one
for i in range(num_masks): for i in range(num_masks):
print(f"🔍 Checking mask {i+1}/{num_masks}...") print(f"🔍 Checking mask {i + 1}/{num_masks}...")
image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i) image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i)
image_w_zoomed_in_mask_i_path = os.path.join( image_w_zoomed_in_mask_i_path = os.path.join(
@@ -363,7 +363,7 @@ def agent_inference(
raise ValueError( raise ValueError(
"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters." "Generated text is None, which is unexpected. Please check the Qwen server and the input parameters."
) )
print(f"Generated text for mask {i+1}: {checking_generated_text}") print(f"Generated text for mask {i + 1}: {checking_generated_text}")
verdict = ( verdict = (
checking_generated_text.split("<verdict>")[-1] checking_generated_text.split("<verdict>")[-1]
.split("</verdict>")[0] .split("</verdict>")[0]
@@ -371,11 +371,11 @@ def agent_inference(
) )
if "Accept" in verdict: if "Accept" in verdict:
assert not "Reject" in verdict assert not "Reject" in verdict
print(f"Mask {i+1} accepted, keeping it in the outputs.") print(f"Mask {i + 1} accepted, keeping it in the outputs.")
masks_to_keep.append(i) masks_to_keep.append(i)
elif "Reject" in verdict: elif "Reject" in verdict:
assert not "Accept" in verdict assert not "Accept" in verdict
print(f"Mask {i+1} rejected, removing it from the outputs.") print(f"Mask {i + 1} rejected, removing it from the outputs.")
else: else:
raise ValueError( raise ValueError(
f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'." f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'."
@@ -397,7 +397,7 @@ def agent_inference(
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png" sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png"
).replace( ).replace(
".png", ".png",
f"_selected_masks_{'-'.join(map(str, [i+1 for i in masks_to_keep]))}.png".replace( f"_selected_masks_{'-'.join(map(str, [i + 1 for i in masks_to_keep]))}.png".replace(
"/", "_" "/", "_"
), ),
) )

View File

@@ -7,7 +7,6 @@ import os
import torch import torch
from PIL import Image from PIL import Image
from sam3.model.box_ops import box_xyxy_to_xywh from sam3.model.box_ops import box_xyxy_to_xywh
from sam3.train.masks_ops import rle_encode from sam3.train.masks_ops import rle_encode

View File

@@ -84,9 +84,9 @@ class BoxMode(IntEnum):
], "Relative mode not yet supported!" ], "Relative mode not yet supported!"
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS: if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
assert ( assert arr.shape[-1] == 5, (
arr.shape[-1] == 5 "The last dimension of input shape must be 5 for XYWHA format"
), "The last dimension of input shape must be 5 for XYWHA format" )
original_dtype = arr.dtype original_dtype = arr.dtype
arr = arr.double() arr = arr.double()
@@ -244,9 +244,9 @@ class Boxes:
if isinstance(item, int): if isinstance(item, int):
return Boxes(self.tensor[item].view(1, -1)) return Boxes(self.tensor[item].view(1, -1))
b = self.tensor[item] b = self.tensor[item]
assert ( assert b.dim() == 2, (
b.dim() == 2 "Indexing on Boxes with {} failed to return a matrix!".format(item)
), "Indexing on Boxes with {} failed to return a matrix!".format(item) )
return Boxes(b) return Boxes(b)
def __len__(self) -> int: def __len__(self) -> int:
@@ -425,7 +425,7 @@ def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
Tensor: iou, sized [N]. Tensor: iou, sized [N].
""" """
assert len(boxes1) == len(boxes2), ( assert len(boxes1) == len(boxes2), (
"boxlists should have the same" "number of entries, got {}, {}".format( "boxlists should have the samenumber of entries, got {}, {}".format(
len(boxes1), len(boxes2) len(boxes1), len(boxes2)
) )
) )

View File

@@ -13,7 +13,6 @@ from torch import device
from .boxes import Boxes from .boxes import Boxes
from .memory import retry_if_cuda_oom from .memory import retry_if_cuda_oom
from .roi_align import ROIAlign from .roi_align import ROIAlign
@@ -142,10 +141,10 @@ class BitMasks:
if isinstance(item, int): if isinstance(item, int):
return BitMasks(self.tensor[item].unsqueeze(0)) return BitMasks(self.tensor[item].unsqueeze(0))
m = self.tensor[item] m = self.tensor[item]
assert ( assert m.dim() == 3, (
m.dim() == 3 "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
), "Indexing on BitMasks with {} returns a tensor with shape {}!".format( item, m.shape
item, m.shape )
) )
return BitMasks(m) return BitMasks(m)

View File

@@ -363,9 +363,9 @@ class RotatedBoxes(Boxes):
if isinstance(item, int): if isinstance(item, int):
return RotatedBoxes(self.tensor[item].view(1, -1)) return RotatedBoxes(self.tensor[item].view(1, -1))
b = self.tensor[item] b = self.tensor[item]
assert ( assert b.dim() == 2, (
b.dim() == 2 "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
), "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item) )
return RotatedBoxes(b) return RotatedBoxes(b)
def __len__(self) -> int: def __len__(self) -> int:

View File

@@ -20,7 +20,6 @@ from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image from PIL import Image
from .boxes import Boxes, BoxMode from .boxes import Boxes, BoxMode
from .color_map import random_color from .color_map import random_color
from .keypoints import Keypoints from .keypoints import Keypoints
from .masks import BitMasks, PolygonMasks from .masks import BitMasks, PolygonMasks
@@ -222,9 +221,9 @@ class _PanopticPrediction:
empty_ids.append(id) empty_ids.append(id)
if len(empty_ids) == 0: if len(empty_ids) == 0:
return np.zeros(self._seg.shape, dtype=np.uint8) return np.zeros(self._seg.shape, dtype=np.uint8)
assert ( assert len(empty_ids) == 1, (
len(empty_ids) == 1 ">1 ids corresponds to no labels. This is currently not supported"
), ">1 ids corresponds to no labels. This is currently not supported" )
return (self._seg != empty_ids[0]).numpy().astype(np.bool) return (self._seg != empty_ids[0]).numpy().astype(np.bool)
def semantic_masks(self): def semantic_masks(self):

View File

@@ -41,7 +41,7 @@ def run_single_image_inference(
print(f"Output JSON {output_json_path} already exists. Skipping.") print(f"Output JSON {output_json_path} already exists. Skipping.")
return return
print(f"{'-'*30} Starting SAM 3 Agent Session... {'-'*30} ") print(f"{'-' * 30} Starting SAM 3 Agent Session... {'-' * 30} ")
agent_history, final_output_dict, rendered_final_output = agent_inference( agent_history, final_output_dict, rendered_final_output = agent_inference(
image_path, image_path,
text_prompt, text_prompt,
@@ -50,7 +50,7 @@ def run_single_image_inference(
output_dir=output_dir, output_dir=output_dir,
debug=debug, debug=debug,
) )
print(f"{'-'*30} End of SAM 3 Agent Session... {'-'*30} ") print(f"{'-' * 30} End of SAM 3 Agent Session... {'-' * 30} ")
final_output_dict["text_prompt"] = text_prompt final_output_dict["text_prompt"] = text_prompt
final_output_dict["image_path"] = image_path final_output_dict["image_path"] = image_path

View File

@@ -73,7 +73,9 @@ def visualize(
idx = int(zoom_in_index) idx = int(zoom_in_index)
num_masks = len(input_json.get("pred_masks", [])) num_masks = len(input_json.get("pred_masks", []))
if idx < 0 or idx >= num_masks: if idx < 0 or idx >= num_masks:
raise ValueError(f"zoom_in_index {idx} is out of range (0..{num_masks-1}).") raise ValueError(
f"zoom_in_index {idx} is out of range (0..{num_masks - 1})."
)
# (1) Replicate zoom_in_and_visualize # (1) Replicate zoom_in_and_visualize
object_data = { object_data = {

View File

@@ -126,9 +126,9 @@ class COCOCustom(COCO):
# MODIFICATION: faster and cached subset check # MODIFICATION: faster and cached subset check
if not hasattr(self, "img_id_set"): if not hasattr(self, "img_id_set"):
self.img_id_set = set(self.getImgIds()) self.img_id_set = set(self.getImgIds())
assert set(annsImgIds).issubset( assert set(annsImgIds).issubset(self.img_id_set), (
self.img_id_set "Results do not correspond to current coco set"
), "Results do not correspond to current coco set" )
# END MODIFICATION # END MODIFICATION
if "caption" in anns[0]: if "caption" in anns[0]:
imgIds = set([img["id"] for img in res.dataset["images"]]) & set( imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
@@ -301,9 +301,9 @@ class CGF1Eval(COCOeval):
TP = (match_scores >= thresh).sum() TP = (match_scores >= thresh).sum()
FP = len(dt) - TP FP = len(dt) - TP
FN = len(gt) - TP FN = len(gt) - TP
assert ( assert FP >= 0 and FN >= 0, (
FP >= 0 and FN >= 0 f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}" )
TPs.append(TP) TPs.append(TP)
FPs.append(FP) FPs.append(FP)
FNs.append(FN) FNs.append(FN)
@@ -599,9 +599,9 @@ class CGF1Evaluator:
""" """
assert len(self.coco_gts) > 0, "No ground truth provided for evaluation." assert len(self.coco_gts) > 0, "No ground truth provided for evaluation."
assert len(self.coco_gts) == len( assert len(self.coco_gts) == len(self.coco_evals), (
self.coco_evals "Mismatch in number of ground truths and evaluators."
), "Mismatch in number of ground truths and evaluators." )
if self.verbose: if self.verbose:
print(f"Loading predictions from {pred_file}") print(f"Loading predictions from {pred_file}")
@@ -668,17 +668,17 @@ class CGF1Evaluator:
if len(scorings) == 1: if len(scorings) == 1:
return scorings[0] return scorings[0]
assert ( assert scorings[0].ndim == 3, (
scorings[0].ndim == 3 f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}" )
assert ( assert scorings[0].shape[0] == 1, (
scorings[0].shape[0] == 1 f"Expecting a single category, got {scorings[0].shape[0]}"
), f"Expecting a single category, got {scorings[0].shape[0]}" )
for scoring in scorings: for scoring in scorings:
assert ( assert scoring.shape == scorings[0].shape, (
scoring.shape == scorings[0].shape f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}" )
selected_imgs = [] selected_imgs = []
for img_id in range(scorings[0].shape[-1]): for img_id in range(scorings[0].shape[-1]):

View File

@@ -18,19 +18,15 @@ import os
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional from typing import Any, List, Optional
import numpy as np import numpy as np
import pycocotools.mask as mask_utils import pycocotools.mask as mask_utils
import torch import torch
from iopath.common.file_io import g_pathmgr from iopath.common.file_io import g_pathmgr
from pycocotools.coco import COCO from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
from sam3.train.masks_ops import rle_encode from sam3.train.masks_ops import rle_encode
from sam3.train.utils.distributed import ( from sam3.train.utils.distributed import (
all_gather, all_gather,
gather_to_rank_0_via_filesys, gather_to_rank_0_via_filesys,
@@ -755,9 +751,9 @@ def loadRes(self, resFile):
anns = resFile anns = resFile
assert type(anns) == list, "results in not an array of objects" assert type(anns) == list, "results in not an array of objects"
annsImgIds = [ann["image_id"] for ann in anns] annsImgIds = [ann["image_id"] for ann in anns]
assert set(annsImgIds) == ( assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), (
set(annsImgIds) & set(self.getImgIds()) "Results do not correspond to current coco set"
), "Results do not correspond to current coco set" )
if "caption" in anns[0]: if "caption" in anns[0]:
imgIds = set([img["id"] for img in res.dataset["images"]]) & set( imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
[ann["image_id"] for ann in anns] [ann["image_id"] for ann in anns]

View File

@@ -83,9 +83,9 @@ class PredictionDumper:
self.merge_predictions = merge_predictions self.merge_predictions = merge_predictions
self.pred_file_evaluators = pred_file_evaluators self.pred_file_evaluators = pred_file_evaluators
if self.pred_file_evaluators is not None: if self.pred_file_evaluators is not None:
assert ( assert merge_predictions, (
merge_predictions "merge_predictions must be True if pred_file_evaluators are provided"
), "merge_predictions must be True if pred_file_evaluators are provided" )
assert self.dump_dir is not None, "dump_dir must be provided" assert self.dump_dir is not None, "dump_dir must be provided"
if is_main_process(): if is_main_process():

View File

@@ -13,11 +13,9 @@ from typing import Optional
import numpy as np import numpy as np
import pycocotools.mask as maskUtils import pycocotools.mask as maskUtils
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
from sam3.eval.coco_eval import CocoEvaluator from sam3.eval.coco_eval import CocoEvaluator
from sam3.train.masks_ops import compute_F_measure from sam3.train.masks_ops import compute_F_measure
from sam3.train.utils.distributed import is_main_process from sam3.train.utils.distributed import is_main_process
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
@@ -156,9 +154,9 @@ class DemoEval(COCOeval):
TP = (match_scores >= thresh).sum() TP = (match_scores >= thresh).sum()
FP = len(dt) - TP FP = len(dt) - TP
FN = len(gt) - TP FN = len(gt) - TP
assert ( assert FP >= 0 and FN >= 0, (
FP >= 0 and FN >= 0 f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}" )
TPs.append(TP) TPs.append(TP)
FPs.append(FP) FPs.append(FP)
FNs.append(FN) FNs.append(FN)
@@ -528,17 +526,17 @@ class DemoEvaluator(CocoEvaluator):
if len(scorings) == 1: if len(scorings) == 1:
return scorings[0] return scorings[0]
assert ( assert scorings[0].ndim == 3, (
scorings[0].ndim == 3 f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}" )
assert ( assert scorings[0].shape[0] == 1, (
scorings[0].shape[0] == 1 f"Expecting a single category, got {scorings[0].shape[0]}"
), f"Expecting a single category, got {scorings[0].shape[0]}" )
for scoring in scorings: for scoring in scorings:
assert ( assert scoring.shape == scorings[0].shape, (
scoring.shape == scorings[0].shape f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}" )
selected_imgs = [] selected_imgs = []
for img_id in range(scorings[0].shape[-1]): for img_id in range(scorings[0].shape[-1]):

View File

@@ -255,9 +255,10 @@ class Evaluator:
if show_progressbar and TQDM_IMPORTED: if show_progressbar and TQDM_IMPORTED:
seq_list_sorted = sorted(seq_list) seq_list_sorted = sorted(seq_list)
with Pool(config["NUM_PARALLEL_CORES"]) as pool, tqdm.tqdm( with (
total=len(seq_list) Pool(config["NUM_PARALLEL_CORES"]) as pool,
) as pbar: tqdm.tqdm(total=len(seq_list)) as pbar,
):
_eval_sequence = partial( _eval_sequence = partial(
eval_sequence, eval_sequence,
dataset=dataset, dataset=dataset,

View File

@@ -83,9 +83,9 @@ class PostProcessImage(nn.Module):
ret_tensordict: Experimental argument. If true, return a tensordict.TensorDict instead of a list of dictionaries for easier manipulation. ret_tensordict: Experimental argument. If true, return a tensordict.TensorDict instead of a list of dictionaries for easier manipulation.
""" """
if ret_tensordict: if ret_tensordict:
assert ( assert consistent is True, (
consistent is True "We don't support returning TensorDict if the outputs have different shapes"
), "We don't support returning TensorDict if the outputs have different shapes" # NOTE: It's possible but we don't support it. ) # NOTE: It's possible but we don't support it.
assert self.detection_threshold <= 0.0, "TODO: implement?" assert self.detection_threshold <= 0.0, "TODO: implement?"
try: try:
from tensordict import TensorDict from tensordict import TensorDict
@@ -118,7 +118,9 @@ class PostProcessImage(nn.Module):
if boxes is None: if boxes is None:
assert out_masks is not None assert out_masks is not None
assert not ret_tensordict, "We don't support returning TensorDict if the output does not contain boxes" assert not ret_tensordict, (
"We don't support returning TensorDict if the output does not contain boxes"
)
B = len(out_masks) B = len(out_masks)
boxes = [None] * B boxes = [None] * B
scores = [None] * B scores = [None] * B
@@ -418,9 +420,9 @@ class PostProcessAPIVideo(PostProcessImage):
if video_id == -1: if video_id == -1:
video_id = unique_vid_id.item() video_id = unique_vid_id.item()
else: else:
assert ( assert video_id == unique_vid_id.item(), (
video_id == unique_vid_id.item() "We can only postprocess one video per datapoint"
), "We can only postprocess one video per datapoint" )
# keeping track of which objects appear in the current frame # keeping track of which objects appear in the current frame
obj_ids_per_frame = frame_outs["pred_object_ids"] obj_ids_per_frame = frame_outs["pred_object_ids"]
assert obj_ids_per_frame.size(-1) == frame_outs["pred_logits"].size(-2) assert obj_ids_per_frame.size(-1) == frame_outs["pred_logits"].size(-2)

View File

@@ -95,9 +95,9 @@ class YTVIS(COCO):
anns = resFile anns = resFile
assert type(anns) == list, "results is not an array of objects" assert type(anns) == list, "results is not an array of objects"
annsImgIds = [ann["image_id"] for ann in anns] annsImgIds = [ann["image_id"] for ann in anns]
assert set(annsImgIds) == ( assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), (
set(annsImgIds) & set(self.getImgIds()) "Results do not correspond to current coco set"
), "Results do not correspond to current coco set" )
if "bboxes" in anns[0] and not anns[0]["bboxes"] == []: if "bboxes" in anns[0] and not anns[0]["bboxes"] == []:
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
for id, ann in enumerate(anns): for id, ann in enumerate(anns):

View File

@@ -109,9 +109,7 @@ class YTVISevalMixin:
) # Num preds x Num GTS x Num frames ) # Num preds x Num GTS x Num frames
inter = inter.sum(-1) inter = inter.sum(-1)
union = union.sum(-1) union = union.sum(-1)
assert ( assert (union > 0).all(), (
union > 0
).all(), (
"There exists a tracklet with zero GTs across time. This is suspicious" "There exists a tracklet with zero GTs across time. This is suspicious"
) )
return inter / union return inter / union
@@ -136,9 +134,9 @@ class YTVISevalMixin:
iou = inter / union iou = inter / union
assert iou >= 0 and iou <= 1, "Encountered an error in IoU computation" assert iou >= 0 and iou <= 1, "Encountered an error in IoU computation"
else: else:
assert np.isclose(inter, 0) and np.isclose( assert np.isclose(inter, 0) and np.isclose(union, 0), (
union, 0 "Encountered an error in IoU computation"
), "Encountered an error in IoU computation" )
iou = 1 iou = 1
return iou return iou
@@ -206,16 +204,16 @@ class YTVISResultsWriter:
if len(prediction) == 0: if len(prediction) == 0:
continue continue
for k in ["boxes", "scores", "labels"]: for k in ["boxes", "scores", "labels"]:
assert ( assert k in prediction, (
k in prediction f"Expected predictions to have `{k}` key, available keys are {prediction.keys()}"
), f"Expected predictions to have `{k}` key, available keys are {prediction.keys()}" )
if self.save_per_frame_scores: if self.save_per_frame_scores:
assert ( assert "per_frame_scores" in prediction, (
"per_frame_scores" in prediction f"Expected predictions to have `per_frame_scores` key, available keys are {prediction.keys()}"
), f"Expected predictions to have `per_frame_scores` key, available keys are {prediction.keys()}" )
assert xor( assert xor("masks" in prediction, "masks_rle" in prediction), (
"masks" in prediction, "masks_rle" in prediction f"Expected predictions to have either `masks` key or `masks_rle` key, available keys are {prediction.keys()}"
), f"Expected predictions to have either `masks` key or `masks_rle` key, available keys are {prediction.keys()}" )
boxes = prediction["boxes"] boxes = prediction["boxes"]
boxes = convert_to_xywh(boxes).tolist() boxes = convert_to_xywh(boxes).tolist()
@@ -223,9 +221,9 @@ class YTVISResultsWriter:
labels = prediction["labels"].tolist() labels = prediction["labels"].tolist()
if "masks" in prediction: if "masks" in prediction:
masks = prediction["masks"].squeeze(2) masks = prediction["masks"].squeeze(2)
assert ( assert masks.ndim == 4, (
masks.ndim == 4 "Expected masks to be of shape(N_preds,T_frames,H,W)"
), "Expected masks to be of shape(N_preds,T_frames,H,W)" )
areas = [mask.flatten(1).sum(1).tolist() for mask in masks] areas = [mask.flatten(1).sum(1).tolist() for mask in masks]
rles = [rle_encode(masklet) for masklet in masks] rles = [rle_encode(masklet) for masklet in masks]

View File

@@ -42,9 +42,9 @@ def get_logger(name, level=logging.INFO):
"""A command line logger.""" """A command line logger."""
if "LOG_LEVEL" in os.environ: if "LOG_LEVEL" in os.environ:
level = os.environ["LOG_LEVEL"].upper() level = os.environ["LOG_LEVEL"].upper()
assert ( assert level in LOG_LEVELS, (
level in LOG_LEVELS f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}"
), f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}" )
level = LOG_LEVELS[level] level = LOG_LEVELS[level]
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(level) logger.setLevel(level)

View File

@@ -7,7 +7,6 @@ Misc functions, including distributed helpers.
import collections import collections
import re import re
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
@@ -29,9 +28,9 @@ def interpolate(
input, size, scale_factor, mode, align_corners input, size, scale_factor, mode, align_corners
) )
assert ( assert input.shape[0] != 0 or input.shape[1] != 0, (
input.shape[0] != 0 or input.shape[1] != 0 "At least one of the two first dimensions must be non zero"
), "At least one of the two first dimensions must be non zero" )
if input.shape[1] == 0: if input.shape[1] == 0:
# Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim # Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim

View File

@@ -9,18 +9,13 @@ Inspired from Pytorch's version, adds the pre-norm variant
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
from sam3.sam.transformer import RoPEAttention from sam3.sam.transformer import RoPEAttention
from torch import nn, Tensor from torch import nn, Tensor
from torchvision.ops.roi_align import RoIAlign from torchvision.ops.roi_align import RoIAlign
from .act_ckpt_utils import activation_ckpt_wrapper from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy from .box_ops import box_cxcywh_to_xyxy
from .model_misc import ( from .model_misc import (
gen_sineembed_for_position, gen_sineembed_for_position,
get_activation_fn, get_activation_fn,
@@ -444,9 +439,9 @@ class TransformerDecoder(nn.Module):
- valid_ratios/spatial_shapes: bs, nlevel, 2 - valid_ratios/spatial_shapes: bs, nlevel, 2
""" """
if memory_mask is not None: if memory_mask is not None:
assert ( assert self.boxRPB == "none", (
self.boxRPB == "none" "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented" )
apply_dac = apply_dac if apply_dac is not None else self.dac apply_dac = apply_dac if apply_dac is not None else self.dac
if apply_dac: if apply_dac:
@@ -516,18 +511,18 @@ class TransformerDecoder(nn.Module):
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
if self.boxRPB != "none" and reference_boxes is not None: if self.boxRPB != "none" and reference_boxes is not None:
assert ( assert spatial_shapes.shape[0] == 1, (
spatial_shapes.shape[0] == 1 "only single scale support implemented"
), "only single scale support implemented" )
memory_mask = self._get_rpb_matrix( memory_mask = self._get_rpb_matrix(
reference_boxes, reference_boxes,
(spatial_shapes[0, 0], spatial_shapes[0, 1]), (spatial_shapes[0, 0], spatial_shapes[0, 1]),
) )
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W) memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
if self.training: if self.training:
assert ( assert self.use_act_checkpoint, (
self.use_act_checkpoint "Activation checkpointing not enabled in the decoder"
), "Activation checkpointing not enabled in the decoder" )
output, presence_out = activation_ckpt_wrapper(layer)( output, presence_out = activation_ckpt_wrapper(layer)(
tgt=output, tgt=output,
tgt_query_pos=query_pos, tgt_query_pos=query_pos,
@@ -676,9 +671,9 @@ class TransformerEncoderCrossAttention(nn.Module):
src_pos[0], src_pos[0],
) )
assert ( assert src.shape[1] == prompt.shape[1], (
src.shape[1] == prompt.shape[1] "Batch size must be the same for src and prompt"
), "Batch size must be the same for src and prompt" )
output = src output = src

View File

@@ -322,9 +322,9 @@ class TransformerEncoder(nn.Module):
return reference_points return reference_points
def _prepare_multilevel_features(self, srcs, masks, pos_embeds): def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
assert ( assert len(srcs) == self.num_feature_levels, (
len(srcs) == self.num_feature_levels "mismatch between expected and received # of feature levels"
), "mismatch between expected and received # of feature levels" )
src_flatten = [] src_flatten = []
mask_flatten = [] mask_flatten = []
@@ -406,9 +406,9 @@ class TransformerEncoder(nn.Module):
- spatial_shapes: Spatial dimensions of each feature level - spatial_shapes: Spatial dimensions of each feature level
- valid_ratios: Valid ratios for each feature level - valid_ratios: Valid ratios for each feature level
""" """
assert ( assert len(src) == self.num_feature_levels, (
len(src) == self.num_feature_levels "must be equal to num_feature_levels"
), "must be equal to num_feature_levels" )
if src_key_padding_masks is not None: if src_key_padding_masks is not None:
assert len(src_key_padding_masks) == self.num_feature_levels assert len(src_key_padding_masks) == self.num_feature_levels
if pos is not None: if pos is not None:
@@ -538,9 +538,9 @@ class TransformerEncoderFusion(TransformerEncoder):
else None else None
) )
else: else:
assert all( assert all(x.dim == 4 for x in src), (
x.dim == 4 for x in src "expected list of (bs, c, h, w) tensors"
), "expected list of (bs, c, h, w) tensors" )
if self.add_pooled_text_to_img_feat: if self.add_pooled_text_to_img_feat:
# Fusion: Add mean pooled text to image features # Fusion: Add mean pooled text to image features

View File

@@ -11,7 +11,6 @@ from typing_extensions import override
from .act_ckpt_utils import activation_ckpt_wrapper from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy from .box_ops import box_cxcywh_to_xyxy
from .model_misc import get_clones from .model_misc import get_clones
@@ -148,54 +147,42 @@ class Prompt:
) )
# Dimension checks # Dimension checks
assert ( assert box_embeddings is not None and list(box_embeddings.shape[:2]) == [
box_embeddings is not None box_seq_len,
and list(box_embeddings.shape[:2]) bs,
== [ ], (
box_seq_len, f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
bs, )
] assert box_mask is not None and list(box_mask.shape) == [
), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}" bs,
assert ( box_seq_len,
box_mask is not None ], (
and list(box_mask.shape) f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
== [ )
bs, assert point_embeddings is not None and list(point_embeddings.shape[:2]) == [
box_seq_len, point_seq_len,
] bs,
), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}" ], (
assert ( f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
point_embeddings is not None )
and list(point_embeddings.shape[:2]) assert point_mask is not None and list(point_mask.shape) == [
== [ bs,
point_seq_len, point_seq_len,
bs, ], (
] f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}" )
assert ( assert box_labels is not None and list(box_labels.shape) == [
point_mask is not None box_seq_len,
and list(point_mask.shape) bs,
== [ ], (
bs, f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
point_seq_len, )
] assert point_labels is not None and list(point_labels.shape) == [
), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}" point_seq_len,
assert ( bs,
box_labels is not None ], (
and list(box_labels.shape) f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
== [ )
box_seq_len,
bs,
]
), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
assert (
point_labels is not None
and list(point_labels.shape)
== [
point_seq_len,
bs,
]
), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
assert ( assert (
# Allowed to be None, we leave it to the encoder to check for validity before encoding. # Allowed to be None, we leave it to the encoder to check for validity before encoding.
mask_embeddings is None mask_embeddings is None
@@ -204,41 +191,41 @@ class Prompt:
mask_seq_len, mask_seq_len,
bs, bs,
] ]
), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}" ), (
assert ( f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
mask_mask is None )
or list(mask_mask.shape) assert mask_mask is None or list(mask_mask.shape) == [
== [ bs,
bs, mask_seq_len,
mask_seq_len, ], (
] f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}" )
# Device checks # Device checks
assert ( assert box_embeddings is not None and box_embeddings.device == device, (
box_embeddings is not None and box_embeddings.device == device f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}" )
assert ( assert box_mask is not None and box_mask.device == device, (
box_mask is not None and box_mask.device == device f"Expected box mask to be on device {device}, got {box_mask.device}"
), f"Expected box mask to be on device {device}, got {box_mask.device}" )
assert ( assert box_labels is not None and box_labels.device == device, (
box_labels is not None and box_labels.device == device f"Expected box labels to be on device {device}, got {box_labels.device}"
), f"Expected box labels to be on device {device}, got {box_labels.device}" )
assert ( assert point_embeddings is not None and point_embeddings.device == device, (
point_embeddings is not None and point_embeddings.device == device f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}" )
assert ( assert point_mask is not None and point_mask.device == device, (
point_mask is not None and point_mask.device == device f"Expected point mask to be on device {device}, got {point_mask.device}"
), f"Expected point mask to be on device {device}, got {point_mask.device}" )
assert ( assert point_labels is not None and point_labels.device == device, (
point_labels is not None and point_labels.device == device f"Expected point labels to be on device {device}, got {point_labels.device}"
), f"Expected point labels to be on device {device}, got {point_labels.device}" )
assert ( assert mask_embeddings is None or mask_embeddings.device == device, (
mask_embeddings is None or mask_embeddings.device == device f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}" )
assert ( assert mask_mask is None or mask_mask.device == device, (
mask_mask is None or mask_mask.device == device f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}" )
self.box_embeddings = box_embeddings self.box_embeddings = box_embeddings
self.point_embeddings = point_embeddings self.point_embeddings = point_embeddings
@@ -264,30 +251,30 @@ class Prompt:
if point_embeddings is not None: if point_embeddings is not None:
point_seq_len = point_embeddings.shape[0] point_seq_len = point_embeddings.shape[0]
if bs is not None: if bs is not None:
assert ( assert bs == point_embeddings.shape[1], (
bs == point_embeddings.shape[1] f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}." )
else: else:
bs = point_embeddings.shape[1] bs = point_embeddings.shape[1]
if device is not None: if device is not None:
assert ( assert device == point_embeddings.device, (
device == point_embeddings.device "Device mismatch between box and point embeddings"
), "Device mismatch between box and point embeddings" )
else: else:
device = point_embeddings.device device = point_embeddings.device
if mask_embeddings is not None: if mask_embeddings is not None:
mask_seq_len = mask_embeddings.shape[0] mask_seq_len = mask_embeddings.shape[0]
if bs is not None: if bs is not None:
assert ( assert bs == mask_embeddings.shape[1], (
bs == mask_embeddings.shape[1] f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}" )
else: else:
bs = mask_embeddings.shape[1] bs = mask_embeddings.shape[1]
if device is not None: if device is not None:
assert ( assert device == mask_embeddings.device, (
device == mask_embeddings.device "Device mismatch between box/point and mask embeddings."
), "Device mismatch between box/point and mask embeddings." )
else: else:
device = mask_embeddings.device device = mask_embeddings.device
@@ -539,9 +526,9 @@ class SequenceGeometryEncoder(nn.Module):
if add_cls: if add_cls:
self.cls_embed = torch.nn.Embedding(1, self.d_model) self.cls_embed = torch.nn.Embedding(1, self.d_model)
assert ( assert points_direct_project or points_pos_enc or points_pool, (
points_direct_project or points_pos_enc or points_pool "Error: need at least one way to encode points"
), "Error: need at least one way to encode points" )
assert ( assert (
encode_boxes_as_points encode_boxes_as_points
or boxes_direct_project or boxes_direct_project
@@ -583,16 +570,16 @@ class SequenceGeometryEncoder(nn.Module):
self.encode = None self.encode = None
if num_layers > 0: if num_layers > 0:
assert ( assert add_cls, (
add_cls "It's currently highly recommended to add a CLS when using a transformer"
), "It's currently highly recommended to add a CLS when using a transformer" )
self.encode = get_clones(layer, num_layers) self.encode = get_clones(layer, num_layers)
self.encode_norm = nn.LayerNorm(self.d_model) self.encode_norm = nn.LayerNorm(self.d_model)
if mask_encoder is not None: if mask_encoder is not None:
assert isinstance( assert isinstance(mask_encoder, MaskEncoder), (
mask_encoder, MaskEncoder f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}." )
if add_mask_label: if add_mask_label:
self.mask_label_embed = torch.nn.Embedding(2, self.d_model) self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
self.add_mask_label = add_mask_label self.add_mask_label = add_mask_label
@@ -701,16 +688,15 @@ class SequenceGeometryEncoder(nn.Module):
img_feats: torch.Tensor = None, img_feats: torch.Tensor = None,
): ):
n_masks, bs = masks.shape[:2] n_masks, bs = masks.shape[:2]
assert ( assert n_masks == 1, (
n_masks == 1 "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed." )
assert ( assert list(attn_mask.shape) == [
list(attn_mask.shape) bs,
== [ n_masks,
bs, ], (
n_masks, f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
] )
), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
masks, pos = self.mask_encoder( masks, pos = self.mask_encoder(
masks=masks.flatten(0, 1).float(), masks=masks.flatten(0, 1).float(),
pix_feat=img_feats, pix_feat=img_feats,

View File

@@ -13,9 +13,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from sam3.logger import get_logger from sam3.logger import get_logger
from tqdm import tqdm from tqdm import tqdm

View File

@@ -248,7 +248,9 @@ class UniversalSegmentationHead(SegmentationHead):
self.d_model = hidden_dim self.d_model = hidden_dim
if dot_product_scorer is not None: if dot_product_scorer is not None:
assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake" assert presence_head, (
"Specifying a dot product scorer without a presence head is likely a mistake"
)
self.presence_head = None self.presence_head = None
if presence_head: if presence_head:

View File

@@ -62,9 +62,9 @@ class SimpleMaskDownSampler(nn.Module):
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
self.interpol_size = interpol_size self.interpol_size = interpol_size
if self.interpol_size is not None: if self.interpol_size is not None:
assert isinstance( assert isinstance(self.interpol_size, (list, tuple)), (
self.interpol_size, (list, tuple) f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
), f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple." )
self.interpol_size = list(interpol_size) self.interpol_size = list(interpol_size)
assert len(self.interpol_size) == 2 assert len(self.interpol_size) == 2

View File

@@ -330,9 +330,9 @@ class SAM3Output(list):
self.output = output self.output = output
else: else:
self.output = [] self.output = []
assert isinstance( assert isinstance(iter_mode, SAM3Output.IterMode), (
iter_mode, SAM3Output.IterMode f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}" )
self.iter_mode = iter_mode self.iter_mode = iter_mode
# We create a weak reference to self to be used in the lambda functions. # We create a weak reference to self to be used in the lambda functions.
@@ -411,9 +411,9 @@ class SAM3Output(list):
return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode) return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode)
def append(self, item: list): def append(self, item: list):
assert isinstance( assert isinstance(item, list), (
item, list f"Only list items are supported. Got {type(item)}"
), f"Only list items are supported. Got {type(item)}" )
self.output.append(item) self.output.append(item)
def __repr__(self): def __repr__(self):

View File

@@ -8,7 +8,6 @@ from copy import deepcopy
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@@ -7,15 +7,12 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging import logging
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL.Image import Image from PIL.Image import Image
from sam3.model.sam3_tracker_base import Sam3TrackerBase from sam3.model.sam3_tracker_base import Sam3TrackerBase
from sam3.model.utils.sam1_utils import SAM2Transforms from sam3.model.utils.sam1_utils import SAM2Transforms
@@ -97,9 +94,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
input_image = self._transforms(image) input_image = self._transforms(image)
input_image = input_image[None, ...].to(self.device) input_image = input_image[None, ...].to(self.device)
assert ( assert len(input_image.shape) == 4 and input_image.shape[1] == 3, (
len(input_image.shape) == 4 and input_image.shape[1] == 3 f"input_image must be of size 1x3xHxW, got {input_image.shape}"
), f"input_image must be of size 1x3xHxW, got {input_image.shape}" )
logging.info("Computing image embeddings for the provided image...") logging.info("Computing image embeddings for the provided image...")
backbone_out = self.model.forward_image(input_image) backbone_out = self.model.forward_image(input_image)
( (
@@ -136,17 +133,17 @@ class SAM3InteractiveImagePredictor(nn.Module):
assert isinstance(image_list, list) assert isinstance(image_list, list)
self._orig_hw = [] self._orig_hw = []
for image in image_list: for image in image_list:
assert isinstance( assert isinstance(image, np.ndarray), (
image, np.ndarray "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" )
self._orig_hw.append(image.shape[:2]) self._orig_hw.append(image.shape[:2])
# Transform the image to the form expected by the model # Transform the image to the form expected by the model
img_batch = self._transforms.forward_batch(image_list) img_batch = self._transforms.forward_batch(image_list)
img_batch = img_batch.to(self.device) img_batch = img_batch.to(self.device)
batch_size = img_batch.shape[0] batch_size = img_batch.shape[0]
assert ( assert len(img_batch.shape) == 4 and img_batch.shape[1] == 3, (
len(img_batch.shape) == 4 and img_batch.shape[1] == 3 f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" )
logging.info("Computing image embeddings for the provided images...") logging.info("Computing image embeddings for the provided images...")
backbone_out = self.model.forward_image(img_batch) backbone_out = self.model.forward_image(img_batch)
( (
@@ -302,9 +299,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
): ):
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
if point_coords is not None: if point_coords is not None:
assert ( assert point_labels is not None, (
point_labels is not None "point_labels must be supplied if point_coords is supplied."
), "point_labels must be supplied if point_coords is supplied." )
point_coords = torch.as_tensor( point_coords = torch.as_tensor(
point_coords, dtype=torch.float, device=self.device point_coords, dtype=torch.float, device=self.device
) )
@@ -441,9 +438,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
raise RuntimeError( raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding." "An image must be set with .set_image(...) to generate an embedding."
) )
assert ( assert self._features is not None, (
self._features is not None "Features must exist if an image has been set."
), "Features must exist if an image has been set." )
return self._features["image_embed"] return self._features["image_embed"]
@property @property

View File

@@ -8,19 +8,14 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from sam3.model.model_misc import SAM3Output from sam3.model.model_misc import SAM3Output
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
from sam3.model.vl_combiner import SAM3VLBackbone from sam3.model.vl_combiner import SAM3VLBackbone
from sam3.perflib.nms import nms_masks from sam3.perflib.nms import nms_masks
from sam3.train.data.collator import BatchedDatapoint from sam3.train.data.collator import BatchedDatapoint
from .act_ckpt_utils import activation_ckpt_wrapper from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy from .box_ops import box_cxcywh_to_xyxy
from .geometry_encoders import Prompt from .geometry_encoders import Prompt
from .model_misc import inverse_sigmoid from .model_misc import inverse_sigmoid
@@ -661,9 +656,9 @@ class Sam3Image(torch.nn.Module):
inference_state["original_heights"], inference_state["original_heights"],
inference_state["original_widths"], inference_state["original_widths"],
) )
assert ( assert batch_size == len(orig_heights) == len(orig_widths), (
batch_size == len(orig_heights) == len(orig_widths) f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}" )
feats = [ feats = [
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
for feat, feat_size in zip( for feat, feat_size in zip(

View File

@@ -6,9 +6,7 @@ from typing import Dict, List
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from sam3.model import box_ops from sam3.model import box_ops
from sam3.model.data_misc import FindStage, interpolate from sam3.model.data_misc import FindStage, interpolate
from torchvision.transforms import v2 from torchvision.transforms import v2
@@ -83,9 +81,9 @@ class Sam3Processor:
if not isinstance(images, list): if not isinstance(images, list):
raise ValueError("Images must be a list of PIL images or tensors") raise ValueError("Images must be a list of PIL images or tensors")
assert len(images) > 0, "Images list must not be empty" assert len(images) > 0, "Images list must not be empty"
assert isinstance( assert isinstance(images[0], PIL.Image.Image), (
images[0], PIL.Image.Image "Images must be a list of PIL images"
), "Images must be a list of PIL images" )
state["original_heights"] = [image.height for image in images] state["original_heights"] = [image.height for image in images]
state["original_widths"] = [image.width for image in images] state["original_widths"] = [image.width for image in images]

View File

@@ -6,11 +6,8 @@ import logging
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sam3.model.memory import SimpleMaskEncoder from sam3.model.memory import SimpleMaskEncoder
from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames
from sam3.sam.mask_decoder import MaskDecoder, MLP from sam3.sam.mask_decoder import MaskDecoder, MLP
from sam3.sam.prompt_encoder import PromptEncoder from sam3.sam.prompt_encoder import PromptEncoder
from sam3.sam.transformer import TwoWayTransformer from sam3.sam.transformer import TwoWayTransformer

View File

@@ -6,7 +6,6 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from numpy.typing import NDArray from numpy.typing import NDArray
from sam3.model.edt import edt_triton from sam3.model.edt import edt_triton

View File

@@ -6,7 +6,6 @@ import logging
from collections import OrderedDict from collections import OrderedDict
import torch import torch
from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase
from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores
from sam3.model.utils.sam2_utils import load_video_frames from sam3.model.utils.sam2_utils import load_video_frames

View File

@@ -16,7 +16,6 @@ import numpy.typing as npt
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from sam3 import perflib from sam3 import perflib
from sam3.logger import get_logger from sam3.logger import get_logger
from sam3.model.box_ops import fast_diag_box_iou from sam3.model.box_ops import fast_diag_box_iou
@@ -620,9 +619,9 @@ class Sam3VideoBase(nn.Module):
num_obj_dropped_due_to_limit, num_obj_dropped_due_to_limit,
trk_id_to_max_iou_high_conf_det, trk_id_to_max_iou_high_conf_det,
] ]
assert ( assert len(update_plan) == NUM_BROADCAST_ITEMS, (
len(update_plan) == NUM_BROADCAST_ITEMS f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}" )
self.broadcast_python_obj_cpu(update_plan, src=0) self.broadcast_python_obj_cpu(update_plan, src=0)
elif self.rank > 0 and self.world_size > 1: elif self.rank > 0 and self.world_size > 1:
update_plan = [ update_plan = [
@@ -842,9 +841,9 @@ class Sam3VideoBase(nn.Module):
binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0 binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
batch_size = tracker_low_res_masks_global.size(0) batch_size = tracker_low_res_masks_global.size(0)
if batch_size > 0: if batch_size > 0:
assert ( assert len(obj_ids_global) == batch_size, (
len(obj_ids_global) == batch_size f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}" )
NEVER_OCCLUDED = -1 NEVER_OCCLUDED = -1
ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic
last_occluded_prev = torch.cat( last_occluded_prev = torch.cat(
@@ -1023,9 +1022,9 @@ class Sam3VideoBase(nn.Module):
reverse: bool = False, reverse: bool = False,
): ):
# Suppress overlapping masks for objects that were most recently occluded # Suppress overlapping masks for objects that were most recently occluded
assert ( assert binary_low_res_masks.dtype == torch.bool, (
binary_low_res_masks.dtype == torch.bool f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
), f"Expected boolean tensor, got {binary_low_res_masks.dtype}" )
to_suppress = torch.zeros( to_suppress = torch.zeros(
binary_low_res_masks.size(0), binary_low_res_masks.size(0),
device=binary_low_res_masks.device, device=binary_low_res_masks.device,
@@ -1130,9 +1129,9 @@ class Sam3VideoBase(nn.Module):
num_frames_propagated += 1 num_frames_propagated += 1
# only 1 frames should be propagated # only 1 frames should be propagated
assert ( assert num_frames_propagated == 1 and out_frame_idx == frame_idx, (
num_frames_propagated == 1 and out_frame_idx == frame_idx f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}" )
assert isinstance(out_obj_ids, list) assert isinstance(out_obj_ids, list)
obj_ids_local.extend(out_obj_ids) obj_ids_local.extend(out_obj_ids)
low_res_masks_list.append(out_low_res_masks.squeeze(1)) low_res_masks_list.append(out_low_res_masks.squeeze(1))
@@ -1189,9 +1188,9 @@ class Sam3VideoBase(nn.Module):
assert det_masks.is_floating_point(), "float tensor expected (do not binarize)" assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)" assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
assert ( assert trk_masks.size(0) == len(trk_obj_ids), (
trk_masks.size(0) == len(trk_obj_ids) f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}" )
if trk_masks.size(0) == 0: if trk_masks.size(0) == 0:
# all detections are new # all detections are new
new_det_fa_inds = np.arange(det_masks.size(0)) new_det_fa_inds = np.arange(det_masks.size(0))
@@ -1655,9 +1654,9 @@ class Sam3VideoBase(nn.Module):
# a) first, expand "confirmation_data" to include new masklets added in this frame # a) first, expand "confirmation_data" to include new masklets added in this frame
status_prev = confirmation_data["status"] status_prev = confirmation_data["status"]
consecutive_det_num_prev = confirmation_data["consecutive_det_num"] consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
assert ( assert status_prev.shape == obj_ids_all_gpu_prev.shape, (
status_prev.shape == obj_ids_all_gpu_prev.shape f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}" )
obj_id_to_updated_idx = { obj_id_to_updated_idx = {
obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated) obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)

View File

@@ -9,7 +9,6 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from sam3 import perflib from sam3 import perflib
from sam3.logger import get_logger from sam3.logger import get_logger
from sam3.model.act_ckpt_utils import clone_output_wrapper from sam3.model.act_ckpt_utils import clone_output_wrapper
@@ -555,7 +554,9 @@ class Sam3VideoInference(Sam3VideoBase):
assert ( assert (
"cached_frame_outputs" in inference_state "cached_frame_outputs" in inference_state
and frame_idx in inference_state["cached_frame_outputs"] and frame_idx in inference_state["cached_frame_outputs"]
), "No cached outputs found. Ensure normal propagation has run first to populate the cache." ), (
"No cached outputs found. Ensure normal propagation has run first to populate the cache."
)
cached_outputs = inference_state["cached_frame_outputs"][frame_idx] cached_outputs = inference_state["cached_frame_outputs"][frame_idx]
obj_id_to_mask = cached_outputs.copy() obj_id_to_mask = cached_outputs.copy()
@@ -563,9 +564,9 @@ class Sam3VideoInference(Sam3VideoBase):
# Update with refined masks if provided # Update with refined masks if provided
if refined_obj_id_to_mask is not None: if refined_obj_id_to_mask is not None:
for obj_id, refined_mask in refined_obj_id_to_mask.items(): for obj_id, refined_mask in refined_obj_id_to_mask.items():
assert ( assert refined_mask is not None, (
refined_mask is not None f"Refined mask data must be provided for obj_id {obj_id}"
), f"Refined mask data must be provided for obj_id {obj_id}" )
obj_id_to_mask[obj_id] = refined_mask obj_id_to_mask[obj_id] = refined_mask
return obj_id_to_mask return obj_id_to_mask
@@ -660,12 +661,12 @@ class Sam3VideoInference(Sam3VideoBase):
for i, thresh in enumerate(new_det_score_thresh_list): for i, thresh in enumerate(new_det_score_thresh_list):
self.new_det_thresh = thresh self.new_det_thresh = thresh
for num_objects in num_objects_list: for num_objects in num_objects_list:
logger.info(f"{i+1}/{num_rounds} warming up model compilation") logger.info(f"{i + 1}/{num_rounds} warming up model compilation")
self.add_prompt( self.add_prompt(
inference_state, frame_idx=start_frame_idx, text_str="cat" inference_state, frame_idx=start_frame_idx, text_str="cat"
) )
logger.info( logger.info(
f"{i+1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects" f"{i + 1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
) )
inference_state = self.add_fake_objects_to_inference_state( inference_state = self.add_fake_objects_to_inference_state(
inference_state, num_objects, frame_idx=start_frame_idx inference_state, num_objects, frame_idx=start_frame_idx
@@ -690,7 +691,7 @@ class Sam3VideoInference(Sam3VideoBase):
pass pass
self.reset_state(inference_state) self.reset_state(inference_state)
logger.info( logger.info(
f"{i+1}/{num_rounds} warming up model compilation -- completed round {i+1} out of {num_rounds}" f"{i + 1}/{num_rounds} warming up model compilation -- completed round {i + 1} out of {num_rounds}"
) )
# Warm up Tracker memory encoder with varying input shapes # Warm up Tracker memory encoder with varying input shapes
@@ -854,12 +855,12 @@ class Sam3VideoInference(Sam3VideoBase):
logger.debug("Running add_prompt on frame %d", frame_idx) logger.debug("Running add_prompt on frame %d", frame_idx)
num_frames = inference_state["num_frames"] num_frames = inference_state["num_frames"]
assert ( assert text_str is not None or boxes_xywh is not None, (
text_str is not None or boxes_xywh is not None "at least one type of prompt (text, boxes) must be provided"
), "at least one type of prompt (text, boxes) must be provided" )
assert ( assert 0 <= frame_idx < num_frames, (
0 <= frame_idx < num_frames f"{frame_idx=} is out of range for a total of {num_frames} frames"
), f"{frame_idx=} is out of range for a total of {num_frames} frames" )
# since it's a semantic prompt, we start over # since it's a semantic prompt, we start over
self.reset_state(inference_state) self.reset_state(inference_state)
@@ -1200,9 +1201,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
"propagation_partial", "propagation_partial",
"propagation_fetch", "propagation_fetch",
] ]
assert ( assert action_type in instance_actions + propagation_actions, (
action_type in instance_actions + propagation_actions f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
), f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}" )
action = { action = {
"type": action_type, "type": action_type,
"frame_idx": frame_idx, "frame_idx": frame_idx,
@@ -1370,12 +1371,12 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
): ):
if points is not None: if points is not None:
# Tracker instance prompts # Tracker instance prompts
assert ( assert text_str is None and boxes_xywh is None, (
text_str is None and boxes_xywh is None "When points are provided, text_str and boxes_xywh must be None."
), "When points are provided, text_str and boxes_xywh must be None." )
assert ( assert obj_id is not None, (
obj_id is not None "When points are provided, obj_id must be provided."
), "When points are provided, obj_id must be provided." )
return self.add_tracker_new_points( return self.add_tracker_new_points(
inference_state, inference_state,
frame_idx, frame_idx,
@@ -1491,9 +1492,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
tracker_states = self._get_tracker_inference_states_by_obj_ids( tracker_states = self._get_tracker_inference_states_by_obj_ids(
inference_state, [obj_id] inference_state, [obj_id]
) )
assert ( assert len(tracker_states) == 1, (
len(tracker_states) == 1 f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
), f"[rank={self.rank}] Multiple Tracker inference states found for the same object id." )
tracker_state = tracker_states[0] tracker_state = tracker_states[0]
# log # log

View File

@@ -16,7 +16,6 @@ from typing import List, Optional
import psutil import psutil
import torch import torch
from sam3.logger import get_logger from sam3.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -170,7 +169,7 @@ class Sam3VideoPredictor:
): ):
"""Remove an object from tracking.""" """Remove an object from tracking."""
logger.debug( logger.debug(
f"remove object {obj_id} in session {session_id}: " f"{is_user_action=}" f"remove object {obj_id} in session {session_id}: {is_user_action=}"
) )
session = self._get_session(session_id) session = self._get_session(session_id)
inference_state = session["state"] inference_state = session["state"]

View File

@@ -318,9 +318,9 @@ class VETextEncoder(nn.Module):
# The text is already encoded, use as is. # The text is already encoded, use as is.
text_attention_mask, text_memory_resized, tokenized = text text_attention_mask, text_memory_resized, tokenized = text
inputs_embeds = tokenized["inputs_embeds"] inputs_embeds = tokenized["inputs_embeds"]
assert ( assert input_boxes is None or len(input_boxes) == 0, (
input_boxes is None or len(input_boxes) == 0 "Can't replace boxes in text if it's already encoded"
), "Can't replace boxes in text if it's already encoded" )
# Note that the input_embeds are returned in pytorch's convention (sequence first) # Note that the input_embeds are returned in pytorch's convention (sequence first)
return ( return (

View File

@@ -708,9 +708,9 @@ class ViT(nn.Module):
self.retain_cls_token = retain_cls_token self.retain_cls_token = retain_cls_token
if self.retain_cls_token: if self.retain_cls_token:
assert pretrain_use_cls_token assert pretrain_use_cls_token
assert ( assert len(window_block_indexes) == 0, (
len(window_block_indexes) == 0 "windowing not supported with cls token"
), "windowing not supported with cls token" )
assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token" assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"

View File

@@ -9,7 +9,6 @@ from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention import sdpa_kernel, SDPBackend
from .act_ckpt_utils import activation_ckpt_wrapper from .act_ckpt_utils import activation_ckpt_wrapper

View File

@@ -6,7 +6,6 @@ import os
from typing import Optional from typing import Optional
import pkg_resources import pkg_resources
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download

View File

@@ -36,9 +36,9 @@ def connected_components_cpu(input_tensor: torch.Tensor):
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1: if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
input_tensor = input_tensor.squeeze(1) input_tensor = input_tensor.squeeze(1)
else: else:
assert ( assert input_tensor.dim() == 3, (
input_tensor.dim() == 3 "Input tensor must be (B, H, W) or (B, 1, H, W)."
), "Input tensor must be (B, H, W) or (B, 1, H, W)." )
batch_size = input_tensor.shape[0] batch_size = input_tensor.shape[0]
labels_list = [] labels_list = []
@@ -67,9 +67,9 @@ def connected_components(input_tensor: torch.Tensor):
if input_tensor.dim() == 3: if input_tensor.dim() == 3:
input_tensor = input_tensor.unsqueeze(1) input_tensor = input_tensor.unsqueeze(1)
assert ( assert input_tensor.dim() == 4 and input_tensor.shape[1] == 1, (
input_tensor.dim() == 4 and input_tensor.shape[1] == 1 "Input tensor must be (B, H, W) or (B, 1, H, W)."
), "Input tensor must be (B, H, W) or (B, 1, H, W)." )
if input_tensor.is_cuda: if input_tensor.is_cuda:
if HAS_CC_TORCH: if HAS_CC_TORCH:

View File

@@ -6,7 +6,6 @@ import logging
import numpy as np import numpy as np
import torch import torch
from sam3.perflib.masks_ops import mask_iou from sam3.perflib.masks_ops import mask_iou

View File

@@ -407,16 +407,16 @@ def connected_components_triton(input_tensor: torch.Tensor):
- A BxHxW output tensor with dense labels. Background is 0. - A BxHxW output tensor with dense labels. Background is 0.
- A BxHxW tensor with the size of the connected component for each pixel. - A BxHxW tensor with the size of the connected component for each pixel.
""" """
assert ( assert input_tensor.is_cuda and input_tensor.is_contiguous(), (
input_tensor.is_cuda and input_tensor.is_contiguous() "Input tensor must be a contiguous CUDA tensor."
), "Input tensor must be a contiguous CUDA tensor." )
out_shape = input_tensor.shape out_shape = input_tensor.shape
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1: if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
input_tensor = input_tensor.squeeze(1) input_tensor = input_tensor.squeeze(1)
else: else:
assert ( assert input_tensor.dim() == 3, (
input_tensor.dim() == 3 "Input tensor must be (B, H, W) or (B, 1, H, W)."
), "Input tensor must be (B, H, W) or (B, 1, H, W)." )
B, H, W = input_tensor.shape B, H, W = input_tensor.shape
numel = B * H * W numel = B * H * W

View File

@@ -202,9 +202,9 @@ class MaskDecoder(nn.Module):
assert image_embeddings.shape[0] == tokens.shape[0] assert image_embeddings.shape[0] == tokens.shape[0]
src = image_embeddings src = image_embeddings
src = src + dense_prompt_embeddings src = src + dense_prompt_embeddings
assert ( assert image_pe.size(0) == 1, (
image_pe.size(0) == 1 "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" )
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape b, c, h, w = src.shape

View File

@@ -8,7 +8,6 @@ from typing import Tuple, Type
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sam3.sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis from sam3.sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis
from torch import nn, Tensor from torch import nn, Tensor
@@ -205,9 +204,9 @@ class Attention(nn.Module):
self.internal_dim = embedding_dim // downsample_rate self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads self.num_heads = num_heads
self.use_fa3 = use_fa3 self.use_fa3 = use_fa3
assert ( assert self.internal_dim % num_heads == 0, (
self.internal_dim % num_heads == 0 "num_heads must divide embedding_dim."
), "num_heads must divide embedding_dim." )
self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)

View File

@@ -142,9 +142,9 @@ class COCO_FROM_JSON:
self.prompts = {} self.prompts = {}
for loc_dict in prompts: for loc_dict in prompts:
self.prompts[int(loc_dict["id"])] = loc_dict["name"] self.prompts[int(loc_dict["id"])] = loc_dict["name"]
assert len(self.prompts) == len( assert len(self.prompts) == len(self._sorted_cat_ids), (
self._sorted_cat_ids "Number of prompts must match number of categories"
), "Number of prompts must match number of categories" )
def getDatapointIds(self): def getDatapointIds(self):
"""Return all datapoint indices for training.""" """Return all datapoint indices for training."""

View File

@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_data
from typing import Any, get_args, get_origin, List, Union from typing import Any, get_args, get_origin, List, Union
import torch import torch
from sam3.model.data_misc import ( from sam3.model.data_misc import (
BatchedDatapoint, BatchedDatapoint,
BatchedFindTarget, BatchedFindTarget,
@@ -217,9 +216,9 @@ def collate_fn_api(
text_batch.append(q.query_text) text_batch.append(q.query_text)
stages[stage_id].text_ids.append(text_batch.index(q.query_text)) stages[stage_id].text_ids.append(text_batch.index(q.query_text))
assert ( assert q.inference_metadata is not None, (
q.inference_metadata is not None "inference_metadata must be provided when FindQueryLoaded is created."
), "inference_metadata must be provided when FindQueryLoaded is created." )
for f in fields(q.inference_metadata): for f in fields(q.inference_metadata):
getattr(find_metadatas[stage_id], f.name).append( getattr(find_metadatas[stage_id], f.name).append(
getattr(q.inference_metadata, f.name) getattr(q.inference_metadata, f.name)

View File

@@ -19,10 +19,8 @@ import torch.utils.data
import torchvision import torchvision
from decord import cpu, VideoReader from decord import cpu, VideoReader
from iopath.common.file_io import g_pathmgr from iopath.common.file_io import g_pathmgr
from PIL import Image as PILImage from PIL import Image as PILImage
from PIL.Image import DecompressionBombError from PIL.Image import DecompressionBombError
from sam3.model.box_ops import box_xywh_to_xyxy from sam3.model.box_ops import box_xywh_to_xyxy
from torchvision.datasets.vision import VisionDataset from torchvision.datasets.vision import VisionDataset
@@ -234,9 +232,9 @@ class CustomCocoDetectionAPI(VisionDataset):
if self.coco is not None: if self.coco is not None:
return return
assert g_pathmgr.isfile( assert g_pathmgr.isfile(self.annFile), (
self.annFile f"please provide valid annotation file. Missing: {self.annFile}"
), f"please provide valid annotation file. Missing: {self.annFile}" )
annFile = g_pathmgr.get_local_path(self.annFile) annFile = g_pathmgr.get_local_path(self.annFile)
if self.coco is not None: if self.coco is not None:
@@ -326,9 +324,9 @@ class CustomCocoDetectionAPI(VisionDataset):
else: else:
num_queries_per_stage = stage2num_queries.most_common(1)[0][1] num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
for stage, num_queries in stage2num_queries.items(): for stage, num_queries in stage2num_queries.items():
assert ( assert num_queries == num_queries_per_stage, (
num_queries == num_queries_per_stage f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}" )
for query_id, query in enumerate(queries): for query_id, query in enumerate(queries):
h, w = id2imsize[query["image_id"]] h, w = id2imsize[query["image_id"]]

View File

@@ -3,7 +3,6 @@
# pyre-unsafe # pyre-unsafe
import copy import copy
import io import io
import json import json
import logging import logging
@@ -16,7 +15,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch import torch
import torchvision import torchvision
# from decord import cpu, VideoReader # from decord import cpu, VideoReader
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
@@ -220,9 +218,9 @@ class VideoGroundingDataset(Sam3ImageDataset):
for query in filtered_queries: for query in filtered_queries:
ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1] ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1] ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
assert ( assert ptr_x_is_empty and ptr_y_is_empty, (
ptr_x_is_empty and ptr_y_is_empty "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers" )
query["query_processing_order"] = stage_id_old2new[ query["query_processing_order"] = stage_id_old2new[
query["query_processing_order"] query["query_processing_order"]
] ]

View File

@@ -9,11 +9,8 @@ import torch
import torch.distributed import torch.distributed
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics import torchmetrics
from sam3.model import box_ops from sam3.model import box_ops
from sam3.model.data_misc import interpolate from sam3.model.data_misc import interpolate
from sam3.train.loss.sigmoid_focal_loss import ( from sam3.train.loss.sigmoid_focal_loss import (
triton_sigmoid_focal_loss, triton_sigmoid_focal_loss,
triton_sigmoid_focal_loss_reduce, triton_sigmoid_focal_loss_reduce,
@@ -327,7 +324,9 @@ class IABCEMdetr(LossWithWeights):
if num_det_queries is not None: if num_det_queries is not None:
logging.warning("note: it's not needed to set num_det_queries anymore") logging.warning("note: it's not needed to set num_det_queries anymore")
if self.use_separate_loss_for_det_and_trk: if self.use_separate_loss_for_det_and_trk:
assert not self.weak_loss, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead" assert not self.weak_loss, (
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
)
self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos
self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg
self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos
@@ -342,7 +341,9 @@ class IABCEMdetr(LossWithWeights):
and det_non_exhaustive_loss_scale_neg == 1.0 and det_non_exhaustive_loss_scale_neg == 1.0
and trk_loss_scale_pos == 1.0 and trk_loss_scale_pos == 1.0
and trk_loss_scale_neg == 1.0 and trk_loss_scale_neg == 1.0
), "If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0" ), (
"If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
)
def get_loss(self, outputs, targets, indices, num_boxes): def get_loss(self, outputs, targets, indices, num_boxes):
assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape" assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape"
@@ -443,7 +444,9 @@ class IABCEMdetr(LossWithWeights):
pass pass
if self.weak_loss: if self.weak_loss:
assert not self.use_separate_loss_for_det_and_trk, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead" assert not self.use_separate_loss_for_det_and_trk, (
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
)
# nullify the negative loss for the non-exhaustive classes # nullify the negative loss for the non-exhaustive classes
assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0] assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0]
@@ -497,9 +500,9 @@ class IABCEMdetr(LossWithWeights):
loss_bce = loss_bce.mean() loss_bce = loss_bce.mean()
else: else:
assert isinstance(self.pad_n_queries, int) assert isinstance(self.pad_n_queries, int)
assert ( assert loss_bce.size(1) < self.pad_n_queries, (
loss_bce.size(1) < self.pad_n_queries f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
), f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions." )
loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0)) loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0))
bce_f1 = torchmetrics.functional.f1_score( bce_f1 = torchmetrics.functional.f1_score(

View File

@@ -3,9 +3,7 @@
# pyre-unsafe # pyre-unsafe
import torch import torch
from sam3.model.model_misc import SAM3Output from sam3.model.model_misc import SAM3Output
from sam3.train.utils.distributed import get_world_size from sam3.train.utils.distributed import get_world_size
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks

View File

@@ -103,9 +103,9 @@ def dilation(mask, kernel_size):
assert mask.ndim == 3 assert mask.ndim == 3
kernel_size = int(kernel_size) kernel_size = int(kernel_size)
assert ( assert kernel_size % 2 == 1, (
kernel_size % 2 == 1 f"Dilation expects a odd kernel size, got {kernel_size}"
), f"Dilation expects a odd kernel size, got {kernel_size}" )
if mask.is_cuda: if mask.is_cuda:
m = mask.unsqueeze(1).to(torch.float16) m = mask.unsqueeze(1).to(torch.float16)

View File

@@ -8,7 +8,6 @@ Modules to compute the matching cost and solve the corresponding LSAP.
import numpy as np import numpy as np
import torch import torch
from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from torch import nn from torch import nn
@@ -60,9 +59,9 @@ class HungarianMatcher(nn.Module):
self.cost_bbox = cost_bbox self.cost_bbox = cost_bbox
self.cost_giou = cost_giou self.cost_giou = cost_giou
self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1) self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1)
assert ( assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0 "all costs cant be 0"
), "all costs cant be 0" )
self.focal_loss = focal_loss self.focal_loss = focal_loss
self.focal_alpha = focal_alpha self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma self.focal_gamma = focal_gamma
@@ -197,9 +196,9 @@ class BinaryHungarianMatcher(nn.Module):
self.cost_bbox = cost_bbox self.cost_bbox = cost_bbox
self.cost_giou = cost_giou self.cost_giou = cost_giou
self.norm = nn.Sigmoid() self.norm = nn.Sigmoid()
assert ( assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0 "all costs cant be 0"
), "all costs cant be 0" )
@torch.no_grad() @torch.no_grad()
def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1): def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1):
@@ -322,9 +321,9 @@ class BinaryFocalHungarianMatcher(nn.Module):
self.alpha = alpha self.alpha = alpha
self.gamma = gamma self.gamma = gamma
self.stable = stable self.stable = stable
assert ( assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0 "all costs cant be 0"
), "all costs cant be 0" )
@torch.no_grad() @torch.no_grad()
def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1): def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1):
@@ -470,9 +469,9 @@ class BinaryHungarianMatcherV2(nn.Module):
self.cost_bbox = cost_bbox self.cost_bbox = cost_bbox
self.cost_giou = cost_giou self.cost_giou = cost_giou
self.norm = nn.Sigmoid() self.norm = nn.Sigmoid()
assert ( assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0 "all costs cant be 0"
), "all costs cant be 0" )
self.focal = focal self.focal = focal
if focal: if focal:
self.alpha = alpha self.alpha = alpha

View File

@@ -22,7 +22,6 @@ from typing import (
) )
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
from omegaconf import DictConfig from omegaconf import DictConfig
@@ -212,9 +211,9 @@ def unix_module_cls_pattern_to_parameter_names(
"match any classes in the model" "match any classes in the model"
) )
matching_parameters = module_cls_to_param_names[module_cls] matching_parameters = module_cls_to_param_names[module_cls]
assert ( assert len(matching_parameters) > 0, (
len(matching_parameters) > 0 f"module_cls_name {module_cls_name} does not contain any parameters in the model"
), f"module_cls_name {module_cls_name} does not contain any parameters in the model" )
logging.info( logging.info(
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} " f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
) )
@@ -240,9 +239,9 @@ def unix_param_pattern_to_parameter_names(
allowed_parameter_names = [] allowed_parameter_names = []
for param_name in filter_param_names: for param_name in filter_param_names:
matching_parameters = set(fnmatch.filter(parameter_names, param_name)) matching_parameters = set(fnmatch.filter(parameter_names, param_name))
assert ( assert len(matching_parameters) >= 1, (
len(matching_parameters) >= 1 f"param_name {param_name} does not match any parameters in the model"
), f"param_name {param_name} does not match any parameters in the model" )
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}") logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
allowed_parameter_names.append(matching_parameters) allowed_parameter_names.append(matching_parameters)
return set.union(*allowed_parameter_names) return set.union(*allowed_parameter_names)

View File

@@ -12,13 +12,10 @@ from copy import deepcopy
import submitit import submitit
import torch import torch
from hydra import compose, initialize_config_module from hydra import compose, initialize_config_module
from hydra.utils import instantiate from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr from iopath.common.file_io import g_pathmgr
from omegaconf import OmegaConf from omegaconf import OmegaConf
from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers
from tqdm import tqdm from tqdm import tqdm
@@ -212,9 +209,9 @@ def main(args) -> None:
}, },
} }
if "include_nodes" in submitit_conf: if "include_nodes" in submitit_conf:
assert ( assert len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes, (
len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes "Not enough nodes"
), "Not enough nodes" )
job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join( job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
submitit_conf["include_nodes"] submitit_conf["include_nodes"]
) )

View File

@@ -15,28 +15,22 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from hydra.utils import instantiate from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr from iopath.common.file_io import g_pathmgr
from sam3.model.data_misc import BatchedDatapoint from sam3.model.data_misc import BatchedDatapoint
from sam3.model.model_misc import SAM3Output from sam3.model.model_misc import SAM3Output
from sam3.model.utils.misc import copy_data_to_device from sam3.model.utils.misc import copy_data_to_device
from sam3.train.optim.optimizer import construct_optimizer from sam3.train.optim.optimizer import construct_optimizer
from sam3.train.utils.checkpoint_utils import ( from sam3.train.utils.checkpoint_utils import (
assert_skipped_parameters_are_frozen, assert_skipped_parameters_are_frozen,
exclude_params_matching_unix_pattern, exclude_params_matching_unix_pattern,
load_state_dict_into_model, load_state_dict_into_model,
with_check_parameter_frozen, with_check_parameter_frozen,
) )
from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank
from sam3.train.utils.logger import Logger, setup_logging from sam3.train.utils.logger import Logger, setup_logging
from sam3.train.utils.train_utils import ( from sam3.train.utils.train_utils import (
AverageMeter, AverageMeter,
@@ -215,9 +209,9 @@ class Trainer:
set_seeds(seed_value, self.max_epochs, self.distributed_rank) set_seeds(seed_value, self.max_epochs, self.distributed_rank)
log_env_variables() log_env_variables()
assert ( assert is_dist_avail_and_initialized(), (
is_dist_avail_and_initialized() "Torch distributed needs to be initialized before calling the trainer."
), "Torch distributed needs to be initialized before calling the trainer." )
self._setup_components() # Except Optimizer everything is setup here. self._setup_components() # Except Optimizer everything is setup here.
self._move_to_device() self._move_to_device()
@@ -227,9 +221,9 @@ class Trainer:
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f") self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
if self.checkpoint_conf.resume_from is not None: if self.checkpoint_conf.resume_from is not None:
assert os.path.exists( assert os.path.exists(self.checkpoint_conf.resume_from), (
self.checkpoint_conf.resume_from f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!" )
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt") dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
if self.distributed_rank == 0 and not os.path.exists(dst): if self.distributed_rank == 0 and not os.path.exists(dst):
# Copy the "resume_from" checkpoint to the checkpoint folder # Copy the "resume_from" checkpoint to the checkpoint folder
@@ -477,9 +471,9 @@ class Trainer:
return self.loss[key] return self.loss[key]
assert key != "all", "Loss must be specified for key='all'" assert key != "all", "Loss must be specified for key='all'"
assert ( assert "default" in self.loss, (
"default" in self.loss f"Key {key} not found in losss, and no default provided"
), f"Key {key} not found in losss, and no default provided" )
return self.loss["default"] return self.loss["default"]
def _find_meter(self, phase: str, key: str): def _find_meter(self, phase: str, key: str):
@@ -922,12 +916,12 @@ class Trainer:
self.optim.zero_grad(set_to_none=True) self.optim.zero_grad(set_to_none=True)
if self.gradient_accumulation_steps > 1: if self.gradient_accumulation_steps > 1:
assert isinstance( assert isinstance(batch, list), (
batch, list f"Expected a list of batches, got {type(batch)}"
), f"Expected a list of batches, got {type(batch)}" )
assert ( assert len(batch) == self.gradient_accumulation_steps, (
len(batch) == self.gradient_accumulation_steps f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
), f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}" )
accum_steps = len(batch) accum_steps = len(batch)
else: else:
accum_steps = 1 accum_steps = 1
@@ -1039,9 +1033,9 @@ class Trainer:
def _check_val_key_match(self, val_keys, phase): def _check_val_key_match(self, val_keys, phase):
if val_keys is not None: if val_keys is not None:
# Check if there are any duplicates # Check if there are any duplicates
assert len(val_keys) == len( assert len(val_keys) == len(set(val_keys)), (
set(val_keys) f"Duplicate keys in val datasets, keys: {val_keys}"
), f"Duplicate keys in val datasets, keys: {val_keys}" )
# Check that the keys match the meter keys # Check that the keys match the meter keys
if self.meters_conf is not None and phase in self.meters_conf: if self.meters_conf is not None and phase in self.meters_conf:
@@ -1055,9 +1049,9 @@ class Trainer:
loss_keys = set(self.loss_conf.keys()) - set(["all"]) loss_keys = set(self.loss_conf.keys()) - set(["all"])
if "default" not in loss_keys: if "default" not in loss_keys:
for k in val_keys: for k in val_keys:
assert ( assert k in loss_keys, (
k in loss_keys f"Error: key {k} is not defined in the losses, and no default is set"
), f"Error: key {k} is not defined in the losses, and no default is set" )
def _setup_components(self): def _setup_components(self):
# Get the keys for all the val datasets, if any # Get the keys for all the val datasets, if any

View File

@@ -14,7 +14,6 @@ import PIL
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from sam3.model.box_ops import box_xyxy_to_cxcywh from sam3.model.box_ops import box_xyxy_to_cxcywh
from sam3.model.data_misc import interpolate from sam3.model.data_misc import interpolate
@@ -277,9 +276,9 @@ class RandomSizeCrop:
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1)) max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
) )
result_img, result_target = crop(img, target, [j, i, h, w]) result_img, result_target = crop(img, target, [j, i, h, w])
assert ( assert len(result_target["boxes"]) == init_boxes, (
len(result_target["boxes"]) == init_boxes f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}" )
return result_img, result_target return result_img, result_target
else: else:

View File

@@ -7,7 +7,6 @@ Transforms and data augmentation for both image + bbox.
""" """
import logging import logging
import numbers import numbers
import random import random
from collections.abc import Sequence from collections.abc import Sequence
@@ -17,9 +16,7 @@ import torch
import torchvision.transforms as T import torchvision.transforms as T
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
import torchvision.transforms.v2.functional as Fv2 import torchvision.transforms.v2.functional as Fv2
from PIL import Image as PILImage from PIL import Image as PILImage
from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes
from sam3.train.data.sam3_image_dataset import Datapoint from sam3.train.data.sam3_image_dataset import Datapoint
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode

View File

@@ -4,12 +4,10 @@
import logging import logging
import random import random
from collections import defaultdict from collections import defaultdict
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object
@@ -381,9 +379,9 @@ class FlexibleFilterFindGetQueries:
if len(new_find_queries) == 0: if len(new_find_queries) == 0:
start_with_zero_check = True start_with_zero_check = True
assert ( assert start_with_zero_check, (
start_with_zero_check "Invalid Find queries, they need to start at query_processing_order = 0"
), "Invalid Find queries, they need to start at query_processing_order = 0" )
datapoint.find_queries = new_find_queries datapoint.find_queries = new_find_queries

View File

@@ -7,7 +7,6 @@ import numpy as np
import torch import torch
from PIL import Image as PILImage from PIL import Image as PILImage
from pycocotools import mask as mask_util from pycocotools import mask as mask_util
from sam3.train.data.sam3_image_dataset import Datapoint from sam3.train.data.sam3_image_dataset import Datapoint
from torchvision.ops import masks_to_boxes from torchvision.ops import masks_to_boxes
@@ -250,9 +249,9 @@ class RandomGeometricInputsAPI:
def _get_target_object(self, datapoint, query): def _get_target_object(self, datapoint, query):
img = datapoint.images[query.image_id] img = datapoint.images[query.image_id]
targets = query.object_ids_output targets = query.object_ids_output
assert ( assert len(targets) == 1, (
len(targets) == 1 "Geometric queries only support a single target object."
), "Geometric queries only support a single target object." )
target_idx = targets[0] target_idx = targets[0]
return img.objects[target_idx] return img.objects[target_idx]

View File

@@ -5,12 +5,9 @@
import numpy as np import numpy as np
import pycocotools.mask as mask_utils import pycocotools.mask as mask_utils
import torch import torch
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from PIL import Image as PILImage from PIL import Image as PILImage
from sam3.model.box_ops import masks_to_boxes from sam3.model.box_ops import masks_to_boxes
from sam3.train.data.sam3_image_dataset import Datapoint from sam3.train.data.sam3_image_dataset import Datapoint

View File

@@ -36,9 +36,9 @@ def unix_pattern_to_parameter_names(
parameter_names = [] parameter_names = []
for param_name in constraints: for param_name in constraints:
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name)) matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
assert ( assert len(matching_parameters) > 0, (
len(matching_parameters) > 0 f"param_names {param_name} don't match any param in the given names."
), f"param_names {param_name} don't match any param in the given names." )
parameter_names.append(matching_parameters) parameter_names.append(matching_parameters)
return set.union(*parameter_names) return set.union(*parameter_names)

View File

@@ -10,10 +10,8 @@ import uuid
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from hydra.utils import instantiate from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr from iopath.common.file_io import g_pathmgr
from numpy import ndarray from numpy import ndarray
from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter

View File

@@ -11,7 +11,6 @@ from datetime import timedelta
from typing import Optional from typing import Optional
import hydra import hydra
import numpy as np import numpy as np
import omegaconf import omegaconf
import torch import torch
@@ -83,9 +82,9 @@ def get_machine_local_and_dist_rank():
""" """
local_rank = int(os.environ.get("LOCAL_RANK", None)) local_rank = int(os.environ.get("LOCAL_RANK", None))
distributed_rank = int(os.environ.get("RANK", None)) distributed_rank = int(os.environ.get("RANK", None))
assert ( assert local_rank is not None and distributed_rank is not None, (
local_rank is not None and distributed_rank is not None "Please the set the RANK and LOCAL_RANK environment variables."
), "Please the set the RANK and LOCAL_RANK environment variables." )
return local_rank, distributed_rank return local_rank, distributed_rank

View File

@@ -158,22 +158,22 @@ def plot_mask(mask, color="r", ax=None):
def normalize_bbox(bbox_xywh, img_w, img_h): def normalize_bbox(bbox_xywh, img_w, img_h):
# Assumes bbox_xywh is in XYWH format # Assumes bbox_xywh is in XYWH format
if isinstance(bbox_xywh, list): if isinstance(bbox_xywh, list):
assert ( assert len(bbox_xywh) == 4, (
len(bbox_xywh) == 4 "bbox_xywh list must have 4 elements. Batching not support except for torch tensors."
), "bbox_xywh list must have 4 elements. Batching not support except for torch tensors." )
normalized_bbox = bbox_xywh.copy() normalized_bbox = bbox_xywh.copy()
normalized_bbox[0] /= img_w normalized_bbox[0] /= img_w
normalized_bbox[1] /= img_h normalized_bbox[1] /= img_h
normalized_bbox[2] /= img_w normalized_bbox[2] /= img_w
normalized_bbox[3] /= img_h normalized_bbox[3] /= img_h
else: else:
assert isinstance( assert isinstance(bbox_xywh, torch.Tensor), (
bbox_xywh, torch.Tensor "Only torch tensors are supported for batching."
), "Only torch tensors are supported for batching." )
normalized_bbox = bbox_xywh.clone() normalized_bbox = bbox_xywh.clone()
assert ( assert normalized_bbox.size(-1) == 4, (
normalized_bbox.size(-1) == 4 "bbox_xywh tensor must have last dimension of size 4."
), "bbox_xywh tensor must have last dimension of size 4." )
normalized_bbox[..., 0] /= img_w normalized_bbox[..., 0] /= img_w
normalized_bbox[..., 1] /= img_h normalized_bbox[..., 1] /= img_h
normalized_bbox[..., 2] /= img_w normalized_bbox[..., 2] /= img_w
@@ -244,10 +244,10 @@ def visualize_formatted_frame_output(
num_outputs = len(outputs_list) num_outputs = len(outputs_list)
if titles is None: if titles is None:
titles = [f"Set {i+1}" for i in range(num_outputs)] titles = [f"Set {i + 1}" for i in range(num_outputs)]
assert ( assert len(titles) == num_outputs, (
len(titles) == num_outputs "length of `titles` should match that of `outputs_list` if not None."
), "length of `titles` should match that of `outputs_list` if not None." )
_, axes = plt.subplots(1, num_outputs, figsize=figsize) _, axes = plt.subplots(1, num_outputs, figsize=figsize)
if num_outputs == 1: if num_outputs == 1:
@@ -703,9 +703,9 @@ def get_all_annotations_for_frame(
# Get the frame # Get the frame
video_df_current = video_df[video_df.id == video_id] video_df_current = video_df[video_df.id == video_id]
assert ( assert len(video_df_current) == 1, (
len(video_df_current) == 1 f"Expected 1 video row, got {len(video_df_current)}"
), f"Expected 1 video row, got {len(video_df_current)}" )
video_row = video_df_current.iloc[0] video_row = video_df_current.iloc[0]
file_name = video_row.file_names[frame_idx] file_name = video_row.file_names[frame_idx]
file_path = os.path.join( file_path = os.path.join(
@@ -796,7 +796,7 @@ def visualize_prompt_overlay(
ax.text( ax.text(
x_img + 5, x_img + 5,
y_img - 5, y_img - 5,
f"P{i+1}", f"P{i + 1}",
color=color, color=color,
fontsize=10, fontsize=10,
weight="bold", weight="bold",
@@ -828,7 +828,7 @@ def visualize_prompt_overlay(
ax.text( ax.text(
x_img, x_img,
y_img - 5, y_img - 5,
f"B{i+1}", f"B{i + 1}",
color=color, color=color,
fontsize=10, fontsize=10,
weight="bold", weight="bold",

View File

@@ -11,7 +11,6 @@ from concurrent.futures import as_completed, ThreadPoolExecutor
from pathlib import Path from pathlib import Path
import yt_dlp import yt_dlp
from utils import ( from utils import (
annotation_files, annotation_files,
config, config,
@@ -244,9 +243,9 @@ def download_sav():
def main(): def main():
assert len(sys.argv) > 1, "You have to provide the name of the dataset" assert len(sys.argv) > 1, "You have to provide the name of the dataset"
dataset_name = sys.argv[1] dataset_name = sys.argv[1]
assert ( assert dataset_name in annotation_files, (
dataset_name in annotation_files f"The dataset can be one of {list(annotation_files.keys())}"
), f"The dataset can be one of {list(annotation_files.keys())}" )
if dataset_name == "yt1b": if dataset_name == "yt1b":
download_youtube() download_youtube()

View File

@@ -68,9 +68,9 @@ def process_image(args):
def main(): def main():
assert len(sys.argv) > 1, "You have to provide the name of the dataset" assert len(sys.argv) > 1, "You have to provide the name of the dataset"
dataset_name = sys.argv[1] dataset_name = sys.argv[1]
assert ( assert dataset_name in annotation_files, (
dataset_name in annotation_files f"The dataset can be one of {list(annotation_files.keys())}"
), f"The dataset can be one of {list(annotation_files.keys())}" )
all_outputs = [] all_outputs = []
for file in annotation_files[dataset_name]: for file in annotation_files[dataset_name]:
with open(os.path.join(config["path_annotations"], file), "r") as f: with open(os.path.join(config["path_annotations"], file), "r") as f:

View File

@@ -51,7 +51,7 @@ def main(args, n_workers=20):
paths = [ paths = [
( (
raw_folder_food_images raw_folder_food_images
/ f'{Path(each).stem.split("_")[-1]}{Path(each).suffix}', / f"{Path(each).stem.split('_')[-1]}{Path(each).suffix}",
processed_folder / each, processed_folder / each,
) )
for each in img_filenames for each in img_filenames

View File

@@ -3,7 +3,6 @@
# pyre-unsafe # pyre-unsafe
import argparse import argparse
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
from functools import partial from functools import partial

View File

@@ -58,9 +58,9 @@ class YtVideoPrep:
df = self.yt1b_start_end_time_df[ df = self.yt1b_start_end_time_df[
self.yt1b_start_end_time_df.saco_yt1b_id == self.saco_yt1b_id self.yt1b_start_end_time_df.saco_yt1b_id == self.saco_yt1b_id
] ]
assert ( assert len(df) == 1, (
len(df) == 1 f"Expected exactly 1 row for saco_yt1b_id: {self.saco_yt1b_id}, found {len(df)}"
), f"Expected exactly 1 row for saco_yt1b_id: {self.saco_yt1b_id}, found {len(df)}" )
id_and_frame_map_row = df.iloc[0] id_and_frame_map_row = df.iloc[0]
yt_video_id = ( yt_video_id = (
@@ -82,9 +82,9 @@ class YtVideoPrep:
def download_youtube_video(self): def download_youtube_video(self):
video_url = f"https://youtube.com/watch?v={self.yt_video_id}" video_url = f"https://youtube.com/watch?v={self.yt_video_id}"
assert os.path.exists( assert os.path.exists(self.cookies_file), (
self.cookies_file f"Cookies file '{self.cookies_file}' not found. Must have it to download videos."
), f"Cookies file '{self.cookies_file}' not found. Must have it to download videos." )
outtmpl = self.raw_video_path outtmpl = self.raw_video_path