Lgr54HFi commited on
Commit
6d5c935
·
verified ·
1 Parent(s): d83bada

Upload chimera/training/loops.py

Browse files
Files changed (1) hide show
  1. chimera/training/loops.py +13 -9
chimera/training/loops.py CHANGED
@@ -53,26 +53,30 @@ def train_standard_loop(args, model, config, loader, compute_loss, optimizer, us
53
  def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
54
  use_compile = getattr(args, "compile", False)
55
 
56
- # FIX: Use args.lr instead of hardcoded 0.02.
57
- # FIX: Use args.warmup instead of hardcoded 200.
58
- # FIX: Reduce MTP heads from 3->2 to cut 51M params of overhead.
59
- # FIX: Soften LLRD decay (0.85->0.92) so early layers still learn.
60
- # FIX: Lower Grokfast lambda (2.0->1.0) to reduce gradient amplification noise.
 
 
 
 
 
61
  model, optimizer, scheduler, extras = chimera_turbo.apply(
62
  model,
63
  max_steps=args.max_steps,
64
- lr=args.lr,
65
  weight_decay=0.01,
66
- warmup_steps=args.warmup,
67
  use_compile=use_compile,
68
- mtp_heads=2,
69
  llrd_decay=0.92,
70
  grokfast_alpha=0.98,
71
  grokfast_lambda=1.0,
72
  )
73
  model.train()
74
 
75
- # Progressive looping
76
  loop_sched = ProgressiveLoopScheduler(args.max_steps, max_loops=3)
77
  cur_loops = 1
78
 
 
53
  def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
54
  use_compile = getattr(args, "compile", False)
55
 
56
+ # Muon needs higher LR than AdamW: NS orthogonalization normalizes
57
+ # update direction, so LR controls step SIZE not direction stability.
58
+ # 0.02 is the standard Muon LR; CLI default 1.5e-3 was for AdamW.
59
+ # Warmup shortened: NS already provides early stability.
60
+ #
61
+ # MTP DISABLED (mtp_heads=0): lm_head (256->200073) costs 4x the entire
62
+ # 28-layer stack. Each MTP head doubles that. At loss=13 the model can't
63
+ # predict token+1, so token+2 is noise. Re-enable once loss < 5.
64
+ muon_lr = max(args.lr, 0.02)
65
+ muon_warmup = min(args.warmup, 100)
66
  model, optimizer, scheduler, extras = chimera_turbo.apply(
67
  model,
68
  max_steps=args.max_steps,
69
+ lr=muon_lr,
70
  weight_decay=0.01,
71
+ warmup_steps=muon_warmup,
72
  use_compile=use_compile,
73
+ mtp_heads=0,
74
  llrd_decay=0.92,
75
  grokfast_alpha=0.98,
76
  grokfast_lambda=1.0,
77
  )
78
  model.train()
79
 
 
80
  loop_sched = ProgressiveLoopScheduler(args.max_steps, max_loops=3)
81
  cur_loops = 1
82