Kernels
wyldecat Claude Opus 4.6 (1M context) commited on
Commit
79a877a
·
1 Parent(s): 09ecd67

refactor: replace warp shuffle with CUB BlockReduce

Browse files

- Replace hand-written warp_reduce_sum/block_reduce_sum with CUB BlockReduce
- Add BLOCK_SIZE template param to fwd_small for compile-time CUB sizing
- Dispatch fwd_small with matched BlockReduce<float, 32/64/128/256>
- Simpler code, no performance regression (small dims slightly faster)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

activation/rms_norm.cu CHANGED
@@ -13,62 +13,12 @@ template <typename type, int N> struct alignas(sizeof(type) * N) type_vec_t {
13
  type data[N];
14
  };
15
 
16
- // Warp shuffle block reduction
17
- // Uses ~128 bytes shared memory vs ~8KB for CUB
18
- __device__ __forceinline__ float warp_reduce_sum(float val) {
19
- #pragma unroll
20
- for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
21
- val += __shfl_down_sync(0xffffffff, val, offset);
22
- return val;
23
- }
24
-
25
- __device__ __forceinline__ float2 warp_reduce_sum2(float2 v) {
26
- #pragma unroll
27
- for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
28
- v.x += __shfl_down_sync(0xffffffff, v.x, offset);
29
- v.y += __shfl_down_sync(0xffffffff, v.y, offset);
30
- }
31
- return v;
32
- }
33
-
34
- __device__ __forceinline__ float2 block_reduce_sum2(float2 v) {
35
- __shared__ float warp_x[32];
36
- __shared__ float warp_y[32];
37
- int lane = threadIdx.x % WARP_SIZE;
38
- int warp = threadIdx.x / WARP_SIZE;
39
- int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
40
-
41
- v = warp_reduce_sum2(v);
42
- if (lane == 0) {
43
- warp_x[warp] = v.x;
44
- warp_y[warp] = v.y;
45
- }
46
- __syncthreads();
47
-
48
- if (warp == 0) {
49
- v.x = (lane < num_warps) ? warp_x[lane] : 0.0f;
50
- v.y = (lane < num_warps) ? warp_y[lane] : 0.0f;
51
- v = warp_reduce_sum2(v);
52
  }
53
- return v;
54
- }
55
-
56
- __device__ __forceinline__ float block_reduce_sum(float val) {
57
- __shared__ float warp_sums[32];
58
- int lane = threadIdx.x % WARP_SIZE;
59
- int warp = threadIdx.x / WARP_SIZE;
60
- int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
61
-
62
- val = warp_reduce_sum(val);
63
- if (lane == 0)
64
- warp_sums[warp] = val;
65
- __syncthreads();
66
-
67
- val = (lane < num_warps) ? warp_sums[lane] : 0.0f;
68
- if (warp == 0)
69
- val = warp_reduce_sum(val);
70
- return val;
71
- }
72
 
73
  // ---------------------------------------------------------------------------
74
  // Forward (dim ≤ 2048): single-pass with register caching
@@ -78,7 +28,8 @@ __device__ __forceinline__ float block_reduce_sum(float val) {
78
  // Pass 2: write output from cache (no second global read)
79
  // Also writes inv_rms[token] for backward
80
  // ---------------------------------------------------------------------------
81
- template <typename scalar_t, typename acc_t, int width, int NVECS>
 
82
  __global__ void rms_norm_fwd_small(scalar_t *__restrict__ out,
83
  acc_t *__restrict__ inv_rms_out,
84
  const scalar_t *__restrict__ input,
@@ -103,7 +54,11 @@ __global__ void rms_norm_fwd_small(scalar_t *__restrict__ out,
103
  }
104
  }
105
 
106
- sum_square = block_reduce_sum(sum_square);
 
 
 
 
107
 
108
  __shared__ acc_t s_scale;
109
  if (threadIdx.x == 0) {
@@ -159,7 +114,11 @@ __global__ void rms_norm_fwd_large(scalar_t *__restrict__ out,
159
  }
160
  }
161
 
162
- sum_square = block_reduce_sum(sum_square);
 
 
 
 
163
 
164
  __shared__ acc_t s_scale;
165
  if (threadIdx.x == 0) {
@@ -203,7 +162,11 @@ __global__ void rms_norm_fwd_scalar(scalar_t *__restrict__ out,
203
  sum_square += x * x;
204
  }
205
 
206
- sum_square = block_reduce_sum(sum_square);
 
 
 
 
207
 
208
  __shared__ acc_t s_scale;
209
  if (threadIdx.x == 0) {
@@ -264,7 +227,11 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
264
  }
265
  }
266
 
267
- d_sum = block_reduce_sum(d_sum);
 
 
 
 
268
 
269
  __shared__ acc_t s_dxx;
270
  if (threadIdx.x == 0) {
@@ -331,6 +298,13 @@ __global__ void rms_norm_bwd_fused(
331
  int64_t token_start = static_cast<int64_t>(blockIdx.x) * tpb;
332
  int64_t token_end = min(token_start + tpb, num_tokens);
333
 
 
 
 
 
 
 
 
334
  // Process tokens in pairs with float2 reduction
335
  int64_t t = token_start;
336
  for (; t + 1 < token_end; t += 2) {
@@ -355,7 +329,9 @@ __global__ void rms_norm_bwd_fused(
355
  }
356
  }
357
 
358
- float2 sums = block_reduce_sum2(make_float2(dsum0, dsum1));
 
 
359
 
360
  // dxx = d_sum * scale^3 / d
361
  __shared__ acc_t sd0, sd1;
@@ -405,7 +381,7 @@ __global__ void rms_norm_bwd_fused(
405
  static_cast<acc_t>(x_vec.data[i]) *
406
  static_cast<acc_t>(w_vec.data[i]);
407
  }
408
- d_sum = block_reduce_sum(d_sum);
409
 
410
  __shared__ acc_t s_dxx;
411
  if (threadIdx.x == 0)
@@ -520,7 +496,11 @@ __global__ void rms_norm_bwd_scalar(scalar_t *__restrict__ input_grad,
520
  d_sum += dy * x * w;
521
  }
522
 
523
- d_sum = block_reduce_sum(d_sum);
 
 
 
 
524
 
525
  __shared__ acc_t s_dxx;
526
  if (threadIdx.x == 0) {
@@ -574,15 +554,26 @@ rms_norm(const torch::Tensor &input, // [..., d]
574
  // Single-pass: 1 vec per thread (dim <= 2048)
575
  int block_size = ((vec_d + 31) / 32) * 32;
576
  block_size = std::max(block_size, 32);
577
- dim3 block(block_size);
578
- MOTIF_DISPATCH_FLOATING_TYPES(
579
- input.scalar_type(), "rms_norm_fwd_small", [&] {
580
- motif::rms_norm_fwd_small<scalar_t, float, 8, 1>
581
- <<<grid, block, 0, stream>>>(
582
- out.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(),
583
- input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
584
- eps, d);
585
- });
 
 
 
 
 
 
 
 
 
 
 
586
  } else {
587
  // Large dims: 2-pass
588
  dim3 block(block_2pass);
@@ -619,23 +610,23 @@ rms_norm(const torch::Tensor &input, // [..., d]
619
 
620
  std::tuple<torch::Tensor, torch::Tensor>
621
  rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
622
- const torch::Tensor &output, // [..., d] — forward output y
623
- const torch::Tensor &weight, // [d]
624
- const torch::Tensor &inv_rms, // [num_tokens]
625
  double eps) {
626
- torch::Tensor input_grad = torch::empty_like(output);
627
  torch::Tensor weight_grad = torch::empty_like(weight);
628
 
629
  AssertTensorContiguous(output_grad, "output_grad");
630
- AssertTensorContiguous(output, "output");
631
  AssertTensorContiguous(weight, "weight");
632
  AssertTensorContiguous(inv_rms, "inv_rms");
633
 
634
- int d = output.size(-1);
635
- int64_t num_tokens = output.numel() / output.size(-1);
636
 
637
  dim3 grid(num_tokens);
638
- const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
639
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
640
 
641
  if (d % 8 == 0) {
@@ -648,20 +639,20 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
648
  int64_t num_blocks_mt = (num_tokens + tpb - 1) / tpb;
649
 
650
  torch::Tensor wg_acc =
651
- torch::zeros({d}, output.options().dtype(torch::kFloat));
652
 
653
  int block_size = std::min(vec_d, 128);
654
  block_size = ((block_size + 31) / 32) * 32;
655
  block_size = std::max(block_size, 32);
656
  size_t smem = d * sizeof(float);
657
- // 'output' C++ arg receives input (saved by Python autograd)
658
  MOTIF_DISPATCH_FLOATING_TYPES(
659
- output.scalar_type(), "rms_norm_bwd_fused", [&] {
660
  motif::rms_norm_bwd_fused<scalar_t, float, 8>
661
  <<<num_blocks_mt, block_size, smem, stream>>>(
662
  input_grad.data_ptr<scalar_t>(), wg_acc.data_ptr<float>(),
663
  output_grad.data_ptr<scalar_t>(),
664
- output.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
665
  inv_rms.data_ptr<float>(), d, num_tokens, tpb);
666
  });
667
 
@@ -671,20 +662,20 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
671
  } else {
672
  // Large dims (d > 8192): input-based bwd + column-parallel weight grad
673
  MOTIF_DISPATCH_FLOATING_TYPES(
674
- output.scalar_type(), "rms_norm_bwd_large_input_grad", [&] {
675
  motif::rms_norm_bwd_large_input_grad<scalar_t, float, 8>
676
  <<<grid, dim3(256), 0, stream>>>(
677
  input_grad.data_ptr<scalar_t>(),
678
  output_grad.data_ptr<scalar_t>(),
679
- output.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
680
  inv_rms.data_ptr<float>(), d);
681
  });
682
 
683
  if (weight_grad.defined()) {
684
  int64_t chunk_size = 256;
685
  int64_t num_chunks = (num_tokens + chunk_size - 1) / chunk_size;
686
- torch::Tensor partial_wg = torch::empty(
687
- {num_chunks, d}, output.options().dtype(torch::kFloat));
688
 
689
  constexpr int TILE_T = 64;
690
  constexpr int VEC_W = 8;
@@ -692,13 +683,13 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
692
  int cols_per_block = wg_threads * VEC_W;
693
  dim3 wg_grid((d + cols_per_block - 1) / cols_per_block, num_chunks);
694
  MOTIF_DISPATCH_FLOATING_TYPES(
695
- output.scalar_type(), "rms_norm_bwd_large_weight_grad", [&] {
696
  motif::rms_norm_bwd_large_weight_grad<scalar_t, float, TILE_T,
697
  VEC_W>
698
  <<<wg_grid, wg_threads, 0, stream>>>(
699
  partial_wg.data_ptr<float>(),
700
  output_grad.data_ptr<scalar_t>(),
701
- output.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
702
  inv_rms.data_ptr<float>(), d, num_tokens, chunk_size);
703
  });
704
 
@@ -712,19 +703,19 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
712
  // Scalar fallback: temp buffer + at::sum_out (still uses input-based for
713
  // scalar)
714
  torch::Tensor temp_weight_grad =
715
- torch::empty({num_tokens, d}, output.options().dtype(torch::kFloat));
716
 
717
  int block_size = std::min(d, 256);
718
  block_size = ((block_size + 31) / 32) * 32;
719
  block_size = std::max(block_size, 32);
720
  dim3 block(block_size);
721
  MOTIF_DISPATCH_FLOATING_TYPES(
722
- output.scalar_type(), "rms_norm_bwd_scalar", [&] {
723
  motif::rms_norm_bwd_scalar<scalar_t, float>
724
  <<<grid, block, 0, stream>>>(
725
  input_grad.data_ptr<scalar_t>(),
726
  temp_weight_grad.data_ptr<float>(),
727
- output_grad.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
728
  weight.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), d);
729
  });
730
 
 
13
  type data[N];
14
  };
15
 
16
+ // Float2 sum operator for CUB BlockReduce
17
+ struct Float2SumOp {
18
+ __device__ float2 operator()(const float2 &a, const float2 &b) const {
19
+ return make_float2(a.x + b.x, a.y + b.y);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  }
21
+ };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  // ---------------------------------------------------------------------------
24
  // Forward (dim ≤ 2048): single-pass with register caching
 
28
  // Pass 2: write output from cache (no second global read)
29
  // Also writes inv_rms[token] for backward
30
  // ---------------------------------------------------------------------------
31
+ template <typename scalar_t, typename acc_t, int width, int NVECS,
32
+ int BLOCK_SIZE = 256>
33
  __global__ void rms_norm_fwd_small(scalar_t *__restrict__ out,
34
  acc_t *__restrict__ inv_rms_out,
35
  const scalar_t *__restrict__ input,
 
54
  }
55
  }
56
 
57
+ {
58
+ using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>;
59
+ __shared__ typename BlockReduce::TempStorage reduceStore;
60
+ sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
61
+ }
62
 
63
  __shared__ acc_t s_scale;
64
  if (threadIdx.x == 0) {
 
114
  }
115
  }
116
 
117
+ {
118
+ using BlockReduce = cub::BlockReduce<float, 256>;
119
+ __shared__ typename BlockReduce::TempStorage reduceStore;
120
+ sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
121
+ }
122
 
123
  __shared__ acc_t s_scale;
124
  if (threadIdx.x == 0) {
 
162
  sum_square += x * x;
163
  }
164
 
165
+ {
166
+ using BlockReduce = cub::BlockReduce<float, 256>;
167
+ __shared__ typename BlockReduce::TempStorage reduceStore;
168
+ sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
169
+ }
170
 
171
  __shared__ acc_t s_scale;
172
  if (threadIdx.x == 0) {
 
227
  }
228
  }
229
 
230
+ {
231
+ using BlockReduce = cub::BlockReduce<float, 256>;
232
+ __shared__ typename BlockReduce::TempStorage reduceStore;
233
+ d_sum = BlockReduce(reduceStore).Sum(d_sum, blockDim.x);
234
+ }
235
 
236
  __shared__ acc_t s_dxx;
237
  if (threadIdx.x == 0) {
 
298
  int64_t token_start = static_cast<int64_t>(blockIdx.x) * tpb;
299
  int64_t token_end = min(token_start + tpb, num_tokens);
300
 
301
+ // Shared TempStorage for CUB block reductions — declared once, reused per
302
+ // iteration
303
+ using BlockReduce2 = cub::BlockReduce<float2, 256>;
304
+ __shared__ typename BlockReduce2::TempStorage reduceStore2;
305
+ using BlockReduce = cub::BlockReduce<float, 256>;
306
+ __shared__ typename BlockReduce::TempStorage reduceStore;
307
+
308
  // Process tokens in pairs with float2 reduction
309
  int64_t t = token_start;
310
  for (; t + 1 < token_end; t += 2) {
 
329
  }
330
  }
331
 
332
+ float2 sums =
333
+ BlockReduce2(reduceStore2)
334
+ .Reduce(make_float2(dsum0, dsum1), Float2SumOp{}, blockDim.x);
335
 
336
  // dxx = d_sum * scale^3 / d
337
  __shared__ acc_t sd0, sd1;
 
381
  static_cast<acc_t>(x_vec.data[i]) *
382
  static_cast<acc_t>(w_vec.data[i]);
383
  }
384
+ d_sum = BlockReduce(reduceStore).Sum(d_sum, blockDim.x);
385
 
386
  __shared__ acc_t s_dxx;
387
  if (threadIdx.x == 0)
 
496
  d_sum += dy * x * w;
497
  }
498
 
499
+ {
500
+ using BlockReduce = cub::BlockReduce<float, 256>;
501
+ __shared__ typename BlockReduce::TempStorage reduceStore;
502
+ d_sum = BlockReduce(reduceStore).Sum(d_sum, blockDim.x);
503
+ }
504
 
505
  __shared__ acc_t s_dxx;
506
  if (threadIdx.x == 0) {
 
554
  // Single-pass: 1 vec per thread (dim <= 2048)
555
  int block_size = ((vec_d + 31) / 32) * 32;
556
  block_size = std::max(block_size, 32);
557
+
558
+ #define LAUNCH_FWD_SMALL(BS) \
559
+ MOTIF_DISPATCH_FLOATING_TYPES( \
560
+ input.scalar_type(), "rms_norm_fwd_small", [&] { \
561
+ motif::rms_norm_fwd_small<scalar_t, float, 8, 1, BS> \
562
+ <<<grid, dim3(BS), 0, stream>>>( \
563
+ out.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), \
564
+ input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), eps, \
565
+ d); \
566
+ })
567
+
568
+ if (block_size <= 32)
569
+ LAUNCH_FWD_SMALL(32);
570
+ else if (block_size <= 64)
571
+ LAUNCH_FWD_SMALL(64);
572
+ else if (block_size <= 128)
573
+ LAUNCH_FWD_SMALL(128);
574
+ else
575
+ LAUNCH_FWD_SMALL(256);
576
+ #undef LAUNCH_FWD_SMALL
577
  } else {
578
  // Large dims: 2-pass
579
  dim3 block(block_2pass);
 
610
 
611
  std::tuple<torch::Tensor, torch::Tensor>
612
  rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
613
+ const torch::Tensor &input, // [..., d]
614
+ const torch::Tensor &weight, // [d]
615
+ const torch::Tensor &inv_rms, // [num_tokens]
616
  double eps) {
617
+ torch::Tensor input_grad = torch::empty_like(input);
618
  torch::Tensor weight_grad = torch::empty_like(weight);
619
 
620
  AssertTensorContiguous(output_grad, "output_grad");
621
+ AssertTensorContiguous(input, "input");
622
  AssertTensorContiguous(weight, "weight");
623
  AssertTensorContiguous(inv_rms, "inv_rms");
624
 
625
+ int d = input.size(-1);
626
+ int64_t num_tokens = input.numel() / input.size(-1);
627
 
628
  dim3 grid(num_tokens);
629
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
630
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
631
 
632
  if (d % 8 == 0) {
 
639
  int64_t num_blocks_mt = (num_tokens + tpb - 1) / tpb;
640
 
641
  torch::Tensor wg_acc =
642
+ torch::zeros({d}, input.options().dtype(torch::kFloat));
643
 
644
  int block_size = std::min(vec_d, 128);
645
  block_size = ((block_size + 31) / 32) * 32;
646
  block_size = std::max(block_size, 32);
647
  size_t smem = d * sizeof(float);
648
+ // Small/medium dims (d <= 8192): multi-token block with shared memory
649
  MOTIF_DISPATCH_FLOATING_TYPES(
650
+ input.scalar_type(), "rms_norm_bwd_fused", [&] {
651
  motif::rms_norm_bwd_fused<scalar_t, float, 8>
652
  <<<num_blocks_mt, block_size, smem, stream>>>(
653
  input_grad.data_ptr<scalar_t>(), wg_acc.data_ptr<float>(),
654
  output_grad.data_ptr<scalar_t>(),
655
+ input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
656
  inv_rms.data_ptr<float>(), d, num_tokens, tpb);
657
  });
658
 
 
662
  } else {
663
  // Large dims (d > 8192): input-based bwd + column-parallel weight grad
664
  MOTIF_DISPATCH_FLOATING_TYPES(
665
+ input.scalar_type(), "rms_norm_bwd_large_input_grad", [&] {
666
  motif::rms_norm_bwd_large_input_grad<scalar_t, float, 8>
667
  <<<grid, dim3(256), 0, stream>>>(
668
  input_grad.data_ptr<scalar_t>(),
669
  output_grad.data_ptr<scalar_t>(),
670
+ input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
671
  inv_rms.data_ptr<float>(), d);
672
  });
673
 
674
  if (weight_grad.defined()) {
675
  int64_t chunk_size = 256;
676
  int64_t num_chunks = (num_tokens + chunk_size - 1) / chunk_size;
677
+ torch::Tensor partial_wg =
678
+ torch::empty({num_chunks, d}, input.options().dtype(torch::kFloat));
679
 
680
  constexpr int TILE_T = 64;
681
  constexpr int VEC_W = 8;
 
683
  int cols_per_block = wg_threads * VEC_W;
684
  dim3 wg_grid((d + cols_per_block - 1) / cols_per_block, num_chunks);
685
  MOTIF_DISPATCH_FLOATING_TYPES(
686
+ input.scalar_type(), "rms_norm_bwd_large_weight_grad", [&] {
687
  motif::rms_norm_bwd_large_weight_grad<scalar_t, float, TILE_T,
688
  VEC_W>
689
  <<<wg_grid, wg_threads, 0, stream>>>(
690
  partial_wg.data_ptr<float>(),
691
  output_grad.data_ptr<scalar_t>(),
692
+ input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
693
  inv_rms.data_ptr<float>(), d, num_tokens, chunk_size);
694
  });
695
 
 
703
  // Scalar fallback: temp buffer + at::sum_out (still uses input-based for
704
  // scalar)
705
  torch::Tensor temp_weight_grad =
706
+ torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat));
707
 
708
  int block_size = std::min(d, 256);
709
  block_size = ((block_size + 31) / 32) * 32;
710
  block_size = std::max(block_size, 32);
711
  dim3 block(block_size);
712
  MOTIF_DISPATCH_FLOATING_TYPES(
713
+ input.scalar_type(), "rms_norm_bwd_scalar", [&] {
714
  motif::rms_norm_bwd_scalar<scalar_t, float>
715
  <<<grid, block, 0, stream>>>(
716
  input_grad.data_ptr<scalar_t>(),
717
  temp_weight_grad.data_ptr<float>(),
718
+ output_grad.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
719
  weight.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), d);
720
  });
721
 
torch-ext/torch_binding.cpp CHANGED
@@ -20,7 +20,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
20
  "rms_norm(Tensor input, Tensor weight, float eps) -> (Tensor, Tensor)");
21
  ops.impl("rms_norm", torch::kCUDA, &rms_norm);
22
 
23
- ops.def("rms_norm_backward(Tensor output_grad, Tensor output, Tensor weight, "
24
  "Tensor inv_rms, float eps) -> (Tensor, Tensor)");
25
  ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
26
 
 
20
  "rms_norm(Tensor input, Tensor weight, float eps) -> (Tensor, Tensor)");
21
  ops.impl("rms_norm", torch::kCUDA, &rms_norm);
22
 
23
+ ops.def("rms_norm_backward(Tensor output_grad, Tensor input, Tensor weight, "
24
  "Tensor inv_rms, float eps) -> (Tensor, Tensor)");
25
  ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
26
 
torch-ext/torch_binding.h CHANGED
@@ -14,7 +14,7 @@ void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
14
  std::tuple<torch::Tensor, torch::Tensor>
15
  rms_norm(const torch::Tensor &input, const torch::Tensor &weights, double eps);
16
  std::tuple<torch::Tensor, torch::Tensor>
17
- rms_norm_backward(const torch::Tensor &output_grad, const torch::Tensor &output,
18
  const torch::Tensor &weight, const torch::Tensor &inv_rms,
19
  double eps);
20
 
 
14
  std::tuple<torch::Tensor, torch::Tensor>
15
  rms_norm(const torch::Tensor &input, const torch::Tensor &weights, double eps);
16
  std::tuple<torch::Tensor, torch::Tensor>
17
+ rms_norm_backward(const torch::Tensor &output_grad, const torch::Tensor &input,
18
  const torch::Tensor &weight, const torch::Tensor &inv_rms,
19
  double eps);
20