import logging from dataclasses import replace from functools import partial from typing import Callable import torch from tqdm import tqdm from ltx_core.components.diffusion_steps import Res2sDiffusionStep from ltx_core.components.protocols import DiffusionStepProtocol from ltx_core.model.transformer import X0Model from ltx_core.utils import to_denoised, to_velocity from ltx_pipelines.utils.helpers import post_process_latent, timesteps_from_mask from ltx_pipelines.utils.res2s import get_res2s_coefficients from ltx_pipelines.utils.types import Denoiser, LatentState logger = logging.getLogger(__name__) def _step_state( state: LatentState | None, denoised: torch.Tensor | None, stepper: DiffusionStepProtocol, sigmas: torch.Tensor, step_idx: int, ) -> LatentState | None: """Advance one diffusion step for a single modality, or return ``None`` if absent.""" if state is None or denoised is None: return state denoised = post_process_latent(denoised, state.denoise_mask, state.clean_latent) return replace(state, latent=stepper.step(state.latent, denoised, sigmas, step_idx)) def euler_denoising_loop( sigmas: torch.Tensor, video_state: LatentState | None, audio_state: LatentState | None, stepper: DiffusionStepProtocol, transformer: X0Model, denoiser: Denoiser, ) -> tuple[LatentState | None, LatentState | None]: """ Perform the joint audio-video denoising loop over a diffusion schedule. Either ``video_state`` or ``audio_state`` may be ``None`` for absent modalities; the absent modality is passed through unchanged. ### Parameters sigmas: A 1D tensor of noise levels (diffusion sigmas) defining the sampling schedule. All steps except the last element are iterated over. video_state: The current video :class:`LatentState`, or ``None`` if video is absent. audio_state: The current audio :class:`LatentState`, or ``None`` if audio is absent. stepper: An implementation of :class:`DiffusionStepProtocol` that updates a latent given the current latent, its denoised estimate, the full ``sigmas`` schedule, and the current step index. transformer: The diffusion model passed to the denoiser at each step. denoiser: A callable implementing :class:`Denoiser`. It is invoked as ``denoiser(transformer, video_state, audio_state, sigmas, step_index)`` and must return ``(denoised_video, denoised_audio)``. ### Returns tuple[LatentState | None, LatentState | None] Final ``(video_state, audio_state)`` after the denoising loop. """ for step_idx, _ in enumerate(tqdm(sigmas[:-1])): denoised_video, denoised_audio = denoiser(transformer, video_state, audio_state, sigmas, step_idx) video_state = _step_state(video_state, denoised_video, stepper, sigmas, step_idx) audio_state = _step_state(audio_state, denoised_audio, stepper, sigmas, step_idx) return (video_state, audio_state) def heun_denoising_loop( sigmas: torch.Tensor, video_state: LatentState | None, audio_state: LatentState | None, stepper: DiffusionStepProtocol, transformer: X0Model, denoiser: Denoiser, ) -> tuple[LatentState | None, LatentState | None]: """Heun (2nd-order predictor-corrector) variant of :func:`euler_denoising_loop`. Port of the ``jkass_quality`` sampler from JK-AceStep-Nodes. ~2x model calls per step; last step falls back to plain Euler. """ n = len(sigmas) - 1 for step_idx in tqdm(range(n)): sigma_curr = sigmas[step_idx] sigma_next = sigmas[step_idx + 1] dt = sigma_next - sigma_curr denoised_video_1, denoised_audio_1 = denoiser( transformer, video_state, audio_state, sigmas, step_idx, ) video_pred = _step_state(video_state, denoised_video_1, stepper, sigmas, step_idx) audio_pred = _step_state(audio_state, denoised_audio_1, stepper, sigmas, step_idx) if step_idx == n - 1 or float(sigma_next) == 0.0: video_state, audio_state = video_pred, audio_pred continue denoised_video_2, denoised_audio_2 = denoiser( transformer, video_pred, audio_pred, sigmas, step_idx + 1, ) def _heun_step(state, state_pred, denoised_1, denoised_2): if state is None or denoised_1 is None or denoised_2 is None: return state d1 = post_process_latent(denoised_1, state.denoise_mask, state.clean_latent) d2 = post_process_latent(denoised_2, state.denoise_mask, state.clean_latent) v1 = to_velocity(state.latent, sigma_curr, d1) v2 = to_velocity(state_pred.latent, sigma_next, d2) v_avg = 0.5 * (v1 + v2) new_lat = (state.latent.to(torch.float32) + v_avg.to(torch.float32) * dt).to(state.latent.dtype) return replace(state, latent=new_lat) video_state = _heun_step(video_state, video_pred, denoised_video_1, denoised_video_2) audio_state = _heun_step(audio_state, audio_pred, denoised_audio_1, denoised_audio_2) return (video_state, audio_state) def gradient_estimating_euler_denoising_loop( sigmas: torch.Tensor, video_state: LatentState | None, audio_state: LatentState | None, stepper: DiffusionStepProtocol, transformer: X0Model, denoiser: Denoiser, ge_gamma: float = 2.0, ) -> tuple[LatentState | None, LatentState | None]: """ Perform the joint audio-video denoising loop using gradient-estimation sampling. Same interface as :func:`euler_denoising_loop` with an additional ``ge_gamma`` parameter for velocity correction. ### Parameters ge_gamma: Gradient estimation coefficient controlling the velocity correction term. Default is 2.0. Paper: https://openreview.net/pdf?id=o2ND9v0CeK ### Returns tuple[LatentState | None, LatentState | None] See :func:`euler_denoising_loop` for return value description. """ previous_audio_velocity = None previous_video_velocity = None def update_velocity_and_sample( noisy_sample: torch.Tensor, denoised_sample: torch.Tensor, sigma: float, previous_velocity: torch.Tensor | None ) -> tuple[torch.Tensor, torch.Tensor]: current_velocity = to_velocity(noisy_sample, sigma, denoised_sample) if previous_velocity is not None: delta_v = current_velocity - previous_velocity total_velocity = ge_gamma * delta_v + previous_velocity denoised_sample = to_denoised(noisy_sample, total_velocity, sigma) return current_velocity, denoised_sample for step_idx, _ in enumerate(tqdm(sigmas[:-1])): denoised_video, denoised_audio = denoiser(transformer, video_state, audio_state, sigmas, step_idx) if video_state is not None and denoised_video is not None: denoised_video = post_process_latent(denoised_video, video_state.denoise_mask, video_state.clean_latent) if audio_state is not None and denoised_audio is not None: denoised_audio = post_process_latent(denoised_audio, audio_state.denoise_mask, audio_state.clean_latent) if sigmas[step_idx + 1] == 0: if video_state is not None and denoised_video is not None: video_state = replace(video_state, latent=denoised_video) if audio_state is not None and denoised_audio is not None: audio_state = replace(audio_state, latent=denoised_audio) return video_state, audio_state if video_state is not None and denoised_video is not None: previous_video_velocity, denoised_video = update_velocity_and_sample( video_state.latent, denoised_video, sigmas[step_idx], previous_video_velocity ) video_state = replace( video_state, latent=stepper.step(video_state.latent, denoised_video, sigmas, step_idx) ) if audio_state is not None and denoised_audio is not None: previous_audio_velocity, denoised_audio = update_velocity_and_sample( audio_state.latent, denoised_audio, sigmas[step_idx], previous_audio_velocity ) audio_state = replace( audio_state, latent=stepper.step(audio_state.latent, denoised_audio, sigmas, step_idx) ) return (video_state, audio_state) def _channelwise_normalize(x: torch.Tensor) -> torch.Tensor: return x.sub_(x.mean(dim=(-2, -1), keepdim=True)).div_(x.std(dim=(-2, -1), keepdim=True)) def _get_new_noise(x: torch.Tensor, generator: torch.Generator) -> torch.Tensor: noise = torch.randn(x.shape, generator=generator, dtype=torch.float64, device=generator.device) noise = (noise - noise.mean()) / noise.std() return _channelwise_normalize(noise) def _inject_sde_noise( state: LatentState, sample: torch.Tensor, denoised_sample: torch.Tensor, step_noise_generator: torch.Generator, new_noise_fn: Callable[[torch.Tensor, torch.Generator], torch.Tensor], stepper: DiffusionStepProtocol, sigmas: torch.Tensor, step_idx: int, legacy_mode: bool = False, eta: float = 0.5, ) -> torch.Tensor: sigmas_copy = sigmas.clone() new_noise = new_noise_fn(state.latent, step_noise_generator) if not legacy_mode: timesteps = timesteps_from_mask(state.denoise_mask.double(), sigmas_copy[step_idx].double()) next_timesteps = timesteps_from_mask(state.denoise_mask.double(), sigmas_copy[step_idx + 1].double()) sigmas = torch.stack([timesteps, next_timesteps]) step_idx = 0 x_next = stepper.step( sample=sample, denoised_sample=denoised_sample, sigmas=sigmas, step_index=step_idx, noise=new_noise, eta=eta, ) if legacy_mode: x_next = post_process_latent(x_next, state.denoise_mask, state.clean_latent) return x_next def res2s_audio_video_denoising_loop( # noqa: PLR0913,PLR0915,PLR0912 sigmas: torch.Tensor, video_state: LatentState | None, audio_state: LatentState | None, stepper: DiffusionStepProtocol, transformer: X0Model, denoiser: Denoiser, noise_seed: int = -1, noise_seed_substep: int | None = None, eta: float = 0.5, bongmath: bool = True, bongmath_max_iter: int = 100, new_noise_fn: Callable[[torch.Tensor, torch.Generator], torch.Tensor] = _get_new_noise, model_dtype: torch.dtype = torch.bfloat16, legacy_mode: bool = True, ) -> tuple[LatentState | None, LatentState | None]: """ Joint audio-video denoising loop using the res_2s second-order sampler. Iterates over the diffusion schedule with a two-stage Runge-Kutta step: evaluates the denoiser at the current point and at a midpoint (with SDE noise), then combines both with RK coefficients. Supports anchor-point refinement (bong iteration) and optional SDE noise injection. Requires :class:`Res2sDiffusionStep` as ``stepper``. Either modality may be ``None`` (absent). ### Parameters transformer: The diffusion model passed to the denoiser at each step. denoiser: Callable implementing :class:`Denoiser`. noise_seed: Seed for step-level SDE noise; substep seed defaults to ``noise_seed + 10000``. noise_seed_substep: Optional seed for substep SDE noise; if None, derived from ``noise_seed``. eta: Controls stochastic noise injection strength (0=deterministic, 1=maximum). Applies to main diffusion steps; substeps always use 0.5. Default 0.5. bongmath: Whether to run iterative anchor refinement (bong iteration) when step size is small. bongmath_max_iter: Max iterations for bong refinement when enabled. new_noise_fn: Callable ``(latent, generator) -> noise`` for SDE injection. model_dtype: Dtype for latent state updates (e.g. bfloat16). ### Returns tuple[LatentState | None, LatentState | None] Final ``(video_state, audio_state)`` after the denoising loop. """ # Determine device from whichever state is present present_state = video_state or audio_state if present_state is None: raise ValueError("At least one of video_state or audio_state must be provided") state_device = present_state.latent.device # Initialize noise generators with different seeds if noise_seed_substep is None: noise_seed_substep = noise_seed + 10000 # Offset to ensure different seeds step_noise_generator = torch.Generator(device=state_device).manual_seed(noise_seed) substep_noise_generator = torch.Generator(device=state_device).manual_seed(noise_seed_substep) sde_noise_injecting_fn = partial( _inject_sde_noise, stepper=stepper, new_noise_fn=new_noise_fn, legacy_mode=legacy_mode ) step_noise_injecting_fn = partial(sde_noise_injecting_fn, step_noise_generator=step_noise_generator, eta=eta) # substep eta is always default 0.5 for compatibility with original implementation. substep_noise_injecting_fn = partial(sde_noise_injecting_fn, step_noise_generator=substep_noise_generator, eta=0.5) if not isinstance(stepper, Res2sDiffusionStep): raise ValueError("stepper must be an instance of Res2sDiffusionStep") n_full_steps = len(sigmas) - 1 # inject minimal sigma value to avoid division by zero if sigmas[-1] == 0: sigmas = torch.cat([sigmas[:-1], torch.tensor([0.0011, 0.0], device=sigmas.device)], dim=0) # Compute step sizes in hyperbolic space hs = -torch.log(sigmas[1:].double().cpu() / (sigmas[:-1].double().cpu())) # Initialize phi cache for reuse across loop iterations phi_cache = {} c2 = 0.5 # Midpoint for res_2s for step_idx in tqdm(range(n_full_steps)): sigma = sigmas[step_idx].double() sigma_next = sigmas[step_idx + 1].double() # Initialize anchor point x_anchor_video = video_state.latent.clone().double() if video_state is not None else None x_anchor_audio = audio_state.latent.clone().double() if audio_state is not None else None # ==================================================================== # STAGE 1: Evaluate at current point # ==================================================================== denoised_video_1, denoised_audio_1 = denoiser(transformer, video_state, audio_state, sigmas, step_idx) if video_state is not None and denoised_video_1 is not None: denoised_video_1 = post_process_latent(denoised_video_1, video_state.denoise_mask, video_state.clean_latent) if audio_state is not None and denoised_audio_1 is not None: denoised_audio_1 = post_process_latent(denoised_audio_1, audio_state.denoise_mask, audio_state.clean_latent) h = hs[step_idx].item() # Compute RK coefficients (pass phi_cache for caching) a21, b1, b2 = get_res2s_coefficients(h, phi_cache, c2) # Compute substep sigma, sqrt is a hardcode for c2 = 0.5 sub_sigma = torch.sqrt(sigma * sigma_next) # ==================================================================== # Compute substep x using RK coefficient a21 # ==================================================================== if x_anchor_video is not None and denoised_video_1 is not None: eps_1_video = denoised_video_1.double() - x_anchor_video x_mid_video = x_anchor_video.double() + h * a21 * eps_1_video else: eps_1_video = None x_mid_video = None if x_anchor_audio is not None and denoised_audio_1 is not None: eps_1_audio = denoised_audio_1.double() - x_anchor_audio x_mid_audio = x_anchor_audio.double() + h * a21 * eps_1_audio else: eps_1_audio = None x_mid_audio = None # ==================================================================== # SDE noise injection at substep # ==================================================================== if x_mid_video is not None and video_state is not None: x_mid_video = substep_noise_injecting_fn( state=video_state, sample=x_anchor_video, denoised_sample=x_mid_video, sigmas=torch.stack([sigma, sub_sigma]), step_idx=0, ) if x_mid_audio is not None and audio_state is not None: x_mid_audio = substep_noise_injecting_fn( state=audio_state, sample=x_anchor_audio, denoised_sample=x_mid_audio, sigmas=torch.stack([sigma, sub_sigma]), step_idx=0, ) # ==================================================================== # ITERATIVE REFINEMENT (Bong Iteration) # ==================================================================== if bongmath and h < 0.5 and sigma > 0.03: for _ in range(bongmath_max_iter): if x_mid_video is not None and eps_1_video is not None: x_anchor_video = x_mid_video - h * a21 * eps_1_video eps_1_video = denoised_video_1.double() - x_anchor_video if x_mid_audio is not None and eps_1_audio is not None: x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio eps_1_audio = denoised_audio_1.double() - x_anchor_audio # ==================================================================== # STAGE 2: Evaluate at substep point (WITH NOISE) # ==================================================================== mid_video_state = ( replace(video_state, latent=x_mid_video.to(model_dtype)) if video_state is not None and x_mid_video is not None else None ) mid_audio_state = ( replace(audio_state, latent=x_mid_audio.to(model_dtype)) if audio_state is not None and x_mid_audio is not None else None ) denoised_video_2, denoised_audio_2 = denoiser( transformer, video_state=mid_video_state, audio_state=mid_audio_state, sigmas=torch.stack([sub_sigma]).to(sigmas.device), step_index=0, ) if video_state is not None and denoised_video_2 is not None: denoised_video_2 = post_process_latent(denoised_video_2, video_state.denoise_mask, video_state.clean_latent) if audio_state is not None and denoised_audio_2 is not None: denoised_audio_2 = post_process_latent(denoised_audio_2, audio_state.denoise_mask, audio_state.clean_latent) # ==================================================================== # FINAL COMBINATION: Compute x_next using RK coefficients # ==================================================================== if x_anchor_video is not None and eps_1_video is not None and denoised_video_2 is not None: eps_2_video = denoised_video_2.double() - x_anchor_video x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video) else: x_next_video = None if x_anchor_audio is not None and eps_1_audio is not None and denoised_audio_2 is not None: eps_2_audio = denoised_audio_2.double() - x_anchor_audio x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio) else: x_next_audio = None # ==================================================================== # SDE NOISE INJECTION AT STEP LEVEL # ==================================================================== if x_next_video is not None and video_state is not None: x_next_video = step_noise_injecting_fn( state=video_state, sample=x_anchor_video, denoised_sample=x_next_video, sigmas=sigmas, step_idx=step_idx, ) if x_next_audio is not None and audio_state is not None: x_next_audio = step_noise_injecting_fn( state=audio_state, sample=x_anchor_audio, denoised_sample=x_next_audio, sigmas=sigmas, step_idx=step_idx, ) # Update states if video_state is not None and x_next_video is not None: video_state = replace(video_state, latent=x_next_video.to(model_dtype)) if audio_state is not None and x_next_audio is not None: audio_state = replace(audio_state, latent=x_next_audio.to(model_dtype)) # Final step if we need to fully remove the noise if sigmas[-1] == 0: denoised_video_1, denoised_audio_1 = denoiser(transformer, video_state, audio_state, sigmas, n_full_steps) if video_state is not None and denoised_video_1 is not None: denoised_video_1 = post_process_latent(denoised_video_1, video_state.denoise_mask, video_state.clean_latent) video_state = replace(video_state, latent=denoised_video_1.to(model_dtype)) if audio_state is not None and denoised_audio_1 is not None: denoised_audio_1 = post_process_latent(denoised_audio_1, audio_state.denoise_mask, audio_state.clean_latent) audio_state = replace(audio_state, latent=denoised_audio_1.to(model_dtype)) return video_state, audio_state