Kernels
wyldecat Claude Opus 4.6 (1M context) commited on
Commit
09ecd67
·
1 Parent(s): 9dcee96

fix: unify all backward kernels to input-based math + fix test import

Browse files

- Always save input in setup_context (no output-based path)
- bwd_fused: input-based math (d_sum = sum(dy*x*w), grad = scale*dy*w - dxx*x)
- bwd_large_input_grad: input-based math
- bwd_large_weight_grad: input-based (weight_grad = dy*x*scale, needs inv_rms)
- bwd_scalar: input-based (unchanged)
- pytest.ini: add pythonpath to prevent activation/ CUDA dir namespace shadow

Eliminates bf16 precision loss from output-based x recovery (y/(w*scale)).
All 48 test_rms_norm tests pass. No performance regression.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

activation/rms_norm.cu CHANGED
@@ -223,20 +223,17 @@ __global__ void rms_norm_fwd_scalar(scalar_t *__restrict__ out,
223
  // Backward: single-pass with register caching (NVECS vecs per thread)
224
  // ---------------------------------------------------------------------------
225
  // ---------------------------------------------------------------------------
226
- // Backward input_grad (dim > 8192): output-based, 2-pass
227
  // Architecture: 1 block/token, 256 threads, __launch_bounds__(256, 4)
228
- // Uses forward output y instead of input x (saved by autograd):
229
- // d_sum = (1/scale) * sum(dy * y) no weight read in pass 1
230
- // input_grad = scale*dy*w - dxx*y/(w*scale)
231
- // Pass 1: read dy + y → d_sum (warp shuffle reduction)
232
- // Pass 2: read dy + y + w (L1 hit) → write input_grad
233
  // Weight grad computed by separate bwd_large_weight_grad kernel
234
  // ---------------------------------------------------------------------------
235
  template <typename scalar_t, typename acc_t, int width>
236
  __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
237
  scalar_t *__restrict__ input_grad,
238
  const scalar_t *__restrict__ output_grad, // dy
239
- const scalar_t *__restrict__ output, // y = x * w * scale
240
  const scalar_t *__restrict__ weight, const acc_t *__restrict__ inv_rms,
241
  const int d) {
242
  using vec_t = type_vec_t<scalar_t, width>;
@@ -245,8 +242,7 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
245
  const int vec_d = d / width;
246
  const int64_t vec_offset = token_idx * vec_d;
247
 
248
- const vec_t *__restrict__ output_vec =
249
- reinterpret_cast<const vec_t *>(output);
250
  const vec_t *__restrict__ output_grad_vec =
251
  reinterpret_cast<const vec_t *>(output_grad);
252
  const vec_t *__restrict__ weight_vec =
@@ -254,46 +250,44 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
254
 
255
  acc_t scale = inv_rms[token_idx];
256
 
257
- // Pass 1: d_sum = (1/scale) * sum(dy * y) only reads dy and y!
258
- acc_t dy_y_sum = 0.0f;
259
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
260
- vec_t y_vec = output_vec[vec_offset + vidx];
261
  vec_t dy_vec = output_grad_vec[vec_offset + vidx];
 
262
  #pragma unroll
263
  for (int i = 0; i < width; ++i) {
264
- dy_y_sum += static_cast<acc_t>(dy_vec.data[i]) *
265
- static_cast<acc_t>(y_vec.data[i]);
 
266
  }
267
  }
268
 
269
- dy_y_sum = block_reduce_sum(dy_y_sum);
270
 
271
- // d_sum = dy_y_sum / scale, dxx = d_sum * scale^3 / d = dy_y_sum * scale^2 /
272
- // d
273
  __shared__ acc_t s_dxx;
274
  if (threadIdx.x == 0) {
275
- s_dxx = dy_y_sum * scale * scale / d;
276
  }
277
  __syncthreads();
278
  acc_t dxx = s_dxx;
279
 
280
- // Pass 2: input_grad = scale * dy * w - dxx * y / (w * scale)
281
  vec_t *__restrict__ input_grad_vec = reinterpret_cast<vec_t *>(input_grad);
282
 
283
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
284
- vec_t y_vec = output_vec[vec_offset + vidx];
285
  vec_t dy_vec = output_grad_vec[vec_offset + vidx];
286
  vec_t w_vec = weight_vec[vidx];
287
 
288
  vec_t in_grad_vec;
289
  #pragma unroll
290
  for (int i = 0; i < width; ++i) {
291
- acc_t y = y_vec.data[i];
292
  acc_t dy = dy_vec.data[i];
293
  acc_t w = w_vec.data[i];
294
- // x = y / (w * scale), so: scale*dy*w - dxx*x = scale*dy*w -
295
- // dxx*y/(w*scale)
296
- in_grad_vec.data[i] = scale * dy * w - dxx * y / (w * scale);
297
  }
298
  input_grad_vec[vec_offset + vidx] = in_grad_vec;
299
  }
@@ -302,8 +296,7 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
302
  // ---------------------------------------------------------------------------
303
  // Backward fused (dim ≤ 8192): multi-token block, input_grad + weight_grad
304
  // Architecture: 32 tokens/block, 128 threads, dynamic shared mem [d] fp32
305
- // Uses saved input (not output) input-based math:
306
- // d_sum = sum(dy * x * w), input_grad = scale*dy*w - dxx*x
307
  // 2-token batched float2 reduction (syncthreads halved)
308
  // Weight grad accumulated in shared memory → atomicAdd to global
309
  // Single kernel: no separate weight_grad kernel needed
@@ -313,13 +306,13 @@ __global__ void rms_norm_bwd_fused(
313
  scalar_t *__restrict__ input_grad,
314
  acc_t *__restrict__ weight_grad_acc, // [d] — atomicAdd target
315
  const scalar_t *__restrict__ output_grad,
316
- const scalar_t *__restrict__ output, // forward output y
317
- const scalar_t *__restrict__ weight, const acc_t *__restrict__ inv_rms,
318
- const int d, const int64_t num_tokens, const int tpb) {
319
 
320
  using vec_t = type_vec_t<scalar_t, width>;
321
 
322
- extern __shared__ acc_t wg_shared[]; // [d] float32 — accumulates sum(dy*y)
323
 
324
  const int vec_d = d / width;
325
 
@@ -328,8 +321,7 @@ __global__ void rms_norm_bwd_fused(
328
  wg_shared[idx] = 0.0f;
329
  __syncthreads();
330
 
331
- const vec_t *__restrict__ output_vec =
332
- reinterpret_cast<const vec_t *>(output);
333
  const vec_t *__restrict__ output_grad_vec =
334
  reinterpret_cast<const vec_t *>(output_grad);
335
  const vec_t *__restrict__ weight_vec =
@@ -345,50 +337,52 @@ __global__ void rms_norm_bwd_fused(
345
  acc_t s0 = inv_rms[t], s1 = inv_rms[t + 1];
346
  int64_t off0 = t * vec_d, off1 = (t + 1) * vec_d;
347
 
348
- // Pass 1: dy_y_sum = sum(dy * y) no weight needed for d_sum!
349
- acc_t dys0 = 0.0f, dys1 = 0.0f;
350
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
351
- vec_t y0 = output_vec[off0 + vidx];
352
  vec_t dy0 = output_grad_vec[off0 + vidx];
353
- vec_t y1 = output_vec[off1 + vidx];
354
  vec_t dy1 = output_grad_vec[off1 + vidx];
 
355
  #pragma unroll
356
  for (int i = 0; i < width; ++i) {
357
- dys0 +=
358
- static_cast<acc_t>(dy0.data[i]) * static_cast<acc_t>(y0.data[i]);
359
- dys1 +=
360
- static_cast<acc_t>(dy1.data[i]) * static_cast<acc_t>(y1.data[i]);
 
361
  }
362
  }
363
 
364
- float2 sums = block_reduce_sum2(make_float2(dys0, dys1));
365
 
366
- // dxx = dy_y_sum * scale^2 / d
367
  __shared__ acc_t sd0, sd1;
368
  if (threadIdx.x == 0) {
369
- sd0 = sums.x * s0 * s0 / d;
370
- sd1 = sums.y * s1 * s1 / d;
371
  }
372
  __syncthreads();
373
  acc_t dxx0 = sd0, dxx1 = sd1;
374
 
375
- // Pass 2: input_grad + wg_shared accumulate (L1 hit on y, dy)
376
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
377
- vec_t y0 = output_vec[off0 + vidx];
378
  vec_t dy0 = output_grad_vec[off0 + vidx];
379
- vec_t y1 = output_vec[off1 + vidx];
380
  vec_t dy1 = output_grad_vec[off1 + vidx];
381
  vec_t w = weight_vec[vidx];
382
 
383
  vec_t g0, g1;
384
  #pragma unroll
385
  for (int i = 0; i < width; ++i) {
386
- acc_t yi0 = y0.data[i], di0 = dy0.data[i], wi = w.data[i];
387
- acc_t yi1 = y1.data[i], di1 = dy1.data[i];
388
- g0.data[i] = s0 * di0 * wi - dxx0 * yi0 / (wi * s0);
389
- g1.data[i] = s1 * di1 * wi - dxx1 * yi1 / (wi * s1);
390
- // wg_shared accumulates sum(dy * y), divide by w at the end
391
- wg_shared[vidx * width + i] += di0 * yi0 + di1 * yi1;
392
  }
393
  input_grad_vec[off0 + vidx] = g0;
394
  input_grad_vec[off1 + vidx] = g1;
@@ -399,44 +393,44 @@ __global__ void rms_norm_bwd_fused(
399
  if (t < token_end) {
400
  acc_t scale = inv_rms[t];
401
  int64_t vec_offset = t * vec_d;
402
- acc_t dy_y_sum = 0.0f;
403
 
404
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
405
- vec_t y_vec = output_vec[vec_offset + vidx];
406
  vec_t dy_vec = output_grad_vec[vec_offset + vidx];
 
407
  #pragma unroll
408
  for (int i = 0; i < width; ++i)
409
- dy_y_sum += static_cast<acc_t>(dy_vec.data[i]) *
410
- static_cast<acc_t>(y_vec.data[i]);
 
411
  }
412
- dy_y_sum = block_reduce_sum(dy_y_sum);
413
 
414
  __shared__ acc_t s_dxx;
415
  if (threadIdx.x == 0)
416
- s_dxx = dy_y_sum * scale * scale / d;
417
  __syncthreads();
418
  acc_t dxx = s_dxx;
419
 
420
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
421
- vec_t y_vec = output_vec[vec_offset + vidx];
422
  vec_t dy_vec = output_grad_vec[vec_offset + vidx];
423
- vec_t w_vec = weight_vec[vidx];
424
  vec_t gv;
425
  #pragma unroll
426
  for (int i = 0; i < width; ++i) {
427
- acc_t y = y_vec.data[i], dy = dy_vec.data[i], w = w_vec.data[i];
428
- gv.data[i] = scale * dy * w - dxx * y / (w * scale);
429
- wg_shared[vidx * width + i] += dy * y;
 
430
  }
431
  input_grad_vec[vec_offset + vidx] = gv;
432
  }
433
  }
434
 
435
- // AtomicAdd accumulated weight grad: wg_shared has sum(dy*y), divide by w
436
- const scalar_t *__restrict__ w_ptr = weight;
437
  for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
438
- acc_t w = static_cast<acc_t>(w_ptr[idx]);
439
- atomicAdd(&weight_grad_acc[idx], wg_shared[idx] / w);
440
  }
441
  }
442
 
@@ -444,7 +438,7 @@ __global__ void rms_norm_bwd_fused(
444
  // Backward weight_grad (dim > 8192): column-parallel with vec8 loads
445
  // Architecture: 2D grid (col_blocks × token_chunks), 256 threads
446
  // Each thread handles 8 columns (vec8), iterates over chunk_size tokens
447
- // Output-based: weight_grad_i = sum_t(dy*y) / w_i — no inv_rms needed
448
  // Writes partial_wg[chunk, d], host reduces with at::sum_out
449
  // __launch_bounds__(256, 4)
450
  // ---------------------------------------------------------------------------
@@ -452,11 +446,11 @@ template <typename scalar_t, typename acc_t, int TILE_T, int VEC_W>
452
  __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_weight_grad(
453
  acc_t *__restrict__ partial_wg, // [num_chunks, d]
454
  const scalar_t *__restrict__ output_grad, // [num_tokens, d]
455
- const scalar_t *__restrict__ output, // [num_tokens, d] — forward output y
456
- const scalar_t *__restrict__ weight, // [d]
 
457
  const int d, const int64_t num_tokens, const int64_t chunk_size) {
458
- // weight_grad_i = (1/w_i) * sum_t(dy * y)
459
- // No inv_rms needed! Just dy, y, and w.
460
 
461
  using vec_t = type_vec_t<scalar_t, VEC_W>;
462
 
@@ -465,19 +459,15 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_weight_grad(
465
  const int64_t token_start = blockIdx.y * chunk_size;
466
  const int64_t token_end = min(token_start + chunk_size, num_tokens);
467
 
468
- const vec_t *__restrict__ output_vec =
469
- reinterpret_cast<const vec_t *>(output);
470
  const vec_t *__restrict__ grad_vec =
471
  reinterpret_cast<const vec_t *>(output_grad);
472
- const vec_t *__restrict__ weight_vec =
473
- reinterpret_cast<const vec_t *>(weight);
474
  const int vec_d = d / VEC_W;
475
 
476
- // Accumulate sum_t(dy * y) in registers
477
- acc_t dy_y_acc[VEC_W];
478
  #pragma unroll
479
  for (int i = 0; i < VEC_W; i++)
480
- dy_y_acc[i] = 0.0f;
481
 
482
  int vec_col = blockIdx.x * blockDim.x + threadIdx.x;
483
 
@@ -487,23 +477,22 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_weight_grad(
487
  if (vec_col < vec_d) {
488
  for (int r = 0; r < tile_size; r++) {
489
  int64_t t = t_base + r;
490
- vec_t y_v = output_vec[t * vec_d + vec_col];
 
491
  vec_t dy_v = grad_vec[t * vec_d + vec_col];
492
  #pragma unroll
493
  for (int i = 0; i < VEC_W; i++) {
494
- dy_y_acc[i] += static_cast<acc_t>(dy_v.data[i]) *
495
- static_cast<acc_t>(y_v.data[i]);
496
  }
497
  }
498
  }
499
  }
500
 
501
- // Write results: weight_grad_i = dy_y_acc_i / w_i
502
  if (vec_col < vec_d) {
503
- vec_t w_v = weight_vec[vec_col];
504
  for (int i = 0; i < VEC_W; i++) {
505
- acc_t w = static_cast<acc_t>(w_v.data[i]);
506
- partial_wg[blockIdx.y * d + col_base + i] = dy_y_acc[i] / w;
507
  }
508
  }
509
  }
@@ -665,24 +654,22 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
665
  block_size = ((block_size + 31) / 32) * 32;
666
  block_size = std::max(block_size, 32);
667
  size_t smem = d * sizeof(float);
668
- // For dim <= 8192: 'output' arg is actually input (saved by Python
669
- // autograd)
670
  MOTIF_DISPATCH_FLOATING_TYPES(
671
  output.scalar_type(), "rms_norm_bwd_fused", [&] {
672
  motif::rms_norm_bwd_fused<scalar_t, float, 8>
673
  <<<num_blocks_mt, block_size, smem, stream>>>(
674
  input_grad.data_ptr<scalar_t>(), wg_acc.data_ptr<float>(),
675
  output_grad.data_ptr<scalar_t>(),
676
- output.data_ptr<scalar_t>(), // actually input for dim<=8192
677
- weight.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), d,
678
- num_tokens, tpb);
679
  });
680
 
681
  if (weight_grad.defined()) {
682
  weight_grad.copy_(wg_acc);
683
  }
684
  } else {
685
- // Large dims (d > 8192): output-based bwd + column-parallel weight grad
686
  MOTIF_DISPATCH_FLOATING_TYPES(
687
  output.scalar_type(), "rms_norm_bwd_large_input_grad", [&] {
688
  motif::rms_norm_bwd_large_input_grad<scalar_t, float, 8>
@@ -712,7 +699,7 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
712
  partial_wg.data_ptr<float>(),
713
  output_grad.data_ptr<scalar_t>(),
714
  output.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
715
- d, num_tokens, chunk_size);
716
  });
717
 
718
  torch::Tensor acc =
 
223
  // Backward: single-pass with register caching (NVECS vecs per thread)
224
  // ---------------------------------------------------------------------------
225
  // ---------------------------------------------------------------------------
226
+ // Backward input_grad (dim > 8192): input-based, 2-pass
227
  // Architecture: 1 block/token, 256 threads, __launch_bounds__(256, 4)
228
+ // Pass 1: read dy + x + w → d_sum = sum(dy * x * w)
229
+ // Pass 2: read dy + x + w (L1 hit) input_grad = scale*dy*w - dxx*x
 
 
 
230
  // Weight grad computed by separate bwd_large_weight_grad kernel
231
  // ---------------------------------------------------------------------------
232
  template <typename scalar_t, typename acc_t, int width>
233
  __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
234
  scalar_t *__restrict__ input_grad,
235
  const scalar_t *__restrict__ output_grad, // dy
236
+ const scalar_t *__restrict__ input, // x
237
  const scalar_t *__restrict__ weight, const acc_t *__restrict__ inv_rms,
238
  const int d) {
239
  using vec_t = type_vec_t<scalar_t, width>;
 
242
  const int vec_d = d / width;
243
  const int64_t vec_offset = token_idx * vec_d;
244
 
245
+ const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
 
246
  const vec_t *__restrict__ output_grad_vec =
247
  reinterpret_cast<const vec_t *>(output_grad);
248
  const vec_t *__restrict__ weight_vec =
 
250
 
251
  acc_t scale = inv_rms[token_idx];
252
 
253
+ // Pass 1: d_sum = sum(dy * x * w)
254
+ acc_t d_sum = 0.0f;
255
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
256
+ vec_t x_vec = input_vec[vec_offset + vidx];
257
  vec_t dy_vec = output_grad_vec[vec_offset + vidx];
258
+ vec_t w_vec = weight_vec[vidx];
259
  #pragma unroll
260
  for (int i = 0; i < width; ++i) {
261
+ d_sum += static_cast<acc_t>(dy_vec.data[i]) *
262
+ static_cast<acc_t>(x_vec.data[i]) *
263
+ static_cast<acc_t>(w_vec.data[i]);
264
  }
265
  }
266
 
267
+ d_sum = block_reduce_sum(d_sum);
268
 
 
 
269
  __shared__ acc_t s_dxx;
270
  if (threadIdx.x == 0) {
271
+ s_dxx = d_sum * scale * scale * scale / d;
272
  }
273
  __syncthreads();
274
  acc_t dxx = s_dxx;
275
 
276
+ // Pass 2: input_grad = scale * dy * w - dxx * x
277
  vec_t *__restrict__ input_grad_vec = reinterpret_cast<vec_t *>(input_grad);
278
 
279
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
280
+ vec_t x_vec = input_vec[vec_offset + vidx];
281
  vec_t dy_vec = output_grad_vec[vec_offset + vidx];
282
  vec_t w_vec = weight_vec[vidx];
283
 
284
  vec_t in_grad_vec;
285
  #pragma unroll
286
  for (int i = 0; i < width; ++i) {
287
+ acc_t x = x_vec.data[i];
288
  acc_t dy = dy_vec.data[i];
289
  acc_t w = w_vec.data[i];
290
+ in_grad_vec.data[i] = scale * dy * w - dxx * x;
 
 
291
  }
292
  input_grad_vec[vec_offset + vidx] = in_grad_vec;
293
  }
 
296
  // ---------------------------------------------------------------------------
297
  // Backward fused (dim ≤ 8192): multi-token block, input_grad + weight_grad
298
  // Architecture: 32 tokens/block, 128 threads, dynamic shared mem [d] fp32
299
+ // Input-based math: d_sum = sum(dy * x * w), input_grad = scale*dy*w - dxx*x
 
300
  // 2-token batched float2 reduction (syncthreads halved)
301
  // Weight grad accumulated in shared memory → atomicAdd to global
302
  // Single kernel: no separate weight_grad kernel needed
 
306
  scalar_t *__restrict__ input_grad,
307
  acc_t *__restrict__ weight_grad_acc, // [d] — atomicAdd target
308
  const scalar_t *__restrict__ output_grad,
309
+ const scalar_t *__restrict__ input, const scalar_t *__restrict__ weight,
310
+ const acc_t *__restrict__ inv_rms, const int d, const int64_t num_tokens,
311
+ const int tpb) {
312
 
313
  using vec_t = type_vec_t<scalar_t, width>;
314
 
315
+ extern __shared__ acc_t wg_shared[]; // [d] float32 — accumulates weight grad
316
 
317
  const int vec_d = d / width;
318
 
 
321
  wg_shared[idx] = 0.0f;
322
  __syncthreads();
323
 
324
+ const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
 
325
  const vec_t *__restrict__ output_grad_vec =
326
  reinterpret_cast<const vec_t *>(output_grad);
327
  const vec_t *__restrict__ weight_vec =
 
337
  acc_t s0 = inv_rms[t], s1 = inv_rms[t + 1];
338
  int64_t off0 = t * vec_d, off1 = (t + 1) * vec_d;
339
 
340
+ // Pass 1: d_sum = sum(dy * x * w)
341
+ acc_t dsum0 = 0.0f, dsum1 = 0.0f;
342
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
343
+ vec_t x0 = input_vec[off0 + vidx];
344
  vec_t dy0 = output_grad_vec[off0 + vidx];
345
+ vec_t x1 = input_vec[off1 + vidx];
346
  vec_t dy1 = output_grad_vec[off1 + vidx];
347
+ vec_t w = weight_vec[vidx];
348
  #pragma unroll
349
  for (int i = 0; i < width; ++i) {
350
+ acc_t wi = w.data[i];
351
+ dsum0 += static_cast<acc_t>(dy0.data[i]) *
352
+ static_cast<acc_t>(x0.data[i]) * wi;
353
+ dsum1 += static_cast<acc_t>(dy1.data[i]) *
354
+ static_cast<acc_t>(x1.data[i]) * wi;
355
  }
356
  }
357
 
358
+ float2 sums = block_reduce_sum2(make_float2(dsum0, dsum1));
359
 
360
+ // dxx = d_sum * scale^3 / d
361
  __shared__ acc_t sd0, sd1;
362
  if (threadIdx.x == 0) {
363
+ sd0 = sums.x * s0 * s0 * s0 / d;
364
+ sd1 = sums.y * s1 * s1 * s1 / d;
365
  }
366
  __syncthreads();
367
  acc_t dxx0 = sd0, dxx1 = sd1;
368
 
369
+ // Pass 2: input_grad + wg_shared accumulate (L1 hit on x, dy)
370
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
371
+ vec_t x0 = input_vec[off0 + vidx];
372
  vec_t dy0 = output_grad_vec[off0 + vidx];
373
+ vec_t x1 = input_vec[off1 + vidx];
374
  vec_t dy1 = output_grad_vec[off1 + vidx];
375
  vec_t w = weight_vec[vidx];
376
 
377
  vec_t g0, g1;
378
  #pragma unroll
379
  for (int i = 0; i < width; ++i) {
380
+ acc_t xi0 = x0.data[i], di0 = dy0.data[i], wi = w.data[i];
381
+ acc_t xi1 = x1.data[i], di1 = dy1.data[i];
382
+ g0.data[i] = s0 * di0 * wi - dxx0 * xi0;
383
+ g1.data[i] = s1 * di1 * wi - dxx1 * xi1;
384
+ // weight_grad = dy * x * scale
385
+ wg_shared[vidx * width + i] += di0 * xi0 * s0 + di1 * xi1 * s1;
386
  }
387
  input_grad_vec[off0 + vidx] = g0;
388
  input_grad_vec[off1 + vidx] = g1;
 
393
  if (t < token_end) {
394
  acc_t scale = inv_rms[t];
395
  int64_t vec_offset = t * vec_d;
396
+ acc_t d_sum = 0.0f;
397
 
398
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
399
+ vec_t x_vec = input_vec[vec_offset + vidx];
400
  vec_t dy_vec = output_grad_vec[vec_offset + vidx];
401
+ vec_t w_vec = weight_vec[vidx];
402
  #pragma unroll
403
  for (int i = 0; i < width; ++i)
404
+ d_sum += static_cast<acc_t>(dy_vec.data[i]) *
405
+ static_cast<acc_t>(x_vec.data[i]) *
406
+ static_cast<acc_t>(w_vec.data[i]);
407
  }
408
+ d_sum = block_reduce_sum(d_sum);
409
 
410
  __shared__ acc_t s_dxx;
411
  if (threadIdx.x == 0)
412
+ s_dxx = d_sum * scale * scale * scale / d;
413
  __syncthreads();
414
  acc_t dxx = s_dxx;
415
 
416
  for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
417
+ vec_t x_vec = input_vec[vec_offset + vidx];
418
  vec_t dy_vec = output_grad_vec[vec_offset + vidx];
 
419
  vec_t gv;
420
  #pragma unroll
421
  for (int i = 0; i < width; ++i) {
422
+ acc_t x = x_vec.data[i], dy = dy_vec.data[i];
423
+ gv.data[i] =
424
+ scale * dy * static_cast<acc_t>(weight_vec[vidx].data[i]) - dxx * x;
425
+ wg_shared[vidx * width + i] += dy * x * scale;
426
  }
427
  input_grad_vec[vec_offset + vidx] = gv;
428
  }
429
  }
430
 
431
+ // AtomicAdd accumulated weight grad
 
432
  for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
433
+ atomicAdd(&weight_grad_acc[idx], wg_shared[idx]);
 
434
  }
435
  }
436
 
 
438
  // Backward weight_grad (dim > 8192): column-parallel with vec8 loads
439
  // Architecture: 2D grid (col_blocks × token_chunks), 256 threads
440
  // Each thread handles 8 columns (vec8), iterates over chunk_size tokens
441
+ // Input-based: weight_grad_i = sum_t(dy * x * scale)
442
  // Writes partial_wg[chunk, d], host reduces with at::sum_out
443
  // __launch_bounds__(256, 4)
444
  // ---------------------------------------------------------------------------
 
446
  __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_weight_grad(
447
  acc_t *__restrict__ partial_wg, // [num_chunks, d]
448
  const scalar_t *__restrict__ output_grad, // [num_tokens, d]
449
+ const scalar_t *__restrict__ input, // [num_tokens, d]
450
+ const scalar_t *__restrict__ weight, // [d]
451
+ const acc_t *__restrict__ inv_rms, // [num_tokens]
452
  const int d, const int64_t num_tokens, const int64_t chunk_size) {
453
+ // weight_grad_i = sum_t(dy_t * x_t * scale_t)
 
454
 
455
  using vec_t = type_vec_t<scalar_t, VEC_W>;
456
 
 
459
  const int64_t token_start = blockIdx.y * chunk_size;
460
  const int64_t token_end = min(token_start + chunk_size, num_tokens);
461
 
462
+ const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
 
463
  const vec_t *__restrict__ grad_vec =
464
  reinterpret_cast<const vec_t *>(output_grad);
 
 
465
  const int vec_d = d / VEC_W;
466
 
467
+ acc_t wg_acc[VEC_W];
 
468
  #pragma unroll
469
  for (int i = 0; i < VEC_W; i++)
470
+ wg_acc[i] = 0.0f;
471
 
472
  int vec_col = blockIdx.x * blockDim.x + threadIdx.x;
473
 
 
477
  if (vec_col < vec_d) {
478
  for (int r = 0; r < tile_size; r++) {
479
  int64_t t = t_base + r;
480
+ acc_t scale = inv_rms[t];
481
+ vec_t x_v = input_vec[t * vec_d + vec_col];
482
  vec_t dy_v = grad_vec[t * vec_d + vec_col];
483
  #pragma unroll
484
  for (int i = 0; i < VEC_W; i++) {
485
+ wg_acc[i] += static_cast<acc_t>(dy_v.data[i]) *
486
+ static_cast<acc_t>(x_v.data[i]) * scale;
487
  }
488
  }
489
  }
490
  }
491
 
492
+ // Write results directly
493
  if (vec_col < vec_d) {
 
494
  for (int i = 0; i < VEC_W; i++) {
495
+ partial_wg[blockIdx.y * d + col_base + i] = wg_acc[i];
 
496
  }
497
  }
498
  }
 
654
  block_size = ((block_size + 31) / 32) * 32;
655
  block_size = std::max(block_size, 32);
656
  size_t smem = d * sizeof(float);
657
+ // 'output' C++ arg receives input (saved by Python autograd)
 
658
  MOTIF_DISPATCH_FLOATING_TYPES(
659
  output.scalar_type(), "rms_norm_bwd_fused", [&] {
660
  motif::rms_norm_bwd_fused<scalar_t, float, 8>
661
  <<<num_blocks_mt, block_size, smem, stream>>>(
662
  input_grad.data_ptr<scalar_t>(), wg_acc.data_ptr<float>(),
663
  output_grad.data_ptr<scalar_t>(),
664
+ output.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
665
+ inv_rms.data_ptr<float>(), d, num_tokens, tpb);
 
666
  });
667
 
668
  if (weight_grad.defined()) {
669
  weight_grad.copy_(wg_acc);
670
  }
671
  } else {
672
+ // Large dims (d > 8192): input-based bwd + column-parallel weight grad
673
  MOTIF_DISPATCH_FLOATING_TYPES(
674
  output.scalar_type(), "rms_norm_bwd_large_input_grad", [&] {
675
  motif::rms_norm_bwd_large_input_grad<scalar_t, float, 8>
 
699
  partial_wg.data_ptr<float>(),
700
  output_grad.data_ptr<scalar_t>(),
701
  output.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
702
+ inv_rms.data_ptr<float>(), d, num_tokens, chunk_size);
703
  });
704
 
705
  torch::Tensor acc =
tests/pytest.ini CHANGED
@@ -1,3 +1,4 @@
1
  [pytest]
2
  log_cli = true
3
  log_cli_level = INFO
 
 
1
  [pytest]
2
  log_cli = true
3
  log_cli_level = INFO
4
+ pythonpath = ../torch-ext
torch-ext/activation/rms_norm.py CHANGED
@@ -15,13 +15,7 @@ class RMSNormFunction(torch.autograd.Function):
15
  def setup_context(ctx, inputs, outputs):
16
  input, weight, eps = inputs
17
  output, inv_rms = outputs
18
- d = input.size(-1)
19
- # Large dims: save output (output-based backward, less memory traffic)
20
- # Small dims: save input (multitoken backward, avoids division overhead)
21
- if d > 8192:
22
- ctx.save_for_backward(output, weight, inv_rms)
23
- else:
24
- ctx.save_for_backward(input, weight, inv_rms)
25
  ctx.eps = eps
26
 
27
  @staticmethod
 
15
  def setup_context(ctx, inputs, outputs):
16
  input, weight, eps = inputs
17
  output, inv_rms = outputs
18
+ ctx.save_for_backward(input, weight, inv_rms)
 
 
 
 
 
 
19
  ctx.eps = eps
20
 
21
  @staticmethod