# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Tuple, Union from utils import FloatTensor import mlx.core as mx # Custom function to mimic torch.finfo def get_finfo_min(dtype: mx.Dtype): dtype_str = str(dtype) if dtype_str == 'float32': return -3.4028235e+38 # Minimum value for float32 elif dtype_str == 'float64': return -1.7976931348623157e+308 # Minimum value for float64 elif dtype_str == 'float16': return -65504.0 # Minimum value for float16 else: raise ValueError(f"Unsupported data type: {dtype_str}") @dataclass class AttentionMaskConverter: is_causal: bool sliding_window: Optional[int] def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): self.is_causal = is_causal self.sliding_window = sliding_window if self.sliding_window is not None and self.sliding_window <= 0: raise ValueError( f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" ) def to_causal_4d( self, batch_size: int, query_length: int, key_value_length: int, dtype: mx.Dtype, device: Union[mx.Device, "str"] = "cpu", ) -> Optional[mx.array]: """ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative bias to upper right hand triangular matrix (causal mask). """ if not self.is_causal: raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") # If shape is not cached, create a new causal mask and cache it input_shape = (batch_size, query_length) past_key_values_length = key_value_length - query_length # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] causal_4d_mask = None if input_shape[-1] > 1 or self.sliding_window is not None: causal_4d_mask = self._make_causal_mask( input_shape, dtype, device=device, past_key_values_length=past_key_values_length, sliding_window=self.sliding_window, ) return causal_4d_mask def to_4d( self, attention_mask_2d: mx.array, query_length: int, dtype: mx.Dtype, key_value_length: Optional[int] = None, ) -> mx.array: """ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is causal, a causal mask will be added. """ input_shape = (attention_mask_2d.shape[0], query_length) # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] causal_4d_mask = None if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: if key_value_length is None: raise ValueError( "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." ) past_key_values_length = key_value_length - query_length causal_4d_mask = self._make_causal_mask( input_shape, dtype, device=attention_mask_2d.device, past_key_values_length=past_key_values_length, sliding_window=self.sliding_window, ) elif self.sliding_window is not None: raise NotImplementedError("Sliding window is currently only implemented for causal masking") # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( attention_mask_2d.device ) if causal_4d_mask is not None: expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), get_finfo_min(dtype)) # expanded_attn_mask + causal_4d_mask can cause some overflow expanded_4d_mask = expanded_attn_mask return expanded_4d_mask @staticmethod def _make_causal_mask( input_ids_shape: Tuple[int, int], dtype: mx.Dtype, device: mx.Device, past_key_values_length: int = 0, sliding_window: Optional[int] = None, ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = mx.full((tgt_len, tgt_len), get_finfo_min(dtype), device=device) mask_cond = mx.arange(tgt_len, device=device) mask = mask * (mask_cond[:, None] >= mask_cond[None, :]) mask = mask.astype(dtype) if past_key_values_length > 0: past_mask = mx.zeros((tgt_len, past_key_values_length), dtype=dtype, device=device) mask = mx.concatenate([past_mask, mask], dim=-1) # add lower triangular sliding window mask if necessary if sliding_window is not None: diagonal = past_key_values_length - sliding_window - 1 context_mask = mx.tril(mx.ones_like(mask, dtype=mx.bool_), k=diagonal) mask = mask * (1 - context_mask.astype(dtype)) + context_mask.astype(dtype) * get_finfo_min(dtype) return mask.expand_dims(axis=0).expand_dims(axis=0).broadcast_to((bsz, 1, tgt_len, tgt_len + past_key_values_length)) @staticmethod def _expand_mask(mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(mx.bool_), get_finfo_min(dtype)) @staticmethod def _unmask_unattended( expanded_mask: FloatTensor, min_dtype: float, ): # fmt: off """ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. Details: https://github.com/pytorch/pytorch/issues/110213 `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. `attention_mask` is [bsz, src_seq_len]. The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. For example, if `expanded_mask` is (e.g. here left-padding case) ``` [[[[0, 0, 0], [0, 0, 0], [0, 0, 1]]], [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], [[[0, 0, 0], [0, 1, 0], [0, 1, 1]]]] ``` then the modified `expanded_mask` will be ``` [[[[1, 1, 1], <-- modified [1, 1, 1], <-- modified [0, 0, 1]]], [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], [[[1, 1, 1], <-- modified [0, 1, 0], [0, 1, 1]]]] ``` """ # fmt: on if expanded_mask.dtype == mx.bool_: raise ValueError( "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." ) return expanded_mask.mul(~mx.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) def _prepare_4d_causal_attention_mask( attention_mask: Optional[mx.array], input_shape: Union[mx.array, Tuple, List], inputs_embeds: mx.array, past_key_values_length: int, sliding_window: Optional[int] = None, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)` Args: attention_mask (`mx.array` or `None`): A 2D attention mask of shape `(batch_size, key_value_length)` input_shape (`tuple(int)` or `list(int)`): The input shape should be a tuple that defines `(batch_size, query_length)`. inputs_embeds (`mx.array`): The embedded inputs as a torch Tensor. past_key_values_length (`int`): The length of the key value cache. sliding_window (`int`, *optional*): If the model uses windowed attention, a sliding window should be passed. """ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = input_shape[-1] + past_key_values_length # 4d mask is passed through the layers if attention_mask is not None and len(attention_mask.shape) == 2: attention_mask = attn_mask_converter.to_4d( attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype ) elif attention_mask is not None and len(attention_mask.shape) == 4: expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) if tuple(attention_mask.shape) != expected_shape: raise ValueError( f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." ) else: # if the 4D mask has correct shape - invert it and fill with negative infinity inverted_mask = 1.0 - attention_mask attention_mask = inverted_mask.masked_fill( inverted_mask.to(mx.bool_), get_finfo_min(inputs_embeds.dtype) ) else: attention_mask = attn_mask_converter.to_causal_4d( input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device ) return attention_mask