Spaces:
Running on Zero
Running on Zero
| import torch | |
| from ltx_core.components.protocols import DiffusionStepProtocol | |
| from ltx_core.utils import to_velocity | |
| class EulerDiffusionStep(DiffusionStepProtocol): | |
| """ | |
| First-order Euler method for diffusion sampling. | |
| Takes a single step from the current noise level (sigma) to the next by | |
| computing velocity from the denoised prediction and applying: sample + velocity * dt. | |
| """ | |
| def step( | |
| self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **_kwargs | |
| ) -> torch.Tensor: | |
| sigma = sigmas[step_index] | |
| sigma_next = sigmas[step_index + 1] | |
| dt = sigma_next - sigma | |
| velocity = to_velocity(sample, sigma, denoised_sample) | |
| return (sample.to(torch.float32) + velocity.to(torch.float32) * dt).to(sample.dtype) | |
| class Res2sDiffusionStep(DiffusionStepProtocol): | |
| """ | |
| Second-order diffusion step for res_2s sampling with SDE noise injection. | |
| Used by the res_2s denoising loop. Advances the sample from the current | |
| sigma to the next by mixing a deterministic update (from the denoised | |
| prediction) with injected noise via ``get_sde_coeff``, producing | |
| variance-preserving transitions. | |
| """ | |
| def get_sde_coeff( | |
| sigma_next: torch.Tensor, | |
| sigma_up: torch.Tensor | None = None, | |
| sigma_down: torch.Tensor | None = None, | |
| sigma_max: torch.Tensor | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Compute SDE coefficients (alpha_ratio, sigma_down, sigma_up) for the step. | |
| Given either ``sigma_down`` or ``sigma_up``, returns the mixing | |
| coefficients used for variance-preserving noise injection. If | |
| ``sigma_up`` is provided, ``sigma_down`` and ``alpha_ratio`` are | |
| derived; if ``sigma_down`` is provided, ``sigma_up`` and | |
| ``alpha_ratio`` are derived. | |
| """ | |
| if sigma_down is not None: | |
| alpha_ratio = (1 - sigma_next) / (1 - sigma_down) | |
| sigma_up = (sigma_next**2 - sigma_down**2 * alpha_ratio**2).clamp(min=0) ** 0.5 | |
| elif sigma_up is not None: | |
| # Fallback to avoid sqrt(neg_num) | |
| sigma_up.clamp_(max=sigma_next * 0.9999) | |
| sigmax = sigma_max if sigma_max is not None else torch.ones_like(sigma_next) | |
| sigma_signal = sigmax - sigma_next | |
| sigma_residual = (sigma_next**2 - sigma_up**2).clamp(min=0) ** 0.5 | |
| alpha_ratio = sigma_signal + sigma_residual | |
| sigma_down = sigma_residual / alpha_ratio | |
| else: | |
| alpha_ratio = torch.ones_like(sigma_next) | |
| sigma_down = sigma_next | |
| sigma_up = torch.zeros_like(sigma_next) | |
| sigma_up = torch.nan_to_num(sigma_up if sigma_up is not None else torch.zeros_like(sigma_next), 0.0) | |
| # Replace NaNs in sigma_down with corresponding sigma_next elements (float32) | |
| nan_mask = torch.isnan(sigma_down) | |
| sigma_down[nan_mask] = sigma_next[nan_mask].to(sigma_down.dtype) | |
| alpha_ratio = torch.nan_to_num(alpha_ratio, 1.0) | |
| return alpha_ratio, sigma_down, sigma_up | |
| def step( | |
| self, | |
| sample: torch.Tensor, | |
| denoised_sample: torch.Tensor, | |
| sigmas: torch.Tensor, | |
| step_index: int, | |
| noise: torch.Tensor, | |
| eta: float = 0.5, | |
| ) -> torch.Tensor: | |
| """Advance one step with SDE noise injection via get_sde_coeff. | |
| Args: | |
| sample: Current noisy sample. | |
| denoised_sample: Denoised prediction from the model. | |
| sigmas: Noise schedule tensor. | |
| step_index: Current step index in the schedule. | |
| noise: Random noise tensor for stochastic injection. | |
| eta: Controls stochastic noise injection strength (0=deterministic, 1=maximum). Default 0.5. | |
| Returns: | |
| Next sample with SDE noise injection applied. | |
| """ | |
| sigma = sigmas[step_index] | |
| sigma_next = sigmas[step_index + 1] | |
| alpha_ratio, sigma_down, sigma_up = self.get_sde_coeff(sigma_next, sigma_up=sigma_next * eta) | |
| output_dtype = denoised_sample.dtype | |
| if torch.any(sigma_up == 0) or torch.any(sigma_next == 0): | |
| return denoised_sample | |
| # Extract epsilon prediction | |
| eps_next = (sample - denoised_sample) / (sigma - sigma_next) | |
| denoised_next = sample - sigma * eps_next | |
| # Mix deterministic and stochastic components | |
| x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise | |
| return x_noised.to(output_dtype) | |