| import torch |
| import torchaudio |
| import torch.nn as nn |
| import pandas as pd |
| import random |
| from torch.utils.data import Dataset |
| import ast |
|
|
|
|
| class TSED_AS(Dataset): |
| def __init__(self, data_dir, clap_dir, meta_dir, label_dir, class_list, |
| seg_length=10, sr=16000, label_sr=25, label_per_audio=[10, 10], |
| norm=True, mono=True, label_type='strong', debug=False, sample_method='random', |
| neg_removed_weight=0.25, |
| **kwargs): |
|
|
| self.data_dir = data_dir |
| self.clap_dir = clap_dir |
| meta = pd.read_csv(meta_dir) |
| meta = meta[meta['duration'] != 0] |
| self.meta = meta |
| if label_type == 'strong': |
| label = pd.read_csv(label_dir) |
| self.label = label |
| else: |
| self.label = None |
|
|
| self.label_per_audio = label_per_audio |
|
|
| self.class_list = pd.read_csv(class_list) |
| self.class_dict = dict(self.class_list.set_index('id')['label']) |
| |
| self.cls_ids = sorted(self.class_list['id'].unique().tolist()) |
| self.sample_method = sample_method |
|
|
| self.seg_len = seg_length |
| self.sr = sr |
| self.label_sr = label_sr |
| self.label_type = label_type |
|
|
| self.norm = norm |
| self.mono = mono |
|
|
| self.neg_removed_weight = neg_removed_weight |
|
|
| def load_audio(self, audio_path): |
| y, sr = torchaudio.load(audio_path) |
| assert sr == self.sr |
|
|
| |
| if self.mono: |
| |
| y = torch.mean(y, dim=0, keepdim=True) |
| else: |
| if y.shape[0] == 1: |
| pass |
| elif y.shape[0] == 2: |
| |
| if random.choice([True, False]): |
| y = torch.mean(y, dim=0, keepdim=True) |
| else: |
| channel = random.choice([0, 1]) |
| y = y[channel, :].unsqueeze(0) |
| else: |
| raise ValueError("Unsupported number of channels: {}".format(y.shape[0])) |
|
|
| total_length = y.shape[-1] |
|
|
| start = 0 |
| end = min(start + self.seg_len * self.sr, total_length) |
|
|
| audio_clip = torch.zeros(self.seg_len * self.sr) |
| audio_clip[:end - start] = y[0, start:end] |
|
|
| if self.norm: |
| eps = 1e-9 |
| max_val = torch.max(torch.abs(audio_clip)) |
| audio_clip = audio_clip / (max_val + eps) |
| |
| return audio_clip |
|
|
| def load_label(self, filelabel, event_label): |
| target = torch.zeros(self.seg_len * self.label_sr) |
| if self.label_type == 'strong': |
| label = filelabel[filelabel['label'] == event_label] |
| for i in range(len(label)): |
| row = label.iloc[i] |
| onset = row['onset'] |
| offset = row['offset'] |
| target[round(onset*self.label_sr):round(offset*self.label_sr)] = 1 |
| else: |
| pass |
| return target.unsqueeze(0) |
|
|
| def __getitem__(self, index): |
| row = self.meta.iloc[index] |
| audio = self.load_audio(self.data_dir + row['file_name']) |
|
|
| |
| if self.sample_method == 'fix': |
| cls_list = row['ids'] |
| if self.sample_method == 'random': |
| cls_queue = self.cls_ids |
| cls_list = random.sample(cls_queue, self.label_per_audio) |
| elif self.sample_method == 'balance': |
| pos_ids = ast.literal_eval(row['pos_ids']) |
| neg_ids = ast.literal_eval(row['neg_ids']) |
| removed_ids = ast.literal_eval(row['removed_ids']) |
| N_p, N_n = self.label_per_audio |
| if len(pos_ids) < N_p: |
| N_n += N_p - len(pos_ids) |
| assert len(neg_ids) + len(removed_ids) >= N_n |
| |
| |
| sampled_pos = random.sample(pos_ids, min(N_p, len(pos_ids))) |
|
|
| |
| candidates = neg_ids + removed_ids |
| weights = [1.0] * len(neg_ids) + [self.neg_removed_weight] * len(removed_ids) |
| sampled_neg = random.choices(candidates, weights=weights, k=min(N_n, len(candidates))) |
|
|
| cls_list = sampled_pos + sampled_neg |
|
|
| cls_tokens = [] |
| labels = [] |
|
|
| filelabel = self.label[self.label['filename'] == row['file_name']] |
|
|
| for cls_id in cls_list: |
| event_label = self.class_dict[cls_id] |
| cls = torch.load(self.clap_dir + event_label + '.pt') |
| cls_tokens.append(cls) |
| label = self.load_label(filelabel, event_label) |
| labels.append(label) |
|
|
| cls_tokens = torch.cat(cls_tokens, dim=0) |
| labels = torch.cat(labels, dim=0) |
|
|
| return audio, cls_tokens, labels, row['file_name'] |
|
|
| def __len__(self): |
| return len(self.meta) |