Commit
•
91a0561
1
Parent(s):
b3121c8
fix rmsnorm init weight bug. (#59)
Browse files- fix rmsnorm init weight bug. (9d3d7be563d07295abb119ff28714aa9267580b8)
Co-authored-by: Ben <[email protected]>
- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
@@ -181,7 +181,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
|
|
181 |
class RMSNorm(torch.nn.Module):
|
182 |
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
183 |
super().__init__()
|
184 |
-
self.weight = torch.nn.Parameter(torch.
|
185 |
self.eps = eps
|
186 |
|
187 |
def forward(self, hidden_states: torch.Tensor):
|
|
|
181 |
class RMSNorm(torch.nn.Module):
|
182 |
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
183 |
super().__init__()
|
184 |
+
self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype))
|
185 |
self.eps = eps
|
186 |
|
187 |
def forward(self, hidden_states: torch.Tensor):
|