Lgr54HFi commited on
Commit
3859a82
·
verified ·
1 Parent(s): 05566cc

feat: loops.py v11 — aligned with GENESIS engine, no distiller overhead"

Browse files
Files changed (1) hide show
  1. chimera/training/loops.py +19 -32
chimera/training/loops.py CHANGED
@@ -9,20 +9,17 @@ import torch
9
 
10
  import chimera_turbo
11
 
12
- from .common import cosine_lr, save_final_checkpoint, save_training_checkpoint
13
  from .hyper import ProgressiveLoopScheduler
14
 
15
 
16
  def train_fast_loop(args, model, config, loader, compute_loss) -> str:
17
  optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
18
  os.makedirs(args.output_dir, exist_ok=True)
19
- log_f = open(os.path.join(args.output_dir, "log.jsonl"), "w", encoding="utf-8")
20
  model.train()
21
  step, total_loss, best_loss, toks = 0, 0.0, float("inf"), 0
22
  t0 = time.time()
23
  data_iter = iter(loader)
24
- warmup = min(args.warmup, max(1, args.max_steps // 10))
25
-
26
  while step < args.max_steps:
27
  try:
28
  batch = next(data_iter)
@@ -33,9 +30,6 @@ def train_fast_loop(args, model, config, loader, compute_loss) -> str:
33
  loss.backward()
34
  total_loss += float(loss.item())
35
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
36
- cur_lr = cosine_lr(step, warmup, args.max_steps, args.lr, args.lr * 0.1)
37
- for pg in optimizer.param_groups:
38
- pg["lr"] = cur_lr
39
  optimizer.step()
40
  optimizer.zero_grad(set_to_none=True)
41
  toks += batch["input_ids"].numel()
@@ -53,33 +47,29 @@ def train_fast_loop(args, model, config, loader, compute_loss) -> str:
53
 
54
 
55
  def train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo):
56
- # Legacy — unchanged
57
  pass
58
 
59
 
60
  def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
61
  use_compile = getattr(args, "compile", False)
62
 
63
- # Apply all paradigms: Muon + MTP + EMA Distillation
64
  model, optimizer, scheduler, extras = chimera_turbo.apply(
65
  model,
66
  max_steps=args.max_steps,
67
- lr=0.02, # Muon default LR (10× higher than AdamW, paper-standard)
68
  weight_decay=0.01,
69
- warmup_steps=200, # Short warmup for fast ramp
70
  use_compile=use_compile,
71
- use_muon=True,
72
- use_mtp=True,
73
- use_distill=True,
74
- mtp_heads=3, # Predict next 3 tokens
75
  )
76
  model.train()
77
 
78
  # Progressive looping
79
  loop_sched = ProgressiveLoopScheduler(args.max_steps, max_loops=3)
80
  cur_loops = 1
81
- print(f"[LOOP] Progressive looping: 1→2→3 over {args.max_steps} steps")
82
- print(f"[P5] Train mode: BitLinear STE (clamp-aware, NaN-safe)")
83
 
84
  use_bf16 = bool(args.bf16)
85
 
@@ -90,16 +80,14 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
90
  cur_seq = initial_seq
91
  eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
92
  loader = torch.utils.data.DataLoader(
93
- dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True
94
- )
95
  data_iter = iter(loader)
96
 
97
  print(f"\n{'=' * 65}")
98
- print(f"Training eff_batch={eff_batch} seq={cur_seq} loops={cur_loops}")
99
  print(f"{'=' * 65}\n")
100
 
101
  while step < args.max_steps:
102
- # ── Seq length scheduling ──
103
  if grow:
104
  ns = grow.get_seq_len(step)
105
  if ns != cur_seq:
@@ -107,19 +95,17 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
107
  dataset.set_seq_len(cur_seq)
108
  eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
109
  loader = torch.utils.data.DataLoader(
110
- dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True
111
- )
112
  data_iter = iter(loader)
113
  print(f" [P1] seq → {cur_seq} batch → {eff_batch}")
114
 
115
- # ── Loop scheduling ──
116
  new_loops = loop_sched.get_loops(step)
117
  if new_loops != cur_loops:
118
  cur_loops = new_loops
119
  raw = getattr(model, "_orig_mod", model)
120
  if hasattr(raw, "loop_controller"):
121
  raw.loop_controller.loop_default = cur_loops
122
- print(f" [LOOP] loops → {cur_loops}")
123
 
124
  if unfreezer:
125
  unfreezer.update(step)
@@ -132,12 +118,11 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
132
 
133
  loss_val = chimera_turbo.training_step(
134
  model, batch, optimizer, scheduler,
135
- extras=extras,
136
- grad_accum_steps=1, step=step,
137
  autocast_dtype=torch.bfloat16 if use_bf16 else None,
138
  )
139
 
140
- cur_lr = optimizer.param_groups[0]["lr"]
141
  if math.isfinite(loss_val):
142
  total_loss += loss_val
143
  valid_count += 1
@@ -149,12 +134,12 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
149
  avg = total_loss / max(1, valid_count)
150
  ppl = math.exp(min(avg, 20)) if math.isfinite(avg) else float("nan")
151
  tps = toks / dt if dt > 0 else 0
152
- eta = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0
153
  log_f.write(json.dumps({
154
  "step": step, "loss": round(avg, 4) if math.isfinite(avg) else None,
155
  "ppl": round(ppl, 2) if math.isfinite(ppl) else None,
156
  "lr": round(cur_lr, 6), "tok/s": round(tps),
157
- "seq_len": cur_seq, "loops": cur_loops,
158
  }) + "\n")
159
  log_f.flush()
160
  print(
@@ -165,10 +150,12 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
165
  total_loss, valid_count, toks, t0 = 0.0, 0, 0, time.time()
166
 
167
  if step % args.save_every == 0:
168
- d = save_training_checkpoint(model, config, step, os.path.join(args.output_dir, f"ckpt-{step}"))
 
169
  print(f" [SAVE] {d}")
170
 
171
- d = save_final_checkpoint(model, config, step, best_loss, os.path.join(args.output_dir, "final"))
 
172
  log_f.close()
173
  print(f"\nDONE — best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}")
174
  return d
 
9
 
10
  import chimera_turbo
11
 
12
+ from .common import save_final_checkpoint, save_training_checkpoint
13
  from .hyper import ProgressiveLoopScheduler
14
 
15
 
16
  def train_fast_loop(args, model, config, loader, compute_loss) -> str:
17
  optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
18
  os.makedirs(args.output_dir, exist_ok=True)
 
19
  model.train()
20
  step, total_loss, best_loss, toks = 0, 0.0, float("inf"), 0
21
  t0 = time.time()
22
  data_iter = iter(loader)
 
 
23
  while step < args.max_steps:
24
  try:
25
  batch = next(data_iter)
 
30
  loss.backward()
31
  total_loss += float(loss.item())
32
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
 
 
 
33
  optimizer.step()
34
  optimizer.zero_grad(set_to_none=True)
35
  toks += batch["input_ids"].numel()
 
47
 
48
 
49
  def train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo):
 
50
  pass
51
 
52
 
53
  def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
54
  use_compile = getattr(args, "compile", False)
55
 
 
56
  model, optimizer, scheduler, extras = chimera_turbo.apply(
57
  model,
58
  max_steps=args.max_steps,
59
+ lr=0.02,
60
  weight_decay=0.01,
61
+ warmup_steps=200,
62
  use_compile=use_compile,
63
+ mtp_heads=3,
64
+ llrd_decay=0.85,
65
+ grokfast_alpha=0.98,
66
+ grokfast_lambda=2.0,
67
  )
68
  model.train()
69
 
70
  # Progressive looping
71
  loop_sched = ProgressiveLoopScheduler(args.max_steps, max_loops=3)
72
  cur_loops = 1
 
 
73
 
74
  use_bf16 = bool(args.bf16)
75
 
 
80
  cur_seq = initial_seq
81
  eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
82
  loader = torch.utils.data.DataLoader(
83
+ dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
 
84
  data_iter = iter(loader)
85
 
86
  print(f"\n{'=' * 65}")
87
+ print(f"Training batch={eff_batch} seq={cur_seq} loops={cur_loops}")
88
  print(f"{'=' * 65}\n")
89
 
90
  while step < args.max_steps:
 
91
  if grow:
92
  ns = grow.get_seq_len(step)
93
  if ns != cur_seq:
 
95
  dataset.set_seq_len(cur_seq)
96
  eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
97
  loader = torch.utils.data.DataLoader(
98
+ dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
 
99
  data_iter = iter(loader)
100
  print(f" [P1] seq → {cur_seq} batch → {eff_batch}")
101
 
 
102
  new_loops = loop_sched.get_loops(step)
103
  if new_loops != cur_loops:
104
  cur_loops = new_loops
105
  raw = getattr(model, "_orig_mod", model)
106
  if hasattr(raw, "loop_controller"):
107
  raw.loop_controller.loop_default = cur_loops
108
+ print(f" [LOOP] → {cur_loops}")
109
 
110
  if unfreezer:
111
  unfreezer.update(step)
 
118
 
119
  loss_val = chimera_turbo.training_step(
120
  model, batch, optimizer, scheduler,
121
+ extras=extras, grad_accum_steps=1, step=step,
 
122
  autocast_dtype=torch.bfloat16 if use_bf16 else None,
123
  )
124
 
125
+ cur_lr = optimizer.param_groups[0]["lr"] * optimizer.param_groups[0].get("lr_scale", 1.0)
126
  if math.isfinite(loss_val):
127
  total_loss += loss_val
128
  valid_count += 1
 
134
  avg = total_loss / max(1, valid_count)
135
  ppl = math.exp(min(avg, 20)) if math.isfinite(avg) else float("nan")
136
  tps = toks / dt if dt > 0 else 0
137
+ eta = (args.max_steps - step) / max(1, step) * (time.time() - t0) / 3600 if step > 0 else 0
138
  log_f.write(json.dumps({
139
  "step": step, "loss": round(avg, 4) if math.isfinite(avg) else None,
140
  "ppl": round(ppl, 2) if math.isfinite(ppl) else None,
141
  "lr": round(cur_lr, 6), "tok/s": round(tps),
142
+ "seq": cur_seq, "loops": cur_loops,
143
  }) + "\n")
144
  log_f.flush()
145
  print(
 
150
  total_loss, valid_count, toks, t0 = 0.0, 0, 0, time.time()
151
 
152
  if step % args.save_every == 0:
153
+ d = save_training_checkpoint(model, config, step,
154
+ os.path.join(args.output_dir, f"ckpt-{step}"))
155
  print(f" [SAVE] {d}")
156
 
157
+ d = save_final_checkpoint(model, config, step, best_loss,
158
+ os.path.join(args.output_dir, "final"))
159
  log_f.close()
160
  print(f"\nDONE — best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}")
161
  return d