niobures commited on
Commit
3b5bc00
·
verified ·
1 Parent(s): 3daf776

SepReformer (code, models, paper)

Browse files
Files changed (38) hide show
  1. .gitattributes +1 -0
  2. Separate and Reconstruct. Asymmetric Encoder-Decoder for Speech Separation.pdf +3 -0
  3. code/SepReformer.zip +3 -0
  4. code/sepformer-tse.zip +3 -0
  5. models/SepReformer/SepReformer_Base_WSJ0/configs.yaml +139 -0
  6. models/SepReformer/SepReformer_Base_WSJ0/dataset.py +165 -0
  7. models/SepReformer/SepReformer_Base_WSJ0/engine.py +216 -0
  8. models/SepReformer/SepReformer_Base_WSJ0/log/scratch_weights/epoch.0180.pth +3 -0
  9. models/SepReformer/SepReformer_Base_WSJ0/main.py +47 -0
  10. models/SepReformer/SepReformer_Base_WSJ0/model.py +53 -0
  11. models/SepReformer/SepReformer_Base_WSJ0/modules/module.py +283 -0
  12. models/SepReformer/SepReformer_Base_WSJ0/modules/network.py +252 -0
  13. models/SepReformer/SepReformer_Large_DM_WHAM/configs.yaml +129 -0
  14. models/SepReformer/SepReformer_Large_DM_WHAM/dataset.py +177 -0
  15. models/SepReformer/SepReformer_Large_DM_WHAM/engine.py +192 -0
  16. models/SepReformer/SepReformer_Large_DM_WHAM/main.py +44 -0
  17. models/SepReformer/SepReformer_Large_DM_WHAM/model.py +53 -0
  18. models/SepReformer/SepReformer_Large_DM_WHAM/modules/module.py +286 -0
  19. models/SepReformer/SepReformer_Large_DM_WHAM/modules/network.py +252 -0
  20. models/SepReformer/SepReformer_Large_DM_WHAMR/configs.yaml +131 -0
  21. models/SepReformer/SepReformer_Large_DM_WHAMR/dataset.py +187 -0
  22. models/SepReformer/SepReformer_Large_DM_WHAMR/engine.py +192 -0
  23. models/SepReformer/SepReformer_Large_DM_WHAMR/main.py +44 -0
  24. models/SepReformer/SepReformer_Large_DM_WHAMR/model.py +53 -0
  25. models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/module.cpython-310.pyc +0 -0
  26. models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/module.cpython-38.pyc +0 -0
  27. models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/network.cpython-310.pyc +0 -0
  28. models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/network.cpython-38.pyc +0 -0
  29. models/SepReformer/SepReformer_Large_DM_WHAMR/modules/module.py +283 -0
  30. models/SepReformer/SepReformer_Large_DM_WHAMR/modules/network.py +252 -0
  31. models/SepReformer/SepReformer_Large_DM_WSJ0/configs.yaml +128 -0
  32. models/SepReformer/SepReformer_Large_DM_WSJ0/dataset.py +171 -0
  33. models/SepReformer/SepReformer_Large_DM_WSJ0/engine.py +192 -0
  34. models/SepReformer/SepReformer_Large_DM_WSJ0/main.py +44 -0
  35. models/SepReformer/SepReformer_Large_DM_WSJ0/model.py +53 -0
  36. models/SepReformer/SepReformer_Large_DM_WSJ0/modules/module.py +283 -0
  37. models/SepReformer/SepReformer_Large_DM_WSJ0/modules/network.py +252 -0
  38. models/SepReformer/source.txt +1 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Separate[[:space:]]and[[:space:]]Reconstruct.[[:space:]]Asymmetric[[:space:]]Encoder-Decoder[[:space:]]for[[:space:]]Speech[[:space:]]Separation.pdf filter=lfs diff=lfs merge=lfs -text
Separate and Reconstruct. Asymmetric Encoder-Decoder for Speech Separation.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97766a617509db55e816689f3f3e8e52c03b06ebf04f98f03e298a5556a4e898
3
+ size 1952305
code/SepReformer.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc9b31d464b79b6ac037879b160445f53a4de9e4a411cce0954fc24c0ff7706d
3
+ size 16944535
code/sepformer-tse.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13905e009d88354fcb21579e11f5dcebe3114eb2d238eab0fa4cc11f9cc237ea
3
+ size 114198916
models/SepReformer/SepReformer_Base_WSJ0/configs.yaml ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ project: "[Project] SepReformer" ### Dont't change
2
+ notes: "SepReformer final version" ### Insert schanges(plz write details !!!)
3
+ # ------------------------------------------------------------------------------------------------------------------------------ #
4
+ config:
5
+ # ------------------------------------------------------------ #
6
+ dataset:
7
+ max_len: 32000
8
+ sampling_rate: 8000
9
+ scp_dir: "data/scp_ss_8k"
10
+ train:
11
+ mixture: "tr_mix.scp"
12
+ spk1: "tr_s1.scp"
13
+ spk2: "tr_s2.scp"
14
+ dynamic_mixing: false
15
+ valid:
16
+ mixture: "cv_mix.scp"
17
+ spk1: "cv_s1.scp"
18
+ spk2: "cv_s2.scp"
19
+ test:
20
+ mixture: "tt_mix.scp"
21
+ spk1: "tt_s1.scp"
22
+ spk2: "tt_s2.scp"
23
+ # ------------------------------------------------------------ #
24
+ dataloader:
25
+ batch_size: 2
26
+ pin_memory: false
27
+ num_workers: 12
28
+ drop_last: false
29
+ # ------------------------------------------------------------ #
30
+ model:
31
+ num_stages: &var_model_num_stages 4 # R
32
+ num_spks: &var_model_num_spks 2
33
+ module_audio_enc:
34
+ in_channels: 1
35
+ out_channels: &var_model_audio_enc_out_channels 256
36
+ kernel_size: &var_model_audio_enc_kernel_size 16 # L
37
+ stride: &var_model_audio_enc_stride 4 # S
38
+ groups: 1
39
+ bias: false
40
+ module_feature_projector:
41
+ num_channels: *var_model_audio_enc_out_channels
42
+ in_channels: *var_model_audio_enc_out_channels
43
+ out_channels: &feature_projector_out_channels 128 # F
44
+ kernel_size: 1
45
+ bias: false
46
+ module_separator:
47
+ num_stages: *var_model_num_stages
48
+ relative_positional_encoding:
49
+ in_channels: *feature_projector_out_channels
50
+ num_heads: 8
51
+ maxlen: 2000
52
+ embed_v: false
53
+ enc_stage:
54
+ global_blocks:
55
+ in_channels: *feature_projector_out_channels
56
+ num_mha_heads: 8
57
+ dropout_rate: 0.05
58
+ local_blocks:
59
+ in_channels: *feature_projector_out_channels
60
+ kernel_size: 65
61
+ dropout_rate: 0.05
62
+ down_conv_layer:
63
+ in_channels: *feature_projector_out_channels
64
+ samp_kernel_size: &var_model_samp_kernel_size 5
65
+ spk_split_stage:
66
+ in_channels: *feature_projector_out_channels
67
+ num_spks: *var_model_num_spks
68
+ simple_fusion:
69
+ out_channels: *feature_projector_out_channels
70
+ dec_stage:
71
+ num_spks: *var_model_num_spks
72
+ global_blocks:
73
+ in_channels: *feature_projector_out_channels
74
+ num_mha_heads: 8
75
+ dropout_rate: 0.05
76
+ local_blocks:
77
+ in_channels: *feature_projector_out_channels
78
+ kernel_size: 65
79
+ dropout_rate: 0.05
80
+ spk_attention:
81
+ in_channels: *feature_projector_out_channels
82
+ num_mha_heads: 8
83
+ dropout_rate: 0.05
84
+ module_output_layer:
85
+ in_channels: *var_model_audio_enc_out_channels
86
+ out_channels: *feature_projector_out_channels
87
+ num_spks: *var_model_num_spks
88
+ module_audio_dec:
89
+ in_channels: *var_model_audio_enc_out_channels
90
+ out_channels: 1
91
+ kernel_size: *var_model_audio_enc_kernel_size
92
+ stride: *var_model_audio_enc_stride
93
+ bias: false
94
+ # ------------------------------------------------------------ #
95
+ criterion:
96
+ name: ["PIT_SISNR_mag", "PIT_SISNR_time", "PIT_SISNRi", "PIT_SDRi"] ### Choose a torch.nn's loss function class(=attribute) e.g. ["L1Loss", "MSELoss", "CrossEntropyLoss", ...]
97
+ PIT_SISNR_mag:
98
+ frame_length: 512
99
+ frame_shift: 128
100
+ window: 'hann'
101
+ num_stages: *var_model_num_stages
102
+ num_spks: *var_model_num_spks
103
+ scale_inv: true
104
+ mel_opt: false
105
+ PIT_SISNR_time:
106
+ num_spks: *var_model_num_spks
107
+ scale_inv: true
108
+ PIT_SISNRi:
109
+ num_spks: *var_model_num_spks
110
+ scale_inv: true
111
+ PIT_SDRi:
112
+ dump: 0
113
+ # ------------------------------------------------------------ #
114
+ optimizer:
115
+ name: ["AdamW"] ### Choose a torch.optim's class(=attribute) e.g. ["Adam", "AdamW", "SGD", ...]
116
+ AdamW:
117
+ lr: 1.0e-3
118
+ weight_decay: 1.0e-2
119
+ # ------------------------------------------------------------ #
120
+ scheduler:
121
+ name: ["ReduceLROnPlateau", "WarmupConstantSchedule"] ### Choose a torch.optim.lr_scheduler's class(=attribute) e.g. ["StepLR", "ReduceLROnPlateau", "Custom"]
122
+ ReduceLROnPlateau:
123
+ mode: "min"
124
+ min_lr: 1.0e-10
125
+ factor: 0.8
126
+ patience: 2
127
+ WarmupConstantSchedule:
128
+ warmup_steps: 1000
129
+ # ------------------------------------------------------------ #
130
+ check_computations:
131
+ dummy_len: 16000
132
+ # ------------------------------------------------------------ #
133
+ engine:
134
+ max_epoch: 200
135
+ gpuid: "0" ### "0"(single-gpu) or "0, 1" (multi-gpu)
136
+ mvn: false
137
+ clip_norm: 5
138
+ start_scheduling: 50
139
+ test_epochs: [100, 120, 150, 170]
models/SepReformer/SepReformer_Base_WSJ0/dataset.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import librosa as audio_lib
5
+ import numpy as np
6
+
7
+ from utils import util_dataset
8
+ from utils.decorators import *
9
+ from loguru import logger
10
+ from torch.utils.data import Dataset, DataLoader
11
+
12
+ @logger_wraps()
13
+ def get_dataloaders(args, dataset_config, loader_config):
14
+ # create dataset object for each partition
15
+ partitions = ["test"] if "test" in args.engine_mode else ["train", "valid", "test"]
16
+ dataloaders = {}
17
+ for partition in partitions:
18
+ scp_config_mix = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['mixture'])
19
+ scp_config_spk = [os.path.join(dataset_config["scp_dir"], dataset_config[partition][spk_key]) for spk_key in dataset_config[partition] if spk_key.startswith('spk')]
20
+ dynamic_mixing = dataset_config[partition]["dynamic_mixing"] if partition == 'train' else False
21
+ dataset = MyDataset(
22
+ max_len = dataset_config['max_len'],
23
+ fs = dataset_config['sampling_rate'],
24
+ partition = partition,
25
+ wave_scp_srcs = scp_config_spk,
26
+ wave_scp_mix = scp_config_mix,
27
+ dynamic_mixing = dynamic_mixing)
28
+ dataloader = DataLoader(
29
+ dataset = dataset,
30
+ batch_size = 1 if partition == 'test' else loader_config["batch_size"],
31
+ shuffle = True, # only train: (partition == 'train') / all: True
32
+ pin_memory = loader_config["pin_memory"],
33
+ num_workers = loader_config["num_workers"],
34
+ drop_last = loader_config["drop_last"],
35
+ collate_fn = _collate)
36
+ dataloaders[partition] = dataloader
37
+ return dataloaders
38
+
39
+
40
+ def _collate(egs):
41
+ """
42
+ Transform utterance index into a minbatch
43
+
44
+ Arguments:
45
+ index: a list type [{},{},{}]
46
+
47
+ Returns:
48
+ input_sizes: a tensor correspond to utterance length
49
+ input_feats: packed sequence to feed networks
50
+ source_attr/target_attr: dictionary contains spectrogram/phase needed in loss computation
51
+ """
52
+ def __prepare_target_rir(dict_lsit, index):
53
+ return torch.nn.utils.rnn.pad_sequence([torch.tensor(d["src"][index], dtype=torch.float32) for d in dict_lsit], batch_first=True)
54
+ if type(egs) is not list: raise ValueError("Unsupported index type({})".format(type(egs)))
55
+ num_spks = 2 # you need to set this paramater by yourself
56
+ dict_list = sorted([eg for eg in egs], key=lambda x: x['num_sample'], reverse=True)
57
+ mixture = torch.nn.utils.rnn.pad_sequence([torch.tensor(d['mix'], dtype=torch.float32) for d in dict_list], batch_first=True)
58
+ src = [__prepare_target_rir(dict_list, index) for index in range(num_spks)]
59
+ input_sizes = torch.tensor([d['num_sample'] for d in dict_list], dtype=torch.float32)
60
+ key = [d['key'] for d in dict_list]
61
+ return input_sizes, mixture, src, key
62
+
63
+
64
+ @logger_wraps()
65
+ class MyDataset(Dataset):
66
+ def __init__(self, max_len, fs, partition, wave_scp_srcs, wave_scp_mix, wave_scp_noise=None, dynamic_mixing=False, speed_list=None):
67
+ self.partition = partition
68
+ for wave_scp_src in wave_scp_srcs:
69
+ if not os.path.exists(wave_scp_src): raise FileNotFoundError(f"Could not find file {wave_scp_src}")
70
+ self.max_len = max_len
71
+ self.fs = fs
72
+ self.wave_dict_srcs = [util_dataset.parse_scps(wave_scp_src) for wave_scp_src in wave_scp_srcs]
73
+ self.wave_dict_mix = util_dataset.parse_scps(wave_scp_mix)
74
+ self.wave_dict_noise = util_dataset.parse_scps(wave_scp_noise) if wave_scp_noise else None
75
+ self.wave_keys = list(self.wave_dict_mix.keys())
76
+ logger.info(f"Create MyDataset for {wave_scp_mix} with {len(self.wave_dict_mix)} utterances")
77
+ self.dynamic_mixing = dynamic_mixing
78
+
79
+ def __len__(self):
80
+ return len(self.wave_dict_mix)
81
+
82
+ def __contains__(self, key):
83
+ return key in self.wave_dict_mix
84
+
85
+ def _dynamic_mixing(self, key):
86
+ def __match_length(wav, len_data) :
87
+ leftover = len(wav) - len_data
88
+ idx = random.randint(0,leftover)
89
+ wav = wav[idx:idx+len_data]
90
+ return wav
91
+
92
+ samps_src = []
93
+ src_len = []
94
+ # dyanmic source choice
95
+ # checking whether it is the same speaker
96
+ while True:
97
+ key_random = random.choice(list(self.wave_dict_srcs[0].keys()))
98
+ tmp1 = key.split('_')[1][:3] != key_random.split('_')[3][:3]
99
+ tmp2 = key.split('_')[3][:3] != key_random.split('_')[1][:3]
100
+ if tmp1 and tmp2: break
101
+
102
+ idx1, idx2 = (0, 1) if random.random() > 0.5 else (1, 0)
103
+ files = [self.wave_dict_srcs[idx1][key], self.wave_dict_srcs[idx2][key_random]]
104
+
105
+ # load
106
+ for file in files:
107
+ if not os.path.exists(file): raise FileNotFoundError("Input file {} do not exists!".format(file))
108
+ samps_tmp, _ = audio_lib.load(file, sr=self.fs)
109
+ # mixing with random gains
110
+ gain = pow(10,-random.uniform(-2.5,2.5)/20)
111
+ # Speed Augmentation
112
+ samps_tmp = np.array(self.speed_aug(torch.tensor(samps_tmp))[0])
113
+ samps_src.append(gain*samps_tmp)
114
+ src_len.append(len(samps_tmp))
115
+
116
+ # matching the audio length
117
+ min_len = min(src_len)
118
+
119
+ samps_src = [__match_length(s, min_len) for s in samps_src]
120
+ samps_mix = sum(samps_src)
121
+
122
+ # ! truncated along to the sample Length "L"
123
+ if len(samps_mix)%4 != 0:
124
+ remains = len(samps_mix)%4
125
+ samps_mix = samps_mix[:-remains]
126
+ samps_src = [s[:-remains] for s in samps_src]
127
+
128
+ if self.partition != "test":
129
+ if len(samps_mix) > self.max_len:
130
+ start = random.randint(0, len(samps_mix)-self.max_len)
131
+ samps_mix = samps_mix[start:start+self.max_len]
132
+ samps_src = [s[start:start+self.max_len] for s in samps_src]
133
+ return samps_mix, samps_src
134
+
135
+ def _direct_load(self, key):
136
+ samps_src = []
137
+ files = [wave_dict_src[key] for wave_dict_src in self.wave_dict_srcs]
138
+ for file in files:
139
+ if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
140
+ samps_tmp, _ = audio_lib.load(file, sr=self.fs)
141
+ samps_src.append(samps_tmp)
142
+
143
+ file = self.wave_dict_mix[key]
144
+ if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
145
+ samps_mix, _ = audio_lib.load(file, sr=self.fs)
146
+
147
+ # Truncate samples as needed
148
+ if len(samps_mix) % 4 != 0:
149
+ remains = len(samps_mix) % 4
150
+ samps_mix = samps_mix[:-remains]
151
+ samps_src = [s[:-remains] for s in samps_src]
152
+
153
+ if self.partition != "test":
154
+ if len(samps_mix) > self.max_len:
155
+ start = random.randint(0,len(samps_mix)-self.max_len)
156
+ samps_mix = samps_mix[start:start+self.max_len]
157
+ samps_src = [s[start:start+self.max_len] for s in samps_src]
158
+
159
+ return samps_mix, samps_src
160
+
161
+ def __getitem__(self, index):
162
+ key = self.wave_keys[index]
163
+ if any(key not in self.wave_dict_srcs[i] for i in range(len(self.wave_dict_srcs))) or key not in self.wave_dict_mix: raise KeyError(f"Could not find utterance {key}")
164
+ samps_mix, samps_src = self._dynamic_mixing(key) if self.dynamic_mixing else self._direct_load(key)
165
+ return {"num_sample": samps_mix.shape[0], "mix": samps_mix, "src": samps_src, "key": key}
models/SepReformer/SepReformer_Base_WSJ0/engine.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import csv
4
+ import time
5
+ import soundfile as sf
6
+ import librosa
7
+ from loguru import logger
8
+ from tqdm import tqdm
9
+ from utils import util_engine, functions
10
+ from utils.decorators import *
11
+ from torch.utils.tensorboard import SummaryWriter
12
+
13
+
14
+ @logger_wraps()
15
+ class Engine(object):
16
+ def __init__(self, args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device):
17
+
18
+ ''' Default setting '''
19
+ self.engine_mode = args.engine_mode
20
+ self.out_wav_dir = args.out_wav_dir
21
+ self.config = config
22
+ self.gpuid = gpuid
23
+ self.device = device
24
+ self.model = model.to(self.device)
25
+ self.dataloaders = dataloaders # self.dataloaders['train'] or ['valid'] or ['test']
26
+ self.PIT_SISNR_mag_loss, self.PIT_SISNR_time_loss, self.PIT_SISNRi_loss, self.PIT_SDRi_loss = criterions
27
+ self.main_optimizer = optimizers[0]
28
+ self.main_scheduler, self.warmup_scheduler = schedulers
29
+
30
+ self.pretrain_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "pretrain_weights")
31
+ os.makedirs(self.pretrain_weights_path, exist_ok=True)
32
+ self.scratch_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "scratch_weights")
33
+ os.makedirs(self.scratch_weights_path, exist_ok=True)
34
+
35
+ self.checkpoint_path = self.pretrain_weights_path if any(file.endswith(('.pt', '.pt', '.pkl')) for file in os.listdir(self.pretrain_weights_path)) else self.scratch_weights_path
36
+ self.start_epoch = util_engine.load_last_checkpoint_n_get_epoch(self.checkpoint_path, self.model, self.main_optimizer, location=self.device)
37
+
38
+ # Logging
39
+ util_engine.model_params_mac_summary(
40
+ model=self.model,
41
+ input=torch.randn(1, self.config['check_computations']['dummy_len']).to(self.device),
42
+ dummy_input=torch.rand(1, self.config['check_computations']['dummy_len']).to(self.device),
43
+ metrics=['ptflops', 'thop', 'torchinfo']
44
+ # metrics=['ptflops']
45
+ )
46
+
47
+ logger.info(f"Clip gradient by 2-norm {self.config['engine']['clip_norm']}")
48
+
49
+ @logger_wraps()
50
+ def _train(self, dataloader, epoch):
51
+ self.model.train()
52
+ tot_loss_freq = [0 for _ in range(self.model.num_stages)]
53
+ tot_loss_time, num_batch = 0, 0
54
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:25}{r_bar}{bar:-10b}', colour="YELLOW", dynamic_ncols=True)
55
+ for input_sizes, mixture, src, _ in dataloader:
56
+ nnet_input = mixture
57
+ nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
58
+ num_batch += 1
59
+ pbar.update(1)
60
+ # Scheduler learning rate for warm-up (Iteration-based update for transformers)
61
+ if epoch == 1: self.warmup_scheduler.step()
62
+ nnet_input = nnet_input.to(self.device)
63
+ self.main_optimizer.zero_grad()
64
+ estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
65
+ cur_loss_s_bn = 0
66
+ cur_loss_s_bn = []
67
+ for idx, estim_src_value in enumerate(estim_src_bn):
68
+ cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
69
+ tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
70
+ cur_loss_s = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
71
+ tot_loss_time += cur_loss_s.item() / self.config['model']['num_spks']
72
+ alpha = 0.4 * 0.8**(1+(epoch-101)//5) if epoch > 100 else 0.4
73
+ cur_loss = (1-alpha) * cur_loss_s + alpha * sum(cur_loss_s_bn) / len(cur_loss_s_bn)
74
+ cur_loss = cur_loss / self.config['model']['num_spks']
75
+ cur_loss.backward()
76
+ if self.config['engine']['clip_norm']: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['engine']['clip_norm'])
77
+ self.main_optimizer.step()
78
+ dict_loss = {"T_Loss": tot_loss_time / num_batch}
79
+ dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
80
+ pbar.set_postfix(dict_loss)
81
+ pbar.close()
82
+ tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
83
+ return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
84
+
85
+ @logger_wraps()
86
+ def _validate(self, dataloader):
87
+ self.model.eval()
88
+ tot_loss_freq = [0 for _ in range(self.model.num_stages)]
89
+ tot_loss_time, num_batch = 0, 0
90
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="RED", dynamic_ncols=True)
91
+ with torch.inference_mode():
92
+ for input_sizes, mixture, src, _ in dataloader:
93
+ nnet_input = mixture
94
+ nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
95
+ nnet_input = nnet_input.to(self.device)
96
+ num_batch += 1
97
+ pbar.update(1)
98
+ estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
99
+ cur_loss_s_bn = []
100
+ for idx, estim_src_value in enumerate(estim_src_bn):
101
+ cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
102
+ tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
103
+ cur_loss_s_SDR = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
104
+ tot_loss_time += cur_loss_s_SDR.item() / self.config['model']['num_spks']
105
+ dict_loss = {"T_Loss":tot_loss_time / num_batch}
106
+ dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
107
+ pbar.set_postfix(dict_loss)
108
+ pbar.close()
109
+ tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
110
+ return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
111
+
112
+ @logger_wraps()
113
+ def _test(self, dataloader, wav_dir=None):
114
+ self.model.eval()
115
+ total_loss_SISNRi, total_loss_SDRi, num_batch = 0, 0, 0
116
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="grey", dynamic_ncols=True)
117
+ with torch.inference_mode():
118
+ csv_file_name_sisnr = os.path.join(os.path.dirname(__file__),'test_SISNRi_value.csv')
119
+ csv_file_name_sdr = os.path.join(os.path.dirname(__file__),'test_SDRi_value.csv')
120
+ with open(csv_file_name_sisnr, 'w', newline='') as csvfile_sisnr, open(csv_file_name_sdr, 'w', newline='') as csvfile_sdr:
121
+ idx = 0
122
+ writer_sisnr = csv.writer(csvfile_sisnr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
123
+ writer_sdr = csv.writer(csvfile_sdr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
124
+ for input_sizes, mixture, src, key in dataloader:
125
+ if len(key) > 1:
126
+ raise("batch size is not one!!")
127
+ nnet_input = mixture.to(self.device)
128
+ num_batch += 1
129
+ pbar.update(1)
130
+ estim_src, _ = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
131
+ cur_loss_SISNRi, cur_loss_SISNRi_src = self.PIT_SISNRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src, eps=1.0e-15)
132
+ total_loss_SISNRi += cur_loss_SISNRi.item() / self.config['model']['num_spks']
133
+ cur_loss_SDRi, cur_loss_SDRi_src = self.PIT_SDRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src)
134
+ total_loss_SDRi += cur_loss_SDRi.item() / self.config['model']['num_spks']
135
+ writer_sisnr.writerow([key[0][:-4]] + [cur_loss_SISNRi_src[i].item() for i in range(self.config['model']['num_spks'])])
136
+ writer_sdr.writerow([key[0][:-4]] + [cur_loss_SDRi_src[i].item() for i in range(self.config['model']['num_spks'])])
137
+ if self.engine_mode == "test_save":
138
+ if wav_dir == None: wav_dir = os.path.join(os.path.dirname(__file__),"wav_out")
139
+ if wav_dir and not os.path.exists(wav_dir): os.makedirs(wav_dir)
140
+ mixture = torch.squeeze(mixture).cpu().data.numpy()
141
+ sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_mixture.wav'), 0.5*mixture/max(abs(mixture)), 8000)
142
+ for i in range(self.config['model']['num_spks']):
143
+ src = torch.squeeze(estim_src[i]).cpu().data.numpy()
144
+ sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_out_'+str(i)+'.wav'), 0.5*src/max(abs(src)), 8000)
145
+ idx += 1
146
+ dict_loss = {"SiSNRi": total_loss_SISNRi/num_batch, "SDRi": total_loss_SDRi/num_batch}
147
+ pbar.set_postfix(dict_loss)
148
+ pbar.close()
149
+ return total_loss_SISNRi/num_batch, total_loss_SDRi/num_batch, num_batch
150
+
151
+ @logger_wraps()
152
+ def _inference_sample(self, sample):
153
+ self.model.eval()
154
+ self.fs = self.config["dataset"]["sampling_rate"]
155
+ mixture, _ = librosa.load(sample,sr=self.fs)
156
+ mixture = torch.tensor(mixture, dtype=torch.float32)[None]
157
+ self.stride = self.config["model"]["module_audio_enc"]["stride"]
158
+ remains = mixture.shape[-1] % self.stride
159
+ if remains != 0:
160
+ padding = self.stride - remains
161
+ mixture_padded = torch.nn.functional.pad(mixture, (0, padding), "constant", 0)
162
+ else:
163
+ mixture_padded = mixture
164
+
165
+ with torch.inference_mode():
166
+ nnet_input = mixture_padded.to(self.device)
167
+ estim_src, _ = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
168
+ mixture = torch.squeeze(mixture).cpu().numpy()
169
+ sf.write(sample[:-4]+'_in.wav', 0.9*mixture/max(abs(mixture)), self.fs)
170
+ for i in range(self.config['model']['num_spks']):
171
+ src = torch.squeeze(estim_src[i][...,:mixture.shape[-1]]).cpu().data.numpy()
172
+ sf.write(sample[:-4]+'_out_'+str(i)+'.wav', 0.9*src/max(abs(src)), self.fs)
173
+
174
+
175
+ @logger_wraps()
176
+ def run(self):
177
+ with torch.cuda.device(self.device):
178
+ writer_src = SummaryWriter(os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/tensorboard"))
179
+ if "test" in self.engine_mode:
180
+ on_test_start = time.time()
181
+ test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'], self.out_wav_dir)
182
+ on_test_end = time.time()
183
+ logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
184
+ logger.info(f"Testing done!")
185
+ else:
186
+ start_time = time.time()
187
+ if self.start_epoch > 1:
188
+ init_loss_time, init_loss_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
189
+ else:
190
+ init_loss_time, init_loss_freq = 0, 0
191
+ end_time = time.time()
192
+ logger.info(f"[INIT] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: Loss_t = {init_loss_time:.4f} dB | Loss_f = {init_loss_freq:.4f} dB | Speed = ({end_time-start_time:.2f}s)")
193
+ for epoch in range(self.start_epoch, self.config['engine']['max_epoch']):
194
+ valid_loss_best = init_loss_time
195
+ train_start_time = time.time()
196
+ train_loss_src_time, train_loss_src_freq, train_num_batch = self._train(self.dataloaders['train'], epoch)
197
+ train_end_time = time.time()
198
+ valid_start_time = time.time()
199
+ valid_loss_src_time, valid_loss_src_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
200
+ valid_end_time = time.time()
201
+ if epoch > self.config['engine']['start_scheduling']: self.main_scheduler.step(valid_loss_src_time)
202
+ logger.info(f"[TRAIN] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {train_loss_src_time:.4f} dB | Loss_f = {train_loss_src_freq:.4f} dB | Speed = ({train_end_time - train_start_time:.2f}s/{train_num_batch:d})")
203
+ logger.info(f"[VALID] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {valid_loss_src_time:.4f} dB | Loss_f = {valid_loss_src_freq:.4f} dB | Speed = ({valid_end_time - valid_start_time:.2f}s/{valid_num_batch:d})")
204
+ if epoch in self.config['engine']['test_epochs']:
205
+ on_test_start = time.time()
206
+ test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'])
207
+ on_test_end = time.time()
208
+ logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
209
+ valid_loss_best = util_engine.save_checkpoint_per_best(valid_loss_best, valid_loss_src_time, train_loss_src_time, epoch, self.model, self.main_optimizer, self.checkpoint_path, self.wandb_run)
210
+ # Logging to monitoring tools (Tensorboard && Wandb)
211
+ writer_src.add_scalars("Metrics", {
212
+ 'Loss_train_time': train_loss_src_time,
213
+ 'Loss_valid_time': valid_loss_src_time}, epoch)
214
+ writer_src.add_scalars("Learning Rate", self.main_optimizer.param_groups[0]['lr'], epoch)
215
+ writer_src.flush()
216
+ logger.info(f"Training for {self.config['engine']['max_epoch']} epoches done!")
models/SepReformer/SepReformer_Base_WSJ0/log/scratch_weights/epoch.0180.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14569febefb19900026a350c7b31ca6a927ce4bac7fa83269902d8c6437f0d11
3
+ size 134
models/SepReformer/SepReformer_Base_WSJ0/main.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from loguru import logger
4
+ from .dataset import get_dataloaders
5
+ from .model import Model
6
+ from .engine import Engine
7
+ from utils import util_system, util_implement
8
+ from utils.decorators import *
9
+
10
+ # Setup logger
11
+ log_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/system_log.log")
12
+ logger.add(log_file_path, level="DEBUG", mode="w")
13
+
14
+ @logger_wraps()
15
+ def main(args):
16
+
17
+ ''' Build Setting '''
18
+ # Call configuration file (configs.yaml)
19
+ yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs.yaml")
20
+ yaml_dict = util_system.parse_yaml(yaml_path)
21
+
22
+ # Run wandb and get configuration
23
+ config = yaml_dict["config"] # wandb login success or fail
24
+
25
+ # Call DataLoader [train / valid / test / etc...]
26
+ dataloaders = get_dataloaders(args, config["dataset"], config["dataloader"])
27
+
28
+ ''' Build Model '''
29
+ # Call network model
30
+ model = Model(**config["model"])
31
+
32
+ ''' Build Engine '''
33
+ # Call gpu id & device
34
+ gpuid = tuple(map(int, config["engine"]["gpuid"].split(',')))
35
+ device = torch.device(f'cuda:{gpuid[0]}')
36
+
37
+ # Call Implement [criterion / optimizer / scheduler]
38
+ criterions = util_implement.CriterionFactory(config["criterion"], device).get_criterions()
39
+ optimizers = util_implement.OptimizerFactory(config["optimizer"], model.parameters()).get_optimizers()
40
+ schedulers = util_implement.SchedulerFactory(config["scheduler"], optimizers).get_schedulers()
41
+
42
+ # Call & Run Engine
43
+ engine = Engine(args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device)
44
+ if args.engine_mode == 'infer_sample':
45
+ engine._inference_sample(args.sample_file)
46
+ else:
47
+ engine.run()
models/SepReformer/SepReformer_Base_WSJ0/model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+
4
+ import torch
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ from utils.decorators import *
9
+ from .modules.module import *
10
+
11
+
12
+ @logger_wraps()
13
+ class Model(torch.nn.Module):
14
+ def __init__(self,
15
+ num_stages: int,
16
+ num_spks: int,
17
+ module_audio_enc: dict,
18
+ module_feature_projector: dict,
19
+ module_separator: dict,
20
+ module_output_layer: dict,
21
+ module_audio_dec: dict):
22
+ super().__init__()
23
+ self.num_stages = num_stages
24
+ self.num_spks = num_spks
25
+ self.audio_encoder = AudioEncoder(**module_audio_enc)
26
+ self.feature_projector = FeatureProjector(**module_feature_projector)
27
+ self.separator = Separator(**module_separator)
28
+ self.out_layer = OutputLayer(**module_output_layer)
29
+ self.audio_decoder = AudioDecoder(**module_audio_dec)
30
+
31
+ # Aux_loss
32
+ self.out_layer_bn = torch.nn.ModuleList([])
33
+ self.decoder_bn = torch.nn.ModuleList([])
34
+ for _ in range(self.num_stages):
35
+ self.out_layer_bn.append(OutputLayer(**module_output_layer, masking=True))
36
+ self.decoder_bn.append(AudioDecoder(**module_audio_dec))
37
+
38
+ def forward(self, x):
39
+ encoder_output = self.audio_encoder(x)
40
+ projected_feature = self.feature_projector(encoder_output)
41
+ last_stage_output, each_stage_outputs = self.separator(projected_feature)
42
+ out_layer_output = self.out_layer(last_stage_output, encoder_output)
43
+ each_spk_output = [out_layer_output[idx] for idx in range(self.num_spks)]
44
+ audio = [self.audio_decoder(each_spk_output[idx]) for idx in range(self.num_spks)]
45
+
46
+ # Aux_loss
47
+ audio_aux = []
48
+ for idx, each_stage_output in enumerate(each_stage_outputs):
49
+ each_stage_output = self.out_layer_bn[idx](torch.nn.functional.upsample(each_stage_output, encoder_output.shape[-1]), encoder_output)
50
+ out_aux = [each_stage_output[jdx] for jdx in range(self.num_spks)]
51
+ audio_aux.append([self.decoder_bn[idx](out_aux[jdx])[...,:x.shape[-1]] for jdx in range(self.num_spks)])
52
+
53
+ return audio, audio_aux
models/SepReformer/SepReformer_Base_WSJ0/modules/module.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+
4
+ import torch
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ from utils.decorators import *
9
+ from .network import *
10
+
11
+
12
+ class AudioEncoder(torch.nn.Module):
13
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, bias: bool):
14
+ super().__init__()
15
+ self.conv1d = torch.nn.Conv1d(
16
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, groups=groups, bias=bias)
17
+ self.gelu = torch.nn.GELU()
18
+
19
+ def forward(self, x: torch.Tensor):
20
+ x = torch.unsqueeze(x, dim=0) if len(x.shape) == 1 else torch.unsqueeze(x, dim=1) # [T] - >[1, T] OR [B, T] -> [B, 1, T]
21
+ x = self.conv1d(x)
22
+ x = self.gelu(x)
23
+ return x
24
+
25
+ class FeatureProjector(torch.nn.Module):
26
+ def __init__(self, num_channels: int, in_channels: int, out_channels: int, kernel_size: int, bias: bool):
27
+ super().__init__()
28
+ self.norm = torch.nn.GroupNorm(num_groups=1, num_channels=num_channels, eps=1e-8)
29
+ self.conv1d = torch.nn.Conv1d(
30
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
31
+
32
+ def forward(self, x: torch.Tensor):
33
+ x = self.norm(x)
34
+ x = self.conv1d(x)
35
+ return x
36
+
37
+
38
+ class Separator(torch.nn.Module):
39
+ def __init__(self, num_stages: int, relative_positional_encoding: dict, enc_stage: dict, spk_split_stage: dict, simple_fusion:dict, dec_stage: dict):
40
+ super().__init__()
41
+
42
+ class RelativePositionalEncoding(torch.nn.Module):
43
+ def __init__(self, in_channels: int, num_heads: int, maxlen: int, embed_v=False):
44
+ super().__init__()
45
+ self.in_channels = in_channels
46
+ self.num_heads = num_heads
47
+ self.embedding_dim = self.in_channels // self.num_heads
48
+ self.maxlen = maxlen
49
+ self.pe_k = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim)
50
+ self.pe_v = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim) if embed_v else None
51
+
52
+ def forward(self, pos_seq: torch.Tensor):
53
+ pos_seq.clamp_(-self.maxlen, self.maxlen - 1)
54
+ pos_seq += self.maxlen
55
+ pe_k_output = self.pe_k(pos_seq)
56
+ pe_v_output = self.pe_v(pos_seq) if self.pe_v is not None else None
57
+ return pe_k_output, pe_v_output
58
+
59
+ class SepEncStage(torch.nn.Module):
60
+ def __init__(self, global_blocks: dict, local_blocks: dict, down_conv_layer: dict, down_conv=True):
61
+ super().__init__()
62
+
63
+ class DownConvLayer(torch.nn.Module):
64
+ def __init__(self, in_channels: int, samp_kernel_size: int):
65
+ """Construct an EncoderLayer object."""
66
+ super().__init__()
67
+ self.down_conv = torch.nn.Conv1d(
68
+ in_channels=in_channels, out_channels=in_channels, kernel_size=samp_kernel_size, stride=2, padding=(samp_kernel_size-1)//2, groups=in_channels)
69
+ self.BN = torch.nn.BatchNorm1d(num_features=in_channels)
70
+ self.gelu = torch.nn.GELU()
71
+
72
+ def forward(self, x: torch.Tensor):
73
+ x = x.permute([0, 2, 1])
74
+ x = self.down_conv(x)
75
+ x = self.BN(x)
76
+ x = self.gelu(x)
77
+ x = x.permute([0, 2, 1])
78
+ return x
79
+
80
+ self.g_block_1 = GlobalBlock(**global_blocks)
81
+ self.l_block_1 = LocalBlock(**local_blocks)
82
+
83
+ self.g_block_2 = GlobalBlock(**global_blocks)
84
+ self.l_block_2 = LocalBlock(**local_blocks)
85
+
86
+ self.downconv = DownConvLayer(**down_conv_layer) if down_conv == True else None
87
+
88
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
89
+ '''
90
+ x: [B, N, T]
91
+ '''
92
+ x = self.g_block_1(x, pos_k)
93
+ x = x.permute(0, 2, 1).contiguous()
94
+ x = self.l_block_1(x)
95
+ x = x.permute(0, 2, 1).contiguous()
96
+
97
+ x = self.g_block_2(x, pos_k)
98
+ x = x.permute(0, 2, 1).contiguous()
99
+ x = self.l_block_2(x)
100
+ x = x.permute(0, 2, 1).contiguous()
101
+
102
+ skip = x
103
+ if self.downconv:
104
+ x = x.permute(0, 2, 1).contiguous()
105
+ x = self.downconv(x)
106
+ x = x.permute(0, 2, 1).contiguous()
107
+ # [BK, S, N]
108
+ return x, skip
109
+
110
+ class SpkSplitStage(torch.nn.Module):
111
+ def __init__(self, in_channels: int, num_spks: int):
112
+ super().__init__()
113
+ self.linear = torch.nn.Sequential(
114
+ torch.nn.Conv1d(in_channels, 4*in_channels*num_spks, kernel_size=1),
115
+ torch.nn.GLU(dim=-2),
116
+ torch.nn.Conv1d(2*in_channels*num_spks, in_channels*num_spks, kernel_size=1))
117
+ self.norm = torch.nn.GroupNorm(1, in_channels, eps=1e-8)
118
+ self.num_spks = num_spks
119
+
120
+ def forward(self, x: torch.Tensor):
121
+ x = self.linear(x)
122
+ B, _, T = x.shape
123
+ x = x.view(B*self.num_spks,-1, T).contiguous()
124
+ x = self.norm(x)
125
+ return x
126
+
127
+ class SepDecStage(torch.nn.Module):
128
+ def __init__(self, num_spks: int, global_blocks: dict, local_blocks: dict, spk_attention: dict):
129
+ super().__init__()
130
+
131
+ self.g_block_1 = GlobalBlock(**global_blocks)
132
+ self.l_block_1 = LocalBlock(**local_blocks)
133
+ self.spk_attn_1 = SpkAttention(**spk_attention)
134
+
135
+ self.g_block_2 = GlobalBlock(**global_blocks)
136
+ self.l_block_2 = LocalBlock(**local_blocks)
137
+ self.spk_attn_2 = SpkAttention(**spk_attention)
138
+
139
+ self.g_block_3 = GlobalBlock(**global_blocks)
140
+ self.l_block_3 = LocalBlock(**local_blocks)
141
+ self.spk_attn_3 = SpkAttention(**spk_attention)
142
+
143
+ self.num_spk = num_spks
144
+
145
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
146
+ '''
147
+ x: [B, N, T]
148
+ '''
149
+ # [BS, K, H]
150
+ x = self.g_block_1(x, pos_k)
151
+ x = x.permute(0, 2, 1).contiguous()
152
+ x = self.l_block_1(x)
153
+ x = x.permute(0, 2, 1).contiguous()
154
+ x = self.spk_attn_1(x, self.num_spk)
155
+
156
+ x = self.g_block_2(x, pos_k)
157
+ x = x.permute(0, 2, 1).contiguous()
158
+ x = self.l_block_2(x)
159
+ x = x.permute(0, 2, 1).contiguous()
160
+ x = self.spk_attn_2(x, self.num_spk)
161
+
162
+ x = self.g_block_3(x, pos_k)
163
+ x = x.permute(0, 2, 1).contiguous()
164
+ x = self.l_block_3(x)
165
+ x = x.permute(0, 2, 1).contiguous()
166
+ x = self.spk_attn_3(x, self.num_spk)
167
+
168
+ skip = x
169
+
170
+ return x, skip
171
+
172
+ self.num_stages = num_stages
173
+ self.pos_emb = RelativePositionalEncoding(**relative_positional_encoding)
174
+
175
+ # Temporal Contracting Part
176
+ self.enc_stages = torch.nn.ModuleList([])
177
+ for _ in range(self.num_stages):
178
+ self.enc_stages.append(SepEncStage(**enc_stage, down_conv=True))
179
+
180
+ self.bottleneck_G = SepEncStage(**enc_stage, down_conv=False)
181
+ self.spk_split_block = SpkSplitStage(**spk_split_stage)
182
+
183
+ # Temporal Expanding Part
184
+ self.simple_fusion = torch.nn.ModuleList([])
185
+ self.dec_stages = torch.nn.ModuleList([])
186
+ for _ in range(self.num_stages):
187
+ self.simple_fusion.append(torch.nn.Conv1d(in_channels=simple_fusion['out_channels']*2,out_channels=simple_fusion['out_channels'], kernel_size=1))
188
+ self.dec_stages.append(SepDecStage(**dec_stage))
189
+
190
+ def forward(self, input: torch.Tensor):
191
+ '''input: [B, N, L]'''
192
+ # feature projection
193
+ x, _ = self.pad_signal(input)
194
+ len_x = x.shape[-1]
195
+ # Temporal Contracting Part
196
+ pos_seq = torch.arange(0, len_x//2**self.num_stages).long().to(x.device)
197
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
198
+ pos_k, _ = self.pos_emb(pos_seq)
199
+ skip = []
200
+ for idx in range(self.num_stages):
201
+ x, skip_ = self.enc_stages[idx](x, pos_k)
202
+ skip_ = self.spk_split_block(skip_)
203
+ skip.append(skip_)
204
+ x, _ = self.bottleneck_G(x, pos_k)
205
+ x = self.spk_split_block(x) # B, 2F, T
206
+
207
+ each_stage_outputs = []
208
+ # Temporal Expanding Part
209
+ for idx in range(self.num_stages):
210
+ each_stage_outputs.append(x)
211
+ idx_en = self.num_stages - (idx + 1)
212
+ x = torch.nn.functional.upsample(x, skip[idx_en].shape[-1])
213
+ x = torch.cat([x,skip[idx_en]],dim=1)
214
+ x = self.simple_fusion[idx](x)
215
+ x, _ = self.dec_stages[idx](x, pos_k)
216
+
217
+ last_stage_output = x
218
+ return last_stage_output, each_stage_outputs
219
+
220
+ def pad_signal(self, input: torch.Tensor):
221
+ # (B, T) or (B, 1, T)
222
+ if input.dim() == 1: input = input.unsqueeze(0)
223
+ elif input.dim() not in [2, 3]: raise RuntimeError("Input can only be 2 or 3 dimensional.")
224
+ elif input.dim() == 2: input = input.unsqueeze(1)
225
+ L = 2**self.num_stages
226
+ batch_size = input.size(0)
227
+ ndim = input.size(1)
228
+ nframe = input.size(2)
229
+ padded_len = (nframe//L + 1)*L
230
+ rest = 0 if nframe%L == 0 else padded_len - nframe
231
+ if rest > 0:
232
+ pad = torch.autograd.Variable(torch.zeros(batch_size, ndim, rest)).type(input.type()).to(input.device)
233
+ input = torch.cat([input, pad], dim=-1)
234
+ return input, rest
235
+
236
+
237
+ class OutputLayer(torch.nn.Module):
238
+ def __init__(self, in_channels: int, out_channels: int, num_spks: int, masking: bool = False):
239
+ super().__init__()
240
+ # feature expansion back
241
+ self.masking = masking
242
+ self.spe_block = Masking(in_channels, Activation_mask="ReLU", concat_opt=None)
243
+ self.num_spks = num_spks
244
+ self.end_conv1x1 = torch.nn.Sequential(
245
+ torch.nn.Linear(out_channels, 4*out_channels),
246
+ torch.nn.GLU(),
247
+ torch.nn.Linear(2*out_channels, in_channels))
248
+
249
+ def forward(self, x: torch.Tensor, input: torch.Tensor):
250
+ x = x[...,:input.shape[-1]]
251
+ x = x.permute([0, 2, 1])
252
+ x = self.end_conv1x1(x)
253
+ x = x.permute([0, 2, 1])
254
+ B, N, L = x.shape
255
+ B = B // self.num_spks
256
+
257
+ if self.masking:
258
+ input = input.expand(self.num_spks, B, N, L).transpose(0,1).contiguous()
259
+ input = input.view(B*self.num_spks, N, L)
260
+ x = self.spe_block(x, input)
261
+
262
+ x = x.view(B, self.num_spks, N, L)
263
+ # [spks, B, N, L]
264
+ x = x.transpose(0, 1)
265
+ return x
266
+
267
+
268
+ class AudioDecoder(torch.nn.ConvTranspose1d):
269
+ '''
270
+ Decoder of the TasNet
271
+ This module can be seen as the gradient of Conv1d with respect to its input.
272
+ It is also known as a fractionally-strided convolution
273
+ or a deconvolution (although it is not an actual deconvolution operation).
274
+ '''
275
+ def __init__(self, *args, **kwargs):
276
+ super().__init__(*args, **kwargs)
277
+
278
+ def forward(self, x):
279
+ # x: [B, N, L]
280
+ if x.dim() not in [2, 3]: raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
281
+ x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
282
+ x = torch.squeeze(x, dim=1) if torch.squeeze(x).dim() == 1 else torch.squeeze(x)
283
+ return x
models/SepReformer/SepReformer_Base_WSJ0/modules/network.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy
4
+ from utils.decorators import *
5
+
6
+
7
+ class LayerScale(torch.nn.Module):
8
+ def __init__(self, dims, input_size, Layer_scale_init=1.0e-5):
9
+ super().__init__()
10
+ if dims == 1:
11
+ self.layer_scale = torch.nn.Parameter(torch.ones(input_size)*Layer_scale_init, requires_grad=True)
12
+ elif dims == 2:
13
+ self.layer_scale = torch.nn.Parameter(torch.ones(1,input_size)*Layer_scale_init, requires_grad=True)
14
+ elif dims == 3:
15
+ self.layer_scale = torch.nn.Parameter(torch.ones(1,1,input_size)*Layer_scale_init, requires_grad=True)
16
+
17
+ def forward(self, x):
18
+ return x*self.layer_scale
19
+
20
+ class Masking(torch.nn.Module):
21
+ def __init__(self, input_dim, Activation_mask='Sigmoid', **options):
22
+ super(Masking, self).__init__()
23
+
24
+ self.options = options
25
+ if self.options['concat_opt']:
26
+ self.pw_conv = torch.nn.Conv1d(input_dim*2, input_dim, 1, stride=1, padding=0)
27
+
28
+ if Activation_mask == 'Sigmoid':
29
+ self.gate_act = torch.nn.Sigmoid()
30
+ elif Activation_mask == 'ReLU':
31
+ self.gate_act = torch.nn.ReLU()
32
+
33
+
34
+ def forward(self, x, skip):
35
+
36
+ if self.options['concat_opt']:
37
+ y = torch.cat([x, skip], dim=-2)
38
+ y = self.pw_conv(y)
39
+ else:
40
+ y = x
41
+ y = self.gate_act(y) * skip
42
+
43
+ return y
44
+
45
+
46
+ class GCFN(torch.nn.Module):
47
+ def __init__(self, in_channels, dropout_rate, Layer_scale_init=1.0e-5):
48
+ super().__init__()
49
+ self.net1 = torch.nn.Sequential(
50
+ torch.nn.LayerNorm(in_channels),
51
+ torch.nn.Linear(in_channels, in_channels*6))
52
+ self.depthwise = torch.nn.Conv1d(in_channels*6, in_channels*6, 3, padding=1, groups=in_channels*6)
53
+ self.net2 = torch.nn.Sequential(
54
+ torch.nn.GLU(),
55
+ torch.nn.Dropout(dropout_rate),
56
+ torch.nn.Linear(in_channels*3, in_channels),
57
+ torch.nn.Dropout(dropout_rate))
58
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
59
+
60
+ def forward(self, x):
61
+ y = self.net1(x)
62
+ y = y.permute(0, 2, 1).contiguous()
63
+ y = self.depthwise(y)
64
+ y = y.permute(0, 2, 1).contiguous()
65
+ y = self.net2(y)
66
+ return x + self.Layer_scale(y)
67
+
68
+
69
+ class MultiHeadAttention(torch.nn.Module):
70
+ """
71
+ Multi-Head Attention layer.
72
+ :param int n_head: the number of head s
73
+ :param int n_feat: the number of features
74
+ :param float dropout_rate: dropout rate
75
+ """
76
+ def __init__(self, n_head: int, in_channels: int, dropout_rate: float, Layer_scale_init=1.0e-5):
77
+ super().__init__()
78
+ assert in_channels % n_head == 0
79
+ self.d_k = in_channels // n_head # We assume d_v always equals d_k
80
+ self.h = n_head
81
+ self.layer_norm = torch.nn.LayerNorm(in_channels)
82
+ self.linear_q = torch.nn.Linear(in_channels, in_channels)
83
+ self.linear_k = torch.nn.Linear(in_channels, in_channels)
84
+ self.linear_v = torch.nn.Linear(in_channels, in_channels)
85
+ self.linear_out = torch.nn.Linear(in_channels, in_channels)
86
+ self.attn = None
87
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
88
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
89
+
90
+ def forward(self, x, pos_k, mask):
91
+ """
92
+ Compute 'Scaled Dot Product Attention'.
93
+ :param torch.Tensor mask: (batch, time1, time2)
94
+ :param torch.nn.Dropout dropout:
95
+ :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
96
+ weighted by the query dot key attention (batch, head, time1, time2)
97
+ """
98
+ n_batch = x.size(0)
99
+ x = self.layer_norm(x)
100
+ q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
101
+ k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
102
+ v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k)
103
+ q = q.transpose(1, 2)
104
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
105
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
106
+ A = torch.matmul(q, k.transpose(-2, -1))
107
+ reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1)
108
+ if pos_k is not None:
109
+ B = torch.matmul(reshape_q, pos_k.transpose(-2, -1))
110
+ B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
111
+ scores = (A + B) / math.sqrt(self.d_k)
112
+ else:
113
+ scores = A / math.sqrt(self.d_k)
114
+ if mask is not None:
115
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
116
+ min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
117
+ scores = scores.masked_fill(mask, min_value)
118
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
119
+ else:
120
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
121
+ p_attn = self.dropout(self.attn)
122
+ x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
123
+ x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
124
+ return self.Layer_scale(self.dropout(self.linear_out(x))) # (batch, time1, d_model)
125
+
126
+ class EGA(torch.nn.Module):
127
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
128
+ super().__init__()
129
+ self.block = torch.nn.ModuleDict({
130
+ 'self_attn': MultiHeadAttention(
131
+ n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
132
+ 'linear': torch.nn.Sequential(
133
+ torch.nn.LayerNorm(normalized_shape=in_channels),
134
+ torch.nn.Linear(in_features=in_channels, out_features=in_channels),
135
+ torch.nn.Sigmoid())
136
+ })
137
+
138
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
139
+ """
140
+ Compute encoded features.
141
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
142
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
143
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
144
+ """
145
+ down_len = pos_k.shape[0]
146
+ x_down = torch.nn.functional.adaptive_avg_pool1d(input=x, output_size=down_len)
147
+ x = x.permute([0, 2, 1])
148
+ x_down = x_down.permute([0, 2, 1])
149
+ x_down = self.block['self_attn'](x_down, pos_k, None)
150
+ x_down = x_down.permute([0, 2, 1])
151
+ x_downup = torch.nn.functional.upsample(input=x_down, size=x.shape[1])
152
+ x_downup = x_downup.permute([0, 2, 1])
153
+ x = x + self.block['linear'](x) * x_downup
154
+
155
+ return x
156
+
157
+
158
+
159
+ class CLA(torch.nn.Module):
160
+ def __init__(self, in_channels, kernel_size, dropout_rate, Layer_scale_init=1.0e-5):
161
+ super().__init__()
162
+ self.layer_norm = torch.nn.LayerNorm(in_channels)
163
+ self.linear1 = torch.nn.Linear(in_channels, in_channels*2)
164
+ self.GLU = torch.nn.GLU()
165
+ self.dw_conv_1d = torch.nn.Conv1d(in_channels, in_channels, kernel_size, padding='same', groups=in_channels)
166
+ self.linear2 = torch.nn.Linear(in_channels, 2*in_channels)
167
+ self.BN = torch.nn.BatchNorm1d(2*in_channels)
168
+ self.linear3 = torch.nn.Sequential(
169
+ torch.nn.GELU(),
170
+ torch.nn.Linear(2*in_channels, in_channels),
171
+ torch.nn.Dropout(dropout_rate))
172
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
173
+
174
+ def forward(self, x):
175
+ y = self.layer_norm(x)
176
+ y = self.linear1(y)
177
+ y = self.GLU(y)
178
+ y = y.permute([0, 2, 1]) # B, F, T
179
+ y = self.dw_conv_1d(y)
180
+ y = y.permute(0, 2, 1) # B, T, 2F
181
+ y = self.linear2(y)
182
+ y = y.permute(0, 2, 1) # B, T, 2F
183
+ y = self.BN(y)
184
+ y = y.permute(0, 2, 1) # B, T, 2F
185
+ y = self.linear3(y)
186
+
187
+ return x + self.Layer_scale(y)
188
+
189
+ class GlobalBlock(torch.nn.Module):
190
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
191
+ super().__init__()
192
+ self.block = torch.nn.ModuleDict({
193
+ 'ega': EGA(
194
+ num_mha_heads=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
195
+ 'gcfn': GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
196
+ })
197
+
198
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
199
+ """
200
+ Compute encoded features.
201
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
202
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
203
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
204
+ """
205
+ x = self.block['ega'](x, pos_k)
206
+ x = self.block['gcfn'](x)
207
+ x = x.permute([0, 2, 1])
208
+
209
+ return x
210
+
211
+
212
+ class LocalBlock(torch.nn.Module):
213
+ def __init__(self, in_channels: int, kernel_size: int, dropout_rate: float):
214
+ super().__init__()
215
+ self.block = torch.nn.ModuleDict({
216
+ 'cla': CLA(in_channels, kernel_size, dropout_rate),
217
+ 'gcfn': GCFN(in_channels, dropout_rate)
218
+ })
219
+
220
+ def forward(self, x: torch.Tensor):
221
+ x = self.block['cla'](x)
222
+ x = self.block['gcfn'](x)
223
+
224
+ return x
225
+
226
+
227
+ class SpkAttention(torch.nn.Module):
228
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
229
+ super().__init__()
230
+ self.self_attn = MultiHeadAttention(n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate)
231
+ self.feed_forward = GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
232
+
233
+ def forward(self, x: torch.Tensor, num_spk: int):
234
+ """
235
+ Compute encoded features.
236
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
237
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
238
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
239
+ """
240
+ B, F, T = x.shape
241
+ x = x.view(B//num_spk, num_spk, F, T).contiguous()
242
+ x = x.permute([0, 3, 1, 2]).contiguous()
243
+ x = x.view(-1, num_spk, F).contiguous()
244
+ x = x + self.self_attn(x, None, None)
245
+ x = x.view(B//num_spk, T, num_spk, F).contiguous()
246
+ x = x.permute([0, 2, 3, 1]).contiguous()
247
+ x = x.view(B, F, T).contiguous()
248
+ x = x.permute([0, 2, 1])
249
+ x = self.feed_forward(x)
250
+ x = x.permute([0, 2, 1])
251
+
252
+ return x
models/SepReformer/SepReformer_Large_DM_WHAM/configs.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config:
2
+ dataset:
3
+ max_len : 32000
4
+ sampling_rate: 8000
5
+ scp_dir: "data/scp_ss_8k_wham"
6
+ train:
7
+ mixture: "tr_mix.scp"
8
+ spk1: "tr_s1.scp"
9
+ spk2: "tr_s2.scp"
10
+ noise: "tr_n.scp"
11
+ dynamic_mixing: true
12
+ valid:
13
+ mixture: "cv_mix.scp"
14
+ spk1: "cv_s1.scp"
15
+ spk2: "cv_s2.scp"
16
+ test:
17
+ mixture: "tt_mix.scp"
18
+ spk1: "tt_s1.scp"
19
+ spk2: "tt_s2.scp"
20
+ dataloader:
21
+ batch_size: 2
22
+ pin_memory: false
23
+ num_workers: 12
24
+ drop_last: false
25
+ model:
26
+ num_stages: &var_model_num_stages 4 # R
27
+ num_spks: &var_model_num_spks 2
28
+ module_audio_enc:
29
+ in_channels: 1
30
+ out_channels: &var_model_audio_enc_out_channels 256
31
+ kernel_size: &var_model_audio_enc_kernel_size 16 # L
32
+ stride: &var_model_audio_enc_stride 4 # S
33
+ groups: 1
34
+ bias: false
35
+ module_feature_projector:
36
+ num_channels: *var_model_audio_enc_out_channels
37
+ in_channels: *var_model_audio_enc_out_channels
38
+ out_channels: &feature_projector_out_channels 256 # F
39
+ kernel_size: 1
40
+ bias: false
41
+ module_separator:
42
+ num_stages: *var_model_num_stages
43
+ relative_positional_encoding:
44
+ in_channels: *feature_projector_out_channels
45
+ num_heads: 8
46
+ maxlen: 2000
47
+ embed_v: false
48
+ enc_stage:
49
+ global_blocks:
50
+ in_channels: *feature_projector_out_channels
51
+ num_mha_heads: 8
52
+ dropout_rate: 0.1
53
+ local_blocks:
54
+ in_channels: *feature_projector_out_channels
55
+ kernel_size: 65
56
+ dropout_rate: 0.1
57
+ down_conv_layer:
58
+ in_channels: *feature_projector_out_channels
59
+ samp_kernel_size: &var_model_samp_kernel_size 5
60
+ spk_split_stage:
61
+ in_channels: *feature_projector_out_channels
62
+ num_spks: *var_model_num_spks
63
+ simple_fusion:
64
+ out_channels: *feature_projector_out_channels
65
+ dec_stage:
66
+ num_spks: *var_model_num_spks
67
+ global_blocks:
68
+ in_channels: *feature_projector_out_channels
69
+ num_mha_heads: 8
70
+ dropout_rate: 0.1
71
+ local_blocks:
72
+ in_channels: *feature_projector_out_channels
73
+ kernel_size: 65
74
+ dropout_rate: 0.1
75
+ spk_attention:
76
+ in_channels: *feature_projector_out_channels
77
+ num_mha_heads: 8
78
+ dropout_rate: 0.1
79
+ module_output_layer:
80
+ in_channels: *var_model_audio_enc_out_channels
81
+ out_channels: *feature_projector_out_channels
82
+ num_spks: *var_model_num_spks
83
+ module_audio_dec:
84
+ in_channels: *var_model_audio_enc_out_channels
85
+ out_channels: 1
86
+ kernel_size: *var_model_audio_enc_kernel_size
87
+ stride: *var_model_audio_enc_stride
88
+ bias: false
89
+ criterion: ### Ref: https://pytorch.org/docs/stable/nn.html#loss-functions
90
+ name: ["PIT_SISNR_mag", "PIT_SISNR_time", "PIT_SISNRi", "PIT_SDRi"] ### Choose a torch.nn's loss function class(=attribute) e.g. ["L1Loss", "MSELoss", "CrossEntropyLoss", ...] / You can also build your optimizer :)
91
+ PIT_SISNR_mag:
92
+ frame_length: 512
93
+ frame_shift: 128
94
+ window: 'hann'
95
+ num_stages: *var_model_num_stages
96
+ num_spks: *var_model_num_spks
97
+ scale_inv: true
98
+ mel_opt: false
99
+ PIT_SISNR_time:
100
+ num_spks: *var_model_num_spks
101
+ scale_inv: true
102
+ PIT_SISNRi:
103
+ num_spks: *var_model_num_spks
104
+ scale_inv: true
105
+ PIT_SDRi:
106
+ dump: 0
107
+ optimizer: ### Ref: https://pytorch.org/docs/stable/optim.html#algorithms
108
+ name: ["AdamW"] ### Choose a torch.optim's class(=attribute) e.g. ["Adam", "AdamW", "SGD", ...] / You can also build your optimizer :)
109
+ AdamW:
110
+ lr: 2.0e-4
111
+ weight_decay: 1.0e-2
112
+ scheduler: ### Ref(+ find "How to adjust learning rate"): https://pytorch.org/docs/stable/optim.html#algorithms
113
+ name: ["ReduceLROnPlateau", "WarmupConstantSchedule"] ### Choose a torch.optim.lr_scheduler's class(=attribute) e.g. ["StepLR", "ReduceLROnPlateau", "Custom"] / You can also build your scheduler :)
114
+ ReduceLROnPlateau:
115
+ mode: "min"
116
+ min_lr: 1.0e-10
117
+ factor: 0.8
118
+ patience: 3
119
+ WarmupConstantSchedule:
120
+ warmup_steps: 1000
121
+ check_computations:
122
+ dummy_len: 16000
123
+ engine:
124
+ max_epoch: 200
125
+ gpuid: "1" ### "0"(single-gpu) or "0, 1" (multi-gpu)
126
+ mvn: false
127
+ clip_norm: 5
128
+ start_scheduling: 50
129
+ test_epochs: [50, 80, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 199]
models/SepReformer/SepReformer_Large_DM_WHAM/dataset.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import librosa as audio_lib
5
+ import numpy as np
6
+
7
+ from utils import util_dataset
8
+ from utils.decorators import *
9
+ from loguru import logger
10
+ from torch.utils.data import Dataset, DataLoader
11
+
12
+
13
+ @logger_wraps()
14
+ def get_dataloaders(args, dataset_config, loader_config):
15
+ # create dataset object for each partition
16
+ partitions = ["test"] if "test" in args.engine_mode else ["train", "valid", "test"]
17
+ dataloaders = {}
18
+ for partition in partitions:
19
+ scp_config_mix = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['mixture'])
20
+ scp_config_spk = [os.path.join(dataset_config["scp_dir"], dataset_config[partition][spk_key]) for spk_key in dataset_config[partition] if spk_key.startswith('spk')]
21
+ scp_config_noise = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['noise']) if 'noise' in dataset_config[partition] else None
22
+ dynamic_mixing = dataset_config[partition]["dynamic_mixing"] if partition == 'train' else False
23
+ dataset = MyDataset(
24
+ max_len = dataset_config['max_len'],
25
+ fs = dataset_config['sampling_rate'],
26
+ partition = partition,
27
+ wave_scp_srcs = scp_config_spk,
28
+ wave_scp_mix = scp_config_mix,
29
+ wave_scp_noise = scp_config_noise,
30
+ dynamic_mixing = dynamic_mixing)
31
+ dataloader = DataLoader(
32
+ dataset = dataset,
33
+ batch_size = 1 if partition == 'test' else loader_config["batch_size"],
34
+ shuffle = False if partition == 'test' else True, # only train: (partition == 'train') / all: True
35
+ pin_memory = loader_config["pin_memory"],
36
+ num_workers = loader_config["num_workers"],
37
+ drop_last = loader_config["drop_last"],
38
+ collate_fn = _collate)
39
+ dataloaders[partition] = dataloader
40
+ return dataloaders
41
+
42
+
43
+ def _collate(egs):
44
+ """
45
+ Transform utterance index into a minbatch
46
+
47
+ Arguments:
48
+ index: a list type [{},{},{}]
49
+
50
+ Returns:
51
+ input_sizes: a tensor correspond to utterance length
52
+ input_feats: packed sequence to feed networks
53
+ source_attr/target_attr: dictionary contains spectrogram/phase needed in loss computation
54
+ """
55
+ def __prepare_target_rir(dict_lsit, index):
56
+ return torch.nn.utils.rnn.pad_sequence([torch.tensor(d["src"][index], dtype=torch.float32) for d in dict_lsit], batch_first=True)
57
+ if type(egs) is not list: raise ValueError("Unsupported index type({})".format(type(egs)))
58
+ num_spks = 2 # you need to set this paramater by yourself
59
+ dict_list = sorted([eg for eg in egs], key=lambda x: x['num_sample'], reverse=True)
60
+ mixture = torch.nn.utils.rnn.pad_sequence([torch.tensor(d['mix'], dtype=torch.float32) for d in dict_list], batch_first=True)
61
+ src = [__prepare_target_rir(dict_list, index) for index in range(num_spks)]
62
+ input_sizes = torch.tensor([d['num_sample'] for d in dict_list], dtype=torch.float32)
63
+ key = [d['key'] for d in dict_list]
64
+ return input_sizes, mixture, src, key
65
+
66
+
67
+ @logger_wraps()
68
+ class MyDataset(Dataset):
69
+ def __init__(self, max_len, fs, partition, wave_scp_srcs, wave_scp_mix, wave_scp_noise=None, dynamic_mixing=False, speed_list=None):
70
+ self.partition = partition
71
+ for wave_scp_src in wave_scp_srcs:
72
+ if not os.path.exists(wave_scp_src): raise FileNotFoundError(f"Could not find file {wave_scp_src}")
73
+ self.max_len = max_len
74
+ self.fs = fs
75
+ self.wave_dict_srcs = [util_dataset.parse_scps(wave_scp_src) for wave_scp_src in wave_scp_srcs]
76
+ self.wave_dict_mix = util_dataset.parse_scps(wave_scp_mix)
77
+ self.wave_dict_noise = util_dataset.parse_scps(wave_scp_noise) if wave_scp_noise else None
78
+ self.wave_keys = list(self.wave_dict_mix.keys())
79
+ logger.info(f"Create MyDataset for {wave_scp_mix} with {len(self.wave_dict_mix)} utterances")
80
+ self.dynamic_mixing = dynamic_mixing
81
+
82
+ def __len__(self):
83
+ return len(self.wave_dict_mix)
84
+
85
+ def __contains__(self, key):
86
+ return key in self.wave_dict_mix
87
+
88
+
89
+ def _dynamic_mixing(self, key):
90
+ def __match_length(wav, len_data) :
91
+ leftover = len(wav) - len_data
92
+ idx = random.randint(0,leftover)
93
+ wav = wav[idx:idx+len_data]
94
+ return wav
95
+
96
+ samps_src = []
97
+ src_len = [self.max_len]
98
+
99
+ # dyanmic source choice
100
+ key_random = random.choice(list(self.wave_dict_srcs[0].keys()))
101
+
102
+ idx1, idx2 = (0, 1) if random.random() > 0.5 else (1, 0)
103
+ files = [self.wave_dict_srcs[idx1][key], self.wave_dict_srcs[idx2][key_random]]
104
+
105
+ # load
106
+ for idx, file in enumerate(files):
107
+ if not os.path.exists(file): raise FileNotFoundError("Input file {} do not exists!".format(file))
108
+ samps_tmp, _ = audio_lib.load(file, sr=self.fs)
109
+
110
+ if idx == 0: ref_rms = np.sqrt(np.mean(np.square(samps_tmp)))
111
+ curr_rms = np.sqrt(np.mean(np.square(samps_tmp)))
112
+
113
+ norm_factor = ref_rms / curr_rms
114
+ samps_tmp *= norm_factor
115
+
116
+ # mixing with random gains
117
+ gain = pow(10,-random.uniform(-5,5)/20)
118
+ samps_tmp = np.array(torch.tensor(samps_tmp))
119
+ samps_src.append(gain*samps_tmp)
120
+ src_len.append(len(samps_tmp))
121
+
122
+ # matching the audio length
123
+ min_len = min(src_len)
124
+
125
+ # add noise source
126
+ file_noise = self.wave_dict_noise[key]
127
+ samps_noise, _ = audio_lib.load(file_noise, sr=self.fs)
128
+ curr_rms = np.sqrt(np.mean(np.square(samps_noise)))
129
+ norm_factor = ref_rms / curr_rms
130
+ samps_noise *= norm_factor
131
+ gain_noise = pow(10,-random.uniform(-5,5)/20)
132
+ samps_noise = samps_noise*gain_noise
133
+ src_len.append(len(samps_noise))
134
+
135
+ # truncate
136
+ min_len = min(src_len)
137
+ samps_src = [__match_length(s, min_len) for s in samps_src]
138
+ samps_noise = __match_length(samps_noise, min_len)
139
+ samps_mix = sum(samps_src) + samps_noise
140
+
141
+ if len(samps_mix)%4 != 0:
142
+ remains = len(samps_mix)%4
143
+ samps_mix = samps_mix[:-remains]
144
+ samps_src = [s[:-remains] for s in samps_src]
145
+
146
+ return samps_mix, samps_src
147
+
148
+ def _direct_load(self, key):
149
+ samps_src = []
150
+ files = [wave_dict_src[key] for wave_dict_src in self.wave_dict_srcs]
151
+ for file in files:
152
+ if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
153
+ samps_tmp, _ = audio_lib.load(file, sr=self.fs)
154
+ samps_src.append(samps_tmp)
155
+
156
+ file = self.wave_dict_mix[key]
157
+ if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
158
+ samps_mix, _ = audio_lib.load(file, sr=self.fs)
159
+ # Truncate samples as needed
160
+ if len(samps_mix) % 4 != 0:
161
+ remains = len(samps_mix) % 4
162
+ samps_mix = samps_mix[:-remains]
163
+ samps_src = [s[:-remains] for s in samps_src]
164
+
165
+ if self.partition != "test":
166
+ if len(samps_mix) > self.max_len:
167
+ start = random.randint(0,len(samps_mix)-self.max_len)
168
+ samps_mix = samps_mix[start:start+self.max_len]
169
+ samps_src = [s[start:start+self.max_len] for s in samps_src]
170
+
171
+ return samps_mix, samps_src
172
+
173
+ def __getitem__(self, index):
174
+ key = self.wave_keys[index]
175
+ if any(key not in self.wave_dict_srcs[i] for i in range(len(self.wave_dict_srcs))) or key not in self.wave_dict_mix: raise KeyError(f"Could not find utterance {key}")
176
+ samps_mix, samps_src = self._dynamic_mixing(key) if self.dynamic_mixing else self._direct_load(key)
177
+ return {"num_sample": samps_mix.shape[0], "mix": samps_mix, "src": samps_src, "key": key}
models/SepReformer/SepReformer_Large_DM_WHAM/engine.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import csv
4
+ import time
5
+ import soundfile as sf
6
+
7
+ from loguru import logger
8
+ from tqdm import tqdm
9
+ from utils import util_engine, functions
10
+ from utils.decorators import *
11
+ from torch.utils.tensorboard import SummaryWriter
12
+
13
+
14
+ @logger_wraps()
15
+ class Engine(object):
16
+ def __init__(self, args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device):
17
+
18
+ ''' Default setting '''
19
+ self.engine_mode = args.engine_mode
20
+ self.out_wav_dir = args.out_wav_dir
21
+ self.config = config
22
+ self.gpuid = gpuid
23
+ self.device = device
24
+ self.model = model.to(self.device)
25
+ self.dataloaders = dataloaders # self.dataloaders['train'] or ['valid'] or ['test']
26
+ self.PIT_SISNR_mag_loss, self.PIT_SISNR_time_loss, self.PIT_SISNRi_loss, self.PIT_SDRi_loss = criterions
27
+ self.main_optimizer = optimizers[0]
28
+ self.main_scheduler, self.warmup_scheduler = schedulers
29
+
30
+ self.pretrain_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "pretrain_weights")
31
+ os.makedirs(self.pretrain_weights_path, exist_ok=True)
32
+ self.scratch_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "scratch_weights")
33
+ os.makedirs(self.scratch_weights_path, exist_ok=True)
34
+
35
+ self.checkpoint_path = self.pretrain_weights_path if any(file.endswith(('.pt', '.pt', '.pkl')) for file in os.listdir(self.pretrain_weights_path)) else self.scratch_weights_path
36
+ self.start_epoch = util_engine.load_last_checkpoint_n_get_epoch(self.checkpoint_path, self.model, self.main_optimizer, location=self.device)
37
+
38
+ # Logging
39
+ util_engine.model_params_mac_summary(
40
+ model=self.model,
41
+ input=torch.randn(1, self.config['check_computations']['dummy_len']).to(self.device),
42
+ dummy_input=torch.rand(1, self.config['check_computations']['dummy_len']).to(self.device),
43
+ metrics=['ptflops', 'thop', 'torchinfo']
44
+ # metrics=['ptflops']
45
+ )
46
+
47
+ logger.info(f"Clip gradient by 2-norm {self.config['engine']['clip_norm']}")
48
+
49
+ @logger_wraps()
50
+ def _train(self, dataloader, epoch):
51
+ self.model.train()
52
+ tot_loss_freq = [0 for _ in range(self.model.num_stages)]
53
+ tot_loss_time, num_batch = 0, 0
54
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:25}{r_bar}{bar:-10b}', colour="YELLOW", dynamic_ncols=True)
55
+ for input_sizes, mixture, src, _ in dataloader:
56
+ nnet_input = mixture
57
+ nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
58
+ num_batch += 1
59
+ pbar.update(1)
60
+ # Scheduler learning rate for warm-up (Iteration-based update for transformers)
61
+ if epoch == 1: self.warmup_scheduler.step()
62
+ nnet_input = nnet_input.to(self.device)
63
+ self.main_optimizer.zero_grad()
64
+ estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
65
+ cur_loss_s_bn = 0
66
+ cur_loss_s_bn = []
67
+ for idx, estim_src_value in enumerate(estim_src_bn):
68
+ cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
69
+ tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
70
+ cur_loss_s = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
71
+ tot_loss_time += cur_loss_s.item() / self.config['model']['num_spks']
72
+ alpha = 0.4 * 0.8**(1+(epoch-101)//5) if epoch > 100 else 0.4
73
+ cur_loss = (1-alpha) * cur_loss_s + alpha * sum(cur_loss_s_bn) / len(cur_loss_s_bn)
74
+ cur_loss = cur_loss / self.config['model']['num_spks']
75
+ cur_loss.backward()
76
+ if self.config['engine']['clip_norm']: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['engine']['clip_norm'])
77
+ self.main_optimizer.step()
78
+ dict_loss = {"T_Loss": tot_loss_time / num_batch}
79
+ dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
80
+ pbar.set_postfix(dict_loss)
81
+ pbar.close()
82
+ tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
83
+ return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
84
+
85
+ @logger_wraps()
86
+ def _validate(self, dataloader):
87
+ self.model.eval()
88
+ tot_loss_freq = [0 for _ in range(self.model.num_stages)]
89
+ tot_loss_time, num_batch = 0, 0
90
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="RED", dynamic_ncols=True)
91
+ with torch.inference_mode():
92
+ for input_sizes, mixture, src, _ in dataloader:
93
+ nnet_input = mixture
94
+ nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
95
+ nnet_input = nnet_input.to(self.device)
96
+ num_batch += 1
97
+ pbar.update(1)
98
+ estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
99
+ cur_loss_s_bn = []
100
+ for idx, estim_src_value in enumerate(estim_src_bn):
101
+ cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
102
+ tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
103
+ cur_loss_s_SDR = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
104
+ tot_loss_time += cur_loss_s_SDR.item() / self.config['model']['num_spks']
105
+ dict_loss = {"T_Loss":tot_loss_time / num_batch}
106
+ dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
107
+ pbar.set_postfix(dict_loss)
108
+ pbar.close()
109
+ tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
110
+ return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
111
+
112
+ @logger_wraps()
113
+ def _test(self, dataloader, wav_dir=None):
114
+ self.model.eval()
115
+ total_loss_SISNRi, total_loss_SDRi, num_batch = 0, 0, 0
116
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="grey", dynamic_ncols=True)
117
+ with torch.inference_mode():
118
+ csv_file_name_sisnr = os.path.join(os.path.dirname(__file__),'test_SISNRi_value.csv')
119
+ csv_file_name_sdr = os.path.join(os.path.dirname(__file__),'test_SDRi_value.csv')
120
+ with open(csv_file_name_sisnr, 'w', newline='') as csvfile_sisnr, open(csv_file_name_sdr, 'w', newline='') as csvfile_sdr:
121
+ idx = 0
122
+ writer_sisnr = csv.writer(csvfile_sisnr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
123
+ writer_sdr = csv.writer(csvfile_sdr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
124
+ for input_sizes, mixture, src, key in dataloader:
125
+ if len(key) > 1:
126
+ raise("batch size is not one!!")
127
+ nnet_input = mixture.to(self.device)
128
+ num_batch += 1
129
+ pbar.update(1)
130
+ estim_src, _ = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
131
+ cur_loss_SISNRi, cur_loss_SISNRi_src = self.PIT_SISNRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src, eps=1.0e-15)
132
+ total_loss_SISNRi += cur_loss_SISNRi.item() / self.config['model']['num_spks']
133
+ cur_loss_SDRi, cur_loss_SDRi_src = self.PIT_SDRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src)
134
+ total_loss_SDRi += cur_loss_SDRi.item() / self.config['model']['num_spks']
135
+ writer_sisnr.writerow([key[0][:-4]] + [cur_loss_SISNRi_src[i].item() for i in range(self.config['model']['num_spks'])])
136
+ writer_sdr.writerow([key[0][:-4]] + [cur_loss_SDRi_src[i].item() for i in range(self.config['model']['num_spks'])])
137
+ if self.engine_mode == "test_save":
138
+ if wav_dir == None: wav_dir = os.path.join(os.path.dirname(__file__),"wav_out")
139
+ if wav_dir and not os.path.exists(wav_dir): os.makedirs(wav_dir)
140
+ mixture = torch.squeeze(mixture).cpu().data.numpy()
141
+ sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_mixture.wav'), 0.5*mixture/max(abs(mixture)), 8000)
142
+ for i in range(self.config['model']['num_spks']):
143
+ src = torch.squeeze(estim_src[i]).cpu().data.numpy()
144
+ sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_out_'+str(i)+'.wav'), 0.5*src/max(abs(src)), 8000)
145
+ idx += 1
146
+ dict_loss = {"SiSNRi": total_loss_SISNRi/num_batch, "SDRi": total_loss_SDRi/num_batch}
147
+ pbar.set_postfix(dict_loss)
148
+ pbar.close()
149
+ return total_loss_SISNRi/num_batch, total_loss_SDRi/num_batch, num_batch
150
+
151
+ @logger_wraps()
152
+ def run(self):
153
+ with torch.cuda.device(self.device):
154
+ writer_src = SummaryWriter(os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/tensorboard"))
155
+ if "test" in self.engine_mode:
156
+ on_test_start = time.time()
157
+ test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'], self.out_wav_dir)
158
+ on_test_end = time.time()
159
+ logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
160
+ logger.info(f"Testing done!")
161
+ else:
162
+ start_time = time.time()
163
+ if self.start_epoch > 1:
164
+ init_loss_time, init_loss_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
165
+ else:
166
+ init_loss_time, init_loss_freq = 0, 0
167
+ end_time = time.time()
168
+ logger.info(f"[INIT] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: Loss_t = {init_loss_time:.4f} dB | Loss_f = {init_loss_freq:.4f} dB | Speed = ({end_time-start_time:.2f}s)")
169
+ for epoch in range(self.start_epoch, self.config['engine']['max_epoch']):
170
+ valid_loss_best = init_loss_time
171
+ train_start_time = time.time()
172
+ train_loss_src_time, train_loss_src_freq, train_num_batch = self._train(self.dataloaders['train'], epoch)
173
+ train_end_time = time.time()
174
+ valid_start_time = time.time()
175
+ valid_loss_src_time, valid_loss_src_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
176
+ valid_end_time = time.time()
177
+ if epoch > self.config['engine']['start_scheduling']: self.main_scheduler.step(valid_loss_src_time)
178
+ logger.info(f"[TRAIN] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {train_loss_src_time:.4f} dB | Loss_f = {train_loss_src_freq:.4f} dB | Speed = ({train_end_time - train_start_time:.2f}s/{train_num_batch:d})")
179
+ logger.info(f"[VALID] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {valid_loss_src_time:.4f} dB | Loss_f = {valid_loss_src_freq:.4f} dB | Speed = ({valid_end_time - valid_start_time:.2f}s/{valid_num_batch:d})")
180
+ if epoch in self.config['engine']['test_epochs']:
181
+ on_test_start = time.time()
182
+ test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'])
183
+ on_test_end = time.time()
184
+ logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
185
+ valid_loss_best = util_engine.save_checkpoint_per_best(valid_loss_best, valid_loss_src_time, train_loss_src_time, epoch, self.model, self.main_optimizer, self.checkpoint_path)
186
+ # Logging to monitoring tools (Tensorboard && Wandb)
187
+ writer_src.add_scalars("Metrics", {
188
+ 'Loss_train_time': train_loss_src_time,
189
+ 'Loss_valid_time': valid_loss_src_time}, epoch)
190
+ writer_src.add_scalar("Learning Rate", self.main_optimizer.param_groups[0]['lr'], epoch)
191
+ writer_src.flush()
192
+ logger.info(f"Training for {self.config['engine']['max_epoch']} epoches done!")
models/SepReformer/SepReformer_Large_DM_WHAM/main.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from loguru import logger
4
+ from .dataset import get_dataloaders
5
+ from .model import Model
6
+ from .engine import Engine
7
+ from utils import util_system, util_implement
8
+ from utils.decorators import *
9
+
10
+ # Setup logger
11
+ log_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/system_log.log")
12
+ logger.add(log_file_path, level="DEBUG", mode="w")
13
+
14
+ @logger_wraps()
15
+ def main(args):
16
+
17
+ ''' Build Setting '''
18
+ # Call configuration file (configs.yaml)
19
+ yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs.yaml")
20
+ yaml_dict = util_system.parse_yaml(yaml_path)
21
+
22
+ # Run wandb and get configuration
23
+ config = yaml_dict["config"] # wandb login success or fail
24
+
25
+ # Call DataLoader [train / valid / test / etc...]
26
+ dataloaders = get_dataloaders(args, config["dataset"], config["dataloader"])
27
+
28
+ ''' Build Model '''
29
+ # Call network model
30
+ model = Model(**config["model"])
31
+
32
+ ''' Build Engine '''
33
+ # Call gpu id & device
34
+ gpuid = tuple(map(int, config["engine"]["gpuid"].split(',')))
35
+ device = torch.device(f'cuda:{gpuid[0]}')
36
+
37
+ # Call Implement [criterion / optimizer / scheduler]
38
+ criterions = util_implement.CriterionFactory(config["criterion"], device).get_criterions()
39
+ optimizers = util_implement.OptimizerFactory(config["optimizer"], model.parameters()).get_optimizers()
40
+ schedulers = util_implement.SchedulerFactory(config["scheduler"], optimizers).get_schedulers()
41
+
42
+ # Call & Run Engine
43
+ engine = Engine(args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device)
44
+ engine.run()
models/SepReformer/SepReformer_Large_DM_WHAM/model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+
4
+ import torch
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ from utils.decorators import *
9
+ from .modules.module import *
10
+
11
+
12
+ @logger_wraps()
13
+ class Model(torch.nn.Module):
14
+ def __init__(self,
15
+ num_stages: int,
16
+ num_spks: int,
17
+ module_audio_enc: dict,
18
+ module_feature_projector: dict,
19
+ module_separator: dict,
20
+ module_output_layer: dict,
21
+ module_audio_dec: dict):
22
+ super().__init__()
23
+ self.num_stages = num_stages
24
+ self.num_spks = num_spks
25
+ self.audio_encoder = AudioEncoder(**module_audio_enc)
26
+ self.feature_projector = FeatureProjector(**module_feature_projector)
27
+ self.separator = Separator(**module_separator)
28
+ self.out_layer = OutputLayer(**module_output_layer)
29
+ self.audio_decoder = AudioDecoder(**module_audio_dec)
30
+
31
+ # Aux_loss
32
+ self.out_layer_bn = torch.nn.ModuleList([])
33
+ self.decoder_bn = torch.nn.ModuleList([])
34
+ for _ in range(self.num_stages):
35
+ self.out_layer_bn.append(OutputLayer(**module_output_layer, masking=True))
36
+ self.decoder_bn.append(AudioDecoder(**module_audio_dec))
37
+
38
+ def forward(self, x):
39
+ encoder_output = self.audio_encoder(x)
40
+ projected_feature = self.feature_projector(encoder_output)
41
+ last_stage_output, each_stage_outputs = self.separator(projected_feature)
42
+ out_layer_output = self.out_layer(last_stage_output, encoder_output)
43
+ each_spk_output = [out_layer_output[idx] for idx in range(self.num_spks)]
44
+ audio = [self.audio_decoder(each_spk_output[idx]) for idx in range(self.num_spks)]
45
+
46
+ # Aux_loss
47
+ audio_aux = []
48
+ for idx, each_stage_output in enumerate(each_stage_outputs):
49
+ each_stage_output = self.out_layer_bn[idx](torch.nn.functional.upsample(each_stage_output, encoder_output.shape[-1]), encoder_output)
50
+ out_aux = [each_stage_output[jdx] for jdx in range(self.num_spks)]
51
+ audio_aux.append([self.decoder_bn[idx](out_aux[jdx])[...,:x.shape[-1]] for jdx in range(self.num_spks)])
52
+
53
+ return audio, audio_aux
models/SepReformer/SepReformer_Large_DM_WHAM/modules/module.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+
4
+ import torch
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ from utils.decorators import *
9
+ from .network import *
10
+
11
+
12
+ class AudioEncoder(torch.nn.Module):
13
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, bias: bool):
14
+ super().__init__()
15
+ self.conv1d = torch.nn.Conv1d(
16
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, groups=groups, bias=bias)
17
+ self.gelu = torch.nn.GELU()
18
+
19
+ def forward(self, x: torch.Tensor):
20
+ x = torch.unsqueeze(x, dim=0) if len(x.shape) == 1 else torch.unsqueeze(x, dim=1) # [T] - >[1, T] OR [B, T] -> [B, 1, T]
21
+ x = self.conv1d(x)
22
+ x = self.gelu(x)
23
+ return x
24
+
25
+ class FeatureProjector(torch.nn.Module):
26
+ def __init__(self, num_channels: int, in_channels: int, out_channels: int, kernel_size: int, bias: bool):
27
+ super().__init__()
28
+ self.norm = torch.nn.GroupNorm(num_groups=1, num_channels=num_channels, eps=1e-8)
29
+ self.conv1d = torch.nn.Conv1d(
30
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
31
+
32
+ def forward(self, x: torch.Tensor):
33
+ x = self.norm(x)
34
+ x = self.conv1d(x)
35
+ return x
36
+
37
+
38
+ class Separator(torch.nn.Module):
39
+ def __init__(self, num_stages: int, relative_positional_encoding: dict, enc_stage: dict, spk_split_stage: dict, simple_fusion:dict, dec_stage: dict):
40
+ super().__init__()
41
+
42
+ class RelativePositionalEncoding(torch.nn.Module):
43
+ def __init__(self, in_channels: int, num_heads: int, maxlen: int, embed_v=False):
44
+ super().__init__()
45
+ self.in_channels = in_channels
46
+ self.num_heads = num_heads
47
+ self.embedding_dim = self.in_channels // self.num_heads
48
+ self.maxlen = maxlen
49
+ self.pe_k = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim)
50
+ self.pe_v = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim) if embed_v else None
51
+
52
+ def forward(self, pos_seq: torch.Tensor):
53
+ pos_seq.clamp_(-self.maxlen, self.maxlen - 1)
54
+ pos_seq += self.maxlen
55
+ pe_k_output = self.pe_k(pos_seq)
56
+ pe_v_output = self.pe_v(pos_seq) if self.pe_v is not None else None
57
+ return pe_k_output, pe_v_output
58
+
59
+ class SepEncStage(torch.nn.Module):
60
+ def __init__(self, global_blocks: dict, local_blocks: dict, down_conv_layer: dict, down_conv=True):
61
+ super().__init__()
62
+
63
+ class DownConvLayer(torch.nn.Module):
64
+ def __init__(self, in_channels: int, samp_kernel_size: int):
65
+ """Construct an EncoderLayer object."""
66
+ super().__init__()
67
+ self.down_conv = torch.nn.Conv1d(
68
+ in_channels=in_channels, out_channels=in_channels, kernel_size=samp_kernel_size, stride=2, padding=(samp_kernel_size-1)//2, groups=in_channels)
69
+ self.BN = torch.nn.BatchNorm1d(num_features=in_channels)
70
+ self.gelu = torch.nn.GELU()
71
+
72
+ def forward(self, x: torch.Tensor):
73
+ x = x.permute([0, 2, 1])
74
+ x = self.down_conv(x)
75
+ x = self.BN(x)
76
+ x = self.gelu(x)
77
+ x = x.permute([0, 2, 1])
78
+ return x
79
+
80
+ self.g_block_1 = GlobalBlock(**global_blocks)
81
+ self.l_block_1 = LocalBlock(**local_blocks)
82
+
83
+ self.g_block_2 = GlobalBlock(**global_blocks)
84
+ self.l_block_2 = LocalBlock(**local_blocks)
85
+
86
+ self.downconv = DownConvLayer(**down_conv_layer) if down_conv == True else None
87
+
88
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
89
+ '''
90
+ x: [B, N, T]
91
+ '''
92
+ x = self.g_block_1(x, pos_k)
93
+ x = x.permute(0, 2, 1).contiguous()
94
+ x = self.l_block_1(x)
95
+ x = x.permute(0, 2, 1).contiguous()
96
+
97
+ x = self.g_block_2(x, pos_k)
98
+ x = x.permute(0, 2, 1).contiguous()
99
+ x = self.l_block_2(x)
100
+ x = x.permute(0, 2, 1).contiguous()
101
+
102
+ skip = x
103
+ if self.downconv:
104
+ x = x.permute(0, 2, 1).contiguous()
105
+ x = self.downconv(x)
106
+ x = x.permute(0, 2, 1).contiguous()
107
+ # [BK, S, N]
108
+ return x, skip
109
+
110
+ class SpkSplitStage(torch.nn.Module):
111
+ def __init__(self, in_channels: int, num_spks: int):
112
+ super().__init__()
113
+ self.linear = torch.nn.Sequential(
114
+ torch.nn.Conv1d(in_channels, 4*in_channels*num_spks, kernel_size=1),
115
+ torch.nn.GLU(dim=-2),
116
+ torch.nn.Conv1d(2*in_channels*num_spks, in_channels*num_spks, kernel_size=1))
117
+ self.norm = torch.nn.GroupNorm(1, in_channels, eps=1e-8)
118
+ self.num_spks = num_spks
119
+
120
+ def forward(self, x: torch.Tensor):
121
+ x = self.linear(x)
122
+ B, _, T = x.shape
123
+ x = x.view(B*self.num_spks,-1, T).contiguous()
124
+ x = self.norm(x)
125
+ return x
126
+
127
+ class SepDecStage(torch.nn.Module):
128
+ def __init__(self, num_spks: int, global_blocks: dict, local_blocks: dict, spk_attention: dict):
129
+ super().__init__()
130
+
131
+ self.g_block_1 = GlobalBlock(**global_blocks)
132
+ self.l_block_1 = LocalBlock(**local_blocks)
133
+ self.spk_attn_1 = SpkAttention(**spk_attention)
134
+
135
+ self.g_block_2 = GlobalBlock(**global_blocks)
136
+ self.l_block_2 = LocalBlock(**local_blocks)
137
+ self.spk_attn_2 = SpkAttention(**spk_attention)
138
+
139
+ self.g_block_3 = GlobalBlock(**global_blocks)
140
+ self.l_block_3 = LocalBlock(**local_blocks)
141
+ self.spk_attn_3 = SpkAttention(**spk_attention)
142
+
143
+ self.num_spk = num_spks
144
+
145
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
146
+ '''
147
+ x: [B, N, T]
148
+ '''
149
+ # [BS, K, H]
150
+ x = self.g_block_1(x, pos_k)
151
+ x = x.permute(0, 2, 1).contiguous()
152
+ x = self.l_block_1(x)
153
+ x = x.permute(0, 2, 1).contiguous()
154
+ x = self.spk_attn_1(x, self.num_spk)
155
+
156
+ x = self.g_block_2(x, pos_k)
157
+ x = x.permute(0, 2, 1).contiguous()
158
+ x = self.l_block_2(x)
159
+ x = x.permute(0, 2, 1).contiguous()
160
+ x = self.spk_attn_2(x, self.num_spk)
161
+
162
+ x = self.g_block_3(x, pos_k)
163
+ x = x.permute(0, 2, 1).contiguous()
164
+ x = self.l_block_3(x)
165
+ x = x.permute(0, 2, 1).contiguous()
166
+ x = self.spk_attn_3(x, self.num_spk)
167
+
168
+ skip = x
169
+
170
+ return x, skip
171
+
172
+ self.num_stages = num_stages
173
+ self.pos_emb = RelativePositionalEncoding(**relative_positional_encoding)
174
+
175
+ # Temporal Contracting Part
176
+ self.enc_stages = torch.nn.ModuleList([])
177
+ for _ in range(self.num_stages):
178
+ self.enc_stages.append(SepEncStage(**enc_stage, down_conv=True))
179
+
180
+ self.bottleneck_G = SepEncStage(**enc_stage, down_conv=False)
181
+
182
+ self.spk_split_blocks = torch.nn.ModuleList([])
183
+ for _ in range(self.num_stages+1):
184
+ self.spk_split_blocks.append(SpkSplitStage(**spk_split_stage))
185
+
186
+ # Temporal Expanding Part
187
+ self.simple_fusion = torch.nn.ModuleList([])
188
+ self.dec_stages = torch.nn.ModuleList([])
189
+ for _ in range(self.num_stages):
190
+ self.simple_fusion.append(torch.nn.Conv1d(in_channels=simple_fusion['out_channels']*2,out_channels=simple_fusion['out_channels'], kernel_size=1))
191
+ self.dec_stages.append(SepDecStage(**dec_stage))
192
+
193
+ def forward(self, input: torch.Tensor):
194
+ '''input: [B, N, L]'''
195
+ # feature projection
196
+ x, _ = self.pad_signal(input)
197
+ len_x = x.shape[-1]
198
+ # Temporal Contracting Part
199
+ pos_seq = torch.arange(0, len_x//2**self.num_stages).long().to(x.device)
200
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
201
+ pos_k, _ = self.pos_emb(pos_seq)
202
+ skip = []
203
+ for idx in range(self.num_stages):
204
+ x, skip_ = self.enc_stages[idx](x, pos_k)
205
+ skip_ = self.spk_split_blocks[idx](skip_)
206
+ skip.append(skip_)
207
+ x, _ = self.bottleneck_G(x, pos_k)
208
+ x = self.spk_split_blocks[-1](x) # B, 2F, T
209
+
210
+ each_stage_outputs = []
211
+ # Temporal Expanding Part
212
+ for idx in range(self.num_stages):
213
+ each_stage_outputs.append(x)
214
+ idx_en = self.num_stages - (idx + 1)
215
+ x = torch.nn.functional.upsample(x, skip[idx_en].shape[-1])
216
+ x = torch.cat([x,skip[idx_en]],dim=1)
217
+ x = self.simple_fusion[idx](x)
218
+ x, _ = self.dec_stages[idx](x, pos_k)
219
+
220
+ last_stage_output = x
221
+ return last_stage_output, each_stage_outputs
222
+
223
+ def pad_signal(self, input: torch.Tensor):
224
+ # (B, T) or (B, 1, T)
225
+ if input.dim() == 1: input = input.unsqueeze(0)
226
+ elif input.dim() not in [2, 3]: raise RuntimeError("Input can only be 2 or 3 dimensional.")
227
+ elif input.dim() == 2: input = input.unsqueeze(1)
228
+ L = 2**self.num_stages
229
+ batch_size = input.size(0)
230
+ ndim = input.size(1)
231
+ nframe = input.size(2)
232
+ padded_len = (nframe//L + 1)*L
233
+ rest = 0 if nframe%L == 0 else padded_len - nframe
234
+ if rest > 0:
235
+ pad = torch.autograd.Variable(torch.zeros(batch_size, ndim, rest)).type(input.type()).to(input.device)
236
+ input = torch.cat([input, pad], dim=-1)
237
+ return input, rest
238
+
239
+
240
+ class OutputLayer(torch.nn.Module):
241
+ def __init__(self, in_channels: int, out_channels: int, num_spks: int, masking: bool = False):
242
+ super().__init__()
243
+ # feature expansion back
244
+ self.masking = masking
245
+ self.spe_block = Masking(in_channels, Activation_mask="ReLU", concat_opt=None)
246
+ self.num_spks = num_spks
247
+ self.end_conv1x1 = torch.nn.Sequential(
248
+ torch.nn.Linear(out_channels, 4*out_channels),
249
+ torch.nn.GLU(),
250
+ torch.nn.Linear(2*out_channels, in_channels))
251
+
252
+ def forward(self, x: torch.Tensor, input: torch.Tensor):
253
+ x = x[...,:input.shape[-1]]
254
+ x = x.permute([0, 2, 1])
255
+ x = self.end_conv1x1(x)
256
+ x = x.permute([0, 2, 1])
257
+ B, N, L = x.shape
258
+ B = B // self.num_spks
259
+
260
+ if self.masking:
261
+ input = input.expand(self.num_spks, B, N, L).transpose(0,1).contiguous()
262
+ input = input.view(B*self.num_spks, N, L)
263
+ x = self.spe_block(x, input)
264
+
265
+ x = x.view(B, self.num_spks, N, L)
266
+ # [spks, B, N, L]
267
+ x = x.transpose(0, 1)
268
+ return x
269
+
270
+
271
+ class AudioDecoder(torch.nn.ConvTranspose1d):
272
+ '''
273
+ Decoder of the TasNet
274
+ This module can be seen as the gradient of Conv1d with respect to its input.
275
+ It is also known as a fractionally-strided convolution
276
+ or a deconvolution (although it is not an actual deconvolution operation).
277
+ '''
278
+ def __init__(self, *args, **kwargs):
279
+ super().__init__(*args, **kwargs)
280
+
281
+ def forward(self, x):
282
+ # x: [B, N, L]
283
+ if x.dim() not in [2, 3]: raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
284
+ x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
285
+ x = torch.squeeze(x, dim=1) if torch.squeeze(x).dim() == 1 else torch.squeeze(x)
286
+ return x
models/SepReformer/SepReformer_Large_DM_WHAM/modules/network.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy
4
+ from utils.decorators import *
5
+
6
+
7
+ class LayerScale(torch.nn.Module):
8
+ def __init__(self, dims, input_size, Layer_scale_init=1.0e-5):
9
+ super().__init__()
10
+ if dims == 1:
11
+ self.layer_scale = torch.nn.Parameter(torch.ones(input_size)*Layer_scale_init, requires_grad=True)
12
+ elif dims == 2:
13
+ self.layer_scale = torch.nn.Parameter(torch.ones(1,input_size)*Layer_scale_init, requires_grad=True)
14
+ elif dims == 3:
15
+ self.layer_scale = torch.nn.Parameter(torch.ones(1,1,input_size)*Layer_scale_init, requires_grad=True)
16
+
17
+ def forward(self, x):
18
+ return x*self.layer_scale
19
+
20
+ class Masking(torch.nn.Module):
21
+ def __init__(self, input_dim, Activation_mask='Sigmoid', **options):
22
+ super(Masking, self).__init__()
23
+
24
+ self.options = options
25
+ if self.options['concat_opt']:
26
+ self.pw_conv = torch.nn.Conv1d(input_dim*2, input_dim, 1, stride=1, padding=0)
27
+
28
+ if Activation_mask == 'Sigmoid':
29
+ self.gate_act = torch.nn.Sigmoid()
30
+ elif Activation_mask == 'ReLU':
31
+ self.gate_act = torch.nn.ReLU()
32
+
33
+
34
+ def forward(self, x, skip):
35
+
36
+ if self.options['concat_opt']:
37
+ y = torch.cat([x, skip], dim=-2)
38
+ y = self.pw_conv(y)
39
+ else:
40
+ y = x
41
+ y = self.gate_act(y) * skip
42
+
43
+ return y
44
+
45
+
46
+ class GCFN(torch.nn.Module):
47
+ def __init__(self, in_channels, dropout_rate, Layer_scale_init=1.0e-5):
48
+ super().__init__()
49
+ self.net1 = torch.nn.Sequential(
50
+ torch.nn.LayerNorm(in_channels),
51
+ torch.nn.Linear(in_channels, in_channels*6))
52
+ self.depthwise = torch.nn.Conv1d(in_channels*6, in_channels*6, 3, padding=1, groups=in_channels*6)
53
+ self.net2 = torch.nn.Sequential(
54
+ torch.nn.GLU(),
55
+ torch.nn.Dropout(dropout_rate),
56
+ torch.nn.Linear(in_channels*3, in_channels),
57
+ torch.nn.Dropout(dropout_rate))
58
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
59
+
60
+ def forward(self, x):
61
+ y = self.net1(x)
62
+ y = y.permute(0, 2, 1).contiguous()
63
+ y = self.depthwise(y)
64
+ y = y.permute(0, 2, 1).contiguous()
65
+ y = self.net2(y)
66
+ return x + self.Layer_scale(y)
67
+
68
+
69
+ class MultiHeadAttention(torch.nn.Module):
70
+ """
71
+ Multi-Head Attention layer.
72
+ :param int n_head: the number of head s
73
+ :param int n_feat: the number of features
74
+ :param float dropout_rate: dropout rate
75
+ """
76
+ def __init__(self, n_head: int, in_channels: int, dropout_rate: float, Layer_scale_init=1.0e-5):
77
+ super().__init__()
78
+ assert in_channels % n_head == 0
79
+ self.d_k = in_channels // n_head # We assume d_v always equals d_k
80
+ self.h = n_head
81
+ self.layer_norm = torch.nn.LayerNorm(in_channels)
82
+ self.linear_q = torch.nn.Linear(in_channels, in_channels)
83
+ self.linear_k = torch.nn.Linear(in_channels, in_channels)
84
+ self.linear_v = torch.nn.Linear(in_channels, in_channels)
85
+ self.linear_out = torch.nn.Linear(in_channels, in_channels)
86
+ self.attn = None
87
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
88
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
89
+
90
+ def forward(self, x, pos_k, mask):
91
+ """
92
+ Compute 'Scaled Dot Product Attention'.
93
+ :param torch.Tensor mask: (batch, time1, time2)
94
+ :param torch.nn.Dropout dropout:
95
+ :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
96
+ weighted by the query dot key attention (batch, head, time1, time2)
97
+ """
98
+ n_batch = x.size(0)
99
+ x = self.layer_norm(x)
100
+ q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
101
+ k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
102
+ v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k)
103
+ q = q.transpose(1, 2)
104
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
105
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
106
+ A = torch.matmul(q, k.transpose(-2, -1))
107
+ reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1)
108
+ if pos_k is not None:
109
+ B = torch.matmul(reshape_q, pos_k.transpose(-2, -1))
110
+ B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
111
+ scores = (A + B) / math.sqrt(self.d_k)
112
+ else:
113
+ scores = A / math.sqrt(self.d_k)
114
+ if mask is not None:
115
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
116
+ min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
117
+ scores = scores.masked_fill(mask, min_value)
118
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
119
+ else:
120
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
121
+ p_attn = self.dropout(self.attn)
122
+ x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
123
+ x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
124
+ return self.Layer_scale(self.dropout(self.linear_out(x))) # (batch, time1, d_model)
125
+
126
+ class EGA(torch.nn.Module):
127
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
128
+ super().__init__()
129
+ self.block = torch.nn.ModuleDict({
130
+ 'self_attn': MultiHeadAttention(
131
+ n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
132
+ 'linear': torch.nn.Sequential(
133
+ torch.nn.LayerNorm(normalized_shape=in_channels),
134
+ torch.nn.Linear(in_features=in_channels, out_features=in_channels),
135
+ torch.nn.Sigmoid())
136
+ })
137
+
138
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
139
+ """
140
+ Compute encoded features.
141
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
142
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
143
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
144
+ """
145
+ down_len = pos_k.shape[0]
146
+ x_down = torch.nn.functional.adaptive_avg_pool1d(input=x, output_size=down_len)
147
+ x = x.permute([0, 2, 1])
148
+ x_down = x_down.permute([0, 2, 1])
149
+ x_down = self.block['self_attn'](x_down, pos_k, None)
150
+ x_down = x_down.permute([0, 2, 1])
151
+ x_downup = torch.nn.functional.upsample(input=x_down, size=x.shape[1])
152
+ x_downup = x_downup.permute([0, 2, 1])
153
+ x = x + self.block['linear'](x) * x_downup
154
+
155
+ return x
156
+
157
+
158
+
159
+ class CLA(torch.nn.Module):
160
+ def __init__(self, in_channels, kernel_size, dropout_rate, Layer_scale_init=1.0e-5):
161
+ super().__init__()
162
+ self.layer_norm = torch.nn.LayerNorm(in_channels)
163
+ self.linear1 = torch.nn.Linear(in_channels, in_channels*2)
164
+ self.GLU = torch.nn.GLU()
165
+ self.dw_conv_1d = torch.nn.Conv1d(in_channels, in_channels, kernel_size, padding='same', groups=in_channels)
166
+ self.linear2 = torch.nn.Linear(in_channels, 2*in_channels)
167
+ self.BN = torch.nn.BatchNorm1d(2*in_channels)
168
+ self.linear3 = torch.nn.Sequential(
169
+ torch.nn.GELU(),
170
+ torch.nn.Linear(2*in_channels, in_channels),
171
+ torch.nn.Dropout(dropout_rate))
172
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
173
+
174
+ def forward(self, x):
175
+ y = self.layer_norm(x)
176
+ y = self.linear1(y)
177
+ y = self.GLU(y)
178
+ y = y.permute([0, 2, 1]) # B, F, T
179
+ y = self.dw_conv_1d(y)
180
+ y = y.permute(0, 2, 1) # B, T, 2F
181
+ y = self.linear2(y)
182
+ y = y.permute(0, 2, 1) # B, T, 2F
183
+ y = self.BN(y)
184
+ y = y.permute(0, 2, 1) # B, T, 2F
185
+ y = self.linear3(y)
186
+
187
+ return x + self.Layer_scale(y)
188
+
189
+ class GlobalBlock(torch.nn.Module):
190
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
191
+ super().__init__()
192
+ self.block = torch.nn.ModuleDict({
193
+ 'ega': EGA(
194
+ num_mha_heads=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
195
+ 'gcfn': GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
196
+ })
197
+
198
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
199
+ """
200
+ Compute encoded features.
201
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
202
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
203
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
204
+ """
205
+ x = self.block['ega'](x, pos_k)
206
+ x = self.block['gcfn'](x)
207
+ x = x.permute([0, 2, 1])
208
+
209
+ return x
210
+
211
+
212
+ class LocalBlock(torch.nn.Module):
213
+ def __init__(self, in_channels: int, kernel_size: int, dropout_rate: float):
214
+ super().__init__()
215
+ self.block = torch.nn.ModuleDict({
216
+ 'cla': CLA(in_channels, kernel_size, dropout_rate),
217
+ 'gcfn': GCFN(in_channels, dropout_rate)
218
+ })
219
+
220
+ def forward(self, x: torch.Tensor):
221
+ x = self.block['cla'](x)
222
+ x = self.block['gcfn'](x)
223
+
224
+ return x
225
+
226
+
227
+ class SpkAttention(torch.nn.Module):
228
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
229
+ super().__init__()
230
+ self.self_attn = MultiHeadAttention(n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate)
231
+ self.feed_forward = GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
232
+
233
+ def forward(self, x: torch.Tensor, num_spk: int):
234
+ """
235
+ Compute encoded features.
236
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
237
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
238
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
239
+ """
240
+ B, F, T = x.shape
241
+ x = x.view(B//num_spk, num_spk, F, T).contiguous()
242
+ x = x.permute([0, 3, 1, 2]).contiguous()
243
+ x = x.view(-1, num_spk, F).contiguous()
244
+ x = x + self.self_attn(x, None, None)
245
+ x = x.view(B//num_spk, T, num_spk, F).contiguous()
246
+ x = x.permute([0, 2, 3, 1]).contiguous()
247
+ x = x.view(B, F, T).contiguous()
248
+ x = x.permute([0, 2, 1])
249
+ x = self.feed_forward(x)
250
+ x = x.permute([0, 2, 1])
251
+
252
+ return x
models/SepReformer/SepReformer_Large_DM_WHAMR/configs.yaml ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config:
2
+ dataset:
3
+ max_len : 32000
4
+ sampling_rate: 8000
5
+ scp_dir: "data/scp_ss_8k_whamr"
6
+ train:
7
+ mixture: "tr_mix.scp"
8
+ spk1: "tr_s1.scp"
9
+ spk2: "tr_s2.scp"
10
+ spk1_reverb: "tr_s1_reverb.scp"
11
+ spk2_reverb: "tr_s2_reverb.scp"
12
+ noise: "tr_n.scp"
13
+ dynamic_mixing: true
14
+ valid:
15
+ mixture: "cv_mix.scp"
16
+ spk1: "cv_s1.scp"
17
+ spk2: "cv_s2.scp"
18
+ test:
19
+ mixture: "tt_mix.scp"
20
+ spk1: "tt_s1.scp"
21
+ spk2: "tt_s2.scp"
22
+ dataloader:
23
+ batch_size: 2
24
+ pin_memory: false
25
+ num_workers: 12
26
+ drop_last: false
27
+ model:
28
+ num_stages: &var_model_num_stages 4 # R
29
+ num_spks: &var_model_num_spks 2
30
+ module_audio_enc:
31
+ in_channels: 1
32
+ out_channels: &var_model_audio_enc_out_channels 256
33
+ kernel_size: &var_model_audio_enc_kernel_size 16 # L
34
+ stride: &var_model_audio_enc_stride 4 # S
35
+ groups: 1
36
+ bias: false
37
+ module_feature_projector:
38
+ num_channels: *var_model_audio_enc_out_channels
39
+ in_channels: *var_model_audio_enc_out_channels
40
+ out_channels: &feature_projector_out_channels 256 # F
41
+ kernel_size: 1
42
+ bias: false
43
+ module_separator:
44
+ num_stages: *var_model_num_stages
45
+ relative_positional_encoding:
46
+ in_channels: *feature_projector_out_channels
47
+ num_heads: 8
48
+ maxlen: 2000
49
+ embed_v: false
50
+ enc_stage:
51
+ global_blocks:
52
+ in_channels: *feature_projector_out_channels
53
+ num_mha_heads: 8
54
+ dropout_rate: 0.1
55
+ local_blocks:
56
+ in_channels: *feature_projector_out_channels
57
+ kernel_size: 65
58
+ dropout_rate: 0.1
59
+ down_conv_layer:
60
+ in_channels: *feature_projector_out_channels
61
+ samp_kernel_size: &var_model_samp_kernel_size 5
62
+ spk_split_stage:
63
+ in_channels: *feature_projector_out_channels
64
+ num_spks: *var_model_num_spks
65
+ simple_fusion:
66
+ out_channels: *feature_projector_out_channels
67
+ dec_stage:
68
+ num_spks: *var_model_num_spks
69
+ global_blocks:
70
+ in_channels: *feature_projector_out_channels
71
+ num_mha_heads: 8
72
+ dropout_rate: 0.1
73
+ local_blocks:
74
+ in_channels: *feature_projector_out_channels
75
+ kernel_size: 65
76
+ dropout_rate: 0.1
77
+ spk_attention:
78
+ in_channels: *feature_projector_out_channels
79
+ num_mha_heads: 8
80
+ dropout_rate: 0.1
81
+ module_output_layer:
82
+ in_channels: *var_model_audio_enc_out_channels
83
+ out_channels: *feature_projector_out_channels
84
+ num_spks: *var_model_num_spks
85
+ module_audio_dec:
86
+ in_channels: *var_model_audio_enc_out_channels
87
+ out_channels: 1
88
+ kernel_size: *var_model_audio_enc_kernel_size
89
+ stride: *var_model_audio_enc_stride
90
+ bias: false
91
+ criterion: ### Ref: https://pytorch.org/docs/stable/nn.html#loss-functions
92
+ name: ["PIT_SISNR_mag", "PIT_SISNR_time", "PIT_SISNRi", "PIT_SDRi"] ### Choose a torch.nn's loss function class(=attribute) e.g. ["L1Loss", "MSELoss", "CrossEntropyLoss", ...] / You can also build your optimizer :)
93
+ PIT_SISNR_mag:
94
+ frame_length: 512
95
+ frame_shift: 128
96
+ window: 'hann'
97
+ num_stages: *var_model_num_stages
98
+ num_spks: *var_model_num_spks
99
+ scale_inv: true
100
+ mel_opt: false
101
+ PIT_SISNR_time:
102
+ num_spks: *var_model_num_spks
103
+ scale_inv: true
104
+ PIT_SISNRi:
105
+ num_spks: *var_model_num_spks
106
+ scale_inv: true
107
+ PIT_SDRi:
108
+ dump: 0
109
+ optimizer: ### Ref: https://pytorch.org/docs/stable/optim.html#algorithms
110
+ name: ["AdamW"] ### Choose a torch.optim's class(=attribute) e.g. ["Adam", "AdamW", "SGD", ...] / You can also build your optimizer :)
111
+ AdamW:
112
+ lr: 2.0e-4
113
+ weight_decay: 1.0e-2
114
+ scheduler: ### Ref(+ find "How to adjust learning rate"): https://pytorch.org/docs/stable/optim.html#algorithms
115
+ name: ["ReduceLROnPlateau", "WarmupConstantSchedule"] ### Choose a torch.optim.lr_scheduler's class(=attribute) e.g. ["StepLR", "ReduceLROnPlateau", "Custom"] / You can also build your scheduler :)
116
+ ReduceLROnPlateau:
117
+ mode: "min"
118
+ min_lr: 1.0e-10
119
+ factor: 0.8
120
+ patience: 2
121
+ WarmupConstantSchedule:
122
+ warmup_steps: 1000
123
+ check_computations:
124
+ dummy_len: 16000
125
+ engine:
126
+ max_epoch: 200
127
+ gpuid: "0" ### "0"(single-gpu) or "0, 1" (multi-gpu)
128
+ mvn: false
129
+ clip_norm: 5
130
+ start_scheduling: 50
131
+ test_epochs: [50, 80, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 199]
models/SepReformer/SepReformer_Large_DM_WHAMR/dataset.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import librosa as audio_lib
5
+ import numpy as np
6
+
7
+ from utils import util_dataset
8
+ from utils.decorators import *
9
+ from loguru import logger
10
+ from torch.utils.data import Dataset, DataLoader
11
+
12
+ @logger_wraps()
13
+ def get_dataloaders(args, dataset_config, loader_config):
14
+ # create dataset object for each partition
15
+ partitions = ["test"] if "test" in args.engine_mode else ["train", "valid", "test"]
16
+ dataloaders = {}
17
+ for partition in partitions:
18
+ scp_config_mix = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['mixture'])
19
+ scp_config_spk = [os.path.join(dataset_config["scp_dir"], dataset_config[partition][spk_key]) for spk_key in dataset_config[partition] if spk_key.startswith('spk')]
20
+ scp_config_noise = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['noise']) if 'noise' in dataset_config[partition] else None
21
+ dynamic_mixing = dataset_config[partition]["dynamic_mixing"] if partition == 'train' else False
22
+ dataset = MyDataset(
23
+ max_len = dataset_config['max_len'],
24
+ fs = dataset_config['sampling_rate'],
25
+ partition = partition,
26
+ wave_scp_srcs = scp_config_spk,
27
+ wave_scp_mix = scp_config_mix,
28
+ wave_scp_noise = scp_config_noise,
29
+ dynamic_mixing = dynamic_mixing)
30
+ dataloader = DataLoader(
31
+ dataset = dataset,
32
+ batch_size = 1 if partition == 'test' else loader_config["batch_size"],
33
+ shuffle = True, # only train: (partition == 'train') / all: True
34
+ pin_memory = loader_config["pin_memory"],
35
+ num_workers = loader_config["num_workers"],
36
+ drop_last = loader_config["drop_last"],
37
+ collate_fn = _collate)
38
+ dataloaders[partition] = dataloader
39
+ return dataloaders
40
+
41
+
42
+ def _collate(egs):
43
+ """
44
+ Transform utterance index into a minbatch
45
+
46
+ Arguments:
47
+ index: a list type [{},{},{}]
48
+
49
+ Returns:
50
+ input_sizes: a tensor correspond to utterance length
51
+ input_feats: packed sequence to feed networks
52
+ source_attr/target_attr: dictionary contains spectrogram/phase needed in loss computation
53
+ """
54
+ def __prepare_target_rir(dict_lsit, index):
55
+ return torch.nn.utils.rnn.pad_sequence([torch.tensor(d["src"][index], dtype=torch.float32) for d in dict_lsit], batch_first=True)
56
+ if type(egs) is not list: raise ValueError("Unsupported index type({})".format(type(egs)))
57
+ num_spks = 2 # you need to set this paramater by yourself
58
+ dict_list = sorted([eg for eg in egs], key=lambda x: x['num_sample'], reverse=True)
59
+ mixture = torch.nn.utils.rnn.pad_sequence([torch.tensor(d['mix'], dtype=torch.float32) for d in dict_list], batch_first=True)
60
+ src = [__prepare_target_rir(dict_list, index) for index in range(num_spks)]
61
+ input_sizes = torch.tensor([d['num_sample'] for d in dict_list], dtype=torch.float32)
62
+ key = [d['key'] for d in dict_list]
63
+ return input_sizes, mixture, src, key
64
+
65
+
66
+ @logger_wraps()
67
+ class MyDataset(Dataset):
68
+ def __init__(self, max_len, fs, partition, wave_scp_srcs, wave_scp_mix, wave_scp_noise, dynamic_mixing, speed_list=None):
69
+ self.partition = partition
70
+ for wave_scp_src in wave_scp_srcs:
71
+ if not os.path.exists(wave_scp_src): raise FileNotFoundError(f"Could not find file {wave_scp_src}")
72
+ self.max_len = max_len
73
+ self.fs = fs
74
+ self.wave_dict_srcs = [util_dataset.parse_scps(wave_scp_src) for wave_scp_src in wave_scp_srcs]
75
+ self.wave_dict_mix = util_dataset.parse_scps(wave_scp_mix)
76
+ self.wave_dict_noise = util_dataset.parse_scps(wave_scp_noise) if wave_scp_noise else None
77
+ self.wave_keys = list(self.wave_dict_mix.keys())
78
+ logger.info(f"Create MyDataset for {wave_scp_mix} with {len(self.wave_dict_mix)} utterances")
79
+ self.dynamic_mixing = dynamic_mixing
80
+
81
+ def __len__(self):
82
+ return len(self.wave_dict_mix)
83
+
84
+ def __contains__(self, key):
85
+ return key in self.wave_dict_mix
86
+
87
+ def _dynamic_mixing(self, key):
88
+ def __match_length(wav, len_data):
89
+ leftover = len(wav) - len_data
90
+ idx = random.randint(0,leftover)
91
+ wav = wav[idx:idx+len_data]
92
+ return wav
93
+
94
+ samps_src_reverb = []
95
+ samps_src = []
96
+ src_len = [self.max_len]
97
+ # dyanmic source choice
98
+ # checking whether it is the same speaker
99
+ key_random = random.choice(list(self.wave_dict_srcs[0].keys()))
100
+ idx1, idx2 = (0, 1) if random.random() > 0.5 else (1, 0)
101
+ files = [self.wave_dict_srcs[idx1][key], self.wave_dict_srcs[idx2][key_random]]
102
+ files_reverb = [self.wave_dict_srcs[idx1+2][key], self.wave_dict_srcs[idx2+2][key_random]]
103
+
104
+ # load
105
+ for idx, file in enumerate(files_reverb):
106
+ if not os.path.exists(file):
107
+ raise FileNotFoundError("Input file {} do not exists!".format(file))
108
+ samps_tmp_reverb, _ = audio_lib.load(file, sr=self.fs)
109
+ samps_tmp, _ = audio_lib.load(files[idx], sr=self.fs)
110
+ # mixing with random gains
111
+
112
+ if idx == 0: ref_rms = np.sqrt(np.mean(np.square(samps_tmp)))
113
+ curr_rms = np.sqrt(np.mean(np.square(samps_tmp)))
114
+
115
+ norm_factor = ref_rms / curr_rms
116
+ samps_tmp *= norm_factor
117
+ samps_tmp_reverb *= norm_factor
118
+
119
+ gain = pow(10,-random.uniform(-3,3)/20)
120
+ # Speed Augmentation
121
+ samps_src_reverb.append(gain*samps_tmp_reverb)
122
+ samps_src.append(gain*samps_tmp)
123
+ src_len.append(len(samps_tmp))
124
+
125
+
126
+ # matching the audio length
127
+ min_len = min(src_len)
128
+
129
+ # add noise source
130
+ file_noise = self.wave_dict_noise[key]
131
+ samps_noise, _ = audio_lib.load(file_noise, sr=self.fs)
132
+ curr_rms = np.sqrt(np.mean(np.square(samps_noise)))
133
+ norm_factor = ref_rms / curr_rms
134
+ samps_noise *= norm_factor
135
+ gain_noise = pow(10,-random.uniform(-6,3)/20)
136
+ samps_noise = samps_noise*gain_noise
137
+ src_len.append(len(samps_noise))
138
+
139
+ # truncate
140
+ min_len = min(src_len)
141
+ samps_src_stack = [np.stack([samps_src_reverb[idx], samps_src[idx]],axis=-1) for idx in range(len(samps_src_reverb))]
142
+ samps_src_stack = [__match_length(s, min_len) for s in samps_src_stack]
143
+ samps_src_reverb = [s[...,0] for s in samps_src_stack]
144
+ samps_src = [s[...,1] for s in samps_src_stack]
145
+ samps_noise = __match_length(samps_noise, min_len)
146
+ samps_mix = sum(samps_src_reverb) + samps_noise
147
+
148
+
149
+ if len(samps_mix)%4 != 0:
150
+ remains = len(samps_mix)%4
151
+ samps_mix = samps_mix[:-remains]
152
+ samps_src = [s[:-remains] for s in samps_src]
153
+
154
+ return samps_mix, samps_src
155
+
156
+ def _direct_load(self, key):
157
+ samps_src = []
158
+ files = [self.wave_dict_srcs[0][key], self.wave_dict_srcs[1][key]]
159
+ # files = [wave_dict_src[key] for wave_dict_src in self.wave_dict_srcs]
160
+ for file in files:
161
+ if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
162
+ samps_tmp, _ = audio_lib.load(file, sr=self.fs)
163
+ samps_src.append(samps_tmp)
164
+
165
+ file = self.wave_dict_mix[key]
166
+ if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
167
+ samps_mix, _ = audio_lib.load(file, sr=self.fs)
168
+
169
+ # Truncate samples as needed
170
+ if len(samps_mix) % 4 != 0:
171
+ remains = len(samps_mix) % 4
172
+ samps_mix = samps_mix[:-remains]
173
+ samps_src = [s[:-remains] for s in samps_src]
174
+
175
+ if self.partition != "test":
176
+ if len(samps_mix) > self.max_len:
177
+ start = random.randint(0,len(samps_mix)-self.max_len)
178
+ samps_mix = samps_mix[start:start+self.max_len]
179
+ samps_src = [s[start:start+self.max_len] for s in samps_src]
180
+
181
+ return samps_mix, samps_src
182
+
183
+ def __getitem__(self, index):
184
+ key = self.wave_keys[index]
185
+ if any(key not in self.wave_dict_srcs[i] for i in range(len(self.wave_dict_srcs)-2)) or key not in self.wave_dict_mix: raise KeyError(f"Could not find utterance {key}")
186
+ samps_mix, samps_src = self._dynamic_mixing(key) if self.dynamic_mixing else self._direct_load(key)
187
+ return {"num_sample": samps_mix.shape[0], "mix": samps_mix, "src": samps_src, "key": key}
models/SepReformer/SepReformer_Large_DM_WHAMR/engine.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import csv
4
+ import time
5
+ import soundfile as sf
6
+
7
+ from loguru import logger
8
+ from tqdm import tqdm
9
+ from utils import util_engine, functions
10
+ from utils.decorators import *
11
+ from torch.utils.tensorboard import SummaryWriter
12
+
13
+
14
+ @logger_wraps()
15
+ class Engine(object):
16
+ def __init__(self, args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device):
17
+
18
+ ''' Default setting '''
19
+ self.engine_mode = args.engine_mode
20
+ self.out_wav_dir = args.out_wav_dir
21
+ self.config = config
22
+ self.gpuid = gpuid
23
+ self.device = device
24
+ self.model = model.to(self.device)
25
+ self.dataloaders = dataloaders # self.dataloaders['train'] or ['valid'] or ['test']
26
+ self.PIT_SISNR_mag_loss, self.PIT_SISNR_time_loss, self.PIT_SISNRi_loss, self.PIT_SDRi_loss = criterions
27
+ self.main_optimizer = optimizers[0]
28
+ self.main_scheduler, self.warmup_scheduler = schedulers
29
+
30
+ self.pretrain_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "pretrain_weights")
31
+ os.makedirs(self.pretrain_weights_path, exist_ok=True)
32
+ self.scratch_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "scratch_weights")
33
+ os.makedirs(self.scratch_weights_path, exist_ok=True)
34
+
35
+ self.checkpoint_path = self.pretrain_weights_path if any(file.endswith(('.pt', '.pt', '.pkl')) for file in os.listdir(self.pretrain_weights_path)) else self.scratch_weights_path
36
+ self.start_epoch = util_engine.load_last_checkpoint_n_get_epoch(self.checkpoint_path, self.model, self.main_optimizer, location=self.device)
37
+
38
+ # Logging
39
+ util_engine.model_params_mac_summary(
40
+ model=self.model,
41
+ input=torch.randn(1, self.config['check_computations']['dummy_len']).to(self.device),
42
+ dummy_input=torch.rand(1, self.config['check_computations']['dummy_len']).to(self.device),
43
+ metrics=['ptflops', 'thop', 'torchinfo']
44
+ # metrics=['ptflops']
45
+ )
46
+
47
+ logger.info(f"Clip gradient by 2-norm {self.config['engine']['clip_norm']}")
48
+
49
+ @logger_wraps()
50
+ def _train(self, dataloader, epoch):
51
+ self.model.train()
52
+ tot_loss_freq = [0 for _ in range(self.model.num_stages)]
53
+ tot_loss_time, num_batch = 0, 0
54
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:25}{r_bar}{bar:-10b}', colour="YELLOW", dynamic_ncols=True)
55
+ for input_sizes, mixture, src, _ in dataloader:
56
+ nnet_input = mixture
57
+ nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
58
+ num_batch += 1
59
+ pbar.update(1)
60
+ # Scheduler learning rate for warm-up (Iteration-based update for transformers)
61
+ if epoch == 1: self.warmup_scheduler.step()
62
+ nnet_input = nnet_input.to(self.device)
63
+ self.main_optimizer.zero_grad()
64
+ estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
65
+ cur_loss_s_bn = 0
66
+ cur_loss_s_bn = []
67
+ for idx, estim_src_value in enumerate(estim_src_bn):
68
+ cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
69
+ tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
70
+ cur_loss_s = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
71
+ tot_loss_time += cur_loss_s.item() / self.config['model']['num_spks']
72
+ alpha = 0.4 * 0.8**(1+(epoch-101)//5) if epoch > 100 else 0.4
73
+ cur_loss = (1-alpha) * cur_loss_s + alpha * sum(cur_loss_s_bn) / len(cur_loss_s_bn)
74
+ cur_loss = cur_loss / self.config['model']['num_spks']
75
+ cur_loss.backward()
76
+ if self.config['engine']['clip_norm']: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['engine']['clip_norm'])
77
+ self.main_optimizer.step()
78
+ dict_loss = {"T_Loss": tot_loss_time / num_batch}
79
+ dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
80
+ pbar.set_postfix(dict_loss)
81
+ pbar.close()
82
+ tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
83
+ return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
84
+
85
+ @logger_wraps()
86
+ def _validate(self, dataloader):
87
+ self.model.eval()
88
+ tot_loss_freq = [0 for _ in range(self.model.num_stages)]
89
+ tot_loss_time, num_batch = 0, 0
90
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="RED", dynamic_ncols=True)
91
+ with torch.inference_mode():
92
+ for input_sizes, mixture, src, _ in dataloader:
93
+ nnet_input = mixture
94
+ nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
95
+ nnet_input = nnet_input.to(self.device)
96
+ num_batch += 1
97
+ pbar.update(1)
98
+ estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
99
+ cur_loss_s_bn = []
100
+ for idx, estim_src_value in enumerate(estim_src_bn):
101
+ cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
102
+ tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
103
+ cur_loss_s_SDR = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
104
+ tot_loss_time += cur_loss_s_SDR.item() / self.config['model']['num_spks']
105
+ dict_loss = {"T_Loss":tot_loss_time / num_batch}
106
+ dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
107
+ pbar.set_postfix(dict_loss)
108
+ pbar.close()
109
+ tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
110
+ return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
111
+
112
+ @logger_wraps()
113
+ def _test(self, dataloader, wav_dir=None):
114
+ self.model.eval()
115
+ total_loss_SISNRi, total_loss_SDRi, num_batch = 0, 0, 0
116
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="grey", dynamic_ncols=True)
117
+ with torch.inference_mode():
118
+ csv_file_name_sisnr = os.path.join(os.path.dirname(__file__),'test_SISNRi_value.csv')
119
+ csv_file_name_sdr = os.path.join(os.path.dirname(__file__),'test_SDRi_value.csv')
120
+ with open(csv_file_name_sisnr, 'w', newline='') as csvfile_sisnr, open(csv_file_name_sdr, 'w', newline='') as csvfile_sdr:
121
+ idx = 0
122
+ writer_sisnr = csv.writer(csvfile_sisnr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
123
+ writer_sdr = csv.writer(csvfile_sdr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
124
+ for input_sizes, mixture, src, key in dataloader:
125
+ if len(key) > 1:
126
+ raise("batch size is not one!!")
127
+ nnet_input = mixture.to(self.device)
128
+ num_batch += 1
129
+ pbar.update(1)
130
+ estim_src, _ = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
131
+ cur_loss_SISNRi, cur_loss_SISNRi_src = self.PIT_SISNRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src, eps=1.0e-15)
132
+ total_loss_SISNRi += cur_loss_SISNRi.item() / self.config['model']['num_spks']
133
+ cur_loss_SDRi, cur_loss_SDRi_src = self.PIT_SDRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src)
134
+ total_loss_SDRi += cur_loss_SDRi.item() / self.config['model']['num_spks']
135
+ writer_sisnr.writerow([key[0][:-4]] + [cur_loss_SISNRi_src[i].item() for i in range(self.config['model']['num_spks'])])
136
+ writer_sdr.writerow([key[0][:-4]] + [cur_loss_SDRi_src[i].item() for i in range(self.config['model']['num_spks'])])
137
+ if self.engine_mode == "test_save":
138
+ if wav_dir == None: wav_dir = os.path.join(os.path.dirname(__file__),"wav_out")
139
+ if wav_dir and not os.path.exists(wav_dir): os.makedirs(wav_dir)
140
+ mixture = torch.squeeze(mixture).cpu().data.numpy()
141
+ sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_mixture.wav'), 0.5*mixture/max(abs(mixture)), 8000)
142
+ for i in range(self.config['model']['num_spks']):
143
+ src = torch.squeeze(estim_src[i]).cpu().data.numpy()
144
+ sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_out_'+str(i)+'.wav'), 0.5*src/max(abs(src)), 8000)
145
+ idx += 1
146
+ dict_loss = {"SiSNRi": total_loss_SISNRi/num_batch, "SDRi": total_loss_SDRi/num_batch}
147
+ pbar.set_postfix(dict_loss)
148
+ pbar.close()
149
+ return total_loss_SISNRi/num_batch, total_loss_SDRi/num_batch, num_batch
150
+
151
+ @logger_wraps()
152
+ def run(self):
153
+ with torch.cuda.device(self.device):
154
+ writer_src = SummaryWriter(os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/tensorboard"))
155
+ if "test" in self.engine_mode:
156
+ on_test_start = time.time()
157
+ test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'], self.out_wav_dir)
158
+ on_test_end = time.time()
159
+ logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
160
+ logger.info(f"Testing done!")
161
+ else:
162
+ start_time = time.time()
163
+ if self.start_epoch > 1:
164
+ init_loss_time, init_loss_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
165
+ else:
166
+ init_loss_time, init_loss_freq = 0, 0
167
+ end_time = time.time()
168
+ logger.info(f"[INIT] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: Loss_t = {init_loss_time:.4f} dB | Loss_f = {init_loss_freq:.4f} dB | Speed = ({end_time-start_time:.2f}s)")
169
+ for epoch in range(self.start_epoch, self.config['engine']['max_epoch']):
170
+ valid_loss_best = init_loss_time
171
+ train_start_time = time.time()
172
+ train_loss_src_time, train_loss_src_freq, train_num_batch = self._train(self.dataloaders['train'], epoch)
173
+ train_end_time = time.time()
174
+ valid_start_time = time.time()
175
+ valid_loss_src_time, valid_loss_src_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
176
+ valid_end_time = time.time()
177
+ if epoch > self.config['engine']['start_scheduling']: self.main_scheduler.step(valid_loss_src_time)
178
+ logger.info(f"[TRAIN] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {train_loss_src_time:.4f} dB | Loss_f = {train_loss_src_freq:.4f} dB | Speed = ({train_end_time - train_start_time:.2f}s/{train_num_batch:d})")
179
+ logger.info(f"[VALID] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {valid_loss_src_time:.4f} dB | Loss_f = {valid_loss_src_freq:.4f} dB | Speed = ({valid_end_time - valid_start_time:.2f}s/{valid_num_batch:d})")
180
+ if epoch in self.config['engine']['test_epochs']:
181
+ on_test_start = time.time()
182
+ test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'])
183
+ on_test_end = time.time()
184
+ logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
185
+ valid_loss_best = util_engine.save_checkpoint_per_best(valid_loss_best, valid_loss_src_time, train_loss_src_time, epoch, self.model, self.main_optimizer, self.checkpoint_path)
186
+ # Logging to monitoring tools (Tensorboard && Wandb)
187
+ writer_src.add_scalars("Metrics", {
188
+ 'Loss_train_time': train_loss_src_time,
189
+ 'Loss_valid_time': valid_loss_src_time}, epoch)
190
+ writer_src.add_scalar("Learning Rate", self.main_optimizer.param_groups[0]['lr'], epoch)
191
+ writer_src.flush()
192
+ logger.info(f"Training for {self.config['engine']['max_epoch']} epoches done!")
models/SepReformer/SepReformer_Large_DM_WHAMR/main.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from loguru import logger
4
+ from .dataset import get_dataloaders
5
+ from .model import Model
6
+ from .engine import Engine
7
+ from utils import util_system, util_implement
8
+ from utils.decorators import *
9
+
10
+ # Setup logger
11
+ log_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/system_log.log")
12
+ logger.add(log_file_path, level="DEBUG", mode="w")
13
+
14
+ @logger_wraps()
15
+ def main(args):
16
+
17
+ ''' Build Setting '''
18
+ # Call configuration file (configs.yaml)
19
+ yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs.yaml")
20
+ yaml_dict = util_system.parse_yaml(yaml_path)
21
+
22
+ # Run wandb and get configuration
23
+ config = yaml_dict["config"] # wandb login success or fail
24
+
25
+ # Call DataLoader [train / valid / test / etc...]
26
+ dataloaders = get_dataloaders(args, config["dataset"], config["dataloader"])
27
+
28
+ ''' Build Model '''
29
+ # Call network model
30
+ model = Model(**config["model"])
31
+
32
+ ''' Build Engine '''
33
+ # Call gpu id & device
34
+ gpuid = tuple(map(int, config["engine"]["gpuid"].split(',')))
35
+ device = torch.device(f'cuda:{gpuid[0]}')
36
+
37
+ # Call Implement [criterion / optimizer / scheduler]
38
+ criterions = util_implement.CriterionFactory(config["criterion"], device).get_criterions()
39
+ optimizers = util_implement.OptimizerFactory(config["optimizer"], model.parameters()).get_optimizers()
40
+ schedulers = util_implement.SchedulerFactory(config["scheduler"], optimizers).get_schedulers()
41
+
42
+ # Call & Run Engine
43
+ engine = Engine(args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device)
44
+ engine.run()
models/SepReformer/SepReformer_Large_DM_WHAMR/model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+
4
+ import torch
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ from utils.decorators import *
9
+ from .modules.module import *
10
+
11
+
12
+ @logger_wraps()
13
+ class Model(torch.nn.Module):
14
+ def __init__(self,
15
+ num_stages: int,
16
+ num_spks: int,
17
+ module_audio_enc: dict,
18
+ module_feature_projector: dict,
19
+ module_separator: dict,
20
+ module_output_layer: dict,
21
+ module_audio_dec: dict):
22
+ super().__init__()
23
+ self.num_stages = num_stages
24
+ self.num_spks = num_spks
25
+ self.audio_encoder = AudioEncoder(**module_audio_enc)
26
+ self.feature_projector = FeatureProjector(**module_feature_projector)
27
+ self.separator = Separator(**module_separator)
28
+ self.out_layer = OutputLayer(**module_output_layer)
29
+ self.audio_decoder = AudioDecoder(**module_audio_dec)
30
+
31
+ # Aux_loss
32
+ self.out_layer_bn = torch.nn.ModuleList([])
33
+ self.decoder_bn = torch.nn.ModuleList([])
34
+ for _ in range(self.num_stages):
35
+ self.out_layer_bn.append(OutputLayer(**module_output_layer, masking=True))
36
+ self.decoder_bn.append(AudioDecoder(**module_audio_dec))
37
+
38
+ def forward(self, x):
39
+ encoder_output = self.audio_encoder(x)
40
+ projected_feature = self.feature_projector(encoder_output)
41
+ last_stage_output, each_stage_outputs = self.separator(projected_feature)
42
+ out_layer_output = self.out_layer(last_stage_output, encoder_output)
43
+ each_spk_output = [out_layer_output[idx] for idx in range(self.num_spks)]
44
+ audio = [self.audio_decoder(each_spk_output[idx]) for idx in range(self.num_spks)]
45
+
46
+ # Aux_loss
47
+ audio_aux = []
48
+ for idx, each_stage_output in enumerate(each_stage_outputs):
49
+ each_stage_output = self.out_layer_bn[idx](torch.nn.functional.upsample(each_stage_output, encoder_output.shape[-1]), encoder_output)
50
+ out_aux = [each_stage_output[jdx] for jdx in range(self.num_spks)]
51
+ audio_aux.append([self.decoder_bn[idx](out_aux[jdx])[...,:x.shape[-1]] for jdx in range(self.num_spks)])
52
+
53
+ return audio, audio_aux
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/module.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/module.cpython-38.pyc ADDED
Binary file (11 kB). View file
 
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/network.cpython-310.pyc ADDED
Binary file (8.98 kB). View file
 
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/network.cpython-38.pyc ADDED
Binary file (9.09 kB). View file
 
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/module.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+
4
+ import torch
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ from utils.decorators import *
9
+ from .network import *
10
+
11
+
12
+ class AudioEncoder(torch.nn.Module):
13
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, bias: bool):
14
+ super().__init__()
15
+ self.conv1d = torch.nn.Conv1d(
16
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, groups=groups, bias=bias)
17
+ self.gelu = torch.nn.GELU()
18
+
19
+ def forward(self, x: torch.Tensor):
20
+ x = torch.unsqueeze(x, dim=0) if len(x.shape) == 1 else torch.unsqueeze(x, dim=1) # [T] - >[1, T] OR [B, T] -> [B, 1, T]
21
+ x = self.conv1d(x)
22
+ x = self.gelu(x)
23
+ return x
24
+
25
+ class FeatureProjector(torch.nn.Module):
26
+ def __init__(self, num_channels: int, in_channels: int, out_channels: int, kernel_size: int, bias: bool):
27
+ super().__init__()
28
+ self.norm = torch.nn.GroupNorm(num_groups=1, num_channels=num_channels, eps=1e-8)
29
+ self.conv1d = torch.nn.Conv1d(
30
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
31
+
32
+ def forward(self, x: torch.Tensor):
33
+ x = self.norm(x)
34
+ x = self.conv1d(x)
35
+ return x
36
+
37
+
38
+ class Separator(torch.nn.Module):
39
+ def __init__(self, num_stages: int, relative_positional_encoding: dict, enc_stage: dict, spk_split_stage: dict, simple_fusion:dict, dec_stage: dict):
40
+ super().__init__()
41
+
42
+ class RelativePositionalEncoding(torch.nn.Module):
43
+ def __init__(self, in_channels: int, num_heads: int, maxlen: int, embed_v=False):
44
+ super().__init__()
45
+ self.in_channels = in_channels
46
+ self.num_heads = num_heads
47
+ self.embedding_dim = self.in_channels // self.num_heads
48
+ self.maxlen = maxlen
49
+ self.pe_k = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim)
50
+ self.pe_v = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim) if embed_v else None
51
+
52
+ def forward(self, pos_seq: torch.Tensor):
53
+ pos_seq.clamp_(-self.maxlen, self.maxlen - 1)
54
+ pos_seq += self.maxlen
55
+ pe_k_output = self.pe_k(pos_seq)
56
+ pe_v_output = self.pe_v(pos_seq) if self.pe_v is not None else None
57
+ return pe_k_output, pe_v_output
58
+
59
+ class SepEncStage(torch.nn.Module):
60
+ def __init__(self, global_blocks: dict, local_blocks: dict, down_conv_layer: dict, down_conv=True):
61
+ super().__init__()
62
+
63
+ class DownConvLayer(torch.nn.Module):
64
+ def __init__(self, in_channels: int, samp_kernel_size: int):
65
+ """Construct an EncoderLayer object."""
66
+ super().__init__()
67
+ self.down_conv = torch.nn.Conv1d(
68
+ in_channels=in_channels, out_channels=in_channels, kernel_size=samp_kernel_size, stride=2, padding=(samp_kernel_size-1)//2, groups=in_channels)
69
+ self.BN = torch.nn.BatchNorm1d(num_features=in_channels)
70
+ self.gelu = torch.nn.GELU()
71
+
72
+ def forward(self, x: torch.Tensor):
73
+ x = x.permute([0, 2, 1])
74
+ x = self.down_conv(x)
75
+ x = self.BN(x)
76
+ x = self.gelu(x)
77
+ x = x.permute([0, 2, 1])
78
+ return x
79
+
80
+ self.g_block_1 = GlobalBlock(**global_blocks)
81
+ self.l_block_1 = LocalBlock(**local_blocks)
82
+
83
+ self.g_block_2 = GlobalBlock(**global_blocks)
84
+ self.l_block_2 = LocalBlock(**local_blocks)
85
+
86
+ self.downconv = DownConvLayer(**down_conv_layer) if down_conv == True else None
87
+
88
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
89
+ '''
90
+ x: [B, N, T]
91
+ '''
92
+ x = self.g_block_1(x, pos_k)
93
+ x = x.permute(0, 2, 1).contiguous()
94
+ x = self.l_block_1(x)
95
+ x = x.permute(0, 2, 1).contiguous()
96
+
97
+ x = self.g_block_2(x, pos_k)
98
+ x = x.permute(0, 2, 1).contiguous()
99
+ x = self.l_block_2(x)
100
+ x = x.permute(0, 2, 1).contiguous()
101
+
102
+ skip = x
103
+ if self.downconv:
104
+ x = x.permute(0, 2, 1).contiguous()
105
+ x = self.downconv(x)
106
+ x = x.permute(0, 2, 1).contiguous()
107
+ # [BK, S, N]
108
+ return x, skip
109
+
110
+ class SpkSplitStage(torch.nn.Module):
111
+ def __init__(self, in_channels: int, num_spks: int):
112
+ super().__init__()
113
+ self.linear = torch.nn.Sequential(
114
+ torch.nn.Conv1d(in_channels, 4*in_channels*num_spks, kernel_size=1),
115
+ torch.nn.GLU(dim=-2),
116
+ torch.nn.Conv1d(2*in_channels*num_spks, in_channels*num_spks, kernel_size=1))
117
+ self.norm = torch.nn.GroupNorm(1, in_channels, eps=1e-8)
118
+ self.num_spks = num_spks
119
+
120
+ def forward(self, x: torch.Tensor):
121
+ x = self.linear(x)
122
+ B, _, T = x.shape
123
+ x = x.view(B*self.num_spks,-1, T).contiguous()
124
+ x = self.norm(x)
125
+ return x
126
+
127
+ class SepDecStage(torch.nn.Module):
128
+ def __init__(self, num_spks: int, global_blocks: dict, local_blocks: dict, spk_attention: dict):
129
+ super().__init__()
130
+
131
+ self.g_block_1 = GlobalBlock(**global_blocks)
132
+ self.l_block_1 = LocalBlock(**local_blocks)
133
+ self.spk_attn_1 = SpkAttention(**spk_attention)
134
+
135
+ self.g_block_2 = GlobalBlock(**global_blocks)
136
+ self.l_block_2 = LocalBlock(**local_blocks)
137
+ self.spk_attn_2 = SpkAttention(**spk_attention)
138
+
139
+ self.g_block_3 = GlobalBlock(**global_blocks)
140
+ self.l_block_3 = LocalBlock(**local_blocks)
141
+ self.spk_attn_3 = SpkAttention(**spk_attention)
142
+
143
+ self.num_spk = num_spks
144
+
145
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
146
+ '''
147
+ x: [B, N, T]
148
+ '''
149
+ # [BS, K, H]
150
+ x = self.g_block_1(x, pos_k)
151
+ x = x.permute(0, 2, 1).contiguous()
152
+ x = self.l_block_1(x)
153
+ x = x.permute(0, 2, 1).contiguous()
154
+ x = self.spk_attn_1(x, self.num_spk)
155
+
156
+ x = self.g_block_2(x, pos_k)
157
+ x = x.permute(0, 2, 1).contiguous()
158
+ x = self.l_block_2(x)
159
+ x = x.permute(0, 2, 1).contiguous()
160
+ x = self.spk_attn_2(x, self.num_spk)
161
+
162
+ x = self.g_block_3(x, pos_k)
163
+ x = x.permute(0, 2, 1).contiguous()
164
+ x = self.l_block_3(x)
165
+ x = x.permute(0, 2, 1).contiguous()
166
+ x = self.spk_attn_3(x, self.num_spk)
167
+
168
+ skip = x
169
+
170
+ return x, skip
171
+
172
+ self.num_stages = num_stages
173
+ self.pos_emb = RelativePositionalEncoding(**relative_positional_encoding)
174
+
175
+ # Temporal Contracting Part
176
+ self.enc_stages = torch.nn.ModuleList([])
177
+ for _ in range(self.num_stages):
178
+ self.enc_stages.append(SepEncStage(**enc_stage, down_conv=True))
179
+
180
+ self.bottleneck_G = SepEncStage(**enc_stage, down_conv=False)
181
+ self.spk_split_block = SpkSplitStage(**spk_split_stage)
182
+
183
+ # Temporal Expanding Part
184
+ self.simple_fusion = torch.nn.ModuleList([])
185
+ self.dec_stages = torch.nn.ModuleList([])
186
+ for _ in range(self.num_stages):
187
+ self.simple_fusion.append(torch.nn.Conv1d(in_channels=simple_fusion['out_channels']*2,out_channels=simple_fusion['out_channels'], kernel_size=1))
188
+ self.dec_stages.append(SepDecStage(**dec_stage))
189
+
190
+ def forward(self, input: torch.Tensor):
191
+ '''input: [B, N, L]'''
192
+ # feature projection
193
+ x, _ = self.pad_signal(input)
194
+ len_x = x.shape[-1]
195
+ # Temporal Contracting Part
196
+ pos_seq = torch.arange(0, len_x//2**self.num_stages).long().to(x.device)
197
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
198
+ pos_k, _ = self.pos_emb(pos_seq)
199
+ skip = []
200
+ for idx in range(self.num_stages):
201
+ x, skip_ = self.enc_stages[idx](x, pos_k)
202
+ skip_ = self.spk_split_block(skip_)
203
+ skip.append(skip_)
204
+ x, _ = self.bottleneck_G(x, pos_k)
205
+ x = self.spk_split_block(x) # B, 2F, T
206
+
207
+ each_stage_outputs = []
208
+ # Temporal Expanding Part
209
+ for idx in range(self.num_stages):
210
+ each_stage_outputs.append(x)
211
+ idx_en = self.num_stages - (idx + 1)
212
+ x = torch.nn.functional.upsample(x, skip[idx_en].shape[-1])
213
+ x = torch.cat([x,skip[idx_en]],dim=1)
214
+ x = self.simple_fusion[idx](x)
215
+ x, _ = self.dec_stages[idx](x, pos_k)
216
+
217
+ last_stage_output = x
218
+ return last_stage_output, each_stage_outputs
219
+
220
+ def pad_signal(self, input: torch.Tensor):
221
+ # (B, T) or (B, 1, T)
222
+ if input.dim() == 1: input = input.unsqueeze(0)
223
+ elif input.dim() not in [2, 3]: raise RuntimeError("Input can only be 2 or 3 dimensional.")
224
+ elif input.dim() == 2: input = input.unsqueeze(1)
225
+ L = 2**self.num_stages
226
+ batch_size = input.size(0)
227
+ ndim = input.size(1)
228
+ nframe = input.size(2)
229
+ padded_len = (nframe//L + 1)*L
230
+ rest = 0 if nframe%L == 0 else padded_len - nframe
231
+ if rest > 0:
232
+ pad = torch.autograd.Variable(torch.zeros(batch_size, ndim, rest)).type(input.type()).to(input.device)
233
+ input = torch.cat([input, pad], dim=-1)
234
+ return input, rest
235
+
236
+
237
+ class OutputLayer(torch.nn.Module):
238
+ def __init__(self, in_channels: int, out_channels: int, num_spks: int, masking: bool = False):
239
+ super().__init__()
240
+ # feature expansion back
241
+ self.masking = masking
242
+ self.spe_block = Masking(in_channels, Activation_mask="ReLU", concat_opt=None)
243
+ self.num_spks = num_spks
244
+ self.end_conv1x1 = torch.nn.Sequential(
245
+ torch.nn.Linear(out_channels, 4*out_channels),
246
+ torch.nn.GLU(),
247
+ torch.nn.Linear(2*out_channels, in_channels))
248
+
249
+ def forward(self, x: torch.Tensor, input: torch.Tensor):
250
+ x = x[...,:input.shape[-1]]
251
+ x = x.permute([0, 2, 1])
252
+ x = self.end_conv1x1(x)
253
+ x = x.permute([0, 2, 1])
254
+ B, N, L = x.shape
255
+ B = B // self.num_spks
256
+
257
+ if self.masking:
258
+ input = input.expand(self.num_spks, B, N, L).transpose(0,1).contiguous()
259
+ input = input.view(B*self.num_spks, N, L)
260
+ x = self.spe_block(x, input)
261
+
262
+ x = x.view(B, self.num_spks, N, L)
263
+ # [spks, B, N, L]
264
+ x = x.transpose(0, 1)
265
+ return x
266
+
267
+
268
+ class AudioDecoder(torch.nn.ConvTranspose1d):
269
+ '''
270
+ Decoder of the TasNet
271
+ This module can be seen as the gradient of Conv1d with respect to its input.
272
+ It is also known as a fractionally-strided convolution
273
+ or a deconvolution (although it is not an actual deconvolution operation).
274
+ '''
275
+ def __init__(self, *args, **kwargs):
276
+ super().__init__(*args, **kwargs)
277
+
278
+ def forward(self, x):
279
+ # x: [B, N, L]
280
+ if x.dim() not in [2, 3]: raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
281
+ x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
282
+ x = torch.squeeze(x, dim=1) if torch.squeeze(x).dim() == 1 else torch.squeeze(x)
283
+ return x
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/network.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy
4
+ from utils.decorators import *
5
+
6
+
7
+ class LayerScale(torch.nn.Module):
8
+ def __init__(self, dims, input_size, Layer_scale_init=1.0e-5):
9
+ super().__init__()
10
+ if dims == 1:
11
+ self.layer_scale = torch.nn.Parameter(torch.ones(input_size)*Layer_scale_init, requires_grad=True)
12
+ elif dims == 2:
13
+ self.layer_scale = torch.nn.Parameter(torch.ones(1,input_size)*Layer_scale_init, requires_grad=True)
14
+ elif dims == 3:
15
+ self.layer_scale = torch.nn.Parameter(torch.ones(1,1,input_size)*Layer_scale_init, requires_grad=True)
16
+
17
+ def forward(self, x):
18
+ return x*self.layer_scale
19
+
20
+ class Masking(torch.nn.Module):
21
+ def __init__(self, input_dim, Activation_mask='Sigmoid', **options):
22
+ super(Masking, self).__init__()
23
+
24
+ self.options = options
25
+ if self.options['concat_opt']:
26
+ self.pw_conv = torch.nn.Conv1d(input_dim*2, input_dim, 1, stride=1, padding=0)
27
+
28
+ if Activation_mask == 'Sigmoid':
29
+ self.gate_act = torch.nn.Sigmoid()
30
+ elif Activation_mask == 'ReLU':
31
+ self.gate_act = torch.nn.ReLU()
32
+
33
+
34
+ def forward(self, x, skip):
35
+
36
+ if self.options['concat_opt']:
37
+ y = torch.cat([x, skip], dim=-2)
38
+ y = self.pw_conv(y)
39
+ else:
40
+ y = x
41
+ y = self.gate_act(y) * skip
42
+
43
+ return y
44
+
45
+
46
+ class GCFN(torch.nn.Module):
47
+ def __init__(self, in_channels, dropout_rate, Layer_scale_init=1.0e-5):
48
+ super().__init__()
49
+ self.net1 = torch.nn.Sequential(
50
+ torch.nn.LayerNorm(in_channels),
51
+ torch.nn.Linear(in_channels, in_channels*6))
52
+ self.depthwise = torch.nn.Conv1d(in_channels*6, in_channels*6, 3, padding=1, groups=in_channels*6)
53
+ self.net2 = torch.nn.Sequential(
54
+ torch.nn.GLU(),
55
+ torch.nn.Dropout(dropout_rate),
56
+ torch.nn.Linear(in_channels*3, in_channels),
57
+ torch.nn.Dropout(dropout_rate))
58
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
59
+
60
+ def forward(self, x):
61
+ y = self.net1(x)
62
+ y = y.permute(0, 2, 1).contiguous()
63
+ y = self.depthwise(y)
64
+ y = y.permute(0, 2, 1).contiguous()
65
+ y = self.net2(y)
66
+ return x + self.Layer_scale(y)
67
+
68
+
69
+ class MultiHeadAttention(torch.nn.Module):
70
+ """
71
+ Multi-Head Attention layer.
72
+ :param int n_head: the number of head s
73
+ :param int n_feat: the number of features
74
+ :param float dropout_rate: dropout rate
75
+ """
76
+ def __init__(self, n_head: int, in_channels: int, dropout_rate: float, Layer_scale_init=1.0e-5):
77
+ super().__init__()
78
+ assert in_channels % n_head == 0
79
+ self.d_k = in_channels // n_head # We assume d_v always equals d_k
80
+ self.h = n_head
81
+ self.layer_norm = torch.nn.LayerNorm(in_channels)
82
+ self.linear_q = torch.nn.Linear(in_channels, in_channels)
83
+ self.linear_k = torch.nn.Linear(in_channels, in_channels)
84
+ self.linear_v = torch.nn.Linear(in_channels, in_channels)
85
+ self.linear_out = torch.nn.Linear(in_channels, in_channels)
86
+ self.attn = None
87
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
88
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
89
+
90
+ def forward(self, x, pos_k, mask):
91
+ """
92
+ Compute 'Scaled Dot Product Attention'.
93
+ :param torch.Tensor mask: (batch, time1, time2)
94
+ :param torch.nn.Dropout dropout:
95
+ :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
96
+ weighted by the query dot key attention (batch, head, time1, time2)
97
+ """
98
+ n_batch = x.size(0)
99
+ x = self.layer_norm(x)
100
+ q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
101
+ k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
102
+ v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k)
103
+ q = q.transpose(1, 2)
104
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
105
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
106
+ A = torch.matmul(q, k.transpose(-2, -1))
107
+ reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1)
108
+ if pos_k is not None:
109
+ B = torch.matmul(reshape_q, pos_k.transpose(-2, -1))
110
+ B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
111
+ scores = (A + B) / math.sqrt(self.d_k)
112
+ else:
113
+ scores = A / math.sqrt(self.d_k)
114
+ if mask is not None:
115
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
116
+ min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
117
+ scores = scores.masked_fill(mask, min_value)
118
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
119
+ else:
120
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
121
+ p_attn = self.dropout(self.attn)
122
+ x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
123
+ x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
124
+ return self.Layer_scale(self.dropout(self.linear_out(x))) # (batch, time1, d_model)
125
+
126
+ class EGA(torch.nn.Module):
127
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
128
+ super().__init__()
129
+ self.block = torch.nn.ModuleDict({
130
+ 'self_attn': MultiHeadAttention(
131
+ n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
132
+ 'linear': torch.nn.Sequential(
133
+ torch.nn.LayerNorm(normalized_shape=in_channels),
134
+ torch.nn.Linear(in_features=in_channels, out_features=in_channels),
135
+ torch.nn.Sigmoid())
136
+ })
137
+
138
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
139
+ """
140
+ Compute encoded features.
141
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
142
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
143
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
144
+ """
145
+ down_len = pos_k.shape[0]
146
+ x_down = torch.nn.functional.adaptive_avg_pool1d(input=x, output_size=down_len)
147
+ x = x.permute([0, 2, 1])
148
+ x_down = x_down.permute([0, 2, 1])
149
+ x_down = self.block['self_attn'](x_down, pos_k, None)
150
+ x_down = x_down.permute([0, 2, 1])
151
+ x_downup = torch.nn.functional.upsample(input=x_down, size=x.shape[1])
152
+ x_downup = x_downup.permute([0, 2, 1])
153
+ x = x + self.block['linear'](x) * x_downup
154
+
155
+ return x
156
+
157
+
158
+
159
+ class CLA(torch.nn.Module):
160
+ def __init__(self, in_channels, kernel_size, dropout_rate, Layer_scale_init=1.0e-5):
161
+ super().__init__()
162
+ self.layer_norm = torch.nn.LayerNorm(in_channels)
163
+ self.linear1 = torch.nn.Linear(in_channels, in_channels*2)
164
+ self.GLU = torch.nn.GLU()
165
+ self.dw_conv_1d = torch.nn.Conv1d(in_channels, in_channels, kernel_size, padding='same', groups=in_channels)
166
+ self.linear2 = torch.nn.Linear(in_channels, 2*in_channels)
167
+ self.BN = torch.nn.BatchNorm1d(2*in_channels)
168
+ self.linear3 = torch.nn.Sequential(
169
+ torch.nn.GELU(),
170
+ torch.nn.Linear(2*in_channels, in_channels),
171
+ torch.nn.Dropout(dropout_rate))
172
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
173
+
174
+ def forward(self, x):
175
+ y = self.layer_norm(x)
176
+ y = self.linear1(y)
177
+ y = self.GLU(y)
178
+ y = y.permute([0, 2, 1]) # B, F, T
179
+ y = self.dw_conv_1d(y)
180
+ y = y.permute(0, 2, 1) # B, T, 2F
181
+ y = self.linear2(y)
182
+ y = y.permute(0, 2, 1) # B, T, 2F
183
+ y = self.BN(y)
184
+ y = y.permute(0, 2, 1) # B, T, 2F
185
+ y = self.linear3(y)
186
+
187
+ return x + self.Layer_scale(y)
188
+
189
+ class GlobalBlock(torch.nn.Module):
190
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
191
+ super().__init__()
192
+ self.block = torch.nn.ModuleDict({
193
+ 'ega': EGA(
194
+ num_mha_heads=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
195
+ 'gcfn': GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
196
+ })
197
+
198
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
199
+ """
200
+ Compute encoded features.
201
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
202
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
203
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
204
+ """
205
+ x = self.block['ega'](x, pos_k)
206
+ x = self.block['gcfn'](x)
207
+ x = x.permute([0, 2, 1])
208
+
209
+ return x
210
+
211
+
212
+ class LocalBlock(torch.nn.Module):
213
+ def __init__(self, in_channels: int, kernel_size: int, dropout_rate: float):
214
+ super().__init__()
215
+ self.block = torch.nn.ModuleDict({
216
+ 'cla': CLA(in_channels, kernel_size, dropout_rate),
217
+ 'gcfn': GCFN(in_channels, dropout_rate)
218
+ })
219
+
220
+ def forward(self, x: torch.Tensor):
221
+ x = self.block['cla'](x)
222
+ x = self.block['gcfn'](x)
223
+
224
+ return x
225
+
226
+
227
+ class SpkAttention(torch.nn.Module):
228
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
229
+ super().__init__()
230
+ self.self_attn = MultiHeadAttention(n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate)
231
+ self.feed_forward = GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
232
+
233
+ def forward(self, x: torch.Tensor, num_spk: int):
234
+ """
235
+ Compute encoded features.
236
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
237
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
238
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
239
+ """
240
+ B, F, T = x.shape
241
+ x = x.view(B//num_spk, num_spk, F, T).contiguous()
242
+ x = x.permute([0, 3, 1, 2]).contiguous()
243
+ x = x.view(-1, num_spk, F).contiguous()
244
+ x = x + self.self_attn(x, None, None)
245
+ x = x.view(B//num_spk, T, num_spk, F).contiguous()
246
+ x = x.permute([0, 2, 3, 1]).contiguous()
247
+ x = x.view(B, F, T).contiguous()
248
+ x = x.permute([0, 2, 1])
249
+ x = self.feed_forward(x)
250
+ x = x.permute([0, 2, 1])
251
+
252
+ return x
models/SepReformer/SepReformer_Large_DM_WSJ0/configs.yaml ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config:
2
+ dataset:
3
+ max_len : 32000
4
+ sampling_rate: 8000
5
+ scp_dir: "data/scp_ss_8k"
6
+ train:
7
+ mixture: "tr_mix.scp"
8
+ spk1: "tr_s1.scp"
9
+ spk2: "tr_s2.scp"
10
+ dynamic_mixing: true
11
+ valid:
12
+ mixture: "cv_mix.scp"
13
+ spk1: "cv_s1.scp"
14
+ spk2: "cv_s2.scp"
15
+ test:
16
+ mixture: "tt_mix.scp"
17
+ spk1: "tt_s1.scp"
18
+ spk2: "tt_s2.scp"
19
+ dataloader:
20
+ batch_size: 2
21
+ pin_memory: false
22
+ num_workers: 12
23
+ drop_last: false
24
+ model:
25
+ num_stages: &var_model_num_stages 4 # R
26
+ num_spks: &var_model_num_spks 2
27
+ module_audio_enc:
28
+ in_channels: 1
29
+ out_channels: &var_model_audio_enc_out_channels 256
30
+ kernel_size: &var_model_audio_enc_kernel_size 16 # L
31
+ stride: &var_model_audio_enc_stride 4 # S
32
+ groups: 1
33
+ bias: false
34
+ module_feature_projector:
35
+ num_channels: *var_model_audio_enc_out_channels
36
+ in_channels: *var_model_audio_enc_out_channels
37
+ out_channels: &feature_projector_out_channels 256 # F
38
+ kernel_size: 1
39
+ bias: false
40
+ module_separator:
41
+ num_stages: *var_model_num_stages
42
+ relative_positional_encoding:
43
+ in_channels: *feature_projector_out_channels
44
+ num_heads: 8
45
+ maxlen: 2000
46
+ embed_v: false
47
+ enc_stage:
48
+ global_blocks:
49
+ in_channels: *feature_projector_out_channels
50
+ num_mha_heads: 8
51
+ dropout_rate: 0.1
52
+ local_blocks:
53
+ in_channels: *feature_projector_out_channels
54
+ kernel_size: 65
55
+ dropout_rate: 0.1
56
+ down_conv_layer:
57
+ in_channels: *feature_projector_out_channels
58
+ samp_kernel_size: &var_model_samp_kernel_size 5
59
+ spk_split_stage:
60
+ in_channels: *feature_projector_out_channels
61
+ num_spks: *var_model_num_spks
62
+ simple_fusion:
63
+ out_channels: *feature_projector_out_channels
64
+ dec_stage:
65
+ num_spks: *var_model_num_spks
66
+ global_blocks:
67
+ in_channels: *feature_projector_out_channels
68
+ num_mha_heads: 8
69
+ dropout_rate: 0.1
70
+ local_blocks:
71
+ in_channels: *feature_projector_out_channels
72
+ kernel_size: 65
73
+ dropout_rate: 0.1
74
+ spk_attention:
75
+ in_channels: *feature_projector_out_channels
76
+ num_mha_heads: 8
77
+ dropout_rate: 0.1
78
+ module_output_layer:
79
+ in_channels: *var_model_audio_enc_out_channels
80
+ out_channels: *feature_projector_out_channels
81
+ num_spks: *var_model_num_spks
82
+ module_audio_dec:
83
+ in_channels: *var_model_audio_enc_out_channels
84
+ out_channels: 1
85
+ kernel_size: *var_model_audio_enc_kernel_size
86
+ stride: *var_model_audio_enc_stride
87
+ bias: false
88
+ criterion: ### Ref: https://pytorch.org/docs/stable/nn.html#loss-functions
89
+ name: ["PIT_SISNR_mag", "PIT_SISNR_time", "PIT_SISNRi", "PIT_SDRi"] ### Choose a torch.nn's loss function class(=attribute) e.g. ["L1Loss", "MSELoss", "CrossEntropyLoss", ...] / You can also build your optimizer :)
90
+ PIT_SISNR_mag:
91
+ frame_length: 512
92
+ frame_shift: 128
93
+ window: 'hann'
94
+ num_stages: *var_model_num_stages
95
+ num_spks: *var_model_num_spks
96
+ scale_inv: true
97
+ mel_opt: false
98
+ PIT_SISNR_time:
99
+ num_spks: *var_model_num_spks
100
+ scale_inv: true
101
+ PIT_SISNRi:
102
+ num_spks: *var_model_num_spks
103
+ scale_inv: true
104
+ PIT_SDRi:
105
+ dump: 0
106
+ optimizer: ### Ref: https://pytorch.org/docs/stable/optim.html#algorithms
107
+ name: ["AdamW"] ### Choose a torch.optim's class(=attribute) e.g. ["Adam", "AdamW", "SGD", ...] / You can also build your optimizer :)
108
+ AdamW:
109
+ lr: 2.0e-4
110
+ weight_decay: 1.0e-2
111
+ scheduler: ### Ref(+ find "How to adjust learning rate"): https://pytorch.org/docs/stable/optim.html#algorithms
112
+ name: ["ReduceLROnPlateau", "WarmupConstantSchedule"] ### Choose a torch.optim.lr_scheduler's class(=attribute) e.g. ["StepLR", "ReduceLROnPlateau", "Custom"] / You can also build your scheduler :)
113
+ ReduceLROnPlateau:
114
+ mode: "min"
115
+ min_lr: 1.0e-10
116
+ factor: 0.8
117
+ patience: 2
118
+ WarmupConstantSchedule:
119
+ warmup_steps: 1000
120
+ check_computations:
121
+ dummy_len: 16000
122
+ engine:
123
+ max_epoch: 200
124
+ gpuid: "0" ### "0"(single-gpu) or "0, 1" (multi-gpu)
125
+ mvn: false
126
+ clip_norm: 5
127
+ start_scheduling: 50
128
+ test_epochs: [50, 80, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 199]
models/SepReformer/SepReformer_Large_DM_WSJ0/dataset.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import librosa as audio_lib
5
+ import numpy as np
6
+
7
+ from utils import util_dataset
8
+ from utils.decorators import *
9
+ from loguru import logger
10
+ from torch.utils.data import Dataset, DataLoader
11
+
12
+ @logger_wraps()
13
+ def get_dataloaders(args, dataset_config, loader_config):
14
+ # create dataset object for each partition
15
+ partitions = ["test"] if "test" in args.engine_mode else ["train", "valid", "test"]
16
+ dataloaders = {}
17
+ for partition in partitions:
18
+ scp_config_mix = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['mixture'])
19
+ scp_config_spk = [os.path.join(dataset_config["scp_dir"], dataset_config[partition][spk_key]) for spk_key in dataset_config[partition] if spk_key.startswith('spk')]
20
+ dynamic_mixing = dataset_config[partition]["dynamic_mixing"] if partition == 'train' else False
21
+ dataset = MyDataset(
22
+ max_len = dataset_config['max_len'],
23
+ fs = dataset_config['sampling_rate'],
24
+ partition = partition,
25
+ wave_scp_srcs = scp_config_spk,
26
+ wave_scp_mix = scp_config_mix,
27
+ dynamic_mixing = dynamic_mixing)
28
+ dataloader = DataLoader(
29
+ dataset = dataset,
30
+ batch_size = 1 if partition == 'test' else loader_config["batch_size"],
31
+ shuffle = True, # only train: (partition == 'train') / all: True
32
+ pin_memory = loader_config["pin_memory"],
33
+ num_workers = loader_config["num_workers"],
34
+ drop_last = loader_config["drop_last"],
35
+ collate_fn = _collate)
36
+ dataloaders[partition] = dataloader
37
+ return dataloaders
38
+
39
+
40
+ def _collate(egs):
41
+ """
42
+ Transform utterance index into a minbatch
43
+
44
+ Arguments:
45
+ index: a list type [{},{},{}]
46
+
47
+ Returns:
48
+ input_sizes: a tensor correspond to utterance length
49
+ input_feats: packed sequence to feed networks
50
+ source_attr/target_attr: dictionary contains spectrogram/phase needed in loss computation
51
+ """
52
+ def __prepare_target_rir(dict_lsit, index):
53
+ return torch.nn.utils.rnn.pad_sequence([torch.tensor(d["src"][index], dtype=torch.float32) for d in dict_lsit], batch_first=True)
54
+ if type(egs) is not list: raise ValueError("Unsupported index type({})".format(type(egs)))
55
+ num_spks = 2 # you need to set this paramater by yourself
56
+ dict_list = sorted([eg for eg in egs], key=lambda x: x['num_sample'], reverse=True)
57
+ mixture = torch.nn.utils.rnn.pad_sequence([torch.tensor(d['mix'], dtype=torch.float32) for d in dict_list], batch_first=True)
58
+ src = [__prepare_target_rir(dict_list, index) for index in range(num_spks)]
59
+ input_sizes = torch.tensor([d['num_sample'] for d in dict_list], dtype=torch.float32)
60
+ key = [d['key'] for d in dict_list]
61
+ return input_sizes, mixture, src, key
62
+
63
+
64
+ @logger_wraps()
65
+ class MyDataset(Dataset):
66
+ def __init__(self, max_len, fs, partition, wave_scp_srcs, wave_scp_mix, dynamic_mixing, speed_list=None):
67
+ self.partition = partition
68
+ for wave_scp_src in wave_scp_srcs:
69
+ if not os.path.exists(wave_scp_src): raise FileNotFoundError(f"Could not find file {wave_scp_src}")
70
+ self.max_len = max_len
71
+ self.fs = fs
72
+ self.wave_dict_srcs = [util_dataset.parse_scps(wave_scp_src) for wave_scp_src in wave_scp_srcs]
73
+ self.wave_dict_mix = util_dataset.parse_scps(wave_scp_mix)
74
+ self.wave_keys = list(self.wave_dict_mix.keys())
75
+ logger.info(f"Create MyDataset for {wave_scp_mix} with {len(self.wave_dict_mix)} utterances")
76
+ self.dynamic_mixing = dynamic_mixing
77
+
78
+ def __len__(self):
79
+ return len(self.wave_dict_mix)
80
+
81
+ def __contains__(self, key):
82
+ return key in self.wave_dict_mix
83
+
84
+ def _dynamic_mixing(self, key):
85
+ def __match_length(wav, len_data) :
86
+ leftover = len(wav) - len_data
87
+ idx = random.randint(0,leftover)
88
+ wav = wav[idx:idx+len_data]
89
+ return wav
90
+
91
+ samps_src = []
92
+ src_len = []
93
+ # dyanmic source choice
94
+ # checking whether it is the same speaker
95
+ while True:
96
+ key_random = random.choice(list(self.wave_dict_srcs[0].keys()))
97
+ tmp1 = key.split('_')[1][:3] != key_random.split('_')[3][:3]
98
+ tmp2 = key.split('_')[3][:3] != key_random.split('_')[1][:3]
99
+ if tmp1 and tmp2: break
100
+
101
+ idx1, idx2 = (0, 1) if random.random() > 0.5 else (1, 0)
102
+ files = [self.wave_dict_srcs[idx1][key], self.wave_dict_srcs[idx2][key_random]]
103
+
104
+ # load
105
+ for idx, file in enumerate(files):
106
+ if not os.path.exists(file): raise FileNotFoundError("Input file {} do not exists!".format(file))
107
+ samps_tmp, _ = audio_lib.load(file, sr=self.fs)
108
+
109
+ if idx == 0: ref_rms = np.sqrt(np.mean(np.square(samps_tmp)))
110
+ curr_rms = np.sqrt(np.mean(np.square(samps_tmp)))
111
+
112
+ norm_factor = ref_rms / curr_rms
113
+ samps_tmp *= norm_factor
114
+
115
+ # mixing with random gains
116
+ gain = pow(10,-random.uniform(-5,5)/20)
117
+ samps_tmp = np.array(torch.tensor(samps_tmp))
118
+ samps_src.append(gain*samps_tmp)
119
+ src_len.append(len(samps_tmp))
120
+
121
+ # matching the audio length
122
+ min_len = min(src_len)
123
+
124
+ # add noise source dynamically if needed
125
+ samps_src = [__match_length(s, min_len) for s in samps_src]
126
+ samps_mix = sum(samps_src)
127
+
128
+ # ! truncated along to the sample Length "L"
129
+ if len(samps_mix)%4 != 0:
130
+ remains = len(samps_mix)%4
131
+ samps_mix = samps_mix[:-remains]
132
+ samps_src = [s[:-remains] for s in samps_src]
133
+
134
+ if self.partition != "test":
135
+ if len(samps_mix) > self.max_len:
136
+ start = random.randint(0, len(samps_mix)-self.max_len)
137
+ samps_mix = samps_mix[start:start+self.max_len]
138
+ samps_src = [s[start:start+self.max_len] for s in samps_src]
139
+ return samps_mix, samps_src
140
+
141
+ def _direct_load(self, key):
142
+ samps_src = []
143
+ files = [wave_dict_src[key] for wave_dict_src in self.wave_dict_srcs]
144
+ for file in files:
145
+ if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
146
+ samps_tmp, _ = audio_lib.load(file, sr=self.fs)
147
+ samps_src.append(samps_tmp)
148
+
149
+ file = self.wave_dict_mix[key]
150
+ if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
151
+ samps_mix, _ = audio_lib.load(file, sr=self.fs)
152
+
153
+ # Truncate samples as needed
154
+ if len(samps_mix) % 4 != 0:
155
+ remains = len(samps_mix) % 4
156
+ samps_mix = samps_mix[:-remains]
157
+ samps_src = [s[:-remains] for s in samps_src]
158
+
159
+ if self.partition != "test":
160
+ if len(samps_mix) > self.max_len:
161
+ start = random.randint(0,len(samps_mix)-self.max_len)
162
+ samps_mix = samps_mix[start:start+self.max_len]
163
+ samps_src = [s[start:start+self.max_len] for s in samps_src]
164
+
165
+ return samps_mix, samps_src
166
+
167
+ def __getitem__(self, index):
168
+ key = self.wave_keys[index]
169
+ if any(key not in self.wave_dict_srcs[i] for i in range(len(self.wave_dict_srcs))) or key not in self.wave_dict_mix: raise KeyError(f"Could not find utterance {key}")
170
+ samps_mix, samps_src = self._dynamic_mixing(key) if self.dynamic_mixing else self._direct_load(key)
171
+ return {"num_sample": samps_mix.shape[0], "mix": samps_mix, "src": samps_src, "key": key}
models/SepReformer/SepReformer_Large_DM_WSJ0/engine.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import csv
4
+ import time
5
+ import soundfile as sf
6
+
7
+ from loguru import logger
8
+ from tqdm import tqdm
9
+ from utils import util_engine, functions
10
+ from utils.decorators import *
11
+ from torch.utils.tensorboard import SummaryWriter
12
+
13
+
14
+ @logger_wraps()
15
+ class Engine(object):
16
+ def __init__(self, args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device):
17
+
18
+ ''' Default setting '''
19
+ self.engine_mode = args.engine_mode
20
+ self.out_wav_dir = args.out_wav_dir
21
+ self.config = config
22
+ self.gpuid = gpuid
23
+ self.device = device
24
+ self.model = model.to(self.device)
25
+ self.dataloaders = dataloaders # self.dataloaders['train'] or ['valid'] or ['test']
26
+ self.PIT_SISNR_mag_loss, self.PIT_SISNR_time_loss, self.PIT_SISNRi_loss, self.PIT_SDRi_loss = criterions
27
+ self.main_optimizer = optimizers[0]
28
+ self.main_scheduler, self.warmup_scheduler = schedulers
29
+
30
+ self.pretrain_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "pretrain_weights")
31
+ os.makedirs(self.pretrain_weights_path, exist_ok=True)
32
+ self.scratch_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "scratch_weights")
33
+ os.makedirs(self.scratch_weights_path, exist_ok=True)
34
+
35
+ self.checkpoint_path = self.pretrain_weights_path if any(file.endswith(('.pt', '.pt', '.pkl')) for file in os.listdir(self.pretrain_weights_path)) else self.scratch_weights_path
36
+ self.start_epoch = util_engine.load_last_checkpoint_n_get_epoch(self.checkpoint_path, self.model, self.main_optimizer, location=self.device)
37
+
38
+ # Logging
39
+ util_engine.model_params_mac_summary(
40
+ model=self.model,
41
+ input=torch.randn(1, self.config['check_computations']['dummy_len']).to(self.device),
42
+ dummy_input=torch.rand(1, self.config['check_computations']['dummy_len']).to(self.device),
43
+ metrics=['ptflops', 'thop', 'torchinfo']
44
+ # metrics=['ptflops']
45
+ )
46
+
47
+ logger.info(f"Clip gradient by 2-norm {self.config['engine']['clip_norm']}")
48
+
49
+ @logger_wraps()
50
+ def _train(self, dataloader, epoch):
51
+ self.model.train()
52
+ tot_loss_freq = [0 for _ in range(self.model.num_stages)]
53
+ tot_loss_time, num_batch = 0, 0
54
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:25}{r_bar}{bar:-10b}', colour="YELLOW", dynamic_ncols=True)
55
+ for input_sizes, mixture, src, _ in dataloader:
56
+ nnet_input = mixture
57
+ nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
58
+ num_batch += 1
59
+ pbar.update(1)
60
+ # Scheduler learning rate for warm-up (Iteration-based update for transformers)
61
+ if epoch == 1: self.warmup_scheduler.step()
62
+ nnet_input = nnet_input.to(self.device)
63
+ self.main_optimizer.zero_grad()
64
+ estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
65
+ cur_loss_s_bn = 0
66
+ cur_loss_s_bn = []
67
+ for idx, estim_src_value in enumerate(estim_src_bn):
68
+ cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
69
+ tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
70
+ cur_loss_s = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
71
+ tot_loss_time += cur_loss_s.item() / self.config['model']['num_spks']
72
+ alpha = 0.4 * 0.8**(1+(epoch-101)//5) if epoch > 100 else 0.4
73
+ cur_loss = (1-alpha) * cur_loss_s + alpha * sum(cur_loss_s_bn) / len(cur_loss_s_bn)
74
+ cur_loss = cur_loss / self.config['model']['num_spks']
75
+ cur_loss.backward()
76
+ if self.config['engine']['clip_norm']: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['engine']['clip_norm'])
77
+ self.main_optimizer.step()
78
+ dict_loss = {"T_Loss": tot_loss_time / num_batch}
79
+ dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
80
+ pbar.set_postfix(dict_loss)
81
+ pbar.close()
82
+ tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
83
+ return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
84
+
85
+ @logger_wraps()
86
+ def _validate(self, dataloader):
87
+ self.model.eval()
88
+ tot_loss_freq = [0 for _ in range(self.model.num_stages)]
89
+ tot_loss_time, num_batch = 0, 0
90
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="RED", dynamic_ncols=True)
91
+ with torch.inference_mode():
92
+ for input_sizes, mixture, src, _ in dataloader:
93
+ nnet_input = mixture
94
+ nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
95
+ nnet_input = nnet_input.to(self.device)
96
+ num_batch += 1
97
+ pbar.update(1)
98
+ estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
99
+ cur_loss_s_bn = []
100
+ for idx, estim_src_value in enumerate(estim_src_bn):
101
+ cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
102
+ tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
103
+ cur_loss_s_SDR = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
104
+ tot_loss_time += cur_loss_s_SDR.item() / self.config['model']['num_spks']
105
+ dict_loss = {"T_Loss":tot_loss_time / num_batch}
106
+ dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
107
+ pbar.set_postfix(dict_loss)
108
+ pbar.close()
109
+ tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
110
+ return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
111
+
112
+ @logger_wraps()
113
+ def _test(self, dataloader, wav_dir=None):
114
+ self.model.eval()
115
+ total_loss_SISNRi, total_loss_SDRi, num_batch = 0, 0, 0
116
+ pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="grey", dynamic_ncols=True)
117
+ with torch.inference_mode():
118
+ csv_file_name_sisnr = os.path.join(os.path.dirname(__file__),'test_SISNRi_value.csv')
119
+ csv_file_name_sdr = os.path.join(os.path.dirname(__file__),'test_SDRi_value.csv')
120
+ with open(csv_file_name_sisnr, 'w', newline='') as csvfile_sisnr, open(csv_file_name_sdr, 'w', newline='') as csvfile_sdr:
121
+ idx = 0
122
+ writer_sisnr = csv.writer(csvfile_sisnr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
123
+ writer_sdr = csv.writer(csvfile_sdr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
124
+ for input_sizes, mixture, src, key in dataloader:
125
+ if len(key) > 1:
126
+ raise("batch size is not one!!")
127
+ nnet_input = mixture.to(self.device)
128
+ num_batch += 1
129
+ pbar.update(1)
130
+ estim_src, _ = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
131
+ cur_loss_SISNRi, cur_loss_SISNRi_src = self.PIT_SISNRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src, eps=1.0e-15)
132
+ total_loss_SISNRi += cur_loss_SISNRi.item() / self.config['model']['num_spks']
133
+ cur_loss_SDRi, cur_loss_SDRi_src = self.PIT_SDRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src)
134
+ total_loss_SDRi += cur_loss_SDRi.item() / self.config['model']['num_spks']
135
+ writer_sisnr.writerow([key[0][:-4]] + [cur_loss_SISNRi_src[i].item() for i in range(self.config['model']['num_spks'])])
136
+ writer_sdr.writerow([key[0][:-4]] + [cur_loss_SDRi_src[i].item() for i in range(self.config['model']['num_spks'])])
137
+ if self.engine_mode == "test_save":
138
+ if wav_dir == None: wav_dir = os.path.join(os.path.dirname(__file__),"wav_out")
139
+ if wav_dir and not os.path.exists(wav_dir): os.makedirs(wav_dir)
140
+ mixture = torch.squeeze(mixture).cpu().data.numpy()
141
+ sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_mixture.wav'), 0.5*mixture/max(abs(mixture)), 8000)
142
+ for i in range(self.config['model']['num_spks']):
143
+ src = torch.squeeze(estim_src[i]).cpu().data.numpy()
144
+ sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_out_'+str(i)+'.wav'), 0.5*src/max(abs(src)), 8000)
145
+ idx += 1
146
+ dict_loss = {"SiSNRi": total_loss_SISNRi/num_batch, "SDRi": total_loss_SDRi/num_batch}
147
+ pbar.set_postfix(dict_loss)
148
+ pbar.close()
149
+ return total_loss_SISNRi/num_batch, total_loss_SDRi/num_batch, num_batch
150
+
151
+ @logger_wraps()
152
+ def run(self):
153
+ with torch.cuda.device(self.device):
154
+ writer_src = SummaryWriter(os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/tensorboard"))
155
+ if "test" in self.engine_mode:
156
+ on_test_start = time.time()
157
+ test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'], self.out_wav_dir)
158
+ on_test_end = time.time()
159
+ logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
160
+ logger.info(f"Testing done!")
161
+ else:
162
+ start_time = time.time()
163
+ if self.start_epoch > 1:
164
+ init_loss_time, init_loss_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
165
+ else:
166
+ init_loss_time, init_loss_freq = 0, 0
167
+ end_time = time.time()
168
+ logger.info(f"[INIT] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: Loss_t = {init_loss_time:.4f} dB | Loss_f = {init_loss_freq:.4f} dB | Speed = ({end_time-start_time:.2f}s)")
169
+ for epoch in range(self.start_epoch, self.config['engine']['max_epoch']):
170
+ valid_loss_best = init_loss_time
171
+ train_start_time = time.time()
172
+ train_loss_src_time, train_loss_src_freq, train_num_batch = self._train(self.dataloaders['train'], epoch)
173
+ train_end_time = time.time()
174
+ valid_start_time = time.time()
175
+ valid_loss_src_time, valid_loss_src_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
176
+ valid_end_time = time.time()
177
+ if epoch > self.config['engine']['start_scheduling']: self.main_scheduler.step(valid_loss_src_time)
178
+ logger.info(f"[TRAIN] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {train_loss_src_time:.4f} dB | Loss_f = {train_loss_src_freq:.4f} dB | Speed = ({train_end_time - train_start_time:.2f}s/{train_num_batch:d})")
179
+ logger.info(f"[VALID] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {valid_loss_src_time:.4f} dB | Loss_f = {valid_loss_src_freq:.4f} dB | Speed = ({valid_end_time - valid_start_time:.2f}s/{valid_num_batch:d})")
180
+ if epoch in self.config['engine']['test_epochs']:
181
+ on_test_start = time.time()
182
+ test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'])
183
+ on_test_end = time.time()
184
+ logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
185
+ valid_loss_best = util_engine.save_checkpoint_per_best(valid_loss_best, valid_loss_src_time, train_loss_src_time, epoch, self.model, self.main_optimizer, self.checkpoint_path)
186
+ # Logging to monitoring tools (Tensorboard && Wandb)
187
+ writer_src.add_scalars("Metrics", {
188
+ 'Loss_train_time': train_loss_src_time,
189
+ 'Loss_valid_time': valid_loss_src_time}, epoch)
190
+ writer_src.add_scalar("Learning Rate", self.main_optimizer.param_groups[0]['lr'], epoch)
191
+ writer_src.flush()
192
+ logger.info(f"Training for {self.config['engine']['max_epoch']} epoches done!")
models/SepReformer/SepReformer_Large_DM_WSJ0/main.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from loguru import logger
4
+ from .dataset import get_dataloaders
5
+ from .model import Model
6
+ from .engine import Engine
7
+ from utils import util_system, util_implement
8
+ from utils.decorators import *
9
+
10
+ # Setup logger
11
+ log_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/system_log.log")
12
+ logger.add(log_file_path, level="DEBUG", mode="w")
13
+
14
+ @logger_wraps()
15
+ def main(args):
16
+
17
+ ''' Build Setting '''
18
+ # Call configuration file (configs.yaml)
19
+ yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs.yaml")
20
+ yaml_dict = util_system.parse_yaml(yaml_path)
21
+
22
+ # Run wandb and get configuration
23
+ config = yaml_dict["config"] # wandb login success or fail
24
+
25
+ # Call DataLoader [train / valid / test / etc...]
26
+ dataloaders = get_dataloaders(args, config["dataset"], config["dataloader"])
27
+
28
+ ''' Build Model '''
29
+ # Call network model
30
+ model = Model(**config["model"])
31
+
32
+ ''' Build Engine '''
33
+ # Call gpu id & device
34
+ gpuid = tuple(map(int, config["engine"]["gpuid"].split(',')))
35
+ device = torch.device(f'cuda:{gpuid[0]}')
36
+
37
+ # Call Implement [criterion / optimizer / scheduler]
38
+ criterions = util_implement.CriterionFactory(config["criterion"], device).get_criterions()
39
+ optimizers = util_implement.OptimizerFactory(config["optimizer"], model.parameters()).get_optimizers()
40
+ schedulers = util_implement.SchedulerFactory(config["scheduler"], optimizers).get_schedulers()
41
+
42
+ # Call & Run Engine
43
+ engine = Engine(args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device)
44
+ engine.run()
models/SepReformer/SepReformer_Large_DM_WSJ0/model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+
4
+ import torch
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ from utils.decorators import *
9
+ from .modules.module import *
10
+
11
+
12
+ @logger_wraps()
13
+ class Model(torch.nn.Module):
14
+ def __init__(self,
15
+ num_stages: int,
16
+ num_spks: int,
17
+ module_audio_enc: dict,
18
+ module_feature_projector: dict,
19
+ module_separator: dict,
20
+ module_output_layer: dict,
21
+ module_audio_dec: dict):
22
+ super().__init__()
23
+ self.num_stages = num_stages
24
+ self.num_spks = num_spks
25
+ self.audio_encoder = AudioEncoder(**module_audio_enc)
26
+ self.feature_projector = FeatureProjector(**module_feature_projector)
27
+ self.separator = Separator(**module_separator)
28
+ self.out_layer = OutputLayer(**module_output_layer)
29
+ self.audio_decoder = AudioDecoder(**module_audio_dec)
30
+
31
+ # Aux_loss
32
+ self.out_layer_bn = torch.nn.ModuleList([])
33
+ self.decoder_bn = torch.nn.ModuleList([])
34
+ for _ in range(self.num_stages):
35
+ self.out_layer_bn.append(OutputLayer(**module_output_layer, masking=True))
36
+ self.decoder_bn.append(AudioDecoder(**module_audio_dec))
37
+
38
+ def forward(self, x):
39
+ encoder_output = self.audio_encoder(x)
40
+ projected_feature = self.feature_projector(encoder_output)
41
+ last_stage_output, each_stage_outputs = self.separator(projected_feature)
42
+ out_layer_output = self.out_layer(last_stage_output, encoder_output)
43
+ each_spk_output = [out_layer_output[idx] for idx in range(self.num_spks)]
44
+ audio = [self.audio_decoder(each_spk_output[idx]) for idx in range(self.num_spks)]
45
+
46
+ # Aux_loss
47
+ audio_aux = []
48
+ for idx, each_stage_output in enumerate(each_stage_outputs):
49
+ each_stage_output = self.out_layer_bn[idx](torch.nn.functional.upsample(each_stage_output, encoder_output.shape[-1]), encoder_output)
50
+ out_aux = [each_stage_output[jdx] for jdx in range(self.num_spks)]
51
+ audio_aux.append([self.decoder_bn[idx](out_aux[jdx])[...,:x.shape[-1]] for jdx in range(self.num_spks)])
52
+
53
+ return audio, audio_aux
models/SepReformer/SepReformer_Large_DM_WSJ0/modules/module.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+
4
+ import torch
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ from utils.decorators import *
9
+ from .network import *
10
+
11
+
12
+ class AudioEncoder(torch.nn.Module):
13
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, bias: bool):
14
+ super().__init__()
15
+ self.conv1d = torch.nn.Conv1d(
16
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, groups=groups, bias=bias)
17
+ self.gelu = torch.nn.GELU()
18
+
19
+ def forward(self, x: torch.Tensor):
20
+ x = torch.unsqueeze(x, dim=0) if len(x.shape) == 1 else torch.unsqueeze(x, dim=1) # [T] - >[1, T] OR [B, T] -> [B, 1, T]
21
+ x = self.conv1d(x)
22
+ x = self.gelu(x)
23
+ return x
24
+
25
+ class FeatureProjector(torch.nn.Module):
26
+ def __init__(self, num_channels: int, in_channels: int, out_channels: int, kernel_size: int, bias: bool):
27
+ super().__init__()
28
+ self.norm = torch.nn.GroupNorm(num_groups=1, num_channels=num_channels, eps=1e-8)
29
+ self.conv1d = torch.nn.Conv1d(
30
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
31
+
32
+ def forward(self, x: torch.Tensor):
33
+ x = self.norm(x)
34
+ x = self.conv1d(x)
35
+ return x
36
+
37
+
38
+ class Separator(torch.nn.Module):
39
+ def __init__(self, num_stages: int, relative_positional_encoding: dict, enc_stage: dict, spk_split_stage: dict, simple_fusion:dict, dec_stage: dict):
40
+ super().__init__()
41
+
42
+ class RelativePositionalEncoding(torch.nn.Module):
43
+ def __init__(self, in_channels: int, num_heads: int, maxlen: int, embed_v=False):
44
+ super().__init__()
45
+ self.in_channels = in_channels
46
+ self.num_heads = num_heads
47
+ self.embedding_dim = self.in_channels // self.num_heads
48
+ self.maxlen = maxlen
49
+ self.pe_k = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim)
50
+ self.pe_v = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim) if embed_v else None
51
+
52
+ def forward(self, pos_seq: torch.Tensor):
53
+ pos_seq.clamp_(-self.maxlen, self.maxlen - 1)
54
+ pos_seq += self.maxlen
55
+ pe_k_output = self.pe_k(pos_seq)
56
+ pe_v_output = self.pe_v(pos_seq) if self.pe_v is not None else None
57
+ return pe_k_output, pe_v_output
58
+
59
+ class SepEncStage(torch.nn.Module):
60
+ def __init__(self, global_blocks: dict, local_blocks: dict, down_conv_layer: dict, down_conv=True):
61
+ super().__init__()
62
+
63
+ class DownConvLayer(torch.nn.Module):
64
+ def __init__(self, in_channels: int, samp_kernel_size: int):
65
+ """Construct an EncoderLayer object."""
66
+ super().__init__()
67
+ self.down_conv = torch.nn.Conv1d(
68
+ in_channels=in_channels, out_channels=in_channels, kernel_size=samp_kernel_size, stride=2, padding=(samp_kernel_size-1)//2, groups=in_channels)
69
+ self.BN = torch.nn.BatchNorm1d(num_features=in_channels)
70
+ self.gelu = torch.nn.GELU()
71
+
72
+ def forward(self, x: torch.Tensor):
73
+ x = x.permute([0, 2, 1])
74
+ x = self.down_conv(x)
75
+ x = self.BN(x)
76
+ x = self.gelu(x)
77
+ x = x.permute([0, 2, 1])
78
+ return x
79
+
80
+ self.g_block_1 = GlobalBlock(**global_blocks)
81
+ self.l_block_1 = LocalBlock(**local_blocks)
82
+
83
+ self.g_block_2 = GlobalBlock(**global_blocks)
84
+ self.l_block_2 = LocalBlock(**local_blocks)
85
+
86
+ self.downconv = DownConvLayer(**down_conv_layer) if down_conv == True else None
87
+
88
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
89
+ '''
90
+ x: [B, N, T]
91
+ '''
92
+ x = self.g_block_1(x, pos_k)
93
+ x = x.permute(0, 2, 1).contiguous()
94
+ x = self.l_block_1(x)
95
+ x = x.permute(0, 2, 1).contiguous()
96
+
97
+ x = self.g_block_2(x, pos_k)
98
+ x = x.permute(0, 2, 1).contiguous()
99
+ x = self.l_block_2(x)
100
+ x = x.permute(0, 2, 1).contiguous()
101
+
102
+ skip = x
103
+ if self.downconv:
104
+ x = x.permute(0, 2, 1).contiguous()
105
+ x = self.downconv(x)
106
+ x = x.permute(0, 2, 1).contiguous()
107
+ # [BK, S, N]
108
+ return x, skip
109
+
110
+ class SpkSplitStage(torch.nn.Module):
111
+ def __init__(self, in_channels: int, num_spks: int):
112
+ super().__init__()
113
+ self.linear = torch.nn.Sequential(
114
+ torch.nn.Conv1d(in_channels, 4*in_channels*num_spks, kernel_size=1),
115
+ torch.nn.GLU(dim=-2),
116
+ torch.nn.Conv1d(2*in_channels*num_spks, in_channels*num_spks, kernel_size=1))
117
+ self.norm = torch.nn.GroupNorm(1, in_channels, eps=1e-8)
118
+ self.num_spks = num_spks
119
+
120
+ def forward(self, x: torch.Tensor):
121
+ x = self.linear(x)
122
+ B, _, T = x.shape
123
+ x = x.view(B*self.num_spks,-1, T).contiguous()
124
+ x = self.norm(x)
125
+ return x
126
+
127
+ class SepDecStage(torch.nn.Module):
128
+ def __init__(self, num_spks: int, global_blocks: dict, local_blocks: dict, spk_attention: dict):
129
+ super().__init__()
130
+
131
+ self.g_block_1 = GlobalBlock(**global_blocks)
132
+ self.l_block_1 = LocalBlock(**local_blocks)
133
+ self.spk_attn_1 = SpkAttention(**spk_attention)
134
+
135
+ self.g_block_2 = GlobalBlock(**global_blocks)
136
+ self.l_block_2 = LocalBlock(**local_blocks)
137
+ self.spk_attn_2 = SpkAttention(**spk_attention)
138
+
139
+ self.g_block_3 = GlobalBlock(**global_blocks)
140
+ self.l_block_3 = LocalBlock(**local_blocks)
141
+ self.spk_attn_3 = SpkAttention(**spk_attention)
142
+
143
+ self.num_spk = num_spks
144
+
145
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
146
+ '''
147
+ x: [B, N, T]
148
+ '''
149
+ # [BS, K, H]
150
+ x = self.g_block_1(x, pos_k)
151
+ x = x.permute(0, 2, 1).contiguous()
152
+ x = self.l_block_1(x)
153
+ x = x.permute(0, 2, 1).contiguous()
154
+ x = self.spk_attn_1(x, self.num_spk)
155
+
156
+ x = self.g_block_2(x, pos_k)
157
+ x = x.permute(0, 2, 1).contiguous()
158
+ x = self.l_block_2(x)
159
+ x = x.permute(0, 2, 1).contiguous()
160
+ x = self.spk_attn_2(x, self.num_spk)
161
+
162
+ x = self.g_block_3(x, pos_k)
163
+ x = x.permute(0, 2, 1).contiguous()
164
+ x = self.l_block_3(x)
165
+ x = x.permute(0, 2, 1).contiguous()
166
+ x = self.spk_attn_3(x, self.num_spk)
167
+
168
+ skip = x
169
+
170
+ return x, skip
171
+
172
+ self.num_stages = num_stages
173
+ self.pos_emb = RelativePositionalEncoding(**relative_positional_encoding)
174
+
175
+ # Temporal Contracting Part
176
+ self.enc_stages = torch.nn.ModuleList([])
177
+ for _ in range(self.num_stages):
178
+ self.enc_stages.append(SepEncStage(**enc_stage, down_conv=True))
179
+
180
+ self.bottleneck_G = SepEncStage(**enc_stage, down_conv=False)
181
+ self.spk_split_block = SpkSplitStage(**spk_split_stage)
182
+
183
+ # Temporal Expanding Part
184
+ self.simple_fusion = torch.nn.ModuleList([])
185
+ self.dec_stages = torch.nn.ModuleList([])
186
+ for _ in range(self.num_stages):
187
+ self.simple_fusion.append(torch.nn.Conv1d(in_channels=simple_fusion['out_channels']*2,out_channels=simple_fusion['out_channels'], kernel_size=1))
188
+ self.dec_stages.append(SepDecStage(**dec_stage))
189
+
190
+ def forward(self, input: torch.Tensor):
191
+ '''input: [B, N, L]'''
192
+ # feature projection
193
+ x, _ = self.pad_signal(input)
194
+ len_x = x.shape[-1]
195
+ # Temporal Contracting Part
196
+ pos_seq = torch.arange(0, len_x//2**self.num_stages).long().to(x.device)
197
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
198
+ pos_k, _ = self.pos_emb(pos_seq)
199
+ skip = []
200
+ for idx in range(self.num_stages):
201
+ x, skip_ = self.enc_stages[idx](x, pos_k)
202
+ skip_ = self.spk_split_block(skip_)
203
+ skip.append(skip_)
204
+ x, _ = self.bottleneck_G(x, pos_k)
205
+ x = self.spk_split_block(x) # B, 2F, T
206
+
207
+ each_stage_outputs = []
208
+ # Temporal Expanding Part
209
+ for idx in range(self.num_stages):
210
+ each_stage_outputs.append(x)
211
+ idx_en = self.num_stages - (idx + 1)
212
+ x = torch.nn.functional.upsample(x, skip[idx_en].shape[-1])
213
+ x = torch.cat([x,skip[idx_en]],dim=1)
214
+ x = self.simple_fusion[idx](x)
215
+ x, _ = self.dec_stages[idx](x, pos_k)
216
+
217
+ last_stage_output = x
218
+ return last_stage_output, each_stage_outputs
219
+
220
+ def pad_signal(self, input: torch.Tensor):
221
+ # (B, T) or (B, 1, T)
222
+ if input.dim() == 1: input = input.unsqueeze(0)
223
+ elif input.dim() not in [2, 3]: raise RuntimeError("Input can only be 2 or 3 dimensional.")
224
+ elif input.dim() == 2: input = input.unsqueeze(1)
225
+ L = 2**self.num_stages
226
+ batch_size = input.size(0)
227
+ ndim = input.size(1)
228
+ nframe = input.size(2)
229
+ padded_len = (nframe//L + 1)*L
230
+ rest = 0 if nframe%L == 0 else padded_len - nframe
231
+ if rest > 0:
232
+ pad = torch.autograd.Variable(torch.zeros(batch_size, ndim, rest)).type(input.type()).to(input.device)
233
+ input = torch.cat([input, pad], dim=-1)
234
+ return input, rest
235
+
236
+
237
+ class OutputLayer(torch.nn.Module):
238
+ def __init__(self, in_channels: int, out_channels: int, num_spks: int, masking: bool = False):
239
+ super().__init__()
240
+ # feature expansion back
241
+ self.masking = masking
242
+ self.spe_block = Masking(in_channels, Activation_mask="ReLU", concat_opt=None)
243
+ self.num_spks = num_spks
244
+ self.end_conv1x1 = torch.nn.Sequential(
245
+ torch.nn.Linear(out_channels, 4*out_channels),
246
+ torch.nn.GLU(),
247
+ torch.nn.Linear(2*out_channels, in_channels))
248
+
249
+ def forward(self, x: torch.Tensor, input: torch.Tensor):
250
+ x = x[...,:input.shape[-1]]
251
+ x = x.permute([0, 2, 1])
252
+ x = self.end_conv1x1(x)
253
+ x = x.permute([0, 2, 1])
254
+ B, N, L = x.shape
255
+ B = B // self.num_spks
256
+
257
+ if self.masking:
258
+ input = input.expand(self.num_spks, B, N, L).transpose(0,1).contiguous()
259
+ input = input.view(B*self.num_spks, N, L)
260
+ x = self.spe_block(x, input)
261
+
262
+ x = x.view(B, self.num_spks, N, L)
263
+ # [spks, B, N, L]
264
+ x = x.transpose(0, 1)
265
+ return x
266
+
267
+
268
+ class AudioDecoder(torch.nn.ConvTranspose1d):
269
+ '''
270
+ Decoder of the TasNet
271
+ This module can be seen as the gradient of Conv1d with respect to its input.
272
+ It is also known as a fractionally-strided convolution
273
+ or a deconvolution (although it is not an actual deconvolution operation).
274
+ '''
275
+ def __init__(self, *args, **kwargs):
276
+ super().__init__(*args, **kwargs)
277
+
278
+ def forward(self, x):
279
+ # x: [B, N, L]
280
+ if x.dim() not in [2, 3]: raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
281
+ x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
282
+ x = torch.squeeze(x, dim=1) if torch.squeeze(x).dim() == 1 else torch.squeeze(x)
283
+ return x
models/SepReformer/SepReformer_Large_DM_WSJ0/modules/network.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy
4
+ from utils.decorators import *
5
+
6
+
7
+ class LayerScale(torch.nn.Module):
8
+ def __init__(self, dims, input_size, Layer_scale_init=1.0e-5):
9
+ super().__init__()
10
+ if dims == 1:
11
+ self.layer_scale = torch.nn.Parameter(torch.ones(input_size)*Layer_scale_init, requires_grad=True)
12
+ elif dims == 2:
13
+ self.layer_scale = torch.nn.Parameter(torch.ones(1,input_size)*Layer_scale_init, requires_grad=True)
14
+ elif dims == 3:
15
+ self.layer_scale = torch.nn.Parameter(torch.ones(1,1,input_size)*Layer_scale_init, requires_grad=True)
16
+
17
+ def forward(self, x):
18
+ return x*self.layer_scale
19
+
20
+ class Masking(torch.nn.Module):
21
+ def __init__(self, input_dim, Activation_mask='Sigmoid', **options):
22
+ super(Masking, self).__init__()
23
+
24
+ self.options = options
25
+ if self.options['concat_opt']:
26
+ self.pw_conv = torch.nn.Conv1d(input_dim*2, input_dim, 1, stride=1, padding=0)
27
+
28
+ if Activation_mask == 'Sigmoid':
29
+ self.gate_act = torch.nn.Sigmoid()
30
+ elif Activation_mask == 'ReLU':
31
+ self.gate_act = torch.nn.ReLU()
32
+
33
+
34
+ def forward(self, x, skip):
35
+
36
+ if self.options['concat_opt']:
37
+ y = torch.cat([x, skip], dim=-2)
38
+ y = self.pw_conv(y)
39
+ else:
40
+ y = x
41
+ y = self.gate_act(y) * skip
42
+
43
+ return y
44
+
45
+
46
+ class GCFN(torch.nn.Module):
47
+ def __init__(self, in_channels, dropout_rate, Layer_scale_init=1.0e-5):
48
+ super().__init__()
49
+ self.net1 = torch.nn.Sequential(
50
+ torch.nn.LayerNorm(in_channels),
51
+ torch.nn.Linear(in_channels, in_channels*6))
52
+ self.depthwise = torch.nn.Conv1d(in_channels*6, in_channels*6, 3, padding=1, groups=in_channels*6)
53
+ self.net2 = torch.nn.Sequential(
54
+ torch.nn.GLU(),
55
+ torch.nn.Dropout(dropout_rate),
56
+ torch.nn.Linear(in_channels*3, in_channels),
57
+ torch.nn.Dropout(dropout_rate))
58
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
59
+
60
+ def forward(self, x):
61
+ y = self.net1(x)
62
+ y = y.permute(0, 2, 1).contiguous()
63
+ y = self.depthwise(y)
64
+ y = y.permute(0, 2, 1).contiguous()
65
+ y = self.net2(y)
66
+ return x + self.Layer_scale(y)
67
+
68
+
69
+ class MultiHeadAttention(torch.nn.Module):
70
+ """
71
+ Multi-Head Attention layer.
72
+ :param int n_head: the number of head s
73
+ :param int n_feat: the number of features
74
+ :param float dropout_rate: dropout rate
75
+ """
76
+ def __init__(self, n_head: int, in_channels: int, dropout_rate: float, Layer_scale_init=1.0e-5):
77
+ super().__init__()
78
+ assert in_channels % n_head == 0
79
+ self.d_k = in_channels // n_head # We assume d_v always equals d_k
80
+ self.h = n_head
81
+ self.layer_norm = torch.nn.LayerNorm(in_channels)
82
+ self.linear_q = torch.nn.Linear(in_channels, in_channels)
83
+ self.linear_k = torch.nn.Linear(in_channels, in_channels)
84
+ self.linear_v = torch.nn.Linear(in_channels, in_channels)
85
+ self.linear_out = torch.nn.Linear(in_channels, in_channels)
86
+ self.attn = None
87
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
88
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
89
+
90
+ def forward(self, x, pos_k, mask):
91
+ """
92
+ Compute 'Scaled Dot Product Attention'.
93
+ :param torch.Tensor mask: (batch, time1, time2)
94
+ :param torch.nn.Dropout dropout:
95
+ :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
96
+ weighted by the query dot key attention (batch, head, time1, time2)
97
+ """
98
+ n_batch = x.size(0)
99
+ x = self.layer_norm(x)
100
+ q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
101
+ k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
102
+ v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k)
103
+ q = q.transpose(1, 2)
104
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
105
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
106
+ A = torch.matmul(q, k.transpose(-2, -1))
107
+ reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1)
108
+ if pos_k is not None:
109
+ B = torch.matmul(reshape_q, pos_k.transpose(-2, -1))
110
+ B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
111
+ scores = (A + B) / math.sqrt(self.d_k)
112
+ else:
113
+ scores = A / math.sqrt(self.d_k)
114
+ if mask is not None:
115
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
116
+ min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
117
+ scores = scores.masked_fill(mask, min_value)
118
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
119
+ else:
120
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
121
+ p_attn = self.dropout(self.attn)
122
+ x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
123
+ x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
124
+ return self.Layer_scale(self.dropout(self.linear_out(x))) # (batch, time1, d_model)
125
+
126
+ class EGA(torch.nn.Module):
127
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
128
+ super().__init__()
129
+ self.block = torch.nn.ModuleDict({
130
+ 'self_attn': MultiHeadAttention(
131
+ n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
132
+ 'linear': torch.nn.Sequential(
133
+ torch.nn.LayerNorm(normalized_shape=in_channels),
134
+ torch.nn.Linear(in_features=in_channels, out_features=in_channels),
135
+ torch.nn.Sigmoid())
136
+ })
137
+
138
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
139
+ """
140
+ Compute encoded features.
141
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
142
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
143
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
144
+ """
145
+ down_len = pos_k.shape[0]
146
+ x_down = torch.nn.functional.adaptive_avg_pool1d(input=x, output_size=down_len)
147
+ x = x.permute([0, 2, 1])
148
+ x_down = x_down.permute([0, 2, 1])
149
+ x_down = self.block['self_attn'](x_down, pos_k, None)
150
+ x_down = x_down.permute([0, 2, 1])
151
+ x_downup = torch.nn.functional.upsample(input=x_down, size=x.shape[1])
152
+ x_downup = x_downup.permute([0, 2, 1])
153
+ x = x + self.block['linear'](x) * x_downup
154
+
155
+ return x
156
+
157
+
158
+
159
+ class CLA(torch.nn.Module):
160
+ def __init__(self, in_channels, kernel_size, dropout_rate, Layer_scale_init=1.0e-5):
161
+ super().__init__()
162
+ self.layer_norm = torch.nn.LayerNorm(in_channels)
163
+ self.linear1 = torch.nn.Linear(in_channels, in_channels*2)
164
+ self.GLU = torch.nn.GLU()
165
+ self.dw_conv_1d = torch.nn.Conv1d(in_channels, in_channels, kernel_size, padding='same', groups=in_channels)
166
+ self.linear2 = torch.nn.Linear(in_channels, 2*in_channels)
167
+ self.BN = torch.nn.BatchNorm1d(2*in_channels)
168
+ self.linear3 = torch.nn.Sequential(
169
+ torch.nn.GELU(),
170
+ torch.nn.Linear(2*in_channels, in_channels),
171
+ torch.nn.Dropout(dropout_rate))
172
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
173
+
174
+ def forward(self, x):
175
+ y = self.layer_norm(x)
176
+ y = self.linear1(y)
177
+ y = self.GLU(y)
178
+ y = y.permute([0, 2, 1]) # B, F, T
179
+ y = self.dw_conv_1d(y)
180
+ y = y.permute(0, 2, 1) # B, T, 2F
181
+ y = self.linear2(y)
182
+ y = y.permute(0, 2, 1) # B, T, 2F
183
+ y = self.BN(y)
184
+ y = y.permute(0, 2, 1) # B, T, 2F
185
+ y = self.linear3(y)
186
+
187
+ return x + self.Layer_scale(y)
188
+
189
+ class GlobalBlock(torch.nn.Module):
190
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
191
+ super().__init__()
192
+ self.block = torch.nn.ModuleDict({
193
+ 'ega': EGA(
194
+ num_mha_heads=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
195
+ 'gcfn': GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
196
+ })
197
+
198
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
199
+ """
200
+ Compute encoded features.
201
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
202
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
203
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
204
+ """
205
+ x = self.block['ega'](x, pos_k)
206
+ x = self.block['gcfn'](x)
207
+ x = x.permute([0, 2, 1])
208
+
209
+ return x
210
+
211
+
212
+ class LocalBlock(torch.nn.Module):
213
+ def __init__(self, in_channels: int, kernel_size: int, dropout_rate: float):
214
+ super().__init__()
215
+ self.block = torch.nn.ModuleDict({
216
+ 'cla': CLA(in_channels, kernel_size, dropout_rate),
217
+ 'gcfn': GCFN(in_channels, dropout_rate)
218
+ })
219
+
220
+ def forward(self, x: torch.Tensor):
221
+ x = self.block['cla'](x)
222
+ x = self.block['gcfn'](x)
223
+
224
+ return x
225
+
226
+
227
+ class SpkAttention(torch.nn.Module):
228
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
229
+ super().__init__()
230
+ self.self_attn = MultiHeadAttention(n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate)
231
+ self.feed_forward = GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
232
+
233
+ def forward(self, x: torch.Tensor, num_spk: int):
234
+ """
235
+ Compute encoded features.
236
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
237
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
238
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
239
+ """
240
+ B, F, T = x.shape
241
+ x = x.view(B//num_spk, num_spk, F, T).contiguous()
242
+ x = x.permute([0, 3, 1, 2]).contiguous()
243
+ x = x.view(-1, num_spk, F).contiguous()
244
+ x = x + self.self_attn(x, None, None)
245
+ x = x.view(B//num_spk, T, num_spk, F).contiguous()
246
+ x = x.permute([0, 2, 3, 1]).contiguous()
247
+ x = x.view(B, F, T).contiguous()
248
+ x = x.permute([0, 2, 1])
249
+ x = self.feed_forward(x)
250
+ x = x.permute([0, 2, 1])
251
+
252
+ return x
models/SepReformer/source.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://github.com/dmlguq456/SepReformer