FalconForCausalLM does not support Flash Attention 2.0 yet
#98
by
Menouar
- opened
Dear Repository Owners,
The Falcon model loaded from the library supports Flash Attention:
from transformers import FalconForCausalLM
model = FalconForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map="auto",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config
)
However, the Falcon model loaded from the hub does not support it:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map="auto",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config
)
I encountered the following error:
ValueError: FalconForCausalLM does not support Flash Attention 2.0 yet.
This discrepancy seems to occur because the model was originally hosted on the hub and was later incorporated into this library.
@Rocketknight1
Thanks