apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: itamaro Differential Revision: D90476315 fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
committed by
meta-codesync[bot]
parent
7b89b8fc3f
commit
11dec2936d
@@ -7,7 +7,6 @@ Misc functions, including distributed helpers.
|
||||
|
||||
import collections
|
||||
import re
|
||||
|
||||
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
|
||||
from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
|
||||
|
||||
@@ -29,9 +28,9 @@ def interpolate(
|
||||
input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
|
||||
assert (
|
||||
input.shape[0] != 0 or input.shape[1] != 0
|
||||
), "At least one of the two first dimensions must be non zero"
|
||||
assert input.shape[0] != 0 or input.shape[1] != 0, (
|
||||
"At least one of the two first dimensions must be non zero"
|
||||
)
|
||||
|
||||
if input.shape[1] == 0:
|
||||
# Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim
|
||||
|
||||
@@ -9,18 +9,13 @@ Inspired from Pytorch's version, adds the pre-norm variant
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.sam.transformer import RoPEAttention
|
||||
|
||||
from torch import nn, Tensor
|
||||
from torchvision.ops.roi_align import RoIAlign
|
||||
|
||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||
|
||||
from .box_ops import box_cxcywh_to_xyxy
|
||||
|
||||
from .model_misc import (
|
||||
gen_sineembed_for_position,
|
||||
get_activation_fn,
|
||||
@@ -444,9 +439,9 @@ class TransformerDecoder(nn.Module):
|
||||
- valid_ratios/spatial_shapes: bs, nlevel, 2
|
||||
"""
|
||||
if memory_mask is not None:
|
||||
assert (
|
||||
self.boxRPB == "none"
|
||||
), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
|
||||
assert self.boxRPB == "none", (
|
||||
"inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
|
||||
)
|
||||
|
||||
apply_dac = apply_dac if apply_dac is not None else self.dac
|
||||
if apply_dac:
|
||||
@@ -516,18 +511,18 @@ class TransformerDecoder(nn.Module):
|
||||
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
|
||||
|
||||
if self.boxRPB != "none" and reference_boxes is not None:
|
||||
assert (
|
||||
spatial_shapes.shape[0] == 1
|
||||
), "only single scale support implemented"
|
||||
assert spatial_shapes.shape[0] == 1, (
|
||||
"only single scale support implemented"
|
||||
)
|
||||
memory_mask = self._get_rpb_matrix(
|
||||
reference_boxes,
|
||||
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
|
||||
)
|
||||
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
|
||||
if self.training:
|
||||
assert (
|
||||
self.use_act_checkpoint
|
||||
), "Activation checkpointing not enabled in the decoder"
|
||||
assert self.use_act_checkpoint, (
|
||||
"Activation checkpointing not enabled in the decoder"
|
||||
)
|
||||
output, presence_out = activation_ckpt_wrapper(layer)(
|
||||
tgt=output,
|
||||
tgt_query_pos=query_pos,
|
||||
@@ -676,9 +671,9 @@ class TransformerEncoderCrossAttention(nn.Module):
|
||||
src_pos[0],
|
||||
)
|
||||
|
||||
assert (
|
||||
src.shape[1] == prompt.shape[1]
|
||||
), "Batch size must be the same for src and prompt"
|
||||
assert src.shape[1] == prompt.shape[1], (
|
||||
"Batch size must be the same for src and prompt"
|
||||
)
|
||||
|
||||
output = src
|
||||
|
||||
|
||||
@@ -322,9 +322,9 @@ class TransformerEncoder(nn.Module):
|
||||
return reference_points
|
||||
|
||||
def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
|
||||
assert (
|
||||
len(srcs) == self.num_feature_levels
|
||||
), "mismatch between expected and received # of feature levels"
|
||||
assert len(srcs) == self.num_feature_levels, (
|
||||
"mismatch between expected and received # of feature levels"
|
||||
)
|
||||
|
||||
src_flatten = []
|
||||
mask_flatten = []
|
||||
@@ -406,9 +406,9 @@ class TransformerEncoder(nn.Module):
|
||||
- spatial_shapes: Spatial dimensions of each feature level
|
||||
- valid_ratios: Valid ratios for each feature level
|
||||
"""
|
||||
assert (
|
||||
len(src) == self.num_feature_levels
|
||||
), "must be equal to num_feature_levels"
|
||||
assert len(src) == self.num_feature_levels, (
|
||||
"must be equal to num_feature_levels"
|
||||
)
|
||||
if src_key_padding_masks is not None:
|
||||
assert len(src_key_padding_masks) == self.num_feature_levels
|
||||
if pos is not None:
|
||||
@@ -538,9 +538,9 @@ class TransformerEncoderFusion(TransformerEncoder):
|
||||
else None
|
||||
)
|
||||
else:
|
||||
assert all(
|
||||
x.dim == 4 for x in src
|
||||
), "expected list of (bs, c, h, w) tensors"
|
||||
assert all(x.dim == 4 for x in src), (
|
||||
"expected list of (bs, c, h, w) tensors"
|
||||
)
|
||||
|
||||
if self.add_pooled_text_to_img_feat:
|
||||
# Fusion: Add mean pooled text to image features
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing_extensions import override
|
||||
|
||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||
from .box_ops import box_cxcywh_to_xyxy
|
||||
|
||||
from .model_misc import get_clones
|
||||
|
||||
|
||||
@@ -148,54 +147,42 @@ class Prompt:
|
||||
)
|
||||
|
||||
# Dimension checks
|
||||
assert (
|
||||
box_embeddings is not None
|
||||
and list(box_embeddings.shape[:2])
|
||||
== [
|
||||
box_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
|
||||
assert (
|
||||
box_mask is not None
|
||||
and list(box_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
box_seq_len,
|
||||
]
|
||||
), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
|
||||
assert (
|
||||
point_embeddings is not None
|
||||
and list(point_embeddings.shape[:2])
|
||||
== [
|
||||
point_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
|
||||
assert (
|
||||
point_mask is not None
|
||||
and list(point_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
point_seq_len,
|
||||
]
|
||||
), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
|
||||
assert (
|
||||
box_labels is not None
|
||||
and list(box_labels.shape)
|
||||
== [
|
||||
box_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
|
||||
assert (
|
||||
point_labels is not None
|
||||
and list(point_labels.shape)
|
||||
== [
|
||||
point_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
|
||||
assert box_embeddings is not None and list(box_embeddings.shape[:2]) == [
|
||||
box_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
|
||||
)
|
||||
assert box_mask is not None and list(box_mask.shape) == [
|
||||
bs,
|
||||
box_seq_len,
|
||||
], (
|
||||
f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
|
||||
)
|
||||
assert point_embeddings is not None and list(point_embeddings.shape[:2]) == [
|
||||
point_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
|
||||
)
|
||||
assert point_mask is not None and list(point_mask.shape) == [
|
||||
bs,
|
||||
point_seq_len,
|
||||
], (
|
||||
f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
|
||||
)
|
||||
assert box_labels is not None and list(box_labels.shape) == [
|
||||
box_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
|
||||
)
|
||||
assert point_labels is not None and list(point_labels.shape) == [
|
||||
point_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
|
||||
)
|
||||
assert (
|
||||
# Allowed to be None, we leave it to the encoder to check for validity before encoding.
|
||||
mask_embeddings is None
|
||||
@@ -204,41 +191,41 @@ class Prompt:
|
||||
mask_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
|
||||
assert (
|
||||
mask_mask is None
|
||||
or list(mask_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
mask_seq_len,
|
||||
]
|
||||
), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
|
||||
), (
|
||||
f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
|
||||
)
|
||||
assert mask_mask is None or list(mask_mask.shape) == [
|
||||
bs,
|
||||
mask_seq_len,
|
||||
], (
|
||||
f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
|
||||
)
|
||||
|
||||
# Device checks
|
||||
assert (
|
||||
box_embeddings is not None and box_embeddings.device == device
|
||||
), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
|
||||
assert (
|
||||
box_mask is not None and box_mask.device == device
|
||||
), f"Expected box mask to be on device {device}, got {box_mask.device}"
|
||||
assert (
|
||||
box_labels is not None and box_labels.device == device
|
||||
), f"Expected box labels to be on device {device}, got {box_labels.device}"
|
||||
assert (
|
||||
point_embeddings is not None and point_embeddings.device == device
|
||||
), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
|
||||
assert (
|
||||
point_mask is not None and point_mask.device == device
|
||||
), f"Expected point mask to be on device {device}, got {point_mask.device}"
|
||||
assert (
|
||||
point_labels is not None and point_labels.device == device
|
||||
), f"Expected point labels to be on device {device}, got {point_labels.device}"
|
||||
assert (
|
||||
mask_embeddings is None or mask_embeddings.device == device
|
||||
), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
|
||||
assert (
|
||||
mask_mask is None or mask_mask.device == device
|
||||
), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
|
||||
assert box_embeddings is not None and box_embeddings.device == device, (
|
||||
f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
|
||||
)
|
||||
assert box_mask is not None and box_mask.device == device, (
|
||||
f"Expected box mask to be on device {device}, got {box_mask.device}"
|
||||
)
|
||||
assert box_labels is not None and box_labels.device == device, (
|
||||
f"Expected box labels to be on device {device}, got {box_labels.device}"
|
||||
)
|
||||
assert point_embeddings is not None and point_embeddings.device == device, (
|
||||
f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
|
||||
)
|
||||
assert point_mask is not None and point_mask.device == device, (
|
||||
f"Expected point mask to be on device {device}, got {point_mask.device}"
|
||||
)
|
||||
assert point_labels is not None and point_labels.device == device, (
|
||||
f"Expected point labels to be on device {device}, got {point_labels.device}"
|
||||
)
|
||||
assert mask_embeddings is None or mask_embeddings.device == device, (
|
||||
f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
|
||||
)
|
||||
assert mask_mask is None or mask_mask.device == device, (
|
||||
f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
|
||||
)
|
||||
|
||||
self.box_embeddings = box_embeddings
|
||||
self.point_embeddings = point_embeddings
|
||||
@@ -264,30 +251,30 @@ class Prompt:
|
||||
if point_embeddings is not None:
|
||||
point_seq_len = point_embeddings.shape[0]
|
||||
if bs is not None:
|
||||
assert (
|
||||
bs == point_embeddings.shape[1]
|
||||
), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
|
||||
assert bs == point_embeddings.shape[1], (
|
||||
f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
|
||||
)
|
||||
else:
|
||||
bs = point_embeddings.shape[1]
|
||||
if device is not None:
|
||||
assert (
|
||||
device == point_embeddings.device
|
||||
), "Device mismatch between box and point embeddings"
|
||||
assert device == point_embeddings.device, (
|
||||
"Device mismatch between box and point embeddings"
|
||||
)
|
||||
else:
|
||||
device = point_embeddings.device
|
||||
|
||||
if mask_embeddings is not None:
|
||||
mask_seq_len = mask_embeddings.shape[0]
|
||||
if bs is not None:
|
||||
assert (
|
||||
bs == mask_embeddings.shape[1]
|
||||
), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
|
||||
assert bs == mask_embeddings.shape[1], (
|
||||
f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
|
||||
)
|
||||
else:
|
||||
bs = mask_embeddings.shape[1]
|
||||
if device is not None:
|
||||
assert (
|
||||
device == mask_embeddings.device
|
||||
), "Device mismatch between box/point and mask embeddings."
|
||||
assert device == mask_embeddings.device, (
|
||||
"Device mismatch between box/point and mask embeddings."
|
||||
)
|
||||
else:
|
||||
device = mask_embeddings.device
|
||||
|
||||
@@ -539,9 +526,9 @@ class SequenceGeometryEncoder(nn.Module):
|
||||
if add_cls:
|
||||
self.cls_embed = torch.nn.Embedding(1, self.d_model)
|
||||
|
||||
assert (
|
||||
points_direct_project or points_pos_enc or points_pool
|
||||
), "Error: need at least one way to encode points"
|
||||
assert points_direct_project or points_pos_enc or points_pool, (
|
||||
"Error: need at least one way to encode points"
|
||||
)
|
||||
assert (
|
||||
encode_boxes_as_points
|
||||
or boxes_direct_project
|
||||
@@ -583,16 +570,16 @@ class SequenceGeometryEncoder(nn.Module):
|
||||
|
||||
self.encode = None
|
||||
if num_layers > 0:
|
||||
assert (
|
||||
add_cls
|
||||
), "It's currently highly recommended to add a CLS when using a transformer"
|
||||
assert add_cls, (
|
||||
"It's currently highly recommended to add a CLS when using a transformer"
|
||||
)
|
||||
self.encode = get_clones(layer, num_layers)
|
||||
self.encode_norm = nn.LayerNorm(self.d_model)
|
||||
|
||||
if mask_encoder is not None:
|
||||
assert isinstance(
|
||||
mask_encoder, MaskEncoder
|
||||
), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
|
||||
assert isinstance(mask_encoder, MaskEncoder), (
|
||||
f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
|
||||
)
|
||||
if add_mask_label:
|
||||
self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
|
||||
self.add_mask_label = add_mask_label
|
||||
@@ -701,16 +688,15 @@ class SequenceGeometryEncoder(nn.Module):
|
||||
img_feats: torch.Tensor = None,
|
||||
):
|
||||
n_masks, bs = masks.shape[:2]
|
||||
assert (
|
||||
n_masks == 1
|
||||
), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
|
||||
assert (
|
||||
list(attn_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
n_masks,
|
||||
]
|
||||
), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
|
||||
assert n_masks == 1, (
|
||||
"We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
|
||||
)
|
||||
assert list(attn_mask.shape) == [
|
||||
bs,
|
||||
n_masks,
|
||||
], (
|
||||
f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
|
||||
)
|
||||
masks, pos = self.mask_encoder(
|
||||
masks=masks.flatten(0, 1).float(),
|
||||
pix_feat=img_feats,
|
||||
|
||||
@@ -13,9 +13,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from sam3.logger import get_logger
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@@ -248,7 +248,9 @@ class UniversalSegmentationHead(SegmentationHead):
|
||||
self.d_model = hidden_dim
|
||||
|
||||
if dot_product_scorer is not None:
|
||||
assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
|
||||
assert presence_head, (
|
||||
"Specifying a dot product scorer without a presence head is likely a mistake"
|
||||
)
|
||||
|
||||
self.presence_head = None
|
||||
if presence_head:
|
||||
|
||||
@@ -62,9 +62,9 @@ class SimpleMaskDownSampler(nn.Module):
|
||||
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
||||
self.interpol_size = interpol_size
|
||||
if self.interpol_size is not None:
|
||||
assert isinstance(
|
||||
self.interpol_size, (list, tuple)
|
||||
), f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
|
||||
assert isinstance(self.interpol_size, (list, tuple)), (
|
||||
f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
|
||||
)
|
||||
self.interpol_size = list(interpol_size)
|
||||
assert len(self.interpol_size) == 2
|
||||
|
||||
|
||||
@@ -330,9 +330,9 @@ class SAM3Output(list):
|
||||
self.output = output
|
||||
else:
|
||||
self.output = []
|
||||
assert isinstance(
|
||||
iter_mode, SAM3Output.IterMode
|
||||
), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
|
||||
assert isinstance(iter_mode, SAM3Output.IterMode), (
|
||||
f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
|
||||
)
|
||||
|
||||
self.iter_mode = iter_mode
|
||||
# We create a weak reference to self to be used in the lambda functions.
|
||||
@@ -411,9 +411,9 @@ class SAM3Output(list):
|
||||
return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode)
|
||||
|
||||
def append(self, item: list):
|
||||
assert isinstance(
|
||||
item, list
|
||||
), f"Only list items are supported. Got {type(item)}"
|
||||
assert isinstance(item, list), (
|
||||
f"Only list items are supported. Got {type(item)}"
|
||||
)
|
||||
self.output.append(item)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -8,7 +8,6 @@ from copy import deepcopy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
|
||||
@@ -7,15 +7,12 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
from PIL.Image import Image
|
||||
|
||||
from sam3.model.sam3_tracker_base import Sam3TrackerBase
|
||||
from sam3.model.utils.sam1_utils import SAM2Transforms
|
||||
|
||||
@@ -97,9 +94,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
||||
input_image = self._transforms(image)
|
||||
input_image = input_image[None, ...].to(self.device)
|
||||
|
||||
assert (
|
||||
len(input_image.shape) == 4 and input_image.shape[1] == 3
|
||||
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
||||
assert len(input_image.shape) == 4 and input_image.shape[1] == 3, (
|
||||
f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
||||
)
|
||||
logging.info("Computing image embeddings for the provided image...")
|
||||
backbone_out = self.model.forward_image(input_image)
|
||||
(
|
||||
@@ -136,17 +133,17 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
||||
assert isinstance(image_list, list)
|
||||
self._orig_hw = []
|
||||
for image in image_list:
|
||||
assert isinstance(
|
||||
image, np.ndarray
|
||||
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
||||
assert isinstance(image, np.ndarray), (
|
||||
"Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
||||
)
|
||||
self._orig_hw.append(image.shape[:2])
|
||||
# Transform the image to the form expected by the model
|
||||
img_batch = self._transforms.forward_batch(image_list)
|
||||
img_batch = img_batch.to(self.device)
|
||||
batch_size = img_batch.shape[0]
|
||||
assert (
|
||||
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
|
||||
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
||||
assert len(img_batch.shape) == 4 and img_batch.shape[1] == 3, (
|
||||
f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
||||
)
|
||||
logging.info("Computing image embeddings for the provided images...")
|
||||
backbone_out = self.model.forward_image(img_batch)
|
||||
(
|
||||
@@ -302,9 +299,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
||||
):
|
||||
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
), "point_labels must be supplied if point_coords is supplied."
|
||||
assert point_labels is not None, (
|
||||
"point_labels must be supplied if point_coords is supplied."
|
||||
)
|
||||
point_coords = torch.as_tensor(
|
||||
point_coords, dtype=torch.float, device=self.device
|
||||
)
|
||||
@@ -441,9 +438,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
||||
raise RuntimeError(
|
||||
"An image must be set with .set_image(...) to generate an embedding."
|
||||
)
|
||||
assert (
|
||||
self._features is not None
|
||||
), "Features must exist if an image has been set."
|
||||
assert self._features is not None, (
|
||||
"Features must exist if an image has been set."
|
||||
)
|
||||
return self._features["image_embed"]
|
||||
|
||||
@property
|
||||
|
||||
@@ -8,19 +8,14 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sam3.model.model_misc import SAM3Output
|
||||
|
||||
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
|
||||
from sam3.model.vl_combiner import SAM3VLBackbone
|
||||
from sam3.perflib.nms import nms_masks
|
||||
|
||||
from sam3.train.data.collator import BatchedDatapoint
|
||||
|
||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||
|
||||
from .box_ops import box_cxcywh_to_xyxy
|
||||
|
||||
from .geometry_encoders import Prompt
|
||||
from .model_misc import inverse_sigmoid
|
||||
|
||||
@@ -661,9 +656,9 @@ class Sam3Image(torch.nn.Module):
|
||||
inference_state["original_heights"],
|
||||
inference_state["original_widths"],
|
||||
)
|
||||
assert (
|
||||
batch_size == len(orig_heights) == len(orig_widths)
|
||||
), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
|
||||
assert batch_size == len(orig_heights) == len(orig_widths), (
|
||||
f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
|
||||
)
|
||||
feats = [
|
||||
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
||||
for feat, feat_size in zip(
|
||||
|
||||
@@ -6,9 +6,7 @@ from typing import Dict, List
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from sam3.model import box_ops
|
||||
|
||||
from sam3.model.data_misc import FindStage, interpolate
|
||||
from torchvision.transforms import v2
|
||||
|
||||
@@ -83,9 +81,9 @@ class Sam3Processor:
|
||||
if not isinstance(images, list):
|
||||
raise ValueError("Images must be a list of PIL images or tensors")
|
||||
assert len(images) > 0, "Images list must not be empty"
|
||||
assert isinstance(
|
||||
images[0], PIL.Image.Image
|
||||
), "Images must be a list of PIL images"
|
||||
assert isinstance(images[0], PIL.Image.Image), (
|
||||
"Images must be a list of PIL images"
|
||||
)
|
||||
|
||||
state["original_heights"] = [image.height for image in images]
|
||||
state["original_widths"] = [image.width for image in images]
|
||||
|
||||
@@ -6,11 +6,8 @@ import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sam3.model.memory import SimpleMaskEncoder
|
||||
|
||||
from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames
|
||||
|
||||
from sam3.sam.mask_decoder import MaskDecoder, MLP
|
||||
from sam3.sam.prompt_encoder import PromptEncoder
|
||||
from sam3.sam.transformer import TwoWayTransformer
|
||||
|
||||
@@ -6,7 +6,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from sam3.model.edt import edt_triton
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase
|
||||
from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores
|
||||
from sam3.model.utils.sam2_utils import load_video_frames
|
||||
|
||||
@@ -16,7 +16,6 @@ import numpy.typing as npt
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sam3 import perflib
|
||||
from sam3.logger import get_logger
|
||||
from sam3.model.box_ops import fast_diag_box_iou
|
||||
@@ -620,9 +619,9 @@ class Sam3VideoBase(nn.Module):
|
||||
num_obj_dropped_due_to_limit,
|
||||
trk_id_to_max_iou_high_conf_det,
|
||||
]
|
||||
assert (
|
||||
len(update_plan) == NUM_BROADCAST_ITEMS
|
||||
), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
|
||||
assert len(update_plan) == NUM_BROADCAST_ITEMS, (
|
||||
f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
|
||||
)
|
||||
self.broadcast_python_obj_cpu(update_plan, src=0)
|
||||
elif self.rank > 0 and self.world_size > 1:
|
||||
update_plan = [
|
||||
@@ -842,9 +841,9 @@ class Sam3VideoBase(nn.Module):
|
||||
binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
|
||||
batch_size = tracker_low_res_masks_global.size(0)
|
||||
if batch_size > 0:
|
||||
assert (
|
||||
len(obj_ids_global) == batch_size
|
||||
), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
|
||||
assert len(obj_ids_global) == batch_size, (
|
||||
f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
|
||||
)
|
||||
NEVER_OCCLUDED = -1
|
||||
ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic
|
||||
last_occluded_prev = torch.cat(
|
||||
@@ -1023,9 +1022,9 @@ class Sam3VideoBase(nn.Module):
|
||||
reverse: bool = False,
|
||||
):
|
||||
# Suppress overlapping masks for objects that were most recently occluded
|
||||
assert (
|
||||
binary_low_res_masks.dtype == torch.bool
|
||||
), f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
|
||||
assert binary_low_res_masks.dtype == torch.bool, (
|
||||
f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
|
||||
)
|
||||
to_suppress = torch.zeros(
|
||||
binary_low_res_masks.size(0),
|
||||
device=binary_low_res_masks.device,
|
||||
@@ -1130,9 +1129,9 @@ class Sam3VideoBase(nn.Module):
|
||||
num_frames_propagated += 1
|
||||
|
||||
# only 1 frames should be propagated
|
||||
assert (
|
||||
num_frames_propagated == 1 and out_frame_idx == frame_idx
|
||||
), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
|
||||
assert num_frames_propagated == 1 and out_frame_idx == frame_idx, (
|
||||
f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
|
||||
)
|
||||
assert isinstance(out_obj_ids, list)
|
||||
obj_ids_local.extend(out_obj_ids)
|
||||
low_res_masks_list.append(out_low_res_masks.squeeze(1))
|
||||
@@ -1189,9 +1188,9 @@ class Sam3VideoBase(nn.Module):
|
||||
|
||||
assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
|
||||
assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
|
||||
assert (
|
||||
trk_masks.size(0) == len(trk_obj_ids)
|
||||
), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
|
||||
assert trk_masks.size(0) == len(trk_obj_ids), (
|
||||
f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
|
||||
)
|
||||
if trk_masks.size(0) == 0:
|
||||
# all detections are new
|
||||
new_det_fa_inds = np.arange(det_masks.size(0))
|
||||
@@ -1655,9 +1654,9 @@ class Sam3VideoBase(nn.Module):
|
||||
# a) first, expand "confirmation_data" to include new masklets added in this frame
|
||||
status_prev = confirmation_data["status"]
|
||||
consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
|
||||
assert (
|
||||
status_prev.shape == obj_ids_all_gpu_prev.shape
|
||||
), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
|
||||
assert status_prev.shape == obj_ids_all_gpu_prev.shape, (
|
||||
f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
|
||||
)
|
||||
|
||||
obj_id_to_updated_idx = {
|
||||
obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)
|
||||
|
||||
@@ -9,7 +9,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sam3 import perflib
|
||||
from sam3.logger import get_logger
|
||||
from sam3.model.act_ckpt_utils import clone_output_wrapper
|
||||
@@ -555,7 +554,9 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
assert (
|
||||
"cached_frame_outputs" in inference_state
|
||||
and frame_idx in inference_state["cached_frame_outputs"]
|
||||
), "No cached outputs found. Ensure normal propagation has run first to populate the cache."
|
||||
), (
|
||||
"No cached outputs found. Ensure normal propagation has run first to populate the cache."
|
||||
)
|
||||
cached_outputs = inference_state["cached_frame_outputs"][frame_idx]
|
||||
|
||||
obj_id_to_mask = cached_outputs.copy()
|
||||
@@ -563,9 +564,9 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
# Update with refined masks if provided
|
||||
if refined_obj_id_to_mask is not None:
|
||||
for obj_id, refined_mask in refined_obj_id_to_mask.items():
|
||||
assert (
|
||||
refined_mask is not None
|
||||
), f"Refined mask data must be provided for obj_id {obj_id}"
|
||||
assert refined_mask is not None, (
|
||||
f"Refined mask data must be provided for obj_id {obj_id}"
|
||||
)
|
||||
obj_id_to_mask[obj_id] = refined_mask
|
||||
|
||||
return obj_id_to_mask
|
||||
@@ -660,12 +661,12 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
for i, thresh in enumerate(new_det_score_thresh_list):
|
||||
self.new_det_thresh = thresh
|
||||
for num_objects in num_objects_list:
|
||||
logger.info(f"{i+1}/{num_rounds} warming up model compilation")
|
||||
logger.info(f"{i + 1}/{num_rounds} warming up model compilation")
|
||||
self.add_prompt(
|
||||
inference_state, frame_idx=start_frame_idx, text_str="cat"
|
||||
)
|
||||
logger.info(
|
||||
f"{i+1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
|
||||
f"{i + 1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
|
||||
)
|
||||
inference_state = self.add_fake_objects_to_inference_state(
|
||||
inference_state, num_objects, frame_idx=start_frame_idx
|
||||
@@ -690,7 +691,7 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
pass
|
||||
self.reset_state(inference_state)
|
||||
logger.info(
|
||||
f"{i+1}/{num_rounds} warming up model compilation -- completed round {i+1} out of {num_rounds}"
|
||||
f"{i + 1}/{num_rounds} warming up model compilation -- completed round {i + 1} out of {num_rounds}"
|
||||
)
|
||||
|
||||
# Warm up Tracker memory encoder with varying input shapes
|
||||
@@ -854,12 +855,12 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
logger.debug("Running add_prompt on frame %d", frame_idx)
|
||||
|
||||
num_frames = inference_state["num_frames"]
|
||||
assert (
|
||||
text_str is not None or boxes_xywh is not None
|
||||
), "at least one type of prompt (text, boxes) must be provided"
|
||||
assert (
|
||||
0 <= frame_idx < num_frames
|
||||
), f"{frame_idx=} is out of range for a total of {num_frames} frames"
|
||||
assert text_str is not None or boxes_xywh is not None, (
|
||||
"at least one type of prompt (text, boxes) must be provided"
|
||||
)
|
||||
assert 0 <= frame_idx < num_frames, (
|
||||
f"{frame_idx=} is out of range for a total of {num_frames} frames"
|
||||
)
|
||||
|
||||
# since it's a semantic prompt, we start over
|
||||
self.reset_state(inference_state)
|
||||
@@ -1200,9 +1201,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
||||
"propagation_partial",
|
||||
"propagation_fetch",
|
||||
]
|
||||
assert (
|
||||
action_type in instance_actions + propagation_actions
|
||||
), f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
|
||||
assert action_type in instance_actions + propagation_actions, (
|
||||
f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
|
||||
)
|
||||
action = {
|
||||
"type": action_type,
|
||||
"frame_idx": frame_idx,
|
||||
@@ -1370,12 +1371,12 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
||||
):
|
||||
if points is not None:
|
||||
# Tracker instance prompts
|
||||
assert (
|
||||
text_str is None and boxes_xywh is None
|
||||
), "When points are provided, text_str and boxes_xywh must be None."
|
||||
assert (
|
||||
obj_id is not None
|
||||
), "When points are provided, obj_id must be provided."
|
||||
assert text_str is None and boxes_xywh is None, (
|
||||
"When points are provided, text_str and boxes_xywh must be None."
|
||||
)
|
||||
assert obj_id is not None, (
|
||||
"When points are provided, obj_id must be provided."
|
||||
)
|
||||
return self.add_tracker_new_points(
|
||||
inference_state,
|
||||
frame_idx,
|
||||
@@ -1491,9 +1492,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
||||
tracker_states = self._get_tracker_inference_states_by_obj_ids(
|
||||
inference_state, [obj_id]
|
||||
)
|
||||
assert (
|
||||
len(tracker_states) == 1
|
||||
), f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
|
||||
assert len(tracker_states) == 1, (
|
||||
f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
|
||||
)
|
||||
tracker_state = tracker_states[0]
|
||||
|
||||
# log
|
||||
|
||||
@@ -16,7 +16,6 @@ from typing import List, Optional
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from sam3.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -170,7 +169,7 @@ class Sam3VideoPredictor:
|
||||
):
|
||||
"""Remove an object from tracking."""
|
||||
logger.debug(
|
||||
f"remove object {obj_id} in session {session_id}: " f"{is_user_action=}"
|
||||
f"remove object {obj_id} in session {session_id}: {is_user_action=}"
|
||||
)
|
||||
session = self._get_session(session_id)
|
||||
inference_state = session["state"]
|
||||
|
||||
@@ -318,9 +318,9 @@ class VETextEncoder(nn.Module):
|
||||
# The text is already encoded, use as is.
|
||||
text_attention_mask, text_memory_resized, tokenized = text
|
||||
inputs_embeds = tokenized["inputs_embeds"]
|
||||
assert (
|
||||
input_boxes is None or len(input_boxes) == 0
|
||||
), "Can't replace boxes in text if it's already encoded"
|
||||
assert input_boxes is None or len(input_boxes) == 0, (
|
||||
"Can't replace boxes in text if it's already encoded"
|
||||
)
|
||||
|
||||
# Note that the input_embeds are returned in pytorch's convention (sequence first)
|
||||
return (
|
||||
|
||||
@@ -708,9 +708,9 @@ class ViT(nn.Module):
|
||||
self.retain_cls_token = retain_cls_token
|
||||
if self.retain_cls_token:
|
||||
assert pretrain_use_cls_token
|
||||
assert (
|
||||
len(window_block_indexes) == 0
|
||||
), "windowing not supported with cls token"
|
||||
assert len(window_block_indexes) == 0, (
|
||||
"windowing not supported with cls token"
|
||||
)
|
||||
|
||||
assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||
|
||||
Reference in New Issue
Block a user