SepReformer (code, models, paper)
Browse files- .gitattributes +1 -0
- Separate and Reconstruct. Asymmetric Encoder-Decoder for Speech Separation.pdf +3 -0
- code/SepReformer.zip +3 -0
- code/sepformer-tse.zip +3 -0
- models/SepReformer/SepReformer_Base_WSJ0/configs.yaml +139 -0
- models/SepReformer/SepReformer_Base_WSJ0/dataset.py +165 -0
- models/SepReformer/SepReformer_Base_WSJ0/engine.py +216 -0
- models/SepReformer/SepReformer_Base_WSJ0/log/scratch_weights/epoch.0180.pth +3 -0
- models/SepReformer/SepReformer_Base_WSJ0/main.py +47 -0
- models/SepReformer/SepReformer_Base_WSJ0/model.py +53 -0
- models/SepReformer/SepReformer_Base_WSJ0/modules/module.py +283 -0
- models/SepReformer/SepReformer_Base_WSJ0/modules/network.py +252 -0
- models/SepReformer/SepReformer_Large_DM_WHAM/configs.yaml +129 -0
- models/SepReformer/SepReformer_Large_DM_WHAM/dataset.py +177 -0
- models/SepReformer/SepReformer_Large_DM_WHAM/engine.py +192 -0
- models/SepReformer/SepReformer_Large_DM_WHAM/main.py +44 -0
- models/SepReformer/SepReformer_Large_DM_WHAM/model.py +53 -0
- models/SepReformer/SepReformer_Large_DM_WHAM/modules/module.py +286 -0
- models/SepReformer/SepReformer_Large_DM_WHAM/modules/network.py +252 -0
- models/SepReformer/SepReformer_Large_DM_WHAMR/configs.yaml +131 -0
- models/SepReformer/SepReformer_Large_DM_WHAMR/dataset.py +187 -0
- models/SepReformer/SepReformer_Large_DM_WHAMR/engine.py +192 -0
- models/SepReformer/SepReformer_Large_DM_WHAMR/main.py +44 -0
- models/SepReformer/SepReformer_Large_DM_WHAMR/model.py +53 -0
- models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/module.cpython-310.pyc +0 -0
- models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/module.cpython-38.pyc +0 -0
- models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/network.cpython-310.pyc +0 -0
- models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/network.cpython-38.pyc +0 -0
- models/SepReformer/SepReformer_Large_DM_WHAMR/modules/module.py +283 -0
- models/SepReformer/SepReformer_Large_DM_WHAMR/modules/network.py +252 -0
- models/SepReformer/SepReformer_Large_DM_WSJ0/configs.yaml +128 -0
- models/SepReformer/SepReformer_Large_DM_WSJ0/dataset.py +171 -0
- models/SepReformer/SepReformer_Large_DM_WSJ0/engine.py +192 -0
- models/SepReformer/SepReformer_Large_DM_WSJ0/main.py +44 -0
- models/SepReformer/SepReformer_Large_DM_WSJ0/model.py +53 -0
- models/SepReformer/SepReformer_Large_DM_WSJ0/modules/module.py +283 -0
- models/SepReformer/SepReformer_Large_DM_WSJ0/modules/network.py +252 -0
- models/SepReformer/source.txt +1 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Separate[[:space:]]and[[:space:]]Reconstruct.[[:space:]]Asymmetric[[:space:]]Encoder-Decoder[[:space:]]for[[:space:]]Speech[[:space:]]Separation.pdf filter=lfs diff=lfs merge=lfs -text
|
Separate and Reconstruct. Asymmetric Encoder-Decoder for Speech Separation.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:97766a617509db55e816689f3f3e8e52c03b06ebf04f98f03e298a5556a4e898
|
| 3 |
+
size 1952305
|
code/SepReformer.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dc9b31d464b79b6ac037879b160445f53a4de9e4a411cce0954fc24c0ff7706d
|
| 3 |
+
size 16944535
|
code/sepformer-tse.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13905e009d88354fcb21579e11f5dcebe3114eb2d238eab0fa4cc11f9cc237ea
|
| 3 |
+
size 114198916
|
models/SepReformer/SepReformer_Base_WSJ0/configs.yaml
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
project: "[Project] SepReformer" ### Dont't change
|
| 2 |
+
notes: "SepReformer final version" ### Insert schanges(plz write details !!!)
|
| 3 |
+
# ------------------------------------------------------------------------------------------------------------------------------ #
|
| 4 |
+
config:
|
| 5 |
+
# ------------------------------------------------------------ #
|
| 6 |
+
dataset:
|
| 7 |
+
max_len: 32000
|
| 8 |
+
sampling_rate: 8000
|
| 9 |
+
scp_dir: "data/scp_ss_8k"
|
| 10 |
+
train:
|
| 11 |
+
mixture: "tr_mix.scp"
|
| 12 |
+
spk1: "tr_s1.scp"
|
| 13 |
+
spk2: "tr_s2.scp"
|
| 14 |
+
dynamic_mixing: false
|
| 15 |
+
valid:
|
| 16 |
+
mixture: "cv_mix.scp"
|
| 17 |
+
spk1: "cv_s1.scp"
|
| 18 |
+
spk2: "cv_s2.scp"
|
| 19 |
+
test:
|
| 20 |
+
mixture: "tt_mix.scp"
|
| 21 |
+
spk1: "tt_s1.scp"
|
| 22 |
+
spk2: "tt_s2.scp"
|
| 23 |
+
# ------------------------------------------------------------ #
|
| 24 |
+
dataloader:
|
| 25 |
+
batch_size: 2
|
| 26 |
+
pin_memory: false
|
| 27 |
+
num_workers: 12
|
| 28 |
+
drop_last: false
|
| 29 |
+
# ------------------------------------------------------------ #
|
| 30 |
+
model:
|
| 31 |
+
num_stages: &var_model_num_stages 4 # R
|
| 32 |
+
num_spks: &var_model_num_spks 2
|
| 33 |
+
module_audio_enc:
|
| 34 |
+
in_channels: 1
|
| 35 |
+
out_channels: &var_model_audio_enc_out_channels 256
|
| 36 |
+
kernel_size: &var_model_audio_enc_kernel_size 16 # L
|
| 37 |
+
stride: &var_model_audio_enc_stride 4 # S
|
| 38 |
+
groups: 1
|
| 39 |
+
bias: false
|
| 40 |
+
module_feature_projector:
|
| 41 |
+
num_channels: *var_model_audio_enc_out_channels
|
| 42 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 43 |
+
out_channels: &feature_projector_out_channels 128 # F
|
| 44 |
+
kernel_size: 1
|
| 45 |
+
bias: false
|
| 46 |
+
module_separator:
|
| 47 |
+
num_stages: *var_model_num_stages
|
| 48 |
+
relative_positional_encoding:
|
| 49 |
+
in_channels: *feature_projector_out_channels
|
| 50 |
+
num_heads: 8
|
| 51 |
+
maxlen: 2000
|
| 52 |
+
embed_v: false
|
| 53 |
+
enc_stage:
|
| 54 |
+
global_blocks:
|
| 55 |
+
in_channels: *feature_projector_out_channels
|
| 56 |
+
num_mha_heads: 8
|
| 57 |
+
dropout_rate: 0.05
|
| 58 |
+
local_blocks:
|
| 59 |
+
in_channels: *feature_projector_out_channels
|
| 60 |
+
kernel_size: 65
|
| 61 |
+
dropout_rate: 0.05
|
| 62 |
+
down_conv_layer:
|
| 63 |
+
in_channels: *feature_projector_out_channels
|
| 64 |
+
samp_kernel_size: &var_model_samp_kernel_size 5
|
| 65 |
+
spk_split_stage:
|
| 66 |
+
in_channels: *feature_projector_out_channels
|
| 67 |
+
num_spks: *var_model_num_spks
|
| 68 |
+
simple_fusion:
|
| 69 |
+
out_channels: *feature_projector_out_channels
|
| 70 |
+
dec_stage:
|
| 71 |
+
num_spks: *var_model_num_spks
|
| 72 |
+
global_blocks:
|
| 73 |
+
in_channels: *feature_projector_out_channels
|
| 74 |
+
num_mha_heads: 8
|
| 75 |
+
dropout_rate: 0.05
|
| 76 |
+
local_blocks:
|
| 77 |
+
in_channels: *feature_projector_out_channels
|
| 78 |
+
kernel_size: 65
|
| 79 |
+
dropout_rate: 0.05
|
| 80 |
+
spk_attention:
|
| 81 |
+
in_channels: *feature_projector_out_channels
|
| 82 |
+
num_mha_heads: 8
|
| 83 |
+
dropout_rate: 0.05
|
| 84 |
+
module_output_layer:
|
| 85 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 86 |
+
out_channels: *feature_projector_out_channels
|
| 87 |
+
num_spks: *var_model_num_spks
|
| 88 |
+
module_audio_dec:
|
| 89 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 90 |
+
out_channels: 1
|
| 91 |
+
kernel_size: *var_model_audio_enc_kernel_size
|
| 92 |
+
stride: *var_model_audio_enc_stride
|
| 93 |
+
bias: false
|
| 94 |
+
# ------------------------------------------------------------ #
|
| 95 |
+
criterion:
|
| 96 |
+
name: ["PIT_SISNR_mag", "PIT_SISNR_time", "PIT_SISNRi", "PIT_SDRi"] ### Choose a torch.nn's loss function class(=attribute) e.g. ["L1Loss", "MSELoss", "CrossEntropyLoss", ...]
|
| 97 |
+
PIT_SISNR_mag:
|
| 98 |
+
frame_length: 512
|
| 99 |
+
frame_shift: 128
|
| 100 |
+
window: 'hann'
|
| 101 |
+
num_stages: *var_model_num_stages
|
| 102 |
+
num_spks: *var_model_num_spks
|
| 103 |
+
scale_inv: true
|
| 104 |
+
mel_opt: false
|
| 105 |
+
PIT_SISNR_time:
|
| 106 |
+
num_spks: *var_model_num_spks
|
| 107 |
+
scale_inv: true
|
| 108 |
+
PIT_SISNRi:
|
| 109 |
+
num_spks: *var_model_num_spks
|
| 110 |
+
scale_inv: true
|
| 111 |
+
PIT_SDRi:
|
| 112 |
+
dump: 0
|
| 113 |
+
# ------------------------------------------------------------ #
|
| 114 |
+
optimizer:
|
| 115 |
+
name: ["AdamW"] ### Choose a torch.optim's class(=attribute) e.g. ["Adam", "AdamW", "SGD", ...]
|
| 116 |
+
AdamW:
|
| 117 |
+
lr: 1.0e-3
|
| 118 |
+
weight_decay: 1.0e-2
|
| 119 |
+
# ------------------------------------------------------------ #
|
| 120 |
+
scheduler:
|
| 121 |
+
name: ["ReduceLROnPlateau", "WarmupConstantSchedule"] ### Choose a torch.optim.lr_scheduler's class(=attribute) e.g. ["StepLR", "ReduceLROnPlateau", "Custom"]
|
| 122 |
+
ReduceLROnPlateau:
|
| 123 |
+
mode: "min"
|
| 124 |
+
min_lr: 1.0e-10
|
| 125 |
+
factor: 0.8
|
| 126 |
+
patience: 2
|
| 127 |
+
WarmupConstantSchedule:
|
| 128 |
+
warmup_steps: 1000
|
| 129 |
+
# ------------------------------------------------------------ #
|
| 130 |
+
check_computations:
|
| 131 |
+
dummy_len: 16000
|
| 132 |
+
# ------------------------------------------------------------ #
|
| 133 |
+
engine:
|
| 134 |
+
max_epoch: 200
|
| 135 |
+
gpuid: "0" ### "0"(single-gpu) or "0, 1" (multi-gpu)
|
| 136 |
+
mvn: false
|
| 137 |
+
clip_norm: 5
|
| 138 |
+
start_scheduling: 50
|
| 139 |
+
test_epochs: [100, 120, 150, 170]
|
models/SepReformer/SepReformer_Base_WSJ0/dataset.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import librosa as audio_lib
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from utils import util_dataset
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from loguru import logger
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
|
| 12 |
+
@logger_wraps()
|
| 13 |
+
def get_dataloaders(args, dataset_config, loader_config):
|
| 14 |
+
# create dataset object for each partition
|
| 15 |
+
partitions = ["test"] if "test" in args.engine_mode else ["train", "valid", "test"]
|
| 16 |
+
dataloaders = {}
|
| 17 |
+
for partition in partitions:
|
| 18 |
+
scp_config_mix = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['mixture'])
|
| 19 |
+
scp_config_spk = [os.path.join(dataset_config["scp_dir"], dataset_config[partition][spk_key]) for spk_key in dataset_config[partition] if spk_key.startswith('spk')]
|
| 20 |
+
dynamic_mixing = dataset_config[partition]["dynamic_mixing"] if partition == 'train' else False
|
| 21 |
+
dataset = MyDataset(
|
| 22 |
+
max_len = dataset_config['max_len'],
|
| 23 |
+
fs = dataset_config['sampling_rate'],
|
| 24 |
+
partition = partition,
|
| 25 |
+
wave_scp_srcs = scp_config_spk,
|
| 26 |
+
wave_scp_mix = scp_config_mix,
|
| 27 |
+
dynamic_mixing = dynamic_mixing)
|
| 28 |
+
dataloader = DataLoader(
|
| 29 |
+
dataset = dataset,
|
| 30 |
+
batch_size = 1 if partition == 'test' else loader_config["batch_size"],
|
| 31 |
+
shuffle = True, # only train: (partition == 'train') / all: True
|
| 32 |
+
pin_memory = loader_config["pin_memory"],
|
| 33 |
+
num_workers = loader_config["num_workers"],
|
| 34 |
+
drop_last = loader_config["drop_last"],
|
| 35 |
+
collate_fn = _collate)
|
| 36 |
+
dataloaders[partition] = dataloader
|
| 37 |
+
return dataloaders
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _collate(egs):
|
| 41 |
+
"""
|
| 42 |
+
Transform utterance index into a minbatch
|
| 43 |
+
|
| 44 |
+
Arguments:
|
| 45 |
+
index: a list type [{},{},{}]
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
input_sizes: a tensor correspond to utterance length
|
| 49 |
+
input_feats: packed sequence to feed networks
|
| 50 |
+
source_attr/target_attr: dictionary contains spectrogram/phase needed in loss computation
|
| 51 |
+
"""
|
| 52 |
+
def __prepare_target_rir(dict_lsit, index):
|
| 53 |
+
return torch.nn.utils.rnn.pad_sequence([torch.tensor(d["src"][index], dtype=torch.float32) for d in dict_lsit], batch_first=True)
|
| 54 |
+
if type(egs) is not list: raise ValueError("Unsupported index type({})".format(type(egs)))
|
| 55 |
+
num_spks = 2 # you need to set this paramater by yourself
|
| 56 |
+
dict_list = sorted([eg for eg in egs], key=lambda x: x['num_sample'], reverse=True)
|
| 57 |
+
mixture = torch.nn.utils.rnn.pad_sequence([torch.tensor(d['mix'], dtype=torch.float32) for d in dict_list], batch_first=True)
|
| 58 |
+
src = [__prepare_target_rir(dict_list, index) for index in range(num_spks)]
|
| 59 |
+
input_sizes = torch.tensor([d['num_sample'] for d in dict_list], dtype=torch.float32)
|
| 60 |
+
key = [d['key'] for d in dict_list]
|
| 61 |
+
return input_sizes, mixture, src, key
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@logger_wraps()
|
| 65 |
+
class MyDataset(Dataset):
|
| 66 |
+
def __init__(self, max_len, fs, partition, wave_scp_srcs, wave_scp_mix, wave_scp_noise=None, dynamic_mixing=False, speed_list=None):
|
| 67 |
+
self.partition = partition
|
| 68 |
+
for wave_scp_src in wave_scp_srcs:
|
| 69 |
+
if not os.path.exists(wave_scp_src): raise FileNotFoundError(f"Could not find file {wave_scp_src}")
|
| 70 |
+
self.max_len = max_len
|
| 71 |
+
self.fs = fs
|
| 72 |
+
self.wave_dict_srcs = [util_dataset.parse_scps(wave_scp_src) for wave_scp_src in wave_scp_srcs]
|
| 73 |
+
self.wave_dict_mix = util_dataset.parse_scps(wave_scp_mix)
|
| 74 |
+
self.wave_dict_noise = util_dataset.parse_scps(wave_scp_noise) if wave_scp_noise else None
|
| 75 |
+
self.wave_keys = list(self.wave_dict_mix.keys())
|
| 76 |
+
logger.info(f"Create MyDataset for {wave_scp_mix} with {len(self.wave_dict_mix)} utterances")
|
| 77 |
+
self.dynamic_mixing = dynamic_mixing
|
| 78 |
+
|
| 79 |
+
def __len__(self):
|
| 80 |
+
return len(self.wave_dict_mix)
|
| 81 |
+
|
| 82 |
+
def __contains__(self, key):
|
| 83 |
+
return key in self.wave_dict_mix
|
| 84 |
+
|
| 85 |
+
def _dynamic_mixing(self, key):
|
| 86 |
+
def __match_length(wav, len_data) :
|
| 87 |
+
leftover = len(wav) - len_data
|
| 88 |
+
idx = random.randint(0,leftover)
|
| 89 |
+
wav = wav[idx:idx+len_data]
|
| 90 |
+
return wav
|
| 91 |
+
|
| 92 |
+
samps_src = []
|
| 93 |
+
src_len = []
|
| 94 |
+
# dyanmic source choice
|
| 95 |
+
# checking whether it is the same speaker
|
| 96 |
+
while True:
|
| 97 |
+
key_random = random.choice(list(self.wave_dict_srcs[0].keys()))
|
| 98 |
+
tmp1 = key.split('_')[1][:3] != key_random.split('_')[3][:3]
|
| 99 |
+
tmp2 = key.split('_')[3][:3] != key_random.split('_')[1][:3]
|
| 100 |
+
if tmp1 and tmp2: break
|
| 101 |
+
|
| 102 |
+
idx1, idx2 = (0, 1) if random.random() > 0.5 else (1, 0)
|
| 103 |
+
files = [self.wave_dict_srcs[idx1][key], self.wave_dict_srcs[idx2][key_random]]
|
| 104 |
+
|
| 105 |
+
# load
|
| 106 |
+
for file in files:
|
| 107 |
+
if not os.path.exists(file): raise FileNotFoundError("Input file {} do not exists!".format(file))
|
| 108 |
+
samps_tmp, _ = audio_lib.load(file, sr=self.fs)
|
| 109 |
+
# mixing with random gains
|
| 110 |
+
gain = pow(10,-random.uniform(-2.5,2.5)/20)
|
| 111 |
+
# Speed Augmentation
|
| 112 |
+
samps_tmp = np.array(self.speed_aug(torch.tensor(samps_tmp))[0])
|
| 113 |
+
samps_src.append(gain*samps_tmp)
|
| 114 |
+
src_len.append(len(samps_tmp))
|
| 115 |
+
|
| 116 |
+
# matching the audio length
|
| 117 |
+
min_len = min(src_len)
|
| 118 |
+
|
| 119 |
+
samps_src = [__match_length(s, min_len) for s in samps_src]
|
| 120 |
+
samps_mix = sum(samps_src)
|
| 121 |
+
|
| 122 |
+
# ! truncated along to the sample Length "L"
|
| 123 |
+
if len(samps_mix)%4 != 0:
|
| 124 |
+
remains = len(samps_mix)%4
|
| 125 |
+
samps_mix = samps_mix[:-remains]
|
| 126 |
+
samps_src = [s[:-remains] for s in samps_src]
|
| 127 |
+
|
| 128 |
+
if self.partition != "test":
|
| 129 |
+
if len(samps_mix) > self.max_len:
|
| 130 |
+
start = random.randint(0, len(samps_mix)-self.max_len)
|
| 131 |
+
samps_mix = samps_mix[start:start+self.max_len]
|
| 132 |
+
samps_src = [s[start:start+self.max_len] for s in samps_src]
|
| 133 |
+
return samps_mix, samps_src
|
| 134 |
+
|
| 135 |
+
def _direct_load(self, key):
|
| 136 |
+
samps_src = []
|
| 137 |
+
files = [wave_dict_src[key] for wave_dict_src in self.wave_dict_srcs]
|
| 138 |
+
for file in files:
|
| 139 |
+
if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
|
| 140 |
+
samps_tmp, _ = audio_lib.load(file, sr=self.fs)
|
| 141 |
+
samps_src.append(samps_tmp)
|
| 142 |
+
|
| 143 |
+
file = self.wave_dict_mix[key]
|
| 144 |
+
if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
|
| 145 |
+
samps_mix, _ = audio_lib.load(file, sr=self.fs)
|
| 146 |
+
|
| 147 |
+
# Truncate samples as needed
|
| 148 |
+
if len(samps_mix) % 4 != 0:
|
| 149 |
+
remains = len(samps_mix) % 4
|
| 150 |
+
samps_mix = samps_mix[:-remains]
|
| 151 |
+
samps_src = [s[:-remains] for s in samps_src]
|
| 152 |
+
|
| 153 |
+
if self.partition != "test":
|
| 154 |
+
if len(samps_mix) > self.max_len:
|
| 155 |
+
start = random.randint(0,len(samps_mix)-self.max_len)
|
| 156 |
+
samps_mix = samps_mix[start:start+self.max_len]
|
| 157 |
+
samps_src = [s[start:start+self.max_len] for s in samps_src]
|
| 158 |
+
|
| 159 |
+
return samps_mix, samps_src
|
| 160 |
+
|
| 161 |
+
def __getitem__(self, index):
|
| 162 |
+
key = self.wave_keys[index]
|
| 163 |
+
if any(key not in self.wave_dict_srcs[i] for i in range(len(self.wave_dict_srcs))) or key not in self.wave_dict_mix: raise KeyError(f"Could not find utterance {key}")
|
| 164 |
+
samps_mix, samps_src = self._dynamic_mixing(key) if self.dynamic_mixing else self._direct_load(key)
|
| 165 |
+
return {"num_sample": samps_mix.shape[0], "mix": samps_mix, "src": samps_src, "key": key}
|
models/SepReformer/SepReformer_Base_WSJ0/engine.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import csv
|
| 4 |
+
import time
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
import librosa
|
| 7 |
+
from loguru import logger
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from utils import util_engine, functions
|
| 10 |
+
from utils.decorators import *
|
| 11 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@logger_wraps()
|
| 15 |
+
class Engine(object):
|
| 16 |
+
def __init__(self, args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device):
|
| 17 |
+
|
| 18 |
+
''' Default setting '''
|
| 19 |
+
self.engine_mode = args.engine_mode
|
| 20 |
+
self.out_wav_dir = args.out_wav_dir
|
| 21 |
+
self.config = config
|
| 22 |
+
self.gpuid = gpuid
|
| 23 |
+
self.device = device
|
| 24 |
+
self.model = model.to(self.device)
|
| 25 |
+
self.dataloaders = dataloaders # self.dataloaders['train'] or ['valid'] or ['test']
|
| 26 |
+
self.PIT_SISNR_mag_loss, self.PIT_SISNR_time_loss, self.PIT_SISNRi_loss, self.PIT_SDRi_loss = criterions
|
| 27 |
+
self.main_optimizer = optimizers[0]
|
| 28 |
+
self.main_scheduler, self.warmup_scheduler = schedulers
|
| 29 |
+
|
| 30 |
+
self.pretrain_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "pretrain_weights")
|
| 31 |
+
os.makedirs(self.pretrain_weights_path, exist_ok=True)
|
| 32 |
+
self.scratch_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "scratch_weights")
|
| 33 |
+
os.makedirs(self.scratch_weights_path, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
self.checkpoint_path = self.pretrain_weights_path if any(file.endswith(('.pt', '.pt', '.pkl')) for file in os.listdir(self.pretrain_weights_path)) else self.scratch_weights_path
|
| 36 |
+
self.start_epoch = util_engine.load_last_checkpoint_n_get_epoch(self.checkpoint_path, self.model, self.main_optimizer, location=self.device)
|
| 37 |
+
|
| 38 |
+
# Logging
|
| 39 |
+
util_engine.model_params_mac_summary(
|
| 40 |
+
model=self.model,
|
| 41 |
+
input=torch.randn(1, self.config['check_computations']['dummy_len']).to(self.device),
|
| 42 |
+
dummy_input=torch.rand(1, self.config['check_computations']['dummy_len']).to(self.device),
|
| 43 |
+
metrics=['ptflops', 'thop', 'torchinfo']
|
| 44 |
+
# metrics=['ptflops']
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
logger.info(f"Clip gradient by 2-norm {self.config['engine']['clip_norm']}")
|
| 48 |
+
|
| 49 |
+
@logger_wraps()
|
| 50 |
+
def _train(self, dataloader, epoch):
|
| 51 |
+
self.model.train()
|
| 52 |
+
tot_loss_freq = [0 for _ in range(self.model.num_stages)]
|
| 53 |
+
tot_loss_time, num_batch = 0, 0
|
| 54 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:25}{r_bar}{bar:-10b}', colour="YELLOW", dynamic_ncols=True)
|
| 55 |
+
for input_sizes, mixture, src, _ in dataloader:
|
| 56 |
+
nnet_input = mixture
|
| 57 |
+
nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
|
| 58 |
+
num_batch += 1
|
| 59 |
+
pbar.update(1)
|
| 60 |
+
# Scheduler learning rate for warm-up (Iteration-based update for transformers)
|
| 61 |
+
if epoch == 1: self.warmup_scheduler.step()
|
| 62 |
+
nnet_input = nnet_input.to(self.device)
|
| 63 |
+
self.main_optimizer.zero_grad()
|
| 64 |
+
estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 65 |
+
cur_loss_s_bn = 0
|
| 66 |
+
cur_loss_s_bn = []
|
| 67 |
+
for idx, estim_src_value in enumerate(estim_src_bn):
|
| 68 |
+
cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
|
| 69 |
+
tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
|
| 70 |
+
cur_loss_s = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
|
| 71 |
+
tot_loss_time += cur_loss_s.item() / self.config['model']['num_spks']
|
| 72 |
+
alpha = 0.4 * 0.8**(1+(epoch-101)//5) if epoch > 100 else 0.4
|
| 73 |
+
cur_loss = (1-alpha) * cur_loss_s + alpha * sum(cur_loss_s_bn) / len(cur_loss_s_bn)
|
| 74 |
+
cur_loss = cur_loss / self.config['model']['num_spks']
|
| 75 |
+
cur_loss.backward()
|
| 76 |
+
if self.config['engine']['clip_norm']: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['engine']['clip_norm'])
|
| 77 |
+
self.main_optimizer.step()
|
| 78 |
+
dict_loss = {"T_Loss": tot_loss_time / num_batch}
|
| 79 |
+
dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
|
| 80 |
+
pbar.set_postfix(dict_loss)
|
| 81 |
+
pbar.close()
|
| 82 |
+
tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
|
| 83 |
+
return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
|
| 84 |
+
|
| 85 |
+
@logger_wraps()
|
| 86 |
+
def _validate(self, dataloader):
|
| 87 |
+
self.model.eval()
|
| 88 |
+
tot_loss_freq = [0 for _ in range(self.model.num_stages)]
|
| 89 |
+
tot_loss_time, num_batch = 0, 0
|
| 90 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="RED", dynamic_ncols=True)
|
| 91 |
+
with torch.inference_mode():
|
| 92 |
+
for input_sizes, mixture, src, _ in dataloader:
|
| 93 |
+
nnet_input = mixture
|
| 94 |
+
nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
|
| 95 |
+
nnet_input = nnet_input.to(self.device)
|
| 96 |
+
num_batch += 1
|
| 97 |
+
pbar.update(1)
|
| 98 |
+
estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 99 |
+
cur_loss_s_bn = []
|
| 100 |
+
for idx, estim_src_value in enumerate(estim_src_bn):
|
| 101 |
+
cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
|
| 102 |
+
tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
|
| 103 |
+
cur_loss_s_SDR = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
|
| 104 |
+
tot_loss_time += cur_loss_s_SDR.item() / self.config['model']['num_spks']
|
| 105 |
+
dict_loss = {"T_Loss":tot_loss_time / num_batch}
|
| 106 |
+
dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
|
| 107 |
+
pbar.set_postfix(dict_loss)
|
| 108 |
+
pbar.close()
|
| 109 |
+
tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
|
| 110 |
+
return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
|
| 111 |
+
|
| 112 |
+
@logger_wraps()
|
| 113 |
+
def _test(self, dataloader, wav_dir=None):
|
| 114 |
+
self.model.eval()
|
| 115 |
+
total_loss_SISNRi, total_loss_SDRi, num_batch = 0, 0, 0
|
| 116 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="grey", dynamic_ncols=True)
|
| 117 |
+
with torch.inference_mode():
|
| 118 |
+
csv_file_name_sisnr = os.path.join(os.path.dirname(__file__),'test_SISNRi_value.csv')
|
| 119 |
+
csv_file_name_sdr = os.path.join(os.path.dirname(__file__),'test_SDRi_value.csv')
|
| 120 |
+
with open(csv_file_name_sisnr, 'w', newline='') as csvfile_sisnr, open(csv_file_name_sdr, 'w', newline='') as csvfile_sdr:
|
| 121 |
+
idx = 0
|
| 122 |
+
writer_sisnr = csv.writer(csvfile_sisnr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
| 123 |
+
writer_sdr = csv.writer(csvfile_sdr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
| 124 |
+
for input_sizes, mixture, src, key in dataloader:
|
| 125 |
+
if len(key) > 1:
|
| 126 |
+
raise("batch size is not one!!")
|
| 127 |
+
nnet_input = mixture.to(self.device)
|
| 128 |
+
num_batch += 1
|
| 129 |
+
pbar.update(1)
|
| 130 |
+
estim_src, _ = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 131 |
+
cur_loss_SISNRi, cur_loss_SISNRi_src = self.PIT_SISNRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src, eps=1.0e-15)
|
| 132 |
+
total_loss_SISNRi += cur_loss_SISNRi.item() / self.config['model']['num_spks']
|
| 133 |
+
cur_loss_SDRi, cur_loss_SDRi_src = self.PIT_SDRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src)
|
| 134 |
+
total_loss_SDRi += cur_loss_SDRi.item() / self.config['model']['num_spks']
|
| 135 |
+
writer_sisnr.writerow([key[0][:-4]] + [cur_loss_SISNRi_src[i].item() for i in range(self.config['model']['num_spks'])])
|
| 136 |
+
writer_sdr.writerow([key[0][:-4]] + [cur_loss_SDRi_src[i].item() for i in range(self.config['model']['num_spks'])])
|
| 137 |
+
if self.engine_mode == "test_save":
|
| 138 |
+
if wav_dir == None: wav_dir = os.path.join(os.path.dirname(__file__),"wav_out")
|
| 139 |
+
if wav_dir and not os.path.exists(wav_dir): os.makedirs(wav_dir)
|
| 140 |
+
mixture = torch.squeeze(mixture).cpu().data.numpy()
|
| 141 |
+
sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_mixture.wav'), 0.5*mixture/max(abs(mixture)), 8000)
|
| 142 |
+
for i in range(self.config['model']['num_spks']):
|
| 143 |
+
src = torch.squeeze(estim_src[i]).cpu().data.numpy()
|
| 144 |
+
sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_out_'+str(i)+'.wav'), 0.5*src/max(abs(src)), 8000)
|
| 145 |
+
idx += 1
|
| 146 |
+
dict_loss = {"SiSNRi": total_loss_SISNRi/num_batch, "SDRi": total_loss_SDRi/num_batch}
|
| 147 |
+
pbar.set_postfix(dict_loss)
|
| 148 |
+
pbar.close()
|
| 149 |
+
return total_loss_SISNRi/num_batch, total_loss_SDRi/num_batch, num_batch
|
| 150 |
+
|
| 151 |
+
@logger_wraps()
|
| 152 |
+
def _inference_sample(self, sample):
|
| 153 |
+
self.model.eval()
|
| 154 |
+
self.fs = self.config["dataset"]["sampling_rate"]
|
| 155 |
+
mixture, _ = librosa.load(sample,sr=self.fs)
|
| 156 |
+
mixture = torch.tensor(mixture, dtype=torch.float32)[None]
|
| 157 |
+
self.stride = self.config["model"]["module_audio_enc"]["stride"]
|
| 158 |
+
remains = mixture.shape[-1] % self.stride
|
| 159 |
+
if remains != 0:
|
| 160 |
+
padding = self.stride - remains
|
| 161 |
+
mixture_padded = torch.nn.functional.pad(mixture, (0, padding), "constant", 0)
|
| 162 |
+
else:
|
| 163 |
+
mixture_padded = mixture
|
| 164 |
+
|
| 165 |
+
with torch.inference_mode():
|
| 166 |
+
nnet_input = mixture_padded.to(self.device)
|
| 167 |
+
estim_src, _ = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 168 |
+
mixture = torch.squeeze(mixture).cpu().numpy()
|
| 169 |
+
sf.write(sample[:-4]+'_in.wav', 0.9*mixture/max(abs(mixture)), self.fs)
|
| 170 |
+
for i in range(self.config['model']['num_spks']):
|
| 171 |
+
src = torch.squeeze(estim_src[i][...,:mixture.shape[-1]]).cpu().data.numpy()
|
| 172 |
+
sf.write(sample[:-4]+'_out_'+str(i)+'.wav', 0.9*src/max(abs(src)), self.fs)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@logger_wraps()
|
| 176 |
+
def run(self):
|
| 177 |
+
with torch.cuda.device(self.device):
|
| 178 |
+
writer_src = SummaryWriter(os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/tensorboard"))
|
| 179 |
+
if "test" in self.engine_mode:
|
| 180 |
+
on_test_start = time.time()
|
| 181 |
+
test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'], self.out_wav_dir)
|
| 182 |
+
on_test_end = time.time()
|
| 183 |
+
logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
|
| 184 |
+
logger.info(f"Testing done!")
|
| 185 |
+
else:
|
| 186 |
+
start_time = time.time()
|
| 187 |
+
if self.start_epoch > 1:
|
| 188 |
+
init_loss_time, init_loss_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
|
| 189 |
+
else:
|
| 190 |
+
init_loss_time, init_loss_freq = 0, 0
|
| 191 |
+
end_time = time.time()
|
| 192 |
+
logger.info(f"[INIT] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: Loss_t = {init_loss_time:.4f} dB | Loss_f = {init_loss_freq:.4f} dB | Speed = ({end_time-start_time:.2f}s)")
|
| 193 |
+
for epoch in range(self.start_epoch, self.config['engine']['max_epoch']):
|
| 194 |
+
valid_loss_best = init_loss_time
|
| 195 |
+
train_start_time = time.time()
|
| 196 |
+
train_loss_src_time, train_loss_src_freq, train_num_batch = self._train(self.dataloaders['train'], epoch)
|
| 197 |
+
train_end_time = time.time()
|
| 198 |
+
valid_start_time = time.time()
|
| 199 |
+
valid_loss_src_time, valid_loss_src_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
|
| 200 |
+
valid_end_time = time.time()
|
| 201 |
+
if epoch > self.config['engine']['start_scheduling']: self.main_scheduler.step(valid_loss_src_time)
|
| 202 |
+
logger.info(f"[TRAIN] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {train_loss_src_time:.4f} dB | Loss_f = {train_loss_src_freq:.4f} dB | Speed = ({train_end_time - train_start_time:.2f}s/{train_num_batch:d})")
|
| 203 |
+
logger.info(f"[VALID] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {valid_loss_src_time:.4f} dB | Loss_f = {valid_loss_src_freq:.4f} dB | Speed = ({valid_end_time - valid_start_time:.2f}s/{valid_num_batch:d})")
|
| 204 |
+
if epoch in self.config['engine']['test_epochs']:
|
| 205 |
+
on_test_start = time.time()
|
| 206 |
+
test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'])
|
| 207 |
+
on_test_end = time.time()
|
| 208 |
+
logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
|
| 209 |
+
valid_loss_best = util_engine.save_checkpoint_per_best(valid_loss_best, valid_loss_src_time, train_loss_src_time, epoch, self.model, self.main_optimizer, self.checkpoint_path, self.wandb_run)
|
| 210 |
+
# Logging to monitoring tools (Tensorboard && Wandb)
|
| 211 |
+
writer_src.add_scalars("Metrics", {
|
| 212 |
+
'Loss_train_time': train_loss_src_time,
|
| 213 |
+
'Loss_valid_time': valid_loss_src_time}, epoch)
|
| 214 |
+
writer_src.add_scalars("Learning Rate", self.main_optimizer.param_groups[0]['lr'], epoch)
|
| 215 |
+
writer_src.flush()
|
| 216 |
+
logger.info(f"Training for {self.config['engine']['max_epoch']} epoches done!")
|
models/SepReformer/SepReformer_Base_WSJ0/log/scratch_weights/epoch.0180.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:14569febefb19900026a350c7b31ca6a927ce4bac7fa83269902d8c6437f0d11
|
| 3 |
+
size 134
|
models/SepReformer/SepReformer_Base_WSJ0/main.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from .dataset import get_dataloaders
|
| 5 |
+
from .model import Model
|
| 6 |
+
from .engine import Engine
|
| 7 |
+
from utils import util_system, util_implement
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
|
| 10 |
+
# Setup logger
|
| 11 |
+
log_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/system_log.log")
|
| 12 |
+
logger.add(log_file_path, level="DEBUG", mode="w")
|
| 13 |
+
|
| 14 |
+
@logger_wraps()
|
| 15 |
+
def main(args):
|
| 16 |
+
|
| 17 |
+
''' Build Setting '''
|
| 18 |
+
# Call configuration file (configs.yaml)
|
| 19 |
+
yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs.yaml")
|
| 20 |
+
yaml_dict = util_system.parse_yaml(yaml_path)
|
| 21 |
+
|
| 22 |
+
# Run wandb and get configuration
|
| 23 |
+
config = yaml_dict["config"] # wandb login success or fail
|
| 24 |
+
|
| 25 |
+
# Call DataLoader [train / valid / test / etc...]
|
| 26 |
+
dataloaders = get_dataloaders(args, config["dataset"], config["dataloader"])
|
| 27 |
+
|
| 28 |
+
''' Build Model '''
|
| 29 |
+
# Call network model
|
| 30 |
+
model = Model(**config["model"])
|
| 31 |
+
|
| 32 |
+
''' Build Engine '''
|
| 33 |
+
# Call gpu id & device
|
| 34 |
+
gpuid = tuple(map(int, config["engine"]["gpuid"].split(',')))
|
| 35 |
+
device = torch.device(f'cuda:{gpuid[0]}')
|
| 36 |
+
|
| 37 |
+
# Call Implement [criterion / optimizer / scheduler]
|
| 38 |
+
criterions = util_implement.CriterionFactory(config["criterion"], device).get_criterions()
|
| 39 |
+
optimizers = util_implement.OptimizerFactory(config["optimizer"], model.parameters()).get_optimizers()
|
| 40 |
+
schedulers = util_implement.SchedulerFactory(config["scheduler"], optimizers).get_schedulers()
|
| 41 |
+
|
| 42 |
+
# Call & Run Engine
|
| 43 |
+
engine = Engine(args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device)
|
| 44 |
+
if args.engine_mode == 'infer_sample':
|
| 45 |
+
engine._inference_sample(args.sample_file)
|
| 46 |
+
else:
|
| 47 |
+
engine.run()
|
models/SepReformer/SepReformer_Base_WSJ0/model.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('../')
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings('ignore')
|
| 7 |
+
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from .modules.module import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@logger_wraps()
|
| 13 |
+
class Model(torch.nn.Module):
|
| 14 |
+
def __init__(self,
|
| 15 |
+
num_stages: int,
|
| 16 |
+
num_spks: int,
|
| 17 |
+
module_audio_enc: dict,
|
| 18 |
+
module_feature_projector: dict,
|
| 19 |
+
module_separator: dict,
|
| 20 |
+
module_output_layer: dict,
|
| 21 |
+
module_audio_dec: dict):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.num_stages = num_stages
|
| 24 |
+
self.num_spks = num_spks
|
| 25 |
+
self.audio_encoder = AudioEncoder(**module_audio_enc)
|
| 26 |
+
self.feature_projector = FeatureProjector(**module_feature_projector)
|
| 27 |
+
self.separator = Separator(**module_separator)
|
| 28 |
+
self.out_layer = OutputLayer(**module_output_layer)
|
| 29 |
+
self.audio_decoder = AudioDecoder(**module_audio_dec)
|
| 30 |
+
|
| 31 |
+
# Aux_loss
|
| 32 |
+
self.out_layer_bn = torch.nn.ModuleList([])
|
| 33 |
+
self.decoder_bn = torch.nn.ModuleList([])
|
| 34 |
+
for _ in range(self.num_stages):
|
| 35 |
+
self.out_layer_bn.append(OutputLayer(**module_output_layer, masking=True))
|
| 36 |
+
self.decoder_bn.append(AudioDecoder(**module_audio_dec))
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
encoder_output = self.audio_encoder(x)
|
| 40 |
+
projected_feature = self.feature_projector(encoder_output)
|
| 41 |
+
last_stage_output, each_stage_outputs = self.separator(projected_feature)
|
| 42 |
+
out_layer_output = self.out_layer(last_stage_output, encoder_output)
|
| 43 |
+
each_spk_output = [out_layer_output[idx] for idx in range(self.num_spks)]
|
| 44 |
+
audio = [self.audio_decoder(each_spk_output[idx]) for idx in range(self.num_spks)]
|
| 45 |
+
|
| 46 |
+
# Aux_loss
|
| 47 |
+
audio_aux = []
|
| 48 |
+
for idx, each_stage_output in enumerate(each_stage_outputs):
|
| 49 |
+
each_stage_output = self.out_layer_bn[idx](torch.nn.functional.upsample(each_stage_output, encoder_output.shape[-1]), encoder_output)
|
| 50 |
+
out_aux = [each_stage_output[jdx] for jdx in range(self.num_spks)]
|
| 51 |
+
audio_aux.append([self.decoder_bn[idx](out_aux[jdx])[...,:x.shape[-1]] for jdx in range(self.num_spks)])
|
| 52 |
+
|
| 53 |
+
return audio, audio_aux
|
models/SepReformer/SepReformer_Base_WSJ0/modules/module.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('../')
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings('ignore')
|
| 7 |
+
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from .network import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AudioEncoder(torch.nn.Module):
|
| 13 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, bias: bool):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.conv1d = torch.nn.Conv1d(
|
| 16 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, groups=groups, bias=bias)
|
| 17 |
+
self.gelu = torch.nn.GELU()
|
| 18 |
+
|
| 19 |
+
def forward(self, x: torch.Tensor):
|
| 20 |
+
x = torch.unsqueeze(x, dim=0) if len(x.shape) == 1 else torch.unsqueeze(x, dim=1) # [T] - >[1, T] OR [B, T] -> [B, 1, T]
|
| 21 |
+
x = self.conv1d(x)
|
| 22 |
+
x = self.gelu(x)
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
class FeatureProjector(torch.nn.Module):
|
| 26 |
+
def __init__(self, num_channels: int, in_channels: int, out_channels: int, kernel_size: int, bias: bool):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.norm = torch.nn.GroupNorm(num_groups=1, num_channels=num_channels, eps=1e-8)
|
| 29 |
+
self.conv1d = torch.nn.Conv1d(
|
| 30 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor):
|
| 33 |
+
x = self.norm(x)
|
| 34 |
+
x = self.conv1d(x)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Separator(torch.nn.Module):
|
| 39 |
+
def __init__(self, num_stages: int, relative_positional_encoding: dict, enc_stage: dict, spk_split_stage: dict, simple_fusion:dict, dec_stage: dict):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
class RelativePositionalEncoding(torch.nn.Module):
|
| 43 |
+
def __init__(self, in_channels: int, num_heads: int, maxlen: int, embed_v=False):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.in_channels = in_channels
|
| 46 |
+
self.num_heads = num_heads
|
| 47 |
+
self.embedding_dim = self.in_channels // self.num_heads
|
| 48 |
+
self.maxlen = maxlen
|
| 49 |
+
self.pe_k = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim)
|
| 50 |
+
self.pe_v = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim) if embed_v else None
|
| 51 |
+
|
| 52 |
+
def forward(self, pos_seq: torch.Tensor):
|
| 53 |
+
pos_seq.clamp_(-self.maxlen, self.maxlen - 1)
|
| 54 |
+
pos_seq += self.maxlen
|
| 55 |
+
pe_k_output = self.pe_k(pos_seq)
|
| 56 |
+
pe_v_output = self.pe_v(pos_seq) if self.pe_v is not None else None
|
| 57 |
+
return pe_k_output, pe_v_output
|
| 58 |
+
|
| 59 |
+
class SepEncStage(torch.nn.Module):
|
| 60 |
+
def __init__(self, global_blocks: dict, local_blocks: dict, down_conv_layer: dict, down_conv=True):
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
class DownConvLayer(torch.nn.Module):
|
| 64 |
+
def __init__(self, in_channels: int, samp_kernel_size: int):
|
| 65 |
+
"""Construct an EncoderLayer object."""
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.down_conv = torch.nn.Conv1d(
|
| 68 |
+
in_channels=in_channels, out_channels=in_channels, kernel_size=samp_kernel_size, stride=2, padding=(samp_kernel_size-1)//2, groups=in_channels)
|
| 69 |
+
self.BN = torch.nn.BatchNorm1d(num_features=in_channels)
|
| 70 |
+
self.gelu = torch.nn.GELU()
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor):
|
| 73 |
+
x = x.permute([0, 2, 1])
|
| 74 |
+
x = self.down_conv(x)
|
| 75 |
+
x = self.BN(x)
|
| 76 |
+
x = self.gelu(x)
|
| 77 |
+
x = x.permute([0, 2, 1])
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
self.g_block_1 = GlobalBlock(**global_blocks)
|
| 81 |
+
self.l_block_1 = LocalBlock(**local_blocks)
|
| 82 |
+
|
| 83 |
+
self.g_block_2 = GlobalBlock(**global_blocks)
|
| 84 |
+
self.l_block_2 = LocalBlock(**local_blocks)
|
| 85 |
+
|
| 86 |
+
self.downconv = DownConvLayer(**down_conv_layer) if down_conv == True else None
|
| 87 |
+
|
| 88 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 89 |
+
'''
|
| 90 |
+
x: [B, N, T]
|
| 91 |
+
'''
|
| 92 |
+
x = self.g_block_1(x, pos_k)
|
| 93 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 94 |
+
x = self.l_block_1(x)
|
| 95 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 96 |
+
|
| 97 |
+
x = self.g_block_2(x, pos_k)
|
| 98 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 99 |
+
x = self.l_block_2(x)
|
| 100 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 101 |
+
|
| 102 |
+
skip = x
|
| 103 |
+
if self.downconv:
|
| 104 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 105 |
+
x = self.downconv(x)
|
| 106 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 107 |
+
# [BK, S, N]
|
| 108 |
+
return x, skip
|
| 109 |
+
|
| 110 |
+
class SpkSplitStage(torch.nn.Module):
|
| 111 |
+
def __init__(self, in_channels: int, num_spks: int):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.linear = torch.nn.Sequential(
|
| 114 |
+
torch.nn.Conv1d(in_channels, 4*in_channels*num_spks, kernel_size=1),
|
| 115 |
+
torch.nn.GLU(dim=-2),
|
| 116 |
+
torch.nn.Conv1d(2*in_channels*num_spks, in_channels*num_spks, kernel_size=1))
|
| 117 |
+
self.norm = torch.nn.GroupNorm(1, in_channels, eps=1e-8)
|
| 118 |
+
self.num_spks = num_spks
|
| 119 |
+
|
| 120 |
+
def forward(self, x: torch.Tensor):
|
| 121 |
+
x = self.linear(x)
|
| 122 |
+
B, _, T = x.shape
|
| 123 |
+
x = x.view(B*self.num_spks,-1, T).contiguous()
|
| 124 |
+
x = self.norm(x)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
class SepDecStage(torch.nn.Module):
|
| 128 |
+
def __init__(self, num_spks: int, global_blocks: dict, local_blocks: dict, spk_attention: dict):
|
| 129 |
+
super().__init__()
|
| 130 |
+
|
| 131 |
+
self.g_block_1 = GlobalBlock(**global_blocks)
|
| 132 |
+
self.l_block_1 = LocalBlock(**local_blocks)
|
| 133 |
+
self.spk_attn_1 = SpkAttention(**spk_attention)
|
| 134 |
+
|
| 135 |
+
self.g_block_2 = GlobalBlock(**global_blocks)
|
| 136 |
+
self.l_block_2 = LocalBlock(**local_blocks)
|
| 137 |
+
self.spk_attn_2 = SpkAttention(**spk_attention)
|
| 138 |
+
|
| 139 |
+
self.g_block_3 = GlobalBlock(**global_blocks)
|
| 140 |
+
self.l_block_3 = LocalBlock(**local_blocks)
|
| 141 |
+
self.spk_attn_3 = SpkAttention(**spk_attention)
|
| 142 |
+
|
| 143 |
+
self.num_spk = num_spks
|
| 144 |
+
|
| 145 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 146 |
+
'''
|
| 147 |
+
x: [B, N, T]
|
| 148 |
+
'''
|
| 149 |
+
# [BS, K, H]
|
| 150 |
+
x = self.g_block_1(x, pos_k)
|
| 151 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 152 |
+
x = self.l_block_1(x)
|
| 153 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 154 |
+
x = self.spk_attn_1(x, self.num_spk)
|
| 155 |
+
|
| 156 |
+
x = self.g_block_2(x, pos_k)
|
| 157 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 158 |
+
x = self.l_block_2(x)
|
| 159 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 160 |
+
x = self.spk_attn_2(x, self.num_spk)
|
| 161 |
+
|
| 162 |
+
x = self.g_block_3(x, pos_k)
|
| 163 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 164 |
+
x = self.l_block_3(x)
|
| 165 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 166 |
+
x = self.spk_attn_3(x, self.num_spk)
|
| 167 |
+
|
| 168 |
+
skip = x
|
| 169 |
+
|
| 170 |
+
return x, skip
|
| 171 |
+
|
| 172 |
+
self.num_stages = num_stages
|
| 173 |
+
self.pos_emb = RelativePositionalEncoding(**relative_positional_encoding)
|
| 174 |
+
|
| 175 |
+
# Temporal Contracting Part
|
| 176 |
+
self.enc_stages = torch.nn.ModuleList([])
|
| 177 |
+
for _ in range(self.num_stages):
|
| 178 |
+
self.enc_stages.append(SepEncStage(**enc_stage, down_conv=True))
|
| 179 |
+
|
| 180 |
+
self.bottleneck_G = SepEncStage(**enc_stage, down_conv=False)
|
| 181 |
+
self.spk_split_block = SpkSplitStage(**spk_split_stage)
|
| 182 |
+
|
| 183 |
+
# Temporal Expanding Part
|
| 184 |
+
self.simple_fusion = torch.nn.ModuleList([])
|
| 185 |
+
self.dec_stages = torch.nn.ModuleList([])
|
| 186 |
+
for _ in range(self.num_stages):
|
| 187 |
+
self.simple_fusion.append(torch.nn.Conv1d(in_channels=simple_fusion['out_channels']*2,out_channels=simple_fusion['out_channels'], kernel_size=1))
|
| 188 |
+
self.dec_stages.append(SepDecStage(**dec_stage))
|
| 189 |
+
|
| 190 |
+
def forward(self, input: torch.Tensor):
|
| 191 |
+
'''input: [B, N, L]'''
|
| 192 |
+
# feature projection
|
| 193 |
+
x, _ = self.pad_signal(input)
|
| 194 |
+
len_x = x.shape[-1]
|
| 195 |
+
# Temporal Contracting Part
|
| 196 |
+
pos_seq = torch.arange(0, len_x//2**self.num_stages).long().to(x.device)
|
| 197 |
+
pos_seq = pos_seq[:, None] - pos_seq[None, :]
|
| 198 |
+
pos_k, _ = self.pos_emb(pos_seq)
|
| 199 |
+
skip = []
|
| 200 |
+
for idx in range(self.num_stages):
|
| 201 |
+
x, skip_ = self.enc_stages[idx](x, pos_k)
|
| 202 |
+
skip_ = self.spk_split_block(skip_)
|
| 203 |
+
skip.append(skip_)
|
| 204 |
+
x, _ = self.bottleneck_G(x, pos_k)
|
| 205 |
+
x = self.spk_split_block(x) # B, 2F, T
|
| 206 |
+
|
| 207 |
+
each_stage_outputs = []
|
| 208 |
+
# Temporal Expanding Part
|
| 209 |
+
for idx in range(self.num_stages):
|
| 210 |
+
each_stage_outputs.append(x)
|
| 211 |
+
idx_en = self.num_stages - (idx + 1)
|
| 212 |
+
x = torch.nn.functional.upsample(x, skip[idx_en].shape[-1])
|
| 213 |
+
x = torch.cat([x,skip[idx_en]],dim=1)
|
| 214 |
+
x = self.simple_fusion[idx](x)
|
| 215 |
+
x, _ = self.dec_stages[idx](x, pos_k)
|
| 216 |
+
|
| 217 |
+
last_stage_output = x
|
| 218 |
+
return last_stage_output, each_stage_outputs
|
| 219 |
+
|
| 220 |
+
def pad_signal(self, input: torch.Tensor):
|
| 221 |
+
# (B, T) or (B, 1, T)
|
| 222 |
+
if input.dim() == 1: input = input.unsqueeze(0)
|
| 223 |
+
elif input.dim() not in [2, 3]: raise RuntimeError("Input can only be 2 or 3 dimensional.")
|
| 224 |
+
elif input.dim() == 2: input = input.unsqueeze(1)
|
| 225 |
+
L = 2**self.num_stages
|
| 226 |
+
batch_size = input.size(0)
|
| 227 |
+
ndim = input.size(1)
|
| 228 |
+
nframe = input.size(2)
|
| 229 |
+
padded_len = (nframe//L + 1)*L
|
| 230 |
+
rest = 0 if nframe%L == 0 else padded_len - nframe
|
| 231 |
+
if rest > 0:
|
| 232 |
+
pad = torch.autograd.Variable(torch.zeros(batch_size, ndim, rest)).type(input.type()).to(input.device)
|
| 233 |
+
input = torch.cat([input, pad], dim=-1)
|
| 234 |
+
return input, rest
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class OutputLayer(torch.nn.Module):
|
| 238 |
+
def __init__(self, in_channels: int, out_channels: int, num_spks: int, masking: bool = False):
|
| 239 |
+
super().__init__()
|
| 240 |
+
# feature expansion back
|
| 241 |
+
self.masking = masking
|
| 242 |
+
self.spe_block = Masking(in_channels, Activation_mask="ReLU", concat_opt=None)
|
| 243 |
+
self.num_spks = num_spks
|
| 244 |
+
self.end_conv1x1 = torch.nn.Sequential(
|
| 245 |
+
torch.nn.Linear(out_channels, 4*out_channels),
|
| 246 |
+
torch.nn.GLU(),
|
| 247 |
+
torch.nn.Linear(2*out_channels, in_channels))
|
| 248 |
+
|
| 249 |
+
def forward(self, x: torch.Tensor, input: torch.Tensor):
|
| 250 |
+
x = x[...,:input.shape[-1]]
|
| 251 |
+
x = x.permute([0, 2, 1])
|
| 252 |
+
x = self.end_conv1x1(x)
|
| 253 |
+
x = x.permute([0, 2, 1])
|
| 254 |
+
B, N, L = x.shape
|
| 255 |
+
B = B // self.num_spks
|
| 256 |
+
|
| 257 |
+
if self.masking:
|
| 258 |
+
input = input.expand(self.num_spks, B, N, L).transpose(0,1).contiguous()
|
| 259 |
+
input = input.view(B*self.num_spks, N, L)
|
| 260 |
+
x = self.spe_block(x, input)
|
| 261 |
+
|
| 262 |
+
x = x.view(B, self.num_spks, N, L)
|
| 263 |
+
# [spks, B, N, L]
|
| 264 |
+
x = x.transpose(0, 1)
|
| 265 |
+
return x
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class AudioDecoder(torch.nn.ConvTranspose1d):
|
| 269 |
+
'''
|
| 270 |
+
Decoder of the TasNet
|
| 271 |
+
This module can be seen as the gradient of Conv1d with respect to its input.
|
| 272 |
+
It is also known as a fractionally-strided convolution
|
| 273 |
+
or a deconvolution (although it is not an actual deconvolution operation).
|
| 274 |
+
'''
|
| 275 |
+
def __init__(self, *args, **kwargs):
|
| 276 |
+
super().__init__(*args, **kwargs)
|
| 277 |
+
|
| 278 |
+
def forward(self, x):
|
| 279 |
+
# x: [B, N, L]
|
| 280 |
+
if x.dim() not in [2, 3]: raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
|
| 281 |
+
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
|
| 282 |
+
x = torch.squeeze(x, dim=1) if torch.squeeze(x).dim() == 1 else torch.squeeze(x)
|
| 283 |
+
return x
|
models/SepReformer/SepReformer_Base_WSJ0/modules/network.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import numpy
|
| 4 |
+
from utils.decorators import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LayerScale(torch.nn.Module):
|
| 8 |
+
def __init__(self, dims, input_size, Layer_scale_init=1.0e-5):
|
| 9 |
+
super().__init__()
|
| 10 |
+
if dims == 1:
|
| 11 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(input_size)*Layer_scale_init, requires_grad=True)
|
| 12 |
+
elif dims == 2:
|
| 13 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(1,input_size)*Layer_scale_init, requires_grad=True)
|
| 14 |
+
elif dims == 3:
|
| 15 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(1,1,input_size)*Layer_scale_init, requires_grad=True)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return x*self.layer_scale
|
| 19 |
+
|
| 20 |
+
class Masking(torch.nn.Module):
|
| 21 |
+
def __init__(self, input_dim, Activation_mask='Sigmoid', **options):
|
| 22 |
+
super(Masking, self).__init__()
|
| 23 |
+
|
| 24 |
+
self.options = options
|
| 25 |
+
if self.options['concat_opt']:
|
| 26 |
+
self.pw_conv = torch.nn.Conv1d(input_dim*2, input_dim, 1, stride=1, padding=0)
|
| 27 |
+
|
| 28 |
+
if Activation_mask == 'Sigmoid':
|
| 29 |
+
self.gate_act = torch.nn.Sigmoid()
|
| 30 |
+
elif Activation_mask == 'ReLU':
|
| 31 |
+
self.gate_act = torch.nn.ReLU()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def forward(self, x, skip):
|
| 35 |
+
|
| 36 |
+
if self.options['concat_opt']:
|
| 37 |
+
y = torch.cat([x, skip], dim=-2)
|
| 38 |
+
y = self.pw_conv(y)
|
| 39 |
+
else:
|
| 40 |
+
y = x
|
| 41 |
+
y = self.gate_act(y) * skip
|
| 42 |
+
|
| 43 |
+
return y
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GCFN(torch.nn.Module):
|
| 47 |
+
def __init__(self, in_channels, dropout_rate, Layer_scale_init=1.0e-5):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.net1 = torch.nn.Sequential(
|
| 50 |
+
torch.nn.LayerNorm(in_channels),
|
| 51 |
+
torch.nn.Linear(in_channels, in_channels*6))
|
| 52 |
+
self.depthwise = torch.nn.Conv1d(in_channels*6, in_channels*6, 3, padding=1, groups=in_channels*6)
|
| 53 |
+
self.net2 = torch.nn.Sequential(
|
| 54 |
+
torch.nn.GLU(),
|
| 55 |
+
torch.nn.Dropout(dropout_rate),
|
| 56 |
+
torch.nn.Linear(in_channels*3, in_channels),
|
| 57 |
+
torch.nn.Dropout(dropout_rate))
|
| 58 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
y = self.net1(x)
|
| 62 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 63 |
+
y = self.depthwise(y)
|
| 64 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 65 |
+
y = self.net2(y)
|
| 66 |
+
return x + self.Layer_scale(y)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MultiHeadAttention(torch.nn.Module):
|
| 70 |
+
"""
|
| 71 |
+
Multi-Head Attention layer.
|
| 72 |
+
:param int n_head: the number of head s
|
| 73 |
+
:param int n_feat: the number of features
|
| 74 |
+
:param float dropout_rate: dropout rate
|
| 75 |
+
"""
|
| 76 |
+
def __init__(self, n_head: int, in_channels: int, dropout_rate: float, Layer_scale_init=1.0e-5):
|
| 77 |
+
super().__init__()
|
| 78 |
+
assert in_channels % n_head == 0
|
| 79 |
+
self.d_k = in_channels // n_head # We assume d_v always equals d_k
|
| 80 |
+
self.h = n_head
|
| 81 |
+
self.layer_norm = torch.nn.LayerNorm(in_channels)
|
| 82 |
+
self.linear_q = torch.nn.Linear(in_channels, in_channels)
|
| 83 |
+
self.linear_k = torch.nn.Linear(in_channels, in_channels)
|
| 84 |
+
self.linear_v = torch.nn.Linear(in_channels, in_channels)
|
| 85 |
+
self.linear_out = torch.nn.Linear(in_channels, in_channels)
|
| 86 |
+
self.attn = None
|
| 87 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 88 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, pos_k, mask):
|
| 91 |
+
"""
|
| 92 |
+
Compute 'Scaled Dot Product Attention'.
|
| 93 |
+
:param torch.Tensor mask: (batch, time1, time2)
|
| 94 |
+
:param torch.nn.Dropout dropout:
|
| 95 |
+
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
|
| 96 |
+
weighted by the query dot key attention (batch, head, time1, time2)
|
| 97 |
+
"""
|
| 98 |
+
n_batch = x.size(0)
|
| 99 |
+
x = self.layer_norm(x)
|
| 100 |
+
q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
|
| 101 |
+
k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
|
| 102 |
+
v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k)
|
| 103 |
+
q = q.transpose(1, 2)
|
| 104 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 105 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 106 |
+
A = torch.matmul(q, k.transpose(-2, -1))
|
| 107 |
+
reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1)
|
| 108 |
+
if pos_k is not None:
|
| 109 |
+
B = torch.matmul(reshape_q, pos_k.transpose(-2, -1))
|
| 110 |
+
B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
|
| 111 |
+
scores = (A + B) / math.sqrt(self.d_k)
|
| 112 |
+
else:
|
| 113 |
+
scores = A / math.sqrt(self.d_k)
|
| 114 |
+
if mask is not None:
|
| 115 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
| 116 |
+
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
| 117 |
+
scores = scores.masked_fill(mask, min_value)
|
| 118 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
| 119 |
+
else:
|
| 120 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 121 |
+
p_attn = self.dropout(self.attn)
|
| 122 |
+
x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
|
| 123 |
+
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
|
| 124 |
+
return self.Layer_scale(self.dropout(self.linear_out(x))) # (batch, time1, d_model)
|
| 125 |
+
|
| 126 |
+
class EGA(torch.nn.Module):
|
| 127 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.block = torch.nn.ModuleDict({
|
| 130 |
+
'self_attn': MultiHeadAttention(
|
| 131 |
+
n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
|
| 132 |
+
'linear': torch.nn.Sequential(
|
| 133 |
+
torch.nn.LayerNorm(normalized_shape=in_channels),
|
| 134 |
+
torch.nn.Linear(in_features=in_channels, out_features=in_channels),
|
| 135 |
+
torch.nn.Sigmoid())
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 139 |
+
"""
|
| 140 |
+
Compute encoded features.
|
| 141 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 142 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 143 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 144 |
+
"""
|
| 145 |
+
down_len = pos_k.shape[0]
|
| 146 |
+
x_down = torch.nn.functional.adaptive_avg_pool1d(input=x, output_size=down_len)
|
| 147 |
+
x = x.permute([0, 2, 1])
|
| 148 |
+
x_down = x_down.permute([0, 2, 1])
|
| 149 |
+
x_down = self.block['self_attn'](x_down, pos_k, None)
|
| 150 |
+
x_down = x_down.permute([0, 2, 1])
|
| 151 |
+
x_downup = torch.nn.functional.upsample(input=x_down, size=x.shape[1])
|
| 152 |
+
x_downup = x_downup.permute([0, 2, 1])
|
| 153 |
+
x = x + self.block['linear'](x) * x_downup
|
| 154 |
+
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class CLA(torch.nn.Module):
|
| 160 |
+
def __init__(self, in_channels, kernel_size, dropout_rate, Layer_scale_init=1.0e-5):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.layer_norm = torch.nn.LayerNorm(in_channels)
|
| 163 |
+
self.linear1 = torch.nn.Linear(in_channels, in_channels*2)
|
| 164 |
+
self.GLU = torch.nn.GLU()
|
| 165 |
+
self.dw_conv_1d = torch.nn.Conv1d(in_channels, in_channels, kernel_size, padding='same', groups=in_channels)
|
| 166 |
+
self.linear2 = torch.nn.Linear(in_channels, 2*in_channels)
|
| 167 |
+
self.BN = torch.nn.BatchNorm1d(2*in_channels)
|
| 168 |
+
self.linear3 = torch.nn.Sequential(
|
| 169 |
+
torch.nn.GELU(),
|
| 170 |
+
torch.nn.Linear(2*in_channels, in_channels),
|
| 171 |
+
torch.nn.Dropout(dropout_rate))
|
| 172 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
y = self.layer_norm(x)
|
| 176 |
+
y = self.linear1(y)
|
| 177 |
+
y = self.GLU(y)
|
| 178 |
+
y = y.permute([0, 2, 1]) # B, F, T
|
| 179 |
+
y = self.dw_conv_1d(y)
|
| 180 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 181 |
+
y = self.linear2(y)
|
| 182 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 183 |
+
y = self.BN(y)
|
| 184 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 185 |
+
y = self.linear3(y)
|
| 186 |
+
|
| 187 |
+
return x + self.Layer_scale(y)
|
| 188 |
+
|
| 189 |
+
class GlobalBlock(torch.nn.Module):
|
| 190 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.block = torch.nn.ModuleDict({
|
| 193 |
+
'ega': EGA(
|
| 194 |
+
num_mha_heads=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
|
| 195 |
+
'gcfn': GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
|
| 196 |
+
})
|
| 197 |
+
|
| 198 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 199 |
+
"""
|
| 200 |
+
Compute encoded features.
|
| 201 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 202 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 203 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 204 |
+
"""
|
| 205 |
+
x = self.block['ega'](x, pos_k)
|
| 206 |
+
x = self.block['gcfn'](x)
|
| 207 |
+
x = x.permute([0, 2, 1])
|
| 208 |
+
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class LocalBlock(torch.nn.Module):
|
| 213 |
+
def __init__(self, in_channels: int, kernel_size: int, dropout_rate: float):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.block = torch.nn.ModuleDict({
|
| 216 |
+
'cla': CLA(in_channels, kernel_size, dropout_rate),
|
| 217 |
+
'gcfn': GCFN(in_channels, dropout_rate)
|
| 218 |
+
})
|
| 219 |
+
|
| 220 |
+
def forward(self, x: torch.Tensor):
|
| 221 |
+
x = self.block['cla'](x)
|
| 222 |
+
x = self.block['gcfn'](x)
|
| 223 |
+
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class SpkAttention(torch.nn.Module):
|
| 228 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.self_attn = MultiHeadAttention(n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate)
|
| 231 |
+
self.feed_forward = GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
|
| 232 |
+
|
| 233 |
+
def forward(self, x: torch.Tensor, num_spk: int):
|
| 234 |
+
"""
|
| 235 |
+
Compute encoded features.
|
| 236 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 237 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 238 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 239 |
+
"""
|
| 240 |
+
B, F, T = x.shape
|
| 241 |
+
x = x.view(B//num_spk, num_spk, F, T).contiguous()
|
| 242 |
+
x = x.permute([0, 3, 1, 2]).contiguous()
|
| 243 |
+
x = x.view(-1, num_spk, F).contiguous()
|
| 244 |
+
x = x + self.self_attn(x, None, None)
|
| 245 |
+
x = x.view(B//num_spk, T, num_spk, F).contiguous()
|
| 246 |
+
x = x.permute([0, 2, 3, 1]).contiguous()
|
| 247 |
+
x = x.view(B, F, T).contiguous()
|
| 248 |
+
x = x.permute([0, 2, 1])
|
| 249 |
+
x = self.feed_forward(x)
|
| 250 |
+
x = x.permute([0, 2, 1])
|
| 251 |
+
|
| 252 |
+
return x
|
models/SepReformer/SepReformer_Large_DM_WHAM/configs.yaml
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
config:
|
| 2 |
+
dataset:
|
| 3 |
+
max_len : 32000
|
| 4 |
+
sampling_rate: 8000
|
| 5 |
+
scp_dir: "data/scp_ss_8k_wham"
|
| 6 |
+
train:
|
| 7 |
+
mixture: "tr_mix.scp"
|
| 8 |
+
spk1: "tr_s1.scp"
|
| 9 |
+
spk2: "tr_s2.scp"
|
| 10 |
+
noise: "tr_n.scp"
|
| 11 |
+
dynamic_mixing: true
|
| 12 |
+
valid:
|
| 13 |
+
mixture: "cv_mix.scp"
|
| 14 |
+
spk1: "cv_s1.scp"
|
| 15 |
+
spk2: "cv_s2.scp"
|
| 16 |
+
test:
|
| 17 |
+
mixture: "tt_mix.scp"
|
| 18 |
+
spk1: "tt_s1.scp"
|
| 19 |
+
spk2: "tt_s2.scp"
|
| 20 |
+
dataloader:
|
| 21 |
+
batch_size: 2
|
| 22 |
+
pin_memory: false
|
| 23 |
+
num_workers: 12
|
| 24 |
+
drop_last: false
|
| 25 |
+
model:
|
| 26 |
+
num_stages: &var_model_num_stages 4 # R
|
| 27 |
+
num_spks: &var_model_num_spks 2
|
| 28 |
+
module_audio_enc:
|
| 29 |
+
in_channels: 1
|
| 30 |
+
out_channels: &var_model_audio_enc_out_channels 256
|
| 31 |
+
kernel_size: &var_model_audio_enc_kernel_size 16 # L
|
| 32 |
+
stride: &var_model_audio_enc_stride 4 # S
|
| 33 |
+
groups: 1
|
| 34 |
+
bias: false
|
| 35 |
+
module_feature_projector:
|
| 36 |
+
num_channels: *var_model_audio_enc_out_channels
|
| 37 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 38 |
+
out_channels: &feature_projector_out_channels 256 # F
|
| 39 |
+
kernel_size: 1
|
| 40 |
+
bias: false
|
| 41 |
+
module_separator:
|
| 42 |
+
num_stages: *var_model_num_stages
|
| 43 |
+
relative_positional_encoding:
|
| 44 |
+
in_channels: *feature_projector_out_channels
|
| 45 |
+
num_heads: 8
|
| 46 |
+
maxlen: 2000
|
| 47 |
+
embed_v: false
|
| 48 |
+
enc_stage:
|
| 49 |
+
global_blocks:
|
| 50 |
+
in_channels: *feature_projector_out_channels
|
| 51 |
+
num_mha_heads: 8
|
| 52 |
+
dropout_rate: 0.1
|
| 53 |
+
local_blocks:
|
| 54 |
+
in_channels: *feature_projector_out_channels
|
| 55 |
+
kernel_size: 65
|
| 56 |
+
dropout_rate: 0.1
|
| 57 |
+
down_conv_layer:
|
| 58 |
+
in_channels: *feature_projector_out_channels
|
| 59 |
+
samp_kernel_size: &var_model_samp_kernel_size 5
|
| 60 |
+
spk_split_stage:
|
| 61 |
+
in_channels: *feature_projector_out_channels
|
| 62 |
+
num_spks: *var_model_num_spks
|
| 63 |
+
simple_fusion:
|
| 64 |
+
out_channels: *feature_projector_out_channels
|
| 65 |
+
dec_stage:
|
| 66 |
+
num_spks: *var_model_num_spks
|
| 67 |
+
global_blocks:
|
| 68 |
+
in_channels: *feature_projector_out_channels
|
| 69 |
+
num_mha_heads: 8
|
| 70 |
+
dropout_rate: 0.1
|
| 71 |
+
local_blocks:
|
| 72 |
+
in_channels: *feature_projector_out_channels
|
| 73 |
+
kernel_size: 65
|
| 74 |
+
dropout_rate: 0.1
|
| 75 |
+
spk_attention:
|
| 76 |
+
in_channels: *feature_projector_out_channels
|
| 77 |
+
num_mha_heads: 8
|
| 78 |
+
dropout_rate: 0.1
|
| 79 |
+
module_output_layer:
|
| 80 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 81 |
+
out_channels: *feature_projector_out_channels
|
| 82 |
+
num_spks: *var_model_num_spks
|
| 83 |
+
module_audio_dec:
|
| 84 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 85 |
+
out_channels: 1
|
| 86 |
+
kernel_size: *var_model_audio_enc_kernel_size
|
| 87 |
+
stride: *var_model_audio_enc_stride
|
| 88 |
+
bias: false
|
| 89 |
+
criterion: ### Ref: https://pytorch.org/docs/stable/nn.html#loss-functions
|
| 90 |
+
name: ["PIT_SISNR_mag", "PIT_SISNR_time", "PIT_SISNRi", "PIT_SDRi"] ### Choose a torch.nn's loss function class(=attribute) e.g. ["L1Loss", "MSELoss", "CrossEntropyLoss", ...] / You can also build your optimizer :)
|
| 91 |
+
PIT_SISNR_mag:
|
| 92 |
+
frame_length: 512
|
| 93 |
+
frame_shift: 128
|
| 94 |
+
window: 'hann'
|
| 95 |
+
num_stages: *var_model_num_stages
|
| 96 |
+
num_spks: *var_model_num_spks
|
| 97 |
+
scale_inv: true
|
| 98 |
+
mel_opt: false
|
| 99 |
+
PIT_SISNR_time:
|
| 100 |
+
num_spks: *var_model_num_spks
|
| 101 |
+
scale_inv: true
|
| 102 |
+
PIT_SISNRi:
|
| 103 |
+
num_spks: *var_model_num_spks
|
| 104 |
+
scale_inv: true
|
| 105 |
+
PIT_SDRi:
|
| 106 |
+
dump: 0
|
| 107 |
+
optimizer: ### Ref: https://pytorch.org/docs/stable/optim.html#algorithms
|
| 108 |
+
name: ["AdamW"] ### Choose a torch.optim's class(=attribute) e.g. ["Adam", "AdamW", "SGD", ...] / You can also build your optimizer :)
|
| 109 |
+
AdamW:
|
| 110 |
+
lr: 2.0e-4
|
| 111 |
+
weight_decay: 1.0e-2
|
| 112 |
+
scheduler: ### Ref(+ find "How to adjust learning rate"): https://pytorch.org/docs/stable/optim.html#algorithms
|
| 113 |
+
name: ["ReduceLROnPlateau", "WarmupConstantSchedule"] ### Choose a torch.optim.lr_scheduler's class(=attribute) e.g. ["StepLR", "ReduceLROnPlateau", "Custom"] / You can also build your scheduler :)
|
| 114 |
+
ReduceLROnPlateau:
|
| 115 |
+
mode: "min"
|
| 116 |
+
min_lr: 1.0e-10
|
| 117 |
+
factor: 0.8
|
| 118 |
+
patience: 3
|
| 119 |
+
WarmupConstantSchedule:
|
| 120 |
+
warmup_steps: 1000
|
| 121 |
+
check_computations:
|
| 122 |
+
dummy_len: 16000
|
| 123 |
+
engine:
|
| 124 |
+
max_epoch: 200
|
| 125 |
+
gpuid: "1" ### "0"(single-gpu) or "0, 1" (multi-gpu)
|
| 126 |
+
mvn: false
|
| 127 |
+
clip_norm: 5
|
| 128 |
+
start_scheduling: 50
|
| 129 |
+
test_epochs: [50, 80, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 199]
|
models/SepReformer/SepReformer_Large_DM_WHAM/dataset.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import librosa as audio_lib
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from utils import util_dataset
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from loguru import logger
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@logger_wraps()
|
| 14 |
+
def get_dataloaders(args, dataset_config, loader_config):
|
| 15 |
+
# create dataset object for each partition
|
| 16 |
+
partitions = ["test"] if "test" in args.engine_mode else ["train", "valid", "test"]
|
| 17 |
+
dataloaders = {}
|
| 18 |
+
for partition in partitions:
|
| 19 |
+
scp_config_mix = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['mixture'])
|
| 20 |
+
scp_config_spk = [os.path.join(dataset_config["scp_dir"], dataset_config[partition][spk_key]) for spk_key in dataset_config[partition] if spk_key.startswith('spk')]
|
| 21 |
+
scp_config_noise = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['noise']) if 'noise' in dataset_config[partition] else None
|
| 22 |
+
dynamic_mixing = dataset_config[partition]["dynamic_mixing"] if partition == 'train' else False
|
| 23 |
+
dataset = MyDataset(
|
| 24 |
+
max_len = dataset_config['max_len'],
|
| 25 |
+
fs = dataset_config['sampling_rate'],
|
| 26 |
+
partition = partition,
|
| 27 |
+
wave_scp_srcs = scp_config_spk,
|
| 28 |
+
wave_scp_mix = scp_config_mix,
|
| 29 |
+
wave_scp_noise = scp_config_noise,
|
| 30 |
+
dynamic_mixing = dynamic_mixing)
|
| 31 |
+
dataloader = DataLoader(
|
| 32 |
+
dataset = dataset,
|
| 33 |
+
batch_size = 1 if partition == 'test' else loader_config["batch_size"],
|
| 34 |
+
shuffle = False if partition == 'test' else True, # only train: (partition == 'train') / all: True
|
| 35 |
+
pin_memory = loader_config["pin_memory"],
|
| 36 |
+
num_workers = loader_config["num_workers"],
|
| 37 |
+
drop_last = loader_config["drop_last"],
|
| 38 |
+
collate_fn = _collate)
|
| 39 |
+
dataloaders[partition] = dataloader
|
| 40 |
+
return dataloaders
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _collate(egs):
|
| 44 |
+
"""
|
| 45 |
+
Transform utterance index into a minbatch
|
| 46 |
+
|
| 47 |
+
Arguments:
|
| 48 |
+
index: a list type [{},{},{}]
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
input_sizes: a tensor correspond to utterance length
|
| 52 |
+
input_feats: packed sequence to feed networks
|
| 53 |
+
source_attr/target_attr: dictionary contains spectrogram/phase needed in loss computation
|
| 54 |
+
"""
|
| 55 |
+
def __prepare_target_rir(dict_lsit, index):
|
| 56 |
+
return torch.nn.utils.rnn.pad_sequence([torch.tensor(d["src"][index], dtype=torch.float32) for d in dict_lsit], batch_first=True)
|
| 57 |
+
if type(egs) is not list: raise ValueError("Unsupported index type({})".format(type(egs)))
|
| 58 |
+
num_spks = 2 # you need to set this paramater by yourself
|
| 59 |
+
dict_list = sorted([eg for eg in egs], key=lambda x: x['num_sample'], reverse=True)
|
| 60 |
+
mixture = torch.nn.utils.rnn.pad_sequence([torch.tensor(d['mix'], dtype=torch.float32) for d in dict_list], batch_first=True)
|
| 61 |
+
src = [__prepare_target_rir(dict_list, index) for index in range(num_spks)]
|
| 62 |
+
input_sizes = torch.tensor([d['num_sample'] for d in dict_list], dtype=torch.float32)
|
| 63 |
+
key = [d['key'] for d in dict_list]
|
| 64 |
+
return input_sizes, mixture, src, key
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@logger_wraps()
|
| 68 |
+
class MyDataset(Dataset):
|
| 69 |
+
def __init__(self, max_len, fs, partition, wave_scp_srcs, wave_scp_mix, wave_scp_noise=None, dynamic_mixing=False, speed_list=None):
|
| 70 |
+
self.partition = partition
|
| 71 |
+
for wave_scp_src in wave_scp_srcs:
|
| 72 |
+
if not os.path.exists(wave_scp_src): raise FileNotFoundError(f"Could not find file {wave_scp_src}")
|
| 73 |
+
self.max_len = max_len
|
| 74 |
+
self.fs = fs
|
| 75 |
+
self.wave_dict_srcs = [util_dataset.parse_scps(wave_scp_src) for wave_scp_src in wave_scp_srcs]
|
| 76 |
+
self.wave_dict_mix = util_dataset.parse_scps(wave_scp_mix)
|
| 77 |
+
self.wave_dict_noise = util_dataset.parse_scps(wave_scp_noise) if wave_scp_noise else None
|
| 78 |
+
self.wave_keys = list(self.wave_dict_mix.keys())
|
| 79 |
+
logger.info(f"Create MyDataset for {wave_scp_mix} with {len(self.wave_dict_mix)} utterances")
|
| 80 |
+
self.dynamic_mixing = dynamic_mixing
|
| 81 |
+
|
| 82 |
+
def __len__(self):
|
| 83 |
+
return len(self.wave_dict_mix)
|
| 84 |
+
|
| 85 |
+
def __contains__(self, key):
|
| 86 |
+
return key in self.wave_dict_mix
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _dynamic_mixing(self, key):
|
| 90 |
+
def __match_length(wav, len_data) :
|
| 91 |
+
leftover = len(wav) - len_data
|
| 92 |
+
idx = random.randint(0,leftover)
|
| 93 |
+
wav = wav[idx:idx+len_data]
|
| 94 |
+
return wav
|
| 95 |
+
|
| 96 |
+
samps_src = []
|
| 97 |
+
src_len = [self.max_len]
|
| 98 |
+
|
| 99 |
+
# dyanmic source choice
|
| 100 |
+
key_random = random.choice(list(self.wave_dict_srcs[0].keys()))
|
| 101 |
+
|
| 102 |
+
idx1, idx2 = (0, 1) if random.random() > 0.5 else (1, 0)
|
| 103 |
+
files = [self.wave_dict_srcs[idx1][key], self.wave_dict_srcs[idx2][key_random]]
|
| 104 |
+
|
| 105 |
+
# load
|
| 106 |
+
for idx, file in enumerate(files):
|
| 107 |
+
if not os.path.exists(file): raise FileNotFoundError("Input file {} do not exists!".format(file))
|
| 108 |
+
samps_tmp, _ = audio_lib.load(file, sr=self.fs)
|
| 109 |
+
|
| 110 |
+
if idx == 0: ref_rms = np.sqrt(np.mean(np.square(samps_tmp)))
|
| 111 |
+
curr_rms = np.sqrt(np.mean(np.square(samps_tmp)))
|
| 112 |
+
|
| 113 |
+
norm_factor = ref_rms / curr_rms
|
| 114 |
+
samps_tmp *= norm_factor
|
| 115 |
+
|
| 116 |
+
# mixing with random gains
|
| 117 |
+
gain = pow(10,-random.uniform(-5,5)/20)
|
| 118 |
+
samps_tmp = np.array(torch.tensor(samps_tmp))
|
| 119 |
+
samps_src.append(gain*samps_tmp)
|
| 120 |
+
src_len.append(len(samps_tmp))
|
| 121 |
+
|
| 122 |
+
# matching the audio length
|
| 123 |
+
min_len = min(src_len)
|
| 124 |
+
|
| 125 |
+
# add noise source
|
| 126 |
+
file_noise = self.wave_dict_noise[key]
|
| 127 |
+
samps_noise, _ = audio_lib.load(file_noise, sr=self.fs)
|
| 128 |
+
curr_rms = np.sqrt(np.mean(np.square(samps_noise)))
|
| 129 |
+
norm_factor = ref_rms / curr_rms
|
| 130 |
+
samps_noise *= norm_factor
|
| 131 |
+
gain_noise = pow(10,-random.uniform(-5,5)/20)
|
| 132 |
+
samps_noise = samps_noise*gain_noise
|
| 133 |
+
src_len.append(len(samps_noise))
|
| 134 |
+
|
| 135 |
+
# truncate
|
| 136 |
+
min_len = min(src_len)
|
| 137 |
+
samps_src = [__match_length(s, min_len) for s in samps_src]
|
| 138 |
+
samps_noise = __match_length(samps_noise, min_len)
|
| 139 |
+
samps_mix = sum(samps_src) + samps_noise
|
| 140 |
+
|
| 141 |
+
if len(samps_mix)%4 != 0:
|
| 142 |
+
remains = len(samps_mix)%4
|
| 143 |
+
samps_mix = samps_mix[:-remains]
|
| 144 |
+
samps_src = [s[:-remains] for s in samps_src]
|
| 145 |
+
|
| 146 |
+
return samps_mix, samps_src
|
| 147 |
+
|
| 148 |
+
def _direct_load(self, key):
|
| 149 |
+
samps_src = []
|
| 150 |
+
files = [wave_dict_src[key] for wave_dict_src in self.wave_dict_srcs]
|
| 151 |
+
for file in files:
|
| 152 |
+
if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
|
| 153 |
+
samps_tmp, _ = audio_lib.load(file, sr=self.fs)
|
| 154 |
+
samps_src.append(samps_tmp)
|
| 155 |
+
|
| 156 |
+
file = self.wave_dict_mix[key]
|
| 157 |
+
if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
|
| 158 |
+
samps_mix, _ = audio_lib.load(file, sr=self.fs)
|
| 159 |
+
# Truncate samples as needed
|
| 160 |
+
if len(samps_mix) % 4 != 0:
|
| 161 |
+
remains = len(samps_mix) % 4
|
| 162 |
+
samps_mix = samps_mix[:-remains]
|
| 163 |
+
samps_src = [s[:-remains] for s in samps_src]
|
| 164 |
+
|
| 165 |
+
if self.partition != "test":
|
| 166 |
+
if len(samps_mix) > self.max_len:
|
| 167 |
+
start = random.randint(0,len(samps_mix)-self.max_len)
|
| 168 |
+
samps_mix = samps_mix[start:start+self.max_len]
|
| 169 |
+
samps_src = [s[start:start+self.max_len] for s in samps_src]
|
| 170 |
+
|
| 171 |
+
return samps_mix, samps_src
|
| 172 |
+
|
| 173 |
+
def __getitem__(self, index):
|
| 174 |
+
key = self.wave_keys[index]
|
| 175 |
+
if any(key not in self.wave_dict_srcs[i] for i in range(len(self.wave_dict_srcs))) or key not in self.wave_dict_mix: raise KeyError(f"Could not find utterance {key}")
|
| 176 |
+
samps_mix, samps_src = self._dynamic_mixing(key) if self.dynamic_mixing else self._direct_load(key)
|
| 177 |
+
return {"num_sample": samps_mix.shape[0], "mix": samps_mix, "src": samps_src, "key": key}
|
models/SepReformer/SepReformer_Large_DM_WHAM/engine.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import csv
|
| 4 |
+
import time
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
|
| 7 |
+
from loguru import logger
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from utils import util_engine, functions
|
| 10 |
+
from utils.decorators import *
|
| 11 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@logger_wraps()
|
| 15 |
+
class Engine(object):
|
| 16 |
+
def __init__(self, args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device):
|
| 17 |
+
|
| 18 |
+
''' Default setting '''
|
| 19 |
+
self.engine_mode = args.engine_mode
|
| 20 |
+
self.out_wav_dir = args.out_wav_dir
|
| 21 |
+
self.config = config
|
| 22 |
+
self.gpuid = gpuid
|
| 23 |
+
self.device = device
|
| 24 |
+
self.model = model.to(self.device)
|
| 25 |
+
self.dataloaders = dataloaders # self.dataloaders['train'] or ['valid'] or ['test']
|
| 26 |
+
self.PIT_SISNR_mag_loss, self.PIT_SISNR_time_loss, self.PIT_SISNRi_loss, self.PIT_SDRi_loss = criterions
|
| 27 |
+
self.main_optimizer = optimizers[0]
|
| 28 |
+
self.main_scheduler, self.warmup_scheduler = schedulers
|
| 29 |
+
|
| 30 |
+
self.pretrain_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "pretrain_weights")
|
| 31 |
+
os.makedirs(self.pretrain_weights_path, exist_ok=True)
|
| 32 |
+
self.scratch_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "scratch_weights")
|
| 33 |
+
os.makedirs(self.scratch_weights_path, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
self.checkpoint_path = self.pretrain_weights_path if any(file.endswith(('.pt', '.pt', '.pkl')) for file in os.listdir(self.pretrain_weights_path)) else self.scratch_weights_path
|
| 36 |
+
self.start_epoch = util_engine.load_last_checkpoint_n_get_epoch(self.checkpoint_path, self.model, self.main_optimizer, location=self.device)
|
| 37 |
+
|
| 38 |
+
# Logging
|
| 39 |
+
util_engine.model_params_mac_summary(
|
| 40 |
+
model=self.model,
|
| 41 |
+
input=torch.randn(1, self.config['check_computations']['dummy_len']).to(self.device),
|
| 42 |
+
dummy_input=torch.rand(1, self.config['check_computations']['dummy_len']).to(self.device),
|
| 43 |
+
metrics=['ptflops', 'thop', 'torchinfo']
|
| 44 |
+
# metrics=['ptflops']
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
logger.info(f"Clip gradient by 2-norm {self.config['engine']['clip_norm']}")
|
| 48 |
+
|
| 49 |
+
@logger_wraps()
|
| 50 |
+
def _train(self, dataloader, epoch):
|
| 51 |
+
self.model.train()
|
| 52 |
+
tot_loss_freq = [0 for _ in range(self.model.num_stages)]
|
| 53 |
+
tot_loss_time, num_batch = 0, 0
|
| 54 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:25}{r_bar}{bar:-10b}', colour="YELLOW", dynamic_ncols=True)
|
| 55 |
+
for input_sizes, mixture, src, _ in dataloader:
|
| 56 |
+
nnet_input = mixture
|
| 57 |
+
nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
|
| 58 |
+
num_batch += 1
|
| 59 |
+
pbar.update(1)
|
| 60 |
+
# Scheduler learning rate for warm-up (Iteration-based update for transformers)
|
| 61 |
+
if epoch == 1: self.warmup_scheduler.step()
|
| 62 |
+
nnet_input = nnet_input.to(self.device)
|
| 63 |
+
self.main_optimizer.zero_grad()
|
| 64 |
+
estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 65 |
+
cur_loss_s_bn = 0
|
| 66 |
+
cur_loss_s_bn = []
|
| 67 |
+
for idx, estim_src_value in enumerate(estim_src_bn):
|
| 68 |
+
cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
|
| 69 |
+
tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
|
| 70 |
+
cur_loss_s = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
|
| 71 |
+
tot_loss_time += cur_loss_s.item() / self.config['model']['num_spks']
|
| 72 |
+
alpha = 0.4 * 0.8**(1+(epoch-101)//5) if epoch > 100 else 0.4
|
| 73 |
+
cur_loss = (1-alpha) * cur_loss_s + alpha * sum(cur_loss_s_bn) / len(cur_loss_s_bn)
|
| 74 |
+
cur_loss = cur_loss / self.config['model']['num_spks']
|
| 75 |
+
cur_loss.backward()
|
| 76 |
+
if self.config['engine']['clip_norm']: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['engine']['clip_norm'])
|
| 77 |
+
self.main_optimizer.step()
|
| 78 |
+
dict_loss = {"T_Loss": tot_loss_time / num_batch}
|
| 79 |
+
dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
|
| 80 |
+
pbar.set_postfix(dict_loss)
|
| 81 |
+
pbar.close()
|
| 82 |
+
tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
|
| 83 |
+
return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
|
| 84 |
+
|
| 85 |
+
@logger_wraps()
|
| 86 |
+
def _validate(self, dataloader):
|
| 87 |
+
self.model.eval()
|
| 88 |
+
tot_loss_freq = [0 for _ in range(self.model.num_stages)]
|
| 89 |
+
tot_loss_time, num_batch = 0, 0
|
| 90 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="RED", dynamic_ncols=True)
|
| 91 |
+
with torch.inference_mode():
|
| 92 |
+
for input_sizes, mixture, src, _ in dataloader:
|
| 93 |
+
nnet_input = mixture
|
| 94 |
+
nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
|
| 95 |
+
nnet_input = nnet_input.to(self.device)
|
| 96 |
+
num_batch += 1
|
| 97 |
+
pbar.update(1)
|
| 98 |
+
estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 99 |
+
cur_loss_s_bn = []
|
| 100 |
+
for idx, estim_src_value in enumerate(estim_src_bn):
|
| 101 |
+
cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
|
| 102 |
+
tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
|
| 103 |
+
cur_loss_s_SDR = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
|
| 104 |
+
tot_loss_time += cur_loss_s_SDR.item() / self.config['model']['num_spks']
|
| 105 |
+
dict_loss = {"T_Loss":tot_loss_time / num_batch}
|
| 106 |
+
dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
|
| 107 |
+
pbar.set_postfix(dict_loss)
|
| 108 |
+
pbar.close()
|
| 109 |
+
tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
|
| 110 |
+
return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
|
| 111 |
+
|
| 112 |
+
@logger_wraps()
|
| 113 |
+
def _test(self, dataloader, wav_dir=None):
|
| 114 |
+
self.model.eval()
|
| 115 |
+
total_loss_SISNRi, total_loss_SDRi, num_batch = 0, 0, 0
|
| 116 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="grey", dynamic_ncols=True)
|
| 117 |
+
with torch.inference_mode():
|
| 118 |
+
csv_file_name_sisnr = os.path.join(os.path.dirname(__file__),'test_SISNRi_value.csv')
|
| 119 |
+
csv_file_name_sdr = os.path.join(os.path.dirname(__file__),'test_SDRi_value.csv')
|
| 120 |
+
with open(csv_file_name_sisnr, 'w', newline='') as csvfile_sisnr, open(csv_file_name_sdr, 'w', newline='') as csvfile_sdr:
|
| 121 |
+
idx = 0
|
| 122 |
+
writer_sisnr = csv.writer(csvfile_sisnr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
| 123 |
+
writer_sdr = csv.writer(csvfile_sdr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
| 124 |
+
for input_sizes, mixture, src, key in dataloader:
|
| 125 |
+
if len(key) > 1:
|
| 126 |
+
raise("batch size is not one!!")
|
| 127 |
+
nnet_input = mixture.to(self.device)
|
| 128 |
+
num_batch += 1
|
| 129 |
+
pbar.update(1)
|
| 130 |
+
estim_src, _ = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 131 |
+
cur_loss_SISNRi, cur_loss_SISNRi_src = self.PIT_SISNRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src, eps=1.0e-15)
|
| 132 |
+
total_loss_SISNRi += cur_loss_SISNRi.item() / self.config['model']['num_spks']
|
| 133 |
+
cur_loss_SDRi, cur_loss_SDRi_src = self.PIT_SDRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src)
|
| 134 |
+
total_loss_SDRi += cur_loss_SDRi.item() / self.config['model']['num_spks']
|
| 135 |
+
writer_sisnr.writerow([key[0][:-4]] + [cur_loss_SISNRi_src[i].item() for i in range(self.config['model']['num_spks'])])
|
| 136 |
+
writer_sdr.writerow([key[0][:-4]] + [cur_loss_SDRi_src[i].item() for i in range(self.config['model']['num_spks'])])
|
| 137 |
+
if self.engine_mode == "test_save":
|
| 138 |
+
if wav_dir == None: wav_dir = os.path.join(os.path.dirname(__file__),"wav_out")
|
| 139 |
+
if wav_dir and not os.path.exists(wav_dir): os.makedirs(wav_dir)
|
| 140 |
+
mixture = torch.squeeze(mixture).cpu().data.numpy()
|
| 141 |
+
sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_mixture.wav'), 0.5*mixture/max(abs(mixture)), 8000)
|
| 142 |
+
for i in range(self.config['model']['num_spks']):
|
| 143 |
+
src = torch.squeeze(estim_src[i]).cpu().data.numpy()
|
| 144 |
+
sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_out_'+str(i)+'.wav'), 0.5*src/max(abs(src)), 8000)
|
| 145 |
+
idx += 1
|
| 146 |
+
dict_loss = {"SiSNRi": total_loss_SISNRi/num_batch, "SDRi": total_loss_SDRi/num_batch}
|
| 147 |
+
pbar.set_postfix(dict_loss)
|
| 148 |
+
pbar.close()
|
| 149 |
+
return total_loss_SISNRi/num_batch, total_loss_SDRi/num_batch, num_batch
|
| 150 |
+
|
| 151 |
+
@logger_wraps()
|
| 152 |
+
def run(self):
|
| 153 |
+
with torch.cuda.device(self.device):
|
| 154 |
+
writer_src = SummaryWriter(os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/tensorboard"))
|
| 155 |
+
if "test" in self.engine_mode:
|
| 156 |
+
on_test_start = time.time()
|
| 157 |
+
test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'], self.out_wav_dir)
|
| 158 |
+
on_test_end = time.time()
|
| 159 |
+
logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
|
| 160 |
+
logger.info(f"Testing done!")
|
| 161 |
+
else:
|
| 162 |
+
start_time = time.time()
|
| 163 |
+
if self.start_epoch > 1:
|
| 164 |
+
init_loss_time, init_loss_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
|
| 165 |
+
else:
|
| 166 |
+
init_loss_time, init_loss_freq = 0, 0
|
| 167 |
+
end_time = time.time()
|
| 168 |
+
logger.info(f"[INIT] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: Loss_t = {init_loss_time:.4f} dB | Loss_f = {init_loss_freq:.4f} dB | Speed = ({end_time-start_time:.2f}s)")
|
| 169 |
+
for epoch in range(self.start_epoch, self.config['engine']['max_epoch']):
|
| 170 |
+
valid_loss_best = init_loss_time
|
| 171 |
+
train_start_time = time.time()
|
| 172 |
+
train_loss_src_time, train_loss_src_freq, train_num_batch = self._train(self.dataloaders['train'], epoch)
|
| 173 |
+
train_end_time = time.time()
|
| 174 |
+
valid_start_time = time.time()
|
| 175 |
+
valid_loss_src_time, valid_loss_src_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
|
| 176 |
+
valid_end_time = time.time()
|
| 177 |
+
if epoch > self.config['engine']['start_scheduling']: self.main_scheduler.step(valid_loss_src_time)
|
| 178 |
+
logger.info(f"[TRAIN] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {train_loss_src_time:.4f} dB | Loss_f = {train_loss_src_freq:.4f} dB | Speed = ({train_end_time - train_start_time:.2f}s/{train_num_batch:d})")
|
| 179 |
+
logger.info(f"[VALID] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {valid_loss_src_time:.4f} dB | Loss_f = {valid_loss_src_freq:.4f} dB | Speed = ({valid_end_time - valid_start_time:.2f}s/{valid_num_batch:d})")
|
| 180 |
+
if epoch in self.config['engine']['test_epochs']:
|
| 181 |
+
on_test_start = time.time()
|
| 182 |
+
test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'])
|
| 183 |
+
on_test_end = time.time()
|
| 184 |
+
logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
|
| 185 |
+
valid_loss_best = util_engine.save_checkpoint_per_best(valid_loss_best, valid_loss_src_time, train_loss_src_time, epoch, self.model, self.main_optimizer, self.checkpoint_path)
|
| 186 |
+
# Logging to monitoring tools (Tensorboard && Wandb)
|
| 187 |
+
writer_src.add_scalars("Metrics", {
|
| 188 |
+
'Loss_train_time': train_loss_src_time,
|
| 189 |
+
'Loss_valid_time': valid_loss_src_time}, epoch)
|
| 190 |
+
writer_src.add_scalar("Learning Rate", self.main_optimizer.param_groups[0]['lr'], epoch)
|
| 191 |
+
writer_src.flush()
|
| 192 |
+
logger.info(f"Training for {self.config['engine']['max_epoch']} epoches done!")
|
models/SepReformer/SepReformer_Large_DM_WHAM/main.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from .dataset import get_dataloaders
|
| 5 |
+
from .model import Model
|
| 6 |
+
from .engine import Engine
|
| 7 |
+
from utils import util_system, util_implement
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
|
| 10 |
+
# Setup logger
|
| 11 |
+
log_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/system_log.log")
|
| 12 |
+
logger.add(log_file_path, level="DEBUG", mode="w")
|
| 13 |
+
|
| 14 |
+
@logger_wraps()
|
| 15 |
+
def main(args):
|
| 16 |
+
|
| 17 |
+
''' Build Setting '''
|
| 18 |
+
# Call configuration file (configs.yaml)
|
| 19 |
+
yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs.yaml")
|
| 20 |
+
yaml_dict = util_system.parse_yaml(yaml_path)
|
| 21 |
+
|
| 22 |
+
# Run wandb and get configuration
|
| 23 |
+
config = yaml_dict["config"] # wandb login success or fail
|
| 24 |
+
|
| 25 |
+
# Call DataLoader [train / valid / test / etc...]
|
| 26 |
+
dataloaders = get_dataloaders(args, config["dataset"], config["dataloader"])
|
| 27 |
+
|
| 28 |
+
''' Build Model '''
|
| 29 |
+
# Call network model
|
| 30 |
+
model = Model(**config["model"])
|
| 31 |
+
|
| 32 |
+
''' Build Engine '''
|
| 33 |
+
# Call gpu id & device
|
| 34 |
+
gpuid = tuple(map(int, config["engine"]["gpuid"].split(',')))
|
| 35 |
+
device = torch.device(f'cuda:{gpuid[0]}')
|
| 36 |
+
|
| 37 |
+
# Call Implement [criterion / optimizer / scheduler]
|
| 38 |
+
criterions = util_implement.CriterionFactory(config["criterion"], device).get_criterions()
|
| 39 |
+
optimizers = util_implement.OptimizerFactory(config["optimizer"], model.parameters()).get_optimizers()
|
| 40 |
+
schedulers = util_implement.SchedulerFactory(config["scheduler"], optimizers).get_schedulers()
|
| 41 |
+
|
| 42 |
+
# Call & Run Engine
|
| 43 |
+
engine = Engine(args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device)
|
| 44 |
+
engine.run()
|
models/SepReformer/SepReformer_Large_DM_WHAM/model.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('../')
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings('ignore')
|
| 7 |
+
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from .modules.module import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@logger_wraps()
|
| 13 |
+
class Model(torch.nn.Module):
|
| 14 |
+
def __init__(self,
|
| 15 |
+
num_stages: int,
|
| 16 |
+
num_spks: int,
|
| 17 |
+
module_audio_enc: dict,
|
| 18 |
+
module_feature_projector: dict,
|
| 19 |
+
module_separator: dict,
|
| 20 |
+
module_output_layer: dict,
|
| 21 |
+
module_audio_dec: dict):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.num_stages = num_stages
|
| 24 |
+
self.num_spks = num_spks
|
| 25 |
+
self.audio_encoder = AudioEncoder(**module_audio_enc)
|
| 26 |
+
self.feature_projector = FeatureProjector(**module_feature_projector)
|
| 27 |
+
self.separator = Separator(**module_separator)
|
| 28 |
+
self.out_layer = OutputLayer(**module_output_layer)
|
| 29 |
+
self.audio_decoder = AudioDecoder(**module_audio_dec)
|
| 30 |
+
|
| 31 |
+
# Aux_loss
|
| 32 |
+
self.out_layer_bn = torch.nn.ModuleList([])
|
| 33 |
+
self.decoder_bn = torch.nn.ModuleList([])
|
| 34 |
+
for _ in range(self.num_stages):
|
| 35 |
+
self.out_layer_bn.append(OutputLayer(**module_output_layer, masking=True))
|
| 36 |
+
self.decoder_bn.append(AudioDecoder(**module_audio_dec))
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
encoder_output = self.audio_encoder(x)
|
| 40 |
+
projected_feature = self.feature_projector(encoder_output)
|
| 41 |
+
last_stage_output, each_stage_outputs = self.separator(projected_feature)
|
| 42 |
+
out_layer_output = self.out_layer(last_stage_output, encoder_output)
|
| 43 |
+
each_spk_output = [out_layer_output[idx] for idx in range(self.num_spks)]
|
| 44 |
+
audio = [self.audio_decoder(each_spk_output[idx]) for idx in range(self.num_spks)]
|
| 45 |
+
|
| 46 |
+
# Aux_loss
|
| 47 |
+
audio_aux = []
|
| 48 |
+
for idx, each_stage_output in enumerate(each_stage_outputs):
|
| 49 |
+
each_stage_output = self.out_layer_bn[idx](torch.nn.functional.upsample(each_stage_output, encoder_output.shape[-1]), encoder_output)
|
| 50 |
+
out_aux = [each_stage_output[jdx] for jdx in range(self.num_spks)]
|
| 51 |
+
audio_aux.append([self.decoder_bn[idx](out_aux[jdx])[...,:x.shape[-1]] for jdx in range(self.num_spks)])
|
| 52 |
+
|
| 53 |
+
return audio, audio_aux
|
models/SepReformer/SepReformer_Large_DM_WHAM/modules/module.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('../')
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings('ignore')
|
| 7 |
+
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from .network import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AudioEncoder(torch.nn.Module):
|
| 13 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, bias: bool):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.conv1d = torch.nn.Conv1d(
|
| 16 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, groups=groups, bias=bias)
|
| 17 |
+
self.gelu = torch.nn.GELU()
|
| 18 |
+
|
| 19 |
+
def forward(self, x: torch.Tensor):
|
| 20 |
+
x = torch.unsqueeze(x, dim=0) if len(x.shape) == 1 else torch.unsqueeze(x, dim=1) # [T] - >[1, T] OR [B, T] -> [B, 1, T]
|
| 21 |
+
x = self.conv1d(x)
|
| 22 |
+
x = self.gelu(x)
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
class FeatureProjector(torch.nn.Module):
|
| 26 |
+
def __init__(self, num_channels: int, in_channels: int, out_channels: int, kernel_size: int, bias: bool):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.norm = torch.nn.GroupNorm(num_groups=1, num_channels=num_channels, eps=1e-8)
|
| 29 |
+
self.conv1d = torch.nn.Conv1d(
|
| 30 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor):
|
| 33 |
+
x = self.norm(x)
|
| 34 |
+
x = self.conv1d(x)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Separator(torch.nn.Module):
|
| 39 |
+
def __init__(self, num_stages: int, relative_positional_encoding: dict, enc_stage: dict, spk_split_stage: dict, simple_fusion:dict, dec_stage: dict):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
class RelativePositionalEncoding(torch.nn.Module):
|
| 43 |
+
def __init__(self, in_channels: int, num_heads: int, maxlen: int, embed_v=False):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.in_channels = in_channels
|
| 46 |
+
self.num_heads = num_heads
|
| 47 |
+
self.embedding_dim = self.in_channels // self.num_heads
|
| 48 |
+
self.maxlen = maxlen
|
| 49 |
+
self.pe_k = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim)
|
| 50 |
+
self.pe_v = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim) if embed_v else None
|
| 51 |
+
|
| 52 |
+
def forward(self, pos_seq: torch.Tensor):
|
| 53 |
+
pos_seq.clamp_(-self.maxlen, self.maxlen - 1)
|
| 54 |
+
pos_seq += self.maxlen
|
| 55 |
+
pe_k_output = self.pe_k(pos_seq)
|
| 56 |
+
pe_v_output = self.pe_v(pos_seq) if self.pe_v is not None else None
|
| 57 |
+
return pe_k_output, pe_v_output
|
| 58 |
+
|
| 59 |
+
class SepEncStage(torch.nn.Module):
|
| 60 |
+
def __init__(self, global_blocks: dict, local_blocks: dict, down_conv_layer: dict, down_conv=True):
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
class DownConvLayer(torch.nn.Module):
|
| 64 |
+
def __init__(self, in_channels: int, samp_kernel_size: int):
|
| 65 |
+
"""Construct an EncoderLayer object."""
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.down_conv = torch.nn.Conv1d(
|
| 68 |
+
in_channels=in_channels, out_channels=in_channels, kernel_size=samp_kernel_size, stride=2, padding=(samp_kernel_size-1)//2, groups=in_channels)
|
| 69 |
+
self.BN = torch.nn.BatchNorm1d(num_features=in_channels)
|
| 70 |
+
self.gelu = torch.nn.GELU()
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor):
|
| 73 |
+
x = x.permute([0, 2, 1])
|
| 74 |
+
x = self.down_conv(x)
|
| 75 |
+
x = self.BN(x)
|
| 76 |
+
x = self.gelu(x)
|
| 77 |
+
x = x.permute([0, 2, 1])
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
self.g_block_1 = GlobalBlock(**global_blocks)
|
| 81 |
+
self.l_block_1 = LocalBlock(**local_blocks)
|
| 82 |
+
|
| 83 |
+
self.g_block_2 = GlobalBlock(**global_blocks)
|
| 84 |
+
self.l_block_2 = LocalBlock(**local_blocks)
|
| 85 |
+
|
| 86 |
+
self.downconv = DownConvLayer(**down_conv_layer) if down_conv == True else None
|
| 87 |
+
|
| 88 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 89 |
+
'''
|
| 90 |
+
x: [B, N, T]
|
| 91 |
+
'''
|
| 92 |
+
x = self.g_block_1(x, pos_k)
|
| 93 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 94 |
+
x = self.l_block_1(x)
|
| 95 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 96 |
+
|
| 97 |
+
x = self.g_block_2(x, pos_k)
|
| 98 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 99 |
+
x = self.l_block_2(x)
|
| 100 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 101 |
+
|
| 102 |
+
skip = x
|
| 103 |
+
if self.downconv:
|
| 104 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 105 |
+
x = self.downconv(x)
|
| 106 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 107 |
+
# [BK, S, N]
|
| 108 |
+
return x, skip
|
| 109 |
+
|
| 110 |
+
class SpkSplitStage(torch.nn.Module):
|
| 111 |
+
def __init__(self, in_channels: int, num_spks: int):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.linear = torch.nn.Sequential(
|
| 114 |
+
torch.nn.Conv1d(in_channels, 4*in_channels*num_spks, kernel_size=1),
|
| 115 |
+
torch.nn.GLU(dim=-2),
|
| 116 |
+
torch.nn.Conv1d(2*in_channels*num_spks, in_channels*num_spks, kernel_size=1))
|
| 117 |
+
self.norm = torch.nn.GroupNorm(1, in_channels, eps=1e-8)
|
| 118 |
+
self.num_spks = num_spks
|
| 119 |
+
|
| 120 |
+
def forward(self, x: torch.Tensor):
|
| 121 |
+
x = self.linear(x)
|
| 122 |
+
B, _, T = x.shape
|
| 123 |
+
x = x.view(B*self.num_spks,-1, T).contiguous()
|
| 124 |
+
x = self.norm(x)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
class SepDecStage(torch.nn.Module):
|
| 128 |
+
def __init__(self, num_spks: int, global_blocks: dict, local_blocks: dict, spk_attention: dict):
|
| 129 |
+
super().__init__()
|
| 130 |
+
|
| 131 |
+
self.g_block_1 = GlobalBlock(**global_blocks)
|
| 132 |
+
self.l_block_1 = LocalBlock(**local_blocks)
|
| 133 |
+
self.spk_attn_1 = SpkAttention(**spk_attention)
|
| 134 |
+
|
| 135 |
+
self.g_block_2 = GlobalBlock(**global_blocks)
|
| 136 |
+
self.l_block_2 = LocalBlock(**local_blocks)
|
| 137 |
+
self.spk_attn_2 = SpkAttention(**spk_attention)
|
| 138 |
+
|
| 139 |
+
self.g_block_3 = GlobalBlock(**global_blocks)
|
| 140 |
+
self.l_block_3 = LocalBlock(**local_blocks)
|
| 141 |
+
self.spk_attn_3 = SpkAttention(**spk_attention)
|
| 142 |
+
|
| 143 |
+
self.num_spk = num_spks
|
| 144 |
+
|
| 145 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 146 |
+
'''
|
| 147 |
+
x: [B, N, T]
|
| 148 |
+
'''
|
| 149 |
+
# [BS, K, H]
|
| 150 |
+
x = self.g_block_1(x, pos_k)
|
| 151 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 152 |
+
x = self.l_block_1(x)
|
| 153 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 154 |
+
x = self.spk_attn_1(x, self.num_spk)
|
| 155 |
+
|
| 156 |
+
x = self.g_block_2(x, pos_k)
|
| 157 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 158 |
+
x = self.l_block_2(x)
|
| 159 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 160 |
+
x = self.spk_attn_2(x, self.num_spk)
|
| 161 |
+
|
| 162 |
+
x = self.g_block_3(x, pos_k)
|
| 163 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 164 |
+
x = self.l_block_3(x)
|
| 165 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 166 |
+
x = self.spk_attn_3(x, self.num_spk)
|
| 167 |
+
|
| 168 |
+
skip = x
|
| 169 |
+
|
| 170 |
+
return x, skip
|
| 171 |
+
|
| 172 |
+
self.num_stages = num_stages
|
| 173 |
+
self.pos_emb = RelativePositionalEncoding(**relative_positional_encoding)
|
| 174 |
+
|
| 175 |
+
# Temporal Contracting Part
|
| 176 |
+
self.enc_stages = torch.nn.ModuleList([])
|
| 177 |
+
for _ in range(self.num_stages):
|
| 178 |
+
self.enc_stages.append(SepEncStage(**enc_stage, down_conv=True))
|
| 179 |
+
|
| 180 |
+
self.bottleneck_G = SepEncStage(**enc_stage, down_conv=False)
|
| 181 |
+
|
| 182 |
+
self.spk_split_blocks = torch.nn.ModuleList([])
|
| 183 |
+
for _ in range(self.num_stages+1):
|
| 184 |
+
self.spk_split_blocks.append(SpkSplitStage(**spk_split_stage))
|
| 185 |
+
|
| 186 |
+
# Temporal Expanding Part
|
| 187 |
+
self.simple_fusion = torch.nn.ModuleList([])
|
| 188 |
+
self.dec_stages = torch.nn.ModuleList([])
|
| 189 |
+
for _ in range(self.num_stages):
|
| 190 |
+
self.simple_fusion.append(torch.nn.Conv1d(in_channels=simple_fusion['out_channels']*2,out_channels=simple_fusion['out_channels'], kernel_size=1))
|
| 191 |
+
self.dec_stages.append(SepDecStage(**dec_stage))
|
| 192 |
+
|
| 193 |
+
def forward(self, input: torch.Tensor):
|
| 194 |
+
'''input: [B, N, L]'''
|
| 195 |
+
# feature projection
|
| 196 |
+
x, _ = self.pad_signal(input)
|
| 197 |
+
len_x = x.shape[-1]
|
| 198 |
+
# Temporal Contracting Part
|
| 199 |
+
pos_seq = torch.arange(0, len_x//2**self.num_stages).long().to(x.device)
|
| 200 |
+
pos_seq = pos_seq[:, None] - pos_seq[None, :]
|
| 201 |
+
pos_k, _ = self.pos_emb(pos_seq)
|
| 202 |
+
skip = []
|
| 203 |
+
for idx in range(self.num_stages):
|
| 204 |
+
x, skip_ = self.enc_stages[idx](x, pos_k)
|
| 205 |
+
skip_ = self.spk_split_blocks[idx](skip_)
|
| 206 |
+
skip.append(skip_)
|
| 207 |
+
x, _ = self.bottleneck_G(x, pos_k)
|
| 208 |
+
x = self.spk_split_blocks[-1](x) # B, 2F, T
|
| 209 |
+
|
| 210 |
+
each_stage_outputs = []
|
| 211 |
+
# Temporal Expanding Part
|
| 212 |
+
for idx in range(self.num_stages):
|
| 213 |
+
each_stage_outputs.append(x)
|
| 214 |
+
idx_en = self.num_stages - (idx + 1)
|
| 215 |
+
x = torch.nn.functional.upsample(x, skip[idx_en].shape[-1])
|
| 216 |
+
x = torch.cat([x,skip[idx_en]],dim=1)
|
| 217 |
+
x = self.simple_fusion[idx](x)
|
| 218 |
+
x, _ = self.dec_stages[idx](x, pos_k)
|
| 219 |
+
|
| 220 |
+
last_stage_output = x
|
| 221 |
+
return last_stage_output, each_stage_outputs
|
| 222 |
+
|
| 223 |
+
def pad_signal(self, input: torch.Tensor):
|
| 224 |
+
# (B, T) or (B, 1, T)
|
| 225 |
+
if input.dim() == 1: input = input.unsqueeze(0)
|
| 226 |
+
elif input.dim() not in [2, 3]: raise RuntimeError("Input can only be 2 or 3 dimensional.")
|
| 227 |
+
elif input.dim() == 2: input = input.unsqueeze(1)
|
| 228 |
+
L = 2**self.num_stages
|
| 229 |
+
batch_size = input.size(0)
|
| 230 |
+
ndim = input.size(1)
|
| 231 |
+
nframe = input.size(2)
|
| 232 |
+
padded_len = (nframe//L + 1)*L
|
| 233 |
+
rest = 0 if nframe%L == 0 else padded_len - nframe
|
| 234 |
+
if rest > 0:
|
| 235 |
+
pad = torch.autograd.Variable(torch.zeros(batch_size, ndim, rest)).type(input.type()).to(input.device)
|
| 236 |
+
input = torch.cat([input, pad], dim=-1)
|
| 237 |
+
return input, rest
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class OutputLayer(torch.nn.Module):
|
| 241 |
+
def __init__(self, in_channels: int, out_channels: int, num_spks: int, masking: bool = False):
|
| 242 |
+
super().__init__()
|
| 243 |
+
# feature expansion back
|
| 244 |
+
self.masking = masking
|
| 245 |
+
self.spe_block = Masking(in_channels, Activation_mask="ReLU", concat_opt=None)
|
| 246 |
+
self.num_spks = num_spks
|
| 247 |
+
self.end_conv1x1 = torch.nn.Sequential(
|
| 248 |
+
torch.nn.Linear(out_channels, 4*out_channels),
|
| 249 |
+
torch.nn.GLU(),
|
| 250 |
+
torch.nn.Linear(2*out_channels, in_channels))
|
| 251 |
+
|
| 252 |
+
def forward(self, x: torch.Tensor, input: torch.Tensor):
|
| 253 |
+
x = x[...,:input.shape[-1]]
|
| 254 |
+
x = x.permute([0, 2, 1])
|
| 255 |
+
x = self.end_conv1x1(x)
|
| 256 |
+
x = x.permute([0, 2, 1])
|
| 257 |
+
B, N, L = x.shape
|
| 258 |
+
B = B // self.num_spks
|
| 259 |
+
|
| 260 |
+
if self.masking:
|
| 261 |
+
input = input.expand(self.num_spks, B, N, L).transpose(0,1).contiguous()
|
| 262 |
+
input = input.view(B*self.num_spks, N, L)
|
| 263 |
+
x = self.spe_block(x, input)
|
| 264 |
+
|
| 265 |
+
x = x.view(B, self.num_spks, N, L)
|
| 266 |
+
# [spks, B, N, L]
|
| 267 |
+
x = x.transpose(0, 1)
|
| 268 |
+
return x
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class AudioDecoder(torch.nn.ConvTranspose1d):
|
| 272 |
+
'''
|
| 273 |
+
Decoder of the TasNet
|
| 274 |
+
This module can be seen as the gradient of Conv1d with respect to its input.
|
| 275 |
+
It is also known as a fractionally-strided convolution
|
| 276 |
+
or a deconvolution (although it is not an actual deconvolution operation).
|
| 277 |
+
'''
|
| 278 |
+
def __init__(self, *args, **kwargs):
|
| 279 |
+
super().__init__(*args, **kwargs)
|
| 280 |
+
|
| 281 |
+
def forward(self, x):
|
| 282 |
+
# x: [B, N, L]
|
| 283 |
+
if x.dim() not in [2, 3]: raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
|
| 284 |
+
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
|
| 285 |
+
x = torch.squeeze(x, dim=1) if torch.squeeze(x).dim() == 1 else torch.squeeze(x)
|
| 286 |
+
return x
|
models/SepReformer/SepReformer_Large_DM_WHAM/modules/network.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import numpy
|
| 4 |
+
from utils.decorators import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LayerScale(torch.nn.Module):
|
| 8 |
+
def __init__(self, dims, input_size, Layer_scale_init=1.0e-5):
|
| 9 |
+
super().__init__()
|
| 10 |
+
if dims == 1:
|
| 11 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(input_size)*Layer_scale_init, requires_grad=True)
|
| 12 |
+
elif dims == 2:
|
| 13 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(1,input_size)*Layer_scale_init, requires_grad=True)
|
| 14 |
+
elif dims == 3:
|
| 15 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(1,1,input_size)*Layer_scale_init, requires_grad=True)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return x*self.layer_scale
|
| 19 |
+
|
| 20 |
+
class Masking(torch.nn.Module):
|
| 21 |
+
def __init__(self, input_dim, Activation_mask='Sigmoid', **options):
|
| 22 |
+
super(Masking, self).__init__()
|
| 23 |
+
|
| 24 |
+
self.options = options
|
| 25 |
+
if self.options['concat_opt']:
|
| 26 |
+
self.pw_conv = torch.nn.Conv1d(input_dim*2, input_dim, 1, stride=1, padding=0)
|
| 27 |
+
|
| 28 |
+
if Activation_mask == 'Sigmoid':
|
| 29 |
+
self.gate_act = torch.nn.Sigmoid()
|
| 30 |
+
elif Activation_mask == 'ReLU':
|
| 31 |
+
self.gate_act = torch.nn.ReLU()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def forward(self, x, skip):
|
| 35 |
+
|
| 36 |
+
if self.options['concat_opt']:
|
| 37 |
+
y = torch.cat([x, skip], dim=-2)
|
| 38 |
+
y = self.pw_conv(y)
|
| 39 |
+
else:
|
| 40 |
+
y = x
|
| 41 |
+
y = self.gate_act(y) * skip
|
| 42 |
+
|
| 43 |
+
return y
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GCFN(torch.nn.Module):
|
| 47 |
+
def __init__(self, in_channels, dropout_rate, Layer_scale_init=1.0e-5):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.net1 = torch.nn.Sequential(
|
| 50 |
+
torch.nn.LayerNorm(in_channels),
|
| 51 |
+
torch.nn.Linear(in_channels, in_channels*6))
|
| 52 |
+
self.depthwise = torch.nn.Conv1d(in_channels*6, in_channels*6, 3, padding=1, groups=in_channels*6)
|
| 53 |
+
self.net2 = torch.nn.Sequential(
|
| 54 |
+
torch.nn.GLU(),
|
| 55 |
+
torch.nn.Dropout(dropout_rate),
|
| 56 |
+
torch.nn.Linear(in_channels*3, in_channels),
|
| 57 |
+
torch.nn.Dropout(dropout_rate))
|
| 58 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
y = self.net1(x)
|
| 62 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 63 |
+
y = self.depthwise(y)
|
| 64 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 65 |
+
y = self.net2(y)
|
| 66 |
+
return x + self.Layer_scale(y)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MultiHeadAttention(torch.nn.Module):
|
| 70 |
+
"""
|
| 71 |
+
Multi-Head Attention layer.
|
| 72 |
+
:param int n_head: the number of head s
|
| 73 |
+
:param int n_feat: the number of features
|
| 74 |
+
:param float dropout_rate: dropout rate
|
| 75 |
+
"""
|
| 76 |
+
def __init__(self, n_head: int, in_channels: int, dropout_rate: float, Layer_scale_init=1.0e-5):
|
| 77 |
+
super().__init__()
|
| 78 |
+
assert in_channels % n_head == 0
|
| 79 |
+
self.d_k = in_channels // n_head # We assume d_v always equals d_k
|
| 80 |
+
self.h = n_head
|
| 81 |
+
self.layer_norm = torch.nn.LayerNorm(in_channels)
|
| 82 |
+
self.linear_q = torch.nn.Linear(in_channels, in_channels)
|
| 83 |
+
self.linear_k = torch.nn.Linear(in_channels, in_channels)
|
| 84 |
+
self.linear_v = torch.nn.Linear(in_channels, in_channels)
|
| 85 |
+
self.linear_out = torch.nn.Linear(in_channels, in_channels)
|
| 86 |
+
self.attn = None
|
| 87 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 88 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, pos_k, mask):
|
| 91 |
+
"""
|
| 92 |
+
Compute 'Scaled Dot Product Attention'.
|
| 93 |
+
:param torch.Tensor mask: (batch, time1, time2)
|
| 94 |
+
:param torch.nn.Dropout dropout:
|
| 95 |
+
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
|
| 96 |
+
weighted by the query dot key attention (batch, head, time1, time2)
|
| 97 |
+
"""
|
| 98 |
+
n_batch = x.size(0)
|
| 99 |
+
x = self.layer_norm(x)
|
| 100 |
+
q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
|
| 101 |
+
k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
|
| 102 |
+
v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k)
|
| 103 |
+
q = q.transpose(1, 2)
|
| 104 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 105 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 106 |
+
A = torch.matmul(q, k.transpose(-2, -1))
|
| 107 |
+
reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1)
|
| 108 |
+
if pos_k is not None:
|
| 109 |
+
B = torch.matmul(reshape_q, pos_k.transpose(-2, -1))
|
| 110 |
+
B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
|
| 111 |
+
scores = (A + B) / math.sqrt(self.d_k)
|
| 112 |
+
else:
|
| 113 |
+
scores = A / math.sqrt(self.d_k)
|
| 114 |
+
if mask is not None:
|
| 115 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
| 116 |
+
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
| 117 |
+
scores = scores.masked_fill(mask, min_value)
|
| 118 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
| 119 |
+
else:
|
| 120 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 121 |
+
p_attn = self.dropout(self.attn)
|
| 122 |
+
x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
|
| 123 |
+
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
|
| 124 |
+
return self.Layer_scale(self.dropout(self.linear_out(x))) # (batch, time1, d_model)
|
| 125 |
+
|
| 126 |
+
class EGA(torch.nn.Module):
|
| 127 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.block = torch.nn.ModuleDict({
|
| 130 |
+
'self_attn': MultiHeadAttention(
|
| 131 |
+
n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
|
| 132 |
+
'linear': torch.nn.Sequential(
|
| 133 |
+
torch.nn.LayerNorm(normalized_shape=in_channels),
|
| 134 |
+
torch.nn.Linear(in_features=in_channels, out_features=in_channels),
|
| 135 |
+
torch.nn.Sigmoid())
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 139 |
+
"""
|
| 140 |
+
Compute encoded features.
|
| 141 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 142 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 143 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 144 |
+
"""
|
| 145 |
+
down_len = pos_k.shape[0]
|
| 146 |
+
x_down = torch.nn.functional.adaptive_avg_pool1d(input=x, output_size=down_len)
|
| 147 |
+
x = x.permute([0, 2, 1])
|
| 148 |
+
x_down = x_down.permute([0, 2, 1])
|
| 149 |
+
x_down = self.block['self_attn'](x_down, pos_k, None)
|
| 150 |
+
x_down = x_down.permute([0, 2, 1])
|
| 151 |
+
x_downup = torch.nn.functional.upsample(input=x_down, size=x.shape[1])
|
| 152 |
+
x_downup = x_downup.permute([0, 2, 1])
|
| 153 |
+
x = x + self.block['linear'](x) * x_downup
|
| 154 |
+
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class CLA(torch.nn.Module):
|
| 160 |
+
def __init__(self, in_channels, kernel_size, dropout_rate, Layer_scale_init=1.0e-5):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.layer_norm = torch.nn.LayerNorm(in_channels)
|
| 163 |
+
self.linear1 = torch.nn.Linear(in_channels, in_channels*2)
|
| 164 |
+
self.GLU = torch.nn.GLU()
|
| 165 |
+
self.dw_conv_1d = torch.nn.Conv1d(in_channels, in_channels, kernel_size, padding='same', groups=in_channels)
|
| 166 |
+
self.linear2 = torch.nn.Linear(in_channels, 2*in_channels)
|
| 167 |
+
self.BN = torch.nn.BatchNorm1d(2*in_channels)
|
| 168 |
+
self.linear3 = torch.nn.Sequential(
|
| 169 |
+
torch.nn.GELU(),
|
| 170 |
+
torch.nn.Linear(2*in_channels, in_channels),
|
| 171 |
+
torch.nn.Dropout(dropout_rate))
|
| 172 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
y = self.layer_norm(x)
|
| 176 |
+
y = self.linear1(y)
|
| 177 |
+
y = self.GLU(y)
|
| 178 |
+
y = y.permute([0, 2, 1]) # B, F, T
|
| 179 |
+
y = self.dw_conv_1d(y)
|
| 180 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 181 |
+
y = self.linear2(y)
|
| 182 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 183 |
+
y = self.BN(y)
|
| 184 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 185 |
+
y = self.linear3(y)
|
| 186 |
+
|
| 187 |
+
return x + self.Layer_scale(y)
|
| 188 |
+
|
| 189 |
+
class GlobalBlock(torch.nn.Module):
|
| 190 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.block = torch.nn.ModuleDict({
|
| 193 |
+
'ega': EGA(
|
| 194 |
+
num_mha_heads=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
|
| 195 |
+
'gcfn': GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
|
| 196 |
+
})
|
| 197 |
+
|
| 198 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 199 |
+
"""
|
| 200 |
+
Compute encoded features.
|
| 201 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 202 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 203 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 204 |
+
"""
|
| 205 |
+
x = self.block['ega'](x, pos_k)
|
| 206 |
+
x = self.block['gcfn'](x)
|
| 207 |
+
x = x.permute([0, 2, 1])
|
| 208 |
+
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class LocalBlock(torch.nn.Module):
|
| 213 |
+
def __init__(self, in_channels: int, kernel_size: int, dropout_rate: float):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.block = torch.nn.ModuleDict({
|
| 216 |
+
'cla': CLA(in_channels, kernel_size, dropout_rate),
|
| 217 |
+
'gcfn': GCFN(in_channels, dropout_rate)
|
| 218 |
+
})
|
| 219 |
+
|
| 220 |
+
def forward(self, x: torch.Tensor):
|
| 221 |
+
x = self.block['cla'](x)
|
| 222 |
+
x = self.block['gcfn'](x)
|
| 223 |
+
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class SpkAttention(torch.nn.Module):
|
| 228 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.self_attn = MultiHeadAttention(n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate)
|
| 231 |
+
self.feed_forward = GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
|
| 232 |
+
|
| 233 |
+
def forward(self, x: torch.Tensor, num_spk: int):
|
| 234 |
+
"""
|
| 235 |
+
Compute encoded features.
|
| 236 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 237 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 238 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 239 |
+
"""
|
| 240 |
+
B, F, T = x.shape
|
| 241 |
+
x = x.view(B//num_spk, num_spk, F, T).contiguous()
|
| 242 |
+
x = x.permute([0, 3, 1, 2]).contiguous()
|
| 243 |
+
x = x.view(-1, num_spk, F).contiguous()
|
| 244 |
+
x = x + self.self_attn(x, None, None)
|
| 245 |
+
x = x.view(B//num_spk, T, num_spk, F).contiguous()
|
| 246 |
+
x = x.permute([0, 2, 3, 1]).contiguous()
|
| 247 |
+
x = x.view(B, F, T).contiguous()
|
| 248 |
+
x = x.permute([0, 2, 1])
|
| 249 |
+
x = self.feed_forward(x)
|
| 250 |
+
x = x.permute([0, 2, 1])
|
| 251 |
+
|
| 252 |
+
return x
|
models/SepReformer/SepReformer_Large_DM_WHAMR/configs.yaml
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
config:
|
| 2 |
+
dataset:
|
| 3 |
+
max_len : 32000
|
| 4 |
+
sampling_rate: 8000
|
| 5 |
+
scp_dir: "data/scp_ss_8k_whamr"
|
| 6 |
+
train:
|
| 7 |
+
mixture: "tr_mix.scp"
|
| 8 |
+
spk1: "tr_s1.scp"
|
| 9 |
+
spk2: "tr_s2.scp"
|
| 10 |
+
spk1_reverb: "tr_s1_reverb.scp"
|
| 11 |
+
spk2_reverb: "tr_s2_reverb.scp"
|
| 12 |
+
noise: "tr_n.scp"
|
| 13 |
+
dynamic_mixing: true
|
| 14 |
+
valid:
|
| 15 |
+
mixture: "cv_mix.scp"
|
| 16 |
+
spk1: "cv_s1.scp"
|
| 17 |
+
spk2: "cv_s2.scp"
|
| 18 |
+
test:
|
| 19 |
+
mixture: "tt_mix.scp"
|
| 20 |
+
spk1: "tt_s1.scp"
|
| 21 |
+
spk2: "tt_s2.scp"
|
| 22 |
+
dataloader:
|
| 23 |
+
batch_size: 2
|
| 24 |
+
pin_memory: false
|
| 25 |
+
num_workers: 12
|
| 26 |
+
drop_last: false
|
| 27 |
+
model:
|
| 28 |
+
num_stages: &var_model_num_stages 4 # R
|
| 29 |
+
num_spks: &var_model_num_spks 2
|
| 30 |
+
module_audio_enc:
|
| 31 |
+
in_channels: 1
|
| 32 |
+
out_channels: &var_model_audio_enc_out_channels 256
|
| 33 |
+
kernel_size: &var_model_audio_enc_kernel_size 16 # L
|
| 34 |
+
stride: &var_model_audio_enc_stride 4 # S
|
| 35 |
+
groups: 1
|
| 36 |
+
bias: false
|
| 37 |
+
module_feature_projector:
|
| 38 |
+
num_channels: *var_model_audio_enc_out_channels
|
| 39 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 40 |
+
out_channels: &feature_projector_out_channels 256 # F
|
| 41 |
+
kernel_size: 1
|
| 42 |
+
bias: false
|
| 43 |
+
module_separator:
|
| 44 |
+
num_stages: *var_model_num_stages
|
| 45 |
+
relative_positional_encoding:
|
| 46 |
+
in_channels: *feature_projector_out_channels
|
| 47 |
+
num_heads: 8
|
| 48 |
+
maxlen: 2000
|
| 49 |
+
embed_v: false
|
| 50 |
+
enc_stage:
|
| 51 |
+
global_blocks:
|
| 52 |
+
in_channels: *feature_projector_out_channels
|
| 53 |
+
num_mha_heads: 8
|
| 54 |
+
dropout_rate: 0.1
|
| 55 |
+
local_blocks:
|
| 56 |
+
in_channels: *feature_projector_out_channels
|
| 57 |
+
kernel_size: 65
|
| 58 |
+
dropout_rate: 0.1
|
| 59 |
+
down_conv_layer:
|
| 60 |
+
in_channels: *feature_projector_out_channels
|
| 61 |
+
samp_kernel_size: &var_model_samp_kernel_size 5
|
| 62 |
+
spk_split_stage:
|
| 63 |
+
in_channels: *feature_projector_out_channels
|
| 64 |
+
num_spks: *var_model_num_spks
|
| 65 |
+
simple_fusion:
|
| 66 |
+
out_channels: *feature_projector_out_channels
|
| 67 |
+
dec_stage:
|
| 68 |
+
num_spks: *var_model_num_spks
|
| 69 |
+
global_blocks:
|
| 70 |
+
in_channels: *feature_projector_out_channels
|
| 71 |
+
num_mha_heads: 8
|
| 72 |
+
dropout_rate: 0.1
|
| 73 |
+
local_blocks:
|
| 74 |
+
in_channels: *feature_projector_out_channels
|
| 75 |
+
kernel_size: 65
|
| 76 |
+
dropout_rate: 0.1
|
| 77 |
+
spk_attention:
|
| 78 |
+
in_channels: *feature_projector_out_channels
|
| 79 |
+
num_mha_heads: 8
|
| 80 |
+
dropout_rate: 0.1
|
| 81 |
+
module_output_layer:
|
| 82 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 83 |
+
out_channels: *feature_projector_out_channels
|
| 84 |
+
num_spks: *var_model_num_spks
|
| 85 |
+
module_audio_dec:
|
| 86 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 87 |
+
out_channels: 1
|
| 88 |
+
kernel_size: *var_model_audio_enc_kernel_size
|
| 89 |
+
stride: *var_model_audio_enc_stride
|
| 90 |
+
bias: false
|
| 91 |
+
criterion: ### Ref: https://pytorch.org/docs/stable/nn.html#loss-functions
|
| 92 |
+
name: ["PIT_SISNR_mag", "PIT_SISNR_time", "PIT_SISNRi", "PIT_SDRi"] ### Choose a torch.nn's loss function class(=attribute) e.g. ["L1Loss", "MSELoss", "CrossEntropyLoss", ...] / You can also build your optimizer :)
|
| 93 |
+
PIT_SISNR_mag:
|
| 94 |
+
frame_length: 512
|
| 95 |
+
frame_shift: 128
|
| 96 |
+
window: 'hann'
|
| 97 |
+
num_stages: *var_model_num_stages
|
| 98 |
+
num_spks: *var_model_num_spks
|
| 99 |
+
scale_inv: true
|
| 100 |
+
mel_opt: false
|
| 101 |
+
PIT_SISNR_time:
|
| 102 |
+
num_spks: *var_model_num_spks
|
| 103 |
+
scale_inv: true
|
| 104 |
+
PIT_SISNRi:
|
| 105 |
+
num_spks: *var_model_num_spks
|
| 106 |
+
scale_inv: true
|
| 107 |
+
PIT_SDRi:
|
| 108 |
+
dump: 0
|
| 109 |
+
optimizer: ### Ref: https://pytorch.org/docs/stable/optim.html#algorithms
|
| 110 |
+
name: ["AdamW"] ### Choose a torch.optim's class(=attribute) e.g. ["Adam", "AdamW", "SGD", ...] / You can also build your optimizer :)
|
| 111 |
+
AdamW:
|
| 112 |
+
lr: 2.0e-4
|
| 113 |
+
weight_decay: 1.0e-2
|
| 114 |
+
scheduler: ### Ref(+ find "How to adjust learning rate"): https://pytorch.org/docs/stable/optim.html#algorithms
|
| 115 |
+
name: ["ReduceLROnPlateau", "WarmupConstantSchedule"] ### Choose a torch.optim.lr_scheduler's class(=attribute) e.g. ["StepLR", "ReduceLROnPlateau", "Custom"] / You can also build your scheduler :)
|
| 116 |
+
ReduceLROnPlateau:
|
| 117 |
+
mode: "min"
|
| 118 |
+
min_lr: 1.0e-10
|
| 119 |
+
factor: 0.8
|
| 120 |
+
patience: 2
|
| 121 |
+
WarmupConstantSchedule:
|
| 122 |
+
warmup_steps: 1000
|
| 123 |
+
check_computations:
|
| 124 |
+
dummy_len: 16000
|
| 125 |
+
engine:
|
| 126 |
+
max_epoch: 200
|
| 127 |
+
gpuid: "0" ### "0"(single-gpu) or "0, 1" (multi-gpu)
|
| 128 |
+
mvn: false
|
| 129 |
+
clip_norm: 5
|
| 130 |
+
start_scheduling: 50
|
| 131 |
+
test_epochs: [50, 80, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 199]
|
models/SepReformer/SepReformer_Large_DM_WHAMR/dataset.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import librosa as audio_lib
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from utils import util_dataset
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from loguru import logger
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
|
| 12 |
+
@logger_wraps()
|
| 13 |
+
def get_dataloaders(args, dataset_config, loader_config):
|
| 14 |
+
# create dataset object for each partition
|
| 15 |
+
partitions = ["test"] if "test" in args.engine_mode else ["train", "valid", "test"]
|
| 16 |
+
dataloaders = {}
|
| 17 |
+
for partition in partitions:
|
| 18 |
+
scp_config_mix = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['mixture'])
|
| 19 |
+
scp_config_spk = [os.path.join(dataset_config["scp_dir"], dataset_config[partition][spk_key]) for spk_key in dataset_config[partition] if spk_key.startswith('spk')]
|
| 20 |
+
scp_config_noise = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['noise']) if 'noise' in dataset_config[partition] else None
|
| 21 |
+
dynamic_mixing = dataset_config[partition]["dynamic_mixing"] if partition == 'train' else False
|
| 22 |
+
dataset = MyDataset(
|
| 23 |
+
max_len = dataset_config['max_len'],
|
| 24 |
+
fs = dataset_config['sampling_rate'],
|
| 25 |
+
partition = partition,
|
| 26 |
+
wave_scp_srcs = scp_config_spk,
|
| 27 |
+
wave_scp_mix = scp_config_mix,
|
| 28 |
+
wave_scp_noise = scp_config_noise,
|
| 29 |
+
dynamic_mixing = dynamic_mixing)
|
| 30 |
+
dataloader = DataLoader(
|
| 31 |
+
dataset = dataset,
|
| 32 |
+
batch_size = 1 if partition == 'test' else loader_config["batch_size"],
|
| 33 |
+
shuffle = True, # only train: (partition == 'train') / all: True
|
| 34 |
+
pin_memory = loader_config["pin_memory"],
|
| 35 |
+
num_workers = loader_config["num_workers"],
|
| 36 |
+
drop_last = loader_config["drop_last"],
|
| 37 |
+
collate_fn = _collate)
|
| 38 |
+
dataloaders[partition] = dataloader
|
| 39 |
+
return dataloaders
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _collate(egs):
|
| 43 |
+
"""
|
| 44 |
+
Transform utterance index into a minbatch
|
| 45 |
+
|
| 46 |
+
Arguments:
|
| 47 |
+
index: a list type [{},{},{}]
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
input_sizes: a tensor correspond to utterance length
|
| 51 |
+
input_feats: packed sequence to feed networks
|
| 52 |
+
source_attr/target_attr: dictionary contains spectrogram/phase needed in loss computation
|
| 53 |
+
"""
|
| 54 |
+
def __prepare_target_rir(dict_lsit, index):
|
| 55 |
+
return torch.nn.utils.rnn.pad_sequence([torch.tensor(d["src"][index], dtype=torch.float32) for d in dict_lsit], batch_first=True)
|
| 56 |
+
if type(egs) is not list: raise ValueError("Unsupported index type({})".format(type(egs)))
|
| 57 |
+
num_spks = 2 # you need to set this paramater by yourself
|
| 58 |
+
dict_list = sorted([eg for eg in egs], key=lambda x: x['num_sample'], reverse=True)
|
| 59 |
+
mixture = torch.nn.utils.rnn.pad_sequence([torch.tensor(d['mix'], dtype=torch.float32) for d in dict_list], batch_first=True)
|
| 60 |
+
src = [__prepare_target_rir(dict_list, index) for index in range(num_spks)]
|
| 61 |
+
input_sizes = torch.tensor([d['num_sample'] for d in dict_list], dtype=torch.float32)
|
| 62 |
+
key = [d['key'] for d in dict_list]
|
| 63 |
+
return input_sizes, mixture, src, key
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@logger_wraps()
|
| 67 |
+
class MyDataset(Dataset):
|
| 68 |
+
def __init__(self, max_len, fs, partition, wave_scp_srcs, wave_scp_mix, wave_scp_noise, dynamic_mixing, speed_list=None):
|
| 69 |
+
self.partition = partition
|
| 70 |
+
for wave_scp_src in wave_scp_srcs:
|
| 71 |
+
if not os.path.exists(wave_scp_src): raise FileNotFoundError(f"Could not find file {wave_scp_src}")
|
| 72 |
+
self.max_len = max_len
|
| 73 |
+
self.fs = fs
|
| 74 |
+
self.wave_dict_srcs = [util_dataset.parse_scps(wave_scp_src) for wave_scp_src in wave_scp_srcs]
|
| 75 |
+
self.wave_dict_mix = util_dataset.parse_scps(wave_scp_mix)
|
| 76 |
+
self.wave_dict_noise = util_dataset.parse_scps(wave_scp_noise) if wave_scp_noise else None
|
| 77 |
+
self.wave_keys = list(self.wave_dict_mix.keys())
|
| 78 |
+
logger.info(f"Create MyDataset for {wave_scp_mix} with {len(self.wave_dict_mix)} utterances")
|
| 79 |
+
self.dynamic_mixing = dynamic_mixing
|
| 80 |
+
|
| 81 |
+
def __len__(self):
|
| 82 |
+
return len(self.wave_dict_mix)
|
| 83 |
+
|
| 84 |
+
def __contains__(self, key):
|
| 85 |
+
return key in self.wave_dict_mix
|
| 86 |
+
|
| 87 |
+
def _dynamic_mixing(self, key):
|
| 88 |
+
def __match_length(wav, len_data):
|
| 89 |
+
leftover = len(wav) - len_data
|
| 90 |
+
idx = random.randint(0,leftover)
|
| 91 |
+
wav = wav[idx:idx+len_data]
|
| 92 |
+
return wav
|
| 93 |
+
|
| 94 |
+
samps_src_reverb = []
|
| 95 |
+
samps_src = []
|
| 96 |
+
src_len = [self.max_len]
|
| 97 |
+
# dyanmic source choice
|
| 98 |
+
# checking whether it is the same speaker
|
| 99 |
+
key_random = random.choice(list(self.wave_dict_srcs[0].keys()))
|
| 100 |
+
idx1, idx2 = (0, 1) if random.random() > 0.5 else (1, 0)
|
| 101 |
+
files = [self.wave_dict_srcs[idx1][key], self.wave_dict_srcs[idx2][key_random]]
|
| 102 |
+
files_reverb = [self.wave_dict_srcs[idx1+2][key], self.wave_dict_srcs[idx2+2][key_random]]
|
| 103 |
+
|
| 104 |
+
# load
|
| 105 |
+
for idx, file in enumerate(files_reverb):
|
| 106 |
+
if not os.path.exists(file):
|
| 107 |
+
raise FileNotFoundError("Input file {} do not exists!".format(file))
|
| 108 |
+
samps_tmp_reverb, _ = audio_lib.load(file, sr=self.fs)
|
| 109 |
+
samps_tmp, _ = audio_lib.load(files[idx], sr=self.fs)
|
| 110 |
+
# mixing with random gains
|
| 111 |
+
|
| 112 |
+
if idx == 0: ref_rms = np.sqrt(np.mean(np.square(samps_tmp)))
|
| 113 |
+
curr_rms = np.sqrt(np.mean(np.square(samps_tmp)))
|
| 114 |
+
|
| 115 |
+
norm_factor = ref_rms / curr_rms
|
| 116 |
+
samps_tmp *= norm_factor
|
| 117 |
+
samps_tmp_reverb *= norm_factor
|
| 118 |
+
|
| 119 |
+
gain = pow(10,-random.uniform(-3,3)/20)
|
| 120 |
+
# Speed Augmentation
|
| 121 |
+
samps_src_reverb.append(gain*samps_tmp_reverb)
|
| 122 |
+
samps_src.append(gain*samps_tmp)
|
| 123 |
+
src_len.append(len(samps_tmp))
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# matching the audio length
|
| 127 |
+
min_len = min(src_len)
|
| 128 |
+
|
| 129 |
+
# add noise source
|
| 130 |
+
file_noise = self.wave_dict_noise[key]
|
| 131 |
+
samps_noise, _ = audio_lib.load(file_noise, sr=self.fs)
|
| 132 |
+
curr_rms = np.sqrt(np.mean(np.square(samps_noise)))
|
| 133 |
+
norm_factor = ref_rms / curr_rms
|
| 134 |
+
samps_noise *= norm_factor
|
| 135 |
+
gain_noise = pow(10,-random.uniform(-6,3)/20)
|
| 136 |
+
samps_noise = samps_noise*gain_noise
|
| 137 |
+
src_len.append(len(samps_noise))
|
| 138 |
+
|
| 139 |
+
# truncate
|
| 140 |
+
min_len = min(src_len)
|
| 141 |
+
samps_src_stack = [np.stack([samps_src_reverb[idx], samps_src[idx]],axis=-1) for idx in range(len(samps_src_reverb))]
|
| 142 |
+
samps_src_stack = [__match_length(s, min_len) for s in samps_src_stack]
|
| 143 |
+
samps_src_reverb = [s[...,0] for s in samps_src_stack]
|
| 144 |
+
samps_src = [s[...,1] for s in samps_src_stack]
|
| 145 |
+
samps_noise = __match_length(samps_noise, min_len)
|
| 146 |
+
samps_mix = sum(samps_src_reverb) + samps_noise
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if len(samps_mix)%4 != 0:
|
| 150 |
+
remains = len(samps_mix)%4
|
| 151 |
+
samps_mix = samps_mix[:-remains]
|
| 152 |
+
samps_src = [s[:-remains] for s in samps_src]
|
| 153 |
+
|
| 154 |
+
return samps_mix, samps_src
|
| 155 |
+
|
| 156 |
+
def _direct_load(self, key):
|
| 157 |
+
samps_src = []
|
| 158 |
+
files = [self.wave_dict_srcs[0][key], self.wave_dict_srcs[1][key]]
|
| 159 |
+
# files = [wave_dict_src[key] for wave_dict_src in self.wave_dict_srcs]
|
| 160 |
+
for file in files:
|
| 161 |
+
if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
|
| 162 |
+
samps_tmp, _ = audio_lib.load(file, sr=self.fs)
|
| 163 |
+
samps_src.append(samps_tmp)
|
| 164 |
+
|
| 165 |
+
file = self.wave_dict_mix[key]
|
| 166 |
+
if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
|
| 167 |
+
samps_mix, _ = audio_lib.load(file, sr=self.fs)
|
| 168 |
+
|
| 169 |
+
# Truncate samples as needed
|
| 170 |
+
if len(samps_mix) % 4 != 0:
|
| 171 |
+
remains = len(samps_mix) % 4
|
| 172 |
+
samps_mix = samps_mix[:-remains]
|
| 173 |
+
samps_src = [s[:-remains] for s in samps_src]
|
| 174 |
+
|
| 175 |
+
if self.partition != "test":
|
| 176 |
+
if len(samps_mix) > self.max_len:
|
| 177 |
+
start = random.randint(0,len(samps_mix)-self.max_len)
|
| 178 |
+
samps_mix = samps_mix[start:start+self.max_len]
|
| 179 |
+
samps_src = [s[start:start+self.max_len] for s in samps_src]
|
| 180 |
+
|
| 181 |
+
return samps_mix, samps_src
|
| 182 |
+
|
| 183 |
+
def __getitem__(self, index):
|
| 184 |
+
key = self.wave_keys[index]
|
| 185 |
+
if any(key not in self.wave_dict_srcs[i] for i in range(len(self.wave_dict_srcs)-2)) or key not in self.wave_dict_mix: raise KeyError(f"Could not find utterance {key}")
|
| 186 |
+
samps_mix, samps_src = self._dynamic_mixing(key) if self.dynamic_mixing else self._direct_load(key)
|
| 187 |
+
return {"num_sample": samps_mix.shape[0], "mix": samps_mix, "src": samps_src, "key": key}
|
models/SepReformer/SepReformer_Large_DM_WHAMR/engine.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import csv
|
| 4 |
+
import time
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
|
| 7 |
+
from loguru import logger
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from utils import util_engine, functions
|
| 10 |
+
from utils.decorators import *
|
| 11 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@logger_wraps()
|
| 15 |
+
class Engine(object):
|
| 16 |
+
def __init__(self, args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device):
|
| 17 |
+
|
| 18 |
+
''' Default setting '''
|
| 19 |
+
self.engine_mode = args.engine_mode
|
| 20 |
+
self.out_wav_dir = args.out_wav_dir
|
| 21 |
+
self.config = config
|
| 22 |
+
self.gpuid = gpuid
|
| 23 |
+
self.device = device
|
| 24 |
+
self.model = model.to(self.device)
|
| 25 |
+
self.dataloaders = dataloaders # self.dataloaders['train'] or ['valid'] or ['test']
|
| 26 |
+
self.PIT_SISNR_mag_loss, self.PIT_SISNR_time_loss, self.PIT_SISNRi_loss, self.PIT_SDRi_loss = criterions
|
| 27 |
+
self.main_optimizer = optimizers[0]
|
| 28 |
+
self.main_scheduler, self.warmup_scheduler = schedulers
|
| 29 |
+
|
| 30 |
+
self.pretrain_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "pretrain_weights")
|
| 31 |
+
os.makedirs(self.pretrain_weights_path, exist_ok=True)
|
| 32 |
+
self.scratch_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "scratch_weights")
|
| 33 |
+
os.makedirs(self.scratch_weights_path, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
self.checkpoint_path = self.pretrain_weights_path if any(file.endswith(('.pt', '.pt', '.pkl')) for file in os.listdir(self.pretrain_weights_path)) else self.scratch_weights_path
|
| 36 |
+
self.start_epoch = util_engine.load_last_checkpoint_n_get_epoch(self.checkpoint_path, self.model, self.main_optimizer, location=self.device)
|
| 37 |
+
|
| 38 |
+
# Logging
|
| 39 |
+
util_engine.model_params_mac_summary(
|
| 40 |
+
model=self.model,
|
| 41 |
+
input=torch.randn(1, self.config['check_computations']['dummy_len']).to(self.device),
|
| 42 |
+
dummy_input=torch.rand(1, self.config['check_computations']['dummy_len']).to(self.device),
|
| 43 |
+
metrics=['ptflops', 'thop', 'torchinfo']
|
| 44 |
+
# metrics=['ptflops']
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
logger.info(f"Clip gradient by 2-norm {self.config['engine']['clip_norm']}")
|
| 48 |
+
|
| 49 |
+
@logger_wraps()
|
| 50 |
+
def _train(self, dataloader, epoch):
|
| 51 |
+
self.model.train()
|
| 52 |
+
tot_loss_freq = [0 for _ in range(self.model.num_stages)]
|
| 53 |
+
tot_loss_time, num_batch = 0, 0
|
| 54 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:25}{r_bar}{bar:-10b}', colour="YELLOW", dynamic_ncols=True)
|
| 55 |
+
for input_sizes, mixture, src, _ in dataloader:
|
| 56 |
+
nnet_input = mixture
|
| 57 |
+
nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
|
| 58 |
+
num_batch += 1
|
| 59 |
+
pbar.update(1)
|
| 60 |
+
# Scheduler learning rate for warm-up (Iteration-based update for transformers)
|
| 61 |
+
if epoch == 1: self.warmup_scheduler.step()
|
| 62 |
+
nnet_input = nnet_input.to(self.device)
|
| 63 |
+
self.main_optimizer.zero_grad()
|
| 64 |
+
estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 65 |
+
cur_loss_s_bn = 0
|
| 66 |
+
cur_loss_s_bn = []
|
| 67 |
+
for idx, estim_src_value in enumerate(estim_src_bn):
|
| 68 |
+
cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
|
| 69 |
+
tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
|
| 70 |
+
cur_loss_s = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
|
| 71 |
+
tot_loss_time += cur_loss_s.item() / self.config['model']['num_spks']
|
| 72 |
+
alpha = 0.4 * 0.8**(1+(epoch-101)//5) if epoch > 100 else 0.4
|
| 73 |
+
cur_loss = (1-alpha) * cur_loss_s + alpha * sum(cur_loss_s_bn) / len(cur_loss_s_bn)
|
| 74 |
+
cur_loss = cur_loss / self.config['model']['num_spks']
|
| 75 |
+
cur_loss.backward()
|
| 76 |
+
if self.config['engine']['clip_norm']: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['engine']['clip_norm'])
|
| 77 |
+
self.main_optimizer.step()
|
| 78 |
+
dict_loss = {"T_Loss": tot_loss_time / num_batch}
|
| 79 |
+
dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
|
| 80 |
+
pbar.set_postfix(dict_loss)
|
| 81 |
+
pbar.close()
|
| 82 |
+
tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
|
| 83 |
+
return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
|
| 84 |
+
|
| 85 |
+
@logger_wraps()
|
| 86 |
+
def _validate(self, dataloader):
|
| 87 |
+
self.model.eval()
|
| 88 |
+
tot_loss_freq = [0 for _ in range(self.model.num_stages)]
|
| 89 |
+
tot_loss_time, num_batch = 0, 0
|
| 90 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="RED", dynamic_ncols=True)
|
| 91 |
+
with torch.inference_mode():
|
| 92 |
+
for input_sizes, mixture, src, _ in dataloader:
|
| 93 |
+
nnet_input = mixture
|
| 94 |
+
nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
|
| 95 |
+
nnet_input = nnet_input.to(self.device)
|
| 96 |
+
num_batch += 1
|
| 97 |
+
pbar.update(1)
|
| 98 |
+
estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 99 |
+
cur_loss_s_bn = []
|
| 100 |
+
for idx, estim_src_value in enumerate(estim_src_bn):
|
| 101 |
+
cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
|
| 102 |
+
tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
|
| 103 |
+
cur_loss_s_SDR = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
|
| 104 |
+
tot_loss_time += cur_loss_s_SDR.item() / self.config['model']['num_spks']
|
| 105 |
+
dict_loss = {"T_Loss":tot_loss_time / num_batch}
|
| 106 |
+
dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
|
| 107 |
+
pbar.set_postfix(dict_loss)
|
| 108 |
+
pbar.close()
|
| 109 |
+
tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
|
| 110 |
+
return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
|
| 111 |
+
|
| 112 |
+
@logger_wraps()
|
| 113 |
+
def _test(self, dataloader, wav_dir=None):
|
| 114 |
+
self.model.eval()
|
| 115 |
+
total_loss_SISNRi, total_loss_SDRi, num_batch = 0, 0, 0
|
| 116 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="grey", dynamic_ncols=True)
|
| 117 |
+
with torch.inference_mode():
|
| 118 |
+
csv_file_name_sisnr = os.path.join(os.path.dirname(__file__),'test_SISNRi_value.csv')
|
| 119 |
+
csv_file_name_sdr = os.path.join(os.path.dirname(__file__),'test_SDRi_value.csv')
|
| 120 |
+
with open(csv_file_name_sisnr, 'w', newline='') as csvfile_sisnr, open(csv_file_name_sdr, 'w', newline='') as csvfile_sdr:
|
| 121 |
+
idx = 0
|
| 122 |
+
writer_sisnr = csv.writer(csvfile_sisnr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
| 123 |
+
writer_sdr = csv.writer(csvfile_sdr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
| 124 |
+
for input_sizes, mixture, src, key in dataloader:
|
| 125 |
+
if len(key) > 1:
|
| 126 |
+
raise("batch size is not one!!")
|
| 127 |
+
nnet_input = mixture.to(self.device)
|
| 128 |
+
num_batch += 1
|
| 129 |
+
pbar.update(1)
|
| 130 |
+
estim_src, _ = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 131 |
+
cur_loss_SISNRi, cur_loss_SISNRi_src = self.PIT_SISNRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src, eps=1.0e-15)
|
| 132 |
+
total_loss_SISNRi += cur_loss_SISNRi.item() / self.config['model']['num_spks']
|
| 133 |
+
cur_loss_SDRi, cur_loss_SDRi_src = self.PIT_SDRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src)
|
| 134 |
+
total_loss_SDRi += cur_loss_SDRi.item() / self.config['model']['num_spks']
|
| 135 |
+
writer_sisnr.writerow([key[0][:-4]] + [cur_loss_SISNRi_src[i].item() for i in range(self.config['model']['num_spks'])])
|
| 136 |
+
writer_sdr.writerow([key[0][:-4]] + [cur_loss_SDRi_src[i].item() for i in range(self.config['model']['num_spks'])])
|
| 137 |
+
if self.engine_mode == "test_save":
|
| 138 |
+
if wav_dir == None: wav_dir = os.path.join(os.path.dirname(__file__),"wav_out")
|
| 139 |
+
if wav_dir and not os.path.exists(wav_dir): os.makedirs(wav_dir)
|
| 140 |
+
mixture = torch.squeeze(mixture).cpu().data.numpy()
|
| 141 |
+
sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_mixture.wav'), 0.5*mixture/max(abs(mixture)), 8000)
|
| 142 |
+
for i in range(self.config['model']['num_spks']):
|
| 143 |
+
src = torch.squeeze(estim_src[i]).cpu().data.numpy()
|
| 144 |
+
sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_out_'+str(i)+'.wav'), 0.5*src/max(abs(src)), 8000)
|
| 145 |
+
idx += 1
|
| 146 |
+
dict_loss = {"SiSNRi": total_loss_SISNRi/num_batch, "SDRi": total_loss_SDRi/num_batch}
|
| 147 |
+
pbar.set_postfix(dict_loss)
|
| 148 |
+
pbar.close()
|
| 149 |
+
return total_loss_SISNRi/num_batch, total_loss_SDRi/num_batch, num_batch
|
| 150 |
+
|
| 151 |
+
@logger_wraps()
|
| 152 |
+
def run(self):
|
| 153 |
+
with torch.cuda.device(self.device):
|
| 154 |
+
writer_src = SummaryWriter(os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/tensorboard"))
|
| 155 |
+
if "test" in self.engine_mode:
|
| 156 |
+
on_test_start = time.time()
|
| 157 |
+
test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'], self.out_wav_dir)
|
| 158 |
+
on_test_end = time.time()
|
| 159 |
+
logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
|
| 160 |
+
logger.info(f"Testing done!")
|
| 161 |
+
else:
|
| 162 |
+
start_time = time.time()
|
| 163 |
+
if self.start_epoch > 1:
|
| 164 |
+
init_loss_time, init_loss_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
|
| 165 |
+
else:
|
| 166 |
+
init_loss_time, init_loss_freq = 0, 0
|
| 167 |
+
end_time = time.time()
|
| 168 |
+
logger.info(f"[INIT] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: Loss_t = {init_loss_time:.4f} dB | Loss_f = {init_loss_freq:.4f} dB | Speed = ({end_time-start_time:.2f}s)")
|
| 169 |
+
for epoch in range(self.start_epoch, self.config['engine']['max_epoch']):
|
| 170 |
+
valid_loss_best = init_loss_time
|
| 171 |
+
train_start_time = time.time()
|
| 172 |
+
train_loss_src_time, train_loss_src_freq, train_num_batch = self._train(self.dataloaders['train'], epoch)
|
| 173 |
+
train_end_time = time.time()
|
| 174 |
+
valid_start_time = time.time()
|
| 175 |
+
valid_loss_src_time, valid_loss_src_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
|
| 176 |
+
valid_end_time = time.time()
|
| 177 |
+
if epoch > self.config['engine']['start_scheduling']: self.main_scheduler.step(valid_loss_src_time)
|
| 178 |
+
logger.info(f"[TRAIN] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {train_loss_src_time:.4f} dB | Loss_f = {train_loss_src_freq:.4f} dB | Speed = ({train_end_time - train_start_time:.2f}s/{train_num_batch:d})")
|
| 179 |
+
logger.info(f"[VALID] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {valid_loss_src_time:.4f} dB | Loss_f = {valid_loss_src_freq:.4f} dB | Speed = ({valid_end_time - valid_start_time:.2f}s/{valid_num_batch:d})")
|
| 180 |
+
if epoch in self.config['engine']['test_epochs']:
|
| 181 |
+
on_test_start = time.time()
|
| 182 |
+
test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'])
|
| 183 |
+
on_test_end = time.time()
|
| 184 |
+
logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
|
| 185 |
+
valid_loss_best = util_engine.save_checkpoint_per_best(valid_loss_best, valid_loss_src_time, train_loss_src_time, epoch, self.model, self.main_optimizer, self.checkpoint_path)
|
| 186 |
+
# Logging to monitoring tools (Tensorboard && Wandb)
|
| 187 |
+
writer_src.add_scalars("Metrics", {
|
| 188 |
+
'Loss_train_time': train_loss_src_time,
|
| 189 |
+
'Loss_valid_time': valid_loss_src_time}, epoch)
|
| 190 |
+
writer_src.add_scalar("Learning Rate", self.main_optimizer.param_groups[0]['lr'], epoch)
|
| 191 |
+
writer_src.flush()
|
| 192 |
+
logger.info(f"Training for {self.config['engine']['max_epoch']} epoches done!")
|
models/SepReformer/SepReformer_Large_DM_WHAMR/main.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from .dataset import get_dataloaders
|
| 5 |
+
from .model import Model
|
| 6 |
+
from .engine import Engine
|
| 7 |
+
from utils import util_system, util_implement
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
|
| 10 |
+
# Setup logger
|
| 11 |
+
log_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/system_log.log")
|
| 12 |
+
logger.add(log_file_path, level="DEBUG", mode="w")
|
| 13 |
+
|
| 14 |
+
@logger_wraps()
|
| 15 |
+
def main(args):
|
| 16 |
+
|
| 17 |
+
''' Build Setting '''
|
| 18 |
+
# Call configuration file (configs.yaml)
|
| 19 |
+
yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs.yaml")
|
| 20 |
+
yaml_dict = util_system.parse_yaml(yaml_path)
|
| 21 |
+
|
| 22 |
+
# Run wandb and get configuration
|
| 23 |
+
config = yaml_dict["config"] # wandb login success or fail
|
| 24 |
+
|
| 25 |
+
# Call DataLoader [train / valid / test / etc...]
|
| 26 |
+
dataloaders = get_dataloaders(args, config["dataset"], config["dataloader"])
|
| 27 |
+
|
| 28 |
+
''' Build Model '''
|
| 29 |
+
# Call network model
|
| 30 |
+
model = Model(**config["model"])
|
| 31 |
+
|
| 32 |
+
''' Build Engine '''
|
| 33 |
+
# Call gpu id & device
|
| 34 |
+
gpuid = tuple(map(int, config["engine"]["gpuid"].split(',')))
|
| 35 |
+
device = torch.device(f'cuda:{gpuid[0]}')
|
| 36 |
+
|
| 37 |
+
# Call Implement [criterion / optimizer / scheduler]
|
| 38 |
+
criterions = util_implement.CriterionFactory(config["criterion"], device).get_criterions()
|
| 39 |
+
optimizers = util_implement.OptimizerFactory(config["optimizer"], model.parameters()).get_optimizers()
|
| 40 |
+
schedulers = util_implement.SchedulerFactory(config["scheduler"], optimizers).get_schedulers()
|
| 41 |
+
|
| 42 |
+
# Call & Run Engine
|
| 43 |
+
engine = Engine(args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device)
|
| 44 |
+
engine.run()
|
models/SepReformer/SepReformer_Large_DM_WHAMR/model.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('../')
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings('ignore')
|
| 7 |
+
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from .modules.module import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@logger_wraps()
|
| 13 |
+
class Model(torch.nn.Module):
|
| 14 |
+
def __init__(self,
|
| 15 |
+
num_stages: int,
|
| 16 |
+
num_spks: int,
|
| 17 |
+
module_audio_enc: dict,
|
| 18 |
+
module_feature_projector: dict,
|
| 19 |
+
module_separator: dict,
|
| 20 |
+
module_output_layer: dict,
|
| 21 |
+
module_audio_dec: dict):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.num_stages = num_stages
|
| 24 |
+
self.num_spks = num_spks
|
| 25 |
+
self.audio_encoder = AudioEncoder(**module_audio_enc)
|
| 26 |
+
self.feature_projector = FeatureProjector(**module_feature_projector)
|
| 27 |
+
self.separator = Separator(**module_separator)
|
| 28 |
+
self.out_layer = OutputLayer(**module_output_layer)
|
| 29 |
+
self.audio_decoder = AudioDecoder(**module_audio_dec)
|
| 30 |
+
|
| 31 |
+
# Aux_loss
|
| 32 |
+
self.out_layer_bn = torch.nn.ModuleList([])
|
| 33 |
+
self.decoder_bn = torch.nn.ModuleList([])
|
| 34 |
+
for _ in range(self.num_stages):
|
| 35 |
+
self.out_layer_bn.append(OutputLayer(**module_output_layer, masking=True))
|
| 36 |
+
self.decoder_bn.append(AudioDecoder(**module_audio_dec))
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
encoder_output = self.audio_encoder(x)
|
| 40 |
+
projected_feature = self.feature_projector(encoder_output)
|
| 41 |
+
last_stage_output, each_stage_outputs = self.separator(projected_feature)
|
| 42 |
+
out_layer_output = self.out_layer(last_stage_output, encoder_output)
|
| 43 |
+
each_spk_output = [out_layer_output[idx] for idx in range(self.num_spks)]
|
| 44 |
+
audio = [self.audio_decoder(each_spk_output[idx]) for idx in range(self.num_spks)]
|
| 45 |
+
|
| 46 |
+
# Aux_loss
|
| 47 |
+
audio_aux = []
|
| 48 |
+
for idx, each_stage_output in enumerate(each_stage_outputs):
|
| 49 |
+
each_stage_output = self.out_layer_bn[idx](torch.nn.functional.upsample(each_stage_output, encoder_output.shape[-1]), encoder_output)
|
| 50 |
+
out_aux = [each_stage_output[jdx] for jdx in range(self.num_spks)]
|
| 51 |
+
audio_aux.append([self.decoder_bn[idx](out_aux[jdx])[...,:x.shape[-1]] for jdx in range(self.num_spks)])
|
| 52 |
+
|
| 53 |
+
return audio, audio_aux
|
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/module.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/module.cpython-38.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/network.cpython-310.pyc
ADDED
|
Binary file (8.98 kB). View file
|
|
|
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/__pycache__/network.cpython-38.pyc
ADDED
|
Binary file (9.09 kB). View file
|
|
|
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/module.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('../')
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings('ignore')
|
| 7 |
+
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from .network import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AudioEncoder(torch.nn.Module):
|
| 13 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, bias: bool):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.conv1d = torch.nn.Conv1d(
|
| 16 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, groups=groups, bias=bias)
|
| 17 |
+
self.gelu = torch.nn.GELU()
|
| 18 |
+
|
| 19 |
+
def forward(self, x: torch.Tensor):
|
| 20 |
+
x = torch.unsqueeze(x, dim=0) if len(x.shape) == 1 else torch.unsqueeze(x, dim=1) # [T] - >[1, T] OR [B, T] -> [B, 1, T]
|
| 21 |
+
x = self.conv1d(x)
|
| 22 |
+
x = self.gelu(x)
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
class FeatureProjector(torch.nn.Module):
|
| 26 |
+
def __init__(self, num_channels: int, in_channels: int, out_channels: int, kernel_size: int, bias: bool):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.norm = torch.nn.GroupNorm(num_groups=1, num_channels=num_channels, eps=1e-8)
|
| 29 |
+
self.conv1d = torch.nn.Conv1d(
|
| 30 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor):
|
| 33 |
+
x = self.norm(x)
|
| 34 |
+
x = self.conv1d(x)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Separator(torch.nn.Module):
|
| 39 |
+
def __init__(self, num_stages: int, relative_positional_encoding: dict, enc_stage: dict, spk_split_stage: dict, simple_fusion:dict, dec_stage: dict):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
class RelativePositionalEncoding(torch.nn.Module):
|
| 43 |
+
def __init__(self, in_channels: int, num_heads: int, maxlen: int, embed_v=False):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.in_channels = in_channels
|
| 46 |
+
self.num_heads = num_heads
|
| 47 |
+
self.embedding_dim = self.in_channels // self.num_heads
|
| 48 |
+
self.maxlen = maxlen
|
| 49 |
+
self.pe_k = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim)
|
| 50 |
+
self.pe_v = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim) if embed_v else None
|
| 51 |
+
|
| 52 |
+
def forward(self, pos_seq: torch.Tensor):
|
| 53 |
+
pos_seq.clamp_(-self.maxlen, self.maxlen - 1)
|
| 54 |
+
pos_seq += self.maxlen
|
| 55 |
+
pe_k_output = self.pe_k(pos_seq)
|
| 56 |
+
pe_v_output = self.pe_v(pos_seq) if self.pe_v is not None else None
|
| 57 |
+
return pe_k_output, pe_v_output
|
| 58 |
+
|
| 59 |
+
class SepEncStage(torch.nn.Module):
|
| 60 |
+
def __init__(self, global_blocks: dict, local_blocks: dict, down_conv_layer: dict, down_conv=True):
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
class DownConvLayer(torch.nn.Module):
|
| 64 |
+
def __init__(self, in_channels: int, samp_kernel_size: int):
|
| 65 |
+
"""Construct an EncoderLayer object."""
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.down_conv = torch.nn.Conv1d(
|
| 68 |
+
in_channels=in_channels, out_channels=in_channels, kernel_size=samp_kernel_size, stride=2, padding=(samp_kernel_size-1)//2, groups=in_channels)
|
| 69 |
+
self.BN = torch.nn.BatchNorm1d(num_features=in_channels)
|
| 70 |
+
self.gelu = torch.nn.GELU()
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor):
|
| 73 |
+
x = x.permute([0, 2, 1])
|
| 74 |
+
x = self.down_conv(x)
|
| 75 |
+
x = self.BN(x)
|
| 76 |
+
x = self.gelu(x)
|
| 77 |
+
x = x.permute([0, 2, 1])
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
self.g_block_1 = GlobalBlock(**global_blocks)
|
| 81 |
+
self.l_block_1 = LocalBlock(**local_blocks)
|
| 82 |
+
|
| 83 |
+
self.g_block_2 = GlobalBlock(**global_blocks)
|
| 84 |
+
self.l_block_2 = LocalBlock(**local_blocks)
|
| 85 |
+
|
| 86 |
+
self.downconv = DownConvLayer(**down_conv_layer) if down_conv == True else None
|
| 87 |
+
|
| 88 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 89 |
+
'''
|
| 90 |
+
x: [B, N, T]
|
| 91 |
+
'''
|
| 92 |
+
x = self.g_block_1(x, pos_k)
|
| 93 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 94 |
+
x = self.l_block_1(x)
|
| 95 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 96 |
+
|
| 97 |
+
x = self.g_block_2(x, pos_k)
|
| 98 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 99 |
+
x = self.l_block_2(x)
|
| 100 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 101 |
+
|
| 102 |
+
skip = x
|
| 103 |
+
if self.downconv:
|
| 104 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 105 |
+
x = self.downconv(x)
|
| 106 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 107 |
+
# [BK, S, N]
|
| 108 |
+
return x, skip
|
| 109 |
+
|
| 110 |
+
class SpkSplitStage(torch.nn.Module):
|
| 111 |
+
def __init__(self, in_channels: int, num_spks: int):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.linear = torch.nn.Sequential(
|
| 114 |
+
torch.nn.Conv1d(in_channels, 4*in_channels*num_spks, kernel_size=1),
|
| 115 |
+
torch.nn.GLU(dim=-2),
|
| 116 |
+
torch.nn.Conv1d(2*in_channels*num_spks, in_channels*num_spks, kernel_size=1))
|
| 117 |
+
self.norm = torch.nn.GroupNorm(1, in_channels, eps=1e-8)
|
| 118 |
+
self.num_spks = num_spks
|
| 119 |
+
|
| 120 |
+
def forward(self, x: torch.Tensor):
|
| 121 |
+
x = self.linear(x)
|
| 122 |
+
B, _, T = x.shape
|
| 123 |
+
x = x.view(B*self.num_spks,-1, T).contiguous()
|
| 124 |
+
x = self.norm(x)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
class SepDecStage(torch.nn.Module):
|
| 128 |
+
def __init__(self, num_spks: int, global_blocks: dict, local_blocks: dict, spk_attention: dict):
|
| 129 |
+
super().__init__()
|
| 130 |
+
|
| 131 |
+
self.g_block_1 = GlobalBlock(**global_blocks)
|
| 132 |
+
self.l_block_1 = LocalBlock(**local_blocks)
|
| 133 |
+
self.spk_attn_1 = SpkAttention(**spk_attention)
|
| 134 |
+
|
| 135 |
+
self.g_block_2 = GlobalBlock(**global_blocks)
|
| 136 |
+
self.l_block_2 = LocalBlock(**local_blocks)
|
| 137 |
+
self.spk_attn_2 = SpkAttention(**spk_attention)
|
| 138 |
+
|
| 139 |
+
self.g_block_3 = GlobalBlock(**global_blocks)
|
| 140 |
+
self.l_block_3 = LocalBlock(**local_blocks)
|
| 141 |
+
self.spk_attn_3 = SpkAttention(**spk_attention)
|
| 142 |
+
|
| 143 |
+
self.num_spk = num_spks
|
| 144 |
+
|
| 145 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 146 |
+
'''
|
| 147 |
+
x: [B, N, T]
|
| 148 |
+
'''
|
| 149 |
+
# [BS, K, H]
|
| 150 |
+
x = self.g_block_1(x, pos_k)
|
| 151 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 152 |
+
x = self.l_block_1(x)
|
| 153 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 154 |
+
x = self.spk_attn_1(x, self.num_spk)
|
| 155 |
+
|
| 156 |
+
x = self.g_block_2(x, pos_k)
|
| 157 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 158 |
+
x = self.l_block_2(x)
|
| 159 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 160 |
+
x = self.spk_attn_2(x, self.num_spk)
|
| 161 |
+
|
| 162 |
+
x = self.g_block_3(x, pos_k)
|
| 163 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 164 |
+
x = self.l_block_3(x)
|
| 165 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 166 |
+
x = self.spk_attn_3(x, self.num_spk)
|
| 167 |
+
|
| 168 |
+
skip = x
|
| 169 |
+
|
| 170 |
+
return x, skip
|
| 171 |
+
|
| 172 |
+
self.num_stages = num_stages
|
| 173 |
+
self.pos_emb = RelativePositionalEncoding(**relative_positional_encoding)
|
| 174 |
+
|
| 175 |
+
# Temporal Contracting Part
|
| 176 |
+
self.enc_stages = torch.nn.ModuleList([])
|
| 177 |
+
for _ in range(self.num_stages):
|
| 178 |
+
self.enc_stages.append(SepEncStage(**enc_stage, down_conv=True))
|
| 179 |
+
|
| 180 |
+
self.bottleneck_G = SepEncStage(**enc_stage, down_conv=False)
|
| 181 |
+
self.spk_split_block = SpkSplitStage(**spk_split_stage)
|
| 182 |
+
|
| 183 |
+
# Temporal Expanding Part
|
| 184 |
+
self.simple_fusion = torch.nn.ModuleList([])
|
| 185 |
+
self.dec_stages = torch.nn.ModuleList([])
|
| 186 |
+
for _ in range(self.num_stages):
|
| 187 |
+
self.simple_fusion.append(torch.nn.Conv1d(in_channels=simple_fusion['out_channels']*2,out_channels=simple_fusion['out_channels'], kernel_size=1))
|
| 188 |
+
self.dec_stages.append(SepDecStage(**dec_stage))
|
| 189 |
+
|
| 190 |
+
def forward(self, input: torch.Tensor):
|
| 191 |
+
'''input: [B, N, L]'''
|
| 192 |
+
# feature projection
|
| 193 |
+
x, _ = self.pad_signal(input)
|
| 194 |
+
len_x = x.shape[-1]
|
| 195 |
+
# Temporal Contracting Part
|
| 196 |
+
pos_seq = torch.arange(0, len_x//2**self.num_stages).long().to(x.device)
|
| 197 |
+
pos_seq = pos_seq[:, None] - pos_seq[None, :]
|
| 198 |
+
pos_k, _ = self.pos_emb(pos_seq)
|
| 199 |
+
skip = []
|
| 200 |
+
for idx in range(self.num_stages):
|
| 201 |
+
x, skip_ = self.enc_stages[idx](x, pos_k)
|
| 202 |
+
skip_ = self.spk_split_block(skip_)
|
| 203 |
+
skip.append(skip_)
|
| 204 |
+
x, _ = self.bottleneck_G(x, pos_k)
|
| 205 |
+
x = self.spk_split_block(x) # B, 2F, T
|
| 206 |
+
|
| 207 |
+
each_stage_outputs = []
|
| 208 |
+
# Temporal Expanding Part
|
| 209 |
+
for idx in range(self.num_stages):
|
| 210 |
+
each_stage_outputs.append(x)
|
| 211 |
+
idx_en = self.num_stages - (idx + 1)
|
| 212 |
+
x = torch.nn.functional.upsample(x, skip[idx_en].shape[-1])
|
| 213 |
+
x = torch.cat([x,skip[idx_en]],dim=1)
|
| 214 |
+
x = self.simple_fusion[idx](x)
|
| 215 |
+
x, _ = self.dec_stages[idx](x, pos_k)
|
| 216 |
+
|
| 217 |
+
last_stage_output = x
|
| 218 |
+
return last_stage_output, each_stage_outputs
|
| 219 |
+
|
| 220 |
+
def pad_signal(self, input: torch.Tensor):
|
| 221 |
+
# (B, T) or (B, 1, T)
|
| 222 |
+
if input.dim() == 1: input = input.unsqueeze(0)
|
| 223 |
+
elif input.dim() not in [2, 3]: raise RuntimeError("Input can only be 2 or 3 dimensional.")
|
| 224 |
+
elif input.dim() == 2: input = input.unsqueeze(1)
|
| 225 |
+
L = 2**self.num_stages
|
| 226 |
+
batch_size = input.size(0)
|
| 227 |
+
ndim = input.size(1)
|
| 228 |
+
nframe = input.size(2)
|
| 229 |
+
padded_len = (nframe//L + 1)*L
|
| 230 |
+
rest = 0 if nframe%L == 0 else padded_len - nframe
|
| 231 |
+
if rest > 0:
|
| 232 |
+
pad = torch.autograd.Variable(torch.zeros(batch_size, ndim, rest)).type(input.type()).to(input.device)
|
| 233 |
+
input = torch.cat([input, pad], dim=-1)
|
| 234 |
+
return input, rest
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class OutputLayer(torch.nn.Module):
|
| 238 |
+
def __init__(self, in_channels: int, out_channels: int, num_spks: int, masking: bool = False):
|
| 239 |
+
super().__init__()
|
| 240 |
+
# feature expansion back
|
| 241 |
+
self.masking = masking
|
| 242 |
+
self.spe_block = Masking(in_channels, Activation_mask="ReLU", concat_opt=None)
|
| 243 |
+
self.num_spks = num_spks
|
| 244 |
+
self.end_conv1x1 = torch.nn.Sequential(
|
| 245 |
+
torch.nn.Linear(out_channels, 4*out_channels),
|
| 246 |
+
torch.nn.GLU(),
|
| 247 |
+
torch.nn.Linear(2*out_channels, in_channels))
|
| 248 |
+
|
| 249 |
+
def forward(self, x: torch.Tensor, input: torch.Tensor):
|
| 250 |
+
x = x[...,:input.shape[-1]]
|
| 251 |
+
x = x.permute([0, 2, 1])
|
| 252 |
+
x = self.end_conv1x1(x)
|
| 253 |
+
x = x.permute([0, 2, 1])
|
| 254 |
+
B, N, L = x.shape
|
| 255 |
+
B = B // self.num_spks
|
| 256 |
+
|
| 257 |
+
if self.masking:
|
| 258 |
+
input = input.expand(self.num_spks, B, N, L).transpose(0,1).contiguous()
|
| 259 |
+
input = input.view(B*self.num_spks, N, L)
|
| 260 |
+
x = self.spe_block(x, input)
|
| 261 |
+
|
| 262 |
+
x = x.view(B, self.num_spks, N, L)
|
| 263 |
+
# [spks, B, N, L]
|
| 264 |
+
x = x.transpose(0, 1)
|
| 265 |
+
return x
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class AudioDecoder(torch.nn.ConvTranspose1d):
|
| 269 |
+
'''
|
| 270 |
+
Decoder of the TasNet
|
| 271 |
+
This module can be seen as the gradient of Conv1d with respect to its input.
|
| 272 |
+
It is also known as a fractionally-strided convolution
|
| 273 |
+
or a deconvolution (although it is not an actual deconvolution operation).
|
| 274 |
+
'''
|
| 275 |
+
def __init__(self, *args, **kwargs):
|
| 276 |
+
super().__init__(*args, **kwargs)
|
| 277 |
+
|
| 278 |
+
def forward(self, x):
|
| 279 |
+
# x: [B, N, L]
|
| 280 |
+
if x.dim() not in [2, 3]: raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
|
| 281 |
+
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
|
| 282 |
+
x = torch.squeeze(x, dim=1) if torch.squeeze(x).dim() == 1 else torch.squeeze(x)
|
| 283 |
+
return x
|
models/SepReformer/SepReformer_Large_DM_WHAMR/modules/network.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import numpy
|
| 4 |
+
from utils.decorators import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LayerScale(torch.nn.Module):
|
| 8 |
+
def __init__(self, dims, input_size, Layer_scale_init=1.0e-5):
|
| 9 |
+
super().__init__()
|
| 10 |
+
if dims == 1:
|
| 11 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(input_size)*Layer_scale_init, requires_grad=True)
|
| 12 |
+
elif dims == 2:
|
| 13 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(1,input_size)*Layer_scale_init, requires_grad=True)
|
| 14 |
+
elif dims == 3:
|
| 15 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(1,1,input_size)*Layer_scale_init, requires_grad=True)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return x*self.layer_scale
|
| 19 |
+
|
| 20 |
+
class Masking(torch.nn.Module):
|
| 21 |
+
def __init__(self, input_dim, Activation_mask='Sigmoid', **options):
|
| 22 |
+
super(Masking, self).__init__()
|
| 23 |
+
|
| 24 |
+
self.options = options
|
| 25 |
+
if self.options['concat_opt']:
|
| 26 |
+
self.pw_conv = torch.nn.Conv1d(input_dim*2, input_dim, 1, stride=1, padding=0)
|
| 27 |
+
|
| 28 |
+
if Activation_mask == 'Sigmoid':
|
| 29 |
+
self.gate_act = torch.nn.Sigmoid()
|
| 30 |
+
elif Activation_mask == 'ReLU':
|
| 31 |
+
self.gate_act = torch.nn.ReLU()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def forward(self, x, skip):
|
| 35 |
+
|
| 36 |
+
if self.options['concat_opt']:
|
| 37 |
+
y = torch.cat([x, skip], dim=-2)
|
| 38 |
+
y = self.pw_conv(y)
|
| 39 |
+
else:
|
| 40 |
+
y = x
|
| 41 |
+
y = self.gate_act(y) * skip
|
| 42 |
+
|
| 43 |
+
return y
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GCFN(torch.nn.Module):
|
| 47 |
+
def __init__(self, in_channels, dropout_rate, Layer_scale_init=1.0e-5):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.net1 = torch.nn.Sequential(
|
| 50 |
+
torch.nn.LayerNorm(in_channels),
|
| 51 |
+
torch.nn.Linear(in_channels, in_channels*6))
|
| 52 |
+
self.depthwise = torch.nn.Conv1d(in_channels*6, in_channels*6, 3, padding=1, groups=in_channels*6)
|
| 53 |
+
self.net2 = torch.nn.Sequential(
|
| 54 |
+
torch.nn.GLU(),
|
| 55 |
+
torch.nn.Dropout(dropout_rate),
|
| 56 |
+
torch.nn.Linear(in_channels*3, in_channels),
|
| 57 |
+
torch.nn.Dropout(dropout_rate))
|
| 58 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
y = self.net1(x)
|
| 62 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 63 |
+
y = self.depthwise(y)
|
| 64 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 65 |
+
y = self.net2(y)
|
| 66 |
+
return x + self.Layer_scale(y)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MultiHeadAttention(torch.nn.Module):
|
| 70 |
+
"""
|
| 71 |
+
Multi-Head Attention layer.
|
| 72 |
+
:param int n_head: the number of head s
|
| 73 |
+
:param int n_feat: the number of features
|
| 74 |
+
:param float dropout_rate: dropout rate
|
| 75 |
+
"""
|
| 76 |
+
def __init__(self, n_head: int, in_channels: int, dropout_rate: float, Layer_scale_init=1.0e-5):
|
| 77 |
+
super().__init__()
|
| 78 |
+
assert in_channels % n_head == 0
|
| 79 |
+
self.d_k = in_channels // n_head # We assume d_v always equals d_k
|
| 80 |
+
self.h = n_head
|
| 81 |
+
self.layer_norm = torch.nn.LayerNorm(in_channels)
|
| 82 |
+
self.linear_q = torch.nn.Linear(in_channels, in_channels)
|
| 83 |
+
self.linear_k = torch.nn.Linear(in_channels, in_channels)
|
| 84 |
+
self.linear_v = torch.nn.Linear(in_channels, in_channels)
|
| 85 |
+
self.linear_out = torch.nn.Linear(in_channels, in_channels)
|
| 86 |
+
self.attn = None
|
| 87 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 88 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, pos_k, mask):
|
| 91 |
+
"""
|
| 92 |
+
Compute 'Scaled Dot Product Attention'.
|
| 93 |
+
:param torch.Tensor mask: (batch, time1, time2)
|
| 94 |
+
:param torch.nn.Dropout dropout:
|
| 95 |
+
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
|
| 96 |
+
weighted by the query dot key attention (batch, head, time1, time2)
|
| 97 |
+
"""
|
| 98 |
+
n_batch = x.size(0)
|
| 99 |
+
x = self.layer_norm(x)
|
| 100 |
+
q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
|
| 101 |
+
k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
|
| 102 |
+
v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k)
|
| 103 |
+
q = q.transpose(1, 2)
|
| 104 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 105 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 106 |
+
A = torch.matmul(q, k.transpose(-2, -1))
|
| 107 |
+
reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1)
|
| 108 |
+
if pos_k is not None:
|
| 109 |
+
B = torch.matmul(reshape_q, pos_k.transpose(-2, -1))
|
| 110 |
+
B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
|
| 111 |
+
scores = (A + B) / math.sqrt(self.d_k)
|
| 112 |
+
else:
|
| 113 |
+
scores = A / math.sqrt(self.d_k)
|
| 114 |
+
if mask is not None:
|
| 115 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
| 116 |
+
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
| 117 |
+
scores = scores.masked_fill(mask, min_value)
|
| 118 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
| 119 |
+
else:
|
| 120 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 121 |
+
p_attn = self.dropout(self.attn)
|
| 122 |
+
x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
|
| 123 |
+
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
|
| 124 |
+
return self.Layer_scale(self.dropout(self.linear_out(x))) # (batch, time1, d_model)
|
| 125 |
+
|
| 126 |
+
class EGA(torch.nn.Module):
|
| 127 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.block = torch.nn.ModuleDict({
|
| 130 |
+
'self_attn': MultiHeadAttention(
|
| 131 |
+
n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
|
| 132 |
+
'linear': torch.nn.Sequential(
|
| 133 |
+
torch.nn.LayerNorm(normalized_shape=in_channels),
|
| 134 |
+
torch.nn.Linear(in_features=in_channels, out_features=in_channels),
|
| 135 |
+
torch.nn.Sigmoid())
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 139 |
+
"""
|
| 140 |
+
Compute encoded features.
|
| 141 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 142 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 143 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 144 |
+
"""
|
| 145 |
+
down_len = pos_k.shape[0]
|
| 146 |
+
x_down = torch.nn.functional.adaptive_avg_pool1d(input=x, output_size=down_len)
|
| 147 |
+
x = x.permute([0, 2, 1])
|
| 148 |
+
x_down = x_down.permute([0, 2, 1])
|
| 149 |
+
x_down = self.block['self_attn'](x_down, pos_k, None)
|
| 150 |
+
x_down = x_down.permute([0, 2, 1])
|
| 151 |
+
x_downup = torch.nn.functional.upsample(input=x_down, size=x.shape[1])
|
| 152 |
+
x_downup = x_downup.permute([0, 2, 1])
|
| 153 |
+
x = x + self.block['linear'](x) * x_downup
|
| 154 |
+
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class CLA(torch.nn.Module):
|
| 160 |
+
def __init__(self, in_channels, kernel_size, dropout_rate, Layer_scale_init=1.0e-5):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.layer_norm = torch.nn.LayerNorm(in_channels)
|
| 163 |
+
self.linear1 = torch.nn.Linear(in_channels, in_channels*2)
|
| 164 |
+
self.GLU = torch.nn.GLU()
|
| 165 |
+
self.dw_conv_1d = torch.nn.Conv1d(in_channels, in_channels, kernel_size, padding='same', groups=in_channels)
|
| 166 |
+
self.linear2 = torch.nn.Linear(in_channels, 2*in_channels)
|
| 167 |
+
self.BN = torch.nn.BatchNorm1d(2*in_channels)
|
| 168 |
+
self.linear3 = torch.nn.Sequential(
|
| 169 |
+
torch.nn.GELU(),
|
| 170 |
+
torch.nn.Linear(2*in_channels, in_channels),
|
| 171 |
+
torch.nn.Dropout(dropout_rate))
|
| 172 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
y = self.layer_norm(x)
|
| 176 |
+
y = self.linear1(y)
|
| 177 |
+
y = self.GLU(y)
|
| 178 |
+
y = y.permute([0, 2, 1]) # B, F, T
|
| 179 |
+
y = self.dw_conv_1d(y)
|
| 180 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 181 |
+
y = self.linear2(y)
|
| 182 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 183 |
+
y = self.BN(y)
|
| 184 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 185 |
+
y = self.linear3(y)
|
| 186 |
+
|
| 187 |
+
return x + self.Layer_scale(y)
|
| 188 |
+
|
| 189 |
+
class GlobalBlock(torch.nn.Module):
|
| 190 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.block = torch.nn.ModuleDict({
|
| 193 |
+
'ega': EGA(
|
| 194 |
+
num_mha_heads=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
|
| 195 |
+
'gcfn': GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
|
| 196 |
+
})
|
| 197 |
+
|
| 198 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 199 |
+
"""
|
| 200 |
+
Compute encoded features.
|
| 201 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 202 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 203 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 204 |
+
"""
|
| 205 |
+
x = self.block['ega'](x, pos_k)
|
| 206 |
+
x = self.block['gcfn'](x)
|
| 207 |
+
x = x.permute([0, 2, 1])
|
| 208 |
+
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class LocalBlock(torch.nn.Module):
|
| 213 |
+
def __init__(self, in_channels: int, kernel_size: int, dropout_rate: float):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.block = torch.nn.ModuleDict({
|
| 216 |
+
'cla': CLA(in_channels, kernel_size, dropout_rate),
|
| 217 |
+
'gcfn': GCFN(in_channels, dropout_rate)
|
| 218 |
+
})
|
| 219 |
+
|
| 220 |
+
def forward(self, x: torch.Tensor):
|
| 221 |
+
x = self.block['cla'](x)
|
| 222 |
+
x = self.block['gcfn'](x)
|
| 223 |
+
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class SpkAttention(torch.nn.Module):
|
| 228 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.self_attn = MultiHeadAttention(n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate)
|
| 231 |
+
self.feed_forward = GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
|
| 232 |
+
|
| 233 |
+
def forward(self, x: torch.Tensor, num_spk: int):
|
| 234 |
+
"""
|
| 235 |
+
Compute encoded features.
|
| 236 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 237 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 238 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 239 |
+
"""
|
| 240 |
+
B, F, T = x.shape
|
| 241 |
+
x = x.view(B//num_spk, num_spk, F, T).contiguous()
|
| 242 |
+
x = x.permute([0, 3, 1, 2]).contiguous()
|
| 243 |
+
x = x.view(-1, num_spk, F).contiguous()
|
| 244 |
+
x = x + self.self_attn(x, None, None)
|
| 245 |
+
x = x.view(B//num_spk, T, num_spk, F).contiguous()
|
| 246 |
+
x = x.permute([0, 2, 3, 1]).contiguous()
|
| 247 |
+
x = x.view(B, F, T).contiguous()
|
| 248 |
+
x = x.permute([0, 2, 1])
|
| 249 |
+
x = self.feed_forward(x)
|
| 250 |
+
x = x.permute([0, 2, 1])
|
| 251 |
+
|
| 252 |
+
return x
|
models/SepReformer/SepReformer_Large_DM_WSJ0/configs.yaml
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
config:
|
| 2 |
+
dataset:
|
| 3 |
+
max_len : 32000
|
| 4 |
+
sampling_rate: 8000
|
| 5 |
+
scp_dir: "data/scp_ss_8k"
|
| 6 |
+
train:
|
| 7 |
+
mixture: "tr_mix.scp"
|
| 8 |
+
spk1: "tr_s1.scp"
|
| 9 |
+
spk2: "tr_s2.scp"
|
| 10 |
+
dynamic_mixing: true
|
| 11 |
+
valid:
|
| 12 |
+
mixture: "cv_mix.scp"
|
| 13 |
+
spk1: "cv_s1.scp"
|
| 14 |
+
spk2: "cv_s2.scp"
|
| 15 |
+
test:
|
| 16 |
+
mixture: "tt_mix.scp"
|
| 17 |
+
spk1: "tt_s1.scp"
|
| 18 |
+
spk2: "tt_s2.scp"
|
| 19 |
+
dataloader:
|
| 20 |
+
batch_size: 2
|
| 21 |
+
pin_memory: false
|
| 22 |
+
num_workers: 12
|
| 23 |
+
drop_last: false
|
| 24 |
+
model:
|
| 25 |
+
num_stages: &var_model_num_stages 4 # R
|
| 26 |
+
num_spks: &var_model_num_spks 2
|
| 27 |
+
module_audio_enc:
|
| 28 |
+
in_channels: 1
|
| 29 |
+
out_channels: &var_model_audio_enc_out_channels 256
|
| 30 |
+
kernel_size: &var_model_audio_enc_kernel_size 16 # L
|
| 31 |
+
stride: &var_model_audio_enc_stride 4 # S
|
| 32 |
+
groups: 1
|
| 33 |
+
bias: false
|
| 34 |
+
module_feature_projector:
|
| 35 |
+
num_channels: *var_model_audio_enc_out_channels
|
| 36 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 37 |
+
out_channels: &feature_projector_out_channels 256 # F
|
| 38 |
+
kernel_size: 1
|
| 39 |
+
bias: false
|
| 40 |
+
module_separator:
|
| 41 |
+
num_stages: *var_model_num_stages
|
| 42 |
+
relative_positional_encoding:
|
| 43 |
+
in_channels: *feature_projector_out_channels
|
| 44 |
+
num_heads: 8
|
| 45 |
+
maxlen: 2000
|
| 46 |
+
embed_v: false
|
| 47 |
+
enc_stage:
|
| 48 |
+
global_blocks:
|
| 49 |
+
in_channels: *feature_projector_out_channels
|
| 50 |
+
num_mha_heads: 8
|
| 51 |
+
dropout_rate: 0.1
|
| 52 |
+
local_blocks:
|
| 53 |
+
in_channels: *feature_projector_out_channels
|
| 54 |
+
kernel_size: 65
|
| 55 |
+
dropout_rate: 0.1
|
| 56 |
+
down_conv_layer:
|
| 57 |
+
in_channels: *feature_projector_out_channels
|
| 58 |
+
samp_kernel_size: &var_model_samp_kernel_size 5
|
| 59 |
+
spk_split_stage:
|
| 60 |
+
in_channels: *feature_projector_out_channels
|
| 61 |
+
num_spks: *var_model_num_spks
|
| 62 |
+
simple_fusion:
|
| 63 |
+
out_channels: *feature_projector_out_channels
|
| 64 |
+
dec_stage:
|
| 65 |
+
num_spks: *var_model_num_spks
|
| 66 |
+
global_blocks:
|
| 67 |
+
in_channels: *feature_projector_out_channels
|
| 68 |
+
num_mha_heads: 8
|
| 69 |
+
dropout_rate: 0.1
|
| 70 |
+
local_blocks:
|
| 71 |
+
in_channels: *feature_projector_out_channels
|
| 72 |
+
kernel_size: 65
|
| 73 |
+
dropout_rate: 0.1
|
| 74 |
+
spk_attention:
|
| 75 |
+
in_channels: *feature_projector_out_channels
|
| 76 |
+
num_mha_heads: 8
|
| 77 |
+
dropout_rate: 0.1
|
| 78 |
+
module_output_layer:
|
| 79 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 80 |
+
out_channels: *feature_projector_out_channels
|
| 81 |
+
num_spks: *var_model_num_spks
|
| 82 |
+
module_audio_dec:
|
| 83 |
+
in_channels: *var_model_audio_enc_out_channels
|
| 84 |
+
out_channels: 1
|
| 85 |
+
kernel_size: *var_model_audio_enc_kernel_size
|
| 86 |
+
stride: *var_model_audio_enc_stride
|
| 87 |
+
bias: false
|
| 88 |
+
criterion: ### Ref: https://pytorch.org/docs/stable/nn.html#loss-functions
|
| 89 |
+
name: ["PIT_SISNR_mag", "PIT_SISNR_time", "PIT_SISNRi", "PIT_SDRi"] ### Choose a torch.nn's loss function class(=attribute) e.g. ["L1Loss", "MSELoss", "CrossEntropyLoss", ...] / You can also build your optimizer :)
|
| 90 |
+
PIT_SISNR_mag:
|
| 91 |
+
frame_length: 512
|
| 92 |
+
frame_shift: 128
|
| 93 |
+
window: 'hann'
|
| 94 |
+
num_stages: *var_model_num_stages
|
| 95 |
+
num_spks: *var_model_num_spks
|
| 96 |
+
scale_inv: true
|
| 97 |
+
mel_opt: false
|
| 98 |
+
PIT_SISNR_time:
|
| 99 |
+
num_spks: *var_model_num_spks
|
| 100 |
+
scale_inv: true
|
| 101 |
+
PIT_SISNRi:
|
| 102 |
+
num_spks: *var_model_num_spks
|
| 103 |
+
scale_inv: true
|
| 104 |
+
PIT_SDRi:
|
| 105 |
+
dump: 0
|
| 106 |
+
optimizer: ### Ref: https://pytorch.org/docs/stable/optim.html#algorithms
|
| 107 |
+
name: ["AdamW"] ### Choose a torch.optim's class(=attribute) e.g. ["Adam", "AdamW", "SGD", ...] / You can also build your optimizer :)
|
| 108 |
+
AdamW:
|
| 109 |
+
lr: 2.0e-4
|
| 110 |
+
weight_decay: 1.0e-2
|
| 111 |
+
scheduler: ### Ref(+ find "How to adjust learning rate"): https://pytorch.org/docs/stable/optim.html#algorithms
|
| 112 |
+
name: ["ReduceLROnPlateau", "WarmupConstantSchedule"] ### Choose a torch.optim.lr_scheduler's class(=attribute) e.g. ["StepLR", "ReduceLROnPlateau", "Custom"] / You can also build your scheduler :)
|
| 113 |
+
ReduceLROnPlateau:
|
| 114 |
+
mode: "min"
|
| 115 |
+
min_lr: 1.0e-10
|
| 116 |
+
factor: 0.8
|
| 117 |
+
patience: 2
|
| 118 |
+
WarmupConstantSchedule:
|
| 119 |
+
warmup_steps: 1000
|
| 120 |
+
check_computations:
|
| 121 |
+
dummy_len: 16000
|
| 122 |
+
engine:
|
| 123 |
+
max_epoch: 200
|
| 124 |
+
gpuid: "0" ### "0"(single-gpu) or "0, 1" (multi-gpu)
|
| 125 |
+
mvn: false
|
| 126 |
+
clip_norm: 5
|
| 127 |
+
start_scheduling: 50
|
| 128 |
+
test_epochs: [50, 80, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 199]
|
models/SepReformer/SepReformer_Large_DM_WSJ0/dataset.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import librosa as audio_lib
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from utils import util_dataset
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from loguru import logger
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
|
| 12 |
+
@logger_wraps()
|
| 13 |
+
def get_dataloaders(args, dataset_config, loader_config):
|
| 14 |
+
# create dataset object for each partition
|
| 15 |
+
partitions = ["test"] if "test" in args.engine_mode else ["train", "valid", "test"]
|
| 16 |
+
dataloaders = {}
|
| 17 |
+
for partition in partitions:
|
| 18 |
+
scp_config_mix = os.path.join(dataset_config["scp_dir"], dataset_config[partition]['mixture'])
|
| 19 |
+
scp_config_spk = [os.path.join(dataset_config["scp_dir"], dataset_config[partition][spk_key]) for spk_key in dataset_config[partition] if spk_key.startswith('spk')]
|
| 20 |
+
dynamic_mixing = dataset_config[partition]["dynamic_mixing"] if partition == 'train' else False
|
| 21 |
+
dataset = MyDataset(
|
| 22 |
+
max_len = dataset_config['max_len'],
|
| 23 |
+
fs = dataset_config['sampling_rate'],
|
| 24 |
+
partition = partition,
|
| 25 |
+
wave_scp_srcs = scp_config_spk,
|
| 26 |
+
wave_scp_mix = scp_config_mix,
|
| 27 |
+
dynamic_mixing = dynamic_mixing)
|
| 28 |
+
dataloader = DataLoader(
|
| 29 |
+
dataset = dataset,
|
| 30 |
+
batch_size = 1 if partition == 'test' else loader_config["batch_size"],
|
| 31 |
+
shuffle = True, # only train: (partition == 'train') / all: True
|
| 32 |
+
pin_memory = loader_config["pin_memory"],
|
| 33 |
+
num_workers = loader_config["num_workers"],
|
| 34 |
+
drop_last = loader_config["drop_last"],
|
| 35 |
+
collate_fn = _collate)
|
| 36 |
+
dataloaders[partition] = dataloader
|
| 37 |
+
return dataloaders
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _collate(egs):
|
| 41 |
+
"""
|
| 42 |
+
Transform utterance index into a minbatch
|
| 43 |
+
|
| 44 |
+
Arguments:
|
| 45 |
+
index: a list type [{},{},{}]
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
input_sizes: a tensor correspond to utterance length
|
| 49 |
+
input_feats: packed sequence to feed networks
|
| 50 |
+
source_attr/target_attr: dictionary contains spectrogram/phase needed in loss computation
|
| 51 |
+
"""
|
| 52 |
+
def __prepare_target_rir(dict_lsit, index):
|
| 53 |
+
return torch.nn.utils.rnn.pad_sequence([torch.tensor(d["src"][index], dtype=torch.float32) for d in dict_lsit], batch_first=True)
|
| 54 |
+
if type(egs) is not list: raise ValueError("Unsupported index type({})".format(type(egs)))
|
| 55 |
+
num_spks = 2 # you need to set this paramater by yourself
|
| 56 |
+
dict_list = sorted([eg for eg in egs], key=lambda x: x['num_sample'], reverse=True)
|
| 57 |
+
mixture = torch.nn.utils.rnn.pad_sequence([torch.tensor(d['mix'], dtype=torch.float32) for d in dict_list], batch_first=True)
|
| 58 |
+
src = [__prepare_target_rir(dict_list, index) for index in range(num_spks)]
|
| 59 |
+
input_sizes = torch.tensor([d['num_sample'] for d in dict_list], dtype=torch.float32)
|
| 60 |
+
key = [d['key'] for d in dict_list]
|
| 61 |
+
return input_sizes, mixture, src, key
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@logger_wraps()
|
| 65 |
+
class MyDataset(Dataset):
|
| 66 |
+
def __init__(self, max_len, fs, partition, wave_scp_srcs, wave_scp_mix, dynamic_mixing, speed_list=None):
|
| 67 |
+
self.partition = partition
|
| 68 |
+
for wave_scp_src in wave_scp_srcs:
|
| 69 |
+
if not os.path.exists(wave_scp_src): raise FileNotFoundError(f"Could not find file {wave_scp_src}")
|
| 70 |
+
self.max_len = max_len
|
| 71 |
+
self.fs = fs
|
| 72 |
+
self.wave_dict_srcs = [util_dataset.parse_scps(wave_scp_src) for wave_scp_src in wave_scp_srcs]
|
| 73 |
+
self.wave_dict_mix = util_dataset.parse_scps(wave_scp_mix)
|
| 74 |
+
self.wave_keys = list(self.wave_dict_mix.keys())
|
| 75 |
+
logger.info(f"Create MyDataset for {wave_scp_mix} with {len(self.wave_dict_mix)} utterances")
|
| 76 |
+
self.dynamic_mixing = dynamic_mixing
|
| 77 |
+
|
| 78 |
+
def __len__(self):
|
| 79 |
+
return len(self.wave_dict_mix)
|
| 80 |
+
|
| 81 |
+
def __contains__(self, key):
|
| 82 |
+
return key in self.wave_dict_mix
|
| 83 |
+
|
| 84 |
+
def _dynamic_mixing(self, key):
|
| 85 |
+
def __match_length(wav, len_data) :
|
| 86 |
+
leftover = len(wav) - len_data
|
| 87 |
+
idx = random.randint(0,leftover)
|
| 88 |
+
wav = wav[idx:idx+len_data]
|
| 89 |
+
return wav
|
| 90 |
+
|
| 91 |
+
samps_src = []
|
| 92 |
+
src_len = []
|
| 93 |
+
# dyanmic source choice
|
| 94 |
+
# checking whether it is the same speaker
|
| 95 |
+
while True:
|
| 96 |
+
key_random = random.choice(list(self.wave_dict_srcs[0].keys()))
|
| 97 |
+
tmp1 = key.split('_')[1][:3] != key_random.split('_')[3][:3]
|
| 98 |
+
tmp2 = key.split('_')[3][:3] != key_random.split('_')[1][:3]
|
| 99 |
+
if tmp1 and tmp2: break
|
| 100 |
+
|
| 101 |
+
idx1, idx2 = (0, 1) if random.random() > 0.5 else (1, 0)
|
| 102 |
+
files = [self.wave_dict_srcs[idx1][key], self.wave_dict_srcs[idx2][key_random]]
|
| 103 |
+
|
| 104 |
+
# load
|
| 105 |
+
for idx, file in enumerate(files):
|
| 106 |
+
if not os.path.exists(file): raise FileNotFoundError("Input file {} do not exists!".format(file))
|
| 107 |
+
samps_tmp, _ = audio_lib.load(file, sr=self.fs)
|
| 108 |
+
|
| 109 |
+
if idx == 0: ref_rms = np.sqrt(np.mean(np.square(samps_tmp)))
|
| 110 |
+
curr_rms = np.sqrt(np.mean(np.square(samps_tmp)))
|
| 111 |
+
|
| 112 |
+
norm_factor = ref_rms / curr_rms
|
| 113 |
+
samps_tmp *= norm_factor
|
| 114 |
+
|
| 115 |
+
# mixing with random gains
|
| 116 |
+
gain = pow(10,-random.uniform(-5,5)/20)
|
| 117 |
+
samps_tmp = np.array(torch.tensor(samps_tmp))
|
| 118 |
+
samps_src.append(gain*samps_tmp)
|
| 119 |
+
src_len.append(len(samps_tmp))
|
| 120 |
+
|
| 121 |
+
# matching the audio length
|
| 122 |
+
min_len = min(src_len)
|
| 123 |
+
|
| 124 |
+
# add noise source dynamically if needed
|
| 125 |
+
samps_src = [__match_length(s, min_len) for s in samps_src]
|
| 126 |
+
samps_mix = sum(samps_src)
|
| 127 |
+
|
| 128 |
+
# ! truncated along to the sample Length "L"
|
| 129 |
+
if len(samps_mix)%4 != 0:
|
| 130 |
+
remains = len(samps_mix)%4
|
| 131 |
+
samps_mix = samps_mix[:-remains]
|
| 132 |
+
samps_src = [s[:-remains] for s in samps_src]
|
| 133 |
+
|
| 134 |
+
if self.partition != "test":
|
| 135 |
+
if len(samps_mix) > self.max_len:
|
| 136 |
+
start = random.randint(0, len(samps_mix)-self.max_len)
|
| 137 |
+
samps_mix = samps_mix[start:start+self.max_len]
|
| 138 |
+
samps_src = [s[start:start+self.max_len] for s in samps_src]
|
| 139 |
+
return samps_mix, samps_src
|
| 140 |
+
|
| 141 |
+
def _direct_load(self, key):
|
| 142 |
+
samps_src = []
|
| 143 |
+
files = [wave_dict_src[key] for wave_dict_src in self.wave_dict_srcs]
|
| 144 |
+
for file in files:
|
| 145 |
+
if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
|
| 146 |
+
samps_tmp, _ = audio_lib.load(file, sr=self.fs)
|
| 147 |
+
samps_src.append(samps_tmp)
|
| 148 |
+
|
| 149 |
+
file = self.wave_dict_mix[key]
|
| 150 |
+
if not os.path.exists(file): raise FileNotFoundError(f"Input file {file} do not exists!")
|
| 151 |
+
samps_mix, _ = audio_lib.load(file, sr=self.fs)
|
| 152 |
+
|
| 153 |
+
# Truncate samples as needed
|
| 154 |
+
if len(samps_mix) % 4 != 0:
|
| 155 |
+
remains = len(samps_mix) % 4
|
| 156 |
+
samps_mix = samps_mix[:-remains]
|
| 157 |
+
samps_src = [s[:-remains] for s in samps_src]
|
| 158 |
+
|
| 159 |
+
if self.partition != "test":
|
| 160 |
+
if len(samps_mix) > self.max_len:
|
| 161 |
+
start = random.randint(0,len(samps_mix)-self.max_len)
|
| 162 |
+
samps_mix = samps_mix[start:start+self.max_len]
|
| 163 |
+
samps_src = [s[start:start+self.max_len] for s in samps_src]
|
| 164 |
+
|
| 165 |
+
return samps_mix, samps_src
|
| 166 |
+
|
| 167 |
+
def __getitem__(self, index):
|
| 168 |
+
key = self.wave_keys[index]
|
| 169 |
+
if any(key not in self.wave_dict_srcs[i] for i in range(len(self.wave_dict_srcs))) or key not in self.wave_dict_mix: raise KeyError(f"Could not find utterance {key}")
|
| 170 |
+
samps_mix, samps_src = self._dynamic_mixing(key) if self.dynamic_mixing else self._direct_load(key)
|
| 171 |
+
return {"num_sample": samps_mix.shape[0], "mix": samps_mix, "src": samps_src, "key": key}
|
models/SepReformer/SepReformer_Large_DM_WSJ0/engine.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import csv
|
| 4 |
+
import time
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
|
| 7 |
+
from loguru import logger
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from utils import util_engine, functions
|
| 10 |
+
from utils.decorators import *
|
| 11 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@logger_wraps()
|
| 15 |
+
class Engine(object):
|
| 16 |
+
def __init__(self, args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device):
|
| 17 |
+
|
| 18 |
+
''' Default setting '''
|
| 19 |
+
self.engine_mode = args.engine_mode
|
| 20 |
+
self.out_wav_dir = args.out_wav_dir
|
| 21 |
+
self.config = config
|
| 22 |
+
self.gpuid = gpuid
|
| 23 |
+
self.device = device
|
| 24 |
+
self.model = model.to(self.device)
|
| 25 |
+
self.dataloaders = dataloaders # self.dataloaders['train'] or ['valid'] or ['test']
|
| 26 |
+
self.PIT_SISNR_mag_loss, self.PIT_SISNR_time_loss, self.PIT_SISNRi_loss, self.PIT_SDRi_loss = criterions
|
| 27 |
+
self.main_optimizer = optimizers[0]
|
| 28 |
+
self.main_scheduler, self.warmup_scheduler = schedulers
|
| 29 |
+
|
| 30 |
+
self.pretrain_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "pretrain_weights")
|
| 31 |
+
os.makedirs(self.pretrain_weights_path, exist_ok=True)
|
| 32 |
+
self.scratch_weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "scratch_weights")
|
| 33 |
+
os.makedirs(self.scratch_weights_path, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
self.checkpoint_path = self.pretrain_weights_path if any(file.endswith(('.pt', '.pt', '.pkl')) for file in os.listdir(self.pretrain_weights_path)) else self.scratch_weights_path
|
| 36 |
+
self.start_epoch = util_engine.load_last_checkpoint_n_get_epoch(self.checkpoint_path, self.model, self.main_optimizer, location=self.device)
|
| 37 |
+
|
| 38 |
+
# Logging
|
| 39 |
+
util_engine.model_params_mac_summary(
|
| 40 |
+
model=self.model,
|
| 41 |
+
input=torch.randn(1, self.config['check_computations']['dummy_len']).to(self.device),
|
| 42 |
+
dummy_input=torch.rand(1, self.config['check_computations']['dummy_len']).to(self.device),
|
| 43 |
+
metrics=['ptflops', 'thop', 'torchinfo']
|
| 44 |
+
# metrics=['ptflops']
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
logger.info(f"Clip gradient by 2-norm {self.config['engine']['clip_norm']}")
|
| 48 |
+
|
| 49 |
+
@logger_wraps()
|
| 50 |
+
def _train(self, dataloader, epoch):
|
| 51 |
+
self.model.train()
|
| 52 |
+
tot_loss_freq = [0 for _ in range(self.model.num_stages)]
|
| 53 |
+
tot_loss_time, num_batch = 0, 0
|
| 54 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:25}{r_bar}{bar:-10b}', colour="YELLOW", dynamic_ncols=True)
|
| 55 |
+
for input_sizes, mixture, src, _ in dataloader:
|
| 56 |
+
nnet_input = mixture
|
| 57 |
+
nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
|
| 58 |
+
num_batch += 1
|
| 59 |
+
pbar.update(1)
|
| 60 |
+
# Scheduler learning rate for warm-up (Iteration-based update for transformers)
|
| 61 |
+
if epoch == 1: self.warmup_scheduler.step()
|
| 62 |
+
nnet_input = nnet_input.to(self.device)
|
| 63 |
+
self.main_optimizer.zero_grad()
|
| 64 |
+
estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 65 |
+
cur_loss_s_bn = 0
|
| 66 |
+
cur_loss_s_bn = []
|
| 67 |
+
for idx, estim_src_value in enumerate(estim_src_bn):
|
| 68 |
+
cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
|
| 69 |
+
tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
|
| 70 |
+
cur_loss_s = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
|
| 71 |
+
tot_loss_time += cur_loss_s.item() / self.config['model']['num_spks']
|
| 72 |
+
alpha = 0.4 * 0.8**(1+(epoch-101)//5) if epoch > 100 else 0.4
|
| 73 |
+
cur_loss = (1-alpha) * cur_loss_s + alpha * sum(cur_loss_s_bn) / len(cur_loss_s_bn)
|
| 74 |
+
cur_loss = cur_loss / self.config['model']['num_spks']
|
| 75 |
+
cur_loss.backward()
|
| 76 |
+
if self.config['engine']['clip_norm']: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['engine']['clip_norm'])
|
| 77 |
+
self.main_optimizer.step()
|
| 78 |
+
dict_loss = {"T_Loss": tot_loss_time / num_batch}
|
| 79 |
+
dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
|
| 80 |
+
pbar.set_postfix(dict_loss)
|
| 81 |
+
pbar.close()
|
| 82 |
+
tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
|
| 83 |
+
return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
|
| 84 |
+
|
| 85 |
+
@logger_wraps()
|
| 86 |
+
def _validate(self, dataloader):
|
| 87 |
+
self.model.eval()
|
| 88 |
+
tot_loss_freq = [0 for _ in range(self.model.num_stages)]
|
| 89 |
+
tot_loss_time, num_batch = 0, 0
|
| 90 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="RED", dynamic_ncols=True)
|
| 91 |
+
with torch.inference_mode():
|
| 92 |
+
for input_sizes, mixture, src, _ in dataloader:
|
| 93 |
+
nnet_input = mixture
|
| 94 |
+
nnet_input = functions.apply_cmvn(nnet_input) if self.config['engine']['mvn'] else nnet_input
|
| 95 |
+
nnet_input = nnet_input.to(self.device)
|
| 96 |
+
num_batch += 1
|
| 97 |
+
pbar.update(1)
|
| 98 |
+
estim_src, estim_src_bn = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 99 |
+
cur_loss_s_bn = []
|
| 100 |
+
for idx, estim_src_value in enumerate(estim_src_bn):
|
| 101 |
+
cur_loss_s_bn.append(self.PIT_SISNR_mag_loss(estims=estim_src_value, idx=idx, input_sizes=input_sizes, target_attr=src))
|
| 102 |
+
tot_loss_freq[idx] += cur_loss_s_bn[idx].item() / (self.config['model']['num_spks'])
|
| 103 |
+
cur_loss_s_SDR = self.PIT_SISNR_time_loss(estims=estim_src, input_sizes=input_sizes, target_attr=src)
|
| 104 |
+
tot_loss_time += cur_loss_s_SDR.item() / self.config['model']['num_spks']
|
| 105 |
+
dict_loss = {"T_Loss":tot_loss_time / num_batch}
|
| 106 |
+
dict_loss.update({'F_Loss_' + str(idx): loss / num_batch for idx, loss in enumerate(tot_loss_freq)})
|
| 107 |
+
pbar.set_postfix(dict_loss)
|
| 108 |
+
pbar.close()
|
| 109 |
+
tot_loss_freq = sum(tot_loss_freq) / len(tot_loss_freq)
|
| 110 |
+
return tot_loss_time / num_batch, tot_loss_freq / num_batch, num_batch
|
| 111 |
+
|
| 112 |
+
@logger_wraps()
|
| 113 |
+
def _test(self, dataloader, wav_dir=None):
|
| 114 |
+
self.model.eval()
|
| 115 |
+
total_loss_SISNRi, total_loss_SDRi, num_batch = 0, 0, 0
|
| 116 |
+
pbar = tqdm(total=len(dataloader), unit='batches', bar_format='{l_bar}{bar:5}{r_bar}{bar:-10b}', colour="grey", dynamic_ncols=True)
|
| 117 |
+
with torch.inference_mode():
|
| 118 |
+
csv_file_name_sisnr = os.path.join(os.path.dirname(__file__),'test_SISNRi_value.csv')
|
| 119 |
+
csv_file_name_sdr = os.path.join(os.path.dirname(__file__),'test_SDRi_value.csv')
|
| 120 |
+
with open(csv_file_name_sisnr, 'w', newline='') as csvfile_sisnr, open(csv_file_name_sdr, 'w', newline='') as csvfile_sdr:
|
| 121 |
+
idx = 0
|
| 122 |
+
writer_sisnr = csv.writer(csvfile_sisnr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
| 123 |
+
writer_sdr = csv.writer(csvfile_sdr, quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
| 124 |
+
for input_sizes, mixture, src, key in dataloader:
|
| 125 |
+
if len(key) > 1:
|
| 126 |
+
raise("batch size is not one!!")
|
| 127 |
+
nnet_input = mixture.to(self.device)
|
| 128 |
+
num_batch += 1
|
| 129 |
+
pbar.update(1)
|
| 130 |
+
estim_src, _ = torch.nn.parallel.data_parallel(self.model, nnet_input, device_ids=self.gpuid)
|
| 131 |
+
cur_loss_SISNRi, cur_loss_SISNRi_src = self.PIT_SISNRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src, eps=1.0e-15)
|
| 132 |
+
total_loss_SISNRi += cur_loss_SISNRi.item() / self.config['model']['num_spks']
|
| 133 |
+
cur_loss_SDRi, cur_loss_SDRi_src = self.PIT_SDRi_loss(estims=estim_src, mixture=mixture, input_sizes=input_sizes, target_attr=src)
|
| 134 |
+
total_loss_SDRi += cur_loss_SDRi.item() / self.config['model']['num_spks']
|
| 135 |
+
writer_sisnr.writerow([key[0][:-4]] + [cur_loss_SISNRi_src[i].item() for i in range(self.config['model']['num_spks'])])
|
| 136 |
+
writer_sdr.writerow([key[0][:-4]] + [cur_loss_SDRi_src[i].item() for i in range(self.config['model']['num_spks'])])
|
| 137 |
+
if self.engine_mode == "test_save":
|
| 138 |
+
if wav_dir == None: wav_dir = os.path.join(os.path.dirname(__file__),"wav_out")
|
| 139 |
+
if wav_dir and not os.path.exists(wav_dir): os.makedirs(wav_dir)
|
| 140 |
+
mixture = torch.squeeze(mixture).cpu().data.numpy()
|
| 141 |
+
sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_mixture.wav'), 0.5*mixture/max(abs(mixture)), 8000)
|
| 142 |
+
for i in range(self.config['model']['num_spks']):
|
| 143 |
+
src = torch.squeeze(estim_src[i]).cpu().data.numpy()
|
| 144 |
+
sf.write(os.path.join(wav_dir,key[0][:-4]+str(idx)+'_out_'+str(i)+'.wav'), 0.5*src/max(abs(src)), 8000)
|
| 145 |
+
idx += 1
|
| 146 |
+
dict_loss = {"SiSNRi": total_loss_SISNRi/num_batch, "SDRi": total_loss_SDRi/num_batch}
|
| 147 |
+
pbar.set_postfix(dict_loss)
|
| 148 |
+
pbar.close()
|
| 149 |
+
return total_loss_SISNRi/num_batch, total_loss_SDRi/num_batch, num_batch
|
| 150 |
+
|
| 151 |
+
@logger_wraps()
|
| 152 |
+
def run(self):
|
| 153 |
+
with torch.cuda.device(self.device):
|
| 154 |
+
writer_src = SummaryWriter(os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/tensorboard"))
|
| 155 |
+
if "test" in self.engine_mode:
|
| 156 |
+
on_test_start = time.time()
|
| 157 |
+
test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'], self.out_wav_dir)
|
| 158 |
+
on_test_end = time.time()
|
| 159 |
+
logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
|
| 160 |
+
logger.info(f"Testing done!")
|
| 161 |
+
else:
|
| 162 |
+
start_time = time.time()
|
| 163 |
+
if self.start_epoch > 1:
|
| 164 |
+
init_loss_time, init_loss_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
|
| 165 |
+
else:
|
| 166 |
+
init_loss_time, init_loss_freq = 0, 0
|
| 167 |
+
end_time = time.time()
|
| 168 |
+
logger.info(f"[INIT] Loss(time/mini-batch) \n - Epoch {self.start_epoch:2d}: Loss_t = {init_loss_time:.4f} dB | Loss_f = {init_loss_freq:.4f} dB | Speed = ({end_time-start_time:.2f}s)")
|
| 169 |
+
for epoch in range(self.start_epoch, self.config['engine']['max_epoch']):
|
| 170 |
+
valid_loss_best = init_loss_time
|
| 171 |
+
train_start_time = time.time()
|
| 172 |
+
train_loss_src_time, train_loss_src_freq, train_num_batch = self._train(self.dataloaders['train'], epoch)
|
| 173 |
+
train_end_time = time.time()
|
| 174 |
+
valid_start_time = time.time()
|
| 175 |
+
valid_loss_src_time, valid_loss_src_freq, valid_num_batch = self._validate(self.dataloaders['valid'])
|
| 176 |
+
valid_end_time = time.time()
|
| 177 |
+
if epoch > self.config['engine']['start_scheduling']: self.main_scheduler.step(valid_loss_src_time)
|
| 178 |
+
logger.info(f"[TRAIN] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {train_loss_src_time:.4f} dB | Loss_f = {train_loss_src_freq:.4f} dB | Speed = ({train_end_time - train_start_time:.2f}s/{train_num_batch:d})")
|
| 179 |
+
logger.info(f"[VALID] Loss(time/mini-batch) \n - Epoch {epoch:2d}: Loss_t = {valid_loss_src_time:.4f} dB | Loss_f = {valid_loss_src_freq:.4f} dB | Speed = ({valid_end_time - valid_start_time:.2f}s/{valid_num_batch:d})")
|
| 180 |
+
if epoch in self.config['engine']['test_epochs']:
|
| 181 |
+
on_test_start = time.time()
|
| 182 |
+
test_loss_src_time_1, test_loss_src_time_2, test_num_batch = self._test(self.dataloaders['test'])
|
| 183 |
+
on_test_end = time.time()
|
| 184 |
+
logger.info(f"[TEST] Loss(time/mini-batch) \n - Epoch {epoch:2d}: SISNRi = {test_loss_src_time_1:.4f} dB | SDRi = {test_loss_src_time_2:.4f} dB | Speed = ({on_test_end - on_test_start:.2f}s/{test_num_batch:d})")
|
| 185 |
+
valid_loss_best = util_engine.save_checkpoint_per_best(valid_loss_best, valid_loss_src_time, train_loss_src_time, epoch, self.model, self.main_optimizer, self.checkpoint_path)
|
| 186 |
+
# Logging to monitoring tools (Tensorboard && Wandb)
|
| 187 |
+
writer_src.add_scalars("Metrics", {
|
| 188 |
+
'Loss_train_time': train_loss_src_time,
|
| 189 |
+
'Loss_valid_time': valid_loss_src_time}, epoch)
|
| 190 |
+
writer_src.add_scalar("Learning Rate", self.main_optimizer.param_groups[0]['lr'], epoch)
|
| 191 |
+
writer_src.flush()
|
| 192 |
+
logger.info(f"Training for {self.config['engine']['max_epoch']} epoches done!")
|
models/SepReformer/SepReformer_Large_DM_WSJ0/main.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from .dataset import get_dataloaders
|
| 5 |
+
from .model import Model
|
| 6 |
+
from .engine import Engine
|
| 7 |
+
from utils import util_system, util_implement
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
|
| 10 |
+
# Setup logger
|
| 11 |
+
log_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log/system_log.log")
|
| 12 |
+
logger.add(log_file_path, level="DEBUG", mode="w")
|
| 13 |
+
|
| 14 |
+
@logger_wraps()
|
| 15 |
+
def main(args):
|
| 16 |
+
|
| 17 |
+
''' Build Setting '''
|
| 18 |
+
# Call configuration file (configs.yaml)
|
| 19 |
+
yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs.yaml")
|
| 20 |
+
yaml_dict = util_system.parse_yaml(yaml_path)
|
| 21 |
+
|
| 22 |
+
# Run wandb and get configuration
|
| 23 |
+
config = yaml_dict["config"] # wandb login success or fail
|
| 24 |
+
|
| 25 |
+
# Call DataLoader [train / valid / test / etc...]
|
| 26 |
+
dataloaders = get_dataloaders(args, config["dataset"], config["dataloader"])
|
| 27 |
+
|
| 28 |
+
''' Build Model '''
|
| 29 |
+
# Call network model
|
| 30 |
+
model = Model(**config["model"])
|
| 31 |
+
|
| 32 |
+
''' Build Engine '''
|
| 33 |
+
# Call gpu id & device
|
| 34 |
+
gpuid = tuple(map(int, config["engine"]["gpuid"].split(',')))
|
| 35 |
+
device = torch.device(f'cuda:{gpuid[0]}')
|
| 36 |
+
|
| 37 |
+
# Call Implement [criterion / optimizer / scheduler]
|
| 38 |
+
criterions = util_implement.CriterionFactory(config["criterion"], device).get_criterions()
|
| 39 |
+
optimizers = util_implement.OptimizerFactory(config["optimizer"], model.parameters()).get_optimizers()
|
| 40 |
+
schedulers = util_implement.SchedulerFactory(config["scheduler"], optimizers).get_schedulers()
|
| 41 |
+
|
| 42 |
+
# Call & Run Engine
|
| 43 |
+
engine = Engine(args, config, model, dataloaders, criterions, optimizers, schedulers, gpuid, device)
|
| 44 |
+
engine.run()
|
models/SepReformer/SepReformer_Large_DM_WSJ0/model.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('../')
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings('ignore')
|
| 7 |
+
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from .modules.module import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@logger_wraps()
|
| 13 |
+
class Model(torch.nn.Module):
|
| 14 |
+
def __init__(self,
|
| 15 |
+
num_stages: int,
|
| 16 |
+
num_spks: int,
|
| 17 |
+
module_audio_enc: dict,
|
| 18 |
+
module_feature_projector: dict,
|
| 19 |
+
module_separator: dict,
|
| 20 |
+
module_output_layer: dict,
|
| 21 |
+
module_audio_dec: dict):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.num_stages = num_stages
|
| 24 |
+
self.num_spks = num_spks
|
| 25 |
+
self.audio_encoder = AudioEncoder(**module_audio_enc)
|
| 26 |
+
self.feature_projector = FeatureProjector(**module_feature_projector)
|
| 27 |
+
self.separator = Separator(**module_separator)
|
| 28 |
+
self.out_layer = OutputLayer(**module_output_layer)
|
| 29 |
+
self.audio_decoder = AudioDecoder(**module_audio_dec)
|
| 30 |
+
|
| 31 |
+
# Aux_loss
|
| 32 |
+
self.out_layer_bn = torch.nn.ModuleList([])
|
| 33 |
+
self.decoder_bn = torch.nn.ModuleList([])
|
| 34 |
+
for _ in range(self.num_stages):
|
| 35 |
+
self.out_layer_bn.append(OutputLayer(**module_output_layer, masking=True))
|
| 36 |
+
self.decoder_bn.append(AudioDecoder(**module_audio_dec))
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
encoder_output = self.audio_encoder(x)
|
| 40 |
+
projected_feature = self.feature_projector(encoder_output)
|
| 41 |
+
last_stage_output, each_stage_outputs = self.separator(projected_feature)
|
| 42 |
+
out_layer_output = self.out_layer(last_stage_output, encoder_output)
|
| 43 |
+
each_spk_output = [out_layer_output[idx] for idx in range(self.num_spks)]
|
| 44 |
+
audio = [self.audio_decoder(each_spk_output[idx]) for idx in range(self.num_spks)]
|
| 45 |
+
|
| 46 |
+
# Aux_loss
|
| 47 |
+
audio_aux = []
|
| 48 |
+
for idx, each_stage_output in enumerate(each_stage_outputs):
|
| 49 |
+
each_stage_output = self.out_layer_bn[idx](torch.nn.functional.upsample(each_stage_output, encoder_output.shape[-1]), encoder_output)
|
| 50 |
+
out_aux = [each_stage_output[jdx] for jdx in range(self.num_spks)]
|
| 51 |
+
audio_aux.append([self.decoder_bn[idx](out_aux[jdx])[...,:x.shape[-1]] for jdx in range(self.num_spks)])
|
| 52 |
+
|
| 53 |
+
return audio, audio_aux
|
models/SepReformer/SepReformer_Large_DM_WSJ0/modules/module.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('../')
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings('ignore')
|
| 7 |
+
|
| 8 |
+
from utils.decorators import *
|
| 9 |
+
from .network import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AudioEncoder(torch.nn.Module):
|
| 13 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, bias: bool):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.conv1d = torch.nn.Conv1d(
|
| 16 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, groups=groups, bias=bias)
|
| 17 |
+
self.gelu = torch.nn.GELU()
|
| 18 |
+
|
| 19 |
+
def forward(self, x: torch.Tensor):
|
| 20 |
+
x = torch.unsqueeze(x, dim=0) if len(x.shape) == 1 else torch.unsqueeze(x, dim=1) # [T] - >[1, T] OR [B, T] -> [B, 1, T]
|
| 21 |
+
x = self.conv1d(x)
|
| 22 |
+
x = self.gelu(x)
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
class FeatureProjector(torch.nn.Module):
|
| 26 |
+
def __init__(self, num_channels: int, in_channels: int, out_channels: int, kernel_size: int, bias: bool):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.norm = torch.nn.GroupNorm(num_groups=1, num_channels=num_channels, eps=1e-8)
|
| 29 |
+
self.conv1d = torch.nn.Conv1d(
|
| 30 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor):
|
| 33 |
+
x = self.norm(x)
|
| 34 |
+
x = self.conv1d(x)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Separator(torch.nn.Module):
|
| 39 |
+
def __init__(self, num_stages: int, relative_positional_encoding: dict, enc_stage: dict, spk_split_stage: dict, simple_fusion:dict, dec_stage: dict):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
class RelativePositionalEncoding(torch.nn.Module):
|
| 43 |
+
def __init__(self, in_channels: int, num_heads: int, maxlen: int, embed_v=False):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.in_channels = in_channels
|
| 46 |
+
self.num_heads = num_heads
|
| 47 |
+
self.embedding_dim = self.in_channels // self.num_heads
|
| 48 |
+
self.maxlen = maxlen
|
| 49 |
+
self.pe_k = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim)
|
| 50 |
+
self.pe_v = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim) if embed_v else None
|
| 51 |
+
|
| 52 |
+
def forward(self, pos_seq: torch.Tensor):
|
| 53 |
+
pos_seq.clamp_(-self.maxlen, self.maxlen - 1)
|
| 54 |
+
pos_seq += self.maxlen
|
| 55 |
+
pe_k_output = self.pe_k(pos_seq)
|
| 56 |
+
pe_v_output = self.pe_v(pos_seq) if self.pe_v is not None else None
|
| 57 |
+
return pe_k_output, pe_v_output
|
| 58 |
+
|
| 59 |
+
class SepEncStage(torch.nn.Module):
|
| 60 |
+
def __init__(self, global_blocks: dict, local_blocks: dict, down_conv_layer: dict, down_conv=True):
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
class DownConvLayer(torch.nn.Module):
|
| 64 |
+
def __init__(self, in_channels: int, samp_kernel_size: int):
|
| 65 |
+
"""Construct an EncoderLayer object."""
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.down_conv = torch.nn.Conv1d(
|
| 68 |
+
in_channels=in_channels, out_channels=in_channels, kernel_size=samp_kernel_size, stride=2, padding=(samp_kernel_size-1)//2, groups=in_channels)
|
| 69 |
+
self.BN = torch.nn.BatchNorm1d(num_features=in_channels)
|
| 70 |
+
self.gelu = torch.nn.GELU()
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor):
|
| 73 |
+
x = x.permute([0, 2, 1])
|
| 74 |
+
x = self.down_conv(x)
|
| 75 |
+
x = self.BN(x)
|
| 76 |
+
x = self.gelu(x)
|
| 77 |
+
x = x.permute([0, 2, 1])
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
self.g_block_1 = GlobalBlock(**global_blocks)
|
| 81 |
+
self.l_block_1 = LocalBlock(**local_blocks)
|
| 82 |
+
|
| 83 |
+
self.g_block_2 = GlobalBlock(**global_blocks)
|
| 84 |
+
self.l_block_2 = LocalBlock(**local_blocks)
|
| 85 |
+
|
| 86 |
+
self.downconv = DownConvLayer(**down_conv_layer) if down_conv == True else None
|
| 87 |
+
|
| 88 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 89 |
+
'''
|
| 90 |
+
x: [B, N, T]
|
| 91 |
+
'''
|
| 92 |
+
x = self.g_block_1(x, pos_k)
|
| 93 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 94 |
+
x = self.l_block_1(x)
|
| 95 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 96 |
+
|
| 97 |
+
x = self.g_block_2(x, pos_k)
|
| 98 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 99 |
+
x = self.l_block_2(x)
|
| 100 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 101 |
+
|
| 102 |
+
skip = x
|
| 103 |
+
if self.downconv:
|
| 104 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 105 |
+
x = self.downconv(x)
|
| 106 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 107 |
+
# [BK, S, N]
|
| 108 |
+
return x, skip
|
| 109 |
+
|
| 110 |
+
class SpkSplitStage(torch.nn.Module):
|
| 111 |
+
def __init__(self, in_channels: int, num_spks: int):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.linear = torch.nn.Sequential(
|
| 114 |
+
torch.nn.Conv1d(in_channels, 4*in_channels*num_spks, kernel_size=1),
|
| 115 |
+
torch.nn.GLU(dim=-2),
|
| 116 |
+
torch.nn.Conv1d(2*in_channels*num_spks, in_channels*num_spks, kernel_size=1))
|
| 117 |
+
self.norm = torch.nn.GroupNorm(1, in_channels, eps=1e-8)
|
| 118 |
+
self.num_spks = num_spks
|
| 119 |
+
|
| 120 |
+
def forward(self, x: torch.Tensor):
|
| 121 |
+
x = self.linear(x)
|
| 122 |
+
B, _, T = x.shape
|
| 123 |
+
x = x.view(B*self.num_spks,-1, T).contiguous()
|
| 124 |
+
x = self.norm(x)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
class SepDecStage(torch.nn.Module):
|
| 128 |
+
def __init__(self, num_spks: int, global_blocks: dict, local_blocks: dict, spk_attention: dict):
|
| 129 |
+
super().__init__()
|
| 130 |
+
|
| 131 |
+
self.g_block_1 = GlobalBlock(**global_blocks)
|
| 132 |
+
self.l_block_1 = LocalBlock(**local_blocks)
|
| 133 |
+
self.spk_attn_1 = SpkAttention(**spk_attention)
|
| 134 |
+
|
| 135 |
+
self.g_block_2 = GlobalBlock(**global_blocks)
|
| 136 |
+
self.l_block_2 = LocalBlock(**local_blocks)
|
| 137 |
+
self.spk_attn_2 = SpkAttention(**spk_attention)
|
| 138 |
+
|
| 139 |
+
self.g_block_3 = GlobalBlock(**global_blocks)
|
| 140 |
+
self.l_block_3 = LocalBlock(**local_blocks)
|
| 141 |
+
self.spk_attn_3 = SpkAttention(**spk_attention)
|
| 142 |
+
|
| 143 |
+
self.num_spk = num_spks
|
| 144 |
+
|
| 145 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 146 |
+
'''
|
| 147 |
+
x: [B, N, T]
|
| 148 |
+
'''
|
| 149 |
+
# [BS, K, H]
|
| 150 |
+
x = self.g_block_1(x, pos_k)
|
| 151 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 152 |
+
x = self.l_block_1(x)
|
| 153 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 154 |
+
x = self.spk_attn_1(x, self.num_spk)
|
| 155 |
+
|
| 156 |
+
x = self.g_block_2(x, pos_k)
|
| 157 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 158 |
+
x = self.l_block_2(x)
|
| 159 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 160 |
+
x = self.spk_attn_2(x, self.num_spk)
|
| 161 |
+
|
| 162 |
+
x = self.g_block_3(x, pos_k)
|
| 163 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 164 |
+
x = self.l_block_3(x)
|
| 165 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 166 |
+
x = self.spk_attn_3(x, self.num_spk)
|
| 167 |
+
|
| 168 |
+
skip = x
|
| 169 |
+
|
| 170 |
+
return x, skip
|
| 171 |
+
|
| 172 |
+
self.num_stages = num_stages
|
| 173 |
+
self.pos_emb = RelativePositionalEncoding(**relative_positional_encoding)
|
| 174 |
+
|
| 175 |
+
# Temporal Contracting Part
|
| 176 |
+
self.enc_stages = torch.nn.ModuleList([])
|
| 177 |
+
for _ in range(self.num_stages):
|
| 178 |
+
self.enc_stages.append(SepEncStage(**enc_stage, down_conv=True))
|
| 179 |
+
|
| 180 |
+
self.bottleneck_G = SepEncStage(**enc_stage, down_conv=False)
|
| 181 |
+
self.spk_split_block = SpkSplitStage(**spk_split_stage)
|
| 182 |
+
|
| 183 |
+
# Temporal Expanding Part
|
| 184 |
+
self.simple_fusion = torch.nn.ModuleList([])
|
| 185 |
+
self.dec_stages = torch.nn.ModuleList([])
|
| 186 |
+
for _ in range(self.num_stages):
|
| 187 |
+
self.simple_fusion.append(torch.nn.Conv1d(in_channels=simple_fusion['out_channels']*2,out_channels=simple_fusion['out_channels'], kernel_size=1))
|
| 188 |
+
self.dec_stages.append(SepDecStage(**dec_stage))
|
| 189 |
+
|
| 190 |
+
def forward(self, input: torch.Tensor):
|
| 191 |
+
'''input: [B, N, L]'''
|
| 192 |
+
# feature projection
|
| 193 |
+
x, _ = self.pad_signal(input)
|
| 194 |
+
len_x = x.shape[-1]
|
| 195 |
+
# Temporal Contracting Part
|
| 196 |
+
pos_seq = torch.arange(0, len_x//2**self.num_stages).long().to(x.device)
|
| 197 |
+
pos_seq = pos_seq[:, None] - pos_seq[None, :]
|
| 198 |
+
pos_k, _ = self.pos_emb(pos_seq)
|
| 199 |
+
skip = []
|
| 200 |
+
for idx in range(self.num_stages):
|
| 201 |
+
x, skip_ = self.enc_stages[idx](x, pos_k)
|
| 202 |
+
skip_ = self.spk_split_block(skip_)
|
| 203 |
+
skip.append(skip_)
|
| 204 |
+
x, _ = self.bottleneck_G(x, pos_k)
|
| 205 |
+
x = self.spk_split_block(x) # B, 2F, T
|
| 206 |
+
|
| 207 |
+
each_stage_outputs = []
|
| 208 |
+
# Temporal Expanding Part
|
| 209 |
+
for idx in range(self.num_stages):
|
| 210 |
+
each_stage_outputs.append(x)
|
| 211 |
+
idx_en = self.num_stages - (idx + 1)
|
| 212 |
+
x = torch.nn.functional.upsample(x, skip[idx_en].shape[-1])
|
| 213 |
+
x = torch.cat([x,skip[idx_en]],dim=1)
|
| 214 |
+
x = self.simple_fusion[idx](x)
|
| 215 |
+
x, _ = self.dec_stages[idx](x, pos_k)
|
| 216 |
+
|
| 217 |
+
last_stage_output = x
|
| 218 |
+
return last_stage_output, each_stage_outputs
|
| 219 |
+
|
| 220 |
+
def pad_signal(self, input: torch.Tensor):
|
| 221 |
+
# (B, T) or (B, 1, T)
|
| 222 |
+
if input.dim() == 1: input = input.unsqueeze(0)
|
| 223 |
+
elif input.dim() not in [2, 3]: raise RuntimeError("Input can only be 2 or 3 dimensional.")
|
| 224 |
+
elif input.dim() == 2: input = input.unsqueeze(1)
|
| 225 |
+
L = 2**self.num_stages
|
| 226 |
+
batch_size = input.size(0)
|
| 227 |
+
ndim = input.size(1)
|
| 228 |
+
nframe = input.size(2)
|
| 229 |
+
padded_len = (nframe//L + 1)*L
|
| 230 |
+
rest = 0 if nframe%L == 0 else padded_len - nframe
|
| 231 |
+
if rest > 0:
|
| 232 |
+
pad = torch.autograd.Variable(torch.zeros(batch_size, ndim, rest)).type(input.type()).to(input.device)
|
| 233 |
+
input = torch.cat([input, pad], dim=-1)
|
| 234 |
+
return input, rest
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class OutputLayer(torch.nn.Module):
|
| 238 |
+
def __init__(self, in_channels: int, out_channels: int, num_spks: int, masking: bool = False):
|
| 239 |
+
super().__init__()
|
| 240 |
+
# feature expansion back
|
| 241 |
+
self.masking = masking
|
| 242 |
+
self.spe_block = Masking(in_channels, Activation_mask="ReLU", concat_opt=None)
|
| 243 |
+
self.num_spks = num_spks
|
| 244 |
+
self.end_conv1x1 = torch.nn.Sequential(
|
| 245 |
+
torch.nn.Linear(out_channels, 4*out_channels),
|
| 246 |
+
torch.nn.GLU(),
|
| 247 |
+
torch.nn.Linear(2*out_channels, in_channels))
|
| 248 |
+
|
| 249 |
+
def forward(self, x: torch.Tensor, input: torch.Tensor):
|
| 250 |
+
x = x[...,:input.shape[-1]]
|
| 251 |
+
x = x.permute([0, 2, 1])
|
| 252 |
+
x = self.end_conv1x1(x)
|
| 253 |
+
x = x.permute([0, 2, 1])
|
| 254 |
+
B, N, L = x.shape
|
| 255 |
+
B = B // self.num_spks
|
| 256 |
+
|
| 257 |
+
if self.masking:
|
| 258 |
+
input = input.expand(self.num_spks, B, N, L).transpose(0,1).contiguous()
|
| 259 |
+
input = input.view(B*self.num_spks, N, L)
|
| 260 |
+
x = self.spe_block(x, input)
|
| 261 |
+
|
| 262 |
+
x = x.view(B, self.num_spks, N, L)
|
| 263 |
+
# [spks, B, N, L]
|
| 264 |
+
x = x.transpose(0, 1)
|
| 265 |
+
return x
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class AudioDecoder(torch.nn.ConvTranspose1d):
|
| 269 |
+
'''
|
| 270 |
+
Decoder of the TasNet
|
| 271 |
+
This module can be seen as the gradient of Conv1d with respect to its input.
|
| 272 |
+
It is also known as a fractionally-strided convolution
|
| 273 |
+
or a deconvolution (although it is not an actual deconvolution operation).
|
| 274 |
+
'''
|
| 275 |
+
def __init__(self, *args, **kwargs):
|
| 276 |
+
super().__init__(*args, **kwargs)
|
| 277 |
+
|
| 278 |
+
def forward(self, x):
|
| 279 |
+
# x: [B, N, L]
|
| 280 |
+
if x.dim() not in [2, 3]: raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
|
| 281 |
+
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
|
| 282 |
+
x = torch.squeeze(x, dim=1) if torch.squeeze(x).dim() == 1 else torch.squeeze(x)
|
| 283 |
+
return x
|
models/SepReformer/SepReformer_Large_DM_WSJ0/modules/network.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import numpy
|
| 4 |
+
from utils.decorators import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LayerScale(torch.nn.Module):
|
| 8 |
+
def __init__(self, dims, input_size, Layer_scale_init=1.0e-5):
|
| 9 |
+
super().__init__()
|
| 10 |
+
if dims == 1:
|
| 11 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(input_size)*Layer_scale_init, requires_grad=True)
|
| 12 |
+
elif dims == 2:
|
| 13 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(1,input_size)*Layer_scale_init, requires_grad=True)
|
| 14 |
+
elif dims == 3:
|
| 15 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(1,1,input_size)*Layer_scale_init, requires_grad=True)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return x*self.layer_scale
|
| 19 |
+
|
| 20 |
+
class Masking(torch.nn.Module):
|
| 21 |
+
def __init__(self, input_dim, Activation_mask='Sigmoid', **options):
|
| 22 |
+
super(Masking, self).__init__()
|
| 23 |
+
|
| 24 |
+
self.options = options
|
| 25 |
+
if self.options['concat_opt']:
|
| 26 |
+
self.pw_conv = torch.nn.Conv1d(input_dim*2, input_dim, 1, stride=1, padding=0)
|
| 27 |
+
|
| 28 |
+
if Activation_mask == 'Sigmoid':
|
| 29 |
+
self.gate_act = torch.nn.Sigmoid()
|
| 30 |
+
elif Activation_mask == 'ReLU':
|
| 31 |
+
self.gate_act = torch.nn.ReLU()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def forward(self, x, skip):
|
| 35 |
+
|
| 36 |
+
if self.options['concat_opt']:
|
| 37 |
+
y = torch.cat([x, skip], dim=-2)
|
| 38 |
+
y = self.pw_conv(y)
|
| 39 |
+
else:
|
| 40 |
+
y = x
|
| 41 |
+
y = self.gate_act(y) * skip
|
| 42 |
+
|
| 43 |
+
return y
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GCFN(torch.nn.Module):
|
| 47 |
+
def __init__(self, in_channels, dropout_rate, Layer_scale_init=1.0e-5):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.net1 = torch.nn.Sequential(
|
| 50 |
+
torch.nn.LayerNorm(in_channels),
|
| 51 |
+
torch.nn.Linear(in_channels, in_channels*6))
|
| 52 |
+
self.depthwise = torch.nn.Conv1d(in_channels*6, in_channels*6, 3, padding=1, groups=in_channels*6)
|
| 53 |
+
self.net2 = torch.nn.Sequential(
|
| 54 |
+
torch.nn.GLU(),
|
| 55 |
+
torch.nn.Dropout(dropout_rate),
|
| 56 |
+
torch.nn.Linear(in_channels*3, in_channels),
|
| 57 |
+
torch.nn.Dropout(dropout_rate))
|
| 58 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
y = self.net1(x)
|
| 62 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 63 |
+
y = self.depthwise(y)
|
| 64 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 65 |
+
y = self.net2(y)
|
| 66 |
+
return x + self.Layer_scale(y)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MultiHeadAttention(torch.nn.Module):
|
| 70 |
+
"""
|
| 71 |
+
Multi-Head Attention layer.
|
| 72 |
+
:param int n_head: the number of head s
|
| 73 |
+
:param int n_feat: the number of features
|
| 74 |
+
:param float dropout_rate: dropout rate
|
| 75 |
+
"""
|
| 76 |
+
def __init__(self, n_head: int, in_channels: int, dropout_rate: float, Layer_scale_init=1.0e-5):
|
| 77 |
+
super().__init__()
|
| 78 |
+
assert in_channels % n_head == 0
|
| 79 |
+
self.d_k = in_channels // n_head # We assume d_v always equals d_k
|
| 80 |
+
self.h = n_head
|
| 81 |
+
self.layer_norm = torch.nn.LayerNorm(in_channels)
|
| 82 |
+
self.linear_q = torch.nn.Linear(in_channels, in_channels)
|
| 83 |
+
self.linear_k = torch.nn.Linear(in_channels, in_channels)
|
| 84 |
+
self.linear_v = torch.nn.Linear(in_channels, in_channels)
|
| 85 |
+
self.linear_out = torch.nn.Linear(in_channels, in_channels)
|
| 86 |
+
self.attn = None
|
| 87 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 88 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, pos_k, mask):
|
| 91 |
+
"""
|
| 92 |
+
Compute 'Scaled Dot Product Attention'.
|
| 93 |
+
:param torch.Tensor mask: (batch, time1, time2)
|
| 94 |
+
:param torch.nn.Dropout dropout:
|
| 95 |
+
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
|
| 96 |
+
weighted by the query dot key attention (batch, head, time1, time2)
|
| 97 |
+
"""
|
| 98 |
+
n_batch = x.size(0)
|
| 99 |
+
x = self.layer_norm(x)
|
| 100 |
+
q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
|
| 101 |
+
k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
|
| 102 |
+
v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k)
|
| 103 |
+
q = q.transpose(1, 2)
|
| 104 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 105 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 106 |
+
A = torch.matmul(q, k.transpose(-2, -1))
|
| 107 |
+
reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1)
|
| 108 |
+
if pos_k is not None:
|
| 109 |
+
B = torch.matmul(reshape_q, pos_k.transpose(-2, -1))
|
| 110 |
+
B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
|
| 111 |
+
scores = (A + B) / math.sqrt(self.d_k)
|
| 112 |
+
else:
|
| 113 |
+
scores = A / math.sqrt(self.d_k)
|
| 114 |
+
if mask is not None:
|
| 115 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
| 116 |
+
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
| 117 |
+
scores = scores.masked_fill(mask, min_value)
|
| 118 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
| 119 |
+
else:
|
| 120 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 121 |
+
p_attn = self.dropout(self.attn)
|
| 122 |
+
x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
|
| 123 |
+
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
|
| 124 |
+
return self.Layer_scale(self.dropout(self.linear_out(x))) # (batch, time1, d_model)
|
| 125 |
+
|
| 126 |
+
class EGA(torch.nn.Module):
|
| 127 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.block = torch.nn.ModuleDict({
|
| 130 |
+
'self_attn': MultiHeadAttention(
|
| 131 |
+
n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
|
| 132 |
+
'linear': torch.nn.Sequential(
|
| 133 |
+
torch.nn.LayerNorm(normalized_shape=in_channels),
|
| 134 |
+
torch.nn.Linear(in_features=in_channels, out_features=in_channels),
|
| 135 |
+
torch.nn.Sigmoid())
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 139 |
+
"""
|
| 140 |
+
Compute encoded features.
|
| 141 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 142 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 143 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 144 |
+
"""
|
| 145 |
+
down_len = pos_k.shape[0]
|
| 146 |
+
x_down = torch.nn.functional.adaptive_avg_pool1d(input=x, output_size=down_len)
|
| 147 |
+
x = x.permute([0, 2, 1])
|
| 148 |
+
x_down = x_down.permute([0, 2, 1])
|
| 149 |
+
x_down = self.block['self_attn'](x_down, pos_k, None)
|
| 150 |
+
x_down = x_down.permute([0, 2, 1])
|
| 151 |
+
x_downup = torch.nn.functional.upsample(input=x_down, size=x.shape[1])
|
| 152 |
+
x_downup = x_downup.permute([0, 2, 1])
|
| 153 |
+
x = x + self.block['linear'](x) * x_downup
|
| 154 |
+
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class CLA(torch.nn.Module):
|
| 160 |
+
def __init__(self, in_channels, kernel_size, dropout_rate, Layer_scale_init=1.0e-5):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.layer_norm = torch.nn.LayerNorm(in_channels)
|
| 163 |
+
self.linear1 = torch.nn.Linear(in_channels, in_channels*2)
|
| 164 |
+
self.GLU = torch.nn.GLU()
|
| 165 |
+
self.dw_conv_1d = torch.nn.Conv1d(in_channels, in_channels, kernel_size, padding='same', groups=in_channels)
|
| 166 |
+
self.linear2 = torch.nn.Linear(in_channels, 2*in_channels)
|
| 167 |
+
self.BN = torch.nn.BatchNorm1d(2*in_channels)
|
| 168 |
+
self.linear3 = torch.nn.Sequential(
|
| 169 |
+
torch.nn.GELU(),
|
| 170 |
+
torch.nn.Linear(2*in_channels, in_channels),
|
| 171 |
+
torch.nn.Dropout(dropout_rate))
|
| 172 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
y = self.layer_norm(x)
|
| 176 |
+
y = self.linear1(y)
|
| 177 |
+
y = self.GLU(y)
|
| 178 |
+
y = y.permute([0, 2, 1]) # B, F, T
|
| 179 |
+
y = self.dw_conv_1d(y)
|
| 180 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 181 |
+
y = self.linear2(y)
|
| 182 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 183 |
+
y = self.BN(y)
|
| 184 |
+
y = y.permute(0, 2, 1) # B, T, 2F
|
| 185 |
+
y = self.linear3(y)
|
| 186 |
+
|
| 187 |
+
return x + self.Layer_scale(y)
|
| 188 |
+
|
| 189 |
+
class GlobalBlock(torch.nn.Module):
|
| 190 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.block = torch.nn.ModuleDict({
|
| 193 |
+
'ega': EGA(
|
| 194 |
+
num_mha_heads=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
|
| 195 |
+
'gcfn': GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
|
| 196 |
+
})
|
| 197 |
+
|
| 198 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 199 |
+
"""
|
| 200 |
+
Compute encoded features.
|
| 201 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 202 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 203 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 204 |
+
"""
|
| 205 |
+
x = self.block['ega'](x, pos_k)
|
| 206 |
+
x = self.block['gcfn'](x)
|
| 207 |
+
x = x.permute([0, 2, 1])
|
| 208 |
+
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class LocalBlock(torch.nn.Module):
|
| 213 |
+
def __init__(self, in_channels: int, kernel_size: int, dropout_rate: float):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.block = torch.nn.ModuleDict({
|
| 216 |
+
'cla': CLA(in_channels, kernel_size, dropout_rate),
|
| 217 |
+
'gcfn': GCFN(in_channels, dropout_rate)
|
| 218 |
+
})
|
| 219 |
+
|
| 220 |
+
def forward(self, x: torch.Tensor):
|
| 221 |
+
x = self.block['cla'](x)
|
| 222 |
+
x = self.block['gcfn'](x)
|
| 223 |
+
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class SpkAttention(torch.nn.Module):
|
| 228 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.self_attn = MultiHeadAttention(n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate)
|
| 231 |
+
self.feed_forward = GCFN(in_channels=in_channels, dropout_rate=dropout_rate)
|
| 232 |
+
|
| 233 |
+
def forward(self, x: torch.Tensor, num_spk: int):
|
| 234 |
+
"""
|
| 235 |
+
Compute encoded features.
|
| 236 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 237 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 238 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 239 |
+
"""
|
| 240 |
+
B, F, T = x.shape
|
| 241 |
+
x = x.view(B//num_spk, num_spk, F, T).contiguous()
|
| 242 |
+
x = x.permute([0, 3, 1, 2]).contiguous()
|
| 243 |
+
x = x.view(-1, num_spk, F).contiguous()
|
| 244 |
+
x = x + self.self_attn(x, None, None)
|
| 245 |
+
x = x.view(B//num_spk, T, num_spk, F).contiguous()
|
| 246 |
+
x = x.permute([0, 2, 3, 1]).contiguous()
|
| 247 |
+
x = x.view(B, F, T).contiguous()
|
| 248 |
+
x = x.permute([0, 2, 1])
|
| 249 |
+
x = self.feed_forward(x)
|
| 250 |
+
x = x.permute([0, 2, 1])
|
| 251 |
+
|
| 252 |
+
return x
|
models/SepReformer/source.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://github.com/dmlguq456/SepReformer
|