apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: itamaro Differential Revision: D90476315 fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
committed by
meta-codesync[bot]
parent
7b89b8fc3f
commit
11dec2936d
@@ -11,7 +11,6 @@ from typing_extensions import override
|
||||
|
||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||
from .box_ops import box_cxcywh_to_xyxy
|
||||
|
||||
from .model_misc import get_clones
|
||||
|
||||
|
||||
@@ -148,54 +147,42 @@ class Prompt:
|
||||
)
|
||||
|
||||
# Dimension checks
|
||||
assert (
|
||||
box_embeddings is not None
|
||||
and list(box_embeddings.shape[:2])
|
||||
== [
|
||||
box_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
|
||||
assert (
|
||||
box_mask is not None
|
||||
and list(box_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
box_seq_len,
|
||||
]
|
||||
), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
|
||||
assert (
|
||||
point_embeddings is not None
|
||||
and list(point_embeddings.shape[:2])
|
||||
== [
|
||||
point_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
|
||||
assert (
|
||||
point_mask is not None
|
||||
and list(point_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
point_seq_len,
|
||||
]
|
||||
), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
|
||||
assert (
|
||||
box_labels is not None
|
||||
and list(box_labels.shape)
|
||||
== [
|
||||
box_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
|
||||
assert (
|
||||
point_labels is not None
|
||||
and list(point_labels.shape)
|
||||
== [
|
||||
point_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
|
||||
assert box_embeddings is not None and list(box_embeddings.shape[:2]) == [
|
||||
box_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
|
||||
)
|
||||
assert box_mask is not None and list(box_mask.shape) == [
|
||||
bs,
|
||||
box_seq_len,
|
||||
], (
|
||||
f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
|
||||
)
|
||||
assert point_embeddings is not None and list(point_embeddings.shape[:2]) == [
|
||||
point_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
|
||||
)
|
||||
assert point_mask is not None and list(point_mask.shape) == [
|
||||
bs,
|
||||
point_seq_len,
|
||||
], (
|
||||
f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
|
||||
)
|
||||
assert box_labels is not None and list(box_labels.shape) == [
|
||||
box_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
|
||||
)
|
||||
assert point_labels is not None and list(point_labels.shape) == [
|
||||
point_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
|
||||
)
|
||||
assert (
|
||||
# Allowed to be None, we leave it to the encoder to check for validity before encoding.
|
||||
mask_embeddings is None
|
||||
@@ -204,41 +191,41 @@ class Prompt:
|
||||
mask_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
|
||||
assert (
|
||||
mask_mask is None
|
||||
or list(mask_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
mask_seq_len,
|
||||
]
|
||||
), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
|
||||
), (
|
||||
f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
|
||||
)
|
||||
assert mask_mask is None or list(mask_mask.shape) == [
|
||||
bs,
|
||||
mask_seq_len,
|
||||
], (
|
||||
f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
|
||||
)
|
||||
|
||||
# Device checks
|
||||
assert (
|
||||
box_embeddings is not None and box_embeddings.device == device
|
||||
), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
|
||||
assert (
|
||||
box_mask is not None and box_mask.device == device
|
||||
), f"Expected box mask to be on device {device}, got {box_mask.device}"
|
||||
assert (
|
||||
box_labels is not None and box_labels.device == device
|
||||
), f"Expected box labels to be on device {device}, got {box_labels.device}"
|
||||
assert (
|
||||
point_embeddings is not None and point_embeddings.device == device
|
||||
), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
|
||||
assert (
|
||||
point_mask is not None and point_mask.device == device
|
||||
), f"Expected point mask to be on device {device}, got {point_mask.device}"
|
||||
assert (
|
||||
point_labels is not None and point_labels.device == device
|
||||
), f"Expected point labels to be on device {device}, got {point_labels.device}"
|
||||
assert (
|
||||
mask_embeddings is None or mask_embeddings.device == device
|
||||
), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
|
||||
assert (
|
||||
mask_mask is None or mask_mask.device == device
|
||||
), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
|
||||
assert box_embeddings is not None and box_embeddings.device == device, (
|
||||
f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
|
||||
)
|
||||
assert box_mask is not None and box_mask.device == device, (
|
||||
f"Expected box mask to be on device {device}, got {box_mask.device}"
|
||||
)
|
||||
assert box_labels is not None and box_labels.device == device, (
|
||||
f"Expected box labels to be on device {device}, got {box_labels.device}"
|
||||
)
|
||||
assert point_embeddings is not None and point_embeddings.device == device, (
|
||||
f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
|
||||
)
|
||||
assert point_mask is not None and point_mask.device == device, (
|
||||
f"Expected point mask to be on device {device}, got {point_mask.device}"
|
||||
)
|
||||
assert point_labels is not None and point_labels.device == device, (
|
||||
f"Expected point labels to be on device {device}, got {point_labels.device}"
|
||||
)
|
||||
assert mask_embeddings is None or mask_embeddings.device == device, (
|
||||
f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
|
||||
)
|
||||
assert mask_mask is None or mask_mask.device == device, (
|
||||
f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
|
||||
)
|
||||
|
||||
self.box_embeddings = box_embeddings
|
||||
self.point_embeddings = point_embeddings
|
||||
@@ -264,30 +251,30 @@ class Prompt:
|
||||
if point_embeddings is not None:
|
||||
point_seq_len = point_embeddings.shape[0]
|
||||
if bs is not None:
|
||||
assert (
|
||||
bs == point_embeddings.shape[1]
|
||||
), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
|
||||
assert bs == point_embeddings.shape[1], (
|
||||
f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
|
||||
)
|
||||
else:
|
||||
bs = point_embeddings.shape[1]
|
||||
if device is not None:
|
||||
assert (
|
||||
device == point_embeddings.device
|
||||
), "Device mismatch between box and point embeddings"
|
||||
assert device == point_embeddings.device, (
|
||||
"Device mismatch between box and point embeddings"
|
||||
)
|
||||
else:
|
||||
device = point_embeddings.device
|
||||
|
||||
if mask_embeddings is not None:
|
||||
mask_seq_len = mask_embeddings.shape[0]
|
||||
if bs is not None:
|
||||
assert (
|
||||
bs == mask_embeddings.shape[1]
|
||||
), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
|
||||
assert bs == mask_embeddings.shape[1], (
|
||||
f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
|
||||
)
|
||||
else:
|
||||
bs = mask_embeddings.shape[1]
|
||||
if device is not None:
|
||||
assert (
|
||||
device == mask_embeddings.device
|
||||
), "Device mismatch between box/point and mask embeddings."
|
||||
assert device == mask_embeddings.device, (
|
||||
"Device mismatch between box/point and mask embeddings."
|
||||
)
|
||||
else:
|
||||
device = mask_embeddings.device
|
||||
|
||||
@@ -539,9 +526,9 @@ class SequenceGeometryEncoder(nn.Module):
|
||||
if add_cls:
|
||||
self.cls_embed = torch.nn.Embedding(1, self.d_model)
|
||||
|
||||
assert (
|
||||
points_direct_project or points_pos_enc or points_pool
|
||||
), "Error: need at least one way to encode points"
|
||||
assert points_direct_project or points_pos_enc or points_pool, (
|
||||
"Error: need at least one way to encode points"
|
||||
)
|
||||
assert (
|
||||
encode_boxes_as_points
|
||||
or boxes_direct_project
|
||||
@@ -583,16 +570,16 @@ class SequenceGeometryEncoder(nn.Module):
|
||||
|
||||
self.encode = None
|
||||
if num_layers > 0:
|
||||
assert (
|
||||
add_cls
|
||||
), "It's currently highly recommended to add a CLS when using a transformer"
|
||||
assert add_cls, (
|
||||
"It's currently highly recommended to add a CLS when using a transformer"
|
||||
)
|
||||
self.encode = get_clones(layer, num_layers)
|
||||
self.encode_norm = nn.LayerNorm(self.d_model)
|
||||
|
||||
if mask_encoder is not None:
|
||||
assert isinstance(
|
||||
mask_encoder, MaskEncoder
|
||||
), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
|
||||
assert isinstance(mask_encoder, MaskEncoder), (
|
||||
f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
|
||||
)
|
||||
if add_mask_label:
|
||||
self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
|
||||
self.add_mask_label = add_mask_label
|
||||
@@ -701,16 +688,15 @@ class SequenceGeometryEncoder(nn.Module):
|
||||
img_feats: torch.Tensor = None,
|
||||
):
|
||||
n_masks, bs = masks.shape[:2]
|
||||
assert (
|
||||
n_masks == 1
|
||||
), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
|
||||
assert (
|
||||
list(attn_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
n_masks,
|
||||
]
|
||||
), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
|
||||
assert n_masks == 1, (
|
||||
"We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
|
||||
)
|
||||
assert list(attn_mask.shape) == [
|
||||
bs,
|
||||
n_masks,
|
||||
], (
|
||||
f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
|
||||
)
|
||||
masks, pos = self.mask_encoder(
|
||||
masks=masks.flatten(0, 1).float(),
|
||||
pix_feat=img_feats,
|
||||
|
||||
Reference in New Issue
Block a user