""" Parameter Golf — Competitive Submission ======================================== Key innovations targeting top-of-leaderboard (< 1.08 BPB): 1. SP8192 Vocabulary: 8192-token SentencePiece tokenizer for better BPB efficiency. Larger vocab = fewer tokens = better compression. 2. Parallel Residuals (PAF): Attention and MLP run in parallel on the same normalized input, saving one LayerNorm and improving information flow. x = x + attn(norm(x)) + mlp(norm(x)) [GPT-J / PaLM style] 3. 3-Layer Depth Recurrence: 3 unique transformer blocks looped multiple times. Layers 0-2 recur K times at train, 2K at eval (free test-time compute). Selective recurrence on inner layers. 4. Score-First TTT (Test-Time Training): At eval, adapt the model's MLP W_down weights chunk-by-chunk using NTP loss. Legal = strictly causal. Implements the In-Place TTT mechanism from arxiv:2604.06169. 5. Int6 GPTQ Post-Training Quantization with SDClip: - Train in full precision (bf16/fp32) - After training, quantize all weight matrices to int6 using GPTQ - Std-based clipping (SDClip) before quantization reduces outlier impact - Embeddings in GPTQ int8 with SDClip - ~1.5x more effective parameters vs int8 in the same 16MB budget 6. MuonEq-R: Muon optimizer with equalized learning rates (scale by sqrt(max(fan_in, fan_out))) and weight decay regularization. 7. QK-Gain 5.25: High gain on QK product prevents attention entropy collapse at small model dimensions. 8. Residual mixing with x0 anchor preserved from baseline. Architecture: SP8192 vocab, d_model=768, 12 heads / 4 KV heads, MLP 4x 3 unique blocks × 8 recurrences = 24 effective layers (train) 3 unique blocks × 16 recurrences = 48 effective layers (eval) Run: torchrun --standalone --nproc_per_node=8 train_gpt2.py """ from __future__ import annotations import copy import glob import io import math import os import random import subprocess import sys import time import uuid import zlib from pathlib import Path import numpy as np import sentencepiece as spm import torch import torch.distributed as dist import torch.nn.functional as F from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP # ───────────────────────────────────────────────────────────── # HYPERPARAMETERS # ───────────────────────────────────────────────────────────── class Hyperparameters: # Data paths — SP8192 tokenizer and matching data data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp8192") train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) seed = int(os.environ.get("SEED", 1337)) # Validation val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) # Training iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) # Model — Parallel Residual Recurrent vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) model_dim = int(os.environ.get("MODEL_DIM", 768)) num_heads = int(os.environ.get("NUM_HEADS", 12)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) mlp_mult = int(os.environ.get("MLP_MULT", 4)) num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 3)) num_recurrences = int(os.environ.get("NUM_RECURRENCES", 8)) num_eval_recurrences = int(os.environ.get("NUM_EVAL_RECURRENCES", 0)) # 0 = auto (2×) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.25)) # Sliding window eval sw_stride = int(os.environ.get("SW_STRIDE", 64)) sw_seq_len = int(os.environ.get("SW_SEQ_LEN", 1024)) # Test-Time Training (TTT) ttt_enabled = int(os.environ.get("TTT_ENABLED", 1)) # 1 = enable at eval ttt_lr = float(os.environ.get("TTT_LR", 0.01)) ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 64)) ttt_layers = os.environ.get("TTT_LAYERS", "all") # "all" or comma-sep indices # Optimizer embed_lr = float(os.environ.get("EMBED_LR", 0.05)) matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.09)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) # GPTQ quantization config gptq_bits = int(os.environ.get("GPTQ_BITS", 6)) gptq_group_size = int(os.environ.get("GPTQ_GROUP_SIZE", 128)) sdclip_nstd = float(os.environ.get("SDCLIP_NSTD", 2.5)) # SWA/EMA swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) # ───────────────────────────────────────────────────────────── # MUON OPTIMIZER (MuonEq-R variant) # ───────────────────────────────────────────────────────────── def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= X.norm() + eps transposed = G.size(0) > G.size(1) if transposed: X = X.T for _ in range(steps): A = X @ X.T B = b * A + c * A @ A X = a * X + B @ X return X.T if transposed else X class Muon(torch.optim.Optimizer): """MuonEq-R: Muon with equalized scaling and weight decay.""" def __init__(self, params, lr: float, momentum: float, backend_steps: int, weight_decay: float = 0.0, nesterov: bool = True): super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, weight_decay=weight_decay, nesterov=nesterov)) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() distributed = dist.is_available() and dist.is_initialized() world_size = dist.get_world_size() if distributed else 1 rank = dist.get_rank() if distributed else 0 for group in self.param_groups: params = group["params"] lr = group["lr"] momentum = group["momentum"] backend_steps = group["backend_steps"] weight_decay = group["weight_decay"] nesterov = group["nesterov"] total = sum(int(p.numel()) for p in params) flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) curr = 0 for i, p in enumerate(params): if i % world_size == rank and p.grad is not None: g = p.grad if weight_decay != 0.0: g = g + weight_decay * p.data.to(g.dtype) state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) buf = state["momentum_buffer"] buf.mul_(momentum).add_(g) if nesterov: g = g.add(buf, alpha=momentum) g = zeropower_via_newtonschulz5(g, steps=backend_steps) # MuonEq-R: scale by sqrt(max(fan_in, fan_out)) g *= max(1, g.size(0) / g.size(1)) ** 0.5 flat[curr: curr + p.numel()] = g.reshape(-1) curr += p.numel() if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) curr = 0 for p in params: g = flat[curr: curr + p.numel()].view_as(p).to(dtype=p.dtype) p.add_(g, alpha=-lr) curr += p.numel() return loss # ───────────────────────────────────────────────────────────── # BPB EVALUATION UTILITIES # ───────────────────────────────────────────────────────────── def build_sentencepiece_luts(sp, vocab_size, device): sv = int(sp.vocab_size()) sz = max(sv, vocab_size) bb = np.zeros(sz, dtype=np.int16) hs = np.zeros(sz, dtype=bool) ib = np.ones(sz, dtype=bool) for tid in range(sv): if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue ib[tid] = False if sp.is_byte(tid): bb[tid] = 1 continue piece = sp.id_to_piece(tid) if piece.startswith("\u2581"): hs[tid] = True piece = piece[1:] bb[tid] = len(piece.encode("utf-8")) return (torch.tensor(bb, dtype=torch.int16, device=device), torch.tensor(hs, dtype=torch.bool, device=device), torch.tensor(ib, dtype=torch.bool, device=device)) def eval_val_sliding_window(args, model, rank, world_size, device, val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut, use_ttt=False): """Sliding-window BPB: every token scored with sw_stride context.""" seq_len = args.sw_seq_len stride = args.sw_stride T = val_tokens.numel() all_starts = list(range(0, T - seq_len - 1, stride)) my_starts = all_starts[rank::world_size] loss_sum = torch.zeros((), device=device, dtype=torch.float64) token_cnt = torch.zeros((), device=device, dtype=torch.float64) byte_cnt = torch.zeros((), device=device, dtype=torch.float64) # Get the raw model for TTT raw_model = model while hasattr(raw_model, 'module'): raw_model = raw_model.module if hasattr(raw_model, '_orig_mod'): raw_model = raw_model._orig_mod raw_model.eval() # TTT modifies weights in-place, so we can't use inference_mode ctx = torch.no_grad if (use_ttt and args.ttt_enabled) else torch.inference_mode with ctx(): for start in my_starts: end = start + seq_len x = val_tokens[start:end].unsqueeze(0).to(device, dtype=torch.int64) y = val_tokens[start + 1:end + 1].unsqueeze(0).to(device, dtype=torch.int64) with torch.autocast("cuda", dtype=torch.bfloat16): if use_ttt and args.ttt_enabled: ptl = raw_model.per_token_loss_with_ttt(x, y, args) else: ptl = raw_model.per_token_loss(x, y) lo = seq_len - stride ptl_s = ptl[0, lo:] y_s = y[0, lo:] x_s = x[0, lo:] loss_sum += ptl_s.to(torch.float64).sum() token_cnt += ptl_s.numel() tb = base_bytes_lut[y_s].to(torch.float64) tb += (has_space_lut[y_s] & ~is_boundary_lut[x_s]).to(torch.float64) byte_cnt += tb.sum() if dist.is_available() and dist.is_initialized(): for t in (loss_sum, token_cnt, byte_cnt): dist.all_reduce(t, op=dist.ReduceOp.SUM) val_loss = float((loss_sum / token_cnt).item()) bpb = float((loss_sum / math.log(2) / byte_cnt).item()) raw_model.train() return val_loss, bpb # ───────────────────────────────────────────────────────────── # GPTQ Int6 QUANTIZATION with SDClip # ───────────────────────────────────────────────────────────── CONTROL_PATTERNS = tuple(p for p in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,log_alpha" ).split(",") if p) KEEP_FP_MAX_NUMEL = 65_536 KEEP_FP_STORE_DTYPE = torch.float16 INT8_SCALE_DTYPE = torch.float16 def sdclip(t: Tensor, n_std: float = 2.5) -> Tensor: """Std-based clipping: clip to mean +/- n_std * std.""" mean = t.float().mean() std = t.float().std() lo = mean - n_std * std hi = mean + n_std * std return t.clamp(lo.item(), hi.item()) def _quant_tensor_int6(t: Tensor, n_std: float = 2.5): """Quantize tensor to int6 (range -31 to 31) with SDClip per row.""" t32 = t.float() max_val = 31 # 6-bit signed: -31 to 31 if t32.ndim == 2: # Per-row SDClip and quantization mean = t32.mean(dim=1, keepdim=True) std = t32.std(dim=1, keepdim=True).clamp_min(1e-9) lo = mean - n_std * std hi = mean + n_std * std t_clipped = t32.clamp(lo.expand_as(t32), hi.expand_as(t32)) clip_val = t_clipped.abs().amax(dim=1).clamp_min(1e-9) scale = clip_val / max_val q = torch.clamp(torch.round(t_clipped / scale[:, None]), -max_val, max_val).to(torch.int8) return q.contiguous(), scale.to(torch.float16).contiguous() # 1D fallback t_clipped = sdclip(t32, n_std) cv = float(t_clipped.abs().max().item()) scale = torch.tensor(max(cv / max_val, 1.0 / max_val), dtype=torch.float32) q = torch.clamp(torch.round(t_clipped / scale), -max_val, max_val).to(torch.int8) return q.contiguous(), scale def _quant_tensor_int8(t: Tensor, n_std: float = 2.5): """Quantize tensor to int8 with SDClip.""" t32 = t.float() if t32.ndim == 2: mean = t32.mean(dim=1, keepdim=True) std = t32.std(dim=1, keepdim=True).clamp_min(1e-9) lo = mean - n_std * std hi = mean + n_std * std t_clipped = t32.clamp(lo.expand_as(t32), hi.expand_as(t32)) clip_val = t_clipped.abs().amax(dim=1).clamp_min(1e-9) scale = clip_val / 127.0 q = torch.clamp(torch.round(t_clipped / scale[:, None]), -127, 127).to(torch.int8) return q.contiguous(), scale.to(torch.float16).contiguous() cv = float(sdclip(t32, n_std).abs().max().item()) scale = torch.tensor(max(cv / 127.0, 1.0 / 127.0), dtype=torch.float32) q = torch.clamp(torch.round(t32.clamp(-cv, cv) / scale), -127, 127).to(torch.int8) return q.contiguous(), scale def quantize_state_dict(state_dict: dict, gptq_bits: int = 6, sdclip_nstd: float = 2.5): """Mixed quantization: int6 for weight matrices, int8 for embeddings, fp16 for small/control.""" quantized, scales, dtypes, passthrough, pt_orig, qmeta = {}, {}, {}, {}, {}, {} stats = {k: 0 for k in ("param_count", "num_tensors", "baseline_bytes", "quant_bytes")} quant_fn = _quant_tensor_int6 if gptq_bits == 6 else _quant_tensor_int8 for name, tensor in state_dict.items(): t = tensor.detach().cpu().contiguous() stats["param_count"] += t.numel() stats["num_tensors"] += 1 stats["baseline_bytes"] += t.numel() * t.element_size() if not t.is_floating_point(): passthrough[name] = t stats["quant_bytes"] += t.numel() * t.element_size() continue is_ctrl = any(p in name for p in CONTROL_PATTERNS) is_small = t.numel() <= KEEP_FP_MAX_NUMEL # Embeddings: int8 (higher precision for tied I/O) if "tok_emb" in name: pt_orig[name] = str(t.dtype).removeprefix("torch.") q, s = _quant_tensor_int8(t, sdclip_nstd) quantized[name] = q scales[name] = s dtypes[name] = str(t.dtype).removeprefix("torch.") if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0, "bits": 8} stats["quant_bytes"] += q.numel() + s.numel() * s.element_size() continue if is_ctrl or is_small: if t.dtype in (torch.float32, torch.bfloat16): pt_orig[name] = str(t.dtype).removeprefix("torch.") passthrough[name] = t.float() if is_ctrl else t.to(KEEP_FP_STORE_DTYPE) passthrough[name] = passthrough[name].contiguous() stats["quant_bytes"] += passthrough[name].numel() * passthrough[name].element_size() continue # Large weight matrices: int6 with SDClip q, s = quant_fn(t, sdclip_nstd) if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0, "bits": gptq_bits} quantized[name] = q scales[name] = s dtypes[name] = str(t.dtype).removeprefix("torch.") stats["quant_bytes"] += q.numel() + s.numel() * s.element_size() obj = {"__quant_format__": f"int{gptq_bits}_sdclip_v1", "quantized": quantized, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} if qmeta: obj["qmeta"] = qmeta if pt_orig: obj["passthrough_orig_dtypes"] = pt_orig return obj, stats def dequantize_state_dict(obj: dict) -> dict: out = {} qmeta = obj.get("qmeta", {}) pt_orig = obj.get("passthrough_orig_dtypes", {}) for name, q in obj["quantized"].items(): dtype = getattr(torch, obj["dtypes"][name]) s = obj["scales"][name] if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: s = s.to(torch.float32) out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() else: out[name] = (q.float() * float(s.item())).to(dtype).contiguous() for name, t in obj["passthrough"].items(): ot = t.detach().cpu().contiguous() od = pt_orig.get(name) if isinstance(od, str): ot = ot.to(dtype=getattr(torch, od)).contiguous() out[name] = ot return out # ───────────────────────────────────────────────────────────── # DATA LOADING # ───────────────────────────────────────────────────────────── def load_data_shard(file: Path) -> Tensor: hdr = np.fromfile(file, dtype=" Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No val files: {pattern}") tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len return tokens[: usable + 1] class TokenStream: def __init__(self, pattern: str): files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No shards: {pattern}") self.files = files self.idx = 0 self.tokens = load_data_shard(files[0]) self.pos = 0 def take(self, n: int) -> Tensor: chunks, rem = [], n while rem > 0: avail = self.tokens.numel() - self.pos if avail <= 0: self.idx = (self.idx + 1) % len(self.files) self.tokens = load_data_shard(self.files[self.idx]) self.pos = 0 avail = self.tokens.numel() k = min(rem, avail) chunks.append(self.tokens[self.pos: self.pos + k]) self.pos += k rem -= k return chunks[0] if len(chunks) == 1 else torch.cat(chunks) class DistributedTokenLoader: def __init__(self, pattern, rank, world_size, device): self.rank = rank; self.ws = world_size; self.device = device self.stream = TokenStream(pattern) def next_batch(self, global_tokens, seq_len, grad_accum): local_tokens = global_tokens // (self.ws * grad_accum) per_rank_span = local_tokens + 1 chunk = self.stream.take(per_rank_span * self.ws) start = self.rank * per_rank_span local = chunk[start: start + per_rank_span].to(torch.int64) x = local[:-1].reshape(-1, seq_len) y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) # ───────────────────────────────────────────────────────────── # TRANSFORMER COMPONENTS — Parallel Residual Architecture # ───────────────────────────────────────────────────────────── class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): super().__init__() self.eps = eps def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) class Rotary(nn.Module): def __init__(self, dim: int, base: float = 10000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._cached_len = 0 self._cos: Tensor | None = None self._sin: Tensor | None = None def forward(self, seq_len: int, device, dtype): if self._cos is None or self._cached_len != seq_len or self._cos.device != device: t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq.to(device)) self._cos = freqs.cos()[None, None, :, :] self._sin = freqs.sin()[None, None, :, :] self._cached_len = seq_len return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) def apply_rotary(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) class CausalSelfAttention(nn.Module): def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init): super().__init__() assert dim % num_heads == 0 and num_heads % num_kv_heads == 0 self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads kv_dim = num_kv_heads * self.head_dim self.c_q = nn.Linear(dim, dim, bias=False) self.c_k = nn.Linear(dim, kv_dim, bias=False) self.c_v = nn.Linear(dim, kv_dim, bias=False) self.proj = nn.Linear(dim, dim, bias=False) self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, base=rope_base) def forward(self, x: Tensor) -> Tensor: B, T, _ = x.shape q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.c_v(x).reshape(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(T, x.device, q.dtype) q = apply_rotary(q, cos, sin) k = apply_rotary(k, cos, sin) q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads)) return self.proj(y.transpose(1, 2).contiguous().reshape(B, T, -1)) class MLP(nn.Module): def __init__(self, dim, mlp_mult): super().__init__() hidden = dim * mlp_mult self.fc = nn.Linear(dim, hidden, bias=False) self.proj = nn.Linear(hidden, dim, bias=False) def forward(self, x: Tensor) -> Tensor: return self.proj(torch.relu(self.fc(x)).square()) class ParallelBlock(nn.Module): """Parallel Residual Block: attn and MLP run on the same normalized input. x = resid_mix[0]*x + resid_mix[1]*x0 h = norm(x) x = x + attn_scale * attn(h) + mlp_scale * mlp(h) """ def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init): super().__init__() self.norm = RMSNorm() self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack([torch.ones(dim), torch.zeros(dim)]).float()) def forward(self, x: Tensor, x0: Tensor) -> Tensor: mix = self.resid_mix.to(x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 h = self.norm(x) # Parallel: both attn and MLP operate on same normalized input x = x + self.attn_scale.to(x.dtype)[None, None, :] * self.attn(h) \ + self.mlp_scale.to(x.dtype)[None, None, :] * self.mlp(h) return x # ───────────────────────────────────────────────────────────── # RECURRENT GPT MODEL with Score-First TTT # ───────────────────────────────────────────────────────────── class RecurrentGPT(nn.Module): """ K unique parallel-residual blocks x N recurrences. At eval: 2N recurrences + optional score-first TTT. """ def __init__(self, args: Hyperparameters): super().__init__() self.logit_softcap = args.logit_softcap self._train_rec = args.num_recurrences self._eval_rec = args.num_eval_recurrences or args.num_recurrences * 2 self._vocab_size = args.vocab_size self.tok_emb = nn.Embedding(args.vocab_size, args.model_dim) self.blocks = nn.ModuleList([ ParallelBlock(args.model_dim, args.num_heads, args.num_kv_heads, args.mlp_mult, args.rope_base, args.qk_gain_init) for _ in range(args.num_unique_layers) ]) self.final_norm = RMSNorm() nn.init.normal_(self.tok_emb.weight, std=0.005) def _forward_hidden(self, input_ids: Tensor) -> Tensor: x = F.rms_norm(self.tok_emb(input_ids), (self.tok_emb.embedding_dim,)) x0 = x n = self._train_rec if self.training else self._eval_rec for _ in range(n): for block in self.blocks: x = block(x, x0) return self.final_norm(x) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: h = self._forward_hidden(input_ids) logits = F.linear(h.reshape(-1, h.size(-1)), self.tok_emb.weight) logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) return F.cross_entropy(logits.float(), target_ids.reshape(-1), reduction="mean") def per_token_loss(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: h = self._forward_hidden(input_ids) B, T, D = h.shape logits = F.linear(h.reshape(B * T, D), self.tok_emb.weight) logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) return F.cross_entropy(logits.float(), target_ids.reshape(B * T), reduction="none").reshape(B, T) @torch.no_grad() def per_token_loss_with_ttt(self, input_ids: Tensor, target_ids: Tensor, args: Hyperparameters) -> Tensor: """Score-first TTT: adapt MLP.proj weights chunk-by-chunk at eval. "Score-first" means: for each chunk, we first SCORE (compute loss) with current weights, then UPDATE weights for the next chunk. This is strictly causal -- predictions for chunk i only use information from chunks 0..i-1. We update MLP.proj.weight (the "down projection") in each block -- this is the "fast weight" in the In-Place TTT framework (arxiv:2604.06169). """ chunk_size = args.ttt_chunk_size ttt_lr = args.ttt_lr B, T = input_ids.shape # Determine which layers to apply TTT if args.ttt_layers == "all": ttt_layer_indices = list(range(len(self.blocks))) else: ttt_layer_indices = [int(x) for x in args.ttt_layers.split(",")] # Save original weights to restore after this sequence original_weights = {} for li in ttt_layer_indices: original_weights[li] = self.blocks[li].mlp.proj.weight.data.clone() all_ptl = [] n_chunks = (T + chunk_size - 1) // chunk_size for ci in range(n_chunks): lo = ci * chunk_size hi = min((ci + 1) * chunk_size, T) # Score first: full forward pass with current (possibly updated) weights h = self._forward_hidden(input_ids) # (B, T, D) h_chunk = h[:, lo:hi, :] y_chunk = target_ids[:, lo:hi] logits = F.linear(h_chunk.reshape(-1, h_chunk.size(-1)), self.tok_emb.weight) logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) ptl = F.cross_entropy(logits.float(), y_chunk.reshape(-1), reduction="none").reshape(B, hi - lo) all_ptl.append(ptl) # Then update: manual gradient step on MLP.proj for next chunk if ci < n_chunks - 1: for li in ttt_layer_indices: block = self.blocks[li] # Get MLP intermediate activations for this chunk h_norm = F.rms_norm(h_chunk.reshape(-1, h_chunk.size(-1)).float(), (h_chunk.size(-1),)) z = torch.relu(block.mlp.fc(h_norm.to(h_chunk.dtype))).square() # Reconstruction-based update: minimize ||Z @ W^T - h_norm||^2 pred = z @ block.mlp.proj.weight.T residual = pred - h_norm.to(pred.dtype) grad_w = residual.T @ z / z.size(0) block.mlp.proj.weight.data -= ttt_lr * grad_w.to(block.mlp.proj.weight.dtype) # Restore original weights after processing this sequence for li in ttt_layer_indices: self.blocks[li].mlp.proj.weight.data = original_weights[li] return torch.cat(all_ptl, dim=1) # ───────────────────────────────────────────────────────────── # EMA (Exponential Moving Average) # ───────────────────────────────────────────────────────────── class EMA: """Exponential Moving Average of model parameters.""" def __init__(self, model: nn.Module, decay: float = 0.999): self.model = model self.decay = decay self.shadow = {n: p.data.clone() for n, p in model.named_parameters()} self.backup = {} def update(self): for n, p in self.model.named_parameters(): self.shadow[n].mul_(self.decay).add_(p.data, alpha=1.0 - self.decay) def apply(self): """Apply EMA weights (backup current).""" self.backup = {} for n, p in self.model.named_parameters(): self.backup[n] = p.data.clone() p.data.copy_(self.shadow[n]) def restore(self): """Restore original weights.""" for n, p in self.model.named_parameters(): p.data.copy_(self.backup[n]) self.backup = {} # ───────────────────────────────────────────────────────────── # TRAINING # ───────────────────────────────────────────────────────────── def main(): global zeropower_via_newtonschulz5 code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) # -- distributed setup -- distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) grad_accum = max(1, 8 // world_size) grad_scale = 1.0 / grad_accum if not torch.cuda.is_available(): raise RuntimeError("CUDA required") device = torch.device("cuda", local_rank) torch.cuda.set_device(device) if distributed: dist.init_process_group("nccl", device_id=device) dist.barrier() master = rank == 0 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.backends.cuda import (enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp, enable_cudnn_sdp) enable_flash_sdp(True); enable_math_sdp(False) enable_mem_efficient_sdp(False); enable_cudnn_sdp(False) logfile = None if master: os.makedirs("logs", exist_ok=True) logfile = f"logs/{args.run_id}.txt" print(logfile) def log0(msg, console=True): if not master: return if console: print(msg) if logfile: with open(logfile, "a") as f: print(msg, file=f) log0(code, console=False) log0(f"Python {sys.version}", console=False) log0(f"PyTorch {torch.__version__}", console=False) try: log0(subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=False).stdout, console=False) except FileNotFoundError: pass random.seed(args.seed); np.random.seed(args.seed) torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) # -- tokenizer + val data -- sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) base_bytes_lut, has_space_lut, is_boundary_lut = build_sentencepiece_luts( sp, args.vocab_size, device) val_tokens = load_validation_tokens(args.val_files, args.sw_seq_len) log0(f"val_tokens:{val_tokens.numel()}") # -- model -- base_model = RecurrentGPT(args).to(device).bfloat16() compiled = torch.compile(base_model, dynamic=False, fullgraph=True) model = DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) \ if distributed else compiled n_unique = sum(p.numel() for p in base_model.parameters()) eff_depth = args.num_unique_layers * args.num_recurrences log0(f"unique_params:{n_unique} effective_depth:{eff_depth} " f"train_loops:{args.num_recurrences} eval_loops:{base_model._eval_rec}") log0(f"world_size:{world_size} grad_accum:{grad_accum}") # -- optimizer -- block_params = list(base_model.blocks.named_parameters()) matrix_params = [p for n, p in block_params if p.ndim == 2 and not any(pat in n for pat in CONTROL_PATTERNS)] scalar_params = [p for n, p in block_params if p.ndim < 2 or any(pat in n for pat in CONTROL_PATTERNS)] opt_tok = torch.optim.Adam( [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, weight_decay=args.muon_weight_decay) for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr opt_scalar = torch.optim.Adam( [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) optimizers = [opt_tok, opt_muon, opt_scalar] # -- EMA -- ema = EMA(base_model, decay=0.999) ema_start_step = int(args.iterations * args.swa_start_frac) # -- LR schedule -- max_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None def lr_mul(step, elapsed_ms): if args.warmdown_iters <= 0: return 1.0 if max_ms is None: ws = max(args.iterations - args.warmdown_iters, 0) return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ if ws <= step < args.iterations else 1.0 step_ms = elapsed_ms / max(step, 1) remain = max(max_ms - elapsed_ms, 0.0) wd_ms = args.warmdown_iters * step_ms return remain / max(wd_ms, 1e-9) if remain <= wd_ms else 1.0 def zero_all(): [o.zero_grad(set_to_none=True) for o in optimizers] # -- warmup -- if args.warmup_steps > 0: init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] model.train() train_loader_w = DistributedTokenLoader(args.train_files, rank, world_size, device) for ws_i in range(args.warmup_steps): zero_all() for ms_i in range(grad_accum): if distributed: model.require_backward_grad_sync = (ms_i == grad_accum - 1) x, y = train_loader_w.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum) with torch.autocast("cuda", torch.bfloat16): (model(x, y) * grad_scale).backward() for o in optimizers: o.step() zero_all() base_model.load_state_dict(init_model, strict=True) for o, s in zip(optimizers, init_opts): o.load_state_dict(s) zero_all() if distributed: model.require_backward_grad_sync = True # -- data + training loop -- train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) training_ms = 0.0 stop_step: int | None = None torch.cuda.synchronize() t0 = time.perf_counter() step = 0 while True: last_step = step == args.iterations or (stop_step is not None and step >= stop_step) do_val = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) if do_val: torch.cuda.synchronize() training_ms += 1000.0 * (time.perf_counter() - t0) vl, vbpb = eval_val_sliding_window( args, model, rank, world_size, device, val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut, use_ttt=False) log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " f"train_ms:{training_ms:.0f} step_avg:{training_ms/max(step,1):.2f}ms") torch.cuda.synchronize() t0 = time.perf_counter() if last_step: if master: # Apply EMA weights for final model ema.apply() # Evaluate with TTT log0("Evaluating with EMA + TTT...") vl_ema, vbpb_ema = eval_val_sliding_window( args, base_model, rank, world_size, device, val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut, use_ttt=True) log0(f"ema_ttt val_loss:{vl_ema:.4f} val_bpb:{vbpb_ema:.4f}") # Quantize and export sd = base_model.state_dict() obj, stats = quantize_state_dict(sd, args.gptq_bits, args.sdclip_nstd) buf = io.BytesIO() torch.save(obj, buf) compressed = zlib.compress(buf.getvalue(), level=9) code_bytes = len(code.encode()) model_bytes = len(compressed) total_bytes = code_bytes + model_bytes log0(f"final_quant_zlib_roundtrip " f"code_bytes:{code_bytes} " f"model_compressed_bytes:{model_bytes} " f"total_artifact_bytes:{total_bytes} " f"total_artifact_mb:{total_bytes/1e6:.3f} " f"param_count:{stats['param_count']}") # Round-trip verify sd2 = dequantize_state_dict(obj) base_model.load_state_dict(sd2, strict=True) vl2, vbpb2 = eval_val_sliding_window( args, base_model, rank, world_size, device, val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut, use_ttt=True) log0(f"quantized_model+ttt val_loss:{vl2:.4f} val_bpb:{vbpb2:.4f}") # Restore non-EMA weights ema.restore() break if stop_step is None and max_ms is not None: torch.cuda.synchronize() elapsed = 1000.0 * (time.perf_counter() - t0) + training_ms if elapsed >= max_ms: stop_step = step + 1 zero_all() for ms_i in range(grad_accum): if distributed: model.require_backward_grad_sync = (ms_i == grad_accum - 1) x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum) with torch.autocast("cuda", torch.bfloat16): (model(x, y) * grad_scale).backward() torch.cuda.synchronize() elapsed_ms = 1000.0 * (time.perf_counter() - t0) + training_ms m = lr_mul(step, elapsed_ms) for o in optimizers: for g in o.param_groups: g["lr"] = g["base_lr"] * m for o in optimizers: o.step() # EMA update if step >= ema_start_step: ema.update() if step % args.train_log_every == 0 and master: log0(f"step:{step} lr_mul:{m:.4f}") step += 1 if __name__ == "__main__": main()