| import random |
|
|
| import pytest |
| import torch |
|
|
| import activation |
|
|
| from .utils import assert_close, opcheck |
|
|
| DTYPES = [torch.float, torch.bfloat16, torch.half] |
| NUM_TOKENS = [7, 83, 256, 2048] |
| D = [1, 7, 512, 13824] |
| SEEDS = [0] |
| CUDA_DEVICES = [ |
| f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) |
| ] |
|
|
|
|
| def add_rms_norm_all_naive(x: torch.Tensor, residual: torch.Tensor, |
| weight: torch.Tensor, eps: float) -> torch.Tensor: |
| h = x + residual |
| return torch.nn.functional.rms_norm(h, weight.shape, weight, eps) + h |
|
|
|
|
| |
| def add_rms_norm_partial_naive(x: torch.Tensor, residual: torch.Tensor, |
| weight: torch.Tensor, |
| eps: float) -> torch.Tensor: |
| h = x + residual |
| return activation.rms_norm(h, weight, eps) + h |
|
|
|
|
| def fused_add_rms_norm(x: torch.Tensor, residual: torch.Tensor, |
| weight: torch.Tensor, eps: float) -> torch.Tensor: |
| out, h = activation.fused_add_rms_norm(x, residual, weight, eps) |
| return out + h |
|
|
|
|
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| @pytest.mark.parametrize("d", D) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("seed", SEEDS) |
| @pytest.mark.parametrize("device", CUDA_DEVICES) |
| def test_fused_add_rms_norm( |
| num_tokens: int, |
| d: int, |
| dtype: torch.dtype, |
| seed: int, |
| device: str, |
| ) -> None: |
| random.seed(seed) |
| torch.manual_seed(seed) |
| torch.set_default_device(device) |
|
|
| x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True) |
| residual = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True) |
| weight = torch.randn(d, dtype=dtype, requires_grad=True) |
| eps = 1e-05 |
|
|
| x.retain_grad() |
| residual.retain_grad() |
| weight.retain_grad() |
| |
|
|
| x_ref = x.detach().clone().requires_grad_(True) |
| residual_ref = residual.detach().clone().requires_grad_(True) |
| weight_ref = weight.detach().clone().requires_grad_(True) |
|
|
| x_ref2 = x.detach().clone().requires_grad_(True) |
| residual_ref2 = residual.detach().clone().requires_grad_(True) |
| weight_ref2 = weight.detach().clone().requires_grad_(True) |
|
|
| torch_fn = add_rms_norm_all_naive |
| torch_fn2 = add_rms_norm_partial_naive |
|
|
| op = activation.ops.fused_add_rms_norm |
| fn = fused_add_rms_norm |
|
|
| layer = activation.layers.FusedAddRMSNorm(d, eps) |
| layer.weight = torch.nn.Parameter(weight) |
|
|
| out = torch.empty(x.shape, dtype=x.dtype, device=x.device) |
| add_out = torch.empty(x.shape, dtype=x.dtype, device=x.device) |
| opcheck(op, (x, residual, weight, eps)) |
|
|
| out = fn(x, residual, weight, eps) |
| mod_out, mod_a_out = layer(x, residual) |
| mod_out = mod_out + mod_a_out |
| ref_out = torch_fn(x_ref, residual_ref, weight_ref, eps) |
| ref_out2 = torch_fn2(x_ref2, residual_ref2, weight_ref2, eps) |
|
|
| assert_close(out, ref_out, atol=0.05, rtol=0.05) |
| assert_close(out, ref_out2) |
| assert_close(mod_out, out, atol=0.0, rtol=0.0) |
|
|
| |
| out_grad = torch.randn_like(out) |
| out_grad = out_grad / out_grad.norm() |
|
|
| ref_out.backward(out_grad) |
| ref_out2.backward(out_grad) |
| mod_out.backward(out_grad) |
|
|
| assert_close(x.grad, x_ref.grad) |
| assert_close(x.grad, x_ref2.grad) |
| assert_close(residual.grad, residual_ref.grad) |
| assert_close(residual.grad, residual_ref2.grad) |
| assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05) |
| assert_close(layer.weight.grad, weight_ref2.grad, rtol=0.05) |
|
|