Kernels
wyldecat Claude Opus 4.6 (1M context) commited on
Commit
0045757
·
1 Parent(s): 60a628a

feat: add grouped poly norm CUDA kernel with scores and hidden_clamp fusion

Browse files

Hand-written CUDA kernel for GroupedFusedMulPolyNorm (MoE).
Fuses polynomial normalization, mul, scores multiplication, and
hidden_clamp into a single kernel launch per forward/backward.

- 4 kernel variants: fwd/bwd × vectorized(width=8)/scalar
- scores: nullable, always fp32, fused into fwd output and bwd gradients
- hidden_clamp: < 0 disabled, >= 0 clamps input/mul/output with correct
backward gradient masking (recompute output for mask, no extra memory)
- ROCm compatible (64-bit warp sync mask)
- C++ op registration for torch.compile (register_fake in Python layer)
- ~3x faster than torch.compile'd PyTorch reference on B200

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

activation/grouped_poly_norm.cu ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/Functions.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ #include <torch/all.h>
5
+
6
+ #include "assert_utils.h"
7
+ #include "cuda_compat.h"
8
+ #include "dispatch_utils.h"
9
+
10
+ namespace motif {
11
+
12
+ template <typename type, int N> struct alignas(sizeof(type) * N) vec_t {
13
+ type data[N];
14
+ };
15
+
16
+ __device__ __forceinline__ int find_expert(const int32_t *__restrict__ offsets,
17
+ int num_experts, int row) {
18
+ int lo = 0, hi = num_experts;
19
+ #pragma unroll 6
20
+ for (int i = 0; i < 12; ++i) {
21
+ if (lo >= hi) break;
22
+ int mid = (lo + hi) >> 1;
23
+ if (offsets[mid] <= row)
24
+ lo = mid + 1;
25
+ else
26
+ hi = mid;
27
+ }
28
+ return lo;
29
+ }
30
+
31
+ __device__ __forceinline__ float4 warp_reduce_f4(float4 v) {
32
+ #ifndef USE_ROCM
33
+ constexpr unsigned int FULL_MASK = 0xffffffff;
34
+ #else
35
+ constexpr unsigned long long FULL_MASK = 0xffffffffffffffffULL;
36
+ #endif
37
+ #pragma unroll
38
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
39
+ v.x += __shfl_xor_sync(FULL_MASK, v.x, mask);
40
+ v.y += __shfl_xor_sync(FULL_MASK, v.y, mask);
41
+ v.z += __shfl_xor_sync(FULL_MASK, v.z, mask);
42
+ v.w += __shfl_xor_sync(FULL_MASK, v.w, mask);
43
+ }
44
+ return v;
45
+ }
46
+
47
+ template <int BLOCK_SIZE>
48
+ __device__ __forceinline__ float4 block_reduce_f4(float4 v) {
49
+ constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE;
50
+ __shared__ float4 warp_results[NUM_WARPS];
51
+
52
+ v = warp_reduce_f4(v);
53
+ const int warp_id = threadIdx.x / WARP_SIZE;
54
+ const int lane_id = threadIdx.x % WARP_SIZE;
55
+
56
+ if (lane_id == 0) warp_results[warp_id] = v;
57
+ __syncthreads();
58
+
59
+ if (warp_id == 0 && lane_id < NUM_WARPS)
60
+ v = warp_results[lane_id];
61
+ else
62
+ v = make_float4(0.f, 0.f, 0.f, 0.f);
63
+
64
+ if (warp_id == 0) v = warp_reduce_f4(v);
65
+
66
+ __shared__ float4 result;
67
+ if (threadIdx.x == 0) result = v;
68
+ __syncthreads();
69
+ return result;
70
+ }
71
+
72
+ // ---------------------------------------------------------------------------
73
+ // Grouped PolyNorm Forward — vectorized (width > 0)
74
+ // Pass 1: accumulate sum_x2, sum_x4, sum_x6 for RMS stats
75
+ // Pass 2: compute poly * mul output, save inv_rms for backward
76
+ // ---------------------------------------------------------------------------
77
+ template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
78
+ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
79
+ grouped_poly_norm_fwd_kernel(
80
+ scalar_t *__restrict__ output,
81
+ acc_t *__restrict__ inv_rms,
82
+ const scalar_t *__restrict__ input,
83
+ const scalar_t *__restrict__ mul,
84
+ const scalar_t *__restrict__ weight,
85
+ const scalar_t *__restrict__ bias,
86
+ const int32_t *__restrict__ offsets,
87
+ const float *__restrict__ scores, // nullable, always fp32
88
+ const acc_t eps, const int D, const int num_experts,
89
+ const int expert_offset,
90
+ const acc_t hidden_clamp) { // < 0 = disabled
91
+ using v_t = vec_t<scalar_t, width>;
92
+ const bool do_clamp = (hidden_clamp >= acc_t(0));
93
+
94
+ const int row = blockIdx.x;
95
+ const int vec_d = D / width;
96
+ const int64_t base = (int64_t)row * vec_d;
97
+
98
+ const v_t *__restrict__ in_v = reinterpret_cast<const v_t *>(input) + base;
99
+
100
+ const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
101
+ const acc_t w0 = weight[eidx * 3 + 0];
102
+ const acc_t w1 = weight[eidx * 3 + 1];
103
+ const acc_t w2 = weight[eidx * 3 + 2];
104
+ const acc_t b_val = bias[eidx];
105
+ const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
106
+
107
+ // Pass 1: RMS stats (on clamped input if enabled)
108
+ acc_t s2 = 0, s4 = 0, s6 = 0;
109
+ for (int i = threadIdx.x; i < vec_d; i += BLOCK_SIZE) {
110
+ v_t xv = in_v[i];
111
+ #pragma unroll
112
+ for (int j = 0; j < width; ++j) {
113
+ acc_t x = xv.data[j];
114
+ if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
115
+ acc_t x2 = x * x;
116
+ s2 += x2;
117
+ s4 += x2 * x2;
118
+ s6 += x2 * x2 * x2;
119
+ }
120
+ }
121
+
122
+ float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(s2, s4, s6, 0.f));
123
+
124
+ const acc_t inv_d = acc_t(1) / D;
125
+ const acc_t ir1 = rsqrtf(sums.x * inv_d + eps);
126
+ const acc_t ir2 = rsqrtf(sums.y * inv_d + eps);
127
+ const acc_t ir3 = rsqrtf(sums.z * inv_d + eps);
128
+
129
+ // Save inv_rms for backward
130
+ if (threadIdx.x == 0) {
131
+ inv_rms[row * 3 + 0] = ir1;
132
+ inv_rms[row * 3 + 1] = ir2;
133
+ inv_rms[row * 3 + 2] = ir3;
134
+ }
135
+
136
+ const acc_t w2ir1 = w2 * ir1;
137
+ const acc_t w1ir2 = w1 * ir2;
138
+ const acc_t w0ir3 = w0 * ir3;
139
+
140
+ // Pass 2: output = poly * mul (with clamping)
141
+ const v_t *__restrict__ m_v = reinterpret_cast<const v_t *>(mul) + base;
142
+ v_t *__restrict__ out_v = reinterpret_cast<v_t *>(output) + base;
143
+
144
+ for (int i = threadIdx.x; i < vec_d; i += BLOCK_SIZE) {
145
+ v_t xv = in_v[i];
146
+ v_t mv = m_v[i];
147
+ v_t ov;
148
+ #pragma unroll
149
+ for (int j = 0; j < width; ++j) {
150
+ acc_t x = xv.data[j];
151
+ if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
152
+ acc_t m = (acc_t)mv.data[j];
153
+ if (do_clamp) m = fminf(fmaxf(m, -hidden_clamp), hidden_clamp);
154
+ acc_t x2 = x * x;
155
+ acc_t x3 = x2 * x;
156
+ acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
157
+ acc_t out_val = poly * m * score;
158
+ if (do_clamp) out_val = fminf(fmaxf(out_val, -hidden_clamp), hidden_clamp);
159
+ ov.data[j] = (scalar_t)out_val;
160
+ }
161
+ out_v[i] = ov;
162
+ }
163
+ }
164
+
165
+ // Scalar fallback forward
166
+ template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
167
+ __global__ void __launch_bounds__(BLOCK_SIZE)
168
+ grouped_poly_norm_fwd_scalar(
169
+ scalar_t *__restrict__ output,
170
+ acc_t *__restrict__ inv_rms,
171
+ const scalar_t *__restrict__ input,
172
+ const scalar_t *__restrict__ mul,
173
+ const scalar_t *__restrict__ weight,
174
+ const scalar_t *__restrict__ bias,
175
+ const int32_t *__restrict__ offsets,
176
+ const float *__restrict__ scores, // nullable, always fp32
177
+ const acc_t eps, const int D, const int num_experts,
178
+ const int expert_offset,
179
+ const acc_t hidden_clamp) {
180
+ const bool do_clamp = (hidden_clamp >= acc_t(0));
181
+ const int row = blockIdx.x;
182
+ const int64_t off = (int64_t)row * D;
183
+
184
+ const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
185
+ const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1], w2 = weight[eidx * 3 + 2];
186
+ const acc_t b_val = bias[eidx];
187
+ const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
188
+
189
+ acc_t s2 = 0, s4 = 0, s6 = 0;
190
+ for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
191
+ acc_t x = input[off + i];
192
+ if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
193
+ acc_t x2 = x * x;
194
+ s2 += x2; s4 += x2 * x2; s6 += x2 * x2 * x2;
195
+ }
196
+
197
+ float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(s2, s4, s6, 0.f));
198
+ const acc_t inv_d = acc_t(1) / D;
199
+ const acc_t ir1 = rsqrtf(sums.x * inv_d + eps);
200
+ const acc_t ir2 = rsqrtf(sums.y * inv_d + eps);
201
+ const acc_t ir3 = rsqrtf(sums.z * inv_d + eps);
202
+
203
+ if (threadIdx.x == 0) {
204
+ inv_rms[row * 3] = ir1; inv_rms[row * 3 + 1] = ir2; inv_rms[row * 3 + 2] = ir3;
205
+ }
206
+
207
+ const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
208
+ for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
209
+ acc_t x = input[off + i];
210
+ if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
211
+ acc_t m = (acc_t)mul[off + i];
212
+ if (do_clamp) m = fminf(fmaxf(m, -hidden_clamp), hidden_clamp);
213
+ acc_t x2 = x * x, x3 = x2 * x;
214
+ acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
215
+ acc_t out_val = poly * m * score;
216
+ if (do_clamp) out_val = fminf(fmaxf(out_val, -hidden_clamp), hidden_clamp);
217
+ output[off + i] = (scalar_t)out_val;
218
+ }
219
+ }
220
+
221
+ // ---------------------------------------------------------------------------
222
+ // Grouped PolyNorm Backward — vectorized (width > 0)
223
+ // Weight/bias grads use atomicAdd directly (no temp buffer + scatter_add).
224
+ // ---------------------------------------------------------------------------
225
+ template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
226
+ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
227
+ grouped_poly_norm_bwd_kernel(
228
+ scalar_t *__restrict__ grad_input,
229
+ scalar_t *__restrict__ grad_mul,
230
+ float *__restrict__ weight_grad, // [num_total_experts, 3] fp32
231
+ float *__restrict__ bias_grad, // [num_total_experts] fp32
232
+ const scalar_t *__restrict__ grad_output,
233
+ const scalar_t *__restrict__ input,
234
+ const scalar_t *__restrict__ mul,
235
+ const scalar_t *__restrict__ weight,
236
+ const scalar_t *__restrict__ bias,
237
+ const int32_t *__restrict__ offsets,
238
+ const acc_t *__restrict__ inv_rms,
239
+ const float *__restrict__ scores, // nullable, always fp32
240
+ acc_t *__restrict__ grad_scores, // nullable (null when scores is null)
241
+ const acc_t eps, const int D, const int num_experts,
242
+ const int expert_offset,
243
+ const acc_t hidden_clamp) {
244
+ using v_t = vec_t<scalar_t, width>;
245
+ const bool do_clamp = (hidden_clamp >= acc_t(0));
246
+
247
+ const int row = blockIdx.x;
248
+ const int vec_d = D / width;
249
+ const int64_t base = (int64_t)row * vec_d;
250
+
251
+ const v_t *__restrict__ in_v = reinterpret_cast<const v_t *>(input) + base;
252
+ const v_t *__restrict__ go_v = reinterpret_cast<const v_t *>(grad_output) + base;
253
+ const v_t *__restrict__ m_v = reinterpret_cast<const v_t *>(mul) + base;
254
+
255
+ const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
256
+ const acc_t w0 = weight[eidx * 3 + 0];
257
+ const acc_t w1 = weight[eidx * 3 + 1];
258
+ const acc_t w2 = weight[eidx * 3 + 2];
259
+ const acc_t b_val = bias[eidx];
260
+ const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
261
+
262
+ const acc_t ir1 = inv_rms[row * 3 + 0];
263
+ const acc_t ir2 = inv_rms[row * 3 + 1];
264
+ const acc_t ir3 = inv_rms[row * 3 + 2];
265
+
266
+ const acc_t w2ir1 = w2 * ir1;
267
+ const acc_t w1ir2 = w1 * ir2;
268
+ const acc_t w0ir3 = w0 * ir3;
269
+
270
+ // ---- Pass 1: dot products (with clamp masks) ----
271
+ acc_t sdpx = 0, sdpx2 = 0, sdpx3 = 0, sdp = 0;
272
+
273
+ for (int i = threadIdx.x; i < vec_d; i += BLOCK_SIZE) {
274
+ v_t xv = in_v[i];
275
+ v_t gv = go_v[i];
276
+ v_t mv = m_v[i];
277
+
278
+ #pragma unroll
279
+ for (int j = 0; j < width; ++j) {
280
+ acc_t x_orig = xv.data[j];
281
+ acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
282
+ acc_t m_orig = (acc_t)mv.data[j];
283
+ acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
284
+ acc_t go = (acc_t)gv.data[j];
285
+
286
+ // Output clamp mask: recompute pre-clamp output
287
+ if (do_clamp) {
288
+ acc_t x2 = x * x, x3 = x2 * x;
289
+ acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
290
+ acc_t out_pre = poly * m * score;
291
+ if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
292
+ }
293
+
294
+ acc_t x2 = x * x;
295
+ acc_t dp = go * m * score;
296
+ sdp += dp;
297
+ sdpx += dp * x;
298
+ sdpx2 += dp * x2;
299
+ sdpx3 += dp * x2 * x;
300
+ }
301
+ }
302
+
303
+ float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(sdpx, sdpx2, sdpx3, sdp));
304
+
305
+ const acc_t inv_d = acc_t(1) / D;
306
+ const acc_t s1 = sums.x * inv_d;
307
+ const acc_t s2 = sums.y * inv_d;
308
+ const acc_t s3 = sums.z * inv_d;
309
+ const acc_t bias_grad_val = sums.w;
310
+
311
+ const acc_t cx = w2 * s1 * ir1 * ir1;
312
+ const acc_t cx2 = w1 * s2 * ir2 * ir2;
313
+ const acc_t cx3 = w0 * s3 * ir3 * ir3;
314
+
315
+ // ---- Pass 2: grad_input + grad_mul + weight grads + grad_scores ----
316
+ acc_t dw0 = 0, dw1 = 0, dw2 = 0, gs_acc = 0;
317
+
318
+ v_t *__restrict__ gi_v = reinterpret_cast<v_t *>(grad_input) + base;
319
+ v_t *__restrict__ gm_v = reinterpret_cast<v_t *>(grad_mul) + base;
320
+
321
+ for (int i = threadIdx.x; i < vec_d; i += BLOCK_SIZE) {
322
+ v_t xv = in_v[i];
323
+ v_t gv = go_v[i];
324
+ v_t mv = m_v[i];
325
+ v_t gi, gm;
326
+
327
+ #pragma unroll
328
+ for (int j = 0; j < width; ++j) {
329
+ acc_t x_orig = xv.data[j];
330
+ acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
331
+ acc_t m_orig = (acc_t)mv.data[j];
332
+ acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
333
+ acc_t x2 = x * x;
334
+ acc_t x3 = x2 * x;
335
+ acc_t go = (acc_t)gv.data[j];
336
+
337
+ // Output clamp mask
338
+ acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
339
+ if (do_clamp) {
340
+ acc_t out_pre = (poly + b_val) * m * score;
341
+ if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
342
+ }
343
+
344
+ acc_t dp = go * m * score;
345
+
346
+ // grad_mul with mul clamp mask
347
+ acc_t gm_val = go * (poly + b_val) * score;
348
+ if (do_clamp && fabsf(m_orig) > hidden_clamp) gm_val = acc_t(0);
349
+ gm.data[j] = (scalar_t)gm_val;
350
+
351
+ // grad_input with input clamp mask
352
+ acc_t g = ir1 * (w2 * dp - x * cx);
353
+ g += acc_t(2) * x * ir2 * (w1 * dp - x2 * cx2);
354
+ g += acc_t(3) * x2 * ir3 * (w0 * dp - x3 * cx3);
355
+ if (do_clamp && fabsf(x_orig) > hidden_clamp) g = acc_t(0);
356
+ gi.data[j] = (scalar_t)g;
357
+
358
+ dw0 += dp * x3 * ir3;
359
+ dw1 += dp * x2 * ir2;
360
+ dw2 += dp * x * ir1;
361
+ gs_acc += go * (poly + b_val) * m; // grad_scores accumulator
362
+ }
363
+
364
+ gi_v[i] = gi;
365
+ gm_v[i] = gm;
366
+ }
367
+
368
+ // Reduce weight grads + grad_scores (.w channel)
369
+ float4 wg = block_reduce_f4<BLOCK_SIZE>(make_float4(dw0, dw1, dw2, gs_acc));
370
+
371
+ if (threadIdx.x == 0) {
372
+ atomicAdd(&weight_grad[eidx * 3 + 0], wg.x);
373
+ atomicAdd(&weight_grad[eidx * 3 + 1], wg.y);
374
+ atomicAdd(&weight_grad[eidx * 3 + 2], wg.z);
375
+ atomicAdd(&bias_grad[eidx], bias_grad_val);
376
+ if (grad_scores != nullptr) {
377
+ grad_scores[row] = wg.w;
378
+ }
379
+ }
380
+ }
381
+
382
+ // ---------------------------------------------------------------------------
383
+ // Scalar fallback (width == 0)
384
+ // ---------------------------------------------------------------------------
385
+ template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
386
+ __global__ void __launch_bounds__(BLOCK_SIZE)
387
+ grouped_poly_norm_bwd_scalar(
388
+ scalar_t *__restrict__ grad_input,
389
+ scalar_t *__restrict__ grad_mul,
390
+ float *__restrict__ weight_grad,
391
+ float *__restrict__ bias_grad,
392
+ const scalar_t *__restrict__ grad_output,
393
+ const scalar_t *__restrict__ input,
394
+ const scalar_t *__restrict__ mul,
395
+ const scalar_t *__restrict__ weight,
396
+ const scalar_t *__restrict__ bias,
397
+ const int32_t *__restrict__ offsets,
398
+ const acc_t *__restrict__ inv_rms,
399
+ const float *__restrict__ scores, // nullable, always fp32
400
+ acc_t *__restrict__ grad_scores, // nullable
401
+ const acc_t eps, const int D, const int num_experts,
402
+ const int expert_offset,
403
+ const acc_t hidden_clamp) {
404
+ const bool do_clamp = (hidden_clamp >= acc_t(0));
405
+ const int row = blockIdx.x;
406
+ const int64_t off = (int64_t)row * D;
407
+
408
+ const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
409
+ const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1], w2 = weight[eidx * 3 + 2];
410
+ const acc_t b_val = bias[eidx];
411
+ const acc_t ir1 = inv_rms[row * 3], ir2 = inv_rms[row * 3 + 1], ir3 = inv_rms[row * 3 + 2];
412
+ const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
413
+ const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
414
+
415
+ // Pass 1: dot products with clamp masks
416
+ acc_t sdpx = 0, sdpx2 = 0, sdpx3 = 0, sdp = 0;
417
+ for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
418
+ acc_t x_orig = input[off + i];
419
+ acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
420
+ acc_t m_orig = (acc_t)mul[off + i];
421
+ acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
422
+ acc_t go = (acc_t)grad_output[off + i];
423
+
424
+ if (do_clamp) {
425
+ acc_t x2 = x * x, x3 = x2 * x;
426
+ acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
427
+ acc_t out_pre = poly * m * score;
428
+ if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
429
+ }
430
+
431
+ acc_t x2 = x * x;
432
+ acc_t dp = go * m * score;
433
+ sdp += dp; sdpx += dp * x; sdpx2 += dp * x2; sdpx3 += dp * x2 * x;
434
+ }
435
+
436
+ float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(sdpx, sdpx2, sdpx3, sdp));
437
+ const acc_t inv_d = acc_t(1) / D;
438
+ const acc_t s1 = sums.x * inv_d, s2 = sums.y * inv_d, s3 = sums.z * inv_d;
439
+ const acc_t cx = w2 * s1 * ir1 * ir1, cx2 = w1 * s2 * ir2 * ir2, cx3 = w0 * s3 * ir3 * ir3;
440
+
441
+ // Pass 2: grads with clamp masks
442
+ acc_t dw0 = 0, dw1 = 0, dw2 = 0, gs_acc = 0;
443
+ for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
444
+ acc_t x_orig = input[off + i], m_orig = (acc_t)mul[off + i];
445
+ acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
446
+ acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
447
+ acc_t go = (acc_t)grad_output[off + i];
448
+ acc_t x2 = x * x, x3 = x2 * x;
449
+ acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
450
+
451
+ // Output clamp mask
452
+ if (do_clamp) {
453
+ acc_t out_pre = (poly + b_val) * m * score;
454
+ if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
455
+ }
456
+
457
+ acc_t dp = go * m * score;
458
+
459
+ acc_t gm_val = go * (poly + b_val) * score;
460
+ if (do_clamp && fabsf(m_orig) > hidden_clamp) gm_val = acc_t(0);
461
+ grad_mul[off + i] = (scalar_t)gm_val;
462
+
463
+ acc_t g = ir1 * (w2 * dp - x * cx) + acc_t(2) * x * ir2 * (w1 * dp - x2 * cx2)
464
+ + acc_t(3) * x2 * ir3 * (w0 * dp - x3 * cx3);
465
+ if (do_clamp && fabsf(x_orig) > hidden_clamp) g = acc_t(0);
466
+ grad_input[off + i] = (scalar_t)g;
467
+
468
+ dw0 += dp * x3 * ir3; dw1 += dp * x2 * ir2; dw2 += dp * x * ir1;
469
+ gs_acc += go * (poly + b_val) * m;
470
+ }
471
+
472
+ float4 wg = block_reduce_f4<BLOCK_SIZE>(make_float4(dw0, dw1, dw2, gs_acc));
473
+ if (threadIdx.x == 0) {
474
+ atomicAdd(&weight_grad[eidx * 3 + 0], wg.x);
475
+ atomicAdd(&weight_grad[eidx * 3 + 1], wg.y);
476
+ atomicAdd(&weight_grad[eidx * 3 + 2], wg.z);
477
+ atomicAdd(&bias_grad[eidx], sums.w);
478
+ if (grad_scores != nullptr) {
479
+ grad_scores[row] = wg.w;
480
+ }
481
+ }
482
+ }
483
+
484
+ } // namespace motif
485
+
486
+ // ---------------------------------------------------------------------------
487
+ // Internal helpers — shared kernel dispatch
488
+ // ---------------------------------------------------------------------------
489
+ #define FWD_LAUNCH(width_val, scalar_type_name) \
490
+ MOTIF_DISPATCH_FLOATING_TYPES( \
491
+ input.scalar_type(), scalar_type_name, [&] { \
492
+ motif::grouped_poly_norm_fwd_kernel<scalar_t, float, width_val, BLOCK> \
493
+ <<<grid, block, 0, stream>>>( \
494
+ output.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), \
495
+ input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(), \
496
+ weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), \
497
+ offsets.data_ptr<int32_t>(), scores_ptr, \
498
+ (float)eps, D, num_experts, (int)expert_offset, \
499
+ (float)hidden_clamp); \
500
+ })
501
+
502
+ static std::tuple<torch::Tensor, torch::Tensor>
503
+ _fwd_impl(const torch::Tensor &input, const torch::Tensor &mul,
504
+ const torch::Tensor &weight, const torch::Tensor &bias,
505
+ const torch::Tensor &offsets, const float *scores_ptr,
506
+ double eps, int64_t expert_offset, double hidden_clamp) {
507
+ const int D = input.size(-1);
508
+ const int64_t N = input.size(0);
509
+ const int num_experts = offsets.size(0);
510
+ constexpr int BLOCK = 128;
511
+ dim3 grid(N); dim3 block(BLOCK);
512
+
513
+ auto output = torch::empty_like(input);
514
+ auto inv_rms = torch::empty({N, 3}, input.options().dtype(torch::kFloat));
515
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
516
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
517
+
518
+ if (D % 8 == 0 && input.element_size() == 2)
519
+ FWD_LAUNCH(8, "grouped_poly_norm_fwd");
520
+ else if (D % 4 == 0 && input.element_size() == 4)
521
+ FWD_LAUNCH(4, "grouped_poly_norm_fwd");
522
+ else {
523
+ MOTIF_DISPATCH_FLOATING_TYPES(
524
+ input.scalar_type(), "grouped_poly_norm_fwd_scalar", [&] {
525
+ motif::grouped_poly_norm_fwd_scalar<scalar_t, float, BLOCK>
526
+ <<<grid, block, 0, stream>>>(
527
+ output.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(),
528
+ input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
529
+ weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
530
+ offsets.data_ptr<int32_t>(), scores_ptr,
531
+ (float)eps, D, num_experts, (int)expert_offset,
532
+ (float)hidden_clamp);
533
+ });
534
+ }
535
+ return {output, inv_rms};
536
+ }
537
+ #undef FWD_LAUNCH
538
+
539
+ #define BWD_LAUNCH(width_val, scalar_type_name, kernel_name) \
540
+ MOTIF_DISPATCH_FLOATING_TYPES( \
541
+ input.scalar_type(), scalar_type_name, [&] { \
542
+ motif::kernel_name<scalar_t, float, width_val, BLOCK> \
543
+ <<<grid, block, 0, stream>>>( \
544
+ input_grad.data_ptr<scalar_t>(), \
545
+ mul_grad.data_ptr<scalar_t>(), \
546
+ wg_f32.data_ptr<float>(), bg_f32.data_ptr<float>(), \
547
+ grad_output.data_ptr<scalar_t>(), \
548
+ input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(), \
549
+ weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), \
550
+ offsets.data_ptr<int32_t>(), inv_rms.data_ptr<float>(), \
551
+ scores_ptr, gs_ptr, \
552
+ (float)eps, D, num_experts, (int)expert_offset, \
553
+ (float)hidden_clamp); \
554
+ })
555
+
556
+ static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
557
+ torch::Tensor>
558
+ _bwd_impl(const torch::Tensor &grad_output, const torch::Tensor &input,
559
+ const torch::Tensor &mul, const torch::Tensor &weight,
560
+ const torch::Tensor &bias, const torch::Tensor &offsets,
561
+ const torch::Tensor &inv_rms, const float *scores_ptr,
562
+ float *gs_ptr, int64_t N,
563
+ double eps, int64_t expert_offset, double hidden_clamp) {
564
+ const int D = input.size(-1);
565
+ const int num_experts = offsets.size(0);
566
+ constexpr int BLOCK = 128;
567
+ dim3 grid(N); dim3 block(BLOCK);
568
+
569
+ auto input_grad = torch::empty_like(input);
570
+ auto mul_grad = torch::empty_like(mul);
571
+ auto wg_f32 = torch::zeros({weight.size(0), 3}, input.options().dtype(torch::kFloat));
572
+ auto bg_f32 = torch::zeros({bias.size(0)}, input.options().dtype(torch::kFloat));
573
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
574
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
575
+
576
+ if (D % 8 == 0 && input.element_size() == 2)
577
+ BWD_LAUNCH(8, "grouped_poly_norm_bwd", grouped_poly_norm_bwd_kernel);
578
+ else if (D % 4 == 0 && input.element_size() == 4)
579
+ BWD_LAUNCH(4, "grouped_poly_norm_bwd", grouped_poly_norm_bwd_kernel);
580
+ else {
581
+ MOTIF_DISPATCH_FLOATING_TYPES(
582
+ input.scalar_type(), "grouped_poly_norm_bwd_scalar", [&] {
583
+ motif::grouped_poly_norm_bwd_scalar<scalar_t, float, BLOCK>
584
+ <<<grid, block, 0, stream>>>(
585
+ input_grad.data_ptr<scalar_t>(),
586
+ mul_grad.data_ptr<scalar_t>(),
587
+ wg_f32.data_ptr<float>(), bg_f32.data_ptr<float>(),
588
+ grad_output.data_ptr<scalar_t>(),
589
+ input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
590
+ weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
591
+ offsets.data_ptr<int32_t>(), inv_rms.data_ptr<float>(),
592
+ scores_ptr, gs_ptr,
593
+ (float)eps, D, num_experts, (int)expert_offset,
594
+ (float)hidden_clamp);
595
+ });
596
+ }
597
+
598
+ auto weight_grad = wg_f32.to(weight.dtype());
599
+ auto bias_grad = bg_f32.unsqueeze(-1).to(bias.dtype());
600
+ // gs_f32 handled by caller
601
+ return {input_grad, mul_grad, weight_grad, bias_grad, torch::Tensor()};
602
+ }
603
+ #undef BWD_LAUNCH
604
+
605
+ // ---------------------------------------------------------------------------
606
+ // Public API: without scores
607
+ // ---------------------------------------------------------------------------
608
+ std::tuple<torch::Tensor, torch::Tensor>
609
+ grouped_poly_norm_forward(
610
+ const torch::Tensor &input, const torch::Tensor &mul,
611
+ const torch::Tensor &weight, const torch::Tensor &bias,
612
+ const torch::Tensor &offsets, double eps, int64_t expert_offset) {
613
+ return _fwd_impl(input, mul, weight, bias, offsets, nullptr, eps, expert_offset, -1.0);
614
+ }
615
+
616
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
617
+ grouped_poly_norm_backward(
618
+ const torch::Tensor &grad_output, const torch::Tensor &input,
619
+ const torch::Tensor &mul, const torch::Tensor &weight,
620
+ const torch::Tensor &bias, const torch::Tensor &offsets,
621
+ const torch::Tensor &inv_rms, double eps, int64_t expert_offset) {
622
+ const int64_t N = input.size(0);
623
+ auto [ig, mg, wg, bg, _] = _bwd_impl(
624
+ grad_output, input, mul, weight, bias, offsets, inv_rms,
625
+ nullptr, nullptr, N, eps, expert_offset, -1.0);
626
+ return {ig, mg, wg, bg};
627
+ }
628
+
629
+ // ---------------------------------------------------------------------------
630
+ // Public API: with scores
631
+ // ---------------------------------------------------------------------------
632
+ std::tuple<torch::Tensor, torch::Tensor>
633
+ grouped_poly_norm_forward_scored(
634
+ const torch::Tensor &input, const torch::Tensor &mul,
635
+ const torch::Tensor &weight, const torch::Tensor &bias,
636
+ const torch::Tensor &offsets, const torch::Tensor &scores,
637
+ double eps, int64_t expert_offset, double hidden_clamp) {
638
+ return _fwd_impl(input, mul, weight, bias, offsets,
639
+ scores.data_ptr<float>(), eps, expert_offset, hidden_clamp);
640
+ }
641
+
642
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
643
+ grouped_poly_norm_backward_scored(
644
+ const torch::Tensor &grad_output, const torch::Tensor &input,
645
+ const torch::Tensor &mul, const torch::Tensor &weight,
646
+ const torch::Tensor &bias, const torch::Tensor &offsets,
647
+ const torch::Tensor &inv_rms, const torch::Tensor &scores,
648
+ double eps, int64_t expert_offset, double hidden_clamp) {
649
+ const int64_t N = input.size(0);
650
+ auto gs_f32 = torch::empty({N}, input.options().dtype(torch::kFloat));
651
+ auto [ig, mg, wg, bg, _] = _bwd_impl(
652
+ grad_output, input, mul, weight, bias, offsets, inv_rms,
653
+ scores.data_ptr<float>(), gs_f32.data_ptr<float>(), N,
654
+ eps, expert_offset, hidden_clamp);
655
+ auto gs = gs_f32.unsqueeze(-1);
656
+ return {ig, mg, wg, bg, gs};
657
+ }
build.toml CHANGED
@@ -19,6 +19,7 @@ src = [
19
  "activation/fused_mul_poly_norm.cu",
20
  "activation/rms_norm.cu",
21
  "activation/fused_add_rms_norm.cu",
 
22
  "activation/cuda_compat.h",
23
  "activation/dispatch_utils.h",
24
  "activation/assert_utils.h",
@@ -33,6 +34,7 @@ src = [
33
  "activation/fused_mul_poly_norm.cu",
34
  "activation/rms_norm.cu",
35
  "activation/fused_add_rms_norm.cu",
 
36
  "activation/cuda_compat.h",
37
  "activation/dispatch_utils.h",
38
  "activation/assert_utils.h",
 
19
  "activation/fused_mul_poly_norm.cu",
20
  "activation/rms_norm.cu",
21
  "activation/fused_add_rms_norm.cu",
22
+ "activation/grouped_poly_norm.cu",
23
  "activation/cuda_compat.h",
24
  "activation/dispatch_utils.h",
25
  "activation/assert_utils.h",
 
34
  "activation/fused_mul_poly_norm.cu",
35
  "activation/rms_norm.cu",
36
  "activation/fused_add_rms_norm.cu",
37
+ "activation/grouped_poly_norm.cu",
38
  "activation/cuda_compat.h",
39
  "activation/dispatch_utils.h",
40
  "activation/assert_utils.h",
torch-ext/torch_binding.cpp CHANGED
@@ -48,6 +48,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
48
  "(Tensor, Tensor)");
49
  ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
50
  &fused_add_rms_norm_backward);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  }
52
 
53
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
48
  "(Tensor, Tensor)");
49
  ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
50
  &fused_add_rms_norm_backward);
51
+
52
+ // grouped_poly_norm (without scores)
53
+ ops.def("grouped_poly_norm_forward("
54
+ "Tensor input, Tensor mul, Tensor weight, "
55
+ "Tensor bias, Tensor offsets, "
56
+ "float eps, int expert_offset) -> (Tensor, Tensor)");
57
+ ops.impl("grouped_poly_norm_forward", torch::kCUDA,
58
+ &grouped_poly_norm_forward);
59
+
60
+ ops.def("grouped_poly_norm_backward("
61
+ "Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
62
+ "Tensor bias, Tensor offsets, Tensor inv_rms, "
63
+ "float eps, int expert_offset) -> (Tensor, Tensor, Tensor, Tensor)");
64
+ ops.impl("grouped_poly_norm_backward", torch::kCUDA,
65
+ &grouped_poly_norm_backward);
66
+
67
+ // grouped_poly_norm (with scores)
68
+ ops.def("grouped_poly_norm_forward_scored("
69
+ "Tensor input, Tensor mul, Tensor weight, "
70
+ "Tensor bias, Tensor offsets, Tensor scores, "
71
+ "float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor)");
72
+ ops.impl("grouped_poly_norm_forward_scored", torch::kCUDA,
73
+ &grouped_poly_norm_forward_scored);
74
+
75
+ ops.def("grouped_poly_norm_backward_scored("
76
+ "Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
77
+ "Tensor bias, Tensor offsets, Tensor inv_rms, Tensor scores, "
78
+ "float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
79
+ ops.impl("grouped_poly_norm_backward_scored", torch::kCUDA,
80
+ &grouped_poly_norm_backward_scored);
81
  }
82
 
83
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h CHANGED
@@ -35,3 +35,33 @@ std::tuple<torch::Tensor, torch::Tensor> fused_add_rms_norm_backward(
35
  const torch::Tensor &output_grad, const torch::Tensor &add_output_grad,
36
  const torch::Tensor &input, const torch::Tensor &weight, double eps,
37
  bool need_input_grad);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  const torch::Tensor &output_grad, const torch::Tensor &add_output_grad,
36
  const torch::Tensor &input, const torch::Tensor &weight, double eps,
37
  bool need_input_grad);
38
+
39
+ // Without scores
40
+ std::tuple<torch::Tensor, torch::Tensor>
41
+ grouped_poly_norm_forward(
42
+ const torch::Tensor &input, const torch::Tensor &mul,
43
+ const torch::Tensor &weight, const torch::Tensor &bias,
44
+ const torch::Tensor &offsets, double eps, int64_t expert_offset);
45
+
46
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
47
+ grouped_poly_norm_backward(
48
+ const torch::Tensor &grad_output, const torch::Tensor &input,
49
+ const torch::Tensor &mul, const torch::Tensor &weight,
50
+ const torch::Tensor &bias, const torch::Tensor &offsets,
51
+ const torch::Tensor &inv_rms, double eps, int64_t expert_offset);
52
+
53
+ // With scores (hidden_clamp < 0 = disabled)
54
+ std::tuple<torch::Tensor, torch::Tensor>
55
+ grouped_poly_norm_forward_scored(
56
+ const torch::Tensor &input, const torch::Tensor &mul,
57
+ const torch::Tensor &weight, const torch::Tensor &bias,
58
+ const torch::Tensor &offsets, const torch::Tensor &scores,
59
+ double eps, int64_t expert_offset, double hidden_clamp);
60
+
61
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
62
+ grouped_poly_norm_backward_scored(
63
+ const torch::Tensor &grad_output, const torch::Tensor &input,
64
+ const torch::Tensor &mul, const torch::Tensor &weight,
65
+ const torch::Tensor &bias, const torch::Tensor &offsets,
66
+ const torch::Tensor &inv_rms, const torch::Tensor &scores,
67
+ double eps, int64_t expert_offset, double hidden_clamp);