ALJIACHI commited on
Commit
ebd0760
·
verified ·
1 Parent(s): 2f77769

Upload modeling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling.py +24 -7
modeling.py CHANGED
@@ -205,15 +205,17 @@ class RotaryEmbedding(torch.nn.Module):
205
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
206
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
207
 
 
 
 
 
 
 
 
 
208
  def forward(self, x, seq_len=None):
209
  # x: [bs, num_attention_heads, seq_len, head_size]
210
- if seq_len > self.max_seq_len_cached:
211
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
212
-
213
- return (
214
- self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
215
- self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
216
- )
217
 
218
 
219
  class NTKScalingRotaryEmbedding(RotaryEmbedding):
@@ -250,6 +252,21 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding):
250
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
251
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  class RMSNorm(nn.Module):
255
  def __init__(self, hidden_size, eps=1e-6):
 
205
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
206
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
207
 
208
+ def _compute_cos_sin(self, seq_len, device, dtype):
209
+ """Compute cos/sin from scratch — avoids persistent buffer corruption on Python 3.13."""
210
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
211
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
212
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
213
+ emb = torch.cat((freqs, freqs), dim=-1)
214
+ return emb.cos().to(dtype=dtype), emb.sin().to(dtype=dtype)
215
+
216
  def forward(self, x, seq_len=None):
217
  # x: [bs, num_attention_heads, seq_len, head_size]
218
+ return self._compute_cos_sin(seq_len, x.device, x.dtype)
 
 
 
 
 
 
219
 
220
 
221
  class NTKScalingRotaryEmbedding(RotaryEmbedding):
 
252
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
253
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
254
 
255
+ def _compute_cos_sin(self, seq_len, device, dtype):
256
+ """Compute NTK-scaled cos/sin from scratch — avoids persistent buffer corruption."""
257
+ base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
258
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
259
+ if self.mixed_b is None:
260
+ inv_freq = inv_freq / self.scaling_factor ** (2 / self.dim)
261
+ else:
262
+ a = torch.tensor(self.scaling_factor, device=device).log() / (self.dim / 2) ** self.mixed_b
263
+ lambda_1_m = (a * torch.arange(1, self.dim // 2 + 1, device=device, dtype=torch.float32) ** self.mixed_b).exp()
264
+ inv_freq = inv_freq / lambda_1_m
265
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
266
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
267
+ emb = torch.cat((freqs, freqs), dim=-1)
268
+ return emb.cos().to(dtype=dtype), emb.sin().to(dtype=dtype)
269
+
270
 
271
  class RMSNorm(nn.Module):
272
  def __init__(self, hidden_size, eps=1e-6):