import torch from ltx_core.components.protocols import DiffusionStepProtocol from ltx_core.utils import to_velocity class EulerDiffusionStep(DiffusionStepProtocol): """ First-order Euler method for diffusion sampling. Takes a single step from the current noise level (sigma) to the next by computing velocity from the denoised prediction and applying: sample + velocity * dt. """ def step( self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **_kwargs ) -> torch.Tensor: sigma = sigmas[step_index] sigma_next = sigmas[step_index + 1] dt = sigma_next - sigma velocity = to_velocity(sample, sigma, denoised_sample) return (sample.to(torch.float32) + velocity.to(torch.float32) * dt).to(sample.dtype) class Res2sDiffusionStep(DiffusionStepProtocol): """ Second-order diffusion step for res_2s sampling with SDE noise injection. Used by the res_2s denoising loop. Advances the sample from the current sigma to the next by mixing a deterministic update (from the denoised prediction) with injected noise via ``get_sde_coeff``, producing variance-preserving transitions. """ @staticmethod def get_sde_coeff( sigma_next: torch.Tensor, sigma_up: torch.Tensor | None = None, sigma_down: torch.Tensor | None = None, sigma_max: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute SDE coefficients (alpha_ratio, sigma_down, sigma_up) for the step. Given either ``sigma_down`` or ``sigma_up``, returns the mixing coefficients used for variance-preserving noise injection. If ``sigma_up`` is provided, ``sigma_down`` and ``alpha_ratio`` are derived; if ``sigma_down`` is provided, ``sigma_up`` and ``alpha_ratio`` are derived. """ if sigma_down is not None: alpha_ratio = (1 - sigma_next) / (1 - sigma_down) sigma_up = (sigma_next**2 - sigma_down**2 * alpha_ratio**2).clamp(min=0) ** 0.5 elif sigma_up is not None: # Fallback to avoid sqrt(neg_num) sigma_up.clamp_(max=sigma_next * 0.9999) sigmax = sigma_max if sigma_max is not None else torch.ones_like(sigma_next) sigma_signal = sigmax - sigma_next sigma_residual = (sigma_next**2 - sigma_up**2).clamp(min=0) ** 0.5 alpha_ratio = sigma_signal + sigma_residual sigma_down = sigma_residual / alpha_ratio else: alpha_ratio = torch.ones_like(sigma_next) sigma_down = sigma_next sigma_up = torch.zeros_like(sigma_next) sigma_up = torch.nan_to_num(sigma_up if sigma_up is not None else torch.zeros_like(sigma_next), 0.0) # Replace NaNs in sigma_down with corresponding sigma_next elements (float32) nan_mask = torch.isnan(sigma_down) sigma_down[nan_mask] = sigma_next[nan_mask].to(sigma_down.dtype) alpha_ratio = torch.nan_to_num(alpha_ratio, 1.0) return alpha_ratio, sigma_down, sigma_up def step( self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, noise: torch.Tensor, eta: float = 0.5, ) -> torch.Tensor: """Advance one step with SDE noise injection via get_sde_coeff. Args: sample: Current noisy sample. denoised_sample: Denoised prediction from the model. sigmas: Noise schedule tensor. step_index: Current step index in the schedule. noise: Random noise tensor for stochastic injection. eta: Controls stochastic noise injection strength (0=deterministic, 1=maximum). Default 0.5. Returns: Next sample with SDE noise injection applied. """ sigma = sigmas[step_index] sigma_next = sigmas[step_index + 1] alpha_ratio, sigma_down, sigma_up = self.get_sde_coeff(sigma_next, sigma_up=sigma_next * eta) output_dtype = denoised_sample.dtype if torch.any(sigma_up == 0) or torch.any(sigma_next == 0): return denoised_sample # Extract epsilon prediction eps_next = (sample - denoised_sample) / (sigma - sigma_next) denoised_next = sample - sigma * eps_next # Mix deterministic and stochastic components x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise return x_noised.to(output_dtype)