# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """Diffusion process and DDIM sampling for motion generation.""" import math from typing import Optional, Tuple import torch from torch import nn def get_beta_schedule( num_diffusion_timesteps: int, max_beta: Optional[float] = 0.999, ) -> torch.Tensor: """Get cosine beta schedule.""" def alpha_bar(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return torch.tensor(betas, dtype=torch.float) class Diffusion(torch.nn.Module): """Cosine-schedule diffusion process: betas, alphas, and DDIM step mapping.""" def __init__(self, num_base_steps: int): """Set up cosine beta schedule and precompute diffusion variables for num_base_steps.""" super().__init__() self.num_base_steps = num_base_steps betas_base = get_beta_schedule(self.num_base_steps) self.register_buffer("betas_base", betas_base, persistent=False) alphas_cumprod_base = torch.cumprod(1.0 - self.betas_base, dim=0) self.register_buffer("alphas_cumprod_base", alphas_cumprod_base, persistent=False) use_timesteps, _ = self.space_timesteps(self.num_base_steps) self.calc_diffusion_vars(use_timesteps) def extra_repr(self) -> str: return f"num_base_steps={self.num_base_steps}" @property def device(self): return self.betas_base.device def space_timesteps(self, num_denoising_steps: int) -> Tuple[torch.Tensor, torch.Tensor]: """Return (use_timesteps, map_tensor) for a subsampled denoising schedule of num_denoising_steps.""" nsteps_train = self.num_base_steps frac_stride = (nsteps_train - 1) / max(1, num_denoising_steps - 1) use_timesteps = torch.round(torch.arange(nsteps_train, device=self.device) * frac_stride).to(torch.long) use_timesteps = torch.clamp(use_timesteps, max=nsteps_train - 1) map_tensor = torch.arange(nsteps_train, device=self.device, dtype=torch.long)[use_timesteps] return use_timesteps, map_tensor def calc_diffusion_vars(self, use_timesteps: torch.Tensor) -> None: """Update buffers (betas, alphas, alphas_cumprod, etc.) for the given subsampled timesteps.""" alphas_cumprod = self.alphas_cumprod_base[use_timesteps] last_alpha_cumprod = torch.cat([torch.tensor([1.0]).to(alphas_cumprod), alphas_cumprod[:-1]]) betas = 1.0 - alphas_cumprod / last_alpha_cumprod self.register_buffer("betas", betas, persistent=False) alphas = 1.0 - self.betas self.register_buffer("alphas", alphas, persistent=False) alphas_cumprod = torch.cumprod(self.alphas, dim=0) alphas_cumprod = torch.clamp(alphas_cumprod, min=1e-9) self.register_buffer("alphas_cumprod", alphas_cumprod, persistent=False) alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(self.alphas_cumprod), self.alphas_cumprod[:-1]]) self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev, persistent=False) sqrt_recip_alphas_cumprod = torch.rsqrt(self.alphas_cumprod) self.register_buffer("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod, persistent=False) sqrt_recipm1_alphas_cumprod = torch.rsqrt(self.alphas_cumprod / (1.0 - self.alphas_cumprod)) self.register_buffer("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod, persistent=False) posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) self.register_buffer("posterior_variance", posterior_variance, persistent=False) sqrt_alphas_cumprod = torch.rsqrt(1.0 / self.alphas_cumprod) self.register_buffer("sqrt_alphas_cumprod", sqrt_alphas_cumprod, persistent=False) sqrt_one_minus_alphas_cumprod = torch.rsqrt(1.0 / (1.0 - self.alphas_cumprod)) self.register_buffer( "sqrt_one_minus_alphas_cumprod", sqrt_one_minus_alphas_cumprod, persistent=False, ) def q_sample( self, x_start: torch.Tensor, t: torch.Tensor, noise: torch.Tensor = None, ): if noise is None: noise = torch.randn_like(x_start) assert noise.shape == x_start.shape xt = ( self.sqrt_alphas_cumprod[t, None, None] * x_start + self.sqrt_one_minus_alphas_cumprod[t, None, None] * noise ) return xt class DDIMSampler(nn.Module): """Deterministic DDIM sampler (eta = 0).""" def __init__(self, diffusion: Diffusion): super().__init__() self.diffusion = diffusion def __call__( self, use_timesteps: torch.Tensor, x_t: torch.Tensor, pred_xstart: torch.Tensor, t: torch.Tensor, ) -> torch.Tensor: self.diffusion.calc_diffusion_vars(use_timesteps) eps = ( self.diffusion.sqrt_recip_alphas_cumprod[t, None, None] * x_t - pred_xstart ) / self.diffusion.sqrt_recipm1_alphas_cumprod[t, None, None] alpha_bar_prev = self.diffusion.alphas_cumprod_prev[t, None, None] x = pred_xstart * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev) * eps return x