| """Triton kernels for video denoising (used by VideoHead).""" |
| import torch |
| import torch.nn as nn |
| from math import ceil as _ceil |
|
|
| from .ternary_scale import _HAS_TRITON |
|
|
| if _HAS_TRITON: |
| import triton |
| import triton.language as tl |
|
|
| @triton.jit |
| def _triton_video_denoise_fwd_kernel( |
| latent, pred_noise, out, |
| TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr, |
| ): |
| offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) |
| mask = offsets < TOTAL |
| l = tl.load(latent + offsets, mask=mask, other=0.0) |
| p = tl.load(pred_noise + offsets, mask=mask, other=0.0) |
| beta = 1.0 - ALPHA |
| inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001) |
| tl.store(out + offsets, (l - beta * p) * inv_sqrt, mask=mask) |
|
|
| @triton.jit |
| def _triton_video_denoise_bwd_kernel( |
| grad_out, grad_latent, grad_pred, |
| TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr, |
| ): |
| offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) |
| mask = offsets < TOTAL |
| g = tl.load(grad_out + offsets, mask=mask, other=0.0) |
| beta = 1.0 - ALPHA |
| inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001) |
| tl.store(grad_latent + offsets, g * inv_sqrt, mask=mask) |
| tl.store(grad_pred + offsets, -beta * g * inv_sqrt, mask=mask) |
|
|
|
|
| class _TritonVideoDenoiseFn(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, latent, pred_noise, alpha): |
| latent_c = latent.contiguous() |
| pred_c = pred_noise.contiguous() |
| out = torch.empty_like(latent_c) |
| total = latent_c.numel() |
| block = 256 |
| grid = (_ceil_div(total, block),) |
| alpha_f = float(alpha) |
| _triton_video_denoise_fwd_kernel[grid]( |
| latent_c, pred_c, out, |
| total, alpha_f, BLOCK=block, |
| ) |
| ctx.alpha = alpha_f |
| ctx.shape = latent.shape |
| return out.reshape_as(latent) |
|
|
| @staticmethod |
| def backward(ctx, grad_out): |
| grad_c = grad_out.contiguous() |
| grad_latent = torch.empty_like(grad_c) |
| grad_pred = torch.empty_like(grad_c) |
| total = grad_c.numel() |
| block = 256 |
| grid = (_ceil_div(total, block),) |
| _triton_video_denoise_bwd_kernel[grid]( |
| grad_c, grad_latent, grad_pred, |
| total, ctx.alpha, BLOCK=block, |
| ) |
| return grad_latent.reshape(ctx.shape), grad_pred.reshape(ctx.shape), None |
|
|
|
|
| def video_denoise_step(latent, pred_noise, alpha): |
| if _HAS_TRITON and latent.is_cuda and pred_noise.is_cuda and _TritonVideoDenoiseFn is not None: |
| return _TritonVideoDenoiseFn.apply(latent, pred_noise, alpha) |
| return (latent - (1 - alpha) * pred_noise) / (alpha ** 0.5 + 1e-8) |
|
|