style: apply yapf, isort, and clang-format
Browse filesCo-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- README.md +0 -103
- activation/grouped_poly_norm.cu +192 -171
- benchmarks/cases/grouped_mul_poly.py +28 -22
- benchmarks/profile_bwd.py +0 -146
- benchmarks/run_cases.py +39 -16
- registration.h +3 -2
- setup.py +3 -3
- tests/test_fused_mul_grouped_poly_norm.py +168 -61
- torch-ext/activation/grouped_poly_norm.py +42 -35
- torch-ext/torch_binding.cpp +14 -10
README.md
CHANGED
|
@@ -251,109 +251,6 @@ print(poly_norm(x))
|
|
| 251 |
> | Forward | 0.7 ms | 2.1 ms | **3.0x** |
|
| 252 |
> | Backward | 1.4 ms | 3.7 ms | **2.6x** |
|
| 253 |
|
| 254 |
-
#### B200 Results (bf16)
|
| 255 |
-
|
| 256 |
-
<details>
|
| 257 |
-
<summary>Forward Performance</summary>
|
| 258 |
-
|
| 259 |
-
| batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive |
|
| 260 |
-
|-----------|---------|-----------|--------------|------------|-----------------|
|
| 261 |
-
| 1 | 1024 | 294.54 | 73.46 | 64.33 | 4.58x |
|
| 262 |
-
| 1 | 2048 | 373.50 | 94.88 | 65.26 | 5.72x |
|
| 263 |
-
| 1 | 4096 | 372.65 | 94.90 | 66.90 | 5.57x |
|
| 264 |
-
| 1 | 8192 | 486.98 | 102.33 | 72.71 | 6.70x |
|
| 265 |
-
| 2 | 4096 | 486.66 | 101.87 | 72.27 | 6.73x |
|
| 266 |
-
| 2 | 8192 | 950.62 | 106.96 | 90.06 | 10.56x |
|
| 267 |
-
| 4 | 4096 | 950.72 | 107.17 | 71.28 | 13.34x |
|
| 268 |
-
| 4 | 8192 | 1779.12 | 198.91 | 96.93 | 18.35x |
|
| 269 |
-
| 8 | 4096 | 1778.73 | 199.10 | 96.88 | 18.36x |
|
| 270 |
-
| 8 | 8192 | 3384.03 | 381.91 | 179.57 | 18.85x |
|
| 271 |
-
|
| 272 |
-
</details>
|
| 273 |
-
|
| 274 |
-
<details>
|
| 275 |
-
<summary>Backward Performance</summary>
|
| 276 |
-
|
| 277 |
-
| batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive |
|
| 278 |
-
|-----------|---------|-----------|--------------|------------|-----------------|
|
| 279 |
-
| 1 | 1024 | 1690.61 | 999.66 | 1017.66 | 1.66x |
|
| 280 |
-
| 1 | 8192 | 1680.39 | 906.43 | 906.41 | 1.85x |
|
| 281 |
-
| 2 | 8192 | 2466.73 | 870.74 | 862.78 | 2.86x |
|
| 282 |
-
| 4 | 4096 | 2466.04 | 942.62 | 945.68 | 2.61x |
|
| 283 |
-
| 4 | 8192 | 4543.10 | 941.01 | 908.30 | 5.00x |
|
| 284 |
-
| 8 | 4096 | 4542.91 | 814.73 | 900.01 | 5.05x |
|
| 285 |
-
| 8 | 8192 | 8599.41 | 956.81 | 955.07 | 9.00x |
|
| 286 |
-
|
| 287 |
-
</details>
|
| 288 |
-
|
| 289 |
-
<details>
|
| 290 |
-
<summary>Forward + Backward Combined</summary>
|
| 291 |
-
|
| 292 |
-
| batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive | CUDA vs Compiled |
|
| 293 |
-
|-----------|---------|-----------|--------------|------------|-----------------|-------------------|
|
| 294 |
-
| 1 | 1024 | 1985.15 | 1073.12 | 1081.99 | 1.83x | 0.99x |
|
| 295 |
-
| 1 | 4096 | 2085.10 | 974.32 | 960.73 | 2.17x | 1.01x |
|
| 296 |
-
| 1 | 8192 | 2167.37 | 1008.76 | 979.12 | 2.21x | 1.03x |
|
| 297 |
-
| 2 | 4096 | 2083.49 | 1001.03 | 965.30 | 2.16x | 1.04x |
|
| 298 |
-
| 2 | 8192 | 3417.35 | 977.70 | 952.84 | 3.59x | 1.03x |
|
| 299 |
-
| 4 | 4096 | 3416.76 | 1049.79 | 1016.97 | 3.36x | 1.03x |
|
| 300 |
-
| 4 | 8192 | 6322.22 | 1139.92 | 1005.23 | 6.29x | 1.13x |
|
| 301 |
-
| 8 | 4096 | 6321.64 | 1013.83 | 996.89 | 6.34x | 1.02x |
|
| 302 |
-
| 8 | 8192 | 11983.44 | 1338.71 | 1134.64 | 10.56x | 1.18x |
|
| 303 |
-
|
| 304 |
-
</details>
|
| 305 |
-
|
| 306 |
-
#### B200 Results (fp32)
|
| 307 |
-
|
| 308 |
-
<details>
|
| 309 |
-
<summary>Forward Performance</summary>
|
| 310 |
-
|
| 311 |
-
| batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive |
|
| 312 |
-
|-----------|---------|-----------|--------------|------------|-----------------|
|
| 313 |
-
| 1 | 1024 | 318.05 | 83.29 | 64.24 | 4.95x |
|
| 314 |
-
| 1 | 2048 | 311.14 | 95.19 | 63.64 | 4.89x |
|
| 315 |
-
| 1 | 8192 | 401.78 | 101.61 | 68.21 | 5.89x |
|
| 316 |
-
| 2 | 4096 | 403.42 | 100.97 | 68.01 | 5.93x |
|
| 317 |
-
| 2 | 8192 | 803.31 | 130.51 | 68.21 | 11.78x |
|
| 318 |
-
| 4 | 4096 | 802.86 | 130.61 | 66.97 | 11.99x |
|
| 319 |
-
| 4 | 8192 | 1505.96 | 246.77 | 100.49 | 14.99x |
|
| 320 |
-
| 8 | 4096 | 1507.87 | 246.84 | 100.23 | 15.04x |
|
| 321 |
-
| 8 | 8192 | 2856.93 | 476.34 | 184.40 | 15.49x |
|
| 322 |
-
|
| 323 |
-
</details>
|
| 324 |
-
|
| 325 |
-
<details>
|
| 326 |
-
<summary>Backward Performance</summary>
|
| 327 |
-
|
| 328 |
-
| batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive |
|
| 329 |
-
|-----------|---------|-----------|--------------|------------|-----------------|
|
| 330 |
-
| 1 | 1024 | 1604.25 | 989.30 | 1114.12 | 1.44x |
|
| 331 |
-
| 1 | 8192 | 1996.40 | 1117.71 | 1115.47 | 1.79x |
|
| 332 |
-
| 2 | 8192 | 2353.87 | 1119.41 | 1118.57 | 2.10x |
|
| 333 |
-
| 4 | 4096 | 2358.47 | 1102.23 | 1125.16 | 2.10x |
|
| 334 |
-
| 4 | 8192 | 4346.92 | 1125.33 | 1135.36 | 3.83x |
|
| 335 |
-
| 8 | 4096 | 4347.47 | 1104.27 | 1119.63 | 3.88x |
|
| 336 |
-
| 8 | 8192 | 8226.50 | 1172.66 | 1197.68 | 6.87x |
|
| 337 |
-
|
| 338 |
-
</details>
|
| 339 |
-
|
| 340 |
-
<details>
|
| 341 |
-
<summary>Forward + Backward Combined</summary>
|
| 342 |
-
|
| 343 |
-
| batch_size | seq_len | Naive (us) | Compiled (us) | CUDA (us) | CUDA vs Naive | CUDA vs Compiled |
|
| 344 |
-
|-----------|---------|-----------|--------------|------------|-----------------|-------------------|
|
| 345 |
-
| 1 | 1024 | 1922.30 | 1072.59 | 1178.36 | 1.63x | 0.91x |
|
| 346 |
-
| 1 | 4096 | 2367.77 | 1208.69 | 1192.07 | 1.99x | 1.01x |
|
| 347 |
-
| 1 | 8192 | 2398.19 | 1219.32 | 1183.69 | 2.03x | 1.03x |
|
| 348 |
-
| 2 | 4096 | 2401.39 | 1248.87 | 1154.72 | 2.08x | 1.08x |
|
| 349 |
-
| 2 | 8192 | 3157.18 | 1249.92 | 1186.77 | 2.66x | 1.05x |
|
| 350 |
-
| 4 | 4096 | 3161.33 | 1232.84 | 1192.13 | 2.65x | 1.03x |
|
| 351 |
-
| 4 | 8192 | 5852.88 | 1372.10 | 1235.86 | 4.74x | 1.11x |
|
| 352 |
-
| 8 | 4096 | 5855.34 | 1351.11 | 1219.85 | 4.80x | 1.11x |
|
| 353 |
-
| 8 | 8192 | 11083.43 | 1649.00 | 1382.07 | 8.02x | 1.19x |
|
| 354 |
-
|
| 355 |
-
</details>
|
| 356 |
-
|
| 357 |
## Pre-commit Hooks
|
| 358 |
|
| 359 |
This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
|
|
|
|
| 251 |
> | Forward | 0.7 ms | 2.1 ms | **3.0x** |
|
| 252 |
> | Backward | 1.4 ms | 3.7 ms | **2.6x** |
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
## Pre-commit Hooks
|
| 255 |
|
| 256 |
This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
|
activation/grouped_poly_norm.cu
CHANGED
|
@@ -18,7 +18,8 @@ __device__ __forceinline__ int find_expert(const int32_t *__restrict__ offsets,
|
|
| 18 |
int lo = 0, hi = num_experts;
|
| 19 |
#pragma unroll 6
|
| 20 |
for (int i = 0; i < 12; ++i) {
|
| 21 |
-
if (lo >= hi)
|
|
|
|
| 22 |
int mid = (lo + hi) >> 1;
|
| 23 |
if (offsets[mid] <= row)
|
| 24 |
lo = mid + 1;
|
|
@@ -53,7 +54,8 @@ __device__ __forceinline__ float4 block_reduce_f4(float4 v) {
|
|
| 53 |
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 54 |
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 55 |
|
| 56 |
-
if (lane_id == 0)
|
|
|
|
| 57 |
__syncthreads();
|
| 58 |
|
| 59 |
if (warp_id == 0 && lane_id < NUM_WARPS)
|
|
@@ -61,10 +63,12 @@ __device__ __forceinline__ float4 block_reduce_f4(float4 v) {
|
|
| 61 |
else
|
| 62 |
v = make_float4(0.f, 0.f, 0.f, 0.f);
|
| 63 |
|
| 64 |
-
if (warp_id == 0)
|
|
|
|
| 65 |
|
| 66 |
__shared__ float4 result;
|
| 67 |
-
if (threadIdx.x == 0)
|
|
|
|
| 68 |
__syncthreads();
|
| 69 |
return result;
|
| 70 |
}
|
|
@@ -77,17 +81,14 @@ __device__ __forceinline__ float4 block_reduce_f4(float4 v) {
|
|
| 77 |
template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
|
| 78 |
__global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
| 79 |
grouped_poly_norm_fwd_kernel(
|
| 80 |
-
scalar_t *__restrict__ output,
|
| 81 |
-
|
| 82 |
-
const scalar_t *__restrict__
|
| 83 |
-
const scalar_t *__restrict__ mul,
|
| 84 |
-
const scalar_t *__restrict__ weight,
|
| 85 |
-
const scalar_t *__restrict__ bias,
|
| 86 |
const int32_t *__restrict__ offsets,
|
| 87 |
-
const float *__restrict__ scores,
|
| 88 |
const acc_t eps, const int D, const int num_experts,
|
| 89 |
const int expert_offset,
|
| 90 |
-
const acc_t hidden_clamp) {
|
| 91 |
using v_t = vec_t<scalar_t, width>;
|
| 92 |
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 93 |
|
|
@@ -111,7 +112,8 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
|
| 111 |
#pragma unroll
|
| 112 |
for (int j = 0; j < width; ++j) {
|
| 113 |
acc_t x = xv.data[j];
|
| 114 |
-
if (do_clamp)
|
|
|
|
| 115 |
acc_t x2 = x * x;
|
| 116 |
s2 += x2;
|
| 117 |
s4 += x2 * x2;
|
|
@@ -148,14 +150,17 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
|
| 148 |
#pragma unroll
|
| 149 |
for (int j = 0; j < width; ++j) {
|
| 150 |
acc_t x = xv.data[j];
|
| 151 |
-
if (do_clamp)
|
|
|
|
| 152 |
acc_t m = (acc_t)mv.data[j];
|
| 153 |
-
if (do_clamp)
|
|
|
|
| 154 |
acc_t x2 = x * x;
|
| 155 |
acc_t x3 = x2 * x;
|
| 156 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 157 |
acc_t out_val = poly * m * score;
|
| 158 |
-
if (do_clamp)
|
|
|
|
| 159 |
ov.data[j] = (scalar_t)out_val;
|
| 160 |
}
|
| 161 |
out_v[i] = ov;
|
|
@@ -164,34 +169,33 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
|
| 164 |
|
| 165 |
// Scalar fallback forward
|
| 166 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
| 167 |
-
__global__ void __launch_bounds__(BLOCK_SIZE)
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
const int32_t *__restrict__ offsets,
|
| 176 |
-
const float *__restrict__ scores, // nullable, always fp32
|
| 177 |
-
const acc_t eps, const int D, const int num_experts,
|
| 178 |
-
const int expert_offset,
|
| 179 |
-
const acc_t hidden_clamp) {
|
| 180 |
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 181 |
const int row = blockIdx.x;
|
| 182 |
const int64_t off = (int64_t)row * D;
|
| 183 |
|
| 184 |
const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
|
| 185 |
-
const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1],
|
|
|
|
| 186 |
const acc_t b_val = bias[eidx];
|
| 187 |
const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
|
| 188 |
|
| 189 |
acc_t s2 = 0, s4 = 0, s6 = 0;
|
| 190 |
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 191 |
acc_t x = input[off + i];
|
| 192 |
-
if (do_clamp)
|
|
|
|
| 193 |
acc_t x2 = x * x;
|
| 194 |
-
s2 += x2;
|
|
|
|
|
|
|
| 195 |
}
|
| 196 |
|
| 197 |
float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(s2, s4, s6, 0.f));
|
|
@@ -201,19 +205,24 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
|
|
| 201 |
const acc_t ir3 = rsqrtf(sums.z * inv_d + eps);
|
| 202 |
|
| 203 |
if (threadIdx.x == 0) {
|
| 204 |
-
inv_rms[row * 3] = ir1;
|
|
|
|
|
|
|
| 205 |
}
|
| 206 |
|
| 207 |
const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
|
| 208 |
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 209 |
acc_t x = input[off + i];
|
| 210 |
-
if (do_clamp)
|
|
|
|
| 211 |
acc_t m = (acc_t)mul[off + i];
|
| 212 |
-
if (do_clamp)
|
|
|
|
| 213 |
acc_t x2 = x * x, x3 = x2 * x;
|
| 214 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 215 |
acc_t out_val = poly * m * score;
|
| 216 |
-
if (do_clamp)
|
|
|
|
| 217 |
output[off + i] = (scalar_t)out_val;
|
| 218 |
}
|
| 219 |
}
|
|
@@ -225,22 +234,17 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
|
|
| 225 |
template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
|
| 226 |
__global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
| 227 |
grouped_poly_norm_bwd_kernel(
|
| 228 |
-
scalar_t *__restrict__ grad_input,
|
| 229 |
-
|
| 230 |
-
float *__restrict__
|
| 231 |
-
float *__restrict__ bias_grad, // [num_total_experts] fp32
|
| 232 |
const scalar_t *__restrict__ grad_output,
|
| 233 |
-
const scalar_t *__restrict__ input,
|
| 234 |
-
const scalar_t *__restrict__
|
| 235 |
-
const
|
| 236 |
-
const
|
| 237 |
-
|
| 238 |
-
const acc_t *__restrict__ inv_rms,
|
| 239 |
-
const float *__restrict__ scores, // nullable, always fp32
|
| 240 |
-
acc_t *__restrict__ grad_scores, // nullable (null when scores is null)
|
| 241 |
const acc_t eps, const int D, const int num_experts,
|
| 242 |
-
const int expert_offset,
|
| 243 |
-
const acc_t hidden_clamp) {
|
| 244 |
using v_t = vec_t<scalar_t, width>;
|
| 245 |
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 246 |
|
|
@@ -249,7 +253,8 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
|
| 249 |
const int64_t base = (int64_t)row * vec_d;
|
| 250 |
|
| 251 |
const v_t *__restrict__ in_v = reinterpret_cast<const v_t *>(input) + base;
|
| 252 |
-
const v_t *__restrict__ go_v =
|
|
|
|
| 253 |
const v_t *__restrict__ m_v = reinterpret_cast<const v_t *>(mul) + base;
|
| 254 |
|
| 255 |
const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
|
|
@@ -278,9 +283,11 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
|
| 278 |
#pragma unroll
|
| 279 |
for (int j = 0; j < width; ++j) {
|
| 280 |
acc_t x_orig = xv.data[j];
|
| 281 |
-
acc_t x =
|
|
|
|
| 282 |
acc_t m_orig = (acc_t)mv.data[j];
|
| 283 |
-
acc_t m =
|
|
|
|
| 284 |
acc_t go = (acc_t)gv.data[j];
|
| 285 |
|
| 286 |
// Output clamp mask: recompute pre-clamp output
|
|
@@ -288,7 +295,8 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
|
| 288 |
acc_t x2 = x * x, x3 = x2 * x;
|
| 289 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 290 |
acc_t out_pre = poly * m * score;
|
| 291 |
-
if (fabsf(out_pre) > hidden_clamp)
|
|
|
|
| 292 |
}
|
| 293 |
|
| 294 |
acc_t x2 = x * x;
|
|
@@ -300,7 +308,8 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
|
| 300 |
}
|
| 301 |
}
|
| 302 |
|
| 303 |
-
float4 sums =
|
|
|
|
| 304 |
|
| 305 |
const acc_t inv_d = acc_t(1) / D;
|
| 306 |
const acc_t s1 = sums.x * inv_d;
|
|
@@ -327,9 +336,11 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
|
| 327 |
#pragma unroll
|
| 328 |
for (int j = 0; j < width; ++j) {
|
| 329 |
acc_t x_orig = xv.data[j];
|
| 330 |
-
acc_t x =
|
|
|
|
| 331 |
acc_t m_orig = (acc_t)mv.data[j];
|
| 332 |
-
acc_t m =
|
|
|
|
| 333 |
acc_t x2 = x * x;
|
| 334 |
acc_t x3 = x2 * x;
|
| 335 |
acc_t go = (acc_t)gv.data[j];
|
|
@@ -338,27 +349,30 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
|
| 338 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
|
| 339 |
if (do_clamp) {
|
| 340 |
acc_t out_pre = (poly + b_val) * m * score;
|
| 341 |
-
if (fabsf(out_pre) > hidden_clamp)
|
|
|
|
| 342 |
}
|
| 343 |
|
| 344 |
acc_t dp = go * m * score;
|
| 345 |
|
| 346 |
// grad_mul with mul clamp mask
|
| 347 |
acc_t gm_val = go * (poly + b_val) * score;
|
| 348 |
-
if (do_clamp && fabsf(m_orig) > hidden_clamp)
|
|
|
|
| 349 |
gm.data[j] = (scalar_t)gm_val;
|
| 350 |
|
| 351 |
// grad_input with input clamp mask
|
| 352 |
acc_t g = ir1 * (w2 * dp - x * cx);
|
| 353 |
g += acc_t(2) * x * ir2 * (w1 * dp - x2 * cx2);
|
| 354 |
g += acc_t(3) * x2 * ir3 * (w0 * dp - x3 * cx3);
|
| 355 |
-
if (do_clamp && fabsf(x_orig) > hidden_clamp)
|
|
|
|
| 356 |
gi.data[j] = (scalar_t)g;
|
| 357 |
|
| 358 |
dw0 += dp * x3 * ir3;
|
| 359 |
dw1 += dp * x2 * ir2;
|
| 360 |
dw2 += dp * x * ir1;
|
| 361 |
-
gs_acc += go * (poly + b_val) * m;
|
| 362 |
}
|
| 363 |
|
| 364 |
gi_v[i] = gi;
|
|
@@ -383,32 +397,27 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
|
| 383 |
// Scalar fallback (width == 0)
|
| 384 |
// ---------------------------------------------------------------------------
|
| 385 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
| 386 |
-
__global__ void __launch_bounds__(BLOCK_SIZE)
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
const int32_t *__restrict__ offsets,
|
| 398 |
-
const acc_t *__restrict__ inv_rms,
|
| 399 |
-
const float *__restrict__ scores, // nullable, always fp32
|
| 400 |
-
acc_t *__restrict__ grad_scores, // nullable
|
| 401 |
-
const acc_t eps, const int D, const int num_experts,
|
| 402 |
-
const int expert_offset,
|
| 403 |
-
const acc_t hidden_clamp) {
|
| 404 |
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 405 |
const int row = blockIdx.x;
|
| 406 |
const int64_t off = (int64_t)row * D;
|
| 407 |
|
| 408 |
const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
|
| 409 |
-
const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1],
|
|
|
|
| 410 |
const acc_t b_val = bias[eidx];
|
| 411 |
-
const acc_t ir1 = inv_rms[row * 3], ir2 = inv_rms[row * 3 + 1],
|
|
|
|
| 412 |
const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
|
| 413 |
const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
|
| 414 |
|
|
@@ -416,34 +425,44 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
|
|
| 416 |
acc_t sdpx = 0, sdpx2 = 0, sdpx3 = 0, sdp = 0;
|
| 417 |
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 418 |
acc_t x_orig = input[off + i];
|
| 419 |
-
acc_t x =
|
|
|
|
| 420 |
acc_t m_orig = (acc_t)mul[off + i];
|
| 421 |
-
acc_t m =
|
|
|
|
| 422 |
acc_t go = (acc_t)grad_output[off + i];
|
| 423 |
|
| 424 |
if (do_clamp) {
|
| 425 |
acc_t x2 = x * x, x3 = x2 * x;
|
| 426 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 427 |
acc_t out_pre = poly * m * score;
|
| 428 |
-
if (fabsf(out_pre) > hidden_clamp)
|
|
|
|
| 429 |
}
|
| 430 |
|
| 431 |
acc_t x2 = x * x;
|
| 432 |
acc_t dp = go * m * score;
|
| 433 |
-
sdp += dp;
|
|
|
|
|
|
|
|
|
|
| 434 |
}
|
| 435 |
|
| 436 |
-
float4 sums =
|
|
|
|
| 437 |
const acc_t inv_d = acc_t(1) / D;
|
| 438 |
const acc_t s1 = sums.x * inv_d, s2 = sums.y * inv_d, s3 = sums.z * inv_d;
|
| 439 |
-
const acc_t cx = w2 * s1 * ir1 * ir1, cx2 = w1 * s2 * ir2 * ir2,
|
|
|
|
| 440 |
|
| 441 |
// Pass 2: grads with clamp masks
|
| 442 |
acc_t dw0 = 0, dw1 = 0, dw2 = 0, gs_acc = 0;
|
| 443 |
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 444 |
acc_t x_orig = input[off + i], m_orig = (acc_t)mul[off + i];
|
| 445 |
-
acc_t x =
|
| 446 |
-
|
|
|
|
|
|
|
| 447 |
acc_t go = (acc_t)grad_output[off + i];
|
| 448 |
acc_t x2 = x * x, x3 = x2 * x;
|
| 449 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
|
|
@@ -451,21 +470,27 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
|
|
| 451 |
// Output clamp mask
|
| 452 |
if (do_clamp) {
|
| 453 |
acc_t out_pre = (poly + b_val) * m * score;
|
| 454 |
-
if (fabsf(out_pre) > hidden_clamp)
|
|
|
|
| 455 |
}
|
| 456 |
|
| 457 |
acc_t dp = go * m * score;
|
| 458 |
|
| 459 |
acc_t gm_val = go * (poly + b_val) * score;
|
| 460 |
-
if (do_clamp && fabsf(m_orig) > hidden_clamp)
|
|
|
|
| 461 |
grad_mul[off + i] = (scalar_t)gm_val;
|
| 462 |
|
| 463 |
-
acc_t g = ir1 * (w2 * dp - x * cx) +
|
| 464 |
-
|
| 465 |
-
|
|
|
|
|
|
|
| 466 |
grad_input[off + i] = (scalar_t)g;
|
| 467 |
|
| 468 |
-
dw0 += dp * x3 * ir3;
|
|
|
|
|
|
|
| 469 |
gs_acc += go * (poly + b_val) * m;
|
| 470 |
}
|
| 471 |
|
|
@@ -487,28 +512,27 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
|
|
| 487 |
// Internal helpers — shared kernel dispatch
|
| 488 |
// ---------------------------------------------------------------------------
|
| 489 |
#define FWD_LAUNCH(width_val, scalar_type_name) \
|
| 490 |
-
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
<
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
(float)hidden_clamp); \
|
| 500 |
-
})
|
| 501 |
|
| 502 |
static std::tuple<torch::Tensor, torch::Tensor>
|
| 503 |
_fwd_impl(const torch::Tensor &input, const torch::Tensor &mul,
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
const int D = input.size(-1);
|
| 508 |
const int64_t N = input.size(0);
|
| 509 |
const int num_experts = offsets.size(0);
|
| 510 |
constexpr int BLOCK = 128;
|
| 511 |
-
dim3 grid(N);
|
|
|
|
| 512 |
|
| 513 |
auto output = torch::empty_like(input);
|
| 514 |
auto inv_rms = torch::empty({N, 3}, input.options().dtype(torch::kFloat));
|
|
@@ -527,9 +551,8 @@ _fwd_impl(const torch::Tensor &input, const torch::Tensor &mul,
|
|
| 527 |
output.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(),
|
| 528 |
input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
|
| 529 |
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
|
| 530 |
-
offsets.data_ptr<int32_t>(), scores_ptr,
|
| 531 |
-
|
| 532 |
-
(float)hidden_clamp);
|
| 533 |
});
|
| 534 |
}
|
| 535 |
return {output, inv_rms};
|
|
@@ -537,39 +560,37 @@ _fwd_impl(const torch::Tensor &input, const torch::Tensor &mul,
|
|
| 537 |
#undef FWD_LAUNCH
|
| 538 |
|
| 539 |
#define BWD_LAUNCH(width_val, scalar_type_name, kernel_name) \
|
| 540 |
-
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
<
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
scores_ptr, gs_ptr, \
|
| 552 |
-
(float)eps, D, num_experts, (int)expert_offset, \
|
| 553 |
-
(float)hidden_clamp); \
|
| 554 |
-
})
|
| 555 |
|
| 556 |
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
| 557 |
-
|
| 558 |
_bwd_impl(const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
double eps, int64_t expert_offset, double hidden_clamp) {
|
| 564 |
const int D = input.size(-1);
|
| 565 |
const int num_experts = offsets.size(0);
|
| 566 |
constexpr int BLOCK = 128;
|
| 567 |
-
dim3 grid(N);
|
|
|
|
| 568 |
|
| 569 |
auto input_grad = torch::empty_like(input);
|
| 570 |
auto mul_grad = torch::empty_like(mul);
|
| 571 |
-
auto wg_f32 =
|
| 572 |
-
|
|
|
|
|
|
|
| 573 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 574 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 575 |
|
|
@@ -583,15 +604,13 @@ _bwd_impl(const torch::Tensor &grad_output, const torch::Tensor &input,
|
|
| 583 |
motif::grouped_poly_norm_bwd_scalar<scalar_t, float, BLOCK>
|
| 584 |
<<<grid, block, 0, stream>>>(
|
| 585 |
input_grad.data_ptr<scalar_t>(),
|
| 586 |
-
mul_grad.data_ptr<scalar_t>(),
|
| 587 |
-
|
| 588 |
-
grad_output.data_ptr<scalar_t>(),
|
| 589 |
input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
|
| 590 |
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
|
| 591 |
offsets.data_ptr<int32_t>(), inv_rms.data_ptr<float>(),
|
| 592 |
-
scores_ptr, gs_ptr,
|
| 593 |
-
(
|
| 594 |
-
(float)hidden_clamp);
|
| 595 |
});
|
| 596 |
}
|
| 597 |
|
|
@@ -606,54 +625,56 @@ _bwd_impl(const torch::Tensor &grad_output, const torch::Tensor &input,
|
|
| 606 |
// Public API: without scores
|
| 607 |
// ---------------------------------------------------------------------------
|
| 608 |
std::tuple<torch::Tensor, torch::Tensor>
|
| 609 |
-
grouped_poly_norm_forward(
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
return _fwd_impl(input, mul, weight, bias, offsets, nullptr, eps,
|
|
|
|
| 615 |
}
|
| 616 |
|
| 617 |
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
| 618 |
-
grouped_poly_norm_backward(
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
|
|
|
| 624 |
const int64_t N = input.size(0);
|
| 625 |
-
auto [ig, mg, wg, bg, _] =
|
| 626 |
-
grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 627 |
-
|
| 628 |
return {ig, mg, wg, bg};
|
| 629 |
}
|
| 630 |
|
| 631 |
// ---------------------------------------------------------------------------
|
| 632 |
// Public API: with scores
|
| 633 |
// ---------------------------------------------------------------------------
|
| 634 |
-
std::tuple<torch::Tensor, torch::Tensor>
|
| 635 |
-
grouped_poly_norm_forward_scored(
|
| 636 |
const torch::Tensor &input, const torch::Tensor &mul,
|
| 637 |
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 638 |
-
const torch::Tensor &offsets, const torch::Tensor &scores,
|
| 639 |
-
|
| 640 |
-
return _fwd_impl(input, mul, weight, bias, offsets,
|
| 641 |
-
|
| 642 |
}
|
| 643 |
|
| 644 |
-
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
|
|
|
| 645 |
grouped_poly_norm_backward_scored(
|
| 646 |
const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 647 |
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 648 |
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 649 |
-
const torch::Tensor &inv_rms, const torch::Tensor &scores,
|
| 650 |
-
|
| 651 |
const int64_t N = input.size(0);
|
| 652 |
auto gs_f32 = torch::empty({N}, input.options().dtype(torch::kFloat));
|
| 653 |
-
auto [ig, mg, wg, bg, _] =
|
| 654 |
-
grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 655 |
-
|
| 656 |
-
|
| 657 |
auto gs = gs_f32.unsqueeze(-1);
|
| 658 |
return {ig, mg, wg, bg, gs};
|
| 659 |
}
|
|
|
|
| 18 |
int lo = 0, hi = num_experts;
|
| 19 |
#pragma unroll 6
|
| 20 |
for (int i = 0; i < 12; ++i) {
|
| 21 |
+
if (lo >= hi)
|
| 22 |
+
break;
|
| 23 |
int mid = (lo + hi) >> 1;
|
| 24 |
if (offsets[mid] <= row)
|
| 25 |
lo = mid + 1;
|
|
|
|
| 54 |
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 55 |
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 56 |
|
| 57 |
+
if (lane_id == 0)
|
| 58 |
+
warp_results[warp_id] = v;
|
| 59 |
__syncthreads();
|
| 60 |
|
| 61 |
if (warp_id == 0 && lane_id < NUM_WARPS)
|
|
|
|
| 63 |
else
|
| 64 |
v = make_float4(0.f, 0.f, 0.f, 0.f);
|
| 65 |
|
| 66 |
+
if (warp_id == 0)
|
| 67 |
+
v = warp_reduce_f4(v);
|
| 68 |
|
| 69 |
__shared__ float4 result;
|
| 70 |
+
if (threadIdx.x == 0)
|
| 71 |
+
result = v;
|
| 72 |
__syncthreads();
|
| 73 |
return result;
|
| 74 |
}
|
|
|
|
| 81 |
template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
|
| 82 |
__global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
| 83 |
grouped_poly_norm_fwd_kernel(
|
| 84 |
+
scalar_t *__restrict__ output, acc_t *__restrict__ inv_rms,
|
| 85 |
+
const scalar_t *__restrict__ input, const scalar_t *__restrict__ mul,
|
| 86 |
+
const scalar_t *__restrict__ weight, const scalar_t *__restrict__ bias,
|
|
|
|
|
|
|
|
|
|
| 87 |
const int32_t *__restrict__ offsets,
|
| 88 |
+
const float *__restrict__ scores, // nullable, always fp32
|
| 89 |
const acc_t eps, const int D, const int num_experts,
|
| 90 |
const int expert_offset,
|
| 91 |
+
const acc_t hidden_clamp) { // < 0 = disabled
|
| 92 |
using v_t = vec_t<scalar_t, width>;
|
| 93 |
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 94 |
|
|
|
|
| 112 |
#pragma unroll
|
| 113 |
for (int j = 0; j < width; ++j) {
|
| 114 |
acc_t x = xv.data[j];
|
| 115 |
+
if (do_clamp)
|
| 116 |
+
x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
|
| 117 |
acc_t x2 = x * x;
|
| 118 |
s2 += x2;
|
| 119 |
s4 += x2 * x2;
|
|
|
|
| 150 |
#pragma unroll
|
| 151 |
for (int j = 0; j < width; ++j) {
|
| 152 |
acc_t x = xv.data[j];
|
| 153 |
+
if (do_clamp)
|
| 154 |
+
x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
|
| 155 |
acc_t m = (acc_t)mv.data[j];
|
| 156 |
+
if (do_clamp)
|
| 157 |
+
m = fminf(fmaxf(m, -hidden_clamp), hidden_clamp);
|
| 158 |
acc_t x2 = x * x;
|
| 159 |
acc_t x3 = x2 * x;
|
| 160 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 161 |
acc_t out_val = poly * m * score;
|
| 162 |
+
if (do_clamp)
|
| 163 |
+
out_val = fminf(fmaxf(out_val, -hidden_clamp), hidden_clamp);
|
| 164 |
ov.data[j] = (scalar_t)out_val;
|
| 165 |
}
|
| 166 |
out_v[i] = ov;
|
|
|
|
| 169 |
|
| 170 |
// Scalar fallback forward
|
| 171 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
| 172 |
+
__global__ void __launch_bounds__(BLOCK_SIZE) grouped_poly_norm_fwd_scalar(
|
| 173 |
+
scalar_t *__restrict__ output, acc_t *__restrict__ inv_rms,
|
| 174 |
+
const scalar_t *__restrict__ input, const scalar_t *__restrict__ mul,
|
| 175 |
+
const scalar_t *__restrict__ weight, const scalar_t *__restrict__ bias,
|
| 176 |
+
const int32_t *__restrict__ offsets,
|
| 177 |
+
const float *__restrict__ scores, // nullable, always fp32
|
| 178 |
+
const acc_t eps, const int D, const int num_experts,
|
| 179 |
+
const int expert_offset, const acc_t hidden_clamp) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 181 |
const int row = blockIdx.x;
|
| 182 |
const int64_t off = (int64_t)row * D;
|
| 183 |
|
| 184 |
const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
|
| 185 |
+
const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1],
|
| 186 |
+
w2 = weight[eidx * 3 + 2];
|
| 187 |
const acc_t b_val = bias[eidx];
|
| 188 |
const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
|
| 189 |
|
| 190 |
acc_t s2 = 0, s4 = 0, s6 = 0;
|
| 191 |
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 192 |
acc_t x = input[off + i];
|
| 193 |
+
if (do_clamp)
|
| 194 |
+
x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
|
| 195 |
acc_t x2 = x * x;
|
| 196 |
+
s2 += x2;
|
| 197 |
+
s4 += x2 * x2;
|
| 198 |
+
s6 += x2 * x2 * x2;
|
| 199 |
}
|
| 200 |
|
| 201 |
float4 sums = block_reduce_f4<BLOCK_SIZE>(make_float4(s2, s4, s6, 0.f));
|
|
|
|
| 205 |
const acc_t ir3 = rsqrtf(sums.z * inv_d + eps);
|
| 206 |
|
| 207 |
if (threadIdx.x == 0) {
|
| 208 |
+
inv_rms[row * 3] = ir1;
|
| 209 |
+
inv_rms[row * 3 + 1] = ir2;
|
| 210 |
+
inv_rms[row * 3 + 2] = ir3;
|
| 211 |
}
|
| 212 |
|
| 213 |
const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
|
| 214 |
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 215 |
acc_t x = input[off + i];
|
| 216 |
+
if (do_clamp)
|
| 217 |
+
x = fminf(fmaxf(x, -hidden_clamp), hidden_clamp);
|
| 218 |
acc_t m = (acc_t)mul[off + i];
|
| 219 |
+
if (do_clamp)
|
| 220 |
+
m = fminf(fmaxf(m, -hidden_clamp), hidden_clamp);
|
| 221 |
acc_t x2 = x * x, x3 = x2 * x;
|
| 222 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 223 |
acc_t out_val = poly * m * score;
|
| 224 |
+
if (do_clamp)
|
| 225 |
+
out_val = fminf(fmaxf(out_val, -hidden_clamp), hidden_clamp);
|
| 226 |
output[off + i] = (scalar_t)out_val;
|
| 227 |
}
|
| 228 |
}
|
|
|
|
| 234 |
template <typename scalar_t, typename acc_t, int width, int BLOCK_SIZE>
|
| 235 |
__global__ void __launch_bounds__(BLOCK_SIZE, 65536 / (BLOCK_SIZE * 64))
|
| 236 |
grouped_poly_norm_bwd_kernel(
|
| 237 |
+
scalar_t *__restrict__ grad_input, scalar_t *__restrict__ grad_mul,
|
| 238 |
+
float *__restrict__ weight_grad, // [num_total_experts, 3] fp32
|
| 239 |
+
float *__restrict__ bias_grad, // [num_total_experts] fp32
|
|
|
|
| 240 |
const scalar_t *__restrict__ grad_output,
|
| 241 |
+
const scalar_t *__restrict__ input, const scalar_t *__restrict__ mul,
|
| 242 |
+
const scalar_t *__restrict__ weight, const scalar_t *__restrict__ bias,
|
| 243 |
+
const int32_t *__restrict__ offsets, const acc_t *__restrict__ inv_rms,
|
| 244 |
+
const float *__restrict__ scores, // nullable, always fp32
|
| 245 |
+
acc_t *__restrict__ grad_scores, // nullable (null when scores is null)
|
|
|
|
|
|
|
|
|
|
| 246 |
const acc_t eps, const int D, const int num_experts,
|
| 247 |
+
const int expert_offset, const acc_t hidden_clamp) {
|
|
|
|
| 248 |
using v_t = vec_t<scalar_t, width>;
|
| 249 |
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 250 |
|
|
|
|
| 253 |
const int64_t base = (int64_t)row * vec_d;
|
| 254 |
|
| 255 |
const v_t *__restrict__ in_v = reinterpret_cast<const v_t *>(input) + base;
|
| 256 |
+
const v_t *__restrict__ go_v =
|
| 257 |
+
reinterpret_cast<const v_t *>(grad_output) + base;
|
| 258 |
const v_t *__restrict__ m_v = reinterpret_cast<const v_t *>(mul) + base;
|
| 259 |
|
| 260 |
const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
|
|
|
|
| 283 |
#pragma unroll
|
| 284 |
for (int j = 0; j < width; ++j) {
|
| 285 |
acc_t x_orig = xv.data[j];
|
| 286 |
+
acc_t x =
|
| 287 |
+
do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
|
| 288 |
acc_t m_orig = (acc_t)mv.data[j];
|
| 289 |
+
acc_t m =
|
| 290 |
+
do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
|
| 291 |
acc_t go = (acc_t)gv.data[j];
|
| 292 |
|
| 293 |
// Output clamp mask: recompute pre-clamp output
|
|
|
|
| 295 |
acc_t x2 = x * x, x3 = x2 * x;
|
| 296 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 297 |
acc_t out_pre = poly * m * score;
|
| 298 |
+
if (fabsf(out_pre) > hidden_clamp)
|
| 299 |
+
go = acc_t(0);
|
| 300 |
}
|
| 301 |
|
| 302 |
acc_t x2 = x * x;
|
|
|
|
| 308 |
}
|
| 309 |
}
|
| 310 |
|
| 311 |
+
float4 sums =
|
| 312 |
+
block_reduce_f4<BLOCK_SIZE>(make_float4(sdpx, sdpx2, sdpx3, sdp));
|
| 313 |
|
| 314 |
const acc_t inv_d = acc_t(1) / D;
|
| 315 |
const acc_t s1 = sums.x * inv_d;
|
|
|
|
| 336 |
#pragma unroll
|
| 337 |
for (int j = 0; j < width; ++j) {
|
| 338 |
acc_t x_orig = xv.data[j];
|
| 339 |
+
acc_t x =
|
| 340 |
+
do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
|
| 341 |
acc_t m_orig = (acc_t)mv.data[j];
|
| 342 |
+
acc_t m =
|
| 343 |
+
do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
|
| 344 |
acc_t x2 = x * x;
|
| 345 |
acc_t x3 = x2 * x;
|
| 346 |
acc_t go = (acc_t)gv.data[j];
|
|
|
|
| 349 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
|
| 350 |
if (do_clamp) {
|
| 351 |
acc_t out_pre = (poly + b_val) * m * score;
|
| 352 |
+
if (fabsf(out_pre) > hidden_clamp)
|
| 353 |
+
go = acc_t(0);
|
| 354 |
}
|
| 355 |
|
| 356 |
acc_t dp = go * m * score;
|
| 357 |
|
| 358 |
// grad_mul with mul clamp mask
|
| 359 |
acc_t gm_val = go * (poly + b_val) * score;
|
| 360 |
+
if (do_clamp && fabsf(m_orig) > hidden_clamp)
|
| 361 |
+
gm_val = acc_t(0);
|
| 362 |
gm.data[j] = (scalar_t)gm_val;
|
| 363 |
|
| 364 |
// grad_input with input clamp mask
|
| 365 |
acc_t g = ir1 * (w2 * dp - x * cx);
|
| 366 |
g += acc_t(2) * x * ir2 * (w1 * dp - x2 * cx2);
|
| 367 |
g += acc_t(3) * x2 * ir3 * (w0 * dp - x3 * cx3);
|
| 368 |
+
if (do_clamp && fabsf(x_orig) > hidden_clamp)
|
| 369 |
+
g = acc_t(0);
|
| 370 |
gi.data[j] = (scalar_t)g;
|
| 371 |
|
| 372 |
dw0 += dp * x3 * ir3;
|
| 373 |
dw1 += dp * x2 * ir2;
|
| 374 |
dw2 += dp * x * ir1;
|
| 375 |
+
gs_acc += go * (poly + b_val) * m; // grad_scores accumulator
|
| 376 |
}
|
| 377 |
|
| 378 |
gi_v[i] = gi;
|
|
|
|
| 397 |
// Scalar fallback (width == 0)
|
| 398 |
// ---------------------------------------------------------------------------
|
| 399 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
| 400 |
+
__global__ void __launch_bounds__(BLOCK_SIZE) grouped_poly_norm_bwd_scalar(
|
| 401 |
+
scalar_t *__restrict__ grad_input, scalar_t *__restrict__ grad_mul,
|
| 402 |
+
float *__restrict__ weight_grad, float *__restrict__ bias_grad,
|
| 403 |
+
const scalar_t *__restrict__ grad_output,
|
| 404 |
+
const scalar_t *__restrict__ input, const scalar_t *__restrict__ mul,
|
| 405 |
+
const scalar_t *__restrict__ weight, const scalar_t *__restrict__ bias,
|
| 406 |
+
const int32_t *__restrict__ offsets, const acc_t *__restrict__ inv_rms,
|
| 407 |
+
const float *__restrict__ scores, // nullable, always fp32
|
| 408 |
+
acc_t *__restrict__ grad_scores, // nullable
|
| 409 |
+
const acc_t eps, const int D, const int num_experts,
|
| 410 |
+
const int expert_offset, const acc_t hidden_clamp) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
const bool do_clamp = (hidden_clamp >= acc_t(0));
|
| 412 |
const int row = blockIdx.x;
|
| 413 |
const int64_t off = (int64_t)row * D;
|
| 414 |
|
| 415 |
const int eidx = find_expert(offsets, num_experts, row) + expert_offset;
|
| 416 |
+
const acc_t w0 = weight[eidx * 3], w1 = weight[eidx * 3 + 1],
|
| 417 |
+
w2 = weight[eidx * 3 + 2];
|
| 418 |
const acc_t b_val = bias[eidx];
|
| 419 |
+
const acc_t ir1 = inv_rms[row * 3], ir2 = inv_rms[row * 3 + 1],
|
| 420 |
+
ir3 = inv_rms[row * 3 + 2];
|
| 421 |
const acc_t w2ir1 = w2 * ir1, w1ir2 = w1 * ir2, w0ir3 = w0 * ir3;
|
| 422 |
const acc_t score = (scores != nullptr) ? (acc_t)scores[row] : acc_t(1);
|
| 423 |
|
|
|
|
| 425 |
acc_t sdpx = 0, sdpx2 = 0, sdpx3 = 0, sdp = 0;
|
| 426 |
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 427 |
acc_t x_orig = input[off + i];
|
| 428 |
+
acc_t x =
|
| 429 |
+
do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
|
| 430 |
acc_t m_orig = (acc_t)mul[off + i];
|
| 431 |
+
acc_t m =
|
| 432 |
+
do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
|
| 433 |
acc_t go = (acc_t)grad_output[off + i];
|
| 434 |
|
| 435 |
if (do_clamp) {
|
| 436 |
acc_t x2 = x * x, x3 = x2 * x;
|
| 437 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1 + b_val;
|
| 438 |
acc_t out_pre = poly * m * score;
|
| 439 |
+
if (fabsf(out_pre) > hidden_clamp)
|
| 440 |
+
go = acc_t(0);
|
| 441 |
}
|
| 442 |
|
| 443 |
acc_t x2 = x * x;
|
| 444 |
acc_t dp = go * m * score;
|
| 445 |
+
sdp += dp;
|
| 446 |
+
sdpx += dp * x;
|
| 447 |
+
sdpx2 += dp * x2;
|
| 448 |
+
sdpx3 += dp * x2 * x;
|
| 449 |
}
|
| 450 |
|
| 451 |
+
float4 sums =
|
| 452 |
+
block_reduce_f4<BLOCK_SIZE>(make_float4(sdpx, sdpx2, sdpx3, sdp));
|
| 453 |
const acc_t inv_d = acc_t(1) / D;
|
| 454 |
const acc_t s1 = sums.x * inv_d, s2 = sums.y * inv_d, s3 = sums.z * inv_d;
|
| 455 |
+
const acc_t cx = w2 * s1 * ir1 * ir1, cx2 = w1 * s2 * ir2 * ir2,
|
| 456 |
+
cx3 = w0 * s3 * ir3 * ir3;
|
| 457 |
|
| 458 |
// Pass 2: grads with clamp masks
|
| 459 |
acc_t dw0 = 0, dw1 = 0, dw2 = 0, gs_acc = 0;
|
| 460 |
for (int i = threadIdx.x; i < D; i += BLOCK_SIZE) {
|
| 461 |
acc_t x_orig = input[off + i], m_orig = (acc_t)mul[off + i];
|
| 462 |
+
acc_t x =
|
| 463 |
+
do_clamp ? fminf(fmaxf(x_orig, -hidden_clamp), hidden_clamp) : x_orig;
|
| 464 |
+
acc_t m =
|
| 465 |
+
do_clamp ? fminf(fmaxf(m_orig, -hidden_clamp), hidden_clamp) : m_orig;
|
| 466 |
acc_t go = (acc_t)grad_output[off + i];
|
| 467 |
acc_t x2 = x * x, x3 = x2 * x;
|
| 468 |
acc_t poly = x3 * w0ir3 + x2 * w1ir2 + x * w2ir1;
|
|
|
|
| 470 |
// Output clamp mask
|
| 471 |
if (do_clamp) {
|
| 472 |
acc_t out_pre = (poly + b_val) * m * score;
|
| 473 |
+
if (fabsf(out_pre) > hidden_clamp)
|
| 474 |
+
go = acc_t(0);
|
| 475 |
}
|
| 476 |
|
| 477 |
acc_t dp = go * m * score;
|
| 478 |
|
| 479 |
acc_t gm_val = go * (poly + b_val) * score;
|
| 480 |
+
if (do_clamp && fabsf(m_orig) > hidden_clamp)
|
| 481 |
+
gm_val = acc_t(0);
|
| 482 |
grad_mul[off + i] = (scalar_t)gm_val;
|
| 483 |
|
| 484 |
+
acc_t g = ir1 * (w2 * dp - x * cx) +
|
| 485 |
+
acc_t(2) * x * ir2 * (w1 * dp - x2 * cx2) +
|
| 486 |
+
acc_t(3) * x2 * ir3 * (w0 * dp - x3 * cx3);
|
| 487 |
+
if (do_clamp && fabsf(x_orig) > hidden_clamp)
|
| 488 |
+
g = acc_t(0);
|
| 489 |
grad_input[off + i] = (scalar_t)g;
|
| 490 |
|
| 491 |
+
dw0 += dp * x3 * ir3;
|
| 492 |
+
dw1 += dp * x2 * ir2;
|
| 493 |
+
dw2 += dp * x * ir1;
|
| 494 |
gs_acc += go * (poly + b_val) * m;
|
| 495 |
}
|
| 496 |
|
|
|
|
| 512 |
// Internal helpers — shared kernel dispatch
|
| 513 |
// ---------------------------------------------------------------------------
|
| 514 |
#define FWD_LAUNCH(width_val, scalar_type_name) \
|
| 515 |
+
MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), scalar_type_name, [&] { \
|
| 516 |
+
motif::grouped_poly_norm_fwd_kernel<scalar_t, float, width_val, BLOCK> \
|
| 517 |
+
<<<grid, block, 0, stream>>>( \
|
| 518 |
+
output.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(), \
|
| 519 |
+
input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(), \
|
| 520 |
+
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), \
|
| 521 |
+
offsets.data_ptr<int32_t>(), scores_ptr, (float)eps, D, \
|
| 522 |
+
num_experts, (int)expert_offset, (float)hidden_clamp); \
|
| 523 |
+
})
|
|
|
|
|
|
|
| 524 |
|
| 525 |
static std::tuple<torch::Tensor, torch::Tensor>
|
| 526 |
_fwd_impl(const torch::Tensor &input, const torch::Tensor &mul,
|
| 527 |
+
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 528 |
+
const torch::Tensor &offsets, const float *scores_ptr, double eps,
|
| 529 |
+
int64_t expert_offset, double hidden_clamp) {
|
| 530 |
const int D = input.size(-1);
|
| 531 |
const int64_t N = input.size(0);
|
| 532 |
const int num_experts = offsets.size(0);
|
| 533 |
constexpr int BLOCK = 128;
|
| 534 |
+
dim3 grid(N);
|
| 535 |
+
dim3 block(BLOCK);
|
| 536 |
|
| 537 |
auto output = torch::empty_like(input);
|
| 538 |
auto inv_rms = torch::empty({N, 3}, input.options().dtype(torch::kFloat));
|
|
|
|
| 551 |
output.data_ptr<scalar_t>(), inv_rms.data_ptr<float>(),
|
| 552 |
input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
|
| 553 |
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
|
| 554 |
+
offsets.data_ptr<int32_t>(), scores_ptr, (float)eps, D,
|
| 555 |
+
num_experts, (int)expert_offset, (float)hidden_clamp);
|
|
|
|
| 556 |
});
|
| 557 |
}
|
| 558 |
return {output, inv_rms};
|
|
|
|
| 560 |
#undef FWD_LAUNCH
|
| 561 |
|
| 562 |
#define BWD_LAUNCH(width_val, scalar_type_name, kernel_name) \
|
| 563 |
+
MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), scalar_type_name, [&] { \
|
| 564 |
+
motif::kernel_name<scalar_t, float, width_val, BLOCK> \
|
| 565 |
+
<<<grid, block, 0, stream>>>( \
|
| 566 |
+
input_grad.data_ptr<scalar_t>(), mul_grad.data_ptr<scalar_t>(), \
|
| 567 |
+
wg_f32.data_ptr<float>(), bg_f32.data_ptr<float>(), \
|
| 568 |
+
grad_output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
|
| 569 |
+
mul.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
| 570 |
+
bias.data_ptr<scalar_t>(), offsets.data_ptr<int32_t>(), \
|
| 571 |
+
inv_rms.data_ptr<float>(), scores_ptr, gs_ptr, (float)eps, D, \
|
| 572 |
+
num_experts, (int)expert_offset, (float)hidden_clamp); \
|
| 573 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
|
| 575 |
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
| 576 |
+
torch::Tensor>
|
| 577 |
_bwd_impl(const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 578 |
+
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 579 |
+
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 580 |
+
const torch::Tensor &inv_rms, const float *scores_ptr, float *gs_ptr,
|
| 581 |
+
int64_t N, double eps, int64_t expert_offset, double hidden_clamp) {
|
|
|
|
| 582 |
const int D = input.size(-1);
|
| 583 |
const int num_experts = offsets.size(0);
|
| 584 |
constexpr int BLOCK = 128;
|
| 585 |
+
dim3 grid(N);
|
| 586 |
+
dim3 block(BLOCK);
|
| 587 |
|
| 588 |
auto input_grad = torch::empty_like(input);
|
| 589 |
auto mul_grad = torch::empty_like(mul);
|
| 590 |
+
auto wg_f32 =
|
| 591 |
+
torch::zeros({weight.size(0), 3}, input.options().dtype(torch::kFloat));
|
| 592 |
+
auto bg_f32 =
|
| 593 |
+
torch::zeros({bias.size(0)}, input.options().dtype(torch::kFloat));
|
| 594 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 595 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 596 |
|
|
|
|
| 604 |
motif::grouped_poly_norm_bwd_scalar<scalar_t, float, BLOCK>
|
| 605 |
<<<grid, block, 0, stream>>>(
|
| 606 |
input_grad.data_ptr<scalar_t>(),
|
| 607 |
+
mul_grad.data_ptr<scalar_t>(), wg_f32.data_ptr<float>(),
|
| 608 |
+
bg_f32.data_ptr<float>(), grad_output.data_ptr<scalar_t>(),
|
|
|
|
| 609 |
input.data_ptr<scalar_t>(), mul.data_ptr<scalar_t>(),
|
| 610 |
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
|
| 611 |
offsets.data_ptr<int32_t>(), inv_rms.data_ptr<float>(),
|
| 612 |
+
scores_ptr, gs_ptr, (float)eps, D, num_experts,
|
| 613 |
+
(int)expert_offset, (float)hidden_clamp);
|
|
|
|
| 614 |
});
|
| 615 |
}
|
| 616 |
|
|
|
|
| 625 |
// Public API: without scores
|
| 626 |
// ---------------------------------------------------------------------------
|
| 627 |
std::tuple<torch::Tensor, torch::Tensor>
|
| 628 |
+
grouped_poly_norm_forward(const torch::Tensor &input, const torch::Tensor &mul,
|
| 629 |
+
const torch::Tensor &weight,
|
| 630 |
+
const torch::Tensor &bias,
|
| 631 |
+
const torch::Tensor &offsets, double eps,
|
| 632 |
+
int64_t expert_offset, double hidden_clamp) {
|
| 633 |
+
return _fwd_impl(input, mul, weight, bias, offsets, nullptr, eps,
|
| 634 |
+
expert_offset, hidden_clamp);
|
| 635 |
}
|
| 636 |
|
| 637 |
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
| 638 |
+
grouped_poly_norm_backward(const torch::Tensor &grad_output,
|
| 639 |
+
const torch::Tensor &input, const torch::Tensor &mul,
|
| 640 |
+
const torch::Tensor &weight,
|
| 641 |
+
const torch::Tensor &bias,
|
| 642 |
+
const torch::Tensor &offsets,
|
| 643 |
+
const torch::Tensor &inv_rms, double eps,
|
| 644 |
+
int64_t expert_offset, double hidden_clamp) {
|
| 645 |
const int64_t N = input.size(0);
|
| 646 |
+
auto [ig, mg, wg, bg, _] =
|
| 647 |
+
_bwd_impl(grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 648 |
+
nullptr, nullptr, N, eps, expert_offset, hidden_clamp);
|
| 649 |
return {ig, mg, wg, bg};
|
| 650 |
}
|
| 651 |
|
| 652 |
// ---------------------------------------------------------------------------
|
| 653 |
// Public API: with scores
|
| 654 |
// ---------------------------------------------------------------------------
|
| 655 |
+
std::tuple<torch::Tensor, torch::Tensor> grouped_poly_norm_forward_scored(
|
|
|
|
| 656 |
const torch::Tensor &input, const torch::Tensor &mul,
|
| 657 |
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 658 |
+
const torch::Tensor &offsets, const torch::Tensor &scores, double eps,
|
| 659 |
+
int64_t expert_offset, double hidden_clamp) {
|
| 660 |
+
return _fwd_impl(input, mul, weight, bias, offsets, scores.data_ptr<float>(),
|
| 661 |
+
eps, expert_offset, hidden_clamp);
|
| 662 |
}
|
| 663 |
|
| 664 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
| 665 |
+
torch::Tensor>
|
| 666 |
grouped_poly_norm_backward_scored(
|
| 667 |
const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 668 |
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 669 |
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 670 |
+
const torch::Tensor &inv_rms, const torch::Tensor &scores, double eps,
|
| 671 |
+
int64_t expert_offset, double hidden_clamp) {
|
| 672 |
const int64_t N = input.size(0);
|
| 673 |
auto gs_f32 = torch::empty({N}, input.options().dtype(torch::kFloat));
|
| 674 |
+
auto [ig, mg, wg, bg, _] =
|
| 675 |
+
_bwd_impl(grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 676 |
+
scores.data_ptr<float>(), gs_f32.data_ptr<float>(), N, eps,
|
| 677 |
+
expert_offset, hidden_clamp);
|
| 678 |
auto gs = gs_f32.unsqueeze(-1);
|
| 679 |
return {ig, mg, wg, bg, gs};
|
| 680 |
}
|
benchmarks/cases/grouped_mul_poly.py
CHANGED
|
@@ -5,10 +5,8 @@ from common.diff_engine import DiffCase
|
|
| 5 |
|
| 6 |
torch._functorch.config.donated_buffer = False
|
| 7 |
|
| 8 |
-
from grouped_poly_norm import (
|
| 9 |
-
|
| 10 |
-
fused_mul_grouped_poly_norm_ref,
|
| 11 |
-
)
|
| 12 |
|
| 13 |
# 384 / 8 (EP) = 48 experts per rank
|
| 14 |
# total_tokens = bs * sl, which equals per-rank tokens
|
|
@@ -28,9 +26,14 @@ class GroupedRefModule(torch.nn.Module):
|
|
| 28 |
self.expert_offset = expert_offset
|
| 29 |
|
| 30 |
def forward(self, x, mul):
|
| 31 |
-
return fused_mul_grouped_poly_norm_ref(
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
class GroupedCUDAModule(torch.nn.Module):
|
|
@@ -45,8 +48,12 @@ class GroupedCUDAModule(torch.nn.Module):
|
|
| 45 |
self.expert_offset = expert_offset
|
| 46 |
|
| 47 |
def forward(self, x, mul):
|
| 48 |
-
return fused_mul_grouped_poly_norm(x,
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
expert_offset=self.expert_offset)
|
| 51 |
|
| 52 |
|
|
@@ -66,29 +73,28 @@ class GroupedMulPoly(DiffCase):
|
|
| 66 |
probs = torch.ones(num_experts) / num_experts
|
| 67 |
assignments = torch.multinomial(probs, total_tokens, replacement=True)
|
| 68 |
counts = torch.bincount(assignments, minlength=num_experts).tolist()
|
| 69 |
-
offsets = torch.cumsum(
|
| 70 |
-
torch.tensor(counts, dtype=torch.int32), dim=0)
|
| 71 |
|
| 72 |
return {
|
| 73 |
"x":
|
| 74 |
-
|
| 75 |
-
|
| 76 |
"mul":
|
| 77 |
-
|
| 78 |
-
|
| 79 |
"weight":
|
| 80 |
-
|
| 81 |
-
|
| 82 |
"bias":
|
| 83 |
-
|
| 84 |
"offsets":
|
| 85 |
-
|
| 86 |
"dim":
|
| 87 |
-
|
| 88 |
"eps":
|
| 89 |
-
|
| 90 |
"dtype":
|
| 91 |
-
|
| 92 |
}
|
| 93 |
|
| 94 |
def make_naive(self, I):
|
|
|
|
| 5 |
|
| 6 |
torch._functorch.config.donated_buffer = False
|
| 7 |
|
| 8 |
+
from grouped_poly_norm import (fused_mul_grouped_poly_norm,
|
| 9 |
+
fused_mul_grouped_poly_norm_ref)
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# 384 / 8 (EP) = 48 experts per rank
|
| 12 |
# total_tokens = bs * sl, which equals per-rank tokens
|
|
|
|
| 26 |
self.expert_offset = expert_offset
|
| 27 |
|
| 28 |
def forward(self, x, mul):
|
| 29 |
+
return fused_mul_grouped_poly_norm_ref(
|
| 30 |
+
x,
|
| 31 |
+
mul,
|
| 32 |
+
self.weight,
|
| 33 |
+
self.bias,
|
| 34 |
+
self.offsets,
|
| 35 |
+
self.eps,
|
| 36 |
+
expert_offset=self.expert_offset)
|
| 37 |
|
| 38 |
|
| 39 |
class GroupedCUDAModule(torch.nn.Module):
|
|
|
|
| 48 |
self.expert_offset = expert_offset
|
| 49 |
|
| 50 |
def forward(self, x, mul):
|
| 51 |
+
return fused_mul_grouped_poly_norm(x,
|
| 52 |
+
mul,
|
| 53 |
+
self.weight,
|
| 54 |
+
self.bias,
|
| 55 |
+
self.offsets,
|
| 56 |
+
self.eps,
|
| 57 |
expert_offset=self.expert_offset)
|
| 58 |
|
| 59 |
|
|
|
|
| 73 |
probs = torch.ones(num_experts) / num_experts
|
| 74 |
assignments = torch.multinomial(probs, total_tokens, replacement=True)
|
| 75 |
counts = torch.bincount(assignments, minlength=num_experts).tolist()
|
| 76 |
+
offsets = torch.cumsum(torch.tensor(counts, dtype=torch.int32), dim=0)
|
|
|
|
| 77 |
|
| 78 |
return {
|
| 79 |
"x":
|
| 80 |
+
torch.randn(total_tokens, hidden, dtype=dtype,
|
| 81 |
+
requires_grad=True) * 0.5,
|
| 82 |
"mul":
|
| 83 |
+
torch.randn(total_tokens, hidden, dtype=dtype,
|
| 84 |
+
requires_grad=True) * 0.5,
|
| 85 |
"weight":
|
| 86 |
+
torch.ones(num_experts, 3, dtype=dtype) / 3 +
|
| 87 |
+
torch.randn(num_experts, 3, dtype=dtype) * 0.01,
|
| 88 |
"bias":
|
| 89 |
+
torch.randn(num_experts, 1, dtype=dtype) * 0.01,
|
| 90 |
"offsets":
|
| 91 |
+
offsets,
|
| 92 |
"dim":
|
| 93 |
+
hidden,
|
| 94 |
"eps":
|
| 95 |
+
eps,
|
| 96 |
"dtype":
|
| 97 |
+
dtype,
|
| 98 |
}
|
| 99 |
|
| 100 |
def make_naive(self, I):
|
benchmarks/profile_bwd.py
DELETED
|
@@ -1,146 +0,0 @@
|
|
| 1 |
-
"""Profiling script for grouped polynorm backward kernel using torch.profiler."""
|
| 2 |
-
import argparse
|
| 3 |
-
import torch
|
| 4 |
-
import torch.cuda
|
| 5 |
-
from torch.profiler import profile, ProfilerActivity
|
| 6 |
-
from grouped_poly_norm import fused_mul_grouped_poly_norm
|
| 7 |
-
|
| 8 |
-
torch.set_default_device("cuda")
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def make_inputs(N, D, num_experts):
|
| 12 |
-
torch.manual_seed(42)
|
| 13 |
-
probs = torch.ones(num_experts) / num_experts
|
| 14 |
-
assignments = torch.multinomial(probs, N, replacement=True)
|
| 15 |
-
counts = torch.bincount(assignments, minlength=num_experts).tolist()
|
| 16 |
-
offsets = torch.cumsum(
|
| 17 |
-
torch.tensor(counts, dtype=torch.int32), dim=0)
|
| 18 |
-
|
| 19 |
-
x = torch.randn(N, D, dtype=torch.bfloat16, requires_grad=True) * 0.5
|
| 20 |
-
m = torch.randn(N, D, dtype=torch.bfloat16, requires_grad=True) * 0.5
|
| 21 |
-
w = (torch.ones(num_experts, 3, dtype=torch.bfloat16) / 3
|
| 22 |
-
).requires_grad_(True)
|
| 23 |
-
b = (torch.randn(num_experts, 1, dtype=torch.bfloat16) * 0.01
|
| 24 |
-
).requires_grad_(True)
|
| 25 |
-
return x, m, w, b, offsets
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def main():
|
| 29 |
-
parser = argparse.ArgumentParser()
|
| 30 |
-
parser.add_argument("--tokens", type=int, default=4096)
|
| 31 |
-
parser.add_argument("--dim", type=int, default=1280)
|
| 32 |
-
parser.add_argument("--experts", type=int, default=48)
|
| 33 |
-
parser.add_argument("--output", type=str, default="/tmp/profile")
|
| 34 |
-
args = parser.parse_args()
|
| 35 |
-
|
| 36 |
-
N, D, num_experts = args.tokens, args.dim, args.experts
|
| 37 |
-
|
| 38 |
-
# Warmup (fresh inputs each time to avoid graph reuse issues)
|
| 39 |
-
for _ in range(3):
|
| 40 |
-
x, m, w, b, offsets = make_inputs(N, D, num_experts)
|
| 41 |
-
out = fused_mul_grouped_poly_norm(x, m, w, b, offsets)
|
| 42 |
-
out.sum().backward()
|
| 43 |
-
torch.cuda.synchronize()
|
| 44 |
-
|
| 45 |
-
# Profiled: mimic do_bench — forward once, backward multiple times with retain_graph
|
| 46 |
-
x, m, w, b, offsets = make_inputs(N, D, num_experts)
|
| 47 |
-
out = fused_mul_grouped_poly_norm(x, m, w, b, offsets)
|
| 48 |
-
gin = [x, m] + [w, b]
|
| 49 |
-
g = [torch.randn_like(out)]
|
| 50 |
-
|
| 51 |
-
# Warmup backward
|
| 52 |
-
for _ in range(5):
|
| 53 |
-
torch.autograd.grad(out, gin, g, retain_graph=True, allow_unused=True)
|
| 54 |
-
torch.cuda.synchronize()
|
| 55 |
-
|
| 56 |
-
with profile(
|
| 57 |
-
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
| 58 |
-
record_shapes=True,
|
| 59 |
-
with_stack=True,
|
| 60 |
-
) as prof:
|
| 61 |
-
for _ in range(100):
|
| 62 |
-
torch.autograd.grad(out, gin, g, retain_graph=True, allow_unused=True)
|
| 63 |
-
torch.cuda.synchronize()
|
| 64 |
-
|
| 65 |
-
# Print kernel-level stats
|
| 66 |
-
print(f"\n=== Kernel Table (N={N}, D={D}) ===")
|
| 67 |
-
print(prof.key_averages().table(
|
| 68 |
-
sort_by="cuda_time_total", row_limit=20))
|
| 69 |
-
|
| 70 |
-
# Export chrome trace
|
| 71 |
-
trace_path = f"{args.output}_trace_N{N}.json"
|
| 72 |
-
prof.export_chrome_trace(trace_path)
|
| 73 |
-
print(f"\nTrace exported to {trace_path}")
|
| 74 |
-
|
| 75 |
-
# === Occupancy analysis from Triton kernel metadata ===
|
| 76 |
-
print(f"\n=== Occupancy Analysis ===")
|
| 77 |
-
|
| 78 |
-
props = torch.cuda.get_device_properties(0)
|
| 79 |
-
print(f"GPU: {props.name}")
|
| 80 |
-
print(f"SMs: {props.multi_processor_count}")
|
| 81 |
-
print(f"Max threads/SM: {props.max_threads_per_multi_processor}")
|
| 82 |
-
print(f"Regs/SM: {props.regs_per_multiprocessor}")
|
| 83 |
-
print(f"Shared mem/block: {props.shared_memory_per_block} bytes")
|
| 84 |
-
|
| 85 |
-
# Get register info from Triton compiled cubins
|
| 86 |
-
try:
|
| 87 |
-
import glob
|
| 88 |
-
import json
|
| 89 |
-
import subprocess
|
| 90 |
-
cache_dir = os.path.expanduser("~/.triton/cache")
|
| 91 |
-
|
| 92 |
-
# Find metadata JSON files
|
| 93 |
-
json_files = sorted(glob.glob(f"{cache_dir}/**/*.json", recursive=True),
|
| 94 |
-
key=os.path.getmtime, reverse=True)
|
| 95 |
-
print(f"\nFound {len(json_files)} compiled kernel metadata files")
|
| 96 |
-
for jf in json_files[:10]:
|
| 97 |
-
try:
|
| 98 |
-
with open(jf) as f:
|
| 99 |
-
meta = json.load(f)
|
| 100 |
-
if isinstance(meta, dict):
|
| 101 |
-
n_regs = meta.get('num_regs', meta.get('n_regs', None))
|
| 102 |
-
n_spills = meta.get('num_spills', meta.get('n_spills', None))
|
| 103 |
-
name = meta.get('name', os.path.basename(jf))
|
| 104 |
-
shared = meta.get('shared', None)
|
| 105 |
-
if n_regs is not None:
|
| 106 |
-
print(f" {name}: regs={n_regs}, spills={n_spills}, shared={shared}")
|
| 107 |
-
except Exception:
|
| 108 |
-
pass
|
| 109 |
-
|
| 110 |
-
# Also try cuobjdump on recent cubins
|
| 111 |
-
cubin_files = sorted(glob.glob(f"{cache_dir}/**/*.cubin", recursive=True),
|
| 112 |
-
key=os.path.getmtime, reverse=True)
|
| 113 |
-
print(f"\nFound {len(cubin_files)} cubins, inspecting latest:")
|
| 114 |
-
for cb in cubin_files[:5]:
|
| 115 |
-
try:
|
| 116 |
-
result = subprocess.run(
|
| 117 |
-
["cuobjdump", "-res-usage", cb],
|
| 118 |
-
capture_output=True, text=True, timeout=5)
|
| 119 |
-
if result.returncode == 0 and result.stdout.strip():
|
| 120 |
-
print(f"\n {os.path.basename(cb)}:")
|
| 121 |
-
for line in result.stdout.strip().split('\n'):
|
| 122 |
-
print(f" {line}")
|
| 123 |
-
except Exception as e:
|
| 124 |
-
print(f" cuobjdump failed: {e}")
|
| 125 |
-
break
|
| 126 |
-
except Exception as e:
|
| 127 |
-
print(f"Cache inspection error: {e}")
|
| 128 |
-
|
| 129 |
-
# Calculate theoretical occupancy for different register counts
|
| 130 |
-
print("\n=== Theoretical Occupancy (num_warps=4, 128 threads/block) ===")
|
| 131 |
-
threads_per_block = 128
|
| 132 |
-
max_threads = props.max_threads_per_multi_processor
|
| 133 |
-
total_regs = props.regs_per_multiprocessor
|
| 134 |
-
for n_regs in [64, 96, 128, 160, 192, 224, 256]:
|
| 135 |
-
regs_per_block = n_regs * threads_per_block
|
| 136 |
-
max_blocks_by_regs = total_regs // regs_per_block
|
| 137 |
-
max_blocks_by_threads = max_threads // threads_per_block
|
| 138 |
-
blocks = min(max_blocks_by_regs, max_blocks_by_threads, 32)
|
| 139 |
-
active_threads = blocks * threads_per_block
|
| 140 |
-
occupancy = active_threads / max_threads * 100
|
| 141 |
-
print(f" {n_regs:3d} regs/thread -> {blocks:2d} blocks/SM -> "
|
| 142 |
-
f"{active_threads:4d} threads -> {occupancy:.1f}% occupancy")
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
if __name__ == "__main__":
|
| 146 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/run_cases.py
CHANGED
|
@@ -28,8 +28,10 @@ def plot_result(r_path, columns=None):
|
|
| 28 |
import pandas as pd
|
| 29 |
df = pd.read_csv(r_path + ".csv")
|
| 30 |
if columns is None:
|
| 31 |
-
columns = [
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
plt.figure(figsize=(12, 6))
|
| 34 |
ax = df.plot(x="config", y=columns, kind="bar", ax=plt.gca())
|
| 35 |
ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
|
|
@@ -64,6 +66,11 @@ def main():
|
|
| 64 |
default="bf16",
|
| 65 |
help="Data type for benchmarking (default: bf16)",
|
| 66 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
args = ap.parse_args()
|
| 68 |
|
| 69 |
dtype_map = {
|
|
@@ -81,12 +88,14 @@ def main():
|
|
| 81 |
mod = importlib.import_module(f"cases.{args.case}")
|
| 82 |
case: DiffCase = mod.CASE
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
|
| 91 |
for dtype_name, dtype in dtypes:
|
| 92 |
print(f"\n{'=' * 60}")
|
|
@@ -161,28 +170,40 @@ def main():
|
|
| 161 |
itertools.product(dim, batch_size_range, seq_length_range))
|
| 162 |
|
| 163 |
if is_grouped:
|
| 164 |
-
|
| 165 |
-
|
| 166 |
"naive": "Naive",
|
| 167 |
"compiled": "Compiled",
|
| 168 |
"cuda": "Triton",
|
| 169 |
"speedup": "SpeedUp",
|
| 170 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
else:
|
| 172 |
-
|
| 173 |
-
|
| 174 |
"naive": "Naive",
|
| 175 |
"cuda": "Cuda",
|
| 176 |
"speedup": "SpeedUp",
|
| 177 |
}
|
|
|
|
|
|
|
| 178 |
|
| 179 |
bench = make_fwd_benchmark_for_case(
|
| 180 |
case=case,
|
| 181 |
configs=configs,
|
| 182 |
plot_name=f"{args.case}-{dtype_name}-fwd-perf",
|
| 183 |
dtype=dtype,
|
| 184 |
-
line_vals=
|
| 185 |
-
line_names=
|
|
|
|
|
|
|
| 186 |
)
|
| 187 |
|
| 188 |
bench.run(print_data=True, save_path=save_dir)
|
|
@@ -192,8 +213,10 @@ def main():
|
|
| 192 |
configs=configs,
|
| 193 |
plot_name=f"{args.case}-{dtype_name}-bwd-perf",
|
| 194 |
dtype=dtype,
|
| 195 |
-
line_vals=
|
| 196 |
-
line_names=
|
|
|
|
|
|
|
| 197 |
)
|
| 198 |
|
| 199 |
bench.run(print_data=True, save_path=save_dir)
|
|
|
|
| 28 |
import pandas as pd
|
| 29 |
df = pd.read_csv(r_path + ".csv")
|
| 30 |
if columns is None:
|
| 31 |
+
columns = [
|
| 32 |
+
c for c in ["Naive", "Compiled", "Cuda", "Triton"]
|
| 33 |
+
if c in df.columns
|
| 34 |
+
]
|
| 35 |
plt.figure(figsize=(12, 6))
|
| 36 |
ax = df.plot(x="config", y=columns, kind="bar", ax=plt.gca())
|
| 37 |
ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
|
|
|
|
| 66 |
default="bf16",
|
| 67 |
help="Data type for benchmarking (default: bf16)",
|
| 68 |
)
|
| 69 |
+
ap.add_argument(
|
| 70 |
+
"--profile",
|
| 71 |
+
action="store_true",
|
| 72 |
+
help="Export chrome traces for backward benchmarks",
|
| 73 |
+
)
|
| 74 |
args = ap.parse_args()
|
| 75 |
|
| 76 |
dtype_map = {
|
|
|
|
| 88 |
mod = importlib.import_module(f"cases.{args.case}")
|
| 89 |
case: DiffCase = mod.CASE
|
| 90 |
|
| 91 |
+
# Correctness checks across multiple configs
|
| 92 |
+
for bs, sl, hid in [(2, 128, 4096), (8, 4096, 1280), (1, 32768, 1280)]:
|
| 93 |
+
print(
|
| 94 |
+
f"Checking correctness: bs={bs}, sl={sl}, D={hid} "
|
| 95 |
+
f"(N={bs*sl})...",
|
| 96 |
+
end=" ")
|
| 97 |
+
calculate_diff(case, batch_size=bs, seq_len=sl, hidden_size=hid)
|
| 98 |
+
print("✅")
|
| 99 |
|
| 100 |
for dtype_name, dtype in dtypes:
|
| 101 |
print(f"\n{'=' * 60}")
|
|
|
|
| 170 |
itertools.product(dim, batch_size_range, seq_length_range))
|
| 171 |
|
| 172 |
if is_grouped:
|
| 173 |
+
fwd_line_vals = ("naive", "compiled", "cuda", "speedup")
|
| 174 |
+
fwd_line_names = {
|
| 175 |
"naive": "Naive",
|
| 176 |
"compiled": "Compiled",
|
| 177 |
"cuda": "Triton",
|
| 178 |
"speedup": "SpeedUp",
|
| 179 |
}
|
| 180 |
+
bwd_line_vals = ("naive", "compiled", "compiled_cuda",
|
| 181 |
+
"speedup")
|
| 182 |
+
bwd_line_names = {
|
| 183 |
+
"naive": "Naive",
|
| 184 |
+
"compiled": "Compiled",
|
| 185 |
+
"compiled_cuda": "CompiledCUDA",
|
| 186 |
+
"speedup": "SpeedUp",
|
| 187 |
+
}
|
| 188 |
else:
|
| 189 |
+
fwd_line_vals = ("naive", "cuda", "speedup")
|
| 190 |
+
fwd_line_names = {
|
| 191 |
"naive": "Naive",
|
| 192 |
"cuda": "Cuda",
|
| 193 |
"speedup": "SpeedUp",
|
| 194 |
}
|
| 195 |
+
bwd_line_vals = fwd_line_vals
|
| 196 |
+
bwd_line_names = fwd_line_names
|
| 197 |
|
| 198 |
bench = make_fwd_benchmark_for_case(
|
| 199 |
case=case,
|
| 200 |
configs=configs,
|
| 201 |
plot_name=f"{args.case}-{dtype_name}-fwd-perf",
|
| 202 |
dtype=dtype,
|
| 203 |
+
line_vals=fwd_line_vals,
|
| 204 |
+
line_names=fwd_line_names,
|
| 205 |
+
profile=args.profile,
|
| 206 |
+
profile_dir=os.path.join(save_dir, "traces"),
|
| 207 |
)
|
| 208 |
|
| 209 |
bench.run(print_data=True, save_path=save_dir)
|
|
|
|
| 213 |
configs=configs,
|
| 214 |
plot_name=f"{args.case}-{dtype_name}-bwd-perf",
|
| 215 |
dtype=dtype,
|
| 216 |
+
line_vals=bwd_line_vals,
|
| 217 |
+
line_names=bwd_line_names,
|
| 218 |
+
profile=args.profile,
|
| 219 |
+
profile_dir=os.path.join(save_dir, "traces"),
|
| 220 |
)
|
| 221 |
|
| 222 |
bench.run(print_data=True, save_path=save_dir)
|
registration.h
CHANGED
|
@@ -2,8 +2,8 @@
|
|
| 2 |
|
| 3 |
// Local build compatibility shim for kernel-builder's registration.h
|
| 4 |
|
| 5 |
-
#include <torch/library.h>
|
| 6 |
#include <torch/extension.h>
|
|
|
|
| 7 |
|
| 8 |
// TORCH_LIBRARY_EXPAND may not be defined in all PyTorch versions
|
| 9 |
#ifndef TORCH_LIBRARY_EXPAND
|
|
@@ -11,4 +11,5 @@
|
|
| 11 |
#endif
|
| 12 |
|
| 13 |
// Generate the PyInit_<name> entry point for the shared library
|
| 14 |
-
#define REGISTER_EXTENSION(name)
|
|
|
|
|
|
| 2 |
|
| 3 |
// Local build compatibility shim for kernel-builder's registration.h
|
| 4 |
|
|
|
|
| 5 |
#include <torch/extension.h>
|
| 6 |
+
#include <torch/library.h>
|
| 7 |
|
| 8 |
// TORCH_LIBRARY_EXPAND may not be defined in all PyTorch versions
|
| 9 |
#ifndef TORCH_LIBRARY_EXPAND
|
|
|
|
| 11 |
#endif
|
| 12 |
|
| 13 |
// Generate the PyInit_<name> entry point for the shared library
|
| 14 |
+
#define REGISTER_EXTENSION(name) \
|
| 15 |
+
PYBIND11_MODULE(name, m) {}
|
setup.py
CHANGED
|
@@ -44,9 +44,9 @@ NVCC_FLAGS = [
|
|
| 44 |
"--use_fast_math",
|
| 45 |
"-std=c++17",
|
| 46 |
# Generate code for common architectures
|
| 47 |
-
"-gencode=arch=compute_80,code=sm_80",
|
| 48 |
-
"-gencode=arch=compute_89,code=sm_89",
|
| 49 |
-
"-gencode=arch=compute_90,code=sm_90",
|
| 50 |
]
|
| 51 |
|
| 52 |
# Check for B200 support (sm_100, requires CUDA 12.8+)
|
|
|
|
| 44 |
"--use_fast_math",
|
| 45 |
"-std=c++17",
|
| 46 |
# Generate code for common architectures
|
| 47 |
+
"-gencode=arch=compute_80,code=sm_80", # A100
|
| 48 |
+
"-gencode=arch=compute_89,code=sm_89", # L40/4090
|
| 49 |
+
"-gencode=arch=compute_90,code=sm_90", # H100
|
| 50 |
]
|
| 51 |
|
| 52 |
# Check for B200 support (sm_100, requires CUDA 12.8+)
|
tests/test_fused_mul_grouped_poly_norm.py
CHANGED
|
@@ -1,10 +1,6 @@
|
|
| 1 |
import pytest
|
| 2 |
import torch
|
| 3 |
-
|
| 4 |
-
from grouped_poly_norm import (
|
| 5 |
-
_has_cuda_ops,
|
| 6 |
-
fused_mul_grouped_poly_norm_ref,
|
| 7 |
-
)
|
| 8 |
|
| 9 |
if _has_cuda_ops:
|
| 10 |
from grouped_poly_norm import fused_mul_grouped_poly_norm
|
|
@@ -23,12 +19,19 @@ CUDA_DEVICES = ["cuda:0"]
|
|
| 23 |
|
| 24 |
def _counts_to_offsets(counts_list, device):
|
| 25 |
"""Convert list of counts to cumsum offsets tensor."""
|
| 26 |
-
return torch.cumsum(
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
"""Create deterministic test inputs with random token distribution."""
|
| 33 |
torch.manual_seed(seed)
|
| 34 |
|
|
@@ -57,40 +60,64 @@ def _make_scores(total_tokens, device, dtype=torch.float32):
|
|
| 57 |
return torch.rand(total_tokens, 1, device=device, dtype=dtype) * 0.5 + 0.5
|
| 58 |
|
| 59 |
|
| 60 |
-
def _run_ref(input_t,
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
"""Run reference forward + backward, return output and grads."""
|
| 63 |
inp = input_t.clone().detach().requires_grad_(True)
|
| 64 |
m = mul_t.clone().detach().requires_grad_(True)
|
| 65 |
w = weight.clone().detach().requires_grad_(True)
|
| 66 |
b = bias.clone().detach().requires_grad_(True)
|
| 67 |
-
s = scores.clone().detach().requires_grad_(
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
expert_offset=expert_offset,
|
| 71 |
-
scores=s,
|
|
|
|
| 72 |
out.sum().backward()
|
| 73 |
|
| 74 |
grads = (out, inp.grad, m.grad, w.grad, b.grad)
|
| 75 |
-
return grads + (s.grad,) if s is not None else grads + (None,)
|
| 76 |
|
| 77 |
|
| 78 |
-
def _run_cuda(input_t,
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"""Run CUDA forward + backward, return output and grads."""
|
| 81 |
inp = input_t.clone().detach().requires_grad_(True)
|
| 82 |
m = mul_t.clone().detach().requires_grad_(True)
|
| 83 |
w = weight.clone().detach().requires_grad_(True)
|
| 84 |
b = bias.clone().detach().requires_grad_(True)
|
| 85 |
-
s = scores.clone().detach().requires_grad_(
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
expert_offset=expert_offset,
|
| 89 |
-
scores=s,
|
|
|
|
| 90 |
out.sum().backward()
|
| 91 |
|
| 92 |
grads = (out, inp.grad, m.grad, w.grad, b.grad)
|
| 93 |
-
return grads + (s.grad,) if s is not None else grads + (None,)
|
| 94 |
|
| 95 |
|
| 96 |
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
|
|
@@ -113,15 +140,26 @@ def test_fused_mul_grouped_poly_norm_forward(
|
|
| 113 |
"""CUDA forward output should match PyTorch reference."""
|
| 114 |
torch.set_default_device(device)
|
| 115 |
input_t, mul_t, weight, bias, offsets = _make_inputs(
|
| 116 |
-
num_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
expert_offset=expert_offset)
|
| 118 |
|
| 119 |
-
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
|
|
|
|
|
|
|
|
|
|
| 120 |
offsets,
|
| 121 |
expert_offset=expert_offset)
|
| 122 |
-
out_tri = fused_mul_grouped_poly_norm(input_t,
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
assert out_ref.shape == out_tri.shape == (num_tokens, d)
|
| 127 |
assert out_ref.dtype == out_tri.dtype == dtype
|
|
@@ -152,7 +190,12 @@ def test_fused_mul_grouped_poly_norm_backward(
|
|
| 152 |
"""CUDA backward gradients should match PyTorch reference."""
|
| 153 |
torch.set_default_device(device)
|
| 154 |
input_t, mul_t, weight, bias, offsets = _make_inputs(
|
| 155 |
-
num_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
expert_offset=expert_offset)
|
| 157 |
|
| 158 |
_, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref(
|
|
@@ -195,12 +238,18 @@ def test_fused_mul_grouped_poly_norm_zero_token_experts(
|
|
| 195 |
bias = torch.zeros(total_experts, 1, device=device, dtype=dtype)
|
| 196 |
offsets = _counts_to_offsets(counts, device)
|
| 197 |
|
| 198 |
-
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
|
|
|
|
|
|
|
|
|
|
| 199 |
offsets,
|
| 200 |
expert_offset=expert_offset)
|
| 201 |
-
out_tri = fused_mul_grouped_poly_norm(input_t,
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
if dtype == torch.float32:
|
| 206 |
assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4)
|
|
@@ -208,12 +257,18 @@ def test_fused_mul_grouped_poly_norm_zero_token_experts(
|
|
| 208 |
assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
|
| 209 |
|
| 210 |
# Check backward with zero-token experts
|
| 211 |
-
_, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
offsets,
|
| 213 |
expert_offset=expert_offset)
|
| 214 |
-
_, _, _, w_grad_tri, b_grad_tri, _ = _run_cuda(input_t, mul_t, weight, bias,
|
| 215 |
-
offsets,
|
| 216 |
-
expert_offset=expert_offset)
|
| 217 |
|
| 218 |
if dtype == torch.float32:
|
| 219 |
atol, rtol = 1e-3, 1e-3
|
|
@@ -270,7 +325,11 @@ def test_fused_mul_grouped_poly_norm_no_nan_inf(
|
|
| 270 |
@pytest.mark.parametrize("dtype", DTYPES)
|
| 271 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 272 |
def test_fused_mul_grouped_poly_norm_scores_forward(
|
| 273 |
-
num_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
):
|
| 275 |
"""Forward with scores should match reference."""
|
| 276 |
torch.set_default_device(device)
|
|
@@ -278,10 +337,18 @@ def test_fused_mul_grouped_poly_norm_scores_forward(
|
|
| 278 |
num_tokens, d, num_experts, dtype, device)
|
| 279 |
scores = _make_scores(num_tokens, device)
|
| 280 |
|
| 281 |
-
out_ref = fused_mul_grouped_poly_norm_ref(
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 287 |
assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
|
|
@@ -294,7 +361,11 @@ def test_fused_mul_grouped_poly_norm_scores_forward(
|
|
| 294 |
@pytest.mark.parametrize("dtype", DTYPES)
|
| 295 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 296 |
def test_fused_mul_grouped_poly_norm_scores_backward(
|
| 297 |
-
num_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
):
|
| 299 |
"""Backward with scores should match reference."""
|
| 300 |
torch.set_default_device(device)
|
|
@@ -302,10 +373,18 @@ def test_fused_mul_grouped_poly_norm_scores_backward(
|
|
| 302 |
num_tokens, d, num_experts, dtype, device)
|
| 303 |
scores = _make_scores(num_tokens, device)
|
| 304 |
|
| 305 |
-
out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
|
| 311 |
# weight/bias grads use atomicAdd accumulation across tokens,
|
|
@@ -332,7 +411,12 @@ CLAMP_VALUES = [10.0, 1.0, 0.5]
|
|
| 332 |
@pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
|
| 333 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 334 |
def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
|
| 335 |
-
num_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
):
|
| 337 |
"""Forward with hidden_clamp should match reference."""
|
| 338 |
torch.set_default_device(device)
|
|
@@ -340,12 +424,20 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
|
|
| 340 |
num_tokens, d, num_experts, dtype, device)
|
| 341 |
scores = _make_scores(num_tokens, device)
|
| 342 |
|
| 343 |
-
out_ref = fused_mul_grouped_poly_norm_ref(
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 351 |
assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
|
|
@@ -359,7 +451,12 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
|
|
| 359 |
@pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
|
| 360 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 361 |
def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
|
| 362 |
-
num_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
):
|
| 364 |
"""Backward with hidden_clamp should match reference."""
|
| 365 |
torch.set_default_device(device)
|
|
@@ -368,11 +465,21 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
|
|
| 368 |
scores = _make_scores(num_tokens, device)
|
| 369 |
|
| 370 |
out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
|
| 371 |
-
input_t,
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(
|
| 374 |
-
input_t,
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
|
| 378 |
# weight/bias grads use atomicAdd accumulation across tokens,
|
|
|
|
| 1 |
import pytest
|
| 2 |
import torch
|
| 3 |
+
from grouped_poly_norm import _has_cuda_ops, fused_mul_grouped_poly_norm_ref
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
if _has_cuda_ops:
|
| 6 |
from grouped_poly_norm import fused_mul_grouped_poly_norm
|
|
|
|
| 19 |
|
| 20 |
def _counts_to_offsets(counts_list, device):
|
| 21 |
"""Convert list of counts to cumsum offsets tensor."""
|
| 22 |
+
return torch.cumsum(torch.tensor(counts_list,
|
| 23 |
+
device=device,
|
| 24 |
+
dtype=torch.int32),
|
| 25 |
+
dim=0).to(torch.int32)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _make_inputs(total_tokens,
|
| 29 |
+
hidden_dim,
|
| 30 |
+
num_experts,
|
| 31 |
+
dtype,
|
| 32 |
+
device,
|
| 33 |
+
seed=42,
|
| 34 |
+
expert_offset=0):
|
| 35 |
"""Create deterministic test inputs with random token distribution."""
|
| 36 |
torch.manual_seed(seed)
|
| 37 |
|
|
|
|
| 60 |
return torch.rand(total_tokens, 1, device=device, dtype=dtype) * 0.5 + 0.5
|
| 61 |
|
| 62 |
|
| 63 |
+
def _run_ref(input_t,
|
| 64 |
+
mul_t,
|
| 65 |
+
weight,
|
| 66 |
+
bias,
|
| 67 |
+
offsets,
|
| 68 |
+
expert_offset=0,
|
| 69 |
+
scores=None,
|
| 70 |
+
hidden_clamp=None):
|
| 71 |
"""Run reference forward + backward, return output and grads."""
|
| 72 |
inp = input_t.clone().detach().requires_grad_(True)
|
| 73 |
m = mul_t.clone().detach().requires_grad_(True)
|
| 74 |
w = weight.clone().detach().requires_grad_(True)
|
| 75 |
b = bias.clone().detach().requires_grad_(True)
|
| 76 |
+
s = scores.clone().detach().requires_grad_(
|
| 77 |
+
True) if scores is not None else None
|
| 78 |
+
|
| 79 |
+
out = fused_mul_grouped_poly_norm_ref(inp,
|
| 80 |
+
m,
|
| 81 |
+
w,
|
| 82 |
+
b,
|
| 83 |
+
offsets,
|
| 84 |
expert_offset=expert_offset,
|
| 85 |
+
scores=s,
|
| 86 |
+
hidden_clamp=hidden_clamp)
|
| 87 |
out.sum().backward()
|
| 88 |
|
| 89 |
grads = (out, inp.grad, m.grad, w.grad, b.grad)
|
| 90 |
+
return grads + (s.grad, ) if s is not None else grads + (None, )
|
| 91 |
|
| 92 |
|
| 93 |
+
def _run_cuda(input_t,
|
| 94 |
+
mul_t,
|
| 95 |
+
weight,
|
| 96 |
+
bias,
|
| 97 |
+
offsets,
|
| 98 |
+
expert_offset=0,
|
| 99 |
+
scores=None,
|
| 100 |
+
hidden_clamp=None):
|
| 101 |
"""Run CUDA forward + backward, return output and grads."""
|
| 102 |
inp = input_t.clone().detach().requires_grad_(True)
|
| 103 |
m = mul_t.clone().detach().requires_grad_(True)
|
| 104 |
w = weight.clone().detach().requires_grad_(True)
|
| 105 |
b = bias.clone().detach().requires_grad_(True)
|
| 106 |
+
s = scores.clone().detach().requires_grad_(
|
| 107 |
+
True) if scores is not None else None
|
| 108 |
+
|
| 109 |
+
out = fused_mul_grouped_poly_norm(inp,
|
| 110 |
+
m,
|
| 111 |
+
w,
|
| 112 |
+
b,
|
| 113 |
+
offsets,
|
| 114 |
expert_offset=expert_offset,
|
| 115 |
+
scores=s,
|
| 116 |
+
hidden_clamp=hidden_clamp)
|
| 117 |
out.sum().backward()
|
| 118 |
|
| 119 |
grads = (out, inp.grad, m.grad, w.grad, b.grad)
|
| 120 |
+
return grads + (s.grad, ) if s is not None else grads + (None, )
|
| 121 |
|
| 122 |
|
| 123 |
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
|
|
|
|
| 140 |
"""CUDA forward output should match PyTorch reference."""
|
| 141 |
torch.set_default_device(device)
|
| 142 |
input_t, mul_t, weight, bias, offsets = _make_inputs(
|
| 143 |
+
num_tokens,
|
| 144 |
+
d,
|
| 145 |
+
num_experts,
|
| 146 |
+
dtype,
|
| 147 |
+
device,
|
| 148 |
+
seed,
|
| 149 |
expert_offset=expert_offset)
|
| 150 |
|
| 151 |
+
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
|
| 152 |
+
mul_t,
|
| 153 |
+
weight,
|
| 154 |
+
bias,
|
| 155 |
offsets,
|
| 156 |
expert_offset=expert_offset)
|
| 157 |
+
out_tri = fused_mul_grouped_poly_norm(input_t,
|
| 158 |
+
mul_t,
|
| 159 |
+
weight,
|
| 160 |
+
bias,
|
| 161 |
+
offsets,
|
| 162 |
+
expert_offset=expert_offset)
|
| 163 |
|
| 164 |
assert out_ref.shape == out_tri.shape == (num_tokens, d)
|
| 165 |
assert out_ref.dtype == out_tri.dtype == dtype
|
|
|
|
| 190 |
"""CUDA backward gradients should match PyTorch reference."""
|
| 191 |
torch.set_default_device(device)
|
| 192 |
input_t, mul_t, weight, bias, offsets = _make_inputs(
|
| 193 |
+
num_tokens,
|
| 194 |
+
d,
|
| 195 |
+
num_experts,
|
| 196 |
+
dtype,
|
| 197 |
+
device,
|
| 198 |
+
seed,
|
| 199 |
expert_offset=expert_offset)
|
| 200 |
|
| 201 |
_, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref(
|
|
|
|
| 238 |
bias = torch.zeros(total_experts, 1, device=device, dtype=dtype)
|
| 239 |
offsets = _counts_to_offsets(counts, device)
|
| 240 |
|
| 241 |
+
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
|
| 242 |
+
mul_t,
|
| 243 |
+
weight,
|
| 244 |
+
bias,
|
| 245 |
offsets,
|
| 246 |
expert_offset=expert_offset)
|
| 247 |
+
out_tri = fused_mul_grouped_poly_norm(input_t,
|
| 248 |
+
mul_t,
|
| 249 |
+
weight,
|
| 250 |
+
bias,
|
| 251 |
+
offsets,
|
| 252 |
+
expert_offset=expert_offset)
|
| 253 |
|
| 254 |
if dtype == torch.float32:
|
| 255 |
assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4)
|
|
|
|
| 257 |
assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
|
| 258 |
|
| 259 |
# Check backward with zero-token experts
|
| 260 |
+
_, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t,
|
| 261 |
+
mul_t,
|
| 262 |
+
weight,
|
| 263 |
+
bias,
|
| 264 |
+
offsets,
|
| 265 |
+
expert_offset=expert_offset)
|
| 266 |
+
_, _, _, w_grad_tri, b_grad_tri, _ = _run_cuda(input_t,
|
| 267 |
+
mul_t,
|
| 268 |
+
weight,
|
| 269 |
+
bias,
|
| 270 |
offsets,
|
| 271 |
expert_offset=expert_offset)
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
if dtype == torch.float32:
|
| 274 |
atol, rtol = 1e-3, 1e-3
|
|
|
|
| 325 |
@pytest.mark.parametrize("dtype", DTYPES)
|
| 326 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 327 |
def test_fused_mul_grouped_poly_norm_scores_forward(
|
| 328 |
+
num_tokens,
|
| 329 |
+
d,
|
| 330 |
+
num_experts,
|
| 331 |
+
dtype,
|
| 332 |
+
device,
|
| 333 |
):
|
| 334 |
"""Forward with scores should match reference."""
|
| 335 |
torch.set_default_device(device)
|
|
|
|
| 337 |
num_tokens, d, num_experts, dtype, device)
|
| 338 |
scores = _make_scores(num_tokens, device)
|
| 339 |
|
| 340 |
+
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
|
| 341 |
+
mul_t,
|
| 342 |
+
weight,
|
| 343 |
+
bias,
|
| 344 |
+
offsets,
|
| 345 |
+
scores=scores)
|
| 346 |
+
out_tri = fused_mul_grouped_poly_norm(input_t,
|
| 347 |
+
mul_t,
|
| 348 |
+
weight,
|
| 349 |
+
bias,
|
| 350 |
+
offsets,
|
| 351 |
+
scores=scores)
|
| 352 |
|
| 353 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 354 |
assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
|
|
|
|
| 361 |
@pytest.mark.parametrize("dtype", DTYPES)
|
| 362 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 363 |
def test_fused_mul_grouped_poly_norm_scores_backward(
|
| 364 |
+
num_tokens,
|
| 365 |
+
d,
|
| 366 |
+
num_experts,
|
| 367 |
+
dtype,
|
| 368 |
+
device,
|
| 369 |
):
|
| 370 |
"""Backward with scores should match reference."""
|
| 371 |
torch.set_default_device(device)
|
|
|
|
| 373 |
num_tokens, d, num_experts, dtype, device)
|
| 374 |
scores = _make_scores(num_tokens, device)
|
| 375 |
|
| 376 |
+
out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(input_t,
|
| 377 |
+
mul_t,
|
| 378 |
+
weight,
|
| 379 |
+
bias,
|
| 380 |
+
offsets,
|
| 381 |
+
scores=scores)
|
| 382 |
+
out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(input_t,
|
| 383 |
+
mul_t,
|
| 384 |
+
weight,
|
| 385 |
+
bias,
|
| 386 |
+
offsets,
|
| 387 |
+
scores=scores)
|
| 388 |
|
| 389 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
|
| 390 |
# weight/bias grads use atomicAdd accumulation across tokens,
|
|
|
|
| 411 |
@pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
|
| 412 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 413 |
def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
|
| 414 |
+
num_tokens,
|
| 415 |
+
d,
|
| 416 |
+
num_experts,
|
| 417 |
+
dtype,
|
| 418 |
+
hidden_clamp,
|
| 419 |
+
device,
|
| 420 |
):
|
| 421 |
"""Forward with hidden_clamp should match reference."""
|
| 422 |
torch.set_default_device(device)
|
|
|
|
| 424 |
num_tokens, d, num_experts, dtype, device)
|
| 425 |
scores = _make_scores(num_tokens, device)
|
| 426 |
|
| 427 |
+
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
|
| 428 |
+
mul_t,
|
| 429 |
+
weight,
|
| 430 |
+
bias,
|
| 431 |
+
offsets,
|
| 432 |
+
scores=scores,
|
| 433 |
+
hidden_clamp=hidden_clamp)
|
| 434 |
+
out_tri = fused_mul_grouped_poly_norm(input_t,
|
| 435 |
+
mul_t,
|
| 436 |
+
weight,
|
| 437 |
+
bias,
|
| 438 |
+
offsets,
|
| 439 |
+
scores=scores,
|
| 440 |
+
hidden_clamp=hidden_clamp)
|
| 441 |
|
| 442 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 443 |
assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
|
|
|
|
| 451 |
@pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
|
| 452 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 453 |
def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
|
| 454 |
+
num_tokens,
|
| 455 |
+
d,
|
| 456 |
+
num_experts,
|
| 457 |
+
dtype,
|
| 458 |
+
hidden_clamp,
|
| 459 |
+
device,
|
| 460 |
):
|
| 461 |
"""Backward with hidden_clamp should match reference."""
|
| 462 |
torch.set_default_device(device)
|
|
|
|
| 465 |
scores = _make_scores(num_tokens, device)
|
| 466 |
|
| 467 |
out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
|
| 468 |
+
input_t,
|
| 469 |
+
mul_t,
|
| 470 |
+
weight,
|
| 471 |
+
bias,
|
| 472 |
+
offsets,
|
| 473 |
+
scores=scores,
|
| 474 |
+
hidden_clamp=hidden_clamp)
|
| 475 |
out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(
|
| 476 |
+
input_t,
|
| 477 |
+
mul_t,
|
| 478 |
+
weight,
|
| 479 |
+
bias,
|
| 480 |
+
offsets,
|
| 481 |
+
scores=scores,
|
| 482 |
+
hidden_clamp=hidden_clamp)
|
| 483 |
|
| 484 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
|
| 485 |
# weight/bias grads use atomicAdd accumulation across tokens,
|
torch-ext/activation/grouped_poly_norm.py
CHANGED
|
@@ -37,37 +37,42 @@ _has_cuda_ops = _ops is not None and hasattr(_ops, 'grouped_poly_norm_forward')
|
|
| 37 |
# Register fake (meta) tensor implementations for torch.compile
|
| 38 |
if _has_cuda_ops:
|
| 39 |
try:
|
|
|
|
| 40 |
@torch.library.register_fake("_activation::grouped_poly_norm_forward")
|
| 41 |
def _fwd_fake(input, mul, weight, bias, offsets, eps, expert_offset,
|
| 42 |
-
|
| 43 |
return (torch.empty_like(input),
|
| 44 |
-
torch.empty(input.shape[0],
|
|
|
|
|
|
|
| 45 |
device=input.device))
|
| 46 |
|
| 47 |
@torch.library.register_fake("_activation::grouped_poly_norm_backward")
|
| 48 |
def _bwd_fake(grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 49 |
-
|
| 50 |
-
return (torch.empty_like(input),
|
| 51 |
-
torch.empty_like(
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
eps, expert_offset, hidden_clamp):
|
| 58 |
return (torch.empty_like(input),
|
| 59 |
-
torch.empty(input.shape[0],
|
|
|
|
|
|
|
| 60 |
device=input.device))
|
| 61 |
|
| 62 |
-
@torch.library.register_fake(
|
|
|
|
| 63 |
def _bwd_scored_fake(grad_output, input, mul, weight, bias, offsets,
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
return (torch.empty_like(input),
|
| 67 |
-
torch.empty_like(
|
| 68 |
-
torch.
|
| 69 |
-
|
| 70 |
-
|
| 71 |
device=input.device))
|
| 72 |
except Exception:
|
| 73 |
pass # already registered
|
|
@@ -111,7 +116,8 @@ def fused_mul_grouped_poly_norm_ref(
|
|
| 111 |
orig_dtype = input.dtype
|
| 112 |
|
| 113 |
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 114 |
-
expert_idx = torch.bucketize(token_positions, offsets,
|
|
|
|
| 115 |
|
| 116 |
weight_fp32 = weight.float()
|
| 117 |
bias_fp32 = bias.float()
|
|
@@ -182,25 +188,25 @@ if _has_cuda_ops:
|
|
| 182 |
"""With scores — same pattern, adds scores + hidden_clamp."""
|
| 183 |
|
| 184 |
@staticmethod
|
| 185 |
-
def forward(input, mul, weight, bias, offsets, scores,
|
| 186 |
-
|
| 187 |
input = input.contiguous()
|
| 188 |
mul = mul.contiguous()
|
| 189 |
assert scores.dtype == torch.float32, \
|
| 190 |
f"scores must be float32, got {scores.dtype}"
|
| 191 |
scores = scores.contiguous()
|
| 192 |
output, inv_rms = _ops.grouped_poly_norm_forward_scored(
|
| 193 |
-
input, mul, weight, bias, offsets, scores, eps,
|
| 194 |
-
|
| 195 |
return output, inv_rms
|
| 196 |
|
| 197 |
@staticmethod
|
| 198 |
def setup_context(ctx, inputs, output):
|
| 199 |
-
(input, mul, weight, bias, offsets, scores,
|
| 200 |
-
|
| 201 |
_, inv_rms = output
|
| 202 |
-
ctx.save_for_backward(input, mul, weight, bias, offsets,
|
| 203 |
-
|
| 204 |
ctx.eps = eps
|
| 205 |
ctx.expert_offset = expert_offset
|
| 206 |
ctx.hidden_clamp = hidden_clamp
|
|
@@ -244,13 +250,14 @@ if _has_cuda_ops:
|
|
| 244 |
"""
|
| 245 |
clamp_val = -1.0 if hidden_clamp is None else float(hidden_clamp)
|
| 246 |
if scores is not None:
|
| 247 |
-
output, _ = _GroupedPolyNormScoredFn.apply(
|
| 248 |
-
|
| 249 |
-
|
|
|
|
| 250 |
else:
|
| 251 |
-
output, _ = _GroupedPolyNormFn.apply(
|
| 252 |
-
|
| 253 |
-
|
| 254 |
return output
|
| 255 |
|
| 256 |
else:
|
|
|
|
| 37 |
# Register fake (meta) tensor implementations for torch.compile
|
| 38 |
if _has_cuda_ops:
|
| 39 |
try:
|
| 40 |
+
|
| 41 |
@torch.library.register_fake("_activation::grouped_poly_norm_forward")
|
| 42 |
def _fwd_fake(input, mul, weight, bias, offsets, eps, expert_offset,
|
| 43 |
+
hidden_clamp):
|
| 44 |
return (torch.empty_like(input),
|
| 45 |
+
torch.empty(input.shape[0],
|
| 46 |
+
3,
|
| 47 |
+
dtype=torch.float32,
|
| 48 |
device=input.device))
|
| 49 |
|
| 50 |
@torch.library.register_fake("_activation::grouped_poly_norm_backward")
|
| 51 |
def _bwd_fake(grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 52 |
+
eps, expert_offset, hidden_clamp):
|
| 53 |
+
return (torch.empty_like(input), torch.empty_like(mul),
|
| 54 |
+
torch.empty_like(weight), torch.empty_like(bias))
|
| 55 |
+
|
| 56 |
+
@torch.library.register_fake(
|
| 57 |
+
"_activation::grouped_poly_norm_forward_scored")
|
| 58 |
+
def _fwd_scored_fake(input, mul, weight, bias, offsets, scores, eps,
|
| 59 |
+
expert_offset, hidden_clamp):
|
|
|
|
| 60 |
return (torch.empty_like(input),
|
| 61 |
+
torch.empty(input.shape[0],
|
| 62 |
+
3,
|
| 63 |
+
dtype=torch.float32,
|
| 64 |
device=input.device))
|
| 65 |
|
| 66 |
+
@torch.library.register_fake(
|
| 67 |
+
"_activation::grouped_poly_norm_backward_scored")
|
| 68 |
def _bwd_scored_fake(grad_output, input, mul, weight, bias, offsets,
|
| 69 |
+
inv_rms, scores, eps, expert_offset,
|
| 70 |
+
hidden_clamp):
|
| 71 |
+
return (torch.empty_like(input), torch.empty_like(mul),
|
| 72 |
+
torch.empty_like(weight), torch.empty_like(bias),
|
| 73 |
+
torch.empty(input.shape[0],
|
| 74 |
+
1,
|
| 75 |
+
dtype=torch.float32,
|
| 76 |
device=input.device))
|
| 77 |
except Exception:
|
| 78 |
pass # already registered
|
|
|
|
| 116 |
orig_dtype = input.dtype
|
| 117 |
|
| 118 |
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 119 |
+
expert_idx = torch.bucketize(token_positions, offsets,
|
| 120 |
+
right=True) + expert_offset
|
| 121 |
|
| 122 |
weight_fp32 = weight.float()
|
| 123 |
bias_fp32 = bias.float()
|
|
|
|
| 188 |
"""With scores — same pattern, adds scores + hidden_clamp."""
|
| 189 |
|
| 190 |
@staticmethod
|
| 191 |
+
def forward(input, mul, weight, bias, offsets, scores, eps,
|
| 192 |
+
expert_offset, hidden_clamp):
|
| 193 |
input = input.contiguous()
|
| 194 |
mul = mul.contiguous()
|
| 195 |
assert scores.dtype == torch.float32, \
|
| 196 |
f"scores must be float32, got {scores.dtype}"
|
| 197 |
scores = scores.contiguous()
|
| 198 |
output, inv_rms = _ops.grouped_poly_norm_forward_scored(
|
| 199 |
+
input, mul, weight, bias, offsets, scores, eps, expert_offset,
|
| 200 |
+
hidden_clamp)
|
| 201 |
return output, inv_rms
|
| 202 |
|
| 203 |
@staticmethod
|
| 204 |
def setup_context(ctx, inputs, output):
|
| 205 |
+
(input, mul, weight, bias, offsets, scores, eps, expert_offset,
|
| 206 |
+
hidden_clamp) = inputs
|
| 207 |
_, inv_rms = output
|
| 208 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets, inv_rms,
|
| 209 |
+
scores)
|
| 210 |
ctx.eps = eps
|
| 211 |
ctx.expert_offset = expert_offset
|
| 212 |
ctx.hidden_clamp = hidden_clamp
|
|
|
|
| 250 |
"""
|
| 251 |
clamp_val = -1.0 if hidden_clamp is None else float(hidden_clamp)
|
| 252 |
if scores is not None:
|
| 253 |
+
output, _ = _GroupedPolyNormScoredFn.apply(input, mul, weight,
|
| 254 |
+
bias, offsets, scores,
|
| 255 |
+
eps, expert_offset,
|
| 256 |
+
clamp_val)
|
| 257 |
else:
|
| 258 |
+
output, _ = _GroupedPolyNormFn.apply(input, mul, weight, bias,
|
| 259 |
+
offsets, eps, expert_offset,
|
| 260 |
+
clamp_val)
|
| 261 |
return output
|
| 262 |
|
| 263 |
else:
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -50,32 +50,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 50 |
&fused_add_rms_norm_backward);
|
| 51 |
|
| 52 |
// grouped_poly_norm (without scores, hidden_clamp < 0 = disabled)
|
| 53 |
-
ops.def(
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
| 57 |
ops.impl("grouped_poly_norm_forward", torch::kCUDA,
|
| 58 |
&grouped_poly_norm_forward);
|
| 59 |
|
| 60 |
ops.def("grouped_poly_norm_backward("
|
| 61 |
"Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
|
| 62 |
"Tensor bias, Tensor offsets, Tensor inv_rms, "
|
| 63 |
-
"float eps, int expert_offset, float hidden_clamp) -> (Tensor,
|
|
|
|
| 64 |
ops.impl("grouped_poly_norm_backward", torch::kCUDA,
|
| 65 |
&grouped_poly_norm_backward);
|
| 66 |
|
| 67 |
// grouped_poly_norm (with scores)
|
| 68 |
-
ops.def(
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
| 72 |
ops.impl("grouped_poly_norm_forward_scored", torch::kCUDA,
|
| 73 |
&grouped_poly_norm_forward_scored);
|
| 74 |
|
| 75 |
ops.def("grouped_poly_norm_backward_scored("
|
| 76 |
"Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
|
| 77 |
"Tensor bias, Tensor offsets, Tensor inv_rms, Tensor scores, "
|
| 78 |
-
"float eps, int expert_offset, float hidden_clamp) -> (Tensor,
|
|
|
|
| 79 |
ops.impl("grouped_poly_norm_backward_scored", torch::kCUDA,
|
| 80 |
&grouped_poly_norm_backward_scored);
|
| 81 |
}
|
|
|
|
| 50 |
&fused_add_rms_norm_backward);
|
| 51 |
|
| 52 |
// grouped_poly_norm (without scores, hidden_clamp < 0 = disabled)
|
| 53 |
+
ops.def(
|
| 54 |
+
"grouped_poly_norm_forward("
|
| 55 |
+
"Tensor input, Tensor mul, Tensor weight, "
|
| 56 |
+
"Tensor bias, Tensor offsets, "
|
| 57 |
+
"float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor)");
|
| 58 |
ops.impl("grouped_poly_norm_forward", torch::kCUDA,
|
| 59 |
&grouped_poly_norm_forward);
|
| 60 |
|
| 61 |
ops.def("grouped_poly_norm_backward("
|
| 62 |
"Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
|
| 63 |
"Tensor bias, Tensor offsets, Tensor inv_rms, "
|
| 64 |
+
"float eps, int expert_offset, float hidden_clamp) -> (Tensor, "
|
| 65 |
+
"Tensor, Tensor, Tensor)");
|
| 66 |
ops.impl("grouped_poly_norm_backward", torch::kCUDA,
|
| 67 |
&grouped_poly_norm_backward);
|
| 68 |
|
| 69 |
// grouped_poly_norm (with scores)
|
| 70 |
+
ops.def(
|
| 71 |
+
"grouped_poly_norm_forward_scored("
|
| 72 |
+
"Tensor input, Tensor mul, Tensor weight, "
|
| 73 |
+
"Tensor bias, Tensor offsets, Tensor scores, "
|
| 74 |
+
"float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor)");
|
| 75 |
ops.impl("grouped_poly_norm_forward_scored", torch::kCUDA,
|
| 76 |
&grouped_poly_norm_forward_scored);
|
| 77 |
|
| 78 |
ops.def("grouped_poly_norm_backward_scored("
|
| 79 |
"Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
|
| 80 |
"Tensor bias, Tensor offsets, Tensor inv_rms, Tensor scores, "
|
| 81 |
+
"float eps, int expert_offset, float hidden_clamp) -> (Tensor, "
|
| 82 |
+
"Tensor, Tensor, Tensor, Tensor)");
|
| 83 |
ops.impl("grouped_poly_norm_backward_scored", torch::kCUDA,
|
| 84 |
&grouped_poly_norm_backward_scored);
|
| 85 |
}
|