from dataclasses import dataclass, field import inspect import logging from typing import Optional, List, Union, Dict, Tuple, Any from transformers.configuration_utils import PretrainedConfig import mlx.core as mx # Define a custom float tensor type using the provided data type class FloatTensor: def __init__(self, data): if data is not None: self.tensor = mx.array(data, dtype=mx.float32) else: self.tensor = None def __repr__(self): return repr(self.tensor) # Define a custom LongTensor class class LongTensor: def __init__(self, data=None): if data is not None: self.tensor = mx.array(data, dtype=mx.int64) else: self.tensor = None def assign(self, data): self.tensor = mx.array(data, dtype=mx.int64) def __repr__(self): return repr(self.tensor) @dataclass class BaseModelOutputWithPast: """ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). Args: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ last_hidden_state: FloatTensor = None past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None hidden_states: Optional[Tuple[FloatTensor, ...]] = None attentions: Optional[Tuple[FloatTensor, ...]] = None @dataclass class Cache: """ Base, abstract class for all caches. The actual data structure is specific to each subclass. """ def update( self, key_states: mx.array, value_states: mx.array, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[mx.array, mx.array]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`mx.array`): The new key states to cache. value_states (`mx.array`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. These are specific to each subclass and allow new types of cache to be created. Return: A tuple containing the updated key and value states. """ raise NotImplementedError("Make sure to implement `update` in a subclass.") def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states, if there is any.""" raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: """Given the sequence length of the new inputs, returns the usable length of the cache.""" # Cache without size limit -> all cache is usable # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache # length, we will need to evict part of the cache (and thus not all cache is usable) max_length = self.get_max_length() previous_seq_length = self.get_seq_length(layer_idx) if max_length is not None and previous_seq_length + new_seq_length > max_length: return max_length - new_seq_length return previous_seq_length # def reorder_cache(self, beam_idx: LongTensor): # """Reorders the cache for beam search, given the selected beam indices.""" # for layer_idx in range(len(self.key_cache)): # device = self.key_cache[layer_idx].device # self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) # device = self.value_cache[layer_idx].device # self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) @property def seen_tokens(self): logging.warning( "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " "model input instead." ) if hasattr(self, "_seen_tokens"): return self._seen_tokens else: return None class DynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. """ def __init__(self) -> None: self.key_cache: List[mx.array] = [] self.value_cache: List[mx.array] = [] self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen def __getitem__(self, layer_idx: int) -> List[Tuple[mx.array]]: """ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the sequence length. """ if layer_idx < len(self): return (self.key_cache[layer_idx], self.value_cache[layer_idx]) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") def __iter__(self): """ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over keys and values """ for layer_idx in range(len(self)): yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) def __len__(self): """ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds to the number of layers in the model. """ return len(self.key_cache) def update( self, key_states: mx.array, value_states: mx.array, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[mx.array, mx.array]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`mx.array`): The new key states to cache. value_states (`mx.array`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. Return: A tuple containing the updated key and value states. """ # Update the number of seen tokens if layer_idx == 0: self._seen_tokens += key_states.shape[-2] # Update the cache if len(self.key_cache) <= layer_idx: self.key_cache.append(key_states) self.value_cache.append(value_states) else: self.key_cache[layer_idx] = mx.concatenate([self.key_cache[layer_idx], key_states], dim=-2) self.value_cache[layer_idx] = mx.concatenate([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` if len(self.key_cache) <= layer_idx: return 0 return self.key_cache[layer_idx].shape[-2] def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" return None def to_legacy_cache(self) -> Tuple[Tuple[mx.array], Tuple[mx.array]]: """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" legacy_cache = () for layer_idx in range(len(self)): legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) return legacy_cache @classmethod def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None) -> "DynamicCache": """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" cache = cls() if past_key_values is not None: for layer_idx in range(len(past_key_values)): key_states, value_states = past_key_values[layer_idx] cache.update(key_states, value_states, layer_idx) return cache @dataclass class CausalLMOutputWithPast(): loss: Optional[FloatTensor] = None logits: FloatTensor = None past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None hidden_states: Optional[Tuple[FloatTensor, ...]] = None attentions: Optional[Tuple[FloatTensor, ...]] = None