Lgr54HFi commited on
Commit
0e7327a
Β·
verified Β·
1 Parent(s): 0e64e3a

Fix NaN cascade: restore per-step gradient sanitization, add weight/momentum repair, harden Newton-Schulz

Browse files

Three NaN-related bugs:

1. Gradient sanitization was reduced from every step to every 10 steps as a
throughput optimization. But BitLinear STE + complex recurrent layers produce
occasional NaN gradients that MUST be caught immediately. A single uncaught
NaN gradient corrupts Muon's momentum buffer, which then corrupts the weights
on the next optimizer step, causing permanent all-NaN forward passes.
Fix: restore per-step sanitization.

2. NaN guard only zeroed gradients and halved LR β€” but never repaired the
already-corrupted weights or Muon momentum buffers. So once NaN entered
the weights, every subsequent step also produced NaN, triggering the LR
halving cascade (5 NaN β†’ halve β†’ 5 NaN β†’ halve β†’ ... β†’ LR = 0).
Fix: NaN guard now sanitizes model weights AND optimizer momentum buffers.
Also increased threshold from 5 to 10 consecutive NaN before halving.

3. Newton-Schulz could produce NaN if input gradient matrix had near-zero norm
(< 1e-12) or already contained NaN values. Now returns zero matrix in
these cases instead of propagating NaN through the polynomial iterations.

Files changed (1) hide show
  1. chimera_turbo.py +37 -80
chimera_turbo.py CHANGED
@@ -57,10 +57,15 @@ def _zeropower_via_newtonschulz5(G, steps=5):
57
  assert G.ndim == 2
58
  a, b, c = 3.4445, -4.7750, 2.0315
59
  X = G.T if G.size(0) > G.size(1) else G.clone()
60
- X = X / (X.norm() + 1e-7)
 
 
 
61
  for _ in range(steps):
62
  A = X @ X.T
63
  X = a * X + (b * A + c * A @ A) @ X
 
 
64
  return X.T if G.size(0) > G.size(1) else X
65
 
66
 
@@ -165,7 +170,6 @@ class MultiTokenPredictionLoss(nn.Module):
165
  sl = min(logits.size(1), targets.size(1))
166
 
167
  if token_weights is not None:
168
- # Apply token triage weights to MTP loss too
169
  per_tok = F.cross_entropy(
170
  logits[:, :sl].reshape(-1, logits.size(-1)),
171
  targets[:, :sl].reshape(-1), ignore_index=-100, reduction="none"
@@ -194,9 +198,6 @@ class TokenTriage:
194
  self.floor_weight = floor_weight
195
  self._loss_ema = None
196
  self._step = 0
197
- # FIX: Anneal floor_weight from 1.0 β†’ floor_weight over warmup_steps.
198
- # When loss is high (early training), all tokens are informative.
199
- # Discarding 40% of gradient signal at loss=10+ starves the model.
200
  self.warmup_steps = 500
201
 
202
  def compute_weights(self, per_token_loss):
@@ -208,7 +209,6 @@ class TokenTriage:
208
  else:
209
  self._loss_ema = self.ema_decay * self._loss_ema + (1 - self.ema_decay) * ml
210
 
211
- # FIX: Anneal β€” during warmup, all tokens get weight β‰ˆ 1.0
212
  if self._step < self.warmup_steps:
213
  t = self._step / self.warmup_steps
214
  cur_floor = 1.0 - t * (1.0 - self.floor_weight)
@@ -234,7 +234,7 @@ class PlateauBreaker:
234
  self._history = deque(maxlen=patience)
235
  self._stagnant_count = 0
236
  self._burst_remaining = 0
237
- self._saved_lrs = None # Save ALL group LRs, not just one
238
  self.total_bursts = 0
239
 
240
  def check_and_adjust(self, loss_val, optimizer, step):
@@ -244,7 +244,6 @@ class PlateauBreaker:
244
  if self._burst_remaining > 0:
245
  self._burst_remaining -= 1
246
  if self._burst_remaining == 0 and self._saved_lrs is not None:
247
- # Restore ALL group LRs (preserves LLRD ratios)
248
  for pg, saved_lr in zip(optimizer.param_groups, self._saved_lrs):
249
  pg["lr"] = saved_lr
250
  self._saved_lrs = None
@@ -259,31 +258,23 @@ class PlateauBreaker:
259
  else:
260
  self._stagnant_count = 0
261
  if self._stagnant_count >= self.patience // 2:
262
- # Save ALL LRs before burst
263
  self._saved_lrs = [pg["lr"] for pg in optimizer.param_groups]
264
  for pg in optimizer.param_groups:
265
- pg["lr"] *= self.lr_mult # Multiply, don't replace β†’ LLRD preserved
266
  self._burst_remaining = self.burst_steps
267
  self._stagnant_count = 0
268
  self.total_bursts += 1
269
  base = self._saved_lrs[0]
270
- print(f" [P16] Plateau! LR Γ—{self.lr_mult} for {self.burst_steps} steps (base {base:.2e})")
271
  return True
272
  return False
273
 
274
 
275
  # ═══════════════════════════════════════════════════════════
276
- # P18 Grokfast-EMA (1D params only β€” NS cancels on 2D)
277
  # ═══════════════════════════════════════════════════════════
278
 
279
  class GrokfastEMA:
280
- """Amplify slow gradient components for generalization.
281
-
282
- Applied ONLY to 1D params and embeddings (AdamW path).
283
- Skipped for 2D matrices because Muon's Newton-Schulz normalisation
284
- cancels the amplitude amplification β€” only direction survives,
285
- which Muon already optimises via orthogonalisation.
286
- """
287
  def __init__(self, alpha=0.98, lamb=2.0):
288
  self.alpha = alpha
289
  self.lamb = lamb
@@ -294,7 +285,6 @@ class GrokfastEMA:
294
  for name, p in model.named_parameters():
295
  if p.grad is None:
296
  continue
297
- # Skip 2D Muon params β€” NS normalisation cancels amplitude
298
  if p.ndim == 2 and not getattr(p, "_is_embed", False):
299
  continue
300
  if name not in self._ema:
@@ -352,12 +342,10 @@ def apply(model, max_steps=10000, lr=0.02, weight_decay=0.01,
352
  raw = getattr(model, "_orig_mod", model)
353
  extras = {}
354
 
355
- # P13: Create MTP FIRST so we can add its params to optimizer
356
  h, v = raw.config["hidden_size"], raw.config["vocab_size"]
357
  mtp = MultiTokenPredictionLoss(h, v, n_future=mtp_heads)
358
  extras["mtp"] = mtp
359
 
360
- # P12+P19: Muon with LLRD + MTP head params included
361
  mtp_params = list(mtp.parameters())
362
  optimizer = create_muon_optimizer(model, lr=lr, weight_decay=weight_decay,
363
  llrd_decay=llrd_decay, extra_params=mtp_params)
@@ -368,40 +356,33 @@ def apply(model, max_steps=10000, lr=0.02, weight_decay=0.01,
368
  scales = [g["lr_scale"] for g in optimizer.param_groups]
369
  n_mtp = sum(p.numel() for p in mtp_params)
370
  print(f"[P12] Muon (lr={lr}) + [P19] LLRD (decay={llrd_decay})")
371
- print(f" {n_total:,} params, LR: {min(scales):.3f}Γ— β†’ {max(scales):.3f}Γ—")
372
- print(f"[P13] MTP ({mtp_heads} heads, {n_mtp:,} params β€” IN optimizer)")
373
 
374
- # P15
375
  extras["triage"] = TokenTriage(ema_decay=0.99, select_ratio=0.6, floor_weight=0.1)
376
  if verbose:
377
- print(f"[P15] Token Triage (60%β†’full, 40%β†’10%, applied to base+MTP)")
378
 
379
- # P16
380
- # FIX: Increase patience (100β†’200) and variance threshold (0.005β†’0.02)
381
- # so the breaker doesn't fire during normal slow convergence.
382
- # The old settings triggered bursts when loss was fluctuating Β±0.07,
383
- # which is normal for stochastic training at loss~10.
384
  extras["plateau"] = PlateauBreaker(patience=200, variance_threshold=0.02,
385
  lr_multiplier=2.0, burst_steps=50)
386
  if verbose:
387
- print(f"[P16] Plateau Breaker (Γ—2 burst, LLRD-aware save/restore)")
388
 
389
- # P18
390
  extras["grokfast"] = GrokfastEMA(alpha=grokfast_alpha, lamb=grokfast_lambda)
391
  if verbose:
392
  n_1d = sum(p.numel() for p in model.parameters()
393
  if p.requires_grad and (p.ndim < 2 or getattr(p, "_is_embed", False)))
394
- print(f"[P18] Grokfast-EMA (Ξ±={grokfast_alpha}, Ξ»={grokfast_lambda}, {n_1d:,} params β€” 1D only)")
395
 
396
  if verbose:
397
- print(f"[P17] Batch Metabolism (hard seq Γ—2, easy Γ—0.5)")
398
  print("=" * 65)
399
 
400
  return model, optimizer, scheduler, extras
401
 
402
 
403
  # ═══════════════════════════════════════════════════════════
404
- # Training step β€” ALL paradigms FUSED + VERIFIED CUMULATIVE
405
  # ═══════════════════════════════════════════════════════════
406
 
407
  _nan_count = 0
@@ -410,29 +391,6 @@ def training_step(model, batch, optimizer, scheduler,
410
  extras=None, grad_accum_steps=1, step=0,
411
  max_grad_norm=1.0, autocast_dtype=None,
412
  mtp_weight=0.1) -> float:
413
- """
414
- Data flow (verified cumulative):
415
-
416
- forward(batch) β†’ logits, hidden_states
417
- β”‚
418
- β”œβ”€ per_token_loss = CE(logits, labels, reduction='none') [B,T]
419
- β”‚
420
- β”œβ”€ P17: seq_weights = sigmoid(z-score(per_seq_loss)) [B]
421
- β”œβ”€ P15: tok_weights = triage(excess_loss) [B,T]
422
- β”œβ”€ combined = tok_weights Γ— seq_weights [B,T]
423
- β”œβ”€ base_loss = weighted_mean(per_token_loss, combined)
424
- β”‚
425
- β”œβ”€ P13: mtp_loss = MTP(hidden, labels, tok_weights) ← triage applied!
426
- β”œβ”€ total_loss = base + 0.1 Γ— mtp
427
- β”‚
428
- backward(total_loss) β†’ param.grad for ALL params (model + MTP heads)
429
- β”‚
430
- β”œβ”€ P18: Grokfast amplifies grad on 1D params only (skip 2D/Muon)
431
- β”‚
432
- optimizer.step() β†’ P12 Muon (2D) + AdamW (1D), P19 LLRD scales per group
433
- β”‚
434
- └─ P16: Plateau checks loss_val, burst preserves LLRD ratios
435
- """
436
  global _nan_count
437
  extras = extras or {}
438
  is_accum = (step + 1) % grad_accum_steps == 0
@@ -456,27 +414,19 @@ def training_step(model, batch, optimizer, scheduler,
456
  ignore_index=-100, reduction="none"
457
  ).reshape(B, T)
458
 
459
- # P17: Batch Metabolism β€” per-sequence difficulty weights
460
- # FIX: With small effective batches (e.g. 8-32), seq_loss.std()
461
- # is extremely noisy, causing wild oscillation in seq_weights.
462
- # Clamp the z-scores and narrow the weight range from [0.5, 2.0]
463
- # to [0.7, 1.4] to reduce gradient noise.
464
  with torch.no_grad():
465
  seq_loss = per_token.mean(dim=1)
466
  seq_mean = seq_loss.mean()
467
  seq_std = seq_loss.std().clamp(min=1e-6)
468
  z = ((seq_loss - seq_mean) / seq_std).clamp(-2.0, 2.0)
469
- seq_weights = torch.sigmoid(z) * 0.7 + 0.7 # [0.7, 1.4]
470
 
471
- # P15: Token Triage β€” per-token informativeness weights
472
  triage = extras.get("triage")
473
  tok_weights = triage.compute_weights(per_token) if triage else torch.ones_like(per_token)
474
 
475
- # Fuse: multiplicative composition
476
  combined = tok_weights * seq_weights.unsqueeze(1)
477
  base_loss = (per_token * combined).sum() / combined.sum()
478
 
479
- # P13: MTP with Token Triage weights passed through
480
  mtp = extras.get("mtp")
481
  hidden = getattr(outputs, "hidden_states", None)
482
  if mtp is not None and hidden is not None:
@@ -489,19 +439,28 @@ def training_step(model, batch, optimizer, scheduler,
489
 
490
  loss_val = total_loss.item()
491
 
492
- # NaN guard
493
  if not math.isfinite(loss_val):
494
  _nan_count += 1
495
  optimizer.zero_grad(set_to_none=True)
496
- if _nan_count >= 5:
 
 
 
 
 
 
 
 
 
 
497
  for pg in optimizer.param_groups:
498
  pg["lr"] *= 0.5
499
- print(f" [NaN] 5Γ— β€” LR halved")
500
  _nan_count = 0
501
  return loss_val
502
  _nan_count = 0
503
 
504
- # P16: Plateau Breaker (before backward, uses loss_val only)
505
  plateau = extras.get("plateau")
506
  if plateau:
507
  plateau.check_and_adjust(loss_val, optimizer, step)
@@ -510,21 +469,19 @@ def training_step(model, batch, optimizer, scheduler,
510
  total_loss = total_loss / grad_accum_steps
511
  total_loss.backward()
512
 
513
- # Sanitize β€” only check every 10 steps to save CPU cycles.
514
- # NaN gradients are rare; checking every step is wasteful.
515
- if step % 10 == 0:
516
- for p in model.parameters():
517
- if p.grad is not None and not torch.isfinite(p.grad).all():
518
- p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
519
 
520
- # P18: Grokfast on 1D params only (2D handled by Muon NS)
521
  grokfast = extras.get("grokfast")
522
  if grokfast:
523
  grokfast.apply(model)
524
 
525
  if is_accum:
526
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
527
- optimizer.step() # P12 Muon (2D) + AdamW (1D), P19 LLRD via lr_scale
528
  scheduler.step()
529
  optimizer.zero_grad(set_to_none=True)
530
  invalidate_all_caches(model)
 
57
  assert G.ndim == 2
58
  a, b, c = 3.4445, -4.7750, 2.0315
59
  X = G.T if G.size(0) > G.size(1) else G.clone()
60
+ nrm = X.norm()
61
+ if nrm < 1e-12 or not torch.isfinite(nrm):
62
+ return torch.zeros_like(G)
63
+ X = X / (nrm + 1e-7)
64
  for _ in range(steps):
65
  A = X @ X.T
66
  X = a * X + (b * A + c * A @ A) @ X
67
+ if not torch.isfinite(X).all():
68
+ return torch.zeros_like(G)
69
  return X.T if G.size(0) > G.size(1) else X
70
 
71
 
 
170
  sl = min(logits.size(1), targets.size(1))
171
 
172
  if token_weights is not None:
 
173
  per_tok = F.cross_entropy(
174
  logits[:, :sl].reshape(-1, logits.size(-1)),
175
  targets[:, :sl].reshape(-1), ignore_index=-100, reduction="none"
 
198
  self.floor_weight = floor_weight
199
  self._loss_ema = None
200
  self._step = 0
 
 
 
201
  self.warmup_steps = 500
202
 
203
  def compute_weights(self, per_token_loss):
 
209
  else:
210
  self._loss_ema = self.ema_decay * self._loss_ema + (1 - self.ema_decay) * ml
211
 
 
212
  if self._step < self.warmup_steps:
213
  t = self._step / self.warmup_steps
214
  cur_floor = 1.0 - t * (1.0 - self.floor_weight)
 
234
  self._history = deque(maxlen=patience)
235
  self._stagnant_count = 0
236
  self._burst_remaining = 0
237
+ self._saved_lrs = None
238
  self.total_bursts = 0
239
 
240
  def check_and_adjust(self, loss_val, optimizer, step):
 
244
  if self._burst_remaining > 0:
245
  self._burst_remaining -= 1
246
  if self._burst_remaining == 0 and self._saved_lrs is not None:
 
247
  for pg, saved_lr in zip(optimizer.param_groups, self._saved_lrs):
248
  pg["lr"] = saved_lr
249
  self._saved_lrs = None
 
258
  else:
259
  self._stagnant_count = 0
260
  if self._stagnant_count >= self.patience // 2:
 
261
  self._saved_lrs = [pg["lr"] for pg in optimizer.param_groups]
262
  for pg in optimizer.param_groups:
263
+ pg["lr"] *= self.lr_mult
264
  self._burst_remaining = self.burst_steps
265
  self._stagnant_count = 0
266
  self.total_bursts += 1
267
  base = self._saved_lrs[0]
268
+ print(f" [P16] Plateau! LR x{self.lr_mult} for {self.burst_steps} steps (base {base:.2e})")
269
  return True
270
  return False
271
 
272
 
273
  # ═══════════════════════════════════════════════════════════
274
+ # P18 Grokfast-EMA (1D params only)
275
  # ═══════════════════════════════════════════════════════════
276
 
277
  class GrokfastEMA:
 
 
 
 
 
 
 
278
  def __init__(self, alpha=0.98, lamb=2.0):
279
  self.alpha = alpha
280
  self.lamb = lamb
 
285
  for name, p in model.named_parameters():
286
  if p.grad is None:
287
  continue
 
288
  if p.ndim == 2 and not getattr(p, "_is_embed", False):
289
  continue
290
  if name not in self._ema:
 
342
  raw = getattr(model, "_orig_mod", model)
343
  extras = {}
344
 
 
345
  h, v = raw.config["hidden_size"], raw.config["vocab_size"]
346
  mtp = MultiTokenPredictionLoss(h, v, n_future=mtp_heads)
347
  extras["mtp"] = mtp
348
 
 
349
  mtp_params = list(mtp.parameters())
350
  optimizer = create_muon_optimizer(model, lr=lr, weight_decay=weight_decay,
351
  llrd_decay=llrd_decay, extra_params=mtp_params)
 
356
  scales = [g["lr_scale"] for g in optimizer.param_groups]
357
  n_mtp = sum(p.numel() for p in mtp_params)
358
  print(f"[P12] Muon (lr={lr}) + [P19] LLRD (decay={llrd_decay})")
359
+ print(f" {n_total:,} params, LR: {min(scales):.3f}x -> {max(scales):.3f}x")
360
+ print(f"[P13] MTP ({mtp_heads} heads, {n_mtp:,} params -- IN optimizer)")
361
 
 
362
  extras["triage"] = TokenTriage(ema_decay=0.99, select_ratio=0.6, floor_weight=0.1)
363
  if verbose:
364
+ print(f"[P15] Token Triage (60%->full, 40%->10%, applied to base+MTP)")
365
 
 
 
 
 
 
366
  extras["plateau"] = PlateauBreaker(patience=200, variance_threshold=0.02,
367
  lr_multiplier=2.0, burst_steps=50)
368
  if verbose:
369
+ print(f"[P16] Plateau Breaker (x2 burst, LLRD-aware save/restore)")
370
 
 
371
  extras["grokfast"] = GrokfastEMA(alpha=grokfast_alpha, lamb=grokfast_lambda)
372
  if verbose:
373
  n_1d = sum(p.numel() for p in model.parameters()
374
  if p.requires_grad and (p.ndim < 2 or getattr(p, "_is_embed", False)))
375
+ print(f"[P18] Grokfast-EMA (a={grokfast_alpha}, l={grokfast_lambda}, {n_1d:,} params -- 1D only)")
376
 
377
  if verbose:
378
+ print(f"[P17] Batch Metabolism (hard seq x2, easy x0.5)")
379
  print("=" * 65)
380
 
381
  return model, optimizer, scheduler, extras
382
 
383
 
384
  # ═══════════════════════════════════════════════════════════
385
+ # Training step
386
  # ═══════════════════════════════════════════════════════════
387
 
388
  _nan_count = 0
 
391
  extras=None, grad_accum_steps=1, step=0,
392
  max_grad_norm=1.0, autocast_dtype=None,
393
  mtp_weight=0.1) -> float:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  global _nan_count
395
  extras = extras or {}
396
  is_accum = (step + 1) % grad_accum_steps == 0
 
414
  ignore_index=-100, reduction="none"
415
  ).reshape(B, T)
416
 
 
 
 
 
 
417
  with torch.no_grad():
418
  seq_loss = per_token.mean(dim=1)
419
  seq_mean = seq_loss.mean()
420
  seq_std = seq_loss.std().clamp(min=1e-6)
421
  z = ((seq_loss - seq_mean) / seq_std).clamp(-2.0, 2.0)
422
+ seq_weights = torch.sigmoid(z) * 0.7 + 0.7
423
 
 
424
  triage = extras.get("triage")
425
  tok_weights = triage.compute_weights(per_token) if triage else torch.ones_like(per_token)
426
 
 
427
  combined = tok_weights * seq_weights.unsqueeze(1)
428
  base_loss = (per_token * combined).sum() / combined.sum()
429
 
 
430
  mtp = extras.get("mtp")
431
  hidden = getattr(outputs, "hidden_states", None)
432
  if mtp is not None and hidden is not None:
 
439
 
440
  loss_val = total_loss.item()
441
 
442
+ # NaN guard β€” skip step AND repair corrupted state
443
  if not math.isfinite(loss_val):
444
  _nan_count += 1
445
  optimizer.zero_grad(set_to_none=True)
446
+ with torch.no_grad():
447
+ for p in model.parameters():
448
+ if not torch.isfinite(p.data).all():
449
+ p.data.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
450
+ for group in optimizer.param_groups:
451
+ for p in group["params"]:
452
+ s = optimizer.state.get(p, {})
453
+ for key in ("buf", "m", "v"):
454
+ if key in s and not torch.isfinite(s[key]).all():
455
+ s[key].nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
456
+ if _nan_count >= 10:
457
  for pg in optimizer.param_groups:
458
  pg["lr"] *= 0.5
459
+ print(f" [NaN] 10x -- LR halved to {optimizer.param_groups[0]['lr']:.2e}")
460
  _nan_count = 0
461
  return loss_val
462
  _nan_count = 0
463
 
 
464
  plateau = extras.get("plateau")
465
  if plateau:
466
  plateau.check_and_adjust(loss_val, optimizer, step)
 
469
  total_loss = total_loss / grad_accum_steps
470
  total_loss.backward()
471
 
472
+ # Sanitize gradients every step β€” BitLinear STE + complex recurrent
473
+ # layers produce occasional NaN gradients that MUST be caught immediately.
474
+ for p in model.parameters():
475
+ if p.grad is not None and not torch.isfinite(p.grad).all():
476
+ p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
 
477
 
 
478
  grokfast = extras.get("grokfast")
479
  if grokfast:
480
  grokfast.apply(model)
481
 
482
  if is_accum:
483
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
484
+ optimizer.step()
485
  scheduler.step()
486
  optimizer.zero_grad(set_to_none=True)
487
  invalidate_all_caches(model)