import torch import typing as tp from tqdm import trange, tqdm import torch.distributions as dist from ..data.utils import create_padding_mask_from_lengths, compute_effective_seq_len_from_conditioning def build_schedule( steps: int, sigma_max: float = 1.0, dist_shift = None, effective_seq_len: tp.Union[int, torch.Tensor, None] = None, fallback_seq_len: tp.Optional[int] = None, include_endpoint: bool = True, device: tp.Union[str, torch.device] = "cpu", ) -> torch.Tensor: """Build a timestep schedule for diffusion sampling. Returns a 1D tensor of shape (N,) where N = steps+1 (if include_endpoint) or steps (if not), OR a 2D tensor of shape (batch_size, N) when effective_seq_len is a tensor and dist_shift produces per-element schedules. Args: steps: Number of sampling steps. sigma_max: Starting noise level (1.0 for full generation, <1.0 for variations). dist_shift: Optional distribution shift object (FluxDistributionShift, DistributionShift, LogSNRShift, etc.). Applied to warp the linear schedule. effective_seq_len: Sequence length for dist_shift. Scalar int or tensor of shape (batch_size,) for per-element schedules. fallback_seq_len: Fallback when effective_seq_len is None (typically x.shape[-1]). include_endpoint: If True, schedule includes 0 as final value (RF samplers). If False, excludes 0 (v-diffusion DDIM). device: Device for the output tensor. """ n_points = steps + 1 if include_endpoint else steps if include_endpoint: t = torch.linspace(sigma_max, 0, n_points, device=device) else: t = torch.linspace(sigma_max, 0, n_points + 1, device=device)[:-1] if dist_shift is not None: seq_len = effective_seq_len if effective_seq_len is not None else fallback_seq_len if isinstance(seq_len, torch.Tensor): # Clamp per-element sequence lengths to avoid zeros causing log/NaN issues seq_len = torch.clamp(seq_len, min=1) elif seq_len is not None: # Clamp scalar sequence length to at least 1 seq_len = max(int(seq_len), 1) t = dist_shift.shift(t, seq_len) # Ensure the first timestep remains aligned with sigma_max after shifting. # This keeps the schedule consistent with the initialization in sample_diffusion(), # which mixes init_data using sigma_max. if isinstance(t, torch.Tensor): sigma_max_tensor = t.new_tensor(sigma_max) if t.ndim == 1: t[0] = sigma_max_tensor else: # For batched/per-element schedules, enforce sigma_max at the first time index. t[..., 0] = sigma_max_tensor return t def sample_timesteps_logsnr(batch_size, mean_logsnr=-1.2, std_logsnr=2.0): """ Sample timesteps for diffusion training by sampling logSNR values and converting to t. Args: batch_size (int): Number of timesteps to sample mean_logsnr (float): Mean of the logSNR Gaussian distribution std_logsnr (float): Standard deviation of the logSNR Gaussian distribution Returns: torch.Tensor: Tensor of shape (batch_size,) containing timestep values t in [0, 1] """ # Sample logSNR from Gaussian distribution logsnr = torch.randn(batch_size) * std_logsnr + mean_logsnr # Convert logSNR to timesteps using the logistic function # Since logSNR = ln((1-t)/t), we can solve for t: # t = 1 / (1 + exp(logsnr)) t = torch.sigmoid(-logsnr) # Clamp values to ensure numerical stability t = t.clamp(1e-4, 1 - 1e-4) return t def sample_timesteps_logsnr_uniform(batch_size, min_logsnr=-6, max_logsnr=5.0): """ Sample timesteps for diffusion training by sampling logSNR values and converting to t. Args: batch_size (int): Number of timesteps to sample min_logsnr (float): Minimum logSNR value max_logsnr (float): Maximum logSNR value Returns: torch.Tensor: Tensor of shape (batch_size,) containing timestep values t in [0, 1] """ # Sample logSNR from uniform distribution logsnr = torch.rand(batch_size) * (max_logsnr - min_logsnr) + min_logsnr # Convert logSNR to timesteps using the logistic function # Since logSNR = ln((1-t)/t), we can solve for t: # t = 1 / (1 + exp(logsnr)) t = torch.sigmoid(-logsnr) # Clamp values to ensure numerical stability t = t.clamp(1e-4, 1 - 1e-4) return t def truncated_logistic_normal_rescaled(shape, left_trunc=0.075, right_trunc=1): """ shape: shape of the output tensor left_trunc: left truncation point, fraction of probability to be discarded right_trunc: right truncation boundary, should be 1 (never seen at test time) """ # Step 1: Sample from the logistic normal distribution (sigmoid of normal) logits = torch.randn(shape) # Step 2: Apply the CDF transformation of the normal distribution normal_dist = dist.Normal(0, 1) cdf_values = normal_dist.cdf(logits) # Step 3: Define the truncation bounds on the CDF lower_bound = normal_dist.cdf(torch.logit(torch.tensor(left_trunc))) upper_bound = normal_dist.cdf(torch.logit(torch.tensor(right_trunc))) # Step 4: Rescale linear CDF values into the truncated region (between lower_bound and upper_bound) truncated_cdf_values = lower_bound + (upper_bound - lower_bound) * cdf_values # Step 5: Map back to logistic-normal space using inverse CDF truncated_samples = torch.sigmoid(normal_dist.icdf(truncated_cdf_values)) # Step 6: Rescale values so that min is 0 and max is just below 1 rescaled_samples = (truncated_samples - left_trunc) / (right_trunc - left_trunc) return rescaled_samples def sample_discrete_euler(model, x, sigmas, callback=None, disable_tqdm=False, **extra_args): """Draws samples from a model given starting noise. Euler method Args: sigmas: Pre-computed schedule tensor. Shape (steps+1,) for global schedule or (batch_size, steps+1) for per-element schedules. """ t = sigmas # Check if we have per-element schedules (batch_size, steps+1) or global schedule (steps+1,) per_element_schedule = t.dim() == 2 t = t.to(x.device) num_steps = t.shape[-1] - 1 for i in tqdm(range(num_steps), disable=disable_tqdm): if per_element_schedule: # Per-element schedules: t has shape (batch_size, steps+1) t_curr_tensor = t[:, i].to(x.dtype) # (batch_size,) t_prev = t[:, i + 1].to(x.dtype) # (batch_size,) dt = t_prev - t_curr_tensor # (batch_size,) # Reshape for broadcasting with x: (batch_size,) -> (batch_size, 1, 1) dt_broadcast = dt.view(-1, 1, 1) else: # Global schedule: t has shape (steps+1,) t_curr = t[i] t_prev = t[i + 1] t_curr_tensor = t_curr * torch.ones((x.shape[0],), dtype=x.dtype, device=x.device) dt = t_prev - t_curr dt_broadcast = dt v = model(x, t_curr_tensor, **extra_args) if callback is not None: denoised = x - t_curr_tensor[:, None, None] * v callback({'x': x, 't': t_curr_tensor, 'sigma': t_curr_tensor, 'i': i, 'denoised': denoised}) x = x + dt_broadcast * v # If we are on the last timestep, output the denoised data return x def sample_rk4(model, x, sigmas, callback=None, disable_tqdm=False, **extra_args): """Draws samples from a model given starting noise. 4th-order Runge-Kutta Args: sigmas: Pre-computed schedule tensor of shape (steps+1,). Per-element schedules not supported for RK4. """ # Make tensor of ones to broadcast the single t values ts = x.new_ones([x.shape[0]]) t = sigmas t = t.to(x.device) for i, (t_curr, t_prev) in enumerate(tqdm(zip(t[:-1], t[1:]), disable=disable_tqdm)): # Broadcast the current timestep to the correct shape t_curr_tensor = t_curr * ts dt = t_prev - t_curr # we solve backwards in our formulation k1 = model(x, t_curr_tensor, **extra_args) if callback is not None: denoised = x - t_curr * k1 callback({'x': x, 't': t_curr, 'sigma': t_curr, 'i': i, 'denoised': denoised}) k2 = model(x + dt / 2 * k1, (t_curr + dt / 2) * ts, **extra_args) k3 = model(x + dt / 2 * k2, (t_curr + dt / 2) * ts, **extra_args) # Clamp t_prev to avoid evaluating model at exactly t=0 # (models aren't trained at t=0 and may return garbage/NaN) t_prev_eval = t_prev.clamp(min=1e-5) k4 = model(x + dt * k3, t_prev_eval * ts, **extra_args) x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4) # If we are on the last timestep, output the denoised data return x def sample_flow_dpmpp(model, x, sigmas, callback=None, disable_tqdm=False, **extra_args): """Draws samples from a model given starting noise. DPM-Solver++ for RF models Args: sigmas: Pre-computed schedule tensor. Shape (steps+1,) for global schedule or (batch_size, steps+1) for per-element schedules. """ t = sigmas # Check if we have per-element schedules (batch_size, steps+1) or global schedule (steps+1,) per_element_schedule = t.dim() == 2 t = t.to(x.device) num_steps = t.shape[-1] - 1 old_denoised = None # Clamp t to avoid numerical issues with log(0) and division by zero # This prevents inf/-inf values that can cause NaN propagation log_snr = lambda t: ((1-t).clamp(min=1e-10) / t.clamp(min=1e-10)).log() for i in trange(num_steps, disable=disable_tqdm): if per_element_schedule: # Per-element schedules: t has shape (batch_size, steps+1) t_curr = t[:, i] # (batch_size,) t_next = t[:, i + 1] # (batch_size,) t_prev = t[:, i - 1] if i > 0 else None # Reshape for broadcasting with x: (batch_size,) -> (batch_size, 1, 1) t_curr_broadcast = t_curr.view(-1, 1, 1) t_next_broadcast = t_next.view(-1, 1, 1) t_curr_tensor = t_curr # already (batch_size,) else: # Global schedule: t has shape (steps+1,) t_curr = t[i] t_next = t[i + 1] t_prev = t[i - 1] if i > 0 else None t_curr_broadcast = t_curr t_next_broadcast = t_next t_curr_tensor = t_curr.expand(x.shape[0]) model_output = model(x, t_curr_tensor, **extra_args) denoised = x - t_curr_broadcast * model_output if callback is not None: callback({'x': x, 'i': i, 't': t_curr, 'sigma': t_curr, 'denoised': denoised}) alpha_t = 1 - t_next_broadcast # For rectified flow, compute the DPM++ coefficient directly without log_snr # to avoid numerical issues at t=0 or t=1 # The formula is: (-h).expm1() = (t_next - t_curr) / [(1 - t_next) * t_curr] # Note: t_next < t_curr, so this is negative # We'll compute this directly instead of going through log_snr dt = t_next_broadcast - t_curr_broadcast # Clamp to avoid division by zero when t_curr or t_next are at boundaries dpmpp_coeff = dt / ((1 - t_next_broadcast).clamp(min=1e-10) * t_curr_broadcast.clamp(min=1e-10)) # Check if this is the first step or the last step (t_next == 0) is_first_step = old_denoised is None is_last_step = (t_next_broadcast == 0).all() if per_element_schedule else (t_next == 0) if is_first_step or is_last_step: # First-order update using the directly computed coefficient x = (t_next_broadcast / t_curr_broadcast.clamp(min=1e-10)) * x - alpha_t * dpmpp_coeff * denoised else: # Second-order update with Richardson extrapolation if per_element_schedule: t_prev_broadcast = t_prev.view(-1, 1, 1) else: t_prev_broadcast = t_prev # Compute r = h_last / h in log-SNR space for second-order correction # h = log_snr(t_next) - log_snr(t_curr), h_last = log_snr(t_curr) - log_snr(t_prev) h = log_snr(t_next_broadcast) - log_snr(t_curr_broadcast) h_last = log_snr(t_curr_broadcast) - log_snr(t_prev_broadcast) r = h_last / h denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised x = (t_next_broadcast / t_curr_broadcast.clamp(min=1e-10)) * x - alpha_t * dpmpp_coeff * denoised_d old_denoised = denoised return x def sample_flow_pingpong(model, x, sigmas, callback=None, disable_tqdm=False, **extra_args): """Draws samples from a model given starting noise. Ping-pong sampling for distilled models Args: sigmas: Pre-computed schedule tensor. Shape (steps+1,) for global schedule or (batch_size, steps+1) for per-element schedules. """ t = sigmas # Check if we have per-element schedules (batch_size, steps+1) or global schedule (steps+1,) per_element_schedule = t.dim() == 2 t = t.to(x.device) num_steps = t.shape[-1] - 1 for i in trange(num_steps, disable=disable_tqdm): if per_element_schedule: # Per-element schedules: t has shape (batch_size, steps+1) t_curr = t[:, i].to(x.dtype) # (batch_size,) t_next = t[:, i + 1].to(x.dtype) # (batch_size,) # Reshape for broadcasting with x: (batch_size,) -> (batch_size, 1, 1) t_curr_broadcast = t_curr.view(-1, 1, 1) t_next_broadcast = t_next.view(-1, 1, 1) else: # Global schedule: t has shape (steps+1,) t_curr = t[i].to(x.dtype) t_next = t[i + 1].to(x.dtype) t_curr_broadcast = t_curr t_next_broadcast = t_next # Model forward if per_element_schedule: t_curr_tensor = t_curr # already (batch_size,) else: t_curr_tensor = t_curr * torch.ones((x.shape[0],), dtype=x.dtype, device=x.device) denoised = x - t_curr_broadcast * model(x, t_curr_tensor, **extra_args) if callback is not None: callback({'x': x, 'i': i, 't': t_curr, 'sigma': t_curr, 'sigma_hat': t_curr, 'denoised': denoised}) x = (1 - t_next_broadcast) * denoised + t_next_broadcast * torch.randn_like(x) return x @torch.no_grad() def sample_diffusion( model, noise: torch.Tensor, cond_inputs: dict, diffusion_objective: str, steps: int, cfg_scale: float = 1.0, # Varlen support conditioning: tp.Optional[tp.List[dict]] = None, sample_rate: int = 44100, pretransform = None, mask_padding_attention: bool = False, use_effective_length_for_schedule: bool = False, headroom_seconds: float = 5.0, padding_mask: tp.Optional[torch.Tensor] = None, # Timestep schedule dist_shift = None, # Sampler options sampler_type: str = None, batch_cfg: bool = True, rescale_cfg: bool = False, # CFG options apg_scale: float = 1.0, # Init data (variation / img2img) init_data: tp.Optional[torch.Tensor] = None, init_noise_level: float = 1.0, # Other callback = None, disable_tqdm: bool = False, decode: bool = True, chunked_decode: tp.Optional[bool] = None, **sampler_kwargs ) -> torch.Tensor: """ Unified sampling function for diffusion models. Handles all diffusion objectives, varlen support (padding_mask + effective_seq_len), timestep scheduling, and init_data for variation/img2img. Args: model: The diffusion model backbone (model.model, not the wrapper) noise: Initial noise tensor of shape (B, C, T) cond_inputs: Pre-processed conditioning inputs dict (merged positive + negative) diffusion_objective: One of "v", "rectified_flow", "rf_denoiser" steps: Number of sampling steps cfg_scale: Classifier-free guidance scale conditioning: List of conditioning dicts (for computing varlen from seconds_total) sample_rate: Audio sample rate pretransform: Optional pretransform for decoding latents and computing downsampling_ratio mask_padding_attention: Whether to create padding_mask for attention use_effective_length_for_schedule: Whether to use effective_seq_len for dist_shift padding_mask: Optional pre-computed padding mask (B, T). If provided, skips internal mask computation. Use this to ensure consistency with training masks. headroom_seconds: Extra seconds beyond seconds_total for valid region dist_shift: Distribution shift object for warping the timestep schedule, or None sampler_type: Sampler type. For RF: "euler", "rk4", "dpmpp", "pingpong". For v-diffusion: "v-ddim", "v-ddim-cfgpp", or k-diffusion types like "dpmpp-2m-sde". batch_cfg: Whether to use batched CFG rescale_cfg: Whether to use rescaled CFG apg_scale: APG (Adaptive Projected Guidance) scale. 1.0 = full APG, 0.0 = vanilla CFG init_data: Optional pre-encoded latent tensor for variation/img2img (shape: B, C, T) init_noise_level: Noise level (sigma_max) when using init_data. 1.0 = full noise (no variation). callback: Optional callback for progress reporting disable_tqdm: Whether to disable progress bar decode: Whether to decode latents using pretransform **sampler_kwargs: Additional kwargs passed to sampler Returns: Generated samples (decoded audio if decode=True, else latents) """ device = noise.device batch_size = noise.shape[0] latent_seq_len = noise.shape[-1] # Compute downsampling ratio downsampling_ratio = pretransform.downsampling_ratio if pretransform is not None else 1 # Default sampler_type per objective if sampler_type is None: sampler_type = "pingpong" if diffusion_objective == "rf_denoiser" else "euler" # Compute effective_seq_len for dist_shift if enabled effective_seq_len = None if use_effective_length_for_schedule and conditioning is not None: effective_seq_len = compute_effective_seq_len_from_conditioning( conditioning, sample_rate, downsampling_ratio, device ) # Create padding_mask for attention if enabled (skip if pre-computed mask provided) if padding_mask is None and mask_padding_attention and conditioning is not None: raw_effective_len = compute_effective_seq_len_from_conditioning( conditioning, sample_rate, downsampling_ratio, device ) if raw_effective_len is not None: headroom_tokens = int(headroom_seconds * sample_rate / downsampling_ratio) valid_lengths = (raw_effective_len + headroom_tokens).clamp(max=latent_seq_len).long() padding_mask = create_padding_mask_from_lengths(valid_lengths, latent_seq_len) # Determine sigma_max for schedule sigma_max = init_noise_level if init_data is not None else 1.0 # Mix init_data with noise for variation/img2img # For k-diffusion v-diffusion samplers, init_data is passed through to sample_k # which handles mixing internally with its own sigma scaling k_diff_sampler_types = {"k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast", "k-dpm-adaptive", "dpmpp-2m-sde", "dpmpp-3m-sde", "dpmpp-2m"} if init_data is not None: noise = init_data * (1 - sigma_max) + noise * sigma_max # Build common sampler kwargs (conditioning + model-level params only). # disable_tqdm and callback are passed explicitly to samplers that use them, # not included here, to avoid leaking into model forward() calls. common_kwargs = { **cond_inputs, "cfg_scale": cfg_scale, "batch_cfg": batch_cfg, "rescale_cfg": rescale_cfg, "padding_mask": padding_mask, "apg_scale": apg_scale, **sampler_kwargs } if diffusion_objective in ["rectified_flow", "rf_denoiser"]: # Remove v-diffusion-specific kwargs that don't apply to RF common_kwargs.pop("sigma_min", None) common_kwargs.pop("sigma_max", None) common_kwargs.pop("rho", None) # Build schedule sigmas = build_schedule( steps=steps, sigma_max=sigma_max, dist_shift=dist_shift, effective_seq_len=effective_seq_len, fallback_seq_len=latent_seq_len, include_endpoint=True, device=device ) # Route to sampler if sampler_type == "euler": sampled = sample_discrete_euler(model, noise, sigmas=sigmas, callback=callback, disable_tqdm=disable_tqdm, **common_kwargs) elif sampler_type == "rk4": sampled = sample_rk4(model, noise, sigmas=sigmas, callback=callback, disable_tqdm=disable_tqdm, **common_kwargs) elif sampler_type == "dpmpp": sampled = sample_flow_dpmpp(model, noise, sigmas=sigmas, callback=callback, disable_tqdm=disable_tqdm, **common_kwargs) elif sampler_type == "pingpong": sampled = sample_flow_pingpong(model, noise, sigmas=sigmas, callback=callback, disable_tqdm=disable_tqdm, **common_kwargs) else: raise ValueError(f"Unknown sampler_type for {diffusion_objective}: {sampler_type}") else: raise ValueError(f"Unknown diffusion_objective: {diffusion_objective}") # Decode if requested if decode and pretransform is not None: sampled = sampled.to(next(pretransform.parameters()).dtype) sampled = pretransform.decode(sampled, chunked=chunked_decode) # Zero out audio beyond valid region (padding positions decode to garbage) if padding_mask is not None: audio_mask = padding_mask.unsqueeze(1).repeat_interleave(downsampling_ratio, dim=-1) # Trim or pad to match sampled length if audio_mask.shape[-1] > sampled.shape[-1]: audio_mask = audio_mask[..., :sampled.shape[-1]] elif audio_mask.shape[-1] < sampled.shape[-1]: audio_mask = torch.nn.functional.pad(audio_mask, (0, sampled.shape[-1] - audio_mask.shape[-1]), value=False) sampled = sampled * audio_mask.to(sampled.dtype) return sampled