"""Triton kernels for video denoising (used by VideoHead).""" import torch import torch.nn as nn from math import ceil as _ceil from .ternary_scale import _HAS_TRITON if _HAS_TRITON: import triton import triton.language as tl @triton.jit def _triton_video_denoise_fwd_kernel( latent, pred_noise, out, TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr, ): offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) mask = offsets < TOTAL l = tl.load(latent + offsets, mask=mask, other=0.0) p = tl.load(pred_noise + offsets, mask=mask, other=0.0) beta = 1.0 - ALPHA inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001) tl.store(out + offsets, (l - beta * p) * inv_sqrt, mask=mask) @triton.jit def _triton_video_denoise_bwd_kernel( grad_out, grad_latent, grad_pred, TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr, ): offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) mask = offsets < TOTAL g = tl.load(grad_out + offsets, mask=mask, other=0.0) beta = 1.0 - ALPHA inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001) tl.store(grad_latent + offsets, g * inv_sqrt, mask=mask) tl.store(grad_pred + offsets, -beta * g * inv_sqrt, mask=mask) class _TritonVideoDenoiseFn(torch.autograd.Function): @staticmethod def forward(ctx, latent, pred_noise, alpha): latent_c = latent.contiguous() pred_c = pred_noise.contiguous() out = torch.empty_like(latent_c) total = latent_c.numel() block = 256 grid = (_ceil_div(total, block),) alpha_f = float(alpha) _triton_video_denoise_fwd_kernel[grid]( latent_c, pred_c, out, total, alpha_f, BLOCK=block, ) ctx.alpha = alpha_f ctx.shape = latent.shape return out.reshape_as(latent) @staticmethod def backward(ctx, grad_out): grad_c = grad_out.contiguous() grad_latent = torch.empty_like(grad_c) grad_pred = torch.empty_like(grad_c) total = grad_c.numel() block = 256 grid = (_ceil_div(total, block),) _triton_video_denoise_bwd_kernel[grid]( grad_c, grad_latent, grad_pred, total, ctx.alpha, BLOCK=block, ) return grad_latent.reshape(ctx.shape), grad_pred.reshape(ctx.shape), None def video_denoise_step(latent, pred_noise, alpha): if _HAS_TRITON and latent.is_cuda and pred_noise.is_cuda and _TritonVideoDenoiseFn is not None: return _TritonVideoDenoiseFn.apply(latent, pred_noise, alpha) return (latent - (1 - alpha) * pred_noise) / (alpha ** 0.5 + 1e-8)