Kernels
3v324v23 Claude Opus 4.6 (1M context) commited on
Commit
0c42208
·
1 Parent(s): 7e86d2e

test: numerical parity for MLA RoPE fused kernels vs PyTorch reference

Browse files

Adds tests/test_mla_rope_grad.py: forward + backward parity between
fused_q_rope_inplace + fused_kv_split_rope_cat (composed with the
PyTorch-native head-shared k_pe RoPE) versus the pre-fusion path
(split + view_as_complex/complex_mul + reorder_headdim + cat).

Self-contained reference — no upstream model code dependency. Uses
existing `tests/utils.assert_close` and pytest parametrize convention.

Shapes match motif3_seq at local_batch_size=8 on a single GPU
(B=8, S=4096, H_q=80, H_kv=16, D_qk=192, D_v=128, D_rope=64, bf16).
KV-side outputs and grads are bit-exact (tol=0); q_total / grad_q match
within bf16 rounding (rtol=atol=1e-2 in fp32-promoted comparison).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. tests/test_mla_rope_grad.py +142 -0
tests/test_mla_rope_grad.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Numerical parity test: activation MLA RoPE kernels vs PyTorch reference.
2
+
3
+ The activation package exposes two Triton kernels for Motif3 MLA attention:
4
+ * fused_q_rope_inplace — in-place RoPE on q's rope section
5
+ * fused_kv_split_rope_cat — split kv_latent + register-broadcast k_pe to H heads + cat
6
+
7
+ This test runs both the fused path and a pure-PyTorch reference over identical
8
+ inputs (forward + backward) and compares all outputs and input gradients.
9
+
10
+ Self-contained: the reference RoPE implementation lives in this file (no
11
+ upstream model code dependency).
12
+ """
13
+
14
+ import pytest
15
+ import torch
16
+
17
+ import activation
18
+
19
+ from .utils import assert_close
20
+
21
+
22
+ # Realistic motif3_seq per-GPU shapes (B=local_batch_size, H_q/H_kv per MLA spec).
23
+ SHAPES = [
24
+ # (B, S, H_q, H_kv, D_nope, D_rope, D_v)
25
+ (8, 4096, 80, 16, 128, 64, 128),
26
+ ]
27
+ DTYPES = [torch.bfloat16]
28
+ SEEDS = [0]
29
+
30
+
31
+ # ------------------------------------------------------------------ reference
32
+
33
+ def _precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
34
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
35
+ t = torch.arange(end, dtype=torch.float32)
36
+ freqs = torch.outer(t, freqs)
37
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
38
+
39
+
40
+ def _apply_rotary_emb_single(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
41
+ """[B, S, H, D] interleaved → rotated, in interleaved layout."""
42
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
43
+ freqs_cis = freqs_cis[: x_.shape[1]].view(1, x_.shape[1], 1, x_.shape[3])
44
+ out = torch.view_as_real(x_ * freqs_cis).flatten(3)
45
+ return out.type_as(x)
46
+
47
+
48
+ def _reorder_headdim_elements_rope(qk: torch.Tensor, B: int, S: int, rope_dim: int) -> torch.Tensor:
49
+ """Interleaved [r0,i0,r1,i1,...] → contiguous [r0,r1,...,i0,i1,...]."""
50
+ qk = qk.view(B, S, -1, rope_dim // 2, 2)
51
+ qk = qk.transpose(3, 4)
52
+ return qk.reshape(B, S, -1, rope_dim)
53
+
54
+
55
+ def vanilla_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v):
56
+ # Q
57
+ q_nope, q_pe = torch.split(q, [D_nope, D_rope], dim=-1)
58
+ q_pe = _apply_rotary_emb_single(q_pe, freqs_cis)
59
+ q_pe = _reorder_headdim_elements_rope(q_pe, B, S, D_rope)
60
+ q_total = torch.cat([q_nope, q_pe], dim=-1)
61
+ # k_pe (head-shared, H=1)
62
+ k_pe_4d = k_pe.unsqueeze(2)
63
+ k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis)
64
+ k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope)
65
+ # KV split + expand + cat
66
+ k_nope, v = torch.split(kv_latent, [D_nope, D_v], dim=-1)
67
+ k_full = torch.cat([k_nope, k_pe_roped.expand(-1, -1, H_kv, -1)], dim=-1)
68
+ return q_total, k_full, v
69
+
70
+
71
+ def fused_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v):
72
+ q_total = activation.fused_q_rope_inplace(q, freqs_cis, D_nope, D_rope)
73
+ # k_pe RoPE stays PyTorch native (head-shared; standalone Triton kernel was
74
+ # launch-bound on B200, no measurable win — see PR #22).
75
+ k_pe_4d = k_pe.unsqueeze(2)
76
+ k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis)
77
+ k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope)
78
+ k_full, v = activation.fused_kv_split_rope_cat(
79
+ kv_latent, k_pe_roped, D_nope, D_v, D_rope
80
+ )
81
+ return q_total, k_full, v
82
+
83
+
84
+ # ------------------------------------------------------------------ harness
85
+
86
+ def _run_with_grad(path_fn, q, kv_latent, k_pe, freqs_cis, **shape_kwargs):
87
+ # Inputs come in as leaves; thread through a no-op so the in-place fused_q
88
+ # kernel sees a non-leaf (mirrors the real model where q is a Linear output).
89
+ q_leaf, kv_leaf, kpe_leaf = (
90
+ q.clone().detach().requires_grad_(True),
91
+ kv_latent.clone().detach().requires_grad_(True),
92
+ k_pe.clone().detach().requires_grad_(True),
93
+ )
94
+ q_in, kv_in, kpe_in = q_leaf * 1.0, kv_leaf * 1.0, kpe_leaf * 1.0
95
+
96
+ q_total, k_full, v = path_fn(q_in, kv_in, kpe_in, freqs_cis, **shape_kwargs)
97
+ loss = (q_total.float() ** 2).sum() + (k_full.float() ** 2).sum() + (v.float() ** 2).sum()
98
+ loss.backward()
99
+
100
+ return (
101
+ q_total.detach(), k_full.detach(), v.detach(),
102
+ q_leaf.grad.detach(), kv_leaf.grad.detach(), kpe_leaf.grad.detach(),
103
+ )
104
+
105
+
106
+ # ------------------------------------------------------------------ test
107
+
108
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
109
+ @pytest.mark.parametrize("shape", SHAPES)
110
+ @pytest.mark.parametrize("dtype", DTYPES)
111
+ @pytest.mark.parametrize("seed", SEEDS)
112
+ def test_mla_rope_fused_vs_reference(shape, dtype, seed):
113
+ B, S, H_q, H_kv, D_nope, D_rope, D_v = shape
114
+ D_qk = D_nope + D_rope
115
+ device = "cuda"
116
+
117
+ torch.manual_seed(seed)
118
+ freqs_cis = _precompute_freqs_cis(D_rope, S).to(device)
119
+
120
+ q = (torch.randn(B, S, H_q, D_qk, device=device, dtype=dtype) * 0.5)
121
+ kv_latent = (torch.randn(B, S, H_kv, D_nope + D_v, device=device, dtype=dtype) * 0.5)
122
+ k_pe = (torch.randn(B, S, D_rope, device=device, dtype=dtype) * 0.5)
123
+
124
+ kw = dict(B=B, S=S, H_kv=H_kv, D_nope=D_nope, D_rope=D_rope, D_v=D_v)
125
+ van_q, van_k, van_v, van_gq, van_gkv, van_gkpe = _run_with_grad(
126
+ vanilla_path, q, kv_latent, k_pe, freqs_cis, **kw
127
+ )
128
+ our_q, our_k, our_v, our_gq, our_gkv, our_gkpe = _run_with_grad(
129
+ fused_path, q, kv_latent, k_pe, freqs_cis, **kw
130
+ )
131
+
132
+ # Forward outputs: small bf16 jitter expected on the q rope rotation
133
+ # (Triton fp32 accum vs inductor fp32 complex_mul order).
134
+ assert_close(our_q.float(), van_q.float(), atol=1e-2, rtol=1e-2)
135
+ # KV path is bit-exact (just slice + register broadcast + store).
136
+ assert_close(our_k.float(), van_k.float(), atol=0.0, rtol=0.0)
137
+ assert_close(our_v.float(), van_v.float(), atol=0.0, rtol=0.0)
138
+
139
+ # Input grads.
140
+ assert_close(our_gq.float(), van_gq.float(), atol=1e-2, rtol=1e-2)
141
+ assert_close(our_gkv.float(), van_gkv.float(), atol=0.0, rtol=0.0)
142
+ assert_close(our_gkpe.float(), van_gkpe.float(), atol=0.0, rtol=0.0)