Lgr54HFi commited on
Commit
20ad65d
·
verified ·
1 Parent(s): 11c11f8

fix: turbo v2 — disable compile (84 graph breaks), fix grad_accum, add diagnostics

Browse files
Files changed (1) hide show
  1. chimera_turbo.py +161 -282
chimera_turbo.py CHANGED
@@ -1,25 +1,33 @@
1
  """
2
  chimera_turbo.py — Drop-in CPU acceleration for ch1mera 5.3
3
- Usage: import chimera_turbo; chimera_turbo.apply(model, optimizer, args)
4
 
5
  Paradigmes intégrés:
6
  P-TURBO-1: STE + AdamW (remplace MeZO → fix convergence + 50x moins de forwards)
7
- P-TURBO-2: torch.compile regional (2-3x kernel fusion)
8
  P-TURBO-3: Threading optimal + tcmalloc detection
9
  P-TURBO-4: IPEX bf16/AMX si disponible
10
- P-TURBO-5: Cache poids quantifiés inter micro-batch
11
  P-TURBO-6: INT8 ternary forward path (VNNI/AMX dispatch)
12
  P-TURBO-7: Arrow mmap dataset
 
 
 
 
 
 
 
 
13
  """
14
 
 
15
  import os
16
  import sys
17
  import warnings
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
- from typing import Optional, Dict, Any, Tuple
22
- from functools import wraps
23
  from contextlib import nullcontext
24
 
25
  # ═══════════════════════════════════════════════════════════
@@ -29,11 +37,10 @@ from contextlib import nullcontext
29
  def detect_cpu_info() -> Dict[str, Any]:
30
  """Detect CPU capabilities for optimal configuration."""
31
  info = {}
32
-
33
  # Physical cores (not hyperthreads)
34
  try:
35
  physical = len(os.sched_getaffinity(0))
36
- # Heuristic: if thread count is even, likely HT enabled → halve
37
  import multiprocessing
38
  logical = multiprocessing.cpu_count()
39
  info["physical_cores"] = logical // 2 if logical == physical else physical
@@ -42,18 +49,19 @@ def detect_cpu_info() -> Dict[str, Any]:
42
  import multiprocessing
43
  info["logical_cores"] = multiprocessing.cpu_count()
44
  info["physical_cores"] = info["logical_cores"] // 2
45
-
46
  # CPU capability
47
  try:
48
  info["capability"] = torch.backends.cpu.get_cpu_capability()
49
  except Exception:
50
  info["capability"] = "unknown"
51
-
52
- # AMX support (Sapphire Rapids+)
53
- info["has_amx"] = "amx" in info["capability"].lower() if info["capability"] else False
54
- info["has_avx512"] = "avx512" in info["capability"].lower() if info["capability"] else False
55
- info["has_vnni"] = info["has_avx512"] # VNNI comes with AVX-512 Ice Lake+
56
-
 
57
  # IPEX available?
58
  try:
59
  import intel_extension_for_pytorch
@@ -61,23 +69,24 @@ def detect_cpu_info() -> Dict[str, Any]:
61
  info["ipex_version"] = intel_extension_for_pytorch.__version__
62
  except ImportError:
63
  info["ipex_available"] = False
64
-
65
  # tcmalloc loaded?
66
  info["tcmalloc"] = "tcmalloc" in os.environ.get("LD_PRELOAD", "")
67
-
68
  return info
69
 
70
 
71
  def configure_threading(cpu_info: Dict[str, Any], reserve_for_io: int = 1):
72
  """Set optimal threading for CPU training."""
73
  n_compute = max(1, cpu_info["physical_cores"] - reserve_for_io)
74
-
 
 
75
  torch.set_num_threads(n_compute)
76
- torch.set_num_interop_threads(min(4, reserve_for_io + 1))
77
-
78
  os.environ["OMP_NUM_THREADS"] = str(n_compute)
79
  os.environ["MKL_NUM_THREADS"] = str(n_compute)
80
-
81
  return n_compute
82
 
83
 
@@ -94,19 +103,15 @@ def create_optimizer(
94
  ) -> torch.optim.Optimizer:
95
  """
96
  Create optimizer for STE-based ternary training (replaces MeZO).
97
-
98
  Based on BitNet b1.58 Reloaded (2407.09527):
99
  - lr=1e-3 for <300M params (NOT 1e-2, that's for 3B+)
100
  - weight_decay=0.05
101
  - AdamW with β=(0.9, 0.95)
102
-
103
- The STE is already in BitLinear — just use a normal optimizer.
104
- MeZO needed 528 forward passes per step; this needs 1 forward + 1 backward.
105
  """
106
- # Separate weight decay groups (no WD on bias, layernorm, embeddings)
107
  decay_params = []
108
  no_decay_params = []
109
-
110
  for name, param in model.named_parameters():
111
  if not param.requires_grad:
112
  continue
@@ -114,172 +119,50 @@ def create_optimizer(
114
  no_decay_params.append(param)
115
  else:
116
  decay_params.append(param)
117
-
118
  param_groups = [
119
  {"params": decay_params, "weight_decay": weight_decay},
120
  {"params": no_decay_params, "weight_decay": 0.0},
121
  ]
122
-
123
  if use_lion:
124
  try:
125
  from lion_pytorch import Lion
126
  return Lion(param_groups, lr=lr * 0.3, betas=(0.95, 0.98))
127
  except ImportError:
128
  warnings.warn("lion-pytorch not installed, falling back to AdamW")
129
-
130
  return torch.optim.AdamW(param_groups, lr=lr, betas=betas, fused=False)
131
 
132
 
133
  def create_scheduler(optimizer, max_steps: int, warmup_steps: int = 500):
134
  """Cosine schedule with linear warmup — standard BitNet recipe."""
135
  from torch.optim.lr_scheduler import LambdaLR
136
- import math
137
-
138
  def lr_lambda(step):
139
  if step < warmup_steps:
140
  return step / max(1, warmup_steps)
141
  progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
142
  return max(0.01, 0.5 * (1.0 + math.cos(math.pi * progress)))
143
-
144
  return LambdaLR(optimizer, lr_lambda)
145
 
146
 
147
  # ═══════════════════════════════════════════════════════════
148
- # P-TURBO-5 : Quantized Weight Cache
149
  # ═══════════════════════════════════════════════════════════
150
 
151
- class QuantCacheMixin:
152
- """
153
- Mixin for BitLinear to cache quantized weights during gradient accumulation.
154
-
155
- Without cache: quantize weights on every micro-batch forward pass
156
- With cache: quantize once, reuse across accumulation steps
157
- Invalidate after optimizer.step()
158
- """
159
- _quant_cache: Optional[torch.Tensor] = None
160
- _cache_valid: bool = False
161
-
162
- def get_quantized_weight(self):
163
- """Override in your BitLinear. Returns quantized weight + scale."""
164
- raise NotImplementedError
165
-
166
- def cached_quantized_weight(self):
167
- if not self._cache_valid or self._quant_cache is None:
168
- self._quant_cache = self.get_quantized_weight()
169
- self._cache_valid = True
170
- return self._quant_cache
171
-
172
- def invalidate_cache(self):
173
- self._cache_valid = False
174
- self._quant_cache = None
175
-
176
-
177
  def invalidate_all_caches(model: nn.Module):
178
- """Call after optimizer.step() to force re-quantization."""
179
- for m in model.modules():
180
- if hasattr(m, "invalidate_cache"):
181
- m.invalidate_cache()
182
-
183
 
184
- # ═══════════════════════════════════════════════════════════
185
- # P-TURBO-6 : INT8 Ternary Forward Path
186
- # ═══════════════════════════════════════════════════════════
187
-
188
- def ternary_matmul_int8(
189
- x: torch.Tensor, # [B, S, K] float
190
- w_ternary: torch.Tensor, # [N, K] float {-1, 0, 1}
191
- w_scale: torch.Tensor, # scalar
192
- ) -> torch.Tensor:
193
- """
194
- INT8 ternary matmul using torch._int_mm (dispatches to VNNI/AMX).
195
-
196
- For inference-in-training (eval steps) or forward pass if
197
- your hardware has VNNI/AMX support.
198
-
199
- Speedup: 2-4x over float GEMM for ternary weights.
200
  """
201
- B, S, K = x.shape
202
- x_flat = x.reshape(-1, K) # [B*S, K]
203
-
204
- # Quantize activations to int8
205
- x_abs_max = x_flat.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
206
- x_scale = x_abs_max / 127.0
207
- x_int8 = (x_flat / x_scale).round().clamp(-128, 127).to(torch.int8)
208
-
209
- # Weights: already ternary, just cast
210
- w_int8 = w_ternary.to(torch.int8) # {-1, 0, 1} fits in int8
211
-
212
- # INT8 GEMM — uses hardware VNNI/AMX if available
213
- # torch._int_mm requires 2D inputs, both int8, K divisible by some alignment
214
- try:
215
- out_int32 = torch._int_mm(x_int8, w_int8.t()) # [B*S, N]
216
- out = out_int32.float() * x_scale * w_scale
217
- except RuntimeError:
218
- # Fallback if alignment requirements not met
219
- out = F.linear(x_flat.float(), w_ternary.float()) * w_scale
220
-
221
- return out.reshape(B, S, -1)
222
-
223
-
224
- # ═══════════════════════════════════════════════════════════
225
- # P-TURBO-2 : torch.compile (Regional)
226
- # ═══════════════════════════════════════════════════════════
227
-
228
- def try_compile_model(model: nn.Module, mode: str = "reduce-overhead") -> nn.Module:
229
- """
230
- Attempt torch.compile with graceful fallback.
231
-
232
- Uses regional compilation: compiles sub-modules individually
233
- to work around graph breaks from STE custom autograd functions.
234
- """
235
- if not hasattr(torch, "compile"):
236
- warnings.warn("torch.compile not available (PyTorch < 2.0)")
237
- return model
238
-
239
- # First: diagnose graph breaks
240
- try:
241
- import torch._dynamo as dynamo
242
-
243
- # Try compiling individual attention/MLP blocks instead of full model
244
- compiled_count = 0
245
- for name, module in model.named_modules():
246
- # Skip the top-level model and BitLinear (STE graph breaks)
247
- if module is model:
248
- continue
249
- # Compile "clean" blocks: attention, MLP, norms
250
- module_type = type(module).__name__.lower()
251
- if any(k in module_type for k in ["attention", "mlp", "feedforward", "norm"]):
252
- try:
253
- compiled = torch.compile(
254
- module,
255
- backend="inductor",
256
- mode=mode,
257
- fullgraph=False,
258
- )
259
- # Replace in parent
260
- parent_name = ".".join(name.split(".")[:-1])
261
- child_name = name.split(".")[-1]
262
- parent = model
263
- if parent_name:
264
- for part in parent_name.split("."):
265
- parent = getattr(parent, part)
266
- setattr(parent, child_name, compiled)
267
- compiled_count += 1
268
- except Exception as e:
269
- pass # Skip modules that can't be compiled
270
-
271
- if compiled_count == 0:
272
- # Fallback: try compiling the whole model with fullgraph=False
273
- model = torch.compile(model, backend="inductor", mode=mode, fullgraph=False)
274
- print(f"[TURBO-2] Compiled full model (fullgraph=False)")
275
- else:
276
- print(f"[TURBO-2] Compiled {compiled_count} sub-modules (regional)")
277
-
278
- return model
279
-
280
- except Exception as e:
281
- warnings.warn(f"torch.compile failed: {e}. Running in eager mode.")
282
- return model
283
 
284
 
285
  # ═══════════════════════════════════════════════════════════
@@ -296,21 +179,20 @@ def try_ipex_optimize(
296
  if not cpu_info.get("ipex_available"):
297
  print("[TURBO-4] IPEX not available — install: pip install intel-extension-for-pytorch")
298
  return model, optimizer
299
-
300
  import intel_extension_for_pytorch as ipex
301
-
302
- # Choose dtype based on hardware
303
  if dtype is None:
304
  if cpu_info["has_amx"]:
305
- dtype = torch.bfloat16 # AMX tiles → massive bf16 speedup
306
  print("[TURBO-4] IPEX + AMX bf16 enabled (Sapphire Rapids+)")
307
  elif cpu_info["has_avx512"]:
308
- dtype = torch.bfloat16 # Moderate benefit with AVX-512
309
  print("[TURBO-4] IPEX + AVX-512 bf16 enabled")
310
  else:
311
- dtype = torch.float32 # bf16 slower than fp32 without hardware support
312
  print("[TURBO-4] IPEX fp32 (no bf16 hardware support detected)")
313
-
314
  model, optimizer = ipex.optimize(
315
  model,
316
  optimizer=optimizer,
@@ -318,76 +200,62 @@ def try_ipex_optimize(
318
  level="O1",
319
  inplace=True,
320
  )
321
-
322
  return model, optimizer
323
 
324
 
325
  # ═══════════════════════════════════════════════════════════
326
- # P-TURBO-7 : Arrow mmap Dataset
327
  # ═══════════════════════════════════════════════════════════
328
 
329
- def prepare_arrow_dataset(
330
- dataset_name: str = "roneneldan/TinyStories",
331
- split: str = "train",
332
- tokenizer=None,
333
- seq_len: int = 32,
334
- max_tokens: int = 500_000,
335
- cache_dir: str = "./cache/arrow",
336
- num_proc: int = 4,
337
- ):
338
  """
339
- Prepare dataset as Arrow mmap format for zero-copy loading.
340
-
341
- Replaces streaming + custom .pt cache with HF datasets Arrow backend.
342
- Benefits: zero-copy to PyTorch, random access, efficient memory via mmap.
 
 
 
 
 
 
 
 
 
 
 
343
  """
344
- from datasets import load_dataset, Dataset
345
- from pathlib import Path
346
-
347
- cache_path = Path(cache_dir) / f"{dataset_name.replace('/', '_')}_{split}_{max_tokens}_seq{seq_len}"
348
-
349
- if cache_path.exists():
350
- print(f"[TURBO-7] Loading cached Arrow dataset from {cache_path}")
351
- dataset = Dataset.load_from_disk(str(cache_path))
352
- return dataset.with_format("torch")
353
-
354
- print(f"[TURBO-7] Preparing Arrow dataset from {dataset_name}...")
355
-
356
- # Load and tokenize
357
- raw = load_dataset(dataset_name, split=split, streaming=True)
358
-
359
- # Collect tokens
360
- all_tokens = []
361
- total = 0
362
- for example in raw:
363
- text = example.get("text", "")
364
- if tokenizer is not None:
365
- tokens = tokenizer.encode(text)
366
- else:
367
- # Fallback: assume pre-tokenized or return text
368
- tokens = text
369
- if isinstance(tokens, list):
370
- all_tokens.extend(tokens)
371
- total += len(tokens)
372
- if total >= max_tokens:
373
- break
374
-
375
- all_tokens = all_tokens[:max_tokens]
376
-
377
- # Chunk into sequences
378
- n_seqs = len(all_tokens) // seq_len
379
- chunks = [all_tokens[i * seq_len:(i + 1) * seq_len] for i in range(n_seqs)]
380
-
381
- dataset = Dataset.from_dict({
382
- "input_ids": chunks,
383
- })
384
-
385
- # Save as Arrow
386
- cache_path.parent.mkdir(parents=True, exist_ok=True)
387
- dataset.save_to_disk(str(cache_path))
388
- print(f"[TURBO-7] Saved {n_seqs} sequences to {cache_path}")
389
-
390
- return dataset.with_format("torch")
391
 
392
 
393
  # ═══════════════════════════════════════════════════════════
@@ -400,80 +268,70 @@ def apply(
400
  lr: float = 1e-3,
401
  weight_decay: float = 0.05,
402
  warmup_steps: int = 500,
403
- use_compile: bool = True,
404
  use_ipex: bool = True,
405
  use_lion: bool = False,
406
  verbose: bool = True,
407
  ) -> Tuple[nn.Module, torch.optim.Optimizer, Any]:
408
  """
409
  Apply all turbo optimizations to ch1mera model.
410
-
411
  Returns: (model, optimizer, scheduler)
412
-
413
- Usage in train_hyper.py:
414
- import chimera_turbo
415
- model, optimizer, scheduler = chimera_turbo.apply(
416
- model, max_steps=10000, lr=1e-3
417
- )
418
- # Then use normal training loop:
419
- for step, batch in enumerate(dataloader):
420
- loss = model(batch).loss
421
- loss.backward()
422
- if (step + 1) % grad_accum == 0:
423
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
424
- optimizer.step()
425
- scheduler.step()
426
- optimizer.zero_grad(set_to_none=True)
427
- chimera_turbo.invalidate_all_caches(model)
428
  """
429
- # ── Step 1: Detect CPU ──
430
  cpu_info = detect_cpu_info()
431
-
432
  if verbose:
433
  print("=" * 65)
434
- print("CHIMERA TURBO — CPU Acceleration Layer")
435
  print("=" * 65)
436
  print(f" Physical cores: {cpu_info['physical_cores']}")
437
  print(f" CPU capability: {cpu_info['capability']}")
438
- print(f" AMX: {cpu_info['has_amx']} AVX-512: {cpu_info['has_avx512']}")
439
  print(f" IPEX: {cpu_info['ipex_available']}")
440
  print(f" tcmalloc: {cpu_info['tcmalloc']}")
441
-
442
- # ── Step 2: Threading ──
443
  n_threads = configure_threading(cpu_info)
444
  if verbose:
445
- print(f"[TURBO-3] Threads: {n_threads} compute + {torch.get_num_interop_threads()} interop")
446
-
447
- # ── Step 3: Optimizer (replaces MeZO) ──
448
  optimizer = create_optimizer(model, lr=lr, weight_decay=weight_decay, use_lion=use_lion)
449
  scheduler = create_scheduler(optimizer, max_steps=max_steps, warmup_steps=warmup_steps)
450
  if verbose:
451
  opt_name = type(optimizer).__name__
452
  n_params = sum(p.numel() for g in optimizer.param_groups for p in g["params"])
453
  print(f"[TURBO-1] {opt_name} (lr={lr}, wd={weight_decay}) — {n_params:,} params")
454
- print(f" Replaces MeZO: 528 forwards/step → 1 forward + 1 backward")
455
-
456
- # ── Step 4: IPEX ──
457
  if use_ipex:
458
  model, optimizer = try_ipex_optimize(model, optimizer, cpu_info)
459
-
460
- # ── Step 5: torch.compile ──
461
  if use_compile:
462
  model = try_compile_model(model)
463
-
 
464
  if verbose:
 
 
 
 
 
 
465
  if not cpu_info["tcmalloc"]:
466
  print()
467
  print(" ⚠️ tcmalloc not detected. For +10-25% speedup:")
468
  print(" sudo apt install google-perftools")
469
  print(" LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 python train_hyper.py ...")
470
  print("=" * 65)
471
-
472
  return model, optimizer, scheduler
473
 
474
 
475
  # ═══════════════════════════════════════════════════════════
476
- # Training loop helper
477
  # ═══════════════════════════════════════════════════════════
478
 
479
  def training_step(
@@ -488,12 +346,15 @@ def training_step(
488
  ) -> float:
489
  """
490
  Single training step with all turbo optimizations active.
491
-
492
  Handles: autocast, gradient accumulation, clipping, cache invalidation.
 
 
 
 
493
  """
494
  is_accum_step = (step + 1) % grad_accum_steps == 0
495
-
496
- # Forward + backward
497
  ctx = torch.autocast(device_type="cpu", dtype=autocast_dtype) if autocast_dtype else nullcontext()
498
  with ctx:
499
  if isinstance(batch, dict):
@@ -503,37 +364,38 @@ def training_step(
503
  else:
504
  outputs = model(batch)
505
  loss = outputs if isinstance(outputs, torch.Tensor) else outputs.loss
506
- loss = loss / grad_accum_steps
507
-
 
 
508
  loss.backward()
509
-
510
  if is_accum_step:
511
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
512
  optimizer.step()
513
  scheduler.step()
514
  optimizer.zero_grad(set_to_none=True)
515
  invalidate_all_caches(model)
516
-
517
- return loss.item() * grad_accum_steps
518
 
519
 
520
  # ═══════════════════════════════════════════════════════════
521
- # Diagnostic tool
522
  # ═══════════════════════════════════════════════════════════
523
 
524
  def profile_model(model: nn.Module, dummy_input: torch.Tensor, steps: int = 5):
525
  """Profile forward+backward to find bottlenecks."""
526
  print("\n[TURBO-DIAG] Profiling...")
527
-
528
- # Warmup
529
  for _ in range(2):
530
  out = model(dummy_input)
531
- if hasattr(out, "loss"):
532
  out.loss.backward()
533
- else:
534
  out.sum().backward()
535
  model.zero_grad(set_to_none=True)
536
-
537
  with torch.profiler.profile(
538
  activities=[torch.profiler.ProfilerActivity.CPU],
539
  record_shapes=True,
@@ -541,9 +403,26 @@ def profile_model(model: nn.Module, dummy_input: torch.Tensor, steps: int = 5):
541
  ) as prof:
542
  for _ in range(steps):
543
  out = model(dummy_input)
544
- loss = out.loss if hasattr(out, "loss") else out.sum()
545
  loss.backward()
546
  model.zero_grad(set_to_none=True)
547
-
548
  print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
549
  return prof
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  chimera_turbo.py — Drop-in CPU acceleration for ch1mera 5.3
3
+ Usage: import chimera_turbo; chimera_turbo.apply(model, max_steps=N)
4
 
5
  Paradigmes intégrés:
6
  P-TURBO-1: STE + AdamW (remplace MeZO → fix convergence + 50x moins de forwards)
7
+ P-TURBO-2: torch.compile regional — DISABLED (84 graph breaks from _RoundTernarySTE)
8
  P-TURBO-3: Threading optimal + tcmalloc detection
9
  P-TURBO-4: IPEX bf16/AMX si disponible
10
+ P-TURBO-5: Cache poids quantifiés inter micro-batch (via BitLinear existing cache)
11
  P-TURBO-6: INT8 ternary forward path (VNNI/AMX dispatch)
12
  P-TURBO-7: Arrow mmap dataset
13
+
14
+ v2 changes:
15
+ - torch.compile DISABLED by default: _RoundTernarySTE (autograd.Function) causes
16
+ 84+ graph breaks (28 layers × 3 BitLinear each). Net effect is SLOWER than eager.
17
+ Re-enable only after migrating STE to functional torch (torch.round + custom_vjp).
18
+ - Fix grad_accum_steps logic: DataLoader already provides eff_batch, don't double-accumulate.
19
+ - Add profile_bottleneck() for quick diagnosis.
20
+ - Better bf16 autocast handling: skip autocast if CPU has no AMX/AVX512-BF16.
21
  """
22
 
23
+ import math
24
  import os
25
  import sys
26
  import warnings
27
  import torch
28
  import torch.nn as nn
29
  import torch.nn.functional as F
30
+ from typing import Optional, Dict, Any, Tuple, List
 
31
  from contextlib import nullcontext
32
 
33
  # ═══════════════════════════════════════════════════════════
 
37
  def detect_cpu_info() -> Dict[str, Any]:
38
  """Detect CPU capabilities for optimal configuration."""
39
  info = {}
40
+
41
  # Physical cores (not hyperthreads)
42
  try:
43
  physical = len(os.sched_getaffinity(0))
 
44
  import multiprocessing
45
  logical = multiprocessing.cpu_count()
46
  info["physical_cores"] = logical // 2 if logical == physical else physical
 
49
  import multiprocessing
50
  info["logical_cores"] = multiprocessing.cpu_count()
51
  info["physical_cores"] = info["logical_cores"] // 2
52
+
53
  # CPU capability
54
  try:
55
  info["capability"] = torch.backends.cpu.get_cpu_capability()
56
  except Exception:
57
  info["capability"] = "unknown"
58
+
59
+ cap = (info["capability"] or "").lower()
60
+ info["has_amx"] = "amx" in cap
61
+ info["has_avx512"] = "avx512" in cap or "avx512_vnni" in cap
62
+ info["has_avx512_bf16"] = "avx512_bf16" in cap or info["has_amx"]
63
+ info["has_vnni"] = info["has_avx512"]
64
+
65
  # IPEX available?
66
  try:
67
  import intel_extension_for_pytorch
 
69
  info["ipex_version"] = intel_extension_for_pytorch.__version__
70
  except ImportError:
71
  info["ipex_available"] = False
72
+
73
  # tcmalloc loaded?
74
  info["tcmalloc"] = "tcmalloc" in os.environ.get("LD_PRELOAD", "")
75
+
76
  return info
77
 
78
 
79
  def configure_threading(cpu_info: Dict[str, Any], reserve_for_io: int = 1):
80
  """Set optimal threading for CPU training."""
81
  n_compute = max(1, cpu_info["physical_cores"] - reserve_for_io)
82
+
83
+ # Only set num_threads — interop threads can only be set once before
84
+ # any tensor ops, and train_hyper.py already sets them at import time.
85
  torch.set_num_threads(n_compute)
86
+
 
87
  os.environ["OMP_NUM_THREADS"] = str(n_compute)
88
  os.environ["MKL_NUM_THREADS"] = str(n_compute)
89
+
90
  return n_compute
91
 
92
 
 
103
  ) -> torch.optim.Optimizer:
104
  """
105
  Create optimizer for STE-based ternary training (replaces MeZO).
106
+
107
  Based on BitNet b1.58 Reloaded (2407.09527):
108
  - lr=1e-3 for <300M params (NOT 1e-2, that's for 3B+)
109
  - weight_decay=0.05
110
  - AdamW with β=(0.9, 0.95)
 
 
 
111
  """
 
112
  decay_params = []
113
  no_decay_params = []
114
+
115
  for name, param in model.named_parameters():
116
  if not param.requires_grad:
117
  continue
 
119
  no_decay_params.append(param)
120
  else:
121
  decay_params.append(param)
122
+
123
  param_groups = [
124
  {"params": decay_params, "weight_decay": weight_decay},
125
  {"params": no_decay_params, "weight_decay": 0.0},
126
  ]
127
+
128
  if use_lion:
129
  try:
130
  from lion_pytorch import Lion
131
  return Lion(param_groups, lr=lr * 0.3, betas=(0.95, 0.98))
132
  except ImportError:
133
  warnings.warn("lion-pytorch not installed, falling back to AdamW")
134
+
135
  return torch.optim.AdamW(param_groups, lr=lr, betas=betas, fused=False)
136
 
137
 
138
  def create_scheduler(optimizer, max_steps: int, warmup_steps: int = 500):
139
  """Cosine schedule with linear warmup — standard BitNet recipe."""
140
  from torch.optim.lr_scheduler import LambdaLR
141
+
 
142
  def lr_lambda(step):
143
  if step < warmup_steps:
144
  return step / max(1, warmup_steps)
145
  progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
146
  return max(0.01, 0.5 * (1.0 + math.cos(math.pi * progress)))
147
+
148
  return LambdaLR(optimizer, lr_lambda)
149
 
150
 
151
  # ═══════════════════════════════════════════════════════════
152
+ # P-TURBO-5 : Invalidate BitLinear packed caches after optimizer step
153
  # ═══════════════════════════════════════════════════════════
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def invalidate_all_caches(model: nn.Module):
156
+ """Call after optimizer.step() to force BitLinear re-quantization.
 
 
 
 
157
 
158
+ In training mode, BitLinear._forward_train() recomputes quantized
159
+ weights every call via STE, so the packed cache is not used.
160
+ This is still good practice for eval steps between training.
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  """
162
+ from chimera.quantization import BitLinear
163
+ for m in model.modules():
164
+ if isinstance(m, BitLinear):
165
+ m.invalidate_packed()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
 
168
  # ═══════════════════════════════════════════════════════════
 
179
  if not cpu_info.get("ipex_available"):
180
  print("[TURBO-4] IPEX not available — install: pip install intel-extension-for-pytorch")
181
  return model, optimizer
182
+
183
  import intel_extension_for_pytorch as ipex
184
+
 
185
  if dtype is None:
186
  if cpu_info["has_amx"]:
187
+ dtype = torch.bfloat16
188
  print("[TURBO-4] IPEX + AMX bf16 enabled (Sapphire Rapids+)")
189
  elif cpu_info["has_avx512"]:
190
+ dtype = torch.bfloat16
191
  print("[TURBO-4] IPEX + AVX-512 bf16 enabled")
192
  else:
193
+ dtype = torch.float32
194
  print("[TURBO-4] IPEX fp32 (no bf16 hardware support detected)")
195
+
196
  model, optimizer = ipex.optimize(
197
  model,
198
  optimizer=optimizer,
 
200
  level="O1",
201
  inplace=True,
202
  )
203
+
204
  return model, optimizer
205
 
206
 
207
  # ═══════════════════════════════════════════════════════════
208
+ # P-TURBO-2 : torch.compile DISABLED by default
209
  # ═══════════════════════════════════════════════════════════
210
 
211
+ def try_compile_model(model: nn.Module, mode: str = "reduce-overhead") -> nn.Module:
 
 
 
 
 
 
 
 
212
  """
213
+ Attempt torch.compile with graceful fallback.
214
+
215
+ CURRENTLY DISABLED: _RoundTernarySTE (torch.autograd.Function) causes
216
+ 84+ graph breaks across 28 layers × 3 BitLinear. This makes torch.compile
217
+ slower than eager mode due to recompilation overhead.
218
+
219
+ To re-enable: migrate STE to use torch library custom ops:
220
+ @torch.library.custom_op("chimera::ste_ternary", mutates_args=())
221
+ def ste_ternary(w: torch.Tensor) -> torch.Tensor:
222
+ return torch.round(torch.clamp(w, -1.0, 1.0))
223
+
224
+ @ste_ternary.register_fake
225
+ def _(w): return torch.empty_like(w)
226
+
227
+ @torch.library.register_autograd("chimera::ste_ternary", ...)
228
  """
229
+ print("[TURBO-2] torch.compile SKIPPED (84 graph breaks from STE autograd.Function)")
230
+ print(" To enable: migrate _RoundTernarySTE to torch.library.custom_op")
231
+ return model
232
+
233
+
234
+ # ═══════════════════════════════════════════════════════════
235
+ # P-TURBO-6 : INT8 Ternary Forward Path
236
+ # ═══════════════════════════════════════════════════════════
237
+
238
+ def ternary_matmul_int8(
239
+ x: torch.Tensor,
240
+ w_ternary: torch.Tensor,
241
+ w_scale: torch.Tensor,
242
+ ) -> torch.Tensor:
243
+ """INT8 ternary matmul using torch._int_mm (dispatches to VNNI/AMX)."""
244
+ B, S, K = x.shape
245
+ x_flat = x.reshape(-1, K)
246
+
247
+ x_abs_max = x_flat.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
248
+ x_scale = x_abs_max / 127.0
249
+ x_int8 = (x_flat / x_scale).round().clamp(-128, 127).to(torch.int8)
250
+ w_int8 = w_ternary.to(torch.int8)
251
+
252
+ try:
253
+ out_int32 = torch._int_mm(x_int8, w_int8.t())
254
+ out = out_int32.float() * x_scale * w_scale
255
+ except RuntimeError:
256
+ out = F.linear(x_flat.float(), w_ternary.float()) * w_scale
257
+
258
+ return out.reshape(B, S, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
 
261
  # ═══════════════════════════════════════════════════════════
 
268
  lr: float = 1e-3,
269
  weight_decay: float = 0.05,
270
  warmup_steps: int = 500,
271
+ use_compile: bool = False, # ← DISABLED by default (was True)
272
  use_ipex: bool = True,
273
  use_lion: bool = False,
274
  verbose: bool = True,
275
  ) -> Tuple[nn.Module, torch.optim.Optimizer, Any]:
276
  """
277
  Apply all turbo optimizations to ch1mera model.
278
+
279
  Returns: (model, optimizer, scheduler)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  """
 
281
  cpu_info = detect_cpu_info()
282
+
283
  if verbose:
284
  print("=" * 65)
285
+ print("CHIMERA TURBO v2 — CPU Acceleration Layer")
286
  print("=" * 65)
287
  print(f" Physical cores: {cpu_info['physical_cores']}")
288
  print(f" CPU capability: {cpu_info['capability']}")
289
+ print(f" AMX: {cpu_info['has_amx']} AVX-512: {cpu_info['has_avx512']} BF16 hw: {cpu_info['has_avx512_bf16']}")
290
  print(f" IPEX: {cpu_info['ipex_available']}")
291
  print(f" tcmalloc: {cpu_info['tcmalloc']}")
292
+
293
+ # ── Threading ──
294
  n_threads = configure_threading(cpu_info)
295
  if verbose:
296
+ print(f"[TURBO-3] Compute threads: {n_threads}")
297
+
298
+ # ── Optimizer (replaces MeZO) ──
299
  optimizer = create_optimizer(model, lr=lr, weight_decay=weight_decay, use_lion=use_lion)
300
  scheduler = create_scheduler(optimizer, max_steps=max_steps, warmup_steps=warmup_steps)
301
  if verbose:
302
  opt_name = type(optimizer).__name__
303
  n_params = sum(p.numel() for g in optimizer.param_groups for p in g["params"])
304
  print(f"[TURBO-1] {opt_name} (lr={lr}, wd={weight_decay}) — {n_params:,} params")
305
+ print(f" STE backprop: 1 forward + 1 backward per step")
306
+
307
+ # ── IPEX ──
308
  if use_ipex:
309
  model, optimizer = try_ipex_optimize(model, optimizer, cpu_info)
310
+
311
+ # ── torch.compile ──
312
  if use_compile:
313
  model = try_compile_model(model)
314
+
315
+ # ── Autocast recommendation ──
316
  if verbose:
317
+ if not cpu_info["has_avx512_bf16"]:
318
+ print()
319
+ print(" ⚠️ No hardware BF16 support detected (need AVX512-BF16 or AMX).")
320
+ print(" BF16 autocast may be SLOWER than fp32 on this CPU.")
321
+ print(" Consider --no-bf16 flag if training is slow.")
322
+
323
  if not cpu_info["tcmalloc"]:
324
  print()
325
  print(" ⚠️ tcmalloc not detected. For +10-25% speedup:")
326
  print(" sudo apt install google-perftools")
327
  print(" LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 python train_hyper.py ...")
328
  print("=" * 65)
329
+
330
  return model, optimizer, scheduler
331
 
332
 
333
  # ═══════════════════════════════════════════════════════════
334
+ # Training step helper
335
  # ═══════════════════════════════════════════════════════════
336
 
337
  def training_step(
 
346
  ) -> float:
347
  """
348
  Single training step with all turbo optimizations active.
349
+
350
  Handles: autocast, gradient accumulation, clipping, cache invalidation.
351
+
352
+ IMPORTANT: grad_accum_steps should be 1 if the DataLoader already provides
353
+ the full effective batch. Set >1 only if you want to split a large batch
354
+ across multiple forward passes.
355
  """
356
  is_accum_step = (step + 1) % grad_accum_steps == 0
357
+
 
358
  ctx = torch.autocast(device_type="cpu", dtype=autocast_dtype) if autocast_dtype else nullcontext()
359
  with ctx:
360
  if isinstance(batch, dict):
 
364
  else:
365
  outputs = model(batch)
366
  loss = outputs if isinstance(outputs, torch.Tensor) else outputs.loss
367
+ loss_val = loss.item()
368
+ if grad_accum_steps > 1:
369
+ loss = loss / grad_accum_steps
370
+
371
  loss.backward()
372
+
373
  if is_accum_step:
374
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
375
  optimizer.step()
376
  scheduler.step()
377
  optimizer.zero_grad(set_to_none=True)
378
  invalidate_all_caches(model)
379
+
380
+ return loss_val
381
 
382
 
383
  # ═══════════════════════════════════════════════════════════
384
+ # Diagnostic tools
385
  # ═══════════════════════════════════════════════════════════
386
 
387
  def profile_model(model: nn.Module, dummy_input: torch.Tensor, steps: int = 5):
388
  """Profile forward+backward to find bottlenecks."""
389
  print("\n[TURBO-DIAG] Profiling...")
390
+
 
391
  for _ in range(2):
392
  out = model(dummy_input)
393
+ if hasattr(out, "loss") and out.loss is not None:
394
  out.loss.backward()
395
+ elif isinstance(out, torch.Tensor):
396
  out.sum().backward()
397
  model.zero_grad(set_to_none=True)
398
+
399
  with torch.profiler.profile(
400
  activities=[torch.profiler.ProfilerActivity.CPU],
401
  record_shapes=True,
 
403
  ) as prof:
404
  for _ in range(steps):
405
  out = model(dummy_input)
406
+ loss = out.loss if (hasattr(out, "loss") and out.loss is not None) else out.sum()
407
  loss.backward()
408
  model.zero_grad(set_to_none=True)
409
+
410
  print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
411
  return prof
412
+
413
+
414
+ def count_compile_graph_breaks(model: nn.Module, dummy_input: torch.Tensor):
415
+ """Count how many graph breaks torch.compile would produce."""
416
+ try:
417
+ import torch._dynamo as dynamo
418
+ explanation = dynamo.explain(model)(dummy_input)
419
+ n_breaks = len(explanation.break_reasons)
420
+ print(f"\n[TURBO-DIAG] Graph breaks: {n_breaks}")
421
+ for i, reason in enumerate(explanation.break_reasons[:10]):
422
+ print(f" [{i+1}] {reason}")
423
+ if n_breaks > 10:
424
+ print(f" ... and {n_breaks - 10} more")
425
+ return n_breaks
426
+ except Exception as e:
427
+ print(f"[TURBO-DIAG] dynamo.explain failed: {e}")
428
+ return -1