Lgr54HFi commited on
Commit
76e1136
·
verified ·
1 Parent(s): acc06f5

feat: P12 Muon optimizer, P13 Multi-Token Prediction, P14 EMA Self-Distillation\n\nThree new paradigms for revolutionary sample efficiency:\n\nP12 Muon: Newton-Schulz orthogonalized momentum for 2D weight matrices.\nSame loss in 52% of FLOPs vs AdamW (arxiv 2502.16982). AdamW fallback\nfor 1D params (biases, norms, embeddings).\n\nP13 MTP: predict next 3 tokens instead of 1. Each forward pass yields\n3x gradient signal. Implemented as auxiliary loss heads sharing the trunk.\n\nP14 EMA Self-Distillation: EMA copy of model acts as teacher. KL loss\nbetween student and EMA soft targets gives dense signal across full vocab\nvs sparse one-hot labels. α=0.5, T=2.0 (Baby Llama recipe, arxiv 2308.02019)."

Browse files
Files changed (1) hide show
  1. chimera_turbo.py +335 -115
chimera_turbo.py CHANGED
@@ -1,10 +1,15 @@
1
  """
2
  chimera_turbo.py — Drop-in CPU acceleration for ch1mera 5.3
3
- Usage: import chimera_turbo; chimera_turbo.apply(model, max_steps=N)
4
 
5
- v8: BitNet-paper aligned hyperparams β2=0.98, wd=0.01, warmup=750
 
 
 
 
 
6
  """
7
 
 
8
  import math
9
  import os
10
  import warnings
@@ -15,12 +20,16 @@ from typing import Optional, Dict, Any, Tuple
15
  from contextlib import nullcontext
16
 
17
 
 
 
 
 
18
  def detect_cpu_info() -> Dict[str, Any]:
19
  info = {}
20
  try:
21
- physical = len(os.sched_getaffinity(0))
22
  import multiprocessing
23
  logical = multiprocessing.cpu_count()
 
24
  info["physical_cores"] = logical // 2 if logical == physical else physical
25
  info["logical_cores"] = logical
26
  except Exception:
@@ -33,9 +42,8 @@ def detect_cpu_info() -> Dict[str, Any]:
33
  info["capability"] = "unknown"
34
  cap = (info["capability"] or "").lower()
35
  info["has_amx"] = "amx" in cap
36
- info["has_avx512"] = "avx512" in cap or "avx512_vnni" in cap
37
  info["has_avx512_bf16"] = "avx512_bf16" in cap or info["has_amx"]
38
- info["has_vnni"] = info["has_avx512"]
39
  try:
40
  import intel_extension_for_pytorch
41
  info["ipex_available"] = True
@@ -45,173 +53,381 @@ def detect_cpu_info() -> Dict[str, Any]:
45
  return info
46
 
47
 
48
- def configure_threading(cpu_info: Dict[str, Any], reserve_for_io: int = 1):
49
- n_compute = max(1, cpu_info["physical_cores"] - reserve_for_io)
50
- torch.set_num_threads(n_compute)
51
- os.environ["OMP_NUM_THREADS"] = str(n_compute)
52
- os.environ["MKL_NUM_THREADS"] = str(n_compute)
53
- return n_compute
 
 
 
 
 
54
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- def create_optimizer(
57
- model: nn.Module,
58
- lr: float = 1.5e-3, # ← BitNet interpolated: 125M→2.4e-3, 350M→1.2e-3
59
- weight_decay: float = 0.01, # ← BitNet original (2310.11453 Table 5)
60
- use_lion: bool = False,
61
- betas: Tuple[float, float] = (0.9, 0.98), # ← BitNet: β2=0.98 NOT 0.95/0.999
62
- ) -> torch.optim.Optimizer:
63
- """AdamW with BitNet-paper hyperparameters.
64
 
65
- Key differences from standard:
66
- - β2=0.98 (not 0.999): faster variance adaptation for ternary noise
67
- - wd=0.01: original BitNet paper value, more stable than 0.05 for from-scratch
68
- - lr=1.5e-3: interpolated from BitNet Table 5 (125M2.4e-3, 350M→1.2e-3)
 
 
 
69
  """
70
- decay_params, no_decay_params = [], []
71
- for name, param in model.named_parameters():
72
- if not param.requires_grad:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  continue
74
- if param.ndim <= 1 or "bias" in name or "norm" in name or "embed" in name:
75
- no_decay_params.append(param)
76
- else:
77
- decay_params.append(param)
78
- param_groups = [
79
- {"params": decay_params, "weight_decay": weight_decay},
80
- {"params": no_decay_params, "weight_decay": 0.0},
81
- ]
82
- if use_lion:
83
- try:
84
- from lion_pytorch import Lion
85
- return Lion(param_groups, lr=lr * 0.3, betas=(0.95, 0.98))
86
- except ImportError:
87
- pass
88
- return torch.optim.AdamW(param_groups, lr=lr, betas=betas, eps=1e-8, fused=False)
89
-
90
-
91
- def create_scheduler(optimizer, max_steps: int, warmup_steps: int = 750):
92
- """Cosine decay with 750-step warmup (BitNet paper-exact)."""
93
- from torch.optim.lr_scheduler import LambdaLR
94
- def lr_lambda(step):
95
- if step < warmup_steps:
96
- return step / max(1, warmup_steps)
97
- progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
98
- return max(0.01, 0.5 * (1.0 + math.cos(math.pi * progress)))
99
- return LambdaLR(optimizer, lr_lambda)
100
 
 
 
101
 
102
- def invalidate_all_caches(model: nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  from chimera.quantization import BitLinear
104
- for m in model.modules():
 
105
  if isinstance(m, BitLinear):
106
  m.invalidate_packed()
107
 
108
 
109
- def try_ipex_optimize(model, optimizer, cpu_info, dtype=None):
110
- if not cpu_info.get("ipex_available"):
111
- return model, optimizer
112
- try:
113
- import intel_extension_for_pytorch as ipex
114
- except Exception:
115
- return model, optimizer
116
- if dtype is None:
117
- dtype = torch.bfloat16 if (cpu_info["has_amx"] or cpu_info["has_avx512"]) else torch.float32
118
- model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=dtype, level="O1", inplace=True)
119
- return model, optimizer
120
 
 
 
 
 
 
 
 
 
 
121
 
122
- def try_compile_model(model: nn.Module, mode: str = "default") -> nn.Module:
123
- if not hasattr(torch, "compile"):
124
- return model
125
- try:
126
- compiled = torch.compile(model, backend="inductor", mode=mode, fullgraph=False)
127
- print(f"[TURBO-2] torch.compile enabled (mode={mode})")
128
- return compiled
129
- except Exception:
130
- return model
131
 
 
 
 
132
 
133
  def apply(
134
- model: nn.Module, max_steps: int = 10000, lr: float = 1.5e-3,
135
- weight_decay: float = 0.01, warmup_steps: int = 750,
136
- use_compile: bool = True, use_ipex: bool = True,
137
- use_lion: bool = False, verbose: bool = True,
138
- ) -> Tuple[nn.Module, torch.optim.Optimizer, Any]:
 
 
 
 
 
139
  cpu_info = detect_cpu_info()
140
  if verbose:
141
  print("=" * 65)
142
- print("CHIMERA TURBO v8BitNet-aligned hyperparams")
143
  print("=" * 65)
144
  print(f" Cores: {cpu_info['physical_cores']} CPU: {cpu_info['capability']}")
145
- print(f" IPEX: {cpu_info['ipex_available']} tcmalloc: {cpu_info['tcmalloc']}")
146
  n_threads = configure_threading(cpu_info)
147
  if verbose:
148
  print(f"[TURBO-3] Threads: {n_threads}")
149
- optimizer = create_optimizer(model, lr=lr, weight_decay=weight_decay)
150
- scheduler = create_scheduler(optimizer, max_steps=max_steps, warmup_steps=warmup_steps)
151
- if verbose:
152
- n_params = sum(p.numel() for g in optimizer.param_groups for p in g["params"])
153
- print(f"[TURBO-1] AdamW (lr={lr}, β=(0.9,0.98), wd={weight_decay}) — {n_params:,} params")
154
- if use_ipex:
155
- model, optimizer = try_ipex_optimize(model, optimizer, cpu_info)
156
- if use_compile:
157
- model = try_compile_model(model, mode="default")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  if verbose:
159
- if not cpu_info["has_avx512_bf16"]:
160
- print(" ⚠️ No BF16 hw — use --no-bf16")
161
- if not cpu_info["tcmalloc"]:
162
  print(" ⚠️ No tcmalloc — LD_PRELOAD=...libtcmalloc.so.4 for +15%")
163
  print("=" * 65)
164
- return model, optimizer, scheduler
165
 
 
166
 
167
- _nan_count = 0
168
- _MAX_CONSECUTIVE_NAN = 5
169
 
 
 
 
 
 
170
 
171
  def training_step(
172
- model: nn.Module, batch, optimizer: torch.optim.Optimizer, scheduler,
173
- grad_accum_steps: int = 1, step: int = 0,
174
- max_grad_norm: float = 1.0, # ← raised back to 1.0 (papers use none, this is light)
175
- autocast_dtype: Optional[torch.dtype] = torch.bfloat16,
176
  ) -> float:
177
- """NaN-safe training step with BitNet-aligned grad clipping.
178
 
179
- BitNet papers use NO grad clipping. We keep a light clip (1.0) as safety
180
- net for the evolution engine side-effects, but it should rarely activate.
181
  """
182
  global _nan_count
183
-
184
  is_accum_step = (step + 1) % grad_accum_steps == 0
185
  ctx = torch.autocast(device_type="cpu", dtype=autocast_dtype) if autocast_dtype else nullcontext()
186
 
187
  with ctx:
188
  if isinstance(batch, dict):
189
- outputs = model(batch["input_ids"], labels=batch.get("labels"))
190
- elif isinstance(batch, (tuple, list)):
191
- outputs = model(*batch)
192
  else:
193
  outputs = model(batch)
194
- loss = outputs if isinstance(outputs, torch.Tensor) else outputs.loss
195
- loss_val = loss.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # ── NaN detection ──
 
 
198
  if not math.isfinite(loss_val):
199
  _nan_count += 1
200
  optimizer.zero_grad(set_to_none=True)
201
- if _nan_count >= _MAX_CONSECUTIVE_NAN:
202
  for pg in optimizer.param_groups:
203
  pg["lr"] *= 0.5
204
- print(f" [NaN] {_nan_count}x — LR halved to {optimizer.param_groups[0]['lr']:.2e}")
205
  _nan_count = 0
206
  return loss_val
207
 
208
  _nan_count = 0
209
 
210
  if grad_accum_steps > 1:
211
- loss = loss / grad_accum_steps
212
- loss.backward()
213
 
214
- # Sanitize any NaN grads from evolution engine
215
  for p in model.parameters():
216
  if p.grad is not None and not torch.isfinite(p.grad).all():
217
  p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
@@ -223,4 +439,8 @@ def training_step(
223
  optimizer.zero_grad(set_to_none=True)
224
  invalidate_all_caches(model)
225
 
 
 
 
 
226
  return loss_val
 
1
  """
2
  chimera_turbo.py — Drop-in CPU acceleration for ch1mera 5.3
 
3
 
4
+ v9: Muon optimizer + Multi-Token Prediction + EMA Self-Distillation
5
+
6
+ New paradigms:
7
+ P12 Muon optimizer — 2× token efficiency via NS-orthogonalized momentum
8
+ P13 Multi-Token Predict — 3× gradient signal per forward pass
9
+ P14 EMA Self-Distill — dense soft targets from EMA teacher copy
10
  """
11
 
12
+ import copy
13
  import math
14
  import os
15
  import warnings
 
20
  from contextlib import nullcontext
21
 
22
 
23
+ # ═══════════════════════════════════════════════════════════
24
+ # P-TURBO-3 : CPU Detection + Threading
25
+ # ═══════════════════════════════════════════════════════════
26
+
27
  def detect_cpu_info() -> Dict[str, Any]:
28
  info = {}
29
  try:
 
30
  import multiprocessing
31
  logical = multiprocessing.cpu_count()
32
+ physical = len(os.sched_getaffinity(0))
33
  info["physical_cores"] = logical // 2 if logical == physical else physical
34
  info["logical_cores"] = logical
35
  except Exception:
 
42
  info["capability"] = "unknown"
43
  cap = (info["capability"] or "").lower()
44
  info["has_amx"] = "amx" in cap
45
+ info["has_avx512"] = "avx512" in cap
46
  info["has_avx512_bf16"] = "avx512_bf16" in cap or info["has_amx"]
 
47
  try:
48
  import intel_extension_for_pytorch
49
  info["ipex_available"] = True
 
53
  return info
54
 
55
 
56
+ def configure_threading(cpu_info, reserve=1):
57
+ n = max(1, cpu_info["physical_cores"] - reserve)
58
+ torch.set_num_threads(n)
59
+ os.environ["OMP_NUM_THREADS"] = str(n)
60
+ os.environ["MKL_NUM_THREADS"] = str(n)
61
+ return n
62
+
63
+
64
+ # ═══════════════════════════════════════════════════════════
65
+ # P12 — Muon Optimizer (arxiv 2502.16982)
66
+ # ═══════════════════════════════════════════════════════════
67
 
68
+ def _zeropower_via_newtonschulz5(G, steps=5):
69
+ """Newton-Schulz iteration for polar factor. Pure PyTorch, CPU-safe."""
70
+ assert G.ndim == 2
71
+ a, b, c = 3.4445, -4.7750, 2.0315
72
+ X = G.T if G.size(0) > G.size(1) else G.clone()
73
+ X = X / (X.norm() + 1e-7)
74
+ for _ in range(steps):
75
+ A = X @ X.T
76
+ X = a * X + (b * A + c * A @ A) @ X
77
+ return X.T if G.size(0) > G.size(1) else X
78
 
 
 
 
 
 
 
 
 
79
 
80
+ class Muon(torch.optim.Optimizer):
81
+ """Muon: MomentUm Orthogonalized by Newton-schulz.
82
+
83
+ 2D weight matrices: SGD momentum NS orthogonalizescaled update.
84
+ Everything else (bias, norm, embed): standard AdamW.
85
+
86
+ ~2× token efficiency vs AdamW (arxiv 2502.16982, Table 3).
87
  """
88
+ def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
89
+ ns_steps=5, weight_decay=0.0,
90
+ adamw_betas=(0.9, 0.98), adamw_eps=1e-8):
91
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov,
92
+ ns_steps=ns_steps, weight_decay=weight_decay,
93
+ adamw_betas=adamw_betas, adamw_eps=adamw_eps)
94
+ super().__init__(params, defaults)
95
+
96
+ @torch.no_grad()
97
+ def step(self):
98
+ for group in self.param_groups:
99
+ lr = group["lr"]
100
+ wd = group["weight_decay"]
101
+ mu = group["momentum"]
102
+ b1, b2 = group["adamw_betas"]
103
+
104
+ for p in group["params"]:
105
+ if p.grad is None:
106
+ continue
107
+ g = p.grad
108
+ s = self.state[p]
109
+
110
+ # ── Muon path: 2D matrices (not embeddings) ──
111
+ if p.ndim == 2 and not getattr(p, "_is_embed", False):
112
+ if "buf" not in s:
113
+ s["buf"] = torch.zeros_like(g)
114
+ s["buf"].mul_(mu).add_(g)
115
+ ns_in = s["buf"] * mu + g if group["nesterov"] else s["buf"]
116
+ O = _zeropower_via_newtonschulz5(ns_in, group["ns_steps"])
117
+ scale = math.sqrt(max(1, p.size(0) / p.size(1)))
118
+ if wd > 0:
119
+ p.mul_(1 - lr * wd)
120
+ p.add_(O, alpha=-lr * scale)
121
+
122
+ # ── AdamW path: 1D params, embeddings ──
123
+ else:
124
+ if "m" not in s:
125
+ s["m"] = torch.zeros_like(g)
126
+ s["v"] = torch.zeros_like(g)
127
+ s["t"] = 0
128
+ s["t"] += 1
129
+ s["m"].mul_(b1).add_(g, alpha=1 - b1)
130
+ s["v"].mul_(b2).addcmul_(g, g, value=1 - b2)
131
+ bc1 = 1 - b1 ** s["t"]
132
+ bc2 = 1 - b2 ** s["t"]
133
+ alr = lr * math.sqrt(bc2) / bc1
134
+ if wd > 0:
135
+ p.mul_(1 - lr * wd)
136
+ p.addcdiv_(s["m"], s["v"].sqrt().add_(group["adamw_eps"]), value=-alr)
137
+
138
+
139
+ def create_muon_optimizer(model, lr=0.02, momentum=0.95, weight_decay=0.01):
140
+ """Create Muon optimizer with proper param group splitting."""
141
+ params = []
142
+ for name, p in model.named_parameters():
143
+ if not p.requires_grad:
144
  continue
145
+ if any(k in name for k in ["embed", "lm_head", "wte", "wpe"]):
146
+ p._is_embed = True
147
+ params.append(p)
148
+ return Muon(
149
+ [{"params": params}],
150
+ lr=lr, momentum=momentum, weight_decay=weight_decay,
151
+ adamw_betas=(0.9, 0.98), adamw_eps=1e-8,
152
+ )
153
+
154
+
155
+ # ═══════════════════════════════════════════════════════════
156
+ # P13 Multi-Token Prediction (arxiv 2404.19737)
157
+ # ═══════════════════════════════════════════════════════════
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ class MultiTokenPredictionLoss(nn.Module):
160
+ """Auxiliary loss: predict next N tokens instead of just 1.
161
 
162
+ Each forward pass yields N× gradient signal from the same hidden states.
163
+ Heads are lightweight linear projections sharing the trunk.
164
+ """
165
+ def __init__(self, hidden_size: int, vocab_size: int, n_future: int = 3):
166
+ super().__init__()
167
+ self.n_future = n_future
168
+ # Extra heads for tokens +2, +3, ... (head for +1 is the main lm_head)
169
+ self.extra_heads = nn.ModuleList([
170
+ nn.Linear(hidden_size, vocab_size, bias=False)
171
+ for _ in range(n_future - 1)
172
+ ])
173
+ # Init small to not destabilize early training
174
+ for head in self.extra_heads:
175
+ nn.init.normal_(head.weight, std=0.006)
176
+
177
+ def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
178
+ """Compute auxiliary MTP loss.
179
+
180
+ Args:
181
+ hidden_states: [B, T, H] from trunk (before lm_head)
182
+ labels: [B, T] target token ids
183
+
184
+ Returns:
185
+ Scalar auxiliary loss (mean over all future positions and heads)
186
+ """
187
+ total_loss = torch.tensor(0.0, device=hidden_states.device)
188
+ count = 0
189
+ for k, head in enumerate(self.extra_heads):
190
+ shift = k + 2 # head 0 predicts +2, head 1 predicts +3, etc.
191
+ if shift >= labels.size(1):
192
+ continue
193
+ # Hidden states predict token at position +shift
194
+ logits = head(hidden_states[:, :-shift]) # [B, T-shift, V]
195
+ targets = labels[:, shift:] # [B, T-shift]
196
+ seq_len = min(logits.size(1), targets.size(1))
197
+ loss = F.cross_entropy(
198
+ logits[:, :seq_len].reshape(-1, logits.size(-1)),
199
+ targets[:, :seq_len].reshape(-1),
200
+ ignore_index=-100,
201
+ )
202
+ if torch.isfinite(loss):
203
+ total_loss = total_loss + loss
204
+ count += 1
205
+ return total_loss / max(count, 1)
206
+
207
+
208
+ # ═══════════════════════════════════════════════════════════
209
+ # P14 — EMA Self-Distillation (arxiv 2308.02019)
210
+ # ═══════════════════════════════════════════════════════════
211
+
212
+ class EMASelfDistiller:
213
+ """Maintain EMA copy of model as teacher for self-distillation.
214
+
215
+ The EMA model's soft targets provide dense gradient signal across
216
+ the full vocabulary, vs sparse one-hot labels from hard targets.
217
+
218
+ α=0.5 blends hard CE and soft KL. T=2.0 temperature.
219
+ Recipe from Baby Llama (arxiv 2308.02019).
220
+ """
221
+ def __init__(self, model: nn.Module, decay: float = 0.999, alpha: float = 0.5,
222
+ temperature: float = 2.0):
223
+ self.decay = decay
224
+ self.alpha = alpha
225
+ self.temperature = temperature
226
+ # Deep copy for EMA — no gradients needed
227
+ self.ema_model = copy.deepcopy(model)
228
+ for p in self.ema_model.parameters():
229
+ p.requires_grad_(False)
230
+ self.ema_model.eval()
231
+
232
+ @torch.no_grad()
233
+ def update(self, model: nn.Module):
234
+ """Update EMA weights. Call after optimizer.step()."""
235
+ for p_ema, p in zip(self.ema_model.parameters(), model.parameters()):
236
+ p_ema.data.mul_(self.decay).add_(p.data, alpha=1 - self.decay)
237
+
238
+ def distillation_loss(self, student_logits: torch.Tensor,
239
+ hard_targets: torch.Tensor,
240
+ input_ids: torch.Tensor) -> torch.Tensor:
241
+ """Compute blended hard + soft distillation loss."""
242
+ T = self.temperature
243
+
244
+ # Hard loss (standard CE)
245
+ seq_len = min(student_logits.size(1), hard_targets.size(1))
246
+ hard_loss = F.cross_entropy(
247
+ student_logits[:, :seq_len].reshape(-1, student_logits.size(-1)),
248
+ hard_targets[:, :seq_len].reshape(-1),
249
+ ignore_index=-100,
250
+ )
251
+
252
+ # Soft loss (KL from EMA teacher)
253
+ with torch.no_grad():
254
+ teacher_out = self.ema_model(input_ids)
255
+ teacher_logits = teacher_out.logits if hasattr(teacher_out, "logits") else teacher_out[1]
256
+
257
+ t_seq = min(student_logits.size(1), teacher_logits.size(1))
258
+ soft_student = F.log_softmax(student_logits[:, :t_seq] / T, dim=-1)
259
+ soft_teacher = F.softmax(teacher_logits[:, :t_seq] / T, dim=-1)
260
+ kl_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (T * T)
261
+
262
+ if not torch.isfinite(kl_loss):
263
+ return hard_loss
264
+
265
+ return self.alpha * hard_loss + (1 - self.alpha) * kl_loss
266
+
267
+
268
+ # ═══════════════════════════════════════════════════════════
269
+ # Cache invalidation
270
+ # ═══════════════════════════════════════════════════════════
271
+
272
+ def invalidate_all_caches(model):
273
  from chimera.quantization import BitLinear
274
+ raw = getattr(model, "_orig_mod", model)
275
+ for m in raw.modules():
276
  if isinstance(m, BitLinear):
277
  m.invalidate_packed()
278
 
279
 
280
+ # ═══════════════════════════════════════════════════════════
281
+ # Scheduler
282
+ # ══════════════════════════════════════════════════════════��
 
 
 
 
 
 
 
 
283
 
284
+ def create_scheduler(optimizer, max_steps, warmup_steps=200):
285
+ """Short warmup (200 steps) then cosine decay. Warmup=750 was too long."""
286
+ from torch.optim.lr_scheduler import LambdaLR
287
+ def lr_lambda(step):
288
+ if step < warmup_steps:
289
+ return step / max(1, warmup_steps)
290
+ progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
291
+ return max(0.01, 0.5 * (1.0 + math.cos(math.pi * progress)))
292
+ return LambdaLR(optimizer, lr_lambda)
293
 
 
 
 
 
 
 
 
 
 
294
 
295
+ # ═══════════════════════════════════════════════════════════
296
+ # MAIN: apply()
297
+ # ═══════════════════════════════════════════════════════════
298
 
299
  def apply(
300
+ model, max_steps=10000, lr=0.02, weight_decay=0.01,
301
+ warmup_steps=200, use_compile=False, use_ipex=True,
302
+ use_muon=True, use_mtp=True, use_distill=True,
303
+ mtp_heads=3, verbose=True,
304
+ ):
305
+ """Apply all turbo + revolutionary paradigms.
306
+
307
+ Returns: (model, optimizer, scheduler, extras)
308
+ where extras = dict with 'mtp_loss_fn', 'distiller', etc.
309
+ """
310
  cpu_info = detect_cpu_info()
311
  if verbose:
312
  print("=" * 65)
313
+ print("CHIMERA TURBO v9Revolutionary Training Paradigms")
314
  print("=" * 65)
315
  print(f" Cores: {cpu_info['physical_cores']} CPU: {cpu_info['capability']}")
316
+
317
  n_threads = configure_threading(cpu_info)
318
  if verbose:
319
  print(f"[TURBO-3] Threads: {n_threads}")
320
+
321
+ # ── P12: Muon optimizer ──
322
+ if use_muon:
323
+ optimizer = create_muon_optimizer(model, lr=lr, weight_decay=weight_decay)
324
+ if verbose:
325
+ n_2d = sum(p.numel() for p in model.parameters()
326
+ if p.requires_grad and p.ndim == 2
327
+ and not getattr(p, "_is_embed", False))
328
+ n_1d = sum(p.numel() for p in model.parameters()
329
+ if p.requires_grad and (p.ndim < 2 or getattr(p, "_is_embed", False)))
330
+ print(f"[P12] Muon optimizer (lr={lr}, NS-5 orthogonalization)")
331
+ print(f" Muon: {n_2d:,} params | AdamW fallback: {n_1d:,} params")
332
+ else:
333
+ from chimera_turbo_legacy import create_optimizer
334
+ optimizer = create_optimizer(model, lr=lr, weight_decay=weight_decay)
335
+
336
+ scheduler = create_scheduler(optimizer, max_steps, warmup_steps)
337
+
338
+ # ── P13: Multi-Token Prediction ──
339
+ extras = {}
340
+ if use_mtp:
341
+ raw = getattr(model, "_orig_mod", model)
342
+ h = raw.config["hidden_size"]
343
+ v = raw.config["vocab_size"]
344
+ mtp = MultiTokenPredictionLoss(h, v, n_future=mtp_heads)
345
+ extras["mtp"] = mtp
346
+ if verbose:
347
+ print(f"[P13] Multi-Token Prediction ({mtp_heads} heads → {mtp_heads}× gradient signal)")
348
+
349
+ # ── P14: EMA Self-Distillation ──
350
+ if use_distill:
351
+ distiller = EMASelfDistiller(model, decay=0.999, alpha=0.5, temperature=2.0)
352
+ extras["distiller"] = distiller
353
+ if verbose:
354
+ print(f"[P14] EMA Self-Distillation (α=0.5, T=2.0, decay=0.999)")
355
+
356
  if verbose:
357
+ if not cpu_info.get("tcmalloc"):
 
 
358
  print(" ⚠️ No tcmalloc — LD_PRELOAD=...libtcmalloc.so.4 for +15%")
359
  print("=" * 65)
 
360
 
361
+ return model, optimizer, scheduler, extras
362
 
 
 
363
 
364
+ # ══════════���════════════════════════════════════════════════
365
+ # Training step with all paradigms
366
+ # ═══════════════════════════════════════════════════════════
367
+
368
+ _nan_count = 0
369
 
370
  def training_step(
371
+ model, batch, optimizer, scheduler,
372
+ extras=None, grad_accum_steps=1, step=0,
373
+ max_grad_norm=1.0, autocast_dtype=None,
374
+ mtp_weight=0.3, distill_weight=0.5,
375
  ) -> float:
376
+ """Training step with Muon + MTP + EMA distillation.
377
 
378
+ Loss = distill_loss (blended hard+soft) + mtp_weight * mtp_aux_loss
 
379
  """
380
  global _nan_count
381
+ extras = extras or {}
382
  is_accum_step = (step + 1) % grad_accum_steps == 0
383
  ctx = torch.autocast(device_type="cpu", dtype=autocast_dtype) if autocast_dtype else nullcontext()
384
 
385
  with ctx:
386
  if isinstance(batch, dict):
387
+ input_ids = batch["input_ids"]
388
+ labels = batch.get("labels", input_ids)
389
+ outputs = model(input_ids, labels=labels)
390
  else:
391
  outputs = model(batch)
392
+ input_ids = batch
393
+ labels = batch
394
+
395
+ # ── Base loss ──
396
+ distiller = extras.get("distiller")
397
+ if distiller is not None and hasattr(outputs, "logits"):
398
+ # P14: distillation loss replaces raw CE
399
+ base_loss = distiller.distillation_loss(outputs.logits, labels, input_ids)
400
+ else:
401
+ base_loss = outputs.loss if hasattr(outputs, "loss") else outputs
402
+
403
+ # ── P13: MTP auxiliary loss ──
404
+ mtp = extras.get("mtp")
405
+ if mtp is not None and hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
406
+ mtp_loss = mtp(outputs.hidden_states, labels)
407
+ total_loss = base_loss + mtp_weight * mtp_loss
408
+ else:
409
+ total_loss = base_loss
410
 
411
+ loss_val = total_loss.item()
412
+
413
+ # ── NaN guard ──
414
  if not math.isfinite(loss_val):
415
  _nan_count += 1
416
  optimizer.zero_grad(set_to_none=True)
417
+ if _nan_count >= 5:
418
  for pg in optimizer.param_groups:
419
  pg["lr"] *= 0.5
420
+ print(f" [NaN] — LR halved to {optimizer.param_groups[0]['lr']:.2e}")
421
  _nan_count = 0
422
  return loss_val
423
 
424
  _nan_count = 0
425
 
426
  if grad_accum_steps > 1:
427
+ total_loss = total_loss / grad_accum_steps
428
+ total_loss.backward()
429
 
430
+ # Sanitize grads
431
  for p in model.parameters():
432
  if p.grad is not None and not torch.isfinite(p.grad).all():
433
  p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
 
439
  optimizer.zero_grad(set_to_none=True)
440
  invalidate_all_caches(model)
441
 
442
+ # P14: update EMA teacher
443
+ if "distiller" in extras:
444
+ extras["distiller"].update(model)
445
+
446
  return loss_val