fix: unify all backward kernels to input-based math + fix test import
Browse files- Always save input in setup_context (no output-based path)
- bwd_fused: input-based math (d_sum = sum(dy*x*w), grad = scale*dy*w - dxx*x)
- bwd_large_input_grad: input-based math
- bwd_large_weight_grad: input-based (weight_grad = dy*x*scale, needs inv_rms)
- bwd_scalar: input-based (unchanged)
- pytest.ini: add pythonpath to prevent activation/ CUDA dir namespace shadow
Eliminates bf16 precision loss from output-based x recovery (y/(w*scale)).
All 48 test_rms_norm tests pass. No performance regression.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- activation/rms_norm.cu +81 -94
- tests/pytest.ini +1 -0
- torch-ext/activation/rms_norm.py +1 -7
activation/rms_norm.cu
CHANGED
|
@@ -223,20 +223,17 @@ __global__ void rms_norm_fwd_scalar(scalar_t *__restrict__ out,
|
|
| 223 |
// Backward: single-pass with register caching (NVECS vecs per thread)
|
| 224 |
// ---------------------------------------------------------------------------
|
| 225 |
// ---------------------------------------------------------------------------
|
| 226 |
-
// Backward input_grad (dim > 8192):
|
| 227 |
// Architecture: 1 block/token, 256 threads, __launch_bounds__(256, 4)
|
| 228 |
-
//
|
| 229 |
-
//
|
| 230 |
-
// input_grad = scale*dy*w - dxx*y/(w*scale)
|
| 231 |
-
// Pass 1: read dy + y → d_sum (warp shuffle reduction)
|
| 232 |
-
// Pass 2: read dy + y + w (L1 hit) → write input_grad
|
| 233 |
// Weight grad computed by separate bwd_large_weight_grad kernel
|
| 234 |
// ---------------------------------------------------------------------------
|
| 235 |
template <typename scalar_t, typename acc_t, int width>
|
| 236 |
__global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
|
| 237 |
scalar_t *__restrict__ input_grad,
|
| 238 |
const scalar_t *__restrict__ output_grad, // dy
|
| 239 |
-
const scalar_t *__restrict__
|
| 240 |
const scalar_t *__restrict__ weight, const acc_t *__restrict__ inv_rms,
|
| 241 |
const int d) {
|
| 242 |
using vec_t = type_vec_t<scalar_t, width>;
|
|
@@ -245,8 +242,7 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
|
|
| 245 |
const int vec_d = d / width;
|
| 246 |
const int64_t vec_offset = token_idx * vec_d;
|
| 247 |
|
| 248 |
-
const vec_t *__restrict__
|
| 249 |
-
reinterpret_cast<const vec_t *>(output);
|
| 250 |
const vec_t *__restrict__ output_grad_vec =
|
| 251 |
reinterpret_cast<const vec_t *>(output_grad);
|
| 252 |
const vec_t *__restrict__ weight_vec =
|
|
@@ -254,46 +250,44 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
|
|
| 254 |
|
| 255 |
acc_t scale = inv_rms[token_idx];
|
| 256 |
|
| 257 |
-
// Pass 1: d_sum =
|
| 258 |
-
acc_t
|
| 259 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 260 |
-
vec_t
|
| 261 |
vec_t dy_vec = output_grad_vec[vec_offset + vidx];
|
|
|
|
| 262 |
#pragma unroll
|
| 263 |
for (int i = 0; i < width; ++i) {
|
| 264 |
-
|
| 265 |
-
|
|
|
|
| 266 |
}
|
| 267 |
}
|
| 268 |
|
| 269 |
-
|
| 270 |
|
| 271 |
-
// d_sum = dy_y_sum / scale, dxx = d_sum * scale^3 / d = dy_y_sum * scale^2 /
|
| 272 |
-
// d
|
| 273 |
__shared__ acc_t s_dxx;
|
| 274 |
if (threadIdx.x == 0) {
|
| 275 |
-
s_dxx =
|
| 276 |
}
|
| 277 |
__syncthreads();
|
| 278 |
acc_t dxx = s_dxx;
|
| 279 |
|
| 280 |
-
// Pass 2: input_grad = scale * dy * w - dxx *
|
| 281 |
vec_t *__restrict__ input_grad_vec = reinterpret_cast<vec_t *>(input_grad);
|
| 282 |
|
| 283 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 284 |
-
vec_t
|
| 285 |
vec_t dy_vec = output_grad_vec[vec_offset + vidx];
|
| 286 |
vec_t w_vec = weight_vec[vidx];
|
| 287 |
|
| 288 |
vec_t in_grad_vec;
|
| 289 |
#pragma unroll
|
| 290 |
for (int i = 0; i < width; ++i) {
|
| 291 |
-
acc_t
|
| 292 |
acc_t dy = dy_vec.data[i];
|
| 293 |
acc_t w = w_vec.data[i];
|
| 294 |
-
|
| 295 |
-
// dxx*y/(w*scale)
|
| 296 |
-
in_grad_vec.data[i] = scale * dy * w - dxx * y / (w * scale);
|
| 297 |
}
|
| 298 |
input_grad_vec[vec_offset + vidx] = in_grad_vec;
|
| 299 |
}
|
|
@@ -302,8 +296,7 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
|
|
| 302 |
// ---------------------------------------------------------------------------
|
| 303 |
// Backward fused (dim ≤ 8192): multi-token block, input_grad + weight_grad
|
| 304 |
// Architecture: 32 tokens/block, 128 threads, dynamic shared mem [d] fp32
|
| 305 |
-
//
|
| 306 |
-
// d_sum = sum(dy * x * w), input_grad = scale*dy*w - dxx*x
|
| 307 |
// 2-token batched float2 reduction (syncthreads halved)
|
| 308 |
// Weight grad accumulated in shared memory → atomicAdd to global
|
| 309 |
// Single kernel: no separate weight_grad kernel needed
|
|
@@ -313,13 +306,13 @@ __global__ void rms_norm_bwd_fused(
|
|
| 313 |
scalar_t *__restrict__ input_grad,
|
| 314 |
acc_t *__restrict__ weight_grad_acc, // [d] — atomicAdd target
|
| 315 |
const scalar_t *__restrict__ output_grad,
|
| 316 |
-
const scalar_t *__restrict__
|
| 317 |
-
const
|
| 318 |
-
const int
|
| 319 |
|
| 320 |
using vec_t = type_vec_t<scalar_t, width>;
|
| 321 |
|
| 322 |
-
extern __shared__ acc_t wg_shared[]; // [d] float32 — accumulates
|
| 323 |
|
| 324 |
const int vec_d = d / width;
|
| 325 |
|
|
@@ -328,8 +321,7 @@ __global__ void rms_norm_bwd_fused(
|
|
| 328 |
wg_shared[idx] = 0.0f;
|
| 329 |
__syncthreads();
|
| 330 |
|
| 331 |
-
const vec_t *__restrict__
|
| 332 |
-
reinterpret_cast<const vec_t *>(output);
|
| 333 |
const vec_t *__restrict__ output_grad_vec =
|
| 334 |
reinterpret_cast<const vec_t *>(output_grad);
|
| 335 |
const vec_t *__restrict__ weight_vec =
|
|
@@ -345,50 +337,52 @@ __global__ void rms_norm_bwd_fused(
|
|
| 345 |
acc_t s0 = inv_rms[t], s1 = inv_rms[t + 1];
|
| 346 |
int64_t off0 = t * vec_d, off1 = (t + 1) * vec_d;
|
| 347 |
|
| 348 |
-
// Pass 1:
|
| 349 |
-
acc_t
|
| 350 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 351 |
-
vec_t
|
| 352 |
vec_t dy0 = output_grad_vec[off0 + vidx];
|
| 353 |
-
vec_t
|
| 354 |
vec_t dy1 = output_grad_vec[off1 + vidx];
|
|
|
|
| 355 |
#pragma unroll
|
| 356 |
for (int i = 0; i < width; ++i) {
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
|
|
|
| 361 |
}
|
| 362 |
}
|
| 363 |
|
| 364 |
-
float2 sums = block_reduce_sum2(make_float2(
|
| 365 |
|
| 366 |
-
// dxx =
|
| 367 |
__shared__ acc_t sd0, sd1;
|
| 368 |
if (threadIdx.x == 0) {
|
| 369 |
-
sd0 = sums.x * s0 * s0 / d;
|
| 370 |
-
sd1 = sums.y * s1 * s1 / d;
|
| 371 |
}
|
| 372 |
__syncthreads();
|
| 373 |
acc_t dxx0 = sd0, dxx1 = sd1;
|
| 374 |
|
| 375 |
-
// Pass 2: input_grad + wg_shared accumulate (L1 hit on
|
| 376 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 377 |
-
vec_t
|
| 378 |
vec_t dy0 = output_grad_vec[off0 + vidx];
|
| 379 |
-
vec_t
|
| 380 |
vec_t dy1 = output_grad_vec[off1 + vidx];
|
| 381 |
vec_t w = weight_vec[vidx];
|
| 382 |
|
| 383 |
vec_t g0, g1;
|
| 384 |
#pragma unroll
|
| 385 |
for (int i = 0; i < width; ++i) {
|
| 386 |
-
acc_t
|
| 387 |
-
acc_t
|
| 388 |
-
g0.data[i] = s0 * di0 * wi - dxx0 *
|
| 389 |
-
g1.data[i] = s1 * di1 * wi - dxx1 *
|
| 390 |
-
//
|
| 391 |
-
wg_shared[vidx * width + i] += di0 *
|
| 392 |
}
|
| 393 |
input_grad_vec[off0 + vidx] = g0;
|
| 394 |
input_grad_vec[off1 + vidx] = g1;
|
|
@@ -399,44 +393,44 @@ __global__ void rms_norm_bwd_fused(
|
|
| 399 |
if (t < token_end) {
|
| 400 |
acc_t scale = inv_rms[t];
|
| 401 |
int64_t vec_offset = t * vec_d;
|
| 402 |
-
acc_t
|
| 403 |
|
| 404 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 405 |
-
vec_t
|
| 406 |
vec_t dy_vec = output_grad_vec[vec_offset + vidx];
|
|
|
|
| 407 |
#pragma unroll
|
| 408 |
for (int i = 0; i < width; ++i)
|
| 409 |
-
|
| 410 |
-
|
|
|
|
| 411 |
}
|
| 412 |
-
|
| 413 |
|
| 414 |
__shared__ acc_t s_dxx;
|
| 415 |
if (threadIdx.x == 0)
|
| 416 |
-
s_dxx =
|
| 417 |
__syncthreads();
|
| 418 |
acc_t dxx = s_dxx;
|
| 419 |
|
| 420 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 421 |
-
vec_t
|
| 422 |
vec_t dy_vec = output_grad_vec[vec_offset + vidx];
|
| 423 |
-
vec_t w_vec = weight_vec[vidx];
|
| 424 |
vec_t gv;
|
| 425 |
#pragma unroll
|
| 426 |
for (int i = 0; i < width; ++i) {
|
| 427 |
-
acc_t
|
| 428 |
-
gv.data[i] =
|
| 429 |
-
|
|
|
|
| 430 |
}
|
| 431 |
input_grad_vec[vec_offset + vidx] = gv;
|
| 432 |
}
|
| 433 |
}
|
| 434 |
|
| 435 |
-
// AtomicAdd accumulated weight grad
|
| 436 |
-
const scalar_t *__restrict__ w_ptr = weight;
|
| 437 |
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 438 |
-
|
| 439 |
-
atomicAdd(&weight_grad_acc[idx], wg_shared[idx] / w);
|
| 440 |
}
|
| 441 |
}
|
| 442 |
|
|
@@ -444,7 +438,7 @@ __global__ void rms_norm_bwd_fused(
|
|
| 444 |
// Backward weight_grad (dim > 8192): column-parallel with vec8 loads
|
| 445 |
// Architecture: 2D grid (col_blocks × token_chunks), 256 threads
|
| 446 |
// Each thread handles 8 columns (vec8), iterates over chunk_size tokens
|
| 447 |
-
//
|
| 448 |
// Writes partial_wg[chunk, d], host reduces with at::sum_out
|
| 449 |
// __launch_bounds__(256, 4)
|
| 450 |
// ---------------------------------------------------------------------------
|
|
@@ -452,11 +446,11 @@ template <typename scalar_t, typename acc_t, int TILE_T, int VEC_W>
|
|
| 452 |
__global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_weight_grad(
|
| 453 |
acc_t *__restrict__ partial_wg, // [num_chunks, d]
|
| 454 |
const scalar_t *__restrict__ output_grad, // [num_tokens, d]
|
| 455 |
-
const scalar_t *__restrict__
|
| 456 |
-
const scalar_t *__restrict__ weight,
|
|
|
|
| 457 |
const int d, const int64_t num_tokens, const int64_t chunk_size) {
|
| 458 |
-
// weight_grad_i = (
|
| 459 |
-
// No inv_rms needed! Just dy, y, and w.
|
| 460 |
|
| 461 |
using vec_t = type_vec_t<scalar_t, VEC_W>;
|
| 462 |
|
|
@@ -465,19 +459,15 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_weight_grad(
|
|
| 465 |
const int64_t token_start = blockIdx.y * chunk_size;
|
| 466 |
const int64_t token_end = min(token_start + chunk_size, num_tokens);
|
| 467 |
|
| 468 |
-
const vec_t *__restrict__
|
| 469 |
-
reinterpret_cast<const vec_t *>(output);
|
| 470 |
const vec_t *__restrict__ grad_vec =
|
| 471 |
reinterpret_cast<const vec_t *>(output_grad);
|
| 472 |
-
const vec_t *__restrict__ weight_vec =
|
| 473 |
-
reinterpret_cast<const vec_t *>(weight);
|
| 474 |
const int vec_d = d / VEC_W;
|
| 475 |
|
| 476 |
-
|
| 477 |
-
acc_t dy_y_acc[VEC_W];
|
| 478 |
#pragma unroll
|
| 479 |
for (int i = 0; i < VEC_W; i++)
|
| 480 |
-
|
| 481 |
|
| 482 |
int vec_col = blockIdx.x * blockDim.x + threadIdx.x;
|
| 483 |
|
|
@@ -487,23 +477,22 @@ __global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_weight_grad(
|
|
| 487 |
if (vec_col < vec_d) {
|
| 488 |
for (int r = 0; r < tile_size; r++) {
|
| 489 |
int64_t t = t_base + r;
|
| 490 |
-
|
|
|
|
| 491 |
vec_t dy_v = grad_vec[t * vec_d + vec_col];
|
| 492 |
#pragma unroll
|
| 493 |
for (int i = 0; i < VEC_W; i++) {
|
| 494 |
-
|
| 495 |
-
|
| 496 |
}
|
| 497 |
}
|
| 498 |
}
|
| 499 |
}
|
| 500 |
|
| 501 |
-
// Write results
|
| 502 |
if (vec_col < vec_d) {
|
| 503 |
-
vec_t w_v = weight_vec[vec_col];
|
| 504 |
for (int i = 0; i < VEC_W; i++) {
|
| 505 |
-
|
| 506 |
-
partial_wg[blockIdx.y * d + col_base + i] = dy_y_acc[i] / w;
|
| 507 |
}
|
| 508 |
}
|
| 509 |
}
|
|
@@ -665,24 +654,22 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
|
|
| 665 |
block_size = ((block_size + 31) / 32) * 32;
|
| 666 |
block_size = std::max(block_size, 32);
|
| 667 |
size_t smem = d * sizeof(float);
|
| 668 |
-
//
|
| 669 |
-
// autograd)
|
| 670 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 671 |
output.scalar_type(), "rms_norm_bwd_fused", [&] {
|
| 672 |
motif::rms_norm_bwd_fused<scalar_t, float, 8>
|
| 673 |
<<<num_blocks_mt, block_size, smem, stream>>>(
|
| 674 |
input_grad.data_ptr<scalar_t>(), wg_acc.data_ptr<float>(),
|
| 675 |
output_grad.data_ptr<scalar_t>(),
|
| 676 |
-
output.data_ptr<scalar_t>(),
|
| 677 |
-
|
| 678 |
-
num_tokens, tpb);
|
| 679 |
});
|
| 680 |
|
| 681 |
if (weight_grad.defined()) {
|
| 682 |
weight_grad.copy_(wg_acc);
|
| 683 |
}
|
| 684 |
} else {
|
| 685 |
-
// Large dims (d > 8192):
|
| 686 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 687 |
output.scalar_type(), "rms_norm_bwd_large_input_grad", [&] {
|
| 688 |
motif::rms_norm_bwd_large_input_grad<scalar_t, float, 8>
|
|
@@ -712,7 +699,7 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
|
|
| 712 |
partial_wg.data_ptr<float>(),
|
| 713 |
output_grad.data_ptr<scalar_t>(),
|
| 714 |
output.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
|
| 715 |
-
d, num_tokens, chunk_size);
|
| 716 |
});
|
| 717 |
|
| 718 |
torch::Tensor acc =
|
|
|
|
| 223 |
// Backward: single-pass with register caching (NVECS vecs per thread)
|
| 224 |
// ---------------------------------------------------------------------------
|
| 225 |
// ---------------------------------------------------------------------------
|
| 226 |
+
// Backward input_grad (dim > 8192): input-based, 2-pass
|
| 227 |
// Architecture: 1 block/token, 256 threads, __launch_bounds__(256, 4)
|
| 228 |
+
// Pass 1: read dy + x + w → d_sum = sum(dy * x * w)
|
| 229 |
+
// Pass 2: read dy + x + w (L1 hit) → input_grad = scale*dy*w - dxx*x
|
|
|
|
|
|
|
|
|
|
| 230 |
// Weight grad computed by separate bwd_large_weight_grad kernel
|
| 231 |
// ---------------------------------------------------------------------------
|
| 232 |
template <typename scalar_t, typename acc_t, int width>
|
| 233 |
__global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_input_grad(
|
| 234 |
scalar_t *__restrict__ input_grad,
|
| 235 |
const scalar_t *__restrict__ output_grad, // dy
|
| 236 |
+
const scalar_t *__restrict__ input, // x
|
| 237 |
const scalar_t *__restrict__ weight, const acc_t *__restrict__ inv_rms,
|
| 238 |
const int d) {
|
| 239 |
using vec_t = type_vec_t<scalar_t, width>;
|
|
|
|
| 242 |
const int vec_d = d / width;
|
| 243 |
const int64_t vec_offset = token_idx * vec_d;
|
| 244 |
|
| 245 |
+
const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
|
|
|
|
| 246 |
const vec_t *__restrict__ output_grad_vec =
|
| 247 |
reinterpret_cast<const vec_t *>(output_grad);
|
| 248 |
const vec_t *__restrict__ weight_vec =
|
|
|
|
| 250 |
|
| 251 |
acc_t scale = inv_rms[token_idx];
|
| 252 |
|
| 253 |
+
// Pass 1: d_sum = sum(dy * x * w)
|
| 254 |
+
acc_t d_sum = 0.0f;
|
| 255 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 256 |
+
vec_t x_vec = input_vec[vec_offset + vidx];
|
| 257 |
vec_t dy_vec = output_grad_vec[vec_offset + vidx];
|
| 258 |
+
vec_t w_vec = weight_vec[vidx];
|
| 259 |
#pragma unroll
|
| 260 |
for (int i = 0; i < width; ++i) {
|
| 261 |
+
d_sum += static_cast<acc_t>(dy_vec.data[i]) *
|
| 262 |
+
static_cast<acc_t>(x_vec.data[i]) *
|
| 263 |
+
static_cast<acc_t>(w_vec.data[i]);
|
| 264 |
}
|
| 265 |
}
|
| 266 |
|
| 267 |
+
d_sum = block_reduce_sum(d_sum);
|
| 268 |
|
|
|
|
|
|
|
| 269 |
__shared__ acc_t s_dxx;
|
| 270 |
if (threadIdx.x == 0) {
|
| 271 |
+
s_dxx = d_sum * scale * scale * scale / d;
|
| 272 |
}
|
| 273 |
__syncthreads();
|
| 274 |
acc_t dxx = s_dxx;
|
| 275 |
|
| 276 |
+
// Pass 2: input_grad = scale * dy * w - dxx * x
|
| 277 |
vec_t *__restrict__ input_grad_vec = reinterpret_cast<vec_t *>(input_grad);
|
| 278 |
|
| 279 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 280 |
+
vec_t x_vec = input_vec[vec_offset + vidx];
|
| 281 |
vec_t dy_vec = output_grad_vec[vec_offset + vidx];
|
| 282 |
vec_t w_vec = weight_vec[vidx];
|
| 283 |
|
| 284 |
vec_t in_grad_vec;
|
| 285 |
#pragma unroll
|
| 286 |
for (int i = 0; i < width; ++i) {
|
| 287 |
+
acc_t x = x_vec.data[i];
|
| 288 |
acc_t dy = dy_vec.data[i];
|
| 289 |
acc_t w = w_vec.data[i];
|
| 290 |
+
in_grad_vec.data[i] = scale * dy * w - dxx * x;
|
|
|
|
|
|
|
| 291 |
}
|
| 292 |
input_grad_vec[vec_offset + vidx] = in_grad_vec;
|
| 293 |
}
|
|
|
|
| 296 |
// ---------------------------------------------------------------------------
|
| 297 |
// Backward fused (dim ≤ 8192): multi-token block, input_grad + weight_grad
|
| 298 |
// Architecture: 32 tokens/block, 128 threads, dynamic shared mem [d] fp32
|
| 299 |
+
// Input-based math: d_sum = sum(dy * x * w), input_grad = scale*dy*w - dxx*x
|
|
|
|
| 300 |
// 2-token batched float2 reduction (syncthreads halved)
|
| 301 |
// Weight grad accumulated in shared memory → atomicAdd to global
|
| 302 |
// Single kernel: no separate weight_grad kernel needed
|
|
|
|
| 306 |
scalar_t *__restrict__ input_grad,
|
| 307 |
acc_t *__restrict__ weight_grad_acc, // [d] — atomicAdd target
|
| 308 |
const scalar_t *__restrict__ output_grad,
|
| 309 |
+
const scalar_t *__restrict__ input, const scalar_t *__restrict__ weight,
|
| 310 |
+
const acc_t *__restrict__ inv_rms, const int d, const int64_t num_tokens,
|
| 311 |
+
const int tpb) {
|
| 312 |
|
| 313 |
using vec_t = type_vec_t<scalar_t, width>;
|
| 314 |
|
| 315 |
+
extern __shared__ acc_t wg_shared[]; // [d] float32 — accumulates weight grad
|
| 316 |
|
| 317 |
const int vec_d = d / width;
|
| 318 |
|
|
|
|
| 321 |
wg_shared[idx] = 0.0f;
|
| 322 |
__syncthreads();
|
| 323 |
|
| 324 |
+
const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
|
|
|
|
| 325 |
const vec_t *__restrict__ output_grad_vec =
|
| 326 |
reinterpret_cast<const vec_t *>(output_grad);
|
| 327 |
const vec_t *__restrict__ weight_vec =
|
|
|
|
| 337 |
acc_t s0 = inv_rms[t], s1 = inv_rms[t + 1];
|
| 338 |
int64_t off0 = t * vec_d, off1 = (t + 1) * vec_d;
|
| 339 |
|
| 340 |
+
// Pass 1: d_sum = sum(dy * x * w)
|
| 341 |
+
acc_t dsum0 = 0.0f, dsum1 = 0.0f;
|
| 342 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 343 |
+
vec_t x0 = input_vec[off0 + vidx];
|
| 344 |
vec_t dy0 = output_grad_vec[off0 + vidx];
|
| 345 |
+
vec_t x1 = input_vec[off1 + vidx];
|
| 346 |
vec_t dy1 = output_grad_vec[off1 + vidx];
|
| 347 |
+
vec_t w = weight_vec[vidx];
|
| 348 |
#pragma unroll
|
| 349 |
for (int i = 0; i < width; ++i) {
|
| 350 |
+
acc_t wi = w.data[i];
|
| 351 |
+
dsum0 += static_cast<acc_t>(dy0.data[i]) *
|
| 352 |
+
static_cast<acc_t>(x0.data[i]) * wi;
|
| 353 |
+
dsum1 += static_cast<acc_t>(dy1.data[i]) *
|
| 354 |
+
static_cast<acc_t>(x1.data[i]) * wi;
|
| 355 |
}
|
| 356 |
}
|
| 357 |
|
| 358 |
+
float2 sums = block_reduce_sum2(make_float2(dsum0, dsum1));
|
| 359 |
|
| 360 |
+
// dxx = d_sum * scale^3 / d
|
| 361 |
__shared__ acc_t sd0, sd1;
|
| 362 |
if (threadIdx.x == 0) {
|
| 363 |
+
sd0 = sums.x * s0 * s0 * s0 / d;
|
| 364 |
+
sd1 = sums.y * s1 * s1 * s1 / d;
|
| 365 |
}
|
| 366 |
__syncthreads();
|
| 367 |
acc_t dxx0 = sd0, dxx1 = sd1;
|
| 368 |
|
| 369 |
+
// Pass 2: input_grad + wg_shared accumulate (L1 hit on x, dy)
|
| 370 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 371 |
+
vec_t x0 = input_vec[off0 + vidx];
|
| 372 |
vec_t dy0 = output_grad_vec[off0 + vidx];
|
| 373 |
+
vec_t x1 = input_vec[off1 + vidx];
|
| 374 |
vec_t dy1 = output_grad_vec[off1 + vidx];
|
| 375 |
vec_t w = weight_vec[vidx];
|
| 376 |
|
| 377 |
vec_t g0, g1;
|
| 378 |
#pragma unroll
|
| 379 |
for (int i = 0; i < width; ++i) {
|
| 380 |
+
acc_t xi0 = x0.data[i], di0 = dy0.data[i], wi = w.data[i];
|
| 381 |
+
acc_t xi1 = x1.data[i], di1 = dy1.data[i];
|
| 382 |
+
g0.data[i] = s0 * di0 * wi - dxx0 * xi0;
|
| 383 |
+
g1.data[i] = s1 * di1 * wi - dxx1 * xi1;
|
| 384 |
+
// weight_grad = dy * x * scale
|
| 385 |
+
wg_shared[vidx * width + i] += di0 * xi0 * s0 + di1 * xi1 * s1;
|
| 386 |
}
|
| 387 |
input_grad_vec[off0 + vidx] = g0;
|
| 388 |
input_grad_vec[off1 + vidx] = g1;
|
|
|
|
| 393 |
if (t < token_end) {
|
| 394 |
acc_t scale = inv_rms[t];
|
| 395 |
int64_t vec_offset = t * vec_d;
|
| 396 |
+
acc_t d_sum = 0.0f;
|
| 397 |
|
| 398 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 399 |
+
vec_t x_vec = input_vec[vec_offset + vidx];
|
| 400 |
vec_t dy_vec = output_grad_vec[vec_offset + vidx];
|
| 401 |
+
vec_t w_vec = weight_vec[vidx];
|
| 402 |
#pragma unroll
|
| 403 |
for (int i = 0; i < width; ++i)
|
| 404 |
+
d_sum += static_cast<acc_t>(dy_vec.data[i]) *
|
| 405 |
+
static_cast<acc_t>(x_vec.data[i]) *
|
| 406 |
+
static_cast<acc_t>(w_vec.data[i]);
|
| 407 |
}
|
| 408 |
+
d_sum = block_reduce_sum(d_sum);
|
| 409 |
|
| 410 |
__shared__ acc_t s_dxx;
|
| 411 |
if (threadIdx.x == 0)
|
| 412 |
+
s_dxx = d_sum * scale * scale * scale / d;
|
| 413 |
__syncthreads();
|
| 414 |
acc_t dxx = s_dxx;
|
| 415 |
|
| 416 |
for (int64_t vidx = threadIdx.x; vidx < vec_d; vidx += blockDim.x) {
|
| 417 |
+
vec_t x_vec = input_vec[vec_offset + vidx];
|
| 418 |
vec_t dy_vec = output_grad_vec[vec_offset + vidx];
|
|
|
|
| 419 |
vec_t gv;
|
| 420 |
#pragma unroll
|
| 421 |
for (int i = 0; i < width; ++i) {
|
| 422 |
+
acc_t x = x_vec.data[i], dy = dy_vec.data[i];
|
| 423 |
+
gv.data[i] =
|
| 424 |
+
scale * dy * static_cast<acc_t>(weight_vec[vidx].data[i]) - dxx * x;
|
| 425 |
+
wg_shared[vidx * width + i] += dy * x * scale;
|
| 426 |
}
|
| 427 |
input_grad_vec[vec_offset + vidx] = gv;
|
| 428 |
}
|
| 429 |
}
|
| 430 |
|
| 431 |
+
// AtomicAdd accumulated weight grad
|
|
|
|
| 432 |
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 433 |
+
atomicAdd(&weight_grad_acc[idx], wg_shared[idx]);
|
|
|
|
| 434 |
}
|
| 435 |
}
|
| 436 |
|
|
|
|
| 438 |
// Backward weight_grad (dim > 8192): column-parallel with vec8 loads
|
| 439 |
// Architecture: 2D grid (col_blocks × token_chunks), 256 threads
|
| 440 |
// Each thread handles 8 columns (vec8), iterates over chunk_size tokens
|
| 441 |
+
// Input-based: weight_grad_i = sum_t(dy * x * scale)
|
| 442 |
// Writes partial_wg[chunk, d], host reduces with at::sum_out
|
| 443 |
// __launch_bounds__(256, 4)
|
| 444 |
// ---------------------------------------------------------------------------
|
|
|
|
| 446 |
__global__ __launch_bounds__(256, 4) void rms_norm_bwd_large_weight_grad(
|
| 447 |
acc_t *__restrict__ partial_wg, // [num_chunks, d]
|
| 448 |
const scalar_t *__restrict__ output_grad, // [num_tokens, d]
|
| 449 |
+
const scalar_t *__restrict__ input, // [num_tokens, d]
|
| 450 |
+
const scalar_t *__restrict__ weight, // [d]
|
| 451 |
+
const acc_t *__restrict__ inv_rms, // [num_tokens]
|
| 452 |
const int d, const int64_t num_tokens, const int64_t chunk_size) {
|
| 453 |
+
// weight_grad_i = sum_t(dy_t * x_t * scale_t)
|
|
|
|
| 454 |
|
| 455 |
using vec_t = type_vec_t<scalar_t, VEC_W>;
|
| 456 |
|
|
|
|
| 459 |
const int64_t token_start = blockIdx.y * chunk_size;
|
| 460 |
const int64_t token_end = min(token_start + chunk_size, num_tokens);
|
| 461 |
|
| 462 |
+
const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
|
|
|
|
| 463 |
const vec_t *__restrict__ grad_vec =
|
| 464 |
reinterpret_cast<const vec_t *>(output_grad);
|
|
|
|
|
|
|
| 465 |
const int vec_d = d / VEC_W;
|
| 466 |
|
| 467 |
+
acc_t wg_acc[VEC_W];
|
|
|
|
| 468 |
#pragma unroll
|
| 469 |
for (int i = 0; i < VEC_W; i++)
|
| 470 |
+
wg_acc[i] = 0.0f;
|
| 471 |
|
| 472 |
int vec_col = blockIdx.x * blockDim.x + threadIdx.x;
|
| 473 |
|
|
|
|
| 477 |
if (vec_col < vec_d) {
|
| 478 |
for (int r = 0; r < tile_size; r++) {
|
| 479 |
int64_t t = t_base + r;
|
| 480 |
+
acc_t scale = inv_rms[t];
|
| 481 |
+
vec_t x_v = input_vec[t * vec_d + vec_col];
|
| 482 |
vec_t dy_v = grad_vec[t * vec_d + vec_col];
|
| 483 |
#pragma unroll
|
| 484 |
for (int i = 0; i < VEC_W; i++) {
|
| 485 |
+
wg_acc[i] += static_cast<acc_t>(dy_v.data[i]) *
|
| 486 |
+
static_cast<acc_t>(x_v.data[i]) * scale;
|
| 487 |
}
|
| 488 |
}
|
| 489 |
}
|
| 490 |
}
|
| 491 |
|
| 492 |
+
// Write results directly
|
| 493 |
if (vec_col < vec_d) {
|
|
|
|
| 494 |
for (int i = 0; i < VEC_W; i++) {
|
| 495 |
+
partial_wg[blockIdx.y * d + col_base + i] = wg_acc[i];
|
|
|
|
| 496 |
}
|
| 497 |
}
|
| 498 |
}
|
|
|
|
| 654 |
block_size = ((block_size + 31) / 32) * 32;
|
| 655 |
block_size = std::max(block_size, 32);
|
| 656 |
size_t smem = d * sizeof(float);
|
| 657 |
+
// 'output' C++ arg receives input (saved by Python autograd)
|
|
|
|
| 658 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 659 |
output.scalar_type(), "rms_norm_bwd_fused", [&] {
|
| 660 |
motif::rms_norm_bwd_fused<scalar_t, float, 8>
|
| 661 |
<<<num_blocks_mt, block_size, smem, stream>>>(
|
| 662 |
input_grad.data_ptr<scalar_t>(), wg_acc.data_ptr<float>(),
|
| 663 |
output_grad.data_ptr<scalar_t>(),
|
| 664 |
+
output.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
|
| 665 |
+
inv_rms.data_ptr<float>(), d, num_tokens, tpb);
|
|
|
|
| 666 |
});
|
| 667 |
|
| 668 |
if (weight_grad.defined()) {
|
| 669 |
weight_grad.copy_(wg_acc);
|
| 670 |
}
|
| 671 |
} else {
|
| 672 |
+
// Large dims (d > 8192): input-based bwd + column-parallel weight grad
|
| 673 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 674 |
output.scalar_type(), "rms_norm_bwd_large_input_grad", [&] {
|
| 675 |
motif::rms_norm_bwd_large_input_grad<scalar_t, float, 8>
|
|
|
|
| 699 |
partial_wg.data_ptr<float>(),
|
| 700 |
output_grad.data_ptr<scalar_t>(),
|
| 701 |
output.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
|
| 702 |
+
inv_rms.data_ptr<float>(), d, num_tokens, chunk_size);
|
| 703 |
});
|
| 704 |
|
| 705 |
torch::Tensor acc =
|
tests/pytest.ini
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
[pytest]
|
| 2 |
log_cli = true
|
| 3 |
log_cli_level = INFO
|
|
|
|
|
|
| 1 |
[pytest]
|
| 2 |
log_cli = true
|
| 3 |
log_cli_level = INFO
|
| 4 |
+
pythonpath = ../torch-ext
|
torch-ext/activation/rms_norm.py
CHANGED
|
@@ -15,13 +15,7 @@ class RMSNormFunction(torch.autograd.Function):
|
|
| 15 |
def setup_context(ctx, inputs, outputs):
|
| 16 |
input, weight, eps = inputs
|
| 17 |
output, inv_rms = outputs
|
| 18 |
-
|
| 19 |
-
# Large dims: save output (output-based backward, less memory traffic)
|
| 20 |
-
# Small dims: save input (multitoken backward, avoids division overhead)
|
| 21 |
-
if d > 8192:
|
| 22 |
-
ctx.save_for_backward(output, weight, inv_rms)
|
| 23 |
-
else:
|
| 24 |
-
ctx.save_for_backward(input, weight, inv_rms)
|
| 25 |
ctx.eps = eps
|
| 26 |
|
| 27 |
@staticmethod
|
|
|
|
| 15 |
def setup_context(ctx, inputs, outputs):
|
| 16 |
input, weight, eps = inputs
|
| 17 |
output, inv_rms = outputs
|
| 18 |
+
ctx.save_for_backward(input, weight, inv_rms)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
ctx.eps = eps
|
| 20 |
|
| 21 |
@staticmethod
|