refactor: replace warp shuffle with CUB BlockReduce
Browse files- Replace hand-written warp_reduce_sum/block_reduce_sum with CUB BlockReduce
- Add BLOCK_SIZE template param to fwd_small for compile-time CUB sizing
- Dispatch fwd_small with matched BlockReduce<float, 32/64/128/256>
- Simpler code, no performance regression (small dims slightly faster)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- activation/rms_norm.cu +84 -93
- torch-ext/torch_binding.cpp +1 -1
- torch-ext/torch_binding.h +1 -1
activation/rms_norm.cu
CHANGED
|
@@ -13,62 +13,12 @@ template <typename type, int N> struct alignas(sizeof(type) * N) type_vec_t {
|
|
| 13 |
type data[N];
|
| 14 |
};
|
| 15 |
|
| 16 |
-
//
|
| 17 |
-
|
| 18 |
-
__device__
|
| 19 |
-
|
| 20 |
-
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
|
| 21 |
-
val += __shfl_down_sync(0xffffffff, val, offset);
|
| 22 |
-
return val;
|
| 23 |
-
}
|
| 24 |
-
|
| 25 |
-
__device__ __forceinline__ float2 warp_reduce_sum2(float2 v) {
|
| 26 |
-
#pragma unroll
|
| 27 |
-
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
| 28 |
-
v.x += __shfl_down_sync(0xffffffff, v.x, offset);
|
| 29 |
-
v.y += __shfl_down_sync(0xffffffff, v.y, offset);
|
| 30 |
-
}
|
| 31 |
-
return v;
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
__device__ __forceinline__ float2 block_reduce_sum2(float2 v) {
|
| 35 |
-
__shared__ float warp_x[32];
|
| 36 |
-
__shared__ float warp_y[32];
|
| 37 |
-
int lane = threadIdx.x % WARP_SIZE;
|
| 38 |
-
int warp = threadIdx.x / WARP_SIZE;
|
| 39 |
-
int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
|
| 40 |
-
|
| 41 |
-
v = warp_reduce_sum2(v);
|
| 42 |
-
if (lane == 0) {
|
| 43 |
-
warp_x[warp] = v.x;
|
| 44 |
-
warp_y[warp] = v.y;
|
| 45 |
-
}
|
| 46 |
-
__syncthreads();
|
| 47 |
-
|
| 48 |
-
if (warp == 0) {
|
| 49 |
-
v.x = (lane < num_warps) ? warp_x[lane] : 0.0f;
|
| 50 |
-
v.y = (lane < num_warps) ? warp_y[lane] : 0.0f;
|
| 51 |
-
v = warp_reduce_sum2(v);
|
| 52 |
}
|
| 53 |
-
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
-
__device__ __forceinline__ float block_reduce_sum(float val) {
|
| 57 |
-
__shared__ float warp_sums[32];
|
| 58 |
-
int lane = threadIdx.x % WARP_SIZE;
|
| 59 |
-
int warp = threadIdx.x / WARP_SIZE;
|
| 60 |
-
int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
|
| 61 |
-
|
| 62 |
-
val = warp_reduce_sum(val);
|
| 63 |
-
if (lane == 0)
|
| 64 |
-
warp_sums[warp] = val;
|
| 65 |
-
__syncthreads();
|
| 66 |
-
|
| 67 |
-
val = (lane < num_warps) ? warp_sums[lane] : 0.0f;
|
| 68 |
-
if (warp == 0)
|
| 69 |
-
val = warp_reduce_sum(val);
|
| 70 |
-
return val;
|
| 71 |
-
}
|
| 72 |
|
| 73 |
// ---------------------------------------------------------------------------
|
| 74 |
// Forward (dim ≤ 2048): single-pass with register caching
|
|
@@ -78,7 +28,8 @@ __device__ __forceinline__ float block_reduce_sum(float val) {
|
|
| 78 |
// Pass 2: write output from cache (no second global read)
|
| 79 |
// Also writes inv_rms[token] for backward
|
| 80 |
// ---------------------------------------------------------------------------
|
| 81 |
-
template <typename scalar_t, typename acc_t, int width, int NVECS
|
|
|
|
| 82 |
__global__ void rms_norm_fwd_small(scalar_t *__restrict__ out,
|
| 83 |
acc_t *__restrict__ inv_rms_out,
|
| 84 |
const scalar_t *__restrict__ input,
|
|
@@ -103,7 +54,11 @@ __global__ void rms_norm_fwd_small(scalar_t *__restrict__ out,
|
|
| 103 |
}
|
| 104 |
}
|
| 105 |
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
__shared__ acc_t s_scale;
|
| 109 |
if (threadIdx.x == 0) {
|
|
@@ -159,7 +114,11 @@ __global__ void rms_norm_fwd_large(scalar_t *__restrict__ out,
|
|
| 159 |
}
|
| 160 |
}
|
| 161 |
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
__shared__ acc_t s_scale;
|
| 165 |
if (threadIdx.x == 0) {
|
|
@@ -203,7 +162,11 @@ __global__ void rms_norm_fwd_scalar(scalar_t *__restrict__ out,
|
|
| 203 |
sum_square += x * x;
|
| 204 |
}
|
| 205 |
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
__shared__ acc_t s_scale;
|
| 209 |
if (threadIdx.x == 0) {
|
|
@@ -264,7 +227,11 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
|
|
| 264 |
}
|
| 265 |
}
|
| 266 |
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
__shared__ acc_t s_dxx;
|
| 270 |
if (threadIdx.x == 0) {
|
|
@@ -331,6 +298,13 @@ __global__ void rms_norm_bwd_fused(
|
|
| 331 |
int64_t token_start = static_cast<int64_t>(blockIdx.x) * tpb;
|
| 332 |
int64_t token_end = min(token_start + tpb, num_tokens);
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
// Process tokens in pairs with float2 reduction
|
| 335 |
int64_t t = token_start;
|
| 336 |
for (; t + 1 < token_end; t += 2) {
|
|
@@ -355,7 +329,9 @@ __global__ void rms_norm_bwd_fused(
|
|
| 355 |
}
|
| 356 |
}
|
| 357 |
|
| 358 |
-
float2 sums =
|
|
|
|
|
|
|
| 359 |
|
| 360 |
// dxx = d_sum * scale^3 / d
|
| 361 |
__shared__ acc_t sd0, sd1;
|
|
@@ -405,7 +381,7 @@ __global__ void rms_norm_bwd_fused(
|
|
| 405 |
static_cast<acc_t>(x_vec.data[i]) *
|
| 406 |
static_cast<acc_t>(w_vec.data[i]);
|
| 407 |
}
|
| 408 |
-
d_sum =
|
| 409 |
|
| 410 |
__shared__ acc_t s_dxx;
|
| 411 |
if (threadIdx.x == 0)
|
|
@@ -520,7 +496,11 @@ __global__ void rms_norm_bwd_scalar(scalar_t *__restrict__ input_grad,
|
|
| 520 |
d_sum += dy * x * w;
|
| 521 |
}
|
| 522 |
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
|
| 525 |
__shared__ acc_t s_dxx;
|
| 526 |
if (threadIdx.x == 0) {
|
|
@@ -574,15 +554,26 @@ rms_norm(const torch::Tensor &input, // [..., d]
|
|
| 574 |
// Single-pass: 1 vec per thread (dim <= 2048)
|
| 575 |
int block_size = ((vec_d + 31) / 32) * 32;
|
| 576 |
block_size = std::max(block_size, 32);
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
} else {
|
| 587 |
// Large dims: 2-pass
|
| 588 |
dim3 block(block_2pass);
|
|
@@ -619,23 +610,23 @@ rms_norm(const torch::Tensor &input, // [..., d]
|
|
| 619 |
|
| 620 |
std::tuple<torch::Tensor, torch::Tensor>
|
| 621 |
rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
|
| 622 |
-
const torch::Tensor &
|
| 623 |
-
const torch::Tensor &weight,
|
| 624 |
-
const torch::Tensor &inv_rms,
|
| 625 |
double eps) {
|
| 626 |
-
torch::Tensor input_grad = torch::empty_like(
|
| 627 |
torch::Tensor weight_grad = torch::empty_like(weight);
|
| 628 |
|
| 629 |
AssertTensorContiguous(output_grad, "output_grad");
|
| 630 |
-
AssertTensorContiguous(
|
| 631 |
AssertTensorContiguous(weight, "weight");
|
| 632 |
AssertTensorContiguous(inv_rms, "inv_rms");
|
| 633 |
|
| 634 |
-
int d =
|
| 635 |
-
int64_t num_tokens =
|
| 636 |
|
| 637 |
dim3 grid(num_tokens);
|
| 638 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(
|
| 639 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 640 |
|
| 641 |
if (d % 8 == 0) {
|
|
@@ -648,20 +639,20 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
|
|
| 648 |
int64_t num_blocks_mt = (num_tokens + tpb - 1) / tpb;
|
| 649 |
|
| 650 |
torch::Tensor wg_acc =
|
| 651 |
-
torch::zeros({d},
|
| 652 |
|
| 653 |
int block_size = std::min(vec_d, 128);
|
| 654 |
block_size = ((block_size + 31) / 32) * 32;
|
| 655 |
block_size = std::max(block_size, 32);
|
| 656 |
size_t smem = d * sizeof(float);
|
| 657 |
-
//
|
| 658 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 659 |
-
|
| 660 |
motif::rms_norm_bwd_fused<scalar_t, float, 8>
|
| 661 |
<<<num_blocks_mt, block_size, smem, stream>>>(
|
| 662 |
input_grad.data_ptr<scalar_t>(), wg_acc.data_ptr<float>(),
|
| 663 |
output_grad.data_ptr<scalar_t>(),
|
| 664 |
-
|
| 665 |
inv_rms.data_ptr<float>(), d, num_tokens, tpb);
|
| 666 |
});
|
| 667 |
|
|
@@ -671,20 +662,20 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
|
|
| 671 |
} else {
|
| 672 |
// Large dims (d > 8192): input-based bwd + column-parallel weight grad
|
| 673 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 674 |
-
|
| 675 |
motif::rms_norm_bwd_large_input_grad<scalar_t, float, 8>
|
| 676 |
<<<grid, dim3(256), 0, stream>>>(
|
| 677 |
input_grad.data_ptr<scalar_t>(),
|
| 678 |
output_grad.data_ptr<scalar_t>(),
|
| 679 |
-
|
| 680 |
inv_rms.data_ptr<float>(), d);
|
| 681 |
});
|
| 682 |
|
| 683 |
if (weight_grad.defined()) {
|
| 684 |
int64_t chunk_size = 256;
|
| 685 |
int64_t num_chunks = (num_tokens + chunk_size - 1) / chunk_size;
|
| 686 |
-
torch::Tensor partial_wg =
|
| 687 |
-
{num_chunks, d},
|
| 688 |
|
| 689 |
constexpr int TILE_T = 64;
|
| 690 |
constexpr int VEC_W = 8;
|
|
@@ -692,13 +683,13 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
|
|
| 692 |
int cols_per_block = wg_threads * VEC_W;
|
| 693 |
dim3 wg_grid((d + cols_per_block - 1) / cols_per_block, num_chunks);
|
| 694 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 695 |
-
|
| 696 |
motif::rms_norm_bwd_large_weight_grad<scalar_t, float, TILE_T,
|
| 697 |
VEC_W>
|
| 698 |
<<<wg_grid, wg_threads, 0, stream>>>(
|
| 699 |
partial_wg.data_ptr<float>(),
|
| 700 |
output_grad.data_ptr<scalar_t>(),
|
| 701 |
-
|
| 702 |
inv_rms.data_ptr<float>(), d, num_tokens, chunk_size);
|
| 703 |
});
|
| 704 |
|
|
@@ -712,19 +703,19 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
|
|
| 712 |
// Scalar fallback: temp buffer + at::sum_out (still uses input-based for
|
| 713 |
// scalar)
|
| 714 |
torch::Tensor temp_weight_grad =
|
| 715 |
-
torch::empty({num_tokens, d},
|
| 716 |
|
| 717 |
int block_size = std::min(d, 256);
|
| 718 |
block_size = ((block_size + 31) / 32) * 32;
|
| 719 |
block_size = std::max(block_size, 32);
|
| 720 |
dim3 block(block_size);
|
| 721 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 722 |
-
|
| 723 |
motif::rms_norm_bwd_scalar<scalar_t, float>
|
| 724 |
<<<grid, block, 0, stream>>>(
|
| 725 |
input_grad.data_ptr<scalar_t>(),
|
| 726 |
temp_weight_grad.data_ptr<float>(),
|
| 727 |
-
output_grad.data_ptr<scalar_t>(),
|
| 728 |
weight.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), d);
|
| 729 |
});
|
| 730 |
|
|
|
|
| 13 |
type data[N];
|
| 14 |
};
|
| 15 |
|
| 16 |
+
// Float2 sum operator for CUB BlockReduce
|
| 17 |
+
struct Float2SumOp {
|
| 18 |
+
__device__ float2 operator()(const float2 &a, const float2 &b) const {
|
| 19 |
+
return make_float2(a.x + b.x, a.y + b.y);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
}
|
| 21 |
+
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
// ---------------------------------------------------------------------------
|
| 24 |
// Forward (dim ≤ 2048): single-pass with register caching
|
|
|
|
| 28 |
// Pass 2: write output from cache (no second global read)
|
| 29 |
// Also writes inv_rms[token] for backward
|
| 30 |
// ---------------------------------------------------------------------------
|
| 31 |
+
template <typename scalar_t, typename acc_t, int width, int NVECS,
|
| 32 |
+
int BLOCK_SIZE = 256>
|
| 33 |
__global__ void rms_norm_fwd_small(scalar_t *__restrict__ out,
|
| 34 |
acc_t *__restrict__ inv_rms_out,
|
| 35 |
const scalar_t *__restrict__ input,
|
|
|
|
| 54 |
}
|
| 55 |
}
|
| 56 |
|
| 57 |
+
{
|
| 58 |
+
using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>;
|
| 59 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 60 |
+
sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
|
| 61 |
+
}
|
| 62 |
|
| 63 |
__shared__ acc_t s_scale;
|
| 64 |
if (threadIdx.x == 0) {
|
|
|
|
| 114 |
}
|
| 115 |
}
|
| 116 |
|
| 117 |
+
{
|
| 118 |
+
using BlockReduce = cub::BlockReduce<float, 256>;
|
| 119 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 120 |
+
sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
|
| 121 |
+
}
|
| 122 |
|
| 123 |
__shared__ acc_t s_scale;
|
| 124 |
if (threadIdx.x == 0) {
|
|
|
|
| 162 |
sum_square += x * x;
|
| 163 |
}
|
| 164 |
|
| 165 |
+
{
|
| 166 |
+
using BlockReduce = cub::BlockReduce<float, 256>;
|
| 167 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 168 |
+
sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
|
| 169 |
+
}
|
| 170 |
|
| 171 |
__shared__ acc_t s_scale;
|
| 172 |
if (threadIdx.x == 0) {
|
|
|
|
| 227 |
}
|
| 228 |
}
|
| 229 |
|
| 230 |
+
{
|
| 231 |
+
using BlockReduce = cub::BlockReduce<float, 256>;
|
| 232 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 233 |
+
d_sum = BlockReduce(reduceStore).Sum(d_sum, blockDim.x);
|
| 234 |
+
}
|
| 235 |
|
| 236 |
__shared__ acc_t s_dxx;
|
| 237 |
if (threadIdx.x == 0) {
|
|
|
|
| 298 |
int64_t token_start = static_cast<int64_t>(blockIdx.x) * tpb;
|
| 299 |
int64_t token_end = min(token_start + tpb, num_tokens);
|
| 300 |
|
| 301 |
+
// Shared TempStorage for CUB block reductions — declared once, reused per
|
| 302 |
+
// iteration
|
| 303 |
+
using BlockReduce2 = cub::BlockReduce<float2, 256>;
|
| 304 |
+
__shared__ typename BlockReduce2::TempStorage reduceStore2;
|
| 305 |
+
using BlockReduce = cub::BlockReduce<float, 256>;
|
| 306 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 307 |
+
|
| 308 |
// Process tokens in pairs with float2 reduction
|
| 309 |
int64_t t = token_start;
|
| 310 |
for (; t + 1 < token_end; t += 2) {
|
|
|
|
| 329 |
}
|
| 330 |
}
|
| 331 |
|
| 332 |
+
float2 sums =
|
| 333 |
+
BlockReduce2(reduceStore2)
|
| 334 |
+
.Reduce(make_float2(dsum0, dsum1), Float2SumOp{}, blockDim.x);
|
| 335 |
|
| 336 |
// dxx = d_sum * scale^3 / d
|
| 337 |
__shared__ acc_t sd0, sd1;
|
|
|
|
| 381 |
static_cast<acc_t>(x_vec.data[i]) *
|
| 382 |
static_cast<acc_t>(w_vec.data[i]);
|
| 383 |
}
|
| 384 |
+
d_sum = BlockReduce(reduceStore).Sum(d_sum, blockDim.x);
|
| 385 |
|
| 386 |
__shared__ acc_t s_dxx;
|
| 387 |
if (threadIdx.x == 0)
|
|
|
|
| 496 |
d_sum += dy * x * w;
|
| 497 |
}
|
| 498 |
|
| 499 |
+
{
|
| 500 |
+
using BlockReduce = cub::BlockReduce<float, 256>;
|
| 501 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 502 |
+
d_sum = BlockReduce(reduceStore).Sum(d_sum, blockDim.x);
|
| 503 |
+
}
|
| 504 |
|
| 505 |
__shared__ acc_t s_dxx;
|
| 506 |
if (threadIdx.x == 0) {
|
|
|
|
| 554 |
// Single-pass: 1 vec per thread (dim <= 2048)
|
| 555 |
int block_size = ((vec_d + 31) / 32) * 32;
|
| 556 |
block_size = std::max(block_size, 32);
|
| 557 |
+
|
| 558 |
+
#define LAUNCH_FWD_SMALL(BS) \
|
| 559 |
+
MOTIF_DISPATCH_FLOATING_TYPES( \
|
| 560 |
+
input.scalar_type(), "rms_norm_fwd_small", [&] { \
|
| 561 |
+
motif::rms_norm_fwd_small<scalar_t, float, 8, 1, BS> \
|
| 562 |
+
<<<grid, dim3(BS), 0, stream>>>( \
|
| 563 |
+
out.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), \
|
| 564 |
+
input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), eps, \
|
| 565 |
+
d); \
|
| 566 |
+
})
|
| 567 |
+
|
| 568 |
+
if (block_size <= 32)
|
| 569 |
+
LAUNCH_FWD_SMALL(32);
|
| 570 |
+
else if (block_size <= 64)
|
| 571 |
+
LAUNCH_FWD_SMALL(64);
|
| 572 |
+
else if (block_size <= 128)
|
| 573 |
+
LAUNCH_FWD_SMALL(128);
|
| 574 |
+
else
|
| 575 |
+
LAUNCH_FWD_SMALL(256);
|
| 576 |
+
#undef LAUNCH_FWD_SMALL
|
| 577 |
} else {
|
| 578 |
// Large dims: 2-pass
|
| 579 |
dim3 block(block_2pass);
|
|
|
|
| 610 |
|
| 611 |
std::tuple<torch::Tensor, torch::Tensor>
|
| 612 |
rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
|
| 613 |
+
const torch::Tensor &input, // [..., d]
|
| 614 |
+
const torch::Tensor &weight, // [d]
|
| 615 |
+
const torch::Tensor &inv_rms, // [num_tokens]
|
| 616 |
double eps) {
|
| 617 |
+
torch::Tensor input_grad = torch::empty_like(input);
|
| 618 |
torch::Tensor weight_grad = torch::empty_like(weight);
|
| 619 |
|
| 620 |
AssertTensorContiguous(output_grad, "output_grad");
|
| 621 |
+
AssertTensorContiguous(input, "input");
|
| 622 |
AssertTensorContiguous(weight, "weight");
|
| 623 |
AssertTensorContiguous(inv_rms, "inv_rms");
|
| 624 |
|
| 625 |
+
int d = input.size(-1);
|
| 626 |
+
int64_t num_tokens = input.numel() / input.size(-1);
|
| 627 |
|
| 628 |
dim3 grid(num_tokens);
|
| 629 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 630 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 631 |
|
| 632 |
if (d % 8 == 0) {
|
|
|
|
| 639 |
int64_t num_blocks_mt = (num_tokens + tpb - 1) / tpb;
|
| 640 |
|
| 641 |
torch::Tensor wg_acc =
|
| 642 |
+
torch::zeros({d}, input.options().dtype(torch::kFloat));
|
| 643 |
|
| 644 |
int block_size = std::min(vec_d, 128);
|
| 645 |
block_size = ((block_size + 31) / 32) * 32;
|
| 646 |
block_size = std::max(block_size, 32);
|
| 647 |
size_t smem = d * sizeof(float);
|
| 648 |
+
// Small/medium dims (d <= 8192): multi-token block with shared memory
|
| 649 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 650 |
+
input.scalar_type(), "rms_norm_bwd_fused", [&] {
|
| 651 |
motif::rms_norm_bwd_fused<scalar_t, float, 8>
|
| 652 |
<<<num_blocks_mt, block_size, smem, stream>>>(
|
| 653 |
input_grad.data_ptr<scalar_t>(), wg_acc.data_ptr<float>(),
|
| 654 |
output_grad.data_ptr<scalar_t>(),
|
| 655 |
+
input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
|
| 656 |
inv_rms.data_ptr<float>(), d, num_tokens, tpb);
|
| 657 |
});
|
| 658 |
|
|
|
|
| 662 |
} else {
|
| 663 |
// Large dims (d > 8192): input-based bwd + column-parallel weight grad
|
| 664 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 665 |
+
input.scalar_type(), "rms_norm_bwd_large_input_grad", [&] {
|
| 666 |
motif::rms_norm_bwd_large_input_grad<scalar_t, float, 8>
|
| 667 |
<<<grid, dim3(256), 0, stream>>>(
|
| 668 |
input_grad.data_ptr<scalar_t>(),
|
| 669 |
output_grad.data_ptr<scalar_t>(),
|
| 670 |
+
input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
|
| 671 |
inv_rms.data_ptr<float>(), d);
|
| 672 |
});
|
| 673 |
|
| 674 |
if (weight_grad.defined()) {
|
| 675 |
int64_t chunk_size = 256;
|
| 676 |
int64_t num_chunks = (num_tokens + chunk_size - 1) / chunk_size;
|
| 677 |
+
torch::Tensor partial_wg =
|
| 678 |
+
torch::empty({num_chunks, d}, input.options().dtype(torch::kFloat));
|
| 679 |
|
| 680 |
constexpr int TILE_T = 64;
|
| 681 |
constexpr int VEC_W = 8;
|
|
|
|
| 683 |
int cols_per_block = wg_threads * VEC_W;
|
| 684 |
dim3 wg_grid((d + cols_per_block - 1) / cols_per_block, num_chunks);
|
| 685 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 686 |
+
input.scalar_type(), "rms_norm_bwd_large_weight_grad", [&] {
|
| 687 |
motif::rms_norm_bwd_large_weight_grad<scalar_t, float, TILE_T,
|
| 688 |
VEC_W>
|
| 689 |
<<<wg_grid, wg_threads, 0, stream>>>(
|
| 690 |
partial_wg.data_ptr<float>(),
|
| 691 |
output_grad.data_ptr<scalar_t>(),
|
| 692 |
+
input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
|
| 693 |
inv_rms.data_ptr<float>(), d, num_tokens, chunk_size);
|
| 694 |
});
|
| 695 |
|
|
|
|
| 703 |
// Scalar fallback: temp buffer + at::sum_out (still uses input-based for
|
| 704 |
// scalar)
|
| 705 |
torch::Tensor temp_weight_grad =
|
| 706 |
+
torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat));
|
| 707 |
|
| 708 |
int block_size = std::min(d, 256);
|
| 709 |
block_size = ((block_size + 31) / 32) * 32;
|
| 710 |
block_size = std::max(block_size, 32);
|
| 711 |
dim3 block(block_size);
|
| 712 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 713 |
+
input.scalar_type(), "rms_norm_bwd_scalar", [&] {
|
| 714 |
motif::rms_norm_bwd_scalar<scalar_t, float>
|
| 715 |
<<<grid, block, 0, stream>>>(
|
| 716 |
input_grad.data_ptr<scalar_t>(),
|
| 717 |
temp_weight_grad.data_ptr<float>(),
|
| 718 |
+
output_grad.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
| 719 |
weight.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), d);
|
| 720 |
});
|
| 721 |
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -20,7 +20,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 20 |
"rms_norm(Tensor input, Tensor weight, float eps) -> (Tensor, Tensor)");
|
| 21 |
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
| 22 |
|
| 23 |
-
ops.def("rms_norm_backward(Tensor output_grad, Tensor
|
| 24 |
"Tensor inv_rms, float eps) -> (Tensor, Tensor)");
|
| 25 |
ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
|
| 26 |
|
|
|
|
| 20 |
"rms_norm(Tensor input, Tensor weight, float eps) -> (Tensor, Tensor)");
|
| 21 |
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
| 22 |
|
| 23 |
+
ops.def("rms_norm_backward(Tensor output_grad, Tensor input, Tensor weight, "
|
| 24 |
"Tensor inv_rms, float eps) -> (Tensor, Tensor)");
|
| 25 |
ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
|
| 26 |
|
torch-ext/torch_binding.h
CHANGED
|
@@ -14,7 +14,7 @@ void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
|
|
| 14 |
std::tuple<torch::Tensor, torch::Tensor>
|
| 15 |
rms_norm(const torch::Tensor &input, const torch::Tensor &weights, double eps);
|
| 16 |
std::tuple<torch::Tensor, torch::Tensor>
|
| 17 |
-
rms_norm_backward(const torch::Tensor &output_grad, const torch::Tensor &
|
| 18 |
const torch::Tensor &weight, const torch::Tensor &inv_rms,
|
| 19 |
double eps);
|
| 20 |
|
|
|
|
| 14 |
std::tuple<torch::Tensor, torch::Tensor>
|
| 15 |
rms_norm(const torch::Tensor &input, const torch::Tensor &weights, double eps);
|
| 16 |
std::tuple<torch::Tensor, torch::Tensor>
|
| 17 |
+
rms_norm_backward(const torch::Tensor &output_grad, const torch::Tensor &input,
|
| 18 |
const torch::Tensor &weight, const torch::Tensor &inv_rms,
|
| 19 |
double eps);
|
| 20 |
|