import inspect import torch from einops import rearrange from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from .helpers import PerceiverResampler def unwrap_fsdp(m): if isinstance(m, FSDP): return unwrap_fsdp(m.module) return m def accepts_parameter(func, parameter_name): signature = inspect.signature(func) return parameter_name in signature.parameters class Flamingo(nn.Module): def __init__( self, vision_encoder: nn.Module, lang_encoder: nn.Module, eoc_token_id: int, media_token_id: int, vis_dim: int, cross_attn_every_n_layers: int = 1, gradient_checkpointing: bool = False, enable_init_network_params: bool = False, initializer_range: float = 0.02, ): """ Args: vision_encoder (nn.Module): HF CLIPModel lang_encoder (nn.Module): HF causal language model eoc_token_id (int): Token id for <|endofchunk|> media_token_id (int): Token id for vis_dim (int): Dimension of the visual features. Visual features are projected to match this shape along the last dimension. cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. """ super().__init__() self.eoc_token_id = eoc_token_id self.media_token_id = media_token_id self.vis_dim = vis_dim if hasattr(lang_encoder.config, "d_model"): self.lang_dim = lang_encoder.config.d_model # mpt uses d_model else: self.lang_dim = lang_encoder.config.hidden_size self.vision_encoder = ( vision_encoder.visual if hasattr(vision_encoder, "visual") else vision_encoder ) self.perceiver = PerceiverResampler( dim=self.vis_dim, enable_init_network_params=enable_init_network_params, initializer_range=initializer_range, gradient_checkpointing=gradient_checkpointing, ) self.lang_encoder = lang_encoder self.lang_encoder.init_flamingo( media_token_id=media_token_id, lang_hidden_size=self.lang_dim, vis_hidden_size=self.vis_dim, cross_attn_every_n_layers=cross_attn_every_n_layers, gradient_checkpointing=gradient_checkpointing, enable_init_network_params=enable_init_network_params, initializer_range=initializer_range, ) self._use_gradient_checkpointing = gradient_checkpointing self.perceiver._use_gradient_checkpointing = gradient_checkpointing def forward( self, vision_x: torch.Tensor, lang_x: torch.Tensor, attention_mask: torch.Tensor = None, labels: torch.Tensor = None, clear_conditioned_layers: bool = True, past_key_values=None, use_cache: bool = False, ): """ Forward pass of Flamingo. Args: vision_x (torch.Tensor): Vision input shape (B, T_img, F, C, H, W) with F=1 lang_x (torch.Tensor): Language input ids shape (B, T_txt) attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. labels (torch.Tensor, optional): Labels. Defaults to None. clear_conditioned_layers: if True, clear the conditioned layers once the foward pass is completed. Set this to false if the same set of images will be reused in another subsequent forward pass. past_key_values: pre-computed values to pass to language model. See past_key_values documentation in Hugging Face CausalLM models. use_cache: whether to use cached key values. See use_cache documentation in Hugging Face CausalLM models. """ assert ( self.lang_encoder.initialized_flamingo ), "Flamingo layers are not initialized. Please call `init_flamingo` first." assert ( self.lang_encoder._use_cached_vision_x or vision_x is not None ), "Must provide either vision_x or have precached media using cache_media()." if self.lang_encoder._use_cached_vision_x: # Case: use cached; vision_x should be cached and other # vision-related inputs should not be provided. assert vision_x is None, ( "Expect vision_x to be None when media has been cached using" " cache_media(). Try uncache_media() first." ) assert self.lang_encoder.is_conditioned() else: # Case: do not use caching (i.e. this is a standard forward pass); self._encode_vision_x(vision_x=vision_x) self._condition_media_locations(input_ids=lang_x) output = self.lang_encoder( input_ids=lang_x, attention_mask=attention_mask, labels=labels, past_key_values=past_key_values, use_cache=use_cache, ) if clear_conditioned_layers: self.lang_encoder.clear_conditioned_layers() return output def generate( self, vision_x: torch.Tensor, lang_x: torch.Tensor, attention_mask: torch.Tensor = None, **kwargs, ): """ Generate text conditioned on vision and language inputs. Args: vision_x (torch.Tensor): Vision input shape (B, T_img, F, C, H, W) images in the same chunk are collated along T_img, and frames are collated along F currently only F=1 is supported (single-frame videos) lang_x (torch.Tensor): Language input shape (B, T_txt) **kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs: max_length (int, optional): Maximum length of the output. Defaults to None. attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. num_beams (int, optional): Number of beams. Defaults to 1. max_new_tokens (int, optional): Maximum new tokens. Defaults to None. temperature (float, optional): Temperature. Defaults to 1.0. top_k (int, optional): Top k. Defaults to 50. top_p (float, optional): Top p. Defaults to 1.0. no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. length_penalty (float, optional): Length penalty. Defaults to 1.0. num_return_sequences (int, optional): Number of return sequences. Defaults to 1. do_sample (bool, optional): Do sample. Defaults to False. early_stopping (bool, optional): Early stopping. Defaults to False. Returns: torch.Tensor: lang_x with generated tokens appended to it """ num_beams = kwargs.pop("num_beams", 1) if num_beams > 1: vision_x = vision_x.repeat_interleave(num_beams, dim=0) self.lang_encoder._use_cached_vision_x = True self._encode_vision_x(vision_x=vision_x) # eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id) output = self.lang_encoder.generate( input_ids=lang_x, attention_mask=attention_mask, # eos_token_id=eos_token_id, num_beams=num_beams, **kwargs, ) self.lang_encoder.clear_conditioned_layers() self.lang_encoder._use_cached_vision_x = False return output def _encode_vision_x(self, vision_x: torch.Tensor): """ Compute media tokens from vision input by passing it through vision encoder and conditioning language model. Args: vision_x (torch.Tensor): Vision input shape (B, T_img, F, C, H, W) Images in the same chunk are collated along T_img, and frames are collated along F Currently only F=1 is supported (single-frame videos) rearrange code based on https://github.com/dhansmair/flamingo-mini """ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" b, T, F = vision_x.shape[:3] assert F == 1, "Only single frame supported" vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") with torch.no_grad(): module_to_inspect = unwrap_fsdp(self.vision_encoder) if accepts_parameter(module_to_inspect.forward, "return_all_features"): vision_x = self.vision_encoder(vision_x, return_all_features=True) else: vision_x = self.vision_encoder(vision_x)[1] vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) vision_x = self.perceiver(vision_x) for layer in self.lang_encoder._get_decoder_layers(): layer.condition_vis_x(vision_x) def _condition_media_locations(self, input_ids: torch.Tensor): """ Compute the media token locations from lang_x and condition the language model on these. Args: input_ids (torch.Tensor): Language input shape (B, T_txt) """ media_locations = input_ids == self.media_token_id for layer in self.lang_encoder._get_decoder_layers(): layer.condition_media_locations(media_locations) def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor): """ Pre-cache a prompt/sequence of images / text for log-likelihood evaluations. All subsequent calls to forward() will generate attending to the LAST image in vision_x. This is not meant to be used to cache things for generate(). Args: input_ids (torch.Tensor): Language input shape (B, T_txt) vision_x (torch.Tensor): Vision input shape (B, T_img, F, C, H, W) Images in the same chunk are collated along T_img, and frames are collated along F Currently only F=1 is supported (single-frame videos) """ self._encode_vision_x(vision_x=vision_x) self._condition_media_locations(input_ids=input_ids) self.lang_encoder._use_cached_vision_x = True def uncache_media(self): """ Clear all conditioning. """ self.lang_encoder.clear_conditioned_layers() self.lang_encoder._use_cached_vision_x = False