perf: re-enable torch.compile now that STE uses detach() trick (zero graph breaks)"
Browse files- chimera_turbo.py +50 -63
chimera_turbo.py
CHANGED
|
@@ -4,30 +4,27 @@ Usage: import chimera_turbo; chimera_turbo.apply(model, max_steps=N)
|
|
| 4 |
|
| 5 |
Paradigmes intΓ©grΓ©s:
|
| 6 |
P-TURBO-1: STE + AdamW (remplace MeZO β fix convergence + 50x moins de forwards)
|
| 7 |
-
P-TURBO-2: torch.compile
|
| 8 |
P-TURBO-3: Threading optimal + tcmalloc detection
|
| 9 |
P-TURBO-4: IPEX bf16/AMX si disponible
|
| 10 |
-
P-TURBO-5:
|
| 11 |
P-TURBO-6: INT8 ternary forward path (VNNI/AMX dispatch)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
-
|
| 19 |
-
- Add profile_bottleneck() for quick diagnosis.
|
| 20 |
-
- Better bf16 autocast handling: skip autocast if CPU has no AMX/AVX512-BF16.
|
| 21 |
"""
|
| 22 |
|
| 23 |
import math
|
| 24 |
import os
|
| 25 |
-
import sys
|
| 26 |
import warnings
|
| 27 |
import torch
|
| 28 |
import torch.nn as nn
|
| 29 |
import torch.nn.functional as F
|
| 30 |
-
from typing import Optional, Dict, Any, Tuple
|
| 31 |
from contextlib import nullcontext
|
| 32 |
|
| 33 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -38,7 +35,6 @@ def detect_cpu_info() -> Dict[str, Any]:
|
|
| 38 |
"""Detect CPU capabilities for optimal configuration."""
|
| 39 |
info = {}
|
| 40 |
|
| 41 |
-
# Physical cores (not hyperthreads)
|
| 42 |
try:
|
| 43 |
physical = len(os.sched_getaffinity(0))
|
| 44 |
import multiprocessing
|
|
@@ -50,7 +46,6 @@ def detect_cpu_info() -> Dict[str, Any]:
|
|
| 50 |
info["logical_cores"] = multiprocessing.cpu_count()
|
| 51 |
info["physical_cores"] = info["logical_cores"] // 2
|
| 52 |
|
| 53 |
-
# CPU capability
|
| 54 |
try:
|
| 55 |
info["capability"] = torch.backends.cpu.get_cpu_capability()
|
| 56 |
except Exception:
|
|
@@ -62,7 +57,6 @@ def detect_cpu_info() -> Dict[str, Any]:
|
|
| 62 |
info["has_avx512_bf16"] = "avx512_bf16" in cap or info["has_amx"]
|
| 63 |
info["has_vnni"] = info["has_avx512"]
|
| 64 |
|
| 65 |
-
# IPEX available?
|
| 66 |
try:
|
| 67 |
import intel_extension_for_pytorch
|
| 68 |
info["ipex_available"] = True
|
|
@@ -70,7 +64,6 @@ def detect_cpu_info() -> Dict[str, Any]:
|
|
| 70 |
except ImportError:
|
| 71 |
info["ipex_available"] = False
|
| 72 |
|
| 73 |
-
# tcmalloc loaded?
|
| 74 |
info["tcmalloc"] = "tcmalloc" in os.environ.get("LD_PRELOAD", "")
|
| 75 |
|
| 76 |
return info
|
|
@@ -79,14 +72,9 @@ def detect_cpu_info() -> Dict[str, Any]:
|
|
| 79 |
def configure_threading(cpu_info: Dict[str, Any], reserve_for_io: int = 1):
|
| 80 |
"""Set optimal threading for CPU training."""
|
| 81 |
n_compute = max(1, cpu_info["physical_cores"] - reserve_for_io)
|
| 82 |
-
|
| 83 |
-
# Only set num_threads β interop threads can only be set once before
|
| 84 |
-
# any tensor ops, and train_hyper.py already sets them at import time.
|
| 85 |
torch.set_num_threads(n_compute)
|
| 86 |
-
|
| 87 |
os.environ["OMP_NUM_THREADS"] = str(n_compute)
|
| 88 |
os.environ["MKL_NUM_THREADS"] = str(n_compute)
|
| 89 |
-
|
| 90 |
return n_compute
|
| 91 |
|
| 92 |
|
|
@@ -105,7 +93,7 @@ def create_optimizer(
|
|
| 105 |
Create optimizer for STE-based ternary training (replaces MeZO).
|
| 106 |
|
| 107 |
Based on BitNet b1.58 Reloaded (2407.09527):
|
| 108 |
-
- lr=1e-3 for <300M params
|
| 109 |
- weight_decay=0.05
|
| 110 |
- AdamW with Ξ²=(0.9, 0.95)
|
| 111 |
"""
|
|
@@ -149,16 +137,11 @@ def create_scheduler(optimizer, max_steps: int, warmup_steps: int = 500):
|
|
| 149 |
|
| 150 |
|
| 151 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 152 |
-
# P-TURBO-5 : Invalidate BitLinear packed caches
|
| 153 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 154 |
|
| 155 |
def invalidate_all_caches(model: nn.Module):
|
| 156 |
-
"""Call after optimizer.step() to force BitLinear re-quantization.
|
| 157 |
-
|
| 158 |
-
In training mode, BitLinear._forward_train() recomputes quantized
|
| 159 |
-
weights every call via STE, so the packed cache is not used.
|
| 160 |
-
This is still good practice for eval steps between training.
|
| 161 |
-
"""
|
| 162 |
from chimera.quantization import BitLinear
|
| 163 |
for m in model.modules():
|
| 164 |
if isinstance(m, BitLinear):
|
|
@@ -194,41 +177,43 @@ def try_ipex_optimize(
|
|
| 194 |
print("[TURBO-4] IPEX fp32 (no bf16 hardware support detected)")
|
| 195 |
|
| 196 |
model, optimizer = ipex.optimize(
|
| 197 |
-
model,
|
| 198 |
-
optimizer=optimizer,
|
| 199 |
-
dtype=dtype,
|
| 200 |
-
level="O1",
|
| 201 |
-
inplace=True,
|
| 202 |
)
|
| 203 |
-
|
| 204 |
return model, optimizer
|
| 205 |
|
| 206 |
|
| 207 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 208 |
-
# P-TURBO-2 : torch.compile
|
| 209 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 210 |
|
| 211 |
def try_compile_model(model: nn.Module, mode: str = "reduce-overhead") -> nn.Module:
|
| 212 |
"""
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
CURRENTLY DISABLED: _RoundTernarySTE (torch.autograd.Function) causes
|
| 216 |
-
84+ graph breaks across 28 layers Γ 3 BitLinear. This makes torch.compile
|
| 217 |
-
slower than eager mode due to recompilation overhead.
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
return torch.round(torch.clamp(w, -1.0, 1.0))
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
@torch.library.register_autograd("chimera::ste_ternary", ...)
|
| 228 |
"""
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
|
| 234 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -259,7 +244,7 @@ def ternary_matmul_int8(
|
|
| 259 |
|
| 260 |
|
| 261 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 262 |
-
# MAIN: apply()
|
| 263 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 264 |
|
| 265 |
def apply(
|
|
@@ -268,7 +253,7 @@ def apply(
|
|
| 268 |
lr: float = 1e-3,
|
| 269 |
weight_decay: float = 0.05,
|
| 270 |
warmup_steps: int = 500,
|
| 271 |
-
use_compile: bool =
|
| 272 |
use_ipex: bool = True,
|
| 273 |
use_lion: bool = False,
|
| 274 |
verbose: bool = True,
|
|
@@ -282,7 +267,7 @@ def apply(
|
|
| 282 |
|
| 283 |
if verbose:
|
| 284 |
print("=" * 65)
|
| 285 |
-
print("CHIMERA TURBO
|
| 286 |
print("=" * 65)
|
| 287 |
print(f" Physical cores: {cpu_info['physical_cores']}")
|
| 288 |
print(f" CPU capability: {cpu_info['capability']}")
|
|
@@ -312,14 +297,13 @@ def apply(
|
|
| 312 |
if use_compile:
|
| 313 |
model = try_compile_model(model)
|
| 314 |
|
| 315 |
-
# ββ
|
| 316 |
if verbose:
|
| 317 |
if not cpu_info["has_avx512_bf16"]:
|
| 318 |
print()
|
| 319 |
print(" β οΈ No hardware BF16 support detected (need AVX512-BF16 or AMX).")
|
| 320 |
print(" BF16 autocast may be SLOWER than fp32 on this CPU.")
|
| 321 |
print(" Consider --no-bf16 flag if training is slow.")
|
| 322 |
-
|
| 323 |
if not cpu_info["tcmalloc"]:
|
| 324 |
print()
|
| 325 |
print(" β οΈ tcmalloc not detected. For +10-25% speedup:")
|
|
@@ -347,11 +331,8 @@ def training_step(
|
|
| 347 |
"""
|
| 348 |
Single training step with all turbo optimizations active.
|
| 349 |
|
| 350 |
-
Handles: autocast, gradient accumulation, clipping, cache invalidation.
|
| 351 |
-
|
| 352 |
IMPORTANT: grad_accum_steps should be 1 if the DataLoader already provides
|
| 353 |
-
the full effective batch. Set >1 only
|
| 354 |
-
across multiple forward passes.
|
| 355 |
"""
|
| 356 |
is_accum_step = (step + 1) % grad_accum_steps == 0
|
| 357 |
|
|
@@ -412,7 +393,11 @@ def profile_model(model: nn.Module, dummy_input: torch.Tensor, steps: int = 5):
|
|
| 412 |
|
| 413 |
|
| 414 |
def count_compile_graph_breaks(model: nn.Module, dummy_input: torch.Tensor):
|
| 415 |
-
"""Count how many graph breaks torch.compile would produce.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
try:
|
| 417 |
import torch._dynamo as dynamo
|
| 418 |
explanation = dynamo.explain(model)(dummy_input)
|
|
@@ -422,6 +407,8 @@ def count_compile_graph_breaks(model: nn.Module, dummy_input: torch.Tensor):
|
|
| 422 |
print(f" [{i+1}] {reason}")
|
| 423 |
if n_breaks > 10:
|
| 424 |
print(f" ... and {n_breaks - 10} more")
|
|
|
|
|
|
|
| 425 |
return n_breaks
|
| 426 |
except Exception as e:
|
| 427 |
print(f"[TURBO-DIAG] dynamo.explain failed: {e}")
|
|
|
|
| 4 |
|
| 5 |
Paradigmes intΓ©grΓ©s:
|
| 6 |
P-TURBO-1: STE + AdamW (remplace MeZO β fix convergence + 50x moins de forwards)
|
| 7 |
+
P-TURBO-2: torch.compile (now possible β STE uses detach() trick, zero graph breaks)
|
| 8 |
P-TURBO-3: Threading optimal + tcmalloc detection
|
| 9 |
P-TURBO-4: IPEX bf16/AMX si disponible
|
| 10 |
+
P-TURBO-5: Invalidate BitLinear packed caches after optimizer step
|
| 11 |
P-TURBO-6: INT8 ternary forward path (VNNI/AMX dispatch)
|
| 12 |
+
|
| 13 |
+
v3 changes (after quantization.py STE migration):
|
| 14 |
+
- torch.compile RE-ENABLED: _RoundTernarySTE replaced with detach() trick
|
| 15 |
+
in quantization.py β zero graph breaks, Inductor can fuse quantize+linear.
|
| 16 |
+
- Compile uses fullgraph=False as safety net for any remaining breaks
|
| 17 |
+
in non-BitLinear modules (evolution engine, loop controller, etc.)
|
| 18 |
+
- grad_accum_steps fix from v2 preserved.
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
|
| 21 |
import math
|
| 22 |
import os
|
|
|
|
| 23 |
import warnings
|
| 24 |
import torch
|
| 25 |
import torch.nn as nn
|
| 26 |
import torch.nn.functional as F
|
| 27 |
+
from typing import Optional, Dict, Any, Tuple
|
| 28 |
from contextlib import nullcontext
|
| 29 |
|
| 30 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 35 |
"""Detect CPU capabilities for optimal configuration."""
|
| 36 |
info = {}
|
| 37 |
|
|
|
|
| 38 |
try:
|
| 39 |
physical = len(os.sched_getaffinity(0))
|
| 40 |
import multiprocessing
|
|
|
|
| 46 |
info["logical_cores"] = multiprocessing.cpu_count()
|
| 47 |
info["physical_cores"] = info["logical_cores"] // 2
|
| 48 |
|
|
|
|
| 49 |
try:
|
| 50 |
info["capability"] = torch.backends.cpu.get_cpu_capability()
|
| 51 |
except Exception:
|
|
|
|
| 57 |
info["has_avx512_bf16"] = "avx512_bf16" in cap or info["has_amx"]
|
| 58 |
info["has_vnni"] = info["has_avx512"]
|
| 59 |
|
|
|
|
| 60 |
try:
|
| 61 |
import intel_extension_for_pytorch
|
| 62 |
info["ipex_available"] = True
|
|
|
|
| 64 |
except ImportError:
|
| 65 |
info["ipex_available"] = False
|
| 66 |
|
|
|
|
| 67 |
info["tcmalloc"] = "tcmalloc" in os.environ.get("LD_PRELOAD", "")
|
| 68 |
|
| 69 |
return info
|
|
|
|
| 72 |
def configure_threading(cpu_info: Dict[str, Any], reserve_for_io: int = 1):
|
| 73 |
"""Set optimal threading for CPU training."""
|
| 74 |
n_compute = max(1, cpu_info["physical_cores"] - reserve_for_io)
|
|
|
|
|
|
|
|
|
|
| 75 |
torch.set_num_threads(n_compute)
|
|
|
|
| 76 |
os.environ["OMP_NUM_THREADS"] = str(n_compute)
|
| 77 |
os.environ["MKL_NUM_THREADS"] = str(n_compute)
|
|
|
|
| 78 |
return n_compute
|
| 79 |
|
| 80 |
|
|
|
|
| 93 |
Create optimizer for STE-based ternary training (replaces MeZO).
|
| 94 |
|
| 95 |
Based on BitNet b1.58 Reloaded (2407.09527):
|
| 96 |
+
- lr=1e-3 for <300M params
|
| 97 |
- weight_decay=0.05
|
| 98 |
- AdamW with Ξ²=(0.9, 0.95)
|
| 99 |
"""
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
+
# P-TURBO-5 : Invalidate BitLinear packed caches
|
| 141 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
|
| 143 |
def invalidate_all_caches(model: nn.Module):
|
| 144 |
+
"""Call after optimizer.step() to force BitLinear re-quantization."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
from chimera.quantization import BitLinear
|
| 146 |
for m in model.modules():
|
| 147 |
if isinstance(m, BitLinear):
|
|
|
|
| 177 |
print("[TURBO-4] IPEX fp32 (no bf16 hardware support detected)")
|
| 178 |
|
| 179 |
model, optimizer = ipex.optimize(
|
| 180 |
+
model, optimizer=optimizer, dtype=dtype, level="O1", inplace=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
)
|
|
|
|
| 182 |
return model, optimizer
|
| 183 |
|
| 184 |
|
| 185 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 186 |
+
# P-TURBO-2 : torch.compile
|
| 187 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 188 |
|
| 189 |
def try_compile_model(model: nn.Module, mode: str = "reduce-overhead") -> nn.Module:
|
| 190 |
"""
|
| 191 |
+
Compile model with torch.compile for kernel fusion.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
Now possible because STE uses the detach() trick (zero graph breaks
|
| 194 |
+
in BitLinear). Uses fullgraph=False as safety net for any remaining
|
| 195 |
+
breaks in non-BitLinear modules (evolution engine, grammar FST, etc.)
|
|
|
|
| 196 |
|
| 197 |
+
Expected speedup: 1.5-3x from fusing quantize + linear + activation.
|
| 198 |
+
First call is slow (compilation); subsequent calls are fast.
|
|
|
|
|
|
|
| 199 |
"""
|
| 200 |
+
if not hasattr(torch, "compile"):
|
| 201 |
+
warnings.warn("torch.compile not available (PyTorch < 2.0)")
|
| 202 |
+
return model
|
| 203 |
+
|
| 204 |
+
try:
|
| 205 |
+
compiled = torch.compile(
|
| 206 |
+
model,
|
| 207 |
+
backend="inductor",
|
| 208 |
+
mode=mode,
|
| 209 |
+
fullgraph=False, # safety net for non-BitLinear graph breaks
|
| 210 |
+
)
|
| 211 |
+
print(f"[TURBO-2] torch.compile enabled (backend=inductor, mode={mode})")
|
| 212 |
+
print(f" First few steps will be slow (compilation). Then 1.5-3x speedup.")
|
| 213 |
+
return compiled
|
| 214 |
+
except Exception as e:
|
| 215 |
+
warnings.warn(f"torch.compile failed: {e}. Running in eager mode.")
|
| 216 |
+
return model
|
| 217 |
|
| 218 |
|
| 219 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 244 |
|
| 245 |
|
| 246 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 247 |
+
# MAIN: apply()
|
| 248 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 249 |
|
| 250 |
def apply(
|
|
|
|
| 253 |
lr: float = 1e-3,
|
| 254 |
weight_decay: float = 0.05,
|
| 255 |
warmup_steps: int = 500,
|
| 256 |
+
use_compile: bool = True, # β RE-ENABLED (STE detach trick = zero graph breaks)
|
| 257 |
use_ipex: bool = True,
|
| 258 |
use_lion: bool = False,
|
| 259 |
verbose: bool = True,
|
|
|
|
| 267 |
|
| 268 |
if verbose:
|
| 269 |
print("=" * 65)
|
| 270 |
+
print("CHIMERA TURBO v3 β CPU Acceleration Layer")
|
| 271 |
print("=" * 65)
|
| 272 |
print(f" Physical cores: {cpu_info['physical_cores']}")
|
| 273 |
print(f" CPU capability: {cpu_info['capability']}")
|
|
|
|
| 297 |
if use_compile:
|
| 298 |
model = try_compile_model(model)
|
| 299 |
|
| 300 |
+
# ββ Warnings ββ
|
| 301 |
if verbose:
|
| 302 |
if not cpu_info["has_avx512_bf16"]:
|
| 303 |
print()
|
| 304 |
print(" β οΈ No hardware BF16 support detected (need AVX512-BF16 or AMX).")
|
| 305 |
print(" BF16 autocast may be SLOWER than fp32 on this CPU.")
|
| 306 |
print(" Consider --no-bf16 flag if training is slow.")
|
|
|
|
| 307 |
if not cpu_info["tcmalloc"]:
|
| 308 |
print()
|
| 309 |
print(" β οΈ tcmalloc not detected. For +10-25% speedup:")
|
|
|
|
| 331 |
"""
|
| 332 |
Single training step with all turbo optimizations active.
|
| 333 |
|
|
|
|
|
|
|
| 334 |
IMPORTANT: grad_accum_steps should be 1 if the DataLoader already provides
|
| 335 |
+
the full effective batch. Set >1 only for memory-constrained scenarios.
|
|
|
|
| 336 |
"""
|
| 337 |
is_accum_step = (step + 1) % grad_accum_steps == 0
|
| 338 |
|
|
|
|
| 393 |
|
| 394 |
|
| 395 |
def count_compile_graph_breaks(model: nn.Module, dummy_input: torch.Tensor):
|
| 396 |
+
"""Count how many graph breaks torch.compile would produce.
|
| 397 |
+
|
| 398 |
+
After the STE detach() migration, BitLinear should produce ZERO breaks.
|
| 399 |
+
Remaining breaks come from non-BitLinear modules (evolution, grammar, etc.)
|
| 400 |
+
"""
|
| 401 |
try:
|
| 402 |
import torch._dynamo as dynamo
|
| 403 |
explanation = dynamo.explain(model)(dummy_input)
|
|
|
|
| 407 |
print(f" [{i+1}] {reason}")
|
| 408 |
if n_breaks > 10:
|
| 409 |
print(f" ... and {n_breaks - 10} more")
|
| 410 |
+
if n_breaks == 0:
|
| 411 |
+
print(" β
Zero graph breaks β full model is compilable!")
|
| 412 |
return n_breaks
|
| 413 |
except Exception as e:
|
| 414 |
print(f"[TURBO-DIAG] dynamo.explain failed: {e}")
|