10 Ideas That Turned a 45-Minute Training Run Into 90 Seconds

Community Article Published March 20, 2026

The modded-nanogpt competition took Andrej Karpathy's clean nanoGPT baseline and asked a simple question: how fast can we train GPT-2 (124M) to 3.28 validation loss on 8xH100s? The answer went from 45 minutes to under 2 minutes across 77+ community-submitted records. Here are the ten most interesting ideas that emerged, with code showing how each one would look applied to nanoGPT.

Gemini_Generated_Image_605nug605nug605n


1. ReLU² Activation

The competition replaced GELU with squared ReLU (ReLU²) in the MLP blocks. The idea is simple: apply ReLU, then square the result. This produces sparser activations than GELU — most values are zeroed out, and the survivors get amplified quadratically. The result is faster to compute (no erf approximation) and empirically trains just as well or better for this model size, likely because the induced sparsity acts as a natural regularizer.

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc   = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
-       x = self.gelu(x)
+       x = F.relu(x) ** 2  # ReLU² — sparse, fast, and effective
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

2. Rotary Position Embeddings (RoPE)

NanoGPT uses learned absolute positional embeddings — a separate embedding table indexed by position. RoPE instead encodes position by rotating query and key vectors in pairs of dimensions, with rotation angles that vary by frequency. This makes attention naturally sensitive to relative position rather than absolute index, generalizes better to unseen sequence lengths, and removes an entire parameter matrix from the model. The modded variant further extends this with YaRN frequency interpolation to reach 64K context.

class RotaryEmbedding(nn.Module):

    def __init__(self, dim, max_seq_len=1024, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        t = torch.arange(max_seq_len).float()
        freqs = torch.outer(t, inv_freq)
        self.register_buffer("cos", freqs.cos())
        self.register_buffer("sin", freqs.sin())

    def forward(self, x, seq_len):
        return self.cos[:seq_len], self.sin[:seq_len]


def apply_rope(x, cos, sin):
    """Apply rotary embeddings to query or key tensor."""
    # x: (B, n_head, T, head_dim)
    d = x.shape[-1] // 2
    x1, x2 = x[..., :d], x[..., d:]
    rotated = torch.cat((-x2, x1), dim=-1)
    return x * cos + rotated * sin


class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        head_dim = config.n_embd // config.n_head
+       self.rope = RotaryEmbedding(head_dim, max_seq_len=config.block_size)

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

+       cos, sin = self.rope(q, T)
+       q = apply_rope(q, cos, sin)
+       k = apply_rope(k, cos, sin)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)

3. QK-Norm

A deceptively simple change: normalize the query and key vectors before computing attention scores. Without normalization, the dot products between Q and K can grow in magnitude during training, leading to extremely peaked softmax distributions that starve gradients. QK-Norm keeps the attention logits in a stable range throughout training, which allows higher learning rates and faster convergence. It's implemented as an RMSNorm (cheaper than LayerNorm since it skips the mean-centering step).

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
        return (x * rms).to(x.dtype) * self.weight


class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        head_dim = config.n_embd // config.n_head
+       self.q_norm = RMSNorm(head_dim)
+       self.k_norm = RMSNorm(head_dim)

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

+       q = self.q_norm(q)
+       k = self.k_norm(k)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)

4. The Muon Optimizer

The single biggest algorithmic breakthrough of the competition. Muon replaces AdamW for weight matrices with a fundamentally different update rule: compute the momentum-smoothed gradient, then orthogonalize it via a matrix sign function before applying the update. The intuition is that for 2D weight matrices, the optimal update direction lives on the Stiefel manifold (the space of orthonormal matrices), and projecting the gradient onto this manifold produces better-conditioned updates. The Polar Express method computes this projection cheaply via 5 iterations of a Newton-Schulz recurrence with precomputed coefficients, avoiding a full SVD.

class Muon(torch.optim.Optimizer):
    """Muon optimizer for 2D weight matrices.
    Uses momentum + orthogonalization via Newton-Schulz iterations."""

    # Precomputed Polar Express coefficients (a, b, c) for 5 iterations
    POLAR_COEFFS = [
        (8.1566, -22.4833, 15.8788),
        (4.0429,  -2.8089,  0.5000),
        (3.8917,  -2.7725,  0.5061),
        (3.2858,  -2.3681,  0.4645),
        (2.3465,  -1.7098,  0.4232),
    ]

    def __init__(self, params, lr=0.05, momentum=0.95):
        defaults = dict(lr=lr, momentum=momentum)
        super().__init__(params, defaults)

    @torch.no_grad()
    def _orthogonalize(self, grad):
        """Project gradient onto orthonormal manifold via Polar Express."""
        # Ensure the matrix is "tall" (rows >= cols) for numerical stability
        original_shape = grad.shape
        if grad.shape[0] < grad.shape[1]:
            grad = grad.T
        # Normalize so the spectral norm is ~1
        grad = grad / (grad.norm() + 1e-30)
        # 5 iterations of: X = (a*I + b*A + c*A²) @ X where A = X @ X^T
        for a, b, c in self.POLAR_COEFFS:
            A = grad @ grad.T
            grad = a * grad + (b * A + c * A @ A) @ grad
        if grad.shape != original_shape:
            grad = grad.T
        return grad

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            beta = group["momentum"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p]
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(p.grad)
                buf = state["momentum_buffer"]
                buf.mul_(beta).add_(p.grad)
                update = self._orthogonalize(buf)
                p.add_(update, alpha=-lr)

5. Multi-Token Prediction

Instead of only predicting the next token at each position, multi-token prediction (MTP) trains the model to simultaneously predict 2 or 3 tokens ahead as well. Each lookahead distance gets its own lightweight projection head, and their losses are combined with declining weights — the 3-tokens-ahead prediction matters less than the 1-token-ahead prediction. This provides a richer training signal per forward pass: the model receives gradients from multiple supervision targets for every position, dramatically improving token efficiency. In the competition, the MTP weights are scheduled to shift from long-range to short-range as training progresses.

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        # ... (standard nanoGPT init) ...
+       # Multi-token prediction heads for 2 and 3 tokens ahead
+       self.mtp_heads = nn.ModuleList([
+           nn.Linear(config.n_embd, config.vocab_size, bias=False)
+           for _ in range(2)  # heads for +2 and +3 token prediction
+       ])

    def forward(self, idx, targets=None):
        b, t = idx.size()
        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(torch.arange(t, device=idx.device))
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # Standard next-token prediction (weight = 1.0)
            logits = self.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1)
            )

+           # Multi-token prediction: predict +2 and +3 tokens ahead
+           mtp_weights = [0.5, 0.25]  # declining weights for further predictions
+           for i, (head, w) in enumerate(zip(self.mtp_heads, mtp_weights)):
+               shift = i + 2  # predict 2 and 3 tokens ahead
+               if t > shift:
+                   mtp_logits = head(x[:, :-shift, :])
+                   mtp_targets = targets[:, shift:]  # shifted target tokens
+                   loss += w * F.cross_entropy(
+                       mtp_logits.reshape(-1, mtp_logits.size(-1)),
+                       mtp_targets.reshape(-1),
+                   )
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None

        return logits, loss

6. Sliding Window Attention

Full causal attention lets every token attend to all previous tokens — expensive and arguably wasteful for early layers that primarily need local context. Sliding window attention restricts each token to only attend to a fixed window of recent tokens, cutting the cost of attention from O(T²) toward O(T×W) where W is the window size. The competition takes this further with progressive windows: early layers get small windows (384 tokens), deeper layers get larger ones (up to 1,664 tokens). This mirrors the intuition that shallow layers handle local syntax while deep layers need global context.

class CausalSelfAttention(nn.Module):

    def __init__(self, config, layer_idx=0):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
+       # Progressive window sizes: early layers see locally, deep layers see globally
+       window_schedule = [384, 384, 384, 896, 896, 896, 1408, 1408, 1664, 1664]
+       self.window_size = window_schedule[min(layer_idx, len(window_schedule) - 1)]

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

+       # Build a sliding window causal mask
+       if T > self.window_size:
+           # Causal mask: attend only to the last `window_size` tokens
+           mask = torch.ones(T, T, dtype=torch.bool, device=x.device).tril()
+           window_mask = torch.ones(T, T, dtype=torch.bool, device=x.device).triu(
+               diagonal=-(self.window_size - 1)
+           )
+           mask = mask & window_mask
+           y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
+       else:
+           y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)

7. U-Net Skip Connections

Transformers stack blocks sequentially, and the only shortcut for information is the residual stream within each block. U-Net skip connections add cross-layer residual paths — for example, connecting the output of block 3 directly into the input of block 6. This gives the later layers direct access to earlier representations without them having to survive passage through intermediate blocks. The idea is borrowed from U-Net in computer vision, where encoder-decoder skip connections are essential for preserving spatial detail. In the competition, this was one of the later records that provided a clean accuracy win at near-zero computational cost.

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte  = nn.Embedding(config.vocab_size, config.n_embd),
            wpe  = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h    = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+       # Learnable gate for the skip connection (initialized near zero)
+       self.skip_gate = nn.Parameter(torch.zeros(1))
+       # Which layers to connect: skip from layer 3 into layer 6
+       self.skip_from = 3
+       self.skip_to = 6

    def forward(self, idx, targets=None):
        b, t = idx.size()
        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(torch.arange(t, device=idx.device))
        x = self.transformer.drop(tok_emb + pos_emb)

+       saved = None
        for i, block in enumerate(self.transformer.h):
+           if i == self.skip_from:
+               saved = x  # save activations at layer 3
            x = block(x)
+           if i == self.skip_to and saved is not None:
+               x = x + self.skip_gate * saved  # inject into layer 6

        x = self.transformer.ln_f(x)
        # ... (rest of forward as usual)

8. Value Embeddings

Standard transformers embed input tokens once and then rely entirely on the transformer layers to refine representations. Value embeddings add a separate set of embedding tables whose outputs are gated directly into the attention mechanism's value stream. This gives the model a "shortcut" path where raw token identity information can bypass the query-key routing and flow directly into the output. Multiple value embedding tables (3–5 in the competition) are used, each gated by a learned scalar, so the model can control how much direct embedding information leaks through at each layer.

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
+       # Separate value embedding table, gated into the value stream
+       self.val_embed = nn.Embedding(config.vocab_size, config.n_embd)
+       self.val_gate = nn.Parameter(torch.zeros(1))  # starts at 0, learned

    def forward(self, x, input_ids=None):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

+       # Mix in value embeddings looked up from the original token ids
+       if input_ids is not None:
+           ve = self.val_embed(input_ids)  # (B, T, C)
+           ve = ve.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
+           v = v + self.val_gate * ve

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)

9. Logit Soft-Capping

During training, logit values can spike to extreme magnitudes — a single logit at +50 dominates the softmax so completely that the cross-entropy loss produces near-zero gradients for all other tokens. Logit soft-capping applies a smooth ceiling using tanh to compress logits into a bounded range before computing the loss. With a cap of 30 (i.e., 30 * tanh(logits / 30)), logits near zero pass through nearly unchanged, but extreme values get smoothly squashed. This stabilizes training and allows more aggressive learning rates without divergence.

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        # ... (standard nanoGPT init) ...
+       self.logit_cap = 30.0

    def forward(self, idx, targets=None):
        # ... (standard transformer forward) ...
        x = self.transformer.ln_f(x)

        if targets is not None:
            logits = self.lm_head(x)
+           # Soft-cap logits to prevent extreme values
+           logits = self.logit_cap * torch.tanh(logits / self.logit_cap)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1)
            )
        else:
            logits = self.lm_head(x[:, [-1], :])
+           logits = self.logit_cap * torch.tanh(logits / self.logit_cap)
            loss = None

        return logits, loss

10. Cautious Weight Decay

Standard weight decay nudges all parameters toward zero every step, regardless of what the gradient is doing. Cautious weight decay only applies the decay when the parameter's sign matches the gradient's sign — meaning both the gradient and the decay are pushing in the same direction. When they disagree (the gradient wants to increase a weight but decay wants to shrink it), the decay is skipped entirely. This prevents the regularizer from fighting the optimizer on parameters that are actively being pushed away from zero, leading to cleaner optimization dynamics and better final loss.

class CautiousAdamW(torch.optim.Optimizer):
    """AdamW with sign-gated weight decay."""

    def __init__(self, params, lr=6e-4, betas=(0.9, 0.95), eps=1e-8,
                 weight_decay=0.1):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            wd = group["weight_decay"]
            b1, b2 = group["betas"]
            eps = group["eps"]

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                state = self.state[p]

                if len(state) == 0:
                    state["step"] = 0
                    state["m"] = torch.zeros_like(p)
                    state["v"] = torch.zeros_like(p)

                state["step"] += 1
                m, v = state["m"], state["v"]

                # Standard Adam moment updates
                m.mul_(b1).add_(grad, alpha=1 - b1)
                v.mul_(b2).addcmul_(grad, grad, value=1 - b2)

                # Bias correction
                bc1 = 1 - b1 ** state["step"]
                bc2 = 1 - b2 ** state["step"]
                m_hat = m / bc1
                v_hat = v / bc2

                # Adam update
                update = m_hat / (v_hat.sqrt() + eps)
                p.add_(update, alpha=-lr)

+               # Cautious weight decay: only decay when signs agree
+               if wd > 0:
+                   mask = (p.data.sign() == grad.sign()).float()
+                   p.data.mul_(1 - lr * wd * mask)

Closing Thoughts

What makes the modded-nanogpt competition remarkable isn't any single idea — it's the compounding of 77+ incremental improvements, each validated against a strict reproducibility protocol. Many of these ideas (RoPE, ReLU², QK-Norm) are well-known in the literature but had never been combined and tuned together at this scale. Others (Muon, Polar Express, value embeddings) are genuinely novel contributions that emerged from the competitive pressure.

The result: training that took 10 billion tokens in the baseline now reaches the same loss with ~500 million tokens, and what took 45 minutes now takes 90 seconds. That's a 20x improvement in data efficiency and a 30x improvement in wall-clock time — all on the same hardware, the same dataset, and the same target.

The code is open. The records are public. The leaderboard is still accepting submissions.

Community

Sign up or log in to comment