| import torch |
| import torch.nn.functional as F |
| from audiotools import AudioSignal |
| from audiotools import STFTParams |
| from torch import nn |
| import typing |
| from typing import List |
|
|
| class L1Loss(nn.L1Loss): |
| """L1 Loss between AudioSignals. Defaults |
| to comparing ``audio_data``, but any |
| attribute of an AudioSignal can be used. |
| |
| Parameters |
| ---------- |
| attribute : str, optional |
| Attribute of signal to compare, defaults to ``audio_data``. |
| weight : float, optional |
| Weight of this loss, defaults to 1.0. |
| |
| Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py |
| """ |
|
|
| def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): |
| self.attribute = attribute |
| self.weight = weight |
| super().__init__(**kwargs) |
|
|
| def forward(self, x: AudioSignal, y: AudioSignal): |
| """ |
| Parameters |
| ---------- |
| x : AudioSignal |
| Estimate AudioSignal |
| y : AudioSignal |
| Reference AudioSignal |
| |
| Returns |
| ------- |
| torch.Tensor |
| L1 loss between AudioSignal attributes. |
| """ |
| if isinstance(x, AudioSignal): |
| x = getattr(x, self.attribute) |
| y = getattr(y, self.attribute) |
| return super().forward(x, y) |
|
|
|
|
| class SISDRLoss(nn.Module): |
| """ |
| Computes the Scale-Invariant Source-to-Distortion Ratio between a batch |
| of estimated and reference audio signals or aligned features. |
| |
| Parameters |
| ---------- |
| scaling : int, optional |
| Whether to use scale-invariant (True) or |
| signal-to-noise ratio (False), by default True |
| reduction : str, optional |
| How to reduce across the batch (either 'mean', |
| 'sum', or none).], by default ' mean' |
| zero_mean : int, optional |
| Zero mean the references and estimates before |
| computing the loss, by default True |
| clip_min : int, optional |
| The minimum possible loss value. Helps network |
| to not focus on making already good examples better, by default None |
| weight : float, optional |
| Weight of this loss, defaults to 1.0. |
| |
| Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py |
| """ |
|
|
| def __init__( |
| self, |
| scaling: int = True, |
| reduction: str = "mean", |
| zero_mean: int = True, |
| clip_min: int = None, |
| weight: float = 1.0, |
| ): |
| self.scaling = scaling |
| self.reduction = reduction |
| self.zero_mean = zero_mean |
| self.clip_min = clip_min |
| self.weight = weight |
| super().__init__() |
|
|
| def forward(self, x: AudioSignal, y: AudioSignal): |
| eps = 1e-8 |
| |
| if isinstance(x, AudioSignal): |
| references = x.audio_data |
| estimates = y.audio_data |
| else: |
| references = x |
| estimates = y |
|
|
| nb = references.shape[0] |
| references = references.reshape(nb, 1, -1).permute(0, 2, 1) |
| estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) |
|
|
| |
| if self.zero_mean: |
| mean_reference = references.mean(dim=1, keepdim=True) |
| mean_estimate = estimates.mean(dim=1, keepdim=True) |
| else: |
| mean_reference = 0 |
| mean_estimate = 0 |
|
|
| _references = references - mean_reference |
| _estimates = estimates - mean_estimate |
|
|
| references_projection = (_references**2).sum(dim=-2) + eps |
| references_on_estimates = (_estimates * _references).sum(dim=-2) + eps |
|
|
| scale = ( |
| (references_on_estimates / references_projection).unsqueeze(1) |
| if self.scaling |
| else 1 |
| ) |
|
|
| e_true = scale * _references |
| e_res = _estimates - e_true |
|
|
| signal = (e_true**2).sum(dim=1) |
| noise = (e_res**2).sum(dim=1) |
| sdr = -10 * torch.log10(signal / noise + eps) |
|
|
| if self.clip_min is not None: |
| sdr = torch.clamp(sdr, min=self.clip_min) |
|
|
| if self.reduction == "mean": |
| sdr = sdr.mean() |
| elif self.reduction == "sum": |
| sdr = sdr.sum() |
| return sdr |
|
|
|
|
| class MultiScaleSTFTLoss(nn.Module): |
| """Computes the multi-scale STFT loss from [1]. |
| |
| Parameters |
| ---------- |
| window_lengths : List[int], optional |
| Length of each window of each STFT, by default [2048, 512] |
| loss_fn : typing.Callable, optional |
| How to compare each loss, by default nn.L1Loss() |
| clamp_eps : float, optional |
| Clamp on the log magnitude, below, by default 1e-5 |
| mag_weight : float, optional |
| Weight of raw magnitude portion of loss, by default 1.0 |
| log_weight : float, optional |
| Weight of log magnitude portion of loss, by default 1.0 |
| pow : float, optional |
| Power to raise magnitude to before taking log, by default 2.0 |
| weight : float, optional |
| Weight of this loss, by default 1.0 |
| match_stride : bool, optional |
| Whether to match the stride of convolutional layers, by default False |
| |
| References |
| ---------- |
| |
| 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. |
| "DDSP: Differentiable Digital Signal Processing." |
| International Conference on Learning Representations. 2019. |
| |
| Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py |
| """ |
|
|
| def __init__( |
| self, |
| window_lengths: List[int] = [2048, 512], |
| loss_fn: typing.Callable = nn.L1Loss(), |
| clamp_eps: float = 1e-5, |
| mag_weight: float = 1.0, |
| log_weight: float = 1.0, |
| pow: float = 2.0, |
| weight: float = 1.0, |
| match_stride: bool = False, |
| window_type: str = None, |
| ): |
| super().__init__() |
| self.stft_params = [ |
| STFTParams( |
| window_length=w, |
| hop_length=w // 4, |
| match_stride=match_stride, |
| window_type=window_type, |
| ) |
| for w in window_lengths |
| ] |
| self.loss_fn = loss_fn |
| self.log_weight = log_weight |
| self.mag_weight = mag_weight |
| self.clamp_eps = clamp_eps |
| self.weight = weight |
| self.pow = pow |
|
|
| def forward(self, x: AudioSignal, y: AudioSignal): |
| """Computes multi-scale STFT between an estimate and a reference |
| signal. |
| |
| Parameters |
| ---------- |
| x : AudioSignal |
| Estimate signal |
| y : AudioSignal |
| Reference signal |
| |
| Returns |
| ------- |
| torch.Tensor |
| Multi-scale STFT loss. |
| """ |
| loss = 0.0 |
| for s in self.stft_params: |
| x.stft(s.window_length, s.hop_length, s.window_type) |
| y.stft(s.window_length, s.hop_length, s.window_type) |
| loss += self.log_weight * self.loss_fn( |
| x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), |
| y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), |
| ) |
| loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) |
| return loss |
|
|
|
|
| class MelSpectrogramLoss(nn.Module): |
| """Compute distance between mel spectrograms. Can be used |
| in a multi-scale way. |
| |
| Parameters |
| ---------- |
| n_mels : List[int] |
| Number of mels per STFT, by default [150, 80], |
| window_lengths : List[int], optional |
| Length of each window of each STFT, by default [2048, 512] |
| loss_fn : typing.Callable, optional |
| How to compare each loss, by default nn.L1Loss() |
| clamp_eps : float, optional |
| Clamp on the log magnitude, below, by default 1e-5 |
| mag_weight : float, optional |
| Weight of raw magnitude portion of loss, by default 1.0 |
| log_weight : float, optional |
| Weight of log magnitude portion of loss, by default 1.0 |
| pow : float, optional |
| Power to raise magnitude to before taking log, by default 2.0 |
| weight : float, optional |
| Weight of this loss, by default 1.0 |
| match_stride : bool, optional |
| Whether to match the stride of convolutional layers, by default False |
| |
| Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py |
| """ |
|
|
| def __init__( |
| self, |
| n_mels: List[int] = [150, 80], |
| window_lengths: List[int] = [2048, 512], |
| loss_fn: typing.Callable = nn.L1Loss(), |
| clamp_eps: float = 1e-5, |
| mag_weight: float = 1.0, |
| log_weight: float = 1.0, |
| pow: float = 2.0, |
| weight: float = 1.0, |
| match_stride: bool = False, |
| mel_fmin: List[float] = [0.0, 0.0], |
| mel_fmax: List[float] = [None, None], |
| window_type: str = None, |
| ): |
| super().__init__() |
| self.stft_params = [ |
| STFTParams( |
| window_length=w, |
| hop_length=w // 4, |
| match_stride=match_stride, |
| window_type=window_type, |
| ) |
| for w in window_lengths |
| ] |
| self.n_mels = n_mels |
| self.loss_fn = loss_fn |
| self.clamp_eps = clamp_eps |
| self.log_weight = log_weight |
| self.mag_weight = mag_weight |
| self.weight = weight |
| self.mel_fmin = mel_fmin |
| self.mel_fmax = mel_fmax |
| self.pow = pow |
|
|
| def forward(self, x: AudioSignal, y: AudioSignal): |
| """Computes mel loss between an estimate and a reference |
| signal. |
| |
| Parameters |
| ---------- |
| x : AudioSignal |
| Estimate signal |
| y : AudioSignal |
| Reference signal |
| |
| Returns |
| ------- |
| torch.Tensor |
| Mel loss. |
| """ |
| loss = 0.0 |
| for n_mels, fmin, fmax, s in zip( |
| self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params |
| ): |
| kwargs = { |
| "window_length": s.window_length, |
| "hop_length": s.hop_length, |
| "window_type": s.window_type, |
| } |
| x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) |
| y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) |
|
|
| loss += self.log_weight * self.loss_fn( |
| x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), |
| y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), |
| ) |
| loss += self.mag_weight * self.loss_fn(x_mels, y_mels) |
| return loss |
|
|
|
|
| class GANLoss(nn.Module): |
| """ |
| Computes a discriminator loss, given a discriminator on |
| generated waveforms/spectrograms compared to ground truth |
| waveforms/spectrograms. Computes the loss for both the |
| discriminator and the generator in separate functions. |
| """ |
|
|
| def __init__(self, discriminator): |
| super().__init__() |
| self.discriminator = discriminator |
|
|
| def forward(self, fake, real): |
| d_fake = self.discriminator(fake.audio_data) |
| d_real = self.discriminator(real.audio_data) |
| return d_fake, d_real |
|
|
| def discriminator_loss(self, fake, real): |
| d_fake, d_real = self.forward(fake.clone().detach(), real) |
|
|
| loss_d = 0 |
| for x_fake, x_real in zip(d_fake, d_real): |
| loss_d += torch.mean(x_fake[-1] ** 2) |
| loss_d += torch.mean((1 - x_real[-1]) ** 2) |
| return loss_d |
|
|
| def generator_loss(self, fake, real): |
| d_fake, d_real = self.forward(fake, real) |
|
|
| loss_g = 0 |
| for x_fake in d_fake: |
| loss_g += torch.mean((1 - x_fake[-1]) ** 2) |
|
|
| loss_feature = 0 |
|
|
| for i in range(len(d_fake)): |
| for j in range(len(d_fake[i]) - 1): |
| loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) |
| return loss_g, loss_feature |
| |
| |
|
|