Lgr54HFi commited on
Commit
dd57d33
Β·
verified Β·
1 Parent(s): 31b0fdf

perf: re-enable torch.compile now that STE uses detach() trick (zero graph breaks)"

Browse files
Files changed (1) hide show
  1. chimera_turbo.py +50 -63
chimera_turbo.py CHANGED
@@ -4,30 +4,27 @@ Usage: import chimera_turbo; chimera_turbo.apply(model, max_steps=N)
4
 
5
  Paradigmes intΓ©grΓ©s:
6
  P-TURBO-1: STE + AdamW (remplace MeZO β†’ fix convergence + 50x moins de forwards)
7
- P-TURBO-2: torch.compile regional β€” DISABLED (84 graph breaks from _RoundTernarySTE)
8
  P-TURBO-3: Threading optimal + tcmalloc detection
9
  P-TURBO-4: IPEX bf16/AMX si disponible
10
- P-TURBO-5: Cache poids quantifiΓ©s inter micro-batch (via BitLinear existing cache)
11
  P-TURBO-6: INT8 ternary forward path (VNNI/AMX dispatch)
12
- P-TURBO-7: Arrow mmap dataset
13
-
14
- v2 changes:
15
- - torch.compile DISABLED by default: _RoundTernarySTE (autograd.Function) causes
16
- 84+ graph breaks (28 layers Γ— 3 BitLinear each). Net effect is SLOWER than eager.
17
- Re-enable only after migrating STE to functional torch (torch.round + custom_vjp).
18
- - Fix grad_accum_steps logic: DataLoader already provides eff_batch, don't double-accumulate.
19
- - Add profile_bottleneck() for quick diagnosis.
20
- - Better bf16 autocast handling: skip autocast if CPU has no AMX/AVX512-BF16.
21
  """
22
 
23
  import math
24
  import os
25
- import sys
26
  import warnings
27
  import torch
28
  import torch.nn as nn
29
  import torch.nn.functional as F
30
- from typing import Optional, Dict, Any, Tuple, List
31
  from contextlib import nullcontext
32
 
33
  # ═══════════════════════════════════════════════════════════
@@ -38,7 +35,6 @@ def detect_cpu_info() -> Dict[str, Any]:
38
  """Detect CPU capabilities for optimal configuration."""
39
  info = {}
40
 
41
- # Physical cores (not hyperthreads)
42
  try:
43
  physical = len(os.sched_getaffinity(0))
44
  import multiprocessing
@@ -50,7 +46,6 @@ def detect_cpu_info() -> Dict[str, Any]:
50
  info["logical_cores"] = multiprocessing.cpu_count()
51
  info["physical_cores"] = info["logical_cores"] // 2
52
 
53
- # CPU capability
54
  try:
55
  info["capability"] = torch.backends.cpu.get_cpu_capability()
56
  except Exception:
@@ -62,7 +57,6 @@ def detect_cpu_info() -> Dict[str, Any]:
62
  info["has_avx512_bf16"] = "avx512_bf16" in cap or info["has_amx"]
63
  info["has_vnni"] = info["has_avx512"]
64
 
65
- # IPEX available?
66
  try:
67
  import intel_extension_for_pytorch
68
  info["ipex_available"] = True
@@ -70,7 +64,6 @@ def detect_cpu_info() -> Dict[str, Any]:
70
  except ImportError:
71
  info["ipex_available"] = False
72
 
73
- # tcmalloc loaded?
74
  info["tcmalloc"] = "tcmalloc" in os.environ.get("LD_PRELOAD", "")
75
 
76
  return info
@@ -79,14 +72,9 @@ def detect_cpu_info() -> Dict[str, Any]:
79
  def configure_threading(cpu_info: Dict[str, Any], reserve_for_io: int = 1):
80
  """Set optimal threading for CPU training."""
81
  n_compute = max(1, cpu_info["physical_cores"] - reserve_for_io)
82
-
83
- # Only set num_threads β€” interop threads can only be set once before
84
- # any tensor ops, and train_hyper.py already sets them at import time.
85
  torch.set_num_threads(n_compute)
86
-
87
  os.environ["OMP_NUM_THREADS"] = str(n_compute)
88
  os.environ["MKL_NUM_THREADS"] = str(n_compute)
89
-
90
  return n_compute
91
 
92
 
@@ -105,7 +93,7 @@ def create_optimizer(
105
  Create optimizer for STE-based ternary training (replaces MeZO).
106
 
107
  Based on BitNet b1.58 Reloaded (2407.09527):
108
- - lr=1e-3 for <300M params (NOT 1e-2, that's for 3B+)
109
  - weight_decay=0.05
110
  - AdamW with Ξ²=(0.9, 0.95)
111
  """
@@ -149,16 +137,11 @@ def create_scheduler(optimizer, max_steps: int, warmup_steps: int = 500):
149
 
150
 
151
  # ═══════════════════════════════════════════════════════════
152
- # P-TURBO-5 : Invalidate BitLinear packed caches after optimizer step
153
  # ═══════════════════════════════════════════════════════════
154
 
155
  def invalidate_all_caches(model: nn.Module):
156
- """Call after optimizer.step() to force BitLinear re-quantization.
157
-
158
- In training mode, BitLinear._forward_train() recomputes quantized
159
- weights every call via STE, so the packed cache is not used.
160
- This is still good practice for eval steps between training.
161
- """
162
  from chimera.quantization import BitLinear
163
  for m in model.modules():
164
  if isinstance(m, BitLinear):
@@ -194,41 +177,43 @@ def try_ipex_optimize(
194
  print("[TURBO-4] IPEX fp32 (no bf16 hardware support detected)")
195
 
196
  model, optimizer = ipex.optimize(
197
- model,
198
- optimizer=optimizer,
199
- dtype=dtype,
200
- level="O1",
201
- inplace=True,
202
  )
203
-
204
  return model, optimizer
205
 
206
 
207
  # ═══════════════════════════════════════════════════════════
208
- # P-TURBO-2 : torch.compile β€” DISABLED by default
209
  # ═══════════════════════════════════════════════════════════
210
 
211
  def try_compile_model(model: nn.Module, mode: str = "reduce-overhead") -> nn.Module:
212
  """
213
- Attempt torch.compile with graceful fallback.
214
-
215
- CURRENTLY DISABLED: _RoundTernarySTE (torch.autograd.Function) causes
216
- 84+ graph breaks across 28 layers Γ— 3 BitLinear. This makes torch.compile
217
- slower than eager mode due to recompilation overhead.
218
 
219
- To re-enable: migrate STE to use torch library custom ops:
220
- @torch.library.custom_op("chimera::ste_ternary", mutates_args=())
221
- def ste_ternary(w: torch.Tensor) -> torch.Tensor:
222
- return torch.round(torch.clamp(w, -1.0, 1.0))
223
 
224
- @ste_ternary.register_fake
225
- def _(w): return torch.empty_like(w)
226
-
227
- @torch.library.register_autograd("chimera::ste_ternary", ...)
228
  """
229
- print("[TURBO-2] torch.compile SKIPPED (84 graph breaks from STE autograd.Function)")
230
- print(" To enable: migrate _RoundTernarySTE to torch.library.custom_op")
231
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
 
234
  # ═══════════════════════════════════════════════════════════
@@ -259,7 +244,7 @@ def ternary_matmul_int8(
259
 
260
 
261
  # ═══════════════════════════════════════════════════════════
262
- # MAIN: apply() β€” Point d'entrΓ©e unique
263
  # ═══════════════════════════════════════════════════════════
264
 
265
  def apply(
@@ -268,7 +253,7 @@ def apply(
268
  lr: float = 1e-3,
269
  weight_decay: float = 0.05,
270
  warmup_steps: int = 500,
271
- use_compile: bool = False, # ← DISABLED by default (was True)
272
  use_ipex: bool = True,
273
  use_lion: bool = False,
274
  verbose: bool = True,
@@ -282,7 +267,7 @@ def apply(
282
 
283
  if verbose:
284
  print("=" * 65)
285
- print("CHIMERA TURBO v2 β€” CPU Acceleration Layer")
286
  print("=" * 65)
287
  print(f" Physical cores: {cpu_info['physical_cores']}")
288
  print(f" CPU capability: {cpu_info['capability']}")
@@ -312,14 +297,13 @@ def apply(
312
  if use_compile:
313
  model = try_compile_model(model)
314
 
315
- # ── Autocast recommendation ──
316
  if verbose:
317
  if not cpu_info["has_avx512_bf16"]:
318
  print()
319
  print(" ⚠️ No hardware BF16 support detected (need AVX512-BF16 or AMX).")
320
  print(" BF16 autocast may be SLOWER than fp32 on this CPU.")
321
  print(" Consider --no-bf16 flag if training is slow.")
322
-
323
  if not cpu_info["tcmalloc"]:
324
  print()
325
  print(" ⚠️ tcmalloc not detected. For +10-25% speedup:")
@@ -347,11 +331,8 @@ def training_step(
347
  """
348
  Single training step with all turbo optimizations active.
349
 
350
- Handles: autocast, gradient accumulation, clipping, cache invalidation.
351
-
352
  IMPORTANT: grad_accum_steps should be 1 if the DataLoader already provides
353
- the full effective batch. Set >1 only if you want to split a large batch
354
- across multiple forward passes.
355
  """
356
  is_accum_step = (step + 1) % grad_accum_steps == 0
357
 
@@ -412,7 +393,11 @@ def profile_model(model: nn.Module, dummy_input: torch.Tensor, steps: int = 5):
412
 
413
 
414
  def count_compile_graph_breaks(model: nn.Module, dummy_input: torch.Tensor):
415
- """Count how many graph breaks torch.compile would produce."""
 
 
 
 
416
  try:
417
  import torch._dynamo as dynamo
418
  explanation = dynamo.explain(model)(dummy_input)
@@ -422,6 +407,8 @@ def count_compile_graph_breaks(model: nn.Module, dummy_input: torch.Tensor):
422
  print(f" [{i+1}] {reason}")
423
  if n_breaks > 10:
424
  print(f" ... and {n_breaks - 10} more")
 
 
425
  return n_breaks
426
  except Exception as e:
427
  print(f"[TURBO-DIAG] dynamo.explain failed: {e}")
 
4
 
5
  Paradigmes intΓ©grΓ©s:
6
  P-TURBO-1: STE + AdamW (remplace MeZO β†’ fix convergence + 50x moins de forwards)
7
+ P-TURBO-2: torch.compile (now possible β€” STE uses detach() trick, zero graph breaks)
8
  P-TURBO-3: Threading optimal + tcmalloc detection
9
  P-TURBO-4: IPEX bf16/AMX si disponible
10
+ P-TURBO-5: Invalidate BitLinear packed caches after optimizer step
11
  P-TURBO-6: INT8 ternary forward path (VNNI/AMX dispatch)
12
+
13
+ v3 changes (after quantization.py STE migration):
14
+ - torch.compile RE-ENABLED: _RoundTernarySTE replaced with detach() trick
15
+ in quantization.py β€” zero graph breaks, Inductor can fuse quantize+linear.
16
+ - Compile uses fullgraph=False as safety net for any remaining breaks
17
+ in non-BitLinear modules (evolution engine, loop controller, etc.)
18
+ - grad_accum_steps fix from v2 preserved.
 
 
19
  """
20
 
21
  import math
22
  import os
 
23
  import warnings
24
  import torch
25
  import torch.nn as nn
26
  import torch.nn.functional as F
27
+ from typing import Optional, Dict, Any, Tuple
28
  from contextlib import nullcontext
29
 
30
  # ═══════════════════════════════════════════════════════════
 
35
  """Detect CPU capabilities for optimal configuration."""
36
  info = {}
37
 
 
38
  try:
39
  physical = len(os.sched_getaffinity(0))
40
  import multiprocessing
 
46
  info["logical_cores"] = multiprocessing.cpu_count()
47
  info["physical_cores"] = info["logical_cores"] // 2
48
 
 
49
  try:
50
  info["capability"] = torch.backends.cpu.get_cpu_capability()
51
  except Exception:
 
57
  info["has_avx512_bf16"] = "avx512_bf16" in cap or info["has_amx"]
58
  info["has_vnni"] = info["has_avx512"]
59
 
 
60
  try:
61
  import intel_extension_for_pytorch
62
  info["ipex_available"] = True
 
64
  except ImportError:
65
  info["ipex_available"] = False
66
 
 
67
  info["tcmalloc"] = "tcmalloc" in os.environ.get("LD_PRELOAD", "")
68
 
69
  return info
 
72
  def configure_threading(cpu_info: Dict[str, Any], reserve_for_io: int = 1):
73
  """Set optimal threading for CPU training."""
74
  n_compute = max(1, cpu_info["physical_cores"] - reserve_for_io)
 
 
 
75
  torch.set_num_threads(n_compute)
 
76
  os.environ["OMP_NUM_THREADS"] = str(n_compute)
77
  os.environ["MKL_NUM_THREADS"] = str(n_compute)
 
78
  return n_compute
79
 
80
 
 
93
  Create optimizer for STE-based ternary training (replaces MeZO).
94
 
95
  Based on BitNet b1.58 Reloaded (2407.09527):
96
+ - lr=1e-3 for <300M params
97
  - weight_decay=0.05
98
  - AdamW with Ξ²=(0.9, 0.95)
99
  """
 
137
 
138
 
139
  # ═══════════════════════════════════════════════════════════
140
+ # P-TURBO-5 : Invalidate BitLinear packed caches
141
  # ═══════════════════════════════════════════════════════════
142
 
143
  def invalidate_all_caches(model: nn.Module):
144
+ """Call after optimizer.step() to force BitLinear re-quantization."""
 
 
 
 
 
145
  from chimera.quantization import BitLinear
146
  for m in model.modules():
147
  if isinstance(m, BitLinear):
 
177
  print("[TURBO-4] IPEX fp32 (no bf16 hardware support detected)")
178
 
179
  model, optimizer = ipex.optimize(
180
+ model, optimizer=optimizer, dtype=dtype, level="O1", inplace=True,
 
 
 
 
181
  )
 
182
  return model, optimizer
183
 
184
 
185
  # ═══════════════════════════════════════════════════════════
186
+ # P-TURBO-2 : torch.compile
187
  # ═══════════════════════════════════════════════════════════
188
 
189
  def try_compile_model(model: nn.Module, mode: str = "reduce-overhead") -> nn.Module:
190
  """
191
+ Compile model with torch.compile for kernel fusion.
 
 
 
 
192
 
193
+ Now possible because STE uses the detach() trick (zero graph breaks
194
+ in BitLinear). Uses fullgraph=False as safety net for any remaining
195
+ breaks in non-BitLinear modules (evolution engine, grammar FST, etc.)
 
196
 
197
+ Expected speedup: 1.5-3x from fusing quantize + linear + activation.
198
+ First call is slow (compilation); subsequent calls are fast.
 
 
199
  """
200
+ if not hasattr(torch, "compile"):
201
+ warnings.warn("torch.compile not available (PyTorch < 2.0)")
202
+ return model
203
+
204
+ try:
205
+ compiled = torch.compile(
206
+ model,
207
+ backend="inductor",
208
+ mode=mode,
209
+ fullgraph=False, # safety net for non-BitLinear graph breaks
210
+ )
211
+ print(f"[TURBO-2] torch.compile enabled (backend=inductor, mode={mode})")
212
+ print(f" First few steps will be slow (compilation). Then 1.5-3x speedup.")
213
+ return compiled
214
+ except Exception as e:
215
+ warnings.warn(f"torch.compile failed: {e}. Running in eager mode.")
216
+ return model
217
 
218
 
219
  # ═══════════════════════════════════════════════════════════
 
244
 
245
 
246
  # ═══════════════════════════════════════════════════════════
247
+ # MAIN: apply()
248
  # ═══════════════════════════════════════════════════════════
249
 
250
  def apply(
 
253
  lr: float = 1e-3,
254
  weight_decay: float = 0.05,
255
  warmup_steps: int = 500,
256
+ use_compile: bool = True, # ← RE-ENABLED (STE detach trick = zero graph breaks)
257
  use_ipex: bool = True,
258
  use_lion: bool = False,
259
  verbose: bool = True,
 
267
 
268
  if verbose:
269
  print("=" * 65)
270
+ print("CHIMERA TURBO v3 β€” CPU Acceleration Layer")
271
  print("=" * 65)
272
  print(f" Physical cores: {cpu_info['physical_cores']}")
273
  print(f" CPU capability: {cpu_info['capability']}")
 
297
  if use_compile:
298
  model = try_compile_model(model)
299
 
300
+ # ── Warnings ──
301
  if verbose:
302
  if not cpu_info["has_avx512_bf16"]:
303
  print()
304
  print(" ⚠️ No hardware BF16 support detected (need AVX512-BF16 or AMX).")
305
  print(" BF16 autocast may be SLOWER than fp32 on this CPU.")
306
  print(" Consider --no-bf16 flag if training is slow.")
 
307
  if not cpu_info["tcmalloc"]:
308
  print()
309
  print(" ⚠️ tcmalloc not detected. For +10-25% speedup:")
 
331
  """
332
  Single training step with all turbo optimizations active.
333
 
 
 
334
  IMPORTANT: grad_accum_steps should be 1 if the DataLoader already provides
335
+ the full effective batch. Set >1 only for memory-constrained scenarios.
 
336
  """
337
  is_accum_step = (step + 1) % grad_accum_steps == 0
338
 
 
393
 
394
 
395
  def count_compile_graph_breaks(model: nn.Module, dummy_input: torch.Tensor):
396
+ """Count how many graph breaks torch.compile would produce.
397
+
398
+ After the STE detach() migration, BitLinear should produce ZERO breaks.
399
+ Remaining breaks come from non-BitLinear modules (evolution, grammar, etc.)
400
+ """
401
  try:
402
  import torch._dynamo as dynamo
403
  explanation = dynamo.explain(model)(dummy_input)
 
407
  print(f" [{i+1}] {reason}")
408
  if n_breaks > 10:
409
  print(f" ... and {n_breaks - 10} more")
410
+ if n_breaks == 0:
411
+ print(" βœ… Zero graph breaks β€” full model is compilable!")
412
  return n_breaks
413
  except Exception as e:
414
  print(f"[TURBO-DIAG] dynamo.explain failed: {e}")