| """ |
| Parameter Golf β Competitive Submission |
| ======================================== |
| Key innovations targeting top-of-leaderboard (< 1.08 BPB): |
| |
| 1. SP8192 Vocabulary: 8192-token SentencePiece tokenizer for better BPB |
| efficiency. Larger vocab = fewer tokens = better compression. |
| |
| 2. Parallel Residuals (PAF): Attention and MLP run in parallel on the same |
| normalized input, saving one LayerNorm and improving information flow. |
| x = x + attn(norm(x)) + mlp(norm(x)) [GPT-J / PaLM style] |
| |
| 3. 3-Layer Depth Recurrence: 3 unique transformer blocks looped multiple |
| times. Layers 0-2 recur K times at train, 2K at eval (free test-time |
| compute). Selective recurrence on inner layers. |
| |
| 4. Score-First TTT (Test-Time Training): At eval, adapt the model's MLP |
| W_down weights chunk-by-chunk using NTP loss. Legal = strictly causal. |
| Implements the In-Place TTT mechanism from arxiv:2604.06169. |
| |
| 5. Int6 GPTQ Post-Training Quantization with SDClip: |
| - Train in full precision (bf16/fp32) |
| - After training, quantize all weight matrices to int6 using GPTQ |
| - Std-based clipping (SDClip) before quantization reduces outlier impact |
| - Embeddings in GPTQ int8 with SDClip |
| - ~1.5x more effective parameters vs int8 in the same 16MB budget |
| |
| 6. MuonEq-R: Muon optimizer with equalized learning rates (scale by |
| sqrt(max(fan_in, fan_out))) and weight decay regularization. |
| |
| 7. QK-Gain 5.25: High gain on QK product prevents attention entropy |
| collapse at small model dimensions. |
| |
| 8. Residual mixing with x0 anchor preserved from baseline. |
| |
| Architecture: |
| SP8192 vocab, d_model=768, 12 heads / 4 KV heads, MLP 4x |
| 3 unique blocks Γ 8 recurrences = 24 effective layers (train) |
| 3 unique blocks Γ 16 recurrences = 48 effective layers (eval) |
| |
| Run: torchrun --standalone --nproc_per_node=8 train_gpt2.py |
| """ |
| from __future__ import annotations |
|
|
| import copy |
| import glob |
| import io |
| import math |
| import os |
| import random |
| import subprocess |
| import sys |
| import time |
| import uuid |
| import zlib |
| from pathlib import Path |
|
|
| import numpy as np |
| import sentencepiece as spm |
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| from torch import Tensor, nn |
| from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
| |
| |
| |
|
|
| class Hyperparameters: |
| |
| data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp8192") |
| train_files = os.path.join(data_path, "fineweb_train_*.bin") |
| val_files = os.path.join(data_path, "fineweb_val_*.bin") |
| tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") |
| run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) |
| seed = int(os.environ.get("SEED", 1337)) |
|
|
| |
| val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) |
| val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) |
| train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) |
|
|
| |
| iterations = int(os.environ.get("ITERATIONS", 20000)) |
| warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) |
| warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) |
| train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) |
| train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) |
| max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) |
|
|
| |
| vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) |
| model_dim = int(os.environ.get("MODEL_DIM", 768)) |
| num_heads = int(os.environ.get("NUM_HEADS", 12)) |
| num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) |
| mlp_mult = int(os.environ.get("MLP_MULT", 4)) |
| num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 3)) |
| num_recurrences = int(os.environ.get("NUM_RECURRENCES", 8)) |
| num_eval_recurrences = int(os.environ.get("NUM_EVAL_RECURRENCES", 0)) |
| rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) |
| logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) |
| qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.25)) |
|
|
| |
| sw_stride = int(os.environ.get("SW_STRIDE", 64)) |
| sw_seq_len = int(os.environ.get("SW_SEQ_LEN", 1024)) |
|
|
| |
| ttt_enabled = int(os.environ.get("TTT_ENABLED", 1)) |
| ttt_lr = float(os.environ.get("TTT_LR", 0.01)) |
| ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 64)) |
| ttt_layers = os.environ.get("TTT_LAYERS", "all") |
|
|
| |
| embed_lr = float(os.environ.get("EMBED_LR", 0.05)) |
| matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) |
| scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) |
| muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) |
| muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) |
| muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.09)) |
| beta1 = float(os.environ.get("BETA1", 0.9)) |
| beta2 = float(os.environ.get("BETA2", 0.95)) |
| adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) |
|
|
| |
| gptq_bits = int(os.environ.get("GPTQ_BITS", 6)) |
| gptq_group_size = int(os.environ.get("GPTQ_GROUP_SIZE", 128)) |
| sdclip_nstd = float(os.environ.get("SDCLIP_NSTD", 2.5)) |
|
|
| |
| swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) |
|
|
|
|
| |
| |
| |
|
|
| def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: |
| a, b, c = (3.4445, -4.7750, 2.0315) |
| X = G.bfloat16() |
| X /= X.norm() + eps |
| transposed = G.size(0) > G.size(1) |
| if transposed: |
| X = X.T |
| for _ in range(steps): |
| A = X @ X.T |
| B = b * A + c * A @ A |
| X = a * X + B @ X |
| return X.T if transposed else X |
|
|
|
|
| class Muon(torch.optim.Optimizer): |
| """MuonEq-R: Muon with equalized scaling and weight decay.""" |
| def __init__(self, params, lr: float, momentum: float, backend_steps: int, |
| weight_decay: float = 0.0, nesterov: bool = True): |
| super().__init__(params, dict(lr=lr, momentum=momentum, |
| backend_steps=backend_steps, |
| weight_decay=weight_decay, |
| nesterov=nesterov)) |
|
|
| @torch.no_grad() |
| def step(self, closure=None): |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
| distributed = dist.is_available() and dist.is_initialized() |
| world_size = dist.get_world_size() if distributed else 1 |
| rank = dist.get_rank() if distributed else 0 |
| for group in self.param_groups: |
| params = group["params"] |
| lr = group["lr"] |
| momentum = group["momentum"] |
| backend_steps = group["backend_steps"] |
| weight_decay = group["weight_decay"] |
| nesterov = group["nesterov"] |
| total = sum(int(p.numel()) for p in params) |
| flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) |
| curr = 0 |
| for i, p in enumerate(params): |
| if i % world_size == rank and p.grad is not None: |
| g = p.grad |
| if weight_decay != 0.0: |
| g = g + weight_decay * p.data.to(g.dtype) |
| state = self.state[p] |
| if "momentum_buffer" not in state: |
| state["momentum_buffer"] = torch.zeros_like(g) |
| buf = state["momentum_buffer"] |
| buf.mul_(momentum).add_(g) |
| if nesterov: |
| g = g.add(buf, alpha=momentum) |
| g = zeropower_via_newtonschulz5(g, steps=backend_steps) |
| |
| g *= max(1, g.size(0) / g.size(1)) ** 0.5 |
| flat[curr: curr + p.numel()] = g.reshape(-1) |
| curr += p.numel() |
| if distributed: |
| dist.all_reduce(flat, op=dist.ReduceOp.SUM) |
| curr = 0 |
| for p in params: |
| g = flat[curr: curr + p.numel()].view_as(p).to(dtype=p.dtype) |
| p.add_(g, alpha=-lr) |
| curr += p.numel() |
| return loss |
|
|
|
|
| |
| |
| |
|
|
| def build_sentencepiece_luts(sp, vocab_size, device): |
| sv = int(sp.vocab_size()) |
| sz = max(sv, vocab_size) |
| bb = np.zeros(sz, dtype=np.int16) |
| hs = np.zeros(sz, dtype=bool) |
| ib = np.ones(sz, dtype=bool) |
| for tid in range(sv): |
| if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): |
| continue |
| ib[tid] = False |
| if sp.is_byte(tid): |
| bb[tid] = 1 |
| continue |
| piece = sp.id_to_piece(tid) |
| if piece.startswith("\u2581"): |
| hs[tid] = True |
| piece = piece[1:] |
| bb[tid] = len(piece.encode("utf-8")) |
| return (torch.tensor(bb, dtype=torch.int16, device=device), |
| torch.tensor(hs, dtype=torch.bool, device=device), |
| torch.tensor(ib, dtype=torch.bool, device=device)) |
|
|
|
|
| def eval_val_sliding_window(args, model, rank, world_size, device, |
| val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut, |
| use_ttt=False): |
| """Sliding-window BPB: every token scored with sw_stride context.""" |
| seq_len = args.sw_seq_len |
| stride = args.sw_stride |
| T = val_tokens.numel() |
| all_starts = list(range(0, T - seq_len - 1, stride)) |
| my_starts = all_starts[rank::world_size] |
|
|
| loss_sum = torch.zeros((), device=device, dtype=torch.float64) |
| token_cnt = torch.zeros((), device=device, dtype=torch.float64) |
| byte_cnt = torch.zeros((), device=device, dtype=torch.float64) |
|
|
| |
| raw_model = model |
| while hasattr(raw_model, 'module'): |
| raw_model = raw_model.module |
| if hasattr(raw_model, '_orig_mod'): |
| raw_model = raw_model._orig_mod |
|
|
| raw_model.eval() |
| |
| ctx = torch.no_grad if (use_ttt and args.ttt_enabled) else torch.inference_mode |
| with ctx(): |
| for start in my_starts: |
| end = start + seq_len |
| x = val_tokens[start:end].unsqueeze(0).to(device, dtype=torch.int64) |
| y = val_tokens[start + 1:end + 1].unsqueeze(0).to(device, dtype=torch.int64) |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| if use_ttt and args.ttt_enabled: |
| ptl = raw_model.per_token_loss_with_ttt(x, y, args) |
| else: |
| ptl = raw_model.per_token_loss(x, y) |
| lo = seq_len - stride |
| ptl_s = ptl[0, lo:] |
| y_s = y[0, lo:] |
| x_s = x[0, lo:] |
| loss_sum += ptl_s.to(torch.float64).sum() |
| token_cnt += ptl_s.numel() |
| tb = base_bytes_lut[y_s].to(torch.float64) |
| tb += (has_space_lut[y_s] & ~is_boundary_lut[x_s]).to(torch.float64) |
| byte_cnt += tb.sum() |
|
|
| if dist.is_available() and dist.is_initialized(): |
| for t in (loss_sum, token_cnt, byte_cnt): |
| dist.all_reduce(t, op=dist.ReduceOp.SUM) |
|
|
| val_loss = float((loss_sum / token_cnt).item()) |
| bpb = float((loss_sum / math.log(2) / byte_cnt).item()) |
| raw_model.train() |
| return val_loss, bpb |
|
|
|
|
| |
| |
| |
|
|
| CONTROL_PATTERNS = tuple(p for p in os.environ.get( |
| "CONTROL_TENSOR_NAME_PATTERNS", |
| "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,log_alpha" |
| ).split(",") if p) |
|
|
| KEEP_FP_MAX_NUMEL = 65_536 |
| KEEP_FP_STORE_DTYPE = torch.float16 |
| INT8_SCALE_DTYPE = torch.float16 |
|
|
|
|
| def sdclip(t: Tensor, n_std: float = 2.5) -> Tensor: |
| """Std-based clipping: clip to mean +/- n_std * std.""" |
| mean = t.float().mean() |
| std = t.float().std() |
| lo = mean - n_std * std |
| hi = mean + n_std * std |
| return t.clamp(lo.item(), hi.item()) |
|
|
|
|
| def _quant_tensor_int6(t: Tensor, n_std: float = 2.5): |
| """Quantize tensor to int6 (range -31 to 31) with SDClip per row.""" |
| t32 = t.float() |
| max_val = 31 |
| if t32.ndim == 2: |
| |
| mean = t32.mean(dim=1, keepdim=True) |
| std = t32.std(dim=1, keepdim=True).clamp_min(1e-9) |
| lo = mean - n_std * std |
| hi = mean + n_std * std |
| t_clipped = t32.clamp(lo.expand_as(t32), hi.expand_as(t32)) |
| clip_val = t_clipped.abs().amax(dim=1).clamp_min(1e-9) |
| scale = clip_val / max_val |
| q = torch.clamp(torch.round(t_clipped / scale[:, None]), -max_val, max_val).to(torch.int8) |
| return q.contiguous(), scale.to(torch.float16).contiguous() |
| |
| t_clipped = sdclip(t32, n_std) |
| cv = float(t_clipped.abs().max().item()) |
| scale = torch.tensor(max(cv / max_val, 1.0 / max_val), dtype=torch.float32) |
| q = torch.clamp(torch.round(t_clipped / scale), -max_val, max_val).to(torch.int8) |
| return q.contiguous(), scale |
|
|
|
|
| def _quant_tensor_int8(t: Tensor, n_std: float = 2.5): |
| """Quantize tensor to int8 with SDClip.""" |
| t32 = t.float() |
| if t32.ndim == 2: |
| mean = t32.mean(dim=1, keepdim=True) |
| std = t32.std(dim=1, keepdim=True).clamp_min(1e-9) |
| lo = mean - n_std * std |
| hi = mean + n_std * std |
| t_clipped = t32.clamp(lo.expand_as(t32), hi.expand_as(t32)) |
| clip_val = t_clipped.abs().amax(dim=1).clamp_min(1e-9) |
| scale = clip_val / 127.0 |
| q = torch.clamp(torch.round(t_clipped / scale[:, None]), -127, 127).to(torch.int8) |
| return q.contiguous(), scale.to(torch.float16).contiguous() |
| cv = float(sdclip(t32, n_std).abs().max().item()) |
| scale = torch.tensor(max(cv / 127.0, 1.0 / 127.0), dtype=torch.float32) |
| q = torch.clamp(torch.round(t32.clamp(-cv, cv) / scale), -127, 127).to(torch.int8) |
| return q.contiguous(), scale |
|
|
|
|
| def quantize_state_dict(state_dict: dict, gptq_bits: int = 6, sdclip_nstd: float = 2.5): |
| """Mixed quantization: int6 for weight matrices, int8 for embeddings, fp16 for small/control.""" |
| quantized, scales, dtypes, passthrough, pt_orig, qmeta = {}, {}, {}, {}, {}, {} |
| stats = {k: 0 for k in ("param_count", "num_tensors", "baseline_bytes", "quant_bytes")} |
| quant_fn = _quant_tensor_int6 if gptq_bits == 6 else _quant_tensor_int8 |
|
|
| for name, tensor in state_dict.items(): |
| t = tensor.detach().cpu().contiguous() |
| stats["param_count"] += t.numel() |
| stats["num_tensors"] += 1 |
| stats["baseline_bytes"] += t.numel() * t.element_size() |
|
|
| if not t.is_floating_point(): |
| passthrough[name] = t |
| stats["quant_bytes"] += t.numel() * t.element_size() |
| continue |
|
|
| is_ctrl = any(p in name for p in CONTROL_PATTERNS) |
| is_small = t.numel() <= KEEP_FP_MAX_NUMEL |
|
|
| |
| if "tok_emb" in name: |
| pt_orig[name] = str(t.dtype).removeprefix("torch.") |
| q, s = _quant_tensor_int8(t, sdclip_nstd) |
| quantized[name] = q |
| scales[name] = s |
| dtypes[name] = str(t.dtype).removeprefix("torch.") |
| if s.ndim > 0: |
| qmeta[name] = {"scheme": "per_row", "axis": 0, "bits": 8} |
| stats["quant_bytes"] += q.numel() + s.numel() * s.element_size() |
| continue |
|
|
| if is_ctrl or is_small: |
| if t.dtype in (torch.float32, torch.bfloat16): |
| pt_orig[name] = str(t.dtype).removeprefix("torch.") |
| passthrough[name] = t.float() if is_ctrl else t.to(KEEP_FP_STORE_DTYPE) |
| passthrough[name] = passthrough[name].contiguous() |
| stats["quant_bytes"] += passthrough[name].numel() * passthrough[name].element_size() |
| continue |
|
|
| |
| q, s = quant_fn(t, sdclip_nstd) |
| if s.ndim > 0: |
| qmeta[name] = {"scheme": "per_row", "axis": 0, "bits": gptq_bits} |
| quantized[name] = q |
| scales[name] = s |
| dtypes[name] = str(t.dtype).removeprefix("torch.") |
| stats["quant_bytes"] += q.numel() + s.numel() * s.element_size() |
|
|
| obj = {"__quant_format__": f"int{gptq_bits}_sdclip_v1", |
| "quantized": quantized, "scales": scales, "dtypes": dtypes, |
| "passthrough": passthrough} |
| if qmeta: obj["qmeta"] = qmeta |
| if pt_orig: obj["passthrough_orig_dtypes"] = pt_orig |
| return obj, stats |
|
|
|
|
| def dequantize_state_dict(obj: dict) -> dict: |
| out = {} |
| qmeta = obj.get("qmeta", {}) |
| pt_orig = obj.get("passthrough_orig_dtypes", {}) |
| for name, q in obj["quantized"].items(): |
| dtype = getattr(torch, obj["dtypes"][name]) |
| s = obj["scales"][name] |
| if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: |
| s = s.to(torch.float32) |
| out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() |
| else: |
| out[name] = (q.float() * float(s.item())).to(dtype).contiguous() |
| for name, t in obj["passthrough"].items(): |
| ot = t.detach().cpu().contiguous() |
| od = pt_orig.get(name) |
| if isinstance(od, str): |
| ot = ot.to(dtype=getattr(torch, od)).contiguous() |
| out[name] = ot |
| return out |
|
|
|
|
| |
| |
| |
|
|
| def load_data_shard(file: Path) -> Tensor: |
| hdr = np.fromfile(file, dtype="<i4", count=256) |
| if hdr.size != 256 or int(hdr[0]) != 20240520 or int(hdr[1]) != 1: |
| raise ValueError(f"Bad shard: {file}") |
| n = int(hdr[2]) |
| tokens = np.fromfile(file, dtype="<u2", count=n, offset=256 * 4) |
| return torch.from_numpy(tokens.astype(np.uint16, copy=False)) |
|
|
|
|
| def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: |
| files = [Path(p) for p in sorted(glob.glob(pattern))] |
| if not files: |
| raise FileNotFoundError(f"No val files: {pattern}") |
| tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() |
| usable = ((tokens.numel() - 1) // seq_len) * seq_len |
| return tokens[: usable + 1] |
|
|
|
|
| class TokenStream: |
| def __init__(self, pattern: str): |
| files = [Path(p) for p in sorted(glob.glob(pattern))] |
| if not files: |
| raise FileNotFoundError(f"No shards: {pattern}") |
| self.files = files |
| self.idx = 0 |
| self.tokens = load_data_shard(files[0]) |
| self.pos = 0 |
|
|
| def take(self, n: int) -> Tensor: |
| chunks, rem = [], n |
| while rem > 0: |
| avail = self.tokens.numel() - self.pos |
| if avail <= 0: |
| self.idx = (self.idx + 1) % len(self.files) |
| self.tokens = load_data_shard(self.files[self.idx]) |
| self.pos = 0 |
| avail = self.tokens.numel() |
| k = min(rem, avail) |
| chunks.append(self.tokens[self.pos: self.pos + k]) |
| self.pos += k |
| rem -= k |
| return chunks[0] if len(chunks) == 1 else torch.cat(chunks) |
|
|
|
|
| class DistributedTokenLoader: |
| def __init__(self, pattern, rank, world_size, device): |
| self.rank = rank; self.ws = world_size; self.device = device |
| self.stream = TokenStream(pattern) |
|
|
| def next_batch(self, global_tokens, seq_len, grad_accum): |
| local_tokens = global_tokens // (self.ws * grad_accum) |
| per_rank_span = local_tokens + 1 |
| chunk = self.stream.take(per_rank_span * self.ws) |
| start = self.rank * per_rank_span |
| local = chunk[start: start + per_rank_span].to(torch.int64) |
| x = local[:-1].reshape(-1, seq_len) |
| y = local[1:].reshape(-1, seq_len) |
| return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) |
|
|
|
|
| |
| |
| |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, eps: float | None = None): |
| super().__init__() |
| self.eps = eps |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return F.rms_norm(x, (x.size(-1),), eps=self.eps) |
|
|
|
|
| class Rotary(nn.Module): |
| def __init__(self, dim: int, base: float = 10000.0): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self._cached_len = 0 |
| self._cos: Tensor | None = None |
| self._sin: Tensor | None = None |
|
|
| def forward(self, seq_len: int, device, dtype): |
| if self._cos is None or self._cached_len != seq_len or self._cos.device != device: |
| t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(t, self.inv_freq.to(device)) |
| self._cos = freqs.cos()[None, None, :, :] |
| self._sin = freqs.sin()[None, None, :, :] |
| self._cached_len = seq_len |
| return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) |
|
|
|
|
| def apply_rotary(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: |
| half = x.size(-1) // 2 |
| x1, x2 = x[..., :half], x[..., half:] |
| return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init): |
| super().__init__() |
| assert dim % num_heads == 0 and num_heads % num_kv_heads == 0 |
| self.num_heads = num_heads |
| self.num_kv_heads = num_kv_heads |
| self.head_dim = dim // num_heads |
| kv_dim = num_kv_heads * self.head_dim |
| self.c_q = nn.Linear(dim, dim, bias=False) |
| self.c_k = nn.Linear(dim, kv_dim, bias=False) |
| self.c_v = nn.Linear(dim, kv_dim, bias=False) |
| self.proj = nn.Linear(dim, dim, bias=False) |
| self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) |
| self.rotary = Rotary(self.head_dim, base=rope_base) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| B, T, _ = x.shape |
| q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.c_v(x).reshape(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| q = F.rms_norm(q, (q.size(-1),)) |
| k = F.rms_norm(k, (k.size(-1),)) |
| cos, sin = self.rotary(T, x.device, q.dtype) |
| q = apply_rotary(q, cos, sin) |
| k = apply_rotary(k, cos, sin) |
| q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] |
| y = F.scaled_dot_product_attention(q, k, v, |
| attn_mask=None, is_causal=True, |
| enable_gqa=(self.num_kv_heads != self.num_heads)) |
| return self.proj(y.transpose(1, 2).contiguous().reshape(B, T, -1)) |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, dim, mlp_mult): |
| super().__init__() |
| hidden = dim * mlp_mult |
| self.fc = nn.Linear(dim, hidden, bias=False) |
| self.proj = nn.Linear(hidden, dim, bias=False) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.proj(torch.relu(self.fc(x)).square()) |
|
|
|
|
| class ParallelBlock(nn.Module): |
| """Parallel Residual Block: attn and MLP run on the same normalized input. |
| |
| x = resid_mix[0]*x + resid_mix[1]*x0 |
| h = norm(x) |
| x = x + attn_scale * attn(h) + mlp_scale * mlp(h) |
| """ |
| def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init): |
| super().__init__() |
| self.norm = RMSNorm() |
| self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) |
| self.mlp = MLP(dim, mlp_mult) |
| self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) |
| self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) |
| self.resid_mix = nn.Parameter(torch.stack([torch.ones(dim), torch.zeros(dim)]).float()) |
|
|
| def forward(self, x: Tensor, x0: Tensor) -> Tensor: |
| mix = self.resid_mix.to(x.dtype) |
| x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 |
| h = self.norm(x) |
| |
| x = x + self.attn_scale.to(x.dtype)[None, None, :] * self.attn(h) \ |
| + self.mlp_scale.to(x.dtype)[None, None, :] * self.mlp(h) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class RecurrentGPT(nn.Module): |
| """ |
| K unique parallel-residual blocks x N recurrences. |
| At eval: 2N recurrences + optional score-first TTT. |
| """ |
| def __init__(self, args: Hyperparameters): |
| super().__init__() |
| self.logit_softcap = args.logit_softcap |
| self._train_rec = args.num_recurrences |
| self._eval_rec = args.num_eval_recurrences or args.num_recurrences * 2 |
| self._vocab_size = args.vocab_size |
|
|
| self.tok_emb = nn.Embedding(args.vocab_size, args.model_dim) |
| self.blocks = nn.ModuleList([ |
| ParallelBlock(args.model_dim, args.num_heads, args.num_kv_heads, |
| args.mlp_mult, args.rope_base, args.qk_gain_init) |
| for _ in range(args.num_unique_layers) |
| ]) |
| self.final_norm = RMSNorm() |
| nn.init.normal_(self.tok_emb.weight, std=0.005) |
|
|
| def _forward_hidden(self, input_ids: Tensor) -> Tensor: |
| x = F.rms_norm(self.tok_emb(input_ids), (self.tok_emb.embedding_dim,)) |
| x0 = x |
| n = self._train_rec if self.training else self._eval_rec |
| for _ in range(n): |
| for block in self.blocks: |
| x = block(x, x0) |
| return self.final_norm(x) |
|
|
| def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: |
| h = self._forward_hidden(input_ids) |
| logits = F.linear(h.reshape(-1, h.size(-1)), self.tok_emb.weight) |
| logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) |
| return F.cross_entropy(logits.float(), target_ids.reshape(-1), reduction="mean") |
|
|
| def per_token_loss(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: |
| h = self._forward_hidden(input_ids) |
| B, T, D = h.shape |
| logits = F.linear(h.reshape(B * T, D), self.tok_emb.weight) |
| logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) |
| return F.cross_entropy(logits.float(), target_ids.reshape(B * T), |
| reduction="none").reshape(B, T) |
|
|
| @torch.no_grad() |
| def per_token_loss_with_ttt(self, input_ids: Tensor, target_ids: Tensor, |
| args: Hyperparameters) -> Tensor: |
| """Score-first TTT: adapt MLP.proj weights chunk-by-chunk at eval. |
| |
| "Score-first" means: for each chunk, we first SCORE (compute loss) with |
| current weights, then UPDATE weights for the next chunk. This is strictly |
| causal -- predictions for chunk i only use information from chunks 0..i-1. |
| |
| We update MLP.proj.weight (the "down projection") in each block -- |
| this is the "fast weight" in the In-Place TTT framework (arxiv:2604.06169). |
| """ |
| chunk_size = args.ttt_chunk_size |
| ttt_lr = args.ttt_lr |
| B, T = input_ids.shape |
|
|
| |
| if args.ttt_layers == "all": |
| ttt_layer_indices = list(range(len(self.blocks))) |
| else: |
| ttt_layer_indices = [int(x) for x in args.ttt_layers.split(",")] |
|
|
| |
| original_weights = {} |
| for li in ttt_layer_indices: |
| original_weights[li] = self.blocks[li].mlp.proj.weight.data.clone() |
|
|
| all_ptl = [] |
| n_chunks = (T + chunk_size - 1) // chunk_size |
|
|
| for ci in range(n_chunks): |
| lo = ci * chunk_size |
| hi = min((ci + 1) * chunk_size, T) |
|
|
| |
| h = self._forward_hidden(input_ids) |
| h_chunk = h[:, lo:hi, :] |
| y_chunk = target_ids[:, lo:hi] |
|
|
| logits = F.linear(h_chunk.reshape(-1, h_chunk.size(-1)), self.tok_emb.weight) |
| logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) |
| ptl = F.cross_entropy(logits.float(), y_chunk.reshape(-1), |
| reduction="none").reshape(B, hi - lo) |
| all_ptl.append(ptl) |
|
|
| |
| if ci < n_chunks - 1: |
| for li in ttt_layer_indices: |
| block = self.blocks[li] |
| |
| h_norm = F.rms_norm(h_chunk.reshape(-1, h_chunk.size(-1)).float(), |
| (h_chunk.size(-1),)) |
| z = torch.relu(block.mlp.fc(h_norm.to(h_chunk.dtype))).square() |
| |
| pred = z @ block.mlp.proj.weight.T |
| residual = pred - h_norm.to(pred.dtype) |
| grad_w = residual.T @ z / z.size(0) |
| block.mlp.proj.weight.data -= ttt_lr * grad_w.to(block.mlp.proj.weight.dtype) |
|
|
| |
| for li in ttt_layer_indices: |
| self.blocks[li].mlp.proj.weight.data = original_weights[li] |
|
|
| return torch.cat(all_ptl, dim=1) |
|
|
|
|
| |
| |
| |
|
|
| class EMA: |
| """Exponential Moving Average of model parameters.""" |
| def __init__(self, model: nn.Module, decay: float = 0.999): |
| self.model = model |
| self.decay = decay |
| self.shadow = {n: p.data.clone() for n, p in model.named_parameters()} |
| self.backup = {} |
|
|
| def update(self): |
| for n, p in self.model.named_parameters(): |
| self.shadow[n].mul_(self.decay).add_(p.data, alpha=1.0 - self.decay) |
|
|
| def apply(self): |
| """Apply EMA weights (backup current).""" |
| self.backup = {} |
| for n, p in self.model.named_parameters(): |
| self.backup[n] = p.data.clone() |
| p.data.copy_(self.shadow[n]) |
|
|
| def restore(self): |
| """Restore original weights.""" |
| for n, p in self.model.named_parameters(): |
| p.data.copy_(self.backup[n]) |
| self.backup = {} |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| global zeropower_via_newtonschulz5 |
| code = Path(__file__).read_text(encoding="utf-8") |
| args = Hyperparameters() |
|
|
| zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) |
|
|
| |
| distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ |
| rank = int(os.environ.get("RANK", "0")) |
| world_size = int(os.environ.get("WORLD_SIZE", "1")) |
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
| grad_accum = max(1, 8 // world_size) |
| grad_scale = 1.0 / grad_accum |
|
|
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA required") |
| device = torch.device("cuda", local_rank) |
| torch.cuda.set_device(device) |
| if distributed: |
| dist.init_process_group("nccl", device_id=device) |
| dist.barrier() |
|
|
| master = rank == 0 |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| from torch.backends.cuda import (enable_flash_sdp, enable_math_sdp, |
| enable_mem_efficient_sdp, enable_cudnn_sdp) |
| enable_flash_sdp(True); enable_math_sdp(False) |
| enable_mem_efficient_sdp(False); enable_cudnn_sdp(False) |
|
|
| logfile = None |
| if master: |
| os.makedirs("logs", exist_ok=True) |
| logfile = f"logs/{args.run_id}.txt" |
| print(logfile) |
|
|
| def log0(msg, console=True): |
| if not master: return |
| if console: print(msg) |
| if logfile: |
| with open(logfile, "a") as f: print(msg, file=f) |
|
|
| log0(code, console=False) |
| log0(f"Python {sys.version}", console=False) |
| log0(f"PyTorch {torch.__version__}", console=False) |
| try: |
| log0(subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=False).stdout, |
| console=False) |
| except FileNotFoundError: |
| pass |
|
|
| random.seed(args.seed); np.random.seed(args.seed) |
| torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) |
|
|
| |
| sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) |
| base_bytes_lut, has_space_lut, is_boundary_lut = build_sentencepiece_luts( |
| sp, args.vocab_size, device) |
| val_tokens = load_validation_tokens(args.val_files, args.sw_seq_len) |
| log0(f"val_tokens:{val_tokens.numel()}") |
|
|
| |
| base_model = RecurrentGPT(args).to(device).bfloat16() |
|
|
| compiled = torch.compile(base_model, dynamic=False, fullgraph=True) |
| model = DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) \ |
| if distributed else compiled |
|
|
| n_unique = sum(p.numel() for p in base_model.parameters()) |
| eff_depth = args.num_unique_layers * args.num_recurrences |
| log0(f"unique_params:{n_unique} effective_depth:{eff_depth} " |
| f"train_loops:{args.num_recurrences} eval_loops:{base_model._eval_rec}") |
| log0(f"world_size:{world_size} grad_accum:{grad_accum}") |
|
|
| |
| block_params = list(base_model.blocks.named_parameters()) |
| matrix_params = [p for n, p in block_params |
| if p.ndim == 2 and not any(pat in n for pat in CONTROL_PATTERNS)] |
| scalar_params = [p for n, p in block_params |
| if p.ndim < 2 or any(pat in n for pat in CONTROL_PATTERNS)] |
|
|
| opt_tok = torch.optim.Adam( |
| [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], |
| betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) |
| opt_muon = Muon(matrix_params, lr=args.matrix_lr, |
| momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, |
| weight_decay=args.muon_weight_decay) |
| for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr |
| opt_scalar = torch.optim.Adam( |
| [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], |
| betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) |
| optimizers = [opt_tok, opt_muon, opt_scalar] |
|
|
| |
| ema = EMA(base_model, decay=0.999) |
| ema_start_step = int(args.iterations * args.swa_start_frac) |
|
|
| |
| max_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None |
|
|
| def lr_mul(step, elapsed_ms): |
| if args.warmdown_iters <= 0: return 1.0 |
| if max_ms is None: |
| ws = max(args.iterations - args.warmdown_iters, 0) |
| return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ |
| if ws <= step < args.iterations else 1.0 |
| step_ms = elapsed_ms / max(step, 1) |
| remain = max(max_ms - elapsed_ms, 0.0) |
| wd_ms = args.warmdown_iters * step_ms |
| return remain / max(wd_ms, 1e-9) if remain <= wd_ms else 1.0 |
|
|
| def zero_all(): [o.zero_grad(set_to_none=True) for o in optimizers] |
|
|
| |
| if args.warmup_steps > 0: |
| init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} |
| init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] |
| model.train() |
| train_loader_w = DistributedTokenLoader(args.train_files, rank, world_size, device) |
| for ws_i in range(args.warmup_steps): |
| zero_all() |
| for ms_i in range(grad_accum): |
| if distributed: |
| model.require_backward_grad_sync = (ms_i == grad_accum - 1) |
| x, y = train_loader_w.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum) |
| with torch.autocast("cuda", torch.bfloat16): |
| (model(x, y) * grad_scale).backward() |
| for o in optimizers: o.step() |
| zero_all() |
| base_model.load_state_dict(init_model, strict=True) |
| for o, s in zip(optimizers, init_opts): o.load_state_dict(s) |
| zero_all() |
| if distributed: model.require_backward_grad_sync = True |
|
|
| |
| train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) |
| training_ms = 0.0 |
| stop_step: int | None = None |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
| step = 0 |
|
|
| while True: |
| last_step = step == args.iterations or (stop_step is not None and step >= stop_step) |
| do_val = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) |
|
|
| if do_val: |
| torch.cuda.synchronize() |
| training_ms += 1000.0 * (time.perf_counter() - t0) |
| vl, vbpb = eval_val_sliding_window( |
| args, model, rank, world_size, device, |
| val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut, |
| use_ttt=False) |
| log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " |
| f"train_ms:{training_ms:.0f} step_avg:{training_ms/max(step,1):.2f}ms") |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
|
|
| if last_step: |
| if master: |
| |
| ema.apply() |
|
|
| |
| log0("Evaluating with EMA + TTT...") |
| vl_ema, vbpb_ema = eval_val_sliding_window( |
| args, base_model, rank, world_size, device, |
| val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut, |
| use_ttt=True) |
| log0(f"ema_ttt val_loss:{vl_ema:.4f} val_bpb:{vbpb_ema:.4f}") |
|
|
| |
| sd = base_model.state_dict() |
| obj, stats = quantize_state_dict(sd, args.gptq_bits, args.sdclip_nstd) |
| buf = io.BytesIO() |
| torch.save(obj, buf) |
| compressed = zlib.compress(buf.getvalue(), level=9) |
| code_bytes = len(code.encode()) |
| model_bytes = len(compressed) |
| total_bytes = code_bytes + model_bytes |
| log0(f"final_quant_zlib_roundtrip " |
| f"code_bytes:{code_bytes} " |
| f"model_compressed_bytes:{model_bytes} " |
| f"total_artifact_bytes:{total_bytes} " |
| f"total_artifact_mb:{total_bytes/1e6:.3f} " |
| f"param_count:{stats['param_count']}") |
|
|
| |
| sd2 = dequantize_state_dict(obj) |
| base_model.load_state_dict(sd2, strict=True) |
| vl2, vbpb2 = eval_val_sliding_window( |
| args, base_model, rank, world_size, device, |
| val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut, |
| use_ttt=True) |
| log0(f"quantized_model+ttt val_loss:{vl2:.4f} val_bpb:{vbpb2:.4f}") |
|
|
| |
| ema.restore() |
| break |
|
|
| if stop_step is None and max_ms is not None: |
| torch.cuda.synchronize() |
| elapsed = 1000.0 * (time.perf_counter() - t0) + training_ms |
| if elapsed >= max_ms: |
| stop_step = step + 1 |
|
|
| zero_all() |
| for ms_i in range(grad_accum): |
| if distributed: |
| model.require_backward_grad_sync = (ms_i == grad_accum - 1) |
| x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum) |
| with torch.autocast("cuda", torch.bfloat16): |
| (model(x, y) * grad_scale).backward() |
|
|
| torch.cuda.synchronize() |
| elapsed_ms = 1000.0 * (time.perf_counter() - t0) + training_ms |
| m = lr_mul(step, elapsed_ms) |
| for o in optimizers: |
| for g in o.param_groups: g["lr"] = g["base_lr"] * m |
| for o in optimizers: o.step() |
|
|
| |
| if step >= ema_start_step: |
| ema.update() |
|
|
| if step % args.train_log_every == 0 and master: |
| log0(f"step:{step} lr_mul:{m:.4f}") |
|
|
| step += 1 |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|