owenisas's picture
Vendor stable-audio-3 for ZeroGPU
6215e7d verified
import math
import random
import torch
from torch import nn
from typing import Optional, Tuple
from torchaudio import transforms as T
class PadCrop(nn.Module):
def __init__(self, n_samples, randomize=True):
super().__init__()
self.n_samples = n_samples
self.randomize = randomize
def __call__(self, signal):
n, s = signal.shape
start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
end = start + self.n_samples
output = signal.new_zeros([n, self.n_samples])
output[:, :min(s, self.n_samples)] = signal[:, start:end]
return output
class PadCrop_Normalized_T(nn.Module):
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True, pad: bool = True):
super().__init__()
self.n_samples = n_samples
self.sample_rate = sample_rate
self.randomize = randomize
self.pad = pad
def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int, torch.Tensor]:
n_channels, n_samples = source.shape
# Calculate bounds and offset
upper_bound = max(0, n_samples - self.n_samples)
offset = 0
if self.randomize and n_samples > self.n_samples:
offset = random.randint(0, upper_bound)
# Calculate normalized times
norm_denom = upper_bound + self.n_samples
t_start = offset / norm_denom
t_end = (offset + self.n_samples) / norm_denom
# Calculate timing info
seconds_start = math.floor(offset / self.sample_rate)
seconds_total = math.ceil(n_samples / self.sample_rate)
# Optimize for different cases
if n_samples >= self.n_samples:
# No padding needed - use view (zero-copy)
chunk = source[:, offset:offset + self.n_samples]
# Create full mask efficiently
padding_mask = torch.ones(self.n_samples, dtype=source.dtype, device=source.device)
elif not self.pad:
# No padding mode - return audio at natural length
chunk = source
padding_mask = torch.ones(n_samples, dtype=source.dtype, device=source.device)
else:
# Padding needed - create chunk and fill in-place
chunk = torch.zeros(n_channels, self.n_samples, dtype=source.dtype, device=source.device)
chunk[:, :n_samples] = source # Use in-place assignment
# Create padding mask in-place
padding_mask = torch.zeros(self.n_samples, dtype=source.dtype, device=source.device)
padding_mask[:n_samples] = 1
return (
chunk,
t_start,
t_end,
seconds_start,
seconds_total,
padding_mask
)
def strip_trailing_silence(audio, sample_rate, threshold_db=-60, min_silence_duration=0.1):
"""Strip silence from the end of an audio tensor.
Args:
audio: tensor [channels, samples]
sample_rate: audio sample rate
threshold_db: dB threshold below which audio is considered silent
min_silence_duration: minimum trailing silence duration in seconds to strip
Returns:
Truncated audio tensor [channels, trimmed_samples], or original if no significant trailing silence
"""
n_samples = audio.shape[-1]
hop_length = max(1, int(sample_rate * 0.01)) # 10ms frames
min_silence_samples = int(sample_rate * min_silence_duration)
n_frames = n_samples // hop_length
if n_frames == 0:
return audio
# Work in float32 for precision
audio_f = audio.float()
# Reshape into frames and compute max absolute amplitude per frame across channels
trimmed = audio_f[:, :n_frames * hop_length]
frames = trimmed.reshape(audio_f.shape[0], n_frames, hop_length)
frame_peak = frames.abs().amax(dim=(0, 2)) # [n_frames] - max across channels and samples
frame_db = 20 * torch.log10(frame_peak + 1e-10)
# Find last frame above threshold
above_thresh = (frame_db > threshold_db).nonzero(as_tuple=True)[0]
if len(above_thresh) == 0:
# Entire audio is silent
return audio[:, :0]
last_active_frame = above_thresh[-1].item()
content_end = min((last_active_frame + 1) * hop_length, n_samples)
# Only strip if trailing silence is long enough
if (n_samples - content_end) < min_silence_samples:
return audio
return audio[:, :content_end]
class PhaseFlipper(nn.Module):
"Randomly invert the phase of a signal"
def __init__(self, p=0.5):
super().__init__()
self.p = p
def __call__(self, signal):
return -signal if (random.random() < self.p) else signal
class Mono(nn.Module):
def __call__(self, signal):
return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
class Stereo(nn.Module):
def __call__(self, signal):
signal_shape = signal.shape
# Check if it's mono
if len(signal_shape) == 1: # s -> 2, s
signal = signal.unsqueeze(0).repeat(2, 1)
elif len(signal_shape) == 2:
if signal_shape[0] == 1: #1, s -> 2, s
signal = signal.repeat(2, 1)
elif signal_shape[0] > 2: #?, s -> 2,s
signal = signal[:2, :]
return signal
class VolumeNorm(nn.Module):
"Volume normalization and augmentation of a signal [LUFS standard]"
def __init__(self, params=[-16, 2], sample_rate=16000, energy_threshold=1e-6):
super().__init__()
self.loudness = T.Loudness(sample_rate)
self.value = params[0]
self.gain_range = [-params[1], params[1]]
self.energy_threshold = energy_threshold
def __call__(self, signal):
"""
signal: torch.Tensor [channels, time]
"""
# avoid do normalisation for silence
energy = torch.mean(signal**2)
if energy < self.energy_threshold:
return signal
input_loudness = self.loudness(signal)
# Generate a random target loudness within the specified range
target_loudness = self.value + (torch.rand(1).item() * (self.gain_range[1] - self.gain_range[0]) + self.gain_range[0])
delta_loudness = target_loudness - input_loudness
gain = torch.pow(10.0, delta_loudness / 20.0)
output = gain * signal
# Check for potentially clipped samples
if torch.max(torch.abs(output)) >= 1.0:
output = self.declip(output)
return output
def declip(self, signal):
"""
Declip the signal by scaling down if any samples are clipped
"""
max_val = torch.max(torch.abs(signal))
if max_val > 1.0:
signal = signal / max_val
signal *= 0.95
return signal
def create_padding_mask_from_lengths(
valid_lengths: torch.Tensor,
total_seq_len: int,
) -> torch.Tensor:
"""
Create a boolean padding mask from per-batch valid sequence lengths.
Args:
valid_lengths: Tensor of shape (batch_size,) with valid length per sample
total_seq_len: Total sequence length of the latent
Returns:
Boolean tensor of shape (batch_size, total_seq_len) where True = valid, False = padding
"""
device = valid_lengths.device
positions = torch.arange(total_seq_len, device=device).unsqueeze(0) # (1, T)
padding_mask = positions < valid_lengths.unsqueeze(1) # (B, T)
return padding_mask
def compute_effective_seq_len_from_conditioning(
conditioning: list,
sample_rate: int,
downsampling_ratio: int = 1,
device: str = "cuda"
) -> Optional[torch.Tensor]:
"""
Compute effective sequence lengths from seconds_total in conditioning dicts.
Args:
conditioning: List of conditioning dicts, one per batch element
sample_rate: Audio sample rate
downsampling_ratio: Pretransform downsampling ratio (1 if no pretransform)
device: Device to place the tensor on
Returns:
Tensor of shape (batch_size,) with effective sequence lengths in latent space,
or None if seconds_total is not present in conditioning
"""
if conditioning is None:
return None
# Check if seconds_total is present in any conditioning dict
if not any("seconds_total" in c for c in conditioning):
return None
effective_lengths = []
for cond_dict in conditioning:
if "seconds_total" in cond_dict:
seconds = cond_dict["seconds_total"]
# Convert seconds to latent sequence length
audio_samples = int(seconds * sample_rate)
latent_length = math.ceil(audio_samples / downsampling_ratio)
effective_lengths.append(latent_length)
else:
# If seconds_total not present for this item, use None as marker
effective_lengths.append(None)
# If any item is missing seconds_total, return None to fall back to full length
if any(l is None for l in effective_lengths):
return None
return torch.tensor(effective_lengths, dtype=torch.float32, device=device)