import pytest import torch from grouped_poly_norm import _has_cuda_ops, fused_mul_grouped_poly_norm_ref if _has_cuda_ops: from grouped_poly_norm import fused_mul_grouped_poly_norm from .utils import assert_close DTYPES = [torch.float, torch.bfloat16, torch.float16] NUM_TOKENS = [4096, 8192] D = [256, 1280] NUM_EXPERTS_LIST = [8, 384] EXPERT_OFFSETS = [0, 4] SEEDS = [0] # Only test on cuda:0 to avoid cross-device issues. CUDA_DEVICES = ["cuda:0"] def _counts_to_offsets(counts_list, device): """Convert list of counts to cumsum offsets tensor.""" return torch.cumsum(torch.tensor(counts_list, device=device, dtype=torch.int32), dim=0).to(torch.int32) def _make_inputs(total_tokens, hidden_dim, num_experts, dtype, device, seed=42, expert_offset=0): """Create deterministic test inputs with random token distribution.""" torch.manual_seed(seed) probs = torch.ones(num_experts) / num_experts assignments = torch.multinomial(probs, total_tokens, replacement=True) counts = torch.bincount(assignments, minlength=num_experts).tolist() # Weight/bias must have expert_offset + num_experts rows total_experts = expert_offset + num_experts # Scale inputs to avoid overflow in bf16 (x^3 can overflow for |x| > 40) input_t = torch.randn(total_tokens, hidden_dim, device=device, dtype=dtype) * 0.5 mul_t = torch.randn(total_tokens, hidden_dim, device=device, dtype=dtype) * 0.5 weight = (torch.ones(total_experts, 3, device=device, dtype=dtype) / 3 + torch.randn(total_experts, 3, device=device, dtype=dtype) * 0.01) bias = torch.randn(total_experts, 1, device=device, dtype=dtype) * 0.01 offsets = _counts_to_offsets(counts, device) return input_t, mul_t, weight, bias, offsets def _make_scores(total_tokens, device, dtype=torch.float32): """Create random scores (N, 1) in fp32.""" return torch.rand(total_tokens, 1, device=device, dtype=dtype) * 0.5 + 0.5 def _run_ref(input_t, mul_t, weight, bias, offsets, expert_offset=0, scores=None, hidden_clamp=None): """Run reference forward + backward, return output and grads.""" inp = input_t.clone().detach().requires_grad_(True) m = mul_t.clone().detach().requires_grad_(True) w = weight.clone().detach().requires_grad_(True) b = bias.clone().detach().requires_grad_(True) s = scores.clone().detach().requires_grad_( True) if scores is not None else None out = fused_mul_grouped_poly_norm_ref(inp, m, w, b, offsets, expert_offset=expert_offset, scores=s, hidden_clamp=hidden_clamp) out.sum().backward() grads = (out, inp.grad, m.grad, w.grad, b.grad) return grads + (s.grad, ) if s is not None else grads + (None, ) def _run_cuda(input_t, mul_t, weight, bias, offsets, expert_offset=0, scores=None, hidden_clamp=None): """Run CUDA forward + backward, return output and grads.""" inp = input_t.clone().detach().requires_grad_(True) m = mul_t.clone().detach().requires_grad_(True) w = weight.clone().detach().requires_grad_(True) b = bias.clone().detach().requires_grad_(True) s = scores.clone().detach().requires_grad_( True) if scores is not None else None out = fused_mul_grouped_poly_norm(inp, m, w, b, offsets, expert_offset=expert_offset, scores=s, hidden_clamp=hidden_clamp) out.sum().backward() grads = (out, inp.grad, m.grad, w.grad, b.grad) return grads + (s.grad, ) if s is not None else grads + (None, ) @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_fused_mul_grouped_poly_norm_forward( num_tokens: int, d: int, num_experts: int, dtype: torch.dtype, expert_offset: int, seed: int, device: str, ) -> None: """CUDA forward output should match PyTorch reference.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_inputs( num_tokens, d, num_experts, dtype, device, seed, expert_offset=expert_offset) out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) out_tri = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) assert out_ref.shape == out_tri.shape == (num_tokens, d) assert out_ref.dtype == out_tri.dtype == dtype if dtype == torch.float32: assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4) else: assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2) @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_fused_mul_grouped_poly_norm_backward( num_tokens: int, d: int, num_experts: int, dtype: torch.dtype, expert_offset: int, seed: int, device: str, ) -> None: """CUDA backward gradients should match PyTorch reference.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_inputs( num_tokens, d, num_experts, dtype, device, seed, expert_offset=expert_offset) _, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref( input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) _, inp_grad_tri, mul_grad_tri, w_grad_tri, b_grad_tri, _ = _run_cuda( input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) if dtype == torch.float32: atol, rtol = 1e-4, 1e-4 else: atol, rtol = 5e-2, 5e-2 assert_close(inp_grad_ref, inp_grad_tri, atol=atol, rtol=rtol) assert_close(mul_grad_ref, mul_grad_tri, atol=atol, rtol=rtol) assert_close(w_grad_ref, w_grad_tri, atol=atol, rtol=rtol) assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol) @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_fused_mul_grouped_poly_norm_zero_token_experts( dtype: torch.dtype, expert_offset: int, device: str, ) -> None: """Correctness when some experts receive 0 tokens.""" torch.set_default_device(device) counts = [100, 50, 0, 80, 30, 0, 60, 40] total = sum(counts) num_experts = 8 total_experts = expert_offset + num_experts hidden_dim = 256 torch.manual_seed(42) input_t = torch.randn(total, hidden_dim, device=device, dtype=dtype) * 0.5 mul_t = torch.randn(total, hidden_dim, device=device, dtype=dtype) * 0.5 weight = torch.ones(total_experts, 3, device=device, dtype=dtype) / 3 bias = torch.zeros(total_experts, 1, device=device, dtype=dtype) offsets = _counts_to_offsets(counts, device) out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) out_tri = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) if dtype == torch.float32: assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4) else: assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2) # Check backward with zero-token experts _, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) _, _, _, w_grad_tri, b_grad_tri, _ = _run_cuda(input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) if dtype == torch.float32: atol, rtol = 1e-3, 1e-3 else: atol, rtol = 5e-2, 5e-2 assert_close(w_grad_ref, w_grad_tri, atol=atol, rtol=rtol) assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol) # Verify zero-token experts have zero weight/bias gradients for eidx in [2, 5]: wi = eidx + expert_offset assert w_grad_tri[wi].abs().max() == 0, ( f"Expert {eidx} (weight idx {wi}) should have zero weight grad " f"but got max={w_grad_tri[wi].abs().max().item()}") assert b_grad_tri[wi].abs().max() == 0, ( f"Expert {eidx} (weight idx {wi}) should have zero bias grad " f"but got max={b_grad_tri[wi].abs().max().item()}") @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_fused_mul_grouped_poly_norm_no_nan_inf( dtype: torch.dtype, expert_offset: int, device: str, ) -> None: """Output and gradients should not contain NaN or Inf.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_inputs( 4096, 256, 8, dtype, device, expert_offset=expert_offset) out, inp_grad, mul_grad, w_grad, b_grad, _ = _run_cuda( input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) assert not out.isnan().any(), "Output contains NaN" assert not out.isinf().any(), "Output contains Inf" for name, grad in [("input", inp_grad), ("mul", mul_grad), ("weight", w_grad), ("bias", b_grad)]: assert not grad.isnan().any(), f"{name}_grad contains NaN" assert not grad.isinf().any(), f"{name}_grad contains Inf" # --------------------------------------------------------------------------- # Scores tests # --------------------------------------------------------------------------- @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("num_experts", [8, 48]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_fused_mul_grouped_poly_norm_scores_forward( num_tokens, d, num_experts, dtype, device, ): """Forward with scores should match reference.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_inputs( num_tokens, d, num_experts, dtype, device) scores = _make_scores(num_tokens, device) out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias, offsets, scores=scores) out_tri = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias, offsets, scores=scores) atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2) assert_close(out_ref, out_tri, atol=atol, rtol=rtol) @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("num_experts", [8, 48]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_fused_mul_grouped_poly_norm_scores_backward( num_tokens, d, num_experts, dtype, device, ): """Backward with scores should match reference.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_inputs( num_tokens, d, num_experts, dtype, device) scores = _make_scores(num_tokens, device) out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(input_t, mul_t, weight, bias, offsets, scores=scores) out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(input_t, mul_t, weight, bias, offsets, scores=scores) atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2) # weight/bias grads use atomicAdd accumulation across tokens, # so allow slightly higher tolerance for fp32 wg_atol = 5e-4 if dtype == torch.float32 else 5e-2 assert_close(ig_ref, ig_tri, atol=atol, rtol=rtol) assert_close(mg_ref, mg_tri, atol=atol, rtol=rtol) assert_close(wg_ref, wg_tri, atol=wg_atol, rtol=wg_atol) assert_close(bg_ref, bg_tri, atol=wg_atol, rtol=wg_atol) assert_close(sg_ref, sg_tri, atol=atol, rtol=rtol) # --------------------------------------------------------------------------- # Hidden clamp tests # --------------------------------------------------------------------------- CLAMP_VALUES = [10.0, 1.0, 0.5] @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("d", [256, 1280]) @pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_fused_mul_grouped_poly_norm_hidden_clamp_forward( num_tokens, d, num_experts, dtype, hidden_clamp, device, ): """Forward with hidden_clamp should match reference.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_inputs( num_tokens, d, num_experts, dtype, device) scores = _make_scores(num_tokens, device) out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias, offsets, scores=scores, hidden_clamp=hidden_clamp) out_tri = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias, offsets, scores=scores, hidden_clamp=hidden_clamp) atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2) assert_close(out_ref, out_tri, atol=atol, rtol=rtol) @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("d", [256, 1280]) @pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_fused_mul_grouped_poly_norm_hidden_clamp_backward( num_tokens, d, num_experts, dtype, hidden_clamp, device, ): """Backward with hidden_clamp should match reference.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_inputs( num_tokens, d, num_experts, dtype, device) scores = _make_scores(num_tokens, device) out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref( input_t, mul_t, weight, bias, offsets, scores=scores, hidden_clamp=hidden_clamp) out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda( input_t, mul_t, weight, bias, offsets, scores=scores, hidden_clamp=hidden_clamp) atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2) # weight/bias grads use atomicAdd accumulation across tokens, # so allow slightly higher tolerance for fp32 wg_atol = 5e-4 if dtype == torch.float32 else 5e-2 assert_close(ig_ref, ig_tri, atol=atol, rtol=rtol) assert_close(mg_ref, mg_tri, atol=atol, rtol=rtol) assert_close(wg_ref, wg_tri, atol=wg_atol, rtol=wg_atol) assert_close(bg_ref, bg_tri, atol=wg_atol, rtol=wg_atol) assert_close(sg_ref, sg_tri, atol=atol, rtol=rtol) # --------------------------------------------------------------------------- # Padding-aware tests # --------------------------------------------------------------------------- PADDING_SIZES = [64, 256] def _make_padded_inputs(num_valid_tokens, num_padding, hidden_dim, num_experts, dtype, device, seed=42, expert_offset=0): """Create inputs with extra padding rows (large values) beyond valid tokens.""" torch.manual_seed(seed) probs = torch.ones(num_experts) / num_experts assignments = torch.multinomial(probs, num_valid_tokens, replacement=True) counts = torch.bincount(assignments, minlength=num_experts).tolist() total_experts = expert_offset + num_experts M = num_valid_tokens + num_padding input_t = torch.randn(M, hidden_dim, device=device, dtype=dtype) * 0.5 mul_t = torch.randn(M, hidden_dim, device=device, dtype=dtype) * 0.5 # Padding rows get large values to catch contamination input_t[num_valid_tokens:] = torch.randn( num_padding, hidden_dim, device=device, dtype=dtype) * 5 mul_t[num_valid_tokens:] = torch.randn( num_padding, hidden_dim, device=device, dtype=dtype) * 5 weight = (torch.ones(total_experts, 3, device=device, dtype=dtype) / 3 + torch.randn(total_experts, 3, device=device, dtype=dtype) * 0.01) bias = torch.randn(total_experts, 1, device=device, dtype=dtype) * 0.01 offsets = _counts_to_offsets(counts, device) return input_t, mul_t, weight, bias, offsets @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", [512, 4096]) @pytest.mark.parametrize("num_padding", PADDING_SIZES) @pytest.mark.parametrize("d", [256, 1280]) @pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_padded_forward(num_tokens, num_padding, d, num_experts, dtype, device): """Forward with padded input: valid rows correct, padding rows zero.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_padded_inputs( num_tokens, num_padding, d, num_experts, dtype, device) num_valid = int(offsets[-1].item()) M = num_tokens + num_padding out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias, offsets) out_cuda = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias, offsets) assert out_cuda.shape == (M, d) assert out_ref.shape == (M, d) atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2) assert_close(out_cuda[:num_valid], out_ref[:num_valid], atol=atol, rtol=rtol) assert out_cuda[num_valid:].abs().max() == 0, \ f"Padding rows not zero: max={out_cuda[num_valid:].abs().max().item()}" @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", [512, 4096]) @pytest.mark.parametrize("num_padding", PADDING_SIZES) @pytest.mark.parametrize("d", [256, 1280]) @pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_padded_backward(num_tokens, num_padding, d, num_experts, dtype, device): """Backward with padded input: dW/dB not corrupted, padding grads zero.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_padded_inputs( num_tokens, num_padding, d, num_experts, dtype, device) num_valid = int(offsets[-1].item()) # Reference backward on valid-only rows _, ig_ref, mg_ref, wg_ref, bg_ref, _ = _run_ref(input_t[:num_valid], mul_t[:num_valid], weight, bias, offsets) # CUDA backward on full padded input _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, _ = _run_cuda( input_t, mul_t, weight, bias, offsets) atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2) wg_atol = 5e-4 if dtype == torch.float32 else 5e-2 # Weight/bias grads must match (not corrupted by padding) assert_close(wg_cuda, wg_ref, atol=wg_atol, rtol=wg_atol) assert_close(bg_cuda, bg_ref, atol=wg_atol, rtol=wg_atol) # Valid row grads must match assert_close(ig_cuda[:num_valid], ig_ref, atol=atol, rtol=rtol) assert_close(mg_cuda[:num_valid], mg_ref, atol=atol, rtol=rtol) # Padding row grads must be zero assert ig_cuda[num_valid:].abs().max() == 0, \ "Padding rows in input grad should be zero" assert mg_cuda[num_valid:].abs().max() == 0, \ "Padding rows in mul grad should be zero" @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", [512, 4096]) @pytest.mark.parametrize("num_padding", PADDING_SIZES) @pytest.mark.parametrize("d", [256]) @pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_padded_forward_scored(num_tokens, num_padding, d, num_experts, dtype, device): """Forward with padded input + scores: valid rows correct, padding zero.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_padded_inputs( num_tokens, num_padding, d, num_experts, dtype, device) num_valid = int(offsets[-1].item()) M = num_tokens + num_padding scores = _make_scores(M, device) out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias, offsets, scores=scores) out_cuda = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias, offsets, scores=scores) atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2) assert_close(out_cuda[:num_valid], out_ref[:num_valid], atol=atol, rtol=rtol) assert out_cuda[num_valid:].abs().max() == 0, \ "Padding rows not zero with scores" @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", [512, 4096]) @pytest.mark.parametrize("num_padding", PADDING_SIZES) @pytest.mark.parametrize("d", [256]) @pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_padded_backward_scored(num_tokens, num_padding, d, num_experts, dtype, device): """Backward with padded input + scores: grads correct, padding zero.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_padded_inputs( num_tokens, num_padding, d, num_experts, dtype, device) num_valid = int(offsets[-1].item()) M = num_tokens + num_padding scores = _make_scores(M, device) # Reference on valid-only _, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref( input_t[:num_valid], mul_t[:num_valid], weight, bias, offsets, scores=scores[:num_valid]) # CUDA on full padded _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(input_t, mul_t, weight, bias, offsets, scores=scores) atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2) wg_atol = 5e-4 if dtype == torch.float32 else 5e-2 assert_close(wg_cuda, wg_ref, atol=wg_atol, rtol=wg_atol) assert_close(bg_cuda, bg_ref, atol=wg_atol, rtol=wg_atol) assert_close(ig_cuda[:num_valid], ig_ref, atol=atol, rtol=rtol) assert_close(mg_cuda[:num_valid], mg_ref, atol=atol, rtol=rtol) assert_close(sg_cuda[:num_valid], sg_ref, atol=atol, rtol=rtol) assert ig_cuda[num_valid:].abs().max() == 0 assert mg_cuda[num_valid:].abs().max() == 0 assert sg_cuda[num_valid:].abs().max() == 0 @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", [512, 4096]) @pytest.mark.parametrize("num_padding", PADDING_SIZES) @pytest.mark.parametrize("d", [256, 1280]) @pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("hidden_clamp", [10.0, 1.0]) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_padded_forward_scored_clamp(num_tokens, num_padding, d, num_experts, dtype, hidden_clamp, device): """Forward with padded input + scores + hidden_clamp.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_padded_inputs( num_tokens, num_padding, d, num_experts, dtype, device) num_valid = int(offsets[-1].item()) M = num_tokens + num_padding scores = _make_scores(M, device) out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias, offsets, scores=scores, hidden_clamp=hidden_clamp) out_cuda = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias, offsets, scores=scores, hidden_clamp=hidden_clamp) atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2) assert_close(out_cuda[:num_valid], out_ref[:num_valid], atol=atol, rtol=rtol) assert out_cuda[num_valid:].abs().max() == 0, \ "Padding rows not zero with scores+clamp" @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available") @pytest.mark.parametrize("num_tokens", [512, 4096]) @pytest.mark.parametrize("num_padding", PADDING_SIZES) @pytest.mark.parametrize("d", [256, 1280]) @pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("hidden_clamp", [10.0, 1.0]) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_padded_backward_scored_clamp(num_tokens, num_padding, d, num_experts, dtype, hidden_clamp, device): """Backward with padded input + scores + hidden_clamp: all grads correct.""" torch.set_default_device(device) input_t, mul_t, weight, bias, offsets = _make_padded_inputs( num_tokens, num_padding, d, num_experts, dtype, device) num_valid = int(offsets[-1].item()) M = num_tokens + num_padding scores = _make_scores(M, device) # Reference on valid-only _, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref( input_t[:num_valid], mul_t[:num_valid], weight, bias, offsets, scores=scores[:num_valid], hidden_clamp=hidden_clamp) # CUDA on full padded _, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda( input_t, mul_t, weight, bias, offsets, scores=scores, hidden_clamp=hidden_clamp) atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2) wg_atol = 5e-4 if dtype == torch.float32 else 5e-2 assert_close(wg_cuda, wg_ref, atol=wg_atol, rtol=wg_atol) assert_close(bg_cuda, bg_ref, atol=wg_atol, rtol=wg_atol) assert_close(ig_cuda[:num_valid], ig_ref, atol=atol, rtol=rtol) assert_close(mg_cuda[:num_valid], mg_ref, atol=atol, rtol=rtol) assert_close(sg_cuda[:num_valid], sg_ref, atol=atol, rtol=rtol) assert ig_cuda[num_valid:].abs().max() == 0 assert mg_cuda[num_valid:].abs().max() == 0 assert sg_cuda[num_valid:].abs().max() == 0