test: numerical parity for MLA RoPE fused kernels vs PyTorch reference
Browse filesAdds tests/test_mla_rope_grad.py: forward + backward parity between
fused_q_rope_inplace + fused_kv_split_rope_cat (composed with the
PyTorch-native head-shared k_pe RoPE) versus the pre-fusion path
(split + view_as_complex/complex_mul + reorder_headdim + cat).
Self-contained reference — no upstream model code dependency. Uses
existing `tests/utils.assert_close` and pytest parametrize convention.
Shapes match motif3_seq at local_batch_size=8 on a single GPU
(B=8, S=4096, H_q=80, H_kv=16, D_qk=192, D_v=128, D_rope=64, bf16).
KV-side outputs and grads are bit-exact (tol=0); q_total / grad_q match
within bf16 rounding (rtol=atol=1e-2 in fp32-promoted comparison).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- tests/test_mla_rope_grad.py +142 -0
tests/test_mla_rope_grad.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Numerical parity test: activation MLA RoPE kernels vs PyTorch reference.
|
| 2 |
+
|
| 3 |
+
The activation package exposes two Triton kernels for Motif3 MLA attention:
|
| 4 |
+
* fused_q_rope_inplace — in-place RoPE on q's rope section
|
| 5 |
+
* fused_kv_split_rope_cat — split kv_latent + register-broadcast k_pe to H heads + cat
|
| 6 |
+
|
| 7 |
+
This test runs both the fused path and a pure-PyTorch reference over identical
|
| 8 |
+
inputs (forward + backward) and compares all outputs and input gradients.
|
| 9 |
+
|
| 10 |
+
Self-contained: the reference RoPE implementation lives in this file (no
|
| 11 |
+
upstream model code dependency).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import pytest
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
import activation
|
| 18 |
+
|
| 19 |
+
from .utils import assert_close
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Realistic motif3_seq per-GPU shapes (B=local_batch_size, H_q/H_kv per MLA spec).
|
| 23 |
+
SHAPES = [
|
| 24 |
+
# (B, S, H_q, H_kv, D_nope, D_rope, D_v)
|
| 25 |
+
(8, 4096, 80, 16, 128, 64, 128),
|
| 26 |
+
]
|
| 27 |
+
DTYPES = [torch.bfloat16]
|
| 28 |
+
SEEDS = [0]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ------------------------------------------------------------------ reference
|
| 32 |
+
|
| 33 |
+
def _precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
|
| 34 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 35 |
+
t = torch.arange(end, dtype=torch.float32)
|
| 36 |
+
freqs = torch.outer(t, freqs)
|
| 37 |
+
return torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _apply_rotary_emb_single(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
"""[B, S, H, D] interleaved → rotated, in interleaved layout."""
|
| 42 |
+
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 43 |
+
freqs_cis = freqs_cis[: x_.shape[1]].view(1, x_.shape[1], 1, x_.shape[3])
|
| 44 |
+
out = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
| 45 |
+
return out.type_as(x)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _reorder_headdim_elements_rope(qk: torch.Tensor, B: int, S: int, rope_dim: int) -> torch.Tensor:
|
| 49 |
+
"""Interleaved [r0,i0,r1,i1,...] → contiguous [r0,r1,...,i0,i1,...]."""
|
| 50 |
+
qk = qk.view(B, S, -1, rope_dim // 2, 2)
|
| 51 |
+
qk = qk.transpose(3, 4)
|
| 52 |
+
return qk.reshape(B, S, -1, rope_dim)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def vanilla_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v):
|
| 56 |
+
# Q
|
| 57 |
+
q_nope, q_pe = torch.split(q, [D_nope, D_rope], dim=-1)
|
| 58 |
+
q_pe = _apply_rotary_emb_single(q_pe, freqs_cis)
|
| 59 |
+
q_pe = _reorder_headdim_elements_rope(q_pe, B, S, D_rope)
|
| 60 |
+
q_total = torch.cat([q_nope, q_pe], dim=-1)
|
| 61 |
+
# k_pe (head-shared, H=1)
|
| 62 |
+
k_pe_4d = k_pe.unsqueeze(2)
|
| 63 |
+
k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis)
|
| 64 |
+
k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope)
|
| 65 |
+
# KV split + expand + cat
|
| 66 |
+
k_nope, v = torch.split(kv_latent, [D_nope, D_v], dim=-1)
|
| 67 |
+
k_full = torch.cat([k_nope, k_pe_roped.expand(-1, -1, H_kv, -1)], dim=-1)
|
| 68 |
+
return q_total, k_full, v
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def fused_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v):
|
| 72 |
+
q_total = activation.fused_q_rope_inplace(q, freqs_cis, D_nope, D_rope)
|
| 73 |
+
# k_pe RoPE stays PyTorch native (head-shared; standalone Triton kernel was
|
| 74 |
+
# launch-bound on B200, no measurable win — see PR #22).
|
| 75 |
+
k_pe_4d = k_pe.unsqueeze(2)
|
| 76 |
+
k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis)
|
| 77 |
+
k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope)
|
| 78 |
+
k_full, v = activation.fused_kv_split_rope_cat(
|
| 79 |
+
kv_latent, k_pe_roped, D_nope, D_v, D_rope
|
| 80 |
+
)
|
| 81 |
+
return q_total, k_full, v
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ------------------------------------------------------------------ harness
|
| 85 |
+
|
| 86 |
+
def _run_with_grad(path_fn, q, kv_latent, k_pe, freqs_cis, **shape_kwargs):
|
| 87 |
+
# Inputs come in as leaves; thread through a no-op so the in-place fused_q
|
| 88 |
+
# kernel sees a non-leaf (mirrors the real model where q is a Linear output).
|
| 89 |
+
q_leaf, kv_leaf, kpe_leaf = (
|
| 90 |
+
q.clone().detach().requires_grad_(True),
|
| 91 |
+
kv_latent.clone().detach().requires_grad_(True),
|
| 92 |
+
k_pe.clone().detach().requires_grad_(True),
|
| 93 |
+
)
|
| 94 |
+
q_in, kv_in, kpe_in = q_leaf * 1.0, kv_leaf * 1.0, kpe_leaf * 1.0
|
| 95 |
+
|
| 96 |
+
q_total, k_full, v = path_fn(q_in, kv_in, kpe_in, freqs_cis, **shape_kwargs)
|
| 97 |
+
loss = (q_total.float() ** 2).sum() + (k_full.float() ** 2).sum() + (v.float() ** 2).sum()
|
| 98 |
+
loss.backward()
|
| 99 |
+
|
| 100 |
+
return (
|
| 101 |
+
q_total.detach(), k_full.detach(), v.detach(),
|
| 102 |
+
q_leaf.grad.detach(), kv_leaf.grad.detach(), kpe_leaf.grad.detach(),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ------------------------------------------------------------------ test
|
| 107 |
+
|
| 108 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
| 109 |
+
@pytest.mark.parametrize("shape", SHAPES)
|
| 110 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 111 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 112 |
+
def test_mla_rope_fused_vs_reference(shape, dtype, seed):
|
| 113 |
+
B, S, H_q, H_kv, D_nope, D_rope, D_v = shape
|
| 114 |
+
D_qk = D_nope + D_rope
|
| 115 |
+
device = "cuda"
|
| 116 |
+
|
| 117 |
+
torch.manual_seed(seed)
|
| 118 |
+
freqs_cis = _precompute_freqs_cis(D_rope, S).to(device)
|
| 119 |
+
|
| 120 |
+
q = (torch.randn(B, S, H_q, D_qk, device=device, dtype=dtype) * 0.5)
|
| 121 |
+
kv_latent = (torch.randn(B, S, H_kv, D_nope + D_v, device=device, dtype=dtype) * 0.5)
|
| 122 |
+
k_pe = (torch.randn(B, S, D_rope, device=device, dtype=dtype) * 0.5)
|
| 123 |
+
|
| 124 |
+
kw = dict(B=B, S=S, H_kv=H_kv, D_nope=D_nope, D_rope=D_rope, D_v=D_v)
|
| 125 |
+
van_q, van_k, van_v, van_gq, van_gkv, van_gkpe = _run_with_grad(
|
| 126 |
+
vanilla_path, q, kv_latent, k_pe, freqs_cis, **kw
|
| 127 |
+
)
|
| 128 |
+
our_q, our_k, our_v, our_gq, our_gkv, our_gkpe = _run_with_grad(
|
| 129 |
+
fused_path, q, kv_latent, k_pe, freqs_cis, **kw
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Forward outputs: small bf16 jitter expected on the q rope rotation
|
| 133 |
+
# (Triton fp32 accum vs inductor fp32 complex_mul order).
|
| 134 |
+
assert_close(our_q.float(), van_q.float(), atol=1e-2, rtol=1e-2)
|
| 135 |
+
# KV path is bit-exact (just slice + register broadcast + store).
|
| 136 |
+
assert_close(our_k.float(), van_k.float(), atol=0.0, rtol=0.0)
|
| 137 |
+
assert_close(our_v.float(), van_v.float(), atol=0.0, rtol=0.0)
|
| 138 |
+
|
| 139 |
+
# Input grads.
|
| 140 |
+
assert_close(our_gq.float(), van_gq.float(), atol=1e-2, rtol=1e-2)
|
| 141 |
+
assert_close(our_gkv.float(), van_gkv.float(), atol=0.0, rtol=0.0)
|
| 142 |
+
assert_close(our_gkpe.float(), van_gkpe.float(), atol=0.0, rtol=0.0)
|