Using `self.transformer.wte.weight` directly for LM head breaks HF accelerate device map auto infer on multi-gpu

#46
by shijie-wu - opened

Hi,

Using self.transformer.wte.weight directly for LM head
https://huggingface.co/mosaicml/mpt-7b/blob/053e1a33a6e7043aefaa3f5d13c48269a5511cff/modeling_mpt.py#L239
breaks HF accelerate device map auto infer on multi-gpu, since the final layer will be placed on different GPU then self.transformer.wte. this could be fixed by making a dummy LM head and tie the params, similar to all GPT-style models in transformers

Working on it, we are testing with FSDP first to make sure nothing breaks: https://github.com/mosaicml/llm-foundry/pull/225

Fixed as of this PR: https://huggingface.co/mosaicml/mpt-7b/discussions/47

Please give it a try!

abhi-mosaic changed discussion status to closed

Sign up or log in to comment