owenisas's picture
Vendor stable-audio-3 for ZeroGPU
6215e7d verified
import math
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger, CometLogger
import os
import torch
import gc
import typing as tp
import torchaudio
from einops import rearrange
from safetensors.torch import save_file
from functools import partial
from torch.nn import functional as F
from ..interface.aeiou import audio_spectrogram_image
from ..inference.sampling import truncated_logistic_normal_rescaled, sample_timesteps_logsnr, sample_timesteps_logsnr_uniform, sample_diffusion
from ..models.diffusion import ConditionedDiffusionModelWrapper
from ..models.inpainting import random_inpaint_mask, MaskType
from ..models.lora import add_lora, get_lora_params, get_lora_state_dict, LoRAParametrization, get_lora_layers, save_lora_safetensors, resolve_adapter_type, prepare_dora_state_dict, cast_base_to_precision
from .utils import create_optimizer_from_config, create_scheduler_from_config, log_audio, log_image, log_metric, get_rank, create_augmented_padding_mask, compute_masked_loss, compute_normalized_mse, resize_padding_mask, StaggeredLogger, compute_per_elem_trim, trim_and_concat
from time import time
class Profiler:
def __init__(self):
self.ticks = [[time(), None]]
def tick(self, msg):
self.ticks.append([time(), msg])
def __repr__(self):
rep = 80 * "=" + "\n"
for i in range(1, len(self.ticks)):
msg = self.ticks[i][1]
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
rep += msg + f": {ellapsed*1000:.2f}ms\n"
rep += 80 * "=" + "\n\n\n"
return rep
class DiffusionCondTrainingWrapper(pl.LightningModule):
'''
Wrapper for training a conditional audio diffusion model.
'''
def __init__(
self,
model: ConditionedDiffusionModelWrapper,
lr: float = None,
mask_loss_weight: float = 0.0,
mask_padding_attention: bool = False,
silence_extension_scale_seconds: float = 0.0,
use_ema: bool = True,
log_loss_info: bool = False,
optimizer_configs: dict = None,
pre_encoded: bool = False,
cfg_dropout_prob = 0.1,
timestep_sampler: tp.Literal["uniform", "logit_normal", "trunc_logit_normal", "log_snr", "log_snr_uniform"] = "uniform",
timestep_sampler_options: tp.Optional[tp.Dict[str, tp.Any]] = None,
validation_timesteps = [0.1, 0.3, 0.5, 0.7, 0.9],
p_one_shot: float = 0.0,
inpainting_config: dict = None,
use_effective_length_for_schedule: bool = False,
sample_rate: int = 44100,
sample_size: int = None,
loss_normalization: tp.Literal["none", "timestep", "sample", "sample_channel"] = "none",
loss_norm_eps: float = 1e-6,
lora_config: tp.Optional[tp.Dict[str, tp.Any]] = None,
lora_state_dict: tp.Optional[tp.Dict[str, tp.Any]] = None,
svd_bases_path: tp.Optional[str] = None,
log_every_n_steps: int = 10,
ot_coupling: bool = False,
base_precision: tp.Optional[str] = None,
):
super().__init__()
self.ot_coupling = ot_coupling
self.diffusion = model
self.lora_config = lora_config
if self.lora_config is not None:
# Don't use EMA with LoRA
use_ema = False
# Freeze the pre-trained model weights
self.diffusion.model.eval().requires_grad_(False)
self.diffusion.conditioner.eval().requires_grad_(False)
rank = self.lora_config.get("rank", 8)
lora_alpha = self.lora_config.get("alpha", rank)
adapter_type = self.lora_config.get("adapter_type", "lora")
include = self.lora_config.get("include", None)
exclude = self.lora_config.get("exclude", None)
# Resolve legacy "dora" to rows/cols variant
adapter_type = resolve_adapter_type(adapter_type, lora_state_dict)
print(f"LoRA config: rank={rank}, alpha={lora_alpha}, adapter_type={adapter_type}")
if include:
print(f" include: {include}")
if exclude:
print(f" exclude: {exclude}")
# Load pre-computed SVD bases for -XS adapter types
svd_bases = None
if svd_bases_path is not None:
print(f"Loading SVD bases from {svd_bases_path}")
svd_bases = torch.load(svd_bases_path, map_location="cpu", weights_only=True)
elif adapter_type.endswith("-xs"):
print("WARNING: -XS adapter without svd_bases_path — SVD will be computed per layer")
lora_config = {
torch.nn.Linear: {
"weight": partial(LoRAParametrization.from_linear, rank=rank, lora_alpha=lora_alpha, adapter_type=adapter_type),
},
torch.nn.Conv1d: {
"weight": partial(LoRAParametrization.from_conv1d, rank=rank, lora_alpha=lora_alpha, adapter_type=adapter_type),
}
}
# Add LoRA to the model
add_lora(self.diffusion.model, lora_config, include=include, exclude=exclude, svd_bases=svd_bases)
# Add LoRA to the conditioner
add_lora(self.diffusion.conditioner, lora_config, include=include, exclude=exclude, svd_bases=svd_bases)
print("lora layers:", len(get_lora_layers(self.diffusion)))
if lora_state_dict is not None:
# Old DoRA checkpoints saved magnitude as 2D (1,fan_in) or (fan_out,1);
# current code expects 1D. Squeeze so old checkpoints still load.
prepare_dora_state_dict(lora_state_dict)
self.diffusion.model.load_state_dict(lora_state_dict, strict=False)
self.diffusion.conditioner.load_state_dict(lora_state_dict, strict=False)
# Cast frozen base weights to lower precision if requested
if base_precision:
cast_base_to_precision(self.diffusion.model, base_precision)
cast_base_to_precision(self.diffusion.conditioner, base_precision)
if self.diffusion.pretransform is not None:
self.diffusion.pretransform.to(
torch.bfloat16 if base_precision in ("bf16", "bfloat16") else torch.float16
)
self.diffusion_ema = None
self.mask_loss_weight = mask_loss_weight
# Attention masking for padded tokens
# Backward compat: if passed from training config, propagate to model
if mask_padding_attention and not self.diffusion.mask_padding_attention:
import warnings
warnings.warn("mask_padding_attention in training config is deprecated. Move to model.diffusion config.", FutureWarning)
self.diffusion.mask_padding_attention = mask_padding_attention
self.mask_padding_attention = self.diffusion.mask_padding_attention
self.silence_extension_scale_seconds = silence_extension_scale_seconds
self.cfg_dropout_prob = cfg_dropout_prob
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
self.timestep_sampler = timestep_sampler
self.timestep_sampler_options = {} if timestep_sampler_options is None else timestep_sampler_options
if self.timestep_sampler == "log_snr":
self.mean_logsnr = self.timestep_sampler_options.get("mean_logsnr", -1.2)
self.std_logsnr = self.timestep_sampler_options.get("std_logsnr", 2.0)
elif self.timestep_sampler == "log_snr_uniform":
self.min_logsnr = self.timestep_sampler_options.get("min_logsnr", -6.0)
self.max_logsnr = self.timestep_sampler_options.get("max_logsnr", 5.0)
self.p_one_shot = p_one_shot
self.diffusion_objective = model.diffusion_objective
self.log_loss_info = log_loss_info
self._staggered_logger = StaggeredLogger(every_n_steps=log_every_n_steps)
assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
if optimizer_configs is None:
optimizer_configs = {
"diffusion": {
"optimizer": {
"type": "Adam",
"config": {
"lr": lr
}
}
}
}
else:
if lr is not None:
print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
self.optimizer_configs = optimizer_configs
self.pre_encoded = pre_encoded
# Loss normalization by target magnitude
# Options: "none", "timestep", "sample", "sample_channel"
self.loss_normalization = loss_normalization
self.loss_norm_eps = loss_norm_eps
# Inpainting
self.inpainting_config = inpainting_config
if self.inpainting_config is not None:
self.inpaint_mask_kwargs = self.inpainting_config.get("mask_kwargs", {})
# Per-element schedule shift based on effective (unpadded) sequence length
# Backward compat: if passed from training config, propagate to model
if use_effective_length_for_schedule and not self.diffusion.use_effective_length_for_schedule:
import warnings
warnings.warn("use_effective_length_for_schedule in training config is deprecated. Move to model.diffusion config.", DeprecationWarning)
self.diffusion.use_effective_length_for_schedule = use_effective_length_for_schedule
self.use_effective_length_for_schedule = self.diffusion.use_effective_length_for_schedule
self.sample_rate = sample_rate
self.sample_size = sample_size
# FSDP
self.use_fsdp = False
# Validation
self.validation_timesteps = validation_timesteps
self.validation_step_outputs = {}
for validation_timestep in self.validation_timesteps:
self.validation_step_outputs[f'val/loss_{validation_timestep:.1f}'] = []
def configure_optimizers(self):
diffusion_opt_config = self.optimizer_configs['diffusion']
if self.lora_config is not None:
opt_params = [*get_lora_params(self.diffusion.model), *get_lora_params(self.diffusion.conditioner)]
elif diffusion_opt_config['optimizer'].get('type') == 'MuonAdamW':
# Pass (name, param) tuples so MuonAdamW can match fused layer patterns
opt_params = [(n, p) for n, p in self.diffusion.named_parameters() if p.requires_grad]
else:
# Only include parameters that require gradients (excludes frozen pretransform, conditioner, etc.)
opt_params = [p for p in self.diffusion.parameters() if p.requires_grad]
opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], opt_params)
if "scheduler" in diffusion_opt_config:
sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
sched_diff_config = {
"scheduler": sched_diff,
"interval": "step"
}
return [opt_diff], [sched_diff_config]
return [opt_diff]
def training_step(self, batch, batch_idx):
reals, metadata = batch
p = Profiler()
if reals.ndim == 4 and reals.shape[0] == 1:
reals = reals[0]
diffusion_input = reals
p.tick("setup")
#with torch.amp.autocast(device_type="cuda"):
conditioning = self.diffusion.conditioner(metadata, self.device)
# Create batch tensor of padding masks from the metadata
# If padding_mask not provided, assume all positions are valid (no padding)
if all("padding_mask" in md for md in metadata):
padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length)
else:
# All-True mask: everything is signal, no padding
padding_masks = torch.ones(diffusion_input.shape[0], diffusion_input.shape[-1], dtype=torch.bool, device=self.device)
p.tick("conditioning")
if self.diffusion.pretransform is not None:
self.diffusion.pretransform.to(self.device)
if not self.pre_encoded:
with torch.cuda.amp.autocast(), torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad)
diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
p.tick("pretransform")
padding_masks = resize_padding_mask(padding_masks, diffusion_input.shape[-1])
else:
# Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
diffusion_input = diffusion_input / self.diffusion.pretransform.scale
if padding_masks.shape[-1] != diffusion_input.shape[-1]:
padding_masks = resize_padding_mask(padding_masks, diffusion_input.shape[-1])
if self.timestep_sampler == "uniform":
# Draw uniformly distributed continuous timesteps
t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
elif self.timestep_sampler == "logit_normal":
t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
elif self.timestep_sampler == "trunc_logit_normal":
# Draw from logistic truncated normal distribution
t = truncated_logistic_normal_rescaled(reals.shape[0]).to(self.device)
# Flip the distribution
t = 1 - t
elif self.timestep_sampler == "log_snr":
t = sample_timesteps_logsnr(reals.shape[0], mean_logsnr=self.mean_logsnr, std_logsnr=self.std_logsnr).to(self.device)
elif self.timestep_sampler == "log_snr_uniform":
t = sample_timesteps_logsnr_uniform(reals.shape[0], min_logsnr=self.min_logsnr, max_logsnr=self.max_logsnr).to(self.device)
else:
raise ValueError(f"Invalid timestep_sampler: {self.timestep_sampler}")
if self.diffusion.dist_shift is not None:
# Compute sequence length for schedule shift
if self.use_effective_length_for_schedule:
# Use per-element effective lengths derived from seconds_total (rounded up)
# This matches inference which computes effective length from seconds_total conditioning
# Fall back to padding_masks.sum() if seconds_total is not available
if all("seconds_total" in md for md in metadata):
downsampling_ratio = self.diffusion.pretransform.downsampling_ratio if self.diffusion.pretransform is not None else 1
effective_seq_len = torch.tensor(
[int(math.ceil(int(md["seconds_total"] * self.sample_rate) / downsampling_ratio)) for md in metadata],
device=self.device
)
else:
# Fallback: use padding mask sum
effective_seq_len = padding_masks.sum(dim=-1)
else:
# Use total sequence length (original behavior)
effective_seq_len = diffusion_input.shape[2]
# Shift the distribution
t = self.diffusion.dist_shift.shift(t, effective_seq_len)
if self.p_one_shot > 0:
# Set t to 1 with probability p_one_shot
t = torch.where(torch.rand_like(t) < self.p_one_shot, torch.ones_like(t), t)
# Calculate the noise schedule parameters for those timesteps
if self.diffusion_objective in ["rectified_flow", "rf_denoiser"]:
alphas, sigmas = 1-t, t
# Combine the ground truth data and the noise
alphas = alphas[:, None, None]
sigmas = sigmas[:, None, None]
noise = torch.randn_like(diffusion_input)
# Minibatch OT coupling: find optimal noise permutation for straighter transport paths
# Based on MelodyFlow (arXiv:2407.03648v2) Section 2.5.2
# Uses GPU-only Sinkhorn approximation to avoid CPU sync
if self.ot_coupling and diffusion_input.shape[0] > 1:
with torch.no_grad():
# Flatten to [batch, features] for distance computation
data_flat = diffusion_input.reshape(diffusion_input.shape[0], -1)
noise_flat = noise.reshape(noise.shape[0], -1)
# Squared L2 cost via matmul (faster than cdist, same optimal assignment)
aa = (data_flat * data_flat).sum(dim=1, keepdim=True)
bb = (noise_flat * noise_flat).sum(dim=1, keepdim=True)
cost_matrix = aa + bb.T - 2.0 * (data_flat @ noise_flat.T)
# Sinkhorn assignment (GPU-only, no CPU sync)
log_P = -cost_matrix / cost_matrix.detach().mean() # normalize for numerical stability
for _ in range(20):
log_P = log_P - torch.logsumexp(log_P, dim=1, keepdim=True)
log_P = log_P - torch.logsumexp(log_P, dim=0, keepdim=True)
# Sequential assignment from soft permutation matrix (guarantees valid permutation)
P = log_P.exp()
B = P.shape[0]
col_indices = torch.empty(B, dtype=torch.long, device=P.device)
used = torch.zeros(B, dtype=torch.bool, device=P.device)
for i in range(B):
P[i, used] = -1
col_indices[i] = P[i].argmax()
used[col_indices[i]] = True
noise = noise[col_indices]
noised_inputs = diffusion_input * alphas + noise * sigmas
if self.diffusion_objective == "v":
targets = noise * alphas - diffusion_input * sigmas
elif self.diffusion_objective in ["rectified_flow", "rf_denoiser"]:
targets = noise - diffusion_input
p.tick("noise")
extra_args = {}
# Compute downsampling ratio for attention mask creation
downsampling_ratio = self.diffusion.pretransform.downsampling_ratio if self.diffusion.pretransform is not None else 1
# Create augmented padding mask with random silence extension
if self.mask_padding_attention and self.silence_extension_scale_seconds > 0:
augmented_padding_mask = create_augmented_padding_mask(
padding_masks,
silence_extension_scale_seconds=self.silence_extension_scale_seconds,
sample_rate=self.sample_rate,
downsampling_ratio=downsampling_ratio,
)
else:
augmented_padding_mask = padding_masks
# Loss mask defines signal vs padding regions for loss computation
# - mask_loss_weight controls padding contribution (0 = signal only)
# - When mask_padding_attention=True: only compute loss on signal (padding saw no attention)
loss_mask = augmented_padding_mask.to(torch.bool)
# Pass padding mask for attention masking - model handles prepend extension
if self.mask_padding_attention:
extra_args["padding_mask"] = augmented_padding_mask
if self.inpainting_config is not None:
# Max mask size is the full sequence length
max_mask_length = diffusion_input.shape[2]
# Create a mask of random length for a random slice of the input
inpaint_masked_input, inpaint_mask = random_inpaint_mask(diffusion_input, padding_masks=augmented_padding_mask, mask_padding=self.mask_padding_attention, **self.inpaint_mask_kwargs)
conditioning['inpaint_mask'] = [inpaint_mask]
conditioning['inpaint_masked_input'] = [inpaint_masked_input]
# Only compute loss on inpainted region (where model is generating)
loss_mask = loss_mask & ~inpaint_mask.squeeze(1).to(torch.bool)
output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
p.tick("diffusion")
if self.log_loss_info:
# Loss debugging logs
num_loss_buckets = 10
bucket_size = 1 / num_loss_buckets
loss_all = F.mse_loss(output, targets, reduction="none")
sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
# gather loss_all across all GPUs
loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
# Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
# Log bucketed losses with corresponding sigma bucket values, if it's not NaN
debug_log_dict = {
f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
}
self.log_dict(debug_log_dict)
p.tick("loss_debug")
# Compute std only over non-padded positions when masking is active
if loss_mask is not None and self.mask_padding_attention:
mask_expanded = loss_mask.unsqueeze(1) # [B, 1, T]
std_data = diffusion_input[mask_expanded.expand_as(diffusion_input)].std()
std_targets = targets[mask_expanded.expand_as(targets)].std().detach()
else:
std_data = diffusion_input.std()
std_targets = targets.std().detach()
log_dict = {
'train/std_data': std_data,
'train/std_targets': std_targets,
'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
}
p.tick("std_compute")
# Compute normalized MSE (normalization only affects non-"none" modes)
mse_loss_full = compute_normalized_mse(output, targets, loss_mask, self.loss_normalization, self.loss_norm_eps)
p.tick("mse_loss")
# Compute loss with signal/padding separation (returns already-detached metrics)
loss, signal_mean, padding_mean = compute_masked_loss(
mse_loss_full, loss_mask, self.mask_padding_attention, self.mask_loss_weight
)
mse_loss = loss
p.tick("masked_loss")
# When attention masking is on, compute_masked_loss excludes everything outside
# loss_mask (which now excludes inpaint context). Add context reconstruction loss
# so the model learns to preserve context regions during inpainting.
# (When mask_padding_attention=False, context is already included via mask_loss_weight.)
context_loss_mean = torch.tensor(0.0, device=loss.device)
if (self.inpainting_config is not None
and self.mask_padding_attention
and self.mask_loss_weight > 0):
# Context = inpaint_mask=1 (keep) AND padding_mask=1 (real audio, not padding)
inpaint_context = inpaint_mask.squeeze(1).to(torch.bool) & augmented_padding_mask.to(torch.bool)
n_ctx = inpaint_context.sum(dim=1) * mse_loss_full.shape[1] # per-sample count
if n_ctx.sum() > 0:
context_vals = torch.where(inpaint_context.unsqueeze(1), mse_loss_full, 0.0)
context_loss_mean = (context_vals.sum(dim=(1, 2)) / (n_ctx + 1e-8)).mean()
loss = loss + context_loss_mean * self.mask_loss_weight
# Log separate signal/padding/context losses for monitoring
log_dict["train/mse_signal"] = signal_mean
log_dict["train/mse_masked_loss"] = padding_mean
log_dict["train/mse_context_loss"] = context_loss_mean.detach()
log_dict["train/mse_loss"] = mse_loss.detach()
log_dict["train/loss"] = loss.detach()
# Stash for external callbacks (e.g. loss-by-timestep logging)
self._last_t = t.detach()
self._last_per_elem_loss = mse_loss_full.detach().mean(dim=(1, 2))
self._staggered_logger.log(log_dict, self)
#p.tick("log_dict")
#print(f"Profiler: {p}")
return loss
def validation_step(self, batch, batch_idx):
reals, metadata = batch
if reals.ndim == 4 and reals.shape[0] == 1:
reals = reals[0]
diffusion_input = reals
with torch.amp.autocast("cuda"), torch.no_grad():
conditioning = self.diffusion.conditioner(metadata, self.device)
# Create batch tensor of padding masks from the metadata
if all("padding_mask" in md for md in metadata):
padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device)
else:
padding_masks = torch.ones(diffusion_input.shape[0], diffusion_input.shape[-1], dtype=torch.bool, device=self.device)
if self.diffusion.pretransform is not None:
self.diffusion.pretransform.to(self.device)
if not self.pre_encoded:
with torch.amp.autocast("cuda"), torch.no_grad():
self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad)
diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
padding_masks = resize_padding_mask(padding_masks, diffusion_input.shape[-1])
else:
# Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
diffusion_input = diffusion_input / self.diffusion.pretransform.scale
if padding_masks.shape[-1] != diffusion_input.shape[-1]:
padding_masks = resize_padding_mask(padding_masks, diffusion_input.shape[-1])
# Use padding mask directly for validation (no silence extension augmentation)
loss_mask = padding_masks.to(torch.bool)
extra_args = {}
if self.mask_padding_attention:
extra_args["padding_mask"] = padding_masks
# Set up inpainting conditioning for validation (FULL_MASK: all zeros)
if self.inpainting_config is not None:
inpaint_mask = torch.zeros(diffusion_input.shape[0], 1, diffusion_input.shape[2], device=self.device)
inpaint_masked_input = torch.zeros_like(diffusion_input)
conditioning['inpaint_mask'] = [inpaint_mask]
conditioning['inpaint_masked_input'] = [inpaint_masked_input]
for validation_timestep in self.validation_timesteps:
t = torch.full((reals.shape[0],), validation_timestep, device=self.device)
# Calculate the noise schedule parameters for those timesteps
if self.diffusion_objective in ["v"]:
alphas, sigmas = get_alphas_sigmas(t)
elif self.diffusion_objective in ["rectified_flow", "rf_denoiser"]:
alphas, sigmas = 1-t, t
# Combine the ground truth data and the noise
alphas = alphas[:, None, None]
sigmas = sigmas[:, None, None]
noise = torch.randn_like(diffusion_input)
noised_inputs = diffusion_input * alphas + noise * sigmas
if self.diffusion_objective == "v":
targets = noise * alphas - diffusion_input * sigmas
elif self.diffusion_objective in ["rectified_flow", "rf_denoiser"]:
targets = noise - diffusion_input
with torch.amp.autocast("cuda"), torch.no_grad():
output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = 0, **extra_args)
mse_loss_full = compute_normalized_mse(output, targets, loss_mask, self.loss_normalization, self.loss_norm_eps)
val_loss, _, _ = compute_masked_loss(
mse_loss_full, loss_mask, self.mask_padding_attention, self.mask_loss_weight
)
self.validation_step_outputs[f'val/loss_{validation_timestep:.1f}'].append(val_loss.item())
def on_validation_epoch_end(self):
log_dict = {}
for validation_timestep in self.validation_timesteps:
outputs_key = f'val/loss_{validation_timestep:.1f}'
val_loss = sum(self.validation_step_outputs[outputs_key]) / len(self.validation_step_outputs[outputs_key])
# Gather losses across all GPUs
val_loss = self.all_gather(val_loss).mean().item()
log_metric(self.logger, outputs_key, val_loss, step=self.global_step)
# Get average over all timesteps
val_loss = torch.tensor([val for val in self.validation_step_outputs.values()]).mean()
# Gather losses across all GPUs
val_loss = self.all_gather(val_loss).mean().item()
log_metric(self.logger, 'val/avg_loss', val_loss, step=self.global_step)
# Reset validation losses
for validation_timestep in self.validation_timesteps:
self.validation_step_outputs[f'val/loss_{validation_timestep:.1f}'] = []
def export_model(self, path, use_safetensors=False):
if self.diffusion_ema is not None:
self.diffusion.model = self.diffusion_ema.ema_model
if use_safetensors:
save_file(self.diffusion.state_dict(), path)
else:
torch.save({"state_dict": self.diffusion.state_dict()}, path)
def export_lora_safetensors(self, path):
"""Export LoRA weights as a safetensors file with embedded config."""
if self.lora_config is None:
raise ValueError("No LoRA config -- this wrapper is not in LoRA mode")
state_dict = {
**get_lora_state_dict(self.diffusion.model),
**get_lora_state_dict(self.diffusion.conditioner)
}
save_lora_safetensors(state_dict, self.lora_config, path)
def on_save_checkpoint(self, checkpoint):
if self.lora_config is not None:
checkpoint.clear()
checkpoint['state_dict'] = {
**get_lora_state_dict(self.diffusion.model),
**get_lora_state_dict(self.diffusion.conditioner)
}
checkpoint['lora_config'] = self.lora_config
class DiffusionCondInpaintDemoCallback(pl.Callback):
def __init__(
self,
demo_every=2000,
demo_steps=250,
sample_size=65536,
sample_rate=48000,
demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7],
demo_conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None,
inpaint_demo_config: tp.Optional[tp.Dict[str, int]] = None,
num_demos: int = 0,
demo_dl=None,
):
super().__init__()
self.demo_every = demo_every
self.demo_steps = demo_steps
self.demo_samples = sample_size
self.sample_rate = sample_rate
self.demo_cfg_scales = demo_cfg_scales
self.demo_conditioning = demo_conditioning or []
self.last_demo_step = -1
# Map config keys to MaskType enum
self._mask_type_map = {
"num_random_segments": MaskType.RANDOM_SEGMENTS,
"num_full_mask": MaskType.FULL_MASK,
"num_causal": MaskType.CAUSAL_MASK,
"num_random_spans": MaskType.RANDOM_SPANS,
}
# Legacy fallback: if no inpaint_demo_config but num_demos is set,
# use num_demos items with random mask sampling (old behavior)
if inpaint_demo_config is not None:
self.inpaint_demo_config = inpaint_demo_config
self.legacy_inpaint_demos = False
elif num_demos > 0:
self.inpaint_demo_config = {}
self.legacy_inpaint_demos = True
self.legacy_num_demos = num_demos
else:
self.inpaint_demo_config = {}
self.legacy_inpaint_demos = False
# Total inpainting demos needed from batch
if self.legacy_inpaint_demos:
self.num_inpaint_demos = self.legacy_num_demos
else:
self.num_inpaint_demos = sum(
self.inpaint_demo_config.get(k, 0) for k in self._mask_type_map
)
if demo_dl is not None:
self.demo_dl = iter(demo_dl)
else:
self.demo_dl = None
self._teacher_demo_done = False
def _generate_prompt_demos(self, module, trainer, is_rank_zero=True):
"""Generate full t2m demos from specified prompts (FULL_MASK)."""
if not self.demo_conditioning:
return [], []
demo_cond = self.demo_conditioning
num_demos = len(demo_cond)
demo_samples = self.demo_samples
if module.diffusion.pretransform is not None:
demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio
# Conditioning from prompts
conditioning = module.diffusion.conditioner(demo_cond, module.device)
# FULL_MASK: all-zero inpaint conditioning
io_channels = module.diffusion.io_channels
inpaint_mask = torch.zeros(num_demos, 1, demo_samples, device=module.device)
inpaint_masked_input = torch.zeros(num_demos, io_channels, demo_samples, device=module.device)
conditioning['inpaint_mask'] = [inpaint_mask]
conditioning['inpaint_masked_input'] = [inpaint_masked_input]
cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
noise = torch.randn(num_demos, io_channels, demo_samples, device=module.device)
model_dtype = next(module.diffusion.parameters()).dtype
noise = noise.to(model_dtype)
per_elem_trim = compute_per_elem_trim(demo_cond, self.sample_rate, margin_seconds=2)
model = module.diffusion_ema.ema_model if module.diffusion_ema is not None else module.diffusion.model
all_audio = []
all_context_masks = []
for cfg_scale in self.demo_cfg_scales:
if is_rank_zero:
print(f"Generating prompt demos for cfg scale {cfg_scale}")
with torch.amp.autocast("cuda"):
fakes = sample_diffusion(
model=model,
noise=noise,
cond_inputs=cond_inputs,
diffusion_objective=module.diffusion_objective,
steps=self.demo_steps,
cfg_scale=cfg_scale,
conditioning=demo_cond,
sample_rate=self.sample_rate,
pretransform=module.diffusion.pretransform,
mask_padding_attention=module.diffusion.mask_padding_attention,
use_effective_length_for_schedule=module.diffusion.use_effective_length_for_schedule,
headroom_seconds=5.0,
dist_shift=module.diffusion.sampling_dist_shift,
batch_cfg=True,
disable_tqdm=not is_rank_zero,
decode=True
)
fakes = trim_and_concat(fakes, per_elem_trim)
all_audio.append(fakes)
# Latent-resolution all-zeros mask (no context for prompt demos),
# trimmed to match the per-element audio durations
ds_ratio = module.diffusion.pretransform.downsampling_ratio if module.diffusion.pretransform is not None else 1
latent_trim = [t // ds_ratio if t is not None else None for t in per_elem_trim] if per_elem_trim is not None else None
latent_mask = torch.zeros(num_demos, 1, demo_samples)
context_mask = trim_and_concat(latent_mask, latent_trim).squeeze(0).cpu()
all_context_masks = [context_mask] * len(self.demo_cfg_scales)
del noise, conditioning, cond_inputs, inpaint_mask, inpaint_masked_input
torch.cuda.empty_cache()
return all_audio, all_context_masks
def _generate_inpaint_demos(self, module, trainer, is_rank_zero=True):
"""Generate inpainting demos from batch data with forced mask types."""
if self.num_inpaint_demos == 0 or self.demo_dl is None:
return [], []
demo_reals, metadata = next(self.demo_dl)
if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
demo_reals = demo_reals[0]
demo_reals = demo_reals[:self.num_inpaint_demos]
metadata = metadata[:self.num_inpaint_demos]
model_dtype = next(module.diffusion.parameters()).dtype
demo_reals = demo_reals.to(module.device, dtype=model_dtype)
if not module.pre_encoded:
if module.diffusion.pretransform is not None:
module.diffusion.pretransform.to(module.device)
demo_reals = module.diffusion.pretransform.encode(demo_reals)
else:
if hasattr(module.diffusion.pretransform, "scale") and module.diffusion.pretransform.scale != 1.0:
demo_reals = demo_reals / module.diffusion.pretransform.scale
padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(module.device)
if padding_masks.shape[-1] != demo_reals.shape[-1]:
padding_masks = resize_padding_mask(padding_masks, demo_reals.shape[-1])
mask_padding = module.diffusion.mask_padding_attention
if self.legacy_inpaint_demos:
# Legacy: random mask type sampling (old behavior)
masked_input, mask = random_inpaint_mask(
demo_reals, padding_masks=padding_masks,
mask_padding=mask_padding,
**module.inpaint_mask_kwargs
)
else:
# New: forced mask types per config
all_masks = []
all_masked_inputs = []
idx = 0
for config_key, mask_type in self._mask_type_map.items():
count = self.inpaint_demo_config.get(config_key, 0)
if count == 0:
continue
subset_reals = demo_reals[idx:idx+count]
subset_padding = padding_masks[idx:idx+count]
mi, m = random_inpaint_mask(
subset_reals, padding_masks=subset_padding,
mask_padding=mask_padding, force_mask_type=mask_type,
**module.inpaint_mask_kwargs
)
all_masks.append(m)
all_masked_inputs.append(mi)
idx += count
mask = torch.cat(all_masks, dim=0)
masked_input = torch.cat(all_masked_inputs, dim=0)
conditioning = module.diffusion.conditioner(metadata, module.device)
conditioning['inpaint_mask'] = [mask]
conditioning['inpaint_masked_input'] = [masked_input]
cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
demo_samples = demo_reals.shape[2]
noise = torch.randn(demo_reals.shape[0], module.diffusion.io_channels, demo_samples, device=module.device)
model_dtype = next(module.diffusion.parameters()).dtype
noise = noise.to(model_dtype)
per_elem_trim = compute_per_elem_trim(metadata, self.sample_rate, margin_seconds=2)
# Trim and concatenate context mask at latent resolution,
# using same trimming basis as audio (per_elem_trim // ds_ratio)
ds_ratio = module.diffusion.pretransform.downsampling_ratio if module.diffusion.pretransform is not None else 1
latent_trim = [t // ds_ratio if t is not None else None for t in per_elem_trim] if per_elem_trim is not None else None
# Zero out padding region in mask for display — the mask is initialized to 1,
# so without mask_padding the padding frames show as false context in the overlay
display_mask = mask * padding_masks.unsqueeze(1)
context_mask = trim_and_concat(display_mask, latent_trim).squeeze(0).cpu()
model = module.diffusion_ema.ema_model if module.diffusion_ema is not None else module.diffusion.model
all_audio = []
all_context_masks = []
for cfg_scale in self.demo_cfg_scales:
if is_rank_zero:
print(f"Generating inpaint demos for cfg scale {cfg_scale}")
with torch.amp.autocast("cuda"):
fakes = sample_diffusion(
model=model,
noise=noise,
cond_inputs=cond_inputs,
diffusion_objective=module.diffusion_objective,
steps=self.demo_steps,
cfg_scale=cfg_scale,
conditioning=metadata,
sample_rate=self.sample_rate,
pretransform=module.diffusion.pretransform,
mask_padding_attention=module.diffusion.mask_padding_attention,
use_effective_length_for_schedule=module.diffusion.use_effective_length_for_schedule,
headroom_seconds=5.0,
dist_shift=module.diffusion.sampling_dist_shift,
batch_cfg=True,
disable_tqdm=not is_rank_zero,
decode=True
)
fakes = trim_and_concat(fakes, per_elem_trim)
all_audio.append(fakes)
all_context_masks.append(context_mask)
del noise, conditioning, cond_inputs, mask, masked_input, padding_masks, demo_reals
torch.cuda.empty_cache()
return all_audio, all_context_masks
@torch.no_grad()
def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx):
if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
return
is_rank_zero = get_rank() == 0
module.eval()
self.last_demo_step = trainer.global_step
try:
# Generate both types of demos, freeing intermediates between phases
prompt_audio, prompt_masks = self._generate_prompt_demos(module, trainer, is_rank_zero)
torch.cuda.empty_cache()
inpaint_audio, inpaint_masks = self._generate_inpaint_demos(module, trainer, is_rank_zero)
torch.cuda.empty_cache()
# Combine per cfg scale (prompt_audio and inpaint_audio have one entry per cfg scale)
if is_rank_zero:
for i, cfg_scale in enumerate(self.demo_cfg_scales):
parts = []
mask_parts = []
if i < len(prompt_audio):
parts.append(prompt_audio[i])
mask_parts.append(prompt_masks[i])
if i < len(inpaint_audio):
parts.append(inpaint_audio[i])
mask_parts.append(inpaint_masks[i])
if not parts:
continue
combined_audio = torch.cat(parts, dim=-1)
combined_mask = torch.cat(mask_parts, dim=-1) if mask_parts else None
filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
combined_audio = combined_audio.to(torch.float32).div(torch.max(torch.abs(combined_audio))).mul(32767).to(torch.int16).cpu()
torchaudio.save(filename, combined_audio, self.sample_rate)
log_audio(trainer.logger, f'demo_cfg_{cfg_scale}', filename, self.sample_rate)
log_image(trainer.logger, f'demo_melspec_left_cfg_{cfg_scale}', audio_spectrogram_image(combined_audio, context_mask=combined_mask))
if isinstance(trainer.logger, (WandbLogger, CometLogger)):
os.remove(filename)
# Teacher ODE warmup diagnostic: mirror the exact ODE warmup sample_diffusion call
# and decode the target to verify teacher output quality.
# Only runs on the first demo.
# Generates both prompt and inpaint demos, consistent with the main callback.
teacher_ref = getattr(module, '_teacher', None) or getattr(module, 'teacher_model', None)
if not self._teacher_demo_done and teacher_ref is not None:
self._teacher_demo_done = True
if is_rank_zero:
print("Generating teacher ODE warmup diagnostic")
try:
pretransform = module.diffusion.pretransform # Shared pretransform (not on teacher)
io_channels = teacher_ref.io_channels
ode_warmup_config = getattr(module, 'ode_warmup_config', {})
teacher_cfg = getattr(module, 'ode_warmup_cfg', self.demo_cfg_scales[0])
ode_steps = getattr(module, 'ode_n_sampling_steps', 20)
mask_padding = module.diffusion.mask_padding_attention
ds_ratio = pretransform.downsampling_ratio if pretransform is not None else 1
# --- Teacher prompt demos (FULL_MASK, same as _generate_prompt_demos) ---
prompt_target = None
prompt_per_elem_trim = None
prompt_context_mask = None
demo_cond = self.demo_conditioning
if demo_cond:
num_demos = len(demo_cond)
demo_samples = self.demo_samples
if pretransform is not None:
demo_samples = demo_samples // ds_ratio
with torch.no_grad():
teacher_conditioning = teacher_ref.conditioner(demo_cond, module.device)
inpaint_mask = torch.zeros(num_demos, 1, demo_samples, device=module.device)
inpaint_masked_input = torch.zeros(num_demos, io_channels, demo_samples, device=module.device)
teacher_conditioning['inpaint_mask'] = [inpaint_mask]
teacher_conditioning['inpaint_masked_input'] = [inpaint_masked_input]
with torch.no_grad():
teacher_cond_inputs = teacher_ref.get_conditioning_inputs(teacher_conditioning)
noise = torch.randn(num_demos, io_channels, demo_samples, device=module.device)
noise = noise.to(next(teacher_ref.parameters()).dtype)
prompt_per_elem_trim = compute_per_elem_trim(demo_cond, self.sample_rate, margin_seconds=2)
prompt_target = sample_diffusion(
model=teacher_ref.model,
noise=noise,
cond_inputs=teacher_cond_inputs,
diffusion_objective=teacher_ref.diffusion_objective,
steps=ode_steps,
cfg_scale=teacher_cfg,
conditioning=demo_cond,
sample_rate=teacher_ref.sample_rate,
pretransform=pretransform,
mask_padding_attention=mask_padding,
use_effective_length_for_schedule=module.diffusion.use_effective_length_for_schedule,
padding_mask=None,
dist_shift=teacher_ref.sampling_dist_shift,
sampler_type=ode_warmup_config.get('sampler', 'dpmpp'),
batch_cfg=True,
disable_tqdm=not is_rank_zero,
decode=False,
)
prompt_latent_trim = [t // ds_ratio if t is not None else None for t in prompt_per_elem_trim] if prompt_per_elem_trim is not None else None
prompt_context_mask = trim_and_concat(
torch.zeros(num_demos, 1, demo_samples), prompt_latent_trim
).squeeze(0).cpu()
# --- Teacher inpaint demos (same mask logic as _generate_inpaint_demos) ---
inpaint_target = None
inpaint_per_elem_trim = None
inpaint_context_mask = None
if self.num_inpaint_demos > 0 and self.demo_dl is not None:
try:
inpaint_reals, inpaint_metadata = next(self.demo_dl)
if inpaint_reals.ndim == 4 and inpaint_reals.shape[0] == 1:
inpaint_reals = inpaint_reals[0]
inpaint_reals = inpaint_reals[:self.num_inpaint_demos]
inpaint_metadata = inpaint_metadata[:self.num_inpaint_demos]
inpaint_reals = inpaint_reals.to(module.device)
if not module.pre_encoded:
if pretransform is not None:
inpaint_reals = pretransform.encode(inpaint_reals)
else:
if hasattr(pretransform, "scale") and pretransform.scale != 1.0:
inpaint_reals = inpaint_reals / pretransform.scale
inpaint_padding_masks = torch.stack(
[md["padding_mask"][0] for md in inpaint_metadata], dim=0
).to(module.device)
if self.legacy_inpaint_demos:
masked_input, mask = random_inpaint_mask(
inpaint_reals, padding_masks=inpaint_padding_masks,
mask_padding=mask_padding, **module.inpaint_mask_kwargs
)
else:
all_masks = []
all_masked_inputs = []
idx = 0
for config_key, mask_type in self._mask_type_map.items():
count = self.inpaint_demo_config.get(config_key, 0)
if count == 0:
continue
mi, m = random_inpaint_mask(
inpaint_reals[idx:idx+count],
padding_masks=inpaint_padding_masks[idx:idx+count],
mask_padding=mask_padding, force_mask_type=mask_type,
**module.inpaint_mask_kwargs
)
all_masks.append(m)
all_masked_inputs.append(mi)
idx += count
mask = torch.cat(all_masks, dim=0)
masked_input = torch.cat(all_masked_inputs, dim=0)
with torch.no_grad():
inpaint_teacher_cond = teacher_ref.conditioner(inpaint_metadata, module.device)
inpaint_teacher_cond['inpaint_mask'] = [mask]
inpaint_teacher_cond['inpaint_masked_input'] = [masked_input]
with torch.no_grad():
inpaint_cond_inputs = teacher_ref.get_conditioning_inputs(inpaint_teacher_cond)
inpaint_samples = inpaint_reals.shape[2]
inpaint_noise = torch.randn(
inpaint_reals.shape[0], io_channels, inpaint_samples, device=module.device
).to(next(teacher_ref.parameters()).dtype)
inpaint_per_elem_trim = compute_per_elem_trim(inpaint_metadata, self.sample_rate, margin_seconds=2)
inpaint_target = sample_diffusion(
model=teacher_ref.model,
noise=inpaint_noise,
cond_inputs=inpaint_cond_inputs,
diffusion_objective=teacher_ref.diffusion_objective,
steps=ode_steps,
cfg_scale=teacher_cfg,
conditioning=inpaint_metadata,
sample_rate=teacher_ref.sample_rate,
pretransform=pretransform,
mask_padding_attention=mask_padding,
use_effective_length_for_schedule=module.diffusion.use_effective_length_for_schedule,
padding_mask=None,
dist_shift=teacher_ref.sampling_dist_shift,
sampler_type=ode_warmup_config.get('sampler', 'dpmpp'),
batch_cfg=True,
disable_tqdm=not is_rank_zero,
decode=False,
)
# Context mask for overlay (same as _generate_inpaint_demos)
display_mask = mask * inpaint_padding_masks.unsqueeze(1)
inpaint_latent_trim = [t // ds_ratio if t is not None else None for t in inpaint_per_elem_trim] if inpaint_per_elem_trim is not None else None
inpaint_context_mask = trim_and_concat(display_mask, inpaint_latent_trim).squeeze(0).cpu()
except StopIteration:
if is_rank_zero:
print("Teacher diagnostic: no inpaint batch available from demo_dl")
# --- Combine and log (same pattern as main callback) ---
if is_rank_zero:
parts = []
mask_parts = []
if prompt_target is not None:
decoded_prompt = pretransform.decode(prompt_target.float())
decoded_prompt = trim_and_concat(decoded_prompt, prompt_per_elem_trim)
parts.append(decoded_prompt)
mask_parts.append(prompt_context_mask)
if inpaint_target is not None:
decoded_inpaint = pretransform.decode(inpaint_target.float())
decoded_inpaint = trim_and_concat(decoded_inpaint, inpaint_per_elem_trim)
parts.append(decoded_inpaint)
mask_parts.append(inpaint_context_mask)
if parts:
combined_audio = torch.cat(parts, dim=-1)
combined_mask = torch.cat(mask_parts, dim=-1) if mask_parts else None
filename = f'demo_teacher_target_{trainer.global_step:08}.wav'
combined_audio = combined_audio.to(torch.float32).div(torch.max(torch.abs(combined_audio))).mul(32767).to(torch.int16).cpu()
torchaudio.save(filename, combined_audio, self.sample_rate)
log_audio(trainer.logger, f'demo_teacher_target', filename, self.sample_rate)
log_image(trainer.logger, f'demo_teacher_target_melspec', audio_spectrogram_image(combined_audio, context_mask=combined_mask))
os.remove(filename)
del prompt_target, inpaint_target
except Exception as e:
if is_rank_zero:
print(f"Teacher ODE warmup diagnostic failed: {e}")
import traceback
traceback.print_exc()
except Exception as e:
if is_rank_zero:
print(f'{type(e).__name__}: {e}')
raise e
finally:
gc.collect()
torch.cuda.empty_cache()
module.train()