| |
| """Spider-FLEXITOKENS training pipeline. |
| |
| Byte-level pretraining on FineWeb-Edu with boundary predictor curriculum. |
| Architecture: RDT (2 prelude + 6 recurrent + 2 coda) with: |
| - SharedProjectionMoE (32 experts, top-2, shared_inter=6144, rank=256) |
| - MLA (Multi-Latent Attention) with compressed KV cache + sliding window |
| - Engram conditional memory at recurrent layers 1 and 4 |
| - BoundaryPredictor + downsample/upsample for FlexiToken integration |
| - LTI Injection + ACT Halting + LoRA Adapter |
| - 256k context (YaRN factor=8.0), sliding_window=8192 |
| - 272-token byte-level vocab (256 bytes + 16 specials) |
| |
| Usage: |
| Single GPU: |
| python train_spider.py |
| Multi-GPU: |
| torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") train_spider.py |
| Resume from checkpoint: |
| python train_spider.py --resume checkpoints/spider-step5000.pt |
| Quick smoke test: |
| python train_spider.py --max_steps 50 --mock_data |
| """ |
|
|
| import os |
| import math |
| import re |
| import sys |
| import time |
| import argparse |
| from contextlib import nullcontext |
| from dataclasses import dataclass, field |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| from torch.nn import CrossEntropyLoss |
| from torch.utils.data import IterableDataset, DataLoader, get_worker_info |
|
|
| from datasets import load_dataset |
|
|
| try: |
| import bitsandbytes as bnb |
| AdamW8bit = bnb.optim.AdamW8bit |
| Adam8bit = bnb.optim.Adam8bit |
| _HAS_8BIT = True |
| except ImportError: |
| _HAS_8BIT = False |
| AdamW8bit = None |
| Adam8bit = None |
|
|
| from spider import ( |
| SpiderConfig, |
| SpiderForConditionalGeneration, |
| SENTINEL_TOKENS, |
| ) |
|
|
| try: |
| from loguru import logger |
| logger.remove() |
| logger.add(sys.stderr, format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") |
| logger.add("train_spider.log", rotation="100 MB", retention="10 days", |
| format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") |
| except ImportError: |
| import logging |
| logging.basicConfig(level=logging.INFO) |
| class _LoguruShim: |
| def info(self, msg): logging.info(msg) |
| def success(self, msg): logging.info(msg) |
| def warning(self, msg): logging.warning(msg) |
| def error(self, msg): logging.error(msg) |
| logger = _LoguruShim() |
|
|
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
|
|
| |
| |
| |
|
|
| BOS_ID = SENTINEL_TOKENS['BOS'] |
| EOS_ID = SENTINEL_TOKENS['EOS'] |
| PAD_ID = SENTINEL_TOKENS['PAD'] |
|
|
|
|
| class ByteLevelDataset(IterableDataset): |
| """Streaming byte-level dataset from FineWeb-Edu. |
| |
| Per D-23: FineWeb-Edu (English first), per-sample UTF-8 byte encoding. |
| Per D-24: Curriculum ordering (English -> multilingual -> code -> math). |
| Per D-34: Streaming only, no local download. |
| |
| Each sample is encoded as raw UTF-8 bytes with BOS/EOS sentinel tokens. |
| Vocab: 272 tokens (256 bytes + 16 specials). Max 8192 bytes per sample. |
| """ |
|
|
| def __init__( |
| self, |
| dataset_name: str = "HuggingFaceFW/fineweb-edu", |
| subset: str = "sample-10BT", |
| split: str = "train", |
| seq_len: int = 8192, |
| max_bytes: int = 8192, |
| rank: int = 0, |
| world_size: int = 1, |
| ): |
| self.seq_len = seq_len |
| self.max_bytes = max_bytes |
| self.dataset_name = dataset_name |
| self.subset = subset |
| self.split = split |
| self.rank = rank |
| self.world_size = world_size |
|
|
| def _encode_sample(self, text: str) -> List[int]: |
| """Encode text as UTF-8 bytes with BOS/EOS, truncated to max_bytes.""" |
| byte_ids = list(text.encode('utf-8'))[:self.max_bytes] |
| return [BOS_ID] + byte_ids + [EOS_ID] |
|
|
| def _pad_or_truncate(self, ids: List[int]) -> List[int]: |
| """Pad or truncate to seq_len, mask padding with -100 for labels.""" |
| ids = ids[:self.seq_len] |
| ids = ids + [PAD_ID] * (self.seq_len - len(ids)) |
| return ids |
|
|
| def __iter__(self): |
| worker = get_worker_info() |
| num_workers = worker.num_workers if worker else 1 |
| worker_id = worker.id if worker else 0 |
| total_shards = self.world_size * num_workers |
| shard_index = self.rank * num_workers + worker_id |
|
|
| ds = load_dataset( |
| self.dataset_name, |
| name=self.subset, |
| split=self.split, |
| streaming=True, |
| ).shard(num_shards=total_shards, index=shard_index) |
|
|
| buf = [] |
| for sample in ds: |
| text = sample.get("text", "") |
| if not text: |
| continue |
| byte_ids = self._encode_sample(text) |
| buf.extend(byte_ids) |
| while len(buf) >= self.seq_len + 1: |
| chunk = buf[:self.seq_len + 1] |
| buf = buf[self.seq_len + 1:] |
| x = torch.tensor(chunk[:-1], dtype=torch.long) |
| y = torch.tensor(chunk[1:], dtype=torch.long) |
| y[y == PAD_ID] = -100 |
| yield x, y |
|
|
|
|
| class MockByteLevelDataset(IterableDataset): |
| """In-memory byte-level dataset for testing (no network required). |
| |
| Uses a fixed set of text samples in multiple languages to verify |
| byte-level encoding, BOS/EOS placement, and multilingual handling. |
| """ |
|
|
| SAMPLES = [ |
| "Hello world, this is a test of the byte-level encoding system.", |
| "The quick brown fox jumps over the lazy dog.", |
| "Spider is a recurrent latent reasoning architecture with engram memory.", |
| "Boundary predictors learn to merge byte sequences into meaningful tokens.", |
| "FineWeb-Edu contains high-quality educational content for pretraining.", |
| "Это текст на русском языке для проверки многозычной поддержки.", |
| "తెలుగు భాష యొక్క పరీక్ష కోసం నమూనా వచనం.", |
| "中文文本用于测试多语言字节编码支持。", |
| "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)", |
| "The integral of x^2 from 0 to 1 equals 1/3.", |
| ] |
|
|
| def __init__(self, seq_len: int = 512, max_bytes: int = 512, num_samples: int = 1000): |
| self.seq_len = seq_len |
| self.max_bytes = max_bytes |
| self.num_samples = num_samples |
|
|
| def __iter__(self): |
| buf = [] |
| count = 0 |
| while count < self.num_samples: |
| for text in self.SAMPLES: |
| byte_ids = list(text.encode('utf-8'))[:self.max_bytes] |
| ids = [BOS_ID] + byte_ids + [EOS_ID] |
| buf.extend(ids) |
| while len(buf) >= self.seq_len + 1: |
| chunk = buf[:self.seq_len + 1] |
| buf = buf[self.seq_len + 1:] |
| x = torch.tensor(chunk[:-1], dtype=torch.long) |
| y = torch.tensor(chunk[1:], dtype=torch.long) |
| y[y == PAD_ID] = -100 |
| yield x, y |
| count += 1 |
| if count >= self.num_samples: |
| return |
|
|
|
|
| |
| |
| |
|
|
| class CurriculumScheduler: |
| """Training curriculum scheduler per D-24 and D-25. |
| |
| Manages dataset switching across training phases and boundary predictor |
| curriculum mode (fixed top-k vs adaptive threshold). |
| |
| Phases: |
| 0-30%: English (FineWeb-Edu), fixed top-k BP (D-25) |
| 30-50%: English + multilingual, adaptive BP |
| 50-70%: English + multilingual + code, adaptive BP |
| 70-90%: English + multilingual + code + math, adaptive BP |
| 90-100%: Mixed + multimodal, adaptive BP |
| """ |
|
|
| def __init__( |
| self, |
| total_steps: int, |
| fixed_compression_k: float = 3.3, |
| adaptive_threshold: float = 0.5, |
| ): |
| self.total_steps = total_steps |
| self.fixed_compression_k = fixed_compression_k |
| self.adaptive_threshold = adaptive_threshold |
| self.curriculum_switch_step = int(0.3 * total_steps) |
|
|
| def get_phase(self, step: int) -> int: |
| if step < int(0.3 * self.total_steps): |
| return 1 |
| elif step < int(0.5 * self.total_steps): |
| return 2 |
| elif step < int(0.7 * self.total_steps): |
| return 3 |
| elif step < int(0.9 * self.total_steps): |
| return 4 |
| else: |
| return 5 |
|
|
| def is_fixed_bp(self, step: int) -> bool: |
| """Return True if BP should use fixed top-k boundaries (D-25).""" |
| return step < self.curriculum_switch_step |
|
|
| def get_fixed_k(self, seq_len: int) -> int: |
| """Number of boundary positions for fixed top-k (3.3x compression).""" |
| return max(1, int(seq_len / self.fixed_compression_k)) |
|
|
| def get_boundaries( |
| self, |
| soft_boundaries: torch.Tensor, |
| step: int, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Compute hard boundaries based on curriculum phase. |
| |
| During fixed phase (first 30% of steps): top-k boundaries with |
| straight-through estimator. During adaptive phase: threshold-based. |
| |
| Args: |
| soft_boundaries: [B, L] boundary probabilities from BoundaryPredictor |
| step: Current training step |
| |
| Returns: |
| Tuple of (soft_boundaries, hard_boundaries), each [B, L] |
| """ |
| if self.is_fixed_bp(step): |
| k = self.get_fixed_k(soft_boundaries.shape[-1]) |
| topk_vals, topk_idx = soft_boundaries.topk(k, dim=-1) |
| hard_boundaries = torch.zeros_like(soft_boundaries) |
| hard_boundaries.scatter_(-1, topk_idx, 1.0) |
| hard_boundaries = ( |
| hard_boundaries - soft_boundaries.detach() + soft_boundaries |
| ) |
| else: |
| hard_boundaries = (soft_boundaries > self.adaptive_threshold).float() |
| hard_boundaries = ( |
| hard_boundaries - soft_boundaries.detach() + soft_boundaries |
| ) |
|
|
| return soft_boundaries, hard_boundaries |
|
|
|
|
| |
| |
| |
|
|
| def compute_bp_loss( |
| soft_boundaries: torch.Tensor, |
| hard_boundaries: torch.Tensor, |
| seq_len: int, |
| binomial_weight: float = 0.1, |
| pred_prior: float = 0.303, |
| ) -> torch.Tensor: |
| """Compute boundary predictor loss per D-26: BCE + binomial prior. |
| |
| During fixed phase: BCE on boundary decisions vs uniform target. |
| During adaptive phase: binomial prior loss only. |
| |
| Args: |
| soft_boundaries: [B, L] boundary probabilities |
| hard_boundaries: [B, L] binary boundary decisions |
| seq_len: Sequence length |
| binomial_weight: Weight for binomial prior term (0.1 per D-26) |
| pred_prior: Expected fraction of boundary positions (1/3.3 ≈ 0.303) |
| |
| Returns: |
| Scalar BP loss tensor |
| """ |
| B = soft_boundaries.shape[0] |
|
|
| |
| target_rate = 1.0 / 3.3 |
| target = torch.full_like(soft_boundaries, target_rate) |
| bce_loss = F.binary_cross_entropy(soft_boundaries, target) |
|
|
| |
| sum_preds = hard_boundaries.sum(dim=-1) |
| binomial = torch.distributions.binomial.Binomial( |
| total_count=float(seq_len), |
| probs=pred_prior, |
| ) |
| log_prob = binomial.log_prob(sum_preds) |
| binomial_loss = -log_prob.mean() / seq_len |
|
|
| return bce_loss + binomial_weight * binomial_loss |
|
|
|
|
| |
| |
| |
|
|
| class RecurrentMonitor: |
| """Monitors recurrent dynamics across loops during training. |
| |
| Catches representation drift, expert collapse, and engram instability |
| before they corrupt training. Per CONTEXT: representation drift across |
| loops is the #1 failure mode for recurrent architectures. |
| |
| Logged metrics (every log_interval steps): |
| - loop_norms: L2 norm of hidden states after each loop (drift detection) |
| - routing_entropy: entropy of expert routing weights per loop (collapse detection) |
| - engram_norms: L2 norm of engram residuals at layers 1 and 4 (memory stability) |
| - halt_distribution: fraction of tokens halting at each loop (ACT health) |
| - loop_grad_norms: gradient norms per recurrent layer (gradient health) |
| """ |
|
|
| def __init__( |
| self, |
| drift_threshold: float = 10.0, |
| collapse_threshold: float = 1.0, |
| ): |
| self.drift_threshold = drift_threshold |
| self.collapse_threshold = collapse_threshold |
|
|
| def compute_routing_entropy(self, router_logits: torch.Tensor) -> float: |
| """Compute routing entropy from router logits. |
| |
| Args: |
| router_logits: [B, L, num_experts] raw router logits |
| |
| Returns: |
| Scalar entropy value (higher = more diverse routing) |
| """ |
| p = F.softmax(router_logits, dim=-1).mean(dim=(0, 1)) |
| entropy = -(p * (p + 1e-10).log()).sum().item() |
| return entropy |
|
|
| def check_health(self, metrics: Dict, step: int) -> List[str]: |
| """Check for drift, collapse, or instability. |
| |
| Args: |
| metrics: Dict with keys: loop_norms, routing_entropy, engram_norms, halt_distribution |
| step: Current training step |
| |
| Returns: |
| List of warning strings (empty if healthy) |
| """ |
| warnings = [] |
|
|
| |
| norms = metrics.get('loop_norms', []) |
| if len(norms) >= 2 and norms[0] > 0: |
| drift_ratio = norms[-1] / norms[0] |
| if drift_ratio > self.drift_threshold: |
| warnings.append( |
| f"DRIFT WARNING step {step}: loop norm ratio {drift_ratio:.1f}x " |
| f"(loop_1={norms[0]:.2f}, loop_{len(norms)}={norms[-1]:.2f})" |
| ) |
|
|
| |
| entropies = metrics.get('routing_entropy', []) |
| if entropies and min(entropies) < self.collapse_threshold: |
| warnings.append( |
| f"COLLAPSE WARNING step {step}: routing entropy {min(entropies):.2f} " |
| f"< threshold {self.collapse_threshold}" |
| ) |
|
|
| return warnings |
|
|
|
|
| |
| |
| |
|
|
| class BPCurriculumTrainer: |
| """Training wrapper for Spider-FLEXITOKENS with BP curriculum. |
| |
| Manages: |
| - BP freeze/unfreeze during warmup (D-27) |
| - Fixed -> adaptive boundary curriculum (D-25) |
| - Dual loss: LM CE + MoE aux + BP (BCE + binomial prior) (D-26) |
| - Per-loop gradient clipping for expert cores |
| - RecurrentMonitor integration for drift/collapse detection |
| """ |
|
|
| def __init__( |
| self, |
| model: SpiderForConditionalGeneration, |
| optimizer: torch.optim.Optimizer, |
| engram_optimizer: Optional[torch.optim.Optimizer], |
| curriculum: CurriculumScheduler, |
| monitor: RecurrentMonitor, |
| warmup_steps: int, |
| base_lr: float, |
| bp_loss_weight: float = 0.1, |
| grad_clip: float = 1.0, |
| expert_core_grad_clip: float = 0.5, |
| ): |
| self.model = model |
| self.optimizer = optimizer |
| self.engram_optimizer = engram_optimizer |
| self.curriculum = curriculum |
| self.monitor = monitor |
| self.warmup_steps = warmup_steps |
| self.base_lr = base_lr |
| self.bp_loss_weight = bp_loss_weight |
| self.grad_clip = grad_clip |
| self.expert_core_grad_clip = expert_core_grad_clip |
| self._bp_frozen = False |
| self.bp_optimizer = None |
|
|
| def freeze_bp(self): |
| """Freeze boundary predictor params during warmup (D-27).""" |
| for name, param in self.model.named_parameters(): |
| if 'boundary_predictor' in name: |
| param.requires_grad = False |
| self._bp_frozen = True |
|
|
| def unfreeze_bp(self): |
| """Unfreeze BP at 0.1x base LR after warmup (D-27).""" |
| bp_param_names = set() |
| bp_params = [] |
| for name, param in self.model.named_parameters(): |
| if 'boundary_predictor' in name: |
| param.requires_grad = True |
| bp_params.append(param) |
| bp_param_names.add(name) |
| self._bp_frozen = False |
|
|
| |
| bp_lr = self.base_lr * 0.1 |
| self.bp_optimizer = torch.optim.Adam( |
| bp_params, lr=bp_lr, betas=(0.9, 0.95), eps=1e-8 |
| ) |
|
|
| def train_step( |
| self, |
| input_ids: torch.Tensor, |
| labels: torch.Tensor, |
| step: int, |
| n_loops: int = 6, |
| amp_ctx: Optional[nullcontext] = None, |
| sdpa_ctx: Optional[nullcontext] = None, |
| ) -> Tuple[torch.Tensor, Dict]: |
| """Single training step with dual loss and monitoring. |
| |
| Args: |
| input_ids: [B, L] byte-level token IDs |
| labels: [B, L] target token IDs (with -100 for padding) |
| step: Current training step |
| n_loops: Number of recurrent loops |
| amp_ctx: Optional autocast context |
| sdpa_ctx: Optional SDPA kernel context |
| |
| Returns: |
| Tuple of (total_loss, metrics_dict) |
| """ |
| amp_ctx = amp_ctx or nullcontext() |
| sdpa_ctx = sdpa_ctx or nullcontext() |
|
|
| |
| if step == 0 and self.warmup_steps > 0: |
| self.freeze_bp() |
| if self._bp_frozen and step >= self.warmup_steps: |
| self.unfreeze_bp() |
|
|
| with amp_ctx, sdpa_ctx: |
| |
| output = self.model(input_ids, labels=labels, n_loops=n_loops) |
|
|
| lm_loss = output['loss'] |
| aux_loss = output['aux_loss'] |
| soft_boundaries = output['soft_boundaries'] |
| hard_boundaries = output['hard_boundaries'] |
|
|
| |
| soft_boundaries, hard_boundaries = self.curriculum.get_boundaries( |
| soft_boundaries, step |
| ) |
|
|
| |
| seq_len = input_ids.shape[-1] |
| if not self._bp_frozen: |
| bp_loss = compute_bp_loss(soft_boundaries, hard_boundaries, seq_len) |
| else: |
| bp_loss = torch.tensor(0.0, device=input_ids.device) |
|
|
| |
| if isinstance(aux_loss, torch.Tensor): |
| total_loss = lm_loss + self.model.config.router_aux_loss_coef * aux_loss |
| else: |
| total_loss = lm_loss + self.model.config.router_aux_loss_coef * aux_loss |
| total_loss = total_loss + self.bp_loss_weight * bp_loss |
|
|
| |
| metrics = { |
| 'lm_loss': lm_loss.item() if isinstance(lm_loss, torch.Tensor) else lm_loss, |
| 'aux_loss': aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss, |
| 'bp_loss': bp_loss.item() if isinstance(bp_loss, torch.Tensor) else bp_loss, |
| 'bp_frozen': self._bp_frozen, |
| 'curriculum_phase': self.curriculum.get_phase(step), |
| 'is_fixed_bp': self.curriculum.is_fixed_bp(step), |
| } |
|
|
| return total_loss, metrics |
|
|
| def clip_gradients(self) -> float: |
| """Clip gradients: global + per-loop expert core clipping. |
| |
| Standard: clip_grad_norm_(all params, max_norm=1.0) |
| Expert cores: tighter clip at 0.5 to prevent drift. |
| """ |
| |
| grad_norm = nn.utils.clip_grad_norm_( |
| self.model.parameters(), max_norm=self.grad_clip |
| ) |
|
|
| |
| expert_core_params = [] |
| for name, param in self.model.named_parameters(): |
| if ('W_gate' in name or 'W_transform' in name) and param.grad is not None: |
| expert_core_params.append(param) |
|
|
| if expert_core_params: |
| nn.utils.clip_grad_norm_( |
| expert_core_params, max_norm=self.expert_core_grad_clip |
| ) |
|
|
| return grad_norm.item() if isinstance(grad_norm, torch.Tensor) else float(grad_norm) |
|
|
|
|
| |
| |
| |
|
|
| def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float: |
| """Cosine learning rate with linear warmup.""" |
| if step < warmup: |
| return max_lr * step / warmup |
| if step >= total: |
| return min_lr |
| decay = (step - warmup) / (total - warmup) |
| return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay)) |
|
|
|
|
| |
| |
| |
|
|
| def save_step_checkpoint(model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp=False, trainer=None, current_best_loss=float("inf")): |
| """Save full checkpoint (model + optimizer) and keep only the last 2.""" |
| if ddp: |
| from torch.distributed.fsdp import ( |
| FullyShardedDataParallel as FSDP, |
| StateDictType, |
| FullStateDictConfig, |
| ) |
| with FSDP.state_dict_type( |
| model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| model_state = model.state_dict() |
| optim_state = FSDP.optim_state_dict(model, optimizer) |
| else: |
| model_state = model.state_dict() |
| optim_state = optimizer.state_dict() |
|
|
| if not master: |
| return None, 0 |
|
|
| os.makedirs(ckpt_dir, exist_ok=True) |
| ckpt_path = os.path.join(ckpt_dir, f"spider-step{step}.pt") |
| tmp_path = ckpt_path + ".tmp" |
| torch.save( |
| { |
| "step": step, |
| "epoch": epoch, |
| "model_state_dict": model_state, |
| "optimizer_state_dict": optim_state, |
| "cfg": cfg, |
| "bp_optimizer_state_dict": ( |
| trainer.bp_optimizer.state_dict() if trainer and trainer.bp_optimizer else None |
| ), |
| "best_loss": current_best_loss, |
| }, |
| tmp_path, |
| ) |
| os.replace(tmp_path, ckpt_path) |
| size_mb = os.path.getsize(ckpt_path) / (1024 * 1024) |
|
|
| |
| step_pattern = re.compile(r"spider-step\d+\.pt$") |
| step_ckpts = sorted( |
| [os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir) if step_pattern.search(f)], |
| key=os.path.getmtime, |
| ) |
| while len(step_ckpts) > 2: |
| old = step_ckpts.pop(0) |
| os.remove(old) |
|
|
| return ckpt_path, size_mb |
|
|
|
|
| def save_full_checkpoint(model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp=False, ckpt_name="full", trainer=None, current_best_loss=float("inf")): |
| """Save full checkpoint with custom name.""" |
| if ddp: |
| from torch.distributed.fsdp import ( |
| FullyShardedDataParallel as FSDP, |
| StateDictType, |
| FullStateDictConfig, |
| ) |
| with FSDP.state_dict_type( |
| model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| model_state = model.state_dict() |
| optim_state = FSDP.optim_state_dict(model, optimizer) |
| else: |
| model_state = model.state_dict() |
| optim_state = optimizer.state_dict() |
|
|
| if not master: |
| return None, 0 |
|
|
| os.makedirs(ckpt_dir, exist_ok=True) |
| final_path = os.path.join(ckpt_dir, f"spider-{ckpt_name}.pt") |
| tmp_path = final_path + ".tmp" |
| torch.save( |
| { |
| "step": step, |
| "epoch": epoch, |
| "model_state_dict": model_state, |
| "optimizer_state_dict": optim_state, |
| "cfg": cfg, |
| "bp_optimizer_state_dict": ( |
| trainer.bp_optimizer.state_dict() if trainer and trainer.bp_optimizer else None |
| ), |
| "best_loss": current_best_loss, |
| }, |
| tmp_path, |
| ) |
| os.replace(tmp_path, final_path) |
| size_mb = os.path.getsize(final_path) / (1024 * 1024) |
| return final_path, size_mb |
|
|
|
|
| def load_checkpoint(model, optimizer, path, ddp=False): |
| """Load model + optimizer state from checkpoint. |
| |
| Handles cross-optimizer resume (e.g. 8bit Adam on local → standard AdamW |
| on remote): if optimizer state dict keys mismatch, we skip the optimizer |
| state and log a warning. The model weights always load successfully. |
| |
| Returns: (step, epoch, bp_optim_state, saved_best_loss) |
| """ |
| ckpt = torch.load(path, map_location="cpu", weights_only=False) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| try: |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
| except (ValueError, KeyError, RuntimeError) as e: |
| logger.warning( |
| f"Optimizer state mismatch (likely 8bit→standard cross-resume): {e}. " |
| f"Skipping optimizer state — optimizer will reinitialize." |
| ) |
| bp_optim_state = ckpt.get("bp_optimizer_state_dict", None) |
| saved_best_loss = ckpt.get("best_loss", float("inf")) |
| return int(ckpt["step"]), int(ckpt.get("epoch", 0)), bp_optim_state, saved_best_loss |
|
|
|
|
| |
| |
| |
|
|
| DEEPSPEED_ZERO3_CONFIG = { |
| "bf16": {"enabled": True}, |
| "zero_optimization": { |
| "stage": 3, |
| "offload_optimizer": { |
| "device": "cpu", |
| "pin_memory": True, |
| }, |
| "offload_param": { |
| "device": "cpu", |
| "pin_memory": True, |
| }, |
| "overlap_comm": True, |
| "contiguous_gradients": True, |
| }, |
| "gradient_accumulation_steps": 1, |
| "gradient_clipping": 1.0, |
| "train_batch_size": "auto", |
| "train_micro_batch_size_per_gpu": "auto", |
| } |
|
|
|
|
| |
| |
| |
|
|
| import enum |
|
|
| class PrecisionMode(enum.Enum): |
| BF16 = "bf16" |
| FP8_DYNAMIC = "fp8_dynamic" |
| MXFP8 = "mxfp8" |
| NVFP4 = "nvfp4" |
|
|
|
|
| def detect_precision_mode() -> PrecisionMode: |
| """Auto-detect best available precision mode based on GPU + libraries. |
| |
| Fallback chain: MXFP8/NVFP4 → FP8_DYNAMIC → BF16 |
| |
| - MXFP8: Requires Blackwell+ (sm120+), torchao with float8 training, |
| block-wise scaling (128x128). Best accuracy among FP8 options. |
| - NVFP4: Requires Blackwell+ (sm120+), fbgemm-gpu-genai with NVFP4 |
| kernels. Most aggressive compression (4-bit weights). |
| - FP8_DYNAMIC: Requires Ada Lovelace+ (sm89+), torchao float8. |
| Row-wise dynamic scaling. Good speed/accuracy tradeoff. |
| - BF16: Fallback for all GPUs. Standard mixed precision. |
| """ |
| if not torch.cuda.is_available(): |
| return PrecisionMode.BF16 |
|
|
| cc = torch.cuda.get_device_capability() |
| major, minor = cc |
|
|
| |
| _has_torchao_fp8 = False |
| try: |
| from torchao.float8 import convert_to_float8_training |
| _has_torchao_fp8 = True |
| except ImportError: |
| pass |
|
|
| |
| _has_nvfp4 = False |
| try: |
| from torchao.quantization import NVFP4Config |
| _has_nvfp4 = True |
| except (ImportError, AttributeError): |
| try: |
| import fbgemm_gpu.genai |
| _has_nvfp4 = True |
| except (ImportError, ModuleNotFoundError): |
| pass |
|
|
| |
| if major >= 12: |
| if _has_torchao_fp8: |
| return PrecisionMode.MXFP8 |
| if _has_nvfp4: |
| return PrecisionMode.NVFP4 |
|
|
| |
| if (major, minor) >= (8, 9) and _has_torchao_fp8: |
| return PrecisionMode.FP8_DYNAMIC |
|
|
| return PrecisionMode.BF16 |
|
|
|
|
| def configure_fp8_training(model, mode: PrecisionMode): |
| """Apply torchao float8 training conversion to model. |
| |
| FP8 training swaps nn.Linear layers with Float8Linear, which performs |
| dynamic quantization of activations and weights to float8_e4m3fn during |
| forward/backward, with high-precision accumulation. |
| |
| Two recipes: |
| - MXFP8 (rowwise_with_gw_hp): Row-wise scaling + high-precision grad weight. |
| Best accuracy. Requires sm120+ hardware. |
| - FP8_DYNAMIC (rowwise): Row-wise dynamic scaling. Good tradeoff. |
| Requires sm89+ hardware. |
| |
| Gradient computation stays in bf16/fp32 for stability. |
| """ |
| from torchao.float8 import convert_to_float8_training, Float8LinearConfig |
|
|
| if mode == PrecisionMode.MXFP8: |
| recipe_name = "rowwise_with_gw_hp" |
| elif mode == PrecisionMode.FP8_DYNAMIC: |
| recipe_name = "rowwise" |
| else: |
| return model |
|
|
| base = Float8LinearConfig.from_recipe_name(recipe_name) |
| config = Float8LinearConfig( |
| cast_config_input=base.cast_config_input, |
| cast_config_weight=base.cast_config_weight, |
| cast_config_grad_output=base.cast_config_grad_output, |
| cast_config_input_for_grad_weight=base.cast_config_input_for_grad_weight, |
| cast_config_weight_for_grad_input=base.cast_config_weight_for_grad_input, |
| cast_config_grad_output_for_grad_weight=base.cast_config_grad_output_for_grad_weight, |
| gemm_config_output=base.gemm_config_output, |
| gemm_config_grad_input=base.gemm_config_grad_input, |
| gemm_config_grad_weight=base.gemm_config_grad_weight, |
| enable_fsdp_float8_all_gather=base.enable_fsdp_float8_all_gather, |
| round_scales_to_power_of_2=base.round_scales_to_power_of_2, |
| pad_inner_dim=True, |
| ) |
|
|
| def module_filter_fn(mod, fqn): |
| skip = any(s in fqn for s in ( |
| "boundary_predictor", |
| "loop_embedding", |
| "engram", |
| "layernorm", |
| "norm", |
| "embed_tokens", |
| "lm_head", |
| "halt_predictor", |
| "gate", |
| )) |
| return not skip |
|
|
| model = convert_to_float8_training( |
| model, |
| module_filter_fn=module_filter_fn, |
| config=config, |
| ) |
| return model |
|
|
|
|
| def configure_nvfp4_training(model): |
| """Apply NVFP4 weight-only quantization for training on Blackwell. |
| |
| NVFP4 uses 4-bit floating-point weights with 8-bit scaling factors. |
| Activations stay in bf16/fp8. Requires fbgemm-gpu-genai kernels. |
| |
| Falls back to FP8_DYNAMIC if NVFP4 kernels unavailable. |
| """ |
| try: |
| from torchao.quantization import NVFP4Config, quantize_ |
| quantize_(model, NVFP4Config()) |
| return model |
| except (ImportError, AttributeError, RuntimeError): |
| logger.warning("NVFP4 not available, falling back to FP8_DYNAMIC") |
| return configure_fp8_training(model, PrecisionMode.FP8_DYNAMIC) |
|
|
|
|
| def try_unsloth(): |
| """Attempt to apply Unsloth patches. Returns (available, FastLanguageModel).""" |
| try: |
| from unsloth import FastLanguageModel |
| return True, FastLanguageModel |
| except (ImportError, Exception): |
| return False, None |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Spider-FLEXITOKENS training") |
| parser.add_argument("--resume", type=str, default="", help="Path to checkpoint to resume from") |
| parser.add_argument("--max_steps", type=int, default=0, help="Override max training steps") |
| parser.add_argument("--mock_data", action="store_true", help="Use mock data (no network)") |
| parser.add_argument("--seq_len", type=int, default=0, help="Override sequence length") |
| parser.add_argument("--micro_batch", type=int, default=0, help="Override micro batch size") |
| parser.add_argument("--n_loops", type=int, default=0, help="Override number of loops") |
| parser.add_argument("--lr", type=float, default=0, help="Override learning rate") |
| parser.add_argument("--ckpt_dir", type=str, default="checkpoints-spider", help="Checkpoint directory") |
| parser.add_argument("--no_unsloth", action="store_true", help="Skip Unsloth even if available") |
| parser.add_argument( |
| "--precision", type=str, default="auto", |
| choices=["auto", "bf16", "fp8_dynamic", "mxfp8", "nvfp4"], |
| help="Training precision: auto (detect), bf16, fp8_dynamic, mxfp8, nvfp4", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| global best_loss |
| args = parse_args() |
|
|
| |
| |
| |
| ddp = int(os.environ.get("RANK", -1)) != -1 |
| if ddp: |
| dist.init_process_group("nccl") |
| rank = int(os.environ["RANK"]) |
| local_rank = int(os.environ["LOCAL_RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
| device = f"cuda:{local_rank}" |
| torch.cuda.set_device(device) |
| else: |
| rank = local_rank = 0 |
| world_size = 1 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| master = rank == 0 |
|
|
| |
| |
| |
| seq_len = args.seq_len or int(os.environ.get("SEQ_LEN", "2048")) |
| micro_batch = args.micro_batch or int(os.environ.get("MICRO_BATCH", "4")) |
| target_tokens = int(os.environ.get("TARGET_TOKENS", "10_000_000_000")) |
| grad_accum = int(os.environ.get("GRAD_ACCUM", "1")) |
| n_loops = args.n_loops or int(os.environ.get("N_LOOPS", "6")) |
| lr = args.lr or float(os.environ.get("LR", "3e-4")) |
| wd = 0.1 |
| warmup_steps = 200 |
| log_every = 10 |
| ckpt_every = int(os.environ.get("CKPT_EVERY", "500")) |
| ckpt_dir = args.ckpt_dir |
|
|
| global_batch_tok = world_size * micro_batch * grad_accum * seq_len |
| total_steps = target_tokens // global_batch_tok |
| if args.max_steps > 0: |
| total_steps = min(total_steps, args.max_steps) |
|
|
| if master: |
| logger.info( |
| f"[Spider-FLEXITOKENS] hidden=2048 | 6 recurrent | 32 experts top-2 | " |
| f"n_loops={n_loops} | seq_len={seq_len} | micro_batch={micro_batch} | " |
| f"grad_accum={grad_accum} | global_batch_tokens={global_batch_tok:,} | " |
| f"total_steps={total_steps:,}" |
| ) |
| logger.info( |
| f"Byte-level vocab: 272 | Context: 256k (YaRN-8) | " |
| f"Sliding window: 8192 | BP curriculum: fixed 30% -> adaptive | " |
| f"Gradient checkpointing: enabled | Precision: {prec_mode.value}" |
| ) |
|
|
| |
| |
| |
| cfg = SpiderConfig() |
| bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported() |
| amp_dtype = torch.bfloat16 if bf16_ok else torch.float16 |
|
|
| |
| if args.precision == "auto": |
| prec_mode = detect_precision_mode() |
| else: |
| prec_mode = PrecisionMode(args.precision) |
|
|
| if master: |
| logger.info(f"Precision mode: {prec_mode.value}") |
|
|
| model = SpiderForConditionalGeneration(cfg).to(amp_dtype) |
| model.gradient_checkpointing_enable() |
| model.enable_input_require_grads() |
|
|
| |
| if prec_mode in (PrecisionMode.MXFP8, PrecisionMode.FP8_DYNAMIC): |
| try: |
| model = configure_fp8_training(model, prec_mode) |
| if master: |
| logger.info(f"torchao FP8 training enabled: {prec_mode.value}") |
| except Exception as e: |
| if master: |
| logger.warning(f"FP8 training setup failed ({e}), falling back to BF16") |
| prec_mode = PrecisionMode.BF16 |
| elif prec_mode == PrecisionMode.NVFP4: |
| try: |
| model = configure_nvfp4_training(model) |
| if master: |
| logger.info("NVFP4 training enabled") |
| except Exception as e: |
| if master: |
| logger.warning(f"NVFP4 setup failed ({e}), falling back to FP8_DYNAMIC") |
| try: |
| model = configure_fp8_training(model, PrecisionMode.FP8_DYNAMIC) |
| prec_mode = PrecisionMode.FP8_DYNAMIC |
| if master: |
| logger.info("Fallback: FP8_DYNAMIC training enabled") |
| except Exception as e2: |
| if master: |
| logger.warning(f"FP8 fallback also failed ({e2}), using BF16") |
| prec_mode = PrecisionMode.BF16 |
|
|
| |
| |
| use_unsloth = False |
| if not args.no_unsloth and not ddp: |
| use_unsloth_available, FastLanguageModel_cls = try_unsloth() |
| if use_unsloth_available: |
| try: |
| |
| |
| os.environ.setdefault("UNSLOTH_MOE_BACKEND", "grouped_mm") |
| use_unsloth = True |
| if master: |
| logger.info("Unsloth MoE + training patches applied") |
| except Exception as e: |
| if master: |
| logger.warning(f"Unsloth patching failed: {e}") |
| if not use_unsloth and master: |
| logger.info("Unsloth not available, using standard PyTorch training") |
|
|
| if ddp: |
| from torch.distributed.fsdp import ( |
| FullyShardedDataParallel as FSDP, |
| ShardingStrategy, |
| MixedPrecision, |
| ) |
| from torch.distributed.fsdp.wrap import ModuleWrapPolicy |
| from spider import SpiderDenseLayer, SpiderRecurrentLayer |
|
|
| mp_policy = MixedPrecision( |
| param_dtype=amp_dtype, |
| reduce_dtype=amp_dtype, |
| buffer_dtype=amp_dtype, |
| ) |
| wrap_policy = ModuleWrapPolicy({SpiderDenseLayer, SpiderRecurrentLayer}) |
| model = FSDP( |
| model, |
| sharding_strategy=ShardingStrategy.FULL_SHARD, |
| mixed_precision=mp_policy, |
| auto_wrap_policy=wrap_policy, |
| device_id=local_rank, |
| ) |
| else: |
| model = model.to(device) |
|
|
| if master: |
| n_params = sum(p.numel() for p in model.parameters()) |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| logger.info( |
| f"Parameters: {n_params:,} total | {trainable:,} trainable | " |
| f"Precision: {prec_mode.value} | AMP dtype: {amp_dtype}" |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| engram_params_list = [ |
| p for n, p in model.named_parameters() |
| if 'engram' in n and 'embed' in n and 'proj' not in n |
| ] |
| backbone_params = [ |
| p for n, p in model.named_parameters() |
| if not ('engram' in n and 'embed' in n and 'proj' not in n) |
| ] |
|
|
| use_8bit_optimizer = _HAS_8BIT and prec_mode == PrecisionMode.BF16 |
|
|
| if use_8bit_optimizer: |
| optimizer = AdamW8bit( |
| backbone_params, lr=lr, weight_decay=wd, |
| betas=(0.9, 0.95), eps=1e-8, |
| ) |
| if engram_params_list: |
| engram_optimizer = Adam8bit( |
| engram_params_list, lr=lr * 5, |
| betas=(0.9, 0.95), eps=1e-8, |
| ) |
| else: |
| engram_optimizer = None |
| if master: |
| logger.info("Optimizer: 8-bit AdamW (bf16 mode, saves ~50% optimizer VRAM)") |
| else: |
| optimizer = torch.optim.AdamW( |
| backbone_params, lr=lr, weight_decay=wd, |
| betas=(0.9, 0.95), foreach=True, eps=1e-8, |
| ) |
| if engram_params_list: |
| engram_optimizer = torch.optim.Adam( |
| engram_params_list, lr=lr * 5, |
| betas=(0.9, 0.95), eps=1e-8, |
| ) |
| else: |
| engram_optimizer = None |
| if master: |
| logger.info(f"Optimizer: standard AdamW ({prec_mode.value} mode)") |
|
|
| |
| |
| |
| curriculum = CurriculumScheduler(total_steps=total_steps) |
| monitor = RecurrentMonitor() |
| trainer = BPCurriculumTrainer( |
| model=model, |
| optimizer=optimizer, |
| engram_optimizer=engram_optimizer, |
| curriculum=curriculum, |
| monitor=monitor, |
| warmup_steps=warmup_steps, |
| base_lr=lr, |
| ) |
|
|
| |
| |
| |
| start_step = 0 |
| start_epoch = 1 |
| bp_optim_state_to_load = None |
| if args.resume and os.path.exists(args.resume): |
| if master: |
| logger.info(f"Resuming from checkpoint: {args.resume}") |
| start_step, start_epoch, bp_optim_state_to_load, saved_best = load_checkpoint( |
| model, optimizer, args.resume, ddp |
| ) |
| best_loss = saved_best |
| if master: |
| logger.info(f"Resumed at step {start_step}, epoch {start_epoch}, best_loss={best_loss:.4f}") |
| else: |
| |
| existing_ckpts = sorted( |
| [os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir) |
| if f.startswith("spider-") and f.endswith(".pt") and not f.endswith(".tmp")] |
| ) if os.path.isdir(ckpt_dir) else [] |
| if existing_ckpts: |
| latest = existing_ckpts[-1] |
| if master: |
| logger.info(f"Auto-resuming from: {latest}") |
| start_step, start_epoch, bp_optim_state_to_load, saved_best = load_checkpoint( |
| model, optimizer, latest, ddp |
| ) |
| best_loss = saved_best |
| if master: |
| logger.info(f"Resumed at step {start_step}, epoch {start_epoch}, best_loss={best_loss:.4f}") |
|
|
| |
| |
| if bp_optim_state_to_load and trainer.bp_optimizer: |
| try: |
| trainer.bp_optimizer.load_state_dict(bp_optim_state_to_load) |
| if master: |
| logger.info("Restored BP optimizer state from checkpoint") |
| except (ValueError, KeyError, RuntimeError) as e: |
| if master: |
| logger.warning(f"BP optimizer state mismatch, skipping: {e}") |
|
|
| |
| |
| |
| if args.mock_data: |
| dataset = MockByteLevelDataset(seq_len=seq_len) |
| else: |
| dataset = ByteLevelDataset( |
| seq_len=seq_len, |
| rank=rank, |
| world_size=world_size, |
| ) |
|
|
| loader = DataLoader( |
| dataset, |
| batch_size=micro_batch, |
| num_workers=4 if not args.mock_data else 0, |
| pin_memory=True, |
| prefetch_factor=1 if not args.mock_data else None, |
| ) |
|
|
| |
| |
| |
| amp_ctx = ( |
| torch.amp.autocast(device_type="cuda", dtype=amp_dtype) |
| if "cuda" in device |
| else nullcontext() |
| ) |
| amp_ctx = nullcontext() if ddp else amp_ctx |
|
|
| try: |
| from torch.nn.attention import sdpa_kernel |
| sdpa_ctx = sdpa_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=True) |
| except Exception: |
| sdpa_ctx = nullcontext() |
|
|
| |
| |
| |
| if master: |
| os.makedirs(ckpt_dir, exist_ok=True) |
|
|
| model.train() |
| data_iter = iter(loader) |
| t0 = time.perf_counter() |
| step = start_step |
| epoch = start_epoch |
| tokens_in_epoch = 0 |
| tokens_per_epoch = target_tokens |
|
|
| while step < total_steps: |
| cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1) |
| for g in optimizer.param_groups: |
| g["lr"] = cur_lr |
| if engram_optimizer: |
| for g in engram_optimizer.param_groups: |
| g["lr"] = cur_lr * 5 |
|
|
| optimizer.zero_grad() |
| if engram_optimizer: |
| engram_optimizer.zero_grad() |
| if trainer.bp_optimizer: |
| trainer.bp_optimizer.zero_grad() |
| loss_accum = 0.0 |
| metrics_accum = {} |
|
|
| for micro_step in range(grad_accum): |
| try: |
| x, y = next(data_iter) |
| except StopIteration: |
| data_iter = iter(loader) |
| x, y = next(data_iter) |
|
|
| x = x.to(device, non_blocking=True) |
| y = y.to(device, non_blocking=True) |
|
|
| sync = ( |
| nullcontext() |
| if (not ddp or micro_step == grad_accum - 1) |
| else model.no_sync() |
| ) |
| with sync: |
| total_loss, metrics = trainer.train_step( |
| x, y, step, n_loops=n_loops, |
| amp_ctx=amp_ctx, sdpa_ctx=sdpa_ctx, |
| ) |
| total_loss = total_loss / grad_accum |
| total_loss.backward() |
|
|
| if master and step == start_step and micro_step == 0: |
| peak_vram = torch.cuda.max_memory_allocated() / 1024**3 |
| logger.info(f"First forward+backward | Peak VRAM: {peak_vram:.1f}GB") |
|
|
| loss_accum += total_loss.item() |
| for k, v in metrics.items(): |
| if k not in metrics_accum: |
| metrics_accum[k] = 0.0 |
| if isinstance(v, (int, float)): |
| metrics_accum[k] += v / grad_accum |
|
|
| |
| grad_norm = trainer.clip_gradients() |
| optimizer.step() |
| if engram_optimizer: |
| engram_optimizer.step() |
| if trainer.bp_optimizer: |
| for g in trainer.bp_optimizer.param_groups: |
| g["lr"] = cur_lr * 0.1 |
| trainer.bp_optimizer.step() |
| step += 1 |
| tokens_in_epoch += global_batch_tok |
|
|
| |
| if master and step % log_every == 0: |
| health_warnings = monitor.check_health(metrics_accum, step) |
| for w in health_warnings: |
| logger.warning(w) |
|
|
| |
| if master and step % log_every == 0: |
| dt = time.perf_counter() - t0 |
| tok_per_sec = global_batch_tok * log_every / dt |
| tokens_seen = step * global_batch_tok |
| bp_status = "FIXED" if metrics_accum.get('is_fixed_bp', True) else "ADAPTIVE" |
| bp_frozen = "FROZEN" if metrics_accum.get('bp_frozen', False) else "ACTIVE" |
| logger.info( |
| f"Epoch {epoch} | step {step:6d}/{total_steps} | " |
| f"loss {loss_accum:.4f} | lm {metrics_accum.get('lm_loss', 0):.4f} | " |
| f"aux {metrics_accum.get('aux_loss', 0):.4f} | " |
| f"bp {metrics_accum.get('bp_loss', 0):.4f} [{bp_status}/{bp_frozen}] | " |
| f"gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} | " |
| f"{tok_per_sec / 1e6:.2f}M tok/s | {tokens_seen / 1e9:.2f}B tokens" |
| ) |
| t0 = time.perf_counter() |
|
|
| |
| if step % ckpt_every == 0: |
| ckpt_path, size_mb = save_step_checkpoint( |
| model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp, trainer, |
| current_best_loss=best_loss, |
| ) |
| if master and ckpt_path: |
| logger.info(f"Saved step checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") |
|
|
| |
| if tokens_in_epoch >= tokens_per_epoch: |
| epoch_loss = loss_accum |
| if master: |
| logger.info(f"Epoch {epoch} complete | loss={epoch_loss:.4f}") |
| ckpt_path, size_mb = save_full_checkpoint( |
| model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp, f"ep{epoch}", trainer, |
| current_best_loss=best_loss, |
| ) |
| if master and ckpt_path: |
| logger.info(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") |
|
|
| if epoch_loss < best_loss: |
| best_loss = epoch_loss |
| ckpt_path, size_mb = save_full_checkpoint( |
| model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp, "best", trainer, |
| current_best_loss=best_loss, |
| ) |
| if master and ckpt_path: |
| logger.info(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") |
|
|
| epoch += 1 |
| tokens_in_epoch = 0 |
|
|
| |
| if step > start_step and master: |
| ckpt_path, size_mb = save_full_checkpoint( |
| model, optimizer, step, epoch, cfg, ckpt_dir, master, ddp, f"final-ep{epoch}", trainer, |
| current_best_loss=best_loss, |
| ) |
| if ckpt_path: |
| logger.info(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") |
|
|
| if ddp: |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
| if master: |
| logger.info("Training complete.") |
|
|
|
|
| if __name__ == "__main__": |
| best_loss = float("inf") |
| main() |
|
|