apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)

Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: itamaro

Differential Revision: D90476315

fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
Bowie Chen
2026-01-11 23:16:49 -08:00
committed by meta-codesync[bot]
parent 7b89b8fc3f
commit 11dec2936d
69 changed files with 445 additions and 522 deletions

View File

@@ -7,7 +7,6 @@ Misc functions, including distributed helpers.
import collections
import re
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
@@ -29,9 +28,9 @@ def interpolate(
input, size, scale_factor, mode, align_corners
)
assert (
input.shape[0] != 0 or input.shape[1] != 0
), "At least one of the two first dimensions must be non zero"
assert input.shape[0] != 0 or input.shape[1] != 0, (
"At least one of the two first dimensions must be non zero"
)
if input.shape[1] == 0:
# Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim

View File

@@ -9,18 +9,13 @@ Inspired from Pytorch's version, adds the pre-norm variant
from typing import Any, Dict, List, Optional
import numpy as np
import torch
from sam3.sam.transformer import RoPEAttention
from torch import nn, Tensor
from torchvision.ops.roi_align import RoIAlign
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .model_misc import (
gen_sineembed_for_position,
get_activation_fn,
@@ -444,9 +439,9 @@ class TransformerDecoder(nn.Module):
- valid_ratios/spatial_shapes: bs, nlevel, 2
"""
if memory_mask is not None:
assert (
self.boxRPB == "none"
), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
assert self.boxRPB == "none", (
"inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
)
apply_dac = apply_dac if apply_dac is not None else self.dac
if apply_dac:
@@ -516,18 +511,18 @@ class TransformerDecoder(nn.Module):
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
if self.boxRPB != "none" and reference_boxes is not None:
assert (
spatial_shapes.shape[0] == 1
), "only single scale support implemented"
assert spatial_shapes.shape[0] == 1, (
"only single scale support implemented"
)
memory_mask = self._get_rpb_matrix(
reference_boxes,
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
)
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
if self.training:
assert (
self.use_act_checkpoint
), "Activation checkpointing not enabled in the decoder"
assert self.use_act_checkpoint, (
"Activation checkpointing not enabled in the decoder"
)
output, presence_out = activation_ckpt_wrapper(layer)(
tgt=output,
tgt_query_pos=query_pos,
@@ -676,9 +671,9 @@ class TransformerEncoderCrossAttention(nn.Module):
src_pos[0],
)
assert (
src.shape[1] == prompt.shape[1]
), "Batch size must be the same for src and prompt"
assert src.shape[1] == prompt.shape[1], (
"Batch size must be the same for src and prompt"
)
output = src

View File

@@ -322,9 +322,9 @@ class TransformerEncoder(nn.Module):
return reference_points
def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
assert (
len(srcs) == self.num_feature_levels
), "mismatch between expected and received # of feature levels"
assert len(srcs) == self.num_feature_levels, (
"mismatch between expected and received # of feature levels"
)
src_flatten = []
mask_flatten = []
@@ -406,9 +406,9 @@ class TransformerEncoder(nn.Module):
- spatial_shapes: Spatial dimensions of each feature level
- valid_ratios: Valid ratios for each feature level
"""
assert (
len(src) == self.num_feature_levels
), "must be equal to num_feature_levels"
assert len(src) == self.num_feature_levels, (
"must be equal to num_feature_levels"
)
if src_key_padding_masks is not None:
assert len(src_key_padding_masks) == self.num_feature_levels
if pos is not None:
@@ -538,9 +538,9 @@ class TransformerEncoderFusion(TransformerEncoder):
else None
)
else:
assert all(
x.dim == 4 for x in src
), "expected list of (bs, c, h, w) tensors"
assert all(x.dim == 4 for x in src), (
"expected list of (bs, c, h, w) tensors"
)
if self.add_pooled_text_to_img_feat:
# Fusion: Add mean pooled text to image features

View File

@@ -11,7 +11,6 @@ from typing_extensions import override
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .model_misc import get_clones
@@ -148,54 +147,42 @@ class Prompt:
)
# Dimension checks
assert (
box_embeddings is not None
and list(box_embeddings.shape[:2])
== [
box_seq_len,
bs,
]
), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
assert (
box_mask is not None
and list(box_mask.shape)
== [
bs,
box_seq_len,
]
), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
assert (
point_embeddings is not None
and list(point_embeddings.shape[:2])
== [
point_seq_len,
bs,
]
), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
assert (
point_mask is not None
and list(point_mask.shape)
== [
bs,
point_seq_len,
]
), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
assert (
box_labels is not None
and list(box_labels.shape)
== [
box_seq_len,
bs,
]
), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
assert (
point_labels is not None
and list(point_labels.shape)
== [
point_seq_len,
bs,
]
), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
assert box_embeddings is not None and list(box_embeddings.shape[:2]) == [
box_seq_len,
bs,
], (
f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
)
assert box_mask is not None and list(box_mask.shape) == [
bs,
box_seq_len,
], (
f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
)
assert point_embeddings is not None and list(point_embeddings.shape[:2]) == [
point_seq_len,
bs,
], (
f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
)
assert point_mask is not None and list(point_mask.shape) == [
bs,
point_seq_len,
], (
f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
)
assert box_labels is not None and list(box_labels.shape) == [
box_seq_len,
bs,
], (
f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
)
assert point_labels is not None and list(point_labels.shape) == [
point_seq_len,
bs,
], (
f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
)
assert (
# Allowed to be None, we leave it to the encoder to check for validity before encoding.
mask_embeddings is None
@@ -204,41 +191,41 @@ class Prompt:
mask_seq_len,
bs,
]
), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
assert (
mask_mask is None
or list(mask_mask.shape)
== [
bs,
mask_seq_len,
]
), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
), (
f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
)
assert mask_mask is None or list(mask_mask.shape) == [
bs,
mask_seq_len,
], (
f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
)
# Device checks
assert (
box_embeddings is not None and box_embeddings.device == device
), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
assert (
box_mask is not None and box_mask.device == device
), f"Expected box mask to be on device {device}, got {box_mask.device}"
assert (
box_labels is not None and box_labels.device == device
), f"Expected box labels to be on device {device}, got {box_labels.device}"
assert (
point_embeddings is not None and point_embeddings.device == device
), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
assert (
point_mask is not None and point_mask.device == device
), f"Expected point mask to be on device {device}, got {point_mask.device}"
assert (
point_labels is not None and point_labels.device == device
), f"Expected point labels to be on device {device}, got {point_labels.device}"
assert (
mask_embeddings is None or mask_embeddings.device == device
), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
assert (
mask_mask is None or mask_mask.device == device
), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
assert box_embeddings is not None and box_embeddings.device == device, (
f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
)
assert box_mask is not None and box_mask.device == device, (
f"Expected box mask to be on device {device}, got {box_mask.device}"
)
assert box_labels is not None and box_labels.device == device, (
f"Expected box labels to be on device {device}, got {box_labels.device}"
)
assert point_embeddings is not None and point_embeddings.device == device, (
f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
)
assert point_mask is not None and point_mask.device == device, (
f"Expected point mask to be on device {device}, got {point_mask.device}"
)
assert point_labels is not None and point_labels.device == device, (
f"Expected point labels to be on device {device}, got {point_labels.device}"
)
assert mask_embeddings is None or mask_embeddings.device == device, (
f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
)
assert mask_mask is None or mask_mask.device == device, (
f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
)
self.box_embeddings = box_embeddings
self.point_embeddings = point_embeddings
@@ -264,30 +251,30 @@ class Prompt:
if point_embeddings is not None:
point_seq_len = point_embeddings.shape[0]
if bs is not None:
assert (
bs == point_embeddings.shape[1]
), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
assert bs == point_embeddings.shape[1], (
f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
)
else:
bs = point_embeddings.shape[1]
if device is not None:
assert (
device == point_embeddings.device
), "Device mismatch between box and point embeddings"
assert device == point_embeddings.device, (
"Device mismatch between box and point embeddings"
)
else:
device = point_embeddings.device
if mask_embeddings is not None:
mask_seq_len = mask_embeddings.shape[0]
if bs is not None:
assert (
bs == mask_embeddings.shape[1]
), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
assert bs == mask_embeddings.shape[1], (
f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
)
else:
bs = mask_embeddings.shape[1]
if device is not None:
assert (
device == mask_embeddings.device
), "Device mismatch between box/point and mask embeddings."
assert device == mask_embeddings.device, (
"Device mismatch between box/point and mask embeddings."
)
else:
device = mask_embeddings.device
@@ -539,9 +526,9 @@ class SequenceGeometryEncoder(nn.Module):
if add_cls:
self.cls_embed = torch.nn.Embedding(1, self.d_model)
assert (
points_direct_project or points_pos_enc or points_pool
), "Error: need at least one way to encode points"
assert points_direct_project or points_pos_enc or points_pool, (
"Error: need at least one way to encode points"
)
assert (
encode_boxes_as_points
or boxes_direct_project
@@ -583,16 +570,16 @@ class SequenceGeometryEncoder(nn.Module):
self.encode = None
if num_layers > 0:
assert (
add_cls
), "It's currently highly recommended to add a CLS when using a transformer"
assert add_cls, (
"It's currently highly recommended to add a CLS when using a transformer"
)
self.encode = get_clones(layer, num_layers)
self.encode_norm = nn.LayerNorm(self.d_model)
if mask_encoder is not None:
assert isinstance(
mask_encoder, MaskEncoder
), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
assert isinstance(mask_encoder, MaskEncoder), (
f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
)
if add_mask_label:
self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
self.add_mask_label = add_mask_label
@@ -701,16 +688,15 @@ class SequenceGeometryEncoder(nn.Module):
img_feats: torch.Tensor = None,
):
n_masks, bs = masks.shape[:2]
assert (
n_masks == 1
), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
assert (
list(attn_mask.shape)
== [
bs,
n_masks,
]
), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
assert n_masks == 1, (
"We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
)
assert list(attn_mask.shape) == [
bs,
n_masks,
], (
f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
)
masks, pos = self.mask_encoder(
masks=masks.flatten(0, 1).float(),
pix_feat=img_feats,

View File

@@ -13,9 +13,7 @@ import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
from sam3.logger import get_logger
from tqdm import tqdm

View File

@@ -248,7 +248,9 @@ class UniversalSegmentationHead(SegmentationHead):
self.d_model = hidden_dim
if dot_product_scorer is not None:
assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
assert presence_head, (
"Specifying a dot product scorer without a presence head is likely a mistake"
)
self.presence_head = None
if presence_head:

View File

@@ -62,9 +62,9 @@ class SimpleMaskDownSampler(nn.Module):
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
self.interpol_size = interpol_size
if self.interpol_size is not None:
assert isinstance(
self.interpol_size, (list, tuple)
), f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
assert isinstance(self.interpol_size, (list, tuple)), (
f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
)
self.interpol_size = list(interpol_size)
assert len(self.interpol_size) == 2

View File

@@ -330,9 +330,9 @@ class SAM3Output(list):
self.output = output
else:
self.output = []
assert isinstance(
iter_mode, SAM3Output.IterMode
), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
assert isinstance(iter_mode, SAM3Output.IterMode), (
f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
)
self.iter_mode = iter_mode
# We create a weak reference to self to be used in the lambda functions.
@@ -411,9 +411,9 @@ class SAM3Output(list):
return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode)
def append(self, item: list):
assert isinstance(
item, list
), f"Only list items are supported. Got {type(item)}"
assert isinstance(item, list), (
f"Only list items are supported. Got {type(item)}"
)
self.output.append(item)
def __repr__(self):

View File

@@ -8,7 +8,6 @@ from copy import deepcopy
from typing import List, Optional, Tuple
import torch
import torch.nn as nn

View File

@@ -7,15 +7,12 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from PIL.Image import Image
from sam3.model.sam3_tracker_base import Sam3TrackerBase
from sam3.model.utils.sam1_utils import SAM2Transforms
@@ -97,9 +94,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
input_image = self._transforms(image)
input_image = input_image[None, ...].to(self.device)
assert (
len(input_image.shape) == 4 and input_image.shape[1] == 3
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
assert len(input_image.shape) == 4 and input_image.shape[1] == 3, (
f"input_image must be of size 1x3xHxW, got {input_image.shape}"
)
logging.info("Computing image embeddings for the provided image...")
backbone_out = self.model.forward_image(input_image)
(
@@ -136,17 +133,17 @@ class SAM3InteractiveImagePredictor(nn.Module):
assert isinstance(image_list, list)
self._orig_hw = []
for image in image_list:
assert isinstance(
image, np.ndarray
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
assert isinstance(image, np.ndarray), (
"Images are expected to be an np.ndarray in RGB format, and of shape HWC"
)
self._orig_hw.append(image.shape[:2])
# Transform the image to the form expected by the model
img_batch = self._transforms.forward_batch(image_list)
img_batch = img_batch.to(self.device)
batch_size = img_batch.shape[0]
assert (
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
assert len(img_batch.shape) == 4 and img_batch.shape[1] == 3, (
f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
)
logging.info("Computing image embeddings for the provided images...")
backbone_out = self.model.forward_image(img_batch)
(
@@ -302,9 +299,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
):
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
if point_coords is not None:
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
assert point_labels is not None, (
"point_labels must be supplied if point_coords is supplied."
)
point_coords = torch.as_tensor(
point_coords, dtype=torch.float, device=self.device
)
@@ -441,9 +438,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
assert (
self._features is not None
), "Features must exist if an image has been set."
assert self._features is not None, (
"Features must exist if an image has been set."
)
return self._features["image_embed"]
@property

View File

@@ -8,19 +8,14 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from sam3.model.model_misc import SAM3Output
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
from sam3.model.vl_combiner import SAM3VLBackbone
from sam3.perflib.nms import nms_masks
from sam3.train.data.collator import BatchedDatapoint
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .geometry_encoders import Prompt
from .model_misc import inverse_sigmoid
@@ -661,9 +656,9 @@ class Sam3Image(torch.nn.Module):
inference_state["original_heights"],
inference_state["original_widths"],
)
assert (
batch_size == len(orig_heights) == len(orig_widths)
), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
assert batch_size == len(orig_heights) == len(orig_widths), (
f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
)
feats = [
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
for feat, feat_size in zip(

View File

@@ -6,9 +6,7 @@ from typing import Dict, List
import numpy as np
import PIL
import torch
from sam3.model import box_ops
from sam3.model.data_misc import FindStage, interpolate
from torchvision.transforms import v2
@@ -83,9 +81,9 @@ class Sam3Processor:
if not isinstance(images, list):
raise ValueError("Images must be a list of PIL images or tensors")
assert len(images) > 0, "Images list must not be empty"
assert isinstance(
images[0], PIL.Image.Image
), "Images must be a list of PIL images"
assert isinstance(images[0], PIL.Image.Image), (
"Images must be a list of PIL images"
)
state["original_heights"] = [image.height for image in images]
state["original_widths"] = [image.width for image in images]

View File

@@ -6,11 +6,8 @@ import logging
import torch
import torch.nn.functional as F
from sam3.model.memory import SimpleMaskEncoder
from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames
from sam3.sam.mask_decoder import MaskDecoder, MLP
from sam3.sam.prompt_encoder import PromptEncoder
from sam3.sam.transformer import TwoWayTransformer

View File

@@ -6,7 +6,6 @@ import numpy as np
import torch
import torch.nn.functional as F
from numpy.typing import NDArray
from sam3.model.edt import edt_triton

View File

@@ -6,7 +6,6 @@ import logging
from collections import OrderedDict
import torch
from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase
from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores
from sam3.model.utils.sam2_utils import load_video_frames

View File

@@ -16,7 +16,6 @@ import numpy.typing as npt
import torch
import torch.distributed as dist
import torch.nn.functional as F
from sam3 import perflib
from sam3.logger import get_logger
from sam3.model.box_ops import fast_diag_box_iou
@@ -620,9 +619,9 @@ class Sam3VideoBase(nn.Module):
num_obj_dropped_due_to_limit,
trk_id_to_max_iou_high_conf_det,
]
assert (
len(update_plan) == NUM_BROADCAST_ITEMS
), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
assert len(update_plan) == NUM_BROADCAST_ITEMS, (
f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
)
self.broadcast_python_obj_cpu(update_plan, src=0)
elif self.rank > 0 and self.world_size > 1:
update_plan = [
@@ -842,9 +841,9 @@ class Sam3VideoBase(nn.Module):
binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
batch_size = tracker_low_res_masks_global.size(0)
if batch_size > 0:
assert (
len(obj_ids_global) == batch_size
), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
assert len(obj_ids_global) == batch_size, (
f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
)
NEVER_OCCLUDED = -1
ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic
last_occluded_prev = torch.cat(
@@ -1023,9 +1022,9 @@ class Sam3VideoBase(nn.Module):
reverse: bool = False,
):
# Suppress overlapping masks for objects that were most recently occluded
assert (
binary_low_res_masks.dtype == torch.bool
), f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
assert binary_low_res_masks.dtype == torch.bool, (
f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
)
to_suppress = torch.zeros(
binary_low_res_masks.size(0),
device=binary_low_res_masks.device,
@@ -1130,9 +1129,9 @@ class Sam3VideoBase(nn.Module):
num_frames_propagated += 1
# only 1 frames should be propagated
assert (
num_frames_propagated == 1 and out_frame_idx == frame_idx
), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
assert num_frames_propagated == 1 and out_frame_idx == frame_idx, (
f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
)
assert isinstance(out_obj_ids, list)
obj_ids_local.extend(out_obj_ids)
low_res_masks_list.append(out_low_res_masks.squeeze(1))
@@ -1189,9 +1188,9 @@ class Sam3VideoBase(nn.Module):
assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
assert (
trk_masks.size(0) == len(trk_obj_ids)
), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
assert trk_masks.size(0) == len(trk_obj_ids), (
f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
)
if trk_masks.size(0) == 0:
# all detections are new
new_det_fa_inds = np.arange(det_masks.size(0))
@@ -1655,9 +1654,9 @@ class Sam3VideoBase(nn.Module):
# a) first, expand "confirmation_data" to include new masklets added in this frame
status_prev = confirmation_data["status"]
consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
assert (
status_prev.shape == obj_ids_all_gpu_prev.shape
), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
assert status_prev.shape == obj_ids_all_gpu_prev.shape, (
f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
)
obj_id_to_updated_idx = {
obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)

View File

@@ -9,7 +9,6 @@ import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from sam3 import perflib
from sam3.logger import get_logger
from sam3.model.act_ckpt_utils import clone_output_wrapper
@@ -555,7 +554,9 @@ class Sam3VideoInference(Sam3VideoBase):
assert (
"cached_frame_outputs" in inference_state
and frame_idx in inference_state["cached_frame_outputs"]
), "No cached outputs found. Ensure normal propagation has run first to populate the cache."
), (
"No cached outputs found. Ensure normal propagation has run first to populate the cache."
)
cached_outputs = inference_state["cached_frame_outputs"][frame_idx]
obj_id_to_mask = cached_outputs.copy()
@@ -563,9 +564,9 @@ class Sam3VideoInference(Sam3VideoBase):
# Update with refined masks if provided
if refined_obj_id_to_mask is not None:
for obj_id, refined_mask in refined_obj_id_to_mask.items():
assert (
refined_mask is not None
), f"Refined mask data must be provided for obj_id {obj_id}"
assert refined_mask is not None, (
f"Refined mask data must be provided for obj_id {obj_id}"
)
obj_id_to_mask[obj_id] = refined_mask
return obj_id_to_mask
@@ -660,12 +661,12 @@ class Sam3VideoInference(Sam3VideoBase):
for i, thresh in enumerate(new_det_score_thresh_list):
self.new_det_thresh = thresh
for num_objects in num_objects_list:
logger.info(f"{i+1}/{num_rounds} warming up model compilation")
logger.info(f"{i + 1}/{num_rounds} warming up model compilation")
self.add_prompt(
inference_state, frame_idx=start_frame_idx, text_str="cat"
)
logger.info(
f"{i+1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
f"{i + 1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
)
inference_state = self.add_fake_objects_to_inference_state(
inference_state, num_objects, frame_idx=start_frame_idx
@@ -690,7 +691,7 @@ class Sam3VideoInference(Sam3VideoBase):
pass
self.reset_state(inference_state)
logger.info(
f"{i+1}/{num_rounds} warming up model compilation -- completed round {i+1} out of {num_rounds}"
f"{i + 1}/{num_rounds} warming up model compilation -- completed round {i + 1} out of {num_rounds}"
)
# Warm up Tracker memory encoder with varying input shapes
@@ -854,12 +855,12 @@ class Sam3VideoInference(Sam3VideoBase):
logger.debug("Running add_prompt on frame %d", frame_idx)
num_frames = inference_state["num_frames"]
assert (
text_str is not None or boxes_xywh is not None
), "at least one type of prompt (text, boxes) must be provided"
assert (
0 <= frame_idx < num_frames
), f"{frame_idx=} is out of range for a total of {num_frames} frames"
assert text_str is not None or boxes_xywh is not None, (
"at least one type of prompt (text, boxes) must be provided"
)
assert 0 <= frame_idx < num_frames, (
f"{frame_idx=} is out of range for a total of {num_frames} frames"
)
# since it's a semantic prompt, we start over
self.reset_state(inference_state)
@@ -1200,9 +1201,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
"propagation_partial",
"propagation_fetch",
]
assert (
action_type in instance_actions + propagation_actions
), f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
assert action_type in instance_actions + propagation_actions, (
f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
)
action = {
"type": action_type,
"frame_idx": frame_idx,
@@ -1370,12 +1371,12 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
):
if points is not None:
# Tracker instance prompts
assert (
text_str is None and boxes_xywh is None
), "When points are provided, text_str and boxes_xywh must be None."
assert (
obj_id is not None
), "When points are provided, obj_id must be provided."
assert text_str is None and boxes_xywh is None, (
"When points are provided, text_str and boxes_xywh must be None."
)
assert obj_id is not None, (
"When points are provided, obj_id must be provided."
)
return self.add_tracker_new_points(
inference_state,
frame_idx,
@@ -1491,9 +1492,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
tracker_states = self._get_tracker_inference_states_by_obj_ids(
inference_state, [obj_id]
)
assert (
len(tracker_states) == 1
), f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
assert len(tracker_states) == 1, (
f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
)
tracker_state = tracker_states[0]
# log

View File

@@ -16,7 +16,6 @@ from typing import List, Optional
import psutil
import torch
from sam3.logger import get_logger
logger = get_logger(__name__)
@@ -170,7 +169,7 @@ class Sam3VideoPredictor:
):
"""Remove an object from tracking."""
logger.debug(
f"remove object {obj_id} in session {session_id}: " f"{is_user_action=}"
f"remove object {obj_id} in session {session_id}: {is_user_action=}"
)
session = self._get_session(session_id)
inference_state = session["state"]

View File

@@ -318,9 +318,9 @@ class VETextEncoder(nn.Module):
# The text is already encoded, use as is.
text_attention_mask, text_memory_resized, tokenized = text
inputs_embeds = tokenized["inputs_embeds"]
assert (
input_boxes is None or len(input_boxes) == 0
), "Can't replace boxes in text if it's already encoded"
assert input_boxes is None or len(input_boxes) == 0, (
"Can't replace boxes in text if it's already encoded"
)
# Note that the input_embeds are returned in pytorch's convention (sequence first)
return (

View File

@@ -708,9 +708,9 @@ class ViT(nn.Module):
self.retain_cls_token = retain_cls_token
if self.retain_cls_token:
assert pretrain_use_cls_token
assert (
len(window_block_indexes) == 0
), "windowing not supported with cls token"
assert len(window_block_indexes) == 0, (
"windowing not supported with cls token"
)
assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"

View File

@@ -9,7 +9,6 @@ from typing import List, Optional
import torch
import torch.nn as nn
from torch.nn.attention import sdpa_kernel, SDPBackend
from .act_ckpt_utils import activation_ckpt_wrapper