| import torch |
| import torchaudio |
| import torch.nn as nn |
| import pandas as pd |
| import random |
| from torch.utils.data import Dataset |
|
|
|
|
| class TSED_Val(Dataset): |
| def __init__(self, file_list, data_dir, |
| seg_length=10, sr=16000, |
| norm=True, mono=True, |
| **kwargs): |
|
|
| self.data_dir = data_dir |
| meta = pd.read_csv(file_list, sep='\t') |
| meta = meta[meta['duration'] != 0] |
| self.meta = meta |
| |
| self.seg_len = seg_length |
| self.sr = sr |
|
|
| self.norm = norm |
| self.mono = mono |
|
|
| def load_audio(self, audio_path): |
| y, sr = torchaudio.load(audio_path) |
| assert sr == self.sr |
|
|
| |
| if self.mono: |
| |
| y = torch.mean(y, dim=0, keepdim=True) |
| else: |
| if y.shape[0] == 1: |
| pass |
| elif y.shape[0] == 2: |
| |
| if random.choice([True, False]): |
| y = torch.mean(y, dim=0, keepdim=True) |
| else: |
| channel = random.choice([0, 1]) |
| y = y[channel, :].unsqueeze(0) |
| else: |
| raise ValueError("Unsupported number of channels: {}".format(y.shape[0])) |
|
|
| total_length = y.shape[-1] |
|
|
| start = 0 |
| end = min(start + self.seg_len * self.sr, total_length) |
|
|
| audio_clip = torch.zeros(self.seg_len * self.sr) |
| audio_clip[:end - start] = y[0, start:end] |
|
|
| if self.norm: |
| eps = 1e-9 |
| max_val = torch.max(torch.abs(audio_clip)) |
| audio_clip = audio_clip / (max_val + eps) |
| |
| return audio_clip |
|
|
| def __getitem__(self, index): |
| row = self.meta.iloc[index] |
| audio = self.load_audio(self.data_dir + row['filename']) |
| return audio, row['filename'] |
|
|
| def __len__(self): |
| return len(self.meta) |