# This code is referenced from https://github.com/dhansmair/flamingo-mini import torch from einops import rearrange, repeat from einops_exts import rearrange_many from torch import einsum, nn import math import torch.nn.functional as F from .configuration_gecko import GeckoConfig from transformers.activations import ACT2FN from torch.nn.init import trunc_normal_ from functools import partial def feed_forward_layer(dim: int, mult: int = 4, activation: str = 'gelu'): """Feed forward layer with given activation function""" activations = dict(gelu=nn.GELU, relu=nn.ReLU) assert activation in activations, f'activation can only be one of {activations.keys()}' inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias=False), activations[activation](), nn.Linear(inner_dim, dim, bias=False), ) class PerceiverAttentionLayer(nn.Module): """Perceiver Attention Layer""" def __init__(self, dim: int, dim_head: int = 64, heads: int = 8): super().__init__() self.scale = dim_head**-0.5 self.heads = heads self.dim_head = dim_head inner_dim = dim_head * heads # trainable components of PerceiverAttentionLayer self.norm_media = nn.LayerNorm(dim) self.norm_latents = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_k = nn.Linear(dim, inner_dim, bias=False) self.to_v = nn.Linear(dim, inner_dim, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) def forward(self, features, latents): """Latent vectors are cross-attending to the visual features x Args: features: Batch of visual features with shape (batch_size, n_tokens, dim) latents: Latent learnt vectors which are used to compute queries with shape (batch_size, n_latents, dim) Returns: Attention score with shape (batch_size, n_latents, dim) """ assert features.ndim == 3 assert latents.ndim == 3 assert features.shape[0] == latents.shape[0] assert features.shape[2] == latents.shape[2] n_heads = self.heads n_batch, n_features, dim = features.shape n_queries = latents.shape[1] # Layer normalization x = self.norm_media(features) latents = self.norm_latents(latents) # Compute the queries from the latents, for all attention heads simultaneously q = self.to_q(latents) q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads) assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head]) # Keys and values for all attention heads kv_input = torch.cat((x, latents), dim=-2) n_features_latents = n_features + n_queries k = self.to_k(kv_input) v = self.to_v(kv_input) k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads) assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head]) q = q * self.scale # Attention scores sim = einsum('b h q d, b h f d -> b h q f', q, k) sim = sim - sim.amax(dim=-1, keepdim=True).detach() alphas = sim.softmax(dim=-1) out = einsum('b h q f, b h f v -> b h q v', alphas, v) out = rearrange(out, 'b h q v -> b q (h v)') return self.to_out(out) class GeckoResamplerProjector(nn.Module): """Perceiver Resampler with multi-head attention layer""" def __init__( self, config: GeckoConfig, num_queries: int = 64, depth: int = 2, dim_head: int = 32, heads: int = 4, ff_mult: int = 2, ): super().__init__() self.dim = config.text_config.hidden_size self.num_queries = num_queries self.latents = nn.Parameter(torch.randn(self.num_queries, self.dim)) # type: ignore[reportPrivateUsage] self.linear = nn.Linear(config.vision_config.hidden_size, self.dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ PerceiverAttentionLayer(dim=self.dim, dim_head=dim_head, heads=heads), feed_forward_layer(dim=self.dim, mult=ff_mult, activation=config.projector_hidden_act), ] ) ) # Layer normalization takes as input the query vector length self.norm = nn.LayerNorm(self.dim) def forward(self, x_f: torch.Tensor): """Run perceiver resampler on the input visual embeddings Args: x_f: Input visual embeddings of shape (batch_size, num_tokens, d_visual) Returns: Resampler features of shape (batch_size, num_queries, d_visual) """ assert x_f.ndim == 3 x_f = self.linear(x_f) batch_size, num_tokens, dim = x_f.shape assert dim == self.dim # Copy the latents for every element in the batch x = repeat(self.latents, 'q d -> b q d', b=batch_size) # Apply attention and feed forward layer for attn, ffw in self.layers: x = x + attn(x_f, x) x = x + ffw(x) assert x.shape == torch.Size([batch_size, self.num_queries, self.dim]) norm = self.norm(x) return norm class GeckoMLPProjector(nn.Module): def __init__(self, config: GeckoConfig): super().__init__() self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size) self.act = ACT2FN[config.projector_hidden_act] self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) def forward(self, image_features): hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states