Current CUDA Device does not support bfloat16. Please switch dtype to float16.

#4
by marcos777123 - opened

Olá,

cuda version é 12.1 e cudnn 8.9.7
coloquei: model_dtype, torch_dtype = 'bf16', torch.float16 #torch.bfloat16
e está indo...

Sign up or log in to comment