Lgr54HFi commited on
Commit
974e9c4
·
verified ·
1 Parent(s): 9897d01

feat: v10 — P15 Selective Token Triage, P16 Plateau Breaker, P17 Batch Metabolism\n\nThree new paradigms fusionné dans le concept 'Adaptive Token Metabolism':\n\nP15 Token Triage (inspiré Rho-1, arxiv 2404.07965):\nCompute per-token excess loss vs EMA baseline. Top 60% tokens get\nfull gradient, bottom 40% get 0.1× gradient. No reference model needed —\nuses running EMA of per-position loss as baseline. This focuses\n~90% of gradient energy on the actually-learnable tokens.\n\nP16 Plateau Breaker:\nTrack loss EMA variance. When loss stagnates (variance < threshold\nfor 100 steps), trigger a 'warm restart': boost LR by 3× for 50 steps\nthen decay back. Inspired by SGDR (arxiv 1608.03983) but adaptive.\n\nP17 Batch Metabolism (Online Hard Example Mining for LLM):\nWithin each batch, weight sequences by their loss relative to\nbatch mean. High-loss sequences get weight up to 2×, easy ones\nget 0.5×. The model 'digests' harder examples more thoroughly."

Browse files
Files changed (1) hide show
  1. chimera_turbo.py +250 -164
chimera_turbo.py CHANGED
@@ -1,12 +1,9 @@
1
  """
2
  chimera_turbo.py — Drop-in CPU acceleration for ch1mera 5.3
3
 
4
- v9: Muon optimizer + Multi-Token Prediction + EMA Self-Distillation
5
 
6
- New paradigms:
7
- P12 Muon optimizer — 2× token efficiency via NS-orthogonalized momentum
8
- P13 Multi-Token Predict — 3× gradient signal per forward pass
9
- P14 EMA Self-Distill — dense soft targets from EMA teacher copy
10
  """
11
 
12
  import copy
@@ -16,34 +13,29 @@ import warnings
16
  import torch
17
  import torch.nn as nn
18
  import torch.nn.functional as F
19
- from typing import Optional, Dict, Any, Tuple
20
  from contextlib import nullcontext
 
21
 
22
 
23
  # ═══════════════════════════════════════════════════════════
24
- # P-TURBO-3 : CPU Detection + Threading
25
  # ═══════════════════════════════════════════════════════════
26
 
27
- def detect_cpu_info() -> Dict[str, Any]:
28
  info = {}
29
  try:
30
  import multiprocessing
31
  logical = multiprocessing.cpu_count()
32
  physical = len(os.sched_getaffinity(0))
33
  info["physical_cores"] = logical // 2 if logical == physical else physical
34
- info["logical_cores"] = logical
35
  except Exception:
36
  import multiprocessing
37
- info["logical_cores"] = multiprocessing.cpu_count()
38
- info["physical_cores"] = info["logical_cores"] // 2
39
  try:
40
  info["capability"] = torch.backends.cpu.get_cpu_capability()
41
  except Exception:
42
  info["capability"] = "unknown"
43
- cap = (info["capability"] or "").lower()
44
- info["has_amx"] = "amx" in cap
45
- info["has_avx512"] = "avx512" in cap
46
- info["has_avx512_bf16"] = "avx512_bf16" in cap or info["has_amx"]
47
  try:
48
  import intel_extension_for_pytorch
49
  info["ipex_available"] = True
@@ -66,7 +58,6 @@ def configure_threading(cpu_info, reserve=1):
66
  # ═══════════════════════════════════════════════════════════
67
 
68
  def _zeropower_via_newtonschulz5(G, steps=5):
69
- """Newton-Schulz iteration for polar factor. Pure PyTorch, CPU-safe."""
70
  assert G.ndim == 2
71
  a, b, c = 3.4445, -4.7750, 2.0315
72
  X = G.T if G.size(0) > G.size(1) else G.clone()
@@ -78,13 +69,6 @@ def _zeropower_via_newtonschulz5(G, steps=5):
78
 
79
 
80
  class Muon(torch.optim.Optimizer):
81
- """Muon: MomentUm Orthogonalized by Newton-schulz.
82
-
83
- 2D weight matrices: SGD momentum → NS orthogonalize → scaled update.
84
- Everything else (bias, norm, embed): standard AdamW.
85
-
86
- ~2× token efficiency vs AdamW (arxiv 2502.16982, Table 3).
87
- """
88
  def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
89
  ns_steps=5, weight_decay=0.0,
90
  adamw_betas=(0.9, 0.98), adamw_eps=1e-8):
@@ -96,18 +80,12 @@ class Muon(torch.optim.Optimizer):
96
  @torch.no_grad()
97
  def step(self):
98
  for group in self.param_groups:
99
- lr = group["lr"]
100
- wd = group["weight_decay"]
101
- mu = group["momentum"]
102
  b1, b2 = group["adamw_betas"]
103
-
104
  for p in group["params"]:
105
  if p.grad is None:
106
  continue
107
- g = p.grad
108
- s = self.state[p]
109
-
110
- # ── Muon path: 2D matrices (not embeddings) ──
111
  if p.ndim == 2 and not getattr(p, "_is_embed", False):
112
  if "buf" not in s:
113
  s["buf"] = torch.zeros_like(g)
@@ -118,8 +96,6 @@ class Muon(torch.optim.Optimizer):
118
  if wd > 0:
119
  p.mul_(1 - lr * wd)
120
  p.add_(O, alpha=-lr * scale)
121
-
122
- # ── AdamW path: 1D params, embeddings ──
123
  else:
124
  if "m" not in s:
125
  s["m"] = torch.zeros_like(g)
@@ -128,16 +104,13 @@ class Muon(torch.optim.Optimizer):
128
  s["t"] += 1
129
  s["m"].mul_(b1).add_(g, alpha=1 - b1)
130
  s["v"].mul_(b2).addcmul_(g, g, value=1 - b2)
131
- bc1 = 1 - b1 ** s["t"]
132
- bc2 = 1 - b2 ** s["t"]
133
- alr = lr * math.sqrt(bc2) / bc1
134
  if wd > 0:
135
  p.mul_(1 - lr * wd)
136
  p.addcdiv_(s["m"], s["v"].sqrt().add_(group["adamw_eps"]), value=-alr)
137
 
138
 
139
  def create_muon_optimizer(model, lr=0.02, momentum=0.95, weight_decay=0.01):
140
- """Create Muon optimizer with proper param group splitting."""
141
  params = []
142
  for name, p in model.named_parameters():
143
  if not p.requires_grad:
@@ -145,11 +118,8 @@ def create_muon_optimizer(model, lr=0.02, momentum=0.95, weight_decay=0.01):
145
  if any(k in name for k in ["embed", "lm_head", "wte", "wpe"]):
146
  p._is_embed = True
147
  params.append(p)
148
- return Muon(
149
- [{"params": params}],
150
- lr=lr, momentum=momentum, weight_decay=weight_decay,
151
- adamw_betas=(0.9, 0.98), adamw_eps=1e-8,
152
- )
153
 
154
 
155
  # ═══════════════════════════════════════════════════════════
@@ -157,52 +127,31 @@ def create_muon_optimizer(model, lr=0.02, momentum=0.95, weight_decay=0.01):
157
  # ═══════════════════════════════════════════════════════════
158
 
159
  class MultiTokenPredictionLoss(nn.Module):
160
- """Auxiliary loss: predict next N tokens instead of just 1.
161
-
162
- Each forward pass yields N× gradient signal from the same hidden states.
163
- Heads are lightweight linear projections sharing the trunk.
164
- """
165
- def __init__(self, hidden_size: int, vocab_size: int, n_future: int = 3):
166
  super().__init__()
167
- self.n_future = n_future
168
- # Extra heads for tokens +2, +3, ... (head for +1 is the main lm_head)
169
  self.extra_heads = nn.ModuleList([
170
  nn.Linear(hidden_size, vocab_size, bias=False)
171
  for _ in range(n_future - 1)
172
  ])
173
- # Init small to not destabilize early training
174
- for head in self.extra_heads:
175
- nn.init.normal_(head.weight, std=0.006)
176
-
177
- def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
178
- """Compute auxiliary MTP loss.
179
 
180
- Args:
181
- hidden_states: [B, T, H] from trunk (before lm_head)
182
- labels: [B, T] target token ids
183
-
184
- Returns:
185
- Scalar auxiliary loss (mean over all future positions and heads)
186
- """
187
- total_loss = torch.tensor(0.0, device=hidden_states.device)
188
- count = 0
189
  for k, head in enumerate(self.extra_heads):
190
- shift = k + 2 # head 0 predicts +2, head 1 predicts +3, etc.
191
  if shift >= labels.size(1):
192
  continue
193
- # Hidden states predict token at position +shift
194
- logits = head(hidden_states[:, :-shift]) # [B, T-shift, V]
195
- targets = labels[:, shift:] # [B, T-shift]
196
- seq_len = min(logits.size(1), targets.size(1))
197
  loss = F.cross_entropy(
198
- logits[:, :seq_len].reshape(-1, logits.size(-1)),
199
- targets[:, :seq_len].reshape(-1),
200
- ignore_index=-100,
201
- )
202
  if torch.isfinite(loss):
203
- total_loss = total_loss + loss
204
  count += 1
205
- return total_loss / max(count, 1)
206
 
207
 
208
  # ═══════════════════════════════════════════════════════════
@@ -210,63 +159,189 @@ class MultiTokenPredictionLoss(nn.Module):
210
  # ═══════════════════════════════════════════════════════════
211
 
212
  class EMASelfDistiller:
213
- """Maintain EMA copy of model as teacher for self-distillation.
214
-
215
- The EMA model's soft targets provide dense gradient signal across
216
- the full vocabulary, vs sparse one-hot labels from hard targets.
217
-
218
- α=0.5 blends hard CE and soft KL. T=2.0 temperature.
219
- Recipe from Baby Llama (arxiv 2308.02019).
220
- """
221
- def __init__(self, model: nn.Module, decay: float = 0.999, alpha: float = 0.5,
222
- temperature: float = 2.0):
223
- self.decay = decay
224
- self.alpha = alpha
225
- self.temperature = temperature
226
- # Deep copy for EMA — no gradients needed
227
  self.ema_model = copy.deepcopy(model)
228
  for p in self.ema_model.parameters():
229
  p.requires_grad_(False)
230
  self.ema_model.eval()
231
 
232
  @torch.no_grad()
233
- def update(self, model: nn.Module):
234
- """Update EMA weights. Call after optimizer.step()."""
235
  for p_ema, p in zip(self.ema_model.parameters(), model.parameters()):
236
  p_ema.data.mul_(self.decay).add_(p.data, alpha=1 - self.decay)
237
 
238
- def distillation_loss(self, student_logits: torch.Tensor,
239
- hard_targets: torch.Tensor,
240
- input_ids: torch.Tensor) -> torch.Tensor:
241
- """Compute blended hard + soft distillation loss."""
242
  T = self.temperature
243
-
244
- # Hard loss (standard CE)
245
- seq_len = min(student_logits.size(1), hard_targets.size(1))
246
  hard_loss = F.cross_entropy(
247
- student_logits[:, :seq_len].reshape(-1, student_logits.size(-1)),
248
- hard_targets[:, :seq_len].reshape(-1),
249
- ignore_index=-100,
250
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- # Soft loss (KL from EMA teacher)
253
  with torch.no_grad():
254
- teacher_out = self.ema_model(input_ids)
255
- teacher_logits = teacher_out.logits if hasattr(teacher_out, "logits") else teacher_out[1]
 
 
 
256
 
257
- t_seq = min(student_logits.size(1), teacher_logits.size(1))
258
- soft_student = F.log_softmax(student_logits[:, :t_seq] / T, dim=-1)
259
- soft_teacher = F.softmax(teacher_logits[:, :t_seq] / T, dim=-1)
260
- kl_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (T * T)
261
 
262
- if not torch.isfinite(kl_loss):
263
- return hard_loss
 
264
 
265
- return self.alpha * hard_loss + (1 - self.alpha) * kl_loss
 
266
 
267
 
268
  # ═══════════════════════════════════════════════════════════
269
- # Cache invalidation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  # ═══════════════════════════════════════════════════════════
271
 
272
  def invalidate_all_caches(model):
@@ -277,12 +352,7 @@ def invalidate_all_caches(model):
277
  m.invalidate_packed()
278
 
279
 
280
- # ═══════════════════════════════════════════════════════════
281
- # Scheduler
282
- # ═══════════════════════════════════════════════════════════
283
-
284
  def create_scheduler(optimizer, max_steps, warmup_steps=200):
285
- """Short warmup (200 steps) then cosine decay. Warmup=750 was too long."""
286
  from torch.optim.lr_scheduler import LambdaLR
287
  def lr_lambda(step):
288
  if step < warmup_steps:
@@ -300,69 +370,76 @@ def apply(
300
  model, max_steps=10000, lr=0.02, weight_decay=0.01,
301
  warmup_steps=200, use_compile=False, use_ipex=True,
302
  use_muon=True, use_mtp=True, use_distill=True,
 
303
  mtp_heads=3, verbose=True,
304
  ):
305
- """Apply all turbo + revolutionary paradigms.
306
-
307
- Returns: (model, optimizer, scheduler, extras)
308
- where extras = dict with 'mtp_loss_fn', 'distiller', etc.
309
- """
310
  cpu_info = detect_cpu_info()
311
  if verbose:
312
  print("=" * 65)
313
- print("CHIMERA TURBO v9Revolutionary Training Paradigms")
314
  print("=" * 65)
315
  print(f" Cores: {cpu_info['physical_cores']} CPU: {cpu_info['capability']}")
316
 
317
  n_threads = configure_threading(cpu_info)
318
  if verbose:
319
- print(f"[TURBO-3] Threads: {n_threads}")
320
 
321
- # ── P12: Muon optimizer ──
322
  if use_muon:
323
  optimizer = create_muon_optimizer(model, lr=lr, weight_decay=weight_decay)
324
  if verbose:
325
- n_2d = sum(p.numel() for p in model.parameters()
326
- if p.requires_grad and p.ndim == 2
327
- and not getattr(p, "_is_embed", False))
328
- n_1d = sum(p.numel() for p in model.parameters()
329
- if p.requires_grad and (p.ndim < 2 or getattr(p, "_is_embed", False)))
330
- print(f"[P12] Muon optimizer (lr={lr}, NS-5 orthogonalization)")
331
- print(f" Muon: {n_2d:,} params | AdamW fallback: {n_1d:,} params")
332
  else:
333
- from chimera_turbo_legacy import create_optimizer
334
- optimizer = create_optimizer(model, lr=lr, weight_decay=weight_decay)
335
 
336
  scheduler = create_scheduler(optimizer, max_steps, warmup_steps)
337
 
338
- # ── P13: Multi-Token Prediction ──
339
  extras = {}
 
 
 
340
  if use_mtp:
341
- raw = getattr(model, "_orig_mod", model)
342
- h = raw.config["hidden_size"]
343
- v = raw.config["vocab_size"]
344
- mtp = MultiTokenPredictionLoss(h, v, n_future=mtp_heads)
345
- extras["mtp"] = mtp
346
  if verbose:
347
- print(f"[P13] Multi-Token Prediction ({mtp_heads} heads → {mtp_heads}× gradient signal)")
348
 
349
- # ── P14: EMA Self-Distillation ──
350
  if use_distill:
351
- distiller = EMASelfDistiller(model, decay=0.999, alpha=0.5, temperature=2.0)
352
- extras["distiller"] = distiller
 
 
 
 
 
 
 
 
 
 
 
 
353
  if verbose:
354
- print(f"[P14] EMA Self-Distillation (α=0.5, T=2.0, decay=0.999)")
 
 
 
 
 
 
355
 
356
  if verbose:
357
- if not cpu_info.get("tcmalloc"):
358
- print(" ⚠️ No tcmalloc — LD_PRELOAD=...libtcmalloc.so.4 for +15%")
359
  print("=" * 65)
360
 
361
  return model, optimizer, scheduler, extras
362
 
363
 
364
  # ═══════════════════════════════════════════════════════════
365
- # Training step with all paradigms
366
  # ═══════════════════════════════════════════════════════════
367
 
368
  _nan_count = 0
@@ -371,12 +448,8 @@ def training_step(
371
  model, batch, optimizer, scheduler,
372
  extras=None, grad_accum_steps=1, step=0,
373
  max_grad_norm=1.0, autocast_dtype=None,
374
- mtp_weight=0.3, distill_weight=0.5,
375
  ) -> float:
376
- """Training step with Muon + MTP + EMA distillation.
377
-
378
- Loss = distill_loss (blended hard+soft) + mtp_weight * mtp_aux_loss
379
- """
380
  global _nan_count
381
  extras = extras or {}
382
  is_accum_step = (step + 1) % grad_accum_steps == 0
@@ -389,18 +462,28 @@ def training_step(
389
  outputs = model(input_ids, labels=labels)
390
  else:
391
  outputs = model(batch)
392
- input_ids = batch
393
- labels = batch
394
 
395
- # ── Base loss ──
 
 
 
 
396
  distiller = extras.get("distiller")
397
- if distiller is not None and hasattr(outputs, "logits"):
398
- # P14: distillation loss replaces raw CE
399
- base_loss = distiller.distillation_loss(outputs.logits, labels, input_ids)
 
 
 
 
 
 
 
400
  else:
401
  base_loss = outputs.loss if hasattr(outputs, "loss") else outputs
402
 
403
- # ── P13: MTP auxiliary loss ──
404
  mtp = extras.get("mtp")
405
  if mtp is not None and hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
406
  mtp_loss = mtp(outputs.hidden_states, labels)
@@ -423,6 +506,11 @@ def training_step(
423
 
424
  _nan_count = 0
425
 
 
 
 
 
 
426
  if grad_accum_steps > 1:
427
  total_loss = total_loss / grad_accum_steps
428
  total_loss.backward()
@@ -438,8 +526,6 @@ def training_step(
438
  scheduler.step()
439
  optimizer.zero_grad(set_to_none=True)
440
  invalidate_all_caches(model)
441
-
442
- # P14: update EMA teacher
443
  if "distiller" in extras:
444
  extras["distiller"].update(model)
445
 
 
1
  """
2
  chimera_turbo.py — Drop-in CPU acceleration for ch1mera 5.3
3
 
4
+ v10: Adaptive Token Metabolism — P15 Token Triage + P16 Plateau Breaker + P17 Batch Metabolism
5
 
6
+ Stack: Muon + MTP + EMA Distill + Token Triage + Plateau Breaker + Batch Metabolism
 
 
 
7
  """
8
 
9
  import copy
 
13
  import torch
14
  import torch.nn as nn
15
  import torch.nn.functional as F
16
+ from typing import Optional, Dict, Any, Tuple, List
17
  from contextlib import nullcontext
18
+ from collections import deque
19
 
20
 
21
  # ═══════════════════════════════════════════════════════════
22
+ # CPU Detection + Threading
23
  # ═══════════════════════════════════════════════════════════
24
 
25
+ def detect_cpu_info():
26
  info = {}
27
  try:
28
  import multiprocessing
29
  logical = multiprocessing.cpu_count()
30
  physical = len(os.sched_getaffinity(0))
31
  info["physical_cores"] = logical // 2 if logical == physical else physical
 
32
  except Exception:
33
  import multiprocessing
34
+ info["physical_cores"] = multiprocessing.cpu_count() // 2
 
35
  try:
36
  info["capability"] = torch.backends.cpu.get_cpu_capability()
37
  except Exception:
38
  info["capability"] = "unknown"
 
 
 
 
39
  try:
40
  import intel_extension_for_pytorch
41
  info["ipex_available"] = True
 
58
  # ═══════════════════════════════════════════════════════════
59
 
60
  def _zeropower_via_newtonschulz5(G, steps=5):
 
61
  assert G.ndim == 2
62
  a, b, c = 3.4445, -4.7750, 2.0315
63
  X = G.T if G.size(0) > G.size(1) else G.clone()
 
69
 
70
 
71
  class Muon(torch.optim.Optimizer):
 
 
 
 
 
 
 
72
  def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
73
  ns_steps=5, weight_decay=0.0,
74
  adamw_betas=(0.9, 0.98), adamw_eps=1e-8):
 
80
  @torch.no_grad()
81
  def step(self):
82
  for group in self.param_groups:
83
+ lr, wd, mu = group["lr"], group["weight_decay"], group["momentum"]
 
 
84
  b1, b2 = group["adamw_betas"]
 
85
  for p in group["params"]:
86
  if p.grad is None:
87
  continue
88
+ g, s = p.grad, self.state[p]
 
 
 
89
  if p.ndim == 2 and not getattr(p, "_is_embed", False):
90
  if "buf" not in s:
91
  s["buf"] = torch.zeros_like(g)
 
96
  if wd > 0:
97
  p.mul_(1 - lr * wd)
98
  p.add_(O, alpha=-lr * scale)
 
 
99
  else:
100
  if "m" not in s:
101
  s["m"] = torch.zeros_like(g)
 
104
  s["t"] += 1
105
  s["m"].mul_(b1).add_(g, alpha=1 - b1)
106
  s["v"].mul_(b2).addcmul_(g, g, value=1 - b2)
107
+ alr = lr * math.sqrt(1 - b2 ** s["t"]) / (1 - b1 ** s["t"])
 
 
108
  if wd > 0:
109
  p.mul_(1 - lr * wd)
110
  p.addcdiv_(s["m"], s["v"].sqrt().add_(group["adamw_eps"]), value=-alr)
111
 
112
 
113
  def create_muon_optimizer(model, lr=0.02, momentum=0.95, weight_decay=0.01):
 
114
  params = []
115
  for name, p in model.named_parameters():
116
  if not p.requires_grad:
 
118
  if any(k in name for k in ["embed", "lm_head", "wte", "wpe"]):
119
  p._is_embed = True
120
  params.append(p)
121
+ return Muon([{"params": params}], lr=lr, momentum=momentum,
122
+ weight_decay=weight_decay, adamw_betas=(0.9, 0.98))
 
 
 
123
 
124
 
125
  # ═══════════════════════════════════════════════════════════
 
127
  # ═══════════════════════════════════════════════════════════
128
 
129
  class MultiTokenPredictionLoss(nn.Module):
130
+ def __init__(self, hidden_size, vocab_size, n_future=3):
 
 
 
 
 
131
  super().__init__()
 
 
132
  self.extra_heads = nn.ModuleList([
133
  nn.Linear(hidden_size, vocab_size, bias=False)
134
  for _ in range(n_future - 1)
135
  ])
136
+ for h in self.extra_heads:
137
+ nn.init.normal_(h.weight, std=0.006)
 
 
 
 
138
 
139
+ def forward(self, hidden_states, labels):
140
+ total, count = 0.0, 0
 
 
 
 
 
 
 
141
  for k, head in enumerate(self.extra_heads):
142
+ shift = k + 2
143
  if shift >= labels.size(1):
144
  continue
145
+ logits = head(hidden_states[:, :-shift])
146
+ targets = labels[:, shift:]
147
+ sl = min(logits.size(1), targets.size(1))
 
148
  loss = F.cross_entropy(
149
+ logits[:, :sl].reshape(-1, logits.size(-1)),
150
+ targets[:, :sl].reshape(-1), ignore_index=-100)
 
 
151
  if torch.isfinite(loss):
152
+ total = total + loss
153
  count += 1
154
+ return total / max(count, 1) if isinstance(total, torch.Tensor) else torch.tensor(0.0)
155
 
156
 
157
  # ═══════════════════════════════════════════════════════════
 
159
  # ═══════════════════════════════════════════════════════════
160
 
161
  class EMASelfDistiller:
162
+ def __init__(self, model, decay=0.999, alpha=0.5, temperature=2.0):
163
+ self.decay, self.alpha, self.temperature = decay, alpha, temperature
 
 
 
 
 
 
 
 
 
 
 
 
164
  self.ema_model = copy.deepcopy(model)
165
  for p in self.ema_model.parameters():
166
  p.requires_grad_(False)
167
  self.ema_model.eval()
168
 
169
  @torch.no_grad()
170
+ def update(self, model):
 
171
  for p_ema, p in zip(self.ema_model.parameters(), model.parameters()):
172
  p_ema.data.mul_(self.decay).add_(p.data, alpha=1 - self.decay)
173
 
174
+ def distillation_loss(self, student_logits, hard_targets, input_ids):
 
 
 
175
  T = self.temperature
176
+ sl = min(student_logits.size(1), hard_targets.size(1))
 
 
177
  hard_loss = F.cross_entropy(
178
+ student_logits[:, :sl].reshape(-1, student_logits.size(-1)),
179
+ hard_targets[:, :sl].reshape(-1), ignore_index=-100)
180
+ with torch.no_grad():
181
+ t_out = self.ema_model(input_ids)
182
+ t_logits = t_out.logits if hasattr(t_out, "logits") else t_out[1]
183
+ tsl = min(student_logits.size(1), t_logits.size(1))
184
+ soft_s = F.log_softmax(student_logits[:, :tsl] / T, dim=-1)
185
+ soft_t = F.softmax(t_logits[:, :tsl] / T, dim=-1)
186
+ kl = F.kl_div(soft_s, soft_t, reduction="batchmean") * T * T
187
+ if not torch.isfinite(kl):
188
+ return hard_loss
189
+ return self.alpha * hard_loss + (1 - self.alpha) * kl
190
+
191
+
192
+ # ═══════════════════════════════════════════════════════════
193
+ # P15 — Token Triage (inspiré Rho-1, arxiv 2404.07965)
194
+ # ═══════════════════════════════════════════════════════════
195
+
196
+ class TokenTriage:
197
+ """Selective token-level gradient weighting without a reference model.
198
+
199
+ Instead of a separate reference model (expensive), use a running EMA
200
+ of per-token loss as the "expected" loss baseline. Tokens with excess
201
+ loss (actual - EMA) above the 40th percentile get full gradient;
202
+ tokens below get 10% gradient. This focuses ~90% of learning on the
203
+ actually-informative tokens.
204
+
205
+ Inspired by Rho-1 (arxiv 2404.07965) but self-referential: the model
206
+ IS its own reference, via temporal smoothing.
207
+ """
208
+ def __init__(self, ema_decay=0.99, select_ratio=0.6, floor_weight=0.1):
209
+ self.ema_decay = ema_decay
210
+ self.select_ratio = select_ratio # top 60% tokens get full weight
211
+ self.floor_weight = floor_weight # bottom 40% get 10% weight
212
+ self._loss_ema = None # scalar EMA of mean token loss
213
+
214
+ def weighted_loss(self, logits, targets):
215
+ """Compute token-weighted CE loss.
216
+
217
+ Returns weighted loss where informative tokens contribute more.
218
+ """
219
+ B, T, V = logits.shape
220
+ # Per-token loss (no reduction)
221
+ per_token = F.cross_entropy(
222
+ logits.reshape(-1, V), targets.reshape(-1),
223
+ ignore_index=-100, reduction="none"
224
+ ).reshape(B, T)
225
 
 
226
  with torch.no_grad():
227
+ mean_loss = per_token.mean().item()
228
+ if self._loss_ema is None:
229
+ self._loss_ema = mean_loss
230
+ else:
231
+ self._loss_ema = self.ema_decay * self._loss_ema + (1 - self.ema_decay) * mean_loss
232
 
233
+ # Excess loss = how much harder this token is than expected
234
+ excess = per_token - self._loss_ema
 
 
235
 
236
+ # Top select_ratio% by excess loss → weight 1.0, rest → floor_weight
237
+ threshold = torch.quantile(excess.flatten(), 1.0 - self.select_ratio)
238
+ weights = torch.where(excess >= threshold, 1.0, self.floor_weight)
239
 
240
+ # Weighted mean
241
+ return (per_token * weights).sum() / weights.sum()
242
 
243
 
244
  # ═══════════════════════════════════════════════════════════
245
+ # P16 — Plateau Breaker (adaptive warm restarts)
246
+ # ═══════════════════════════════════════════════════════════
247
+
248
+ class PlateauBreaker:
249
+ """Detect loss plateaus and inject LR boosts to escape.
250
+
251
+ Tracks loss variance over a window. When variance drops below a
252
+ threshold for patience steps, temporarily boosts LR by multiplier
253
+ for burst_steps, then decays back. Like SGDR warm restarts but
254
+ triggered adaptively by loss stagnation.
255
+ """
256
+ def __init__(self, patience=100, variance_threshold=0.005,
257
+ lr_multiplier=3.0, burst_steps=50):
258
+ self.patience = patience
259
+ self.var_threshold = variance_threshold
260
+ self.lr_mult = lr_multiplier
261
+ self.burst_steps = burst_steps
262
+ self._history = deque(maxlen=patience)
263
+ self._stagnant_count = 0
264
+ self._burst_remaining = 0
265
+ self._base_lr = None
266
+ self.total_bursts = 0
267
+
268
+ def check_and_adjust(self, loss_val, optimizer, step):
269
+ """Call every step. Returns True if burst was triggered."""
270
+ if not math.isfinite(loss_val):
271
+ return False
272
+
273
+ self._history.append(loss_val)
274
+
275
+ # During burst: decay LR back to base over burst_steps
276
+ if self._burst_remaining > 0:
277
+ self._burst_remaining -= 1
278
+ if self._burst_remaining == 0 and self._base_lr is not None:
279
+ for pg in optimizer.param_groups:
280
+ pg["lr"] = self._base_lr
281
+ self._base_lr = None
282
+ return False
283
+
284
+ if len(self._history) < self.patience:
285
+ return False
286
+
287
+ # Check variance
288
+ vals = list(self._history)
289
+ mean = sum(vals) / len(vals)
290
+ var = sum((v - mean) ** 2 for v in vals) / len(vals)
291
+
292
+ if var < self.var_threshold:
293
+ self._stagnant_count += 1
294
+ else:
295
+ self._stagnant_count = 0
296
+
297
+ if self._stagnant_count >= self.patience // 2:
298
+ # TRIGGER BURST
299
+ self._base_lr = optimizer.param_groups[0]["lr"]
300
+ burst_lr = self._base_lr * self.lr_mult
301
+ for pg in optimizer.param_groups:
302
+ pg["lr"] = burst_lr
303
+ self._burst_remaining = self.burst_steps
304
+ self._stagnant_count = 0
305
+ self.total_bursts += 1
306
+ print(f" [P16] Plateau detected! LR burst: {self._base_lr:.2e} → {burst_lr:.2e} for {self.burst_steps} steps")
307
+ return True
308
+ return False
309
+
310
+
311
+ # ═══════════════════════════════════════════════════════════
312
+ # P17 — Batch Metabolism (Online Hard Example Mining for LLM)
313
+ # ═══════════════════════════════════════════════════════════
314
+
315
+ def batch_metabolism_loss(logits, targets, min_weight=0.5, max_weight=2.0):
316
+ """Weight sequences within a batch by their relative difficulty.
317
+
318
+ Hard sequences (above-average loss) get up to max_weight.
319
+ Easy sequences (below-average loss) get down to min_weight.
320
+ The model "digests" harder examples more thoroughly.
321
+ """
322
+ B, T, V = logits.shape
323
+ # Per-sequence loss
324
+ per_token = F.cross_entropy(
325
+ logits.reshape(-1, V), targets.reshape(-1),
326
+ ignore_index=-100, reduction="none"
327
+ ).reshape(B, T)
328
+ seq_loss = per_token.mean(dim=1) # [B]
329
+
330
+ with torch.no_grad():
331
+ # Normalize: center on mean, scale to [min_weight, max_weight]
332
+ mean_loss = seq_loss.mean()
333
+ std_loss = seq_loss.std().clamp(min=1e-6)
334
+ # z-score → sigmoid → rescale to [min_weight, max_weight]
335
+ z = (seq_loss - mean_loss) / std_loss
336
+ weights = torch.sigmoid(z) * (max_weight - min_weight) + min_weight # [B]
337
+
338
+ # Weighted mean across batch
339
+ weighted = (per_token * weights.unsqueeze(1)).sum() / (weights.sum() * T)
340
+ return weighted
341
+
342
+
343
+ # ═══════════════════════════════════════════════════════════
344
+ # Utilities
345
  # ═══════════════════════════════════════════════════════════
346
 
347
  def invalidate_all_caches(model):
 
352
  m.invalidate_packed()
353
 
354
 
 
 
 
 
355
  def create_scheduler(optimizer, max_steps, warmup_steps=200):
 
356
  from torch.optim.lr_scheduler import LambdaLR
357
  def lr_lambda(step):
358
  if step < warmup_steps:
 
370
  model, max_steps=10000, lr=0.02, weight_decay=0.01,
371
  warmup_steps=200, use_compile=False, use_ipex=True,
372
  use_muon=True, use_mtp=True, use_distill=True,
373
+ use_triage=True, use_plateau_breaker=True, use_metabolism=True,
374
  mtp_heads=3, verbose=True,
375
  ):
 
 
 
 
 
376
  cpu_info = detect_cpu_info()
377
  if verbose:
378
  print("=" * 65)
379
+ print("CHIMERA TURBO v10Adaptive Token Metabolism")
380
  print("=" * 65)
381
  print(f" Cores: {cpu_info['physical_cores']} CPU: {cpu_info['capability']}")
382
 
383
  n_threads = configure_threading(cpu_info)
384
  if verbose:
385
+ print(f"[TURBO] Threads: {n_threads}")
386
 
387
+ # P12: Muon
388
  if use_muon:
389
  optimizer = create_muon_optimizer(model, lr=lr, weight_decay=weight_decay)
390
  if verbose:
391
+ n_muon = sum(p.numel() for p in model.parameters()
392
+ if p.requires_grad and p.ndim == 2 and not getattr(p, "_is_embed", False))
393
+ print(f"[P12] Muon (lr={lr}, NS-5) — {n_muon:,} params orthogonalized")
 
 
 
 
394
  else:
395
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr * 0.05,
396
+ betas=(0.9, 0.98), weight_decay=weight_decay)
397
 
398
  scheduler = create_scheduler(optimizer, max_steps, warmup_steps)
399
 
 
400
  extras = {}
401
+ raw = getattr(model, "_orig_mod", model)
402
+
403
+ # P13: MTP
404
  if use_mtp:
405
+ h, v = raw.config["hidden_size"], raw.config["vocab_size"]
406
+ extras["mtp"] = MultiTokenPredictionLoss(h, v, n_future=mtp_heads)
 
 
 
407
  if verbose:
408
+ print(f"[P13] Multi-Token Prediction ({mtp_heads} heads)")
409
 
410
+ # P14: EMA Distillation
411
  if use_distill:
412
+ extras["distiller"] = EMASelfDistiller(model, decay=0.999, alpha=0.5, temperature=2.0)
413
+ if verbose:
414
+ print(f"[P14] EMA Self-Distillation (α=0.5, T=2.0)")
415
+
416
+ # P15: Token Triage
417
+ if use_triage:
418
+ extras["triage"] = TokenTriage(ema_decay=0.99, select_ratio=0.6, floor_weight=0.1)
419
+ if verbose:
420
+ print(f"[P15] Token Triage (top 60% tokens → full grad, bottom 40% → 10%)")
421
+
422
+ # P16: Plateau Breaker
423
+ if use_plateau_breaker:
424
+ extras["plateau"] = PlateauBreaker(patience=100, variance_threshold=0.005,
425
+ lr_multiplier=3.0, burst_steps=50)
426
  if verbose:
427
+ print(f"[P16] Plateau Breaker (detect stagnation → LR burst ×3)")
428
+
429
+ # P17: Batch Metabolism
430
+ if use_metabolism:
431
+ extras["metabolism"] = True
432
+ if verbose:
433
+ print(f"[P17] Batch Metabolism (hard examples → 2× weight)")
434
 
435
  if verbose:
 
 
436
  print("=" * 65)
437
 
438
  return model, optimizer, scheduler, extras
439
 
440
 
441
  # ═══════════════════════════════════════════════════════════
442
+ # Training step ALL paradigms active
443
  # ═══════════════════════════════════════════════════════════
444
 
445
  _nan_count = 0
 
448
  model, batch, optimizer, scheduler,
449
  extras=None, grad_accum_steps=1, step=0,
450
  max_grad_norm=1.0, autocast_dtype=None,
451
+ mtp_weight=0.3,
452
  ) -> float:
 
 
 
 
453
  global _nan_count
454
  extras = extras or {}
455
  is_accum_step = (step + 1) % grad_accum_steps == 0
 
462
  outputs = model(input_ids, labels=labels)
463
  else:
464
  outputs = model(batch)
465
+ input_ids = labels = batch
 
466
 
467
+ logits = outputs.logits if hasattr(outputs, "logits") else None
468
+
469
+ # ── Compute main loss ──
470
+ triage = extras.get("triage")
471
+ metabolism = extras.get("metabolism")
472
  distiller = extras.get("distiller")
473
+
474
+ if logits is not None and triage is not None:
475
+ # P15: Token Triage — selective token weighting
476
+ base_loss = triage.weighted_loss(logits, labels)
477
+ elif logits is not None and metabolism:
478
+ # P17: Batch Metabolism — sequence-level weighting
479
+ base_loss = batch_metabolism_loss(logits, labels)
480
+ elif distiller is not None and logits is not None:
481
+ # P14: EMA distillation
482
+ base_loss = distiller.distillation_loss(logits, labels, input_ids)
483
  else:
484
  base_loss = outputs.loss if hasattr(outputs, "loss") else outputs
485
 
486
+ # ── P13: MTP auxiliary ──
487
  mtp = extras.get("mtp")
488
  if mtp is not None and hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
489
  mtp_loss = mtp(outputs.hidden_states, labels)
 
506
 
507
  _nan_count = 0
508
 
509
+ # ── P16: Plateau Breaker ──
510
+ plateau = extras.get("plateau")
511
+ if plateau is not None:
512
+ plateau.check_and_adjust(loss_val, optimizer, step)
513
+
514
  if grad_accum_steps > 1:
515
  total_loss = total_loss / grad_accum_steps
516
  total_loss.backward()
 
526
  scheduler.step()
527
  optimizer.zero_grad(set_to_none=True)
528
  invalidate_all_caches(model)
 
 
529
  if "distiller" in extras:
530
  extras["distiller"].update(model)
531