x54-729
commited on
Commit
•
2d0920c
1
Parent(s):
2750ce8
fix flash attention import
Browse files- configuration_internlm2.py +9 -2
- modeling_internlm2.py +4 -2
configuration_internlm2.py
CHANGED
@@ -169,5 +169,12 @@ class InternLM2Config(PretrainedConfig):
|
|
169 |
raise ValueError(
|
170 |
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
171 |
)
|
172 |
-
if
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
raise ValueError(
|
170 |
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
171 |
)
|
172 |
+
if (
|
173 |
+
rope_scaling_factor is None
|
174 |
+
or not isinstance(rope_scaling_factor, (float, int))
|
175 |
+
or rope_scaling_factor < 1.0
|
176 |
+
):
|
177 |
+
raise ValueError(
|
178 |
+
f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
|
179 |
+
f"of type {type(rope_scaling_factor)}"
|
180 |
+
)
|
modeling_internlm2.py
CHANGED
@@ -40,7 +40,6 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
40 |
from transformers.utils import (
|
41 |
add_start_docstrings,
|
42 |
add_start_docstrings_to_model_forward,
|
43 |
-
is_flash_attn_2_available,
|
44 |
is_flash_attn_greater_or_equal_2_10,
|
45 |
logging,
|
46 |
replace_return_docstrings,
|
@@ -53,9 +52,12 @@ except Exception:
|
|
53 |
|
54 |
from .configuration_internlm2 import InternLM2Config
|
55 |
|
56 |
-
|
|
|
57 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
58 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
|
|
|
|
59 |
|
60 |
|
61 |
logger = logging.get_logger(__name__)
|
|
|
40 |
from transformers.utils import (
|
41 |
add_start_docstrings,
|
42 |
add_start_docstrings_to_model_forward,
|
|
|
43 |
is_flash_attn_greater_or_equal_2_10,
|
44 |
logging,
|
45 |
replace_return_docstrings,
|
|
|
52 |
|
53 |
from .configuration_internlm2 import InternLM2Config
|
54 |
|
55 |
+
|
56 |
+
try:
|
57 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
58 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
59 |
+
except:
|
60 |
+
pass
|
61 |
|
62 |
|
63 |
logger = logging.get_logger(__name__)
|