Spaces:
Running on Zero
Running on Zero
| import math | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.loggers import WandbLogger, CometLogger | |
| import os | |
| import torch | |
| import gc | |
| import typing as tp | |
| import torchaudio | |
| from einops import rearrange | |
| from safetensors.torch import save_file | |
| from functools import partial | |
| from torch.nn import functional as F | |
| from ..interface.aeiou import audio_spectrogram_image | |
| from ..inference.sampling import truncated_logistic_normal_rescaled, sample_timesteps_logsnr, sample_timesteps_logsnr_uniform, sample_diffusion | |
| from ..models.diffusion import ConditionedDiffusionModelWrapper | |
| from ..models.inpainting import random_inpaint_mask, MaskType | |
| from ..models.lora import add_lora, get_lora_params, get_lora_state_dict, LoRAParametrization, get_lora_layers, save_lora_safetensors, resolve_adapter_type, prepare_dora_state_dict, cast_base_to_precision | |
| from .utils import create_optimizer_from_config, create_scheduler_from_config, log_audio, log_image, log_metric, get_rank, create_augmented_padding_mask, compute_masked_loss, compute_normalized_mse, resize_padding_mask, StaggeredLogger, compute_per_elem_trim, trim_and_concat | |
| from time import time | |
| class Profiler: | |
| def __init__(self): | |
| self.ticks = [[time(), None]] | |
| def tick(self, msg): | |
| self.ticks.append([time(), msg]) | |
| def __repr__(self): | |
| rep = 80 * "=" + "\n" | |
| for i in range(1, len(self.ticks)): | |
| msg = self.ticks[i][1] | |
| ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] | |
| rep += msg + f": {ellapsed*1000:.2f}ms\n" | |
| rep += 80 * "=" + "\n\n\n" | |
| return rep | |
| class DiffusionCondTrainingWrapper(pl.LightningModule): | |
| ''' | |
| Wrapper for training a conditional audio diffusion model. | |
| ''' | |
| def __init__( | |
| self, | |
| model: ConditionedDiffusionModelWrapper, | |
| lr: float = None, | |
| mask_loss_weight: float = 0.0, | |
| mask_padding_attention: bool = False, | |
| silence_extension_scale_seconds: float = 0.0, | |
| use_ema: bool = True, | |
| log_loss_info: bool = False, | |
| optimizer_configs: dict = None, | |
| pre_encoded: bool = False, | |
| cfg_dropout_prob = 0.1, | |
| timestep_sampler: tp.Literal["uniform", "logit_normal", "trunc_logit_normal", "log_snr", "log_snr_uniform"] = "uniform", | |
| timestep_sampler_options: tp.Optional[tp.Dict[str, tp.Any]] = None, | |
| validation_timesteps = [0.1, 0.3, 0.5, 0.7, 0.9], | |
| p_one_shot: float = 0.0, | |
| inpainting_config: dict = None, | |
| use_effective_length_for_schedule: bool = False, | |
| sample_rate: int = 44100, | |
| sample_size: int = None, | |
| loss_normalization: tp.Literal["none", "timestep", "sample", "sample_channel"] = "none", | |
| loss_norm_eps: float = 1e-6, | |
| lora_config: tp.Optional[tp.Dict[str, tp.Any]] = None, | |
| lora_state_dict: tp.Optional[tp.Dict[str, tp.Any]] = None, | |
| svd_bases_path: tp.Optional[str] = None, | |
| log_every_n_steps: int = 10, | |
| ot_coupling: bool = False, | |
| base_precision: tp.Optional[str] = None, | |
| ): | |
| super().__init__() | |
| self.ot_coupling = ot_coupling | |
| self.diffusion = model | |
| self.lora_config = lora_config | |
| if self.lora_config is not None: | |
| # Don't use EMA with LoRA | |
| use_ema = False | |
| # Freeze the pre-trained model weights | |
| self.diffusion.model.eval().requires_grad_(False) | |
| self.diffusion.conditioner.eval().requires_grad_(False) | |
| rank = self.lora_config.get("rank", 8) | |
| lora_alpha = self.lora_config.get("alpha", rank) | |
| adapter_type = self.lora_config.get("adapter_type", "lora") | |
| include = self.lora_config.get("include", None) | |
| exclude = self.lora_config.get("exclude", None) | |
| # Resolve legacy "dora" to rows/cols variant | |
| adapter_type = resolve_adapter_type(adapter_type, lora_state_dict) | |
| print(f"LoRA config: rank={rank}, alpha={lora_alpha}, adapter_type={adapter_type}") | |
| if include: | |
| print(f" include: {include}") | |
| if exclude: | |
| print(f" exclude: {exclude}") | |
| # Load pre-computed SVD bases for -XS adapter types | |
| svd_bases = None | |
| if svd_bases_path is not None: | |
| print(f"Loading SVD bases from {svd_bases_path}") | |
| svd_bases = torch.load(svd_bases_path, map_location="cpu", weights_only=True) | |
| elif adapter_type.endswith("-xs"): | |
| print("WARNING: -XS adapter without svd_bases_path — SVD will be computed per layer") | |
| lora_config = { | |
| torch.nn.Linear: { | |
| "weight": partial(LoRAParametrization.from_linear, rank=rank, lora_alpha=lora_alpha, adapter_type=adapter_type), | |
| }, | |
| torch.nn.Conv1d: { | |
| "weight": partial(LoRAParametrization.from_conv1d, rank=rank, lora_alpha=lora_alpha, adapter_type=adapter_type), | |
| } | |
| } | |
| # Add LoRA to the model | |
| add_lora(self.diffusion.model, lora_config, include=include, exclude=exclude, svd_bases=svd_bases) | |
| # Add LoRA to the conditioner | |
| add_lora(self.diffusion.conditioner, lora_config, include=include, exclude=exclude, svd_bases=svd_bases) | |
| print("lora layers:", len(get_lora_layers(self.diffusion))) | |
| if lora_state_dict is not None: | |
| # Old DoRA checkpoints saved magnitude as 2D (1,fan_in) or (fan_out,1); | |
| # current code expects 1D. Squeeze so old checkpoints still load. | |
| prepare_dora_state_dict(lora_state_dict) | |
| self.diffusion.model.load_state_dict(lora_state_dict, strict=False) | |
| self.diffusion.conditioner.load_state_dict(lora_state_dict, strict=False) | |
| # Cast frozen base weights to lower precision if requested | |
| if base_precision: | |
| cast_base_to_precision(self.diffusion.model, base_precision) | |
| cast_base_to_precision(self.diffusion.conditioner, base_precision) | |
| if self.diffusion.pretransform is not None: | |
| self.diffusion.pretransform.to( | |
| torch.bfloat16 if base_precision in ("bf16", "bfloat16") else torch.float16 | |
| ) | |
| self.diffusion_ema = None | |
| self.mask_loss_weight = mask_loss_weight | |
| # Attention masking for padded tokens | |
| # Backward compat: if passed from training config, propagate to model | |
| if mask_padding_attention and not self.diffusion.mask_padding_attention: | |
| import warnings | |
| warnings.warn("mask_padding_attention in training config is deprecated. Move to model.diffusion config.", FutureWarning) | |
| self.diffusion.mask_padding_attention = mask_padding_attention | |
| self.mask_padding_attention = self.diffusion.mask_padding_attention | |
| self.silence_extension_scale_seconds = silence_extension_scale_seconds | |
| self.cfg_dropout_prob = cfg_dropout_prob | |
| self.rng = torch.quasirandom.SobolEngine(1, scramble=True) | |
| self.timestep_sampler = timestep_sampler | |
| self.timestep_sampler_options = {} if timestep_sampler_options is None else timestep_sampler_options | |
| if self.timestep_sampler == "log_snr": | |
| self.mean_logsnr = self.timestep_sampler_options.get("mean_logsnr", -1.2) | |
| self.std_logsnr = self.timestep_sampler_options.get("std_logsnr", 2.0) | |
| elif self.timestep_sampler == "log_snr_uniform": | |
| self.min_logsnr = self.timestep_sampler_options.get("min_logsnr", -6.0) | |
| self.max_logsnr = self.timestep_sampler_options.get("max_logsnr", 5.0) | |
| self.p_one_shot = p_one_shot | |
| self.diffusion_objective = model.diffusion_objective | |
| self.log_loss_info = log_loss_info | |
| self._staggered_logger = StaggeredLogger(every_n_steps=log_every_n_steps) | |
| assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" | |
| if optimizer_configs is None: | |
| optimizer_configs = { | |
| "diffusion": { | |
| "optimizer": { | |
| "type": "Adam", | |
| "config": { | |
| "lr": lr | |
| } | |
| } | |
| } | |
| } | |
| else: | |
| if lr is not None: | |
| print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") | |
| self.optimizer_configs = optimizer_configs | |
| self.pre_encoded = pre_encoded | |
| # Loss normalization by target magnitude | |
| # Options: "none", "timestep", "sample", "sample_channel" | |
| self.loss_normalization = loss_normalization | |
| self.loss_norm_eps = loss_norm_eps | |
| # Inpainting | |
| self.inpainting_config = inpainting_config | |
| if self.inpainting_config is not None: | |
| self.inpaint_mask_kwargs = self.inpainting_config.get("mask_kwargs", {}) | |
| # Per-element schedule shift based on effective (unpadded) sequence length | |
| # Backward compat: if passed from training config, propagate to model | |
| if use_effective_length_for_schedule and not self.diffusion.use_effective_length_for_schedule: | |
| import warnings | |
| warnings.warn("use_effective_length_for_schedule in training config is deprecated. Move to model.diffusion config.", DeprecationWarning) | |
| self.diffusion.use_effective_length_for_schedule = use_effective_length_for_schedule | |
| self.use_effective_length_for_schedule = self.diffusion.use_effective_length_for_schedule | |
| self.sample_rate = sample_rate | |
| self.sample_size = sample_size | |
| # FSDP | |
| self.use_fsdp = False | |
| # Validation | |
| self.validation_timesteps = validation_timesteps | |
| self.validation_step_outputs = {} | |
| for validation_timestep in self.validation_timesteps: | |
| self.validation_step_outputs[f'val/loss_{validation_timestep:.1f}'] = [] | |
| def configure_optimizers(self): | |
| diffusion_opt_config = self.optimizer_configs['diffusion'] | |
| if self.lora_config is not None: | |
| opt_params = [*get_lora_params(self.diffusion.model), *get_lora_params(self.diffusion.conditioner)] | |
| elif diffusion_opt_config['optimizer'].get('type') == 'MuonAdamW': | |
| # Pass (name, param) tuples so MuonAdamW can match fused layer patterns | |
| opt_params = [(n, p) for n, p in self.diffusion.named_parameters() if p.requires_grad] | |
| else: | |
| # Only include parameters that require gradients (excludes frozen pretransform, conditioner, etc.) | |
| opt_params = [p for p in self.diffusion.parameters() if p.requires_grad] | |
| opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], opt_params) | |
| if "scheduler" in diffusion_opt_config: | |
| sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) | |
| sched_diff_config = { | |
| "scheduler": sched_diff, | |
| "interval": "step" | |
| } | |
| return [opt_diff], [sched_diff_config] | |
| return [opt_diff] | |
| def training_step(self, batch, batch_idx): | |
| reals, metadata = batch | |
| p = Profiler() | |
| if reals.ndim == 4 and reals.shape[0] == 1: | |
| reals = reals[0] | |
| diffusion_input = reals | |
| p.tick("setup") | |
| #with torch.amp.autocast(device_type="cuda"): | |
| conditioning = self.diffusion.conditioner(metadata, self.device) | |
| # Create batch tensor of padding masks from the metadata | |
| # If padding_mask not provided, assume all positions are valid (no padding) | |
| if all("padding_mask" in md for md in metadata): | |
| padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length) | |
| else: | |
| # All-True mask: everything is signal, no padding | |
| padding_masks = torch.ones(diffusion_input.shape[0], diffusion_input.shape[-1], dtype=torch.bool, device=self.device) | |
| p.tick("conditioning") | |
| if self.diffusion.pretransform is not None: | |
| self.diffusion.pretransform.to(self.device) | |
| if not self.pre_encoded: | |
| with torch.cuda.amp.autocast(), torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): | |
| self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) | |
| diffusion_input = self.diffusion.pretransform.encode(diffusion_input) | |
| p.tick("pretransform") | |
| padding_masks = resize_padding_mask(padding_masks, diffusion_input.shape[-1]) | |
| else: | |
| # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run | |
| if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: | |
| diffusion_input = diffusion_input / self.diffusion.pretransform.scale | |
| if padding_masks.shape[-1] != diffusion_input.shape[-1]: | |
| padding_masks = resize_padding_mask(padding_masks, diffusion_input.shape[-1]) | |
| if self.timestep_sampler == "uniform": | |
| # Draw uniformly distributed continuous timesteps | |
| t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) | |
| elif self.timestep_sampler == "logit_normal": | |
| t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) | |
| elif self.timestep_sampler == "trunc_logit_normal": | |
| # Draw from logistic truncated normal distribution | |
| t = truncated_logistic_normal_rescaled(reals.shape[0]).to(self.device) | |
| # Flip the distribution | |
| t = 1 - t | |
| elif self.timestep_sampler == "log_snr": | |
| t = sample_timesteps_logsnr(reals.shape[0], mean_logsnr=self.mean_logsnr, std_logsnr=self.std_logsnr).to(self.device) | |
| elif self.timestep_sampler == "log_snr_uniform": | |
| t = sample_timesteps_logsnr_uniform(reals.shape[0], min_logsnr=self.min_logsnr, max_logsnr=self.max_logsnr).to(self.device) | |
| else: | |
| raise ValueError(f"Invalid timestep_sampler: {self.timestep_sampler}") | |
| if self.diffusion.dist_shift is not None: | |
| # Compute sequence length for schedule shift | |
| if self.use_effective_length_for_schedule: | |
| # Use per-element effective lengths derived from seconds_total (rounded up) | |
| # This matches inference which computes effective length from seconds_total conditioning | |
| # Fall back to padding_masks.sum() if seconds_total is not available | |
| if all("seconds_total" in md for md in metadata): | |
| downsampling_ratio = self.diffusion.pretransform.downsampling_ratio if self.diffusion.pretransform is not None else 1 | |
| effective_seq_len = torch.tensor( | |
| [int(math.ceil(int(md["seconds_total"] * self.sample_rate) / downsampling_ratio)) for md in metadata], | |
| device=self.device | |
| ) | |
| else: | |
| # Fallback: use padding mask sum | |
| effective_seq_len = padding_masks.sum(dim=-1) | |
| else: | |
| # Use total sequence length (original behavior) | |
| effective_seq_len = diffusion_input.shape[2] | |
| # Shift the distribution | |
| t = self.diffusion.dist_shift.shift(t, effective_seq_len) | |
| if self.p_one_shot > 0: | |
| # Set t to 1 with probability p_one_shot | |
| t = torch.where(torch.rand_like(t) < self.p_one_shot, torch.ones_like(t), t) | |
| # Calculate the noise schedule parameters for those timesteps | |
| if self.diffusion_objective in ["rectified_flow", "rf_denoiser"]: | |
| alphas, sigmas = 1-t, t | |
| # Combine the ground truth data and the noise | |
| alphas = alphas[:, None, None] | |
| sigmas = sigmas[:, None, None] | |
| noise = torch.randn_like(diffusion_input) | |
| # Minibatch OT coupling: find optimal noise permutation for straighter transport paths | |
| # Based on MelodyFlow (arXiv:2407.03648v2) Section 2.5.2 | |
| # Uses GPU-only Sinkhorn approximation to avoid CPU sync | |
| if self.ot_coupling and diffusion_input.shape[0] > 1: | |
| with torch.no_grad(): | |
| # Flatten to [batch, features] for distance computation | |
| data_flat = diffusion_input.reshape(diffusion_input.shape[0], -1) | |
| noise_flat = noise.reshape(noise.shape[0], -1) | |
| # Squared L2 cost via matmul (faster than cdist, same optimal assignment) | |
| aa = (data_flat * data_flat).sum(dim=1, keepdim=True) | |
| bb = (noise_flat * noise_flat).sum(dim=1, keepdim=True) | |
| cost_matrix = aa + bb.T - 2.0 * (data_flat @ noise_flat.T) | |
| # Sinkhorn assignment (GPU-only, no CPU sync) | |
| log_P = -cost_matrix / cost_matrix.detach().mean() # normalize for numerical stability | |
| for _ in range(20): | |
| log_P = log_P - torch.logsumexp(log_P, dim=1, keepdim=True) | |
| log_P = log_P - torch.logsumexp(log_P, dim=0, keepdim=True) | |
| # Sequential assignment from soft permutation matrix (guarantees valid permutation) | |
| P = log_P.exp() | |
| B = P.shape[0] | |
| col_indices = torch.empty(B, dtype=torch.long, device=P.device) | |
| used = torch.zeros(B, dtype=torch.bool, device=P.device) | |
| for i in range(B): | |
| P[i, used] = -1 | |
| col_indices[i] = P[i].argmax() | |
| used[col_indices[i]] = True | |
| noise = noise[col_indices] | |
| noised_inputs = diffusion_input * alphas + noise * sigmas | |
| if self.diffusion_objective == "v": | |
| targets = noise * alphas - diffusion_input * sigmas | |
| elif self.diffusion_objective in ["rectified_flow", "rf_denoiser"]: | |
| targets = noise - diffusion_input | |
| p.tick("noise") | |
| extra_args = {} | |
| # Compute downsampling ratio for attention mask creation | |
| downsampling_ratio = self.diffusion.pretransform.downsampling_ratio if self.diffusion.pretransform is not None else 1 | |
| # Create augmented padding mask with random silence extension | |
| if self.mask_padding_attention and self.silence_extension_scale_seconds > 0: | |
| augmented_padding_mask = create_augmented_padding_mask( | |
| padding_masks, | |
| silence_extension_scale_seconds=self.silence_extension_scale_seconds, | |
| sample_rate=self.sample_rate, | |
| downsampling_ratio=downsampling_ratio, | |
| ) | |
| else: | |
| augmented_padding_mask = padding_masks | |
| # Loss mask defines signal vs padding regions for loss computation | |
| # - mask_loss_weight controls padding contribution (0 = signal only) | |
| # - When mask_padding_attention=True: only compute loss on signal (padding saw no attention) | |
| loss_mask = augmented_padding_mask.to(torch.bool) | |
| # Pass padding mask for attention masking - model handles prepend extension | |
| if self.mask_padding_attention: | |
| extra_args["padding_mask"] = augmented_padding_mask | |
| if self.inpainting_config is not None: | |
| # Max mask size is the full sequence length | |
| max_mask_length = diffusion_input.shape[2] | |
| # Create a mask of random length for a random slice of the input | |
| inpaint_masked_input, inpaint_mask = random_inpaint_mask(diffusion_input, padding_masks=augmented_padding_mask, mask_padding=self.mask_padding_attention, **self.inpaint_mask_kwargs) | |
| conditioning['inpaint_mask'] = [inpaint_mask] | |
| conditioning['inpaint_masked_input'] = [inpaint_masked_input] | |
| # Only compute loss on inpainted region (where model is generating) | |
| loss_mask = loss_mask & ~inpaint_mask.squeeze(1).to(torch.bool) | |
| output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) | |
| p.tick("diffusion") | |
| if self.log_loss_info: | |
| # Loss debugging logs | |
| num_loss_buckets = 10 | |
| bucket_size = 1 / num_loss_buckets | |
| loss_all = F.mse_loss(output, targets, reduction="none") | |
| sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() | |
| # gather loss_all across all GPUs | |
| loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") | |
| # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size | |
| loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) | |
| # Log bucketed losses with corresponding sigma bucket values, if it's not NaN | |
| debug_log_dict = { | |
| f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) | |
| } | |
| self.log_dict(debug_log_dict) | |
| p.tick("loss_debug") | |
| # Compute std only over non-padded positions when masking is active | |
| if loss_mask is not None and self.mask_padding_attention: | |
| mask_expanded = loss_mask.unsqueeze(1) # [B, 1, T] | |
| std_data = diffusion_input[mask_expanded.expand_as(diffusion_input)].std() | |
| std_targets = targets[mask_expanded.expand_as(targets)].std().detach() | |
| else: | |
| std_data = diffusion_input.std() | |
| std_targets = targets.std().detach() | |
| log_dict = { | |
| 'train/std_data': std_data, | |
| 'train/std_targets': std_targets, | |
| 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] | |
| } | |
| p.tick("std_compute") | |
| # Compute normalized MSE (normalization only affects non-"none" modes) | |
| mse_loss_full = compute_normalized_mse(output, targets, loss_mask, self.loss_normalization, self.loss_norm_eps) | |
| p.tick("mse_loss") | |
| # Compute loss with signal/padding separation (returns already-detached metrics) | |
| loss, signal_mean, padding_mean = compute_masked_loss( | |
| mse_loss_full, loss_mask, self.mask_padding_attention, self.mask_loss_weight | |
| ) | |
| mse_loss = loss | |
| p.tick("masked_loss") | |
| # When attention masking is on, compute_masked_loss excludes everything outside | |
| # loss_mask (which now excludes inpaint context). Add context reconstruction loss | |
| # so the model learns to preserve context regions during inpainting. | |
| # (When mask_padding_attention=False, context is already included via mask_loss_weight.) | |
| context_loss_mean = torch.tensor(0.0, device=loss.device) | |
| if (self.inpainting_config is not None | |
| and self.mask_padding_attention | |
| and self.mask_loss_weight > 0): | |
| # Context = inpaint_mask=1 (keep) AND padding_mask=1 (real audio, not padding) | |
| inpaint_context = inpaint_mask.squeeze(1).to(torch.bool) & augmented_padding_mask.to(torch.bool) | |
| n_ctx = inpaint_context.sum(dim=1) * mse_loss_full.shape[1] # per-sample count | |
| if n_ctx.sum() > 0: | |
| context_vals = torch.where(inpaint_context.unsqueeze(1), mse_loss_full, 0.0) | |
| context_loss_mean = (context_vals.sum(dim=(1, 2)) / (n_ctx + 1e-8)).mean() | |
| loss = loss + context_loss_mean * self.mask_loss_weight | |
| # Log separate signal/padding/context losses for monitoring | |
| log_dict["train/mse_signal"] = signal_mean | |
| log_dict["train/mse_masked_loss"] = padding_mean | |
| log_dict["train/mse_context_loss"] = context_loss_mean.detach() | |
| log_dict["train/mse_loss"] = mse_loss.detach() | |
| log_dict["train/loss"] = loss.detach() | |
| # Stash for external callbacks (e.g. loss-by-timestep logging) | |
| self._last_t = t.detach() | |
| self._last_per_elem_loss = mse_loss_full.detach().mean(dim=(1, 2)) | |
| self._staggered_logger.log(log_dict, self) | |
| #p.tick("log_dict") | |
| #print(f"Profiler: {p}") | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| reals, metadata = batch | |
| if reals.ndim == 4 and reals.shape[0] == 1: | |
| reals = reals[0] | |
| diffusion_input = reals | |
| with torch.amp.autocast("cuda"), torch.no_grad(): | |
| conditioning = self.diffusion.conditioner(metadata, self.device) | |
| # Create batch tensor of padding masks from the metadata | |
| if all("padding_mask" in md for md in metadata): | |
| padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) | |
| else: | |
| padding_masks = torch.ones(diffusion_input.shape[0], diffusion_input.shape[-1], dtype=torch.bool, device=self.device) | |
| if self.diffusion.pretransform is not None: | |
| self.diffusion.pretransform.to(self.device) | |
| if not self.pre_encoded: | |
| with torch.amp.autocast("cuda"), torch.no_grad(): | |
| self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) | |
| diffusion_input = self.diffusion.pretransform.encode(diffusion_input) | |
| padding_masks = resize_padding_mask(padding_masks, diffusion_input.shape[-1]) | |
| else: | |
| # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run | |
| if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: | |
| diffusion_input = diffusion_input / self.diffusion.pretransform.scale | |
| if padding_masks.shape[-1] != diffusion_input.shape[-1]: | |
| padding_masks = resize_padding_mask(padding_masks, diffusion_input.shape[-1]) | |
| # Use padding mask directly for validation (no silence extension augmentation) | |
| loss_mask = padding_masks.to(torch.bool) | |
| extra_args = {} | |
| if self.mask_padding_attention: | |
| extra_args["padding_mask"] = padding_masks | |
| # Set up inpainting conditioning for validation (FULL_MASK: all zeros) | |
| if self.inpainting_config is not None: | |
| inpaint_mask = torch.zeros(diffusion_input.shape[0], 1, diffusion_input.shape[2], device=self.device) | |
| inpaint_masked_input = torch.zeros_like(diffusion_input) | |
| conditioning['inpaint_mask'] = [inpaint_mask] | |
| conditioning['inpaint_masked_input'] = [inpaint_masked_input] | |
| for validation_timestep in self.validation_timesteps: | |
| t = torch.full((reals.shape[0],), validation_timestep, device=self.device) | |
| # Calculate the noise schedule parameters for those timesteps | |
| if self.diffusion_objective in ["v"]: | |
| alphas, sigmas = get_alphas_sigmas(t) | |
| elif self.diffusion_objective in ["rectified_flow", "rf_denoiser"]: | |
| alphas, sigmas = 1-t, t | |
| # Combine the ground truth data and the noise | |
| alphas = alphas[:, None, None] | |
| sigmas = sigmas[:, None, None] | |
| noise = torch.randn_like(diffusion_input) | |
| noised_inputs = diffusion_input * alphas + noise * sigmas | |
| if self.diffusion_objective == "v": | |
| targets = noise * alphas - diffusion_input * sigmas | |
| elif self.diffusion_objective in ["rectified_flow", "rf_denoiser"]: | |
| targets = noise - diffusion_input | |
| with torch.amp.autocast("cuda"), torch.no_grad(): | |
| output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = 0, **extra_args) | |
| mse_loss_full = compute_normalized_mse(output, targets, loss_mask, self.loss_normalization, self.loss_norm_eps) | |
| val_loss, _, _ = compute_masked_loss( | |
| mse_loss_full, loss_mask, self.mask_padding_attention, self.mask_loss_weight | |
| ) | |
| self.validation_step_outputs[f'val/loss_{validation_timestep:.1f}'].append(val_loss.item()) | |
| def on_validation_epoch_end(self): | |
| log_dict = {} | |
| for validation_timestep in self.validation_timesteps: | |
| outputs_key = f'val/loss_{validation_timestep:.1f}' | |
| val_loss = sum(self.validation_step_outputs[outputs_key]) / len(self.validation_step_outputs[outputs_key]) | |
| # Gather losses across all GPUs | |
| val_loss = self.all_gather(val_loss).mean().item() | |
| log_metric(self.logger, outputs_key, val_loss, step=self.global_step) | |
| # Get average over all timesteps | |
| val_loss = torch.tensor([val for val in self.validation_step_outputs.values()]).mean() | |
| # Gather losses across all GPUs | |
| val_loss = self.all_gather(val_loss).mean().item() | |
| log_metric(self.logger, 'val/avg_loss', val_loss, step=self.global_step) | |
| # Reset validation losses | |
| for validation_timestep in self.validation_timesteps: | |
| self.validation_step_outputs[f'val/loss_{validation_timestep:.1f}'] = [] | |
| def export_model(self, path, use_safetensors=False): | |
| if self.diffusion_ema is not None: | |
| self.diffusion.model = self.diffusion_ema.ema_model | |
| if use_safetensors: | |
| save_file(self.diffusion.state_dict(), path) | |
| else: | |
| torch.save({"state_dict": self.diffusion.state_dict()}, path) | |
| def export_lora_safetensors(self, path): | |
| """Export LoRA weights as a safetensors file with embedded config.""" | |
| if self.lora_config is None: | |
| raise ValueError("No LoRA config -- this wrapper is not in LoRA mode") | |
| state_dict = { | |
| **get_lora_state_dict(self.diffusion.model), | |
| **get_lora_state_dict(self.diffusion.conditioner) | |
| } | |
| save_lora_safetensors(state_dict, self.lora_config, path) | |
| def on_save_checkpoint(self, checkpoint): | |
| if self.lora_config is not None: | |
| checkpoint.clear() | |
| checkpoint['state_dict'] = { | |
| **get_lora_state_dict(self.diffusion.model), | |
| **get_lora_state_dict(self.diffusion.conditioner) | |
| } | |
| checkpoint['lora_config'] = self.lora_config | |
| class DiffusionCondInpaintDemoCallback(pl.Callback): | |
| def __init__( | |
| self, | |
| demo_every=2000, | |
| demo_steps=250, | |
| sample_size=65536, | |
| sample_rate=48000, | |
| demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], | |
| demo_conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None, | |
| inpaint_demo_config: tp.Optional[tp.Dict[str, int]] = None, | |
| num_demos: int = 0, | |
| demo_dl=None, | |
| ): | |
| super().__init__() | |
| self.demo_every = demo_every | |
| self.demo_steps = demo_steps | |
| self.demo_samples = sample_size | |
| self.sample_rate = sample_rate | |
| self.demo_cfg_scales = demo_cfg_scales | |
| self.demo_conditioning = demo_conditioning or [] | |
| self.last_demo_step = -1 | |
| # Map config keys to MaskType enum | |
| self._mask_type_map = { | |
| "num_random_segments": MaskType.RANDOM_SEGMENTS, | |
| "num_full_mask": MaskType.FULL_MASK, | |
| "num_causal": MaskType.CAUSAL_MASK, | |
| "num_random_spans": MaskType.RANDOM_SPANS, | |
| } | |
| # Legacy fallback: if no inpaint_demo_config but num_demos is set, | |
| # use num_demos items with random mask sampling (old behavior) | |
| if inpaint_demo_config is not None: | |
| self.inpaint_demo_config = inpaint_demo_config | |
| self.legacy_inpaint_demos = False | |
| elif num_demos > 0: | |
| self.inpaint_demo_config = {} | |
| self.legacy_inpaint_demos = True | |
| self.legacy_num_demos = num_demos | |
| else: | |
| self.inpaint_demo_config = {} | |
| self.legacy_inpaint_demos = False | |
| # Total inpainting demos needed from batch | |
| if self.legacy_inpaint_demos: | |
| self.num_inpaint_demos = self.legacy_num_demos | |
| else: | |
| self.num_inpaint_demos = sum( | |
| self.inpaint_demo_config.get(k, 0) for k in self._mask_type_map | |
| ) | |
| if demo_dl is not None: | |
| self.demo_dl = iter(demo_dl) | |
| else: | |
| self.demo_dl = None | |
| self._teacher_demo_done = False | |
| def _generate_prompt_demos(self, module, trainer, is_rank_zero=True): | |
| """Generate full t2m demos from specified prompts (FULL_MASK).""" | |
| if not self.demo_conditioning: | |
| return [], [] | |
| demo_cond = self.demo_conditioning | |
| num_demos = len(demo_cond) | |
| demo_samples = self.demo_samples | |
| if module.diffusion.pretransform is not None: | |
| demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio | |
| # Conditioning from prompts | |
| conditioning = module.diffusion.conditioner(demo_cond, module.device) | |
| # FULL_MASK: all-zero inpaint conditioning | |
| io_channels = module.diffusion.io_channels | |
| inpaint_mask = torch.zeros(num_demos, 1, demo_samples, device=module.device) | |
| inpaint_masked_input = torch.zeros(num_demos, io_channels, demo_samples, device=module.device) | |
| conditioning['inpaint_mask'] = [inpaint_mask] | |
| conditioning['inpaint_masked_input'] = [inpaint_masked_input] | |
| cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) | |
| noise = torch.randn(num_demos, io_channels, demo_samples, device=module.device) | |
| model_dtype = next(module.diffusion.parameters()).dtype | |
| noise = noise.to(model_dtype) | |
| per_elem_trim = compute_per_elem_trim(demo_cond, self.sample_rate, margin_seconds=2) | |
| model = module.diffusion_ema.ema_model if module.diffusion_ema is not None else module.diffusion.model | |
| all_audio = [] | |
| all_context_masks = [] | |
| for cfg_scale in self.demo_cfg_scales: | |
| if is_rank_zero: | |
| print(f"Generating prompt demos for cfg scale {cfg_scale}") | |
| with torch.amp.autocast("cuda"): | |
| fakes = sample_diffusion( | |
| model=model, | |
| noise=noise, | |
| cond_inputs=cond_inputs, | |
| diffusion_objective=module.diffusion_objective, | |
| steps=self.demo_steps, | |
| cfg_scale=cfg_scale, | |
| conditioning=demo_cond, | |
| sample_rate=self.sample_rate, | |
| pretransform=module.diffusion.pretransform, | |
| mask_padding_attention=module.diffusion.mask_padding_attention, | |
| use_effective_length_for_schedule=module.diffusion.use_effective_length_for_schedule, | |
| headroom_seconds=5.0, | |
| dist_shift=module.diffusion.sampling_dist_shift, | |
| batch_cfg=True, | |
| disable_tqdm=not is_rank_zero, | |
| decode=True | |
| ) | |
| fakes = trim_and_concat(fakes, per_elem_trim) | |
| all_audio.append(fakes) | |
| # Latent-resolution all-zeros mask (no context for prompt demos), | |
| # trimmed to match the per-element audio durations | |
| ds_ratio = module.diffusion.pretransform.downsampling_ratio if module.diffusion.pretransform is not None else 1 | |
| latent_trim = [t // ds_ratio if t is not None else None for t in per_elem_trim] if per_elem_trim is not None else None | |
| latent_mask = torch.zeros(num_demos, 1, demo_samples) | |
| context_mask = trim_and_concat(latent_mask, latent_trim).squeeze(0).cpu() | |
| all_context_masks = [context_mask] * len(self.demo_cfg_scales) | |
| del noise, conditioning, cond_inputs, inpaint_mask, inpaint_masked_input | |
| torch.cuda.empty_cache() | |
| return all_audio, all_context_masks | |
| def _generate_inpaint_demos(self, module, trainer, is_rank_zero=True): | |
| """Generate inpainting demos from batch data with forced mask types.""" | |
| if self.num_inpaint_demos == 0 or self.demo_dl is None: | |
| return [], [] | |
| demo_reals, metadata = next(self.demo_dl) | |
| if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: | |
| demo_reals = demo_reals[0] | |
| demo_reals = demo_reals[:self.num_inpaint_demos] | |
| metadata = metadata[:self.num_inpaint_demos] | |
| model_dtype = next(module.diffusion.parameters()).dtype | |
| demo_reals = demo_reals.to(module.device, dtype=model_dtype) | |
| if not module.pre_encoded: | |
| if module.diffusion.pretransform is not None: | |
| module.diffusion.pretransform.to(module.device) | |
| demo_reals = module.diffusion.pretransform.encode(demo_reals) | |
| else: | |
| if hasattr(module.diffusion.pretransform, "scale") and module.diffusion.pretransform.scale != 1.0: | |
| demo_reals = demo_reals / module.diffusion.pretransform.scale | |
| padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(module.device) | |
| if padding_masks.shape[-1] != demo_reals.shape[-1]: | |
| padding_masks = resize_padding_mask(padding_masks, demo_reals.shape[-1]) | |
| mask_padding = module.diffusion.mask_padding_attention | |
| if self.legacy_inpaint_demos: | |
| # Legacy: random mask type sampling (old behavior) | |
| masked_input, mask = random_inpaint_mask( | |
| demo_reals, padding_masks=padding_masks, | |
| mask_padding=mask_padding, | |
| **module.inpaint_mask_kwargs | |
| ) | |
| else: | |
| # New: forced mask types per config | |
| all_masks = [] | |
| all_masked_inputs = [] | |
| idx = 0 | |
| for config_key, mask_type in self._mask_type_map.items(): | |
| count = self.inpaint_demo_config.get(config_key, 0) | |
| if count == 0: | |
| continue | |
| subset_reals = demo_reals[idx:idx+count] | |
| subset_padding = padding_masks[idx:idx+count] | |
| mi, m = random_inpaint_mask( | |
| subset_reals, padding_masks=subset_padding, | |
| mask_padding=mask_padding, force_mask_type=mask_type, | |
| **module.inpaint_mask_kwargs | |
| ) | |
| all_masks.append(m) | |
| all_masked_inputs.append(mi) | |
| idx += count | |
| mask = torch.cat(all_masks, dim=0) | |
| masked_input = torch.cat(all_masked_inputs, dim=0) | |
| conditioning = module.diffusion.conditioner(metadata, module.device) | |
| conditioning['inpaint_mask'] = [mask] | |
| conditioning['inpaint_masked_input'] = [masked_input] | |
| cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) | |
| demo_samples = demo_reals.shape[2] | |
| noise = torch.randn(demo_reals.shape[0], module.diffusion.io_channels, demo_samples, device=module.device) | |
| model_dtype = next(module.diffusion.parameters()).dtype | |
| noise = noise.to(model_dtype) | |
| per_elem_trim = compute_per_elem_trim(metadata, self.sample_rate, margin_seconds=2) | |
| # Trim and concatenate context mask at latent resolution, | |
| # using same trimming basis as audio (per_elem_trim // ds_ratio) | |
| ds_ratio = module.diffusion.pretransform.downsampling_ratio if module.diffusion.pretransform is not None else 1 | |
| latent_trim = [t // ds_ratio if t is not None else None for t in per_elem_trim] if per_elem_trim is not None else None | |
| # Zero out padding region in mask for display — the mask is initialized to 1, | |
| # so without mask_padding the padding frames show as false context in the overlay | |
| display_mask = mask * padding_masks.unsqueeze(1) | |
| context_mask = trim_and_concat(display_mask, latent_trim).squeeze(0).cpu() | |
| model = module.diffusion_ema.ema_model if module.diffusion_ema is not None else module.diffusion.model | |
| all_audio = [] | |
| all_context_masks = [] | |
| for cfg_scale in self.demo_cfg_scales: | |
| if is_rank_zero: | |
| print(f"Generating inpaint demos for cfg scale {cfg_scale}") | |
| with torch.amp.autocast("cuda"): | |
| fakes = sample_diffusion( | |
| model=model, | |
| noise=noise, | |
| cond_inputs=cond_inputs, | |
| diffusion_objective=module.diffusion_objective, | |
| steps=self.demo_steps, | |
| cfg_scale=cfg_scale, | |
| conditioning=metadata, | |
| sample_rate=self.sample_rate, | |
| pretransform=module.diffusion.pretransform, | |
| mask_padding_attention=module.diffusion.mask_padding_attention, | |
| use_effective_length_for_schedule=module.diffusion.use_effective_length_for_schedule, | |
| headroom_seconds=5.0, | |
| dist_shift=module.diffusion.sampling_dist_shift, | |
| batch_cfg=True, | |
| disable_tqdm=not is_rank_zero, | |
| decode=True | |
| ) | |
| fakes = trim_and_concat(fakes, per_elem_trim) | |
| all_audio.append(fakes) | |
| all_context_masks.append(context_mask) | |
| del noise, conditioning, cond_inputs, mask, masked_input, padding_masks, demo_reals | |
| torch.cuda.empty_cache() | |
| return all_audio, all_context_masks | |
| def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): | |
| if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: | |
| return | |
| is_rank_zero = get_rank() == 0 | |
| module.eval() | |
| self.last_demo_step = trainer.global_step | |
| try: | |
| # Generate both types of demos, freeing intermediates between phases | |
| prompt_audio, prompt_masks = self._generate_prompt_demos(module, trainer, is_rank_zero) | |
| torch.cuda.empty_cache() | |
| inpaint_audio, inpaint_masks = self._generate_inpaint_demos(module, trainer, is_rank_zero) | |
| torch.cuda.empty_cache() | |
| # Combine per cfg scale (prompt_audio and inpaint_audio have one entry per cfg scale) | |
| if is_rank_zero: | |
| for i, cfg_scale in enumerate(self.demo_cfg_scales): | |
| parts = [] | |
| mask_parts = [] | |
| if i < len(prompt_audio): | |
| parts.append(prompt_audio[i]) | |
| mask_parts.append(prompt_masks[i]) | |
| if i < len(inpaint_audio): | |
| parts.append(inpaint_audio[i]) | |
| mask_parts.append(inpaint_masks[i]) | |
| if not parts: | |
| continue | |
| combined_audio = torch.cat(parts, dim=-1) | |
| combined_mask = torch.cat(mask_parts, dim=-1) if mask_parts else None | |
| filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' | |
| combined_audio = combined_audio.to(torch.float32).div(torch.max(torch.abs(combined_audio))).mul(32767).to(torch.int16).cpu() | |
| torchaudio.save(filename, combined_audio, self.sample_rate) | |
| log_audio(trainer.logger, f'demo_cfg_{cfg_scale}', filename, self.sample_rate) | |
| log_image(trainer.logger, f'demo_melspec_left_cfg_{cfg_scale}', audio_spectrogram_image(combined_audio, context_mask=combined_mask)) | |
| if isinstance(trainer.logger, (WandbLogger, CometLogger)): | |
| os.remove(filename) | |
| # Teacher ODE warmup diagnostic: mirror the exact ODE warmup sample_diffusion call | |
| # and decode the target to verify teacher output quality. | |
| # Only runs on the first demo. | |
| # Generates both prompt and inpaint demos, consistent with the main callback. | |
| teacher_ref = getattr(module, '_teacher', None) or getattr(module, 'teacher_model', None) | |
| if not self._teacher_demo_done and teacher_ref is not None: | |
| self._teacher_demo_done = True | |
| if is_rank_zero: | |
| print("Generating teacher ODE warmup diagnostic") | |
| try: | |
| pretransform = module.diffusion.pretransform # Shared pretransform (not on teacher) | |
| io_channels = teacher_ref.io_channels | |
| ode_warmup_config = getattr(module, 'ode_warmup_config', {}) | |
| teacher_cfg = getattr(module, 'ode_warmup_cfg', self.demo_cfg_scales[0]) | |
| ode_steps = getattr(module, 'ode_n_sampling_steps', 20) | |
| mask_padding = module.diffusion.mask_padding_attention | |
| ds_ratio = pretransform.downsampling_ratio if pretransform is not None else 1 | |
| # --- Teacher prompt demos (FULL_MASK, same as _generate_prompt_demos) --- | |
| prompt_target = None | |
| prompt_per_elem_trim = None | |
| prompt_context_mask = None | |
| demo_cond = self.demo_conditioning | |
| if demo_cond: | |
| num_demos = len(demo_cond) | |
| demo_samples = self.demo_samples | |
| if pretransform is not None: | |
| demo_samples = demo_samples // ds_ratio | |
| with torch.no_grad(): | |
| teacher_conditioning = teacher_ref.conditioner(demo_cond, module.device) | |
| inpaint_mask = torch.zeros(num_demos, 1, demo_samples, device=module.device) | |
| inpaint_masked_input = torch.zeros(num_demos, io_channels, demo_samples, device=module.device) | |
| teacher_conditioning['inpaint_mask'] = [inpaint_mask] | |
| teacher_conditioning['inpaint_masked_input'] = [inpaint_masked_input] | |
| with torch.no_grad(): | |
| teacher_cond_inputs = teacher_ref.get_conditioning_inputs(teacher_conditioning) | |
| noise = torch.randn(num_demos, io_channels, demo_samples, device=module.device) | |
| noise = noise.to(next(teacher_ref.parameters()).dtype) | |
| prompt_per_elem_trim = compute_per_elem_trim(demo_cond, self.sample_rate, margin_seconds=2) | |
| prompt_target = sample_diffusion( | |
| model=teacher_ref.model, | |
| noise=noise, | |
| cond_inputs=teacher_cond_inputs, | |
| diffusion_objective=teacher_ref.diffusion_objective, | |
| steps=ode_steps, | |
| cfg_scale=teacher_cfg, | |
| conditioning=demo_cond, | |
| sample_rate=teacher_ref.sample_rate, | |
| pretransform=pretransform, | |
| mask_padding_attention=mask_padding, | |
| use_effective_length_for_schedule=module.diffusion.use_effective_length_for_schedule, | |
| padding_mask=None, | |
| dist_shift=teacher_ref.sampling_dist_shift, | |
| sampler_type=ode_warmup_config.get('sampler', 'dpmpp'), | |
| batch_cfg=True, | |
| disable_tqdm=not is_rank_zero, | |
| decode=False, | |
| ) | |
| prompt_latent_trim = [t // ds_ratio if t is not None else None for t in prompt_per_elem_trim] if prompt_per_elem_trim is not None else None | |
| prompt_context_mask = trim_and_concat( | |
| torch.zeros(num_demos, 1, demo_samples), prompt_latent_trim | |
| ).squeeze(0).cpu() | |
| # --- Teacher inpaint demos (same mask logic as _generate_inpaint_demos) --- | |
| inpaint_target = None | |
| inpaint_per_elem_trim = None | |
| inpaint_context_mask = None | |
| if self.num_inpaint_demos > 0 and self.demo_dl is not None: | |
| try: | |
| inpaint_reals, inpaint_metadata = next(self.demo_dl) | |
| if inpaint_reals.ndim == 4 and inpaint_reals.shape[0] == 1: | |
| inpaint_reals = inpaint_reals[0] | |
| inpaint_reals = inpaint_reals[:self.num_inpaint_demos] | |
| inpaint_metadata = inpaint_metadata[:self.num_inpaint_demos] | |
| inpaint_reals = inpaint_reals.to(module.device) | |
| if not module.pre_encoded: | |
| if pretransform is not None: | |
| inpaint_reals = pretransform.encode(inpaint_reals) | |
| else: | |
| if hasattr(pretransform, "scale") and pretransform.scale != 1.0: | |
| inpaint_reals = inpaint_reals / pretransform.scale | |
| inpaint_padding_masks = torch.stack( | |
| [md["padding_mask"][0] for md in inpaint_metadata], dim=0 | |
| ).to(module.device) | |
| if self.legacy_inpaint_demos: | |
| masked_input, mask = random_inpaint_mask( | |
| inpaint_reals, padding_masks=inpaint_padding_masks, | |
| mask_padding=mask_padding, **module.inpaint_mask_kwargs | |
| ) | |
| else: | |
| all_masks = [] | |
| all_masked_inputs = [] | |
| idx = 0 | |
| for config_key, mask_type in self._mask_type_map.items(): | |
| count = self.inpaint_demo_config.get(config_key, 0) | |
| if count == 0: | |
| continue | |
| mi, m = random_inpaint_mask( | |
| inpaint_reals[idx:idx+count], | |
| padding_masks=inpaint_padding_masks[idx:idx+count], | |
| mask_padding=mask_padding, force_mask_type=mask_type, | |
| **module.inpaint_mask_kwargs | |
| ) | |
| all_masks.append(m) | |
| all_masked_inputs.append(mi) | |
| idx += count | |
| mask = torch.cat(all_masks, dim=0) | |
| masked_input = torch.cat(all_masked_inputs, dim=0) | |
| with torch.no_grad(): | |
| inpaint_teacher_cond = teacher_ref.conditioner(inpaint_metadata, module.device) | |
| inpaint_teacher_cond['inpaint_mask'] = [mask] | |
| inpaint_teacher_cond['inpaint_masked_input'] = [masked_input] | |
| with torch.no_grad(): | |
| inpaint_cond_inputs = teacher_ref.get_conditioning_inputs(inpaint_teacher_cond) | |
| inpaint_samples = inpaint_reals.shape[2] | |
| inpaint_noise = torch.randn( | |
| inpaint_reals.shape[0], io_channels, inpaint_samples, device=module.device | |
| ).to(next(teacher_ref.parameters()).dtype) | |
| inpaint_per_elem_trim = compute_per_elem_trim(inpaint_metadata, self.sample_rate, margin_seconds=2) | |
| inpaint_target = sample_diffusion( | |
| model=teacher_ref.model, | |
| noise=inpaint_noise, | |
| cond_inputs=inpaint_cond_inputs, | |
| diffusion_objective=teacher_ref.diffusion_objective, | |
| steps=ode_steps, | |
| cfg_scale=teacher_cfg, | |
| conditioning=inpaint_metadata, | |
| sample_rate=teacher_ref.sample_rate, | |
| pretransform=pretransform, | |
| mask_padding_attention=mask_padding, | |
| use_effective_length_for_schedule=module.diffusion.use_effective_length_for_schedule, | |
| padding_mask=None, | |
| dist_shift=teacher_ref.sampling_dist_shift, | |
| sampler_type=ode_warmup_config.get('sampler', 'dpmpp'), | |
| batch_cfg=True, | |
| disable_tqdm=not is_rank_zero, | |
| decode=False, | |
| ) | |
| # Context mask for overlay (same as _generate_inpaint_demos) | |
| display_mask = mask * inpaint_padding_masks.unsqueeze(1) | |
| inpaint_latent_trim = [t // ds_ratio if t is not None else None for t in inpaint_per_elem_trim] if inpaint_per_elem_trim is not None else None | |
| inpaint_context_mask = trim_and_concat(display_mask, inpaint_latent_trim).squeeze(0).cpu() | |
| except StopIteration: | |
| if is_rank_zero: | |
| print("Teacher diagnostic: no inpaint batch available from demo_dl") | |
| # --- Combine and log (same pattern as main callback) --- | |
| if is_rank_zero: | |
| parts = [] | |
| mask_parts = [] | |
| if prompt_target is not None: | |
| decoded_prompt = pretransform.decode(prompt_target.float()) | |
| decoded_prompt = trim_and_concat(decoded_prompt, prompt_per_elem_trim) | |
| parts.append(decoded_prompt) | |
| mask_parts.append(prompt_context_mask) | |
| if inpaint_target is not None: | |
| decoded_inpaint = pretransform.decode(inpaint_target.float()) | |
| decoded_inpaint = trim_and_concat(decoded_inpaint, inpaint_per_elem_trim) | |
| parts.append(decoded_inpaint) | |
| mask_parts.append(inpaint_context_mask) | |
| if parts: | |
| combined_audio = torch.cat(parts, dim=-1) | |
| combined_mask = torch.cat(mask_parts, dim=-1) if mask_parts else None | |
| filename = f'demo_teacher_target_{trainer.global_step:08}.wav' | |
| combined_audio = combined_audio.to(torch.float32).div(torch.max(torch.abs(combined_audio))).mul(32767).to(torch.int16).cpu() | |
| torchaudio.save(filename, combined_audio, self.sample_rate) | |
| log_audio(trainer.logger, f'demo_teacher_target', filename, self.sample_rate) | |
| log_image(trainer.logger, f'demo_teacher_target_melspec', audio_spectrogram_image(combined_audio, context_mask=combined_mask)) | |
| os.remove(filename) | |
| del prompt_target, inpaint_target | |
| except Exception as e: | |
| if is_rank_zero: | |
| print(f"Teacher ODE warmup diagnostic failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| except Exception as e: | |
| if is_rank_zero: | |
| print(f'{type(e).__name__}: {e}') | |
| raise e | |
| finally: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| module.train() |