feat: add grouped poly norm CUDA kernel with scores and hidden_clamp fusion
Browse filesHand-written CUDA kernel for GroupedFusedMulPolyNorm (MoE).
Fuses polynomial normalization, mul, scores multiplication, and
hidden_clamp into a single kernel launch per forward/backward.
- 4 kernel variants: fwd/bwd × vectorized(width=8)/scalar
- scores: nullable, always fp32, fused into fwd output and bwd gradients
- hidden_clamp: < 0 disabled, >= 0 clamps input/mul/output with correct
backward gradient masking (recompute output for mask, no extra memory)
- ROCm compatible (64-bit warp sync mask)
- C++ op registration for torch.compile (register_fake in Python layer)
- ~3x faster than torch.compile'd PyTorch reference on B200
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- activation/grouped_poly_norm.cu +657 -0
- build.toml +2 -0
- torch-ext/torch_binding.cpp +30 -0
- torch-ext/torch_binding.h +30 -0
activation/grouped_poly_norm.cu
ADDED
|
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/Functions.h>
|
| 2 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 3 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 4 |
+
#include <torch/all.h>
|
| 5 |
+
|
| 6 |
+
#include "assert_utils.h"
|
| 7 |
+
#include "cuda_compat.h"
|
| 8 |
+
#include "dispatch_utils.h"
|
| 9 |
+
|
| 10 |
+
namespace motif {
|
| 11 |
+
|
| 12 |
+
template <typename type, int N> struct alignas(sizeof(type) * N) vec_t {
|
| 13 |
+
type data[N];
|
| 14 |
+
};
|
| 15 |
+
|
| 16 |
+
__device__ __forceinline__ int find_expert(const int32_t *__restrict__ offsets,
|
| 17 |
+
int num_experts, int row) {
|
| 18 |
+
int lo = 0, hi = num_experts;
|
| 19 |
+
#pragma unroll 6
|
| 20 |
+
for (int i = 0; i < 12; ++i) {
|
| 21 |
+
if (lo >= hi) break;
|
| 22 |
+
int mid = (lo + hi) >> 1;
|
| 23 |
+
if (offsets[mid] <= row)
|
| 24 |
+
lo = mid + 1;
|
| 25 |
+
else
|
| 26 |
+
hi = mid;
|
| 27 |
+
}
|
| 28 |
+
return lo;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
__device__ __forceinline__ float4 warp_reduce_f4(float4 v) {
|
| 32 |
+
#ifndef USE_ROCM
|
| 33 |
+
constexpr unsigned int FULL_MASK = 0xffffffff;
|
| 34 |
+
#else
|
| 35 |
+
constexpr unsigned long long FULL_MASK = 0xffffffffffffffffULL;
|
| 36 |
+
#endif
|
| 37 |
+
#pragma unroll
|
| 38 |
+
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
| 39 |
+
v.x += __shfl_xor_sync(FULL_MASK, v.x, mask);
|
| 40 |
+
v.y += __shfl_xor_sync(FULL_MASK, v.y, mask);
|
| 41 |
+
v.z += __shfl_xor_sync(FULL_MASK, v.z, mask);
|
| 42 |
+
v.w += __shfl_xor_sync(FULL_MASK, v.w, mask);
|
| 43 |
+
}
|
| 44 |
+
return v;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
template <int BLOCK_SIZE>
|
| 48 |
+
__device__ __forceinline__ float4 block_reduce_f4(float4 v) {
|
| 49 |
+
constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE;
|
| 50 |
+
__shared__ float4 warp_results[NUM_WARPS];
|
| 51 |
+
|
| 52 |
+
v = warp_reduce_f4(v);
|
| 53 |
+
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 54 |
+
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 55 |
+
|
| 56 |
+
if (lane_id == 0) warp_results[warp_id] = v;
|
| 57 |
+
__syncthreads();
|
| 58 |
+
|
| 59 |
+
if (warp_id == 0 && lane_id < NUM_WARPS)
|
| 60 |
+
v = warp_results[lane_id];
|
| 61 |
+
else
|
| 62 |
+
v = make_float4(0.f, 0.f, 0.f, 0.f);
|
| 63 |
+
|
| 64 |
+
if (warp_id == 0) v = warp_reduce_f4(v);
|
| 65 |
+
|
| 66 |
+
__shared__ float4 result;
|
| 67 |
+
if (threadIdx.x == 0) result = v;
|
| 68 |
+
__syncthreads();
|
| 69 |
+
return result;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
// ---------------------------------------------------------------------------
|
| 73 |
+
// Grouped PolyNorm Forward — vectorized (width > 0)
|
| 74 |
+
// Pass 1: accumulate sum_x2, sum_x4, sum_x6 for RMS stats
|
| 75 |
+
// Pass 2: compute poly * mul output, save inv_rms for backward
|
| 76 |
+
// ---------------------------------------------------------------------------
|
| 77 |
+
template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
|
| 78 |
+
__global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
| 79 |
+
grouped_poly_norm_fwd_kernel(
|
| 80 |
+
scalar_t *__restrict__ output,
|
| 81 |
+
acc_t *__restrict__ inv_rms,
|
| 82 |
+
const scalar_t *__restrict__ input,
|
| 83 |
+
const scalar_t *__restrict__ mul,
|
| 84 |
+
const scalar_t *__restrict__ weight,
|
| 85 |
+
const scalar_t *__restrict__ bias,
|
| 86 |
+
const int32_t *__restrict__ offsets,
|
| 87 |
+
const float *__restrict__ scores, // nullable, always fp32
|
| 88 |
+
const acc_t eps, const int D, const int num_experts,
|
| 89 |
+
const int expert_offset,
|
| 90 |
+
const acc_t hidden_clamp) { // < 0 = disabled
|
| 91 |
+
using v_t = vec_t<scalar_t, width>;
|
| 92 |
+
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 93 |
+
|
| 94 |
+
const int row = blockIdx.x;
|
| 95 |
+
const int vec_d = D / width;
|
| 96 |
+
const int64_t base = (int64_t)row * vec_d;
|
| 97 |
+
|
| 98 |
+
const v_t *__restrict__ in_v = reinterpret_cast<const v_t *>(input) + base;
|
| 99 |
+
|
| 100 |
+
const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
|
| 101 |
+
const acc_t w0 = weight[eidx * 3 + 0];
|
| 102 |
+
const acc_t w1 = weight[eidx * 3 + 1];
|
| 103 |
+
const acc_t w2 = weight[eidx * 3 + 2];
|
| 104 |
+
const acc_t b_val = bias[eidx];
|
| 105 |
+
const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
|
| 106 |
+
|
| 107 |
+
// Pass 1: RMS stats (on clamped input if enabled)
|
| 108 |
+
acc_t s2 = 0, s4 = 0, s6 = 0;
|
| 109 |
+
for (int i = threadIdx.x; i < vec_d; i += BLOCK_SIZE) {
|
| 110 |
+
v_t xv = in_v[i];
|
| 111 |
+
#pragma unroll
|
| 112 |
+
for (int j = 0; j < width; ++j) {
|
| 113 |
+
acc_t x = xv.data[j];
|
| 114 |
+
if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
|
| 115 |
+
acc_t x2 = x * x;
|
| 116 |
+
s2 += x2;
|
| 117 |
+
s4 += x2 * x2;
|
| 118 |
+
s6 += x2 * x2 * x2;
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(s2, s4, s6, 0.f));
|
| 123 |
+
|
| 124 |
+
const acc_t inv_d = acc_t(1) / D;
|
| 125 |
+
const acc_t ir1 = rsqrtf(sums.x * inv_d + eps);
|
| 126 |
+
const acc_t ir2 = rsqrtf(sums.y * inv_d + eps);
|
| 127 |
+
const acc_t ir3 = rsqrtf(sums.z * inv_d + eps);
|
| 128 |
+
|
| 129 |
+
// Save inv_rms for backward
|
| 130 |
+
if (threadIdx.x == 0) {
|
| 131 |
+
inv_rms[row * 3 + 0] = ir1;
|
| 132 |
+
inv_rms[row * 3 + 1] = ir2;
|
| 133 |
+
inv_rms[row * 3 + 2] = ir3;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
const acc_t w2ir1 = w2 * ir1;
|
| 137 |
+
const acc_t w1ir2 = w1 * ir2;
|
| 138 |
+
const acc_t w0ir3 = w0 * ir3;
|
| 139 |
+
|
| 140 |
+
// Pass 2: output = poly * mul (with clamping)
|
| 141 |
+
const v_t *__restrict__ m_v = reinterpret_cast<const v_t *>(mul) + base;
|
| 142 |
+
v_t *__restrict__ out_v = reinterpret_cast<v_t *>(output) + base;
|
| 143 |
+
|
| 144 |
+
for (int i = threadIdx.x; i < vec_d; i += BLOCK_SIZE) {
|
| 145 |
+
v_t xv = in_v[i];
|
| 146 |
+
v_t mv = m_v[i];
|
| 147 |
+
v_t ov;
|
| 148 |
+
#pragma unroll
|
| 149 |
+
for (int j = 0; j < width; ++j) {
|
| 150 |
+
acc_t x = xv.data[j];
|
| 151 |
+
if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
|
| 152 |
+
acc_t m = (acc_t)mv.data[j];
|
| 153 |
+
if (do_clamp) m = fminf(fmaxf(m, -hidden_clamp), hidden_clamp);
|
| 154 |
+
acc_t x2 = x * x;
|
| 155 |
+
acc_t x3 = x2 * x;
|
| 156 |
+
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 157 |
+
acc_t out_val = poly * m * score;
|
| 158 |
+
if (do_clamp) out_val = fminf(fmaxf(out_val, -hidden_clamp), hidden_clamp);
|
| 159 |
+
ov.data[j] = (scalar_t)out_val;
|
| 160 |
+
}
|
| 161 |
+
out_v[i] = ov;
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
// Scalar fallback forward
|
| 166 |
+
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
| 167 |
+
__global__ void __launch_bounds__(BLOCK_SIZE)
|
| 168 |
+
grouped_poly_norm_fwd_scalar(
|
| 169 |
+
scalar_t *__restrict__ output,
|
| 170 |
+
acc_t *__restrict__ inv_rms,
|
| 171 |
+
const scalar_t *__restrict__ input,
|
| 172 |
+
const scalar_t *__restrict__ mul,
|
| 173 |
+
const scalar_t *__restrict__ weight,
|
| 174 |
+
const scalar_t *__restrict__ bias,
|
| 175 |
+
const int32_t *__restrict__ offsets,
|
| 176 |
+
const float *__restrict__ scores, // nullable, always fp32
|
| 177 |
+
const acc_t eps, const int D, const int num_experts,
|
| 178 |
+
const int expert_offset,
|
| 179 |
+
const acc_t hidden_clamp) {
|
| 180 |
+
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 181 |
+
const int row = blockIdx.x;
|
| 182 |
+
const int64_t off = (int64_t)row * D;
|
| 183 |
+
|
| 184 |
+
const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
|
| 185 |
+
const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1], w2 = weight[eidx * 3 + 2];
|
| 186 |
+
const acc_t b_val = bias[eidx];
|
| 187 |
+
const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
|
| 188 |
+
|
| 189 |
+
acc_t s2 = 0, s4 = 0, s6 = 0;
|
| 190 |
+
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 191 |
+
acc_t x = input[off + i];
|
| 192 |
+
if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
|
| 193 |
+
acc_t x2 = x * x;
|
| 194 |
+
s2 += x2; s4 += x2 * x2; s6 += x2 * x2 * x2;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(s2, s4, s6, 0.f));
|
| 198 |
+
const acc_t inv_d = acc_t(1) / D;
|
| 199 |
+
const acc_t ir1 = rsqrtf(sums.x * inv_d + eps);
|
| 200 |
+
const acc_t ir2 = rsqrtf(sums.y * inv_d + eps);
|
| 201 |
+
const acc_t ir3 = rsqrtf(sums.z * inv_d + eps);
|
| 202 |
+
|
| 203 |
+
if (threadIdx.x == 0) {
|
| 204 |
+
inv_rms[row * 3] = ir1; inv_rms[row * 3 + 1] = ir2; inv_rms[row * 3 + 2] = ir3;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
|
| 208 |
+
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 209 |
+
acc_t x = input[off + i];
|
| 210 |
+
if (do_clamp) x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
|
| 211 |
+
acc_t m = (acc_t)mul[off + i];
|
| 212 |
+
if (do_clamp) m = fminf(fmaxf(m, -hidden_clamp), hidden_clamp);
|
| 213 |
+
acc_t x2 = x * x, x3 = x2 * x;
|
| 214 |
+
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 215 |
+
acc_t out_val = poly * m * score;
|
| 216 |
+
if (do_clamp) out_val = fminf(fmaxf(out_val, -hidden_clamp), hidden_clamp);
|
| 217 |
+
output[off + i] = (scalar_t)out_val;
|
| 218 |
+
}
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// ---------------------------------------------------------------------------
|
| 222 |
+
// Grouped PolyNorm Backward — vectorized (width > 0)
|
| 223 |
+
// Weight/bias grads use atomicAdd directly (no temp buffer + scatter_add).
|
| 224 |
+
// ---------------------------------------------------------------------------
|
| 225 |
+
template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
|
| 226 |
+
__global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
| 227 |
+
grouped_poly_norm_bwd_kernel(
|
| 228 |
+
scalar_t *__restrict__ grad_input,
|
| 229 |
+
scalar_t *__restrict__ grad_mul,
|
| 230 |
+
float *__restrict__ weight_grad, // [num_total_experts, 3] fp32
|
| 231 |
+
float *__restrict__ bias_grad, // [num_total_experts] fp32
|
| 232 |
+
const scalar_t *__restrict__ grad_output,
|
| 233 |
+
const scalar_t *__restrict__ input,
|
| 234 |
+
const scalar_t *__restrict__ mul,
|
| 235 |
+
const scalar_t *__restrict__ weight,
|
| 236 |
+
const scalar_t *__restrict__ bias,
|
| 237 |
+
const int32_t *__restrict__ offsets,
|
| 238 |
+
const acc_t *__restrict__ inv_rms,
|
| 239 |
+
const float *__restrict__ scores, // nullable, always fp32
|
| 240 |
+
acc_t *__restrict__ grad_scores, // nullable (null when scores is null)
|
| 241 |
+
const acc_t eps, const int D, const int num_experts,
|
| 242 |
+
const int expert_offset,
|
| 243 |
+
const acc_t hidden_clamp) {
|
| 244 |
+
using v_t = vec_t<scalar_t, width>;
|
| 245 |
+
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 246 |
+
|
| 247 |
+
const int row = blockIdx.x;
|
| 248 |
+
const int vec_d = D / width;
|
| 249 |
+
const int64_t base = (int64_t)row * vec_d;
|
| 250 |
+
|
| 251 |
+
const v_t *__restrict__ in_v = reinterpret_cast<const v_t *>(input) + base;
|
| 252 |
+
const v_t *__restrict__ go_v = reinterpret_cast<const v_t *>(grad_output) + base;
|
| 253 |
+
const v_t *__restrict__ m_v = reinterpret_cast<const v_t *>(mul) + base;
|
| 254 |
+
|
| 255 |
+
const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
|
| 256 |
+
const acc_t w0 = weight[eidx * 3 + 0];
|
| 257 |
+
const acc_t w1 = weight[eidx * 3 + 1];
|
| 258 |
+
const acc_t w2 = weight[eidx * 3 + 2];
|
| 259 |
+
const acc_t b_val = bias[eidx];
|
| 260 |
+
const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
|
| 261 |
+
|
| 262 |
+
const acc_t ir1 = inv_rms[row * 3 + 0];
|
| 263 |
+
const acc_t ir2 = inv_rms[row * 3 + 1];
|
| 264 |
+
const acc_t ir3 = inv_rms[row * 3 + 2];
|
| 265 |
+
|
| 266 |
+
const acc_t w2ir1 = w2 * ir1;
|
| 267 |
+
const acc_t w1ir2 = w1 * ir2;
|
| 268 |
+
const acc_t w0ir3 = w0 * ir3;
|
| 269 |
+
|
| 270 |
+
// ---- Pass 1: dot products (with clamp masks) ----
|
| 271 |
+
acc_t sdpx = 0, sdpx2 = 0, sdpx3 = 0, sdp = 0;
|
| 272 |
+
|
| 273 |
+
for (int i = threadIdx.x; i < vec_d; i += BLOCK_SIZE) {
|
| 274 |
+
v_t xv = in_v[i];
|
| 275 |
+
v_t gv = go_v[i];
|
| 276 |
+
v_t mv = m_v[i];
|
| 277 |
+
|
| 278 |
+
#pragma unroll
|
| 279 |
+
for (int j = 0; j < width; ++j) {
|
| 280 |
+
acc_t x_orig = xv.data[j];
|
| 281 |
+
acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
|
| 282 |
+
acc_t m_orig = (acc_t)mv.data[j];
|
| 283 |
+
acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
|
| 284 |
+
acc_t go = (acc_t)gv.data[j];
|
| 285 |
+
|
| 286 |
+
// Output clamp mask: recompute pre-clamp output
|
| 287 |
+
if (do_clamp) {
|
| 288 |
+
acc_t x2 = x * x, x3 = x2 * x;
|
| 289 |
+
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 290 |
+
acc_t out_pre = poly * m * score;
|
| 291 |
+
if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
acc_t x2 = x * x;
|
| 295 |
+
acc_t dp = go * m * score;
|
| 296 |
+
sdp += dp;
|
| 297 |
+
sdpx += dp * x;
|
| 298 |
+
sdpx2 += dp * x2;
|
| 299 |
+
sdpx3 += dp * x2 * x;
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(sdpx, sdpx2, sdpx3, sdp));
|
| 304 |
+
|
| 305 |
+
const acc_t inv_d = acc_t(1) / D;
|
| 306 |
+
const acc_t s1 = sums.x * inv_d;
|
| 307 |
+
const acc_t s2 = sums.y * inv_d;
|
| 308 |
+
const acc_t s3 = sums.z * inv_d;
|
| 309 |
+
const acc_t bias_grad_val = sums.w;
|
| 310 |
+
|
| 311 |
+
const acc_t cx = w2 * s1 * ir1 * ir1;
|
| 312 |
+
const acc_t cx2 = w1 * s2 * ir2 * ir2;
|
| 313 |
+
const acc_t cx3 = w0 * s3 * ir3 * ir3;
|
| 314 |
+
|
| 315 |
+
// ---- Pass 2: grad_input + grad_mul + weight grads + grad_scores ----
|
| 316 |
+
acc_t dw0 = 0, dw1 = 0, dw2 = 0, gs_acc = 0;
|
| 317 |
+
|
| 318 |
+
v_t *__restrict__ gi_v = reinterpret_cast<v_t *>(grad_input) + base;
|
| 319 |
+
v_t *__restrict__ gm_v = reinterpret_cast<v_t *>(grad_mul) + base;
|
| 320 |
+
|
| 321 |
+
for (int i = threadIdx.x; i < vec_d; i += BLOCK_SIZE) {
|
| 322 |
+
v_t xv = in_v[i];
|
| 323 |
+
v_t gv = go_v[i];
|
| 324 |
+
v_t mv = m_v[i];
|
| 325 |
+
v_t gi, gm;
|
| 326 |
+
|
| 327 |
+
#pragma unroll
|
| 328 |
+
for (int j = 0; j < width; ++j) {
|
| 329 |
+
acc_t x_orig = xv.data[j];
|
| 330 |
+
acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
|
| 331 |
+
acc_t m_orig = (acc_t)mv.data[j];
|
| 332 |
+
acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
|
| 333 |
+
acc_t x2 = x * x;
|
| 334 |
+
acc_t x3 = x2 * x;
|
| 335 |
+
acc_t go = (acc_t)gv.data[j];
|
| 336 |
+
|
| 337 |
+
// Output clamp mask
|
| 338 |
+
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
|
| 339 |
+
if (do_clamp) {
|
| 340 |
+
acc_t out_pre = (poly + b_val) * m * score;
|
| 341 |
+
if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
acc_t dp = go * m * score;
|
| 345 |
+
|
| 346 |
+
// grad_mul with mul clamp mask
|
| 347 |
+
acc_t gm_val = go * (poly + b_val) * score;
|
| 348 |
+
if (do_clamp && fabsf(m_orig) > hidden_clamp) gm_val = acc_t(0);
|
| 349 |
+
gm.data[j] = (scalar_t)gm_val;
|
| 350 |
+
|
| 351 |
+
// grad_input with input clamp mask
|
| 352 |
+
acc_t g = ir1 * (w2 * dp - x * cx);
|
| 353 |
+
g += acc_t(2) * x * ir2 * (w1 * dp - x2 * cx2);
|
| 354 |
+
g += acc_t(3) * x2 * ir3 * (w0 * dp - x3 * cx3);
|
| 355 |
+
if (do_clamp && fabsf(x_orig) > hidden_clamp) g = acc_t(0);
|
| 356 |
+
gi.data[j] = (scalar_t)g;
|
| 357 |
+
|
| 358 |
+
dw0 += dp * x3 * ir3;
|
| 359 |
+
dw1 += dp * x2 * ir2;
|
| 360 |
+
dw2 += dp * x * ir1;
|
| 361 |
+
gs_acc += go * (poly + b_val) * m; // grad_scores accumulator
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
gi_v[i] = gi;
|
| 365 |
+
gm_v[i] = gm;
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
// Reduce weight grads + grad_scores (.w channel)
|
| 369 |
+
float4 wg = block_reduce_f4<BLOCK_SIZE>(make_float4(dw0, dw1, dw2, gs_acc));
|
| 370 |
+
|
| 371 |
+
if (threadIdx.x == 0) {
|
| 372 |
+
atomicAdd(&weight_grad[eidx * 3 + 0], wg.x);
|
| 373 |
+
atomicAdd(&weight_grad[eidx * 3 + 1], wg.y);
|
| 374 |
+
atomicAdd(&weight_grad[eidx * 3 + 2], wg.z);
|
| 375 |
+
atomicAdd(&bias_grad[eidx], bias_grad_val);
|
| 376 |
+
if (grad_scores != nullptr) {
|
| 377 |
+
grad_scores[row] = wg.w;
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
// ---------------------------------------------------------------------------
|
| 383 |
+
// Scalar fallback (width == 0)
|
| 384 |
+
// ---------------------------------------------------------------------------
|
| 385 |
+
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
| 386 |
+
__global__ void __launch_bounds__(BLOCK_SIZE)
|
| 387 |
+
grouped_poly_norm_bwd_scalar(
|
| 388 |
+
scalar_t *__restrict__ grad_input,
|
| 389 |
+
scalar_t *__restrict__ grad_mul,
|
| 390 |
+
float *__restrict__ weight_grad,
|
| 391 |
+
float *__restrict__ bias_grad,
|
| 392 |
+
const scalar_t *__restrict__ grad_output,
|
| 393 |
+
const scalar_t *__restrict__ input,
|
| 394 |
+
const scalar_t *__restrict__ mul,
|
| 395 |
+
const scalar_t *__restrict__ weight,
|
| 396 |
+
const scalar_t *__restrict__ bias,
|
| 397 |
+
const int32_t *__restrict__ offsets,
|
| 398 |
+
const acc_t *__restrict__ inv_rms,
|
| 399 |
+
const float *__restrict__ scores, // nullable, always fp32
|
| 400 |
+
acc_t *__restrict__ grad_scores, // nullable
|
| 401 |
+
const acc_t eps, const int D, const int num_experts,
|
| 402 |
+
const int expert_offset,
|
| 403 |
+
const acc_t hidden_clamp) {
|
| 404 |
+
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 405 |
+
const int row = blockIdx.x;
|
| 406 |
+
const int64_t off = (int64_t)row * D;
|
| 407 |
+
|
| 408 |
+
const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
|
| 409 |
+
const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1], w2 = weight[eidx * 3 + 2];
|
| 410 |
+
const acc_t b_val = bias[eidx];
|
| 411 |
+
const acc_t ir1 = inv_rms[row * 3], ir2 = inv_rms[row * 3 + 1], ir3 = inv_rms[row * 3 + 2];
|
| 412 |
+
const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
|
| 413 |
+
const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
|
| 414 |
+
|
| 415 |
+
// Pass 1: dot products with clamp masks
|
| 416 |
+
acc_t sdpx = 0, sdpx2 = 0, sdpx3 = 0, sdp = 0;
|
| 417 |
+
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 418 |
+
acc_t x_orig = input[off + i];
|
| 419 |
+
acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
|
| 420 |
+
acc_t m_orig = (acc_t)mul[off + i];
|
| 421 |
+
acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
|
| 422 |
+
acc_t go = (acc_t)grad_output[off + i];
|
| 423 |
+
|
| 424 |
+
if (do_clamp) {
|
| 425 |
+
acc_t x2 = x * x, x3 = x2 * x;
|
| 426 |
+
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 427 |
+
acc_t out_pre = poly * m * score;
|
| 428 |
+
if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
acc_t x2 = x * x;
|
| 432 |
+
acc_t dp = go * m * score;
|
| 433 |
+
sdp += dp; sdpx += dp * x; sdpx2 += dp * x2; sdpx3 += dp * x2 * x;
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(sdpx, sdpx2, sdpx3, sdp));
|
| 437 |
+
const acc_t inv_d = acc_t(1) / D;
|
| 438 |
+
const acc_t s1 = sums.x * inv_d, s2 = sums.y * inv_d, s3 = sums.z * inv_d;
|
| 439 |
+
const acc_t cx = w2 * s1 * ir1 * ir1, cx2 = w1 * s2 * ir2 * ir2, cx3 = w0 * s3 * ir3 * ir3;
|
| 440 |
+
|
| 441 |
+
// Pass 2: grads with clamp masks
|
| 442 |
+
acc_t dw0 = 0, dw1 = 0, dw2 = 0, gs_acc = 0;
|
| 443 |
+
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 444 |
+
acc_t x_orig = input[off + i], m_orig = (acc_t)mul[off + i];
|
| 445 |
+
acc_t x = do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
|
| 446 |
+
acc_t m = do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
|
| 447 |
+
acc_t go = (acc_t)grad_output[off + i];
|
| 448 |
+
acc_t x2 = x * x, x3 = x2 * x;
|
| 449 |
+
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
|
| 450 |
+
|
| 451 |
+
// Output clamp mask
|
| 452 |
+
if (do_clamp) {
|
| 453 |
+
acc_t out_pre = (poly + b_val) * m * score;
|
| 454 |
+
if (fabsf(out_pre) > hidden_clamp) go = acc_t(0);
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
acc_t dp = go * m * score;
|
| 458 |
+
|
| 459 |
+
acc_t gm_val = go * (poly + b_val) * score;
|
| 460 |
+
if (do_clamp && fabsf(m_orig) > hidden_clamp) gm_val = acc_t(0);
|
| 461 |
+
grad_mul[off + i] = (scalar_t)gm_val;
|
| 462 |
+
|
| 463 |
+
acc_t g = ir1 * (w2 * dp - x * cx) + acc_t(2) * x * ir2 * (w1 * dp - x2 * cx2)
|
| 464 |
+
+ acc_t(3) * x2 * ir3 * (w0 * dp - x3 * cx3);
|
| 465 |
+
if (do_clamp && fabsf(x_orig) > hidden_clamp) g = acc_t(0);
|
| 466 |
+
grad_input[off + i] = (scalar_t)g;
|
| 467 |
+
|
| 468 |
+
dw0 += dp * x3 * ir3; dw1 += dp * x2 * ir2; dw2 += dp * x * ir1;
|
| 469 |
+
gs_acc += go * (poly + b_val) * m;
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
float4 wg = block_reduce_f4<BLOCK_SIZE>(make_float4(dw0, dw1, dw2, gs_acc));
|
| 473 |
+
if (threadIdx.x == 0) {
|
| 474 |
+
atomicAdd(&weight_grad[eidx * 3 + 0], wg.x);
|
| 475 |
+
atomicAdd(&weight_grad[eidx * 3 + 1], wg.y);
|
| 476 |
+
atomicAdd(&weight_grad[eidx * 3 + 2], wg.z);
|
| 477 |
+
atomicAdd(&bias_grad[eidx], sums.w);
|
| 478 |
+
if (grad_scores != nullptr) {
|
| 479 |
+
grad_scores[row] = wg.w;
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
} // namespace motif
|
| 485 |
+
|
| 486 |
+
// ---------------------------------------------------------------------------
|
| 487 |
+
// Internal helpers — shared kernel dispatch
|
| 488 |
+
// ---------------------------------------------------------------------------
|
| 489 |
+
#define FWD_LAUNCH(width_val, scalar_type_name) \
|
| 490 |
+
MOTIF_DISPATCH_FLOATING_TYPES( \
|
| 491 |
+
input.scalar_type(), scalar_type_name, [&] { \
|
| 492 |
+
motif::grouped_poly_norm_fwd_kernel<scalar_t, float, width_val, BLOCK> \
|
| 493 |
+
<<<grid, block, 0, stream>>>( \
|
| 494 |
+
output.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), \
|
| 495 |
+
input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(), \
|
| 496 |
+
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), \
|
| 497 |
+
offsets.data_ptr<int32_t>(), scores_ptr, \
|
| 498 |
+
(float)eps, D, num_experts, (int)expert_offset, \
|
| 499 |
+
(float)hidden_clamp); \
|
| 500 |
+
})
|
| 501 |
+
|
| 502 |
+
static std::tuple<torch::Tensor, torch::Tensor>
|
| 503 |
+
_fwd_impl(const torch::Tensor &input, const torch::Tensor &mul,
|
| 504 |
+
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 505 |
+
const torch::Tensor &offsets, const float *scores_ptr,
|
| 506 |
+
double eps, int64_t expert_offset, double hidden_clamp) {
|
| 507 |
+
const int D = input.size(-1);
|
| 508 |
+
const int64_t N = input.size(0);
|
| 509 |
+
const int num_experts = offsets.size(0);
|
| 510 |
+
constexpr int BLOCK = 128;
|
| 511 |
+
dim3 grid(N); dim3 block(BLOCK);
|
| 512 |
+
|
| 513 |
+
auto output = torch::empty_like(input);
|
| 514 |
+
auto inv_rms = torch::empty({N, 3}, input.options().dtype(torch::kFloat));
|
| 515 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 516 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 517 |
+
|
| 518 |
+
if (D % 8 == 0 && input.element_size() == 2)
|
| 519 |
+
FWD_LAUNCH(8, "grouped_poly_norm_fwd");
|
| 520 |
+
else if (D % 4 == 0 && input.element_size() == 4)
|
| 521 |
+
FWD_LAUNCH(4, "grouped_poly_norm_fwd");
|
| 522 |
+
else {
|
| 523 |
+
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 524 |
+
input.scalar_type(), "grouped_poly_norm_fwd_scalar", [&] {
|
| 525 |
+
motif::grouped_poly_norm_fwd_scalar<scalar_t, float, BLOCK>
|
| 526 |
+
<<<grid, block, 0, stream>>>(
|
| 527 |
+
output.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(),
|
| 528 |
+
input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
|
| 529 |
+
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
|
| 530 |
+
offsets.data_ptr<int32_t>(), scores_ptr,
|
| 531 |
+
(float)eps, D, num_experts, (int)expert_offset,
|
| 532 |
+
(float)hidden_clamp);
|
| 533 |
+
});
|
| 534 |
+
}
|
| 535 |
+
return {output, inv_rms};
|
| 536 |
+
}
|
| 537 |
+
#undef FWD_LAUNCH
|
| 538 |
+
|
| 539 |
+
#define BWD_LAUNCH(width_val, scalar_type_name, kernel_name) \
|
| 540 |
+
MOTIF_DISPATCH_FLOATING_TYPES( \
|
| 541 |
+
input.scalar_type(), scalar_type_name, [&] { \
|
| 542 |
+
motif::kernel_name<scalar_t, float, width_val, BLOCK> \
|
| 543 |
+
<<<grid, block, 0, stream>>>( \
|
| 544 |
+
input_grad.data_ptr<scalar_t>(), \
|
| 545 |
+
mul_grad.data_ptr<scalar_t>(), \
|
| 546 |
+
wg_f32.data_ptr<float>(), bg_f32.data_ptr<float>(), \
|
| 547 |
+
grad_output.data_ptr<scalar_t>(), \
|
| 548 |
+
input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(), \
|
| 549 |
+
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), \
|
| 550 |
+
offsets.data_ptr<int32_t>(), inv_rms.data_ptr<float>(), \
|
| 551 |
+
scores_ptr, gs_ptr, \
|
| 552 |
+
(float)eps, D, num_experts, (int)expert_offset, \
|
| 553 |
+
(float)hidden_clamp); \
|
| 554 |
+
})
|
| 555 |
+
|
| 556 |
+
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
| 557 |
+
torch::Tensor>
|
| 558 |
+
_bwd_impl(const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 559 |
+
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 560 |
+
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 561 |
+
const torch::Tensor &inv_rms, const float *scores_ptr,
|
| 562 |
+
float *gs_ptr, int64_t N,
|
| 563 |
+
double eps, int64_t expert_offset, double hidden_clamp) {
|
| 564 |
+
const int D = input.size(-1);
|
| 565 |
+
const int num_experts = offsets.size(0);
|
| 566 |
+
constexpr int BLOCK = 128;
|
| 567 |
+
dim3 grid(N); dim3 block(BLOCK);
|
| 568 |
+
|
| 569 |
+
auto input_grad = torch::empty_like(input);
|
| 570 |
+
auto mul_grad = torch::empty_like(mul);
|
| 571 |
+
auto wg_f32 = torch::zeros({weight.size(0), 3}, input.options().dtype(torch::kFloat));
|
| 572 |
+
auto bg_f32 = torch::zeros({bias.size(0)}, input.options().dtype(torch::kFloat));
|
| 573 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 574 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 575 |
+
|
| 576 |
+
if (D % 8 == 0 && input.element_size() == 2)
|
| 577 |
+
BWD_LAUNCH(8, "grouped_poly_norm_bwd", grouped_poly_norm_bwd_kernel);
|
| 578 |
+
else if (D % 4 == 0 && input.element_size() == 4)
|
| 579 |
+
BWD_LAUNCH(4, "grouped_poly_norm_bwd", grouped_poly_norm_bwd_kernel);
|
| 580 |
+
else {
|
| 581 |
+
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 582 |
+
input.scalar_type(), "grouped_poly_norm_bwd_scalar", [&] {
|
| 583 |
+
motif::grouped_poly_norm_bwd_scalar<scalar_t, float, BLOCK>
|
| 584 |
+
<<<grid, block, 0, stream>>>(
|
| 585 |
+
input_grad.data_ptr<scalar_t>(),
|
| 586 |
+
mul_grad.data_ptr<scalar_t>(),
|
| 587 |
+
wg_f32.data_ptr<float>(), bg_f32.data_ptr<float>(),
|
| 588 |
+
grad_output.data_ptr<scalar_t>(),
|
| 589 |
+
input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
|
| 590 |
+
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
|
| 591 |
+
offsets.data_ptr<int32_t>(), inv_rms.data_ptr<float>(),
|
| 592 |
+
scores_ptr, gs_ptr,
|
| 593 |
+
(float)eps, D, num_experts, (int)expert_offset,
|
| 594 |
+
(float)hidden_clamp);
|
| 595 |
+
});
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
auto weight_grad = wg_f32.to(weight.dtype());
|
| 599 |
+
auto bias_grad = bg_f32.unsqueeze(-1).to(bias.dtype());
|
| 600 |
+
// gs_f32 handled by caller
|
| 601 |
+
return {input_grad, mul_grad, weight_grad, bias_grad, torch::Tensor()};
|
| 602 |
+
}
|
| 603 |
+
#undef BWD_LAUNCH
|
| 604 |
+
|
| 605 |
+
// ---------------------------------------------------------------------------
|
| 606 |
+
// Public API: without scores
|
| 607 |
+
// ---------------------------------------------------------------------------
|
| 608 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 609 |
+
grouped_poly_norm_forward(
|
| 610 |
+
const torch::Tensor &input, const torch::Tensor &mul,
|
| 611 |
+
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 612 |
+
const torch::Tensor &offsets, double eps, int64_t expert_offset) {
|
| 613 |
+
return _fwd_impl(input, mul, weight, bias, offsets, nullptr, eps, expert_offset, -1.0);
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
| 617 |
+
grouped_poly_norm_backward(
|
| 618 |
+
const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 619 |
+
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 620 |
+
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 621 |
+
const torch::Tensor &inv_rms, double eps, int64_t expert_offset) {
|
| 622 |
+
const int64_t N = input.size(0);
|
| 623 |
+
auto [ig, mg, wg, bg, _] = _bwd_impl(
|
| 624 |
+
grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 625 |
+
nullptr, nullptr, N, eps, expert_offset, -1.0);
|
| 626 |
+
return {ig, mg, wg, bg};
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
// ---------------------------------------------------------------------------
|
| 630 |
+
// Public API: with scores
|
| 631 |
+
// ---------------------------------------------------------------------------
|
| 632 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 633 |
+
grouped_poly_norm_forward_scored(
|
| 634 |
+
const torch::Tensor &input, const torch::Tensor &mul,
|
| 635 |
+
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 636 |
+
const torch::Tensor &offsets, const torch::Tensor &scores,
|
| 637 |
+
double eps, int64_t expert_offset, double hidden_clamp) {
|
| 638 |
+
return _fwd_impl(input, mul, weight, bias, offsets,
|
| 639 |
+
scores.data_ptr<float>(), eps, expert_offset, hidden_clamp);
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
| 643 |
+
grouped_poly_norm_backward_scored(
|
| 644 |
+
const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 645 |
+
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 646 |
+
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 647 |
+
const torch::Tensor &inv_rms, const torch::Tensor &scores,
|
| 648 |
+
double eps, int64_t expert_offset, double hidden_clamp) {
|
| 649 |
+
const int64_t N = input.size(0);
|
| 650 |
+
auto gs_f32 = torch::empty({N}, input.options().dtype(torch::kFloat));
|
| 651 |
+
auto [ig, mg, wg, bg, _] = _bwd_impl(
|
| 652 |
+
grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 653 |
+
scores.data_ptr<float>(), gs_f32.data_ptr<float>(), N,
|
| 654 |
+
eps, expert_offset, hidden_clamp);
|
| 655 |
+
auto gs = gs_f32.unsqueeze(-1);
|
| 656 |
+
return {ig, mg, wg, bg, gs};
|
| 657 |
+
}
|
build.toml
CHANGED
|
@@ -19,6 +19,7 @@ src = [
|
|
| 19 |
"activation/fused_mul_poly_norm.cu",
|
| 20 |
"activation/rms_norm.cu",
|
| 21 |
"activation/fused_add_rms_norm.cu",
|
|
|
|
| 22 |
"activation/cuda_compat.h",
|
| 23 |
"activation/dispatch_utils.h",
|
| 24 |
"activation/assert_utils.h",
|
|
@@ -33,6 +34,7 @@ src = [
|
|
| 33 |
"activation/fused_mul_poly_norm.cu",
|
| 34 |
"activation/rms_norm.cu",
|
| 35 |
"activation/fused_add_rms_norm.cu",
|
|
|
|
| 36 |
"activation/cuda_compat.h",
|
| 37 |
"activation/dispatch_utils.h",
|
| 38 |
"activation/assert_utils.h",
|
|
|
|
| 19 |
"activation/fused_mul_poly_norm.cu",
|
| 20 |
"activation/rms_norm.cu",
|
| 21 |
"activation/fused_add_rms_norm.cu",
|
| 22 |
+
"activation/grouped_poly_norm.cu",
|
| 23 |
"activation/cuda_compat.h",
|
| 24 |
"activation/dispatch_utils.h",
|
| 25 |
"activation/assert_utils.h",
|
|
|
|
| 34 |
"activation/fused_mul_poly_norm.cu",
|
| 35 |
"activation/rms_norm.cu",
|
| 36 |
"activation/fused_add_rms_norm.cu",
|
| 37 |
+
"activation/grouped_poly_norm.cu",
|
| 38 |
"activation/cuda_compat.h",
|
| 39 |
"activation/dispatch_utils.h",
|
| 40 |
"activation/assert_utils.h",
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -48,6 +48,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 48 |
"(Tensor, Tensor)");
|
| 49 |
ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
|
| 50 |
&fused_add_rms_norm_backward);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
}
|
| 52 |
|
| 53 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
|
| 48 |
"(Tensor, Tensor)");
|
| 49 |
ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
|
| 50 |
&fused_add_rms_norm_backward);
|
| 51 |
+
|
| 52 |
+
// grouped_poly_norm (without scores)
|
| 53 |
+
ops.def("grouped_poly_norm_forward("
|
| 54 |
+
"Tensor input, Tensor mul, Tensor weight, "
|
| 55 |
+
"Tensor bias, Tensor offsets, "
|
| 56 |
+
"float eps, int expert_offset) -> (Tensor, Tensor)");
|
| 57 |
+
ops.impl("grouped_poly_norm_forward", torch::kCUDA,
|
| 58 |
+
&grouped_poly_norm_forward);
|
| 59 |
+
|
| 60 |
+
ops.def("grouped_poly_norm_backward("
|
| 61 |
+
"Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
|
| 62 |
+
"Tensor bias, Tensor offsets, Tensor inv_rms, "
|
| 63 |
+
"float eps, int expert_offset) -> (Tensor, Tensor, Tensor, Tensor)");
|
| 64 |
+
ops.impl("grouped_poly_norm_backward", torch::kCUDA,
|
| 65 |
+
&grouped_poly_norm_backward);
|
| 66 |
+
|
| 67 |
+
// grouped_poly_norm (with scores)
|
| 68 |
+
ops.def("grouped_poly_norm_forward_scored("
|
| 69 |
+
"Tensor input, Tensor mul, Tensor weight, "
|
| 70 |
+
"Tensor bias, Tensor offsets, Tensor scores, "
|
| 71 |
+
"float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor)");
|
| 72 |
+
ops.impl("grouped_poly_norm_forward_scored", torch::kCUDA,
|
| 73 |
+
&grouped_poly_norm_forward_scored);
|
| 74 |
+
|
| 75 |
+
ops.def("grouped_poly_norm_backward_scored("
|
| 76 |
+
"Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
|
| 77 |
+
"Tensor bias, Tensor offsets, Tensor inv_rms, Tensor scores, "
|
| 78 |
+
"float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
|
| 79 |
+
ops.impl("grouped_poly_norm_backward_scored", torch::kCUDA,
|
| 80 |
+
&grouped_poly_norm_backward_scored);
|
| 81 |
}
|
| 82 |
|
| 83 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
CHANGED
|
@@ -35,3 +35,33 @@ std::tuple<torch::Tensor, torch::Tensor> fused_add_rms_norm_backward(
|
|
| 35 |
const torch::Tensor &output_grad, const torch::Tensor &add_output_grad,
|
| 36 |
const torch::Tensor &input, const torch::Tensor &weight, double eps,
|
| 37 |
bool need_input_grad);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
const torch::Tensor &output_grad, const torch::Tensor &add_output_grad,
|
| 36 |
const torch::Tensor &input, const torch::Tensor &weight, double eps,
|
| 37 |
bool need_input_grad);
|
| 38 |
+
|
| 39 |
+
// Without scores
|
| 40 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 41 |
+
grouped_poly_norm_forward(
|
| 42 |
+
const torch::Tensor &input, const torch::Tensor &mul,
|
| 43 |
+
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 44 |
+
const torch::Tensor &offsets, double eps, int64_t expert_offset);
|
| 45 |
+
|
| 46 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
| 47 |
+
grouped_poly_norm_backward(
|
| 48 |
+
const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 49 |
+
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 50 |
+
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 51 |
+
const torch::Tensor &inv_rms, double eps, int64_t expert_offset);
|
| 52 |
+
|
| 53 |
+
// With scores (hidden_clamp < 0 = disabled)
|
| 54 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 55 |
+
grouped_poly_norm_forward_scored(
|
| 56 |
+
const torch::Tensor &input, const torch::Tensor &mul,
|
| 57 |
+
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 58 |
+
const torch::Tensor &offsets, const torch::Tensor &scores,
|
| 59 |
+
double eps, int64_t expert_offset, double hidden_clamp);
|
| 60 |
+
|
| 61 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
| 62 |
+
grouped_poly_norm_backward_scored(
|
| 63 |
+
const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 64 |
+
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 65 |
+
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 66 |
+
const torch::Tensor &inv_rms, const torch::Tensor &scores,
|
| 67 |
+
double eps, int64_t expert_offset, double hidden_clamp);
|