Support gradient checkpointing

#41
by muelletm - opened

(This is a trace you get when trying to use mpt-7b from qlora. More details here: https://github.com/artidoro/qlora/issues/10)

/opt/conda/lib/python3.10/site-packages/peft/utils/other.py:76: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.
  warnings.warn(
Traceback (most recent call last):
  File "/code/qlora/qlora.py", line 758, in <module>
    train()
  File "/code/qlora/qlora.py", line 590, in train
    model = get_accelerate_model(args, checkpoint_dir)
  File "/code/qlora/qlora.py", line 295, in get_accelerate_model
    model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
  File "/opt/conda/lib/python3.10/site-packages/peft/utils/other.py", line 80, in prepare_model_for_int8_training
    return prepare_model_for_kbit_training(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/peft/utils/other.py", line 69, in prepare_model_for_kbit_training
    model.gradient_checkpointing_enable()
  File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py", line 1620, in gradient_checkpointing_enable
    raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

Small update but @cekal added a version of mpt-7b that fixes this (and other problems):

https://huggingface.co/cekal/mpt-7b-peft-compatible

Maybe their code can be merged into this repo?

@cekal — what changes were needed for gradient checkpointing in the PEFT library?

@cekal @sam-mosaic any updates?

hmm, ok, but was that propagated to other mpt models like mpt-7b-chat and mpt-7b-8k-chat?

deleted

Just ran into this issue today, I know if you use the Composer trainer you can use gradient checkpointing but it doesn't work through the Huggingface trainer

thanks, noted, feel free to close the issue

Sign up or log in to comment