File size: 2,849 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Triton kernels for video denoising (used by VideoHead)."""
import torch
import torch.nn as nn
from math import ceil as _ceil

from .ternary_scale import _HAS_TRITON

if _HAS_TRITON:
    import triton
    import triton.language as tl

    @triton.jit
    def _triton_video_denoise_fwd_kernel(
        latent, pred_noise, out,
        TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr,
    ):
        offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
        mask = offsets < TOTAL
        l = tl.load(latent + offsets, mask=mask, other=0.0)
        p = tl.load(pred_noise + offsets, mask=mask, other=0.0)
        beta = 1.0 - ALPHA
        inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001)
        tl.store(out + offsets, (l - beta * p) * inv_sqrt, mask=mask)

    @triton.jit
    def _triton_video_denoise_bwd_kernel(
        grad_out, grad_latent, grad_pred,
        TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr,
    ):
        offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
        mask = offsets < TOTAL
        g = tl.load(grad_out + offsets, mask=mask, other=0.0)
        beta = 1.0 - ALPHA
        inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001)
        tl.store(grad_latent + offsets, g * inv_sqrt, mask=mask)
        tl.store(grad_pred + offsets, -beta * g * inv_sqrt, mask=mask)


    class _TritonVideoDenoiseFn(torch.autograd.Function):
        @staticmethod
        def forward(ctx, latent, pred_noise, alpha):
            latent_c = latent.contiguous()
            pred_c = pred_noise.contiguous()
            out = torch.empty_like(latent_c)
            total = latent_c.numel()
            block = 256
            grid = (_ceil_div(total, block),)
            alpha_f = float(alpha)
            _triton_video_denoise_fwd_kernel[grid](
                latent_c, pred_c, out,
                total, alpha_f, BLOCK=block,
            )
            ctx.alpha = alpha_f
            ctx.shape = latent.shape
            return out.reshape_as(latent)

        @staticmethod
        def backward(ctx, grad_out):
            grad_c = grad_out.contiguous()
            grad_latent = torch.empty_like(grad_c)
            grad_pred = torch.empty_like(grad_c)
            total = grad_c.numel()
            block = 256
            grid = (_ceil_div(total, block),)
            _triton_video_denoise_bwd_kernel[grid](
                grad_c, grad_latent, grad_pred,
                total, ctx.alpha, BLOCK=block,
            )
            return grad_latent.reshape(ctx.shape), grad_pred.reshape(ctx.shape), None


def video_denoise_step(latent, pred_noise, alpha):
    if _HAS_TRITON and latent.is_cuda and pred_noise.is_cuda and _TritonVideoDenoiseFn is not None:
        return _TritonVideoDenoiseFn.apply(latent, pred_noise, alpha)
    return (latent - (1 - alpha) * pred_noise) / (alpha ** 0.5 + 1e-8)