Spaces:
Running on Zero
Running on Zero
| import logging | |
| from dataclasses import replace | |
| from functools import partial | |
| from typing import Callable | |
| import torch | |
| from tqdm import tqdm | |
| from ltx_core.components.diffusion_steps import Res2sDiffusionStep | |
| from ltx_core.components.protocols import DiffusionStepProtocol | |
| from ltx_core.model.transformer import X0Model | |
| from ltx_core.utils import to_denoised, to_velocity | |
| from ltx_pipelines.utils.helpers import post_process_latent, timesteps_from_mask | |
| from ltx_pipelines.utils.res2s import get_res2s_coefficients | |
| from ltx_pipelines.utils.types import Denoiser, LatentState | |
| logger = logging.getLogger(__name__) | |
| def _step_state( | |
| state: LatentState | None, | |
| denoised: torch.Tensor | None, | |
| stepper: DiffusionStepProtocol, | |
| sigmas: torch.Tensor, | |
| step_idx: int, | |
| ) -> LatentState | None: | |
| """Advance one diffusion step for a single modality, or return ``None`` if absent.""" | |
| if state is None or denoised is None: | |
| return state | |
| denoised = post_process_latent(denoised, state.denoise_mask, state.clean_latent) | |
| return replace(state, latent=stepper.step(state.latent, denoised, sigmas, step_idx)) | |
| def euler_denoising_loop( | |
| sigmas: torch.Tensor, | |
| video_state: LatentState | None, | |
| audio_state: LatentState | None, | |
| stepper: DiffusionStepProtocol, | |
| transformer: X0Model, | |
| denoiser: Denoiser, | |
| ) -> tuple[LatentState | None, LatentState | None]: | |
| """ | |
| Perform the joint audio-video denoising loop over a diffusion schedule. | |
| Either ``video_state`` or ``audio_state`` may be ``None`` for absent | |
| modalities; the absent modality is passed through unchanged. | |
| ### Parameters | |
| sigmas: | |
| A 1D tensor of noise levels (diffusion sigmas) defining the sampling | |
| schedule. All steps except the last element are iterated over. | |
| video_state: | |
| The current video :class:`LatentState`, or ``None`` if video is absent. | |
| audio_state: | |
| The current audio :class:`LatentState`, or ``None`` if audio is absent. | |
| stepper: | |
| An implementation of :class:`DiffusionStepProtocol` that updates a | |
| latent given the current latent, its denoised estimate, the full | |
| ``sigmas`` schedule, and the current step index. | |
| transformer: | |
| The diffusion model passed to the denoiser at each step. | |
| denoiser: | |
| A callable implementing :class:`Denoiser`. It is invoked as | |
| ``denoiser(transformer, video_state, audio_state, sigmas, step_index)`` | |
| and must return ``(denoised_video, denoised_audio)``. | |
| ### Returns | |
| tuple[LatentState | None, LatentState | None] | |
| Final ``(video_state, audio_state)`` after the denoising loop. | |
| """ | |
| for step_idx, _ in enumerate(tqdm(sigmas[:-1])): | |
| denoised_video, denoised_audio = denoiser(transformer, video_state, audio_state, sigmas, step_idx) | |
| video_state = _step_state(video_state, denoised_video, stepper, sigmas, step_idx) | |
| audio_state = _step_state(audio_state, denoised_audio, stepper, sigmas, step_idx) | |
| return (video_state, audio_state) | |
| def heun_denoising_loop( | |
| sigmas: torch.Tensor, | |
| video_state: LatentState | None, | |
| audio_state: LatentState | None, | |
| stepper: DiffusionStepProtocol, | |
| transformer: X0Model, | |
| denoiser: Denoiser, | |
| ) -> tuple[LatentState | None, LatentState | None]: | |
| """Heun (2nd-order predictor-corrector) variant of :func:`euler_denoising_loop`. | |
| Port of the ``jkass_quality`` sampler from JK-AceStep-Nodes. ~2x model | |
| calls per step; last step falls back to plain Euler. | |
| """ | |
| n = len(sigmas) - 1 | |
| for step_idx in tqdm(range(n)): | |
| sigma_curr = sigmas[step_idx] | |
| sigma_next = sigmas[step_idx + 1] | |
| dt = sigma_next - sigma_curr | |
| denoised_video_1, denoised_audio_1 = denoiser( | |
| transformer, video_state, audio_state, sigmas, step_idx, | |
| ) | |
| video_pred = _step_state(video_state, denoised_video_1, stepper, sigmas, step_idx) | |
| audio_pred = _step_state(audio_state, denoised_audio_1, stepper, sigmas, step_idx) | |
| if step_idx == n - 1 or float(sigma_next) == 0.0: | |
| video_state, audio_state = video_pred, audio_pred | |
| continue | |
| denoised_video_2, denoised_audio_2 = denoiser( | |
| transformer, video_pred, audio_pred, sigmas, step_idx + 1, | |
| ) | |
| def _heun_step(state, state_pred, denoised_1, denoised_2): | |
| if state is None or denoised_1 is None or denoised_2 is None: | |
| return state | |
| d1 = post_process_latent(denoised_1, state.denoise_mask, state.clean_latent) | |
| d2 = post_process_latent(denoised_2, state.denoise_mask, state.clean_latent) | |
| v1 = to_velocity(state.latent, sigma_curr, d1) | |
| v2 = to_velocity(state_pred.latent, sigma_next, d2) | |
| v_avg = 0.5 * (v1 + v2) | |
| new_lat = (state.latent.to(torch.float32) + v_avg.to(torch.float32) * dt).to(state.latent.dtype) | |
| return replace(state, latent=new_lat) | |
| video_state = _heun_step(video_state, video_pred, denoised_video_1, denoised_video_2) | |
| audio_state = _heun_step(audio_state, audio_pred, denoised_audio_1, denoised_audio_2) | |
| return (video_state, audio_state) | |
| def gradient_estimating_euler_denoising_loop( | |
| sigmas: torch.Tensor, | |
| video_state: LatentState | None, | |
| audio_state: LatentState | None, | |
| stepper: DiffusionStepProtocol, | |
| transformer: X0Model, | |
| denoiser: Denoiser, | |
| ge_gamma: float = 2.0, | |
| ) -> tuple[LatentState | None, LatentState | None]: | |
| """ | |
| Perform the joint audio-video denoising loop using gradient-estimation sampling. | |
| Same interface as :func:`euler_denoising_loop` with an additional | |
| ``ge_gamma`` parameter for velocity correction. | |
| ### Parameters | |
| ge_gamma: | |
| Gradient estimation coefficient controlling the velocity correction term. | |
| Default is 2.0. Paper: https://openreview.net/pdf?id=o2ND9v0CeK | |
| ### Returns | |
| tuple[LatentState | None, LatentState | None] | |
| See :func:`euler_denoising_loop` for return value description. | |
| """ | |
| previous_audio_velocity = None | |
| previous_video_velocity = None | |
| def update_velocity_and_sample( | |
| noisy_sample: torch.Tensor, denoised_sample: torch.Tensor, sigma: float, previous_velocity: torch.Tensor | None | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| current_velocity = to_velocity(noisy_sample, sigma, denoised_sample) | |
| if previous_velocity is not None: | |
| delta_v = current_velocity - previous_velocity | |
| total_velocity = ge_gamma * delta_v + previous_velocity | |
| denoised_sample = to_denoised(noisy_sample, total_velocity, sigma) | |
| return current_velocity, denoised_sample | |
| for step_idx, _ in enumerate(tqdm(sigmas[:-1])): | |
| denoised_video, denoised_audio = denoiser(transformer, video_state, audio_state, sigmas, step_idx) | |
| if video_state is not None and denoised_video is not None: | |
| denoised_video = post_process_latent(denoised_video, video_state.denoise_mask, video_state.clean_latent) | |
| if audio_state is not None and denoised_audio is not None: | |
| denoised_audio = post_process_latent(denoised_audio, audio_state.denoise_mask, audio_state.clean_latent) | |
| if sigmas[step_idx + 1] == 0: | |
| if video_state is not None and denoised_video is not None: | |
| video_state = replace(video_state, latent=denoised_video) | |
| if audio_state is not None and denoised_audio is not None: | |
| audio_state = replace(audio_state, latent=denoised_audio) | |
| return video_state, audio_state | |
| if video_state is not None and denoised_video is not None: | |
| previous_video_velocity, denoised_video = update_velocity_and_sample( | |
| video_state.latent, denoised_video, sigmas[step_idx], previous_video_velocity | |
| ) | |
| video_state = replace( | |
| video_state, latent=stepper.step(video_state.latent, denoised_video, sigmas, step_idx) | |
| ) | |
| if audio_state is not None and denoised_audio is not None: | |
| previous_audio_velocity, denoised_audio = update_velocity_and_sample( | |
| audio_state.latent, denoised_audio, sigmas[step_idx], previous_audio_velocity | |
| ) | |
| audio_state = replace( | |
| audio_state, latent=stepper.step(audio_state.latent, denoised_audio, sigmas, step_idx) | |
| ) | |
| return (video_state, audio_state) | |
| def _channelwise_normalize(x: torch.Tensor) -> torch.Tensor: | |
| return x.sub_(x.mean(dim=(-2, -1), keepdim=True)).div_(x.std(dim=(-2, -1), keepdim=True)) | |
| def _get_new_noise(x: torch.Tensor, generator: torch.Generator) -> torch.Tensor: | |
| noise = torch.randn(x.shape, generator=generator, dtype=torch.float64, device=generator.device) | |
| noise = (noise - noise.mean()) / noise.std() | |
| return _channelwise_normalize(noise) | |
| def _inject_sde_noise( | |
| state: LatentState, | |
| sample: torch.Tensor, | |
| denoised_sample: torch.Tensor, | |
| step_noise_generator: torch.Generator, | |
| new_noise_fn: Callable[[torch.Tensor, torch.Generator], torch.Tensor], | |
| stepper: DiffusionStepProtocol, | |
| sigmas: torch.Tensor, | |
| step_idx: int, | |
| legacy_mode: bool = False, | |
| eta: float = 0.5, | |
| ) -> torch.Tensor: | |
| sigmas_copy = sigmas.clone() | |
| new_noise = new_noise_fn(state.latent, step_noise_generator) | |
| if not legacy_mode: | |
| timesteps = timesteps_from_mask(state.denoise_mask.double(), sigmas_copy[step_idx].double()) | |
| next_timesteps = timesteps_from_mask(state.denoise_mask.double(), sigmas_copy[step_idx + 1].double()) | |
| sigmas = torch.stack([timesteps, next_timesteps]) | |
| step_idx = 0 | |
| x_next = stepper.step( | |
| sample=sample, | |
| denoised_sample=denoised_sample, | |
| sigmas=sigmas, | |
| step_index=step_idx, | |
| noise=new_noise, | |
| eta=eta, | |
| ) | |
| if legacy_mode: | |
| x_next = post_process_latent(x_next, state.denoise_mask, state.clean_latent) | |
| return x_next | |
| def res2s_audio_video_denoising_loop( # noqa: PLR0913,PLR0915,PLR0912 | |
| sigmas: torch.Tensor, | |
| video_state: LatentState | None, | |
| audio_state: LatentState | None, | |
| stepper: DiffusionStepProtocol, | |
| transformer: X0Model, | |
| denoiser: Denoiser, | |
| noise_seed: int = -1, | |
| noise_seed_substep: int | None = None, | |
| eta: float = 0.5, | |
| bongmath: bool = True, | |
| bongmath_max_iter: int = 100, | |
| new_noise_fn: Callable[[torch.Tensor, torch.Generator], torch.Tensor] = _get_new_noise, | |
| model_dtype: torch.dtype = torch.bfloat16, | |
| legacy_mode: bool = True, | |
| ) -> tuple[LatentState | None, LatentState | None]: | |
| """ | |
| Joint audio-video denoising loop using the res_2s second-order sampler. | |
| Iterates over the diffusion schedule with a two-stage Runge-Kutta step: | |
| evaluates the denoiser at the current point and at a midpoint (with SDE | |
| noise), then combines both with RK coefficients. Supports anchor-point | |
| refinement (bong iteration) and optional SDE noise injection. Requires | |
| :class:`Res2sDiffusionStep` as ``stepper``. | |
| Either modality may be ``None`` (absent). | |
| ### Parameters | |
| transformer: | |
| The diffusion model passed to the denoiser at each step. | |
| denoiser: | |
| Callable implementing :class:`Denoiser`. | |
| noise_seed: | |
| Seed for step-level SDE noise; substep seed defaults to ``noise_seed + 10000``. | |
| noise_seed_substep: | |
| Optional seed for substep SDE noise; if None, derived from ``noise_seed``. | |
| eta: | |
| Controls stochastic noise injection strength (0=deterministic, 1=maximum). | |
| Applies to main diffusion steps; substeps always use 0.5. Default 0.5. | |
| bongmath: | |
| Whether to run iterative anchor refinement (bong iteration) when step size is small. | |
| bongmath_max_iter: | |
| Max iterations for bong refinement when enabled. | |
| new_noise_fn: | |
| Callable ``(latent, generator) -> noise`` for SDE injection. | |
| model_dtype: | |
| Dtype for latent state updates (e.g. bfloat16). | |
| ### Returns | |
| tuple[LatentState | None, LatentState | None] | |
| Final ``(video_state, audio_state)`` after the denoising loop. | |
| """ | |
| # Determine device from whichever state is present | |
| present_state = video_state or audio_state | |
| if present_state is None: | |
| raise ValueError("At least one of video_state or audio_state must be provided") | |
| state_device = present_state.latent.device | |
| # Initialize noise generators with different seeds | |
| if noise_seed_substep is None: | |
| noise_seed_substep = noise_seed + 10000 # Offset to ensure different seeds | |
| step_noise_generator = torch.Generator(device=state_device).manual_seed(noise_seed) | |
| substep_noise_generator = torch.Generator(device=state_device).manual_seed(noise_seed_substep) | |
| sde_noise_injecting_fn = partial( | |
| _inject_sde_noise, stepper=stepper, new_noise_fn=new_noise_fn, legacy_mode=legacy_mode | |
| ) | |
| step_noise_injecting_fn = partial(sde_noise_injecting_fn, step_noise_generator=step_noise_generator, eta=eta) | |
| # substep eta is always default 0.5 for compatibility with original implementation. | |
| substep_noise_injecting_fn = partial(sde_noise_injecting_fn, step_noise_generator=substep_noise_generator, eta=0.5) | |
| if not isinstance(stepper, Res2sDiffusionStep): | |
| raise ValueError("stepper must be an instance of Res2sDiffusionStep") | |
| n_full_steps = len(sigmas) - 1 | |
| # inject minimal sigma value to avoid division by zero | |
| if sigmas[-1] == 0: | |
| sigmas = torch.cat([sigmas[:-1], torch.tensor([0.0011, 0.0], device=sigmas.device)], dim=0) | |
| # Compute step sizes in hyperbolic space | |
| hs = -torch.log(sigmas[1:].double().cpu() / (sigmas[:-1].double().cpu())) | |
| # Initialize phi cache for reuse across loop iterations | |
| phi_cache = {} | |
| c2 = 0.5 # Midpoint for res_2s | |
| for step_idx in tqdm(range(n_full_steps)): | |
| sigma = sigmas[step_idx].double() | |
| sigma_next = sigmas[step_idx + 1].double() | |
| # Initialize anchor point | |
| x_anchor_video = video_state.latent.clone().double() if video_state is not None else None | |
| x_anchor_audio = audio_state.latent.clone().double() if audio_state is not None else None | |
| # ==================================================================== | |
| # STAGE 1: Evaluate at current point | |
| # ==================================================================== | |
| denoised_video_1, denoised_audio_1 = denoiser(transformer, video_state, audio_state, sigmas, step_idx) | |
| if video_state is not None and denoised_video_1 is not None: | |
| denoised_video_1 = post_process_latent(denoised_video_1, video_state.denoise_mask, video_state.clean_latent) | |
| if audio_state is not None and denoised_audio_1 is not None: | |
| denoised_audio_1 = post_process_latent(denoised_audio_1, audio_state.denoise_mask, audio_state.clean_latent) | |
| h = hs[step_idx].item() | |
| # Compute RK coefficients (pass phi_cache for caching) | |
| a21, b1, b2 = get_res2s_coefficients(h, phi_cache, c2) | |
| # Compute substep sigma, sqrt is a hardcode for c2 = 0.5 | |
| sub_sigma = torch.sqrt(sigma * sigma_next) | |
| # ==================================================================== | |
| # Compute substep x using RK coefficient a21 | |
| # ==================================================================== | |
| if x_anchor_video is not None and denoised_video_1 is not None: | |
| eps_1_video = denoised_video_1.double() - x_anchor_video | |
| x_mid_video = x_anchor_video.double() + h * a21 * eps_1_video | |
| else: | |
| eps_1_video = None | |
| x_mid_video = None | |
| if x_anchor_audio is not None and denoised_audio_1 is not None: | |
| eps_1_audio = denoised_audio_1.double() - x_anchor_audio | |
| x_mid_audio = x_anchor_audio.double() + h * a21 * eps_1_audio | |
| else: | |
| eps_1_audio = None | |
| x_mid_audio = None | |
| # ==================================================================== | |
| # SDE noise injection at substep | |
| # ==================================================================== | |
| if x_mid_video is not None and video_state is not None: | |
| x_mid_video = substep_noise_injecting_fn( | |
| state=video_state, | |
| sample=x_anchor_video, | |
| denoised_sample=x_mid_video, | |
| sigmas=torch.stack([sigma, sub_sigma]), | |
| step_idx=0, | |
| ) | |
| if x_mid_audio is not None and audio_state is not None: | |
| x_mid_audio = substep_noise_injecting_fn( | |
| state=audio_state, | |
| sample=x_anchor_audio, | |
| denoised_sample=x_mid_audio, | |
| sigmas=torch.stack([sigma, sub_sigma]), | |
| step_idx=0, | |
| ) | |
| # ==================================================================== | |
| # ITERATIVE REFINEMENT (Bong Iteration) | |
| # ==================================================================== | |
| if bongmath and h < 0.5 and sigma > 0.03: | |
| for _ in range(bongmath_max_iter): | |
| if x_mid_video is not None and eps_1_video is not None: | |
| x_anchor_video = x_mid_video - h * a21 * eps_1_video | |
| eps_1_video = denoised_video_1.double() - x_anchor_video | |
| if x_mid_audio is not None and eps_1_audio is not None: | |
| x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio | |
| eps_1_audio = denoised_audio_1.double() - x_anchor_audio | |
| # ==================================================================== | |
| # STAGE 2: Evaluate at substep point (WITH NOISE) | |
| # ==================================================================== | |
| mid_video_state = ( | |
| replace(video_state, latent=x_mid_video.to(model_dtype)) | |
| if video_state is not None and x_mid_video is not None | |
| else None | |
| ) | |
| mid_audio_state = ( | |
| replace(audio_state, latent=x_mid_audio.to(model_dtype)) | |
| if audio_state is not None and x_mid_audio is not None | |
| else None | |
| ) | |
| denoised_video_2, denoised_audio_2 = denoiser( | |
| transformer, | |
| video_state=mid_video_state, | |
| audio_state=mid_audio_state, | |
| sigmas=torch.stack([sub_sigma]).to(sigmas.device), | |
| step_index=0, | |
| ) | |
| if video_state is not None and denoised_video_2 is not None: | |
| denoised_video_2 = post_process_latent(denoised_video_2, video_state.denoise_mask, video_state.clean_latent) | |
| if audio_state is not None and denoised_audio_2 is not None: | |
| denoised_audio_2 = post_process_latent(denoised_audio_2, audio_state.denoise_mask, audio_state.clean_latent) | |
| # ==================================================================== | |
| # FINAL COMBINATION: Compute x_next using RK coefficients | |
| # ==================================================================== | |
| if x_anchor_video is not None and eps_1_video is not None and denoised_video_2 is not None: | |
| eps_2_video = denoised_video_2.double() - x_anchor_video | |
| x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video) | |
| else: | |
| x_next_video = None | |
| if x_anchor_audio is not None and eps_1_audio is not None and denoised_audio_2 is not None: | |
| eps_2_audio = denoised_audio_2.double() - x_anchor_audio | |
| x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio) | |
| else: | |
| x_next_audio = None | |
| # ==================================================================== | |
| # SDE NOISE INJECTION AT STEP LEVEL | |
| # ==================================================================== | |
| if x_next_video is not None and video_state is not None: | |
| x_next_video = step_noise_injecting_fn( | |
| state=video_state, | |
| sample=x_anchor_video, | |
| denoised_sample=x_next_video, | |
| sigmas=sigmas, | |
| step_idx=step_idx, | |
| ) | |
| if x_next_audio is not None and audio_state is not None: | |
| x_next_audio = step_noise_injecting_fn( | |
| state=audio_state, | |
| sample=x_anchor_audio, | |
| denoised_sample=x_next_audio, | |
| sigmas=sigmas, | |
| step_idx=step_idx, | |
| ) | |
| # Update states | |
| if video_state is not None and x_next_video is not None: | |
| video_state = replace(video_state, latent=x_next_video.to(model_dtype)) | |
| if audio_state is not None and x_next_audio is not None: | |
| audio_state = replace(audio_state, latent=x_next_audio.to(model_dtype)) | |
| # Final step if we need to fully remove the noise | |
| if sigmas[-1] == 0: | |
| denoised_video_1, denoised_audio_1 = denoiser(transformer, video_state, audio_state, sigmas, n_full_steps) | |
| if video_state is not None and denoised_video_1 is not None: | |
| denoised_video_1 = post_process_latent(denoised_video_1, video_state.denoise_mask, video_state.clean_latent) | |
| video_state = replace(video_state, latent=denoised_video_1.to(model_dtype)) | |
| if audio_state is not None and denoised_audio_1 is not None: | |
| denoised_audio_1 = post_process_latent(denoised_audio_1, audio_state.denoise_mask, audio_state.clean_latent) | |
| audio_state = replace(audio_state, latent=denoised_audio_1.to(model_dtype)) | |
| return video_state, audio_state | |