Kernels
wyldecat Claude Opus 4.6 (1M context) commited on
Commit
6436ad6
·
1 Parent(s): 344ed39

style: apply yapf, isort, and clang-format

Browse files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

README.md CHANGED
@@ -251,109 +251,6 @@ print(poly_norm(x))
251
  > | Forward | 0.7 ms | 2.1 ms | **3.0x** |
252
  > | Backward | 1.4 ms | 3.7 ms | **2.6x** |
253
 
254
- #### B200 Results (bf16)
255
-
256
- <details>
257
- <summary>Forward Performance</summary>
258
-
259
- | batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive |
260
- |-----------|---------|-----------|--------------|------------|-----------------|
261
- | 1 | 1024 | 294.54 | 73.46 | 64.33 | 4.58x |
262
- | 1 | 2048 | 373.50 | 94.88 | 65.26 | 5.72x |
263
- | 1 | 4096 | 372.65 | 94.90 | 66.90 | 5.57x |
264
- | 1 | 8192 | 486.98 | 102.33 | 72.71 | 6.70x |
265
- | 2 | 4096 | 486.66 | 101.87 | 72.27 | 6.73x |
266
- | 2 | 8192 | 950.62 | 106.96 | 90.06 | 10.56x |
267
- | 4 | 4096 | 950.72 | 107.17 | 71.28 | 13.34x |
268
- | 4 | 8192 | 1779.12 | 198.91 | 96.93 | 18.35x |
269
- | 8 | 4096 | 1778.73 | 199.10 | 96.88 | 18.36x |
270
- | 8 | 8192 | 3384.03 | 381.91 | 179.57 | 18.85x |
271
-
272
- </details>
273
-
274
- <details>
275
- <summary>Backward Performance</summary>
276
-
277
- | batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive |
278
- |-----------|---------|-----------|--------------|------------|-----------------|
279
- | 1 | 1024 | 1690.61 | 999.66 | 1017.66 | 1.66x |
280
- | 1 | 8192 | 1680.39 | 906.43 | 906.41 | 1.85x |
281
- | 2 | 8192 | 2466.73 | 870.74 | 862.78 | 2.86x |
282
- | 4 | 4096 | 2466.04 | 942.62 | 945.68 | 2.61x |
283
- | 4 | 8192 | 4543.10 | 941.01 | 908.30 | 5.00x |
284
- | 8 | 4096 | 4542.91 | 814.73 | 900.01 | 5.05x |
285
- | 8 | 8192 | 8599.41 | 956.81 | 955.07 | 9.00x |
286
-
287
- </details>
288
-
289
- <details>
290
- <summary>Forward + Backward Combined</summary>
291
-
292
- | batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive | CUDA vs Compiled |
293
- |-----------|---------|-----------|--------------|------------|-----------------|-------------------|
294
- | 1 | 1024 | 1985.15 | 1073.12 | 1081.99 | 1.83x | 0.99x |
295
- | 1 | 4096 | 2085.10 | 974.32 | 960.73 | 2.17x | 1.01x |
296
- | 1 | 8192 | 2167.37 | 1008.76 | 979.12 | 2.21x | 1.03x |
297
- | 2 | 4096 | 2083.49 | 1001.03 | 965.30 | 2.16x | 1.04x |
298
- | 2 | 8192 | 3417.35 | 977.70 | 952.84 | 3.59x | 1.03x |
299
- | 4 | 4096 | 3416.76 | 1049.79 | 1016.97 | 3.36x | 1.03x |
300
- | 4 | 8192 | 6322.22 | 1139.92 | 1005.23 | 6.29x | 1.13x |
301
- | 8 | 4096 | 6321.64 | 1013.83 | 996.89 | 6.34x | 1.02x |
302
- | 8 | 8192 | 11983.44 | 1338.71 | 1134.64 | 10.56x | 1.18x |
303
-
304
- </details>
305
-
306
- #### B200 Results (fp32)
307
-
308
- <details>
309
- <summary>Forward Performance</summary>
310
-
311
- | batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive |
312
- |-----------|---------|-----------|--------------|------------|-----------------|
313
- | 1 | 1024 | 318.05 | 83.29 | 64.24 | 4.95x |
314
- | 1 | 2048 | 311.14 | 95.19 | 63.64 | 4.89x |
315
- | 1 | 8192 | 401.78 | 101.61 | 68.21 | 5.89x |
316
- | 2 | 4096 | 403.42 | 100.97 | 68.01 | 5.93x |
317
- | 2 | 8192 | 803.31 | 130.51 | 68.21 | 11.78x |
318
- | 4 | 4096 | 802.86 | 130.61 | 66.97 | 11.99x |
319
- | 4 | 8192 | 1505.96 | 246.77 | 100.49 | 14.99x |
320
- | 8 | 4096 | 1507.87 | 246.84 | 100.23 | 15.04x |
321
- | 8 | 8192 | 2856.93 | 476.34 | 184.40 | 15.49x |
322
-
323
- </details>
324
-
325
- <details>
326
- <summary>Backward Performance</summary>
327
-
328
- | batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive |
329
- |-----------|---------|-----------|--------------|------------|-----------------|
330
- | 1 | 1024 | 1604.25 | 989.30 | 1114.12 | 1.44x |
331
- | 1 | 8192 | 1996.40 | 1117.71 | 1115.47 | 1.79x |
332
- | 2 | 8192 | 2353.87 | 1119.41 | 1118.57 | 2.10x |
333
- | 4 | 4096 | 2358.47 | 1102.23 | 1125.16 | 2.10x |
334
- | 4 | 8192 | 4346.92 | 1125.33 | 1135.36 | 3.83x |
335
- | 8 | 4096 | 4347.47 | 1104.27 | 1119.63 | 3.88x |
336
- | 8 | 8192 | 8226.50 | 1172.66 | 1197.68 | 6.87x |
337
-
338
- </details>
339
-
340
- <details>
341
- <summary>Forward + Backward Combined</summary>
342
-
343
- | batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive | CUDA vs Compiled |
344
- |-----------|---------|-----------|--------------|------------|-----------------|-------------------|
345
- | 1 | 1024 | 1922.30 | 1072.59 | 1178.36 | 1.63x | 0.91x |
346
- | 1 | 4096 | 2367.77 | 1208.69 | 1192.07 | 1.99x | 1.01x |
347
- | 1 | 8192 | 2398.19 | 1219.32 | 1183.69 | 2.03x | 1.03x |
348
- | 2 | 4096 | 2401.39 | 1248.87 | 1154.72 | 2.08x | 1.08x |
349
- | 2 | 8192 | 3157.18 | 1249.92 | 1186.77 | 2.66x | 1.05x |
350
- | 4 | 4096 | 3161.33 | 1232.84 | 1192.13 | 2.65x | 1.03x |
351
- | 4 | 8192 | 5852.88 | 1372.10 | 1235.86 | 4.74x | 1.11x |
352
- | 8 | 4096 | 5855.34 | 1351.11 | 1219.85 | 4.80x | 1.11x |
353
- | 8 | 8192 | 11083.43 | 1649.00 | 1382.07 | 8.02x | 1.19x |
354
-
355
- </details>
356
-
357
  ## Pre-commit Hooks
358
 
359
  This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
 
251
  > | Forward | 0.7 ms | 2.1 ms | **3.0x** |
252
  > | Backward | 1.4 ms | 3.7 ms | **2.6x** |
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  ## Pre-commit Hooks
255
 
256
  This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
activation/grouped_poly_norm.cu CHANGED
@@ -18,7 +18,8 @@ __device__ __forceinline__ int find_expert(const int32_t *__restrict__ offsets,
18
  int lo = 0, hi = num_experts;
19
  #pragma unroll 6
20
  for (int i = 0; i < 12; ++i) {
21
- if (lo >= hi) break;
 
22
  int mid = (lo + hi) >> 1;
23
  if (offsets[mid] <= row)
24
  lo = mid + 1;
@@ -53,7 +54,8 @@ __device__ __forceinline__ float4 block_reduce_f4(float4 v) {
53
  const int warp_id = threadIdx.x / WARP_SIZE;
54
  const int lane_id = threadIdx.x % WARP_SIZE;
55
 
56
- if (lane_id == 0) warp_results[warp_id] = v;
 
57
  __syncthreads();
58
 
59
  if (warp_id == 0 && lane_id < NUM_WARPS)
@@ -61,10 +63,12 @@ __device__ __forceinline__ float4 block_reduce_f4(float4 v) {
61
  else
62
  v = make_float4(0.f, 0.f, 0.f, 0.f);
63
 
64
- if (warp_id == 0) v = warp_reduce_f4(v);
 
65
 
66
  __shared__ float4 result;
67
- if (threadIdx.x == 0) result = v;
 
68
  __syncthreads();
69
  return result;
70
  }
@@ -77,17 +81,14 @@ __device__ __forceinline__ float4 block_reduce_f4(float4 v) {
77
  template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
78
  __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
79
  grouped_poly_norm_fwd_kernel(
80
- scalar_t *__restrict__ output,
81
- acc_t *__restrict__ inv_rms,
82
- const scalar_t *__restrict__ input,
83
- const scalar_t *__restrict__ mul,
84
- const scalar_t *__restrict__ weight,
85
- const scalar_t *__restrict__ bias,
86
  const int32_t *__restrict__ offsets,
87
- const float *__restrict__ scores, // nullable, always fp32
88
  const acc_t eps, const int D, const int num_experts,
89
  const int expert_offset,
90
- const acc_t hidden_clamp) { // < 0 = disabled
91
  using v_t = vec_t<scalar_t, width>;
92
  const bool do_clamp = (hidden_clamp >= acc_t(0));
93
 
@@ -111,7 +112,8 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
111
  #pragma unroll
112
  for (int j = 0; j < width; ++j) {
113
  acc_t x = xv.data[j];
114
- if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
 
115
  acc_t x2 = x * x;
116
  s2 += x2;
117
  s4 += x2 * x2;
@@ -148,14 +150,17 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
148
  #pragma unroll
149
  for (int j = 0; j < width; ++j) {
150
  acc_t x = xv.data[j];
151
- if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
 
152
  acc_t m = (acc_t)mv.data[j];
153
- if (do_clamp) m = fminf(fmaxf(m, -hidden_clamp), hidden_clamp);
 
154
  acc_t x2 = x * x;
155
  acc_t x3 = x2 * x;
156
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
157
  acc_t out_val = poly * m * score;
158
- if (do_clamp) out_val = fminf(fmaxf(out_val, -hidden_clamp), hidden_clamp);
 
159
  ov.data[j] = (scalar_t)out_val;
160
  }
161
  out_v[i] = ov;
@@ -164,34 +169,33 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
164
 
165
  // Scalar fallback forward
166
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
167
- __global__ void __launch_bounds__(BLOCK_SIZE)
168
- grouped_poly_norm_fwd_scalar(
169
- scalar_t *__restrict__ output,
170
- acc_t *__restrict__ inv_rms,
171
- const scalar_t *__restrict__ input,
172
- const scalar_t *__restrict__ mul,
173
- const scalar_t *__restrict__ weight,
174
- const scalar_t *__restrict__ bias,
175
- const int32_t *__restrict__ offsets,
176
- const float *__restrict__ scores, // nullable, always fp32
177
- const acc_t eps, const int D, const int num_experts,
178
- const int expert_offset,
179
- const acc_t hidden_clamp) {
180
  const bool do_clamp = (hidden_clamp >= acc_t(0));
181
  const int row = blockIdx.x;
182
  const int64_t off = (int64_t)row * D;
183
 
184
  const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
185
- const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1], w2 = weight[eidx * 3 + 2];
 
186
  const acc_t b_val = bias[eidx];
187
  const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
188
 
189
  acc_t s2 = 0, s4 = 0, s6 = 0;
190
  for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
191
  acc_t x = input[off + i];
192
- if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
 
193
  acc_t x2 = x * x;
194
- s2 += x2; s4 += x2 * x2; s6 += x2 * x2 * x2;
 
 
195
  }
196
 
197
  float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(s2, s4, s6, 0.f));
@@ -201,19 +205,24 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
201
  const acc_t ir3 = rsqrtf(sums.z * inv_d + eps);
202
 
203
  if (threadIdx.x == 0) {
204
- inv_rms[row * 3] = ir1; inv_rms[row * 3 + 1] = ir2; inv_rms[row * 3 + 2] = ir3;
 
 
205
  }
206
 
207
  const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
208
  for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
209
  acc_t x = input[off + i];
210
- if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
 
211
  acc_t m = (acc_t)mul[off + i];
212
- if (do_clamp) m = fminf(fmaxf(m, -hidden_clamp), hidden_clamp);
 
213
  acc_t x2 = x * x, x3 = x2 * x;
214
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
215
  acc_t out_val = poly * m * score;
216
- if (do_clamp) out_val = fminf(fmaxf(out_val, -hidden_clamp), hidden_clamp);
 
217
  output[off + i] = (scalar_t)out_val;
218
  }
219
  }
@@ -225,22 +234,17 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
225
  template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
226
  __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
227
  grouped_poly_norm_bwd_kernel(
228
- scalar_t *__restrict__ grad_input,
229
- scalar_t *__restrict__ grad_mul,
230
- float *__restrict__ weight_grad, // [num_total_experts, 3] fp32
231
- float *__restrict__ bias_grad, // [num_total_experts] fp32
232
  const scalar_t *__restrict__ grad_output,
233
- const scalar_t *__restrict__ input,
234
- const scalar_t *__restrict__ mul,
235
- const scalar_t *__restrict__ weight,
236
- const scalar_t *__restrict__ bias,
237
- const int32_t *__restrict__ offsets,
238
- const acc_t *__restrict__ inv_rms,
239
- const float *__restrict__ scores, // nullable, always fp32
240
- acc_t *__restrict__ grad_scores, // nullable (null when scores is null)
241
  const acc_t eps, const int D, const int num_experts,
242
- const int expert_offset,
243
- const acc_t hidden_clamp) {
244
  using v_t = vec_t<scalar_t, width>;
245
  const bool do_clamp = (hidden_clamp >= acc_t(0));
246
 
@@ -249,7 +253,8 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
249
  const int64_t base = (int64_t)row * vec_d;
250
 
251
  const v_t *__restrict__ in_v = reinterpret_cast<const v_t *>(input) + base;
252
- const v_t *__restrict__ go_v = reinterpret_cast<const v_t *>(grad_output) + base;
 
253
  const v_t *__restrict__ m_v = reinterpret_cast<const v_t *>(mul) + base;
254
 
255
  const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
@@ -278,9 +283,11 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
278
  #pragma unroll
279
  for (int j = 0; j < width; ++j) {
280
  acc_t x_orig = xv.data[j];
281
- acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
 
282
  acc_t m_orig = (acc_t)mv.data[j];
283
- acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
 
284
  acc_t go = (acc_t)gv.data[j];
285
 
286
  // Output clamp mask: recompute pre-clamp output
@@ -288,7 +295,8 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
288
  acc_t x2 = x * x, x3 = x2 * x;
289
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
290
  acc_t out_pre = poly * m * score;
291
- if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
 
292
  }
293
 
294
  acc_t x2 = x * x;
@@ -300,7 +308,8 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
300
  }
301
  }
302
 
303
- float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(sdpx, sdpx2, sdpx3, sdp));
 
304
 
305
  const acc_t inv_d = acc_t(1) / D;
306
  const acc_t s1 = sums.x * inv_d;
@@ -327,9 +336,11 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
327
  #pragma unroll
328
  for (int j = 0; j < width; ++j) {
329
  acc_t x_orig = xv.data[j];
330
- acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
 
331
  acc_t m_orig = (acc_t)mv.data[j];
332
- acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
 
333
  acc_t x2 = x * x;
334
  acc_t x3 = x2 * x;
335
  acc_t go = (acc_t)gv.data[j];
@@ -338,27 +349,30 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
338
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
339
  if (do_clamp) {
340
  acc_t out_pre = (poly + b_val) * m * score;
341
- if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
 
342
  }
343
 
344
  acc_t dp = go * m * score;
345
 
346
  // grad_mul with mul clamp mask
347
  acc_t gm_val = go * (poly + b_val) * score;
348
- if (do_clamp && fabsf(m_orig) > hidden_clamp) gm_val = acc_t(0);
 
349
  gm.data[j] = (scalar_t)gm_val;
350
 
351
  // grad_input with input clamp mask
352
  acc_t g = ir1 * (w2 * dp - x * cx);
353
  g += acc_t(2) * x * ir2 * (w1 * dp - x2 * cx2);
354
  g += acc_t(3) * x2 * ir3 * (w0 * dp - x3 * cx3);
355
- if (do_clamp && fabsf(x_orig) > hidden_clamp) g = acc_t(0);
 
356
  gi.data[j] = (scalar_t)g;
357
 
358
  dw0 += dp * x3 * ir3;
359
  dw1 += dp * x2 * ir2;
360
  dw2 += dp * x * ir1;
361
- gs_acc += go * (poly + b_val) * m; // grad_scores accumulator
362
  }
363
 
364
  gi_v[i] = gi;
@@ -383,32 +397,27 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
383
  // Scalar fallback (width == 0)
384
  // ---------------------------------------------------------------------------
385
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
386
- __global__ void __launch_bounds__(BLOCK_SIZE)
387
- grouped_poly_norm_bwd_scalar(
388
- scalar_t *__restrict__ grad_input,
389
- scalar_t *__restrict__ grad_mul,
390
- float *__restrict__ weight_grad,
391
- float *__restrict__ bias_grad,
392
- const scalar_t *__restrict__ grad_output,
393
- const scalar_t *__restrict__ input,
394
- const scalar_t *__restrict__ mul,
395
- const scalar_t *__restrict__ weight,
396
- const scalar_t *__restrict__ bias,
397
- const int32_t *__restrict__ offsets,
398
- const acc_t *__restrict__ inv_rms,
399
- const float *__restrict__ scores, // nullable, always fp32
400
- acc_t *__restrict__ grad_scores, // nullable
401
- const acc_t eps, const int D, const int num_experts,
402
- const int expert_offset,
403
- const acc_t hidden_clamp) {
404
  const bool do_clamp = (hidden_clamp >= acc_t(0));
405
  const int row = blockIdx.x;
406
  const int64_t off = (int64_t)row * D;
407
 
408
  const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
409
- const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1], w2 = weight[eidx * 3 + 2];
 
410
  const acc_t b_val = bias[eidx];
411
- const acc_t ir1 = inv_rms[row * 3], ir2 = inv_rms[row * 3 + 1], ir3 = inv_rms[row * 3 + 2];
 
412
  const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
413
  const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
414
 
@@ -416,34 +425,44 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
416
  acc_t sdpx = 0, sdpx2 = 0, sdpx3 = 0, sdp = 0;
417
  for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
418
  acc_t x_orig = input[off + i];
419
- acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
 
420
  acc_t m_orig = (acc_t)mul[off + i];
421
- acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
 
422
  acc_t go = (acc_t)grad_output[off + i];
423
 
424
  if (do_clamp) {
425
  acc_t x2 = x * x, x3 = x2 * x;
426
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
427
  acc_t out_pre = poly * m * score;
428
- if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
 
429
  }
430
 
431
  acc_t x2 = x * x;
432
  acc_t dp = go * m * score;
433
- sdp += dp; sdpx += dp * x; sdpx2 += dp * x2; sdpx3 += dp * x2 * x;
 
 
 
434
  }
435
 
436
- float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(sdpx, sdpx2, sdpx3, sdp));
 
437
  const acc_t inv_d = acc_t(1) / D;
438
  const acc_t s1 = sums.x * inv_d, s2 = sums.y * inv_d, s3 = sums.z * inv_d;
439
- const acc_t cx = w2 * s1 * ir1 * ir1, cx2 = w1 * s2 * ir2 * ir2, cx3 = w0 * s3 * ir3 * ir3;
 
440
 
441
  // Pass 2: grads with clamp masks
442
  acc_t dw0 = 0, dw1 = 0, dw2 = 0, gs_acc = 0;
443
  for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
444
  acc_t x_orig = input[off + i], m_orig = (acc_t)mul[off + i];
445
- acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
446
- acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
 
 
447
  acc_t go = (acc_t)grad_output[off + i];
448
  acc_t x2 = x * x, x3 = x2 * x;
449
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
@@ -451,21 +470,27 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
451
  // Output clamp mask
452
  if (do_clamp) {
453
  acc_t out_pre = (poly + b_val) * m * score;
454
- if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
 
455
  }
456
 
457
  acc_t dp = go * m * score;
458
 
459
  acc_t gm_val = go * (poly + b_val) * score;
460
- if (do_clamp && fabsf(m_orig) > hidden_clamp) gm_val = acc_t(0);
 
461
  grad_mul[off + i] = (scalar_t)gm_val;
462
 
463
- acc_t g = ir1 * (w2 * dp - x * cx) + acc_t(2) * x * ir2 * (w1 * dp - x2 * cx2)
464
- + acc_t(3) * x2 * ir3 * (w0 * dp - x3 * cx3);
465
- if (do_clamp && fabsf(x_orig) > hidden_clamp) g = acc_t(0);
 
 
466
  grad_input[off + i] = (scalar_t)g;
467
 
468
- dw0 += dp * x3 * ir3; dw1 += dp * x2 * ir2; dw2 += dp * x * ir1;
 
 
469
  gs_acc += go * (poly + b_val) * m;
470
  }
471
 
@@ -487,28 +512,27 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
487
  // Internal helpers — shared kernel dispatch
488
  // ---------------------------------------------------------------------------
489
  #define FWD_LAUNCH(width_val, scalar_type_name) \
490
- MOTIF_DISPATCH_FLOATING_TYPES( \
491
- input.scalar_type(), scalar_type_name, [&] { \
492
- motif::grouped_poly_norm_fwd_kernel<scalar_t, float, width_val, BLOCK> \
493
- <<<grid, block, 0, stream>>>( \
494
- output.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), \
495
- input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(), \
496
- weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), \
497
- offsets.data_ptr<int32_t>(), scores_ptr, \
498
- (float)eps, D, num_experts, (int)expert_offset, \
499
- (float)hidden_clamp); \
500
- })
501
 
502
  static std::tuple<torch::Tensor, torch::Tensor>
503
  _fwd_impl(const torch::Tensor &input, const torch::Tensor &mul,
504
- const torch::Tensor &weight, const torch::Tensor &bias,
505
- const torch::Tensor &offsets, const float *scores_ptr,
506
- double eps, int64_t expert_offset, double hidden_clamp) {
507
  const int D = input.size(-1);
508
  const int64_t N = input.size(0);
509
  const int num_experts = offsets.size(0);
510
  constexpr int BLOCK = 128;
511
- dim3 grid(N); dim3 block(BLOCK);
 
512
 
513
  auto output = torch::empty_like(input);
514
  auto inv_rms = torch::empty({N, 3}, input.options().dtype(torch::kFloat));
@@ -527,9 +551,8 @@ _fwd_impl(const torch::Tensor &input, const torch::Tensor &mul,
527
  output.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(),
528
  input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
529
  weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
530
- offsets.data_ptr<int32_t>(), scores_ptr,
531
- (float)eps, D, num_experts, (int)expert_offset,
532
- (float)hidden_clamp);
533
  });
534
  }
535
  return {output, inv_rms};
@@ -537,39 +560,37 @@ _fwd_impl(const torch::Tensor &input, const torch::Tensor &mul,
537
  #undef FWD_LAUNCH
538
 
539
  #define BWD_LAUNCH(width_val, scalar_type_name, kernel_name) \
540
- MOTIF_DISPATCH_FLOATING_TYPES( \
541
- input.scalar_type(), scalar_type_name, [&] { \
542
- motif::kernel_name<scalar_t, float, width_val, BLOCK> \
543
- <<<grid, block, 0, stream>>>( \
544
- input_grad.data_ptr<scalar_t>(), \
545
- mul_grad.data_ptr<scalar_t>(), \
546
- wg_f32.data_ptr<float>(), bg_f32.data_ptr<float>(), \
547
- grad_output.data_ptr<scalar_t>(), \
548
- input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(), \
549
- weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), \
550
- offsets.data_ptr<int32_t>(), inv_rms.data_ptr<float>(), \
551
- scores_ptr, gs_ptr, \
552
- (float)eps, D, num_experts, (int)expert_offset, \
553
- (float)hidden_clamp); \
554
- })
555
 
556
  static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
557
- torch::Tensor>
558
  _bwd_impl(const torch::Tensor &grad_output, const torch::Tensor &input,
559
- const torch::Tensor &mul, const torch::Tensor &weight,
560
- const torch::Tensor &bias, const torch::Tensor &offsets,
561
- const torch::Tensor &inv_rms, const float *scores_ptr,
562
- float *gs_ptr, int64_t N,
563
- double eps, int64_t expert_offset, double hidden_clamp) {
564
  const int D = input.size(-1);
565
  const int num_experts = offsets.size(0);
566
  constexpr int BLOCK = 128;
567
- dim3 grid(N); dim3 block(BLOCK);
 
568
 
569
  auto input_grad = torch::empty_like(input);
570
  auto mul_grad = torch::empty_like(mul);
571
- auto wg_f32 = torch::zeros({weight.size(0), 3}, input.options().dtype(torch::kFloat));
572
- auto bg_f32 = torch::zeros({bias.size(0)}, input.options().dtype(torch::kFloat));
 
 
573
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
574
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
575
 
@@ -583,15 +604,13 @@ _bwd_impl(const torch::Tensor &grad_output, const torch::Tensor &input,
583
  motif::grouped_poly_norm_bwd_scalar<scalar_t, float, BLOCK>
584
  <<<grid, block, 0, stream>>>(
585
  input_grad.data_ptr<scalar_t>(),
586
- mul_grad.data_ptr<scalar_t>(),
587
- wg_f32.data_ptr<float>(), bg_f32.data_ptr<float>(),
588
- grad_output.data_ptr<scalar_t>(),
589
  input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
590
  weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
591
  offsets.data_ptr<int32_t>(), inv_rms.data_ptr<float>(),
592
- scores_ptr, gs_ptr,
593
- (float)eps, D, num_experts, (int)expert_offset,
594
- (float)hidden_clamp);
595
  });
596
  }
597
 
@@ -606,54 +625,56 @@ _bwd_impl(const torch::Tensor &grad_output, const torch::Tensor &input,
606
  // Public API: without scores
607
  // ---------------------------------------------------------------------------
608
  std::tuple<torch::Tensor, torch::Tensor>
609
- grouped_poly_norm_forward(
610
- const torch::Tensor &input, const torch::Tensor &mul,
611
- const torch::Tensor &weight, const torch::Tensor &bias,
612
- const torch::Tensor &offsets, double eps, int64_t expert_offset,
613
- double hidden_clamp) {
614
- return _fwd_impl(input, mul, weight, bias, offsets, nullptr, eps, expert_offset, hidden_clamp);
 
615
  }
616
 
617
  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
618
- grouped_poly_norm_backward(
619
- const torch::Tensor &grad_output, const torch::Tensor &input,
620
- const torch::Tensor &mul, const torch::Tensor &weight,
621
- const torch::Tensor &bias, const torch::Tensor &offsets,
622
- const torch::Tensor &inv_rms, double eps, int64_t expert_offset,
623
- double hidden_clamp) {
 
624
  const int64_t N = input.size(0);
625
- auto [ig, mg, wg, bg, _] = _bwd_impl(
626
- grad_output, input, mul, weight, bias, offsets, inv_rms,
627
- nullptr, nullptr, N, eps, expert_offset, hidden_clamp);
628
  return {ig, mg, wg, bg};
629
  }
630
 
631
  // ---------------------------------------------------------------------------
632
  // Public API: with scores
633
  // ---------------------------------------------------------------------------
634
- std::tuple<torch::Tensor, torch::Tensor>
635
- grouped_poly_norm_forward_scored(
636
  const torch::Tensor &input, const torch::Tensor &mul,
637
  const torch::Tensor &weight, const torch::Tensor &bias,
638
- const torch::Tensor &offsets, const torch::Tensor &scores,
639
- double eps, int64_t expert_offset, double hidden_clamp) {
640
- return _fwd_impl(input, mul, weight, bias, offsets,
641
- scores.data_ptr<float>(), eps, expert_offset, hidden_clamp);
642
  }
643
 
644
- std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
 
645
  grouped_poly_norm_backward_scored(
646
  const torch::Tensor &grad_output, const torch::Tensor &input,
647
  const torch::Tensor &mul, const torch::Tensor &weight,
648
  const torch::Tensor &bias, const torch::Tensor &offsets,
649
- const torch::Tensor &inv_rms, const torch::Tensor &scores,
650
- double eps, int64_t expert_offset, double hidden_clamp) {
651
  const int64_t N = input.size(0);
652
  auto gs_f32 = torch::empty({N}, input.options().dtype(torch::kFloat));
653
- auto [ig, mg, wg, bg, _] = _bwd_impl(
654
- grad_output, input, mul, weight, bias, offsets, inv_rms,
655
- scores.data_ptr<float>(), gs_f32.data_ptr<float>(), N,
656
- eps, expert_offset, hidden_clamp);
657
  auto gs = gs_f32.unsqueeze(-1);
658
  return {ig, mg, wg, bg, gs};
659
  }
 
18
  int lo = 0, hi = num_experts;
19
  #pragma unroll 6
20
  for (int i = 0; i < 12; ++i) {
21
+ if (lo >= hi)
22
+ break;
23
  int mid = (lo + hi) >> 1;
24
  if (offsets[mid] <= row)
25
  lo = mid + 1;
 
54
  const int warp_id = threadIdx.x / WARP_SIZE;
55
  const int lane_id = threadIdx.x % WARP_SIZE;
56
 
57
+ if (lane_id == 0)
58
+ warp_results[warp_id] = v;
59
  __syncthreads();
60
 
61
  if (warp_id == 0 && lane_id < NUM_WARPS)
 
63
  else
64
  v = make_float4(0.f, 0.f, 0.f, 0.f);
65
 
66
+ if (warp_id == 0)
67
+ v = warp_reduce_f4(v);
68
 
69
  __shared__ float4 result;
70
+ if (threadIdx.x == 0)
71
+ result = v;
72
  __syncthreads();
73
  return result;
74
  }
 
81
  template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
82
  __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
83
  grouped_poly_norm_fwd_kernel(
84
+ scalar_t *__restrict__ output, acc_t *__restrict__ inv_rms,
85
+ const scalar_t *__restrict__ input, const scalar_t *__restrict__ mul,
86
+ const scalar_t *__restrict__ weight, const scalar_t *__restrict__ bias,
 
 
 
87
  const int32_t *__restrict__ offsets,
88
+ const float *__restrict__ scores, // nullable, always fp32
89
  const acc_t eps, const int D, const int num_experts,
90
  const int expert_offset,
91
+ const acc_t hidden_clamp) { // < 0 = disabled
92
  using v_t = vec_t<scalar_t, width>;
93
  const bool do_clamp = (hidden_clamp >= acc_t(0));
94
 
 
112
  #pragma unroll
113
  for (int j = 0; j < width; ++j) {
114
  acc_t x = xv.data[j];
115
+ if (do_clamp)
116
+ x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
117
  acc_t x2 = x * x;
118
  s2 += x2;
119
  s4 += x2 * x2;
 
150
  #pragma unroll
151
  for (int j = 0; j < width; ++j) {
152
  acc_t x = xv.data[j];
153
+ if (do_clamp)
154
+ x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
155
  acc_t m = (acc_t)mv.data[j];
156
+ if (do_clamp)
157
+ m = fminf(fmaxf(m, -hidden_clamp), hidden_clamp);
158
  acc_t x2 = x * x;
159
  acc_t x3 = x2 * x;
160
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
161
  acc_t out_val = poly * m * score;
162
+ if (do_clamp)
163
+ out_val = fminf(fmaxf(out_val, -hidden_clamp), hidden_clamp);
164
  ov.data[j] = (scalar_t)out_val;
165
  }
166
  out_v[i] = ov;
 
169
 
170
  // Scalar fallback forward
171
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
172
+ __global__ void __launch_bounds__(BLOCK_SIZE) grouped_poly_norm_fwd_scalar(
173
+ scalar_t *__restrict__ output, acc_t *__restrict__ inv_rms,
174
+ const scalar_t *__restrict__ input, const scalar_t *__restrict__ mul,
175
+ const scalar_t *__restrict__ weight, const scalar_t *__restrict__ bias,
176
+ const int32_t *__restrict__ offsets,
177
+ const float *__restrict__ scores, // nullable, always fp32
178
+ const acc_t eps, const int D, const int num_experts,
179
+ const int expert_offset, const acc_t hidden_clamp) {
 
 
 
 
 
180
  const bool do_clamp = (hidden_clamp >= acc_t(0));
181
  const int row = blockIdx.x;
182
  const int64_t off = (int64_t)row * D;
183
 
184
  const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
185
+ const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1],
186
+ w2 = weight[eidx * 3 + 2];
187
  const acc_t b_val = bias[eidx];
188
  const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
189
 
190
  acc_t s2 = 0, s4 = 0, s6 = 0;
191
  for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
192
  acc_t x = input[off + i];
193
+ if (do_clamp)
194
+ x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
195
  acc_t x2 = x * x;
196
+ s2 += x2;
197
+ s4 += x2 * x2;
198
+ s6 += x2 * x2 * x2;
199
  }
200
 
201
  float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(s2, s4, s6, 0.f));
 
205
  const acc_t ir3 = rsqrtf(sums.z * inv_d + eps);
206
 
207
  if (threadIdx.x == 0) {
208
+ inv_rms[row * 3] = ir1;
209
+ inv_rms[row * 3 + 1] = ir2;
210
+ inv_rms[row * 3 + 2] = ir3;
211
  }
212
 
213
  const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
214
  for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
215
  acc_t x = input[off + i];
216
+ if (do_clamp)
217
+ x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
218
  acc_t m = (acc_t)mul[off + i];
219
+ if (do_clamp)
220
+ m = fminf(fmaxf(m, -hidden_clamp), hidden_clamp);
221
  acc_t x2 = x * x, x3 = x2 * x;
222
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
223
  acc_t out_val = poly * m * score;
224
+ if (do_clamp)
225
+ out_val = fminf(fmaxf(out_val, -hidden_clamp), hidden_clamp);
226
  output[off + i] = (scalar_t)out_val;
227
  }
228
  }
 
234
  template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
235
  __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
236
  grouped_poly_norm_bwd_kernel(
237
+ scalar_t *__restrict__ grad_input, scalar_t *__restrict__ grad_mul,
238
+ float *__restrict__ weight_grad, // [num_total_experts, 3] fp32
239
+ float *__restrict__ bias_grad, // [num_total_experts] fp32
 
240
  const scalar_t *__restrict__ grad_output,
241
+ const scalar_t *__restrict__ input, const scalar_t *__restrict__ mul,
242
+ const scalar_t *__restrict__ weight, const scalar_t *__restrict__ bias,
243
+ const int32_t *__restrict__ offsets, const acc_t *__restrict__ inv_rms,
244
+ const float *__restrict__ scores, // nullable, always fp32
245
+ acc_t *__restrict__ grad_scores, // nullable (null when scores is null)
 
 
 
246
  const acc_t eps, const int D, const int num_experts,
247
+ const int expert_offset, const acc_t hidden_clamp) {
 
248
  using v_t = vec_t<scalar_t, width>;
249
  const bool do_clamp = (hidden_clamp >= acc_t(0));
250
 
 
253
  const int64_t base = (int64_t)row * vec_d;
254
 
255
  const v_t *__restrict__ in_v = reinterpret_cast<const v_t *>(input) + base;
256
+ const v_t *__restrict__ go_v =
257
+ reinterpret_cast<const v_t *>(grad_output) + base;
258
  const v_t *__restrict__ m_v = reinterpret_cast<const v_t *>(mul) + base;
259
 
260
  const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
 
283
  #pragma unroll
284
  for (int j = 0; j < width; ++j) {
285
  acc_t x_orig = xv.data[j];
286
+ acc_t x =
287
+ do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
288
  acc_t m_orig = (acc_t)mv.data[j];
289
+ acc_t m =
290
+ do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
291
  acc_t go = (acc_t)gv.data[j];
292
 
293
  // Output clamp mask: recompute pre-clamp output
 
295
  acc_t x2 = x * x, x3 = x2 * x;
296
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
297
  acc_t out_pre = poly * m * score;
298
+ if (fabsf(out_pre) > hidden_clamp)
299
+ go = acc_t(0);
300
  }
301
 
302
  acc_t x2 = x * x;
 
308
  }
309
  }
310
 
311
+ float4 sums =
312
+ block_reduce_f4<BLOCK_SIZE>(make_float4(sdpx, sdpx2, sdpx3, sdp));
313
 
314
  const acc_t inv_d = acc_t(1) / D;
315
  const acc_t s1 = sums.x * inv_d;
 
336
  #pragma unroll
337
  for (int j = 0; j < width; ++j) {
338
  acc_t x_orig = xv.data[j];
339
+ acc_t x =
340
+ do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
341
  acc_t m_orig = (acc_t)mv.data[j];
342
+ acc_t m =
343
+ do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
344
  acc_t x2 = x * x;
345
  acc_t x3 = x2 * x;
346
  acc_t go = (acc_t)gv.data[j];
 
349
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
350
  if (do_clamp) {
351
  acc_t out_pre = (poly + b_val) * m * score;
352
+ if (fabsf(out_pre) > hidden_clamp)
353
+ go = acc_t(0);
354
  }
355
 
356
  acc_t dp = go * m * score;
357
 
358
  // grad_mul with mul clamp mask
359
  acc_t gm_val = go * (poly + b_val) * score;
360
+ if (do_clamp && fabsf(m_orig) > hidden_clamp)
361
+ gm_val = acc_t(0);
362
  gm.data[j] = (scalar_t)gm_val;
363
 
364
  // grad_input with input clamp mask
365
  acc_t g = ir1 * (w2 * dp - x * cx);
366
  g += acc_t(2) * x * ir2 * (w1 * dp - x2 * cx2);
367
  g += acc_t(3) * x2 * ir3 * (w0 * dp - x3 * cx3);
368
+ if (do_clamp && fabsf(x_orig) > hidden_clamp)
369
+ g = acc_t(0);
370
  gi.data[j] = (scalar_t)g;
371
 
372
  dw0 += dp * x3 * ir3;
373
  dw1 += dp * x2 * ir2;
374
  dw2 += dp * x * ir1;
375
+ gs_acc += go * (poly + b_val) * m; // grad_scores accumulator
376
  }
377
 
378
  gi_v[i] = gi;
 
397
  // Scalar fallback (width == 0)
398
  // ---------------------------------------------------------------------------
399
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
400
+ __global__ void __launch_bounds__(BLOCK_SIZE) grouped_poly_norm_bwd_scalar(
401
+ scalar_t *__restrict__ grad_input, scalar_t *__restrict__ grad_mul,
402
+ float *__restrict__ weight_grad, float *__restrict__ bias_grad,
403
+ const scalar_t *__restrict__ grad_output,
404
+ const scalar_t *__restrict__ input, const scalar_t *__restrict__ mul,
405
+ const scalar_t *__restrict__ weight, const scalar_t *__restrict__ bias,
406
+ const int32_t *__restrict__ offsets, const acc_t *__restrict__ inv_rms,
407
+ const float *__restrict__ scores, // nullable, always fp32
408
+ acc_t *__restrict__ grad_scores, // nullable
409
+ const acc_t eps, const int D, const int num_experts,
410
+ const int expert_offset, const acc_t hidden_clamp) {
 
 
 
 
 
 
 
411
  const bool do_clamp = (hidden_clamp >= acc_t(0));
412
  const int row = blockIdx.x;
413
  const int64_t off = (int64_t)row * D;
414
 
415
  const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
416
+ const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1],
417
+ w2 = weight[eidx * 3 + 2];
418
  const acc_t b_val = bias[eidx];
419
+ const acc_t ir1 = inv_rms[row * 3], ir2 = inv_rms[row * 3 + 1],
420
+ ir3 = inv_rms[row * 3 + 2];
421
  const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
422
  const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
423
 
 
425
  acc_t sdpx = 0, sdpx2 = 0, sdpx3 = 0, sdp = 0;
426
  for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
427
  acc_t x_orig = input[off + i];
428
+ acc_t x =
429
+ do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
430
  acc_t m_orig = (acc_t)mul[off + i];
431
+ acc_t m =
432
+ do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
433
  acc_t go = (acc_t)grad_output[off + i];
434
 
435
  if (do_clamp) {
436
  acc_t x2 = x * x, x3 = x2 * x;
437
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
438
  acc_t out_pre = poly * m * score;
439
+ if (fabsf(out_pre) > hidden_clamp)
440
+ go = acc_t(0);
441
  }
442
 
443
  acc_t x2 = x * x;
444
  acc_t dp = go * m * score;
445
+ sdp += dp;
446
+ sdpx += dp * x;
447
+ sdpx2 += dp * x2;
448
+ sdpx3 += dp * x2 * x;
449
  }
450
 
451
+ float4 sums =
452
+ block_reduce_f4<BLOCK_SIZE>(make_float4(sdpx, sdpx2, sdpx3, sdp));
453
  const acc_t inv_d = acc_t(1) / D;
454
  const acc_t s1 = sums.x * inv_d, s2 = sums.y * inv_d, s3 = sums.z * inv_d;
455
+ const acc_t cx = w2 * s1 * ir1 * ir1, cx2 = w1 * s2 * ir2 * ir2,
456
+ cx3 = w0 * s3 * ir3 * ir3;
457
 
458
  // Pass 2: grads with clamp masks
459
  acc_t dw0 = 0, dw1 = 0, dw2 = 0, gs_acc = 0;
460
  for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
461
  acc_t x_orig = input[off + i], m_orig = (acc_t)mul[off + i];
462
+ acc_t x =
463
+ do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
464
+ acc_t m =
465
+ do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
466
  acc_t go = (acc_t)grad_output[off + i];
467
  acc_t x2 = x * x, x3 = x2 * x;
468
  acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
 
470
  // Output clamp mask
471
  if (do_clamp) {
472
  acc_t out_pre = (poly + b_val) * m * score;
473
+ if (fabsf(out_pre) > hidden_clamp)
474
+ go = acc_t(0);
475
  }
476
 
477
  acc_t dp = go * m * score;
478
 
479
  acc_t gm_val = go * (poly + b_val) * score;
480
+ if (do_clamp && fabsf(m_orig) > hidden_clamp)
481
+ gm_val = acc_t(0);
482
  grad_mul[off + i] = (scalar_t)gm_val;
483
 
484
+ acc_t g = ir1 * (w2 * dp - x * cx) +
485
+ acc_t(2) * x * ir2 * (w1 * dp - x2 * cx2) +
486
+ acc_t(3) * x2 * ir3 * (w0 * dp - x3 * cx3);
487
+ if (do_clamp && fabsf(x_orig) > hidden_clamp)
488
+ g = acc_t(0);
489
  grad_input[off + i] = (scalar_t)g;
490
 
491
+ dw0 += dp * x3 * ir3;
492
+ dw1 += dp * x2 * ir2;
493
+ dw2 += dp * x * ir1;
494
  gs_acc += go * (poly + b_val) * m;
495
  }
496
 
 
512
  // Internal helpers — shared kernel dispatch
513
  // ---------------------------------------------------------------------------
514
  #define FWD_LAUNCH(width_val, scalar_type_name) \
515
+ MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), scalar_type_name, [&] { \
516
+ motif::grouped_poly_norm_fwd_kernel<scalar_t, float, width_val, BLOCK> \
517
+ <<<grid, block, 0, stream>>>( \
518
+ output.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), \
519
+ input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(), \
520
+ weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), \
521
+ offsets.data_ptr<int32_t>(), scores_ptr, (float)eps, D, \
522
+ num_experts, (int)expert_offset, (float)hidden_clamp); \
523
+ })
 
 
524
 
525
  static std::tuple<torch::Tensor, torch::Tensor>
526
  _fwd_impl(const torch::Tensor &input, const torch::Tensor &mul,
527
+ const torch::Tensor &weight, const torch::Tensor &bias,
528
+ const torch::Tensor &offsets, const float *scores_ptr, double eps,
529
+ int64_t expert_offset, double hidden_clamp) {
530
  const int D = input.size(-1);
531
  const int64_t N = input.size(0);
532
  const int num_experts = offsets.size(0);
533
  constexpr int BLOCK = 128;
534
+ dim3 grid(N);
535
+ dim3 block(BLOCK);
536
 
537
  auto output = torch::empty_like(input);
538
  auto inv_rms = torch::empty({N, 3}, input.options().dtype(torch::kFloat));
 
551
  output.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(),
552
  input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
553
  weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
554
+ offsets.data_ptr<int32_t>(), scores_ptr, (float)eps, D,
555
+ num_experts, (int)expert_offset, (float)hidden_clamp);
 
556
  });
557
  }
558
  return {output, inv_rms};
 
560
  #undef FWD_LAUNCH
561
 
562
  #define BWD_LAUNCH(width_val, scalar_type_name, kernel_name) \
563
+ MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), scalar_type_name, [&] { \
564
+ motif::kernel_name<scalar_t, float, width_val, BLOCK> \
565
+ <<<grid, block, 0, stream>>>( \
566
+ input_grad.data_ptr<scalar_t>(), mul_grad.data_ptr<scalar_t>(), \
567
+ wg_f32.data_ptr<float>(), bg_f32.data_ptr<float>(), \
568
+ grad_output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
569
+ mul.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
570
+ bias.data_ptr<scalar_t>(), offsets.data_ptr<int32_t>(), \
571
+ inv_rms.data_ptr<float>(), scores_ptr, gs_ptr, (float)eps, D, \
572
+ num_experts, (int)expert_offset, (float)hidden_clamp); \
573
+ })
 
 
 
 
574
 
575
  static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
576
+ torch::Tensor>
577
  _bwd_impl(const torch::Tensor &grad_output, const torch::Tensor &input,
578
+ const torch::Tensor &mul, const torch::Tensor &weight,
579
+ const torch::Tensor &bias, const torch::Tensor &offsets,
580
+ const torch::Tensor &inv_rms, const float *scores_ptr, float *gs_ptr,
581
+ int64_t N, double eps, int64_t expert_offset, double hidden_clamp) {
 
582
  const int D = input.size(-1);
583
  const int num_experts = offsets.size(0);
584
  constexpr int BLOCK = 128;
585
+ dim3 grid(N);
586
+ dim3 block(BLOCK);
587
 
588
  auto input_grad = torch::empty_like(input);
589
  auto mul_grad = torch::empty_like(mul);
590
+ auto wg_f32 =
591
+ torch::zeros({weight.size(0), 3}, input.options().dtype(torch::kFloat));
592
+ auto bg_f32 =
593
+ torch::zeros({bias.size(0)}, input.options().dtype(torch::kFloat));
594
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
595
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
596
 
 
604
  motif::grouped_poly_norm_bwd_scalar<scalar_t, float, BLOCK>
605
  <<<grid, block, 0, stream>>>(
606
  input_grad.data_ptr<scalar_t>(),
607
+ mul_grad.data_ptr<scalar_t>(), wg_f32.data_ptr<float>(),
608
+ bg_f32.data_ptr<float>(), grad_output.data_ptr<scalar_t>(),
 
609
  input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
610
  weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
611
  offsets.data_ptr<int32_t>(), inv_rms.data_ptr<float>(),
612
+ scores_ptr, gs_ptr, (float)eps, D, num_experts,
613
+ (int)expert_offset, (float)hidden_clamp);
 
614
  });
615
  }
616
 
 
625
  // Public API: without scores
626
  // ---------------------------------------------------------------------------
627
  std::tuple<torch::Tensor, torch::Tensor>
628
+ grouped_poly_norm_forward(const torch::Tensor &input, const torch::Tensor &mul,
629
+ const torch::Tensor &weight,
630
+ const torch::Tensor &bias,
631
+ const torch::Tensor &offsets, double eps,
632
+ int64_t expert_offset, double hidden_clamp) {
633
+ return _fwd_impl(input, mul, weight, bias, offsets, nullptr, eps,
634
+ expert_offset, hidden_clamp);
635
  }
636
 
637
  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
638
+ grouped_poly_norm_backward(const torch::Tensor &grad_output,
639
+ const torch::Tensor &input, const torch::Tensor &mul,
640
+ const torch::Tensor &weight,
641
+ const torch::Tensor &bias,
642
+ const torch::Tensor &offsets,
643
+ const torch::Tensor &inv_rms, double eps,
644
+ int64_t expert_offset, double hidden_clamp) {
645
  const int64_t N = input.size(0);
646
+ auto [ig, mg, wg, bg, _] =
647
+ _bwd_impl(grad_output, input, mul, weight, bias, offsets, inv_rms,
648
+ nullptr, nullptr, N, eps, expert_offset, hidden_clamp);
649
  return {ig, mg, wg, bg};
650
  }
651
 
652
  // ---------------------------------------------------------------------------
653
  // Public API: with scores
654
  // ---------------------------------------------------------------------------
655
+ std::tuple<torch::Tensor, torch::Tensor> grouped_poly_norm_forward_scored(
 
656
  const torch::Tensor &input, const torch::Tensor &mul,
657
  const torch::Tensor &weight, const torch::Tensor &bias,
658
+ const torch::Tensor &offsets, const torch::Tensor &scores, double eps,
659
+ int64_t expert_offset, double hidden_clamp) {
660
+ return _fwd_impl(input, mul, weight, bias, offsets, scores.data_ptr<float>(),
661
+ eps, expert_offset, hidden_clamp);
662
  }
663
 
664
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
665
+ torch::Tensor>
666
  grouped_poly_norm_backward_scored(
667
  const torch::Tensor &grad_output, const torch::Tensor &input,
668
  const torch::Tensor &mul, const torch::Tensor &weight,
669
  const torch::Tensor &bias, const torch::Tensor &offsets,
670
+ const torch::Tensor &inv_rms, const torch::Tensor &scores, double eps,
671
+ int64_t expert_offset, double hidden_clamp) {
672
  const int64_t N = input.size(0);
673
  auto gs_f32 = torch::empty({N}, input.options().dtype(torch::kFloat));
674
+ auto [ig, mg, wg, bg, _] =
675
+ _bwd_impl(grad_output, input, mul, weight, bias, offsets, inv_rms,
676
+ scores.data_ptr<float>(), gs_f32.data_ptr<float>(), N, eps,
677
+ expert_offset, hidden_clamp);
678
  auto gs = gs_f32.unsqueeze(-1);
679
  return {ig, mg, wg, bg, gs};
680
  }
benchmarks/cases/grouped_mul_poly.py CHANGED
@@ -5,10 +5,8 @@ from common.diff_engine import DiffCase
5
 
6
  torch._functorch.config.donated_buffer = False
7
 
8
- from grouped_poly_norm import (
9
- fused_mul_grouped_poly_norm,
10
- fused_mul_grouped_poly_norm_ref,
11
- )
12
 
13
  # 384 / 8 (EP) = 48 experts per rank
14
  # total_tokens = bs * sl, which equals per-rank tokens
@@ -28,9 +26,14 @@ class GroupedRefModule(torch.nn.Module):
28
  self.expert_offset = expert_offset
29
 
30
  def forward(self, x, mul):
31
- return fused_mul_grouped_poly_norm_ref(x, mul, self.weight, self.bias,
32
- self.offsets, self.eps,
33
- expert_offset=self.expert_offset)
 
 
 
 
 
34
 
35
 
36
  class GroupedCUDAModule(torch.nn.Module):
@@ -45,8 +48,12 @@ class GroupedCUDAModule(torch.nn.Module):
45
  self.expert_offset = expert_offset
46
 
47
  def forward(self, x, mul):
48
- return fused_mul_grouped_poly_norm(x, mul, self.weight, self.bias,
49
- self.offsets, self.eps,
 
 
 
 
50
  expert_offset=self.expert_offset)
51
 
52
 
@@ -66,29 +73,28 @@ class GroupedMulPoly(DiffCase):
66
  probs = torch.ones(num_experts) / num_experts
67
  assignments = torch.multinomial(probs, total_tokens, replacement=True)
68
  counts = torch.bincount(assignments, minlength=num_experts).tolist()
69
- offsets = torch.cumsum(
70
- torch.tensor(counts, dtype=torch.int32), dim=0)
71
 
72
  return {
73
  "x":
74
- torch.randn(total_tokens, hidden, dtype=dtype,
75
- requires_grad=True) * 0.5,
76
  "mul":
77
- torch.randn(total_tokens, hidden, dtype=dtype,
78
- requires_grad=True) * 0.5,
79
  "weight":
80
- torch.ones(num_experts, 3, dtype=dtype) / 3 +
81
- torch.randn(num_experts, 3, dtype=dtype) * 0.01,
82
  "bias":
83
- torch.randn(num_experts, 1, dtype=dtype) * 0.01,
84
  "offsets":
85
- offsets,
86
  "dim":
87
- hidden,
88
  "eps":
89
- eps,
90
  "dtype":
91
- dtype,
92
  }
93
 
94
  def make_naive(self, I):
 
5
 
6
  torch._functorch.config.donated_buffer = False
7
 
8
+ from grouped_poly_norm import (fused_mul_grouped_poly_norm,
9
+ fused_mul_grouped_poly_norm_ref)
 
 
10
 
11
  # 384 / 8 (EP) = 48 experts per rank
12
  # total_tokens = bs * sl, which equals per-rank tokens
 
26
  self.expert_offset = expert_offset
27
 
28
  def forward(self, x, mul):
29
+ return fused_mul_grouped_poly_norm_ref(
30
+ x,
31
+ mul,
32
+ self.weight,
33
+ self.bias,
34
+ self.offsets,
35
+ self.eps,
36
+ expert_offset=self.expert_offset)
37
 
38
 
39
  class GroupedCUDAModule(torch.nn.Module):
 
48
  self.expert_offset = expert_offset
49
 
50
  def forward(self, x, mul):
51
+ return fused_mul_grouped_poly_norm(x,
52
+ mul,
53
+ self.weight,
54
+ self.bias,
55
+ self.offsets,
56
+ self.eps,
57
  expert_offset=self.expert_offset)
58
 
59
 
 
73
  probs = torch.ones(num_experts) / num_experts
74
  assignments = torch.multinomial(probs, total_tokens, replacement=True)
75
  counts = torch.bincount(assignments, minlength=num_experts).tolist()
76
+ offsets = torch.cumsum(torch.tensor(counts, dtype=torch.int32), dim=0)
 
77
 
78
  return {
79
  "x":
80
+ torch.randn(total_tokens, hidden, dtype=dtype,
81
+ requires_grad=True) * 0.5,
82
  "mul":
83
+ torch.randn(total_tokens, hidden, dtype=dtype,
84
+ requires_grad=True) * 0.5,
85
  "weight":
86
+ torch.ones(num_experts, 3, dtype=dtype) / 3 +
87
+ torch.randn(num_experts, 3, dtype=dtype) * 0.01,
88
  "bias":
89
+ torch.randn(num_experts, 1, dtype=dtype) * 0.01,
90
  "offsets":
91
+ offsets,
92
  "dim":
93
+ hidden,
94
  "eps":
95
+ eps,
96
  "dtype":
97
+ dtype,
98
  }
99
 
100
  def make_naive(self, I):
benchmarks/profile_bwd.py DELETED
@@ -1,146 +0,0 @@
1
- """Profiling script for grouped polynorm backward kernel using torch.profiler."""
2
- import argparse
3
- import torch
4
- import torch.cuda
5
- from torch.profiler import profile, ProfilerActivity
6
- from grouped_poly_norm import fused_mul_grouped_poly_norm
7
-
8
- torch.set_default_device("cuda")
9
-
10
-
11
- def make_inputs(N, D, num_experts):
12
- torch.manual_seed(42)
13
- probs = torch.ones(num_experts) / num_experts
14
- assignments = torch.multinomial(probs, N, replacement=True)
15
- counts = torch.bincount(assignments, minlength=num_experts).tolist()
16
- offsets = torch.cumsum(
17
- torch.tensor(counts, dtype=torch.int32), dim=0)
18
-
19
- x = torch.randn(N, D, dtype=torch.bfloat16, requires_grad=True) * 0.5
20
- m = torch.randn(N, D, dtype=torch.bfloat16, requires_grad=True) * 0.5
21
- w = (torch.ones(num_experts, 3, dtype=torch.bfloat16) / 3
22
- ).requires_grad_(True)
23
- b = (torch.randn(num_experts, 1, dtype=torch.bfloat16) * 0.01
24
- ).requires_grad_(True)
25
- return x, m, w, b, offsets
26
-
27
-
28
- def main():
29
- parser = argparse.ArgumentParser()
30
- parser.add_argument("--tokens", type=int, default=4096)
31
- parser.add_argument("--dim", type=int, default=1280)
32
- parser.add_argument("--experts", type=int, default=48)
33
- parser.add_argument("--output", type=str, default="/tmp/profile")
34
- args = parser.parse_args()
35
-
36
- N, D, num_experts = args.tokens, args.dim, args.experts
37
-
38
- # Warmup (fresh inputs each time to avoid graph reuse issues)
39
- for _ in range(3):
40
- x, m, w, b, offsets = make_inputs(N, D, num_experts)
41
- out = fused_mul_grouped_poly_norm(x, m, w, b, offsets)
42
- out.sum().backward()
43
- torch.cuda.synchronize()
44
-
45
- # Profiled: mimic do_bench — forward once, backward multiple times with retain_graph
46
- x, m, w, b, offsets = make_inputs(N, D, num_experts)
47
- out = fused_mul_grouped_poly_norm(x, m, w, b, offsets)
48
- gin = [x, m] + [w, b]
49
- g = [torch.randn_like(out)]
50
-
51
- # Warmup backward
52
- for _ in range(5):
53
- torch.autograd.grad(out, gin, g, retain_graph=True, allow_unused=True)
54
- torch.cuda.synchronize()
55
-
56
- with profile(
57
- activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
58
- record_shapes=True,
59
- with_stack=True,
60
- ) as prof:
61
- for _ in range(100):
62
- torch.autograd.grad(out, gin, g, retain_graph=True, allow_unused=True)
63
- torch.cuda.synchronize()
64
-
65
- # Print kernel-level stats
66
- print(f"\n=== Kernel Table (N={N}, D={D}) ===")
67
- print(prof.key_averages().table(
68
- sort_by="cuda_time_total", row_limit=20))
69
-
70
- # Export chrome trace
71
- trace_path = f"{args.output}_trace_N{N}.json"
72
- prof.export_chrome_trace(trace_path)
73
- print(f"\nTrace exported to {trace_path}")
74
-
75
- # === Occupancy analysis from Triton kernel metadata ===
76
- print(f"\n=== Occupancy Analysis ===")
77
-
78
- props = torch.cuda.get_device_properties(0)
79
- print(f"GPU: {props.name}")
80
- print(f"SMs: {props.multi_processor_count}")
81
- print(f"Max threads/SM: {props.max_threads_per_multi_processor}")
82
- print(f"Regs/SM: {props.regs_per_multiprocessor}")
83
- print(f"Shared mem/block: {props.shared_memory_per_block} bytes")
84
-
85
- # Get register info from Triton compiled cubins
86
- try:
87
- import glob
88
- import json
89
- import subprocess
90
- cache_dir = os.path.expanduser("~/.triton/cache")
91
-
92
- # Find metadata JSON files
93
- json_files = sorted(glob.glob(f"{cache_dir}/**/*.json", recursive=True),
94
- key=os.path.getmtime, reverse=True)
95
- print(f"\nFound {len(json_files)} compiled kernel metadata files")
96
- for jf in json_files[:10]:
97
- try:
98
- with open(jf) as f:
99
- meta = json.load(f)
100
- if isinstance(meta, dict):
101
- n_regs = meta.get('num_regs', meta.get('n_regs', None))
102
- n_spills = meta.get('num_spills', meta.get('n_spills', None))
103
- name = meta.get('name', os.path.basename(jf))
104
- shared = meta.get('shared', None)
105
- if n_regs is not None:
106
- print(f" {name}: regs={n_regs}, spills={n_spills}, shared={shared}")
107
- except Exception:
108
- pass
109
-
110
- # Also try cuobjdump on recent cubins
111
- cubin_files = sorted(glob.glob(f"{cache_dir}/**/*.cubin", recursive=True),
112
- key=os.path.getmtime, reverse=True)
113
- print(f"\nFound {len(cubin_files)} cubins, inspecting latest:")
114
- for cb in cubin_files[:5]:
115
- try:
116
- result = subprocess.run(
117
- ["cuobjdump", "-res-usage", cb],
118
- capture_output=True, text=True, timeout=5)
119
- if result.returncode == 0 and result.stdout.strip():
120
- print(f"\n {os.path.basename(cb)}:")
121
- for line in result.stdout.strip().split('\n'):
122
- print(f" {line}")
123
- except Exception as e:
124
- print(f" cuobjdump failed: {e}")
125
- break
126
- except Exception as e:
127
- print(f"Cache inspection error: {e}")
128
-
129
- # Calculate theoretical occupancy for different register counts
130
- print("\n=== Theoretical Occupancy (num_warps=4, 128 threads/block) ===")
131
- threads_per_block = 128
132
- max_threads = props.max_threads_per_multi_processor
133
- total_regs = props.regs_per_multiprocessor
134
- for n_regs in [64, 96, 128, 160, 192, 224, 256]:
135
- regs_per_block = n_regs * threads_per_block
136
- max_blocks_by_regs = total_regs // regs_per_block
137
- max_blocks_by_threads = max_threads // threads_per_block
138
- blocks = min(max_blocks_by_regs, max_blocks_by_threads, 32)
139
- active_threads = blocks * threads_per_block
140
- occupancy = active_threads / max_threads * 100
141
- print(f" {n_regs:3d} regs/thread -> {blocks:2d} blocks/SM -> "
142
- f"{active_threads:4d} threads -> {occupancy:.1f}% occupancy")
143
-
144
-
145
- if __name__ == "__main__":
146
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/run_cases.py CHANGED
@@ -28,8 +28,10 @@ def plot_result(r_path, columns=None):
28
  import pandas as pd
29
  df = pd.read_csv(r_path + ".csv")
30
  if columns is None:
31
- columns = [c for c in ["Naive", "Compiled", "Cuda", "Triton"]
32
- if c in df.columns]
 
 
33
  plt.figure(figsize=(12, 6))
34
  ax = df.plot(x="config", y=columns, kind="bar", ax=plt.gca())
35
  ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
@@ -64,6 +66,11 @@ def main():
64
  default="bf16",
65
  help="Data type for benchmarking (default: bf16)",
66
  )
 
 
 
 
 
67
  args = ap.parse_args()
68
 
69
  dtype_map = {
@@ -81,12 +88,14 @@ def main():
81
  mod = importlib.import_module(f"cases.{args.case}")
82
  case: DiffCase = mod.CASE
83
 
84
- calculate_diff(
85
- case,
86
- batch_size=2,
87
- seq_len=128,
88
- hidden_size=4096,
89
- )
 
 
90
 
91
  for dtype_name, dtype in dtypes:
92
  print(f"\n{'=' * 60}")
@@ -161,28 +170,40 @@ def main():
161
  itertools.product(dim, batch_size_range, seq_length_range))
162
 
163
  if is_grouped:
164
- csv_line_vals = ("naive", "compiled", "cuda", "speedup")
165
- csv_line_names = {
166
  "naive": "Naive",
167
  "compiled": "Compiled",
168
  "cuda": "Triton",
169
  "speedup": "SpeedUp",
170
  }
 
 
 
 
 
 
 
 
171
  else:
172
- csv_line_vals = ("naive", "cuda", "speedup")
173
- csv_line_names = {
174
  "naive": "Naive",
175
  "cuda": "Cuda",
176
  "speedup": "SpeedUp",
177
  }
 
 
178
 
179
  bench = make_fwd_benchmark_for_case(
180
  case=case,
181
  configs=configs,
182
  plot_name=f"{args.case}-{dtype_name}-fwd-perf",
183
  dtype=dtype,
184
- line_vals=csv_line_vals,
185
- line_names=csv_line_names,
 
 
186
  )
187
 
188
  bench.run(print_data=True, save_path=save_dir)
@@ -192,8 +213,10 @@ def main():
192
  configs=configs,
193
  plot_name=f"{args.case}-{dtype_name}-bwd-perf",
194
  dtype=dtype,
195
- line_vals=csv_line_vals,
196
- line_names=csv_line_names,
 
 
197
  )
198
 
199
  bench.run(print_data=True, save_path=save_dir)
 
28
  import pandas as pd
29
  df = pd.read_csv(r_path + ".csv")
30
  if columns is None:
31
+ columns = [
32
+ c for c in ["Naive", "Compiled", "Cuda", "Triton"]
33
+ if c in df.columns
34
+ ]
35
  plt.figure(figsize=(12, 6))
36
  ax = df.plot(x="config", y=columns, kind="bar", ax=plt.gca())
37
  ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
 
66
  default="bf16",
67
  help="Data type for benchmarking (default: bf16)",
68
  )
69
+ ap.add_argument(
70
+ "--profile",
71
+ action="store_true",
72
+ help="Export chrome traces for backward benchmarks",
73
+ )
74
  args = ap.parse_args()
75
 
76
  dtype_map = {
 
88
  mod = importlib.import_module(f"cases.{args.case}")
89
  case: DiffCase = mod.CASE
90
 
91
+ # Correctness checks across multiple configs
92
+ for bs, sl, hid in [(2, 128, 4096), (8, 4096, 1280), (1, 32768, 1280)]:
93
+ print(
94
+ f"Checking correctness: bs={bs}, sl={sl}, D={hid} "
95
+ f"(N={bs*sl})...",
96
+ end=" ")
97
+ calculate_diff(case, batch_size=bs, seq_len=sl, hidden_size=hid)
98
+ print("✅")
99
 
100
  for dtype_name, dtype in dtypes:
101
  print(f"\n{'=' * 60}")
 
170
  itertools.product(dim, batch_size_range, seq_length_range))
171
 
172
  if is_grouped:
173
+ fwd_line_vals = ("naive", "compiled", "cuda", "speedup")
174
+ fwd_line_names = {
175
  "naive": "Naive",
176
  "compiled": "Compiled",
177
  "cuda": "Triton",
178
  "speedup": "SpeedUp",
179
  }
180
+ bwd_line_vals = ("naive", "compiled", "compiled_cuda",
181
+ "speedup")
182
+ bwd_line_names = {
183
+ "naive": "Naive",
184
+ "compiled": "Compiled",
185
+ "compiled_cuda": "CompiledCUDA",
186
+ "speedup": "SpeedUp",
187
+ }
188
  else:
189
+ fwd_line_vals = ("naive", "cuda", "speedup")
190
+ fwd_line_names = {
191
  "naive": "Naive",
192
  "cuda": "Cuda",
193
  "speedup": "SpeedUp",
194
  }
195
+ bwd_line_vals = fwd_line_vals
196
+ bwd_line_names = fwd_line_names
197
 
198
  bench = make_fwd_benchmark_for_case(
199
  case=case,
200
  configs=configs,
201
  plot_name=f"{args.case}-{dtype_name}-fwd-perf",
202
  dtype=dtype,
203
+ line_vals=fwd_line_vals,
204
+ line_names=fwd_line_names,
205
+ profile=args.profile,
206
+ profile_dir=os.path.join(save_dir, "traces"),
207
  )
208
 
209
  bench.run(print_data=True, save_path=save_dir)
 
213
  configs=configs,
214
  plot_name=f"{args.case}-{dtype_name}-bwd-perf",
215
  dtype=dtype,
216
+ line_vals=bwd_line_vals,
217
+ line_names=bwd_line_names,
218
+ profile=args.profile,
219
+ profile_dir=os.path.join(save_dir, "traces"),
220
  )
221
 
222
  bench.run(print_data=True, save_path=save_dir)
registration.h CHANGED
@@ -2,8 +2,8 @@
2
 
3
  // Local build compatibility shim for kernel-builder's registration.h
4
 
5
- #include <torch/library.h>
6
  #include <torch/extension.h>
 
7
 
8
  // TORCH_LIBRARY_EXPAND may not be defined in all PyTorch versions
9
  #ifndef TORCH_LIBRARY_EXPAND
@@ -11,4 +11,5 @@
11
  #endif
12
 
13
  // Generate the PyInit_<name> entry point for the shared library
14
- #define REGISTER_EXTENSION(name) PYBIND11_MODULE(name, m) {}
 
 
2
 
3
  // Local build compatibility shim for kernel-builder's registration.h
4
 
 
5
  #include <torch/extension.h>
6
+ #include <torch/library.h>
7
 
8
  // TORCH_LIBRARY_EXPAND may not be defined in all PyTorch versions
9
  #ifndef TORCH_LIBRARY_EXPAND
 
11
  #endif
12
 
13
  // Generate the PyInit_<name> entry point for the shared library
14
+ #define REGISTER_EXTENSION(name) \
15
+ PYBIND11_MODULE(name, m) {}
setup.py CHANGED
@@ -44,9 +44,9 @@ NVCC_FLAGS = [
44
  "--use_fast_math",
45
  "-std=c++17",
46
  # Generate code for common architectures
47
- "-gencode=arch=compute_80,code=sm_80", # A100
48
- "-gencode=arch=compute_89,code=sm_89", # L40/4090
49
- "-gencode=arch=compute_90,code=sm_90", # H100
50
  ]
51
 
52
  # Check for B200 support (sm_100, requires CUDA 12.8+)
 
44
  "--use_fast_math",
45
  "-std=c++17",
46
  # Generate code for common architectures
47
+ "-gencode=arch=compute_80,code=sm_80", # A100
48
+ "-gencode=arch=compute_89,code=sm_89", # L40/4090
49
+ "-gencode=arch=compute_90,code=sm_90", # H100
50
  ]
51
 
52
  # Check for B200 support (sm_100, requires CUDA 12.8+)
tests/test_fused_mul_grouped_poly_norm.py CHANGED
@@ -1,10 +1,6 @@
1
  import pytest
2
  import torch
3
-
4
- from grouped_poly_norm import (
5
- _has_cuda_ops,
6
- fused_mul_grouped_poly_norm_ref,
7
- )
8
 
9
  if _has_cuda_ops:
10
  from grouped_poly_norm import fused_mul_grouped_poly_norm
@@ -23,12 +19,19 @@ CUDA_DEVICES = ["cuda:0"]
23
 
24
  def _counts_to_offsets(counts_list, device):
25
  """Convert list of counts to cumsum offsets tensor."""
26
- return torch.cumsum(
27
- torch.tensor(counts_list, device=device, dtype=torch.int32), dim=0).to(torch.int32)
28
-
29
-
30
- def _make_inputs(total_tokens, hidden_dim, num_experts, dtype, device,
31
- seed=42, expert_offset=0):
 
 
 
 
 
 
 
32
  """Create deterministic test inputs with random token distribution."""
33
  torch.manual_seed(seed)
34
 
@@ -57,40 +60,64 @@ def _make_scores(total_tokens, device, dtype=torch.float32):
57
  return torch.rand(total_tokens, 1, device=device, dtype=dtype) * 0.5 + 0.5
58
 
59
 
60
- def _run_ref(input_t, mul_t, weight, bias, offsets, expert_offset=0,
61
- scores=None, hidden_clamp=None):
 
 
 
 
 
 
62
  """Run reference forward + backward, return output and grads."""
63
  inp = input_t.clone().detach().requires_grad_(True)
64
  m = mul_t.clone().detach().requires_grad_(True)
65
  w = weight.clone().detach().requires_grad_(True)
66
  b = bias.clone().detach().requires_grad_(True)
67
- s = scores.clone().detach().requires_grad_(True) if scores is not None else None
68
-
69
- out = fused_mul_grouped_poly_norm_ref(inp, m, w, b, offsets,
 
 
 
 
 
70
  expert_offset=expert_offset,
71
- scores=s, hidden_clamp=hidden_clamp)
 
72
  out.sum().backward()
73
 
74
  grads = (out, inp.grad, m.grad, w.grad, b.grad)
75
- return grads + (s.grad,) if s is not None else grads + (None,)
76
 
77
 
78
- def _run_cuda(input_t, mul_t, weight, bias, offsets, expert_offset=0,
79
- scores=None, hidden_clamp=None):
 
 
 
 
 
 
80
  """Run CUDA forward + backward, return output and grads."""
81
  inp = input_t.clone().detach().requires_grad_(True)
82
  m = mul_t.clone().detach().requires_grad_(True)
83
  w = weight.clone().detach().requires_grad_(True)
84
  b = bias.clone().detach().requires_grad_(True)
85
- s = scores.clone().detach().requires_grad_(True) if scores is not None else None
86
-
87
- out = fused_mul_grouped_poly_norm(inp, m, w, b, offsets,
 
 
 
 
 
88
  expert_offset=expert_offset,
89
- scores=s, hidden_clamp=hidden_clamp)
 
90
  out.sum().backward()
91
 
92
  grads = (out, inp.grad, m.grad, w.grad, b.grad)
93
- return grads + (s.grad,) if s is not None else grads + (None,)
94
 
95
 
96
  @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@@ -113,15 +140,26 @@ def test_fused_mul_grouped_poly_norm_forward(
113
  """CUDA forward output should match PyTorch reference."""
114
  torch.set_default_device(device)
115
  input_t, mul_t, weight, bias, offsets = _make_inputs(
116
- num_tokens, d, num_experts, dtype, device, seed,
 
 
 
 
 
117
  expert_offset=expert_offset)
118
 
119
- out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias,
 
 
 
120
  offsets,
121
  expert_offset=expert_offset)
122
- out_tri = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias,
123
- offsets,
124
- expert_offset=expert_offset)
 
 
 
125
 
126
  assert out_ref.shape == out_tri.shape == (num_tokens, d)
127
  assert out_ref.dtype == out_tri.dtype == dtype
@@ -152,7 +190,12 @@ def test_fused_mul_grouped_poly_norm_backward(
152
  """CUDA backward gradients should match PyTorch reference."""
153
  torch.set_default_device(device)
154
  input_t, mul_t, weight, bias, offsets = _make_inputs(
155
- num_tokens, d, num_experts, dtype, device, seed,
 
 
 
 
 
156
  expert_offset=expert_offset)
157
 
158
  _, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref(
@@ -195,12 +238,18 @@ def test_fused_mul_grouped_poly_norm_zero_token_experts(
195
  bias = torch.zeros(total_experts, 1, device=device, dtype=dtype)
196
  offsets = _counts_to_offsets(counts, device)
197
 
198
- out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias,
 
 
 
199
  offsets,
200
  expert_offset=expert_offset)
201
- out_tri = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias,
202
- offsets,
203
- expert_offset=expert_offset)
 
 
 
204
 
205
  if dtype == torch.float32:
206
  assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4)
@@ -208,12 +257,18 @@ def test_fused_mul_grouped_poly_norm_zero_token_experts(
208
  assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
209
 
210
  # Check backward with zero-token experts
211
- _, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t, mul_t, weight, bias,
 
 
 
 
 
 
 
 
 
212
  offsets,
213
  expert_offset=expert_offset)
214
- _, _, _, w_grad_tri, b_grad_tri, _ = _run_cuda(input_t, mul_t, weight, bias,
215
- offsets,
216
- expert_offset=expert_offset)
217
 
218
  if dtype == torch.float32:
219
  atol, rtol = 1e-3, 1e-3
@@ -270,7 +325,11 @@ def test_fused_mul_grouped_poly_norm_no_nan_inf(
270
  @pytest.mark.parametrize("dtype", DTYPES)
271
  @pytest.mark.parametrize("device", CUDA_DEVICES)
272
  def test_fused_mul_grouped_poly_norm_scores_forward(
273
- num_tokens, d, num_experts, dtype, device,
 
 
 
 
274
  ):
275
  """Forward with scores should match reference."""
276
  torch.set_default_device(device)
@@ -278,10 +337,18 @@ def test_fused_mul_grouped_poly_norm_scores_forward(
278
  num_tokens, d, num_experts, dtype, device)
279
  scores = _make_scores(num_tokens, device)
280
 
281
- out_ref = fused_mul_grouped_poly_norm_ref(
282
- input_t, mul_t, weight, bias, offsets, scores=scores)
283
- out_tri = fused_mul_grouped_poly_norm(
284
- input_t, mul_t, weight, bias, offsets, scores=scores)
 
 
 
 
 
 
 
 
285
 
286
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
287
  assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
@@ -294,7 +361,11 @@ def test_fused_mul_grouped_poly_norm_scores_forward(
294
  @pytest.mark.parametrize("dtype", DTYPES)
295
  @pytest.mark.parametrize("device", CUDA_DEVICES)
296
  def test_fused_mul_grouped_poly_norm_scores_backward(
297
- num_tokens, d, num_experts, dtype, device,
 
 
 
 
298
  ):
299
  """Backward with scores should match reference."""
300
  torch.set_default_device(device)
@@ -302,10 +373,18 @@ def test_fused_mul_grouped_poly_norm_scores_backward(
302
  num_tokens, d, num_experts, dtype, device)
303
  scores = _make_scores(num_tokens, device)
304
 
305
- out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
306
- input_t, mul_t, weight, bias, offsets, scores=scores)
307
- out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(
308
- input_t, mul_t, weight, bias, offsets, scores=scores)
 
 
 
 
 
 
 
 
309
 
310
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
311
  # weight/bias grads use atomicAdd accumulation across tokens,
@@ -332,7 +411,12 @@ CLAMP_VALUES = [10.0, 1.0, 0.5]
332
  @pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
333
  @pytest.mark.parametrize("device", CUDA_DEVICES)
334
  def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
335
- num_tokens, d, num_experts, dtype, hidden_clamp, device,
 
 
 
 
 
336
  ):
337
  """Forward with hidden_clamp should match reference."""
338
  torch.set_default_device(device)
@@ -340,12 +424,20 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
340
  num_tokens, d, num_experts, dtype, device)
341
  scores = _make_scores(num_tokens, device)
342
 
343
- out_ref = fused_mul_grouped_poly_norm_ref(
344
- input_t, mul_t, weight, bias, offsets,
345
- scores=scores, hidden_clamp=hidden_clamp)
346
- out_tri = fused_mul_grouped_poly_norm(
347
- input_t, mul_t, weight, bias, offsets,
348
- scores=scores, hidden_clamp=hidden_clamp)
 
 
 
 
 
 
 
 
349
 
350
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
351
  assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
@@ -359,7 +451,12 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
359
  @pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
360
  @pytest.mark.parametrize("device", CUDA_DEVICES)
361
  def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
362
- num_tokens, d, num_experts, dtype, hidden_clamp, device,
 
 
 
 
 
363
  ):
364
  """Backward with hidden_clamp should match reference."""
365
  torch.set_default_device(device)
@@ -368,11 +465,21 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
368
  scores = _make_scores(num_tokens, device)
369
 
370
  out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
371
- input_t, mul_t, weight, bias, offsets,
372
- scores=scores, hidden_clamp=hidden_clamp)
 
 
 
 
 
373
  out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(
374
- input_t, mul_t, weight, bias, offsets,
375
- scores=scores, hidden_clamp=hidden_clamp)
 
 
 
 
 
376
 
377
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
378
  # weight/bias grads use atomicAdd accumulation across tokens,
 
1
  import pytest
2
  import torch
3
+ from grouped_poly_norm import _has_cuda_ops, fused_mul_grouped_poly_norm_ref
 
 
 
 
4
 
5
  if _has_cuda_ops:
6
  from grouped_poly_norm import fused_mul_grouped_poly_norm
 
19
 
20
  def _counts_to_offsets(counts_list, device):
21
  """Convert list of counts to cumsum offsets tensor."""
22
+ return torch.cumsum(torch.tensor(counts_list,
23
+ device=device,
24
+ dtype=torch.int32),
25
+ dim=0).to(torch.int32)
26
+
27
+
28
+ def _make_inputs(total_tokens,
29
+ hidden_dim,
30
+ num_experts,
31
+ dtype,
32
+ device,
33
+ seed=42,
34
+ expert_offset=0):
35
  """Create deterministic test inputs with random token distribution."""
36
  torch.manual_seed(seed)
37
 
 
60
  return torch.rand(total_tokens, 1, device=device, dtype=dtype) * 0.5 + 0.5
61
 
62
 
63
+ def _run_ref(input_t,
64
+ mul_t,
65
+ weight,
66
+ bias,
67
+ offsets,
68
+ expert_offset=0,
69
+ scores=None,
70
+ hidden_clamp=None):
71
  """Run reference forward + backward, return output and grads."""
72
  inp = input_t.clone().detach().requires_grad_(True)
73
  m = mul_t.clone().detach().requires_grad_(True)
74
  w = weight.clone().detach().requires_grad_(True)
75
  b = bias.clone().detach().requires_grad_(True)
76
+ s = scores.clone().detach().requires_grad_(
77
+ True) if scores is not None else None
78
+
79
+ out = fused_mul_grouped_poly_norm_ref(inp,
80
+ m,
81
+ w,
82
+ b,
83
+ offsets,
84
  expert_offset=expert_offset,
85
+ scores=s,
86
+ hidden_clamp=hidden_clamp)
87
  out.sum().backward()
88
 
89
  grads = (out, inp.grad, m.grad, w.grad, b.grad)
90
+ return grads + (s.grad, ) if s is not None else grads + (None, )
91
 
92
 
93
+ def _run_cuda(input_t,
94
+ mul_t,
95
+ weight,
96
+ bias,
97
+ offsets,
98
+ expert_offset=0,
99
+ scores=None,
100
+ hidden_clamp=None):
101
  """Run CUDA forward + backward, return output and grads."""
102
  inp = input_t.clone().detach().requires_grad_(True)
103
  m = mul_t.clone().detach().requires_grad_(True)
104
  w = weight.clone().detach().requires_grad_(True)
105
  b = bias.clone().detach().requires_grad_(True)
106
+ s = scores.clone().detach().requires_grad_(
107
+ True) if scores is not None else None
108
+
109
+ out = fused_mul_grouped_poly_norm(inp,
110
+ m,
111
+ w,
112
+ b,
113
+ offsets,
114
  expert_offset=expert_offset,
115
+ scores=s,
116
+ hidden_clamp=hidden_clamp)
117
  out.sum().backward()
118
 
119
  grads = (out, inp.grad, m.grad, w.grad, b.grad)
120
+ return grads + (s.grad, ) if s is not None else grads + (None, )
121
 
122
 
123
  @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
 
140
  """CUDA forward output should match PyTorch reference."""
141
  torch.set_default_device(device)
142
  input_t, mul_t, weight, bias, offsets = _make_inputs(
143
+ num_tokens,
144
+ d,
145
+ num_experts,
146
+ dtype,
147
+ device,
148
+ seed,
149
  expert_offset=expert_offset)
150
 
151
+ out_ref = fused_mul_grouped_poly_norm_ref(input_t,
152
+ mul_t,
153
+ weight,
154
+ bias,
155
  offsets,
156
  expert_offset=expert_offset)
157
+ out_tri = fused_mul_grouped_poly_norm(input_t,
158
+ mul_t,
159
+ weight,
160
+ bias,
161
+ offsets,
162
+ expert_offset=expert_offset)
163
 
164
  assert out_ref.shape == out_tri.shape == (num_tokens, d)
165
  assert out_ref.dtype == out_tri.dtype == dtype
 
190
  """CUDA backward gradients should match PyTorch reference."""
191
  torch.set_default_device(device)
192
  input_t, mul_t, weight, bias, offsets = _make_inputs(
193
+ num_tokens,
194
+ d,
195
+ num_experts,
196
+ dtype,
197
+ device,
198
+ seed,
199
  expert_offset=expert_offset)
200
 
201
  _, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref(
 
238
  bias = torch.zeros(total_experts, 1, device=device, dtype=dtype)
239
  offsets = _counts_to_offsets(counts, device)
240
 
241
+ out_ref = fused_mul_grouped_poly_norm_ref(input_t,
242
+ mul_t,
243
+ weight,
244
+ bias,
245
  offsets,
246
  expert_offset=expert_offset)
247
+ out_tri = fused_mul_grouped_poly_norm(input_t,
248
+ mul_t,
249
+ weight,
250
+ bias,
251
+ offsets,
252
+ expert_offset=expert_offset)
253
 
254
  if dtype == torch.float32:
255
  assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4)
 
257
  assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
258
 
259
  # Check backward with zero-token experts
260
+ _, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t,
261
+ mul_t,
262
+ weight,
263
+ bias,
264
+ offsets,
265
+ expert_offset=expert_offset)
266
+ _, _, _, w_grad_tri, b_grad_tri, _ = _run_cuda(input_t,
267
+ mul_t,
268
+ weight,
269
+ bias,
270
  offsets,
271
  expert_offset=expert_offset)
 
 
 
272
 
273
  if dtype == torch.float32:
274
  atol, rtol = 1e-3, 1e-3
 
325
  @pytest.mark.parametrize("dtype", DTYPES)
326
  @pytest.mark.parametrize("device", CUDA_DEVICES)
327
  def test_fused_mul_grouped_poly_norm_scores_forward(
328
+ num_tokens,
329
+ d,
330
+ num_experts,
331
+ dtype,
332
+ device,
333
  ):
334
  """Forward with scores should match reference."""
335
  torch.set_default_device(device)
 
337
  num_tokens, d, num_experts, dtype, device)
338
  scores = _make_scores(num_tokens, device)
339
 
340
+ out_ref = fused_mul_grouped_poly_norm_ref(input_t,
341
+ mul_t,
342
+ weight,
343
+ bias,
344
+ offsets,
345
+ scores=scores)
346
+ out_tri = fused_mul_grouped_poly_norm(input_t,
347
+ mul_t,
348
+ weight,
349
+ bias,
350
+ offsets,
351
+ scores=scores)
352
 
353
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
354
  assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
 
361
  @pytest.mark.parametrize("dtype", DTYPES)
362
  @pytest.mark.parametrize("device", CUDA_DEVICES)
363
  def test_fused_mul_grouped_poly_norm_scores_backward(
364
+ num_tokens,
365
+ d,
366
+ num_experts,
367
+ dtype,
368
+ device,
369
  ):
370
  """Backward with scores should match reference."""
371
  torch.set_default_device(device)
 
373
  num_tokens, d, num_experts, dtype, device)
374
  scores = _make_scores(num_tokens, device)
375
 
376
+ out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(input_t,
377
+ mul_t,
378
+ weight,
379
+ bias,
380
+ offsets,
381
+ scores=scores)
382
+ out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(input_t,
383
+ mul_t,
384
+ weight,
385
+ bias,
386
+ offsets,
387
+ scores=scores)
388
 
389
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
390
  # weight/bias grads use atomicAdd accumulation across tokens,
 
411
  @pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
412
  @pytest.mark.parametrize("device", CUDA_DEVICES)
413
  def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
414
+ num_tokens,
415
+ d,
416
+ num_experts,
417
+ dtype,
418
+ hidden_clamp,
419
+ device,
420
  ):
421
  """Forward with hidden_clamp should match reference."""
422
  torch.set_default_device(device)
 
424
  num_tokens, d, num_experts, dtype, device)
425
  scores = _make_scores(num_tokens, device)
426
 
427
+ out_ref = fused_mul_grouped_poly_norm_ref(input_t,
428
+ mul_t,
429
+ weight,
430
+ bias,
431
+ offsets,
432
+ scores=scores,
433
+ hidden_clamp=hidden_clamp)
434
+ out_tri = fused_mul_grouped_poly_norm(input_t,
435
+ mul_t,
436
+ weight,
437
+ bias,
438
+ offsets,
439
+ scores=scores,
440
+ hidden_clamp=hidden_clamp)
441
 
442
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
443
  assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
 
451
  @pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
452
  @pytest.mark.parametrize("device", CUDA_DEVICES)
453
  def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
454
+ num_tokens,
455
+ d,
456
+ num_experts,
457
+ dtype,
458
+ hidden_clamp,
459
+ device,
460
  ):
461
  """Backward with hidden_clamp should match reference."""
462
  torch.set_default_device(device)
 
465
  scores = _make_scores(num_tokens, device)
466
 
467
  out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
468
+ input_t,
469
+ mul_t,
470
+ weight,
471
+ bias,
472
+ offsets,
473
+ scores=scores,
474
+ hidden_clamp=hidden_clamp)
475
  out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(
476
+ input_t,
477
+ mul_t,
478
+ weight,
479
+ bias,
480
+ offsets,
481
+ scores=scores,
482
+ hidden_clamp=hidden_clamp)
483
 
484
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
485
  # weight/bias grads use atomicAdd accumulation across tokens,
torch-ext/activation/grouped_poly_norm.py CHANGED
@@ -37,37 +37,42 @@ _has_cuda_ops = _ops is not None and hasattr(_ops, 'grouped_poly_norm_forward')
37
  # Register fake (meta) tensor implementations for torch.compile
38
  if _has_cuda_ops:
39
  try:
 
40
  @torch.library.register_fake("_activation::grouped_poly_norm_forward")
41
  def _fwd_fake(input, mul, weight, bias, offsets, eps, expert_offset,
42
- hidden_clamp):
43
  return (torch.empty_like(input),
44
- torch.empty(input.shape[0], 3, dtype=torch.float32,
 
 
45
  device=input.device))
46
 
47
  @torch.library.register_fake("_activation::grouped_poly_norm_backward")
48
  def _bwd_fake(grad_output, input, mul, weight, bias, offsets, inv_rms,
49
- eps, expert_offset, hidden_clamp):
50
- return (torch.empty_like(input),
51
- torch.empty_like(mul),
52
- torch.empty_like(weight),
53
- torch.empty_like(bias))
54
-
55
- @torch.library.register_fake("_activation::grouped_poly_norm_forward_scored")
56
- def _fwd_scored_fake(input, mul, weight, bias, offsets, scores,
57
- eps, expert_offset, hidden_clamp):
58
  return (torch.empty_like(input),
59
- torch.empty(input.shape[0], 3, dtype=torch.float32,
 
 
60
  device=input.device))
61
 
62
- @torch.library.register_fake("_activation::grouped_poly_norm_backward_scored")
 
63
  def _bwd_scored_fake(grad_output, input, mul, weight, bias, offsets,
64
- inv_rms, scores, eps, expert_offset,
65
- hidden_clamp):
66
- return (torch.empty_like(input),
67
- torch.empty_like(mul),
68
- torch.empty_like(weight),
69
- torch.empty_like(bias),
70
- torch.empty(input.shape[0], 1, dtype=torch.float32,
71
  device=input.device))
72
  except Exception:
73
  pass # already registered
@@ -111,7 +116,8 @@ def fused_mul_grouped_poly_norm_ref(
111
  orig_dtype = input.dtype
112
 
113
  token_positions = torch.arange(input.shape[0], device=input.device)
114
- expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
 
115
 
116
  weight_fp32 = weight.float()
117
  bias_fp32 = bias.float()
@@ -182,25 +188,25 @@ if _has_cuda_ops:
182
  """With scores — same pattern, adds scores + hidden_clamp."""
183
 
184
  @staticmethod
185
- def forward(input, mul, weight, bias, offsets, scores,
186
- eps, expert_offset, hidden_clamp):
187
  input = input.contiguous()
188
  mul = mul.contiguous()
189
  assert scores.dtype == torch.float32, \
190
  f"scores must be float32, got {scores.dtype}"
191
  scores = scores.contiguous()
192
  output, inv_rms = _ops.grouped_poly_norm_forward_scored(
193
- input, mul, weight, bias, offsets, scores, eps,
194
- expert_offset, hidden_clamp)
195
  return output, inv_rms
196
 
197
  @staticmethod
198
  def setup_context(ctx, inputs, output):
199
- (input, mul, weight, bias, offsets, scores,
200
- eps, expert_offset, hidden_clamp) = inputs
201
  _, inv_rms = output
202
- ctx.save_for_backward(input, mul, weight, bias, offsets,
203
- inv_rms, scores)
204
  ctx.eps = eps
205
  ctx.expert_offset = expert_offset
206
  ctx.hidden_clamp = hidden_clamp
@@ -244,13 +250,14 @@ if _has_cuda_ops:
244
  """
245
  clamp_val = -1.0 if hidden_clamp is None else float(hidden_clamp)
246
  if scores is not None:
247
- output, _ = _GroupedPolyNormScoredFn.apply(
248
- input, mul, weight, bias, offsets, scores, eps,
249
- expert_offset, clamp_val)
 
250
  else:
251
- output, _ = _GroupedPolyNormFn.apply(
252
- input, mul, weight, bias, offsets, eps, expert_offset,
253
- clamp_val)
254
  return output
255
 
256
  else:
 
37
  # Register fake (meta) tensor implementations for torch.compile
38
  if _has_cuda_ops:
39
  try:
40
+
41
  @torch.library.register_fake("_activation::grouped_poly_norm_forward")
42
  def _fwd_fake(input, mul, weight, bias, offsets, eps, expert_offset,
43
+ hidden_clamp):
44
  return (torch.empty_like(input),
45
+ torch.empty(input.shape[0],
46
+ 3,
47
+ dtype=torch.float32,
48
  device=input.device))
49
 
50
  @torch.library.register_fake("_activation::grouped_poly_norm_backward")
51
  def _bwd_fake(grad_output, input, mul, weight, bias, offsets, inv_rms,
52
+ eps, expert_offset, hidden_clamp):
53
+ return (torch.empty_like(input), torch.empty_like(mul),
54
+ torch.empty_like(weight), torch.empty_like(bias))
55
+
56
+ @torch.library.register_fake(
57
+ "_activation::grouped_poly_norm_forward_scored")
58
+ def _fwd_scored_fake(input, mul, weight, bias, offsets, scores, eps,
59
+ expert_offset, hidden_clamp):
 
60
  return (torch.empty_like(input),
61
+ torch.empty(input.shape[0],
62
+ 3,
63
+ dtype=torch.float32,
64
  device=input.device))
65
 
66
+ @torch.library.register_fake(
67
+ "_activation::grouped_poly_norm_backward_scored")
68
  def _bwd_scored_fake(grad_output, input, mul, weight, bias, offsets,
69
+ inv_rms, scores, eps, expert_offset,
70
+ hidden_clamp):
71
+ return (torch.empty_like(input), torch.empty_like(mul),
72
+ torch.empty_like(weight), torch.empty_like(bias),
73
+ torch.empty(input.shape[0],
74
+ 1,
75
+ dtype=torch.float32,
76
  device=input.device))
77
  except Exception:
78
  pass # already registered
 
116
  orig_dtype = input.dtype
117
 
118
  token_positions = torch.arange(input.shape[0], device=input.device)
119
+ expert_idx = torch.bucketize(token_positions, offsets,
120
+ right=True) + expert_offset
121
 
122
  weight_fp32 = weight.float()
123
  bias_fp32 = bias.float()
 
188
  """With scores — same pattern, adds scores + hidden_clamp."""
189
 
190
  @staticmethod
191
+ def forward(input, mul, weight, bias, offsets, scores, eps,
192
+ expert_offset, hidden_clamp):
193
  input = input.contiguous()
194
  mul = mul.contiguous()
195
  assert scores.dtype == torch.float32, \
196
  f"scores must be float32, got {scores.dtype}"
197
  scores = scores.contiguous()
198
  output, inv_rms = _ops.grouped_poly_norm_forward_scored(
199
+ input, mul, weight, bias, offsets, scores, eps, expert_offset,
200
+ hidden_clamp)
201
  return output, inv_rms
202
 
203
  @staticmethod
204
  def setup_context(ctx, inputs, output):
205
+ (input, mul, weight, bias, offsets, scores, eps, expert_offset,
206
+ hidden_clamp) = inputs
207
  _, inv_rms = output
208
+ ctx.save_for_backward(input, mul, weight, bias, offsets, inv_rms,
209
+ scores)
210
  ctx.eps = eps
211
  ctx.expert_offset = expert_offset
212
  ctx.hidden_clamp = hidden_clamp
 
250
  """
251
  clamp_val = -1.0 if hidden_clamp is None else float(hidden_clamp)
252
  if scores is not None:
253
+ output, _ = _GroupedPolyNormScoredFn.apply(input, mul, weight,
254
+ bias, offsets, scores,
255
+ eps, expert_offset,
256
+ clamp_val)
257
  else:
258
+ output, _ = _GroupedPolyNormFn.apply(input, mul, weight, bias,
259
+ offsets, eps, expert_offset,
260
+ clamp_val)
261
  return output
262
 
263
  else:
torch-ext/torch_binding.cpp CHANGED
@@ -50,32 +50,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
50
  &fused_add_rms_norm_backward);
51
 
52
  // grouped_poly_norm (without scores, hidden_clamp < 0 = disabled)
53
- ops.def("grouped_poly_norm_forward("
54
- "Tensor input, Tensor mul, Tensor weight, "
55
- "Tensor bias, Tensor offsets, "
56
- "float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor)");
 
57
  ops.impl("grouped_poly_norm_forward", torch::kCUDA,
58
  &grouped_poly_norm_forward);
59
 
60
  ops.def("grouped_poly_norm_backward("
61
  "Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
62
  "Tensor bias, Tensor offsets, Tensor inv_rms, "
63
- "float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor, Tensor, Tensor)");
 
64
  ops.impl("grouped_poly_norm_backward", torch::kCUDA,
65
  &grouped_poly_norm_backward);
66
 
67
  // grouped_poly_norm (with scores)
68
- ops.def("grouped_poly_norm_forward_scored("
69
- "Tensor input, Tensor mul, Tensor weight, "
70
- "Tensor bias, Tensor offsets, Tensor scores, "
71
- "float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor)");
 
72
  ops.impl("grouped_poly_norm_forward_scored", torch::kCUDA,
73
  &grouped_poly_norm_forward_scored);
74
 
75
  ops.def("grouped_poly_norm_backward_scored("
76
  "Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
77
  "Tensor bias, Tensor offsets, Tensor inv_rms, Tensor scores, "
78
- "float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
 
79
  ops.impl("grouped_poly_norm_backward_scored", torch::kCUDA,
80
  &grouped_poly_norm_backward_scored);
81
  }
 
50
  &fused_add_rms_norm_backward);
51
 
52
  // grouped_poly_norm (without scores, hidden_clamp < 0 = disabled)
53
+ ops.def(
54
+ "grouped_poly_norm_forward("
55
+ "Tensor input, Tensor mul, Tensor weight, "
56
+ "Tensor bias, Tensor offsets, "
57
+ "float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor)");
58
  ops.impl("grouped_poly_norm_forward", torch::kCUDA,
59
  &grouped_poly_norm_forward);
60
 
61
  ops.def("grouped_poly_norm_backward("
62
  "Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
63
  "Tensor bias, Tensor offsets, Tensor inv_rms, "
64
+ "float eps, int expert_offset, float hidden_clamp) -> (Tensor, "
65
+ "Tensor, Tensor, Tensor)");
66
  ops.impl("grouped_poly_norm_backward", torch::kCUDA,
67
  &grouped_poly_norm_backward);
68
 
69
  // grouped_poly_norm (with scores)
70
+ ops.def(
71
+ "grouped_poly_norm_forward_scored("
72
+ "Tensor input, Tensor mul, Tensor weight, "
73
+ "Tensor bias, Tensor offsets, Tensor scores, "
74
+ "float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor)");
75
  ops.impl("grouped_poly_norm_forward_scored", torch::kCUDA,
76
  &grouped_poly_norm_forward_scored);
77
 
78
  ops.def("grouped_poly_norm_backward_scored("
79
  "Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
80
  "Tensor bias, Tensor offsets, Tensor inv_rms, Tensor scores, "
81
+ "float eps, int expert_offset, float hidden_clamp) -> (Tensor, "
82
+ "Tensor, Tensor, Tensor, Tensor)");
83
  ops.impl("grouped_poly_norm_backward_scored", torch::kCUDA,
84
  &grouped_poly_norm_backward_scored);
85
  }