Files
sam3_local/sam3/model/text_encoder_ve.py
Bowie Chen 11dec2936d apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: itamaro

Differential Revision: D90476315

fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
2026-01-11 23:16:49 -08:00

331 lines
11 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# pyre-unsafe
from collections import OrderedDict
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .model_misc import LayerScale
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: Optional[float] = None,
act_layer: Callable[[], nn.Module] = nn.GELU,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
):
super().__init__()
# Attention
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
# LayerNorm, LayerScale
self.ln_1 = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
self.ls_1 = (
LayerScale(d_model, ls_init_value)
if ls_init_value is not None
else nn.Identity()
)
self.ls_2 = (
LayerScale(d_model, ls_init_value)
if ls_init_value is not None
else nn.Identity()
)
# MLP
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model)),
]
)
)
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
if attn_mask is not None:
# Leave boolean masks as is
if not attn_mask.dtype == torch.bool:
attn_mask = attn_mask.to(q_x.dtype)
return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0]
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
k_x = (
self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
)
v_x = (
self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
)
x = q_x + self.ls_1(
self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
)
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: Optional[float] = None,
act_layer: Callable[[], nn.Module] = nn.GELU,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
compile_mode: Optional[str] = None,
use_act_checkpoint: bool = False,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = use_act_checkpoint
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
for _ in range(layers)
]
)
if compile_mode is not None:
self.forward = torch.compile(
self.forward, mode=compile_mode, fullgraph=True
)
if self.grad_checkpointing:
torch._dynamo.config.optimize_ddp = False
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for _, r in enumerate(self.resblocks):
if (
self.grad_checkpointing
and not torch.jit.is_scripting()
and self.training
):
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
else:
x = r(
x,
attn_mask=attn_mask,
)
return x
def text_global_pool(
x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
) -> Tuple[torch.Tensor, torch.Tensor]:
if pool_type == "first":
pooled, tokens = x[:, 0], x[:, 1:]
elif pool_type == "last":
pooled, tokens = x[:, -1], x[:, :-1]
elif pool_type == "argmax":
# take features from the eot embedding (eot_token is the highest number in each sequence)
assert text is not None
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
else:
pooled = tokens = x
return pooled, tokens
class TextTransformer(nn.Module):
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
mlp_ratio: float = 4.0,
ls_init_value: Optional[float] = None,
output_dim: int = 512,
no_causal_mask: bool = False,
pool_type: str = "none", # no pooling
proj_bias: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
output_tokens: bool = False,
use_ln_post: bool = True,
compile_mode: Optional[str] = None,
use_act_checkpoint: bool = False,
):
super().__init__()
assert pool_type in ("first", "last", "argmax", "none")
self.output_tokens = output_tokens
self.num_pos = self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.heads = heads
self.pool_type = pool_type
self.token_embedding = nn.Embedding(self.vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
compile_mode=compile_mode,
use_act_checkpoint=use_act_checkpoint,
)
self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
if no_causal_mask:
self.attn_mask = None
else:
self.register_buffer(
"attn_mask", self.build_causal_mask(), persistent=False
)
if proj_bias:
self.text_projection = nn.Linear(width, output_dim)
else:
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
def build_causal_mask(self) -> torch.Tensor:
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(
self, text: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_len = text.shape[1]
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
attn_mask = self.attn_mask
if attn_mask is not None:
attn_mask = attn_mask[:seq_len, :seq_len]
x = x + self.positional_embedding[:seq_len]
x = self.transformer(x, attn_mask=attn_mask)
x = self.ln_final(x)
pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
pooled = self.text_projection(pooled)
else:
pooled = pooled @ self.text_projection
if self.output_tokens:
return pooled, tokens
return pooled
class VETextEncoder(nn.Module):
def __init__(
self,
d_model: int,
tokenizer: Callable,
width: int = 1024,
heads: int = 16,
layers: int = 24,
context_length: int = 32,
vocab_size: int = 49408,
use_ln_post: bool = True,
compile_mode: Optional[str] = None,
use_act_checkpoint: bool = True,
):
super().__init__()
self.context_length = context_length
self.use_ln_post = use_ln_post
self.tokenizer = tokenizer
self.encoder = TextTransformer(
context_length=self.context_length,
vocab_size=vocab_size,
width=width,
heads=heads,
layers=layers,
# we want the tokens, not just the pooled output
output_tokens=True,
use_ln_post=use_ln_post,
compile_mode=compile_mode,
use_act_checkpoint=use_act_checkpoint,
)
self.resizer = nn.Linear(self.encoder.width, d_model)
def forward(
self,
text: Union[List[str], Tuple[torch.Tensor, torch.Tensor, dict]],
input_boxes: Optional[List] = None,
device: torch.device = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if isinstance(text[0], str):
# no use case for this
assert input_boxes is None or len(input_boxes) == 0, "not supported"
# Encode the text
tokenized = self.tokenizer(text, context_length=self.context_length).to(
device
) # [b, seq_len]
text_attention_mask = (tokenized != 0).bool()
# manually embed the tokens
inputs_embeds = self.encoder.token_embedding(
tokenized
) # [b, seq_len, d=1024]
_, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024]
assert text_memory.shape[1] == inputs_embeds.shape[1]
# Invert attention mask because its the opposite in pytorch transformer
text_attention_mask = text_attention_mask.ne(1)
# Transpose memory because pytorch's attention expects sequence first
text_memory = text_memory.transpose(0, 1)
# Resize the encoder hidden states to be of the same d_model as the decoder
text_memory_resized = self.resizer(text_memory)
else:
# The text is already encoded, use as is.
text_attention_mask, text_memory_resized, tokenized = text
inputs_embeds = tokenized["inputs_embeds"]
assert input_boxes is None or len(input_boxes) == 0, (
"Can't replace boxes in text if it's already encoded"
)
# Note that the input_embeds are returned in pytorch's convention (sequence first)
return (
text_attention_mask,
text_memory_resized,
inputs_embeds.transpose(0, 1),
)