| import os |
| import h5py |
| import numpy as np |
| from typing import Any, Tuple |
| import torch |
| import random |
| from pytorch_lightning import LightningDataModule |
| import torchaudio |
| from torchaudio.functional import apply_codec |
| from torch.utils.data import DataLoader, Dataset |
| from typing import Any, Dict, Optional, Tuple |
|
|
| def compute_mch_rms_dB(mch_wav, fs=16000, energy_thresh=-50): |
| """Return the wav RMS calculated only in the active portions""" |
| mean_square = max(1e-20, torch.mean(mch_wav ** 2)) |
| return 10 * np.log10(mean_square) |
|
|
| def match2(x, d): |
| assert x.dim()==2, x.shape |
| assert d.dim()==2, d.shape |
| minlen = min(x.shape[-1], d.shape[-1]) |
| x, d = x[:,0:minlen], d[:,0:minlen] |
| Fx = torch.fft.rfft(x, dim=-1) |
| Fd = torch.fft.rfft(d, dim=-1) |
| Phi = Fd*Fx.conj() |
| Phi = Phi / (Phi.abs() + 1e-3) |
| Phi[:,0] = 0 |
| tmp = torch.fft.irfft(Phi, dim=-1) |
| tau = torch.argmax(tmp.abs(),dim=-1).tolist() |
| return tau |
|
|
| def codec_simu(wav, sr=16000, options={'bitrate':'random','compression':'random', 'complexity':'random', 'vbr':'random'}): |
|
|
| if options['bitrate'] == 'random': |
| options['bitrate'] = random.choice([24000, 32000, 48000, 64000, 96000, 128000]) |
| compression = int(options['bitrate']//1000) |
| param = {'format': "mp3", "compression": compression} |
| wav_encdec = apply_codec(wav, sr, **param) |
| if wav_encdec.shape[-1] >= wav.shape[-1]: |
| wav_encdec = wav_encdec[...,:wav.shape[-1]] |
| else: |
| wav_encdec = torch.cat([wav_encdec, wav[..., wav_encdec.shape[-1]:]], -1) |
| tau = match2(wav, wav_encdec) |
| wav_encdec = torch.roll(wav_encdec, -tau[0], -1) |
|
|
| return wav_encdec |
|
|
| def get_wav_files(root_dir): |
| wav_files = [] |
| for dirpath, dirnames, filenames in os.walk(root_dir): |
| for filename in filenames: |
| if filename.endswith('.wav'): |
| if "musdb18hq" in dirpath and "mixture" not in filename: |
| wav_files.append(os.path.join(dirpath, filename)) |
| elif "moisesdb" in dirpath: |
| wav_files.append(os.path.join(dirpath, filename)) |
| return wav_files |
|
|
| class MusdbMoisesdbDataset(Dataset): |
| def __init__( |
| self, |
| data_dir: str, |
| codec_type: str, |
| codec_options: dict, |
| sr: int = 16000, |
| segments: int = 10, |
| num_stems: int = 4, |
| snr_range: Tuple[int, int] = (-10, 10), |
| num_samples: int = 1000, |
| ) -> None: |
| |
| self.data_dir = data_dir |
| self.codec_type = codec_type |
| self.codec_options = codec_options |
| self.segments = int(segments * sr) |
| self.sr = sr |
| self.num_stems = num_stems |
| self.snr_range = snr_range |
| self.num_samples = num_samples |
| |
| self.instruments = [ |
| "bass", |
| "bowed_strings", |
| "drums", |
| "guitar", |
| "other", |
| "other_keys", |
| "other_plucked", |
| "percussion", |
| "piano", |
| "vocals", |
| "wind" |
| ] |
|
|
| def __len__(self) -> int: |
| return self.num_samples |
| |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| if random.random() > 0.5: |
| select_stems = random.randint(1, self.num_stems) |
| select_stems = random.choices(self.instruments, k=select_stems) |
| ori_wav = [] |
| for stem in select_stems: |
| h5path = random.choice(os.listdir(os.path.join(self.data_dir, stem))) |
| datas = h5py.File(os.path.join(self.data_dir, stem, h5path), 'r')['data'] |
| random_index = random.randint(0, datas.shape[0]-1) |
| music_wav = torch.FloatTensor(datas[random_index]) |
| start = random.randint(0, music_wav.shape[-1] - self.segments) |
| music_wav = music_wav[:, start:start+self.segments] |
| |
| rescale_snr = random.randint(self.snr_range[0], self.snr_range[1]) |
| music_wav = music_wav * np.sqrt(10**(rescale_snr/10)) |
| ori_wav.append(music_wav) |
| ori_wav = torch.stack(ori_wav).sum(0) |
| else: |
| h5path = random.choice(os.listdir(os.path.join(self.data_dir, "mixture"))) |
| datas = h5py.File(os.path.join(self.data_dir, "mixture", h5path), 'r')['data'] |
| random_index = random.randint(0, datas.shape[0]-1) |
| music_wav = torch.FloatTensor(datas[random_index]) |
| start = random.randint(0, music_wav.shape[-1] - self.segments) |
| ori_wav = music_wav[:, start:start+self.segments] |
| |
| codec_wav = codec_simu(ori_wav, sr=self.sr, options=self.codec_options) |
| |
| max_scale = max(ori_wav.abs().max(), codec_wav.abs().max()) |
| |
| if max_scale > 0: |
| ori_wav = ori_wav / max_scale |
| codec_wav = codec_wav / max_scale |
| |
| return ori_wav, codec_wav |
| |
|
|
| class MusdbMoisesdbEval(Dataset): |
| def __init__( |
| self, |
| data_dir: str |
| ) -> None: |
| self.data_path = os.listdir(data_dir) |
| self.data_path = [os.path.join(data_dir, i) for i in self.data_path] |
| |
| def __len__(self) -> int: |
| return len(self.data_path) |
| |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| ori_wav = torchaudio.load(self.data_path[idx]+"/ori_wav.wav")[0] |
| codec_wav = torchaudio.load(self.data_path[idx]+"/codec_wav.wav")[0] |
| |
| return ori_wav, codec_wav, self.data_path[idx] |
| |
| class MusdbMoisesdbDataModule(LightningDataModule): |
| def __init__( |
| self, |
| train_dir: str, |
| eval_dir: str, |
| codec_type: str, |
| codec_options: dict, |
| sr: int = 16000, |
| segments: int = 10, |
| num_stems: int = 4, |
| snr_range: Tuple[int, int] = (-10, 10), |
| num_samples: int = 1000, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| ) -> None: |
| super().__init__() |
| self.save_hyperparameters(logger=False) |
| |
| self.data_train: Optional[Dataset] = None |
| self.data_val: Optional[Dataset] = None |
| |
| def setup(self, stage: Optional[str] = None) -> None: |
| """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. |
| |
| This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and |
| `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after |
| `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to |
| `self.setup()` once the data is prepared and available for use. |
| |
| :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. |
| """ |
| |
| if not self.data_train and not self.data_val: |
| self.data_train = MusdbMoisesdbDataset( |
| data_dir=self.hparams.train_dir, |
| codec_type=self.hparams.codec_type, |
| codec_options=self.hparams.codec_options, |
| sr=self.hparams.sr, |
| segments=self.hparams.segments, |
| num_stems=self.hparams.num_stems, |
| snr_range=self.hparams.snr_range, |
| num_samples=self.hparams.num_samples, |
| ) |
| |
| self.data_val = MusdbMoisesdbEval( |
| data_dir=self.hparams.eval_dir |
| ) |
| |
| def train_dataloader(self) -> DataLoader: |
| return DataLoader( |
| self.data_train, |
| batch_size=self.hparams.batch_size, |
| num_workers=self.hparams.num_workers, |
| shuffle=True, |
| pin_memory=True, |
| ) |
| |
| def val_dataloader(self) -> DataLoader: |
| return DataLoader( |
| self.data_val, |
| batch_size=self.hparams.batch_size, |
| num_workers=self.hparams.num_workers, |
| shuffle=False, |
| pin_memory=True, |
| ) |
|
|