| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn.functional as F |
| import triton |
| import triton.language as tl |
|
|
|
|
| |
| @triton.jit |
| def srms_norm_fw(X, Y, V, stride, N, eps, BLOCK_SIZE_N: tl.constexpr): |
| |
| row = tl.program_id(0) |
| cols = tl.arange(0, BLOCK_SIZE_N) |
| mask = cols < N |
|
|
| |
| x_ptrs = X + row * stride + cols |
| x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) |
|
|
| x_zm = tl.where(mask, x, 0.0) |
|
|
| x_var = tl.sum(x_zm * x_zm, axis=0) / N |
| rstd = 1.0 / tl.sqrt(x_var + eps) |
|
|
| |
| y = x_zm * rstd |
| tl.store(V + row, rstd) |
|
|
| y_ptrs = Y + row * stride + cols |
| tl.store(y_ptrs, y, mask=mask) |
|
|
|
|
| |
| |
| @triton.jit |
| def srms_norm_bwd_dx_fused( |
| DX, DY, |
| X, V, |
| stride, N, |
| |
| BLOCK_SIZE_N: tl.constexpr, |
| ): |
| |
|
|
| |
| row = tl.program_id(0) |
| cols = tl.arange(0, BLOCK_SIZE_N) |
| mask = cols < N |
|
|
| |
| x_ptrs = X + row * stride + cols |
| dy_ptrs = DY + row * stride + cols |
|
|
| |
| x = tl.load(x_ptrs, mask=mask, other=0) |
| dy = tl.load(dy_ptrs, mask=mask, other=0) |
| rstd = tl.load(V + row) |
|
|
| |
| xhat = x * rstd |
| wdy = dy |
|
|
| xhat = tl.where(mask, xhat, 0.) |
| wdy = tl.where(mask, wdy, 0.) |
| mean1 = tl.sum(xhat * wdy, axis=0) / N |
| dx = (wdy - (xhat * mean1)) * rstd |
|
|
| |
| mask = cols < N |
| dx_ptrs = DX + row * stride + cols |
| tl.store(dx_ptrs, dx, mask=mask) |
|
|
|
|
| class _SrmsNorm(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward(ctx, x, eps): |
| |
| if x.dtype == torch.float16: |
| eps = max(eps, 1.6e-5) |
|
|
| |
| y = torch.empty_like(x) |
|
|
| |
| x_arg = x.reshape(-1, x.shape[-1]) |
| M, N = x_arg.shape |
|
|
| |
| rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) |
|
|
| |
| MAX_FUSED_SIZE = 65536 // x.element_size() |
| BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) |
| if N > BLOCK_SIZE_N: |
| raise RuntimeError( |
| "This layer norm doesn't support feature dim >= 64KB.") |
|
|
| if not x_arg.is_contiguous() or not y.is_contiguous(): |
| x_arg = x_arg.contiguous() |
| y = y.contiguous() |
|
|
| |
| num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16) |
|
|
| |
| |
| srms_norm_fw[(M,)]( |
| x_arg, y, rstd, |
| x_arg.stride(0), |
| N, |
| eps, |
| num_warps=num_warps, |
| BLOCK_SIZE_N=BLOCK_SIZE_N, |
| ) |
| |
|
|
| ctx.save_for_backward(x, rstd) |
| ctx.BLOCK_SIZE_N = BLOCK_SIZE_N |
| ctx.num_warps = num_warps |
|
|
| return y.reshape_as(x) |
|
|
| @staticmethod |
| def backward( |
| ctx, dy |
| ): |
| x, rstd = ctx.saved_tensors |
|
|
| |
| |
| x = x.reshape(-1, x.size(-1)) |
| M, N = x.size() |
|
|
| |
| GROUP_SIZE_M = 32 |
| if N <= 8192: |
| GROUP_SIZE_M = 64 |
| if N <= 4096: |
| GROUP_SIZE_M = 96 |
| if N <= 2048: |
| GROUP_SIZE_M = 128 |
| if N <= 1024: |
| GROUP_SIZE_M = 256 |
|
|
| if dy.dtype == torch.float32: |
| GROUP_SIZE_M = GROUP_SIZE_M // 2 |
|
|
| |
| dy = dy.contiguous() |
| dx = torch.empty_like(dy) |
|
|
| |
| |
| assert ( |
| dy.numel() == x.numel() |
| ), "Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm" |
|
|
| |
| |
| num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16) |
|
|
| |
| srms_norm_bwd_dx_fused[(M,)]( |
| dx, dy, x, |
| rstd, |
| x.stride(0), |
| N, |
| BLOCK_SIZE_N=ctx.BLOCK_SIZE_N, |
| num_warps=num_warps |
| ) |
| |
|
|
| dx = dx.reshape_as(dy) |
| return dx, None, None |
|
|
|
|
| class SimpleRMSNorm(torch.nn.Module): |
|
|
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.dim = dim |
|
|
| def forward(self, x): |
| return _SrmsNorm.apply(x, self.eps) |
|
|