Fix throughput (26→~80+ tok/s) and convergence (lr 0.0015→0.02)
Browse filesTHROUGHPUT (26 tok/s → ~80+ tok/s estimated):
Root cause: lm_head [256, 200073] matmul costs 52 GFLOPs per step — 3.8x
the ENTIRE 28-layer transformer stack. MTP added another 52 GFLOPs head.
The model spent 75% of its time on vocabulary projection, not learning.
Fixes:
1. MTP disabled (mtp_heads=0): removes 51M-param head that costs as much
as lm_head itself. At loss=13 the model can't predict token+1, so
multi-token prediction is pure overhead. Re-enable once loss < 5.
2. Skip model's internal CE loss: training_step passes labels=None to
forward(), avoiding a redundant 200K-dim cross_entropy computation
(training_step computes its own weighted CE for triage/metabolism).
3. SpanInferenceEngine skipped during training: risk-gated modulation
on hidden states is inference-only, no training signal.
4. Grammar + DebtLedger skipped during training: these are identity/
near-identity on 200K-dim logits, but still allocate intermediates.
5. Faster grad sanitization: single concatenated isfinite check instead
of per-parameter scan (common case = clean, avoid O(N_params) loop).
CONVERGENCE (glacial → fast):
Root cause: LR=0.0015 with 200-step warmup. Muon's Newton-Schulz
orthogonalization normalizes update DIRECTION — LR controls step SIZE.
Standard Muon LR is 0.02, not 0.0015 (which was the AdamW/STE default).
At step 90 the effective LR was 6.5e-05 — essentially zero learning.
Fixes:
6. LR raised to 0.02 (Muon standard) with floor: max(args.lr, 0.02)
7. Warmup shortened to 100 steps (NS already stabilizes early updates)
- chimera_turbo.py +27 -31
|
@@ -4,10 +4,10 @@ chimera_turbo.py — CHIMERA GENESIS v12
|
|
| 4 |
Interaction-audited paradigm stack. Every paradigm verified cumulative.
|
| 5 |
|
| 6 |
P12 Muon — NS-orthogonalized momentum for 2D matrices
|
| 7 |
-
P13 MTP —
|
| 8 |
P15 Token Triage — focus on informative tokens (applied to ALL losses)
|
| 9 |
P16 Plateau Breaker — adaptive LR burst (LLRD-aware save/restore)
|
| 10 |
-
P17 Batch Metabolism — hard sequences weighted
|
| 11 |
P18 Grokfast-EMA — amplify slow grads (1D params ONLY — NS cancels on 2D)
|
| 12 |
P19 LLRD — layer-wise LR decay for ternary
|
| 13 |
"""
|
|
@@ -114,7 +114,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 114 |
|
| 115 |
def create_muon_optimizer(model, lr=0.02, momentum=0.95, weight_decay=0.01,
|
| 116 |
llrd_decay=0.85, extra_params=None):
|
| 117 |
-
"""Create Muon with LLRD. extra_params: additional nn.Module params to include."""
|
| 118 |
raw = getattr(model, "_orig_mod", model)
|
| 119 |
n_layers = len(raw.layers) if hasattr(raw, "layers") else 28
|
| 120 |
|
|
@@ -134,7 +133,6 @@ def create_muon_optimizer(model, lr=0.02, momentum=0.95, weight_decay=0.01,
|
|
| 134 |
lr_scale = llrd_decay ** n_layers
|
| 135 |
param_groups.append({"params": [p], "lr_scale": lr_scale})
|
| 136 |
|
| 137 |
-
# Add extra params (e.g. MTP heads) at full LR
|
| 138 |
if extra_params:
|
| 139 |
for p in extra_params:
|
| 140 |
if p.requires_grad:
|
|
@@ -159,7 +157,6 @@ class MultiTokenPredictionLoss(nn.Module):
|
|
| 159 |
nn.init.normal_(h.weight, std=0.006)
|
| 160 |
|
| 161 |
def forward(self, hidden_states, labels, token_weights=None):
|
| 162 |
-
"""Compute MTP loss, optionally weighted by token_weights from Triage."""
|
| 163 |
total, count = 0.0, 0
|
| 164 |
for k, head in enumerate(self.extra_heads):
|
| 165 |
shift = k + 2
|
|
@@ -168,7 +165,6 @@ class MultiTokenPredictionLoss(nn.Module):
|
|
| 168 |
logits = head(hidden_states[:, :-shift])
|
| 169 |
targets = labels[:, shift:]
|
| 170 |
sl = min(logits.size(1), targets.size(1))
|
| 171 |
-
|
| 172 |
if token_weights is not None:
|
| 173 |
per_tok = F.cross_entropy(
|
| 174 |
logits[:, :sl].reshape(-1, logits.size(-1)),
|
|
@@ -180,7 +176,6 @@ class MultiTokenPredictionLoss(nn.Module):
|
|
| 180 |
loss = F.cross_entropy(
|
| 181 |
logits[:, :sl].reshape(-1, logits.size(-1)),
|
| 182 |
targets[:, :sl].reshape(-1), ignore_index=-100)
|
| 183 |
-
|
| 184 |
if torch.isfinite(loss):
|
| 185 |
total = total + loss
|
| 186 |
count += 1
|
|
@@ -208,20 +203,18 @@ class TokenTriage:
|
|
| 208 |
self._loss_ema = ml
|
| 209 |
else:
|
| 210 |
self._loss_ema = self.ema_decay * self._loss_ema + (1 - self.ema_decay) * ml
|
| 211 |
-
|
| 212 |
if self._step < self.warmup_steps:
|
| 213 |
t = self._step / self.warmup_steps
|
| 214 |
cur_floor = 1.0 - t * (1.0 - self.floor_weight)
|
| 215 |
else:
|
| 216 |
cur_floor = self.floor_weight
|
| 217 |
-
|
| 218 |
excess = per_token_loss - self._loss_ema
|
| 219 |
thr = torch.quantile(excess.flatten(), 1.0 - self.select_ratio)
|
| 220 |
return torch.where(excess >= thr, 1.0, cur_floor)
|
| 221 |
|
| 222 |
|
| 223 |
# ═══════════════════════════════════════════════════════════
|
| 224 |
-
# P16 Plateau Breaker
|
| 225 |
# ═══════════════════════════════════════════════════════════
|
| 226 |
|
| 227 |
class PlateauBreaker:
|
|
@@ -264,8 +257,6 @@ class PlateauBreaker:
|
|
| 264 |
self._burst_remaining = self.burst_steps
|
| 265 |
self._stagnant_count = 0
|
| 266 |
self.total_bursts += 1
|
| 267 |
-
base = self._saved_lrs[0]
|
| 268 |
-
print(f" [P16] Plateau! LR x{self.lr_mult} for {self.burst_steps} steps (base {base:.2e})")
|
| 269 |
return True
|
| 270 |
return False
|
| 271 |
|
|
@@ -343,10 +334,14 @@ def apply(model, max_steps=10000, lr=0.02, weight_decay=0.01,
|
|
| 343 |
extras = {}
|
| 344 |
|
| 345 |
h, v = raw.config["hidden_size"], raw.config["vocab_size"]
|
| 346 |
-
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
-
mtp_params = list(mtp.parameters())
|
| 350 |
optimizer = create_muon_optimizer(model, lr=lr, weight_decay=weight_decay,
|
| 351 |
llrd_decay=llrd_decay, extra_params=mtp_params)
|
| 352 |
scheduler = create_scheduler(optimizer, max_steps, warmup_steps)
|
|
@@ -354,28 +349,29 @@ def apply(model, max_steps=10000, lr=0.02, weight_decay=0.01,
|
|
| 354 |
if verbose:
|
| 355 |
n_total = sum(p.numel() for g in optimizer.param_groups for p in g["params"])
|
| 356 |
scales = [g["lr_scale"] for g in optimizer.param_groups]
|
| 357 |
-
n_mtp = sum(p.numel() for p in mtp_params)
|
| 358 |
print(f"[P12] Muon (lr={lr}) + [P19] LLRD (decay={llrd_decay})")
|
| 359 |
print(f" {n_total:,} params, LR: {min(scales):.3f}x -> {max(scales):.3f}x")
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
extras["triage"] = TokenTriage(ema_decay=0.99, select_ratio=0.6, floor_weight=0.1)
|
| 363 |
if verbose:
|
| 364 |
-
print(f"[P15] Token Triage (
|
| 365 |
|
| 366 |
extras["plateau"] = PlateauBreaker(patience=200, variance_threshold=0.02,
|
| 367 |
lr_multiplier=2.0, burst_steps=50)
|
| 368 |
if verbose:
|
| 369 |
-
print(f"[P16] Plateau Breaker (x2 burst, LLRD-aware
|
| 370 |
|
| 371 |
extras["grokfast"] = GrokfastEMA(alpha=grokfast_alpha, lamb=grokfast_lambda)
|
| 372 |
if verbose:
|
| 373 |
n_1d = sum(p.numel() for p in model.parameters()
|
| 374 |
if p.requires_grad and (p.ndim < 2 or getattr(p, "_is_embed", False)))
|
| 375 |
-
print(f"[P18] Grokfast-EMA (a={grokfast_alpha}, l={grokfast_lambda}, {n_1d:,}
|
| 376 |
-
|
| 377 |
-
if verbose:
|
| 378 |
-
print(f"[P17] Batch Metabolism (hard seq x2, easy x0.5)")
|
| 379 |
print("=" * 65)
|
| 380 |
|
| 381 |
return model, optimizer, scheduler, extras
|
|
@@ -400,7 +396,7 @@ def training_step(model, batch, optimizer, scheduler,
|
|
| 400 |
if isinstance(batch, dict):
|
| 401 |
input_ids = batch["input_ids"]
|
| 402 |
labels = batch.get("labels", input_ids)
|
| 403 |
-
outputs = model(input_ids, labels=
|
| 404 |
else:
|
| 405 |
outputs = model(batch)
|
| 406 |
input_ids = labels = batch
|
|
@@ -439,7 +435,6 @@ def training_step(model, batch, optimizer, scheduler,
|
|
| 439 |
|
| 440 |
loss_val = total_loss.item()
|
| 441 |
|
| 442 |
-
# NaN guard — skip step AND repair corrupted state
|
| 443 |
if not math.isfinite(loss_val):
|
| 444 |
_nan_count += 1
|
| 445 |
optimizer.zero_grad(set_to_none=True)
|
|
@@ -456,7 +451,6 @@ def training_step(model, batch, optimizer, scheduler,
|
|
| 456 |
if _nan_count >= 10:
|
| 457 |
for pg in optimizer.param_groups:
|
| 458 |
pg["lr"] *= 0.5
|
| 459 |
-
print(f" [NaN] 10x -- LR halved to {optimizer.param_groups[0]['lr']:.2e}")
|
| 460 |
_nan_count = 0
|
| 461 |
return loss_val
|
| 462 |
_nan_count = 0
|
|
@@ -469,11 +463,13 @@ def training_step(model, batch, optimizer, scheduler,
|
|
| 469 |
total_loss = total_loss / grad_accum_steps
|
| 470 |
total_loss.backward()
|
| 471 |
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
if
|
| 476 |
-
p
|
|
|
|
|
|
|
| 477 |
|
| 478 |
grokfast = extras.get("grokfast")
|
| 479 |
if grokfast:
|
|
|
|
| 4 |
Interaction-audited paradigm stack. Every paradigm verified cumulative.
|
| 5 |
|
| 6 |
P12 Muon — NS-orthogonalized momentum for 2D matrices
|
| 7 |
+
P13 MTP — aux heads (disabled when vocab/hidden ratio too high)
|
| 8 |
P15 Token Triage — focus on informative tokens (applied to ALL losses)
|
| 9 |
P16 Plateau Breaker — adaptive LR burst (LLRD-aware save/restore)
|
| 10 |
+
P17 Batch Metabolism — hard sequences weighted higher
|
| 11 |
P18 Grokfast-EMA — amplify slow grads (1D params ONLY — NS cancels on 2D)
|
| 12 |
P19 LLRD — layer-wise LR decay for ternary
|
| 13 |
"""
|
|
|
|
| 114 |
|
| 115 |
def create_muon_optimizer(model, lr=0.02, momentum=0.95, weight_decay=0.01,
|
| 116 |
llrd_decay=0.85, extra_params=None):
|
|
|
|
| 117 |
raw = getattr(model, "_orig_mod", model)
|
| 118 |
n_layers = len(raw.layers) if hasattr(raw, "layers") else 28
|
| 119 |
|
|
|
|
| 133 |
lr_scale = llrd_decay ** n_layers
|
| 134 |
param_groups.append({"params": [p], "lr_scale": lr_scale})
|
| 135 |
|
|
|
|
| 136 |
if extra_params:
|
| 137 |
for p in extra_params:
|
| 138 |
if p.requires_grad:
|
|
|
|
| 157 |
nn.init.normal_(h.weight, std=0.006)
|
| 158 |
|
| 159 |
def forward(self, hidden_states, labels, token_weights=None):
|
|
|
|
| 160 |
total, count = 0.0, 0
|
| 161 |
for k, head in enumerate(self.extra_heads):
|
| 162 |
shift = k + 2
|
|
|
|
| 165 |
logits = head(hidden_states[:, :-shift])
|
| 166 |
targets = labels[:, shift:]
|
| 167 |
sl = min(logits.size(1), targets.size(1))
|
|
|
|
| 168 |
if token_weights is not None:
|
| 169 |
per_tok = F.cross_entropy(
|
| 170 |
logits[:, :sl].reshape(-1, logits.size(-1)),
|
|
|
|
| 176 |
loss = F.cross_entropy(
|
| 177 |
logits[:, :sl].reshape(-1, logits.size(-1)),
|
| 178 |
targets[:, :sl].reshape(-1), ignore_index=-100)
|
|
|
|
| 179 |
if torch.isfinite(loss):
|
| 180 |
total = total + loss
|
| 181 |
count += 1
|
|
|
|
| 203 |
self._loss_ema = ml
|
| 204 |
else:
|
| 205 |
self._loss_ema = self.ema_decay * self._loss_ema + (1 - self.ema_decay) * ml
|
|
|
|
| 206 |
if self._step < self.warmup_steps:
|
| 207 |
t = self._step / self.warmup_steps
|
| 208 |
cur_floor = 1.0 - t * (1.0 - self.floor_weight)
|
| 209 |
else:
|
| 210 |
cur_floor = self.floor_weight
|
|
|
|
| 211 |
excess = per_token_loss - self._loss_ema
|
| 212 |
thr = torch.quantile(excess.flatten(), 1.0 - self.select_ratio)
|
| 213 |
return torch.where(excess >= thr, 1.0, cur_floor)
|
| 214 |
|
| 215 |
|
| 216 |
# ═══════════════════════════════════════════════════════════
|
| 217 |
+
# P16 Plateau Breaker
|
| 218 |
# ═══════════════════════════════════════════════════════════
|
| 219 |
|
| 220 |
class PlateauBreaker:
|
|
|
|
| 257 |
self._burst_remaining = self.burst_steps
|
| 258 |
self._stagnant_count = 0
|
| 259 |
self.total_bursts += 1
|
|
|
|
|
|
|
| 260 |
return True
|
| 261 |
return False
|
| 262 |
|
|
|
|
| 334 |
extras = {}
|
| 335 |
|
| 336 |
h, v = raw.config["hidden_size"], raw.config["vocab_size"]
|
| 337 |
+
if mtp_heads > 0:
|
| 338 |
+
mtp = MultiTokenPredictionLoss(h, v, n_future=mtp_heads)
|
| 339 |
+
extras["mtp"] = mtp
|
| 340 |
+
mtp_params = list(mtp.parameters())
|
| 341 |
+
else:
|
| 342 |
+
extras["mtp"] = None
|
| 343 |
+
mtp_params = []
|
| 344 |
|
|
|
|
| 345 |
optimizer = create_muon_optimizer(model, lr=lr, weight_decay=weight_decay,
|
| 346 |
llrd_decay=llrd_decay, extra_params=mtp_params)
|
| 347 |
scheduler = create_scheduler(optimizer, max_steps, warmup_steps)
|
|
|
|
| 349 |
if verbose:
|
| 350 |
n_total = sum(p.numel() for g in optimizer.param_groups for p in g["params"])
|
| 351 |
scales = [g["lr_scale"] for g in optimizer.param_groups]
|
|
|
|
| 352 |
print(f"[P12] Muon (lr={lr}) + [P19] LLRD (decay={llrd_decay})")
|
| 353 |
print(f" {n_total:,} params, LR: {min(scales):.3f}x -> {max(scales):.3f}x")
|
| 354 |
+
if mtp_heads > 0:
|
| 355 |
+
n_mtp = sum(p.numel() for p in mtp_params)
|
| 356 |
+
print(f"[P13] MTP ({mtp_heads} heads, {n_mtp:,} params)")
|
| 357 |
+
else:
|
| 358 |
+
print(f"[P13] MTP disabled (vocab/hidden ratio too high for CPU)")
|
| 359 |
|
| 360 |
extras["triage"] = TokenTriage(ema_decay=0.99, select_ratio=0.6, floor_weight=0.1)
|
| 361 |
if verbose:
|
| 362 |
+
print(f"[P15] Token Triage (annealed warmup)")
|
| 363 |
|
| 364 |
extras["plateau"] = PlateauBreaker(patience=200, variance_threshold=0.02,
|
| 365 |
lr_multiplier=2.0, burst_steps=50)
|
| 366 |
if verbose:
|
| 367 |
+
print(f"[P16] Plateau Breaker (x2 burst, LLRD-aware)")
|
| 368 |
|
| 369 |
extras["grokfast"] = GrokfastEMA(alpha=grokfast_alpha, lamb=grokfast_lambda)
|
| 370 |
if verbose:
|
| 371 |
n_1d = sum(p.numel() for p in model.parameters()
|
| 372 |
if p.requires_grad and (p.ndim < 2 or getattr(p, "_is_embed", False)))
|
| 373 |
+
print(f"[P18] Grokfast-EMA (a={grokfast_alpha}, l={grokfast_lambda}, {n_1d:,} 1D params)")
|
| 374 |
+
print(f"[P17] Batch Metabolism (clamped z-score)")
|
|
|
|
|
|
|
| 375 |
print("=" * 65)
|
| 376 |
|
| 377 |
return model, optimizer, scheduler, extras
|
|
|
|
| 396 |
if isinstance(batch, dict):
|
| 397 |
input_ids = batch["input_ids"]
|
| 398 |
labels = batch.get("labels", input_ids)
|
| 399 |
+
outputs = model(input_ids, labels=None)
|
| 400 |
else:
|
| 401 |
outputs = model(batch)
|
| 402 |
input_ids = labels = batch
|
|
|
|
| 435 |
|
| 436 |
loss_val = total_loss.item()
|
| 437 |
|
|
|
|
| 438 |
if not math.isfinite(loss_val):
|
| 439 |
_nan_count += 1
|
| 440 |
optimizer.zero_grad(set_to_none=True)
|
|
|
|
| 451 |
if _nan_count >= 10:
|
| 452 |
for pg in optimizer.param_groups:
|
| 453 |
pg["lr"] *= 0.5
|
|
|
|
| 454 |
_nan_count = 0
|
| 455 |
return loss_val
|
| 456 |
_nan_count = 0
|
|
|
|
| 463 |
total_loss = total_loss / grad_accum_steps
|
| 464 |
total_loss.backward()
|
| 465 |
|
| 466 |
+
grad_tensors = [p.grad for p in model.parameters() if p.grad is not None]
|
| 467 |
+
if grad_tensors:
|
| 468 |
+
all_grads = torch.cat([g.reshape(-1) for g in grad_tensors])
|
| 469 |
+
if not torch.isfinite(all_grads).all():
|
| 470 |
+
for p in model.parameters():
|
| 471 |
+
if p.grad is not None and not torch.isfinite(p.grad).all():
|
| 472 |
+
p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
|
| 473 |
|
| 474 |
grokfast = extras.get("grokfast")
|
| 475 |
if grokfast:
|