Kernels
activation / tests /test_fused_mul_grouped_poly_norm.py
wyldecat's picture
style: fix yapf/isort formatting for CI --all-files check
3f2678c
import pytest
import torch
from grouped_poly_norm import _has_cuda_ops, fused_mul_grouped_poly_norm_ref
if _has_cuda_ops:
from grouped_poly_norm import fused_mul_grouped_poly_norm
from .utils import assert_close
DTYPES = [torch.float, torch.bfloat16, torch.float16]
NUM_TOKENS = [4096, 8192]
D = [256, 1280]
NUM_EXPERTS_LIST = [8, 384]
EXPERT_OFFSETS = [0, 4]
SEEDS = [0]
# Only test on cuda:0 to avoid cross-device issues.
CUDA_DEVICES = ["cuda:0"]
def _counts_to_offsets(counts_list, device):
"""Convert list of counts to cumsum offsets tensor."""
return torch.cumsum(torch.tensor(counts_list,
device=device,
dtype=torch.int32),
dim=0).to(torch.int32)
def _make_inputs(total_tokens,
hidden_dim,
num_experts,
dtype,
device,
seed=42,
expert_offset=0):
"""Create deterministic test inputs with random token distribution."""
torch.manual_seed(seed)
probs = torch.ones(num_experts) / num_experts
assignments = torch.multinomial(probs, total_tokens, replacement=True)
counts = torch.bincount(assignments, minlength=num_experts).tolist()
# Weight/bias must have expert_offset + num_experts rows
total_experts = expert_offset + num_experts
# Scale inputs to avoid overflow in bf16 (x^3 can overflow for |x| > 40)
input_t = torch.randn(total_tokens, hidden_dim, device=device,
dtype=dtype) * 0.5
mul_t = torch.randn(total_tokens, hidden_dim, device=device,
dtype=dtype) * 0.5
weight = (torch.ones(total_experts, 3, device=device, dtype=dtype) / 3 +
torch.randn(total_experts, 3, device=device, dtype=dtype) * 0.01)
bias = torch.randn(total_experts, 1, device=device, dtype=dtype) * 0.01
offsets = _counts_to_offsets(counts, device)
return input_t, mul_t, weight, bias, offsets
def _make_scores(total_tokens, device, dtype=torch.float32):
"""Create random scores (N, 1) in fp32."""
return torch.rand(total_tokens, 1, device=device, dtype=dtype) * 0.5 + 0.5
def _run_ref(input_t,
mul_t,
weight,
bias,
offsets,
expert_offset=0,
scores=None,
hidden_clamp=None):
"""Run reference forward + backward, return output and grads."""
inp = input_t.clone().detach().requires_grad_(True)
m = mul_t.clone().detach().requires_grad_(True)
w = weight.clone().detach().requires_grad_(True)
b = bias.clone().detach().requires_grad_(True)
s = scores.clone().detach().requires_grad_(
True) if scores is not None else None
out = fused_mul_grouped_poly_norm_ref(inp,
m,
w,
b,
offsets,
expert_offset=expert_offset,
scores=s,
hidden_clamp=hidden_clamp)
out.sum().backward()
grads = (out, inp.grad, m.grad, w.grad, b.grad)
return grads + (s.grad, ) if s is not None else grads + (None, )
def _run_cuda(input_t,
mul_t,
weight,
bias,
offsets,
expert_offset=0,
scores=None,
hidden_clamp=None):
"""Run CUDA forward + backward, return output and grads."""
inp = input_t.clone().detach().requires_grad_(True)
m = mul_t.clone().detach().requires_grad_(True)
w = weight.clone().detach().requires_grad_(True)
b = bias.clone().detach().requires_grad_(True)
s = scores.clone().detach().requires_grad_(
True) if scores is not None else None
out = fused_mul_grouped_poly_norm(inp,
m,
w,
b,
offsets,
expert_offset=expert_offset,
scores=s,
hidden_clamp=hidden_clamp)
out.sum().backward()
grads = (out, inp.grad, m.grad, w.grad, b.grad)
return grads + (s.grad, ) if s is not None else grads + (None, )
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_mul_grouped_poly_norm_forward(
num_tokens: int,
d: int,
num_experts: int,
dtype: torch.dtype,
expert_offset: int,
seed: int,
device: str,
) -> None:
"""CUDA forward output should match PyTorch reference."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_inputs(
num_tokens,
d,
num_experts,
dtype,
device,
seed,
expert_offset=expert_offset)
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
mul_t,
weight,
bias,
offsets,
expert_offset=expert_offset)
out_tri = fused_mul_grouped_poly_norm(input_t,
mul_t,
weight,
bias,
offsets,
expert_offset=expert_offset)
assert out_ref.shape == out_tri.shape == (num_tokens, d)
assert out_ref.dtype == out_tri.dtype == dtype
if dtype == torch.float32:
assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4)
else:
assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_mul_grouped_poly_norm_backward(
num_tokens: int,
d: int,
num_experts: int,
dtype: torch.dtype,
expert_offset: int,
seed: int,
device: str,
) -> None:
"""CUDA backward gradients should match PyTorch reference."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_inputs(
num_tokens,
d,
num_experts,
dtype,
device,
seed,
expert_offset=expert_offset)
_, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref(
input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
_, inp_grad_tri, mul_grad_tri, w_grad_tri, b_grad_tri, _ = _run_cuda(
input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
assert_close(inp_grad_ref, inp_grad_tri, atol=atol, rtol=rtol)
assert_close(mul_grad_ref, mul_grad_tri, atol=atol, rtol=rtol)
assert_close(w_grad_ref, w_grad_tri, atol=atol, rtol=rtol)
assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol)
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_mul_grouped_poly_norm_zero_token_experts(
dtype: torch.dtype,
expert_offset: int,
device: str,
) -> None:
"""Correctness when some experts receive 0 tokens."""
torch.set_default_device(device)
counts = [100, 50, 0, 80, 30, 0, 60, 40]
total = sum(counts)
num_experts = 8
total_experts = expert_offset + num_experts
hidden_dim = 256
torch.manual_seed(42)
input_t = torch.randn(total, hidden_dim, device=device, dtype=dtype) * 0.5
mul_t = torch.randn(total, hidden_dim, device=device, dtype=dtype) * 0.5
weight = torch.ones(total_experts, 3, device=device, dtype=dtype) / 3
bias = torch.zeros(total_experts, 1, device=device, dtype=dtype)
offsets = _counts_to_offsets(counts, device)
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
mul_t,
weight,
bias,
offsets,
expert_offset=expert_offset)
out_tri = fused_mul_grouped_poly_norm(input_t,
mul_t,
weight,
bias,
offsets,
expert_offset=expert_offset)
if dtype == torch.float32:
assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4)
else:
assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
# Check backward with zero-token experts
_, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t,
mul_t,
weight,
bias,
offsets,
expert_offset=expert_offset)
_, _, _, w_grad_tri, b_grad_tri, _ = _run_cuda(input_t,
mul_t,
weight,
bias,
offsets,
expert_offset=expert_offset)
if dtype == torch.float32:
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-2, 5e-2
assert_close(w_grad_ref, w_grad_tri, atol=atol, rtol=rtol)
assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol)
# Verify zero-token experts have zero weight/bias gradients
for eidx in [2, 5]:
wi = eidx + expert_offset
assert w_grad_tri[wi].abs().max() == 0, (
f"Expert {eidx} (weight idx {wi}) should have zero weight grad "
f"but got max={w_grad_tri[wi].abs().max().item()}")
assert b_grad_tri[wi].abs().max() == 0, (
f"Expert {eidx} (weight idx {wi}) should have zero bias grad "
f"but got max={b_grad_tri[wi].abs().max().item()}")
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_mul_grouped_poly_norm_no_nan_inf(
dtype: torch.dtype,
expert_offset: int,
device: str,
) -> None:
"""Output and gradients should not contain NaN or Inf."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_inputs(
4096, 256, 8, dtype, device, expert_offset=expert_offset)
out, inp_grad, mul_grad, w_grad, b_grad, _ = _run_cuda(
input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
assert not out.isnan().any(), "Output contains NaN"
assert not out.isinf().any(), "Output contains Inf"
for name, grad in [("input", inp_grad), ("mul", mul_grad),
("weight", w_grad), ("bias", b_grad)]:
assert not grad.isnan().any(), f"{name}_grad contains NaN"
assert not grad.isinf().any(), f"{name}_grad contains Inf"
# ---------------------------------------------------------------------------
# Scores tests
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("num_experts", [8, 48])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_mul_grouped_poly_norm_scores_forward(
num_tokens,
d,
num_experts,
dtype,
device,
):
"""Forward with scores should match reference."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_inputs(
num_tokens, d, num_experts, dtype, device)
scores = _make_scores(num_tokens, device)
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
mul_t,
weight,
bias,
offsets,
scores=scores)
out_tri = fused_mul_grouped_poly_norm(input_t,
mul_t,
weight,
bias,
offsets,
scores=scores)
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("num_experts", [8, 48])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_mul_grouped_poly_norm_scores_backward(
num_tokens,
d,
num_experts,
dtype,
device,
):
"""Backward with scores should match reference."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_inputs(
num_tokens, d, num_experts, dtype, device)
scores = _make_scores(num_tokens, device)
out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(input_t,
mul_t,
weight,
bias,
offsets,
scores=scores)
out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(input_t,
mul_t,
weight,
bias,
offsets,
scores=scores)
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
# weight/bias grads use atomicAdd accumulation across tokens,
# so allow slightly higher tolerance for fp32
wg_atol = 5e-4 if dtype == torch.float32 else 5e-2
assert_close(ig_ref, ig_tri, atol=atol, rtol=rtol)
assert_close(mg_ref, mg_tri, atol=atol, rtol=rtol)
assert_close(wg_ref, wg_tri, atol=wg_atol, rtol=wg_atol)
assert_close(bg_ref, bg_tri, atol=wg_atol, rtol=wg_atol)
assert_close(sg_ref, sg_tri, atol=atol, rtol=rtol)
# ---------------------------------------------------------------------------
# Hidden clamp tests
# ---------------------------------------------------------------------------
CLAMP_VALUES = [10.0, 1.0, 0.5]
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("d", [256, 1280])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
num_tokens,
d,
num_experts,
dtype,
hidden_clamp,
device,
):
"""Forward with hidden_clamp should match reference."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_inputs(
num_tokens, d, num_experts, dtype, device)
scores = _make_scores(num_tokens, device)
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
mul_t,
weight,
bias,
offsets,
scores=scores,
hidden_clamp=hidden_clamp)
out_tri = fused_mul_grouped_poly_norm(input_t,
mul_t,
weight,
bias,
offsets,
scores=scores,
hidden_clamp=hidden_clamp)
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("d", [256, 1280])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
num_tokens,
d,
num_experts,
dtype,
hidden_clamp,
device,
):
"""Backward with hidden_clamp should match reference."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_inputs(
num_tokens, d, num_experts, dtype, device)
scores = _make_scores(num_tokens, device)
out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
input_t,
mul_t,
weight,
bias,
offsets,
scores=scores,
hidden_clamp=hidden_clamp)
out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(
input_t,
mul_t,
weight,
bias,
offsets,
scores=scores,
hidden_clamp=hidden_clamp)
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
# weight/bias grads use atomicAdd accumulation across tokens,
# so allow slightly higher tolerance for fp32
wg_atol = 5e-4 if dtype == torch.float32 else 5e-2
assert_close(ig_ref, ig_tri, atol=atol, rtol=rtol)
assert_close(mg_ref, mg_tri, atol=atol, rtol=rtol)
assert_close(wg_ref, wg_tri, atol=wg_atol, rtol=wg_atol)
assert_close(bg_ref, bg_tri, atol=wg_atol, rtol=wg_atol)
assert_close(sg_ref, sg_tri, atol=atol, rtol=rtol)
# ---------------------------------------------------------------------------
# Padding-aware tests
# ---------------------------------------------------------------------------
PADDING_SIZES = [64, 256]
def _make_padded_inputs(num_valid_tokens,
num_padding,
hidden_dim,
num_experts,
dtype,
device,
seed=42,
expert_offset=0):
"""Create inputs with extra padding rows (large values) beyond valid tokens."""
torch.manual_seed(seed)
probs = torch.ones(num_experts) / num_experts
assignments = torch.multinomial(probs, num_valid_tokens, replacement=True)
counts = torch.bincount(assignments, minlength=num_experts).tolist()
total_experts = expert_offset + num_experts
M = num_valid_tokens + num_padding
input_t = torch.randn(M, hidden_dim, device=device, dtype=dtype) * 0.5
mul_t = torch.randn(M, hidden_dim, device=device, dtype=dtype) * 0.5
# Padding rows get large values to catch contamination
input_t[num_valid_tokens:] = torch.randn(
num_padding, hidden_dim, device=device, dtype=dtype) * 5
mul_t[num_valid_tokens:] = torch.randn(
num_padding, hidden_dim, device=device, dtype=dtype) * 5
weight = (torch.ones(total_experts, 3, device=device, dtype=dtype) / 3 +
torch.randn(total_experts, 3, device=device, dtype=dtype) * 0.01)
bias = torch.randn(total_experts, 1, device=device, dtype=dtype) * 0.01
offsets = _counts_to_offsets(counts, device)
return input_t, mul_t, weight, bias, offsets
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", [512, 4096])
@pytest.mark.parametrize("num_padding", PADDING_SIZES)
@pytest.mark.parametrize("d", [256, 1280])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_padded_forward(num_tokens, num_padding, d, num_experts, dtype,
device):
"""Forward with padded input: valid rows correct, padding rows zero."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_padded_inputs(
num_tokens, num_padding, d, num_experts, dtype, device)
num_valid = int(offsets[-1].item())
M = num_tokens + num_padding
out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias,
offsets)
out_cuda = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias,
offsets)
assert out_cuda.shape == (M, d)
assert out_ref.shape == (M, d)
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
assert_close(out_cuda[:num_valid],
out_ref[:num_valid],
atol=atol,
rtol=rtol)
assert out_cuda[num_valid:].abs().max() == 0, \
f"Padding rows not zero: max={out_cuda[num_valid:].abs().max().item()}"
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", [512, 4096])
@pytest.mark.parametrize("num_padding", PADDING_SIZES)
@pytest.mark.parametrize("d", [256, 1280])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_padded_backward(num_tokens, num_padding, d, num_experts, dtype,
device):
"""Backward with padded input: dW/dB not corrupted, padding grads zero."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_padded_inputs(
num_tokens, num_padding, d, num_experts, dtype, device)
num_valid = int(offsets[-1].item())
# Reference backward on valid-only rows
_, ig_ref, mg_ref, wg_ref, bg_ref, _ = _run_ref(input_t[:num_valid],
mul_t[:num_valid], weight,
bias, offsets)
# CUDA backward on full padded input
_, ig_cuda, mg_cuda, wg_cuda, bg_cuda, _ = _run_cuda(
input_t, mul_t, weight, bias, offsets)
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
wg_atol = 5e-4 if dtype == torch.float32 else 5e-2
# Weight/bias grads must match (not corrupted by padding)
assert_close(wg_cuda, wg_ref, atol=wg_atol, rtol=wg_atol)
assert_close(bg_cuda, bg_ref, atol=wg_atol, rtol=wg_atol)
# Valid row grads must match
assert_close(ig_cuda[:num_valid], ig_ref, atol=atol, rtol=rtol)
assert_close(mg_cuda[:num_valid], mg_ref, atol=atol, rtol=rtol)
# Padding row grads must be zero
assert ig_cuda[num_valid:].abs().max() == 0, \
"Padding rows in input grad should be zero"
assert mg_cuda[num_valid:].abs().max() == 0, \
"Padding rows in mul grad should be zero"
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", [512, 4096])
@pytest.mark.parametrize("num_padding", PADDING_SIZES)
@pytest.mark.parametrize("d", [256])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_padded_forward_scored(num_tokens, num_padding, d, num_experts, dtype,
device):
"""Forward with padded input + scores: valid rows correct, padding zero."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_padded_inputs(
num_tokens, num_padding, d, num_experts, dtype, device)
num_valid = int(offsets[-1].item())
M = num_tokens + num_padding
scores = _make_scores(M, device)
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
mul_t,
weight,
bias,
offsets,
scores=scores)
out_cuda = fused_mul_grouped_poly_norm(input_t,
mul_t,
weight,
bias,
offsets,
scores=scores)
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
assert_close(out_cuda[:num_valid],
out_ref[:num_valid],
atol=atol,
rtol=rtol)
assert out_cuda[num_valid:].abs().max() == 0, \
"Padding rows not zero with scores"
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", [512, 4096])
@pytest.mark.parametrize("num_padding", PADDING_SIZES)
@pytest.mark.parametrize("d", [256])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_padded_backward_scored(num_tokens, num_padding, d, num_experts, dtype,
device):
"""Backward with padded input + scores: grads correct, padding zero."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_padded_inputs(
num_tokens, num_padding, d, num_experts, dtype, device)
num_valid = int(offsets[-1].item())
M = num_tokens + num_padding
scores = _make_scores(M, device)
# Reference on valid-only
_, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
input_t[:num_valid],
mul_t[:num_valid],
weight,
bias,
offsets,
scores=scores[:num_valid])
# CUDA on full padded
_, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(input_t,
mul_t,
weight,
bias,
offsets,
scores=scores)
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
wg_atol = 5e-4 if dtype == torch.float32 else 5e-2
assert_close(wg_cuda, wg_ref, atol=wg_atol, rtol=wg_atol)
assert_close(bg_cuda, bg_ref, atol=wg_atol, rtol=wg_atol)
assert_close(ig_cuda[:num_valid], ig_ref, atol=atol, rtol=rtol)
assert_close(mg_cuda[:num_valid], mg_ref, atol=atol, rtol=rtol)
assert_close(sg_cuda[:num_valid], sg_ref, atol=atol, rtol=rtol)
assert ig_cuda[num_valid:].abs().max() == 0
assert mg_cuda[num_valid:].abs().max() == 0
assert sg_cuda[num_valid:].abs().max() == 0
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", [512, 4096])
@pytest.mark.parametrize("num_padding", PADDING_SIZES)
@pytest.mark.parametrize("d", [256, 1280])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("hidden_clamp", [10.0, 1.0])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_padded_forward_scored_clamp(num_tokens, num_padding, d, num_experts,
dtype, hidden_clamp, device):
"""Forward with padded input + scores + hidden_clamp."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_padded_inputs(
num_tokens, num_padding, d, num_experts, dtype, device)
num_valid = int(offsets[-1].item())
M = num_tokens + num_padding
scores = _make_scores(M, device)
out_ref = fused_mul_grouped_poly_norm_ref(input_t,
mul_t,
weight,
bias,
offsets,
scores=scores,
hidden_clamp=hidden_clamp)
out_cuda = fused_mul_grouped_poly_norm(input_t,
mul_t,
weight,
bias,
offsets,
scores=scores,
hidden_clamp=hidden_clamp)
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
assert_close(out_cuda[:num_valid],
out_ref[:num_valid],
atol=atol,
rtol=rtol)
assert out_cuda[num_valid:].abs().max() == 0, \
"Padding rows not zero with scores+clamp"
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
@pytest.mark.parametrize("num_tokens", [512, 4096])
@pytest.mark.parametrize("num_padding", PADDING_SIZES)
@pytest.mark.parametrize("d", [256, 1280])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("hidden_clamp", [10.0, 1.0])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_padded_backward_scored_clamp(num_tokens, num_padding, d, num_experts,
dtype, hidden_clamp, device):
"""Backward with padded input + scores + hidden_clamp: all grads correct."""
torch.set_default_device(device)
input_t, mul_t, weight, bias, offsets = _make_padded_inputs(
num_tokens, num_padding, d, num_experts, dtype, device)
num_valid = int(offsets[-1].item())
M = num_tokens + num_padding
scores = _make_scores(M, device)
# Reference on valid-only
_, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
input_t[:num_valid],
mul_t[:num_valid],
weight,
bias,
offsets,
scores=scores[:num_valid],
hidden_clamp=hidden_clamp)
# CUDA on full padded
_, ig_cuda, mg_cuda, wg_cuda, bg_cuda, sg_cuda = _run_cuda(
input_t,
mul_t,
weight,
bias,
offsets,
scores=scores,
hidden_clamp=hidden_clamp)
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
wg_atol = 5e-4 if dtype == torch.float32 else 5e-2
assert_close(wg_cuda, wg_ref, atol=wg_atol, rtol=wg_atol)
assert_close(bg_cuda, bg_ref, atol=wg_atol, rtol=wg_atol)
assert_close(ig_cuda[:num_valid], ig_ref, atol=atol, rtol=rtol)
assert_close(mg_cuda[:num_valid], mg_ref, atol=atol, rtol=rtol)
assert_close(sg_cuda[:num_valid], sg_ref, atol=atol, rtol=rtol)
assert ig_cuda[num_valid:].abs().max() == 0
assert mg_cuda[num_valid:].abs().max() == 0
assert sg_cuda[num_valid:].abs().max() == 0