Support gradient checkpointing
#41
by
muelletm
- opened
(This is a trace you get when trying to use mpt-7b from qlora. More details here: https://github.com/artidoro/qlora/issues/10)
/opt/conda/lib/python3.10/site-packages/peft/utils/other.py:76: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.
warnings.warn(
Traceback (most recent call last):
File "/code/qlora/qlora.py", line 758, in <module>
train()
File "/code/qlora/qlora.py", line 590, in train
model = get_accelerate_model(args, checkpoint_dir)
File "/code/qlora/qlora.py", line 295, in get_accelerate_model
model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
File "/opt/conda/lib/python3.10/site-packages/peft/utils/other.py", line 80, in prepare_model_for_int8_training
return prepare_model_for_kbit_training(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/peft/utils/other.py", line 69, in prepare_model_for_kbit_training
model.gradient_checkpointing_enable()
File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py", line 1620, in gradient_checkpointing_enable
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
Small update but @cekal added a version of mpt-7b that fixes this (and other problems):
https://huggingface.co/cekal/mpt-7b-peft-compatible
Maybe their code can be merged into this repo?
@cekal — what changes were needed for gradient checkpointing in the PEFT library?
@cekal @sam-mosaic any updates?
this looks like the change to enable gradient checkpointing?
https://huggingface.co/cekal/mpt-7b-peft-compatible/commit/a5eab52c1c61c1d50a4e01428949f6ff90c73c48
hmm, ok, but was that propagated to other mpt models like mpt-7b-chat and mpt-7b-8k-chat?
thanks, noted, feel free to close the issue