File size: 5,541 Bytes
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Diffusion process and DDIM sampling for motion generation."""

import math
from typing import Optional, Tuple

import torch
from torch import nn


def get_beta_schedule(
    num_diffusion_timesteps: int,
    max_beta: Optional[float] = 0.999,
) -> torch.Tensor:
    """Get cosine beta schedule."""

    def alpha_bar(t):
        return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return torch.tensor(betas, dtype=torch.float)


class Diffusion(torch.nn.Module):
    """Cosine-schedule diffusion process: betas, alphas, and DDIM step mapping."""

    def __init__(self, num_base_steps: int):
        """Set up cosine beta schedule and precompute diffusion variables for num_base_steps."""
        super().__init__()
        self.num_base_steps = num_base_steps
        betas_base = get_beta_schedule(self.num_base_steps)
        self.register_buffer("betas_base", betas_base, persistent=False)
        alphas_cumprod_base = torch.cumprod(1.0 - self.betas_base, dim=0)
        self.register_buffer("alphas_cumprod_base", alphas_cumprod_base, persistent=False)
        use_timesteps, _ = self.space_timesteps(self.num_base_steps)
        self.calc_diffusion_vars(use_timesteps)

    def extra_repr(self) -> str:
        return f"num_base_steps={self.num_base_steps}"

    @property
    def device(self):
        return self.betas_base.device

    def space_timesteps(self, num_denoising_steps: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return (use_timesteps, map_tensor) for a subsampled denoising schedule of
        num_denoising_steps."""
        nsteps_train = self.num_base_steps
        frac_stride = (nsteps_train - 1) / max(1, num_denoising_steps - 1)
        use_timesteps = torch.round(torch.arange(nsteps_train, device=self.device) * frac_stride).to(torch.long)
        use_timesteps = torch.clamp(use_timesteps, max=nsteps_train - 1)
        map_tensor = torch.arange(nsteps_train, device=self.device, dtype=torch.long)[use_timesteps]
        return use_timesteps, map_tensor

    def calc_diffusion_vars(self, use_timesteps: torch.Tensor) -> None:
        """Update buffers (betas, alphas, alphas_cumprod, etc.) for the given subsampled
        timesteps."""
        alphas_cumprod = self.alphas_cumprod_base[use_timesteps]
        last_alpha_cumprod = torch.cat([torch.tensor([1.0]).to(alphas_cumprod), alphas_cumprod[:-1]])
        betas = 1.0 - alphas_cumprod / last_alpha_cumprod
        self.register_buffer("betas", betas, persistent=False)

        alphas = 1.0 - self.betas
        self.register_buffer("alphas", alphas, persistent=False)
        alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        alphas_cumprod = torch.clamp(alphas_cumprod, min=1e-9)
        self.register_buffer("alphas_cumprod", alphas_cumprod, persistent=False)

        alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(self.alphas_cumprod), self.alphas_cumprod[:-1]])
        self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev, persistent=False)

        sqrt_recip_alphas_cumprod = torch.rsqrt(self.alphas_cumprod)
        self.register_buffer("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod, persistent=False)

        sqrt_recipm1_alphas_cumprod = torch.rsqrt(self.alphas_cumprod / (1.0 - self.alphas_cumprod))
        self.register_buffer("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod, persistent=False)

        posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.register_buffer("posterior_variance", posterior_variance, persistent=False)

        sqrt_alphas_cumprod = torch.rsqrt(1.0 / self.alphas_cumprod)
        self.register_buffer("sqrt_alphas_cumprod", sqrt_alphas_cumprod, persistent=False)

        sqrt_one_minus_alphas_cumprod = torch.rsqrt(1.0 / (1.0 - self.alphas_cumprod))
        self.register_buffer(
            "sqrt_one_minus_alphas_cumprod",
            sqrt_one_minus_alphas_cumprod,
            persistent=False,
        )

    def q_sample(
        self,
        x_start: torch.Tensor,
        t: torch.Tensor,
        noise: torch.Tensor = None,
    ):
        if noise is None:
            noise = torch.randn_like(x_start)
        assert noise.shape == x_start.shape

        xt = (
            self.sqrt_alphas_cumprod[t, None, None] * x_start
            + self.sqrt_one_minus_alphas_cumprod[t, None, None] * noise
        )
        return xt


class DDIMSampler(nn.Module):
    """Deterministic DDIM sampler (eta = 0)."""

    def __init__(self, diffusion: Diffusion):
        super().__init__()
        self.diffusion = diffusion

    def __call__(
        self,
        use_timesteps: torch.Tensor,
        x_t: torch.Tensor,
        pred_xstart: torch.Tensor,
        t: torch.Tensor,
    ) -> torch.Tensor:
        self.diffusion.calc_diffusion_vars(use_timesteps)
        eps = (
            self.diffusion.sqrt_recip_alphas_cumprod[t, None, None] * x_t - pred_xstart
        ) / self.diffusion.sqrt_recipm1_alphas_cumprod[t, None, None]
        alpha_bar_prev = self.diffusion.alphas_cumprod_prev[t, None, None]
        x = pred_xstart * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev) * eps
        return x