Spaces:
Runtime error
Runtime error
# This code is referenced from https://github.com/dhansmair/flamingo-mini | |
import torch | |
from einops import rearrange, repeat | |
from einops_exts import rearrange_many | |
from torch import einsum, nn | |
import math | |
import torch.nn.functional as F | |
from .configuration_gecko import GeckoConfig | |
from transformers.activations import ACT2FN | |
from torch.nn.init import trunc_normal_ | |
from functools import partial | |
def feed_forward_layer(dim: int, mult: int = 4, activation: str = 'gelu'): | |
"""Feed forward layer with given activation function""" | |
activations = dict(gelu=nn.GELU, relu=nn.ReLU) | |
assert activation in activations, f'activation can only be one of {activations.keys()}' | |
inner_dim = int(dim * mult) | |
return nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Linear(dim, inner_dim, bias=False), | |
activations[activation](), | |
nn.Linear(inner_dim, dim, bias=False), | |
) | |
class PerceiverAttentionLayer(nn.Module): | |
"""Perceiver Attention Layer""" | |
def __init__(self, dim: int, dim_head: int = 64, heads: int = 8): | |
super().__init__() | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
self.dim_head = dim_head | |
inner_dim = dim_head * heads | |
# trainable components of PerceiverAttentionLayer | |
self.norm_media = nn.LayerNorm(dim) | |
self.norm_latents = nn.LayerNorm(dim) | |
self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
self.to_k = nn.Linear(dim, inner_dim, bias=False) | |
self.to_v = nn.Linear(dim, inner_dim, bias=False) | |
self.to_out = nn.Linear(inner_dim, dim, bias=False) | |
def forward(self, features, latents): | |
"""Latent vectors are cross-attending to the visual features x | |
Args: | |
features: Batch of visual features with shape (batch_size, n_tokens, dim) | |
latents: Latent learnt vectors which are used to compute queries with shape (batch_size, n_latents, dim) | |
Returns: | |
Attention score with shape (batch_size, n_latents, dim) | |
""" | |
assert features.ndim == 3 | |
assert latents.ndim == 3 | |
assert features.shape[0] == latents.shape[0] | |
assert features.shape[2] == latents.shape[2] | |
n_heads = self.heads | |
n_batch, n_features, dim = features.shape | |
n_queries = latents.shape[1] | |
# Layer normalization | |
x = self.norm_media(features) | |
latents = self.norm_latents(latents) | |
# Compute the queries from the latents, for all attention heads simultaneously | |
q = self.to_q(latents) | |
q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads) | |
assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head]) | |
# Keys and values for all attention heads | |
kv_input = torch.cat((x, latents), dim=-2) | |
n_features_latents = n_features + n_queries | |
k = self.to_k(kv_input) | |
v = self.to_v(kv_input) | |
k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads) | |
assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head]) | |
q = q * self.scale | |
# Attention scores | |
sim = einsum('b h q d, b h f d -> b h q f', q, k) | |
sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
alphas = sim.softmax(dim=-1) | |
out = einsum('b h q f, b h f v -> b h q v', alphas, v) | |
out = rearrange(out, 'b h q v -> b q (h v)') | |
return self.to_out(out) | |
class GeckoResamplerProjector(nn.Module): | |
"""Perceiver Resampler with multi-head attention layer""" | |
def __init__( | |
self, | |
config: GeckoConfig, | |
num_queries: int = 64, | |
depth: int = 2, | |
dim_head: int = 32, | |
heads: int = 4, | |
ff_mult: int = 2, | |
): | |
super().__init__() | |
self.dim = config.text_config.hidden_size | |
self.num_queries = num_queries | |
self.latents = nn.Parameter(torch.randn(self.num_queries, self.dim)) # type: ignore[reportPrivateUsage] | |
self.linear = nn.Linear(config.vision_config.hidden_size, self.dim) | |
self.layers = nn.ModuleList([]) | |
for _ in range(depth): | |
self.layers.append( | |
nn.ModuleList( | |
[ | |
PerceiverAttentionLayer(dim=self.dim, dim_head=dim_head, heads=heads), | |
feed_forward_layer(dim=self.dim, mult=ff_mult, activation=config.projector_hidden_act), | |
] | |
) | |
) | |
# Layer normalization takes as input the query vector length | |
self.norm = nn.LayerNorm(self.dim) | |
def forward(self, x_f: torch.Tensor): | |
"""Run perceiver resampler on the input visual embeddings | |
Args: | |
x_f: Input visual embeddings of shape (batch_size, num_tokens, d_visual) | |
Returns: | |
Resampler features of shape (batch_size, num_queries, d_visual) | |
""" | |
assert x_f.ndim == 3 | |
x_f = self.linear(x_f) | |
batch_size, num_tokens, dim = x_f.shape | |
assert dim == self.dim | |
# Copy the latents for every element in the batch | |
x = repeat(self.latents, 'q d -> b q d', b=batch_size) | |
# Apply attention and feed forward layer | |
for attn, ffw in self.layers: | |
x = x + attn(x_f, x) | |
x = x + ffw(x) | |
assert x.shape == torch.Size([batch_size, self.num_queries, self.dim]) | |
norm = self.norm(x) | |
return norm | |
class GeckoMLPProjector(nn.Module): | |
def __init__(self, config: GeckoConfig): | |
super().__init__() | |
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size) | |
self.act = ACT2FN[config.projector_hidden_act] | |
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) | |
def forward(self, image_features): | |
hidden_states = self.linear_1(image_features) | |
hidden_states = self.act(hidden_states) | |
hidden_states = self.linear_2(hidden_states) | |
return hidden_states |