pszemraj commited on
Commit
0688e28
1 Parent(s): e19495a

add MPTBlock to _no_split_modules

Browse files

the blocks.py is a separate file vs [standard transformers modeling scripts](https://github.com/huggingface/transformers/blob/94056b57beb4499f4f74d5d88a41e8266cc01778/src/transformers/models/longt5/modeling_longt5.py#L1273) that I am used to, so it took me a bit to find.

Files changed (1) hide show
  1. modeling_mpt.py +1 -1
modeling_mpt.py CHANGED
@@ -32,7 +32,7 @@ class MPTPreTrainedModel(PreTrainedModel):
32
  config_class = MPTConfig
33
  base_model_prefix = "model"
34
  supports_gradient_checkpointing = True
35
- _no_split_modules = []
36
 
37
 
38
  class MPTModel(MPTPreTrainedModel):
 
32
  config_class = MPTConfig
33
  base_model_prefix = "model"
34
  supports_gradient_checkpointing = True
35
+ _no_split_modules = ["MPTBlock"]
36
 
37
 
38
  class MPTModel(MPTPreTrainedModel):