Differential Revision: D90237984 fbshipit-source-id: 526fd760f303bf31be4f743bdcd77760496de0de
331 lines
11 KiB
Python
331 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
|
|
# pyre-unsafe
|
|
|
|
from collections import OrderedDict
|
|
from typing import Callable, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
from .model_misc import LayerScale
|
|
|
|
|
|
class ResidualAttentionBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
n_head: int,
|
|
mlp_ratio: float = 4.0,
|
|
ls_init_value: Optional[float] = None,
|
|
act_layer: Callable[[], nn.Module] = nn.GELU,
|
|
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
|
):
|
|
super().__init__()
|
|
# Attention
|
|
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
|
|
|
|
# LayerNorm, LayerScale
|
|
self.ln_1 = norm_layer(d_model)
|
|
self.ln_2 = norm_layer(d_model)
|
|
|
|
self.ls_1 = (
|
|
LayerScale(d_model, ls_init_value)
|
|
if ls_init_value is not None
|
|
else nn.Identity()
|
|
)
|
|
self.ls_2 = (
|
|
LayerScale(d_model, ls_init_value)
|
|
if ls_init_value is not None
|
|
else nn.Identity()
|
|
)
|
|
|
|
# MLP
|
|
mlp_width = int(d_model * mlp_ratio)
|
|
self.mlp = nn.Sequential(
|
|
OrderedDict(
|
|
[
|
|
("c_fc", nn.Linear(d_model, mlp_width)),
|
|
("gelu", act_layer()),
|
|
("c_proj", nn.Linear(mlp_width, d_model)),
|
|
]
|
|
)
|
|
)
|
|
|
|
def attention(
|
|
self,
|
|
q_x: torch.Tensor,
|
|
k_x: Optional[torch.Tensor] = None,
|
|
v_x: Optional[torch.Tensor] = None,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
k_x = k_x if k_x is not None else q_x
|
|
v_x = v_x if v_x is not None else q_x
|
|
if attn_mask is not None:
|
|
# Leave boolean masks as is
|
|
if not attn_mask.dtype == torch.bool:
|
|
attn_mask = attn_mask.to(q_x.dtype)
|
|
|
|
return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0]
|
|
|
|
def forward(
|
|
self,
|
|
q_x: torch.Tensor,
|
|
k_x: Optional[torch.Tensor] = None,
|
|
v_x: Optional[torch.Tensor] = None,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
k_x = (
|
|
self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
|
)
|
|
v_x = (
|
|
self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
|
)
|
|
x = q_x + self.ls_1(
|
|
self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
|
|
)
|
|
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
|
return x
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
width: int,
|
|
layers: int,
|
|
heads: int,
|
|
mlp_ratio: float = 4.0,
|
|
ls_init_value: Optional[float] = None,
|
|
act_layer: Callable[[], nn.Module] = nn.GELU,
|
|
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
|
compile_mode: Optional[str] = None,
|
|
use_act_checkpoint: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.width = width
|
|
self.layers = layers
|
|
self.grad_checkpointing = use_act_checkpoint
|
|
self.resblocks = nn.ModuleList(
|
|
[
|
|
ResidualAttentionBlock(
|
|
width,
|
|
heads,
|
|
mlp_ratio,
|
|
ls_init_value=ls_init_value,
|
|
act_layer=act_layer,
|
|
norm_layer=norm_layer,
|
|
)
|
|
for _ in range(layers)
|
|
]
|
|
)
|
|
|
|
if compile_mode is not None:
|
|
self.forward = torch.compile(
|
|
self.forward, mode=compile_mode, fullgraph=True
|
|
)
|
|
if self.grad_checkpointing:
|
|
torch._dynamo.config.optimize_ddp = False
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
for _, r in enumerate(self.resblocks):
|
|
if (
|
|
self.grad_checkpointing
|
|
and not torch.jit.is_scripting()
|
|
and self.training
|
|
):
|
|
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
|
|
else:
|
|
x = r(
|
|
x,
|
|
attn_mask=attn_mask,
|
|
)
|
|
return x
|
|
|
|
|
|
def text_global_pool(
|
|
x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if pool_type == "first":
|
|
pooled, tokens = x[:, 0], x[:, 1:]
|
|
elif pool_type == "last":
|
|
pooled, tokens = x[:, -1], x[:, :-1]
|
|
elif pool_type == "argmax":
|
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
|
assert text is not None
|
|
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
|
|
else:
|
|
pooled = tokens = x
|
|
return pooled, tokens
|
|
|
|
|
|
class TextTransformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
context_length: int = 77,
|
|
vocab_size: int = 49408,
|
|
width: int = 512,
|
|
heads: int = 8,
|
|
layers: int = 12,
|
|
mlp_ratio: float = 4.0,
|
|
ls_init_value: Optional[float] = None,
|
|
output_dim: int = 512,
|
|
no_causal_mask: bool = False,
|
|
pool_type: str = "none", # no pooling
|
|
proj_bias: bool = False,
|
|
act_layer: Callable = nn.GELU,
|
|
norm_layer: Callable = nn.LayerNorm,
|
|
output_tokens: bool = False,
|
|
use_ln_post: bool = True,
|
|
compile_mode: Optional[str] = None,
|
|
use_act_checkpoint: bool = False,
|
|
):
|
|
super().__init__()
|
|
assert pool_type in ("first", "last", "argmax", "none")
|
|
self.output_tokens = output_tokens
|
|
self.num_pos = self.context_length = context_length
|
|
self.vocab_size = vocab_size
|
|
self.width = width
|
|
self.output_dim = output_dim
|
|
self.heads = heads
|
|
self.pool_type = pool_type
|
|
|
|
self.token_embedding = nn.Embedding(self.vocab_size, width)
|
|
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
|
self.transformer = Transformer(
|
|
width=width,
|
|
layers=layers,
|
|
heads=heads,
|
|
mlp_ratio=mlp_ratio,
|
|
ls_init_value=ls_init_value,
|
|
act_layer=act_layer,
|
|
norm_layer=norm_layer,
|
|
compile_mode=compile_mode,
|
|
use_act_checkpoint=use_act_checkpoint,
|
|
)
|
|
self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
|
|
if no_causal_mask:
|
|
self.attn_mask = None
|
|
else:
|
|
self.register_buffer(
|
|
"attn_mask", self.build_causal_mask(), persistent=False
|
|
)
|
|
if proj_bias:
|
|
self.text_projection = nn.Linear(width, output_dim)
|
|
else:
|
|
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
|
|
|
def build_causal_mask(self) -> torch.Tensor:
|
|
# lazily create causal attention mask, with full attention between the tokens
|
|
# pytorch uses additive attention mask; fill with -inf
|
|
mask = torch.empty(self.num_pos, self.num_pos)
|
|
mask.fill_(float("-inf"))
|
|
mask.triu_(1) # zero out the lower diagonal
|
|
return mask
|
|
|
|
def forward(
|
|
self, text: torch.Tensor
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
seq_len = text.shape[1]
|
|
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
|
|
|
|
attn_mask = self.attn_mask
|
|
if attn_mask is not None:
|
|
attn_mask = attn_mask[:seq_len, :seq_len]
|
|
|
|
x = x + self.positional_embedding[:seq_len]
|
|
x = self.transformer(x, attn_mask=attn_mask)
|
|
|
|
x = self.ln_final(x)
|
|
pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
|
|
if self.text_projection is not None:
|
|
if isinstance(self.text_projection, nn.Linear):
|
|
pooled = self.text_projection(pooled)
|
|
else:
|
|
pooled = pooled @ self.text_projection
|
|
if self.output_tokens:
|
|
return pooled, tokens
|
|
return pooled
|
|
|
|
|
|
class VETextEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
tokenizer: Callable,
|
|
width: int = 1024,
|
|
heads: int = 16,
|
|
layers: int = 24,
|
|
context_length: int = 32,
|
|
vocab_size: int = 49408,
|
|
use_ln_post: bool = True,
|
|
compile_mode: Optional[str] = None,
|
|
use_act_checkpoint: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.context_length = context_length
|
|
self.use_ln_post = use_ln_post
|
|
self.tokenizer = tokenizer
|
|
|
|
self.encoder = TextTransformer(
|
|
context_length=self.context_length,
|
|
vocab_size=vocab_size,
|
|
width=width,
|
|
heads=heads,
|
|
layers=layers,
|
|
# we want the tokens, not just the pooled output
|
|
output_tokens=True,
|
|
use_ln_post=use_ln_post,
|
|
compile_mode=compile_mode,
|
|
use_act_checkpoint=use_act_checkpoint,
|
|
)
|
|
self.resizer = nn.Linear(self.encoder.width, d_model)
|
|
|
|
def forward(
|
|
self,
|
|
text: Union[List[str], Tuple[torch.Tensor, torch.Tensor, dict]],
|
|
input_boxes: Optional[List] = None,
|
|
device: torch.device = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
if isinstance(text[0], str):
|
|
# no use case for this
|
|
assert input_boxes is None or len(input_boxes) == 0, "not supported"
|
|
|
|
# Encode the text
|
|
tokenized = self.tokenizer(text, context_length=self.context_length).to(
|
|
device
|
|
) # [b, seq_len]
|
|
text_attention_mask = (tokenized != 0).bool()
|
|
|
|
# manually embed the tokens
|
|
inputs_embeds = self.encoder.token_embedding(
|
|
tokenized
|
|
) # [b, seq_len, d=1024]
|
|
_, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024]
|
|
|
|
assert text_memory.shape[1] == inputs_embeds.shape[1]
|
|
# Invert attention mask because its the opposite in pytorch transformer
|
|
text_attention_mask = text_attention_mask.ne(1)
|
|
# Transpose memory because pytorch's attention expects sequence first
|
|
text_memory = text_memory.transpose(0, 1)
|
|
# Resize the encoder hidden states to be of the same d_model as the decoder
|
|
text_memory_resized = self.resizer(text_memory)
|
|
else:
|
|
# The text is already encoded, use as is.
|
|
text_attention_mask, text_memory_resized, tokenized = text
|
|
inputs_embeds = tokenized["inputs_embeds"]
|
|
assert (
|
|
input_boxes is None or len(input_boxes) == 0
|
|
), "Can't replace boxes in text if it's already encoded"
|
|
|
|
# Note that the input_embeds are returned in pytorch's convention (sequence first)
|
|
return (
|
|
text_attention_mask,
|
|
text_memory_resized,
|
|
inputs_embeds.transpose(0, 1),
|
|
)
|