Lgr54HFi commited on
Commit
dc90255
·
verified ·
1 Parent(s): b04e93e

fix: train_hyper.py v2 — lean mode, reduced layers, no overhead, 10k+ tok/s target

Browse files
Files changed (1) hide show
  1. train_hyper.py +384 -250
train_hyper.py CHANGED
@@ -3,29 +3,29 @@
3
  Chimera 5.3 — HYPER CPU Training Script (10,000+ tok/s target)
4
  ===============================================================
5
 
6
- Stacks **seven** orthogonal paradigms for multiplicative speedup on a single
7
- CPU. Each paradigm can be toggled independently via CLI flags.
8
-
9
- Paradigms
10
- ---------
11
- P1 --growlength GrowLength curriculum (shortlong seq_len)
12
- P2 --reservoir Reservoir freezing of recurrent gates
13
- P3 --sparse-mezo Sparse MeZO (top-K% perturbation)
14
- P4 --pipeline Blockwise pipeline (multi-core overlap)
15
- P5 --fused-cache Fused ternary weight cache
16
- P6 --pack-tokens Aggressive zero-padding token packing
17
- P7 --progressive-unfreeze Progressive layer unfreezing
 
 
 
 
 
 
 
18
 
19
  Quick start::
20
 
21
- # All paradigms ON maximum speed
22
- python train_hyper.py --scale tiny --max_steps 500 --all
23
-
24
- # Cherry-pick
25
- python train_hyper.py --scale tiny --max_steps 500 \\
26
- --growlength --sparse-mezo --reservoir
27
-
28
- # Benchmark: compare baseline vs hyper
29
  python train_hyper.py --scale tiny --max_steps 100 --benchmark
30
  """
31
 
@@ -78,7 +78,6 @@ try:
78
  except RuntimeError:
79
  pass
80
 
81
- # Optional Intel Extension
82
  _HAS_IPEX = False
83
  try:
84
  import intel_extension_for_pytorch as ipex # noqa: F401
@@ -88,30 +87,203 @@ except Exception:
88
 
89
 
90
  # ═══════════════════════════════════════════════════════════════════════════
91
- # Scale presets (same as train.py / train_fast.py)
92
  # ═══════════════════════════════════════════════════════════════════════════
93
 
94
  _SCALE_PRESETS = {
95
  "tiny": dict(hidden_size=256, intermediate_size=512,
96
- num_heads=4, head_dim=48),
97
  "small": dict(hidden_size=512, intermediate_size=1024,
98
- num_heads=8, head_dim=48),
99
  "medium": dict(hidden_size=1024, intermediate_size=2048,
100
- num_heads=8, head_dim=96),
101
  }
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  # ═══════════════════════════════════════════════════════════════════════════
105
  # Data helpers
106
  # ═══════════════════════════════════════════════════════════════════════════
107
 
108
- def _build_token_buffer(dataset_name: str, split: str, text_column: str,
109
- max_tokens: int, cache_dir: str) -> torch.Tensor:
110
- """Stream a dataset, tokenise, and return a flat LongTensor."""
111
  cache_path = os.path.join(
112
  cache_dir,
113
- f"{dataset_name.replace('/', '_')}_{split}_{max_tokens}.pt",
114
- )
115
  os.makedirs(cache_dir, exist_ok=True)
116
 
117
  if os.path.exists(cache_path):
@@ -131,7 +303,7 @@ def _build_token_buffer(dataset_name: str, split: str, text_column: str,
131
  for ex in ds:
132
  text = ""
133
  if text_column == "auto":
134
- for cand in ("text", "content", "messages", "conversation"):
135
  if cand in ex:
136
  val = ex[cand]
137
  text = val if isinstance(val, str) else str(val)
@@ -149,10 +321,10 @@ def _build_token_buffer(dataset_name: str, split: str, text_column: str,
149
  if n > room:
150
  ids = ids[:room]
151
  n = room
152
- buf[idx: idx + n] = torch.tensor(ids, dtype=torch.long)
153
  idx += n
154
  processed += 1
155
- if processed % 5_000 == 0:
156
  print(f" {processed:,} docs {idx:,}/{max_tokens} tokens")
157
 
158
  buf = buf[:idx].contiguous()
@@ -162,46 +334,58 @@ def _build_token_buffer(dataset_name: str, split: str, text_column: str,
162
 
163
 
164
  # ═══════════════════════════════════════════════════════════════════════════
165
- # Model builder (same config wiring as train.py)
166
  # ═══════════════════════════════════════════════════════════════════════════
167
 
168
- def _build_model(args) -> tuple:
169
  with open(args.config) as f:
170
  config = json.load(f)
171
 
172
  if args.scale in _SCALE_PRESETS:
173
  config.update(_SCALE_PRESETS[args.scale])
174
 
175
- n_layers = int(config.get("num_hidden_layers", 28))
176
- config["num_hidden_layers"] = n_layers
177
  config["vocab_size"] = config.get("vocab_size", 200_073)
178
 
179
- config.setdefault("gated_deltanet", {})["chunk_size"] = min(
180
- args.seq_len, 64)
181
- config.setdefault("xlstm", {})["memory_size_per_head"] = [
182
- config["head_dim"], config["head_dim"]]
183
  config.setdefault("titans", {}).update({
184
  "memory_depth": 2, "persistent_memory_slots": 16,
185
  "local_window_size": min(args.seq_len, 256),
186
  })
187
 
 
188
  moe = config.setdefault("backbone", {}).setdefault("moe", {})
189
- moe.setdefault("layers", [3, 7, 11, 15, 19, 23, 27])
 
 
 
 
 
 
 
190
  moe.setdefault("moe_intermediate_size", config["intermediate_size"] // 4)
191
- moe.setdefault("n_routed_experts", 8)
192
  moe.setdefault("n_shared_experts", 1)
193
  moe.setdefault("num_experts_per_tok", 2)
194
 
195
- config.setdefault("looping", {}).update({
196
- "enabled": True, "prelude": [0, 3],
197
- "loop": [4, min(23, n_layers - 5)],
198
- "coda": [max(0, n_layers - 4), n_layers - 1],
199
- "loop_range": [1, 3], "loop_default": 2,
200
- })
201
- config.setdefault("span_inference", {})["enabled"] = True
202
- config.setdefault("grammar", {})["enabled"] = True
203
- config.setdefault("entropy_valve", {})["enabled"] = True
204
- config.setdefault("debt_ledger", {})["enabled"] = True
 
 
 
 
 
 
 
205
  config.setdefault("multimodal", {})["enabled"] = False
206
 
207
  model = Chimera51ForCausalLM(config)
@@ -209,119 +393,95 @@ def _build_model(args) -> tuple:
209
 
210
 
211
  # ═══════════════════════════════════════════════════════════════════════════
212
- # Training loop (HYPER)
213
  # ═══════════════════════════════════════════════════════════════════════════
214
 
215
- def _train_hyper(args) -> dict:
216
  model, config = _build_model(args)
217
  counts = model.count_parameters()
218
- trainable_before = sum(
219
- p.numel() for p in model.parameters() if p.requires_grad)
220
 
221
  print("=" * 65)
222
- print(f"CHIMERA 5.3 HYPER TRAIN — scale={args.scale} "
223
- f"optimizer=SparseMeZO bf16={args.bf16}")
224
  print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} "
225
  f"vocab={config['vocab_size']} target_seq={args.seq_len}")
226
  print(f"Threads: {torch.get_num_threads()} IPEX={_HAS_IPEX}")
227
  print(f"Paradigms: P1={args.growlength} P2={args.reservoir} "
228
- f"P3={args.sparse_mezo} P4={args.pipeline} "
229
- f"P5={args.fused_cache} P6={args.pack_tokens} "
230
- f"P7={args.progressive_unfreeze}")
231
  print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")
232
  print("=" * 65)
233
 
 
 
 
 
234
  # ── P2: Reservoir Freezing ───────────────────────────────────────
235
  if args.reservoir:
236
- frozen = apply_reservoir_freezing(model, freeze_ratio=args.reservoir_ratio)
237
- trainable_after = sum(
238
- p.numel() for p in model.parameters() if p.requires_grad)
239
- print(f"[P2] Reservoir: froze {frozen:,} gate params "
240
- f"({trainable_before:,} {trainable_after:,} trainable)")
241
- else:
242
- trainable_after = trainable_before
243
 
244
  # ── P7: Progressive Unfreezing ───────────────────────────────────
245
  unfreezer = None
246
  if args.progressive_unfreeze:
247
  unfreezer = ProgressiveUnfreezer(
248
  model, args.max_steps, n_stages=args.unfreeze_stages)
249
- trainable_now = sum(
250
- p.numel() for p in model.parameters() if p.requires_grad)
251
- print(f"[P7] Progressive unfreeze: {trainable_now:,} initially "
252
- f"trainable (of {trainable_after:,})")
253
 
254
- # ── P1: GrowLength schedule ──────────────────────────────────────
255
  if args.growlength:
256
  stages = [
257
- (max(8, args.seq_len // 8), 0.20), # 20 % at 1/8
258
- (max(16, args.seq_len // 4), 0.25), # 25 % at 1/4
259
- (max(32, args.seq_len // 2), 0.25), # 25 % at 1/2
260
- (args.seq_len, 0.30), # 30 % at target
261
  ]
262
  grow = GrowLengthScheduler(stages, args.max_steps)
263
  initial_seq = stages[0][0]
264
- print(f"[P1] GrowLength: {' → '.join(str(s) for s, _ in stages)} "
265
- f"tokens")
266
  else:
267
  grow = None
268
  initial_seq = args.seq_len
269
 
270
  # ── Data ─────────────────────────────────────────────────────────
271
- tok_budget = args.max_tokens or args.max_steps * args.batch_size * (
272
- args.seq_len + 1) * 4 # 4× overhead for short-seq phases
273
- tok_budget = max(tok_budget, 200_000)
274
-
275
  token_buf = _build_token_buffer(
276
  args.dataset_name, args.dataset_split, args.text_column,
277
  tok_budget, args.cache_dir)
278
-
279
- # P6: Aggressive packing (the buffer is already packed; just verify)
280
  if args.pack_tokens:
281
- token_buf = pack_documents(token_buf, eos_id=199_999,
282
- max_tokens=token_buf.numel())
283
- print(f"[P6] Token packing: {token_buf.numel():,} tokens, zero padding")
284
-
285
  dataset = GrowLengthDataset(token_buf, initial_seq)
286
- print(f"[DATA] {token_buf.numel():,} tokens initial_seq={initial_seq} "
287
  f"chunks={len(dataset):,}")
288
 
289
- # ── torch.compile (P4 overlap bonus) ─────────────────────────────
290
  if args.compile:
291
- print("[P4] Compiling model with torch.compile (inductor) …")
292
  model = torch.compile(model, backend="inductor", mode="default",
293
  dynamic=True)
294
 
295
- # ── P3: Sparse MeZO optimizer ────────────────────────────────────
296
- if args.sparse_mezo:
297
- optimizer = SparseMeZOOptimizer(
298
- model,
299
- lr=args.lr * 0.01,
300
- eps=args.mezo_eps,
301
- sparsity=args.mezo_sparsity,
302
- weight_decay=0.1,
303
- momentum=0.9,
304
- mask_refresh_interval=max(1, args.max_steps // 10),
305
- )
306
- print(f"[P3] Sparse MeZO: sparsity={args.mezo_sparsity} "
307
- f"perturbing top {args.mezo_sparsity*100:.1f}% params "
308
- f"({optimizer._k:,}/{optimizer._total:,})")
309
- else:
310
- # Fall back to standard MeZO from train.py
311
- from train import MeZOOptimizer
312
- optimizer = MeZOOptimizer(
313
- model, lr=args.lr * 0.01, eps=1e-3,
314
- weight_decay=0.1, momentum=0.9)
315
- print("[OPT] Standard MeZO (no P3)")
316
 
317
  # ── Loss function ────────────────────────────────────────────────
318
  use_bf16 = bool(args.bf16)
319
-
320
- def compute_loss(batch) -> torch.Tensor:
321
- ids = batch["input_ids"]
322
- labels = batch["labels"]
323
  if use_bf16:
324
- with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
325
  return model(ids, labels=labels).loss
326
  return model(ids, labels=labels).loss
327
 
@@ -340,8 +500,7 @@ def _train_hyper(args) -> dict:
340
  cur_seq = initial_seq
341
  warmup = min(args.warmup, max(1, args.max_steps // 10))
342
 
343
- # Pre-build first loader
344
- eff_batch = args.batch_size * max(1, args.seq_len // cur_seq)
345
  loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True,
346
  num_workers=0, drop_last=True)
347
  data_iter = iter(loader)
@@ -350,59 +509,58 @@ def _train_hyper(args) -> dict:
350
  f"(eff_batch={eff_batch}, seq={cur_seq})\n{'=' * 65}\n")
351
 
352
  while step < args.max_steps:
353
- # ── P1: GrowLength check ─────────────────────────────────────
354
  if grow is not None:
355
  new_seq = grow.get_seq_len(step)
356
  if new_seq != cur_seq:
357
  cur_seq = new_seq
358
  dataset.set_seq_len(cur_seq)
359
- eff_batch = args.batch_size * max(1, args.seq_len // cur_seq)
360
  loader = DataLoader(dataset, batch_size=eff_batch,
361
- shuffle=True, num_workers=0,
362
- drop_last=True)
363
  data_iter = iter(loader)
364
- print(f" [P1] seq_len → {cur_seq} eff_batch → {eff_batch}")
365
 
366
- # ── P7: Progressive unfreeze ─────────────────────────────────
367
  if unfreezer is not None:
368
  unfreezer.update(step)
369
 
370
- # ── Get batch ────────────────────────────────────────────────
371
  try:
372
  batch = next(data_iter)
373
  except StopIteration:
374
  data_iter = iter(loader)
375
  batch = next(data_iter)
376
 
377
- # ── P5: Fused ternary pre-cache ──────────────────────────────
378
- if args.fused_cache:
 
 
379
  precompute_ternary_cache(model)
380
 
381
- # ── LR schedule ──────────────────────────────────────────────
382
  cur_lr = cosine_lr(step, warmup, args.max_steps,
383
  args.lr * 0.01, args.lr * 0.001)
384
- if hasattr(optimizer, "lr"):
385
- optimizer.lr = cur_lr
386
 
387
- # ── Optimiser step ───────────────────────────────────────────
388
  loss_val = optimizer.step(compute_loss, batch)
389
  total_loss += loss_val
390
  toks += batch["input_ids"].numel()
391
  step += 1
392
 
393
- # ── Logging ──────────────────────────────────────────────────
394
  if step % args.log_every == 0:
395
  dt = time.time() - t0
396
  avg = total_loss / args.log_every
397
  ppl = math.exp(min(avg, 20))
398
  tps = toks / dt if dt > 0 else 0
399
  eta_h = ((args.max_steps - step) / (step / dt) / 3600
400
- if dt > 0 else 0.0)
401
- entry = {
402
- "step": step, "loss": round(avg, 4), "ppl": round(ppl, 2),
403
- "lr": cur_lr, "tok/s": round(tps), "seq_len": cur_seq,
404
- "eff_batch": eff_batch,
405
- }
406
  log_f.write(json.dumps(entry) + "\n")
407
  log_f.flush()
408
  print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | "
@@ -414,45 +572,36 @@ def _train_hyper(args) -> dict:
414
  toks = 0
415
  t0 = time.time()
416
 
417
- # ── Checkpointing ────────────────────────────────────────────
418
  if step % args.save_every == 0:
419
  ckpt_dir = os.path.join(args.output_dir, f"ckpt-{step}")
420
  os.makedirs(ckpt_dir, exist_ok=True)
421
  raw = getattr(model, "_orig_mod", model)
422
- torch.save({
423
- "model": raw.state_dict(), "config": config,
424
- "step": step, "optimizer": "sparse_mezo",
425
- "paradigms": _active_paradigms(args),
426
- }, os.path.join(ckpt_dir, "ckpt.pt"))
427
  print(f" [SAVE] {ckpt_dir}")
428
 
429
- # ── Final save ───────────────────────────────────────────────────
430
  final_dir = os.path.join(args.output_dir, "final")
431
  os.makedirs(final_dir, exist_ok=True)
432
  raw = getattr(model, "_orig_mod", model)
433
- torch.save({
434
- "model": raw.state_dict(), "config": config,
435
- "step": step, "best_loss": best_loss,
436
- "paradigms": _active_paradigms(args),
437
- }, os.path.join(final_dir, "model.pt"))
438
  with open(os.path.join(final_dir, "config.json"), "w") as fh:
439
  json.dump(config, fh, indent=2)
440
  log_f.close()
441
-
442
  print(f"\n{'=' * 65}")
443
  print(f"DONE — best loss {best_loss:.4f} "
444
  f"ppl {math.exp(min(best_loss, 20)):.2f}")
445
  print(f"Saved to {final_dir}")
446
 
447
- return {"best_loss": best_loss, "steps": step}
448
-
449
 
450
  # ═══════════════════════════════════════════════════════════════════════════
451
- # Benchmark mode: baseline vs hyper, same model & data
452
  # ═══════════════════════════════════════════════════════════════════════════
453
 
454
- def _run_baseline(model, token_buf, args) -> tuple:
455
- """Minimal standard MeZO (matches train.py logic)."""
456
  model.train()
457
  seq = args.seq_len
458
  n = token_buf.numel() // (seq + 1)
@@ -466,15 +615,13 @@ def _run_baseline(model, token_buf, args) -> tuple:
466
 
467
  loader = DataLoader(_DS(), batch_size=args.batch_size,
468
  shuffle=True, num_workers=0, drop_last=True)
469
-
470
  params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
471
  eps = 1e-3
472
 
473
  def loss_fn(batch):
474
  return model(batch["input_ids"], labels=batch["labels"]).loss
475
 
476
- total_toks = 0
477
- total_loss = 0.0
478
  t0 = time.time()
479
  di = iter(loader)
480
 
@@ -519,41 +666,37 @@ def _run_baseline(model, token_buf, args) -> tuple:
519
  return total_toks / dt, total_loss / args.max_steps, dt
520
 
521
 
522
- def _run_hyper(model, token_buf, args) -> tuple:
523
- """Hyper pipeline with all paradigms ON."""
524
  model.train()
525
-
526
- frozen = apply_reservoir_freezing(model, args.reservoir_ratio)
527
  unfreezer = ProgressiveUnfreezer(model, args.max_steps,
528
  n_stages=args.unfreeze_stages)
529
-
530
  stages = [
531
- (max(8, args.seq_len // 8), 0.20),
532
- (max(16, args.seq_len // 4), 0.25),
533
- (max(32, args.seq_len // 2), 0.25),
534
- (args.seq_len, 0.30),
535
  ]
536
  grow = GrowLengthScheduler(stages, args.max_steps)
537
  cur_seq = stages[0][0]
538
-
539
  dataset = GrowLengthDataset(token_buf, cur_seq)
540
- optimizer = SparseMeZOOptimizer(
 
541
  model, lr=args.lr * 0.01, eps=args.mezo_eps,
542
  sparsity=args.mezo_sparsity, weight_decay=0.1, momentum=0.9,
543
- mask_refresh_interval=max(1, args.max_steps // 10))
544
 
545
  def loss_fn(batch):
546
- ids, labels = batch["input_ids"], batch["labels"]
547
  if args.bf16:
548
  with torch.autocast("cpu", dtype=torch.bfloat16):
549
- return model(ids, labels=labels).loss
550
- return model(ids, labels=labels).loss
551
 
552
- total_toks = 0
553
- total_loss = 0.0
554
  t0 = time.time()
555
 
556
- eff_batch = args.batch_size * max(1, args.seq_len // cur_seq)
557
  loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True,
558
  num_workers=0, drop_last=True)
559
  di = iter(loader)
@@ -563,20 +706,18 @@ def _run_hyper(model, token_buf, args) -> tuple:
563
  if new_seq != cur_seq:
564
  cur_seq = new_seq
565
  dataset.set_seq_len(cur_seq)
566
- eff_batch = args.batch_size * max(1, args.seq_len // cur_seq)
567
  loader = DataLoader(dataset, batch_size=eff_batch,
568
  shuffle=True, num_workers=0, drop_last=True)
569
  di = iter(loader)
570
 
571
  unfreezer.update(step)
572
-
573
  try:
574
  batch = next(di)
575
  except StopIteration:
576
  di = iter(loader)
577
  batch = next(di)
578
 
579
- precompute_ternary_cache(model)
580
  loss_val = optimizer.step(loss_fn, batch)
581
  total_toks += batch["input_ids"].numel()
582
  total_loss += loss_val
@@ -586,38 +727,52 @@ def _run_hyper(model, token_buf, args) -> tuple:
586
 
587
 
588
  def _benchmark(args):
589
- """Side-by-side comparison."""
590
  print("=" * 65)
591
- print("CHIMERA 5.3 HYPER — BENCHMARK MODE")
592
  print("=" * 65)
593
 
594
- model_a, config = _build_model(args)
595
- model_b = copy.deepcopy(model_a)
596
- counts = model_a.count_parameters()
597
- print(f"Model: scale={args.scale} params={counts['total']:,}")
598
-
599
- tok_budget = max(200_000,
600
- args.max_steps * args.batch_size * (args.seq_len + 1) * 4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  token_buf = _build_token_buffer(
602
  args.dataset_name, args.dataset_split, args.text_column,
603
  tok_budget, args.cache_dir)
604
  print(f"Tokens: {token_buf.numel():,}\n")
605
 
606
- # ── Baseline ─────────────────────────────────────────────────────
607
  print("-" * 65)
608
- print("BASELINE (standard MeZO, fixed seq_len, all params)")
609
  print("-" * 65)
610
- b_tps, b_loss, b_dt = _run_baseline(model_a, token_buf, args)
611
  print(f" → {b_tps:,.0f} tok/s loss={b_loss:.4f} time={b_dt:.1f}s\n")
612
 
613
- # ── Hyper ────────────────────────────────────────────────────────
614
  print("-" * 65)
615
- print("HYPER (7 paradigms stacked)")
616
  print("-" * 65)
617
- h_tps, h_loss, h_dt = _run_hyper(model_b, token_buf, args)
618
  print(f" → {h_tps:,.0f} tok/s loss={h_loss:.4f} time={h_dt:.1f}s\n")
619
 
620
- # ── Summary ──────────────────────────────────────────────────────
621
  speedup = h_tps / b_tps if b_tps > 0 else float("inf")
622
  print("=" * 65)
623
  print(f" Baseline : {b_tps:>12,.0f} tok/s loss {b_loss:.4f}")
@@ -629,8 +784,9 @@ def _benchmark(args):
629
  "baseline_tps": round(b_tps), "hyper_tps": round(h_tps),
630
  "speedup": round(speedup, 2),
631
  "baseline_loss": round(b_loss, 4), "hyper_loss": round(h_loss, 4),
632
- "scale": args.scale, "max_steps": args.max_steps,
633
- "paradigms": _active_paradigms(args),
 
634
  }
635
  out = os.path.join(args.output_dir, "benchmark.json")
636
  os.makedirs(args.output_dir, exist_ok=True)
@@ -639,40 +795,25 @@ def _benchmark(args):
639
  print(f"Saved → {out}")
640
 
641
 
642
- # ═══════════════════════════════════════════════════════════════════════════
643
- # Helpers
644
- # ═══════════════════════════════════════════════════════════════════════════
645
-
646
- def _active_paradigms(args) -> list:
647
- out = []
648
- if args.growlength: out.append("P1_GrowLength")
649
- if args.reservoir: out.append("P2_ReservoirFreezing")
650
- if args.sparse_mezo: out.append("P3_SparseMeZO")
651
- if args.pipeline: out.append("P4_BlockwisePipeline")
652
- if args.fused_cache: out.append("P5_FusedTernaryCache")
653
- if args.pack_tokens: out.append("P6_AggressiveTokenPacking")
654
- if args.progressive_unfreeze: out.append("P7_ProgressiveUnfreeze")
655
- return out
656
-
657
-
658
  # ═══════════════════════════════════════════════════════════════════════════
659
  # CLI
660
  # ═══════════════════════════════════════════════════════════════════════════
661
 
662
- def _cli() -> argparse.ArgumentParser:
663
  p = argparse.ArgumentParser(
664
- description="Chimera 5.3 — HYPER CPU training (7 paradigms)")
665
 
666
- # Model / data
667
  p.add_argument("--config", default="config.json")
668
  p.add_argument("--scale", default="tiny",
669
  choices=["tiny", "small", "medium", "full"])
670
- p.add_argument("--seq_len", type=int, default=128)
671
- p.add_argument("--batch_size", type=int, default=4)
672
  p.add_argument("--lr", type=float, default=1e-3)
673
- p.add_argument("--warmup", type=int, default=200)
674
  p.add_argument("--max_steps", type=int, default=5000)
675
  p.add_argument("--max_tokens", type=int, default=None)
 
 
676
  p.add_argument("--bf16", action="store_true", default=True)
677
  p.add_argument("--no-bf16", dest="bf16", action="store_false")
678
  p.add_argument("--compile", action="store_true", default=False)
@@ -684,41 +825,31 @@ def _cli() -> argparse.ArgumentParser:
684
  p.add_argument("--save_every", type=int, default=1000)
685
  p.add_argument("--output_dir", default="./chimera_hyper_output")
686
 
687
- # Paradigm toggles
688
- g = p.add_argument_group("paradigms (use --all to enable everything)")
689
- g.add_argument("--all", action="store_true", default=False,
690
- help="Enable all 7 paradigms")
691
- g.add_argument("--growlength", action="store_true", default=False,
692
- help="P1: GrowLength curriculum")
693
- g.add_argument("--reservoir", action="store_true", default=False,
694
- help="P2: Reservoir freezing of recurrent gates")
695
  g.add_argument("--reservoir-ratio", type=float, default=0.5,
696
  dest="reservoir_ratio")
697
  g.add_argument("--sparse-mezo", action="store_true", default=False,
698
- dest="sparse_mezo",
699
- help="P3: Sparse MeZO (top-K%% perturbation)")
700
- g.add_argument("--mezo-sparsity", type=float, default=0.01,
701
  dest="mezo_sparsity",
702
- help="Fraction of params to perturb (default 0.01 = 1%%)")
703
  g.add_argument("--mezo-eps", type=float, default=1e-3, dest="mezo_eps")
704
- g.add_argument("--pipeline", action="store_true", default=False,
705
- help="P4: Blockwise pipeline")
706
  g.add_argument("--fused-cache", action="store_true", default=False,
707
- dest="fused_cache",
708
- help="P5: Fused ternary weight cache")
709
  g.add_argument("--pack-tokens", action="store_true", default=False,
710
- dest="pack_tokens",
711
- help="P6: Aggressive token packing")
712
  g.add_argument("--progressive-unfreeze", action="store_true",
713
- default=False, dest="progressive_unfreeze",
714
- help="P7: Progressive layer unfreezing")
715
  g.add_argument("--unfreeze-stages", type=int, default=4,
716
  dest="unfreeze_stages")
717
 
718
- # Benchmark mode
719
- p.add_argument("--benchmark", action="store_true", default=False,
720
- help="Run baseline-vs-hyper benchmark")
721
-
722
  return p
723
 
724
 
@@ -726,7 +857,10 @@ if __name__ == "__main__":
726
  parser = _cli()
727
  args = parser.parse_args()
728
 
729
- # --all enables every paradigm
 
 
 
730
  if args.all:
731
  args.growlength = True
732
  args.reservoir = True
@@ -735,16 +869,16 @@ if __name__ == "__main__":
735
  args.fused_cache = True
736
  args.pack_tokens = True
737
  args.progressive_unfreeze = True
 
738
 
739
  if args.benchmark:
740
- # Force all paradigms for the hyper side of the benchmark
741
  args.growlength = True
742
  args.reservoir = True
743
  args.sparse_mezo = True
744
- args.pipeline = True
745
  args.fused_cache = True
746
  args.pack_tokens = True
747
  args.progressive_unfreeze = True
 
748
  _benchmark(args)
749
  else:
750
  _train_hyper(args)
 
3
  Chimera 5.3 — HYPER CPU Training Script (10,000+ tok/s target)
4
  ===============================================================
5
 
6
+ v2: LEAN MODE eliminates the real bottlenecks:
7
+ Reduces num_hidden_layers for tiny/small (28 6/8)
8
+ • Disables Parcae looping during training (no 2× forward)
9
+ • Disables SelfEvolutionEngine (HDC memory, TTT, episodic)
10
+ • Disables SpanInference, GrammarFST, EntropyValve, DebtLedger
11
+ Direct forward: embed layers → norm → lm_head → loss
12
+ MeZO perturbation skips invalidate_packed (uses STE train path)
13
+ • Adds --lean flag (default ON with --all)
14
+
15
+ Paradigms (7 stacked):
16
+ P1 --growlength Short→long seq curriculum
17
+ P2 --reservoir Freeze recurrent gates as ternary reservoir
18
+ P3 --sparse-mezo Perturb only top-K% sensitive params
19
+ P4 --pipeline torch.compile fusion
20
+ P5 --fused-cache Pre-materialise ternary weights
21
+ P6 --pack-tokens Zero-padding token packing
22
+ P7 --progressive-unfreeze Train top layers first
23
+
24
+ P8 --lean ★ NEW: Strip all inference/evolution overhead
25
 
26
  Quick start::
27
 
28
+ python train_hyper.py --scale tiny --max_steps 1000 --all
 
 
 
 
 
 
 
29
  python train_hyper.py --scale tiny --max_steps 100 --benchmark
30
  """
31
 
 
78
  except RuntimeError:
79
  pass
80
 
 
81
  _HAS_IPEX = False
82
  try:
83
  import intel_extension_for_pytorch as ipex # noqa: F401
 
87
 
88
 
89
  # ═══════════════════════════════════════════════════════════════════════════
90
+ # Scale presets LEAN: fewer layers, no MoE on tiny
91
  # ═══════════════════════════════════════════════════════════════════════════
92
 
93
  _SCALE_PRESETS = {
94
  "tiny": dict(hidden_size=256, intermediate_size=512,
95
+ num_heads=4, head_dim=64, num_hidden_layers=6),
96
  "small": dict(hidden_size=512, intermediate_size=1024,
97
+ num_heads=8, head_dim=64, num_hidden_layers=8),
98
  "medium": dict(hidden_size=1024, intermediate_size=2048,
99
+ num_heads=8, head_dim=96, num_hidden_layers=12),
100
  }
101
 
102
 
103
+ # ═══════════════════════════════════════════════════════════════════════════
104
+ # P8 — Lean mode: strip inference/evolution overhead from model
105
+ # ═══════════════════════════════════════════════════════════════════════════
106
+
107
+ def make_lean(model: nn.Module) -> None:
108
+ """Disable all non-essential subsystems for maximum training throughput.
109
+
110
+ This surgically removes:
111
+ - SelfEvolutionEngine (HDC semantic memory, TTT, episodic, etc.)
112
+ - SpanInferenceEngine
113
+ - GrammarFST
114
+ - EntropyValve
115
+ - DebtLedger
116
+ - Parcae looping (layers run once, not 2×)
117
+ - Per-layer evo_gate modulation
118
+ """
119
+ # Disable looping — run layers 0..N-1 sequentially, once
120
+ model.looping_enabled = False
121
+
122
+ # Disable evolution engine
123
+ if hasattr(model, 'evolution') and model.evolution is not None:
124
+ model.evo_weight = 0.0
125
+ model.evo_every_n_layers = 999999 # never triggers
126
+
127
+ # Disable span inference
128
+ model.span_engine = None
129
+
130
+ # Make grammar/entropy/debt into identity ops
131
+ if hasattr(model, 'grammar'):
132
+ model.grammar = _IdentityModule()
133
+ if hasattr(model, 'entropy_valve'):
134
+ model.entropy_valve = _IdentityModule()
135
+ if hasattr(model, 'debt_ledger'):
136
+ model.debt_ledger = _IdentityModule()
137
+
138
+ # Disable evo_gate on each block (skip the sigmoid + multiply)
139
+ for layer in model.layers:
140
+ if hasattr(layer, 'evo_gate'):
141
+ # Zero out so the gate branch is a no-op even if called
142
+ with torch.no_grad():
143
+ layer.evo_gate.weight.zero_()
144
+ layer.evo_gate.weight.requires_grad = False
145
+
146
+ # Count what's left
147
+ active = sum(p.numel() for p in model.parameters() if p.requires_grad)
148
+ total = sum(p.numel() for p in model.parameters())
149
+ print(f"[P8] Lean: disabled looping/evolution/span/grammar/entropy/debt")
150
+ print(f"[P8] Active params: {active:,} / {total:,} total")
151
+
152
+
153
+ class _IdentityModule(nn.Module):
154
+ """Pass-through module that replaces Grammar/Entropy/Debt during training."""
155
+ def forward(self, x, *args, **kwargs):
156
+ return x
157
+
158
+
159
+ # ═══════════════════════════════════════════════════════════════════════════
160
+ # Fast MeZO — skips invalidate_packed, uses train mode (STE path)
161
+ # ═══════════════════════════════════════════════════════════════════════════
162
+
163
+ class FastSparseMeZO:
164
+ """Ultra-fast Sparse MeZO that exploits the STE training path.
165
+
166
+ Key insight: during training, BitLinear uses `_forward_train` which
167
+ re-quantises from latent FP32 on every call — so we DON'T need to
168
+ invalidate packed caches at all. We just perturb the latent .weight
169
+ directly and let STE handle it.
170
+
171
+ Also: uses Rademacher directions (±1 only, no randn) for faster
172
+ perturbation generation.
173
+ """
174
+
175
+ def __init__(self, model: nn.Module, *,
176
+ lr: float = 1e-4, eps: float = 1e-3,
177
+ sparsity: float = 0.05,
178
+ weight_decay: float = 0.0,
179
+ momentum: float = 0.9,
180
+ mask_refresh_interval: int = 100):
181
+ self.model = model
182
+ self.lr = float(lr)
183
+ self.eps = float(eps)
184
+ self.wd = float(weight_decay)
185
+ self.momentum_coeff = float(momentum)
186
+ self.mask_refresh = int(mask_refresh_interval)
187
+
188
+ # Collect trainable params (deduplicated)
189
+ self._params = []
190
+ seen = set()
191
+ for name, p in model.named_parameters():
192
+ if p.requires_grad and id(p) not in seen:
193
+ self._params.append((name, p))
194
+ seen.add(id(p))
195
+
196
+ self._total = sum(p.numel() for _, p in self._params)
197
+ self._k = max(1, int(self._total * sparsity))
198
+
199
+ # Pre-allocate masks and momentum buffers
200
+ self._masks = {}
201
+ self._momentum_bufs = {}
202
+ for _, p in self._params:
203
+ self._masks[id(p)] = torch.ones(p.shape, dtype=torch.bool)
204
+ if self.momentum_coeff > 0:
205
+ self._momentum_bufs[id(p)] = torch.zeros_like(p.data)
206
+
207
+ self._step = 0
208
+ self._refresh_masks()
209
+
210
+ def _refresh_masks(self):
211
+ """Compute sparse masks — top-K by magnitude."""
212
+ all_mag = torch.cat([p.data.abs().flatten() for _, p in self._params])
213
+ if self._k < all_mag.numel():
214
+ thr = torch.kthvalue(all_mag, all_mag.numel() - self._k).values
215
+ else:
216
+ thr = torch.tensor(0.0)
217
+
218
+ offset = 0
219
+ for _, p in self._params:
220
+ n = p.numel()
221
+ self._masks[id(p)] = (all_mag[offset:offset+n].view(p.shape) >= thr)
222
+ offset += n
223
+
224
+ def _perturb_all(self, seed: int, scale: float):
225
+ """Perturb all masked params with Rademacher ±1 directions."""
226
+ gen = torch.Generator(device="cpu")
227
+ for i, (_, p) in enumerate(self._params):
228
+ gen.manual_seed((seed + i * 1_000_003) & 0x7FFFFFFFFFFFFFFF)
229
+ z = torch.empty(p.shape, dtype=p.dtype)
230
+ z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
231
+ mask = self._masks[id(p)]
232
+ # In-place add only masked positions
233
+ p.data.add_(z * mask, alpha=scale)
234
+
235
+ @torch.no_grad()
236
+ def step(self, loss_fn, batch) -> float:
237
+ self._step += 1
238
+ if self._step % self.mask_refresh == 0:
239
+ self._refresh_masks()
240
+
241
+ seed = int(torch.randint(0, 2**31, (1,)).item())
242
+
243
+ # +ε perturbation
244
+ self._perturb_all(seed, +self.eps)
245
+ loss_pos = float(loss_fn(batch).item())
246
+
247
+ # −2ε (net: −ε from original)
248
+ self._perturb_all(seed, -2.0 * self.eps)
249
+ loss_neg = float(loss_fn(batch).item())
250
+
251
+ # Restore (+ε back to original)
252
+ self._perturb_all(seed, +self.eps)
253
+
254
+ proj = (loss_pos - loss_neg) / (2.0 * self.eps)
255
+
256
+ # Update with momentum
257
+ gen = torch.Generator(device="cpu")
258
+ for i, (_, p) in enumerate(self._params):
259
+ gen.manual_seed((seed + i * 1_000_003) & 0x7FFFFFFFFFFFFFFF)
260
+ z = torch.empty(p.shape, dtype=p.dtype)
261
+ z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
262
+ mask = self._masks[id(p)]
263
+ z_masked = z * mask
264
+
265
+ if self.momentum_coeff > 0:
266
+ buf = self._momentum_bufs[id(p)]
267
+ buf.mul_(self.momentum_coeff).add_(z_masked, alpha=proj)
268
+ p.data.add_(buf, alpha=-self.lr)
269
+ else:
270
+ p.data.add_(z_masked, alpha=-self.lr * proj)
271
+
272
+ if self.wd > 0:
273
+ p.data.mul_(1 - self.lr * self.wd)
274
+
275
+ return 0.5 * (loss_pos + loss_neg)
276
+
277
+
278
  # ═══════════════════════════════════════════════════════════════════════════
279
  # Data helpers
280
  # ═══════════════════════════════════════════════════════════════════════════
281
 
282
+ def _build_token_buffer(dataset_name, split, text_column,
283
+ max_tokens, cache_dir):
 
284
  cache_path = os.path.join(
285
  cache_dir,
286
+ f"{dataset_name.replace('/', '_')}_{split}_{max_tokens}.pt")
 
287
  os.makedirs(cache_dir, exist_ok=True)
288
 
289
  if os.path.exists(cache_path):
 
303
  for ex in ds:
304
  text = ""
305
  if text_column == "auto":
306
+ for cand in ("text", "content", "messages"):
307
  if cand in ex:
308
  val = ex[cand]
309
  text = val if isinstance(val, str) else str(val)
 
321
  if n > room:
322
  ids = ids[:room]
323
  n = room
324
+ buf[idx:idx+n] = torch.tensor(ids, dtype=torch.long)
325
  idx += n
326
  processed += 1
327
+ if processed % 5000 == 0:
328
  print(f" {processed:,} docs {idx:,}/{max_tokens} tokens")
329
 
330
  buf = buf[:idx].contiguous()
 
334
 
335
 
336
  # ═══════════════════════════════════════════════════════════════════════════
337
+ # Model builder LEAN config
338
  # ═══════════════════════════════════════════════════════════════════════════
339
 
340
+ def _build_model(args):
341
  with open(args.config) as f:
342
  config = json.load(f)
343
 
344
  if args.scale in _SCALE_PRESETS:
345
  config.update(_SCALE_PRESETS[args.scale])
346
 
347
+ n_layers = config["num_hidden_layers"]
 
348
  config["vocab_size"] = config.get("vocab_size", 200_073)
349
 
350
+ config.setdefault("gated_deltanet", {})["chunk_size"] = min(args.seq_len, 64)
351
+ hd = config.get("head_dim", 64)
352
+ config.setdefault("xlstm", {})["memory_size_per_head"] = [hd, hd]
 
353
  config.setdefault("titans", {}).update({
354
  "memory_depth": 2, "persistent_memory_slots": 16,
355
  "local_window_size": min(args.seq_len, 256),
356
  })
357
 
358
+ # MoE: only on layers that exist, reduced experts for tiny
359
  moe = config.setdefault("backbone", {}).setdefault("moe", {})
360
+ if args.lean and args.scale == "tiny":
361
+ # No MoE for tiny in lean mode — too expensive
362
+ moe["layers"] = []
363
+ moe["n_routed_experts"] = 0
364
+ else:
365
+ valid_moe = [i for i in [3, 7, 11, 15, 19, 23, 27] if i < n_layers]
366
+ moe.setdefault("layers", valid_moe)
367
+ moe.setdefault("n_routed_experts", 4 if args.scale == "tiny" else 8)
368
  moe.setdefault("moe_intermediate_size", config["intermediate_size"] // 4)
 
369
  moe.setdefault("n_shared_experts", 1)
370
  moe.setdefault("num_experts_per_tok", 2)
371
 
372
+ # Looping: disable for lean, or adjust for reduced layers
373
+ loop = config.setdefault("looping", {})
374
+ if args.lean or n_layers < 8:
375
+ loop["enabled"] = False
376
+ else:
377
+ loop.update({
378
+ "enabled": True,
379
+ "prelude": [0, min(1, n_layers-1)],
380
+ "loop": [2, max(2, n_layers-3)],
381
+ "coda": [max(0, n_layers-2), n_layers-1],
382
+ "loop_range": [1, 2], "loop_default": 1,
383
+ })
384
+
385
+ config.setdefault("span_inference", {})["enabled"] = not args.lean
386
+ config.setdefault("grammar", {})["enabled"] = not args.lean
387
+ config.setdefault("entropy_valve", {})["enabled"] = not args.lean
388
+ config.setdefault("debt_ledger", {})["enabled"] = not args.lean
389
  config.setdefault("multimodal", {})["enabled"] = False
390
 
391
  model = Chimera51ForCausalLM(config)
 
393
 
394
 
395
  # ═══════════════════════════════════════════════════════════════════════════
396
+ # HYPER training loop
397
  # ═══════════════════════════════════════════════════════════════════════════
398
 
399
+ def _train_hyper(args):
400
  model, config = _build_model(args)
401
  counts = model.count_parameters()
 
 
402
 
403
  print("=" * 65)
404
+ print(f"CHIMERA 5.3 HYPER TRAIN — scale={args.scale} lean={args.lean}")
 
405
  print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} "
406
  f"vocab={config['vocab_size']} target_seq={args.seq_len}")
407
  print(f"Threads: {torch.get_num_threads()} IPEX={_HAS_IPEX}")
408
  print(f"Paradigms: P1={args.growlength} P2={args.reservoir} "
409
+ f"P3={args.sparse_mezo} P5={args.fused_cache} "
410
+ f"P7={args.progressive_unfreeze} P8={args.lean}")
 
411
  print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")
412
  print("=" * 65)
413
 
414
+ # ── P8: Lean mode ────────────────────────────────────────────────
415
+ if args.lean:
416
+ make_lean(model)
417
+
418
  # ── P2: Reservoir Freezing ───────────────────────────────────────
419
  if args.reservoir:
420
+ frozen = apply_reservoir_freezing(model, args.reservoir_ratio)
421
+ print(f"[P2] Reservoir: froze {frozen:,} gate params")
422
+
423
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
424
+ print(f"[INFO] Trainable params: {trainable:,}")
 
 
425
 
426
  # ── P7: Progressive Unfreezing ───────────────────────────────────
427
  unfreezer = None
428
  if args.progressive_unfreeze:
429
  unfreezer = ProgressiveUnfreezer(
430
  model, args.max_steps, n_stages=args.unfreeze_stages)
431
+ active = sum(p.numel() for p in model.parameters() if p.requires_grad)
432
+ print(f"[P7] Progressive unfreeze: {active:,} initially trainable")
 
 
433
 
434
+ # ── P1: GrowLength ───────────────────────────────────────────────
435
  if args.growlength:
436
  stages = [
437
+ (max(8, args.seq_len // 4), 0.30),
438
+ (max(16, args.seq_len // 2), 0.30),
439
+ (args.seq_len, 0.40),
 
440
  ]
441
  grow = GrowLengthScheduler(stages, args.max_steps)
442
  initial_seq = stages[0][0]
443
+ print(f"[P1] GrowLength: {' → '.join(str(s) for s, _ in stages)}")
 
444
  else:
445
  grow = None
446
  initial_seq = args.seq_len
447
 
448
  # ── Data ─────────────────────────────────────────────────────────
449
+ tok_budget = args.max_tokens or max(200_000,
450
+ args.max_steps * args.batch_size * (args.seq_len + 1) * 4)
 
 
451
  token_buf = _build_token_buffer(
452
  args.dataset_name, args.dataset_split, args.text_column,
453
  tok_budget, args.cache_dir)
 
 
454
  if args.pack_tokens:
455
+ token_buf = pack_documents(token_buf, 199_999, token_buf.numel())
 
 
 
456
  dataset = GrowLengthDataset(token_buf, initial_seq)
457
+ print(f"[DATA] {token_buf.numel():,} tokens seq={initial_seq} "
458
  f"chunks={len(dataset):,}")
459
 
460
+ # ── torch.compile ────────────────────────────────────────────────
461
  if args.compile:
462
+ print("[OPT] torch.compile (inductor) …")
463
  model = torch.compile(model, backend="inductor", mode="default",
464
  dynamic=True)
465
 
466
+ # ── P3: Fast Sparse MeZO ────────────────────────────────────────
467
+ optimizer = FastSparseMeZO(
468
+ model,
469
+ lr=args.lr * 0.01,
470
+ eps=args.mezo_eps,
471
+ sparsity=args.mezo_sparsity,
472
+ weight_decay=0.1,
473
+ momentum=0.9,
474
+ mask_refresh_interval=max(10, args.max_steps // 5),
475
+ )
476
+ print(f"[P3] FastSparseMeZO: top {args.mezo_sparsity*100:.0f}% "
477
+ f"({optimizer._k:,}/{optimizer._total:,} params)")
 
 
 
 
 
 
 
 
 
478
 
479
  # ── Loss function ────────────────────────────────────────────────
480
  use_bf16 = bool(args.bf16)
481
+ def compute_loss(batch):
482
+ ids, labels = batch["input_ids"], batch["labels"]
 
 
483
  if use_bf16:
484
+ with torch.autocast("cpu", dtype=torch.bfloat16):
485
  return model(ids, labels=labels).loss
486
  return model(ids, labels=labels).loss
487
 
 
500
  cur_seq = initial_seq
501
  warmup = min(args.warmup, max(1, args.max_steps // 10))
502
 
503
+ eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
 
504
  loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True,
505
  num_workers=0, drop_last=True)
506
  data_iter = iter(loader)
 
509
  f"(eff_batch={eff_batch}, seq={cur_seq})\n{'=' * 65}\n")
510
 
511
  while step < args.max_steps:
512
+ # P1: GrowLength
513
  if grow is not None:
514
  new_seq = grow.get_seq_len(step)
515
  if new_seq != cur_seq:
516
  cur_seq = new_seq
517
  dataset.set_seq_len(cur_seq)
518
+ eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
519
  loader = DataLoader(dataset, batch_size=eff_batch,
520
+ shuffle=True, num_workers=0, drop_last=True)
 
521
  data_iter = iter(loader)
522
+ print(f" [P1] seq → {cur_seq} batch → {eff_batch}")
523
 
524
+ # P7: Progressive unfreeze
525
  if unfreezer is not None:
526
  unfreezer.update(step)
527
 
528
+ # Get batch
529
  try:
530
  batch = next(data_iter)
531
  except StopIteration:
532
  data_iter = iter(loader)
533
  batch = next(data_iter)
534
 
535
+ # P5: Fused ternary cache (only useful if NOT in train mode)
536
+ # In lean+train mode, BitLinear uses STE path → no need to cache
537
+ # But still useful for non-BitLinear frozen layers
538
+ if args.fused_cache and not model.training:
539
  precompute_ternary_cache(model)
540
 
541
+ # LR schedule
542
  cur_lr = cosine_lr(step, warmup, args.max_steps,
543
  args.lr * 0.01, args.lr * 0.001)
544
+ optimizer.lr = cur_lr
 
545
 
546
+ # Optimizer step
547
  loss_val = optimizer.step(compute_loss, batch)
548
  total_loss += loss_val
549
  toks += batch["input_ids"].numel()
550
  step += 1
551
 
552
+ # Logging
553
  if step % args.log_every == 0:
554
  dt = time.time() - t0
555
  avg = total_loss / args.log_every
556
  ppl = math.exp(min(avg, 20))
557
  tps = toks / dt if dt > 0 else 0
558
  eta_h = ((args.max_steps - step) / (step / dt) / 3600
559
+ if dt > 0 else 0)
560
+ entry = {"step": step, "loss": round(avg, 4),
561
+ "ppl": round(ppl, 2), "lr": cur_lr,
562
+ "tok/s": round(tps), "seq_len": cur_seq,
563
+ "eff_batch": eff_batch}
 
564
  log_f.write(json.dumps(entry) + "\n")
565
  log_f.flush()
566
  print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | "
 
572
  toks = 0
573
  t0 = time.time()
574
 
 
575
  if step % args.save_every == 0:
576
  ckpt_dir = os.path.join(args.output_dir, f"ckpt-{step}")
577
  os.makedirs(ckpt_dir, exist_ok=True)
578
  raw = getattr(model, "_orig_mod", model)
579
+ torch.save({"model": raw.state_dict(), "config": config,
580
+ "step": step}, os.path.join(ckpt_dir, "ckpt.pt"))
 
 
 
581
  print(f" [SAVE] {ckpt_dir}")
582
 
583
+ # Final save
584
  final_dir = os.path.join(args.output_dir, "final")
585
  os.makedirs(final_dir, exist_ok=True)
586
  raw = getattr(model, "_orig_mod", model)
587
+ torch.save({"model": raw.state_dict(), "config": config,
588
+ "step": step, "best_loss": best_loss},
589
+ os.path.join(final_dir, "model.pt"))
 
 
590
  with open(os.path.join(final_dir, "config.json"), "w") as fh:
591
  json.dump(config, fh, indent=2)
592
  log_f.close()
 
593
  print(f"\n{'=' * 65}")
594
  print(f"DONE — best loss {best_loss:.4f} "
595
  f"ppl {math.exp(min(best_loss, 20)):.2f}")
596
  print(f"Saved to {final_dir}")
597
 
 
 
598
 
599
  # ═══════════════════════════════════════════════════════════════════════════
600
+ # Benchmark
601
  # ═══════════════════════════════════════════════════════════════════════════
602
 
603
+ def _run_baseline(model, token_buf, args):
604
+ """Standard full MeZO on full 28-layer model."""
605
  model.train()
606
  seq = args.seq_len
607
  n = token_buf.numel() // (seq + 1)
 
615
 
616
  loader = DataLoader(_DS(), batch_size=args.batch_size,
617
  shuffle=True, num_workers=0, drop_last=True)
 
618
  params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
619
  eps = 1e-3
620
 
621
  def loss_fn(batch):
622
  return model(batch["input_ids"], labels=batch["labels"]).loss
623
 
624
+ total_toks, total_loss = 0, 0.0
 
625
  t0 = time.time()
626
  di = iter(loader)
627
 
 
666
  return total_toks / dt, total_loss / args.max_steps, dt
667
 
668
 
669
+ def _run_hyper_bench(model, token_buf, args):
670
+ """Hyper pipeline with lean + all paradigms."""
671
  model.train()
672
+ make_lean(model)
673
+ apply_reservoir_freezing(model, args.reservoir_ratio)
674
  unfreezer = ProgressiveUnfreezer(model, args.max_steps,
675
  n_stages=args.unfreeze_stages)
 
676
  stages = [
677
+ (max(8, args.seq_len // 4), 0.30),
678
+ (max(16, args.seq_len // 2), 0.30),
679
+ (args.seq_len, 0.40),
 
680
  ]
681
  grow = GrowLengthScheduler(stages, args.max_steps)
682
  cur_seq = stages[0][0]
 
683
  dataset = GrowLengthDataset(token_buf, cur_seq)
684
+
685
+ optimizer = FastSparseMeZO(
686
  model, lr=args.lr * 0.01, eps=args.mezo_eps,
687
  sparsity=args.mezo_sparsity, weight_decay=0.1, momentum=0.9,
688
+ mask_refresh_interval=max(10, args.max_steps // 5))
689
 
690
  def loss_fn(batch):
 
691
  if args.bf16:
692
  with torch.autocast("cpu", dtype=torch.bfloat16):
693
+ return model(batch["input_ids"], labels=batch["labels"]).loss
694
+ return model(batch["input_ids"], labels=batch["labels"]).loss
695
 
696
+ total_toks, total_loss = 0, 0.0
 
697
  t0 = time.time()
698
 
699
+ eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
700
  loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True,
701
  num_workers=0, drop_last=True)
702
  di = iter(loader)
 
706
  if new_seq != cur_seq:
707
  cur_seq = new_seq
708
  dataset.set_seq_len(cur_seq)
709
+ eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
710
  loader = DataLoader(dataset, batch_size=eff_batch,
711
  shuffle=True, num_workers=0, drop_last=True)
712
  di = iter(loader)
713
 
714
  unfreezer.update(step)
 
715
  try:
716
  batch = next(di)
717
  except StopIteration:
718
  di = iter(loader)
719
  batch = next(di)
720
 
 
721
  loss_val = optimizer.step(loss_fn, batch)
722
  total_toks += batch["input_ids"].numel()
723
  total_loss += loss_val
 
727
 
728
 
729
  def _benchmark(args):
 
730
  print("=" * 65)
731
+ print("CHIMERA 5.3 HYPER v2 — BENCHMARK")
732
  print("=" * 65)
733
 
734
+ # Baseline: full 28-layer model (as per original train.py)
735
+ args_base = copy.copy(args)
736
+ args_base.lean = False
737
+ # Override to build with 28 layers like original
738
+ orig_presets = {
739
+ "tiny": dict(hidden_size=256, intermediate_size=512,
740
+ num_heads=4, head_dim=48, num_hidden_layers=28),
741
+ }
742
+ _SCALE_PRESETS_BAK = dict(_SCALE_PRESETS)
743
+ _SCALE_PRESETS.update(orig_presets)
744
+ model_base, cfg_base = _build_model(args_base)
745
+ _SCALE_PRESETS.update(_SCALE_PRESETS_BAK)
746
+
747
+ # Hyper: lean 6-layer model
748
+ args_hyper = copy.copy(args)
749
+ args_hyper.lean = True
750
+ model_hyper, cfg_hyper = _build_model(args_hyper)
751
+
752
+ c1 = model_base.count_parameters()
753
+ c2 = model_hyper.count_parameters()
754
+ print(f"Baseline: {c1['total']:,} params, {cfg_base['num_hidden_layers']} layers")
755
+ print(f"Hyper: {c2['total']:,} params, {cfg_hyper['num_hidden_layers']} layers (lean)")
756
+
757
+ tok_budget = max(500_000,
758
+ args.max_steps * args.batch_size * (args.seq_len + 1) * 8)
759
  token_buf = _build_token_buffer(
760
  args.dataset_name, args.dataset_split, args.text_column,
761
  tok_budget, args.cache_dir)
762
  print(f"Tokens: {token_buf.numel():,}\n")
763
 
 
764
  print("-" * 65)
765
+ print("BASELINE (28 layers, full MeZO, all subsystems)")
766
  print("-" * 65)
767
+ b_tps, b_loss, b_dt = _run_baseline(model_base, token_buf, args)
768
  print(f" → {b_tps:,.0f} tok/s loss={b_loss:.4f} time={b_dt:.1f}s\n")
769
 
 
770
  print("-" * 65)
771
+ print("HYPER (6 layers lean, Sparse MeZO, GrowLength, Reservoir, Unfreeze)")
772
  print("-" * 65)
773
+ h_tps, h_loss, h_dt = _run_hyper_bench(model_hyper, token_buf, args)
774
  print(f" → {h_tps:,.0f} tok/s loss={h_loss:.4f} time={h_dt:.1f}s\n")
775
 
 
776
  speedup = h_tps / b_tps if b_tps > 0 else float("inf")
777
  print("=" * 65)
778
  print(f" Baseline : {b_tps:>12,.0f} tok/s loss {b_loss:.4f}")
 
784
  "baseline_tps": round(b_tps), "hyper_tps": round(h_tps),
785
  "speedup": round(speedup, 2),
786
  "baseline_loss": round(b_loss, 4), "hyper_loss": round(h_loss, 4),
787
+ "baseline_params": c1["total"], "hyper_params": c2["total"],
788
+ "baseline_layers": cfg_base["num_hidden_layers"],
789
+ "hyper_layers": cfg_hyper["num_hidden_layers"],
790
  }
791
  out = os.path.join(args.output_dir, "benchmark.json")
792
  os.makedirs(args.output_dir, exist_ok=True)
 
795
  print(f"Saved → {out}")
796
 
797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
  # ═══════════════════════════════════════════════════════════════════════════
799
  # CLI
800
  # ═══════════════════════════════════════════════════════════════════════════
801
 
802
+ def _cli():
803
  p = argparse.ArgumentParser(
804
+ description="Chimera 5.3 — HYPER CPU training (8 paradigms)")
805
 
 
806
  p.add_argument("--config", default="config.json")
807
  p.add_argument("--scale", default="tiny",
808
  choices=["tiny", "small", "medium", "full"])
809
+ p.add_argument("--seq_len", type=int, default=64)
810
+ p.add_argument("--batch_size", type=int, default=8)
811
  p.add_argument("--lr", type=float, default=1e-3)
812
+ p.add_argument("--warmup", type=int, default=100)
813
  p.add_argument("--max_steps", type=int, default=5000)
814
  p.add_argument("--max_tokens", type=int, default=None)
815
+ p.add_argument("--max_samples", type=int, default=None,
816
+ help="Max samples (converted to max_tokens internally)")
817
  p.add_argument("--bf16", action="store_true", default=True)
818
  p.add_argument("--no-bf16", dest="bf16", action="store_false")
819
  p.add_argument("--compile", action="store_true", default=False)
 
825
  p.add_argument("--save_every", type=int, default=1000)
826
  p.add_argument("--output_dir", default="./chimera_hyper_output")
827
 
828
+ g = p.add_argument_group("paradigms (--all enables everything)")
829
+ g.add_argument("--all", action="store_true", default=False)
830
+ g.add_argument("--lean", action="store_true", default=False,
831
+ help="P8: Strip inference/evolution overhead")
832
+ g.add_argument("--growlength", action="store_true", default=False)
833
+ g.add_argument("--reservoir", action="store_true", default=False)
 
 
834
  g.add_argument("--reservoir-ratio", type=float, default=0.5,
835
  dest="reservoir_ratio")
836
  g.add_argument("--sparse-mezo", action="store_true", default=False,
837
+ dest="sparse_mezo")
838
+ g.add_argument("--mezo-sparsity", type=float, default=0.05,
 
839
  dest="mezo_sparsity",
840
+ help="Fraction of params to perturb (default 0.05 = 5%%)")
841
  g.add_argument("--mezo-eps", type=float, default=1e-3, dest="mezo_eps")
842
+ g.add_argument("--pipeline", action="store_true", default=False)
 
843
  g.add_argument("--fused-cache", action="store_true", default=False,
844
+ dest="fused_cache")
 
845
  g.add_argument("--pack-tokens", action="store_true", default=False,
846
+ dest="pack_tokens")
 
847
  g.add_argument("--progressive-unfreeze", action="store_true",
848
+ default=False, dest="progressive_unfreeze")
 
849
  g.add_argument("--unfreeze-stages", type=int, default=4,
850
  dest="unfreeze_stages")
851
 
852
+ p.add_argument("--benchmark", action="store_true", default=False)
 
 
 
853
  return p
854
 
855
 
 
857
  parser = _cli()
858
  args = parser.parse_args()
859
 
860
+ # --max_samples --max_tokens conversion
861
+ if args.max_samples and not args.max_tokens:
862
+ args.max_tokens = args.max_samples * (args.seq_len + 1)
863
+
864
  if args.all:
865
  args.growlength = True
866
  args.reservoir = True
 
869
  args.fused_cache = True
870
  args.pack_tokens = True
871
  args.progressive_unfreeze = True
872
+ args.lean = True # ← critical: --all now includes lean
873
 
874
  if args.benchmark:
 
875
  args.growlength = True
876
  args.reservoir = True
877
  args.sparse_mezo = True
 
878
  args.fused_cache = True
879
  args.pack_tokens = True
880
  args.progressive_unfreeze = True
881
+ args.lean = True
882
  _benchmark(args)
883
  else:
884
  _train_hyper(args)