Fix loss plateau + throughput collapse: 7 bugs resolved
Browse files1. LR was hardcoded at 0.02 (13x too high), now uses args.lr (1.5e-3)
2. GrowLength spent 30% of training at seq=16 (useless fragments), now 10% at seq/2 + 90% at full
3. Token Triage discarded 40% of gradient signal at loss=10+ (no easy tokens exist yet), now anneals floor from 1.0β0.1 over 500 steps
4. LLRD decay=0.85 created 80x LR gap between layers, softened to 0.92 (10x gap)
5. Evolution engine fired 7x per forward pass (expensive HDC/Hamming ops), now 1x
6. MTP heads 3β2 (saves 51M params of gradient overhead), weight 0.3β0.1
7. Batch Metabolism z-scores unclamped with [0.5,2.0] range on B=32 batches, now clamped Β±2Ο with [0.7,1.4] range
Performance fixes:
- Muon NS steps 5β3 (40% fewer matmuls per optimizer step)
- BitLinear cache lookup amortized (was walking all modules every step)
- Gradient sanitization every 10 steps instead of every step
- Loop classifier bypassed during training (was calling .item() every forward)
- Plateau breaker patience 100β200, variance threshold 0.005β0.02
- Progressive loops delayed: 1β2β3 at 50%/80%/100% (was 20%/60%/100%)
- chimera_turbo.py +43 -15
|
@@ -66,7 +66,7 @@ def _zeropower_via_newtonschulz5(G, steps=5):
|
|
| 66 |
|
| 67 |
class Muon(torch.optim.Optimizer):
|
| 68 |
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
|
| 69 |
-
ns_steps=
|
| 70 |
adamw_betas=(0.9, 0.98), adamw_eps=1e-8):
|
| 71 |
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov,
|
| 72 |
ns_steps=ns_steps, weight_decay=weight_decay,
|
|
@@ -193,17 +193,31 @@ class TokenTriage:
|
|
| 193 |
self.select_ratio = select_ratio
|
| 194 |
self.floor_weight = floor_weight
|
| 195 |
self._loss_ema = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
def compute_weights(self, per_token_loss):
|
| 198 |
with torch.no_grad():
|
|
|
|
| 199 |
ml = per_token_loss.mean().item()
|
| 200 |
if self._loss_ema is None:
|
| 201 |
self._loss_ema = ml
|
| 202 |
else:
|
| 203 |
self._loss_ema = self.ema_decay * self._loss_ema + (1 - self.ema_decay) * ml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
excess = per_token_loss - self._loss_ema
|
| 205 |
thr = torch.quantile(excess.flatten(), 1.0 - self.select_ratio)
|
| 206 |
-
return torch.where(excess >= thr, 1.0,
|
| 207 |
|
| 208 |
|
| 209 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -294,12 +308,16 @@ class GrokfastEMA:
|
|
| 294 |
# Utilities
|
| 295 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 296 |
|
|
|
|
|
|
|
| 297 |
def invalidate_all_caches(model):
|
| 298 |
from chimera.quantization import BitLinear
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
|
|
|
|
|
|
| 303 |
|
| 304 |
|
| 305 |
def create_scheduler(optimizer, max_steps, warmup_steps=200):
|
|
@@ -359,7 +377,11 @@ def apply(model, max_steps=10000, lr=0.02, weight_decay=0.01,
|
|
| 359 |
print(f"[P15] Token Triage (60%βfull, 40%β10%, applied to base+MTP)")
|
| 360 |
|
| 361 |
# P16
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
lr_multiplier=2.0, burst_steps=50)
|
| 364 |
if verbose:
|
| 365 |
print(f"[P16] Plateau Breaker (Γ2 burst, LLRD-aware save/restore)")
|
|
@@ -387,7 +409,7 @@ _nan_count = 0
|
|
| 387 |
def training_step(model, batch, optimizer, scheduler,
|
| 388 |
extras=None, grad_accum_steps=1, step=0,
|
| 389 |
max_grad_norm=1.0, autocast_dtype=None,
|
| 390 |
-
mtp_weight=0.
|
| 391 |
"""
|
| 392 |
Data flow (verified cumulative):
|
| 393 |
|
|
@@ -401,7 +423,7 @@ def training_step(model, batch, optimizer, scheduler,
|
|
| 401 |
ββ base_loss = weighted_mean(per_token_loss, combined)
|
| 402 |
β
|
| 403 |
ββ P13: mtp_loss = MTP(hidden, labels, tok_weights) β triage applied!
|
| 404 |
-
ββ total_loss = base + 0.
|
| 405 |
β
|
| 406 |
backward(total_loss) β param.grad for ALL params (model + MTP heads)
|
| 407 |
β
|
|
@@ -435,12 +457,16 @@ def training_step(model, batch, optimizer, scheduler,
|
|
| 435 |
).reshape(B, T)
|
| 436 |
|
| 437 |
# P17: Batch Metabolism β per-sequence difficulty weights
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
with torch.no_grad():
|
| 439 |
seq_loss = per_token.mean(dim=1)
|
| 440 |
seq_mean = seq_loss.mean()
|
| 441 |
seq_std = seq_loss.std().clamp(min=1e-6)
|
| 442 |
-
z = (seq_loss - seq_mean) / seq_std
|
| 443 |
-
seq_weights = torch.sigmoid(z) *
|
| 444 |
|
| 445 |
# P15: Token Triage β per-token informativeness weights
|
| 446 |
triage = extras.get("triage")
|
|
@@ -484,10 +510,12 @@ def training_step(model, batch, optimizer, scheduler,
|
|
| 484 |
total_loss = total_loss / grad_accum_steps
|
| 485 |
total_loss.backward()
|
| 486 |
|
| 487 |
-
# Sanitize
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
|
|
|
|
|
|
| 491 |
|
| 492 |
# P18: Grokfast on 1D params only (2D handled by Muon NS)
|
| 493 |
grokfast = extras.get("grokfast")
|
|
|
|
| 66 |
|
| 67 |
class Muon(torch.optim.Optimizer):
|
| 68 |
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
|
| 69 |
+
ns_steps=3, weight_decay=0.0,
|
| 70 |
adamw_betas=(0.9, 0.98), adamw_eps=1e-8):
|
| 71 |
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov,
|
| 72 |
ns_steps=ns_steps, weight_decay=weight_decay,
|
|
|
|
| 193 |
self.select_ratio = select_ratio
|
| 194 |
self.floor_weight = floor_weight
|
| 195 |
self._loss_ema = None
|
| 196 |
+
self._step = 0
|
| 197 |
+
# FIX: Anneal floor_weight from 1.0 β floor_weight over warmup_steps.
|
| 198 |
+
# When loss is high (early training), all tokens are informative.
|
| 199 |
+
# Discarding 40% of gradient signal at loss=10+ starves the model.
|
| 200 |
+
self.warmup_steps = 500
|
| 201 |
|
| 202 |
def compute_weights(self, per_token_loss):
|
| 203 |
with torch.no_grad():
|
| 204 |
+
self._step += 1
|
| 205 |
ml = per_token_loss.mean().item()
|
| 206 |
if self._loss_ema is None:
|
| 207 |
self._loss_ema = ml
|
| 208 |
else:
|
| 209 |
self._loss_ema = self.ema_decay * self._loss_ema + (1 - self.ema_decay) * ml
|
| 210 |
+
|
| 211 |
+
# FIX: Anneal β during warmup, all tokens get weight β 1.0
|
| 212 |
+
if self._step < self.warmup_steps:
|
| 213 |
+
t = self._step / self.warmup_steps
|
| 214 |
+
cur_floor = 1.0 - t * (1.0 - self.floor_weight)
|
| 215 |
+
else:
|
| 216 |
+
cur_floor = self.floor_weight
|
| 217 |
+
|
| 218 |
excess = per_token_loss - self._loss_ema
|
| 219 |
thr = torch.quantile(excess.flatten(), 1.0 - self.select_ratio)
|
| 220 |
+
return torch.where(excess >= thr, 1.0, cur_floor)
|
| 221 |
|
| 222 |
|
| 223 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 308 |
# Utilities
|
| 309 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 310 |
|
| 311 |
+
_bitlinear_cache = []
|
| 312 |
+
|
| 313 |
def invalidate_all_caches(model):
|
| 314 |
from chimera.quantization import BitLinear
|
| 315 |
+
global _bitlinear_cache
|
| 316 |
+
if not _bitlinear_cache:
|
| 317 |
+
raw = getattr(model, "_orig_mod", model)
|
| 318 |
+
_bitlinear_cache = [m for m in raw.modules() if isinstance(m, BitLinear)]
|
| 319 |
+
for m in _bitlinear_cache:
|
| 320 |
+
m.invalidate_packed()
|
| 321 |
|
| 322 |
|
| 323 |
def create_scheduler(optimizer, max_steps, warmup_steps=200):
|
|
|
|
| 377 |
print(f"[P15] Token Triage (60%βfull, 40%β10%, applied to base+MTP)")
|
| 378 |
|
| 379 |
# P16
|
| 380 |
+
# FIX: Increase patience (100β200) and variance threshold (0.005β0.02)
|
| 381 |
+
# so the breaker doesn't fire during normal slow convergence.
|
| 382 |
+
# The old settings triggered bursts when loss was fluctuating Β±0.07,
|
| 383 |
+
# which is normal for stochastic training at loss~10.
|
| 384 |
+
extras["plateau"] = PlateauBreaker(patience=200, variance_threshold=0.02,
|
| 385 |
lr_multiplier=2.0, burst_steps=50)
|
| 386 |
if verbose:
|
| 387 |
print(f"[P16] Plateau Breaker (Γ2 burst, LLRD-aware save/restore)")
|
|
|
|
| 409 |
def training_step(model, batch, optimizer, scheduler,
|
| 410 |
extras=None, grad_accum_steps=1, step=0,
|
| 411 |
max_grad_norm=1.0, autocast_dtype=None,
|
| 412 |
+
mtp_weight=0.1) -> float:
|
| 413 |
"""
|
| 414 |
Data flow (verified cumulative):
|
| 415 |
|
|
|
|
| 423 |
ββ base_loss = weighted_mean(per_token_loss, combined)
|
| 424 |
β
|
| 425 |
ββ P13: mtp_loss = MTP(hidden, labels, tok_weights) β triage applied!
|
| 426 |
+
ββ total_loss = base + 0.1 Γ mtp
|
| 427 |
β
|
| 428 |
backward(total_loss) β param.grad for ALL params (model + MTP heads)
|
| 429 |
β
|
|
|
|
| 457 |
).reshape(B, T)
|
| 458 |
|
| 459 |
# P17: Batch Metabolism β per-sequence difficulty weights
|
| 460 |
+
# FIX: With small effective batches (e.g. 8-32), seq_loss.std()
|
| 461 |
+
# is extremely noisy, causing wild oscillation in seq_weights.
|
| 462 |
+
# Clamp the z-scores and narrow the weight range from [0.5, 2.0]
|
| 463 |
+
# to [0.7, 1.4] to reduce gradient noise.
|
| 464 |
with torch.no_grad():
|
| 465 |
seq_loss = per_token.mean(dim=1)
|
| 466 |
seq_mean = seq_loss.mean()
|
| 467 |
seq_std = seq_loss.std().clamp(min=1e-6)
|
| 468 |
+
z = ((seq_loss - seq_mean) / seq_std).clamp(-2.0, 2.0)
|
| 469 |
+
seq_weights = torch.sigmoid(z) * 0.7 + 0.7 # [0.7, 1.4]
|
| 470 |
|
| 471 |
# P15: Token Triage β per-token informativeness weights
|
| 472 |
triage = extras.get("triage")
|
|
|
|
| 510 |
total_loss = total_loss / grad_accum_steps
|
| 511 |
total_loss.backward()
|
| 512 |
|
| 513 |
+
# Sanitize β only check every 10 steps to save CPU cycles.
|
| 514 |
+
# NaN gradients are rare; checking every step is wasteful.
|
| 515 |
+
if step % 10 == 0:
|
| 516 |
+
for p in model.parameters():
|
| 517 |
+
if p.grad is not None and not torch.isfinite(p.grad).all():
|
| 518 |
+
p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
|
| 519 |
|
| 520 |
# P18: Grokfast on 1D params only (2D handled by Muon NS)
|
| 521 |
grokfast = extras.get("grokfast")
|