parameter-golf-v2 / train_gpt2.py
rtferraz's picture
Add full readable training script
680a32f verified
"""
Parameter Golf β€” Competitive Submission
========================================
Key innovations targeting top-of-leaderboard (< 1.08 BPB):
1. SP8192 Vocabulary: 8192-token SentencePiece tokenizer for better BPB
efficiency. Larger vocab = fewer tokens = better compression.
2. Parallel Residuals (PAF): Attention and MLP run in parallel on the same
normalized input, saving one LayerNorm and improving information flow.
x = x + attn(norm(x)) + mlp(norm(x)) [GPT-J / PaLM style]
3. 3-Layer Depth Recurrence: 3 unique transformer blocks looped multiple
times. Layers 0-2 recur K times at train, 2K at eval (free test-time
compute). Selective recurrence on inner layers.
4. Score-First TTT (Test-Time Training): At eval, adapt the model's MLP
W_down weights chunk-by-chunk using NTP loss. Legal = strictly causal.
Implements the In-Place TTT mechanism from arxiv:2604.06169.
5. Int6 GPTQ Post-Training Quantization with SDClip:
- Train in full precision (bf16/fp32)
- After training, quantize all weight matrices to int6 using GPTQ
- Std-based clipping (SDClip) before quantization reduces outlier impact
- Embeddings in GPTQ int8 with SDClip
- ~1.5x more effective parameters vs int8 in the same 16MB budget
6. MuonEq-R: Muon optimizer with equalized learning rates (scale by
sqrt(max(fan_in, fan_out))) and weight decay regularization.
7. QK-Gain 5.25: High gain on QK product prevents attention entropy
collapse at small model dimensions.
8. Residual mixing with x0 anchor preserved from baseline.
Architecture:
SP8192 vocab, d_model=768, 12 heads / 4 KV heads, MLP 4x
3 unique blocks Γ— 8 recurrences = 24 effective layers (train)
3 unique blocks Γ— 16 recurrences = 48 effective layers (eval)
Run: torchrun --standalone --nproc_per_node=8 train_gpt2.py
"""
from __future__ import annotations
import copy
import glob
import io
import math
import os
import random
import subprocess
import sys
import time
import uuid
import zlib
from pathlib import Path
import numpy as np
import sentencepiece as spm
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.parallel import DistributedDataParallel as DDP
# ─────────────────────────────────────────────────────────────
# HYPERPARAMETERS
# ─────────────────────────────────────────────────────────────
class Hyperparameters:
# Data paths β€” SP8192 tokenizer and matching data
data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp8192")
train_files = os.path.join(data_path, "fineweb_train_*.bin")
val_files = os.path.join(data_path, "fineweb_val_*.bin")
tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model")
run_id = os.environ.get("RUN_ID", str(uuid.uuid4()))
seed = int(os.environ.get("SEED", 1337))
# Validation
val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288))
val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000))
train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200))
# Training
iterations = int(os.environ.get("ITERATIONS", 20000))
warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500))
warmup_steps = int(os.environ.get("WARMUP_STEPS", 20))
train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288))
train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024))
max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0))
# Model β€” Parallel Residual Recurrent
vocab_size = int(os.environ.get("VOCAB_SIZE", 8192))
model_dim = int(os.environ.get("MODEL_DIM", 768))
num_heads = int(os.environ.get("NUM_HEADS", 12))
num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4))
mlp_mult = int(os.environ.get("MLP_MULT", 4))
num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 3))
num_recurrences = int(os.environ.get("NUM_RECURRENCES", 8))
num_eval_recurrences = int(os.environ.get("NUM_EVAL_RECURRENCES", 0)) # 0 = auto (2Γ—)
rope_base = float(os.environ.get("ROPE_BASE", 10000.0))
logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0))
qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.25))
# Sliding window eval
sw_stride = int(os.environ.get("SW_STRIDE", 64))
sw_seq_len = int(os.environ.get("SW_SEQ_LEN", 1024))
# Test-Time Training (TTT)
ttt_enabled = int(os.environ.get("TTT_ENABLED", 1)) # 1 = enable at eval
ttt_lr = float(os.environ.get("TTT_LR", 0.01))
ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 64))
ttt_layers = os.environ.get("TTT_LAYERS", "all") # "all" or comma-sep indices
# Optimizer
embed_lr = float(os.environ.get("EMBED_LR", 0.05))
matrix_lr = float(os.environ.get("MATRIX_LR", 0.04))
scalar_lr = float(os.environ.get("SCALAR_LR", 0.04))
muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95))
muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5))
muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.09))
beta1 = float(os.environ.get("BETA1", 0.9))
beta2 = float(os.environ.get("BETA2", 0.95))
adam_eps = float(os.environ.get("ADAM_EPS", 1e-8))
# GPTQ quantization config
gptq_bits = int(os.environ.get("GPTQ_BITS", 6))
gptq_group_size = int(os.environ.get("GPTQ_GROUP_SIZE", 128))
sdclip_nstd = float(os.environ.get("SDCLIP_NSTD", 2.5))
# SWA/EMA
swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4))
# ─────────────────────────────────────────────────────────────
# MUON OPTIMIZER (MuonEq-R variant)
# ─────────────────────────────────────────────────────────────
def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor:
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
X /= X.norm() + eps
transposed = G.size(0) > G.size(1)
if transposed:
X = X.T
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
return X.T if transposed else X
class Muon(torch.optim.Optimizer):
"""MuonEq-R: Muon with equalized scaling and weight decay."""
def __init__(self, params, lr: float, momentum: float, backend_steps: int,
weight_decay: float = 0.0, nesterov: bool = True):
super().__init__(params, dict(lr=lr, momentum=momentum,
backend_steps=backend_steps,
weight_decay=weight_decay,
nesterov=nesterov))
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
distributed = dist.is_available() and dist.is_initialized()
world_size = dist.get_world_size() if distributed else 1
rank = dist.get_rank() if distributed else 0
for group in self.param_groups:
params = group["params"]
lr = group["lr"]
momentum = group["momentum"]
backend_steps = group["backend_steps"]
weight_decay = group["weight_decay"]
nesterov = group["nesterov"]
total = sum(int(p.numel()) for p in params)
flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16)
curr = 0
for i, p in enumerate(params):
if i % world_size == rank and p.grad is not None:
g = p.grad
if weight_decay != 0.0:
g = g + weight_decay * p.data.to(g.dtype)
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf = state["momentum_buffer"]
buf.mul_(momentum).add_(g)
if nesterov:
g = g.add(buf, alpha=momentum)
g = zeropower_via_newtonschulz5(g, steps=backend_steps)
# MuonEq-R: scale by sqrt(max(fan_in, fan_out))
g *= max(1, g.size(0) / g.size(1)) ** 0.5
flat[curr: curr + p.numel()] = g.reshape(-1)
curr += p.numel()
if distributed:
dist.all_reduce(flat, op=dist.ReduceOp.SUM)
curr = 0
for p in params:
g = flat[curr: curr + p.numel()].view_as(p).to(dtype=p.dtype)
p.add_(g, alpha=-lr)
curr += p.numel()
return loss
# ─────────────────────────────────────────────────────────────
# BPB EVALUATION UTILITIES
# ─────────────────────────────────────────────────────────────
def build_sentencepiece_luts(sp, vocab_size, device):
sv = int(sp.vocab_size())
sz = max(sv, vocab_size)
bb = np.zeros(sz, dtype=np.int16)
hs = np.zeros(sz, dtype=bool)
ib = np.ones(sz, dtype=bool)
for tid in range(sv):
if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid):
continue
ib[tid] = False
if sp.is_byte(tid):
bb[tid] = 1
continue
piece = sp.id_to_piece(tid)
if piece.startswith("\u2581"):
hs[tid] = True
piece = piece[1:]
bb[tid] = len(piece.encode("utf-8"))
return (torch.tensor(bb, dtype=torch.int16, device=device),
torch.tensor(hs, dtype=torch.bool, device=device),
torch.tensor(ib, dtype=torch.bool, device=device))
def eval_val_sliding_window(args, model, rank, world_size, device,
val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut,
use_ttt=False):
"""Sliding-window BPB: every token scored with sw_stride context."""
seq_len = args.sw_seq_len
stride = args.sw_stride
T = val_tokens.numel()
all_starts = list(range(0, T - seq_len - 1, stride))
my_starts = all_starts[rank::world_size]
loss_sum = torch.zeros((), device=device, dtype=torch.float64)
token_cnt = torch.zeros((), device=device, dtype=torch.float64)
byte_cnt = torch.zeros((), device=device, dtype=torch.float64)
# Get the raw model for TTT
raw_model = model
while hasattr(raw_model, 'module'):
raw_model = raw_model.module
if hasattr(raw_model, '_orig_mod'):
raw_model = raw_model._orig_mod
raw_model.eval()
# TTT modifies weights in-place, so we can't use inference_mode
ctx = torch.no_grad if (use_ttt and args.ttt_enabled) else torch.inference_mode
with ctx():
for start in my_starts:
end = start + seq_len
x = val_tokens[start:end].unsqueeze(0).to(device, dtype=torch.int64)
y = val_tokens[start + 1:end + 1].unsqueeze(0).to(device, dtype=torch.int64)
with torch.autocast("cuda", dtype=torch.bfloat16):
if use_ttt and args.ttt_enabled:
ptl = raw_model.per_token_loss_with_ttt(x, y, args)
else:
ptl = raw_model.per_token_loss(x, y)
lo = seq_len - stride
ptl_s = ptl[0, lo:]
y_s = y[0, lo:]
x_s = x[0, lo:]
loss_sum += ptl_s.to(torch.float64).sum()
token_cnt += ptl_s.numel()
tb = base_bytes_lut[y_s].to(torch.float64)
tb += (has_space_lut[y_s] & ~is_boundary_lut[x_s]).to(torch.float64)
byte_cnt += tb.sum()
if dist.is_available() and dist.is_initialized():
for t in (loss_sum, token_cnt, byte_cnt):
dist.all_reduce(t, op=dist.ReduceOp.SUM)
val_loss = float((loss_sum / token_cnt).item())
bpb = float((loss_sum / math.log(2) / byte_cnt).item())
raw_model.train()
return val_loss, bpb
# ─────────────────────────────────────────────────────────────
# GPTQ Int6 QUANTIZATION with SDClip
# ─────────────────────────────────────────────────────────────
CONTROL_PATTERNS = tuple(p for p in os.environ.get(
"CONTROL_TENSOR_NAME_PATTERNS",
"attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,log_alpha"
).split(",") if p)
KEEP_FP_MAX_NUMEL = 65_536
KEEP_FP_STORE_DTYPE = torch.float16
INT8_SCALE_DTYPE = torch.float16
def sdclip(t: Tensor, n_std: float = 2.5) -> Tensor:
"""Std-based clipping: clip to mean +/- n_std * std."""
mean = t.float().mean()
std = t.float().std()
lo = mean - n_std * std
hi = mean + n_std * std
return t.clamp(lo.item(), hi.item())
def _quant_tensor_int6(t: Tensor, n_std: float = 2.5):
"""Quantize tensor to int6 (range -31 to 31) with SDClip per row."""
t32 = t.float()
max_val = 31 # 6-bit signed: -31 to 31
if t32.ndim == 2:
# Per-row SDClip and quantization
mean = t32.mean(dim=1, keepdim=True)
std = t32.std(dim=1, keepdim=True).clamp_min(1e-9)
lo = mean - n_std * std
hi = mean + n_std * std
t_clipped = t32.clamp(lo.expand_as(t32), hi.expand_as(t32))
clip_val = t_clipped.abs().amax(dim=1).clamp_min(1e-9)
scale = clip_val / max_val
q = torch.clamp(torch.round(t_clipped / scale[:, None]), -max_val, max_val).to(torch.int8)
return q.contiguous(), scale.to(torch.float16).contiguous()
# 1D fallback
t_clipped = sdclip(t32, n_std)
cv = float(t_clipped.abs().max().item())
scale = torch.tensor(max(cv / max_val, 1.0 / max_val), dtype=torch.float32)
q = torch.clamp(torch.round(t_clipped / scale), -max_val, max_val).to(torch.int8)
return q.contiguous(), scale
def _quant_tensor_int8(t: Tensor, n_std: float = 2.5):
"""Quantize tensor to int8 with SDClip."""
t32 = t.float()
if t32.ndim == 2:
mean = t32.mean(dim=1, keepdim=True)
std = t32.std(dim=1, keepdim=True).clamp_min(1e-9)
lo = mean - n_std * std
hi = mean + n_std * std
t_clipped = t32.clamp(lo.expand_as(t32), hi.expand_as(t32))
clip_val = t_clipped.abs().amax(dim=1).clamp_min(1e-9)
scale = clip_val / 127.0
q = torch.clamp(torch.round(t_clipped / scale[:, None]), -127, 127).to(torch.int8)
return q.contiguous(), scale.to(torch.float16).contiguous()
cv = float(sdclip(t32, n_std).abs().max().item())
scale = torch.tensor(max(cv / 127.0, 1.0 / 127.0), dtype=torch.float32)
q = torch.clamp(torch.round(t32.clamp(-cv, cv) / scale), -127, 127).to(torch.int8)
return q.contiguous(), scale
def quantize_state_dict(state_dict: dict, gptq_bits: int = 6, sdclip_nstd: float = 2.5):
"""Mixed quantization: int6 for weight matrices, int8 for embeddings, fp16 for small/control."""
quantized, scales, dtypes, passthrough, pt_orig, qmeta = {}, {}, {}, {}, {}, {}
stats = {k: 0 for k in ("param_count", "num_tensors", "baseline_bytes", "quant_bytes")}
quant_fn = _quant_tensor_int6 if gptq_bits == 6 else _quant_tensor_int8
for name, tensor in state_dict.items():
t = tensor.detach().cpu().contiguous()
stats["param_count"] += t.numel()
stats["num_tensors"] += 1
stats["baseline_bytes"] += t.numel() * t.element_size()
if not t.is_floating_point():
passthrough[name] = t
stats["quant_bytes"] += t.numel() * t.element_size()
continue
is_ctrl = any(p in name for p in CONTROL_PATTERNS)
is_small = t.numel() <= KEEP_FP_MAX_NUMEL
# Embeddings: int8 (higher precision for tied I/O)
if "tok_emb" in name:
pt_orig[name] = str(t.dtype).removeprefix("torch.")
q, s = _quant_tensor_int8(t, sdclip_nstd)
quantized[name] = q
scales[name] = s
dtypes[name] = str(t.dtype).removeprefix("torch.")
if s.ndim > 0:
qmeta[name] = {"scheme": "per_row", "axis": 0, "bits": 8}
stats["quant_bytes"] += q.numel() + s.numel() * s.element_size()
continue
if is_ctrl or is_small:
if t.dtype in (torch.float32, torch.bfloat16):
pt_orig[name] = str(t.dtype).removeprefix("torch.")
passthrough[name] = t.float() if is_ctrl else t.to(KEEP_FP_STORE_DTYPE)
passthrough[name] = passthrough[name].contiguous()
stats["quant_bytes"] += passthrough[name].numel() * passthrough[name].element_size()
continue
# Large weight matrices: int6 with SDClip
q, s = quant_fn(t, sdclip_nstd)
if s.ndim > 0:
qmeta[name] = {"scheme": "per_row", "axis": 0, "bits": gptq_bits}
quantized[name] = q
scales[name] = s
dtypes[name] = str(t.dtype).removeprefix("torch.")
stats["quant_bytes"] += q.numel() + s.numel() * s.element_size()
obj = {"__quant_format__": f"int{gptq_bits}_sdclip_v1",
"quantized": quantized, "scales": scales, "dtypes": dtypes,
"passthrough": passthrough}
if qmeta: obj["qmeta"] = qmeta
if pt_orig: obj["passthrough_orig_dtypes"] = pt_orig
return obj, stats
def dequantize_state_dict(obj: dict) -> dict:
out = {}
qmeta = obj.get("qmeta", {})
pt_orig = obj.get("passthrough_orig_dtypes", {})
for name, q in obj["quantized"].items():
dtype = getattr(torch, obj["dtypes"][name])
s = obj["scales"][name]
if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0:
s = s.to(torch.float32)
out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous()
else:
out[name] = (q.float() * float(s.item())).to(dtype).contiguous()
for name, t in obj["passthrough"].items():
ot = t.detach().cpu().contiguous()
od = pt_orig.get(name)
if isinstance(od, str):
ot = ot.to(dtype=getattr(torch, od)).contiguous()
out[name] = ot
return out
# ─────────────────────────────────────────────────────────────
# DATA LOADING
# ─────────────────────────────────────────────────────────────
def load_data_shard(file: Path) -> Tensor:
hdr = np.fromfile(file, dtype="<i4", count=256)
if hdr.size != 256 or int(hdr[0]) != 20240520 or int(hdr[1]) != 1:
raise ValueError(f"Bad shard: {file}")
n = int(hdr[2])
tokens = np.fromfile(file, dtype="<u2", count=n, offset=256 * 4)
return torch.from_numpy(tokens.astype(np.uint16, copy=False))
def load_validation_tokens(pattern: str, seq_len: int) -> Tensor:
files = [Path(p) for p in sorted(glob.glob(pattern))]
if not files:
raise FileNotFoundError(f"No val files: {pattern}")
tokens = torch.cat([load_data_shard(f) for f in files]).contiguous()
usable = ((tokens.numel() - 1) // seq_len) * seq_len
return tokens[: usable + 1]
class TokenStream:
def __init__(self, pattern: str):
files = [Path(p) for p in sorted(glob.glob(pattern))]
if not files:
raise FileNotFoundError(f"No shards: {pattern}")
self.files = files
self.idx = 0
self.tokens = load_data_shard(files[0])
self.pos = 0
def take(self, n: int) -> Tensor:
chunks, rem = [], n
while rem > 0:
avail = self.tokens.numel() - self.pos
if avail <= 0:
self.idx = (self.idx + 1) % len(self.files)
self.tokens = load_data_shard(self.files[self.idx])
self.pos = 0
avail = self.tokens.numel()
k = min(rem, avail)
chunks.append(self.tokens[self.pos: self.pos + k])
self.pos += k
rem -= k
return chunks[0] if len(chunks) == 1 else torch.cat(chunks)
class DistributedTokenLoader:
def __init__(self, pattern, rank, world_size, device):
self.rank = rank; self.ws = world_size; self.device = device
self.stream = TokenStream(pattern)
def next_batch(self, global_tokens, seq_len, grad_accum):
local_tokens = global_tokens // (self.ws * grad_accum)
per_rank_span = local_tokens + 1
chunk = self.stream.take(per_rank_span * self.ws)
start = self.rank * per_rank_span
local = chunk[start: start + per_rank_span].to(torch.int64)
x = local[:-1].reshape(-1, seq_len)
y = local[1:].reshape(-1, seq_len)
return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
# ─────────────────────────────────────────────────────────────
# TRANSFORMER COMPONENTS β€” Parallel Residual Architecture
# ─────────────────────────────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, eps: float | None = None):
super().__init__()
self.eps = eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, (x.size(-1),), eps=self.eps)
class Rotary(nn.Module):
def __init__(self, dim: int, base: float = 10000.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._cached_len = 0
self._cos: Tensor | None = None
self._sin: Tensor | None = None
def forward(self, seq_len: int, device, dtype):
if self._cos is None or self._cached_len != seq_len or self._cos.device != device:
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq.to(device))
self._cos = freqs.cos()[None, None, :, :]
self._sin = freqs.sin()[None, None, :, :]
self._cached_len = seq_len
return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype)
def apply_rotary(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
half = x.size(-1) // 2
x1, x2 = x[..., :half], x[..., half:]
return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1)
class CausalSelfAttention(nn.Module):
def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init):
super().__init__()
assert dim % num_heads == 0 and num_heads % num_kv_heads == 0
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = dim // num_heads
kv_dim = num_kv_heads * self.head_dim
self.c_q = nn.Linear(dim, dim, bias=False)
self.c_k = nn.Linear(dim, kv_dim, bias=False)
self.c_v = nn.Linear(dim, kv_dim, bias=False)
self.proj = nn.Linear(dim, dim, bias=False)
self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32))
self.rotary = Rotary(self.head_dim, base=rope_base)
def forward(self, x: Tensor) -> Tensor:
B, T, _ = x.shape
q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.c_v(x).reshape(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
q = F.rms_norm(q, (q.size(-1),))
k = F.rms_norm(k, (k.size(-1),))
cos, sin = self.rotary(T, x.device, q.dtype)
q = apply_rotary(q, cos, sin)
k = apply_rotary(k, cos, sin)
q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None]
y = F.scaled_dot_product_attention(q, k, v,
attn_mask=None, is_causal=True,
enable_gqa=(self.num_kv_heads != self.num_heads))
return self.proj(y.transpose(1, 2).contiguous().reshape(B, T, -1))
class MLP(nn.Module):
def __init__(self, dim, mlp_mult):
super().__init__()
hidden = dim * mlp_mult
self.fc = nn.Linear(dim, hidden, bias=False)
self.proj = nn.Linear(hidden, dim, bias=False)
def forward(self, x: Tensor) -> Tensor:
return self.proj(torch.relu(self.fc(x)).square())
class ParallelBlock(nn.Module):
"""Parallel Residual Block: attn and MLP run on the same normalized input.
x = resid_mix[0]*x + resid_mix[1]*x0
h = norm(x)
x = x + attn_scale * attn(h) + mlp_scale * mlp(h)
"""
def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init):
super().__init__()
self.norm = RMSNorm()
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init)
self.mlp = MLP(dim, mlp_mult)
self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.resid_mix = nn.Parameter(torch.stack([torch.ones(dim), torch.zeros(dim)]).float())
def forward(self, x: Tensor, x0: Tensor) -> Tensor:
mix = self.resid_mix.to(x.dtype)
x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
h = self.norm(x)
# Parallel: both attn and MLP operate on same normalized input
x = x + self.attn_scale.to(x.dtype)[None, None, :] * self.attn(h) \
+ self.mlp_scale.to(x.dtype)[None, None, :] * self.mlp(h)
return x
# ─────────────────────────────────────────────────────────────
# RECURRENT GPT MODEL with Score-First TTT
# ─────────────────────────────────────────────────────────────
class RecurrentGPT(nn.Module):
"""
K unique parallel-residual blocks x N recurrences.
At eval: 2N recurrences + optional score-first TTT.
"""
def __init__(self, args: Hyperparameters):
super().__init__()
self.logit_softcap = args.logit_softcap
self._train_rec = args.num_recurrences
self._eval_rec = args.num_eval_recurrences or args.num_recurrences * 2
self._vocab_size = args.vocab_size
self.tok_emb = nn.Embedding(args.vocab_size, args.model_dim)
self.blocks = nn.ModuleList([
ParallelBlock(args.model_dim, args.num_heads, args.num_kv_heads,
args.mlp_mult, args.rope_base, args.qk_gain_init)
for _ in range(args.num_unique_layers)
])
self.final_norm = RMSNorm()
nn.init.normal_(self.tok_emb.weight, std=0.005)
def _forward_hidden(self, input_ids: Tensor) -> Tensor:
x = F.rms_norm(self.tok_emb(input_ids), (self.tok_emb.embedding_dim,))
x0 = x
n = self._train_rec if self.training else self._eval_rec
for _ in range(n):
for block in self.blocks:
x = block(x, x0)
return self.final_norm(x)
def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
h = self._forward_hidden(input_ids)
logits = F.linear(h.reshape(-1, h.size(-1)), self.tok_emb.weight)
logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap)
return F.cross_entropy(logits.float(), target_ids.reshape(-1), reduction="mean")
def per_token_loss(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
h = self._forward_hidden(input_ids)
B, T, D = h.shape
logits = F.linear(h.reshape(B * T, D), self.tok_emb.weight)
logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap)
return F.cross_entropy(logits.float(), target_ids.reshape(B * T),
reduction="none").reshape(B, T)
@torch.no_grad()
def per_token_loss_with_ttt(self, input_ids: Tensor, target_ids: Tensor,
args: Hyperparameters) -> Tensor:
"""Score-first TTT: adapt MLP.proj weights chunk-by-chunk at eval.
"Score-first" means: for each chunk, we first SCORE (compute loss) with
current weights, then UPDATE weights for the next chunk. This is strictly
causal -- predictions for chunk i only use information from chunks 0..i-1.
We update MLP.proj.weight (the "down projection") in each block --
this is the "fast weight" in the In-Place TTT framework (arxiv:2604.06169).
"""
chunk_size = args.ttt_chunk_size
ttt_lr = args.ttt_lr
B, T = input_ids.shape
# Determine which layers to apply TTT
if args.ttt_layers == "all":
ttt_layer_indices = list(range(len(self.blocks)))
else:
ttt_layer_indices = [int(x) for x in args.ttt_layers.split(",")]
# Save original weights to restore after this sequence
original_weights = {}
for li in ttt_layer_indices:
original_weights[li] = self.blocks[li].mlp.proj.weight.data.clone()
all_ptl = []
n_chunks = (T + chunk_size - 1) // chunk_size
for ci in range(n_chunks):
lo = ci * chunk_size
hi = min((ci + 1) * chunk_size, T)
# Score first: full forward pass with current (possibly updated) weights
h = self._forward_hidden(input_ids) # (B, T, D)
h_chunk = h[:, lo:hi, :]
y_chunk = target_ids[:, lo:hi]
logits = F.linear(h_chunk.reshape(-1, h_chunk.size(-1)), self.tok_emb.weight)
logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap)
ptl = F.cross_entropy(logits.float(), y_chunk.reshape(-1),
reduction="none").reshape(B, hi - lo)
all_ptl.append(ptl)
# Then update: manual gradient step on MLP.proj for next chunk
if ci < n_chunks - 1:
for li in ttt_layer_indices:
block = self.blocks[li]
# Get MLP intermediate activations for this chunk
h_norm = F.rms_norm(h_chunk.reshape(-1, h_chunk.size(-1)).float(),
(h_chunk.size(-1),))
z = torch.relu(block.mlp.fc(h_norm.to(h_chunk.dtype))).square()
# Reconstruction-based update: minimize ||Z @ W^T - h_norm||^2
pred = z @ block.mlp.proj.weight.T
residual = pred - h_norm.to(pred.dtype)
grad_w = residual.T @ z / z.size(0)
block.mlp.proj.weight.data -= ttt_lr * grad_w.to(block.mlp.proj.weight.dtype)
# Restore original weights after processing this sequence
for li in ttt_layer_indices:
self.blocks[li].mlp.proj.weight.data = original_weights[li]
return torch.cat(all_ptl, dim=1)
# ─────────────────────────────────────────────────────────────
# EMA (Exponential Moving Average)
# ─────────────────────────────────────────────────────────────
class EMA:
"""Exponential Moving Average of model parameters."""
def __init__(self, model: nn.Module, decay: float = 0.999):
self.model = model
self.decay = decay
self.shadow = {n: p.data.clone() for n, p in model.named_parameters()}
self.backup = {}
def update(self):
for n, p in self.model.named_parameters():
self.shadow[n].mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
def apply(self):
"""Apply EMA weights (backup current)."""
self.backup = {}
for n, p in self.model.named_parameters():
self.backup[n] = p.data.clone()
p.data.copy_(self.shadow[n])
def restore(self):
"""Restore original weights."""
for n, p in self.model.named_parameters():
p.data.copy_(self.backup[n])
self.backup = {}
# ─────────────────────────────────────────────────────────────
# TRAINING
# ─────────────────────────────────────────────────────────────
def main():
global zeropower_via_newtonschulz5
code = Path(__file__).read_text(encoding="utf-8")
args = Hyperparameters()
zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5)
# -- distributed setup --
distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
grad_accum = max(1, 8 // world_size)
grad_scale = 1.0 / grad_accum
if not torch.cuda.is_available():
raise RuntimeError("CUDA required")
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)
if distributed:
dist.init_process_group("nccl", device_id=device)
dist.barrier()
master = rank == 0
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from torch.backends.cuda import (enable_flash_sdp, enable_math_sdp,
enable_mem_efficient_sdp, enable_cudnn_sdp)
enable_flash_sdp(True); enable_math_sdp(False)
enable_mem_efficient_sdp(False); enable_cudnn_sdp(False)
logfile = None
if master:
os.makedirs("logs", exist_ok=True)
logfile = f"logs/{args.run_id}.txt"
print(logfile)
def log0(msg, console=True):
if not master: return
if console: print(msg)
if logfile:
with open(logfile, "a") as f: print(msg, file=f)
log0(code, console=False)
log0(f"Python {sys.version}", console=False)
log0(f"PyTorch {torch.__version__}", console=False)
try:
log0(subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=False).stdout,
console=False)
except FileNotFoundError:
pass
random.seed(args.seed); np.random.seed(args.seed)
torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed)
# -- tokenizer + val data --
sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path)
base_bytes_lut, has_space_lut, is_boundary_lut = build_sentencepiece_luts(
sp, args.vocab_size, device)
val_tokens = load_validation_tokens(args.val_files, args.sw_seq_len)
log0(f"val_tokens:{val_tokens.numel()}")
# -- model --
base_model = RecurrentGPT(args).to(device).bfloat16()
compiled = torch.compile(base_model, dynamic=False, fullgraph=True)
model = DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) \
if distributed else compiled
n_unique = sum(p.numel() for p in base_model.parameters())
eff_depth = args.num_unique_layers * args.num_recurrences
log0(f"unique_params:{n_unique} effective_depth:{eff_depth} "
f"train_loops:{args.num_recurrences} eval_loops:{base_model._eval_rec}")
log0(f"world_size:{world_size} grad_accum:{grad_accum}")
# -- optimizer --
block_params = list(base_model.blocks.named_parameters())
matrix_params = [p for n, p in block_params
if p.ndim == 2 and not any(pat in n for pat in CONTROL_PATTERNS)]
scalar_params = [p for n, p in block_params
if p.ndim < 2 or any(pat in n for pat in CONTROL_PATTERNS)]
opt_tok = torch.optim.Adam(
[{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}],
betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True)
opt_muon = Muon(matrix_params, lr=args.matrix_lr,
momentum=args.muon_momentum, backend_steps=args.muon_backend_steps,
weight_decay=args.muon_weight_decay)
for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr
opt_scalar = torch.optim.Adam(
[{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}],
betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True)
optimizers = [opt_tok, opt_muon, opt_scalar]
# -- EMA --
ema = EMA(base_model, decay=0.999)
ema_start_step = int(args.iterations * args.swa_start_frac)
# -- LR schedule --
max_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None
def lr_mul(step, elapsed_ms):
if args.warmdown_iters <= 0: return 1.0
if max_ms is None:
ws = max(args.iterations - args.warmdown_iters, 0)
return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \
if ws <= step < args.iterations else 1.0
step_ms = elapsed_ms / max(step, 1)
remain = max(max_ms - elapsed_ms, 0.0)
wd_ms = args.warmdown_iters * step_ms
return remain / max(wd_ms, 1e-9) if remain <= wd_ms else 1.0
def zero_all(): [o.zero_grad(set_to_none=True) for o in optimizers]
# -- warmup --
if args.warmup_steps > 0:
init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()}
init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers]
model.train()
train_loader_w = DistributedTokenLoader(args.train_files, rank, world_size, device)
for ws_i in range(args.warmup_steps):
zero_all()
for ms_i in range(grad_accum):
if distributed:
model.require_backward_grad_sync = (ms_i == grad_accum - 1)
x, y = train_loader_w.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum)
with torch.autocast("cuda", torch.bfloat16):
(model(x, y) * grad_scale).backward()
for o in optimizers: o.step()
zero_all()
base_model.load_state_dict(init_model, strict=True)
for o, s in zip(optimizers, init_opts): o.load_state_dict(s)
zero_all()
if distributed: model.require_backward_grad_sync = True
# -- data + training loop --
train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device)
training_ms = 0.0
stop_step: int | None = None
torch.cuda.synchronize()
t0 = time.perf_counter()
step = 0
while True:
last_step = step == args.iterations or (stop_step is not None and step >= stop_step)
do_val = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)
if do_val:
torch.cuda.synchronize()
training_ms += 1000.0 * (time.perf_counter() - t0)
vl, vbpb = eval_val_sliding_window(
args, model, rank, world_size, device,
val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut,
use_ttt=False)
log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} "
f"train_ms:{training_ms:.0f} step_avg:{training_ms/max(step,1):.2f}ms")
torch.cuda.synchronize()
t0 = time.perf_counter()
if last_step:
if master:
# Apply EMA weights for final model
ema.apply()
# Evaluate with TTT
log0("Evaluating with EMA + TTT...")
vl_ema, vbpb_ema = eval_val_sliding_window(
args, base_model, rank, world_size, device,
val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut,
use_ttt=True)
log0(f"ema_ttt val_loss:{vl_ema:.4f} val_bpb:{vbpb_ema:.4f}")
# Quantize and export
sd = base_model.state_dict()
obj, stats = quantize_state_dict(sd, args.gptq_bits, args.sdclip_nstd)
buf = io.BytesIO()
torch.save(obj, buf)
compressed = zlib.compress(buf.getvalue(), level=9)
code_bytes = len(code.encode())
model_bytes = len(compressed)
total_bytes = code_bytes + model_bytes
log0(f"final_quant_zlib_roundtrip "
f"code_bytes:{code_bytes} "
f"model_compressed_bytes:{model_bytes} "
f"total_artifact_bytes:{total_bytes} "
f"total_artifact_mb:{total_bytes/1e6:.3f} "
f"param_count:{stats['param_count']}")
# Round-trip verify
sd2 = dequantize_state_dict(obj)
base_model.load_state_dict(sd2, strict=True)
vl2, vbpb2 = eval_val_sliding_window(
args, base_model, rank, world_size, device,
val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut,
use_ttt=True)
log0(f"quantized_model+ttt val_loss:{vl2:.4f} val_bpb:{vbpb2:.4f}")
# Restore non-EMA weights
ema.restore()
break
if stop_step is None and max_ms is not None:
torch.cuda.synchronize()
elapsed = 1000.0 * (time.perf_counter() - t0) + training_ms
if elapsed >= max_ms:
stop_step = step + 1
zero_all()
for ms_i in range(grad_accum):
if distributed:
model.require_backward_grad_sync = (ms_i == grad_accum - 1)
x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum)
with torch.autocast("cuda", torch.bfloat16):
(model(x, y) * grad_scale).backward()
torch.cuda.synchronize()
elapsed_ms = 1000.0 * (time.perf_counter() - t0) + training_ms
m = lr_mul(step, elapsed_ms)
for o in optimizers:
for g in o.param_groups: g["lr"] = g["base_lr"] * m
for o in optimizers: o.step()
# EMA update
if step >= ema_start_step:
ema.update()
if step % args.train_log_every == 0 and master:
log0(f"step:{step} lr_mul:{m:.4f}")
step += 1
if __name__ == "__main__":
main()