|
|
| import os.path as osp
|
| import random
|
| import numpy as np
|
| import random
|
| import soundfile as sf
|
| import librosa
|
|
|
| import torch
|
| import torchaudio
|
| import torch.utils.data
|
| import torch.distributed as dist
|
| from multiprocessing import Pool
|
|
|
| import logging
|
| logger = logging.getLogger(__name__)
|
| logger.setLevel(logging.DEBUG)
|
|
|
| import pandas as pd
|
|
|
| class TextCleaner:
|
| def __init__(self, symbol_dict, debug=True):
|
| self.word_index_dictionary = symbol_dict
|
| self.debug = debug
|
| def __call__(self, text):
|
| indexes = []
|
| for char in text:
|
| try:
|
| indexes.append(self.word_index_dictionary[char])
|
| except KeyError as e:
|
| if self.debug:
|
| print("\nWARNING UNKNOWN IPA CHARACTERS/LETTERS: ", char)
|
| print("To ignore set 'debug' to false in the config")
|
| continue
|
| return indexes
|
|
|
| np.random.seed(1)
|
| random.seed(1)
|
| SPECT_PARAMS = {
|
| "n_fft": 2048,
|
| "win_length": 1200,
|
| "hop_length": 300
|
| }
|
| MEL_PARAMS = {
|
| "n_mels": 80,
|
| }
|
|
|
| to_mel = torchaudio.transforms.MelSpectrogram(
|
| n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
|
| mean, std = -4, 4
|
|
|
| def preprocess(wave):
|
| wave_tensor = torch.from_numpy(wave).float()
|
| mel_tensor = to_mel(wave_tensor)
|
| mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
|
| return mel_tensor
|
|
|
| class FilePathDataset(torch.utils.data.Dataset):
|
| def __init__(self,
|
| data_list,
|
| root_path,
|
| symbol_dict,
|
| sr=24000,
|
| data_augmentation=False,
|
| validation=False,
|
| debug=True
|
| ):
|
|
|
| _data_list = [l.strip().split('|') for l in data_list]
|
| self.data_list = _data_list
|
| self.text_cleaner = TextCleaner(symbol_dict, debug)
|
| self.sr = sr
|
|
|
| self.df = pd.DataFrame(self.data_list)
|
|
|
| self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS)
|
|
|
| self.mean, self.std = -4, 4
|
| self.data_augmentation = data_augmentation and (not validation)
|
| self.max_mel_length = 192
|
|
|
| self.root_path = root_path
|
|
|
| def __len__(self):
|
| return len(self.data_list)
|
|
|
| def __getitem__(self, idx):
|
| data = self.data_list[idx]
|
| path = data[0]
|
|
|
| wave, text_tensor = self._load_tensor(data)
|
|
|
| mel_tensor = preprocess(wave).squeeze()
|
|
|
| acoustic_feature = mel_tensor.squeeze()
|
| length_feature = acoustic_feature.size(1)
|
| acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
|
|
|
| return acoustic_feature, text_tensor, path, wave
|
|
|
| def _load_tensor(self, data):
|
| wave_path, text = data
|
| wave, sr = sf.read(osp.join(self.root_path, wave_path))
|
| if wave.shape[-1] == 2:
|
| wave = wave[:, 0].squeeze()
|
| if sr != 24000:
|
| wave = librosa.resample(wave, orig_sr=sr, target_sr=24000)
|
| print(wave_path, sr)
|
|
|
|
|
| wave = np.concatenate([np.zeros([12000]), wave, np.zeros([12000])], axis=0)
|
|
|
| text = self.text_cleaner(text)
|
|
|
| text.insert(0, 0)
|
| text.append(0)
|
|
|
| text = torch.LongTensor(text)
|
|
|
| return wave, text
|
|
|
| def _load_data(self, data):
|
| wave, text_tensor = self._load_tensor(data)
|
| mel_tensor = preprocess(wave).squeeze()
|
|
|
| mel_length = mel_tensor.size(1)
|
| if mel_length > self.max_mel_length:
|
| random_start = np.random.randint(0, mel_length - self.max_mel_length)
|
| mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length]
|
|
|
| return mel_tensor
|
|
|
|
|
| class Collater(object):
|
| """
|
| Args:
|
| adaptive_batch_size (bool): if true, decrease batch size when long data comes.
|
| """
|
|
|
| def __init__(self, return_wave=False):
|
| self.text_pad_index = 0
|
| self.min_mel_length = 192
|
| self.max_mel_length = 192
|
| self.return_wave = return_wave
|
|
|
|
|
| def __call__(self, batch):
|
| batch_size = len(batch)
|
|
|
|
|
| lengths = [b[0].shape[1] for b in batch]
|
| batch_indexes = np.argsort(lengths)[::-1]
|
| batch = [batch[bid] for bid in batch_indexes]
|
|
|
| nmels = batch[0][0].size(0)
|
| max_mel_length = max([b[0].shape[1] for b in batch])
|
| max_text_length = max([b[1].shape[0] for b in batch])
|
|
|
| mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
|
| texts = torch.zeros((batch_size, max_text_length)).long()
|
|
|
| input_lengths = torch.zeros(batch_size).long()
|
| output_lengths = torch.zeros(batch_size).long()
|
| paths = ['' for _ in range(batch_size)]
|
| waves = [None for _ in range(batch_size)]
|
|
|
| for bid, (mel, text, path, wave) in enumerate(batch):
|
| mel_size = mel.size(1)
|
| text_size = text.size(0)
|
| mels[bid, :, :mel_size] = mel
|
| texts[bid, :text_size] = text
|
| input_lengths[bid] = text_size
|
| output_lengths[bid] = mel_size
|
| paths[bid] = path
|
|
|
| waves[bid] = wave
|
|
|
| return waves, texts, input_lengths, mels, output_lengths
|
|
|
|
|
| def get_length(wave_path, root_path):
|
| info = sf.info(osp.join(root_path, wave_path))
|
| return info.frames * (24000 / info.samplerate)
|
|
|
| def build_dataloader(path_list,
|
| root_path,
|
| symbol_dict,
|
| validation=False,
|
| batch_size=4,
|
| num_workers=1,
|
| device='cpu',
|
| collate_config={},
|
| dataset_config={}):
|
|
|
| dataset = FilePathDataset(path_list, root_path, symbol_dict, validation=validation, **dataset_config)
|
| collate_fn = Collater(**collate_config)
|
|
|
| print("Getting sample lengths...")
|
|
|
| num_processes = num_workers * 2
|
| if num_processes != 0:
|
| list_of_tuples = [(d[0], root_path) for d in dataset.data_list]
|
| with Pool(processes=num_processes) as pool:
|
| sample_lengths = pool.starmap(get_length, list_of_tuples, chunksize=16)
|
| else:
|
| sample_lengths = []
|
| for d in dataset.data_list:
|
| sample_lengths.append(get_length(d[0], root_path))
|
|
|
| data_loader = torch.utils.data.DataLoader(
|
| dataset,
|
| num_workers=num_workers,
|
| batch_sampler=BatchSampler(
|
| sample_lengths,
|
| batch_size,
|
| shuffle=(not validation),
|
| drop_last=(not validation),
|
| num_replicas=1,
|
| rank=0,
|
| ),
|
| collate_fn=collate_fn,
|
| pin_memory=(device != "cpu"),
|
| )
|
|
|
| return data_loader
|
|
|
|
|
| class BatchSampler(torch.utils.data.Sampler):
|
| def __init__(
|
| self,
|
| sample_lengths,
|
| batch_sizes,
|
| num_replicas=None,
|
| rank=None,
|
| shuffle=True,
|
| drop_last=False,
|
| ):
|
| self.batch_sizes = batch_sizes
|
| if num_replicas is None:
|
| self.num_replicas = dist.get_world_size()
|
| else:
|
| self.num_replicas = num_replicas
|
| if rank is None:
|
| self.rank = dist.get_rank()
|
| else:
|
| self.rank = rank
|
| self.shuffle = shuffle
|
| self.drop_last = drop_last
|
|
|
| self.time_bins = {}
|
| self.epoch = 0
|
| self.total_len = 0
|
| self.last_bin = None
|
|
|
| for i in range(len(sample_lengths)):
|
| bin_num = self.get_time_bin(sample_lengths[i])
|
| if bin_num != -1:
|
| if bin_num not in self.time_bins:
|
| self.time_bins[bin_num] = []
|
| self.time_bins[bin_num].append(i)
|
|
|
| for key in self.time_bins.keys():
|
| val = self.time_bins[key]
|
| total_batch = self.batch_sizes * num_replicas
|
| self.total_len += len(val) // total_batch
|
| if not self.drop_last and len(val) % total_batch != 0:
|
| self.total_len += 1
|
|
|
| def __iter__(self):
|
| sampler_order = list(self.time_bins.keys())
|
| sampler_indices = []
|
|
|
| if self.shuffle:
|
| sampler_indices = torch.randperm(len(sampler_order)).tolist()
|
| else:
|
| sampler_indices = list(range(len(sampler_order)))
|
|
|
| for index in sampler_indices:
|
| key = sampler_order[index]
|
| current_bin = self.time_bins[key]
|
| dist = torch.utils.data.distributed.DistributedSampler(
|
| current_bin,
|
| num_replicas=self.num_replicas,
|
| rank=self.rank,
|
| shuffle=self.shuffle,
|
| drop_last=self.drop_last,
|
| )
|
| dist.set_epoch(self.epoch)
|
| sampler = torch.utils.data.sampler.BatchSampler(
|
| dist, self.batch_sizes, self.drop_last
|
| )
|
| for item_list in sampler:
|
| self.last_bin = key
|
| yield [current_bin[i] for i in item_list]
|
|
|
| def __len__(self):
|
| return self.total_len
|
|
|
| def set_epoch(self, epoch):
|
| self.epoch = epoch
|
|
|
| def get_time_bin(self, sample_count):
|
| result = -1
|
| frames = sample_count // 300
|
| if frames >= 20:
|
| result = (frames - 20) // 20
|
| return result |