Add _support_flash_attn_2 to Llama 2 32k
#37
by
arshzahed
- opened
- modeling_flash_llama.py +1 -0
modeling_flash_llama.py
CHANGED
@@ -499,6 +499,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
|
499 |
supports_gradient_checkpointing = True
|
500 |
_no_split_modules = ["LlamaDecoderLayer"]
|
501 |
_skip_keys_device_placement = "past_key_values"
|
|
|
502 |
|
503 |
def _init_weights(self, module):
|
504 |
std = self.config.initializer_range
|
|
|
499 |
supports_gradient_checkpointing = True
|
500 |
_no_split_modules = ["LlamaDecoderLayer"]
|
501 |
_skip_keys_device_placement = "past_key_values"
|
502 |
+
_supports_flash_attn_2 = True
|
503 |
|
504 |
def _init_weights(self, module):
|
505 |
std = self.config.initializer_range
|