ARBS / arbitor /kernel /triton_video.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Triton kernels for video denoising (used by VideoHead)."""
import torch
import torch.nn as nn
from math import ceil as _ceil
from .ternary_scale import _HAS_TRITON
if _HAS_TRITON:
import triton
import triton.language as tl
@triton.jit
def _triton_video_denoise_fwd_kernel(
latent, pred_noise, out,
TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr,
):
offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
mask = offsets < TOTAL
l = tl.load(latent + offsets, mask=mask, other=0.0)
p = tl.load(pred_noise + offsets, mask=mask, other=0.0)
beta = 1.0 - ALPHA
inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001)
tl.store(out + offsets, (l - beta * p) * inv_sqrt, mask=mask)
@triton.jit
def _triton_video_denoise_bwd_kernel(
grad_out, grad_latent, grad_pred,
TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr,
):
offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
mask = offsets < TOTAL
g = tl.load(grad_out + offsets, mask=mask, other=0.0)
beta = 1.0 - ALPHA
inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001)
tl.store(grad_latent + offsets, g * inv_sqrt, mask=mask)
tl.store(grad_pred + offsets, -beta * g * inv_sqrt, mask=mask)
class _TritonVideoDenoiseFn(torch.autograd.Function):
@staticmethod
def forward(ctx, latent, pred_noise, alpha):
latent_c = latent.contiguous()
pred_c = pred_noise.contiguous()
out = torch.empty_like(latent_c)
total = latent_c.numel()
block = 256
grid = (_ceil_div(total, block),)
alpha_f = float(alpha)
_triton_video_denoise_fwd_kernel[grid](
latent_c, pred_c, out,
total, alpha_f, BLOCK=block,
)
ctx.alpha = alpha_f
ctx.shape = latent.shape
return out.reshape_as(latent)
@staticmethod
def backward(ctx, grad_out):
grad_c = grad_out.contiguous()
grad_latent = torch.empty_like(grad_c)
grad_pred = torch.empty_like(grad_c)
total = grad_c.numel()
block = 256
grid = (_ceil_div(total, block),)
_triton_video_denoise_bwd_kernel[grid](
grad_c, grad_latent, grad_pred,
total, ctx.alpha, BLOCK=block,
)
return grad_latent.reshape(ctx.shape), grad_pred.reshape(ctx.shape), None
def video_denoise_step(latent, pred_noise, alpha):
if _HAS_TRITON and latent.is_cuda and pred_noise.is_cuda and _TritonVideoDenoiseFn is not None:
return _TritonVideoDenoiseFn.apply(latent, pred_noise, alpha)
return (latent - (1 - alpha) * pred_noise) / (alpha ** 0.5 + 1e-8)