| from typing import Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from librosa.util import pad_center |
| from scipy.signal import get_window |
|
|
|
|
| class STFT(torch.nn.Module): |
| def __init__( |
| self, |
| filter_length=1024, |
| hop_length=512, |
| win_length: Optional[int] = None, |
| window="hann", |
| use_torch_stft=True, |
| ): |
| """ |
| This module implements an STFT using 1D convolution and 1D transpose convolutions. |
| This is a bit tricky so there are some cases that probably won't work as working |
| out the same sizes before and after in all overlap add setups is tough. Right now, |
| this code should work with hop lengths that are half the filter length (50% overlap |
| between frames). |
| |
| Keyword Arguments: |
| filter_length {int} -- Length of filters used (default: {1024}) |
| hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512}) |
| win_length {[type]} -- Length of the window function applied to each frame (if not specified, it |
| equals the filter length). (default: {None}) |
| window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris) |
| (default: {'hann'}) |
| """ |
| super(STFT, self).__init__() |
| self.filter_length = filter_length |
| self.hop_length = hop_length |
| self.pad_amount = int(self.filter_length / 2) |
| self.win_length = win_length |
| self.hann_window = {} |
| self.use_torch_stft = use_torch_stft |
|
|
| if use_torch_stft: |
| return |
|
|
| fourier_basis = np.fft.fft(np.eye(self.filter_length)) |
|
|
| cutoff = int((self.filter_length / 2 + 1)) |
| fourier_basis = np.vstack( |
| [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] |
| ) |
| forward_basis = torch.FloatTensor(fourier_basis) |
| inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis)) |
|
|
| if win_length is None or not win_length: |
| win_length = filter_length |
| assert filter_length >= win_length |
|
|
| |
| fft_window = get_window(window, win_length, fftbins=True) |
| fft_window = pad_center(fft_window, size=filter_length) |
| fft_window = torch.from_numpy(fft_window).float() |
|
|
| |
| forward_basis *= fft_window |
| inverse_basis = (inverse_basis.T * fft_window).T |
|
|
| self.register_buffer("forward_basis", forward_basis.float()) |
| self.register_buffer("inverse_basis", inverse_basis.float()) |
| self.register_buffer("fft_window", fft_window.float()) |
|
|
| def __call__( |
| self, |
| input_data: torch.Tensor, |
| keyshift: int = 0, |
| speed: int = 1, |
| center: bool = True, |
| ) -> torch.Tensor: |
| return super().__call__(input_data, keyshift, speed, center) |
|
|
| def transform( |
| self, |
| input_data: torch.Tensor, |
| return_phase=False, |
| ) -> Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]]: |
| """Take input data (audio) to STFT domain. |
| |
| Arguments: |
| input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) |
| |
| Returns: |
| magnitude {tensor} -- Magnitude of STFT with shape (num_batch, |
| num_frequencies, num_frames) |
| phase {tensor} -- Phase of STFT with shape (num_batch, |
| num_frequencies, num_frames) |
| """ |
| input_data = F.pad( |
| input_data, |
| (self.pad_amount, self.pad_amount), |
| mode="reflect", |
| ) |
| forward_transform = input_data.unfold( |
| 1, self.filter_length, self.hop_length |
| ).permute(0, 2, 1) |
| forward_transform = torch.matmul(self.forward_basis, forward_transform) |
| cutoff = int((self.filter_length / 2) + 1) |
| real_part = forward_transform[:, :cutoff, :] |
| imag_part = forward_transform[:, cutoff:, :] |
| magnitude = torch.sqrt(real_part**2 + imag_part**2) |
| if return_phase: |
| phase = torch.atan2(imag_part.data, real_part.data) |
| return magnitude, phase |
| else: |
| return magnitude |
|
|
| def inverse( |
| self, |
| magnitude: torch.Tensor, |
| phase: torch.Tensor, |
| ) -> torch.Tensor: |
| """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced |
| by the ```transform``` function. |
| |
| Arguments: |
| magnitude {tensor} -- Magnitude of STFT with shape (num_batch, |
| num_frequencies, num_frames) |
| phase {tensor} -- Phase of STFT with shape (num_batch, |
| num_frequencies, num_frames) |
| |
| Returns: |
| inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of |
| shape (num_batch, num_samples) |
| """ |
| cat = torch.cat( |
| [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 |
| ) |
| fold = torch.nn.Fold( |
| output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length), |
| kernel_size=(1, self.filter_length), |
| stride=(1, self.hop_length), |
| ) |
| inverse_transform = torch.matmul(self.inverse_basis, cat) |
| inverse_transform: torch.Tensor = fold(inverse_transform)[ |
| :, 0, 0, self.pad_amount : -self.pad_amount |
| ] |
| window_square_sum = ( |
| self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0) |
| ) |
| window_square_sum = fold(window_square_sum)[ |
| :, 0, 0, self.pad_amount : -self.pad_amount |
| ] |
| inverse_transform /= window_square_sum |
| return inverse_transform |
|
|
| def forward( |
| self, |
| input_data: torch.Tensor, |
| keyshift: int = 0, |
| speed: int = 1, |
| center: bool = True, |
| ) -> torch.Tensor: |
| factor = 2 ** (keyshift / 12) |
| n_fft_new = int(np.round(self.filter_length * factor)) |
| win_length_new = int(np.round(self.win_length * factor)) |
| hop_length_new = int(np.round(self.hop_length * speed)) |
| if self.use_torch_stft: |
| keyshift_key = str(keyshift) + "_" + str(input_data.device) |
| if keyshift_key not in self.hann_window: |
| self.hann_window[keyshift_key] = torch.hann_window( |
| self.win_length, |
| ).to(input_data.device) |
| fft = torch.stft( |
| input_data, |
| n_fft=n_fft_new, |
| hop_length=hop_length_new, |
| win_length=win_length_new, |
| window=self.hann_window[keyshift_key], |
| center=center, |
| return_complex=True, |
| ) |
| return torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) |
| return self.transform(input_data) |
| """Take input data (audio) to STFT domain and then back to audio. |
| |
| Arguments: |
| input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) |
| |
| Returns: |
| reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of |
| shape (num_batch, num_samples) |
| reconstruction = self.inverse( |
| self.transform(input_data, return_phase=True), |
| ) |
| return reconstruction |
| """ |
|
|