| |
| |
| |
| |
|
|
| import torch |
| import torchaudio |
| import json |
| import os |
| import numpy as np |
| import librosa |
| from torch.nn.utils.rnn import pad_sequence |
| from modules import whisper_extractor as whisper |
|
|
|
|
| class TorchaudioDataset(torch.utils.data.Dataset): |
| def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): |
| """ |
| Args: |
| cfg: config |
| dataset: dataset name |
| |
| """ |
| assert isinstance(dataset, str) |
|
|
| self.sr = sr |
| self.cfg = cfg |
|
|
| if metadata is None: |
| self.train_metadata_path = os.path.join( |
| cfg.preprocess.processed_dir, dataset, cfg.preprocess.train_file |
| ) |
| self.valid_metadata_path = os.path.join( |
| cfg.preprocess.processed_dir, dataset, cfg.preprocess.valid_file |
| ) |
| self.metadata = self.get_metadata() |
| else: |
| self.metadata = metadata |
|
|
| if accelerator is not None: |
| self.device = accelerator.device |
| elif torch.cuda.is_available(): |
| self.device = torch.device("cuda") |
| else: |
| self.device = torch.device("cpu") |
|
|
| def get_metadata(self): |
| metadata = [] |
| with open(self.train_metadata_path, "r", encoding="utf-8") as t: |
| metadata.extend(json.load(t)) |
| with open(self.valid_metadata_path, "r", encoding="utf-8") as v: |
| metadata.extend(json.load(v)) |
| return metadata |
|
|
| def __len__(self): |
| return len(self.metadata) |
|
|
| def __getitem__(self, index): |
| utt_info = self.metadata[index] |
| wav_path = utt_info["Path"] |
|
|
| wav, sr = torchaudio.load(wav_path) |
|
|
| |
| if sr != self.sr: |
| wav = torchaudio.functional.resample(wav, sr, self.sr) |
| |
| if wav.shape[0] > 1: |
| wav = torch.mean(wav, dim=0, keepdim=True) |
| assert wav.shape[0] == 1 |
| wav = wav.squeeze(0) |
| |
| length = wav.shape[0] |
| |
| return utt_info, wav, length |
|
|
|
|
| class LibrosaDataset(TorchaudioDataset): |
| def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): |
| super().__init__(cfg, dataset, sr, accelerator, metadata) |
|
|
| def __getitem__(self, index): |
| utt_info = self.metadata[index] |
| wav_path = utt_info["Path"] |
|
|
| wav, _ = librosa.load(wav_path, sr=self.sr) |
| |
| wav = torch.from_numpy(wav) |
|
|
| |
| length = wav.shape[0] |
| return utt_info, wav, length |
|
|
|
|
| class FFmpegDataset(TorchaudioDataset): |
| def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): |
| super().__init__(cfg, dataset, sr, accelerator, metadata) |
|
|
| def __getitem__(self, index): |
| utt_info = self.metadata[index] |
| wav_path = utt_info["Path"] |
|
|
| |
| wav = whisper.load_audio(wav_path) |
| |
| wav = torch.from_numpy(wav) |
| |
| length = wav.shape[0] |
|
|
| return utt_info, wav, length |
|
|
|
|
| def collate_batch(batch_list): |
| """ |
| Args: |
| batch_list: list of (metadata, wav, length) |
| """ |
| metadata = [item[0] for item in batch_list] |
| |
| wavs = pad_sequence([item[1] for item in batch_list], batch_first=True) |
| lens = [item[2] for item in batch_list] |
|
|
| return metadata, wavs, lens |
|
|