Dramabox / ltx2 /ltx_pipelines /utils /samplers.py
Manmay's picture
DramaBox Space — initial app + vendored ltx2
08c5e28 verified
import logging
from dataclasses import replace
from functools import partial
from typing import Callable
import torch
from tqdm import tqdm
from ltx_core.components.diffusion_steps import Res2sDiffusionStep
from ltx_core.components.protocols import DiffusionStepProtocol
from ltx_core.model.transformer import X0Model
from ltx_core.utils import to_denoised, to_velocity
from ltx_pipelines.utils.helpers import post_process_latent, timesteps_from_mask
from ltx_pipelines.utils.res2s import get_res2s_coefficients
from ltx_pipelines.utils.types import Denoiser, LatentState
logger = logging.getLogger(__name__)
def _step_state(
state: LatentState | None,
denoised: torch.Tensor | None,
stepper: DiffusionStepProtocol,
sigmas: torch.Tensor,
step_idx: int,
) -> LatentState | None:
"""Advance one diffusion step for a single modality, or return ``None`` if absent."""
if state is None or denoised is None:
return state
denoised = post_process_latent(denoised, state.denoise_mask, state.clean_latent)
return replace(state, latent=stepper.step(state.latent, denoised, sigmas, step_idx))
def euler_denoising_loop(
sigmas: torch.Tensor,
video_state: LatentState | None,
audio_state: LatentState | None,
stepper: DiffusionStepProtocol,
transformer: X0Model,
denoiser: Denoiser,
) -> tuple[LatentState | None, LatentState | None]:
"""
Perform the joint audio-video denoising loop over a diffusion schedule.
Either ``video_state`` or ``audio_state`` may be ``None`` for absent
modalities; the absent modality is passed through unchanged.
### Parameters
sigmas:
A 1D tensor of noise levels (diffusion sigmas) defining the sampling
schedule. All steps except the last element are iterated over.
video_state:
The current video :class:`LatentState`, or ``None`` if video is absent.
audio_state:
The current audio :class:`LatentState`, or ``None`` if audio is absent.
stepper:
An implementation of :class:`DiffusionStepProtocol` that updates a
latent given the current latent, its denoised estimate, the full
``sigmas`` schedule, and the current step index.
transformer:
The diffusion model passed to the denoiser at each step.
denoiser:
A callable implementing :class:`Denoiser`. It is invoked as
``denoiser(transformer, video_state, audio_state, sigmas, step_index)``
and must return ``(denoised_video, denoised_audio)``.
### Returns
tuple[LatentState | None, LatentState | None]
Final ``(video_state, audio_state)`` after the denoising loop.
"""
for step_idx, _ in enumerate(tqdm(sigmas[:-1])):
denoised_video, denoised_audio = denoiser(transformer, video_state, audio_state, sigmas, step_idx)
video_state = _step_state(video_state, denoised_video, stepper, sigmas, step_idx)
audio_state = _step_state(audio_state, denoised_audio, stepper, sigmas, step_idx)
return (video_state, audio_state)
def heun_denoising_loop(
sigmas: torch.Tensor,
video_state: LatentState | None,
audio_state: LatentState | None,
stepper: DiffusionStepProtocol,
transformer: X0Model,
denoiser: Denoiser,
) -> tuple[LatentState | None, LatentState | None]:
"""Heun (2nd-order predictor-corrector) variant of :func:`euler_denoising_loop`.
Port of the ``jkass_quality`` sampler from JK-AceStep-Nodes. ~2x model
calls per step; last step falls back to plain Euler.
"""
n = len(sigmas) - 1
for step_idx in tqdm(range(n)):
sigma_curr = sigmas[step_idx]
sigma_next = sigmas[step_idx + 1]
dt = sigma_next - sigma_curr
denoised_video_1, denoised_audio_1 = denoiser(
transformer, video_state, audio_state, sigmas, step_idx,
)
video_pred = _step_state(video_state, denoised_video_1, stepper, sigmas, step_idx)
audio_pred = _step_state(audio_state, denoised_audio_1, stepper, sigmas, step_idx)
if step_idx == n - 1 or float(sigma_next) == 0.0:
video_state, audio_state = video_pred, audio_pred
continue
denoised_video_2, denoised_audio_2 = denoiser(
transformer, video_pred, audio_pred, sigmas, step_idx + 1,
)
def _heun_step(state, state_pred, denoised_1, denoised_2):
if state is None or denoised_1 is None or denoised_2 is None:
return state
d1 = post_process_latent(denoised_1, state.denoise_mask, state.clean_latent)
d2 = post_process_latent(denoised_2, state.denoise_mask, state.clean_latent)
v1 = to_velocity(state.latent, sigma_curr, d1)
v2 = to_velocity(state_pred.latent, sigma_next, d2)
v_avg = 0.5 * (v1 + v2)
new_lat = (state.latent.to(torch.float32) + v_avg.to(torch.float32) * dt).to(state.latent.dtype)
return replace(state, latent=new_lat)
video_state = _heun_step(video_state, video_pred, denoised_video_1, denoised_video_2)
audio_state = _heun_step(audio_state, audio_pred, denoised_audio_1, denoised_audio_2)
return (video_state, audio_state)
def gradient_estimating_euler_denoising_loop(
sigmas: torch.Tensor,
video_state: LatentState | None,
audio_state: LatentState | None,
stepper: DiffusionStepProtocol,
transformer: X0Model,
denoiser: Denoiser,
ge_gamma: float = 2.0,
) -> tuple[LatentState | None, LatentState | None]:
"""
Perform the joint audio-video denoising loop using gradient-estimation sampling.
Same interface as :func:`euler_denoising_loop` with an additional
``ge_gamma`` parameter for velocity correction.
### Parameters
ge_gamma:
Gradient estimation coefficient controlling the velocity correction term.
Default is 2.0. Paper: https://openreview.net/pdf?id=o2ND9v0CeK
### Returns
tuple[LatentState | None, LatentState | None]
See :func:`euler_denoising_loop` for return value description.
"""
previous_audio_velocity = None
previous_video_velocity = None
def update_velocity_and_sample(
noisy_sample: torch.Tensor, denoised_sample: torch.Tensor, sigma: float, previous_velocity: torch.Tensor | None
) -> tuple[torch.Tensor, torch.Tensor]:
current_velocity = to_velocity(noisy_sample, sigma, denoised_sample)
if previous_velocity is not None:
delta_v = current_velocity - previous_velocity
total_velocity = ge_gamma * delta_v + previous_velocity
denoised_sample = to_denoised(noisy_sample, total_velocity, sigma)
return current_velocity, denoised_sample
for step_idx, _ in enumerate(tqdm(sigmas[:-1])):
denoised_video, denoised_audio = denoiser(transformer, video_state, audio_state, sigmas, step_idx)
if video_state is not None and denoised_video is not None:
denoised_video = post_process_latent(denoised_video, video_state.denoise_mask, video_state.clean_latent)
if audio_state is not None and denoised_audio is not None:
denoised_audio = post_process_latent(denoised_audio, audio_state.denoise_mask, audio_state.clean_latent)
if sigmas[step_idx + 1] == 0:
if video_state is not None and denoised_video is not None:
video_state = replace(video_state, latent=denoised_video)
if audio_state is not None and denoised_audio is not None:
audio_state = replace(audio_state, latent=denoised_audio)
return video_state, audio_state
if video_state is not None and denoised_video is not None:
previous_video_velocity, denoised_video = update_velocity_and_sample(
video_state.latent, denoised_video, sigmas[step_idx], previous_video_velocity
)
video_state = replace(
video_state, latent=stepper.step(video_state.latent, denoised_video, sigmas, step_idx)
)
if audio_state is not None and denoised_audio is not None:
previous_audio_velocity, denoised_audio = update_velocity_and_sample(
audio_state.latent, denoised_audio, sigmas[step_idx], previous_audio_velocity
)
audio_state = replace(
audio_state, latent=stepper.step(audio_state.latent, denoised_audio, sigmas, step_idx)
)
return (video_state, audio_state)
def _channelwise_normalize(x: torch.Tensor) -> torch.Tensor:
return x.sub_(x.mean(dim=(-2, -1), keepdim=True)).div_(x.std(dim=(-2, -1), keepdim=True))
def _get_new_noise(x: torch.Tensor, generator: torch.Generator) -> torch.Tensor:
noise = torch.randn(x.shape, generator=generator, dtype=torch.float64, device=generator.device)
noise = (noise - noise.mean()) / noise.std()
return _channelwise_normalize(noise)
def _inject_sde_noise(
state: LatentState,
sample: torch.Tensor,
denoised_sample: torch.Tensor,
step_noise_generator: torch.Generator,
new_noise_fn: Callable[[torch.Tensor, torch.Generator], torch.Tensor],
stepper: DiffusionStepProtocol,
sigmas: torch.Tensor,
step_idx: int,
legacy_mode: bool = False,
eta: float = 0.5,
) -> torch.Tensor:
sigmas_copy = sigmas.clone()
new_noise = new_noise_fn(state.latent, step_noise_generator)
if not legacy_mode:
timesteps = timesteps_from_mask(state.denoise_mask.double(), sigmas_copy[step_idx].double())
next_timesteps = timesteps_from_mask(state.denoise_mask.double(), sigmas_copy[step_idx + 1].double())
sigmas = torch.stack([timesteps, next_timesteps])
step_idx = 0
x_next = stepper.step(
sample=sample,
denoised_sample=denoised_sample,
sigmas=sigmas,
step_index=step_idx,
noise=new_noise,
eta=eta,
)
if legacy_mode:
x_next = post_process_latent(x_next, state.denoise_mask, state.clean_latent)
return x_next
def res2s_audio_video_denoising_loop( # noqa: PLR0913,PLR0915,PLR0912
sigmas: torch.Tensor,
video_state: LatentState | None,
audio_state: LatentState | None,
stepper: DiffusionStepProtocol,
transformer: X0Model,
denoiser: Denoiser,
noise_seed: int = -1,
noise_seed_substep: int | None = None,
eta: float = 0.5,
bongmath: bool = True,
bongmath_max_iter: int = 100,
new_noise_fn: Callable[[torch.Tensor, torch.Generator], torch.Tensor] = _get_new_noise,
model_dtype: torch.dtype = torch.bfloat16,
legacy_mode: bool = True,
) -> tuple[LatentState | None, LatentState | None]:
"""
Joint audio-video denoising loop using the res_2s second-order sampler.
Iterates over the diffusion schedule with a two-stage Runge-Kutta step:
evaluates the denoiser at the current point and at a midpoint (with SDE
noise), then combines both with RK coefficients. Supports anchor-point
refinement (bong iteration) and optional SDE noise injection. Requires
:class:`Res2sDiffusionStep` as ``stepper``.
Either modality may be ``None`` (absent).
### Parameters
transformer:
The diffusion model passed to the denoiser at each step.
denoiser:
Callable implementing :class:`Denoiser`.
noise_seed:
Seed for step-level SDE noise; substep seed defaults to ``noise_seed + 10000``.
noise_seed_substep:
Optional seed for substep SDE noise; if None, derived from ``noise_seed``.
eta:
Controls stochastic noise injection strength (0=deterministic, 1=maximum).
Applies to main diffusion steps; substeps always use 0.5. Default 0.5.
bongmath:
Whether to run iterative anchor refinement (bong iteration) when step size is small.
bongmath_max_iter:
Max iterations for bong refinement when enabled.
new_noise_fn:
Callable ``(latent, generator) -> noise`` for SDE injection.
model_dtype:
Dtype for latent state updates (e.g. bfloat16).
### Returns
tuple[LatentState | None, LatentState | None]
Final ``(video_state, audio_state)`` after the denoising loop.
"""
# Determine device from whichever state is present
present_state = video_state or audio_state
if present_state is None:
raise ValueError("At least one of video_state or audio_state must be provided")
state_device = present_state.latent.device
# Initialize noise generators with different seeds
if noise_seed_substep is None:
noise_seed_substep = noise_seed + 10000 # Offset to ensure different seeds
step_noise_generator = torch.Generator(device=state_device).manual_seed(noise_seed)
substep_noise_generator = torch.Generator(device=state_device).manual_seed(noise_seed_substep)
sde_noise_injecting_fn = partial(
_inject_sde_noise, stepper=stepper, new_noise_fn=new_noise_fn, legacy_mode=legacy_mode
)
step_noise_injecting_fn = partial(sde_noise_injecting_fn, step_noise_generator=step_noise_generator, eta=eta)
# substep eta is always default 0.5 for compatibility with original implementation.
substep_noise_injecting_fn = partial(sde_noise_injecting_fn, step_noise_generator=substep_noise_generator, eta=0.5)
if not isinstance(stepper, Res2sDiffusionStep):
raise ValueError("stepper must be an instance of Res2sDiffusionStep")
n_full_steps = len(sigmas) - 1
# inject minimal sigma value to avoid division by zero
if sigmas[-1] == 0:
sigmas = torch.cat([sigmas[:-1], torch.tensor([0.0011, 0.0], device=sigmas.device)], dim=0)
# Compute step sizes in hyperbolic space
hs = -torch.log(sigmas[1:].double().cpu() / (sigmas[:-1].double().cpu()))
# Initialize phi cache for reuse across loop iterations
phi_cache = {}
c2 = 0.5 # Midpoint for res_2s
for step_idx in tqdm(range(n_full_steps)):
sigma = sigmas[step_idx].double()
sigma_next = sigmas[step_idx + 1].double()
# Initialize anchor point
x_anchor_video = video_state.latent.clone().double() if video_state is not None else None
x_anchor_audio = audio_state.latent.clone().double() if audio_state is not None else None
# ====================================================================
# STAGE 1: Evaluate at current point
# ====================================================================
denoised_video_1, denoised_audio_1 = denoiser(transformer, video_state, audio_state, sigmas, step_idx)
if video_state is not None and denoised_video_1 is not None:
denoised_video_1 = post_process_latent(denoised_video_1, video_state.denoise_mask, video_state.clean_latent)
if audio_state is not None and denoised_audio_1 is not None:
denoised_audio_1 = post_process_latent(denoised_audio_1, audio_state.denoise_mask, audio_state.clean_latent)
h = hs[step_idx].item()
# Compute RK coefficients (pass phi_cache for caching)
a21, b1, b2 = get_res2s_coefficients(h, phi_cache, c2)
# Compute substep sigma, sqrt is a hardcode for c2 = 0.5
sub_sigma = torch.sqrt(sigma * sigma_next)
# ====================================================================
# Compute substep x using RK coefficient a21
# ====================================================================
if x_anchor_video is not None and denoised_video_1 is not None:
eps_1_video = denoised_video_1.double() - x_anchor_video
x_mid_video = x_anchor_video.double() + h * a21 * eps_1_video
else:
eps_1_video = None
x_mid_video = None
if x_anchor_audio is not None and denoised_audio_1 is not None:
eps_1_audio = denoised_audio_1.double() - x_anchor_audio
x_mid_audio = x_anchor_audio.double() + h * a21 * eps_1_audio
else:
eps_1_audio = None
x_mid_audio = None
# ====================================================================
# SDE noise injection at substep
# ====================================================================
if x_mid_video is not None and video_state is not None:
x_mid_video = substep_noise_injecting_fn(
state=video_state,
sample=x_anchor_video,
denoised_sample=x_mid_video,
sigmas=torch.stack([sigma, sub_sigma]),
step_idx=0,
)
if x_mid_audio is not None and audio_state is not None:
x_mid_audio = substep_noise_injecting_fn(
state=audio_state,
sample=x_anchor_audio,
denoised_sample=x_mid_audio,
sigmas=torch.stack([sigma, sub_sigma]),
step_idx=0,
)
# ====================================================================
# ITERATIVE REFINEMENT (Bong Iteration)
# ====================================================================
if bongmath and h < 0.5 and sigma > 0.03:
for _ in range(bongmath_max_iter):
if x_mid_video is not None and eps_1_video is not None:
x_anchor_video = x_mid_video - h * a21 * eps_1_video
eps_1_video = denoised_video_1.double() - x_anchor_video
if x_mid_audio is not None and eps_1_audio is not None:
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
eps_1_audio = denoised_audio_1.double() - x_anchor_audio
# ====================================================================
# STAGE 2: Evaluate at substep point (WITH NOISE)
# ====================================================================
mid_video_state = (
replace(video_state, latent=x_mid_video.to(model_dtype))
if video_state is not None and x_mid_video is not None
else None
)
mid_audio_state = (
replace(audio_state, latent=x_mid_audio.to(model_dtype))
if audio_state is not None and x_mid_audio is not None
else None
)
denoised_video_2, denoised_audio_2 = denoiser(
transformer,
video_state=mid_video_state,
audio_state=mid_audio_state,
sigmas=torch.stack([sub_sigma]).to(sigmas.device),
step_index=0,
)
if video_state is not None and denoised_video_2 is not None:
denoised_video_2 = post_process_latent(denoised_video_2, video_state.denoise_mask, video_state.clean_latent)
if audio_state is not None and denoised_audio_2 is not None:
denoised_audio_2 = post_process_latent(denoised_audio_2, audio_state.denoise_mask, audio_state.clean_latent)
# ====================================================================
# FINAL COMBINATION: Compute x_next using RK coefficients
# ====================================================================
if x_anchor_video is not None and eps_1_video is not None and denoised_video_2 is not None:
eps_2_video = denoised_video_2.double() - x_anchor_video
x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video)
else:
x_next_video = None
if x_anchor_audio is not None and eps_1_audio is not None and denoised_audio_2 is not None:
eps_2_audio = denoised_audio_2.double() - x_anchor_audio
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
else:
x_next_audio = None
# ====================================================================
# SDE NOISE INJECTION AT STEP LEVEL
# ====================================================================
if x_next_video is not None and video_state is not None:
x_next_video = step_noise_injecting_fn(
state=video_state,
sample=x_anchor_video,
denoised_sample=x_next_video,
sigmas=sigmas,
step_idx=step_idx,
)
if x_next_audio is not None and audio_state is not None:
x_next_audio = step_noise_injecting_fn(
state=audio_state,
sample=x_anchor_audio,
denoised_sample=x_next_audio,
sigmas=sigmas,
step_idx=step_idx,
)
# Update states
if video_state is not None and x_next_video is not None:
video_state = replace(video_state, latent=x_next_video.to(model_dtype))
if audio_state is not None and x_next_audio is not None:
audio_state = replace(audio_state, latent=x_next_audio.to(model_dtype))
# Final step if we need to fully remove the noise
if sigmas[-1] == 0:
denoised_video_1, denoised_audio_1 = denoiser(transformer, video_state, audio_state, sigmas, n_full_steps)
if video_state is not None and denoised_video_1 is not None:
denoised_video_1 = post_process_latent(denoised_video_1, video_state.denoise_mask, video_state.clean_latent)
video_state = replace(video_state, latent=denoised_video_1.to(model_dtype))
if audio_state is not None and denoised_audio_1 is not None:
denoised_audio_1 = post_process_latent(denoised_audio_1, audio_state.denoise_mask, audio_state.clean_latent)
audio_state = replace(audio_state, latent=denoised_audio_1.to(model_dtype))
return video_state, audio_state