Spaces:
Running on Zero
Running on Zero
File size: 9,171 Bytes
6215e7d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 | import math
import random
import torch
from torch import nn
from typing import Optional, Tuple
from torchaudio import transforms as T
class PadCrop(nn.Module):
def __init__(self, n_samples, randomize=True):
super().__init__()
self.n_samples = n_samples
self.randomize = randomize
def __call__(self, signal):
n, s = signal.shape
start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
end = start + self.n_samples
output = signal.new_zeros([n, self.n_samples])
output[:, :min(s, self.n_samples)] = signal[:, start:end]
return output
class PadCrop_Normalized_T(nn.Module):
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True, pad: bool = True):
super().__init__()
self.n_samples = n_samples
self.sample_rate = sample_rate
self.randomize = randomize
self.pad = pad
def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int, torch.Tensor]:
n_channels, n_samples = source.shape
# Calculate bounds and offset
upper_bound = max(0, n_samples - self.n_samples)
offset = 0
if self.randomize and n_samples > self.n_samples:
offset = random.randint(0, upper_bound)
# Calculate normalized times
norm_denom = upper_bound + self.n_samples
t_start = offset / norm_denom
t_end = (offset + self.n_samples) / norm_denom
# Calculate timing info
seconds_start = math.floor(offset / self.sample_rate)
seconds_total = math.ceil(n_samples / self.sample_rate)
# Optimize for different cases
if n_samples >= self.n_samples:
# No padding needed - use view (zero-copy)
chunk = source[:, offset:offset + self.n_samples]
# Create full mask efficiently
padding_mask = torch.ones(self.n_samples, dtype=source.dtype, device=source.device)
elif not self.pad:
# No padding mode - return audio at natural length
chunk = source
padding_mask = torch.ones(n_samples, dtype=source.dtype, device=source.device)
else:
# Padding needed - create chunk and fill in-place
chunk = torch.zeros(n_channels, self.n_samples, dtype=source.dtype, device=source.device)
chunk[:, :n_samples] = source # Use in-place assignment
# Create padding mask in-place
padding_mask = torch.zeros(self.n_samples, dtype=source.dtype, device=source.device)
padding_mask[:n_samples] = 1
return (
chunk,
t_start,
t_end,
seconds_start,
seconds_total,
padding_mask
)
def strip_trailing_silence(audio, sample_rate, threshold_db=-60, min_silence_duration=0.1):
"""Strip silence from the end of an audio tensor.
Args:
audio: tensor [channels, samples]
sample_rate: audio sample rate
threshold_db: dB threshold below which audio is considered silent
min_silence_duration: minimum trailing silence duration in seconds to strip
Returns:
Truncated audio tensor [channels, trimmed_samples], or original if no significant trailing silence
"""
n_samples = audio.shape[-1]
hop_length = max(1, int(sample_rate * 0.01)) # 10ms frames
min_silence_samples = int(sample_rate * min_silence_duration)
n_frames = n_samples // hop_length
if n_frames == 0:
return audio
# Work in float32 for precision
audio_f = audio.float()
# Reshape into frames and compute max absolute amplitude per frame across channels
trimmed = audio_f[:, :n_frames * hop_length]
frames = trimmed.reshape(audio_f.shape[0], n_frames, hop_length)
frame_peak = frames.abs().amax(dim=(0, 2)) # [n_frames] - max across channels and samples
frame_db = 20 * torch.log10(frame_peak + 1e-10)
# Find last frame above threshold
above_thresh = (frame_db > threshold_db).nonzero(as_tuple=True)[0]
if len(above_thresh) == 0:
# Entire audio is silent
return audio[:, :0]
last_active_frame = above_thresh[-1].item()
content_end = min((last_active_frame + 1) * hop_length, n_samples)
# Only strip if trailing silence is long enough
if (n_samples - content_end) < min_silence_samples:
return audio
return audio[:, :content_end]
class PhaseFlipper(nn.Module):
"Randomly invert the phase of a signal"
def __init__(self, p=0.5):
super().__init__()
self.p = p
def __call__(self, signal):
return -signal if (random.random() < self.p) else signal
class Mono(nn.Module):
def __call__(self, signal):
return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
class Stereo(nn.Module):
def __call__(self, signal):
signal_shape = signal.shape
# Check if it's mono
if len(signal_shape) == 1: # s -> 2, s
signal = signal.unsqueeze(0).repeat(2, 1)
elif len(signal_shape) == 2:
if signal_shape[0] == 1: #1, s -> 2, s
signal = signal.repeat(2, 1)
elif signal_shape[0] > 2: #?, s -> 2,s
signal = signal[:2, :]
return signal
class VolumeNorm(nn.Module):
"Volume normalization and augmentation of a signal [LUFS standard]"
def __init__(self, params=[-16, 2], sample_rate=16000, energy_threshold=1e-6):
super().__init__()
self.loudness = T.Loudness(sample_rate)
self.value = params[0]
self.gain_range = [-params[1], params[1]]
self.energy_threshold = energy_threshold
def __call__(self, signal):
"""
signal: torch.Tensor [channels, time]
"""
# avoid do normalisation for silence
energy = torch.mean(signal**2)
if energy < self.energy_threshold:
return signal
input_loudness = self.loudness(signal)
# Generate a random target loudness within the specified range
target_loudness = self.value + (torch.rand(1).item() * (self.gain_range[1] - self.gain_range[0]) + self.gain_range[0])
delta_loudness = target_loudness - input_loudness
gain = torch.pow(10.0, delta_loudness / 20.0)
output = gain * signal
# Check for potentially clipped samples
if torch.max(torch.abs(output)) >= 1.0:
output = self.declip(output)
return output
def declip(self, signal):
"""
Declip the signal by scaling down if any samples are clipped
"""
max_val = torch.max(torch.abs(signal))
if max_val > 1.0:
signal = signal / max_val
signal *= 0.95
return signal
def create_padding_mask_from_lengths(
valid_lengths: torch.Tensor,
total_seq_len: int,
) -> torch.Tensor:
"""
Create a boolean padding mask from per-batch valid sequence lengths.
Args:
valid_lengths: Tensor of shape (batch_size,) with valid length per sample
total_seq_len: Total sequence length of the latent
Returns:
Boolean tensor of shape (batch_size, total_seq_len) where True = valid, False = padding
"""
device = valid_lengths.device
positions = torch.arange(total_seq_len, device=device).unsqueeze(0) # (1, T)
padding_mask = positions < valid_lengths.unsqueeze(1) # (B, T)
return padding_mask
def compute_effective_seq_len_from_conditioning(
conditioning: list,
sample_rate: int,
downsampling_ratio: int = 1,
device: str = "cuda"
) -> Optional[torch.Tensor]:
"""
Compute effective sequence lengths from seconds_total in conditioning dicts.
Args:
conditioning: List of conditioning dicts, one per batch element
sample_rate: Audio sample rate
downsampling_ratio: Pretransform downsampling ratio (1 if no pretransform)
device: Device to place the tensor on
Returns:
Tensor of shape (batch_size,) with effective sequence lengths in latent space,
or None if seconds_total is not present in conditioning
"""
if conditioning is None:
return None
# Check if seconds_total is present in any conditioning dict
if not any("seconds_total" in c for c in conditioning):
return None
effective_lengths = []
for cond_dict in conditioning:
if "seconds_total" in cond_dict:
seconds = cond_dict["seconds_total"]
# Convert seconds to latent sequence length
audio_samples = int(seconds * sample_rate)
latent_length = math.ceil(audio_samples / downsampling_ratio)
effective_lengths.append(latent_length)
else:
# If seconds_total not present for this item, use None as marker
effective_lengths.append(None)
# If any item is missing seconds_total, return None to fall back to full length
if any(l is None for l in effective_lengths):
return None
return torch.tensor(effective_lengths, dtype=torch.float32, device=device)
|