| |
| import torch |
| import numpy as np |
| import random |
|
|
|
|
| def frame_shift(features, label=None, net_pooling=None): |
| if label is not None: |
| batch_size, _, _ = features.shape |
| shifted_feature = [] |
| shifted_label = [] |
| for idx in range(batch_size): |
| shift = int(random.gauss(0, 90)) |
| shifted_feature.append(torch.roll(features[idx], shift, dims=-1)) |
| shift = -abs(shift) // net_pooling if shift < 0 else shift // net_pooling |
| shifted_label.append(torch.roll(label[idx], shift, dims=-1)) |
| return torch.stack(shifted_feature), torch.stack(shifted_label) |
| else: |
| batch_size, _, _ = features.shape |
| shifted_feature = [] |
| for idx in range(batch_size): |
| shift = int(random.gauss(0, 90)) |
| shifted_feature.append(torch.roll(features[idx], shift, dims=-1)) |
| return torch.stack(shifted_feature) |
|
|
|
|
| def mixup(features, label=None, permutation=None, c=None, alpha=0.2, beta=0.2, mixup_label_type="soft", returnc=False): |
| with torch.no_grad(): |
| batch_size = features.size(0) |
|
|
| if permutation is None: |
| permutation = torch.randperm(batch_size) |
|
|
| if c is None: |
| if mixup_label_type == "soft": |
| c = np.random.beta(alpha, beta) |
| elif mixup_label_type == "hard": |
| c = np.random.beta(alpha, beta) * 0.4 + 0.3 |
|
|
| mixed_features = c * features + (1 - c) * features[permutation, :] |
| if label is not None: |
| if mixup_label_type == "soft": |
| mixed_label = torch.clamp(c * label + (1 - c) * label[permutation, :], min=0, max=1) |
| elif mixup_label_type == "hard": |
| mixed_label = torch.clamp(label + label[permutation, :], min=0, max=1) |
| else: |
| raise NotImplementedError(f"mixup_label_type: {mixup_label_type} not implemented. choice in " |
| f"{'soft', 'hard'}") |
| if returnc: |
| return mixed_features, mixed_label, c, permutation |
| else: |
| return mixed_features, mixed_label |
| else: |
| return mixed_features |
|
|
|
|
| def time_mask(features, labels=None, net_pooling=None, mask_ratios=(10, 20)): |
| |
| if labels is not None: |
| _, _, n_frame = labels.shape |
| t_width = torch.randint(low=int(n_frame/mask_ratios[1]), high=int(n_frame/mask_ratios[0]), size=(1,)) |
| t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,)) |
| features[:, :, t_low * net_pooling:(t_low+t_width)*net_pooling] = 0 |
| labels[:, :, t_low:t_low+t_width] = 0 |
| return features, labels |
| else: |
| _, _, n_frame = features.shape |
| t_width = torch.randint(low=int(n_frame/mask_ratios[1]), high=int(n_frame/mask_ratios[0]), size=(1,)) |
| t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,)) |
| features[:, :, t_low:(t_low + t_width)] = 0 |
| return features |
|
|
|
|
| def feature_transformation(features, n_transform, choice, filter_db_range, filter_bands, |
| filter_minimum_bandwidth, filter_type, freq_mask_ratio, noise_snrs): |
| if n_transform == 2: |
| feature_list = [] |
| for _ in range(n_transform): |
| features_temp = features |
| if choice[0]: |
| features_temp = filt_aug(features_temp, db_range=filter_db_range, n_band=filter_bands, |
| min_bw=filter_minimum_bandwidth, filter_type=filter_type) |
| if choice[1]: |
| features_temp = freq_mask(features_temp, mask_ratio=freq_mask_ratio) |
| if choice[2]: |
| features_temp = add_noise(features_temp, snrs=noise_snrs) |
| feature_list.append(features_temp) |
| return feature_list |
| elif n_transform == 1: |
| if choice[0]: |
| features = filt_aug(features, db_range=filter_db_range, n_band=filter_bands, |
| min_bw=filter_minimum_bandwidth, filter_type=filter_type) |
| if choice[1]: |
| features = freq_mask(features, mask_ratio=freq_mask_ratio) |
| if choice[2]: |
| features = add_noise(features, snrs=noise_snrs) |
| return [features, features] |
| else: |
| return [features, features] |
|
|
|
|
| def filt_aug(features, db_range=[-6, 6], n_band=[3, 6], min_bw=6, filter_type="linear"): |
| |
| if not isinstance(filter_type, str): |
| if torch.rand(1).item() < filter_type: |
| filter_type = "step" |
| n_band = [2, 5] |
| min_bw = 4 |
| else: |
| filter_type = "linear" |
| n_band = [3, 6] |
| min_bw = 6 |
|
|
| batch_size, n_freq_bin, _ = features.shape |
| n_freq_band = torch.randint(low=n_band[0], high=n_band[1], size=(1,)).item() |
| if n_freq_band > 1: |
| while n_freq_bin - n_freq_band * min_bw + 1 < 0: |
| min_bw -= 1 |
| band_bndry_freqs = torch.sort(torch.randint(0, n_freq_bin - n_freq_band * min_bw + 1, |
| (n_freq_band - 1,)))[0] + \ |
| torch.arange(1, n_freq_band) * min_bw |
| band_bndry_freqs = torch.cat((torch.tensor([0]), band_bndry_freqs, torch.tensor([n_freq_bin]))) |
|
|
| if filter_type == "step": |
| band_factors = torch.rand((batch_size, n_freq_band)).to(features) * (db_range[1] - db_range[0]) + db_range[0] |
| band_factors = 10 ** (band_factors / 20) |
|
|
| freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features) |
| for i in range(n_freq_band): |
| freq_filt[:, band_bndry_freqs[i]:band_bndry_freqs[i + 1], :] = band_factors[:, i].unsqueeze(-1).unsqueeze(-1) |
|
|
| elif filter_type == "linear": |
| band_factors = torch.rand((batch_size, n_freq_band + 1)).to(features) * (db_range[1] - db_range[0]) + db_range[0] |
| freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features) |
| for i in range(n_freq_band): |
| for j in range(batch_size): |
| freq_filt[j, band_bndry_freqs[i]:band_bndry_freqs[i+1], :] = \ |
| torch.linspace(band_factors[j, i], band_factors[j, i+1], |
| band_bndry_freqs[i+1] - band_bndry_freqs[i]).unsqueeze(-1) |
| freq_filt = 10 ** (freq_filt / 20) |
| return features * freq_filt |
|
|
| else: |
| return features |
|
|
|
|
| def freq_mask(features, mask_ratio=16): |
| batch_size, n_freq_bin, _ = features.shape |
| max_mask = int(n_freq_bin/mask_ratio) |
| if max_mask == 1: |
| f_widths = torch.ones(batch_size) |
| else: |
| f_widths = torch.randint(low=1, high=max_mask, size=(batch_size,)) |
|
|
| for i in range(batch_size): |
| f_width = f_widths[i] |
| f_low = torch.randint(low=0, high=n_freq_bin-f_width, size=(1,)) |
|
|
| features[i, f_low:f_low+f_width, :] = 0 |
| return features |
|
|
|
|
| def add_noise(features, snrs=(15, 30), dims=(1, 2)): |
| if isinstance(snrs, (list, tuple)): |
| snr = (snrs[0] - snrs[1]) * torch.rand((features.shape[0],), device=features.device).reshape(-1, 1, 1) + snrs[1] |
| else: |
| snr = snrs |
|
|
| snr = 10 ** (snr / 20) |
| sigma = torch.std(features, dim=dims, keepdim=True) / snr |
| return features + torch.randn(features.shape, device=features.device) * sigma |
|
|
|
|
|
|