| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
|
|
|
|
| import torch.utils.data as data |
| |
| import soundfile as sf |
| import PIL |
| import os |
| import os.path |
| import pickle |
| import random |
| import numpy as np |
| import pandas as pd |
| from scipy import signal |
|
|
| from miscc.config import cfg |
|
|
|
|
| class TextDataset(data.Dataset): |
| def __init__(self, data_dir, split='train',rirsize=4096): |
|
|
| |
| |
| self.rirsize = rirsize |
| self.data = [] |
| self.data_dir = data_dir |
| self.bbox = None |
| |
| split_dir = os.path.join(data_dir, split) |
|
|
| self.filenames = self.load_filenames(split_dir) |
| self.embeddings = self.load_embedding(split_dir) |
|
|
| def get_RIR(self, RIR_path): |
| wav,fs = sf.read(RIR_path) |
| length = wav.size |
| |
| crop_length = 4096 |
| if(length<crop_length): |
| zeros = np.zeros(crop_length-length) |
| RIR_original = np.concatenate([wav,zeros]) |
| else: |
| RIR_original = wav[0:crop_length] |
|
|
| |
| resample_length = int(self.rirsize) |
| if(resample_length==16384): |
| RIR = RIR_original |
| else: |
| RIR = RIR_original |
| RIR = np.array([RIR]).astype('float32') |
|
|
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return RIR |
|
|
|
|
| def load_embedding(self, data_dir): |
| embedding_filename = '/embeddings.pickle' |
| with open(data_dir + embedding_filename, 'rb') as f: |
| embeddings = pickle.load(f) |
| |
| |
| |
| return embeddings |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def load_filenames(self, data_dir): |
| filepath = os.path.join(data_dir, 'filenames.pickle') |
| with open(filepath, 'rb') as f: |
| filenames = pickle.load(f) |
| print('Load filenames from: %s (%d)' % (filepath, len(filenames))) |
| return filenames |
|
|
| def __getitem__(self, index): |
| key = self.filenames[index] |
|
|
| data_dir = self.data_dir |
|
|
| |
| embeddings = self.embeddings[key] |
| RIR_name = '%s/RIR/%s.wav' % (data_dir, key) |
| RIR = self.get_RIR(RIR_name) |
| embedding = np.array(embeddings).astype('float32') |
| |
| |
| return RIR, embedding |
|
|
| def __len__(self): |
| return len(self.filenames) |
|
|