Fine-tuning Script: QLoRA w/ Flash Attn fails

#41
by RonanMcGovern - opened

Running the script, but adding _attn_implementation="flash_attention_2", to model loading, yields:
```
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float16.

RuntimeError Traceback (most recent call last)
Cell In[8], line 1
----> 1 trainer.train()

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1875, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1873 hf_hub_utils.enable_progress_bars()
1874 else:
-> 1875 return inner_training_loop(
1876 args=args,
1877 resume_from_checkpoint=resume_from_checkpoint,
1878 trial=trial,
1879 ignore_keys_for_eval=ignore_keys_for_eval,
1880 )

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2206, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2203 self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
2205 with self.accelerator.accumulate(model):
-> 2206 tr_loss_step = self.training_step(model, inputs)
2208 if (
2209 args.logging_nan_inf_filter
2210 and not is_torch_xla_available()
2211 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
2212 ):
2213 # if loss is nan or inf simply add the average of previous logged losses
2214 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3184, in Trainer.training_step(self, model, inputs)
3181 return loss_mb.reduce_mean().detach().to(self.args.device)
3183 with self.compute_loss_context_manager():
-> 3184 loss = self.compute_loss(model, inputs)
3186 if self.args.n_gpu > 1:
3187 loss = loss.mean() # mean() to average on multi-gpu parallel training

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3207, in Trainer.compute_loss(self, model, inputs, return_outputs)
3205 else:
3206 labels = None
-> 3207 outputs = model(**inputs)
3208 # Save past state if it exists
3209 # TODO: this needs to be fixed and made cleaner later.
3210 if self.args.past_index >= 0:

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:825, in convert_outputs_to_fp32..forward(*args, **kwargs)
824 def forward(*args, **kwargs):
--> 825 return model_forward(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:813, in ConvertOutputsToFp32.call(self, *args, **kwargs)
812 def call(self, *args, **kwargs):
--> 813 return convert_to_fp32(self.model_forward(*args, **kwargs))

File /usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:16, in autocast_decorator..decorate_autocast(*args, **kwargs)
13 @functools.wraps(func)
14 def decorate_autocast(*args, **kwargs):
15 with autocast_instance:
---> 16 return func(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1829, in Idefics2ForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1826 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1828 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1829 outputs = self.model(
1830 input_ids=input_ids,
1831 attention_mask=attention_mask,
1832 position_ids=position_ids,
1833 past_key_values=past_key_values,
1834 inputs_embeds=inputs_embeds,
1835 pixel_values=pixel_values,
1836 pixel_attention_mask=pixel_attention_mask,
1837 image_hidden_states=image_hidden_states,
1838 use_cache=use_cache,
1839 output_attentions=output_attentions,
1840 output_hidden_states=output_hidden_states,
1841 return_dict=return_dict,
1842 )
1844 hidden_states = outputs[0]
1845 logits = self.lm_head(hidden_states)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1649, in Idefics2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, use_cache, output_attentions, output_hidden_states, return_dict)
1643 image_hidden_states = self.vision_model(
1644 pixel_values=pixel_values,
1645 patch_attention_mask=patch_attention_mask,
1646 ).last_hidden_state
1648 # Modality projection & resampling
-> 1649 image_hidden_states = self.connector(
1650 image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
1651 )
1653 elif image_hidden_states is not None:
1654 image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1317, in Idefics2Connector.forward(self, image_hidden_states, attention_mask)
1315 def forward(self, image_hidden_states, attention_mask):
1316 image_hidden_states = self.modality_projection(image_hidden_states)
-> 1317 image_hidden_states = self.perceiver_resampler(context=image_hidden_states, attention_mask=attention_mask)
1318 return image_hidden_states

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1287, in Idefics2PerceiverResampler.forward(self, context, attention_mask)
1285 compressed_context = latents
1286 for perceiver_layer in self.layers:
-> 1287 layer_outputs = perceiver_layer(
1288 compressed_context,
1289 context,
1290 attention_mask=attention_mask,
1291 position_ids=None,
1292 past_key_value=None,
1293 output_attentions=False,
1294 use_cache=False,
1295 )
1297 compressed_context = layer_outputs[0]
1299 compressed_context = self.norm(compressed_context)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1220, in Idefics2PerceiverLayer.forward(self, latents, context, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
1217 latents = self.input_latents_norm(latents)
1218 context = self.input_context_norm(context)
-> 1220 latents, self_attn_weights, present_key_value = self.self_attn(
1221 latents=latents,
1222 context=context,
1223 attention_mask=attention_mask,
1224 )
1225 latents = residual + latents
1226 residual = latents

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1004, in Idefics2PerceiverFlashAttention2.forward(self, latents, context, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
1001 key_states = key_states.transpose(1, 2)
1002 value_states = value_states.transpose(1, 2)
-> 1004 attn_output = self._flash_attention_forward(
1005 query_states,
1006 key_states,
1007 value_states,
1008 attention_mask,
1009 q_len,
1010 dropout=dropout_rate,
1011 use_sliding_windows=False,
1012 )
1014 attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
1015 attn_output = self.o_proj(attn_output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1071, in Idefics2PerceiverFlashAttention2._flash_attention_forward(self, query_states, key_states, value_states, attention_mask, query_length, dropout, softmax_scale, use_sliding_windows)
1068 max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1070 if not use_sliding_windows:
-> 1071 attn_output_unpad = flash_attn_varlen_func(
1072 query_states,
1073 key_states,
1074 value_states,
1075 cu_seqlens_q=cu_seqlens_q,
1076 cu_seqlens_k=cu_seqlens_k,
1077 max_seqlen_q=max_seqlen_in_batch_q,
1078 max_seqlen_k=max_seqlen_in_batch_k,
1079 dropout_p=dropout,
1080 softmax_scale=softmax_scale,
1081 causal=causal,
1082 )
1083 else:
1084 attn_output_unpad = flash_attn_varlen_func(
1085 query_states,
1086 key_states,
(...)
1095 window_size=(self.config.sliding_window, self.config.sliding_window),
1096 )

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:1066, in flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, block_table)
995 def flash_attn_varlen_func(
996 q,
997 k,
(...)
1010 block_table=None,
1011 ):
1012 """dropout_p should be set to 0.0 during evaluation
1013 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1014 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
(...)
1064 pattern (negative means that location was dropped, nonnegative means it was kept).
1065 """
-> 1066 return FlashAttnVarlenFunc.apply(
1067 q,
1068 k,
1069 v,
1070 cu_seqlens_q,
1071 cu_seqlens_k,
1072 max_seqlen_q,
1073 max_seqlen_k,
1074 dropout_p,
1075 softmax_scale,
1076 causal,
1077 window_size,
1078 alibi_slopes,
1079 deterministic,
1080 return_attn_probs,
1081 block_table,
1082 )

File /usr/local/lib/python3.10/dist-packages/torch/autograd/function.py:539, in Function.apply(cls, *args, **kwargs)
536 if not torch._C._are_functorch_transforms_active():
537 # See NOTE: [functorch vjp and autograd interaction]
538 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 539 return super().apply(*args, **kwargs) # type: ignore[misc]
541 if cls.setup_context == _SingleLevelFunction.setup_context:
542 raise RuntimeError(
543 "In order to use an autograd.Function with functorch transforms "
544 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
545 "staticmethod. For more details, please see "
546 "https://pytorch.org/docs/master/notes/extending.func.html"
547 )

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:581, in FlashAttnVarlenFunc.forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, block_table)
579 if softmax_scale is None:
580 softmax_scale = q.shape[-1] ** (-0.5)
--> 581 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
582 q,
583 k,
584 v,
585 cu_seqlens_q,
586 cu_seqlens_k,
587 max_seqlen_q,
588 max_seqlen_k,
589 dropout_p,
590 softmax_scale,
591 causal=causal,
592 window_size=window_size,
593 alibi_slopes=alibi_slopes,
594 return_softmax=return_softmax and dropout_p > 0,
595 block_table=block_table,
596 )
597 ctx.save_for_backward(
598 q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
599 )
600 ctx.dropout_p = dropout_p

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:86, in _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, block_table)
84 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
85 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 86 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
87 q,
88 k,
89 v,
90 None,
91 cu_seqlens_q,
92 cu_seqlens_k,
93 None,
94 block_table,
95 alibi_slopes,
96 max_seqlen_q,
97 max_seqlen_k,
98 dropout_p,
99 softmax_scale,
100 False,
101 causal,
102 window_size[0],
103 window_size[1],
104 return_softmax,
105 None,
106 )
107 # if out.isnan().any() or softmax_lse.isnan().any():
108 # breakpoint()
109 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state

RuntimeError: query and key must have the same dtype

when running the training.
RonanMcGovern changed discussion title from Fine-tuning script error to Fine-tuning Script: QLoRA w/ Flash Attn fails

Seems similar to this issue
https://github.com/huggingface/transformers/issues/30019

Thanks, yes, that's the issue. I'll subscribe to that on github and comment back here once that's resolved.

Sign up or log in to comment