apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)

Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: itamaro

Differential Revision: D90476315

fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
Bowie Chen
2026-01-11 23:16:49 -08:00
committed by meta-codesync[bot]
parent 7b89b8fc3f
commit 11dec2936d
69 changed files with 445 additions and 522 deletions

View File

@@ -11,7 +11,6 @@ from typing_extensions import override
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .model_misc import get_clones
@@ -148,54 +147,42 @@ class Prompt:
)
# Dimension checks
assert (
box_embeddings is not None
and list(box_embeddings.shape[:2])
== [
box_seq_len,
bs,
]
), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
assert (
box_mask is not None
and list(box_mask.shape)
== [
bs,
box_seq_len,
]
), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
assert (
point_embeddings is not None
and list(point_embeddings.shape[:2])
== [
point_seq_len,
bs,
]
), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
assert (
point_mask is not None
and list(point_mask.shape)
== [
bs,
point_seq_len,
]
), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
assert (
box_labels is not None
and list(box_labels.shape)
== [
box_seq_len,
bs,
]
), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
assert (
point_labels is not None
and list(point_labels.shape)
== [
point_seq_len,
bs,
]
), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
assert box_embeddings is not None and list(box_embeddings.shape[:2]) == [
box_seq_len,
bs,
], (
f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
)
assert box_mask is not None and list(box_mask.shape) == [
bs,
box_seq_len,
], (
f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
)
assert point_embeddings is not None and list(point_embeddings.shape[:2]) == [
point_seq_len,
bs,
], (
f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
)
assert point_mask is not None and list(point_mask.shape) == [
bs,
point_seq_len,
], (
f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
)
assert box_labels is not None and list(box_labels.shape) == [
box_seq_len,
bs,
], (
f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
)
assert point_labels is not None and list(point_labels.shape) == [
point_seq_len,
bs,
], (
f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
)
assert (
# Allowed to be None, we leave it to the encoder to check for validity before encoding.
mask_embeddings is None
@@ -204,41 +191,41 @@ class Prompt:
mask_seq_len,
bs,
]
), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
assert (
mask_mask is None
or list(mask_mask.shape)
== [
bs,
mask_seq_len,
]
), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
), (
f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
)
assert mask_mask is None or list(mask_mask.shape) == [
bs,
mask_seq_len,
], (
f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
)
# Device checks
assert (
box_embeddings is not None and box_embeddings.device == device
), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
assert (
box_mask is not None and box_mask.device == device
), f"Expected box mask to be on device {device}, got {box_mask.device}"
assert (
box_labels is not None and box_labels.device == device
), f"Expected box labels to be on device {device}, got {box_labels.device}"
assert (
point_embeddings is not None and point_embeddings.device == device
), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
assert (
point_mask is not None and point_mask.device == device
), f"Expected point mask to be on device {device}, got {point_mask.device}"
assert (
point_labels is not None and point_labels.device == device
), f"Expected point labels to be on device {device}, got {point_labels.device}"
assert (
mask_embeddings is None or mask_embeddings.device == device
), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
assert (
mask_mask is None or mask_mask.device == device
), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
assert box_embeddings is not None and box_embeddings.device == device, (
f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
)
assert box_mask is not None and box_mask.device == device, (
f"Expected box mask to be on device {device}, got {box_mask.device}"
)
assert box_labels is not None and box_labels.device == device, (
f"Expected box labels to be on device {device}, got {box_labels.device}"
)
assert point_embeddings is not None and point_embeddings.device == device, (
f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
)
assert point_mask is not None and point_mask.device == device, (
f"Expected point mask to be on device {device}, got {point_mask.device}"
)
assert point_labels is not None and point_labels.device == device, (
f"Expected point labels to be on device {device}, got {point_labels.device}"
)
assert mask_embeddings is None or mask_embeddings.device == device, (
f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
)
assert mask_mask is None or mask_mask.device == device, (
f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
)
self.box_embeddings = box_embeddings
self.point_embeddings = point_embeddings
@@ -264,30 +251,30 @@ class Prompt:
if point_embeddings is not None:
point_seq_len = point_embeddings.shape[0]
if bs is not None:
assert (
bs == point_embeddings.shape[1]
), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
assert bs == point_embeddings.shape[1], (
f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
)
else:
bs = point_embeddings.shape[1]
if device is not None:
assert (
device == point_embeddings.device
), "Device mismatch between box and point embeddings"
assert device == point_embeddings.device, (
"Device mismatch between box and point embeddings"
)
else:
device = point_embeddings.device
if mask_embeddings is not None:
mask_seq_len = mask_embeddings.shape[0]
if bs is not None:
assert (
bs == mask_embeddings.shape[1]
), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
assert bs == mask_embeddings.shape[1], (
f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
)
else:
bs = mask_embeddings.shape[1]
if device is not None:
assert (
device == mask_embeddings.device
), "Device mismatch between box/point and mask embeddings."
assert device == mask_embeddings.device, (
"Device mismatch between box/point and mask embeddings."
)
else:
device = mask_embeddings.device
@@ -539,9 +526,9 @@ class SequenceGeometryEncoder(nn.Module):
if add_cls:
self.cls_embed = torch.nn.Embedding(1, self.d_model)
assert (
points_direct_project or points_pos_enc or points_pool
), "Error: need at least one way to encode points"
assert points_direct_project or points_pos_enc or points_pool, (
"Error: need at least one way to encode points"
)
assert (
encode_boxes_as_points
or boxes_direct_project
@@ -583,16 +570,16 @@ class SequenceGeometryEncoder(nn.Module):
self.encode = None
if num_layers > 0:
assert (
add_cls
), "It's currently highly recommended to add a CLS when using a transformer"
assert add_cls, (
"It's currently highly recommended to add a CLS when using a transformer"
)
self.encode = get_clones(layer, num_layers)
self.encode_norm = nn.LayerNorm(self.d_model)
if mask_encoder is not None:
assert isinstance(
mask_encoder, MaskEncoder
), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
assert isinstance(mask_encoder, MaskEncoder), (
f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
)
if add_mask_label:
self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
self.add_mask_label = add_mask_label
@@ -701,16 +688,15 @@ class SequenceGeometryEncoder(nn.Module):
img_feats: torch.Tensor = None,
):
n_masks, bs = masks.shape[:2]
assert (
n_masks == 1
), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
assert (
list(attn_mask.shape)
== [
bs,
n_masks,
]
), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
assert n_masks == 1, (
"We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
)
assert list(attn_mask.shape) == [
bs,
n_masks,
], (
f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
)
masks, pos = self.mask_encoder(
masks=masks.flatten(0, 1).float(),
pix_feat=img_feats,