Lgr54HFi commited on
Commit
a97a233
·
verified ·
1 Parent(s): ec200d2

fix: lower max_grad_norm 1.0→0.5 to prevent NaN with ternary STE training"

Browse files
Files changed (1) hide show
  1. chimera_turbo.py +24 -230
chimera_turbo.py CHANGED
@@ -10,10 +10,7 @@ Paradigmes intégrés:
10
  P-TURBO-5: Invalidate BitLinear packed caches after optimizer step
11
  P-TURBO-6: INT8 ternary forward path (VNNI/AMX dispatch)
12
 
13
- v5 changes:
14
- - Fix IPEX version mismatch crash: IPEX for PyTorch 2.8 installed with
15
- PyTorch 2.11 calls os.exit(127) which doesn't exist → AttributeError.
16
- Now catches Exception (not just ImportError) on IPEX import.
17
  """
18
 
19
  import math
@@ -25,14 +22,9 @@ import torch.nn.functional as F
25
  from typing import Optional, Dict, Any, Tuple
26
  from contextlib import nullcontext
27
 
28
- # ═══════════════════════════════════════════════════════════
29
- # P-TURBO-3 : Threading + Environment
30
- # ═══════════════════════════════════════════════════════════
31
 
32
  def detect_cpu_info() -> Dict[str, Any]:
33
- """Detect CPU capabilities for optimal configuration."""
34
  info = {}
35
-
36
  try:
37
  physical = len(os.sched_getaffinity(0))
38
  import multiprocessing
@@ -43,35 +35,26 @@ def detect_cpu_info() -> Dict[str, Any]:
43
  import multiprocessing
44
  info["logical_cores"] = multiprocessing.cpu_count()
45
  info["physical_cores"] = info["logical_cores"] // 2
46
-
47
  try:
48
  info["capability"] = torch.backends.cpu.get_cpu_capability()
49
  except Exception:
50
  info["capability"] = "unknown"
51
-
52
  cap = (info["capability"] or "").lower()
53
  info["has_amx"] = "amx" in cap
54
  info["has_avx512"] = "avx512" in cap or "avx512_vnni" in cap
55
  info["has_avx512_bf16"] = "avx512_bf16" in cap or info["has_amx"]
56
  info["has_vnni"] = info["has_avx512"]
57
-
58
- # IPEX import can crash in many ways: ImportError (not installed),
59
- # SystemExit (version mismatch), AttributeError (buggy os.exit in IPEX),
60
- # RuntimeError, etc. Catch broadly.
61
  try:
62
  import intel_extension_for_pytorch
63
  info["ipex_available"] = True
64
  info["ipex_version"] = intel_extension_for_pytorch.__version__
65
  except Exception:
66
  info["ipex_available"] = False
67
-
68
  info["tcmalloc"] = "tcmalloc" in os.environ.get("LD_PRELOAD", "")
69
-
70
  return info
71
 
72
 
73
  def configure_threading(cpu_info: Dict[str, Any], reserve_for_io: int = 1):
74
- """Set optimal threading for CPU training."""
75
  n_compute = max(1, cpu_info["physical_cores"] - reserve_for_io)
76
  torch.set_num_threads(n_compute)
77
  os.environ["OMP_NUM_THREADS"] = str(n_compute)
@@ -79,28 +62,11 @@ def configure_threading(cpu_info: Dict[str, Any], reserve_for_io: int = 1):
79
  return n_compute
80
 
81
 
82
- # ═══════════════════════════════════════════════════════════
83
- # P-TURBO-1 : STE + AdamW (remplace MeZO)
84
- # ═══════════════════════════════════════════════════════════
85
-
86
  def create_optimizer(
87
- model: nn.Module,
88
- lr: float = 1e-3,
89
- weight_decay: float = 0.05,
90
- use_lion: bool = False,
91
- betas: Tuple[float, float] = (0.9, 0.95),
92
  ) -> torch.optim.Optimizer:
93
- """
94
- Create optimizer for STE-based ternary training (replaces MeZO).
95
-
96
- Based on BitNet b1.58 Reloaded (2407.09527):
97
- - lr=1e-3 for <300M params
98
- - weight_decay=0.05
99
- - AdamW with β=(0.9, 0.95)
100
- """
101
- decay_params = []
102
- no_decay_params = []
103
-
104
  for name, param in model.named_parameters():
105
  if not param.requires_grad:
106
  continue
@@ -108,237 +74,118 @@ def create_optimizer(
108
  no_decay_params.append(param)
109
  else:
110
  decay_params.append(param)
111
-
112
  param_groups = [
113
  {"params": decay_params, "weight_decay": weight_decay},
114
  {"params": no_decay_params, "weight_decay": 0.0},
115
  ]
116
-
117
  if use_lion:
118
  try:
119
  from lion_pytorch import Lion
120
  return Lion(param_groups, lr=lr * 0.3, betas=(0.95, 0.98))
121
  except ImportError:
122
  warnings.warn("lion-pytorch not installed, falling back to AdamW")
123
-
124
  return torch.optim.AdamW(param_groups, lr=lr, betas=betas, fused=False)
125
 
126
 
127
  def create_scheduler(optimizer, max_steps: int, warmup_steps: int = 500):
128
- """Cosine schedule with linear warmup — standard BitNet recipe."""
129
  from torch.optim.lr_scheduler import LambdaLR
130
-
131
  def lr_lambda(step):
132
  if step < warmup_steps:
133
  return step / max(1, warmup_steps)
134
  progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
135
  return max(0.01, 0.5 * (1.0 + math.cos(math.pi * progress)))
136
-
137
  return LambdaLR(optimizer, lr_lambda)
138
 
139
 
140
- # ═══════════════════════════════════════════════════════════
141
- # P-TURBO-5 : Invalidate BitLinear packed caches
142
- # ═══════════════════════════════════════════════════════════
143
-
144
  def invalidate_all_caches(model: nn.Module):
145
- """Call after optimizer.step() to force BitLinear re-quantization."""
146
  from chimera.quantization import BitLinear
147
  for m in model.modules():
148
  if isinstance(m, BitLinear):
149
  m.invalidate_packed()
150
 
151
 
152
- # ═══════════════════════════════════════════════════════════
153
- # P-TURBO-4 : IPEX Integration
154
- # ═══════════════════════════════════════════════════════════
155
-
156
- def try_ipex_optimize(
157
- model: nn.Module,
158
- optimizer: torch.optim.Optimizer,
159
- cpu_info: Dict[str, Any],
160
- dtype: Optional[torch.dtype] = None,
161
- ) -> Tuple[nn.Module, torch.optim.Optimizer]:
162
- """Apply IPEX optimization if available and beneficial."""
163
  if not cpu_info.get("ipex_available"):
164
  print("[TURBO-4] IPEX not available — skipping")
165
  return model, optimizer
166
-
167
  try:
168
  import intel_extension_for_pytorch as ipex
169
  except Exception:
170
  print("[TURBO-4] IPEX import failed — skipping")
171
  return model, optimizer
172
-
173
  if dtype is None:
174
  if cpu_info["has_amx"]:
175
  dtype = torch.bfloat16
176
- print("[TURBO-4] IPEX + AMX bf16 enabled (Sapphire Rapids+)")
177
  elif cpu_info["has_avx512"]:
178
  dtype = torch.bfloat16
179
  print("[TURBO-4] IPEX + AVX-512 bf16 enabled")
180
  else:
181
  dtype = torch.float32
182
- print("[TURBO-4] IPEX fp32 (no bf16 hardware support detected)")
183
-
184
- model, optimizer = ipex.optimize(
185
- model, optimizer=optimizer, dtype=dtype, level="O1", inplace=True,
186
- )
187
  return model, optimizer
188
 
189
 
190
- # ═══════════════════════════════════════════════════════════
191
- # P-TURBO-2 : torch.compile
192
- # ═══════════════════════════════════════════════════════════
193
-
194
  def try_compile_model(model: nn.Module, mode: str = "default") -> nn.Module:
195
- """
196
- Compile model with torch.compile for kernel fusion.
197
-
198
- Uses mode='default' for CPU stability. Do NOT use 'reduce-overhead'
199
- on CPU — it corrupts the glibc heap allocator.
200
-
201
- Expected: first ~10 steps slow (compilation), then ~1.5-2x speedup.
202
- """
203
  if not hasattr(torch, "compile"):
204
- warnings.warn("torch.compile not available (PyTorch < 2.0)")
205
  return model
206
-
207
  try:
208
- compiled = torch.compile(
209
- model,
210
- backend="inductor",
211
- mode=mode,
212
- fullgraph=False,
213
- )
214
- print(f"[TURBO-2] torch.compile enabled (backend=inductor, mode={mode})")
215
- print(f" First few steps will be slow (compilation). Then ~1.5-2x speedup.")
216
  return compiled
217
  except Exception as e:
218
- warnings.warn(f"torch.compile failed: {e}. Running in eager mode.")
219
  return model
220
 
221
 
222
- # ═══════════════════════════════════════════════════════════
223
- # P-TURBO-6 : INT8 Ternary Forward Path
224
- # ═══════════════════════════════════════════════════════════
225
-
226
- def ternary_matmul_int8(
227
- x: torch.Tensor,
228
- w_ternary: torch.Tensor,
229
- w_scale: torch.Tensor,
230
- ) -> torch.Tensor:
231
- """INT8 ternary matmul using torch._int_mm (dispatches to VNNI/AMX)."""
232
- B, S, K = x.shape
233
- x_flat = x.reshape(-1, K)
234
-
235
- x_abs_max = x_flat.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
236
- x_scale = x_abs_max / 127.0
237
- x_int8 = (x_flat / x_scale).round().clamp(-128, 127).to(torch.int8)
238
- w_int8 = w_ternary.to(torch.int8)
239
-
240
- try:
241
- out_int32 = torch._int_mm(x_int8, w_int8.t())
242
- out = out_int32.float() * x_scale * w_scale
243
- except RuntimeError:
244
- out = F.linear(x_flat.float(), w_ternary.float()) * w_scale
245
-
246
- return out.reshape(B, S, -1)
247
-
248
-
249
- # ═══════════════════════════════════════════════════════════
250
- # MAIN: apply()
251
- # ═══════════════════════════════════════════════════════════
252
-
253
  def apply(
254
- model: nn.Module,
255
- max_steps: int = 10000,
256
- lr: float = 1e-3,
257
- weight_decay: float = 0.05,
258
- warmup_steps: int = 500,
259
- use_compile: bool = True,
260
- use_ipex: bool = True,
261
- use_lion: bool = False,
262
- verbose: bool = True,
263
  ) -> Tuple[nn.Module, torch.optim.Optimizer, Any]:
264
- """
265
- Apply all turbo optimizations to ch1mera model.
266
-
267
- Returns: (model, optimizer, scheduler)
268
- """
269
  cpu_info = detect_cpu_info()
270
-
271
  if verbose:
272
  print("=" * 65)
273
- print("CHIMERA TURBO v5 — CPU Acceleration Layer")
274
  print("=" * 65)
275
- print(f" Physical cores: {cpu_info['physical_cores']}")
276
- print(f" CPU capability: {cpu_info['capability']}")
277
  print(f" AMX: {cpu_info['has_amx']} AVX-512: {cpu_info['has_avx512']} BF16 hw: {cpu_info['has_avx512_bf16']}")
278
- print(f" IPEX: {cpu_info['ipex_available']}")
279
- print(f" tcmalloc: {cpu_info['tcmalloc']}")
280
 
281
- # ── Threading ──
282
  n_threads = configure_threading(cpu_info)
283
  if verbose:
284
  print(f"[TURBO-3] Compute threads: {n_threads}")
285
 
286
- # ── Optimizer (replaces MeZO) ──
287
  optimizer = create_optimizer(model, lr=lr, weight_decay=weight_decay, use_lion=use_lion)
288
  scheduler = create_scheduler(optimizer, max_steps=max_steps, warmup_steps=warmup_steps)
289
  if verbose:
290
- opt_name = type(optimizer).__name__
291
  n_params = sum(p.numel() for g in optimizer.param_groups for p in g["params"])
292
- print(f"[TURBO-1] {opt_name} (lr={lr}, wd={weight_decay}) — {n_params:,} params")
293
- print(f" STE backprop: 1 forward + 1 backward per step")
294
 
295
- # ── IPEX ──
296
  if use_ipex:
297
  model, optimizer = try_ipex_optimize(model, optimizer, cpu_info)
298
-
299
- # ── torch.compile ──
300
  if use_compile:
301
  model = try_compile_model(model, mode="default")
302
 
303
- # ── Warnings ──
304
  if verbose:
305
  if not cpu_info["has_avx512_bf16"]:
306
- print()
307
- print(" ⚠️ No hardware BF16 support detected (need AVX512-BF16 or AMX).")
308
- print(" BF16 autocast may be SLOWER than fp32 on this CPU.")
309
- print(" Consider --no-bf16 flag if training is slow.")
310
  if not cpu_info["tcmalloc"]:
311
- print()
312
- print(" ⚠️ tcmalloc not detected. For +10-25% speedup:")
313
- print(" sudo apt install google-perftools")
314
- print(" LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 python train_hyper.py ...")
315
  print("=" * 65)
316
 
317
  return model, optimizer, scheduler
318
 
319
 
320
- # ═══════════════════════════════════════════════════════════
321
- # Training step helper
322
- # ═══════════════════════════════════════════════════════════
323
-
324
  def training_step(
325
- model: nn.Module,
326
- batch,
327
- optimizer: torch.optim.Optimizer,
328
- scheduler,
329
- grad_accum_steps: int = 1,
330
- step: int = 0,
331
- max_grad_norm: float = 1.0,
332
  autocast_dtype: Optional[torch.dtype] = torch.bfloat16,
333
  ) -> float:
334
- """
335
- Single training step with all turbo optimizations active.
336
-
337
- IMPORTANT: grad_accum_steps should be 1 if the DataLoader already provides
338
- the full effective batch. Set >1 only for memory-constrained scenarios.
339
- """
340
  is_accum_step = (step + 1) % grad_accum_steps == 0
341
-
342
  ctx = torch.autocast(device_type="cpu", dtype=autocast_dtype) if autocast_dtype else nullcontext()
343
  with ctx:
344
  if isinstance(batch, dict):
@@ -351,64 +198,11 @@ def training_step(
351
  loss_val = loss.item()
352
  if grad_accum_steps > 1:
353
  loss = loss / grad_accum_steps
354
-
355
  loss.backward()
356
-
357
  if is_accum_step:
358
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
359
  optimizer.step()
360
  scheduler.step()
361
  optimizer.zero_grad(set_to_none=True)
362
  invalidate_all_caches(model)
363
-
364
  return loss_val
365
-
366
-
367
- # ═══════════════════════════════════════════════════════════
368
- # Diagnostic tools
369
- # ═══════════════════════════════════════════════════════════
370
-
371
- def profile_model(model: nn.Module, dummy_input: torch.Tensor, steps: int = 5):
372
- """Profile forward+backward to find bottlenecks."""
373
- print("\n[TURBO-DIAG] Profiling...")
374
-
375
- for _ in range(2):
376
- out = model(dummy_input)
377
- if hasattr(out, "loss") and out.loss is not None:
378
- out.loss.backward()
379
- elif isinstance(out, torch.Tensor):
380
- out.sum().backward()
381
- model.zero_grad(set_to_none=True)
382
-
383
- with torch.profiler.profile(
384
- activities=[torch.profiler.ProfilerActivity.CPU],
385
- record_shapes=True,
386
- with_stack=True,
387
- ) as prof:
388
- for _ in range(steps):
389
- out = model(dummy_input)
390
- loss = out.loss if (hasattr(out, "loss") and out.loss is not None) else out.sum()
391
- loss.backward()
392
- model.zero_grad(set_to_none=True)
393
-
394
- print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
395
- return prof
396
-
397
-
398
- def count_compile_graph_breaks(model: nn.Module, dummy_input: torch.Tensor):
399
- """Count how many graph breaks torch.compile would produce."""
400
- try:
401
- import torch._dynamo as dynamo
402
- explanation = dynamo.explain(model)(dummy_input)
403
- n_breaks = len(explanation.break_reasons)
404
- print(f"\n[TURBO-DIAG] Graph breaks: {n_breaks}")
405
- for i, reason in enumerate(explanation.break_reasons[:10]):
406
- print(f" [{i+1}] {reason}")
407
- if n_breaks > 10:
408
- print(f" ... and {n_breaks - 10} more")
409
- if n_breaks == 0:
410
- print(" ✅ Zero graph breaks — full model is compilable!")
411
- return n_breaks
412
- except Exception as e:
413
- print(f"[TURBO-DIAG] dynamo.explain failed: {e}")
414
- return -1
 
10
  P-TURBO-5: Invalidate BitLinear packed caches after optimizer step
11
  P-TURBO-6: INT8 ternary forward path (VNNI/AMX dispatch)
12
 
13
+ v6: lower max_grad_norm 1.0→0.5, clamp-aware STE in quantization.py
 
 
 
14
  """
15
 
16
  import math
 
22
  from typing import Optional, Dict, Any, Tuple
23
  from contextlib import nullcontext
24
 
 
 
 
25
 
26
  def detect_cpu_info() -> Dict[str, Any]:
 
27
  info = {}
 
28
  try:
29
  physical = len(os.sched_getaffinity(0))
30
  import multiprocessing
 
35
  import multiprocessing
36
  info["logical_cores"] = multiprocessing.cpu_count()
37
  info["physical_cores"] = info["logical_cores"] // 2
 
38
  try:
39
  info["capability"] = torch.backends.cpu.get_cpu_capability()
40
  except Exception:
41
  info["capability"] = "unknown"
 
42
  cap = (info["capability"] or "").lower()
43
  info["has_amx"] = "amx" in cap
44
  info["has_avx512"] = "avx512" in cap or "avx512_vnni" in cap
45
  info["has_avx512_bf16"] = "avx512_bf16" in cap or info["has_amx"]
46
  info["has_vnni"] = info["has_avx512"]
 
 
 
 
47
  try:
48
  import intel_extension_for_pytorch
49
  info["ipex_available"] = True
50
  info["ipex_version"] = intel_extension_for_pytorch.__version__
51
  except Exception:
52
  info["ipex_available"] = False
 
53
  info["tcmalloc"] = "tcmalloc" in os.environ.get("LD_PRELOAD", "")
 
54
  return info
55
 
56
 
57
  def configure_threading(cpu_info: Dict[str, Any], reserve_for_io: int = 1):
 
58
  n_compute = max(1, cpu_info["physical_cores"] - reserve_for_io)
59
  torch.set_num_threads(n_compute)
60
  os.environ["OMP_NUM_THREADS"] = str(n_compute)
 
62
  return n_compute
63
 
64
 
 
 
 
 
65
  def create_optimizer(
66
+ model: nn.Module, lr: float = 1e-3, weight_decay: float = 0.05,
67
+ use_lion: bool = False, betas: Tuple[float, float] = (0.9, 0.95),
 
 
 
68
  ) -> torch.optim.Optimizer:
69
+ decay_params, no_decay_params = [], []
 
 
 
 
 
 
 
 
 
 
70
  for name, param in model.named_parameters():
71
  if not param.requires_grad:
72
  continue
 
74
  no_decay_params.append(param)
75
  else:
76
  decay_params.append(param)
 
77
  param_groups = [
78
  {"params": decay_params, "weight_decay": weight_decay},
79
  {"params": no_decay_params, "weight_decay": 0.0},
80
  ]
 
81
  if use_lion:
82
  try:
83
  from lion_pytorch import Lion
84
  return Lion(param_groups, lr=lr * 0.3, betas=(0.95, 0.98))
85
  except ImportError:
86
  warnings.warn("lion-pytorch not installed, falling back to AdamW")
 
87
  return torch.optim.AdamW(param_groups, lr=lr, betas=betas, fused=False)
88
 
89
 
90
  def create_scheduler(optimizer, max_steps: int, warmup_steps: int = 500):
 
91
  from torch.optim.lr_scheduler import LambdaLR
 
92
  def lr_lambda(step):
93
  if step < warmup_steps:
94
  return step / max(1, warmup_steps)
95
  progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
96
  return max(0.01, 0.5 * (1.0 + math.cos(math.pi * progress)))
 
97
  return LambdaLR(optimizer, lr_lambda)
98
 
99
 
 
 
 
 
100
  def invalidate_all_caches(model: nn.Module):
 
101
  from chimera.quantization import BitLinear
102
  for m in model.modules():
103
  if isinstance(m, BitLinear):
104
  m.invalidate_packed()
105
 
106
 
107
+ def try_ipex_optimize(model, optimizer, cpu_info, dtype=None):
 
 
 
 
 
 
 
 
 
 
108
  if not cpu_info.get("ipex_available"):
109
  print("[TURBO-4] IPEX not available — skipping")
110
  return model, optimizer
 
111
  try:
112
  import intel_extension_for_pytorch as ipex
113
  except Exception:
114
  print("[TURBO-4] IPEX import failed — skipping")
115
  return model, optimizer
 
116
  if dtype is None:
117
  if cpu_info["has_amx"]:
118
  dtype = torch.bfloat16
119
+ print("[TURBO-4] IPEX + AMX bf16 enabled")
120
  elif cpu_info["has_avx512"]:
121
  dtype = torch.bfloat16
122
  print("[TURBO-4] IPEX + AVX-512 bf16 enabled")
123
  else:
124
  dtype = torch.float32
125
+ print("[TURBO-4] IPEX fp32")
126
+ model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=dtype, level="O1", inplace=True)
 
 
 
127
  return model, optimizer
128
 
129
 
 
 
 
 
130
  def try_compile_model(model: nn.Module, mode: str = "default") -> nn.Module:
 
 
 
 
 
 
 
 
131
  if not hasattr(torch, "compile"):
 
132
  return model
 
133
  try:
134
+ compiled = torch.compile(model, backend="inductor", mode=mode, fullgraph=False)
135
+ print(f"[TURBO-2] torch.compile enabled (mode={mode})")
 
 
 
 
 
 
136
  return compiled
137
  except Exception as e:
138
+ warnings.warn(f"torch.compile failed: {e}. Eager mode.")
139
  return model
140
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def apply(
143
+ model: nn.Module, max_steps: int = 10000, lr: float = 1e-3,
144
+ weight_decay: float = 0.05, warmup_steps: int = 500,
145
+ use_compile: bool = True, use_ipex: bool = True,
146
+ use_lion: bool = False, verbose: bool = True,
 
 
 
 
 
147
  ) -> Tuple[nn.Module, torch.optim.Optimizer, Any]:
 
 
 
 
 
148
  cpu_info = detect_cpu_info()
 
149
  if verbose:
150
  print("=" * 65)
151
+ print("CHIMERA TURBO v6 — CPU Acceleration Layer")
152
  print("=" * 65)
153
+ print(f" Cores: {cpu_info['physical_cores']} CPU: {cpu_info['capability']}")
 
154
  print(f" AMX: {cpu_info['has_amx']} AVX-512: {cpu_info['has_avx512']} BF16 hw: {cpu_info['has_avx512_bf16']}")
155
+ print(f" IPEX: {cpu_info['ipex_available']} tcmalloc: {cpu_info['tcmalloc']}")
 
156
 
 
157
  n_threads = configure_threading(cpu_info)
158
  if verbose:
159
  print(f"[TURBO-3] Compute threads: {n_threads}")
160
 
 
161
  optimizer = create_optimizer(model, lr=lr, weight_decay=weight_decay, use_lion=use_lion)
162
  scheduler = create_scheduler(optimizer, max_steps=max_steps, warmup_steps=warmup_steps)
163
  if verbose:
 
164
  n_params = sum(p.numel() for g in optimizer.param_groups for p in g["params"])
165
+ print(f"[TURBO-1] AdamW (lr={lr}, wd={weight_decay}) — {n_params:,} params")
 
166
 
 
167
  if use_ipex:
168
  model, optimizer = try_ipex_optimize(model, optimizer, cpu_info)
 
 
169
  if use_compile:
170
  model = try_compile_model(model, mode="default")
171
 
 
172
  if verbose:
173
  if not cpu_info["has_avx512_bf16"]:
174
+ print(" ⚠️ No BF16 hw — use --no-bf16")
 
 
 
175
  if not cpu_info["tcmalloc"]:
176
+ print(" ⚠️ No tcmalloc — LD_PRELOAD=...libtcmalloc.so.4 for +15%")
 
 
 
177
  print("=" * 65)
178
 
179
  return model, optimizer, scheduler
180
 
181
 
 
 
 
 
182
  def training_step(
183
+ model: nn.Module, batch, optimizer: torch.optim.Optimizer, scheduler,
184
+ grad_accum_steps: int = 1, step: int = 0,
185
+ max_grad_norm: float = 0.5, # ← lowered from 1.0 to prevent NaN
 
 
 
 
186
  autocast_dtype: Optional[torch.dtype] = torch.bfloat16,
187
  ) -> float:
 
 
 
 
 
 
188
  is_accum_step = (step + 1) % grad_accum_steps == 0
 
189
  ctx = torch.autocast(device_type="cpu", dtype=autocast_dtype) if autocast_dtype else nullcontext()
190
  with ctx:
191
  if isinstance(batch, dict):
 
198
  loss_val = loss.item()
199
  if grad_accum_steps > 1:
200
  loss = loss / grad_accum_steps
 
201
  loss.backward()
 
202
  if is_accum_step:
203
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
204
  optimizer.step()
205
  scheduler.step()
206
  optimizer.zero_grad(set_to_none=True)
207
  invalidate_all_caches(model)
 
208
  return loss_val