CLIWorks commited on
Commit
edb6a10
·
verified ·
1 Parent(s): b3b689e

Upload train_spider.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_spider.py +1353 -0
train_spider.py ADDED
@@ -0,0 +1,1353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Spider-FLEXITOKENS training pipeline.
3
+
4
+ Byte-level pretraining on FineWeb-Edu with boundary predictor curriculum.
5
+ Architecture: RDT (2 prelude + 6 recurrent + 2 coda) with:
6
+ - SharedProjectionMoE (32 experts, top-2, shared_inter=6144, rank=256)
7
+ - MLA (Multi-Latent Attention) with compressed KV cache + sliding window
8
+ - Engram conditional memory at recurrent layers 1 and 4
9
+ - BoundaryPredictor + downsample/upsample for FlexiToken integration
10
+ - LTI Injection + ACT Halting + LoRA Adapter
11
+ - 256k context (YaRN factor=8.0), sliding_window=8192
12
+ - 272-token byte-level vocab (256 bytes + 16 specials)
13
+
14
+ Usage:
15
+ Single GPU:
16
+ python train_spider.py
17
+ Multi-GPU:
18
+ torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") train_spider.py
19
+ Resume from checkpoint:
20
+ python train_spider.py --resume checkpoints/spider-step5000.pt
21
+ Quick smoke test:
22
+ python train_spider.py --max_steps 50 --mock_data
23
+ """
24
+
25
+ import os
26
+ import math
27
+ import re
28
+ import sys
29
+ import time
30
+ import argparse
31
+ from contextlib import nullcontext
32
+ from dataclasses import dataclass, field
33
+ from typing import Dict, List, Optional, Tuple
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ import torch.distributed as dist
39
+ from torch.nn import CrossEntropyLoss
40
+ from torch.utils.data import IterableDataset, DataLoader, get_worker_info
41
+
42
+ from datasets import load_dataset
43
+
44
+ try:
45
+ import bitsandbytes as bnb
46
+ AdamW8bit = bnb.optim.AdamW8bit
47
+ Adam8bit = bnb.optim.Adam8bit
48
+ _HAS_8BIT = True
49
+ except ImportError:
50
+ _HAS_8BIT = False
51
+ AdamW8bit = None
52
+ Adam8bit = None
53
+
54
+ from spider import (
55
+ SpiderConfig,
56
+ SpiderForConditionalGeneration,
57
+ SENTINEL_TOKENS,
58
+ )
59
+
60
+ try:
61
+ from loguru import logger
62
+ logger.remove()
63
+ logger.add(sys.stderr, format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}")
64
+ logger.add("train_spider.log", rotation="100 MB", retention="10 days",
65
+ format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}")
66
+ except ImportError:
67
+ import logging
68
+ logging.basicConfig(level=logging.INFO)
69
+ class _LoguruShim:
70
+ def info(self, msg): logging.info(msg)
71
+ def success(self, msg): logging.info(msg)
72
+ def warning(self, msg): logging.warning(msg)
73
+ def error(self, msg): logging.error(msg)
74
+ logger = _LoguruShim()
75
+
76
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
77
+
78
+
79
+ # ============================================================================
80
+ # Byte-Level Dataset
81
+ # ============================================================================
82
+
83
+ BOS_ID = SENTINEL_TOKENS['BOS'] # 257
84
+ EOS_ID = SENTINEL_TOKENS['EOS'] # 258
85
+ PAD_ID = SENTINEL_TOKENS['PAD'] # 256
86
+
87
+
88
+ class ByteLevelDataset(IterableDataset):
89
+ """Streaming byte-level dataset from FineWeb-Edu.
90
+
91
+ Per D-23: FineWeb-Edu (English first), per-sample UTF-8 byte encoding.
92
+ Per D-24: Curriculum ordering (English -> multilingual -> code -> math).
93
+ Per D-34: Streaming only, no local download.
94
+
95
+ Each sample is encoded as raw UTF-8 bytes with BOS/EOS sentinel tokens.
96
+ Vocab: 272 tokens (256 bytes + 16 specials). Max 8192 bytes per sample.
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ dataset_name: str = "HuggingFaceFW/fineweb-edu",
102
+ subset: str = "sample-10BT",
103
+ split: str = "train",
104
+ seq_len: int = 8192,
105
+ max_bytes: int = 8192,
106
+ rank: int = 0,
107
+ world_size: int = 1,
108
+ ):
109
+ self.seq_len = seq_len
110
+ self.max_bytes = max_bytes
111
+ self.dataset_name = dataset_name
112
+ self.subset = subset
113
+ self.split = split
114
+ self.rank = rank
115
+ self.world_size = world_size
116
+
117
+ def _encode_sample(self, text: str) -> List[int]:
118
+ """Encode text as UTF-8 bytes with BOS/EOS, truncated to max_bytes."""
119
+ byte_ids = list(text.encode('utf-8'))[:self.max_bytes]
120
+ return [BOS_ID] + byte_ids + [EOS_ID]
121
+
122
+ def _pad_or_truncate(self, ids: List[int]) -> List[int]:
123
+ """Pad or truncate to seq_len, mask padding with -100 for labels."""
124
+ ids = ids[:self.seq_len]
125
+ ids = ids + [PAD_ID] * (self.seq_len - len(ids))
126
+ return ids
127
+
128
+ def __iter__(self):
129
+ worker = get_worker_info()
130
+ num_workers = worker.num_workers if worker else 1
131
+ worker_id = worker.id if worker else 0
132
+ total_shards = self.world_size * num_workers
133
+ shard_index = self.rank * num_workers + worker_id
134
+
135
+ ds = load_dataset(
136
+ self.dataset_name,
137
+ name=self.subset,
138
+ split=self.split,
139
+ streaming=True,
140
+ ).shard(num_shards=total_shards, index=shard_index)
141
+
142
+ buf = []
143
+ for sample in ds:
144
+ text = sample.get("text", "")
145
+ if not text:
146
+ continue
147
+ byte_ids = self._encode_sample(text)
148
+ buf.extend(byte_ids)
149
+ while len(buf) >= self.seq_len + 1:
150
+ chunk = buf[:self.seq_len + 1]
151
+ buf = buf[self.seq_len + 1:]
152
+ x = torch.tensor(chunk[:-1], dtype=torch.long)
153
+ y = torch.tensor(chunk[1:], dtype=torch.long)
154
+ y[y == PAD_ID] = -100
155
+ yield x, y
156
+
157
+
158
+ class MockByteLevelDataset(IterableDataset):
159
+ """In-memory byte-level dataset for testing (no network required).
160
+
161
+ Uses a fixed set of text samples in multiple languages to verify
162
+ byte-level encoding, BOS/EOS placement, and multilingual handling.
163
+ """
164
+
165
+ SAMPLES = [
166
+ "Hello world, this is a test of the byte-level encoding system.",
167
+ "The quick brown fox jumps over the lazy dog.",
168
+ "Spider is a recurrent latent reasoning architecture with engram memory.",
169
+ "Boundary predictors learn to merge byte sequences into meaningful tokens.",
170
+ "FineWeb-Edu contains high-quality educational content for pretraining.",
171
+ "Это текст на русском языке для проверки многозычной поддержки.",
172
+ "తెలుగు భాష యొక్క పరీక్ష కోసం నమూనా వచనం.",
173
+ "中文文本用于测试多语言字节编码支持。",
174
+ "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)",
175
+ "The integral of x^2 from 0 to 1 equals 1/3.",
176
+ ]
177
+
178
+ def __init__(self, seq_len: int = 512, max_bytes: int = 512, num_samples: int = 1000):
179
+ self.seq_len = seq_len
180
+ self.max_bytes = max_bytes
181
+ self.num_samples = num_samples
182
+
183
+ def __iter__(self):
184
+ buf = []
185
+ count = 0
186
+ while count < self.num_samples:
187
+ for text in self.SAMPLES:
188
+ byte_ids = list(text.encode('utf-8'))[:self.max_bytes]
189
+ ids = [BOS_ID] + byte_ids + [EOS_ID]
190
+ buf.extend(ids)
191
+ while len(buf) >= self.seq_len + 1:
192
+ chunk = buf[:self.seq_len + 1]
193
+ buf = buf[self.seq_len + 1:]
194
+ x = torch.tensor(chunk[:-1], dtype=torch.long)
195
+ y = torch.tensor(chunk[1:], dtype=torch.long)
196
+ y[y == PAD_ID] = -100
197
+ yield x, y
198
+ count += 1
199
+ if count >= self.num_samples:
200
+ return
201
+
202
+
203
+ # ============================================================================
204
+ # Curriculum Scheduler
205
+ # ============================================================================
206
+
207
+ class CurriculumScheduler:
208
+ """Training curriculum scheduler per D-24 and D-25.
209
+
210
+ Manages dataset switching across training phases and boundary predictor
211
+ curriculum mode (fixed top-k vs adaptive threshold).
212
+
213
+ Phases:
214
+ 0-30%: English (FineWeb-Edu), fixed top-k BP (D-25)
215
+ 30-50%: English + multilingual, adaptive BP
216
+ 50-70%: English + multilingual + code, adaptive BP
217
+ 70-90%: English + multilingual + code + math, adaptive BP
218
+ 90-100%: Mixed + multimodal, adaptive BP
219
+ """
220
+
221
+ def __init__(
222
+ self,
223
+ total_steps: int,
224
+ fixed_compression_k: float = 3.3,
225
+ adaptive_threshold: float = 0.5,
226
+ ):
227
+ self.total_steps = total_steps
228
+ self.fixed_compression_k = fixed_compression_k
229
+ self.adaptive_threshold = adaptive_threshold
230
+ self.curriculum_switch_step = int(0.3 * total_steps)
231
+
232
+ def get_phase(self, step: int) -> int:
233
+ if step < int(0.3 * self.total_steps):
234
+ return 1
235
+ elif step < int(0.5 * self.total_steps):
236
+ return 2
237
+ elif step < int(0.7 * self.total_steps):
238
+ return 3
239
+ elif step < int(0.9 * self.total_steps):
240
+ return 4
241
+ else:
242
+ return 5
243
+
244
+ def is_fixed_bp(self, step: int) -> bool:
245
+ """Return True if BP should use fixed top-k boundaries (D-25)."""
246
+ return step < self.curriculum_switch_step
247
+
248
+ def get_fixed_k(self, seq_len: int) -> int:
249
+ """Number of boundary positions for fixed top-k (3.3x compression)."""
250
+ return max(1, int(seq_len / self.fixed_compression_k))
251
+
252
+ def get_boundaries(
253
+ self,
254
+ soft_boundaries: torch.Tensor,
255
+ step: int,
256
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
257
+ """Compute hard boundaries based on curriculum phase.
258
+
259
+ During fixed phase (first 30% of steps): top-k boundaries with
260
+ straight-through estimator. During adaptive phase: threshold-based.
261
+
262
+ Args:
263
+ soft_boundaries: [B, L] boundary probabilities from BoundaryPredictor
264
+ step: Current training step
265
+
266
+ Returns:
267
+ Tuple of (soft_boundaries, hard_boundaries), each [B, L]
268
+ """
269
+ if self.is_fixed_bp(step):
270
+ k = self.get_fixed_k(soft_boundaries.shape[-1])
271
+ topk_vals, topk_idx = soft_boundaries.topk(k, dim=-1)
272
+ hard_boundaries = torch.zeros_like(soft_boundaries)
273
+ hard_boundaries.scatter_(-1, topk_idx, 1.0)
274
+ hard_boundaries = (
275
+ hard_boundaries - soft_boundaries.detach() + soft_boundaries
276
+ )
277
+ else:
278
+ hard_boundaries = (soft_boundaries > self.adaptive_threshold).float()
279
+ hard_boundaries = (
280
+ hard_boundaries - soft_boundaries.detach() + soft_boundaries
281
+ )
282
+
283
+ return soft_boundaries, hard_boundaries
284
+
285
+
286
+ # ============================================================================
287
+ # BP Loss (D-26)
288
+ # ============================================================================
289
+
290
+ def compute_bp_loss(
291
+ soft_boundaries: torch.Tensor,
292
+ hard_boundaries: torch.Tensor,
293
+ seq_len: int,
294
+ binomial_weight: float = 0.1,
295
+ pred_prior: float = 0.303,
296
+ ) -> torch.Tensor:
297
+ """Compute boundary predictor loss per D-26: BCE + binomial prior.
298
+
299
+ During fixed phase: BCE on boundary decisions vs uniform target.
300
+ During adaptive phase: binomial prior loss only.
301
+
302
+ Args:
303
+ soft_boundaries: [B, L] boundary probabilities
304
+ hard_boundaries: [B, L] binary boundary decisions
305
+ seq_len: Sequence length
306
+ binomial_weight: Weight for binomial prior term (0.1 per D-26)
307
+ pred_prior: Expected fraction of boundary positions (1/3.3 ≈ 0.303)
308
+
309
+ Returns:
310
+ Scalar BP loss tensor
311
+ """
312
+ B = soft_boundaries.shape[0]
313
+
314
+ # BCE loss: encourage boundary probability to match expected compression
315
+ target_rate = 1.0 / 3.3
316
+ target = torch.full_like(soft_boundaries, target_rate)
317
+ bce_loss = F.binary_cross_entropy(soft_boundaries, target)
318
+
319
+ # Binomial prior: regularize number of predicted boundaries
320
+ sum_preds = hard_boundaries.sum(dim=-1) # [B]
321
+ binomial = torch.distributions.binomial.Binomial(
322
+ total_count=float(seq_len),
323
+ probs=pred_prior,
324
+ )
325
+ log_prob = binomial.log_prob(sum_preds)
326
+ binomial_loss = -log_prob.mean() / seq_len
327
+
328
+ return bce_loss + binomial_weight * binomial_loss
329
+
330
+
331
+ # ============================================================================
332
+ # Recurrent Monitor (drift/collapse detection)
333
+ # ============================================================================
334
+
335
+ class RecurrentMonitor:
336
+ """Monitors recurrent dynamics across loops during training.
337
+
338
+ Catches representation drift, expert collapse, and engram instability
339
+ before they corrupt training. Per CONTEXT: representation drift across
340
+ loops is the #1 failure mode for recurrent architectures.
341
+
342
+ Logged metrics (every log_interval steps):
343
+ - loop_norms: L2 norm of hidden states after each loop (drift detection)
344
+ - routing_entropy: entropy of expert routing weights per loop (collapse detection)
345
+ - engram_norms: L2 norm of engram residuals at layers 1 and 4 (memory stability)
346
+ - halt_distribution: fraction of tokens halting at each loop (ACT health)
347
+ - loop_grad_norms: gradient norms per recurrent layer (gradient health)
348
+ """
349
+
350
+ def __init__(
351
+ self,
352
+ drift_threshold: float = 10.0,
353
+ collapse_threshold: float = 1.0,
354
+ ):
355
+ self.drift_threshold = drift_threshold
356
+ self.collapse_threshold = collapse_threshold
357
+
358
+ def compute_routing_entropy(self, router_logits: torch.Tensor) -> float:
359
+ """Compute routing entropy from router logits.
360
+
361
+ Args:
362
+ router_logits: [B, L, num_experts] raw router logits
363
+
364
+ Returns:
365
+ Scalar entropy value (higher = more diverse routing)
366
+ """
367
+ p = F.softmax(router_logits, dim=-1).mean(dim=(0, 1)) # [num_experts]
368
+ entropy = -(p * (p + 1e-10).log()).sum().item()
369
+ return entropy
370
+
371
+ def check_health(self, metrics: Dict, step: int) -> List[str]:
372
+ """Check for drift, collapse, or instability.
373
+
374
+ Args:
375
+ metrics: Dict with keys: loop_norms, routing_entropy, engram_norms, halt_distribution
376
+ step: Current training step
377
+
378
+ Returns:
379
+ List of warning strings (empty if healthy)
380
+ """
381
+ warnings = []
382
+
383
+ # Drift detection: hidden norm ratio between first and last loop
384
+ norms = metrics.get('loop_norms', [])
385
+ if len(norms) >= 2 and norms[0] > 0:
386
+ drift_ratio = norms[-1] / norms[0]
387
+ if drift_ratio > self.drift_threshold:
388
+ warnings.append(
389
+ f"DRIFT WARNING step {step}: loop norm ratio {drift_ratio:.1f}x "
390
+ f"(loop_1={norms[0]:.2f}, loop_{len(norms)}={norms[-1]:.2f})"
391
+ )
392
+
393
+ # Collapse detection: low routing entropy
394
+ entropies = metrics.get('routing_entropy', [])
395
+ if entropies and min(entropies) < self.collapse_threshold:
396
+ warnings.append(
397
+ f"COLLAPSE WARNING step {step}: routing entropy {min(entropies):.2f} "
398
+ f"< threshold {self.collapse_threshold}"
399
+ )
400
+
401
+ return warnings
402
+
403
+
404
+ # ============================================================================
405
+ # BP Curriculum Trainer
406
+ # ============================================================================
407
+
408
+ class BPCurriculumTrainer:
409
+ """Training wrapper for Spider-FLEXITOKENS with BP curriculum.
410
+
411
+ Manages:
412
+ - BP freeze/unfreeze during warmup (D-27)
413
+ - Fixed -> adaptive boundary curriculum (D-25)
414
+ - Dual loss: LM CE + MoE aux + BP (BCE + binomial prior) (D-26)
415
+ - Per-loop gradient clipping for expert cores
416
+ - RecurrentMonitor integration for drift/collapse detection
417
+ """
418
+
419
+ def __init__(
420
+ self,
421
+ model: SpiderForConditionalGeneration,
422
+ optimizer: torch.optim.Optimizer,
423
+ engram_optimizer: Optional[torch.optim.Optimizer],
424
+ curriculum: CurriculumScheduler,
425
+ monitor: RecurrentMonitor,
426
+ warmup_steps: int,
427
+ base_lr: float,
428
+ bp_loss_weight: float = 0.1,
429
+ grad_clip: float = 1.0,
430
+ expert_core_grad_clip: float = 0.5,
431
+ ):
432
+ self.model = model
433
+ self.optimizer = optimizer
434
+ self.engram_optimizer = engram_optimizer
435
+ self.curriculum = curriculum
436
+ self.monitor = monitor
437
+ self.warmup_steps = warmup_steps
438
+ self.base_lr = base_lr
439
+ self.bp_loss_weight = bp_loss_weight
440
+ self.grad_clip = grad_clip
441
+ self.expert_core_grad_clip = expert_core_grad_clip
442
+ self._bp_frozen = False
443
+ self.bp_optimizer = None
444
+
445
+ def freeze_bp(self):
446
+ """Freeze boundary predictor params during warmup (D-27)."""
447
+ for name, param in self.model.named_parameters():
448
+ if 'boundary_predictor' in name:
449
+ param.requires_grad = False
450
+ self._bp_frozen = True
451
+
452
+ def unfreeze_bp(self):
453
+ """Unfreeze BP at 0.1x base LR after warmup (D-27)."""
454
+ bp_param_names = set()
455
+ bp_params = []
456
+ for name, param in self.model.named_parameters():
457
+ if 'boundary_predictor' in name:
458
+ param.requires_grad = True
459
+ bp_params.append(param)
460
+ bp_param_names.add(name)
461
+ self._bp_frozen = False
462
+
463
+ # Create separate optimizer for BP params with 0.1x base LR (D-27)
464
+ bp_lr = self.base_lr * 0.1
465
+ self.bp_optimizer = torch.optim.Adam(
466
+ bp_params, lr=bp_lr, betas=(0.9, 0.95), eps=1e-8
467
+ )
468
+
469
+ def train_step(
470
+ self,
471
+ input_ids: torch.Tensor,
472
+ labels: torch.Tensor,
473
+ step: int,
474
+ n_loops: int = 6,
475
+ amp_ctx: Optional[nullcontext] = None,
476
+ sdpa_ctx: Optional[nullcontext] = None,
477
+ ) -> Tuple[torch.Tensor, Dict]:
478
+ """Single training step with dual loss and monitoring.
479
+
480
+ Args:
481
+ input_ids: [B, L] byte-level token IDs
482
+ labels: [B, L] target token IDs (with -100 for padding)
483
+ step: Current training step
484
+ n_loops: Number of recurrent loops
485
+ amp_ctx: Optional autocast context
486
+ sdpa_ctx: Optional SDPA kernel context
487
+
488
+ Returns:
489
+ Tuple of (total_loss, metrics_dict)
490
+ """
491
+ amp_ctx = amp_ctx or nullcontext()
492
+ sdpa_ctx = sdpa_ctx or nullcontext()
493
+
494
+ # BP freeze/unfreeze logic (D-27)
495
+ if step == 0 and self.warmup_steps > 0:
496
+ self.freeze_bp()
497
+ if self._bp_frozen and step >= self.warmup_steps:
498
+ self.unfreeze_bp()
499
+
500
+ with amp_ctx, sdpa_ctx:
501
+ # Override boundaries based on curriculum phase
502
+ output = self.model(input_ids, labels=labels, n_loops=n_loops)
503
+
504
+ lm_loss = output['loss']
505
+ aux_loss = output['aux_loss']
506
+ soft_boundaries = output['soft_boundaries']
507
+ hard_boundaries = output['hard_boundaries']
508
+
509
+ # Apply curriculum override for hard_boundaries
510
+ soft_boundaries, hard_boundaries = self.curriculum.get_boundaries(
511
+ soft_boundaries, step
512
+ )
513
+
514
+ # BP dual loss (D-26)
515
+ seq_len = input_ids.shape[-1]
516
+ if not self._bp_frozen:
517
+ bp_loss = compute_bp_loss(soft_boundaries, hard_boundaries, seq_len)
518
+ else:
519
+ bp_loss = torch.tensor(0.0, device=input_ids.device)
520
+
521
+ # Total loss: LM + MoE aux + BP
522
+ if isinstance(aux_loss, torch.Tensor):
523
+ total_loss = lm_loss + self.model.config.router_aux_loss_coef * aux_loss
524
+ else:
525
+ total_loss = lm_loss + self.model.config.router_aux_loss_coef * aux_loss
526
+ total_loss = total_loss + self.bp_loss_weight * bp_loss
527
+
528
+ # Collect monitoring metrics
529
+ metrics = {
530
+ 'lm_loss': lm_loss.item() if isinstance(lm_loss, torch.Tensor) else lm_loss,
531
+ 'aux_loss': aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss,
532
+ 'bp_loss': bp_loss.item() if isinstance(bp_loss, torch.Tensor) else bp_loss,
533
+ 'bp_frozen': self._bp_frozen,
534
+ 'curriculum_phase': self.curriculum.get_phase(step),
535
+ 'is_fixed_bp': self.curriculum.is_fixed_bp(step),
536
+ }
537
+
538
+ return total_loss, metrics
539
+
540
+ def clip_gradients(self) -> float:
541
+ """Clip gradients: global + per-loop expert core clipping.
542
+
543
+ Standard: clip_grad_norm_(all params, max_norm=1.0)
544
+ Expert cores: tighter clip at 0.5 to prevent drift.
545
+ """
546
+ # Global gradient clipping
547
+ grad_norm = nn.utils.clip_grad_norm_(
548
+ self.model.parameters(), max_norm=self.grad_clip
549
+ )
550
+
551
+ # Per-loop expert core clipping (tighter)
552
+ expert_core_params = []
553
+ for name, param in self.model.named_parameters():
554
+ if ('W_gate' in name or 'W_transform' in name) and param.grad is not None:
555
+ expert_core_params.append(param)
556
+
557
+ if expert_core_params:
558
+ nn.utils.clip_grad_norm_(
559
+ expert_core_params, max_norm=self.expert_core_grad_clip
560
+ )
561
+
562
+ return grad_norm.item() if isinstance(grad_norm, torch.Tensor) else float(grad_norm)
563
+
564
+
565
+ # ============================================================================
566
+ # LR Schedule
567
+ # ============================================================================
568
+
569
+ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
570
+ """Cosine learning rate with linear warmup."""
571
+ if step < warmup:
572
+ return max_lr * step / warmup
573
+ if step >= total:
574
+ return min_lr
575
+ decay = (step - warmup) / (total - warmup)
576
+ return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
577
+
578
+
579
+ # ============================================================================
580
+ # Checkpointing
581
+ # ============================================================================
582
+
583
+ def save_step_checkpoint(model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp=False, trainer=None, current_best_loss=float("inf")):
584
+ """Save full checkpoint (model + optimizer) and keep only the last 2."""
585
+ if ddp:
586
+ from torch.distributed.fsdp import (
587
+ FullyShardedDataParallel as FSDP,
588
+ StateDictType,
589
+ FullStateDictConfig,
590
+ )
591
+ with FSDP.state_dict_type(
592
+ model,
593
+ StateDictType.FULL_STATE_DICT,
594
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
595
+ ):
596
+ model_state = model.state_dict()
597
+ optim_state = FSDP.optim_state_dict(model, optimizer)
598
+ else:
599
+ model_state = model.state_dict()
600
+ optim_state = optimizer.state_dict()
601
+
602
+ if not master:
603
+ return None, 0
604
+
605
+ os.makedirs(ckpt_dir, exist_ok=True)
606
+ ckpt_path = os.path.join(ckpt_dir, f"spider-step{step}.pt")
607
+ tmp_path = ckpt_path + ".tmp"
608
+ torch.save(
609
+ {
610
+ "step": step,
611
+ "epoch": epoch,
612
+ "model_state_dict": model_state,
613
+ "optimizer_state_dict": optim_state,
614
+ "cfg": cfg,
615
+ "bp_optimizer_state_dict": (
616
+ trainer.bp_optimizer.state_dict() if trainer and trainer.bp_optimizer else None
617
+ ),
618
+ "best_loss": current_best_loss,
619
+ },
620
+ tmp_path,
621
+ )
622
+ os.replace(tmp_path, ckpt_path)
623
+ size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
624
+
625
+ # Keep only the last 2 step checkpoints
626
+ step_pattern = re.compile(r"spider-step\d+\.pt$")
627
+ step_ckpts = sorted(
628
+ [os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir) if step_pattern.search(f)],
629
+ key=os.path.getmtime,
630
+ )
631
+ while len(step_ckpts) > 2:
632
+ old = step_ckpts.pop(0)
633
+ os.remove(old)
634
+
635
+ return ckpt_path, size_mb
636
+
637
+
638
+ def save_full_checkpoint(model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp=False, ckpt_name="full", trainer=None, current_best_loss=float("inf")):
639
+ """Save full checkpoint with custom name."""
640
+ if ddp:
641
+ from torch.distributed.fsdp import (
642
+ FullyShardedDataParallel as FSDP,
643
+ StateDictType,
644
+ FullStateDictConfig,
645
+ )
646
+ with FSDP.state_dict_type(
647
+ model,
648
+ StateDictType.FULL_STATE_DICT,
649
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
650
+ ):
651
+ model_state = model.state_dict()
652
+ optim_state = FSDP.optim_state_dict(model, optimizer)
653
+ else:
654
+ model_state = model.state_dict()
655
+ optim_state = optimizer.state_dict()
656
+
657
+ if not master:
658
+ return None, 0
659
+
660
+ os.makedirs(ckpt_dir, exist_ok=True)
661
+ final_path = os.path.join(ckpt_dir, f"spider-{ckpt_name}.pt")
662
+ tmp_path = final_path + ".tmp"
663
+ torch.save(
664
+ {
665
+ "step": step,
666
+ "epoch": epoch,
667
+ "model_state_dict": model_state,
668
+ "optimizer_state_dict": optim_state,
669
+ "cfg": cfg,
670
+ "bp_optimizer_state_dict": (
671
+ trainer.bp_optimizer.state_dict() if trainer and trainer.bp_optimizer else None
672
+ ),
673
+ "best_loss": current_best_loss,
674
+ },
675
+ tmp_path,
676
+ )
677
+ os.replace(tmp_path, final_path)
678
+ size_mb = os.path.getsize(final_path) / (1024 * 1024)
679
+ return final_path, size_mb
680
+
681
+
682
+ def load_checkpoint(model, optimizer, path, ddp=False):
683
+ """Load model + optimizer state from checkpoint.
684
+
685
+ Handles cross-optimizer resume (e.g. 8bit Adam on local → standard AdamW
686
+ on remote): if optimizer state dict keys mismatch, we skip the optimizer
687
+ state and log a warning. The model weights always load successfully.
688
+
689
+ Returns: (step, epoch, bp_optim_state, saved_best_loss)
690
+ """
691
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
692
+ model.load_state_dict(ckpt["model_state_dict"])
693
+ try:
694
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
695
+ except (ValueError, KeyError, RuntimeError) as e:
696
+ logger.warning(
697
+ f"Optimizer state mismatch (likely 8bit→standard cross-resume): {e}. "
698
+ f"Skipping optimizer state — optimizer will reinitialize."
699
+ )
700
+ bp_optim_state = ckpt.get("bp_optimizer_state_dict", None)
701
+ saved_best_loss = ckpt.get("best_loss", float("inf"))
702
+ return int(ckpt["step"]), int(ckpt.get("epoch", 0)), bp_optim_state, saved_best_loss
703
+
704
+
705
+ # ============================================================================
706
+ # DeepSpeed Config (fallback for RTX 4060 8GB)
707
+ # ============================================================================
708
+
709
+ DEEPSPEED_ZERO3_CONFIG = {
710
+ "bf16": {"enabled": True},
711
+ "zero_optimization": {
712
+ "stage": 3,
713
+ "offload_optimizer": {
714
+ "device": "cpu",
715
+ "pin_memory": True,
716
+ },
717
+ "offload_param": {
718
+ "device": "cpu",
719
+ "pin_memory": True,
720
+ },
721
+ "overlap_comm": True,
722
+ "contiguous_gradients": True,
723
+ },
724
+ "gradient_accumulation_steps": 1,
725
+ "gradient_clipping": 1.0,
726
+ "train_batch_size": "auto",
727
+ "train_micro_batch_size_per_gpu": "auto",
728
+ }
729
+
730
+
731
+ # ============================================================================
732
+ # Precision Mode (MXFP8 / NVFP4 / FP8_DYNAMIC / BF16)
733
+ # ============================================================================
734
+
735
+ import enum
736
+
737
+ class PrecisionMode(enum.Enum):
738
+ BF16 = "bf16"
739
+ FP8_DYNAMIC = "fp8_dynamic"
740
+ MXFP8 = "mxfp8"
741
+ NVFP4 = "nvfp4"
742
+
743
+
744
+ def detect_precision_mode() -> PrecisionMode:
745
+ """Auto-detect best available precision mode based on GPU + libraries.
746
+
747
+ Fallback chain: MXFP8/NVFP4 → FP8_DYNAMIC → BF16
748
+
749
+ - MXFP8: Requires Blackwell+ (sm120+), torchao with float8 training,
750
+ block-wise scaling (128x128). Best accuracy among FP8 options.
751
+ - NVFP4: Requires Blackwell+ (sm120+), fbgemm-gpu-genai with NVFP4
752
+ kernels. Most aggressive compression (4-bit weights).
753
+ - FP8_DYNAMIC: Requires Ada Lovelace+ (sm89+), torchao float8.
754
+ Row-wise dynamic scaling. Good speed/accuracy tradeoff.
755
+ - BF16: Fallback for all GPUs. Standard mixed precision.
756
+ """
757
+ if not torch.cuda.is_available():
758
+ return PrecisionMode.BF16
759
+
760
+ cc = torch.cuda.get_device_capability()
761
+ major, minor = cc
762
+
763
+ # Check for torchao float8 training support
764
+ _has_torchao_fp8 = False
765
+ try:
766
+ from torchao.float8 import convert_to_float8_training
767
+ _has_torchao_fp8 = True
768
+ except ImportError:
769
+ pass
770
+
771
+ # Check for fbgemm NVFP4 support
772
+ _has_nvfp4 = False
773
+ try:
774
+ from torchao.quantization import NVFP4Config # type: ignore[attr-defined]
775
+ _has_nvfp4 = True
776
+ except (ImportError, AttributeError):
777
+ try:
778
+ import fbgemm_gpu.genai # type: ignore[import-untyped]
779
+ _has_nvfp4 = True
780
+ except (ImportError, ModuleNotFoundError):
781
+ pass
782
+
783
+ # Blackwell+ (sm120+): MXFP8 or NVFP4
784
+ if major >= 12:
785
+ if _has_torchao_fp8:
786
+ return PrecisionMode.MXFP8
787
+ if _has_nvfp4:
788
+ return PrecisionMode.NVFP4
789
+
790
+ # Ada Lovelace+ (sm89+): FP8 dynamic
791
+ if (major, minor) >= (8, 9) and _has_torchao_fp8:
792
+ return PrecisionMode.FP8_DYNAMIC
793
+
794
+ return PrecisionMode.BF16
795
+
796
+
797
+ def configure_fp8_training(model, mode: PrecisionMode):
798
+ """Apply torchao float8 training conversion to model.
799
+
800
+ FP8 training swaps nn.Linear layers with Float8Linear, which performs
801
+ dynamic quantization of activations and weights to float8_e4m3fn during
802
+ forward/backward, with high-precision accumulation.
803
+
804
+ Two recipes:
805
+ - MXFP8 (rowwise_with_gw_hp): Row-wise scaling + high-precision grad weight.
806
+ Best accuracy. Requires sm120+ hardware.
807
+ - FP8_DYNAMIC (rowwise): Row-wise dynamic scaling. Good tradeoff.
808
+ Requires sm89+ hardware.
809
+
810
+ Gradient computation stays in bf16/fp32 for stability.
811
+ """
812
+ from torchao.float8 import convert_to_float8_training, Float8LinearConfig
813
+
814
+ if mode == PrecisionMode.MXFP8:
815
+ recipe_name = "rowwise_with_gw_hp"
816
+ elif mode == PrecisionMode.FP8_DYNAMIC:
817
+ recipe_name = "rowwise"
818
+ else:
819
+ return model
820
+
821
+ base = Float8LinearConfig.from_recipe_name(recipe_name)
822
+ config = Float8LinearConfig(
823
+ cast_config_input=base.cast_config_input,
824
+ cast_config_weight=base.cast_config_weight,
825
+ cast_config_grad_output=base.cast_config_grad_output,
826
+ cast_config_input_for_grad_weight=base.cast_config_input_for_grad_weight,
827
+ cast_config_weight_for_grad_input=base.cast_config_weight_for_grad_input,
828
+ cast_config_grad_output_for_grad_weight=base.cast_config_grad_output_for_grad_weight,
829
+ gemm_config_output=base.gemm_config_output,
830
+ gemm_config_grad_input=base.gemm_config_grad_input,
831
+ gemm_config_grad_weight=base.gemm_config_grad_weight,
832
+ enable_fsdp_float8_all_gather=base.enable_fsdp_float8_all_gather,
833
+ round_scales_to_power_of_2=base.round_scales_to_power_of_2,
834
+ pad_inner_dim=True,
835
+ )
836
+
837
+ def module_filter_fn(mod, fqn):
838
+ skip = any(s in fqn for s in (
839
+ "boundary_predictor",
840
+ "loop_embedding",
841
+ "engram",
842
+ "layernorm",
843
+ "norm",
844
+ "embed_tokens",
845
+ "lm_head",
846
+ "halt_predictor",
847
+ "gate",
848
+ ))
849
+ return not skip
850
+
851
+ model = convert_to_float8_training(
852
+ model,
853
+ module_filter_fn=module_filter_fn,
854
+ config=config,
855
+ )
856
+ return model
857
+
858
+
859
+ def configure_nvfp4_training(model):
860
+ """Apply NVFP4 weight-only quantization for training on Blackwell.
861
+
862
+ NVFP4 uses 4-bit floating-point weights with 8-bit scaling factors.
863
+ Activations stay in bf16/fp8. Requires fbgemm-gpu-genai kernels.
864
+
865
+ Falls back to FP8_DYNAMIC if NVFP4 kernels unavailable.
866
+ """
867
+ try:
868
+ from torchao.quantization import NVFP4Config, quantize_
869
+ quantize_(model, NVFP4Config())
870
+ return model
871
+ except (ImportError, AttributeError, RuntimeError):
872
+ logger.warning("NVFP4 not available, falling back to FP8_DYNAMIC")
873
+ return configure_fp8_training(model, PrecisionMode.FP8_DYNAMIC)
874
+
875
+
876
+ def try_unsloth():
877
+ """Attempt to apply Unsloth patches. Returns (available, FastLanguageModel)."""
878
+ try:
879
+ from unsloth import FastLanguageModel
880
+ return True, FastLanguageModel
881
+ except (ImportError, Exception):
882
+ return False, None
883
+
884
+
885
+ # ============================================================================
886
+ # Main Training Loop
887
+ # ============================================================================
888
+
889
+ def parse_args():
890
+ parser = argparse.ArgumentParser(description="Spider-FLEXITOKENS training")
891
+ parser.add_argument("--resume", type=str, default="", help="Path to checkpoint to resume from")
892
+ parser.add_argument("--max_steps", type=int, default=0, help="Override max training steps")
893
+ parser.add_argument("--mock_data", action="store_true", help="Use mock data (no network)")
894
+ parser.add_argument("--seq_len", type=int, default=0, help="Override sequence length")
895
+ parser.add_argument("--micro_batch", type=int, default=0, help="Override micro batch size")
896
+ parser.add_argument("--n_loops", type=int, default=0, help="Override number of loops")
897
+ parser.add_argument("--lr", type=float, default=0, help="Override learning rate")
898
+ parser.add_argument("--ckpt_dir", type=str, default="checkpoints-spider", help="Checkpoint directory")
899
+ parser.add_argument("--no_unsloth", action="store_true", help="Skip Unsloth even if available")
900
+ parser.add_argument(
901
+ "--precision", type=str, default="auto",
902
+ choices=["auto", "bf16", "fp8_dynamic", "mxfp8", "nvfp4"],
903
+ help="Training precision: auto (detect), bf16, fp8_dynamic, mxfp8, nvfp4",
904
+ )
905
+ return parser.parse_args()
906
+
907
+
908
+ def main():
909
+ global best_loss
910
+ args = parse_args()
911
+
912
+ # ------------------------------------------------------------------
913
+ # Distributed init
914
+ # ------------------------------------------------------------------
915
+ ddp = int(os.environ.get("RANK", -1)) != -1
916
+ if ddp:
917
+ dist.init_process_group("nccl")
918
+ rank = int(os.environ["RANK"])
919
+ local_rank = int(os.environ["LOCAL_RANK"])
920
+ world_size = int(os.environ["WORLD_SIZE"])
921
+ device = f"cuda:{local_rank}"
922
+ torch.cuda.set_device(device)
923
+ else:
924
+ rank = local_rank = 0
925
+ world_size = 1
926
+ device = "cuda" if torch.cuda.is_available() else "cpu"
927
+ master = rank == 0
928
+
929
+ # ------------------------------------------------------------------
930
+ # Hyperparameters
931
+ # ------------------------------------------------------------------
932
+ seq_len = args.seq_len or int(os.environ.get("SEQ_LEN", "2048"))
933
+ micro_batch = args.micro_batch or int(os.environ.get("MICRO_BATCH", "4"))
934
+ target_tokens = int(os.environ.get("TARGET_TOKENS", "10_000_000_000"))
935
+ grad_accum = int(os.environ.get("GRAD_ACCUM", "1"))
936
+ n_loops = args.n_loops or int(os.environ.get("N_LOOPS", "6"))
937
+ lr = args.lr or float(os.environ.get("LR", "3e-4"))
938
+ wd = 0.1
939
+ warmup_steps = 200
940
+ log_every = 10
941
+ ckpt_every = int(os.environ.get("CKPT_EVERY", "500"))
942
+ ckpt_dir = args.ckpt_dir
943
+
944
+ global_batch_tok = world_size * micro_batch * grad_accum * seq_len
945
+ total_steps = target_tokens // global_batch_tok
946
+ if args.max_steps > 0:
947
+ total_steps = min(total_steps, args.max_steps)
948
+
949
+ if master:
950
+ logger.info(
951
+ f"[Spider-FLEXITOKENS] hidden=2048 | 6 recurrent | 32 experts top-2 | "
952
+ f"n_loops={n_loops} | seq_len={seq_len} | micro_batch={micro_batch} | "
953
+ f"grad_accum={grad_accum} | global_batch_tokens={global_batch_tok:,} | "
954
+ f"total_steps={total_steps:,}"
955
+ )
956
+ logger.info(
957
+ f"Byte-level vocab: 272 | Context: 256k (YaRN-8) | "
958
+ f"Sliding window: 8192 | BP curriculum: fixed 30% -> adaptive | "
959
+ f"Gradient checkpointing: enabled | Precision: {prec_mode.value}"
960
+ )
961
+
962
+ # ------------------------------------------------------------------
963
+ # Model + Precision Mode
964
+ # ------------------------------------------------------------------
965
+ cfg = SpiderConfig()
966
+ bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
967
+ amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
968
+
969
+ # Resolve precision mode: CLI override or auto-detect
970
+ if args.precision == "auto":
971
+ prec_mode = detect_precision_mode()
972
+ else:
973
+ prec_mode = PrecisionMode(args.precision)
974
+
975
+ if master:
976
+ logger.info(f"Precision mode: {prec_mode.value}")
977
+
978
+ model = SpiderForConditionalGeneration(cfg).to(amp_dtype)
979
+ model.gradient_checkpointing_enable()
980
+ model.enable_input_require_grads()
981
+
982
+ # Apply FP8/MXFP8/NVFP4 quantization (before Unsloth, before FSDP)
983
+ if prec_mode in (PrecisionMode.MXFP8, PrecisionMode.FP8_DYNAMIC):
984
+ try:
985
+ model = configure_fp8_training(model, prec_mode)
986
+ if master:
987
+ logger.info(f"torchao FP8 training enabled: {prec_mode.value}")
988
+ except Exception as e:
989
+ if master:
990
+ logger.warning(f"FP8 training setup failed ({e}), falling back to BF16")
991
+ prec_mode = PrecisionMode.BF16
992
+ elif prec_mode == PrecisionMode.NVFP4:
993
+ try:
994
+ model = configure_nvfp4_training(model)
995
+ if master:
996
+ logger.info("NVFP4 training enabled")
997
+ except Exception as e:
998
+ if master:
999
+ logger.warning(f"NVFP4 setup failed ({e}), falling back to FP8_DYNAMIC")
1000
+ try:
1001
+ model = configure_fp8_training(model, PrecisionMode.FP8_DYNAMIC)
1002
+ prec_mode = PrecisionMode.FP8_DYNAMIC
1003
+ if master:
1004
+ logger.info("Fallback: FP8_DYNAMIC training enabled")
1005
+ except Exception as e2:
1006
+ if master:
1007
+ logger.warning(f"FP8 fallback also failed ({e2}), using BF16")
1008
+ prec_mode = PrecisionMode.BF16
1009
+
1010
+ # Unsloth (optional, per D-35): applies MoE kernel optimizations,
1011
+ # gradient checkpointing, and memory-efficient attention
1012
+ use_unsloth = False
1013
+ if not args.no_unsloth and not ddp:
1014
+ use_unsloth_available, FastLanguageModel_cls = try_unsloth()
1015
+ if use_unsloth_available:
1016
+ try:
1017
+ # Unsloth patches: SDPA optimization, memory-efficient GC
1018
+ # For MoE: set UNSLOTH_MOE_BACKEND=grouped_mm (default)
1019
+ os.environ.setdefault("UNSLOTH_MOE_BACKEND", "grouped_mm")
1020
+ use_unsloth = True
1021
+ if master:
1022
+ logger.info("Unsloth MoE + training patches applied")
1023
+ except Exception as e:
1024
+ if master:
1025
+ logger.warning(f"Unsloth patching failed: {e}")
1026
+ if not use_unsloth and master:
1027
+ logger.info("Unsloth not available, using standard PyTorch training")
1028
+
1029
+ if ddp:
1030
+ from torch.distributed.fsdp import (
1031
+ FullyShardedDataParallel as FSDP,
1032
+ ShardingStrategy,
1033
+ MixedPrecision,
1034
+ )
1035
+ from torch.distributed.fsdp.wrap import ModuleWrapPolicy
1036
+ from spider import SpiderDenseLayer, SpiderRecurrentLayer
1037
+
1038
+ mp_policy = MixedPrecision(
1039
+ param_dtype=amp_dtype,
1040
+ reduce_dtype=amp_dtype,
1041
+ buffer_dtype=amp_dtype,
1042
+ )
1043
+ wrap_policy = ModuleWrapPolicy({SpiderDenseLayer, SpiderRecurrentLayer})
1044
+ model = FSDP(
1045
+ model,
1046
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
1047
+ mixed_precision=mp_policy,
1048
+ auto_wrap_policy=wrap_policy,
1049
+ device_id=local_rank,
1050
+ )
1051
+ else:
1052
+ model = model.to(device)
1053
+
1054
+ if master:
1055
+ n_params = sum(p.numel() for p in model.parameters())
1056
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
1057
+ logger.info(
1058
+ f"Parameters: {n_params:,} total | {trainable:,} trainable | "
1059
+ f"Precision: {prec_mode.value} | AMP dtype: {amp_dtype}"
1060
+ )
1061
+
1062
+ # ------------------------------------------------------------------
1063
+ # Optimizer — 8-bit Adam for BF16 on small GPUs; standard AdamW for FP8+
1064
+ # When FP8/MXFP8/NVFP4 is active, weight memory is already halved,
1065
+ # so 8-bit Adam is less critical and can conflict with Float8Linear.
1066
+ # Dual optimizer for Engram embeddings (per mythos pattern).
1067
+ # ------------------------------------------------------------------
1068
+ engram_params_list = [
1069
+ p for n, p in model.named_parameters()
1070
+ if 'engram' in n and 'embed' in n and 'proj' not in n
1071
+ ]
1072
+ backbone_params = [
1073
+ p for n, p in model.named_parameters()
1074
+ if not ('engram' in n and 'embed' in n and 'proj' not in n)
1075
+ ]
1076
+
1077
+ use_8bit_optimizer = _HAS_8BIT and prec_mode == PrecisionMode.BF16
1078
+
1079
+ if use_8bit_optimizer:
1080
+ optimizer = AdamW8bit(
1081
+ backbone_params, lr=lr, weight_decay=wd,
1082
+ betas=(0.9, 0.95), eps=1e-8,
1083
+ )
1084
+ if engram_params_list:
1085
+ engram_optimizer = Adam8bit(
1086
+ engram_params_list, lr=lr * 5,
1087
+ betas=(0.9, 0.95), eps=1e-8,
1088
+ )
1089
+ else:
1090
+ engram_optimizer = None
1091
+ if master:
1092
+ logger.info("Optimizer: 8-bit AdamW (bf16 mode, saves ~50% optimizer VRAM)")
1093
+ else:
1094
+ optimizer = torch.optim.AdamW(
1095
+ backbone_params, lr=lr, weight_decay=wd,
1096
+ betas=(0.9, 0.95), foreach=True, eps=1e-8,
1097
+ )
1098
+ if engram_params_list:
1099
+ engram_optimizer = torch.optim.Adam(
1100
+ engram_params_list, lr=lr * 5,
1101
+ betas=(0.9, 0.95), eps=1e-8,
1102
+ )
1103
+ else:
1104
+ engram_optimizer = None
1105
+ if master:
1106
+ logger.info(f"Optimizer: standard AdamW ({prec_mode.value} mode)")
1107
+
1108
+ # ------------------------------------------------------------------
1109
+ # Curriculum + Monitor + Trainer
1110
+ # ------------------------------------------------------------------
1111
+ curriculum = CurriculumScheduler(total_steps=total_steps)
1112
+ monitor = RecurrentMonitor()
1113
+ trainer = BPCurriculumTrainer(
1114
+ model=model,
1115
+ optimizer=optimizer,
1116
+ engram_optimizer=engram_optimizer,
1117
+ curriculum=curriculum,
1118
+ monitor=monitor,
1119
+ warmup_steps=warmup_steps,
1120
+ base_lr=lr,
1121
+ )
1122
+
1123
+ # ------------------------------------------------------------------
1124
+ # Resume from checkpoint
1125
+ # ------------------------------------------------------------------
1126
+ start_step = 0
1127
+ start_epoch = 1
1128
+ bp_optim_state_to_load = None
1129
+ if args.resume and os.path.exists(args.resume):
1130
+ if master:
1131
+ logger.info(f"Resuming from checkpoint: {args.resume}")
1132
+ start_step, start_epoch, bp_optim_state_to_load, saved_best = load_checkpoint(
1133
+ model, optimizer, args.resume, ddp
1134
+ )
1135
+ best_loss = saved_best
1136
+ if master:
1137
+ logger.info(f"Resumed at step {start_step}, epoch {start_epoch}, best_loss={best_loss:.4f}")
1138
+ else:
1139
+ # Auto-resume from latest checkpoint in ckpt_dir
1140
+ existing_ckpts = sorted(
1141
+ [os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir)
1142
+ if f.startswith("spider-") and f.endswith(".pt") and not f.endswith(".tmp")]
1143
+ ) if os.path.isdir(ckpt_dir) else []
1144
+ if existing_ckpts:
1145
+ latest = existing_ckpts[-1]
1146
+ if master:
1147
+ logger.info(f"Auto-resuming from: {latest}")
1148
+ start_step, start_epoch, bp_optim_state_to_load, saved_best = load_checkpoint(
1149
+ model, optimizer, latest, ddp
1150
+ )
1151
+ best_loss = saved_best
1152
+ if master:
1153
+ logger.info(f"Resumed at step {start_step}, epoch {start_epoch}, best_loss={best_loss:.4f}")
1154
+
1155
+ # Restore BP optimizer state if available (after trainer is created,
1156
+ # BP optimizer is initialized during first unfreeze_bp() call)
1157
+ if bp_optim_state_to_load and trainer.bp_optimizer:
1158
+ try:
1159
+ trainer.bp_optimizer.load_state_dict(bp_optim_state_to_load)
1160
+ if master:
1161
+ logger.info("Restored BP optimizer state from checkpoint")
1162
+ except (ValueError, KeyError, RuntimeError) as e:
1163
+ if master:
1164
+ logger.warning(f"BP optimizer state mismatch, skipping: {e}")
1165
+
1166
+ # ------------------------------------------------------------------
1167
+ # Dataset + DataLoader
1168
+ # ------------------------------------------------------------------
1169
+ if args.mock_data:
1170
+ dataset = MockByteLevelDataset(seq_len=seq_len)
1171
+ else:
1172
+ dataset = ByteLevelDataset(
1173
+ seq_len=seq_len,
1174
+ rank=rank,
1175
+ world_size=world_size,
1176
+ )
1177
+
1178
+ loader = DataLoader(
1179
+ dataset,
1180
+ batch_size=micro_batch,
1181
+ num_workers=4 if not args.mock_data else 0,
1182
+ pin_memory=True,
1183
+ prefetch_factor=1 if not args.mock_data else None,
1184
+ )
1185
+
1186
+ # ------------------------------------------------------------------
1187
+ # AMP + SDPA contexts
1188
+ # ------------------------------------------------------------------
1189
+ amp_ctx = (
1190
+ torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
1191
+ if "cuda" in device
1192
+ else nullcontext()
1193
+ )
1194
+ amp_ctx = nullcontext() if ddp else amp_ctx
1195
+
1196
+ try:
1197
+ from torch.nn.attention import sdpa_kernel
1198
+ sdpa_ctx = sdpa_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=True)
1199
+ except Exception:
1200
+ sdpa_ctx = nullcontext()
1201
+
1202
+ # ------------------------------------------------------------------
1203
+ # Training loop
1204
+ # ------------------------------------------------------------------
1205
+ if master:
1206
+ os.makedirs(ckpt_dir, exist_ok=True)
1207
+
1208
+ model.train()
1209
+ data_iter = iter(loader)
1210
+ t0 = time.perf_counter()
1211
+ step = start_step
1212
+ epoch = start_epoch
1213
+ tokens_in_epoch = 0
1214
+ tokens_per_epoch = target_tokens
1215
+
1216
+ while step < total_steps:
1217
+ cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
1218
+ for g in optimizer.param_groups:
1219
+ g["lr"] = cur_lr
1220
+ if engram_optimizer:
1221
+ for g in engram_optimizer.param_groups:
1222
+ g["lr"] = cur_lr * 5
1223
+
1224
+ optimizer.zero_grad()
1225
+ if engram_optimizer:
1226
+ engram_optimizer.zero_grad()
1227
+ if trainer.bp_optimizer:
1228
+ trainer.bp_optimizer.zero_grad()
1229
+ loss_accum = 0.0
1230
+ metrics_accum = {}
1231
+
1232
+ for micro_step in range(grad_accum):
1233
+ try:
1234
+ x, y = next(data_iter)
1235
+ except StopIteration:
1236
+ data_iter = iter(loader)
1237
+ x, y = next(data_iter)
1238
+
1239
+ x = x.to(device, non_blocking=True)
1240
+ y = y.to(device, non_blocking=True)
1241
+
1242
+ sync = (
1243
+ nullcontext()
1244
+ if (not ddp or micro_step == grad_accum - 1)
1245
+ else model.no_sync()
1246
+ )
1247
+ with sync:
1248
+ total_loss, metrics = trainer.train_step(
1249
+ x, y, step, n_loops=n_loops,
1250
+ amp_ctx=amp_ctx, sdpa_ctx=sdpa_ctx,
1251
+ )
1252
+ total_loss = total_loss / grad_accum
1253
+ total_loss.backward()
1254
+
1255
+ if master and step == start_step and micro_step == 0:
1256
+ peak_vram = torch.cuda.max_memory_allocated() / 1024**3
1257
+ logger.info(f"First forward+backward | Peak VRAM: {peak_vram:.1f}GB")
1258
+
1259
+ loss_accum += total_loss.item()
1260
+ for k, v in metrics.items():
1261
+ if k not in metrics_accum:
1262
+ metrics_accum[k] = 0.0
1263
+ if isinstance(v, (int, float)):
1264
+ metrics_accum[k] += v / grad_accum
1265
+
1266
+ # Gradient clipping
1267
+ grad_norm = trainer.clip_gradients()
1268
+ optimizer.step()
1269
+ if engram_optimizer:
1270
+ engram_optimizer.step()
1271
+ if trainer.bp_optimizer:
1272
+ for g in trainer.bp_optimizer.param_groups:
1273
+ g["lr"] = cur_lr * 0.1
1274
+ trainer.bp_optimizer.step()
1275
+ step += 1
1276
+ tokens_in_epoch += global_batch_tok
1277
+
1278
+ # Health checks
1279
+ if master and step % log_every == 0:
1280
+ health_warnings = monitor.check_health(metrics_accum, step)
1281
+ for w in health_warnings:
1282
+ logger.warning(w)
1283
+
1284
+ # Logging
1285
+ if master and step % log_every == 0:
1286
+ dt = time.perf_counter() - t0
1287
+ tok_per_sec = global_batch_tok * log_every / dt
1288
+ tokens_seen = step * global_batch_tok
1289
+ bp_status = "FIXED" if metrics_accum.get('is_fixed_bp', True) else "ADAPTIVE"
1290
+ bp_frozen = "FROZEN" if metrics_accum.get('bp_frozen', False) else "ACTIVE"
1291
+ logger.info(
1292
+ f"Epoch {epoch} | step {step:6d}/{total_steps} | "
1293
+ f"loss {loss_accum:.4f} | lm {metrics_accum.get('lm_loss', 0):.4f} | "
1294
+ f"aux {metrics_accum.get('aux_loss', 0):.4f} | "
1295
+ f"bp {metrics_accum.get('bp_loss', 0):.4f} [{bp_status}/{bp_frozen}] | "
1296
+ f"gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} | "
1297
+ f"{tok_per_sec / 1e6:.2f}M tok/s | {tokens_seen / 1e9:.2f}B tokens"
1298
+ )
1299
+ t0 = time.perf_counter()
1300
+
1301
+ # Checkpointing
1302
+ if step % ckpt_every == 0:
1303
+ ckpt_path, size_mb = save_step_checkpoint(
1304
+ model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp, trainer,
1305
+ current_best_loss=best_loss,
1306
+ )
1307
+ if master and ckpt_path:
1308
+ logger.info(f"Saved step checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1309
+
1310
+ # Epoch boundary
1311
+ if tokens_in_epoch >= tokens_per_epoch:
1312
+ epoch_loss = loss_accum
1313
+ if master:
1314
+ logger.info(f"Epoch {epoch} complete | loss={epoch_loss:.4f}")
1315
+ ckpt_path, size_mb = save_full_checkpoint(
1316
+ model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp, f"ep{epoch}", trainer,
1317
+ current_best_loss=best_loss,
1318
+ )
1319
+ if master and ckpt_path:
1320
+ logger.info(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1321
+
1322
+ if epoch_loss < best_loss:
1323
+ best_loss = epoch_loss
1324
+ ckpt_path, size_mb = save_full_checkpoint(
1325
+ model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp, "best", trainer,
1326
+ current_best_loss=best_loss,
1327
+ )
1328
+ if master and ckpt_path:
1329
+ logger.info(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1330
+
1331
+ epoch += 1
1332
+ tokens_in_epoch = 0
1333
+
1334
+ # Final checkpoint
1335
+ if step > start_step and master:
1336
+ ckpt_path, size_mb = save_full_checkpoint(
1337
+ model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp, f"final-ep{epoch}", trainer,
1338
+ current_best_loss=best_loss,
1339
+ )
1340
+ if ckpt_path:
1341
+ logger.info(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1342
+
1343
+ if ddp:
1344
+ dist.barrier()
1345
+ dist.destroy_process_group()
1346
+
1347
+ if master:
1348
+ logger.info("Training complete.")
1349
+
1350
+
1351
+ if __name__ == "__main__":
1352
+ best_loss = float("inf")
1353
+ main()