| import pytest |
| import torch |
| from grouped_poly_norm import _has_cuda_ops, fused_mul_grouped_poly_norm_ref |
|
|
| if _has_cuda_ops: |
| from grouped_poly_norm import fused_mul_grouped_poly_norm |
|
|
| from .utils import assert_close |
|
|
| DTYPES = [torch.float, torch.bfloat16, torch.float16] |
| NUM_TOKENS = [4096, 8192] |
| D = [256, 1280] |
| NUM_EXPERTS_LIST = [8, 384] |
| EXPERT_OFFSETS = [0, 4] |
| SEEDS = [0] |
| |
| CUDA_DEVICES = ["cuda:0"] |
|
|
|
|
| def _counts_to_offsets(counts_list, device): |
| """Convert list of counts to cumsum offsets tensor.""" |
| return torch.cumsum(torch.tensor(counts_list, |
| device=device, |
| dtype=torch.int32), |
| dim=0).to(torch.int32) |
|
|
|
|
| def _make_inputs(total_tokens, |
| hidden_dim, |
| num_experts, |
| dtype, |
| device, |
| seed=42, |
| expert_offset=0): |
| """Create deterministic test inputs with random token distribution.""" |
| torch.manual_seed(seed) |
|
|
| probs = torch.ones(num_experts) / num_experts |
| assignments = torch.multinomial(probs, total_tokens, replacement=True) |
| counts = torch.bincount(assignments, minlength=num_experts).tolist() |
|
|
| |
| total_experts = expert_offset + num_experts |
|
|
| |
| input_t = torch.randn(total_tokens, hidden_dim, device=device, |
| dtype=dtype) * 0.5 |
| mul_t = torch.randn(total_tokens, hidden_dim, device=device, |
| dtype=dtype) * 0.5 |
| weight = (torch.ones(total_experts, 3, device=device, dtype=dtype) / 3 + |
| torch.randn(total_experts, 3, device=device, dtype=dtype) * 0.01) |
| bias = torch.randn(total_experts, 1, device=device, dtype=dtype) * 0.01 |
| offsets = _counts_to_offsets(counts, device) |
|
|
| return input_t, mul_t, weight, bias, offsets |
|
|
|
|
| def _make_scores(total_tokens, device, dtype=torch.float32): |
| """Create random scores (N, 1) in fp32.""" |
| return torch.rand(total_tokens, 1, device=device, dtype=dtype) * 0.5 + 0.5 |
|
|
|
|
| def _run_ref(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| expert_offset=0, |
| scores=None, |
| hidden_clamp=None): |
| """Run reference forward + backward, return output and grads.""" |
| inp = input_t.clone().detach().requires_grad_(True) |
| m = mul_t.clone().detach().requires_grad_(True) |
| w = weight.clone().detach().requires_grad_(True) |
| b = bias.clone().detach().requires_grad_(True) |
| s = scores.clone().detach().requires_grad_( |
| True) if scores is not None else None |
|
|
| out = fused_mul_grouped_poly_norm_ref(inp, |
| m, |
| w, |
| b, |
| offsets, |
| expert_offset=expert_offset, |
| scores=s, |
| hidden_clamp=hidden_clamp) |
| out.sum().backward() |
|
|
| grads = (out, inp.grad, m.grad, w.grad, b.grad) |
| return grads + (s.grad, ) if s is not None else grads + (None, ) |
|
|
|
|
| def _run_cuda(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| expert_offset=0, |
| scores=None, |
| hidden_clamp=None): |
| """Run CUDA forward + backward, return output and grads.""" |
| inp = input_t.clone().detach().requires_grad_(True) |
| m = mul_t.clone().detach().requires_grad_(True) |
| w = weight.clone().detach().requires_grad_(True) |
| b = bias.clone().detach().requires_grad_(True) |
| s = scores.clone().detach().requires_grad_( |
| True) if scores is not None else None |
|
|
| out = fused_mul_grouped_poly_norm(inp, |
| m, |
| w, |
| b, |
| offsets, |
| expert_offset=expert_offset, |
| scores=s, |
| hidden_clamp=hidden_clamp) |
| out.sum().backward() |
|
|
| grads = (out, inp.grad, m.grad, w.grad, b.grad) |
| return grads + (s.grad, ) if s is not None else grads + (None, ) |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| @pytest.mark.parametrize("d", D) |
| @pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) |
| @pytest.mark.parametrize("seed", SEEDS) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_fused_mul_grouped_poly_norm_forward( |
| num_tokens: int, |
| d: int, |
| num_experts: int, |
| dtype: torch.dtype, |
| expert_offset: int, |
| seed: int, |
| device: str, |
| ) -> None: |
| """CUDA forward output should match PyTorch reference.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_inputs( |
| num_tokens, |
| d, |
| num_experts, |
| dtype, |
| device, |
| seed, |
| expert_offset=expert_offset) |
|
|
| out_ref = fused_mul_grouped_poly_norm_ref(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| expert_offset=expert_offset) |
| out_tri = fused_mul_grouped_poly_norm(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| expert_offset=expert_offset) |
|
|
| assert out_ref.shape == out_tri.shape == (num_tokens, d) |
| assert out_ref.dtype == out_tri.dtype == dtype |
|
|
| if dtype == torch.float32: |
| assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4) |
| else: |
| assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2) |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| @pytest.mark.parametrize("d", D) |
| @pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) |
| @pytest.mark.parametrize("seed", SEEDS) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_fused_mul_grouped_poly_norm_backward( |
| num_tokens: int, |
| d: int, |
| num_experts: int, |
| dtype: torch.dtype, |
| expert_offset: int, |
| seed: int, |
| device: str, |
| ) -> None: |
| """CUDA backward gradients should match PyTorch reference.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_inputs( |
| num_tokens, |
| d, |
| num_experts, |
| dtype, |
| device, |
| seed, |
| expert_offset=expert_offset) |
|
|
| _, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref( |
| input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) |
| _, inp_grad_tri, mul_grad_tri, w_grad_tri, b_grad_tri, _ = _run_cuda( |
| input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) |
|
|
| if dtype == torch.float32: |
| atol, rtol = 1e-4, 1e-4 |
| else: |
| atol, rtol = 5e-2, 5e-2 |
|
|
| assert_close(inp_grad_ref, inp_grad_tri, atol=atol, rtol=rtol) |
| assert_close(mul_grad_ref, mul_grad_tri, atol=atol, rtol=rtol) |
| assert_close(w_grad_ref, w_grad_tri, atol=atol, rtol=rtol) |
| assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol) |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_fused_mul_grouped_poly_norm_zero_token_experts( |
| dtype: torch.dtype, |
| expert_offset: int, |
| device: str, |
| ) -> None: |
| """Correctness when some experts receive 0 tokens.""" |
| torch.set_default_device(device) |
| counts = [100, 50, 0, 80, 30, 0, 60, 40] |
| total = sum(counts) |
| num_experts = 8 |
| total_experts = expert_offset + num_experts |
| hidden_dim = 256 |
|
|
| torch.manual_seed(42) |
| input_t = torch.randn(total, hidden_dim, device=device, dtype=dtype) * 0.5 |
| mul_t = torch.randn(total, hidden_dim, device=device, dtype=dtype) * 0.5 |
| weight = torch.ones(total_experts, 3, device=device, dtype=dtype) / 3 |
| bias = torch.zeros(total_experts, 1, device=device, dtype=dtype) |
| offsets = _counts_to_offsets(counts, device) |
|
|
| out_ref = fused_mul_grouped_poly_norm_ref(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| expert_offset=expert_offset) |
| out_tri = fused_mul_grouped_poly_norm(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| expert_offset=expert_offset) |
|
|
| if dtype == torch.float32: |
| assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4) |
| else: |
| assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2) |
|
|
| |
| _, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| expert_offset=expert_offset) |
| _, _, _, w_grad_tri, b_grad_tri, _ = _run_cuda(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| expert_offset=expert_offset) |
|
|
| if dtype == torch.float32: |
| atol, rtol = 1e-3, 1e-3 |
| else: |
| atol, rtol = 5e-2, 5e-2 |
|
|
| assert_close(w_grad_ref, w_grad_tri, atol=atol, rtol=rtol) |
| assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol) |
|
|
| |
| for eidx in [2, 5]: |
| wi = eidx + expert_offset |
| assert w_grad_tri[wi].abs().max() == 0, ( |
| f"Expert {eidx} (weight idx {wi}) should have zero weight grad " |
| f"but got max={w_grad_tri[wi].abs().max().item()}") |
| assert b_grad_tri[wi].abs().max() == 0, ( |
| f"Expert {eidx} (weight idx {wi}) should have zero bias grad " |
| f"but got max={b_grad_tri[wi].abs().max().item()}") |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_fused_mul_grouped_poly_norm_no_nan_inf( |
| dtype: torch.dtype, |
| expert_offset: int, |
| device: str, |
| ) -> None: |
| """Output and gradients should not contain NaN or Inf.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_inputs( |
| 4096, 256, 8, dtype, device, expert_offset=expert_offset) |
|
|
| out, inp_grad, mul_grad, w_grad, b_grad, _ = _run_cuda( |
| input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) |
|
|
| assert not out.isnan().any(), "Output contains NaN" |
| assert not out.isinf().any(), "Output contains Inf" |
|
|
| for name, grad in [("input", inp_grad), ("mul", mul_grad), |
| ("weight", w_grad), ("bias", b_grad)]: |
| assert not grad.isnan().any(), f"{name}_grad contains NaN" |
| assert not grad.isinf().any(), f"{name}_grad contains Inf" |
|
|
|
|
| |
| |
| |
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| @pytest.mark.parametrize("d", D) |
| @pytest.mark.parametrize("num_experts", [8, 48]) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_fused_mul_grouped_poly_norm_scores_forward( |
| num_tokens, |
| d, |
| num_experts, |
| dtype, |
| device, |
| ): |
| """Forward with scores should match reference.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_inputs( |
| num_tokens, d, num_experts, dtype, device) |
| scores = _make_scores(num_tokens, device) |
|
|
| out_ref = fused_mul_grouped_poly_norm_ref(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores) |
| out_tri = fused_mul_grouped_poly_norm(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores) |
|
|
| atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2) |
| assert_close(out_ref, out_tri, atol=atol, rtol=rtol) |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| @pytest.mark.parametrize("d", D) |
| @pytest.mark.parametrize("num_experts", [8, 48]) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_fused_mul_grouped_poly_norm_scores_backward( |
| num_tokens, |
| d, |
| num_experts, |
| dtype, |
| device, |
| ): |
| """Backward with scores should match reference.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_inputs( |
| num_tokens, d, num_experts, dtype, device) |
| scores = _make_scores(num_tokens, device) |
|
|
| out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores) |
| out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores) |
|
|
| atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2) |
| |
| |
| wg_atol = 5e-4 if dtype == torch.float32 else 5e-2 |
| assert_close(ig_ref, ig_tri, atol=atol, rtol=rtol) |
| assert_close(mg_ref, mg_tri, atol=atol, rtol=rtol) |
| assert_close(wg_ref, wg_tri, atol=wg_atol, rtol=wg_atol) |
| assert_close(bg_ref, bg_tri, atol=wg_atol, rtol=wg_atol) |
| assert_close(sg_ref, sg_tri, atol=atol, rtol=rtol) |
|
|
|
|
| |
| |
| |
| CLAMP_VALUES = [10.0, 1.0, 0.5] |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", [4096]) |
| @pytest.mark.parametrize("d", [256, 1280]) |
| @pytest.mark.parametrize("num_experts", [8]) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_fused_mul_grouped_poly_norm_hidden_clamp_forward( |
| num_tokens, |
| d, |
| num_experts, |
| dtype, |
| hidden_clamp, |
| device, |
| ): |
| """Forward with hidden_clamp should match reference.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_inputs( |
| num_tokens, d, num_experts, dtype, device) |
| scores = _make_scores(num_tokens, device) |
|
|
| out_ref = fused_mul_grouped_poly_norm_ref(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores, |
| hidden_clamp=hidden_clamp) |
| out_tri = fused_mul_grouped_poly_norm(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores, |
| hidden_clamp=hidden_clamp) |
|
|
| atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2) |
| assert_close(out_ref, out_tri, atol=atol, rtol=rtol) |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", [4096]) |
| @pytest.mark.parametrize("d", [256, 1280]) |
| @pytest.mark.parametrize("num_experts", [8]) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_fused_mul_grouped_poly_norm_hidden_clamp_backward( |
| num_tokens, |
| d, |
| num_experts, |
| dtype, |
| hidden_clamp, |
| device, |
| ): |
| """Backward with hidden_clamp should match reference.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_inputs( |
| num_tokens, d, num_experts, dtype, device) |
| scores = _make_scores(num_tokens, device) |
|
|
| out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref( |
| input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores, |
| hidden_clamp=hidden_clamp) |
| out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda( |
| input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores, |
| hidden_clamp=hidden_clamp) |
|
|
| atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2) |
| |
| |
| wg_atol = 5e-4 if dtype == torch.float32 else 5e-2 |
| assert_close(ig_ref, ig_tri, atol=atol, rtol=rtol) |
| assert_close(mg_ref, mg_tri, atol=atol, rtol=rtol) |
| assert_close(wg_ref, wg_tri, atol=wg_atol, rtol=wg_atol) |
| assert_close(bg_ref, bg_tri, atol=wg_atol, rtol=wg_atol) |
| assert_close(sg_ref, sg_tri, atol=atol, rtol=rtol) |
|
|
|
|
| |
| |
| |
| PADDING_SIZES = [64, 256] |
|
|
|
|
| def _make_padded_inputs(num_valid_tokens, |
| num_padding, |
| hidden_dim, |
| num_experts, |
| dtype, |
| device, |
| seed=42, |
| expert_offset=0): |
| """Create inputs with extra padding rows (large values) beyond valid tokens.""" |
| torch.manual_seed(seed) |
| probs = torch.ones(num_experts) / num_experts |
| assignments = torch.multinomial(probs, num_valid_tokens, replacement=True) |
| counts = torch.bincount(assignments, minlength=num_experts).tolist() |
| total_experts = expert_offset + num_experts |
| M = num_valid_tokens + num_padding |
|
|
| input_t = torch.randn(M, hidden_dim, device=device, dtype=dtype) * 0.5 |
| mul_t = torch.randn(M, hidden_dim, device=device, dtype=dtype) * 0.5 |
| |
| input_t[num_valid_tokens:] = torch.randn( |
| num_padding, hidden_dim, device=device, dtype=dtype) * 5 |
| mul_t[num_valid_tokens:] = torch.randn( |
| num_padding, hidden_dim, device=device, dtype=dtype) * 5 |
|
|
| weight = (torch.ones(total_experts, 3, device=device, dtype=dtype) / 3 + |
| torch.randn(total_experts, 3, device=device, dtype=dtype) * 0.01) |
| bias = torch.randn(total_experts, 1, device=device, dtype=dtype) * 0.01 |
| offsets = _counts_to_offsets(counts, device) |
| return input_t, mul_t, weight, bias, offsets |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", [512, 4096]) |
| @pytest.mark.parametrize("num_padding", PADDING_SIZES) |
| @pytest.mark.parametrize("d", [256, 1280]) |
| @pytest.mark.parametrize("num_experts", [8]) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_padded_forward(num_tokens, num_padding, d, num_experts, dtype, |
| device): |
| """Forward with padded input: valid rows correct, padding rows zero.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_padded_inputs( |
| num_tokens, num_padding, d, num_experts, dtype, device) |
| num_valid = int(offsets[-1].item()) |
| M = num_tokens + num_padding |
|
|
| out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias, |
| offsets) |
| out_cuda = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias, |
| offsets) |
|
|
| assert out_cuda.shape == (M, d) |
| assert out_ref.shape == (M, d) |
|
|
| atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2) |
| assert_close(out_cuda[:num_valid], |
| out_ref[:num_valid], |
| atol=atol, |
| rtol=rtol) |
| assert out_cuda[num_valid:].abs().max() == 0, \ |
| f"Padding rows not zero: max={out_cuda[num_valid:].abs().max().item()}" |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", [512, 4096]) |
| @pytest.mark.parametrize("num_padding", PADDING_SIZES) |
| @pytest.mark.parametrize("d", [256, 1280]) |
| @pytest.mark.parametrize("num_experts", [8]) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_padded_backward(num_tokens, num_padding, d, num_experts, dtype, |
| device): |
| """Backward with padded input: dW/dB not corrupted, padding grads zero.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_padded_inputs( |
| num_tokens, num_padding, d, num_experts, dtype, device) |
| num_valid = int(offsets[-1].item()) |
|
|
| |
| _, ig_ref, mg_ref, wg_ref, bg_ref, _ = _run_ref(input_t[:num_valid], |
| mul_t[:num_valid], weight, |
| bias, offsets) |
|
|
| |
| _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, _ = _run_cuda( |
| input_t, mul_t, weight, bias, offsets) |
|
|
| atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2) |
| wg_atol = 5e-4 if dtype == torch.float32 else 5e-2 |
|
|
| |
| assert_close(wg_cuda, wg_ref, atol=wg_atol, rtol=wg_atol) |
| assert_close(bg_cuda, bg_ref, atol=wg_atol, rtol=wg_atol) |
|
|
| |
| assert_close(ig_cuda[:num_valid], ig_ref, atol=atol, rtol=rtol) |
| assert_close(mg_cuda[:num_valid], mg_ref, atol=atol, rtol=rtol) |
|
|
| |
| assert ig_cuda[num_valid:].abs().max() == 0, \ |
| "Padding rows in input grad should be zero" |
| assert mg_cuda[num_valid:].abs().max() == 0, \ |
| "Padding rows in mul grad should be zero" |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", [512, 4096]) |
| @pytest.mark.parametrize("num_padding", PADDING_SIZES) |
| @pytest.mark.parametrize("d", [256]) |
| @pytest.mark.parametrize("num_experts", [8]) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_padded_forward_scored(num_tokens, num_padding, d, num_experts, dtype, |
| device): |
| """Forward with padded input + scores: valid rows correct, padding zero.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_padded_inputs( |
| num_tokens, num_padding, d, num_experts, dtype, device) |
| num_valid = int(offsets[-1].item()) |
| M = num_tokens + num_padding |
| scores = _make_scores(M, device) |
|
|
| out_ref = fused_mul_grouped_poly_norm_ref(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores) |
| out_cuda = fused_mul_grouped_poly_norm(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores) |
|
|
| atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2) |
| assert_close(out_cuda[:num_valid], |
| out_ref[:num_valid], |
| atol=atol, |
| rtol=rtol) |
| assert out_cuda[num_valid:].abs().max() == 0, \ |
| "Padding rows not zero with scores" |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", [512, 4096]) |
| @pytest.mark.parametrize("num_padding", PADDING_SIZES) |
| @pytest.mark.parametrize("d", [256]) |
| @pytest.mark.parametrize("num_experts", [8]) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_padded_backward_scored(num_tokens, num_padding, d, num_experts, dtype, |
| device): |
| """Backward with padded input + scores: grads correct, padding zero.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_padded_inputs( |
| num_tokens, num_padding, d, num_experts, dtype, device) |
| num_valid = int(offsets[-1].item()) |
| M = num_tokens + num_padding |
| scores = _make_scores(M, device) |
|
|
| |
| _, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref( |
| input_t[:num_valid], |
| mul_t[:num_valid], |
| weight, |
| bias, |
| offsets, |
| scores=scores[:num_valid]) |
|
|
| |
| _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores) |
|
|
| atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2) |
| wg_atol = 5e-4 if dtype == torch.float32 else 5e-2 |
|
|
| assert_close(wg_cuda, wg_ref, atol=wg_atol, rtol=wg_atol) |
| assert_close(bg_cuda, bg_ref, atol=wg_atol, rtol=wg_atol) |
| assert_close(ig_cuda[:num_valid], ig_ref, atol=atol, rtol=rtol) |
| assert_close(mg_cuda[:num_valid], mg_ref, atol=atol, rtol=rtol) |
| assert_close(sg_cuda[:num_valid], sg_ref, atol=atol, rtol=rtol) |
|
|
| assert ig_cuda[num_valid:].abs().max() == 0 |
| assert mg_cuda[num_valid:].abs().max() == 0 |
| assert sg_cuda[num_valid:].abs().max() == 0 |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", [512, 4096]) |
| @pytest.mark.parametrize("num_padding", PADDING_SIZES) |
| @pytest.mark.parametrize("d", [256, 1280]) |
| @pytest.mark.parametrize("num_experts", [8]) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("hidden_clamp", [10.0, 1.0]) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_padded_forward_scored_clamp(num_tokens, num_padding, d, num_experts, |
| dtype, hidden_clamp, device): |
| """Forward with padded input + scores + hidden_clamp.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_padded_inputs( |
| num_tokens, num_padding, d, num_experts, dtype, device) |
| num_valid = int(offsets[-1].item()) |
| M = num_tokens + num_padding |
| scores = _make_scores(M, device) |
|
|
| out_ref = fused_mul_grouped_poly_norm_ref(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores, |
| hidden_clamp=hidden_clamp) |
| out_cuda = fused_mul_grouped_poly_norm(input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores, |
| hidden_clamp=hidden_clamp) |
|
|
| atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2) |
| assert_close(out_cuda[:num_valid], |
| out_ref[:num_valid], |
| atol=atol, |
| rtol=rtol) |
| assert out_cuda[num_valid:].abs().max() == 0, \ |
| "Padding rows not zero with scores+clamp" |
|
|
|
|
| @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") |
| @pytest.mark.parametrize("num_tokens", [512, 4096]) |
| @pytest.mark.parametrize("num_padding", PADDING_SIZES) |
| @pytest.mark.parametrize("d", [256, 1280]) |
| @pytest.mark.parametrize("num_experts", [8]) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("hidden_clamp", [10.0, 1.0]) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_padded_backward_scored_clamp(num_tokens, num_padding, d, num_experts, |
| dtype, hidden_clamp, device): |
| """Backward with padded input + scores + hidden_clamp: all grads correct.""" |
| torch.set_default_device(device) |
| input_t, mul_t, weight, bias, offsets = _make_padded_inputs( |
| num_tokens, num_padding, d, num_experts, dtype, device) |
| num_valid = int(offsets[-1].item()) |
| M = num_tokens + num_padding |
| scores = _make_scores(M, device) |
|
|
| |
| _, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref( |
| input_t[:num_valid], |
| mul_t[:num_valid], |
| weight, |
| bias, |
| offsets, |
| scores=scores[:num_valid], |
| hidden_clamp=hidden_clamp) |
|
|
| |
| _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda( |
| input_t, |
| mul_t, |
| weight, |
| bias, |
| offsets, |
| scores=scores, |
| hidden_clamp=hidden_clamp) |
|
|
| atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2) |
| wg_atol = 5e-4 if dtype == torch.float32 else 5e-2 |
|
|
| assert_close(wg_cuda, wg_ref, atol=wg_atol, rtol=wg_atol) |
| assert_close(bg_cuda, bg_ref, atol=wg_atol, rtol=wg_atol) |
| assert_close(ig_cuda[:num_valid], ig_ref, atol=atol, rtol=rtol) |
| assert_close(mg_cuda[:num_valid], mg_ref, atol=atol, rtol=rtol) |
| assert_close(sg_cuda[:num_valid], sg_ref, atol=atol, rtol=rtol) |
|
|
| assert ig_cuda[num_valid:].abs().max() == 0 |
| assert mg_cuda[num_valid:].abs().max() == 0 |
| assert sg_cuda[num_valid:].abs().max() == 0 |
|
|