Kernels
Jangwoong Kim commited on
Commit
5adea7d
·
unverified ·
2 Parent(s): 79a877a536f0b2

Merge pull request #22 from MotifTechnologies/jangwoong/mla-rope-fa4-port

Browse files
benchmarks/bench_mla_rope.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apiVersion: trainer.kubeflow.org/v1alpha1
2
+ kind: TrainJob
3
+ metadata:
4
+ name: jangwoong-mla-rope-bench
5
+ namespace: kbm-g-np-motif
6
+ spec:
7
+ managedBy: trainer.kubeflow.org/trainjob-controller
8
+ podTemplateOverrides:
9
+ - spec:
10
+ containers:
11
+ - name: node
12
+ volumeMounts:
13
+ - mountPath: /dev/shm
14
+ name: shm
15
+ - mountPath: /mair
16
+ name: mair
17
+ volumes:
18
+ - emptyDir:
19
+ medium: Memory
20
+ sizeLimit: 64Gi
21
+ name: shm
22
+ - name: mair
23
+ persistentVolumeClaim:
24
+ claimName: mair
25
+ targetJobs:
26
+ - name: node
27
+ runtimeRef:
28
+ apiGroup: trainer.kubeflow.org
29
+ kind: ClusterTrainingRuntime
30
+ name: torch-distributed
31
+ suspend: false
32
+ trainer:
33
+ args:
34
+ - /bin/bash
35
+ - '-c'
36
+ - |
37
+ set -e
38
+ ACTIVATIONPATH=/mair/team-sys/jangwoong/activation
39
+ DATESTAMP=$(date +'%y_%m_%d_%H_%M')
40
+ SAVE_PATH=$ACTIVATIONPATH/benchmarks/results/mla_rope/${DATESTAMP}
41
+ mkdir -p $SAVE_PATH
42
+
43
+ pip install triton pandas
44
+
45
+ # Build activation from local source (copy to /tmp to avoid NFS race)
46
+ mkdir -p /tmp/activation_src && rm -rf /tmp/activation_src/* && \
47
+ rsync -a --exclude=build $ACTIVATIONPATH/ /tmp/activation_src/ && \
48
+ pip install --no-build-isolation /tmp/activation_src 2>&1 | tail -50
49
+
50
+ python -c "import activation; print('fused_q_rope_inplace:', activation.fused_q_rope_inplace); print('fused_kv_split_rope_cat:', activation.fused_kv_split_rope_cat)"
51
+
52
+ nvidia-smi | tee $SAVE_PATH/nvidia_smi.txt
53
+
54
+ echo "=== MLA RoPE benchmark ==="
55
+ cd $ACTIVATIONPATH/benchmarks
56
+ CUDA_VISIBLE_DEVICES=0 python run_cases.py --case mla_rope --dtype bf16 \
57
+ --save-path $SAVE_PATH/bench 2>&1 | tee $SAVE_PATH/bench.log
58
+
59
+ echo "=== Done. Results at: $SAVE_PATH ==="
60
+ exit 0
61
+ env:
62
+ - name: PYTHONUNBUFFERED
63
+ value: '1'
64
+ - name: PYTORCH_ALLOC_CONF
65
+ value: expandable_segments:True
66
+ - name: CUDA_LAUNCH_BLOCKING
67
+ value: '0'
68
+ - name: OMP_NUM_THREADS
69
+ value: '1'
70
+ image: ghcr.io/motiftechnologies/llm-training:v0.1.8
71
+ numNodes: 1
72
+ numProcPerNode: 1
73
+ resourcesPerNode:
74
+ limits:
75
+ cpu: '96'
76
+ memory: 1024Gi
77
+ nvidia.com/gpu: '8'
78
+ requests:
79
+ cpu: '96'
80
+ memory: 1024Gi
81
+ nvidia.com/gpu: '8'
benchmarks/cases/mla_rope.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MLA RoPE case: fused (activation Triton) vs vanilla (PyTorch native).
2
+
3
+ MLA head dims are fixed by motif3 spec (H_q=80, H_kv=16, D_nope=128,
4
+ D_rope=64, D_v=128); the benchmark's only sweep axes are (bs, sl).
5
+ The framework's ``dim`` axis is a dummy here — pass 0 in configs.
6
+ """
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ import activation
12
+
13
+ from common.diff_engine import DiffCase
14
+
15
+ # ---- MLA shapes (motif3_seq) -----------------------------------------------
16
+ H_Q, H_KV = 80, 16
17
+ D_NOPE, D_ROPE, D_V = 128, 64, 128
18
+ D_QK = D_NOPE + D_ROPE # 192
19
+
20
+
21
+ # ---- reference (PyTorch native) --------------------------------------------
22
+ def _precompute_freqs_cis(dim, end, theta=10000.0):
23
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
24
+ t = torch.arange(end, dtype=torch.float32)
25
+ freqs = torch.outer(t, freqs)
26
+ return torch.polar(torch.ones_like(freqs), freqs)
27
+
28
+
29
+ def _apply_rotary_emb_single(x, freqs_cis):
30
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
31
+ freqs_cis = freqs_cis[: x_.shape[1]].view(1, x_.shape[1], 1, x_.shape[3])
32
+ out = torch.view_as_real(x_ * freqs_cis).flatten(3)
33
+ return out.type_as(x)
34
+
35
+
36
+ def _reorder(qk, rope_dim):
37
+ B, S = qk.shape[0], qk.shape[1]
38
+ qk = qk.view(B, S, -1, rope_dim // 2, 2).transpose(3, 4)
39
+ return qk.reshape(B, S, -1, rope_dim)
40
+
41
+
42
+ def _vanilla(q, kv_latent, k_pe, freqs_cis):
43
+ q_nope, q_pe = torch.split(q, [D_NOPE, D_ROPE], dim=-1)
44
+ q_pe = _reorder(_apply_rotary_emb_single(q_pe, freqs_cis), D_ROPE)
45
+ q_total = torch.cat([q_nope, q_pe], dim=-1)
46
+ k_pe_roped = _reorder(_apply_rotary_emb_single(k_pe.unsqueeze(2), freqs_cis), D_ROPE)
47
+ k_nope, v = torch.split(kv_latent, [D_NOPE, D_V], dim=-1)
48
+ k_full = torch.cat([k_nope, k_pe_roped.expand(-1, -1, H_KV, -1)], dim=-1)
49
+ return q_total, k_full, v
50
+
51
+
52
+ def _fused(q, kv_latent, k_pe, freqs_cis):
53
+ q_total = activation.fused_q_rope_inplace(q, freqs_cis, D_NOPE, D_ROPE)
54
+ # k_pe RoPE stays PyTorch native (head-shared, too small for custom kernel)
55
+ k_pe_roped = _reorder(_apply_rotary_emb_single(k_pe.unsqueeze(2), freqs_cis), D_ROPE)
56
+ k_full, v = activation.fused_kv_split_rope_cat(kv_latent, k_pe_roped, D_NOPE, D_V, D_ROPE)
57
+ return q_total, k_full, v
58
+
59
+
60
+ class _VanillaModule(nn.Module):
61
+ def forward(self, q, kv_latent, k_pe, freqs_cis):
62
+ return _vanilla(q, kv_latent, k_pe, freqs_cis)
63
+
64
+
65
+ class _FusedModule(nn.Module):
66
+ def forward(self, q, kv_latent, k_pe, freqs_cis):
67
+ return _fused(q, kv_latent, k_pe, freqs_cis)
68
+
69
+
70
+ class MLARoPE(DiffCase):
71
+ # Framework calls build_inputs(bs, sl, dim, dtype, eps) — dim unused.
72
+ def build_inputs(self, bs, sl, hidden, dtype, eps):
73
+ return {
74
+ "q": (torch.randn(bs, sl, H_Q, D_QK, dtype=dtype) * 0.5).requires_grad_(True),
75
+ "kv_latent": (torch.randn(bs, sl, H_KV, D_NOPE + D_V, dtype=dtype) * 0.5).requires_grad_(True),
76
+ "k_pe": (torch.randn(bs, sl, D_ROPE, dtype=dtype) * 0.5).requires_grad_(True),
77
+ "freqs_cis": _precompute_freqs_cis(D_ROPE, sl),
78
+ }
79
+
80
+ def make_naive(self, I):
81
+ return _VanillaModule()
82
+
83
+ def make_cuda(self, I):
84
+ return _FusedModule()
85
+
86
+ def forward(self, obj, I):
87
+ # fused_q_rope_inplace needs non-leaf q; wrap both paths for fairness
88
+ q_in = I["q"] * 1.0
89
+ kv_in = I["kv_latent"] * 1.0
90
+ kpe_in = I["k_pe"] * 1.0
91
+ return obj(q_in, kv_in, kpe_in, I["freqs_cis"])
92
+
93
+ def grad_inputs(self, I):
94
+ return [I["q"], I["kv_latent"], I["k_pe"]]
95
+
96
+
97
+ CASE = MLARoPE()
benchmarks/run_cases.py CHANGED
@@ -62,7 +62,8 @@ def main():
62
  ap = argparse.ArgumentParser()
63
  ap.add_argument(
64
  "--case",
65
- choices=["rms", "add_rms", "poly", "mul_poly", "grouped_mul_poly"],
 
66
  required=True)
67
  ap.add_argument("--plot", action="store_true")
68
  ap.add_argument(
@@ -95,12 +96,26 @@ def main():
95
  case: DiffCase = mod.CASE
96
 
97
  # Correctness checks across multiple configs
98
- for bs, sl, hid in [(2, 128, 4096), (8, 4096, 1280), (1, 32768, 1280)]:
 
 
 
 
 
 
 
 
 
99
  print(
100
  f"Checking correctness: bs={bs}, sl={sl}, D={hid} "
101
  f"(N={bs*sl})...",
102
  end=" ")
103
- calculate_diff(case, batch_size=bs, seq_len=sl, hidden_size=hid)
 
 
 
 
 
104
  print("✅")
105
 
106
  for dtype_name, dtype in dtypes:
@@ -109,6 +124,7 @@ def main():
109
  print(f"{'=' * 60}\n")
110
 
111
  save_dir = os.path.join(args.save_path, args.case, dtype_name)
 
112
  is_grouped = args.case == "grouped_mul_poly"
113
 
114
  if args.plot:
@@ -118,6 +134,8 @@ def main():
118
  dim = [1280]
119
  elif "poly" in args.case:
120
  dim = [8192, 16384]
 
 
121
  else:
122
  dim = [2048, 4096]
123
  configs = list(
@@ -170,6 +188,8 @@ def main():
170
  dim = [1280]
171
  elif "poly" in args.case:
172
  dim = [8192, 16384]
 
 
173
  else:
174
  dim = [2048, 4096]
175
  configs = list(
 
62
  ap = argparse.ArgumentParser()
63
  ap.add_argument(
64
  "--case",
65
+ choices=["rms", "add_rms", "poly", "mul_poly", "grouped_mul_poly",
66
+ "mla_rope"],
67
  required=True)
68
  ap.add_argument("--plot", action="store_true")
69
  ap.add_argument(
 
96
  case: DiffCase = mod.CASE
97
 
98
  # Correctness checks across multiple configs
99
+ # NOTE: calculate_diff positionally calls build_inputs(hidden_size, bs, sl);
100
+ # bench framework positionally calls build_inputs(bs, sl, dim). These
101
+ # disagree — rms-style cases don't care (all 3 axes are flat dims), but
102
+ # mla_rope does. We match the bench convention in cases/mla_rope.py, so
103
+ # we swap arg names at the correctness call site below for that case.
104
+ if args.case == "mla_rope":
105
+ cfgs = [(1, 1024, 0), (4, 4096, 0), (8, 4096, 0)] # (bs, sl, dummy)
106
+ else:
107
+ cfgs = [(2, 128, 4096), (8, 4096, 1280), (1, 32768, 1280)]
108
+ for bs, sl, hid in cfgs:
109
  print(
110
  f"Checking correctness: bs={bs}, sl={sl}, D={hid} "
111
  f"(N={bs*sl})...",
112
  end=" ")
113
+ if args.case == "mla_rope":
114
+ # Swap so positional (hidden_size, batch_size, seq_len) maps to
115
+ # our build_inputs(bs, sl, dim) as (bs, sl, dummy).
116
+ calculate_diff(case, batch_size=sl, seq_len=hid, hidden_size=bs)
117
+ else:
118
+ calculate_diff(case, batch_size=bs, seq_len=sl, hidden_size=hid)
119
  print("✅")
120
 
121
  for dtype_name, dtype in dtypes:
 
124
  print(f"{'=' * 60}\n")
125
 
126
  save_dir = os.path.join(args.save_path, args.case, dtype_name)
127
+ os.makedirs(save_dir, exist_ok=True)
128
  is_grouped = args.case == "grouped_mul_poly"
129
 
130
  if args.plot:
 
134
  dim = [1280]
135
  elif "poly" in args.case:
136
  dim = [8192, 16384]
137
+ elif args.case == "mla_rope":
138
+ dim = [0] # MLA head dims are fixed; dim axis is a dummy
139
  else:
140
  dim = [2048, 4096]
141
  configs = list(
 
188
  dim = [1280]
189
  elif "poly" in args.case:
190
  dim = [8192, 16384]
191
+ elif args.case == "mla_rope":
192
+ dim = [0] # MLA head dims are fixed; dim axis is a dummy
193
  else:
194
  dim = [2048, 4096]
195
  configs = list(
tests/test_mla_rope_grad.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Numerical parity test: activation MLA RoPE kernels vs PyTorch reference.
2
+
3
+ The activation package exposes two Triton kernels for Motif3 MLA attention:
4
+ * fused_q_rope_inplace — in-place RoPE on q's rope section
5
+ * fused_kv_split_rope_cat — split kv_latent + register-broadcast k_pe to H heads + cat
6
+
7
+ This test runs both the fused path and a pure-PyTorch reference over identical
8
+ inputs (forward + backward) and compares all outputs and input gradients.
9
+
10
+ Self-contained: the reference RoPE implementation lives in this file (no
11
+ upstream model code dependency).
12
+ """
13
+
14
+ import pytest
15
+ import torch
16
+
17
+ import activation
18
+
19
+ from .utils import assert_close
20
+
21
+
22
+ # Realistic motif3_seq per-GPU shapes (B=local_batch_size, H_q/H_kv per MLA spec).
23
+ SHAPES = [
24
+ # (B, S, H_q, H_kv, D_nope, D_rope, D_v)
25
+ (8, 4096, 80, 16, 128, 64, 128),
26
+ ]
27
+ DTYPES = [torch.bfloat16]
28
+ SEEDS = [0]
29
+
30
+
31
+ # ------------------------------------------------------------------ reference
32
+
33
+ def _precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
34
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
35
+ t = torch.arange(end, dtype=torch.float32)
36
+ freqs = torch.outer(t, freqs)
37
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
38
+
39
+
40
+ def _apply_rotary_emb_single(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
41
+ """[B, S, H, D] interleaved → rotated, in interleaved layout."""
42
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
43
+ freqs_cis = freqs_cis[: x_.shape[1]].view(1, x_.shape[1], 1, x_.shape[3])
44
+ out = torch.view_as_real(x_ * freqs_cis).flatten(3)
45
+ return out.type_as(x)
46
+
47
+
48
+ def _reorder_headdim_elements_rope(qk: torch.Tensor, B: int, S: int, rope_dim: int) -> torch.Tensor:
49
+ """Interleaved [r0,i0,r1,i1,...] → contiguous [r0,r1,...,i0,i1,...]."""
50
+ qk = qk.view(B, S, -1, rope_dim // 2, 2)
51
+ qk = qk.transpose(3, 4)
52
+ return qk.reshape(B, S, -1, rope_dim)
53
+
54
+
55
+ def vanilla_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v):
56
+ # Q
57
+ q_nope, q_pe = torch.split(q, [D_nope, D_rope], dim=-1)
58
+ q_pe = _apply_rotary_emb_single(q_pe, freqs_cis)
59
+ q_pe = _reorder_headdim_elements_rope(q_pe, B, S, D_rope)
60
+ q_total = torch.cat([q_nope, q_pe], dim=-1)
61
+ # k_pe (head-shared, H=1)
62
+ k_pe_4d = k_pe.unsqueeze(2)
63
+ k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis)
64
+ k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope)
65
+ # KV split + expand + cat
66
+ k_nope, v = torch.split(kv_latent, [D_nope, D_v], dim=-1)
67
+ k_full = torch.cat([k_nope, k_pe_roped.expand(-1, -1, H_kv, -1)], dim=-1)
68
+ return q_total, k_full, v
69
+
70
+
71
+ def fused_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v):
72
+ q_total = activation.fused_q_rope_inplace(q, freqs_cis, D_nope, D_rope)
73
+ # k_pe RoPE stays PyTorch native (head-shared; standalone Triton kernel was
74
+ # launch-bound on B200, no measurable win — see PR #22).
75
+ k_pe_4d = k_pe.unsqueeze(2)
76
+ k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis)
77
+ k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope)
78
+ k_full, v = activation.fused_kv_split_rope_cat(
79
+ kv_latent, k_pe_roped, D_nope, D_v, D_rope
80
+ )
81
+ return q_total, k_full, v
82
+
83
+
84
+ # ------------------------------------------------------------------ harness
85
+
86
+ def _run_with_grad(path_fn, q, kv_latent, k_pe, freqs_cis, **shape_kwargs):
87
+ # Inputs come in as leaves; thread through a no-op so the in-place fused_q
88
+ # kernel sees a non-leaf (mirrors the real model where q is a Linear output).
89
+ q_leaf, kv_leaf, kpe_leaf = (
90
+ q.clone().detach().requires_grad_(True),
91
+ kv_latent.clone().detach().requires_grad_(True),
92
+ k_pe.clone().detach().requires_grad_(True),
93
+ )
94
+ q_in, kv_in, kpe_in = q_leaf * 1.0, kv_leaf * 1.0, kpe_leaf * 1.0
95
+
96
+ q_total, k_full, v = path_fn(q_in, kv_in, kpe_in, freqs_cis, **shape_kwargs)
97
+ loss = (q_total.float() ** 2).sum() + (k_full.float() ** 2).sum() + (v.float() ** 2).sum()
98
+ loss.backward()
99
+
100
+ return (
101
+ q_total.detach(), k_full.detach(), v.detach(),
102
+ q_leaf.grad.detach(), kv_leaf.grad.detach(), kpe_leaf.grad.detach(),
103
+ )
104
+
105
+
106
+ # ------------------------------------------------------------------ test
107
+
108
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
109
+ @pytest.mark.parametrize("shape", SHAPES)
110
+ @pytest.mark.parametrize("dtype", DTYPES)
111
+ @pytest.mark.parametrize("seed", SEEDS)
112
+ def test_mla_rope_fused_vs_reference(shape, dtype, seed):
113
+ B, S, H_q, H_kv, D_nope, D_rope, D_v = shape
114
+ D_qk = D_nope + D_rope
115
+ device = "cuda"
116
+
117
+ torch.manual_seed(seed)
118
+ freqs_cis = _precompute_freqs_cis(D_rope, S).to(device)
119
+
120
+ q = (torch.randn(B, S, H_q, D_qk, device=device, dtype=dtype) * 0.5)
121
+ kv_latent = (torch.randn(B, S, H_kv, D_nope + D_v, device=device, dtype=dtype) * 0.5)
122
+ k_pe = (torch.randn(B, S, D_rope, device=device, dtype=dtype) * 0.5)
123
+
124
+ kw = dict(B=B, S=S, H_kv=H_kv, D_nope=D_nope, D_rope=D_rope, D_v=D_v)
125
+ van_q, van_k, van_v, van_gq, van_gkv, van_gkpe = _run_with_grad(
126
+ vanilla_path, q, kv_latent, k_pe, freqs_cis, **kw
127
+ )
128
+ our_q, our_k, our_v, our_gq, our_gkv, our_gkpe = _run_with_grad(
129
+ fused_path, q, kv_latent, k_pe, freqs_cis, **kw
130
+ )
131
+
132
+ # Forward outputs: small bf16 jitter expected on the q rope rotation
133
+ # (Triton fp32 accum vs inductor fp32 complex_mul order).
134
+ assert_close(our_q.float(), van_q.float(), atol=1e-2, rtol=1e-2)
135
+ # KV path is bit-exact (just slice + register broadcast + store).
136
+ assert_close(our_k.float(), van_k.float(), atol=0.0, rtol=0.0)
137
+ assert_close(our_v.float(), van_v.float(), atol=0.0, rtol=0.0)
138
+
139
+ # Input grads.
140
+ assert_close(our_gq.float(), van_gq.float(), atol=1e-2, rtol=1e-2)
141
+ assert_close(our_gkv.float(), van_gkv.float(), atol=0.0, rtol=0.0)
142
+ assert_close(our_gkpe.float(), van_gkpe.float(), atol=0.0, rtol=0.0)
torch-ext/activation/__init__.py CHANGED
@@ -2,6 +2,10 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
 
 
 
5
  from .grouped_poly_norm import fused_mul_grouped_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -50,6 +54,8 @@ __all__ = [
50
  "fused_mul_grouped_poly_norm",
51
  "rms_norm",
52
  "fused_add_rms_norm",
 
 
53
  "layers",
54
  "parallel_style",
55
  "ops",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .fused_rope import (
6
+ fused_kv_split_rope_cat,
7
+ fused_q_rope_inplace,
8
+ )
9
  from .grouped_poly_norm import fused_mul_grouped_poly_norm
10
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
11
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
 
54
  "fused_mul_grouped_poly_norm",
55
  "rms_norm",
56
  "fused_add_rms_norm",
57
+ "fused_q_rope_inplace",
58
+ "fused_kv_split_rope_cat",
59
  "layers",
60
  "parallel_style",
61
  "ops",
torch-ext/activation/fused_rope.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fused MLA RoPE kernels for Motif3 GDLAttention.
2
+
3
+ Applies RoPE to the input tensor and outputs in contiguous format
4
+ [real..., imag...] so no reorder_headdim_elements_rope is needed.
5
+
6
+ Registered as torch custom_op for torch.compile compatibility (no graph break).
7
+
8
+ Reference: Megatron-LM fused_mla_yarn_rope_apply.py
9
+ """
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Static configs per kernel — no autotune.
17
+ #
18
+ # Picked from a full-sweep dump (TRITON_PRINT_AUTOTUNING=1) on motif3_seq,
19
+ # B=8, S=4096, n_layers=4, 8×B200, EP=8 (plurality across 8 ranks; see
20
+ # PR description for details). These are the shapes the kernels see at
21
+ # the current prod workload; hard-coding avoids autotune warm-up cost and
22
+ # tie-break jitter between near-equivalent configs.
23
+ #
24
+ # If model config changes materially (different H, D, batch-size/seq-len
25
+ # regime), re-run the dump and update these. Autotune key used to be
26
+ # (rope_dim, head_num); same shape invariants still apply at runtime.
27
+ # ---------------------------------------------------------------------------
28
+
29
+ # BLOCK_H, num_warps, num_stages per kernel. H values denote head_num
30
+ # (q is per-head H=80, kv fused expands head-shared k_pe to H_kv=16 in motif3_seq).
31
+ _CFG_KV_ROPE_FWD = dict(BLOCK_H=32, num_warps=8, num_stages=2) # kv fused, H=80
32
+ _CFG_KV_ROPE_BWD = dict(BLOCK_H=16, num_warps=4, num_stages=2) # kv bwd; BLOCK_H must cover H_kv=16 (single program per token)
33
+ _CFG_Q_ROPE_INPLACE_FWD = dict(BLOCK_H=8, num_warps=1, num_stages=2) # q in-place, H=80
34
+ _CFG_Q_ROPE_BWD = dict(BLOCK_H=4, num_warps=2, num_stages=2) # q bwd, H=80
35
+
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Phase 2: Fused KV split + RoPE + expand + cat kernel
40
+ # Fuses: split(kv_latent, [k_dim, v_dim]) + RoPE(k_pe) + expand + cat → key, value
41
+ # Reference: Megatron rotary_fwd_kv_kernel
42
+ # ---------------------------------------------------------------------------
43
+
44
+ @triton.jit
45
+ def _kv_rope_fwd_kernel(
46
+ KV, # [B*S, H, k_nope_dim + v_dim]
47
+ K_PE, # [B*S, 1, rope_dim] (already RoPE'd, contiguous format)
48
+ O_KEY, # [B*S, H, k_nope_dim + rope_dim] output
49
+ O_VALUE, # [B*S, H, v_dim] output
50
+ rope_dim: tl.constexpr,
51
+ k_nope_dim: tl.constexpr,
52
+ v_dim: tl.constexpr,
53
+ head_num: tl.constexpr,
54
+ stride_kv_token,
55
+ stride_kv_head,
56
+ stride_pe_token,
57
+ stride_k_token,
58
+ stride_k_head,
59
+ stride_v_token,
60
+ stride_v_head,
61
+ BLOCK_H: tl.constexpr,
62
+ ):
63
+ pid_token = tl.program_id(0)
64
+ pid_hblock = tl.program_id(1)
65
+
66
+ KV_ptr = KV + pid_token * stride_kv_token + pid_hblock * BLOCK_H * stride_kv_head
67
+ K_ptr = O_KEY + pid_token * stride_k_token + pid_hblock * BLOCK_H * stride_k_head
68
+ V_ptr = O_VALUE + pid_token * stride_v_token + pid_hblock * BLOCK_H * stride_v_head
69
+
70
+ h_off = tl.arange(0, BLOCK_H)[:, None]
71
+ mask = (pid_hblock * BLOCK_H + h_off) < head_num
72
+
73
+ # Read k_nope from KV
74
+ k_nope_off = h_off * stride_kv_head + tl.arange(0, k_nope_dim)[None, :]
75
+ k_nope = tl.load(KV_ptr + k_nope_off, mask=mask)
76
+
77
+ # Read v from KV
78
+ v_off = h_off * stride_kv_head + k_nope_dim + tl.arange(0, v_dim)[None, :]
79
+ v = tl.load(KV_ptr + v_off, mask=mask)
80
+
81
+ # Read k_pe (shared across all heads, already RoPE'd)
82
+ # K_PE is [B*S, 1, rope_dim], broadcast to all heads
83
+ pe_ptr = K_PE + pid_token * stride_pe_token
84
+ k_pe = tl.load(pe_ptr + tl.arange(0, rope_dim)[None, :])
85
+ k_pe = k_pe.broadcast_to(BLOCK_H, rope_dim)
86
+
87
+ # Write key = [k_nope | k_pe]
88
+ k_nope_out = h_off * stride_k_head + tl.arange(0, k_nope_dim)[None, :]
89
+ tl.store(K_ptr + k_nope_out, k_nope, mask=mask)
90
+ k_pe_out = h_off * stride_k_head + k_nope_dim + tl.arange(0, rope_dim)[None, :]
91
+ tl.store(K_ptr + k_pe_out, k_pe, mask=mask)
92
+
93
+ # Write value
94
+ v_out = h_off * stride_v_head + tl.arange(0, v_dim)[None, :]
95
+ tl.store(V_ptr + v_out, v, mask=mask)
96
+
97
+
98
+ @triton.jit
99
+ def _kv_rope_bwd_kernel(
100
+ DO_KEY, # [B*S, H, k_nope_dim + rope_dim] grad_key (in)
101
+ DO_VALUE, # [B*S, H, v_dim] grad_value(in)
102
+ DKV, # [B*S, H, k_nope_dim + v_dim] grad_kv_latent (out)
103
+ DKPE, # [B*S, 1, rope_dim] grad_k_pe (out, bf16)
104
+ rope_dim: tl.constexpr,
105
+ k_nope_dim: tl.constexpr,
106
+ v_dim: tl.constexpr,
107
+ head_num: tl.constexpr,
108
+ stride_dk_token, stride_dk_head,
109
+ stride_dv_token, stride_dv_head,
110
+ stride_dkv_token, stride_dkv_head,
111
+ stride_dkpe_token,
112
+ BLOCK_H: tl.constexpr,
113
+ ):
114
+ """Reverse of _kv_rope_fwd_kernel.
115
+ Forward did: key_nope = slice(kv_latent), key_rope = broadcast(k_pe, H),
116
+ value = slice(kv_latent).
117
+ Backward therefore: grad_kv_latent_nope = grad_key_nope, grad_kv_latent_v = grad_value,
118
+ grad_k_pe = sum(grad_key_rope, dim=H).
119
+ Single program per token; BLOCK_H must cover head_num (H_kv=16 in MLA).
120
+ """
121
+ pid_token = tl.program_id(0)
122
+
123
+ DK_ptr = DO_KEY + pid_token * stride_dk_token
124
+ DV_ptr = DO_VALUE + pid_token * stride_dv_token
125
+ DKV_ptr = DKV + pid_token * stride_dkv_token
126
+
127
+ h_off = tl.arange(0, BLOCK_H)[:, None]
128
+ mask = h_off < head_num
129
+
130
+ # 1. grad_key nope → grad_kv_latent nope section
131
+ nope_in = h_off * stride_dk_head + tl.arange(0, k_nope_dim)[None, :]
132
+ nope_out = h_off * stride_dkv_head + tl.arange(0, k_nope_dim)[None, :]
133
+ nope_data = tl.load(DK_ptr + nope_in, mask=mask)
134
+ tl.store(DKV_ptr + nope_out, nope_data, mask=mask)
135
+
136
+ # 2. grad_value → grad_kv_latent v section
137
+ v_in = h_off * stride_dv_head + tl.arange(0, v_dim)[None, :]
138
+ v_out = h_off * stride_dkv_head + k_nope_dim + tl.arange(0, v_dim)[None, :]
139
+ v_data = tl.load(DV_ptr + v_in, mask=mask)
140
+ tl.store(DKV_ptr + v_out, v_data, mask=mask)
141
+
142
+ # 3. grad_k_pe = sum over H of grad_key rope section
143
+ # Accumulate in fp32 for precision across H heads.
144
+ rope_in = h_off * stride_dk_head + k_nope_dim + tl.arange(0, rope_dim)[None, :]
145
+ rope_grads = tl.load(DK_ptr + rope_in, mask=mask, other=0.0).to(tl.float32)
146
+ summed = tl.sum(rope_grads, axis=0) # [rope_dim], fp32
147
+ pe_out = pid_token * stride_dkpe_token + tl.arange(0, rope_dim)
148
+ tl.store(DKPE + pe_out, summed.to(DKPE.dtype.element_ty))
149
+
150
+
151
+ @torch.library.custom_op("motif::kv_rope_fwd", mutates_args=())
152
+ def _kv_rope_fwd(
153
+ kv_latent: torch.Tensor, # [B*S, H, k_nope_dim + v_dim]
154
+ k_pe: torch.Tensor, # [B*S, 1, rope_dim] (already RoPE'd)
155
+ k_nope_dim: int,
156
+ v_dim: int,
157
+ rope_dim: int,
158
+ ) -> tuple[torch.Tensor, torch.Tensor]:
159
+ assert kv_latent.stride(-1) == 1 and k_pe.stride(-1) == 1, (
160
+ "fused_rope kernel requires last-dim unit stride"
161
+ )
162
+ # MLA convention: k_pe is head-shared (single head broadcast to all Q heads
163
+ # via register-level broadcast inside the kernel). GQA/MQA variants that
164
+ # give k_pe its own head dim are not supported by this kernel.
165
+ assert k_pe.shape[1] == 1, f"k_pe must be head-shared (shape[1]==1), got {k_pe.shape}"
166
+ B_S, H, _ = kv_latent.shape
167
+ key = kv_latent.new_empty(B_S, H, k_nope_dim + rope_dim)
168
+ value = kv_latent.new_empty(B_S, H, v_dim)
169
+ grid = lambda META: (B_S, triton.cdiv(H, META["BLOCK_H"]))
170
+ _kv_rope_fwd_kernel[grid](
171
+ kv_latent, k_pe, key, value,
172
+ rope_dim, k_nope_dim, v_dim, H,
173
+ kv_latent.stride(0), kv_latent.stride(1),
174
+ k_pe.stride(0),
175
+ key.stride(0), key.stride(1),
176
+ value.stride(0), value.stride(1),
177
+ **_CFG_KV_ROPE_FWD,
178
+ )
179
+ return key, value
180
+
181
+
182
+ @_kv_rope_fwd.register_fake
183
+ def _kv_rope_fwd_fake(kv_latent, k_pe, k_nope_dim, v_dim, rope_dim):
184
+ B_S, H, _ = kv_latent.shape
185
+ key = kv_latent.new_empty(B_S, H, k_nope_dim + rope_dim)
186
+ value = kv_latent.new_empty(B_S, H, v_dim)
187
+ return key, value
188
+
189
+
190
+ @torch.library.custom_op("motif::kv_rope_bwd", mutates_args=())
191
+ def _kv_rope_bwd(
192
+ grad_key: torch.Tensor, # [B*S, H, k_nope_dim + rope_dim]
193
+ grad_value: torch.Tensor, # [B*S, H, v_dim]
194
+ k_nope_dim: int,
195
+ v_dim: int,
196
+ rope_dim: int,
197
+ ) -> tuple[torch.Tensor, torch.Tensor]:
198
+ assert grad_key.stride(-1) == 1 and grad_value.stride(-1) == 1, (
199
+ "fused_rope kernel requires last-dim unit stride"
200
+ )
201
+ B_S, H, _ = grad_key.shape
202
+ # grad_kv_latent layout matches forward input: [nope | v]
203
+ grad_kv_latent = grad_key.new_empty(B_S, H, k_nope_dim + v_dim)
204
+ # grad_k_pe: head-shared, matches forward input shape
205
+ grad_k_pe = grad_key.new_empty(B_S, 1, rope_dim)
206
+ # Single program per token; BLOCK_H (=16) must be >= H.
207
+ grid = (B_S,)
208
+ _kv_rope_bwd_kernel[grid](
209
+ grad_key, grad_value, grad_kv_latent, grad_k_pe,
210
+ rope_dim, k_nope_dim, v_dim, H,
211
+ grad_key.stride(0), grad_key.stride(1),
212
+ grad_value.stride(0), grad_value.stride(1),
213
+ grad_kv_latent.stride(0), grad_kv_latent.stride(1),
214
+ grad_k_pe.stride(0),
215
+ **_CFG_KV_ROPE_BWD,
216
+ )
217
+ return grad_kv_latent, grad_k_pe
218
+
219
+
220
+ @_kv_rope_bwd.register_fake
221
+ def _kv_rope_bwd_fake(grad_key, grad_value, k_nope_dim, v_dim, rope_dim):
222
+ B_S, H, _ = grad_key.shape
223
+ return (
224
+ grad_key.new_empty(B_S, H, k_nope_dim + v_dim),
225
+ grad_key.new_empty(B_S, 1, rope_dim),
226
+ )
227
+
228
+
229
+ class FusedKVRoPE(torch.autograd.Function):
230
+ @staticmethod
231
+ def forward(ctx, kv_latent, k_pe, k_nope_dim, v_dim, rope_dim):
232
+ # kv_latent: [B, S, H, k_nope_dim + v_dim]
233
+ # k_pe: [B, S, 1, rope_dim] (already RoPE'd)
234
+ B, S, H, D = kv_latent.shape
235
+ key_3d, value_3d = _kv_rope_fwd(
236
+ kv_latent.reshape(B * S, H, D),
237
+ k_pe.reshape(B * S, 1, rope_dim),
238
+ k_nope_dim, v_dim, rope_dim,
239
+ )
240
+ ctx.k_nope_dim = k_nope_dim
241
+ ctx.v_dim = v_dim
242
+ ctx.rope_dim = rope_dim
243
+ ctx.shape = (B, S, H)
244
+ return key_3d.view(B, S, H, k_nope_dim + rope_dim), value_3d.view(B, S, H, v_dim)
245
+
246
+ @staticmethod
247
+ def backward(ctx, grad_key, grad_value):
248
+ B, S, H = ctx.shape
249
+ k_nope_dim = ctx.k_nope_dim
250
+ v_dim = ctx.v_dim
251
+ rope_dim = ctx.rope_dim
252
+
253
+ # Single Triton kernel does: nope copy + v copy + head-sum of rope section.
254
+ # Replaces (slice + cat + sum) inductor path.
255
+ grad_kv_latent_3d, grad_k_pe_3d = _kv_rope_bwd(
256
+ grad_key.contiguous().reshape(B * S, H, k_nope_dim + rope_dim),
257
+ grad_value.contiguous().reshape(B * S, H, v_dim),
258
+ k_nope_dim, v_dim, rope_dim,
259
+ )
260
+ grad_kv_latent = grad_kv_latent_3d.view(B, S, H, k_nope_dim + v_dim)
261
+ grad_k_pe = grad_k_pe_3d.view(B, S, 1, rope_dim)
262
+ return grad_kv_latent, grad_k_pe, None, None, None
263
+
264
+
265
+ # ---------------------------------------------------------------------------
266
+ # Q RoPE backward kernel — used by FusedQRoPEInplace.backward (below).
267
+ # Out-of-place: reads contiguous grad_out, writes interleaved grad_in; nope
268
+ # gradient is copied through unchanged.
269
+ # ---------------------------------------------------------------------------
270
+
271
+ @triton.jit
272
+ def _q_rope_bwd_kernel(
273
+ DO, # [B*S, H, nope_dim + rope_dim] grad (contiguous rope fmt)
274
+ DQ, # [B*S, H, nope_dim + rope_dim] grad output (interleaved rope fmt)
275
+ COS, # [S, rope_dim // 2]
276
+ SIN, # [S, rope_dim // 2]
277
+ nope_dim: tl.constexpr,
278
+ rope_dim: tl.constexpr,
279
+ head_num: tl.constexpr,
280
+ seq_len,
281
+ stride_do_token,
282
+ stride_do_head,
283
+ stride_dq_token,
284
+ stride_dq_head,
285
+ BLOCK_H: tl.constexpr,
286
+ ):
287
+ HALF: tl.constexpr = rope_dim // 2
288
+
289
+ pid_token = tl.program_id(0)
290
+ pid_hblock = tl.program_id(1)
291
+
292
+ pos = pid_token % seq_len
293
+
294
+ DO_ptr = DO + pid_token * stride_do_token + pid_hblock * BLOCK_H * stride_do_head
295
+ DQ_ptr = DQ + pid_token * stride_dq_token + pid_hblock * BLOCK_H * stride_dq_head
296
+
297
+ h_off = tl.arange(0, BLOCK_H)[:, None]
298
+ mask = (pid_hblock * BLOCK_H + h_off) < head_num
299
+
300
+ # Copy nope grad as-is
301
+ nope_off_in = h_off * stride_do_head + tl.arange(0, nope_dim)[None, :]
302
+ nope_grad = tl.load(DO_ptr + nope_off_in, mask=mask)
303
+ nope_off_out = h_off * stride_dq_head + tl.arange(0, nope_dim)[None, :]
304
+ tl.store(DQ_ptr + nope_off_out, nope_grad, mask=mask)
305
+
306
+ # Inverse RoPE: contiguous → interleaved
307
+ cos = tl.load(COS + pos * HALF + tl.arange(0, HALF))
308
+ sin = tl.load(SIN + pos * HALF + tl.arange(0, HALF))
309
+ cos = cos.expand_dims(0).broadcast_to(BLOCK_H, HALF)
310
+ sin = sin.expand_dims(0).broadcast_to(BLOCK_H, HALF)
311
+
312
+ real_off = h_off * stride_do_head + nope_dim + tl.arange(0, HALF)[None, :]
313
+ imag_off = real_off + HALF
314
+ d_real = tl.load(DO_ptr + real_off, mask=mask).to(tl.float32)
315
+ d_imag = tl.load(DO_ptr + imag_off, mask=mask).to(tl.float32)
316
+
317
+ dx1 = d_real * cos + d_imag * sin
318
+ dx2 = -d_real * sin + d_imag * cos
319
+
320
+ x1_off = h_off * stride_dq_head + nope_dim + tl.arange(0, HALF)[None, :] * 2
321
+ x2_off = x1_off + 1
322
+ tl.store(DQ_ptr + x1_off, dx1, mask=mask)
323
+ tl.store(DQ_ptr + x2_off, dx2, mask=mask)
324
+
325
+
326
+ @torch.library.custom_op("motif::q_rope_bwd", mutates_args=())
327
+ def _q_rope_bwd(
328
+ grad_out: torch.Tensor,
329
+ cos: torch.Tensor,
330
+ sin: torch.Tensor,
331
+ nope_dim: int,
332
+ rope_dim: int,
333
+ seq_len: int,
334
+ ) -> torch.Tensor:
335
+ assert grad_out.stride(-1) == 1, "fused_rope kernel requires last-dim unit stride"
336
+ B_S, H, D = grad_out.shape
337
+ grad_in = torch.empty_like(grad_out)
338
+ grid = lambda META: (B_S, triton.cdiv(H, META["BLOCK_H"]))
339
+ _q_rope_bwd_kernel[grid](
340
+ grad_out, grad_in, cos, sin,
341
+ nope_dim, rope_dim, H, seq_len,
342
+ grad_out.stride(0), grad_out.stride(1),
343
+ grad_in.stride(0), grad_in.stride(1),
344
+ **_CFG_Q_ROPE_BWD,
345
+ )
346
+ return grad_in
347
+
348
+
349
+ @_q_rope_bwd.register_fake
350
+ def _q_rope_bwd_fake(grad_out, cos, sin, nope_dim, rope_dim, seq_len):
351
+ return torch.empty_like(grad_out)
352
+
353
+
354
+ # ---------------------------------------------------------------------------
355
+ # In-place Q RoPE kernel (eliminates cat + nope copy)
356
+ # Modifies q[..., nope_dim:] from interleaved → contiguous format IN-PLACE.
357
+ # nope section [..., :nope_dim] is untouched.
358
+ # Forward: in-place on Q (mutates)
359
+ # Backward: out-of-place via existing _q_rope_bwd_kernel (nope grad copy + inverse rope)
360
+ # ---------------------------------------------------------------------------
361
+
362
+ @triton.jit
363
+ def _q_rope_inplace_fwd_kernel(
364
+ Q, # [B*S, H, nope_dim + rope_dim] modified in-place on [..., nope_dim:]
365
+ COS, # [S, rope_dim // 2]
366
+ SIN, # [S, rope_dim // 2]
367
+ nope_dim: tl.constexpr,
368
+ rope_dim: tl.constexpr,
369
+ head_num: tl.constexpr,
370
+ seq_len,
371
+ stride_q_token,
372
+ stride_q_head,
373
+ BLOCK_H: tl.constexpr,
374
+ ):
375
+ HALF: tl.constexpr = rope_dim // 2
376
+
377
+ pid_token = tl.program_id(0)
378
+ pid_hblock = tl.program_id(1)
379
+
380
+ pos = pid_token % seq_len
381
+ cos = tl.load(COS + pos * HALF + tl.arange(0, HALF))
382
+ sin = tl.load(SIN + pos * HALF + tl.arange(0, HALF))
383
+ cos = cos.expand_dims(0).broadcast_to(BLOCK_H, HALF)
384
+ sin = sin.expand_dims(0).broadcast_to(BLOCK_H, HALF)
385
+
386
+ Q_ptr = Q + pid_token * stride_q_token + pid_hblock * BLOCK_H * stride_q_head
387
+
388
+ h_off = tl.arange(0, BLOCK_H)[:, None]
389
+ mask = (pid_hblock * BLOCK_H + h_off) < head_num
390
+
391
+ # Read rope section interleaved: [r0,i0,r1,i1,...]
392
+ x1_off = h_off * stride_q_head + nope_dim + tl.arange(0, HALF)[None, :] * 2
393
+ x2_off = x1_off + 1
394
+ x1 = tl.load(Q_ptr + x1_off, mask=mask).to(tl.float32)
395
+ x2 = tl.load(Q_ptr + x2_off, mask=mask).to(tl.float32)
396
+
397
+ out_real = x1 * cos - x2 * sin
398
+ out_imag = x1 * sin + x2 * cos
399
+
400
+ # Write back to SAME rope section in contiguous format: [r0..r31, i0..i31]
401
+ real_off = h_off * stride_q_head + nope_dim + tl.arange(0, HALF)[None, :]
402
+ imag_off = real_off + HALF
403
+ tl.store(Q_ptr + real_off, out_real, mask=mask)
404
+ tl.store(Q_ptr + imag_off, out_imag, mask=mask)
405
+
406
+
407
+ @torch.library.custom_op("motif::q_rope_inplace_fwd", mutates_args=("q",))
408
+ def _q_rope_inplace_fwd(
409
+ q: torch.Tensor,
410
+ cos: torch.Tensor,
411
+ sin: torch.Tensor,
412
+ nope_dim: int,
413
+ rope_dim: int,
414
+ seq_len: int,
415
+ ) -> None:
416
+ # In-place op: `q` is mutated on [..., nope_dim:]; no return value by design
417
+ # (declared via `mutates_args=("q",)` on the custom_op).
418
+ assert q.stride(-1) == 1, "fused_rope kernel requires last-dim unit stride"
419
+ B_S, H, _ = q.shape
420
+ grid = lambda META: (B_S, triton.cdiv(H, META["BLOCK_H"]))
421
+ _q_rope_inplace_fwd_kernel[grid](
422
+ q, cos, sin,
423
+ nope_dim, rope_dim, H, seq_len,
424
+ q.stride(0), q.stride(1),
425
+ **_CFG_Q_ROPE_INPLACE_FWD,
426
+ )
427
+
428
+
429
+ class FusedQRoPEInplace(torch.autograd.Function):
430
+ @staticmethod
431
+ def forward(ctx, q, cos, sin, nope_dim, rope_dim, seq_len):
432
+ # q: [B, S, H, nope_dim + rope_dim] mutated in-place on [..., nope_dim:]
433
+ B, S, H, D = q.shape
434
+ assert D == nope_dim + rope_dim
435
+ # Require full contiguity so that `reshape(B*S, H, D)` is guaranteed to
436
+ # be a view — otherwise reshape silently copies and the in-place mutation
437
+ # would not reach the original `q` that `ctx.mark_dirty` targets.
438
+ assert q.is_contiguous(), "FusedQRoPEInplace requires contiguous q"
439
+ _q_rope_inplace_fwd(q.reshape(B * S, H, D), cos, sin, nope_dim, rope_dim, seq_len)
440
+ ctx.mark_dirty(q)
441
+ ctx.save_for_backward(cos, sin)
442
+ ctx.nope_dim = nope_dim
443
+ ctx.rope_dim = rope_dim
444
+ ctx.seq_len = seq_len
445
+ ctx.shape = (B, S, H, D)
446
+ # Return the mutated `q` itself (not a new tensor) so autograd edges
447
+ # flow through; the underlying op is declared `-> None` (in-place).
448
+ return q
449
+
450
+ @staticmethod
451
+ def backward(ctx, grad_out):
452
+ # grad_out: [B, S, H, nope_dim + rope_dim] rope section in contiguous grad fmt
453
+ # Produce grad_in: same shape, rope section in interleaved grad fmt (matches Q input layout)
454
+ cos, sin = ctx.saved_tensors
455
+ B, S, H, D = ctx.shape
456
+ # Reuse existing Q backward kernel (copies nope grad + inverse-ropes pe section)
457
+ grad_in_3d = _q_rope_bwd(
458
+ grad_out.contiguous().reshape(B * S, H, D),
459
+ cos, sin, ctx.nope_dim, ctx.rope_dim, ctx.seq_len,
460
+ )
461
+ return grad_in_3d.view(B, S, H, D), None, None, None, None, None
462
+
463
+
464
+ def fused_q_rope_inplace(q, freqs_cis, nope_dim, rope_dim):
465
+ """In-place fused Q RoPE. Modifies q[..., nope_dim:] from interleaved → contiguous format.
466
+
467
+ Replaces:
468
+ q_nope, q_pe = split(q, [nope_dim, rope_dim])
469
+ q_pe = fused_apply_rope(q_pe, freqs_cis)
470
+ q_total = cat([q_nope, q_pe])
471
+
472
+ Saves the cat copy (~415 µs/layer for Motif3) compared to out-of-place variants.
473
+
474
+ Args:
475
+ q: [B, S, H, nope_dim + rope_dim] from wq_b output. Will be mutated.
476
+ freqs_cis: [max_seq_len, rope_dim//2] complex64
477
+
478
+ Returns:
479
+ Same tensor `q`, now with rope section in contiguous format.
480
+ """
481
+ S = q.shape[1]
482
+ cos = freqs_cis[:S].real.contiguous()
483
+ sin = freqs_cis[:S].imag.contiguous()
484
+ return FusedQRoPEInplace.apply(q, cos, sin, nope_dim, rope_dim, S)
485
+
486
+
487
+ def fused_kv_split_rope_cat(kv_latent, k_pe, k_nope_dim, v_dim, rope_dim):
488
+ """Fused KV split + k_pe expand + cat. No graph break.
489
+
490
+ Replaces:
491
+ k_nope, v = split(kv_latent, [k_nope_dim, v_dim])
492
+ k_full = cat([k_nope, k_pe.expand(-1,-1,H,-1)])
493
+
494
+ Args:
495
+ kv_latent: [B, S, H, k_nope_dim + v_dim]
496
+ k_pe: [B, S, 1, rope_dim] (already RoPE'd, contiguous format)
497
+
498
+ Returns:
499
+ key: [B, S, H, k_nope_dim + rope_dim]
500
+ value: [B, S, H, v_dim]
501
+ """
502
+ return FusedKVRoPE.apply(kv_latent, k_pe, k_nope_dim, v_dim, rope_dim)