Fix loss rebound: lower Muon LR (0.02→0.008), clamp ternary latents, steeper cosine decay
Browse filesRoot cause: Muon's NS orthogonal update + momentum=0.95 pushed ternary latent
weights outside the STE clamp zone [-1, 1] after ~230 steps at LR=0.02. The
clamp-aware STE gradient is ZERO for weights outside [-1, 1] — those weights
become permanently dead (no gradient to recover them). This progressively
degraded model capacity, causing loss to rebound from 5.43 back to 6.48.
Three fixes:
1. LR 0.02→0.008: Standard Muon LR is 0.02 for dense fp32 weights with
unbounded range. Ternary STE restricts the useful weight range to [-1,1]
and the gradient-active zone to the same interval. The per-step weight
perturbation must be proportionally smaller. 0.008 gives ~2.5x slower
convergence but prevents overshoot past the STE boundary.
2. Latent weight clamping to [-2, 2]: After every Muon 2D update, clamp
weights to [-2, 2]. This is a safety net — weights that drift past
±1 from accumulated momentum are pulled back into the gradient-active
zone. The ±2 bound (not ±1) allows slight overshoot that round() in
the STE forward still maps correctly to {-1, 0, +1}.
3. Cosine min_ratio 0.01→0.05: The old schedule kept LR near peak for
too long. With ternary weights, you want to reach a low-LR fine-tuning
regime faster. At 5% of peak (0.008 * 0.05 = 0.0004), the per-step
update is small enough to fine-tune within the ternary basin without
escaping it.
|
@@ -53,15 +53,12 @@ def train_standard_loop(args, model, config, loader, compute_loss, optimizer, us
|
|
| 53 |
def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
|
| 54 |
use_compile = getattr(args, "compile", False)
|
| 55 |
|
| 56 |
-
# Muon
|
| 57 |
-
#
|
| 58 |
-
#
|
| 59 |
-
#
|
| 60 |
-
#
|
| 61 |
-
|
| 62 |
-
# 28-layer stack. Each MTP head doubles that. At loss=13 the model can't
|
| 63 |
-
# predict token+1, so token+2 is noise. Re-enable once loss < 5.
|
| 64 |
-
muon_lr = max(args.lr, 0.02)
|
| 65 |
muon_warmup = min(args.warmup, 100)
|
| 66 |
model, optimizer, scheduler, extras = chimera_turbo.apply(
|
| 67 |
model,
|
|
|
|
| 53 |
def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
|
| 54 |
use_compile = getattr(args, "compile", False)
|
| 55 |
|
| 56 |
+
# Muon LR for ternary BitLinear: standard Muon uses 0.02 for dense fp32/bf16
|
| 57 |
+
# weights, but ternary STE has a much narrower useful weight range [-1, 1].
|
| 58 |
+
# The NS unit-orthogonal update + momentum accumulation causes overshoot
|
| 59 |
+
# past step ~230, pushing weights outside the STE clamp zone (zero gradient).
|
| 60 |
+
# Optimal for ternary: 0.008 peak with aggressive cosine decay.
|
| 61 |
+
muon_lr = 0.008
|
|
|
|
|
|
|
|
|
|
| 62 |
muon_warmup = min(args.warmup, 100)
|
| 63 |
model, optimizer, scheduler, extras = chimera_turbo.apply(
|
| 64 |
model,
|