| import os |
| import re |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchaudio |
| import numpy as np |
| import pytorch_lightning as pl |
| import random |
| import librosa |
| from os.path import basename, exists, join |
| from torch.utils.data import Dataset, DataLoader |
| import hydra |
| import utils |
| import torchaudio |
| from transformers import AutoFeatureExtractor |
| from torchaudio.transforms import Resample |
| from tqdm import tqdm |
|
|
| class DataModule(pl.LightningDataModule): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| |
| ocwd = hydra.utils.get_original_cwd() |
| self.ocwd = ocwd |
|
|
| def get_loader(self, phase): |
| phase_cfg = self.cfg.dataset.get(phase) |
| batch_size = phase_cfg.batch_size |
| ds = FSDataset(phase, self.cfg) |
| |
|
|
| dl = DataLoader(ds, |
| batch_size=batch_size, |
| shuffle=phase_cfg.shuffle, |
| num_workers=8, |
| collate_fn=ds.collate_fn, |
| pin_memory=True, |
| persistent_workers=False) |
|
|
| return dl |
|
|
| def train_dataloader(self): |
| return self.get_loader('train') |
|
|
| def val_dataloader(self): |
| return self.get_loader('val') |
|
|
| def test_dataloader(self): |
| pass |
|
|
| class FSDataset(Dataset): |
| """Dataset batching wav, mel |
| and other acoustic features |
| |
| Args: |
| phase: train, val, test |
| cfg: hydra config |
| """ |
| def __init__(self, phase, cfg): |
| self.phase = phase |
| self.cfg = cfg |
| self.phase_cfg = cfg.dataset.get(phase) |
| self.ocwd = hydra.utils.get_original_cwd() |
| |
| self.sr = cfg.preprocess.audio.sr |
| |
| |
| self.filelist = self.get_filelist(self.phase_cfg.filelist) |
| self.min_audio_length = cfg.dataset.min_audio_length |
| self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") |
| self.resample_to_16k = Resample(24000, 16000) |
| |
| def __len__(self): |
| return len(self.filelist) |
|
|
| def load_wav(self, path): |
| wav, sr = librosa.load(path, sr=self.sr) |
| return wav |
|
|
| def get_filelist(self, fpath): |
| with open(fpath, 'r') as f: |
| |
| flist = [l.strip().split('\t')[0] for l in f if l.strip()] |
| return flist |
|
|
|
|
| def __getitem__(self, idx): |
| wavpath = self.filelist[idx] |
| |
| try: |
| wav, sr = torchaudio.load(wavpath) |
| except Exception as e: |
| print(f"Error loading {wavpath}: {e}") |
| wav = torch.zeros((1, self.min_audio_length)) |
| sr = self.sr |
| |
| if sr != 24000: |
| wav = Resample(sr, 24000)(wav) |
| |
| wav = wav[0,:] |
| length = wav.shape[0] |
| |
| if length < self.min_audio_length: |
| wav = F.pad(wav, (0, self.min_audio_length - length)) |
| length = wav.shape[0] |
| |
| i = random.randint(0, length - self.min_audio_length) |
| wav = wav[i:i + self.min_audio_length] |
| |
| |
| wav_16k = self.resample_to_16k(wav) |
| wav_16k_pad = F.pad(wav_16k, (160, 160)) |
| |
| feat = self.feature_extractor(wav_16k_pad, sampling_rate=16000, return_tensors="pt").data['input_features'].squeeze(0) |
| |
| out = { |
| 'wav': wav, |
| 'feat': feat, |
| } |
| |
| return out |
| |
| def collate_fn(self, bs): |
| wavs = [b['wav'] for b in bs] |
| wavs = torch.stack(wavs) |
| feats = [b['feat'] for b in bs] |
| feats = torch.stack(feats) |
| out = { |
| 'wav': wavs, |
| 'feats': feats, |
| |
| } |
| return out |
|
|
| @hydra.main(config_path='config', config_name='default', version_base=None) |
| def main(cfg): |
| data_module = DataModule(cfg) |
| train_loader = data_module.val_dataloader() |
| |
| valid_filelist = [] |
| |
| for batch_idx, batch in enumerate(tqdm(train_loader, desc="Processing batches", unit="batch")): |
| wavs = batch['wav'] |
|
|
| if __name__ == "__main__": |
| main() |