jupyterjazz commited on
Commit
d8cbc92
1 Parent(s): a6bb16f

fix: update frequencies when updating the rope base value

Browse files
Files changed (1) hide show
  1. rotary.py +1 -0
rotary.py CHANGED
@@ -495,6 +495,7 @@ class RotaryEmbedding(torch.nn.Module):
495
  def base(self, new_base):
496
  if new_base > 0:
497
  self._base = float(new_base)
 
498
  else:
499
  raise ValueError("Rotary base value must be positive")
500
 
 
495
  def base(self, new_base):
496
  if new_base > 0:
497
  self._base = float(new_base)
498
+ self.inv_freq = self._compute_inv_freq(device=self.inv_freq.device)
499
  else:
500
  raise ValueError("Rotary base value must be positive")
501