import math import time import warnings from importlib.metadata import version from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import transformers from transformers.cache_utils import Cache, DynamicCache from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.utils import logging logger = logging.get_logger(__name__) # https://github.com/huggingface/transformers/blob/v4.37-release/src/transformers/models/llama/modeling_llama.py def llama_flash_attn2_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # [SnapKV] register kv_cluster init_snapkv(self) # LlamaFlashAttention2 attention does not support output_attentions if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) # overwrite attention_mask with padding_mask attention_mask = kwargs.pop("padding_mask") output_attentions = False bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) kv_seq_len = key_states.shape[-2] # if past_key_value is not None: # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) if hasattr(self, "kv_seq_len"): # [SnapKV] add kv_seq_len if self.kv_seq_len != 0: kv_seq_len += self.kv_seq_len else: kv_seq_len += past_key_value.get_usable_length( kv_seq_len, self.layer_idx ) else: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) # [SnapKV] move to ahead key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # print('kv_seq_len:', kv_seq_len) # print('key_states.shape:', key_states.shape) if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len key_states_compress, value_states_compress = self.kv_cluster.update_kv( key_states, query_states, value_states, attention_mask, self.num_key_value_groups, ) past_key_value.update( key_states_compress, value_states_compress, self.layer_idx, cache_kwargs ) else: self.kv_seq_len += q_len key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def prepare_inputs_for_generation_llama( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs, ): if past_key_values is None: # [SnapKV] for layer in self.model.layers: layer.self_attn.kv_seq_len = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() else: # cache_length = past_length = past_key_values[0][0].shape[2] # max_cache_length = None cache_length = past_length = self.model.layers[0].self_attn.kv_seq_len max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs llama_flash_attn2_forward_4_37 = llama_flash_attn2_forward prepare_inputs_for_generation_llama_4_37 = prepare_inputs_for_generation_llama @torch.no_grad() def rope_forward(self, x, seq_len): # x: [bs, num_attention_heads, seq_len, head_size] position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0) inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = ( device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ) with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( 1, 2 ) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) ################## # perform qk calculation and get indices # this version will not update in inference mode # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class SnapKVCluster: def __init__( self, window_size=64, max_capacity_prompt=256 + 64, kernel_size=5, pooling="avgpool", ): self.window_size = window_size self.max_capacity_prompt = max_capacity_prompt assert self.max_capacity_prompt - self.window_size > 0 self.kernel_size = kernel_size self.pooling = pooling def reset( self, window_size=64, max_capacity_prompt=256 + 64, kernel_size=5, pooling="avgpool", ): self.window_size = window_size self.max_capacity_prompt = max_capacity_prompt assert self.max_capacity_prompt - self.window_size > 0 self.kernel_size = kernel_size self.pooling = pooling def update_kv( self, key_states, query_states, value_states, attention_mask, num_key_value_groups, ): # check if prefix phase assert key_states.shape[-2] == query_states.shape[-2] bsz, num_heads, q_len, head_dim = query_states.shape if q_len < self.max_capacity_prompt: return key_states, value_states else: attn_weights = torch.matmul( query_states[..., -self.window_size :, :], key_states.transpose(2, 3) ) / math.sqrt(head_dim) mask = torch.full( (self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device, ) mask_cond = torch.arange(mask.size(-1), device=attn_weights.device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(attn_weights.device) attention_mask = mask[None, None, :, :] attn_weights[ :, :, -self.window_size :, -self.window_size : ] += attention_mask attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) attn_weights_sum = attn_weights[ :, :, -self.window_size :, : -self.window_size ].sum(dim=-2) if self.pooling == "avgpool": attn_cache = F.avg_pool1d( attn_weights_sum, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1, ) elif self.pooling == "maxpool": attn_cache = F.max_pool1d( attn_weights_sum, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1, ) else: raise ValueError("Pooling method not supported") indices = attn_cache.topk( self.max_capacity_prompt - self.window_size, dim=-1 ).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim) k_past_compress = key_states[:, :, : -self.window_size, :].gather( dim=2, index=indices ) v_past_compress = value_states[:, :, : -self.window_size, :].gather( dim=2, index=indices ) k_cur = key_states[:, :, -self.window_size :, :] v_cur = value_states[:, :, -self.window_size :, :] key_states = torch.cat([k_past_compress, k_cur], dim=2) value_states = torch.cat([v_past_compress, v_cur], dim=2) return key_states, value_states def init_snapkv(self): if not hasattr(self, "kv_cluster"): if not hasattr(self.config, "window_size"): self.config.window_size = 64 if not hasattr(self.config, "max_capacity_prompt"): self.config.max_capacity_prompt = 4096 if not hasattr(self.config, "kernel_size"): self.config.kernel_size = 13 if not hasattr(self.config, "pooling"): self.config.pooling = "avgpool" self.kv_cluster = SnapKVCluster( window_size=self.config.window_size, max_capacity_prompt=self.config.max_capacity_prompt, kernel_size=self.config.kernel_size, pooling=self.config.pooling, ) ############ def check_version(): try: transformers_version = version("transformers") except Exception as e: print(f"Transformers not installed: {e}") return transformers_version def replace_llama(): transformers_version = check_version() version_list = ["4.37"] warning_flag = True for version in version_list: if version in transformers_version: warning_flag = False break if warning_flag: warnings.warn( f"Transformers version {transformers_version} might not be compatible with SnapKV. SnapKV is tested with Transformers version {version_list}." ) transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ( prepare_inputs_for_generation_llama_4_37 ) transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = ( llama_flash_attn2_forward_4_37 )