jupyterjazz
commited on
Commit
•
d8cbc92
1
Parent(s):
a6bb16f
fix: update frequencies when updating the rope base value
Browse files
rotary.py
CHANGED
@@ -495,6 +495,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
495 |
def base(self, new_base):
|
496 |
if new_base > 0:
|
497 |
self._base = float(new_base)
|
|
|
498 |
else:
|
499 |
raise ValueError("Rotary base value must be positive")
|
500 |
|
|
|
495 |
def base(self, new_base):
|
496 |
if new_base > 0:
|
497 |
self._base = float(new_base)
|
498 |
+
self.inv_freq = self._compute_inv_freq(device=self.inv_freq.device)
|
499 |
else:
|
500 |
raise ValueError("Rotary base value must be positive")
|
501 |
|