Fix NaN cascade: restore per-step gradient sanitization, add weight/momentum repair, harden Newton-Schulz
Browse filesThree NaN-related bugs:
1. Gradient sanitization was reduced from every step to every 10 steps as a
throughput optimization. But BitLinear STE + complex recurrent layers produce
occasional NaN gradients that MUST be caught immediately. A single uncaught
NaN gradient corrupts Muon's momentum buffer, which then corrupts the weights
on the next optimizer step, causing permanent all-NaN forward passes.
Fix: restore per-step sanitization.
2. NaN guard only zeroed gradients and halved LR β but never repaired the
already-corrupted weights or Muon momentum buffers. So once NaN entered
the weights, every subsequent step also produced NaN, triggering the LR
halving cascade (5 NaN β halve β 5 NaN β halve β ... β LR = 0).
Fix: NaN guard now sanitizes model weights AND optimizer momentum buffers.
Also increased threshold from 5 to 10 consecutive NaN before halving.
3. Newton-Schulz could produce NaN if input gradient matrix had near-zero norm
(< 1e-12) or already contained NaN values. Now returns zero matrix in
these cases instead of propagating NaN through the polynomial iterations.
- chimera_turbo.py +37 -80
|
@@ -57,10 +57,15 @@ def _zeropower_via_newtonschulz5(G, steps=5):
|
|
| 57 |
assert G.ndim == 2
|
| 58 |
a, b, c = 3.4445, -4.7750, 2.0315
|
| 59 |
X = G.T if G.size(0) > G.size(1) else G.clone()
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
| 61 |
for _ in range(steps):
|
| 62 |
A = X @ X.T
|
| 63 |
X = a * X + (b * A + c * A @ A) @ X
|
|
|
|
|
|
|
| 64 |
return X.T if G.size(0) > G.size(1) else X
|
| 65 |
|
| 66 |
|
|
@@ -165,7 +170,6 @@ class MultiTokenPredictionLoss(nn.Module):
|
|
| 165 |
sl = min(logits.size(1), targets.size(1))
|
| 166 |
|
| 167 |
if token_weights is not None:
|
| 168 |
-
# Apply token triage weights to MTP loss too
|
| 169 |
per_tok = F.cross_entropy(
|
| 170 |
logits[:, :sl].reshape(-1, logits.size(-1)),
|
| 171 |
targets[:, :sl].reshape(-1), ignore_index=-100, reduction="none"
|
|
@@ -194,9 +198,6 @@ class TokenTriage:
|
|
| 194 |
self.floor_weight = floor_weight
|
| 195 |
self._loss_ema = None
|
| 196 |
self._step = 0
|
| 197 |
-
# FIX: Anneal floor_weight from 1.0 β floor_weight over warmup_steps.
|
| 198 |
-
# When loss is high (early training), all tokens are informative.
|
| 199 |
-
# Discarding 40% of gradient signal at loss=10+ starves the model.
|
| 200 |
self.warmup_steps = 500
|
| 201 |
|
| 202 |
def compute_weights(self, per_token_loss):
|
|
@@ -208,7 +209,6 @@ class TokenTriage:
|
|
| 208 |
else:
|
| 209 |
self._loss_ema = self.ema_decay * self._loss_ema + (1 - self.ema_decay) * ml
|
| 210 |
|
| 211 |
-
# FIX: Anneal β during warmup, all tokens get weight β 1.0
|
| 212 |
if self._step < self.warmup_steps:
|
| 213 |
t = self._step / self.warmup_steps
|
| 214 |
cur_floor = 1.0 - t * (1.0 - self.floor_weight)
|
|
@@ -234,7 +234,7 @@ class PlateauBreaker:
|
|
| 234 |
self._history = deque(maxlen=patience)
|
| 235 |
self._stagnant_count = 0
|
| 236 |
self._burst_remaining = 0
|
| 237 |
-
self._saved_lrs = None
|
| 238 |
self.total_bursts = 0
|
| 239 |
|
| 240 |
def check_and_adjust(self, loss_val, optimizer, step):
|
|
@@ -244,7 +244,6 @@ class PlateauBreaker:
|
|
| 244 |
if self._burst_remaining > 0:
|
| 245 |
self._burst_remaining -= 1
|
| 246 |
if self._burst_remaining == 0 and self._saved_lrs is not None:
|
| 247 |
-
# Restore ALL group LRs (preserves LLRD ratios)
|
| 248 |
for pg, saved_lr in zip(optimizer.param_groups, self._saved_lrs):
|
| 249 |
pg["lr"] = saved_lr
|
| 250 |
self._saved_lrs = None
|
|
@@ -259,31 +258,23 @@ class PlateauBreaker:
|
|
| 259 |
else:
|
| 260 |
self._stagnant_count = 0
|
| 261 |
if self._stagnant_count >= self.patience // 2:
|
| 262 |
-
# Save ALL LRs before burst
|
| 263 |
self._saved_lrs = [pg["lr"] for pg in optimizer.param_groups]
|
| 264 |
for pg in optimizer.param_groups:
|
| 265 |
-
pg["lr"] *= self.lr_mult
|
| 266 |
self._burst_remaining = self.burst_steps
|
| 267 |
self._stagnant_count = 0
|
| 268 |
self.total_bursts += 1
|
| 269 |
base = self._saved_lrs[0]
|
| 270 |
-
print(f" [P16] Plateau! LR
|
| 271 |
return True
|
| 272 |
return False
|
| 273 |
|
| 274 |
|
| 275 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 276 |
-
# P18 Grokfast-EMA (1D params only
|
| 277 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 278 |
|
| 279 |
class GrokfastEMA:
|
| 280 |
-
"""Amplify slow gradient components for generalization.
|
| 281 |
-
|
| 282 |
-
Applied ONLY to 1D params and embeddings (AdamW path).
|
| 283 |
-
Skipped for 2D matrices because Muon's Newton-Schulz normalisation
|
| 284 |
-
cancels the amplitude amplification β only direction survives,
|
| 285 |
-
which Muon already optimises via orthogonalisation.
|
| 286 |
-
"""
|
| 287 |
def __init__(self, alpha=0.98, lamb=2.0):
|
| 288 |
self.alpha = alpha
|
| 289 |
self.lamb = lamb
|
|
@@ -294,7 +285,6 @@ class GrokfastEMA:
|
|
| 294 |
for name, p in model.named_parameters():
|
| 295 |
if p.grad is None:
|
| 296 |
continue
|
| 297 |
-
# Skip 2D Muon params β NS normalisation cancels amplitude
|
| 298 |
if p.ndim == 2 and not getattr(p, "_is_embed", False):
|
| 299 |
continue
|
| 300 |
if name not in self._ema:
|
|
@@ -352,12 +342,10 @@ def apply(model, max_steps=10000, lr=0.02, weight_decay=0.01,
|
|
| 352 |
raw = getattr(model, "_orig_mod", model)
|
| 353 |
extras = {}
|
| 354 |
|
| 355 |
-
# P13: Create MTP FIRST so we can add its params to optimizer
|
| 356 |
h, v = raw.config["hidden_size"], raw.config["vocab_size"]
|
| 357 |
mtp = MultiTokenPredictionLoss(h, v, n_future=mtp_heads)
|
| 358 |
extras["mtp"] = mtp
|
| 359 |
|
| 360 |
-
# P12+P19: Muon with LLRD + MTP head params included
|
| 361 |
mtp_params = list(mtp.parameters())
|
| 362 |
optimizer = create_muon_optimizer(model, lr=lr, weight_decay=weight_decay,
|
| 363 |
llrd_decay=llrd_decay, extra_params=mtp_params)
|
|
@@ -368,40 +356,33 @@ def apply(model, max_steps=10000, lr=0.02, weight_decay=0.01,
|
|
| 368 |
scales = [g["lr_scale"] for g in optimizer.param_groups]
|
| 369 |
n_mtp = sum(p.numel() for p in mtp_params)
|
| 370 |
print(f"[P12] Muon (lr={lr}) + [P19] LLRD (decay={llrd_decay})")
|
| 371 |
-
print(f" {n_total:,} params, LR: {min(scales):.3f}
|
| 372 |
-
print(f"[P13] MTP ({mtp_heads} heads, {n_mtp:,} params
|
| 373 |
|
| 374 |
-
# P15
|
| 375 |
extras["triage"] = TokenTriage(ema_decay=0.99, select_ratio=0.6, floor_weight=0.1)
|
| 376 |
if verbose:
|
| 377 |
-
print(f"[P15] Token Triage (60%
|
| 378 |
|
| 379 |
-
# P16
|
| 380 |
-
# FIX: Increase patience (100β200) and variance threshold (0.005β0.02)
|
| 381 |
-
# so the breaker doesn't fire during normal slow convergence.
|
| 382 |
-
# The old settings triggered bursts when loss was fluctuating Β±0.07,
|
| 383 |
-
# which is normal for stochastic training at loss~10.
|
| 384 |
extras["plateau"] = PlateauBreaker(patience=200, variance_threshold=0.02,
|
| 385 |
lr_multiplier=2.0, burst_steps=50)
|
| 386 |
if verbose:
|
| 387 |
-
print(f"[P16] Plateau Breaker (
|
| 388 |
|
| 389 |
-
# P18
|
| 390 |
extras["grokfast"] = GrokfastEMA(alpha=grokfast_alpha, lamb=grokfast_lambda)
|
| 391 |
if verbose:
|
| 392 |
n_1d = sum(p.numel() for p in model.parameters()
|
| 393 |
if p.requires_grad and (p.ndim < 2 or getattr(p, "_is_embed", False)))
|
| 394 |
-
print(f"[P18] Grokfast-EMA (
|
| 395 |
|
| 396 |
if verbose:
|
| 397 |
-
print(f"[P17] Batch Metabolism (hard seq
|
| 398 |
print("=" * 65)
|
| 399 |
|
| 400 |
return model, optimizer, scheduler, extras
|
| 401 |
|
| 402 |
|
| 403 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 404 |
-
# Training step
|
| 405 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 406 |
|
| 407 |
_nan_count = 0
|
|
@@ -410,29 +391,6 @@ def training_step(model, batch, optimizer, scheduler,
|
|
| 410 |
extras=None, grad_accum_steps=1, step=0,
|
| 411 |
max_grad_norm=1.0, autocast_dtype=None,
|
| 412 |
mtp_weight=0.1) -> float:
|
| 413 |
-
"""
|
| 414 |
-
Data flow (verified cumulative):
|
| 415 |
-
|
| 416 |
-
forward(batch) β logits, hidden_states
|
| 417 |
-
β
|
| 418 |
-
ββ per_token_loss = CE(logits, labels, reduction='none') [B,T]
|
| 419 |
-
β
|
| 420 |
-
ββ P17: seq_weights = sigmoid(z-score(per_seq_loss)) [B]
|
| 421 |
-
ββ P15: tok_weights = triage(excess_loss) [B,T]
|
| 422 |
-
ββ combined = tok_weights Γ seq_weights [B,T]
|
| 423 |
-
ββ base_loss = weighted_mean(per_token_loss, combined)
|
| 424 |
-
β
|
| 425 |
-
ββ P13: mtp_loss = MTP(hidden, labels, tok_weights) β triage applied!
|
| 426 |
-
ββ total_loss = base + 0.1 Γ mtp
|
| 427 |
-
β
|
| 428 |
-
backward(total_loss) β param.grad for ALL params (model + MTP heads)
|
| 429 |
-
β
|
| 430 |
-
ββ P18: Grokfast amplifies grad on 1D params only (skip 2D/Muon)
|
| 431 |
-
β
|
| 432 |
-
optimizer.step() β P12 Muon (2D) + AdamW (1D), P19 LLRD scales per group
|
| 433 |
-
β
|
| 434 |
-
ββ P16: Plateau checks loss_val, burst preserves LLRD ratios
|
| 435 |
-
"""
|
| 436 |
global _nan_count
|
| 437 |
extras = extras or {}
|
| 438 |
is_accum = (step + 1) % grad_accum_steps == 0
|
|
@@ -456,27 +414,19 @@ def training_step(model, batch, optimizer, scheduler,
|
|
| 456 |
ignore_index=-100, reduction="none"
|
| 457 |
).reshape(B, T)
|
| 458 |
|
| 459 |
-
# P17: Batch Metabolism β per-sequence difficulty weights
|
| 460 |
-
# FIX: With small effective batches (e.g. 8-32), seq_loss.std()
|
| 461 |
-
# is extremely noisy, causing wild oscillation in seq_weights.
|
| 462 |
-
# Clamp the z-scores and narrow the weight range from [0.5, 2.0]
|
| 463 |
-
# to [0.7, 1.4] to reduce gradient noise.
|
| 464 |
with torch.no_grad():
|
| 465 |
seq_loss = per_token.mean(dim=1)
|
| 466 |
seq_mean = seq_loss.mean()
|
| 467 |
seq_std = seq_loss.std().clamp(min=1e-6)
|
| 468 |
z = ((seq_loss - seq_mean) / seq_std).clamp(-2.0, 2.0)
|
| 469 |
-
seq_weights = torch.sigmoid(z) * 0.7 + 0.7
|
| 470 |
|
| 471 |
-
# P15: Token Triage β per-token informativeness weights
|
| 472 |
triage = extras.get("triage")
|
| 473 |
tok_weights = triage.compute_weights(per_token) if triage else torch.ones_like(per_token)
|
| 474 |
|
| 475 |
-
# Fuse: multiplicative composition
|
| 476 |
combined = tok_weights * seq_weights.unsqueeze(1)
|
| 477 |
base_loss = (per_token * combined).sum() / combined.sum()
|
| 478 |
|
| 479 |
-
# P13: MTP with Token Triage weights passed through
|
| 480 |
mtp = extras.get("mtp")
|
| 481 |
hidden = getattr(outputs, "hidden_states", None)
|
| 482 |
if mtp is not None and hidden is not None:
|
|
@@ -489,19 +439,28 @@ def training_step(model, batch, optimizer, scheduler,
|
|
| 489 |
|
| 490 |
loss_val = total_loss.item()
|
| 491 |
|
| 492 |
-
# NaN guard
|
| 493 |
if not math.isfinite(loss_val):
|
| 494 |
_nan_count += 1
|
| 495 |
optimizer.zero_grad(set_to_none=True)
|
| 496 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
for pg in optimizer.param_groups:
|
| 498 |
pg["lr"] *= 0.5
|
| 499 |
-
print(f" [NaN]
|
| 500 |
_nan_count = 0
|
| 501 |
return loss_val
|
| 502 |
_nan_count = 0
|
| 503 |
|
| 504 |
-
# P16: Plateau Breaker (before backward, uses loss_val only)
|
| 505 |
plateau = extras.get("plateau")
|
| 506 |
if plateau:
|
| 507 |
plateau.check_and_adjust(loss_val, optimizer, step)
|
|
@@ -510,21 +469,19 @@ def training_step(model, batch, optimizer, scheduler,
|
|
| 510 |
total_loss = total_loss / grad_accum_steps
|
| 511 |
total_loss.backward()
|
| 512 |
|
| 513 |
-
# Sanitize
|
| 514 |
-
#
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
|
| 519 |
|
| 520 |
-
# P18: Grokfast on 1D params only (2D handled by Muon NS)
|
| 521 |
grokfast = extras.get("grokfast")
|
| 522 |
if grokfast:
|
| 523 |
grokfast.apply(model)
|
| 524 |
|
| 525 |
if is_accum:
|
| 526 |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
| 527 |
-
optimizer.step()
|
| 528 |
scheduler.step()
|
| 529 |
optimizer.zero_grad(set_to_none=True)
|
| 530 |
invalidate_all_caches(model)
|
|
|
|
| 57 |
assert G.ndim == 2
|
| 58 |
a, b, c = 3.4445, -4.7750, 2.0315
|
| 59 |
X = G.T if G.size(0) > G.size(1) else G.clone()
|
| 60 |
+
nrm = X.norm()
|
| 61 |
+
if nrm < 1e-12 or not torch.isfinite(nrm):
|
| 62 |
+
return torch.zeros_like(G)
|
| 63 |
+
X = X / (nrm + 1e-7)
|
| 64 |
for _ in range(steps):
|
| 65 |
A = X @ X.T
|
| 66 |
X = a * X + (b * A + c * A @ A) @ X
|
| 67 |
+
if not torch.isfinite(X).all():
|
| 68 |
+
return torch.zeros_like(G)
|
| 69 |
return X.T if G.size(0) > G.size(1) else X
|
| 70 |
|
| 71 |
|
|
|
|
| 170 |
sl = min(logits.size(1), targets.size(1))
|
| 171 |
|
| 172 |
if token_weights is not None:
|
|
|
|
| 173 |
per_tok = F.cross_entropy(
|
| 174 |
logits[:, :sl].reshape(-1, logits.size(-1)),
|
| 175 |
targets[:, :sl].reshape(-1), ignore_index=-100, reduction="none"
|
|
|
|
| 198 |
self.floor_weight = floor_weight
|
| 199 |
self._loss_ema = None
|
| 200 |
self._step = 0
|
|
|
|
|
|
|
|
|
|
| 201 |
self.warmup_steps = 500
|
| 202 |
|
| 203 |
def compute_weights(self, per_token_loss):
|
|
|
|
| 209 |
else:
|
| 210 |
self._loss_ema = self.ema_decay * self._loss_ema + (1 - self.ema_decay) * ml
|
| 211 |
|
|
|
|
| 212 |
if self._step < self.warmup_steps:
|
| 213 |
t = self._step / self.warmup_steps
|
| 214 |
cur_floor = 1.0 - t * (1.0 - self.floor_weight)
|
|
|
|
| 234 |
self._history = deque(maxlen=patience)
|
| 235 |
self._stagnant_count = 0
|
| 236 |
self._burst_remaining = 0
|
| 237 |
+
self._saved_lrs = None
|
| 238 |
self.total_bursts = 0
|
| 239 |
|
| 240 |
def check_and_adjust(self, loss_val, optimizer, step):
|
|
|
|
| 244 |
if self._burst_remaining > 0:
|
| 245 |
self._burst_remaining -= 1
|
| 246 |
if self._burst_remaining == 0 and self._saved_lrs is not None:
|
|
|
|
| 247 |
for pg, saved_lr in zip(optimizer.param_groups, self._saved_lrs):
|
| 248 |
pg["lr"] = saved_lr
|
| 249 |
self._saved_lrs = None
|
|
|
|
| 258 |
else:
|
| 259 |
self._stagnant_count = 0
|
| 260 |
if self._stagnant_count >= self.patience // 2:
|
|
|
|
| 261 |
self._saved_lrs = [pg["lr"] for pg in optimizer.param_groups]
|
| 262 |
for pg in optimizer.param_groups:
|
| 263 |
+
pg["lr"] *= self.lr_mult
|
| 264 |
self._burst_remaining = self.burst_steps
|
| 265 |
self._stagnant_count = 0
|
| 266 |
self.total_bursts += 1
|
| 267 |
base = self._saved_lrs[0]
|
| 268 |
+
print(f" [P16] Plateau! LR x{self.lr_mult} for {self.burst_steps} steps (base {base:.2e})")
|
| 269 |
return True
|
| 270 |
return False
|
| 271 |
|
| 272 |
|
| 273 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 274 |
+
# P18 Grokfast-EMA (1D params only)
|
| 275 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 276 |
|
| 277 |
class GrokfastEMA:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
def __init__(self, alpha=0.98, lamb=2.0):
|
| 279 |
self.alpha = alpha
|
| 280 |
self.lamb = lamb
|
|
|
|
| 285 |
for name, p in model.named_parameters():
|
| 286 |
if p.grad is None:
|
| 287 |
continue
|
|
|
|
| 288 |
if p.ndim == 2 and not getattr(p, "_is_embed", False):
|
| 289 |
continue
|
| 290 |
if name not in self._ema:
|
|
|
|
| 342 |
raw = getattr(model, "_orig_mod", model)
|
| 343 |
extras = {}
|
| 344 |
|
|
|
|
| 345 |
h, v = raw.config["hidden_size"], raw.config["vocab_size"]
|
| 346 |
mtp = MultiTokenPredictionLoss(h, v, n_future=mtp_heads)
|
| 347 |
extras["mtp"] = mtp
|
| 348 |
|
|
|
|
| 349 |
mtp_params = list(mtp.parameters())
|
| 350 |
optimizer = create_muon_optimizer(model, lr=lr, weight_decay=weight_decay,
|
| 351 |
llrd_decay=llrd_decay, extra_params=mtp_params)
|
|
|
|
| 356 |
scales = [g["lr_scale"] for g in optimizer.param_groups]
|
| 357 |
n_mtp = sum(p.numel() for p in mtp_params)
|
| 358 |
print(f"[P12] Muon (lr={lr}) + [P19] LLRD (decay={llrd_decay})")
|
| 359 |
+
print(f" {n_total:,} params, LR: {min(scales):.3f}x -> {max(scales):.3f}x")
|
| 360 |
+
print(f"[P13] MTP ({mtp_heads} heads, {n_mtp:,} params -- IN optimizer)")
|
| 361 |
|
|
|
|
| 362 |
extras["triage"] = TokenTriage(ema_decay=0.99, select_ratio=0.6, floor_weight=0.1)
|
| 363 |
if verbose:
|
| 364 |
+
print(f"[P15] Token Triage (60%->full, 40%->10%, applied to base+MTP)")
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
extras["plateau"] = PlateauBreaker(patience=200, variance_threshold=0.02,
|
| 367 |
lr_multiplier=2.0, burst_steps=50)
|
| 368 |
if verbose:
|
| 369 |
+
print(f"[P16] Plateau Breaker (x2 burst, LLRD-aware save/restore)")
|
| 370 |
|
|
|
|
| 371 |
extras["grokfast"] = GrokfastEMA(alpha=grokfast_alpha, lamb=grokfast_lambda)
|
| 372 |
if verbose:
|
| 373 |
n_1d = sum(p.numel() for p in model.parameters()
|
| 374 |
if p.requires_grad and (p.ndim < 2 or getattr(p, "_is_embed", False)))
|
| 375 |
+
print(f"[P18] Grokfast-EMA (a={grokfast_alpha}, l={grokfast_lambda}, {n_1d:,} params -- 1D only)")
|
| 376 |
|
| 377 |
if verbose:
|
| 378 |
+
print(f"[P17] Batch Metabolism (hard seq x2, easy x0.5)")
|
| 379 |
print("=" * 65)
|
| 380 |
|
| 381 |
return model, optimizer, scheduler, extras
|
| 382 |
|
| 383 |
|
| 384 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 385 |
+
# Training step
|
| 386 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 387 |
|
| 388 |
_nan_count = 0
|
|
|
|
| 391 |
extras=None, grad_accum_steps=1, step=0,
|
| 392 |
max_grad_norm=1.0, autocast_dtype=None,
|
| 393 |
mtp_weight=0.1) -> float:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
global _nan_count
|
| 395 |
extras = extras or {}
|
| 396 |
is_accum = (step + 1) % grad_accum_steps == 0
|
|
|
|
| 414 |
ignore_index=-100, reduction="none"
|
| 415 |
).reshape(B, T)
|
| 416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
with torch.no_grad():
|
| 418 |
seq_loss = per_token.mean(dim=1)
|
| 419 |
seq_mean = seq_loss.mean()
|
| 420 |
seq_std = seq_loss.std().clamp(min=1e-6)
|
| 421 |
z = ((seq_loss - seq_mean) / seq_std).clamp(-2.0, 2.0)
|
| 422 |
+
seq_weights = torch.sigmoid(z) * 0.7 + 0.7
|
| 423 |
|
|
|
|
| 424 |
triage = extras.get("triage")
|
| 425 |
tok_weights = triage.compute_weights(per_token) if triage else torch.ones_like(per_token)
|
| 426 |
|
|
|
|
| 427 |
combined = tok_weights * seq_weights.unsqueeze(1)
|
| 428 |
base_loss = (per_token * combined).sum() / combined.sum()
|
| 429 |
|
|
|
|
| 430 |
mtp = extras.get("mtp")
|
| 431 |
hidden = getattr(outputs, "hidden_states", None)
|
| 432 |
if mtp is not None and hidden is not None:
|
|
|
|
| 439 |
|
| 440 |
loss_val = total_loss.item()
|
| 441 |
|
| 442 |
+
# NaN guard β skip step AND repair corrupted state
|
| 443 |
if not math.isfinite(loss_val):
|
| 444 |
_nan_count += 1
|
| 445 |
optimizer.zero_grad(set_to_none=True)
|
| 446 |
+
with torch.no_grad():
|
| 447 |
+
for p in model.parameters():
|
| 448 |
+
if not torch.isfinite(p.data).all():
|
| 449 |
+
p.data.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
|
| 450 |
+
for group in optimizer.param_groups:
|
| 451 |
+
for p in group["params"]:
|
| 452 |
+
s = optimizer.state.get(p, {})
|
| 453 |
+
for key in ("buf", "m", "v"):
|
| 454 |
+
if key in s and not torch.isfinite(s[key]).all():
|
| 455 |
+
s[key].nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
|
| 456 |
+
if _nan_count >= 10:
|
| 457 |
for pg in optimizer.param_groups:
|
| 458 |
pg["lr"] *= 0.5
|
| 459 |
+
print(f" [NaN] 10x -- LR halved to {optimizer.param_groups[0]['lr']:.2e}")
|
| 460 |
_nan_count = 0
|
| 461 |
return loss_val
|
| 462 |
_nan_count = 0
|
| 463 |
|
|
|
|
| 464 |
plateau = extras.get("plateau")
|
| 465 |
if plateau:
|
| 466 |
plateau.check_and_adjust(loss_val, optimizer, step)
|
|
|
|
| 469 |
total_loss = total_loss / grad_accum_steps
|
| 470 |
total_loss.backward()
|
| 471 |
|
| 472 |
+
# Sanitize gradients every step β BitLinear STE + complex recurrent
|
| 473 |
+
# layers produce occasional NaN gradients that MUST be caught immediately.
|
| 474 |
+
for p in model.parameters():
|
| 475 |
+
if p.grad is not None and not torch.isfinite(p.grad).all():
|
| 476 |
+
p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
| 477 |
|
|
|
|
| 478 |
grokfast = extras.get("grokfast")
|
| 479 |
if grokfast:
|
| 480 |
grokfast.apply(model)
|
| 481 |
|
| 482 |
if is_accum:
|
| 483 |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
| 484 |
+
optimizer.step()
|
| 485 |
scheduler.step()
|
| 486 |
optimizer.zero_grad(set_to_none=True)
|
| 487 |
invalidate_all_caches(model)
|