Add _support_flash_attn_2 to Llama 2 32k

#37
Files changed (1) hide show
  1. modeling_flash_llama.py +1 -0
modeling_flash_llama.py CHANGED
@@ -499,6 +499,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
499
  supports_gradient_checkpointing = True
500
  _no_split_modules = ["LlamaDecoderLayer"]
501
  _skip_keys_device_placement = "past_key_values"
 
502
 
503
  def _init_weights(self, module):
504
  std = self.config.initializer_range
 
499
  supports_gradient_checkpointing = True
500
  _no_split_modules = ["LlamaDecoderLayer"]
501
  _skip_keys_device_placement = "past_key_values"
502
+ _supports_flash_attn_2 = True
503
 
504
  def _init_weights(self, module):
505
  std = self.config.initializer_range