"""Numerical parity test: activation MLA RoPE kernels vs PyTorch reference. The activation package exposes two Triton kernels for Motif3 MLA attention: * fused_q_rope_inplace — in-place RoPE on q's rope section * fused_kv_split_rope_cat — split kv_latent + register-broadcast k_pe to H heads + cat This test runs both the fused path and a pure-PyTorch reference over identical inputs (forward + backward) and compares all outputs and input gradients. Self-contained: the reference RoPE implementation lives in this file (no upstream model code dependency). """ import pytest import torch import activation from .utils import assert_close # Realistic motif3_seq per-GPU shapes (B=local_batch_size, H_q/H_kv per MLA spec). SHAPES = [ # (B, S, H_q, H_kv, D_nope, D_rope, D_v) (8, 4096, 80, 16, 128, 64, 128), ] DTYPES = [torch.bfloat16] SEEDS = [0] # ------------------------------------------------------------------ reference def _precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, dtype=torch.float32) freqs = torch.outer(t, freqs) return torch.polar(torch.ones_like(freqs), freqs) # complex64 def _apply_rotary_emb_single(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """[B, S, H, D] interleaved → rotated, in interleaved layout.""" x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis[: x_.shape[1]].view(1, x_.shape[1], 1, x_.shape[3]) out = torch.view_as_real(x_ * freqs_cis).flatten(3) return out.type_as(x) def _reorder_headdim_elements_rope(qk: torch.Tensor, B: int, S: int, rope_dim: int) -> torch.Tensor: """Interleaved [r0,i0,r1,i1,...] → contiguous [r0,r1,...,i0,i1,...].""" qk = qk.view(B, S, -1, rope_dim // 2, 2) qk = qk.transpose(3, 4) return qk.reshape(B, S, -1, rope_dim) def vanilla_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v): # Q q_nope, q_pe = torch.split(q, [D_nope, D_rope], dim=-1) q_pe = _apply_rotary_emb_single(q_pe, freqs_cis) q_pe = _reorder_headdim_elements_rope(q_pe, B, S, D_rope) q_total = torch.cat([q_nope, q_pe], dim=-1) # k_pe (head-shared, H=1) k_pe_4d = k_pe.unsqueeze(2) k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis) k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope) # KV split + expand + cat k_nope, v = torch.split(kv_latent, [D_nope, D_v], dim=-1) k_full = torch.cat([k_nope, k_pe_roped.expand(-1, -1, H_kv, -1)], dim=-1) return q_total, k_full, v def fused_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v): q_total = activation.fused_q_rope_inplace(q, freqs_cis, D_nope, D_rope) # k_pe RoPE stays PyTorch native (head-shared; standalone Triton kernel was # launch-bound on B200, no measurable win — see PR #22). k_pe_4d = k_pe.unsqueeze(2) k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis) k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope) k_full, v = activation.fused_kv_split_rope_cat( kv_latent, k_pe_roped, D_nope, D_v, D_rope ) return q_total, k_full, v # ------------------------------------------------------------------ harness def _run_with_grad(path_fn, q, kv_latent, k_pe, freqs_cis, **shape_kwargs): # Inputs come in as leaves; thread through a no-op so the in-place fused_q # kernel sees a non-leaf (mirrors the real model where q is a Linear output). q_leaf, kv_leaf, kpe_leaf = ( q.clone().detach().requires_grad_(True), kv_latent.clone().detach().requires_grad_(True), k_pe.clone().detach().requires_grad_(True), ) q_in, kv_in, kpe_in = q_leaf * 1.0, kv_leaf * 1.0, kpe_leaf * 1.0 q_total, k_full, v = path_fn(q_in, kv_in, kpe_in, freqs_cis, **shape_kwargs) loss = (q_total.float() ** 2).sum() + (k_full.float() ** 2).sum() + (v.float() ** 2).sum() loss.backward() return ( q_total.detach(), k_full.detach(), v.detach(), q_leaf.grad.detach(), kv_leaf.grad.detach(), kpe_leaf.grad.detach(), ) # ------------------------------------------------------------------ test @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) def test_mla_rope_fused_vs_reference(shape, dtype, seed): B, S, H_q, H_kv, D_nope, D_rope, D_v = shape D_qk = D_nope + D_rope device = "cuda" torch.manual_seed(seed) freqs_cis = _precompute_freqs_cis(D_rope, S).to(device) q = (torch.randn(B, S, H_q, D_qk, device=device, dtype=dtype) * 0.5) kv_latent = (torch.randn(B, S, H_kv, D_nope + D_v, device=device, dtype=dtype) * 0.5) k_pe = (torch.randn(B, S, D_rope, device=device, dtype=dtype) * 0.5) kw = dict(B=B, S=S, H_kv=H_kv, D_nope=D_nope, D_rope=D_rope, D_v=D_v) van_q, van_k, van_v, van_gq, van_gkv, van_gkpe = _run_with_grad( vanilla_path, q, kv_latent, k_pe, freqs_cis, **kw ) our_q, our_k, our_v, our_gq, our_gkv, our_gkpe = _run_with_grad( fused_path, q, kv_latent, k_pe, freqs_cis, **kw ) # Forward outputs: small bf16 jitter expected on the q rope rotation # (Triton fp32 accum vs inductor fp32 complex_mul order). assert_close(our_q.float(), van_q.float(), atol=1e-2, rtol=1e-2) # KV path is bit-exact (just slice + register broadcast + store). assert_close(our_k.float(), van_k.float(), atol=0.0, rtol=0.0) assert_close(our_v.float(), van_v.float(), atol=0.0, rtol=0.0) # Input grads. assert_close(our_gq.float(), van_gq.float(), atol=1e-2, rtol=1e-2) assert_close(our_gkv.float(), van_gkv.float(), atol=0.0, rtol=0.0) assert_close(our_gkpe.float(), van_gkpe.float(), atol=0.0, rtol=0.0)