Merge pull request #22 from MotifTechnologies/jangwoong/mla-rope-fa4-port
Browse files- benchmarks/bench_mla_rope.yaml +81 -0
- benchmarks/cases/mla_rope.py +97 -0
- benchmarks/run_cases.py +23 -3
- tests/test_mla_rope_grad.py +142 -0
- torch-ext/activation/__init__.py +6 -0
- torch-ext/activation/fused_rope.py +502 -0
benchmarks/bench_mla_rope.yaml
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
apiVersion: trainer.kubeflow.org/v1alpha1
|
| 2 |
+
kind: TrainJob
|
| 3 |
+
metadata:
|
| 4 |
+
name: jangwoong-mla-rope-bench
|
| 5 |
+
namespace: kbm-g-np-motif
|
| 6 |
+
spec:
|
| 7 |
+
managedBy: trainer.kubeflow.org/trainjob-controller
|
| 8 |
+
podTemplateOverrides:
|
| 9 |
+
- spec:
|
| 10 |
+
containers:
|
| 11 |
+
- name: node
|
| 12 |
+
volumeMounts:
|
| 13 |
+
- mountPath: /dev/shm
|
| 14 |
+
name: shm
|
| 15 |
+
- mountPath: /mair
|
| 16 |
+
name: mair
|
| 17 |
+
volumes:
|
| 18 |
+
- emptyDir:
|
| 19 |
+
medium: Memory
|
| 20 |
+
sizeLimit: 64Gi
|
| 21 |
+
name: shm
|
| 22 |
+
- name: mair
|
| 23 |
+
persistentVolumeClaim:
|
| 24 |
+
claimName: mair
|
| 25 |
+
targetJobs:
|
| 26 |
+
- name: node
|
| 27 |
+
runtimeRef:
|
| 28 |
+
apiGroup: trainer.kubeflow.org
|
| 29 |
+
kind: ClusterTrainingRuntime
|
| 30 |
+
name: torch-distributed
|
| 31 |
+
suspend: false
|
| 32 |
+
trainer:
|
| 33 |
+
args:
|
| 34 |
+
- /bin/bash
|
| 35 |
+
- '-c'
|
| 36 |
+
- |
|
| 37 |
+
set -e
|
| 38 |
+
ACTIVATIONPATH=/mair/team-sys/jangwoong/activation
|
| 39 |
+
DATESTAMP=$(date +'%y_%m_%d_%H_%M')
|
| 40 |
+
SAVE_PATH=$ACTIVATIONPATH/benchmarks/results/mla_rope/${DATESTAMP}
|
| 41 |
+
mkdir -p $SAVE_PATH
|
| 42 |
+
|
| 43 |
+
pip install triton pandas
|
| 44 |
+
|
| 45 |
+
# Build activation from local source (copy to /tmp to avoid NFS race)
|
| 46 |
+
mkdir -p /tmp/activation_src && rm -rf /tmp/activation_src/* && \
|
| 47 |
+
rsync -a --exclude=build $ACTIVATIONPATH/ /tmp/activation_src/ && \
|
| 48 |
+
pip install --no-build-isolation /tmp/activation_src 2>&1 | tail -50
|
| 49 |
+
|
| 50 |
+
python -c "import activation; print('fused_q_rope_inplace:', activation.fused_q_rope_inplace); print('fused_kv_split_rope_cat:', activation.fused_kv_split_rope_cat)"
|
| 51 |
+
|
| 52 |
+
nvidia-smi | tee $SAVE_PATH/nvidia_smi.txt
|
| 53 |
+
|
| 54 |
+
echo "=== MLA RoPE benchmark ==="
|
| 55 |
+
cd $ACTIVATIONPATH/benchmarks
|
| 56 |
+
CUDA_VISIBLE_DEVICES=0 python run_cases.py --case mla_rope --dtype bf16 \
|
| 57 |
+
--save-path $SAVE_PATH/bench 2>&1 | tee $SAVE_PATH/bench.log
|
| 58 |
+
|
| 59 |
+
echo "=== Done. Results at: $SAVE_PATH ==="
|
| 60 |
+
exit 0
|
| 61 |
+
env:
|
| 62 |
+
- name: PYTHONUNBUFFERED
|
| 63 |
+
value: '1'
|
| 64 |
+
- name: PYTORCH_ALLOC_CONF
|
| 65 |
+
value: expandable_segments:True
|
| 66 |
+
- name: CUDA_LAUNCH_BLOCKING
|
| 67 |
+
value: '0'
|
| 68 |
+
- name: OMP_NUM_THREADS
|
| 69 |
+
value: '1'
|
| 70 |
+
image: ghcr.io/motiftechnologies/llm-training:v0.1.8
|
| 71 |
+
numNodes: 1
|
| 72 |
+
numProcPerNode: 1
|
| 73 |
+
resourcesPerNode:
|
| 74 |
+
limits:
|
| 75 |
+
cpu: '96'
|
| 76 |
+
memory: 1024Gi
|
| 77 |
+
nvidia.com/gpu: '8'
|
| 78 |
+
requests:
|
| 79 |
+
cpu: '96'
|
| 80 |
+
memory: 1024Gi
|
| 81 |
+
nvidia.com/gpu: '8'
|
benchmarks/cases/mla_rope.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MLA RoPE case: fused (activation Triton) vs vanilla (PyTorch native).
|
| 2 |
+
|
| 3 |
+
MLA head dims are fixed by motif3 spec (H_q=80, H_kv=16, D_nope=128,
|
| 4 |
+
D_rope=64, D_v=128); the benchmark's only sweep axes are (bs, sl).
|
| 5 |
+
The framework's ``dim`` axis is a dummy here — pass 0 in configs.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
import activation
|
| 12 |
+
|
| 13 |
+
from common.diff_engine import DiffCase
|
| 14 |
+
|
| 15 |
+
# ---- MLA shapes (motif3_seq) -----------------------------------------------
|
| 16 |
+
H_Q, H_KV = 80, 16
|
| 17 |
+
D_NOPE, D_ROPE, D_V = 128, 64, 128
|
| 18 |
+
D_QK = D_NOPE + D_ROPE # 192
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ---- reference (PyTorch native) --------------------------------------------
|
| 22 |
+
def _precompute_freqs_cis(dim, end, theta=10000.0):
|
| 23 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
|
| 24 |
+
t = torch.arange(end, dtype=torch.float32)
|
| 25 |
+
freqs = torch.outer(t, freqs)
|
| 26 |
+
return torch.polar(torch.ones_like(freqs), freqs)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _apply_rotary_emb_single(x, freqs_cis):
|
| 30 |
+
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 31 |
+
freqs_cis = freqs_cis[: x_.shape[1]].view(1, x_.shape[1], 1, x_.shape[3])
|
| 32 |
+
out = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
| 33 |
+
return out.type_as(x)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _reorder(qk, rope_dim):
|
| 37 |
+
B, S = qk.shape[0], qk.shape[1]
|
| 38 |
+
qk = qk.view(B, S, -1, rope_dim // 2, 2).transpose(3, 4)
|
| 39 |
+
return qk.reshape(B, S, -1, rope_dim)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _vanilla(q, kv_latent, k_pe, freqs_cis):
|
| 43 |
+
q_nope, q_pe = torch.split(q, [D_NOPE, D_ROPE], dim=-1)
|
| 44 |
+
q_pe = _reorder(_apply_rotary_emb_single(q_pe, freqs_cis), D_ROPE)
|
| 45 |
+
q_total = torch.cat([q_nope, q_pe], dim=-1)
|
| 46 |
+
k_pe_roped = _reorder(_apply_rotary_emb_single(k_pe.unsqueeze(2), freqs_cis), D_ROPE)
|
| 47 |
+
k_nope, v = torch.split(kv_latent, [D_NOPE, D_V], dim=-1)
|
| 48 |
+
k_full = torch.cat([k_nope, k_pe_roped.expand(-1, -1, H_KV, -1)], dim=-1)
|
| 49 |
+
return q_total, k_full, v
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _fused(q, kv_latent, k_pe, freqs_cis):
|
| 53 |
+
q_total = activation.fused_q_rope_inplace(q, freqs_cis, D_NOPE, D_ROPE)
|
| 54 |
+
# k_pe RoPE stays PyTorch native (head-shared, too small for custom kernel)
|
| 55 |
+
k_pe_roped = _reorder(_apply_rotary_emb_single(k_pe.unsqueeze(2), freqs_cis), D_ROPE)
|
| 56 |
+
k_full, v = activation.fused_kv_split_rope_cat(kv_latent, k_pe_roped, D_NOPE, D_V, D_ROPE)
|
| 57 |
+
return q_total, k_full, v
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class _VanillaModule(nn.Module):
|
| 61 |
+
def forward(self, q, kv_latent, k_pe, freqs_cis):
|
| 62 |
+
return _vanilla(q, kv_latent, k_pe, freqs_cis)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class _FusedModule(nn.Module):
|
| 66 |
+
def forward(self, q, kv_latent, k_pe, freqs_cis):
|
| 67 |
+
return _fused(q, kv_latent, k_pe, freqs_cis)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class MLARoPE(DiffCase):
|
| 71 |
+
# Framework calls build_inputs(bs, sl, dim, dtype, eps) — dim unused.
|
| 72 |
+
def build_inputs(self, bs, sl, hidden, dtype, eps):
|
| 73 |
+
return {
|
| 74 |
+
"q": (torch.randn(bs, sl, H_Q, D_QK, dtype=dtype) * 0.5).requires_grad_(True),
|
| 75 |
+
"kv_latent": (torch.randn(bs, sl, H_KV, D_NOPE + D_V, dtype=dtype) * 0.5).requires_grad_(True),
|
| 76 |
+
"k_pe": (torch.randn(bs, sl, D_ROPE, dtype=dtype) * 0.5).requires_grad_(True),
|
| 77 |
+
"freqs_cis": _precompute_freqs_cis(D_ROPE, sl),
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def make_naive(self, I):
|
| 81 |
+
return _VanillaModule()
|
| 82 |
+
|
| 83 |
+
def make_cuda(self, I):
|
| 84 |
+
return _FusedModule()
|
| 85 |
+
|
| 86 |
+
def forward(self, obj, I):
|
| 87 |
+
# fused_q_rope_inplace needs non-leaf q; wrap both paths for fairness
|
| 88 |
+
q_in = I["q"] * 1.0
|
| 89 |
+
kv_in = I["kv_latent"] * 1.0
|
| 90 |
+
kpe_in = I["k_pe"] * 1.0
|
| 91 |
+
return obj(q_in, kv_in, kpe_in, I["freqs_cis"])
|
| 92 |
+
|
| 93 |
+
def grad_inputs(self, I):
|
| 94 |
+
return [I["q"], I["kv_latent"], I["k_pe"]]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
CASE = MLARoPE()
|
benchmarks/run_cases.py
CHANGED
|
@@ -62,7 +62,8 @@ def main():
|
|
| 62 |
ap = argparse.ArgumentParser()
|
| 63 |
ap.add_argument(
|
| 64 |
"--case",
|
| 65 |
-
choices=["rms", "add_rms", "poly", "mul_poly", "grouped_mul_poly"
|
|
|
|
| 66 |
required=True)
|
| 67 |
ap.add_argument("--plot", action="store_true")
|
| 68 |
ap.add_argument(
|
|
@@ -95,12 +96,26 @@ def main():
|
|
| 95 |
case: DiffCase = mod.CASE
|
| 96 |
|
| 97 |
# Correctness checks across multiple configs
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
print(
|
| 100 |
f"Checking correctness: bs={bs}, sl={sl}, D={hid} "
|
| 101 |
f"(N={bs*sl})...",
|
| 102 |
end=" ")
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
print("✅")
|
| 105 |
|
| 106 |
for dtype_name, dtype in dtypes:
|
|
@@ -109,6 +124,7 @@ def main():
|
|
| 109 |
print(f"{'=' * 60}\n")
|
| 110 |
|
| 111 |
save_dir = os.path.join(args.save_path, args.case, dtype_name)
|
|
|
|
| 112 |
is_grouped = args.case == "grouped_mul_poly"
|
| 113 |
|
| 114 |
if args.plot:
|
|
@@ -118,6 +134,8 @@ def main():
|
|
| 118 |
dim = [1280]
|
| 119 |
elif "poly" in args.case:
|
| 120 |
dim = [8192, 16384]
|
|
|
|
|
|
|
| 121 |
else:
|
| 122 |
dim = [2048, 4096]
|
| 123 |
configs = list(
|
|
@@ -170,6 +188,8 @@ def main():
|
|
| 170 |
dim = [1280]
|
| 171 |
elif "poly" in args.case:
|
| 172 |
dim = [8192, 16384]
|
|
|
|
|
|
|
| 173 |
else:
|
| 174 |
dim = [2048, 4096]
|
| 175 |
configs = list(
|
|
|
|
| 62 |
ap = argparse.ArgumentParser()
|
| 63 |
ap.add_argument(
|
| 64 |
"--case",
|
| 65 |
+
choices=["rms", "add_rms", "poly", "mul_poly", "grouped_mul_poly",
|
| 66 |
+
"mla_rope"],
|
| 67 |
required=True)
|
| 68 |
ap.add_argument("--plot", action="store_true")
|
| 69 |
ap.add_argument(
|
|
|
|
| 96 |
case: DiffCase = mod.CASE
|
| 97 |
|
| 98 |
# Correctness checks across multiple configs
|
| 99 |
+
# NOTE: calculate_diff positionally calls build_inputs(hidden_size, bs, sl);
|
| 100 |
+
# bench framework positionally calls build_inputs(bs, sl, dim). These
|
| 101 |
+
# disagree — rms-style cases don't care (all 3 axes are flat dims), but
|
| 102 |
+
# mla_rope does. We match the bench convention in cases/mla_rope.py, so
|
| 103 |
+
# we swap arg names at the correctness call site below for that case.
|
| 104 |
+
if args.case == "mla_rope":
|
| 105 |
+
cfgs = [(1, 1024, 0), (4, 4096, 0), (8, 4096, 0)] # (bs, sl, dummy)
|
| 106 |
+
else:
|
| 107 |
+
cfgs = [(2, 128, 4096), (8, 4096, 1280), (1, 32768, 1280)]
|
| 108 |
+
for bs, sl, hid in cfgs:
|
| 109 |
print(
|
| 110 |
f"Checking correctness: bs={bs}, sl={sl}, D={hid} "
|
| 111 |
f"(N={bs*sl})...",
|
| 112 |
end=" ")
|
| 113 |
+
if args.case == "mla_rope":
|
| 114 |
+
# Swap so positional (hidden_size, batch_size, seq_len) maps to
|
| 115 |
+
# our build_inputs(bs, sl, dim) as (bs, sl, dummy).
|
| 116 |
+
calculate_diff(case, batch_size=sl, seq_len=hid, hidden_size=bs)
|
| 117 |
+
else:
|
| 118 |
+
calculate_diff(case, batch_size=bs, seq_len=sl, hidden_size=hid)
|
| 119 |
print("✅")
|
| 120 |
|
| 121 |
for dtype_name, dtype in dtypes:
|
|
|
|
| 124 |
print(f"{'=' * 60}\n")
|
| 125 |
|
| 126 |
save_dir = os.path.join(args.save_path, args.case, dtype_name)
|
| 127 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 128 |
is_grouped = args.case == "grouped_mul_poly"
|
| 129 |
|
| 130 |
if args.plot:
|
|
|
|
| 134 |
dim = [1280]
|
| 135 |
elif "poly" in args.case:
|
| 136 |
dim = [8192, 16384]
|
| 137 |
+
elif args.case == "mla_rope":
|
| 138 |
+
dim = [0] # MLA head dims are fixed; dim axis is a dummy
|
| 139 |
else:
|
| 140 |
dim = [2048, 4096]
|
| 141 |
configs = list(
|
|
|
|
| 188 |
dim = [1280]
|
| 189 |
elif "poly" in args.case:
|
| 190 |
dim = [8192, 16384]
|
| 191 |
+
elif args.case == "mla_rope":
|
| 192 |
+
dim = [0] # MLA head dims are fixed; dim axis is a dummy
|
| 193 |
else:
|
| 194 |
dim = [2048, 4096]
|
| 195 |
configs = list(
|
tests/test_mla_rope_grad.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Numerical parity test: activation MLA RoPE kernels vs PyTorch reference.
|
| 2 |
+
|
| 3 |
+
The activation package exposes two Triton kernels for Motif3 MLA attention:
|
| 4 |
+
* fused_q_rope_inplace — in-place RoPE on q's rope section
|
| 5 |
+
* fused_kv_split_rope_cat — split kv_latent + register-broadcast k_pe to H heads + cat
|
| 6 |
+
|
| 7 |
+
This test runs both the fused path and a pure-PyTorch reference over identical
|
| 8 |
+
inputs (forward + backward) and compares all outputs and input gradients.
|
| 9 |
+
|
| 10 |
+
Self-contained: the reference RoPE implementation lives in this file (no
|
| 11 |
+
upstream model code dependency).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import pytest
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
import activation
|
| 18 |
+
|
| 19 |
+
from .utils import assert_close
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Realistic motif3_seq per-GPU shapes (B=local_batch_size, H_q/H_kv per MLA spec).
|
| 23 |
+
SHAPES = [
|
| 24 |
+
# (B, S, H_q, H_kv, D_nope, D_rope, D_v)
|
| 25 |
+
(8, 4096, 80, 16, 128, 64, 128),
|
| 26 |
+
]
|
| 27 |
+
DTYPES = [torch.bfloat16]
|
| 28 |
+
SEEDS = [0]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ------------------------------------------------------------------ reference
|
| 32 |
+
|
| 33 |
+
def _precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
|
| 34 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 35 |
+
t = torch.arange(end, dtype=torch.float32)
|
| 36 |
+
freqs = torch.outer(t, freqs)
|
| 37 |
+
return torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _apply_rotary_emb_single(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
"""[B, S, H, D] interleaved → rotated, in interleaved layout."""
|
| 42 |
+
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 43 |
+
freqs_cis = freqs_cis[: x_.shape[1]].view(1, x_.shape[1], 1, x_.shape[3])
|
| 44 |
+
out = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
| 45 |
+
return out.type_as(x)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _reorder_headdim_elements_rope(qk: torch.Tensor, B: int, S: int, rope_dim: int) -> torch.Tensor:
|
| 49 |
+
"""Interleaved [r0,i0,r1,i1,...] → contiguous [r0,r1,...,i0,i1,...]."""
|
| 50 |
+
qk = qk.view(B, S, -1, rope_dim // 2, 2)
|
| 51 |
+
qk = qk.transpose(3, 4)
|
| 52 |
+
return qk.reshape(B, S, -1, rope_dim)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def vanilla_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v):
|
| 56 |
+
# Q
|
| 57 |
+
q_nope, q_pe = torch.split(q, [D_nope, D_rope], dim=-1)
|
| 58 |
+
q_pe = _apply_rotary_emb_single(q_pe, freqs_cis)
|
| 59 |
+
q_pe = _reorder_headdim_elements_rope(q_pe, B, S, D_rope)
|
| 60 |
+
q_total = torch.cat([q_nope, q_pe], dim=-1)
|
| 61 |
+
# k_pe (head-shared, H=1)
|
| 62 |
+
k_pe_4d = k_pe.unsqueeze(2)
|
| 63 |
+
k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis)
|
| 64 |
+
k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope)
|
| 65 |
+
# KV split + expand + cat
|
| 66 |
+
k_nope, v = torch.split(kv_latent, [D_nope, D_v], dim=-1)
|
| 67 |
+
k_full = torch.cat([k_nope, k_pe_roped.expand(-1, -1, H_kv, -1)], dim=-1)
|
| 68 |
+
return q_total, k_full, v
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def fused_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v):
|
| 72 |
+
q_total = activation.fused_q_rope_inplace(q, freqs_cis, D_nope, D_rope)
|
| 73 |
+
# k_pe RoPE stays PyTorch native (head-shared; standalone Triton kernel was
|
| 74 |
+
# launch-bound on B200, no measurable win — see PR #22).
|
| 75 |
+
k_pe_4d = k_pe.unsqueeze(2)
|
| 76 |
+
k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis)
|
| 77 |
+
k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope)
|
| 78 |
+
k_full, v = activation.fused_kv_split_rope_cat(
|
| 79 |
+
kv_latent, k_pe_roped, D_nope, D_v, D_rope
|
| 80 |
+
)
|
| 81 |
+
return q_total, k_full, v
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ------------------------------------------------------------------ harness
|
| 85 |
+
|
| 86 |
+
def _run_with_grad(path_fn, q, kv_latent, k_pe, freqs_cis, **shape_kwargs):
|
| 87 |
+
# Inputs come in as leaves; thread through a no-op so the in-place fused_q
|
| 88 |
+
# kernel sees a non-leaf (mirrors the real model where q is a Linear output).
|
| 89 |
+
q_leaf, kv_leaf, kpe_leaf = (
|
| 90 |
+
q.clone().detach().requires_grad_(True),
|
| 91 |
+
kv_latent.clone().detach().requires_grad_(True),
|
| 92 |
+
k_pe.clone().detach().requires_grad_(True),
|
| 93 |
+
)
|
| 94 |
+
q_in, kv_in, kpe_in = q_leaf * 1.0, kv_leaf * 1.0, kpe_leaf * 1.0
|
| 95 |
+
|
| 96 |
+
q_total, k_full, v = path_fn(q_in, kv_in, kpe_in, freqs_cis, **shape_kwargs)
|
| 97 |
+
loss = (q_total.float() ** 2).sum() + (k_full.float() ** 2).sum() + (v.float() ** 2).sum()
|
| 98 |
+
loss.backward()
|
| 99 |
+
|
| 100 |
+
return (
|
| 101 |
+
q_total.detach(), k_full.detach(), v.detach(),
|
| 102 |
+
q_leaf.grad.detach(), kv_leaf.grad.detach(), kpe_leaf.grad.detach(),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ------------------------------------------------------------------ test
|
| 107 |
+
|
| 108 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
| 109 |
+
@pytest.mark.parametrize("shape", SHAPES)
|
| 110 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 111 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 112 |
+
def test_mla_rope_fused_vs_reference(shape, dtype, seed):
|
| 113 |
+
B, S, H_q, H_kv, D_nope, D_rope, D_v = shape
|
| 114 |
+
D_qk = D_nope + D_rope
|
| 115 |
+
device = "cuda"
|
| 116 |
+
|
| 117 |
+
torch.manual_seed(seed)
|
| 118 |
+
freqs_cis = _precompute_freqs_cis(D_rope, S).to(device)
|
| 119 |
+
|
| 120 |
+
q = (torch.randn(B, S, H_q, D_qk, device=device, dtype=dtype) * 0.5)
|
| 121 |
+
kv_latent = (torch.randn(B, S, H_kv, D_nope + D_v, device=device, dtype=dtype) * 0.5)
|
| 122 |
+
k_pe = (torch.randn(B, S, D_rope, device=device, dtype=dtype) * 0.5)
|
| 123 |
+
|
| 124 |
+
kw = dict(B=B, S=S, H_kv=H_kv, D_nope=D_nope, D_rope=D_rope, D_v=D_v)
|
| 125 |
+
van_q, van_k, van_v, van_gq, van_gkv, van_gkpe = _run_with_grad(
|
| 126 |
+
vanilla_path, q, kv_latent, k_pe, freqs_cis, **kw
|
| 127 |
+
)
|
| 128 |
+
our_q, our_k, our_v, our_gq, our_gkv, our_gkpe = _run_with_grad(
|
| 129 |
+
fused_path, q, kv_latent, k_pe, freqs_cis, **kw
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Forward outputs: small bf16 jitter expected on the q rope rotation
|
| 133 |
+
# (Triton fp32 accum vs inductor fp32 complex_mul order).
|
| 134 |
+
assert_close(our_q.float(), van_q.float(), atol=1e-2, rtol=1e-2)
|
| 135 |
+
# KV path is bit-exact (just slice + register broadcast + store).
|
| 136 |
+
assert_close(our_k.float(), van_k.float(), atol=0.0, rtol=0.0)
|
| 137 |
+
assert_close(our_v.float(), van_v.float(), atol=0.0, rtol=0.0)
|
| 138 |
+
|
| 139 |
+
# Input grads.
|
| 140 |
+
assert_close(our_gq.float(), van_gq.float(), atol=1e-2, rtol=1e-2)
|
| 141 |
+
assert_close(our_gkv.float(), van_gkv.float(), atol=0.0, rtol=0.0)
|
| 142 |
+
assert_close(our_gkpe.float(), van_gkpe.float(), atol=0.0, rtol=0.0)
|
torch-ext/activation/__init__.py
CHANGED
|
@@ -2,6 +2,10 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from .grouped_poly_norm import fused_mul_grouped_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
|
@@ -50,6 +54,8 @@ __all__ = [
|
|
| 50 |
"fused_mul_grouped_poly_norm",
|
| 51 |
"rms_norm",
|
| 52 |
"fused_add_rms_norm",
|
|
|
|
|
|
|
| 53 |
"layers",
|
| 54 |
"parallel_style",
|
| 55 |
"ops",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .fused_rope import (
|
| 6 |
+
fused_kv_split_rope_cat,
|
| 7 |
+
fused_q_rope_inplace,
|
| 8 |
+
)
|
| 9 |
from .grouped_poly_norm import fused_mul_grouped_poly_norm
|
| 10 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 11 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
|
|
|
| 54 |
"fused_mul_grouped_poly_norm",
|
| 55 |
"rms_norm",
|
| 56 |
"fused_add_rms_norm",
|
| 57 |
+
"fused_q_rope_inplace",
|
| 58 |
+
"fused_kv_split_rope_cat",
|
| 59 |
"layers",
|
| 60 |
"parallel_style",
|
| 61 |
"ops",
|
torch-ext/activation/fused_rope.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fused MLA RoPE kernels for Motif3 GDLAttention.
|
| 2 |
+
|
| 3 |
+
Applies RoPE to the input tensor and outputs in contiguous format
|
| 4 |
+
[real..., imag...] so no reorder_headdim_elements_rope is needed.
|
| 5 |
+
|
| 6 |
+
Registered as torch custom_op for torch.compile compatibility (no graph break).
|
| 7 |
+
|
| 8 |
+
Reference: Megatron-LM fused_mla_yarn_rope_apply.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import triton
|
| 13 |
+
import triton.language as tl
|
| 14 |
+
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
# Static configs per kernel — no autotune.
|
| 17 |
+
#
|
| 18 |
+
# Picked from a full-sweep dump (TRITON_PRINT_AUTOTUNING=1) on motif3_seq,
|
| 19 |
+
# B=8, S=4096, n_layers=4, 8×B200, EP=8 (plurality across 8 ranks; see
|
| 20 |
+
# PR description for details). These are the shapes the kernels see at
|
| 21 |
+
# the current prod workload; hard-coding avoids autotune warm-up cost and
|
| 22 |
+
# tie-break jitter between near-equivalent configs.
|
| 23 |
+
#
|
| 24 |
+
# If model config changes materially (different H, D, batch-size/seq-len
|
| 25 |
+
# regime), re-run the dump and update these. Autotune key used to be
|
| 26 |
+
# (rope_dim, head_num); same shape invariants still apply at runtime.
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
# BLOCK_H, num_warps, num_stages per kernel. H values denote head_num
|
| 30 |
+
# (q is per-head H=80, kv fused expands head-shared k_pe to H_kv=16 in motif3_seq).
|
| 31 |
+
_CFG_KV_ROPE_FWD = dict(BLOCK_H=32, num_warps=8, num_stages=2) # kv fused, H=80
|
| 32 |
+
_CFG_KV_ROPE_BWD = dict(BLOCK_H=16, num_warps=4, num_stages=2) # kv bwd; BLOCK_H must cover H_kv=16 (single program per token)
|
| 33 |
+
_CFG_Q_ROPE_INPLACE_FWD = dict(BLOCK_H=8, num_warps=1, num_stages=2) # q in-place, H=80
|
| 34 |
+
_CFG_Q_ROPE_BWD = dict(BLOCK_H=4, num_warps=2, num_stages=2) # q bwd, H=80
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# Phase 2: Fused KV split + RoPE + expand + cat kernel
|
| 40 |
+
# Fuses: split(kv_latent, [k_dim, v_dim]) + RoPE(k_pe) + expand + cat → key, value
|
| 41 |
+
# Reference: Megatron rotary_fwd_kv_kernel
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
@triton.jit
|
| 45 |
+
def _kv_rope_fwd_kernel(
|
| 46 |
+
KV, # [B*S, H, k_nope_dim + v_dim]
|
| 47 |
+
K_PE, # [B*S, 1, rope_dim] (already RoPE'd, contiguous format)
|
| 48 |
+
O_KEY, # [B*S, H, k_nope_dim + rope_dim] output
|
| 49 |
+
O_VALUE, # [B*S, H, v_dim] output
|
| 50 |
+
rope_dim: tl.constexpr,
|
| 51 |
+
k_nope_dim: tl.constexpr,
|
| 52 |
+
v_dim: tl.constexpr,
|
| 53 |
+
head_num: tl.constexpr,
|
| 54 |
+
stride_kv_token,
|
| 55 |
+
stride_kv_head,
|
| 56 |
+
stride_pe_token,
|
| 57 |
+
stride_k_token,
|
| 58 |
+
stride_k_head,
|
| 59 |
+
stride_v_token,
|
| 60 |
+
stride_v_head,
|
| 61 |
+
BLOCK_H: tl.constexpr,
|
| 62 |
+
):
|
| 63 |
+
pid_token = tl.program_id(0)
|
| 64 |
+
pid_hblock = tl.program_id(1)
|
| 65 |
+
|
| 66 |
+
KV_ptr = KV + pid_token * stride_kv_token + pid_hblock * BLOCK_H * stride_kv_head
|
| 67 |
+
K_ptr = O_KEY + pid_token * stride_k_token + pid_hblock * BLOCK_H * stride_k_head
|
| 68 |
+
V_ptr = O_VALUE + pid_token * stride_v_token + pid_hblock * BLOCK_H * stride_v_head
|
| 69 |
+
|
| 70 |
+
h_off = tl.arange(0, BLOCK_H)[:, None]
|
| 71 |
+
mask = (pid_hblock * BLOCK_H + h_off) < head_num
|
| 72 |
+
|
| 73 |
+
# Read k_nope from KV
|
| 74 |
+
k_nope_off = h_off * stride_kv_head + tl.arange(0, k_nope_dim)[None, :]
|
| 75 |
+
k_nope = tl.load(KV_ptr + k_nope_off, mask=mask)
|
| 76 |
+
|
| 77 |
+
# Read v from KV
|
| 78 |
+
v_off = h_off * stride_kv_head + k_nope_dim + tl.arange(0, v_dim)[None, :]
|
| 79 |
+
v = tl.load(KV_ptr + v_off, mask=mask)
|
| 80 |
+
|
| 81 |
+
# Read k_pe (shared across all heads, already RoPE'd)
|
| 82 |
+
# K_PE is [B*S, 1, rope_dim], broadcast to all heads
|
| 83 |
+
pe_ptr = K_PE + pid_token * stride_pe_token
|
| 84 |
+
k_pe = tl.load(pe_ptr + tl.arange(0, rope_dim)[None, :])
|
| 85 |
+
k_pe = k_pe.broadcast_to(BLOCK_H, rope_dim)
|
| 86 |
+
|
| 87 |
+
# Write key = [k_nope | k_pe]
|
| 88 |
+
k_nope_out = h_off * stride_k_head + tl.arange(0, k_nope_dim)[None, :]
|
| 89 |
+
tl.store(K_ptr + k_nope_out, k_nope, mask=mask)
|
| 90 |
+
k_pe_out = h_off * stride_k_head + k_nope_dim + tl.arange(0, rope_dim)[None, :]
|
| 91 |
+
tl.store(K_ptr + k_pe_out, k_pe, mask=mask)
|
| 92 |
+
|
| 93 |
+
# Write value
|
| 94 |
+
v_out = h_off * stride_v_head + tl.arange(0, v_dim)[None, :]
|
| 95 |
+
tl.store(V_ptr + v_out, v, mask=mask)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@triton.jit
|
| 99 |
+
def _kv_rope_bwd_kernel(
|
| 100 |
+
DO_KEY, # [B*S, H, k_nope_dim + rope_dim] grad_key (in)
|
| 101 |
+
DO_VALUE, # [B*S, H, v_dim] grad_value(in)
|
| 102 |
+
DKV, # [B*S, H, k_nope_dim + v_dim] grad_kv_latent (out)
|
| 103 |
+
DKPE, # [B*S, 1, rope_dim] grad_k_pe (out, bf16)
|
| 104 |
+
rope_dim: tl.constexpr,
|
| 105 |
+
k_nope_dim: tl.constexpr,
|
| 106 |
+
v_dim: tl.constexpr,
|
| 107 |
+
head_num: tl.constexpr,
|
| 108 |
+
stride_dk_token, stride_dk_head,
|
| 109 |
+
stride_dv_token, stride_dv_head,
|
| 110 |
+
stride_dkv_token, stride_dkv_head,
|
| 111 |
+
stride_dkpe_token,
|
| 112 |
+
BLOCK_H: tl.constexpr,
|
| 113 |
+
):
|
| 114 |
+
"""Reverse of _kv_rope_fwd_kernel.
|
| 115 |
+
Forward did: key_nope = slice(kv_latent), key_rope = broadcast(k_pe, H),
|
| 116 |
+
value = slice(kv_latent).
|
| 117 |
+
Backward therefore: grad_kv_latent_nope = grad_key_nope, grad_kv_latent_v = grad_value,
|
| 118 |
+
grad_k_pe = sum(grad_key_rope, dim=H).
|
| 119 |
+
Single program per token; BLOCK_H must cover head_num (H_kv=16 in MLA).
|
| 120 |
+
"""
|
| 121 |
+
pid_token = tl.program_id(0)
|
| 122 |
+
|
| 123 |
+
DK_ptr = DO_KEY + pid_token * stride_dk_token
|
| 124 |
+
DV_ptr = DO_VALUE + pid_token * stride_dv_token
|
| 125 |
+
DKV_ptr = DKV + pid_token * stride_dkv_token
|
| 126 |
+
|
| 127 |
+
h_off = tl.arange(0, BLOCK_H)[:, None]
|
| 128 |
+
mask = h_off < head_num
|
| 129 |
+
|
| 130 |
+
# 1. grad_key nope → grad_kv_latent nope section
|
| 131 |
+
nope_in = h_off * stride_dk_head + tl.arange(0, k_nope_dim)[None, :]
|
| 132 |
+
nope_out = h_off * stride_dkv_head + tl.arange(0, k_nope_dim)[None, :]
|
| 133 |
+
nope_data = tl.load(DK_ptr + nope_in, mask=mask)
|
| 134 |
+
tl.store(DKV_ptr + nope_out, nope_data, mask=mask)
|
| 135 |
+
|
| 136 |
+
# 2. grad_value → grad_kv_latent v section
|
| 137 |
+
v_in = h_off * stride_dv_head + tl.arange(0, v_dim)[None, :]
|
| 138 |
+
v_out = h_off * stride_dkv_head + k_nope_dim + tl.arange(0, v_dim)[None, :]
|
| 139 |
+
v_data = tl.load(DV_ptr + v_in, mask=mask)
|
| 140 |
+
tl.store(DKV_ptr + v_out, v_data, mask=mask)
|
| 141 |
+
|
| 142 |
+
# 3. grad_k_pe = sum over H of grad_key rope section
|
| 143 |
+
# Accumulate in fp32 for precision across H heads.
|
| 144 |
+
rope_in = h_off * stride_dk_head + k_nope_dim + tl.arange(0, rope_dim)[None, :]
|
| 145 |
+
rope_grads = tl.load(DK_ptr + rope_in, mask=mask, other=0.0).to(tl.float32)
|
| 146 |
+
summed = tl.sum(rope_grads, axis=0) # [rope_dim], fp32
|
| 147 |
+
pe_out = pid_token * stride_dkpe_token + tl.arange(0, rope_dim)
|
| 148 |
+
tl.store(DKPE + pe_out, summed.to(DKPE.dtype.element_ty))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@torch.library.custom_op("motif::kv_rope_fwd", mutates_args=())
|
| 152 |
+
def _kv_rope_fwd(
|
| 153 |
+
kv_latent: torch.Tensor, # [B*S, H, k_nope_dim + v_dim]
|
| 154 |
+
k_pe: torch.Tensor, # [B*S, 1, rope_dim] (already RoPE'd)
|
| 155 |
+
k_nope_dim: int,
|
| 156 |
+
v_dim: int,
|
| 157 |
+
rope_dim: int,
|
| 158 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 159 |
+
assert kv_latent.stride(-1) == 1 and k_pe.stride(-1) == 1, (
|
| 160 |
+
"fused_rope kernel requires last-dim unit stride"
|
| 161 |
+
)
|
| 162 |
+
# MLA convention: k_pe is head-shared (single head broadcast to all Q heads
|
| 163 |
+
# via register-level broadcast inside the kernel). GQA/MQA variants that
|
| 164 |
+
# give k_pe its own head dim are not supported by this kernel.
|
| 165 |
+
assert k_pe.shape[1] == 1, f"k_pe must be head-shared (shape[1]==1), got {k_pe.shape}"
|
| 166 |
+
B_S, H, _ = kv_latent.shape
|
| 167 |
+
key = kv_latent.new_empty(B_S, H, k_nope_dim + rope_dim)
|
| 168 |
+
value = kv_latent.new_empty(B_S, H, v_dim)
|
| 169 |
+
grid = lambda META: (B_S, triton.cdiv(H, META["BLOCK_H"]))
|
| 170 |
+
_kv_rope_fwd_kernel[grid](
|
| 171 |
+
kv_latent, k_pe, key, value,
|
| 172 |
+
rope_dim, k_nope_dim, v_dim, H,
|
| 173 |
+
kv_latent.stride(0), kv_latent.stride(1),
|
| 174 |
+
k_pe.stride(0),
|
| 175 |
+
key.stride(0), key.stride(1),
|
| 176 |
+
value.stride(0), value.stride(1),
|
| 177 |
+
**_CFG_KV_ROPE_FWD,
|
| 178 |
+
)
|
| 179 |
+
return key, value
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@_kv_rope_fwd.register_fake
|
| 183 |
+
def _kv_rope_fwd_fake(kv_latent, k_pe, k_nope_dim, v_dim, rope_dim):
|
| 184 |
+
B_S, H, _ = kv_latent.shape
|
| 185 |
+
key = kv_latent.new_empty(B_S, H, k_nope_dim + rope_dim)
|
| 186 |
+
value = kv_latent.new_empty(B_S, H, v_dim)
|
| 187 |
+
return key, value
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@torch.library.custom_op("motif::kv_rope_bwd", mutates_args=())
|
| 191 |
+
def _kv_rope_bwd(
|
| 192 |
+
grad_key: torch.Tensor, # [B*S, H, k_nope_dim + rope_dim]
|
| 193 |
+
grad_value: torch.Tensor, # [B*S, H, v_dim]
|
| 194 |
+
k_nope_dim: int,
|
| 195 |
+
v_dim: int,
|
| 196 |
+
rope_dim: int,
|
| 197 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 198 |
+
assert grad_key.stride(-1) == 1 and grad_value.stride(-1) == 1, (
|
| 199 |
+
"fused_rope kernel requires last-dim unit stride"
|
| 200 |
+
)
|
| 201 |
+
B_S, H, _ = grad_key.shape
|
| 202 |
+
# grad_kv_latent layout matches forward input: [nope | v]
|
| 203 |
+
grad_kv_latent = grad_key.new_empty(B_S, H, k_nope_dim + v_dim)
|
| 204 |
+
# grad_k_pe: head-shared, matches forward input shape
|
| 205 |
+
grad_k_pe = grad_key.new_empty(B_S, 1, rope_dim)
|
| 206 |
+
# Single program per token; BLOCK_H (=16) must be >= H.
|
| 207 |
+
grid = (B_S,)
|
| 208 |
+
_kv_rope_bwd_kernel[grid](
|
| 209 |
+
grad_key, grad_value, grad_kv_latent, grad_k_pe,
|
| 210 |
+
rope_dim, k_nope_dim, v_dim, H,
|
| 211 |
+
grad_key.stride(0), grad_key.stride(1),
|
| 212 |
+
grad_value.stride(0), grad_value.stride(1),
|
| 213 |
+
grad_kv_latent.stride(0), grad_kv_latent.stride(1),
|
| 214 |
+
grad_k_pe.stride(0),
|
| 215 |
+
**_CFG_KV_ROPE_BWD,
|
| 216 |
+
)
|
| 217 |
+
return grad_kv_latent, grad_k_pe
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@_kv_rope_bwd.register_fake
|
| 221 |
+
def _kv_rope_bwd_fake(grad_key, grad_value, k_nope_dim, v_dim, rope_dim):
|
| 222 |
+
B_S, H, _ = grad_key.shape
|
| 223 |
+
return (
|
| 224 |
+
grad_key.new_empty(B_S, H, k_nope_dim + v_dim),
|
| 225 |
+
grad_key.new_empty(B_S, 1, rope_dim),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class FusedKVRoPE(torch.autograd.Function):
|
| 230 |
+
@staticmethod
|
| 231 |
+
def forward(ctx, kv_latent, k_pe, k_nope_dim, v_dim, rope_dim):
|
| 232 |
+
# kv_latent: [B, S, H, k_nope_dim + v_dim]
|
| 233 |
+
# k_pe: [B, S, 1, rope_dim] (already RoPE'd)
|
| 234 |
+
B, S, H, D = kv_latent.shape
|
| 235 |
+
key_3d, value_3d = _kv_rope_fwd(
|
| 236 |
+
kv_latent.reshape(B * S, H, D),
|
| 237 |
+
k_pe.reshape(B * S, 1, rope_dim),
|
| 238 |
+
k_nope_dim, v_dim, rope_dim,
|
| 239 |
+
)
|
| 240 |
+
ctx.k_nope_dim = k_nope_dim
|
| 241 |
+
ctx.v_dim = v_dim
|
| 242 |
+
ctx.rope_dim = rope_dim
|
| 243 |
+
ctx.shape = (B, S, H)
|
| 244 |
+
return key_3d.view(B, S, H, k_nope_dim + rope_dim), value_3d.view(B, S, H, v_dim)
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def backward(ctx, grad_key, grad_value):
|
| 248 |
+
B, S, H = ctx.shape
|
| 249 |
+
k_nope_dim = ctx.k_nope_dim
|
| 250 |
+
v_dim = ctx.v_dim
|
| 251 |
+
rope_dim = ctx.rope_dim
|
| 252 |
+
|
| 253 |
+
# Single Triton kernel does: nope copy + v copy + head-sum of rope section.
|
| 254 |
+
# Replaces (slice + cat + sum) inductor path.
|
| 255 |
+
grad_kv_latent_3d, grad_k_pe_3d = _kv_rope_bwd(
|
| 256 |
+
grad_key.contiguous().reshape(B * S, H, k_nope_dim + rope_dim),
|
| 257 |
+
grad_value.contiguous().reshape(B * S, H, v_dim),
|
| 258 |
+
k_nope_dim, v_dim, rope_dim,
|
| 259 |
+
)
|
| 260 |
+
grad_kv_latent = grad_kv_latent_3d.view(B, S, H, k_nope_dim + v_dim)
|
| 261 |
+
grad_k_pe = grad_k_pe_3d.view(B, S, 1, rope_dim)
|
| 262 |
+
return grad_kv_latent, grad_k_pe, None, None, None
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# ---------------------------------------------------------------------------
|
| 266 |
+
# Q RoPE backward kernel — used by FusedQRoPEInplace.backward (below).
|
| 267 |
+
# Out-of-place: reads contiguous grad_out, writes interleaved grad_in; nope
|
| 268 |
+
# gradient is copied through unchanged.
|
| 269 |
+
# ---------------------------------------------------------------------------
|
| 270 |
+
|
| 271 |
+
@triton.jit
|
| 272 |
+
def _q_rope_bwd_kernel(
|
| 273 |
+
DO, # [B*S, H, nope_dim + rope_dim] grad (contiguous rope fmt)
|
| 274 |
+
DQ, # [B*S, H, nope_dim + rope_dim] grad output (interleaved rope fmt)
|
| 275 |
+
COS, # [S, rope_dim // 2]
|
| 276 |
+
SIN, # [S, rope_dim // 2]
|
| 277 |
+
nope_dim: tl.constexpr,
|
| 278 |
+
rope_dim: tl.constexpr,
|
| 279 |
+
head_num: tl.constexpr,
|
| 280 |
+
seq_len,
|
| 281 |
+
stride_do_token,
|
| 282 |
+
stride_do_head,
|
| 283 |
+
stride_dq_token,
|
| 284 |
+
stride_dq_head,
|
| 285 |
+
BLOCK_H: tl.constexpr,
|
| 286 |
+
):
|
| 287 |
+
HALF: tl.constexpr = rope_dim // 2
|
| 288 |
+
|
| 289 |
+
pid_token = tl.program_id(0)
|
| 290 |
+
pid_hblock = tl.program_id(1)
|
| 291 |
+
|
| 292 |
+
pos = pid_token % seq_len
|
| 293 |
+
|
| 294 |
+
DO_ptr = DO + pid_token * stride_do_token + pid_hblock * BLOCK_H * stride_do_head
|
| 295 |
+
DQ_ptr = DQ + pid_token * stride_dq_token + pid_hblock * BLOCK_H * stride_dq_head
|
| 296 |
+
|
| 297 |
+
h_off = tl.arange(0, BLOCK_H)[:, None]
|
| 298 |
+
mask = (pid_hblock * BLOCK_H + h_off) < head_num
|
| 299 |
+
|
| 300 |
+
# Copy nope grad as-is
|
| 301 |
+
nope_off_in = h_off * stride_do_head + tl.arange(0, nope_dim)[None, :]
|
| 302 |
+
nope_grad = tl.load(DO_ptr + nope_off_in, mask=mask)
|
| 303 |
+
nope_off_out = h_off * stride_dq_head + tl.arange(0, nope_dim)[None, :]
|
| 304 |
+
tl.store(DQ_ptr + nope_off_out, nope_grad, mask=mask)
|
| 305 |
+
|
| 306 |
+
# Inverse RoPE: contiguous → interleaved
|
| 307 |
+
cos = tl.load(COS + pos * HALF + tl.arange(0, HALF))
|
| 308 |
+
sin = tl.load(SIN + pos * HALF + tl.arange(0, HALF))
|
| 309 |
+
cos = cos.expand_dims(0).broadcast_to(BLOCK_H, HALF)
|
| 310 |
+
sin = sin.expand_dims(0).broadcast_to(BLOCK_H, HALF)
|
| 311 |
+
|
| 312 |
+
real_off = h_off * stride_do_head + nope_dim + tl.arange(0, HALF)[None, :]
|
| 313 |
+
imag_off = real_off + HALF
|
| 314 |
+
d_real = tl.load(DO_ptr + real_off, mask=mask).to(tl.float32)
|
| 315 |
+
d_imag = tl.load(DO_ptr + imag_off, mask=mask).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
dx1 = d_real * cos + d_imag * sin
|
| 318 |
+
dx2 = -d_real * sin + d_imag * cos
|
| 319 |
+
|
| 320 |
+
x1_off = h_off * stride_dq_head + nope_dim + tl.arange(0, HALF)[None, :] * 2
|
| 321 |
+
x2_off = x1_off + 1
|
| 322 |
+
tl.store(DQ_ptr + x1_off, dx1, mask=mask)
|
| 323 |
+
tl.store(DQ_ptr + x2_off, dx2, mask=mask)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
@torch.library.custom_op("motif::q_rope_bwd", mutates_args=())
|
| 327 |
+
def _q_rope_bwd(
|
| 328 |
+
grad_out: torch.Tensor,
|
| 329 |
+
cos: torch.Tensor,
|
| 330 |
+
sin: torch.Tensor,
|
| 331 |
+
nope_dim: int,
|
| 332 |
+
rope_dim: int,
|
| 333 |
+
seq_len: int,
|
| 334 |
+
) -> torch.Tensor:
|
| 335 |
+
assert grad_out.stride(-1) == 1, "fused_rope kernel requires last-dim unit stride"
|
| 336 |
+
B_S, H, D = grad_out.shape
|
| 337 |
+
grad_in = torch.empty_like(grad_out)
|
| 338 |
+
grid = lambda META: (B_S, triton.cdiv(H, META["BLOCK_H"]))
|
| 339 |
+
_q_rope_bwd_kernel[grid](
|
| 340 |
+
grad_out, grad_in, cos, sin,
|
| 341 |
+
nope_dim, rope_dim, H, seq_len,
|
| 342 |
+
grad_out.stride(0), grad_out.stride(1),
|
| 343 |
+
grad_in.stride(0), grad_in.stride(1),
|
| 344 |
+
**_CFG_Q_ROPE_BWD,
|
| 345 |
+
)
|
| 346 |
+
return grad_in
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@_q_rope_bwd.register_fake
|
| 350 |
+
def _q_rope_bwd_fake(grad_out, cos, sin, nope_dim, rope_dim, seq_len):
|
| 351 |
+
return torch.empty_like(grad_out)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
# ---------------------------------------------------------------------------
|
| 355 |
+
# In-place Q RoPE kernel (eliminates cat + nope copy)
|
| 356 |
+
# Modifies q[..., nope_dim:] from interleaved → contiguous format IN-PLACE.
|
| 357 |
+
# nope section [..., :nope_dim] is untouched.
|
| 358 |
+
# Forward: in-place on Q (mutates)
|
| 359 |
+
# Backward: out-of-place via existing _q_rope_bwd_kernel (nope grad copy + inverse rope)
|
| 360 |
+
# ---------------------------------------------------------------------------
|
| 361 |
+
|
| 362 |
+
@triton.jit
|
| 363 |
+
def _q_rope_inplace_fwd_kernel(
|
| 364 |
+
Q, # [B*S, H, nope_dim + rope_dim] modified in-place on [..., nope_dim:]
|
| 365 |
+
COS, # [S, rope_dim // 2]
|
| 366 |
+
SIN, # [S, rope_dim // 2]
|
| 367 |
+
nope_dim: tl.constexpr,
|
| 368 |
+
rope_dim: tl.constexpr,
|
| 369 |
+
head_num: tl.constexpr,
|
| 370 |
+
seq_len,
|
| 371 |
+
stride_q_token,
|
| 372 |
+
stride_q_head,
|
| 373 |
+
BLOCK_H: tl.constexpr,
|
| 374 |
+
):
|
| 375 |
+
HALF: tl.constexpr = rope_dim // 2
|
| 376 |
+
|
| 377 |
+
pid_token = tl.program_id(0)
|
| 378 |
+
pid_hblock = tl.program_id(1)
|
| 379 |
+
|
| 380 |
+
pos = pid_token % seq_len
|
| 381 |
+
cos = tl.load(COS + pos * HALF + tl.arange(0, HALF))
|
| 382 |
+
sin = tl.load(SIN + pos * HALF + tl.arange(0, HALF))
|
| 383 |
+
cos = cos.expand_dims(0).broadcast_to(BLOCK_H, HALF)
|
| 384 |
+
sin = sin.expand_dims(0).broadcast_to(BLOCK_H, HALF)
|
| 385 |
+
|
| 386 |
+
Q_ptr = Q + pid_token * stride_q_token + pid_hblock * BLOCK_H * stride_q_head
|
| 387 |
+
|
| 388 |
+
h_off = tl.arange(0, BLOCK_H)[:, None]
|
| 389 |
+
mask = (pid_hblock * BLOCK_H + h_off) < head_num
|
| 390 |
+
|
| 391 |
+
# Read rope section interleaved: [r0,i0,r1,i1,...]
|
| 392 |
+
x1_off = h_off * stride_q_head + nope_dim + tl.arange(0, HALF)[None, :] * 2
|
| 393 |
+
x2_off = x1_off + 1
|
| 394 |
+
x1 = tl.load(Q_ptr + x1_off, mask=mask).to(tl.float32)
|
| 395 |
+
x2 = tl.load(Q_ptr + x2_off, mask=mask).to(tl.float32)
|
| 396 |
+
|
| 397 |
+
out_real = x1 * cos - x2 * sin
|
| 398 |
+
out_imag = x1 * sin + x2 * cos
|
| 399 |
+
|
| 400 |
+
# Write back to SAME rope section in contiguous format: [r0..r31, i0..i31]
|
| 401 |
+
real_off = h_off * stride_q_head + nope_dim + tl.arange(0, HALF)[None, :]
|
| 402 |
+
imag_off = real_off + HALF
|
| 403 |
+
tl.store(Q_ptr + real_off, out_real, mask=mask)
|
| 404 |
+
tl.store(Q_ptr + imag_off, out_imag, mask=mask)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
@torch.library.custom_op("motif::q_rope_inplace_fwd", mutates_args=("q",))
|
| 408 |
+
def _q_rope_inplace_fwd(
|
| 409 |
+
q: torch.Tensor,
|
| 410 |
+
cos: torch.Tensor,
|
| 411 |
+
sin: torch.Tensor,
|
| 412 |
+
nope_dim: int,
|
| 413 |
+
rope_dim: int,
|
| 414 |
+
seq_len: int,
|
| 415 |
+
) -> None:
|
| 416 |
+
# In-place op: `q` is mutated on [..., nope_dim:]; no return value by design
|
| 417 |
+
# (declared via `mutates_args=("q",)` on the custom_op).
|
| 418 |
+
assert q.stride(-1) == 1, "fused_rope kernel requires last-dim unit stride"
|
| 419 |
+
B_S, H, _ = q.shape
|
| 420 |
+
grid = lambda META: (B_S, triton.cdiv(H, META["BLOCK_H"]))
|
| 421 |
+
_q_rope_inplace_fwd_kernel[grid](
|
| 422 |
+
q, cos, sin,
|
| 423 |
+
nope_dim, rope_dim, H, seq_len,
|
| 424 |
+
q.stride(0), q.stride(1),
|
| 425 |
+
**_CFG_Q_ROPE_INPLACE_FWD,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class FusedQRoPEInplace(torch.autograd.Function):
|
| 430 |
+
@staticmethod
|
| 431 |
+
def forward(ctx, q, cos, sin, nope_dim, rope_dim, seq_len):
|
| 432 |
+
# q: [B, S, H, nope_dim + rope_dim] mutated in-place on [..., nope_dim:]
|
| 433 |
+
B, S, H, D = q.shape
|
| 434 |
+
assert D == nope_dim + rope_dim
|
| 435 |
+
# Require full contiguity so that `reshape(B*S, H, D)` is guaranteed to
|
| 436 |
+
# be a view — otherwise reshape silently copies and the in-place mutation
|
| 437 |
+
# would not reach the original `q` that `ctx.mark_dirty` targets.
|
| 438 |
+
assert q.is_contiguous(), "FusedQRoPEInplace requires contiguous q"
|
| 439 |
+
_q_rope_inplace_fwd(q.reshape(B * S, H, D), cos, sin, nope_dim, rope_dim, seq_len)
|
| 440 |
+
ctx.mark_dirty(q)
|
| 441 |
+
ctx.save_for_backward(cos, sin)
|
| 442 |
+
ctx.nope_dim = nope_dim
|
| 443 |
+
ctx.rope_dim = rope_dim
|
| 444 |
+
ctx.seq_len = seq_len
|
| 445 |
+
ctx.shape = (B, S, H, D)
|
| 446 |
+
# Return the mutated `q` itself (not a new tensor) so autograd edges
|
| 447 |
+
# flow through; the underlying op is declared `-> None` (in-place).
|
| 448 |
+
return q
|
| 449 |
+
|
| 450 |
+
@staticmethod
|
| 451 |
+
def backward(ctx, grad_out):
|
| 452 |
+
# grad_out: [B, S, H, nope_dim + rope_dim] rope section in contiguous grad fmt
|
| 453 |
+
# Produce grad_in: same shape, rope section in interleaved grad fmt (matches Q input layout)
|
| 454 |
+
cos, sin = ctx.saved_tensors
|
| 455 |
+
B, S, H, D = ctx.shape
|
| 456 |
+
# Reuse existing Q backward kernel (copies nope grad + inverse-ropes pe section)
|
| 457 |
+
grad_in_3d = _q_rope_bwd(
|
| 458 |
+
grad_out.contiguous().reshape(B * S, H, D),
|
| 459 |
+
cos, sin, ctx.nope_dim, ctx.rope_dim, ctx.seq_len,
|
| 460 |
+
)
|
| 461 |
+
return grad_in_3d.view(B, S, H, D), None, None, None, None, None
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def fused_q_rope_inplace(q, freqs_cis, nope_dim, rope_dim):
|
| 465 |
+
"""In-place fused Q RoPE. Modifies q[..., nope_dim:] from interleaved → contiguous format.
|
| 466 |
+
|
| 467 |
+
Replaces:
|
| 468 |
+
q_nope, q_pe = split(q, [nope_dim, rope_dim])
|
| 469 |
+
q_pe = fused_apply_rope(q_pe, freqs_cis)
|
| 470 |
+
q_total = cat([q_nope, q_pe])
|
| 471 |
+
|
| 472 |
+
Saves the cat copy (~415 µs/layer for Motif3) compared to out-of-place variants.
|
| 473 |
+
|
| 474 |
+
Args:
|
| 475 |
+
q: [B, S, H, nope_dim + rope_dim] from wq_b output. Will be mutated.
|
| 476 |
+
freqs_cis: [max_seq_len, rope_dim//2] complex64
|
| 477 |
+
|
| 478 |
+
Returns:
|
| 479 |
+
Same tensor `q`, now with rope section in contiguous format.
|
| 480 |
+
"""
|
| 481 |
+
S = q.shape[1]
|
| 482 |
+
cos = freqs_cis[:S].real.contiguous()
|
| 483 |
+
sin = freqs_cis[:S].imag.contiguous()
|
| 484 |
+
return FusedQRoPEInplace.apply(q, cos, sin, nope_dim, rope_dim, S)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def fused_kv_split_rope_cat(kv_latent, k_pe, k_nope_dim, v_dim, rope_dim):
|
| 488 |
+
"""Fused KV split + k_pe expand + cat. No graph break.
|
| 489 |
+
|
| 490 |
+
Replaces:
|
| 491 |
+
k_nope, v = split(kv_latent, [k_nope_dim, v_dim])
|
| 492 |
+
k_full = cat([k_nope, k_pe.expand(-1,-1,H,-1)])
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
kv_latent: [B, S, H, k_nope_dim + v_dim]
|
| 496 |
+
k_pe: [B, S, 1, rope_dim] (already RoPE'd, contiguous format)
|
| 497 |
+
|
| 498 |
+
Returns:
|
| 499 |
+
key: [B, S, H, k_nope_dim + rope_dim]
|
| 500 |
+
value: [B, S, H, v_dim]
|
| 501 |
+
"""
|
| 502 |
+
return FusedKVRoPE.apply(kv_latent, k_pe, k_nope_dim, v_dim, rope_dim)
|