Spaces:
Running on Zero
Running on Zero
| import torch | |
| import typing as tp | |
| from tqdm import trange, tqdm | |
| import torch.distributions as dist | |
| from ..data.utils import create_padding_mask_from_lengths, compute_effective_seq_len_from_conditioning | |
| def build_schedule( | |
| steps: int, | |
| sigma_max: float = 1.0, | |
| dist_shift = None, | |
| effective_seq_len: tp.Union[int, torch.Tensor, None] = None, | |
| fallback_seq_len: tp.Optional[int] = None, | |
| include_endpoint: bool = True, | |
| device: tp.Union[str, torch.device] = "cpu", | |
| ) -> torch.Tensor: | |
| """Build a timestep schedule for diffusion sampling. | |
| Returns a 1D tensor of shape (N,) where N = steps+1 (if include_endpoint) | |
| or steps (if not), OR a 2D tensor of shape (batch_size, N) when | |
| effective_seq_len is a tensor and dist_shift produces per-element schedules. | |
| Args: | |
| steps: Number of sampling steps. | |
| sigma_max: Starting noise level (1.0 for full generation, <1.0 for variations). | |
| dist_shift: Optional distribution shift object (FluxDistributionShift, | |
| DistributionShift, LogSNRShift, etc.). Applied to warp the linear schedule. | |
| effective_seq_len: Sequence length for dist_shift. Scalar int or | |
| tensor of shape (batch_size,) for per-element schedules. | |
| fallback_seq_len: Fallback when effective_seq_len is None (typically x.shape[-1]). | |
| include_endpoint: If True, schedule includes 0 as final value (RF samplers). | |
| If False, excludes 0 (v-diffusion DDIM). | |
| device: Device for the output tensor. | |
| """ | |
| n_points = steps + 1 if include_endpoint else steps | |
| if include_endpoint: | |
| t = torch.linspace(sigma_max, 0, n_points, device=device) | |
| else: | |
| t = torch.linspace(sigma_max, 0, n_points + 1, device=device)[:-1] | |
| if dist_shift is not None: | |
| seq_len = effective_seq_len if effective_seq_len is not None else fallback_seq_len | |
| if isinstance(seq_len, torch.Tensor): | |
| # Clamp per-element sequence lengths to avoid zeros causing log/NaN issues | |
| seq_len = torch.clamp(seq_len, min=1) | |
| elif seq_len is not None: | |
| # Clamp scalar sequence length to at least 1 | |
| seq_len = max(int(seq_len), 1) | |
| t = dist_shift.shift(t, seq_len) | |
| # Ensure the first timestep remains aligned with sigma_max after shifting. | |
| # This keeps the schedule consistent with the initialization in sample_diffusion(), | |
| # which mixes init_data using sigma_max. | |
| if isinstance(t, torch.Tensor): | |
| sigma_max_tensor = t.new_tensor(sigma_max) | |
| if t.ndim == 1: | |
| t[0] = sigma_max_tensor | |
| else: | |
| # For batched/per-element schedules, enforce sigma_max at the first time index. | |
| t[..., 0] = sigma_max_tensor | |
| return t | |
| def sample_timesteps_logsnr(batch_size, mean_logsnr=-1.2, std_logsnr=2.0): | |
| """ | |
| Sample timesteps for diffusion training by sampling logSNR values and converting to t. | |
| Args: | |
| batch_size (int): Number of timesteps to sample | |
| mean_logsnr (float): Mean of the logSNR Gaussian distribution | |
| std_logsnr (float): Standard deviation of the logSNR Gaussian distribution | |
| Returns: | |
| torch.Tensor: Tensor of shape (batch_size,) containing timestep values t in [0, 1] | |
| """ | |
| # Sample logSNR from Gaussian distribution | |
| logsnr = torch.randn(batch_size) * std_logsnr + mean_logsnr | |
| # Convert logSNR to timesteps using the logistic function | |
| # Since logSNR = ln((1-t)/t), we can solve for t: | |
| # t = 1 / (1 + exp(logsnr)) | |
| t = torch.sigmoid(-logsnr) | |
| # Clamp values to ensure numerical stability | |
| t = t.clamp(1e-4, 1 - 1e-4) | |
| return t | |
| def sample_timesteps_logsnr_uniform(batch_size, min_logsnr=-6, max_logsnr=5.0): | |
| """ | |
| Sample timesteps for diffusion training by sampling logSNR values and converting to t. | |
| Args: | |
| batch_size (int): Number of timesteps to sample | |
| min_logsnr (float): Minimum logSNR value | |
| max_logsnr (float): Maximum logSNR value | |
| Returns: | |
| torch.Tensor: Tensor of shape (batch_size,) containing timestep values t in [0, 1] | |
| """ | |
| # Sample logSNR from uniform distribution | |
| logsnr = torch.rand(batch_size) * (max_logsnr - min_logsnr) + min_logsnr | |
| # Convert logSNR to timesteps using the logistic function | |
| # Since logSNR = ln((1-t)/t), we can solve for t: | |
| # t = 1 / (1 + exp(logsnr)) | |
| t = torch.sigmoid(-logsnr) | |
| # Clamp values to ensure numerical stability | |
| t = t.clamp(1e-4, 1 - 1e-4) | |
| return t | |
| def truncated_logistic_normal_rescaled(shape, left_trunc=0.075, right_trunc=1): | |
| """ | |
| shape: shape of the output tensor | |
| left_trunc: left truncation point, fraction of probability to be discarded | |
| right_trunc: right truncation boundary, should be 1 (never seen at test time) | |
| """ | |
| # Step 1: Sample from the logistic normal distribution (sigmoid of normal) | |
| logits = torch.randn(shape) | |
| # Step 2: Apply the CDF transformation of the normal distribution | |
| normal_dist = dist.Normal(0, 1) | |
| cdf_values = normal_dist.cdf(logits) | |
| # Step 3: Define the truncation bounds on the CDF | |
| lower_bound = normal_dist.cdf(torch.logit(torch.tensor(left_trunc))) | |
| upper_bound = normal_dist.cdf(torch.logit(torch.tensor(right_trunc))) | |
| # Step 4: Rescale linear CDF values into the truncated region (between lower_bound and upper_bound) | |
| truncated_cdf_values = lower_bound + (upper_bound - lower_bound) * cdf_values | |
| # Step 5: Map back to logistic-normal space using inverse CDF | |
| truncated_samples = torch.sigmoid(normal_dist.icdf(truncated_cdf_values)) | |
| # Step 6: Rescale values so that min is 0 and max is just below 1 | |
| rescaled_samples = (truncated_samples - left_trunc) / (right_trunc - left_trunc) | |
| return rescaled_samples | |
| def sample_discrete_euler(model, x, sigmas, callback=None, disable_tqdm=False, **extra_args): | |
| """Draws samples from a model given starting noise. Euler method | |
| Args: | |
| sigmas: Pre-computed schedule tensor. Shape (steps+1,) for global schedule | |
| or (batch_size, steps+1) for per-element schedules. | |
| """ | |
| t = sigmas | |
| # Check if we have per-element schedules (batch_size, steps+1) or global schedule (steps+1,) | |
| per_element_schedule = t.dim() == 2 | |
| t = t.to(x.device) | |
| num_steps = t.shape[-1] - 1 | |
| for i in tqdm(range(num_steps), disable=disable_tqdm): | |
| if per_element_schedule: | |
| # Per-element schedules: t has shape (batch_size, steps+1) | |
| t_curr_tensor = t[:, i].to(x.dtype) # (batch_size,) | |
| t_prev = t[:, i + 1].to(x.dtype) # (batch_size,) | |
| dt = t_prev - t_curr_tensor # (batch_size,) | |
| # Reshape for broadcasting with x: (batch_size,) -> (batch_size, 1, 1) | |
| dt_broadcast = dt.view(-1, 1, 1) | |
| else: | |
| # Global schedule: t has shape (steps+1,) | |
| t_curr = t[i] | |
| t_prev = t[i + 1] | |
| t_curr_tensor = t_curr * torch.ones((x.shape[0],), dtype=x.dtype, device=x.device) | |
| dt = t_prev - t_curr | |
| dt_broadcast = dt | |
| v = model(x, t_curr_tensor, **extra_args) | |
| if callback is not None: | |
| denoised = x - t_curr_tensor[:, None, None] * v | |
| callback({'x': x, 't': t_curr_tensor, 'sigma': t_curr_tensor, 'i': i, 'denoised': denoised}) | |
| x = x + dt_broadcast * v | |
| # If we are on the last timestep, output the denoised data | |
| return x | |
| def sample_rk4(model, x, sigmas, callback=None, disable_tqdm=False, **extra_args): | |
| """Draws samples from a model given starting noise. 4th-order Runge-Kutta | |
| Args: | |
| sigmas: Pre-computed schedule tensor of shape (steps+1,). | |
| Per-element schedules not supported for RK4. | |
| """ | |
| # Make tensor of ones to broadcast the single t values | |
| ts = x.new_ones([x.shape[0]]) | |
| t = sigmas | |
| t = t.to(x.device) | |
| for i, (t_curr, t_prev) in enumerate(tqdm(zip(t[:-1], t[1:]), disable=disable_tqdm)): | |
| # Broadcast the current timestep to the correct shape | |
| t_curr_tensor = t_curr * ts | |
| dt = t_prev - t_curr # we solve backwards in our formulation | |
| k1 = model(x, t_curr_tensor, **extra_args) | |
| if callback is not None: | |
| denoised = x - t_curr * k1 | |
| callback({'x': x, 't': t_curr, 'sigma': t_curr, 'i': i, 'denoised': denoised}) | |
| k2 = model(x + dt / 2 * k1, (t_curr + dt / 2) * ts, **extra_args) | |
| k3 = model(x + dt / 2 * k2, (t_curr + dt / 2) * ts, **extra_args) | |
| # Clamp t_prev to avoid evaluating model at exactly t=0 | |
| # (models aren't trained at t=0 and may return garbage/NaN) | |
| t_prev_eval = t_prev.clamp(min=1e-5) | |
| k4 = model(x + dt * k3, t_prev_eval * ts, **extra_args) | |
| x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4) | |
| # If we are on the last timestep, output the denoised data | |
| return x | |
| def sample_flow_dpmpp(model, x, sigmas, callback=None, disable_tqdm=False, **extra_args): | |
| """Draws samples from a model given starting noise. DPM-Solver++ for RF models | |
| Args: | |
| sigmas: Pre-computed schedule tensor. Shape (steps+1,) for global schedule | |
| or (batch_size, steps+1) for per-element schedules. | |
| """ | |
| t = sigmas | |
| # Check if we have per-element schedules (batch_size, steps+1) or global schedule (steps+1,) | |
| per_element_schedule = t.dim() == 2 | |
| t = t.to(x.device) | |
| num_steps = t.shape[-1] - 1 | |
| old_denoised = None | |
| # Clamp t to avoid numerical issues with log(0) and division by zero | |
| # This prevents inf/-inf values that can cause NaN propagation | |
| log_snr = lambda t: ((1-t).clamp(min=1e-10) / t.clamp(min=1e-10)).log() | |
| for i in trange(num_steps, disable=disable_tqdm): | |
| if per_element_schedule: | |
| # Per-element schedules: t has shape (batch_size, steps+1) | |
| t_curr = t[:, i] # (batch_size,) | |
| t_next = t[:, i + 1] # (batch_size,) | |
| t_prev = t[:, i - 1] if i > 0 else None | |
| # Reshape for broadcasting with x: (batch_size,) -> (batch_size, 1, 1) | |
| t_curr_broadcast = t_curr.view(-1, 1, 1) | |
| t_next_broadcast = t_next.view(-1, 1, 1) | |
| t_curr_tensor = t_curr # already (batch_size,) | |
| else: | |
| # Global schedule: t has shape (steps+1,) | |
| t_curr = t[i] | |
| t_next = t[i + 1] | |
| t_prev = t[i - 1] if i > 0 else None | |
| t_curr_broadcast = t_curr | |
| t_next_broadcast = t_next | |
| t_curr_tensor = t_curr.expand(x.shape[0]) | |
| model_output = model(x, t_curr_tensor, **extra_args) | |
| denoised = x - t_curr_broadcast * model_output | |
| if callback is not None: | |
| callback({'x': x, 'i': i, 't': t_curr, 'sigma': t_curr, 'denoised': denoised}) | |
| alpha_t = 1 - t_next_broadcast | |
| # For rectified flow, compute the DPM++ coefficient directly without log_snr | |
| # to avoid numerical issues at t=0 or t=1 | |
| # The formula is: (-h).expm1() = (t_next - t_curr) / [(1 - t_next) * t_curr] | |
| # Note: t_next < t_curr, so this is negative | |
| # We'll compute this directly instead of going through log_snr | |
| dt = t_next_broadcast - t_curr_broadcast | |
| # Clamp to avoid division by zero when t_curr or t_next are at boundaries | |
| dpmpp_coeff = dt / ((1 - t_next_broadcast).clamp(min=1e-10) * t_curr_broadcast.clamp(min=1e-10)) | |
| # Check if this is the first step or the last step (t_next == 0) | |
| is_first_step = old_denoised is None | |
| is_last_step = (t_next_broadcast == 0).all() if per_element_schedule else (t_next == 0) | |
| if is_first_step or is_last_step: | |
| # First-order update using the directly computed coefficient | |
| x = (t_next_broadcast / t_curr_broadcast.clamp(min=1e-10)) * x - alpha_t * dpmpp_coeff * denoised | |
| else: | |
| # Second-order update with Richardson extrapolation | |
| if per_element_schedule: | |
| t_prev_broadcast = t_prev.view(-1, 1, 1) | |
| else: | |
| t_prev_broadcast = t_prev | |
| # Compute r = h_last / h in log-SNR space for second-order correction | |
| # h = log_snr(t_next) - log_snr(t_curr), h_last = log_snr(t_curr) - log_snr(t_prev) | |
| h = log_snr(t_next_broadcast) - log_snr(t_curr_broadcast) | |
| h_last = log_snr(t_curr_broadcast) - log_snr(t_prev_broadcast) | |
| r = h_last / h | |
| denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised | |
| x = (t_next_broadcast / t_curr_broadcast.clamp(min=1e-10)) * x - alpha_t * dpmpp_coeff * denoised_d | |
| old_denoised = denoised | |
| return x | |
| def sample_flow_pingpong(model, x, sigmas, callback=None, disable_tqdm=False, **extra_args): | |
| """Draws samples from a model given starting noise. Ping-pong sampling for distilled models | |
| Args: | |
| sigmas: Pre-computed schedule tensor. Shape (steps+1,) for global schedule | |
| or (batch_size, steps+1) for per-element schedules. | |
| """ | |
| t = sigmas | |
| # Check if we have per-element schedules (batch_size, steps+1) or global schedule (steps+1,) | |
| per_element_schedule = t.dim() == 2 | |
| t = t.to(x.device) | |
| num_steps = t.shape[-1] - 1 | |
| for i in trange(num_steps, disable=disable_tqdm): | |
| if per_element_schedule: | |
| # Per-element schedules: t has shape (batch_size, steps+1) | |
| t_curr = t[:, i].to(x.dtype) # (batch_size,) | |
| t_next = t[:, i + 1].to(x.dtype) # (batch_size,) | |
| # Reshape for broadcasting with x: (batch_size,) -> (batch_size, 1, 1) | |
| t_curr_broadcast = t_curr.view(-1, 1, 1) | |
| t_next_broadcast = t_next.view(-1, 1, 1) | |
| else: | |
| # Global schedule: t has shape (steps+1,) | |
| t_curr = t[i].to(x.dtype) | |
| t_next = t[i + 1].to(x.dtype) | |
| t_curr_broadcast = t_curr | |
| t_next_broadcast = t_next | |
| # Model forward | |
| if per_element_schedule: | |
| t_curr_tensor = t_curr # already (batch_size,) | |
| else: | |
| t_curr_tensor = t_curr * torch.ones((x.shape[0],), dtype=x.dtype, device=x.device) | |
| denoised = x - t_curr_broadcast * model(x, t_curr_tensor, **extra_args) | |
| if callback is not None: | |
| callback({'x': x, 'i': i, 't': t_curr, 'sigma': t_curr, 'sigma_hat': t_curr, 'denoised': denoised}) | |
| x = (1 - t_next_broadcast) * denoised + t_next_broadcast * torch.randn_like(x) | |
| return x | |
| def sample_diffusion( | |
| model, | |
| noise: torch.Tensor, | |
| cond_inputs: dict, | |
| diffusion_objective: str, | |
| steps: int, | |
| cfg_scale: float = 1.0, | |
| # Varlen support | |
| conditioning: tp.Optional[tp.List[dict]] = None, | |
| sample_rate: int = 44100, | |
| pretransform = None, | |
| mask_padding_attention: bool = False, | |
| use_effective_length_for_schedule: bool = False, | |
| headroom_seconds: float = 5.0, | |
| padding_mask: tp.Optional[torch.Tensor] = None, | |
| # Timestep schedule | |
| dist_shift = None, | |
| # Sampler options | |
| sampler_type: str = None, | |
| batch_cfg: bool = True, | |
| rescale_cfg: bool = False, | |
| # CFG options | |
| apg_scale: float = 1.0, | |
| # Init data (variation / img2img) | |
| init_data: tp.Optional[torch.Tensor] = None, | |
| init_noise_level: float = 1.0, | |
| # Other | |
| callback = None, | |
| disable_tqdm: bool = False, | |
| decode: bool = True, | |
| chunked_decode: tp.Optional[bool] = None, | |
| **sampler_kwargs | |
| ) -> torch.Tensor: | |
| """ | |
| Unified sampling function for diffusion models. Handles all diffusion objectives, | |
| varlen support (padding_mask + effective_seq_len), timestep scheduling, and init_data | |
| for variation/img2img. | |
| Args: | |
| model: The diffusion model backbone (model.model, not the wrapper) | |
| noise: Initial noise tensor of shape (B, C, T) | |
| cond_inputs: Pre-processed conditioning inputs dict (merged positive + negative) | |
| diffusion_objective: One of "v", "rectified_flow", "rf_denoiser" | |
| steps: Number of sampling steps | |
| cfg_scale: Classifier-free guidance scale | |
| conditioning: List of conditioning dicts (for computing varlen from seconds_total) | |
| sample_rate: Audio sample rate | |
| pretransform: Optional pretransform for decoding latents and computing downsampling_ratio | |
| mask_padding_attention: Whether to create padding_mask for attention | |
| use_effective_length_for_schedule: Whether to use effective_seq_len for dist_shift | |
| padding_mask: Optional pre-computed padding mask (B, T). If provided, skips | |
| internal mask computation. Use this to ensure consistency with training masks. | |
| headroom_seconds: Extra seconds beyond seconds_total for valid region | |
| dist_shift: Distribution shift object for warping the timestep schedule, or None | |
| sampler_type: Sampler type. For RF: "euler", "rk4", "dpmpp", "pingpong". | |
| For v-diffusion: "v-ddim", "v-ddim-cfgpp", or k-diffusion types like "dpmpp-2m-sde". | |
| batch_cfg: Whether to use batched CFG | |
| rescale_cfg: Whether to use rescaled CFG | |
| apg_scale: APG (Adaptive Projected Guidance) scale. 1.0 = full APG, 0.0 = vanilla CFG | |
| init_data: Optional pre-encoded latent tensor for variation/img2img (shape: B, C, T) | |
| init_noise_level: Noise level (sigma_max) when using init_data. 1.0 = full noise (no variation). | |
| callback: Optional callback for progress reporting | |
| disable_tqdm: Whether to disable progress bar | |
| decode: Whether to decode latents using pretransform | |
| **sampler_kwargs: Additional kwargs passed to sampler | |
| Returns: | |
| Generated samples (decoded audio if decode=True, else latents) | |
| """ | |
| device = noise.device | |
| batch_size = noise.shape[0] | |
| latent_seq_len = noise.shape[-1] | |
| # Compute downsampling ratio | |
| downsampling_ratio = pretransform.downsampling_ratio if pretransform is not None else 1 | |
| # Default sampler_type per objective | |
| if sampler_type is None: | |
| sampler_type = "pingpong" if diffusion_objective == "rf_denoiser" else "euler" | |
| # Compute effective_seq_len for dist_shift if enabled | |
| effective_seq_len = None | |
| if use_effective_length_for_schedule and conditioning is not None: | |
| effective_seq_len = compute_effective_seq_len_from_conditioning( | |
| conditioning, sample_rate, downsampling_ratio, device | |
| ) | |
| # Create padding_mask for attention if enabled (skip if pre-computed mask provided) | |
| if padding_mask is None and mask_padding_attention and conditioning is not None: | |
| raw_effective_len = compute_effective_seq_len_from_conditioning( | |
| conditioning, sample_rate, downsampling_ratio, device | |
| ) | |
| if raw_effective_len is not None: | |
| headroom_tokens = int(headroom_seconds * sample_rate / downsampling_ratio) | |
| valid_lengths = (raw_effective_len + headroom_tokens).clamp(max=latent_seq_len).long() | |
| padding_mask = create_padding_mask_from_lengths(valid_lengths, latent_seq_len) | |
| # Determine sigma_max for schedule | |
| sigma_max = init_noise_level if init_data is not None else 1.0 | |
| # Mix init_data with noise for variation/img2img | |
| # For k-diffusion v-diffusion samplers, init_data is passed through to sample_k | |
| # which handles mixing internally with its own sigma scaling | |
| k_diff_sampler_types = {"k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", | |
| "k-dpm-fast", "k-dpm-adaptive", "dpmpp-2m-sde", "dpmpp-3m-sde", "dpmpp-2m"} | |
| if init_data is not None: | |
| noise = init_data * (1 - sigma_max) + noise * sigma_max | |
| # Build common sampler kwargs (conditioning + model-level params only). | |
| # disable_tqdm and callback are passed explicitly to samplers that use them, | |
| # not included here, to avoid leaking into model forward() calls. | |
| common_kwargs = { | |
| **cond_inputs, | |
| "cfg_scale": cfg_scale, | |
| "batch_cfg": batch_cfg, | |
| "rescale_cfg": rescale_cfg, | |
| "padding_mask": padding_mask, | |
| "apg_scale": apg_scale, | |
| **sampler_kwargs | |
| } | |
| if diffusion_objective in ["rectified_flow", "rf_denoiser"]: | |
| # Remove v-diffusion-specific kwargs that don't apply to RF | |
| common_kwargs.pop("sigma_min", None) | |
| common_kwargs.pop("sigma_max", None) | |
| common_kwargs.pop("rho", None) | |
| # Build schedule | |
| sigmas = build_schedule( | |
| steps=steps, sigma_max=sigma_max, | |
| dist_shift=dist_shift, effective_seq_len=effective_seq_len, | |
| fallback_seq_len=latent_seq_len, include_endpoint=True, device=device | |
| ) | |
| # Route to sampler | |
| if sampler_type == "euler": | |
| sampled = sample_discrete_euler(model, noise, sigmas=sigmas, callback=callback, disable_tqdm=disable_tqdm, **common_kwargs) | |
| elif sampler_type == "rk4": | |
| sampled = sample_rk4(model, noise, sigmas=sigmas, callback=callback, disable_tqdm=disable_tqdm, **common_kwargs) | |
| elif sampler_type == "dpmpp": | |
| sampled = sample_flow_dpmpp(model, noise, sigmas=sigmas, callback=callback, disable_tqdm=disable_tqdm, **common_kwargs) | |
| elif sampler_type == "pingpong": | |
| sampled = sample_flow_pingpong(model, noise, sigmas=sigmas, callback=callback, disable_tqdm=disable_tqdm, **common_kwargs) | |
| else: | |
| raise ValueError(f"Unknown sampler_type for {diffusion_objective}: {sampler_type}") | |
| else: | |
| raise ValueError(f"Unknown diffusion_objective: {diffusion_objective}") | |
| # Decode if requested | |
| if decode and pretransform is not None: | |
| sampled = sampled.to(next(pretransform.parameters()).dtype) | |
| sampled = pretransform.decode(sampled, chunked=chunked_decode) | |
| # Zero out audio beyond valid region (padding positions decode to garbage) | |
| if padding_mask is not None: | |
| audio_mask = padding_mask.unsqueeze(1).repeat_interleave(downsampling_ratio, dim=-1) | |
| # Trim or pad to match sampled length | |
| if audio_mask.shape[-1] > sampled.shape[-1]: | |
| audio_mask = audio_mask[..., :sampled.shape[-1]] | |
| elif audio_mask.shape[-1] < sampled.shape[-1]: | |
| audio_mask = torch.nn.functional.pad(audio_mask, (0, sampled.shape[-1] - audio_mask.shape[-1]), value=False) | |
| sampled = sampled * audio_mask.to(sampled.dtype) | |
| return sampled |