Adding _bnb_4bit_dequantize_and_rescale in ` modeling_rwkv5.py ` based on the provided docmentations on github
#7
by
kaifahmad
- opened
- modeling_rwkv5.py +21 -0
modeling_rwkv5.py
CHANGED
@@ -852,3 +852,24 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
|
|
852 |
hidden_states=rwkv_outputs.hidden_states,
|
853 |
attentions=rwkv_outputs.attentions,
|
854 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
852 |
hidden_states=rwkv_outputs.hidden_states,
|
853 |
attentions=rwkv_outputs.attentions,
|
854 |
)
|
855 |
+
|
856 |
+
def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
|
857 |
+
r"""
|
858 |
+
Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
|
859 |
+
be quantized again.
|
860 |
+
"""
|
861 |
+
if not is_bitsandbytes_available():
|
862 |
+
raise ImportError("Please install bitsandbytes to use this method.")
|
863 |
+
import bitsandbytes as bnb
|
864 |
+
|
865 |
+
dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
|
866 |
+
|
867 |
+
dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
|
868 |
+
|
869 |
+
# re-quantize the model:
|
870 |
+
# we need to put it first on CPU then back to the device
|
871 |
+
# this will create an overhead :/
|
872 |
+
# We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
|
873 |
+
# bugs with bnb
|
874 |
+
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
|
875 |
+
setattr(target_layer, "weight", quant_weight)
|