PhuongLT commited on
Commit
7b84203
·
0 Parent(s):

Initial clean commit for Space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +39 -0
  2. .gitignore +1 -0
  3. .vscode/settings.json +5 -0
  4. LICENSE +21 -0
  5. Models/multi_phoaudio_gemini/config_phoaudio_gemini_small.yml +35 -0
  6. Modules/__init__.py +1 -0
  7. Modules/__pycache__/__init__.cpython-310.pyc +0 -0
  8. Modules/__pycache__/discriminators.cpython-310.pyc +0 -0
  9. Modules/__pycache__/istftnet.cpython-310.pyc +0 -0
  10. Modules/__pycache__/utils.cpython-310.pyc +0 -0
  11. Modules/diffusion/__init__.py +1 -0
  12. Modules/diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  13. Modules/diffusion/__pycache__/diffusion.cpython-310.pyc +0 -0
  14. Modules/diffusion/__pycache__/modules.cpython-310.pyc +0 -0
  15. Modules/diffusion/__pycache__/sampler.cpython-310.pyc +0 -0
  16. Modules/diffusion/__pycache__/utils.cpython-310.pyc +0 -0
  17. Modules/diffusion/diffusion.py +94 -0
  18. Modules/diffusion/modules.py +693 -0
  19. Modules/diffusion/sampler.py +691 -0
  20. Modules/diffusion/utils.py +82 -0
  21. Modules/discriminators.py +188 -0
  22. Modules/hifigan.py +477 -0
  23. Modules/istftnet.py +530 -0
  24. Modules/slmadv.py +195 -0
  25. Modules/utils.py +14 -0
  26. README.md +12 -0
  27. Utils_extend_v1/.ipynb_checkpoints/__init__-checkpoint.py +3 -0
  28. Utils_extend_v1/ASR/.ipynb_checkpoints/config-checkpoint.yml +3 -0
  29. Utils_extend_v1/ASR/.ipynb_checkpoints/layers-checkpoint.py +3 -0
  30. Utils_extend_v1/ASR/.ipynb_checkpoints/model_struct-checkpoint.txt +3 -0
  31. Utils_extend_v1/ASR/.ipynb_checkpoints/models-checkpoint.py +3 -0
  32. Utils_extend_v1/ASR/__init__.py +3 -0
  33. Utils_extend_v1/ASR/__pycache__/__init__.cpython-310.pyc +3 -0
  34. Utils_extend_v1/ASR/__pycache__/__init__.cpython-312.pyc +3 -0
  35. Utils_extend_v1/ASR/__pycache__/layers.cpython-310.pyc +3 -0
  36. Utils_extend_v1/ASR/__pycache__/layers.cpython-312.pyc +3 -0
  37. Utils_extend_v1/ASR/__pycache__/models.cpython-310.pyc +3 -0
  38. Utils_extend_v1/ASR/__pycache__/models.cpython-312.pyc +3 -0
  39. Utils_extend_v1/ASR/config.yml +3 -0
  40. Utils_extend_v1/ASR/epoch_00080.pth +3 -0
  41. Utils_extend_v1/ASR/epoch_extend_186.pth +3 -0
  42. Utils_extend_v1/ASR/layers.py +3 -0
  43. Utils_extend_v1/ASR/model_struct.txt +3 -0
  44. Utils_extend_v1/ASR/models.py +3 -0
  45. Utils_extend_v1/JDC/.ipynb_checkpoints/model-checkpoint.py +3 -0
  46. Utils_extend_v1/JDC/__init__.py +3 -0
  47. Utils_extend_v1/JDC/__pycache__/__init__.cpython-310.pyc +3 -0
  48. Utils_extend_v1/JDC/__pycache__/__init__.cpython-312.pyc +3 -0
  49. Utils_extend_v1/JDC/__pycache__/model.cpython-310.pyc +3 -0
  50. Utils_extend_v1/JDC/__pycache__/model.cpython-312.pyc +3 -0
.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.t7 filter=lfs diff=lfs merge=lfs -text
37
+ *.wav filter=lfs diff=lfs merge=lfs -text
38
+ ref_voice/** filter=lfs diff=lfs merge=lfs -text
39
+ Utils_extend_v1/** filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ Models/multi_phoaudio_gemini/*.pth
.vscode/settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "python-envs.defaultEnvManager": "ms-python.python:conda",
3
+ "python-envs.defaultPackageManager": "ms-python.python:conda",
4
+ "python-envs.pythonProjects": []
5
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Aaron (Yinghao) Li
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Models/multi_phoaudio_gemini/config_phoaudio_gemini_small.yml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {ASR_config: Utils_extend_v1/ASR/config.yml, ASR_path: Utils_extend_v1/ASR/epoch_extend_186.pth,
2
+ F0_path: Utils_extend_v1/JDC/bst.t7, PLBERT_dir: Utils_extend_v1/PLBERT/, batch_size: 8,
3
+ data_params: {OOD_data: /home/xdep/data/jupyterhub/users/datnvt/data/custom_datasets/text_gemini_phoaudio_multi_speaker_small_v1/ood_multi_phoaudio.txt,
4
+ min_length: 50, root_path: /home/xdep/data/jupyterhub/users/datnvt/project/styletts2/custom_datasets/wavs_gemini_phoaudio_multi_speaker_small_v1,
5
+ train_data: /home/xdep/data/jupyterhub/users/datnvt/data/custom_datasets/text_gemini_phoaudio_multi_speaker_small_v1/train_filtered.txt,
6
+ val_data: /home/xdep/data/jupyterhub/users/datnvt/data/custom_datasets/text_gemini_phoaudio_multi_speaker_small_v1/validation_list.no_brackets.txt},
7
+ device: cuda, epochs_1st: 200, epochs_2nd: 150, extend_PLBERT: true, first_stage_path: '',
8
+ load_only_params: false, log_dir: Models/phoaudio/combine_phoaudio_gemini_small,
9
+ log_interval: 100, loss_params: {TMA_epoch: 50, diff_epoch: 0, joint_epoch: 0, lambda_F0: 1.0,
10
+ lambda_ce: 20.0, lambda_diff: 1.0, lambda_dur: 1.0, lambda_gen: 1.0, lambda_mel: 5.0,
11
+ lambda_mono: 1.0, lambda_norm: 1.0, lambda_s2s: 1.0, lambda_slm: 1.0, lambda_sty: 1.0},
12
+ max_len: 400, model_params: {decoder: {gen_istft_hop_size: 5, gen_istft_n_fft: 20,
13
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes: [
14
+ 3, 7, 11], type: istftnet, upsample_initial_channel: 512, upsample_kernel_sizes: [
15
+ 20, 12], upsample_rates: [10, 6]}, diffusion: {dist: {estimate_sigma_data: true,
16
+ mean: -3.0, sigma_data: 0.29304919641906935, std: 1.0}, embedding_mask_proba: 0.1,
17
+ transformer: {head_features: 64, multiplier: 2, num_heads: 8, num_layers: 3}},
18
+ dim_in: 64, dropout: 0.2, hidden_dim: 512, max_conv_dim: 512, max_dur: 50, multispeaker: true,
19
+ n_layer: 3, n_mels: 80, n_token: 186, slm: {hidden: 768, initial_channel: 64,
20
+ model: microsoft/wavlm-base-plus, nlayers: 13, sr: 16000}, style_dim: 128},
21
+ optimizer_params: {bert_lr: 1.0e-05, ft_lr: 1.0e-05, lr: 0.0001}, preprocess_params: {
22
+ spect_params: {hop_length: 300, n_fft: 2048, win_length: 1200}, sr: 24000}, pretrained_model: Models/phoaudio/combine_phoaudio_gemini_small/epoch_2nd_00003.pth,
23
+ save_freq: 1, second_stage_load_pretrained: true, slmadv_params: {batch_percentage: 0.5,
24
+ iter: 10, max_len: 500, min_len: 400, scale: 0.01, sig: 1.5, thresh: 5}, symbol: {
25
+ extend: "-124567\u032A", letters: ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz,
26
+ letters_ipa: "\u0251\u0250\u0252\xE6\u0253\u0299\u03B2\u0254\u0255\xE7\u0257\u0256\
27
+ \xF0\u02A4\u0259\u0258\u025A\u025B\u025C\u025D\u025E\u025F\u0284\u0261\u0260\
28
+ \u0262\u029B\u0266\u0267\u0127\u0265\u029C\u0268\u026A\u029D\u026D\u026C\u026B\
29
+ \u026E\u029F\u0271\u026F\u0270\u014B\u0273\u0272\u0274\xF8\u0275\u0278\u03B8\
30
+ \u0153\u0276\u0298\u0279\u027A\u027E\u027B\u0280\u0281\u027D\u0282\u0283\u0288\
31
+ \u02A7\u0289\u028A\u028B\u2C71\u028C\u0263\u0264\u028D\u03C7\u028E\u028F\u0291\
32
+ \u0290\u0292\u0294\u02A1\u0295\u02A2\u01C0\u01C1\u01C2\u01C3\u02C8\u02CC\u02D0\
33
+ \u02D1\u02BC\u02B4\u02B0\u02B1\u02B2\u02B7\u02E0\u02E4\u02DE\u2193\u2191\u2192\
34
+ \u2197\u2198'\u0329'\u1D7B", pad: $, punctuation: ";:,.!?\xA1\xBF\u2014\u2026\
35
+ \"\xAB\xBB\u201C\u201D "}}
Modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (142 Bytes). View file
 
Modules/__pycache__/discriminators.cpython-310.pyc ADDED
Binary file (6.04 kB). View file
 
Modules/__pycache__/istftnet.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
Modules/__pycache__/utils.cpython-310.pyc ADDED
Binary file (750 Bytes). View file
 
Modules/diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Modules/diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
Modules/diffusion/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (3.64 kB). View file
 
Modules/diffusion/__pycache__/modules.cpython-310.pyc ADDED
Binary file (16.2 kB). View file
 
Modules/diffusion/__pycache__/sampler.cpython-310.pyc ADDED
Binary file (22.1 kB). View file
 
Modules/diffusion/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.49 kB). View file
 
Modules/diffusion/diffusion.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from random import randint
3
+ from typing import Any, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+ from tqdm import tqdm
9
+
10
+ from .utils import *
11
+ from .sampler import *
12
+
13
+ """
14
+ Diffusion Classes (generic for 1d data)
15
+ """
16
+
17
+
18
+ class Model1d(nn.Module):
19
+ def __init__(self, unet_type: str = "base", **kwargs):
20
+ super().__init__()
21
+ diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
22
+ self.unet = None
23
+ self.diffusion = None
24
+
25
+ def forward(self, x: Tensor, **kwargs) -> Tensor:
26
+ return self.diffusion(x, **kwargs)
27
+
28
+ def sample(self, *args, **kwargs) -> Tensor:
29
+ return self.diffusion.sample(*args, **kwargs)
30
+
31
+
32
+ """
33
+ Audio Diffusion Classes (specific for 1d audio data)
34
+ """
35
+
36
+
37
+ def get_default_model_kwargs():
38
+ return dict(
39
+ channels=128,
40
+ patch_size=16,
41
+ multipliers=[1, 2, 4, 4, 4, 4, 4],
42
+ factors=[4, 4, 4, 2, 2, 2],
43
+ num_blocks=[2, 2, 2, 2, 2, 2],
44
+ attentions=[0, 0, 0, 1, 1, 1, 1],
45
+ attention_heads=8,
46
+ attention_features=64,
47
+ attention_multiplier=2,
48
+ attention_use_rel_pos=False,
49
+ diffusion_type="v",
50
+ diffusion_sigma_distribution=UniformDistribution(),
51
+ )
52
+
53
+
54
+ def get_default_sampling_kwargs():
55
+ return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
56
+
57
+
58
+ class AudioDiffusionModel(Model1d):
59
+ def __init__(self, **kwargs):
60
+ super().__init__(**{**get_default_model_kwargs(), **kwargs})
61
+
62
+ def sample(self, *args, **kwargs):
63
+ return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
64
+
65
+
66
+ class AudioDiffusionConditional(Model1d):
67
+ def __init__(
68
+ self,
69
+ embedding_features: int,
70
+ embedding_max_length: int,
71
+ embedding_mask_proba: float = 0.1,
72
+ **kwargs,
73
+ ):
74
+ self.embedding_mask_proba = embedding_mask_proba
75
+ default_kwargs = dict(
76
+ **get_default_model_kwargs(),
77
+ unet_type="cfg",
78
+ context_embedding_features=embedding_features,
79
+ context_embedding_max_length=embedding_max_length,
80
+ )
81
+ super().__init__(**{**default_kwargs, **kwargs})
82
+
83
+ def forward(self, *args, **kwargs):
84
+ default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
85
+ return super().forward(*args, **{**default_kwargs, **kwargs})
86
+
87
+ def sample(self, *args, **kwargs):
88
+ default_kwargs = dict(
89
+ **get_default_sampling_kwargs(),
90
+ embedding_scale=5.0,
91
+ )
92
+ return super().sample(*args, **{**default_kwargs, **kwargs})
93
+
94
+
Modules/diffusion/modules.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import floor, log, pi
2
+ from typing import Any, List, Optional, Sequence, Tuple, Union
3
+
4
+ from .utils import *
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange, reduce, repeat
9
+ from einops.layers.torch import Rearrange
10
+ from einops_exts import rearrange_many
11
+ from torch import Tensor, einsum
12
+
13
+
14
+ """
15
+ Utils
16
+ """
17
+
18
+ class AdaLayerNorm(nn.Module):
19
+ def __init__(self, style_dim, channels, eps=1e-5):
20
+ super().__init__()
21
+ self.channels = channels
22
+ self.eps = eps
23
+
24
+ self.fc = nn.Linear(style_dim, channels*2)
25
+
26
+ def forward(self, x, s):
27
+ x = x.transpose(-1, -2)
28
+ x = x.transpose(1, -1)
29
+
30
+ h = self.fc(s)
31
+ h = h.view(h.size(0), h.size(1), 1)
32
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
33
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
34
+
35
+
36
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
37
+ x = (1 + gamma) * x + beta
38
+ return x.transpose(1, -1).transpose(-1, -2)
39
+
40
+ class StyleTransformer1d(nn.Module):
41
+ def __init__(
42
+ self,
43
+ num_layers: int,
44
+ channels: int,
45
+ num_heads: int,
46
+ head_features: int,
47
+ multiplier: int,
48
+ use_context_time: bool = True,
49
+ use_rel_pos: bool = False,
50
+ context_features_multiplier: int = 1,
51
+ rel_pos_num_buckets: Optional[int] = None,
52
+ rel_pos_max_distance: Optional[int] = None,
53
+ context_features: Optional[int] = None,
54
+ context_embedding_features: Optional[int] = None,
55
+ embedding_max_length: int = 512,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.blocks = nn.ModuleList(
60
+ [
61
+ StyleTransformerBlock(
62
+ features=channels + context_embedding_features,
63
+ head_features=head_features,
64
+ num_heads=num_heads,
65
+ multiplier=multiplier,
66
+ style_dim=context_features,
67
+ use_rel_pos=use_rel_pos,
68
+ rel_pos_num_buckets=rel_pos_num_buckets,
69
+ rel_pos_max_distance=rel_pos_max_distance,
70
+ )
71
+ for i in range(num_layers)
72
+ ]
73
+ )
74
+
75
+ self.to_out = nn.Sequential(
76
+ Rearrange("b t c -> b c t"),
77
+ nn.Conv1d(
78
+ in_channels=channels + context_embedding_features,
79
+ out_channels=channels,
80
+ kernel_size=1,
81
+ ),
82
+ )
83
+
84
+ use_context_features = exists(context_features)
85
+ self.use_context_features = use_context_features
86
+ self.use_context_time = use_context_time
87
+
88
+ if use_context_time or use_context_features:
89
+ context_mapping_features = channels + context_embedding_features
90
+
91
+ self.to_mapping = nn.Sequential(
92
+ nn.Linear(context_mapping_features, context_mapping_features),
93
+ nn.GELU(),
94
+ nn.Linear(context_mapping_features, context_mapping_features),
95
+ nn.GELU(),
96
+ )
97
+
98
+ if use_context_time:
99
+ assert exists(context_mapping_features)
100
+ self.to_time = nn.Sequential(
101
+ TimePositionalEmbedding(
102
+ dim=channels, out_features=context_mapping_features
103
+ ),
104
+ nn.GELU(),
105
+ )
106
+
107
+ if use_context_features:
108
+ assert exists(context_features) and exists(context_mapping_features)
109
+ self.to_features = nn.Sequential(
110
+ nn.Linear(
111
+ in_features=context_features, out_features=context_mapping_features
112
+ ),
113
+ nn.GELU(),
114
+ )
115
+
116
+ self.fixed_embedding = FixedEmbedding(
117
+ max_length=embedding_max_length, features=context_embedding_features
118
+ )
119
+
120
+
121
+ def get_mapping(
122
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
123
+ ) -> Optional[Tensor]:
124
+ """Combines context time features and features into mapping"""
125
+ items, mapping = [], None
126
+ # Compute time features
127
+ if self.use_context_time:
128
+ assert_message = "use_context_time=True but no time features provided"
129
+ assert exists(time), assert_message
130
+ items += [self.to_time(time)]
131
+ # Compute features
132
+ if self.use_context_features:
133
+ assert_message = "context_features exists but no features provided"
134
+ assert exists(features), assert_message
135
+ items += [self.to_features(features)]
136
+
137
+ # Compute joint mapping
138
+ if self.use_context_time or self.use_context_features:
139
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
140
+ mapping = self.to_mapping(mapping)
141
+
142
+ return mapping
143
+
144
+ def run(self, x, time, embedding, features):
145
+
146
+ mapping = self.get_mapping(time, features)
147
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
148
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
149
+
150
+ for block in self.blocks:
151
+ x = x + mapping
152
+ x = block(x, features)
153
+
154
+ x = x.mean(axis=1).unsqueeze(1)
155
+ x = self.to_out(x)
156
+ x = x.transpose(-1, -2)
157
+
158
+ return x
159
+
160
+ def forward(self, x: Tensor,
161
+ time: Tensor,
162
+ embedding_mask_proba: float = 0.0,
163
+ embedding: Optional[Tensor] = None,
164
+ features: Optional[Tensor] = None,
165
+ embedding_scale: float = 1.0) -> Tensor:
166
+
167
+ b, device = embedding.shape[0], embedding.device
168
+ fixed_embedding = self.fixed_embedding(embedding)
169
+ if embedding_mask_proba > 0.0:
170
+ # Randomly mask embedding
171
+ batch_mask = rand_bool(
172
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
173
+ )
174
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
175
+
176
+ if embedding_scale != 1.0:
177
+ # Compute both normal and fixed embedding outputs
178
+ out = self.run(x, time, embedding=embedding, features=features)
179
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
180
+ # Scale conditional output using classifier-free guidance
181
+ return out_masked + (out - out_masked) * embedding_scale
182
+ else:
183
+ return self.run(x, time, embedding=embedding, features=features)
184
+
185
+ return x
186
+
187
+
188
+ class StyleTransformerBlock(nn.Module):
189
+ def __init__(
190
+ self,
191
+ features: int,
192
+ num_heads: int,
193
+ head_features: int,
194
+ style_dim: int,
195
+ multiplier: int,
196
+ use_rel_pos: bool,
197
+ rel_pos_num_buckets: Optional[int] = None,
198
+ rel_pos_max_distance: Optional[int] = None,
199
+ context_features: Optional[int] = None,
200
+ ):
201
+ super().__init__()
202
+
203
+ self.use_cross_attention = exists(context_features) and context_features > 0
204
+
205
+ self.attention = StyleAttention(
206
+ features=features,
207
+ style_dim=style_dim,
208
+ num_heads=num_heads,
209
+ head_features=head_features,
210
+ use_rel_pos=use_rel_pos,
211
+ rel_pos_num_buckets=rel_pos_num_buckets,
212
+ rel_pos_max_distance=rel_pos_max_distance,
213
+ )
214
+
215
+ if self.use_cross_attention:
216
+ self.cross_attention = StyleAttention(
217
+ features=features,
218
+ style_dim=style_dim,
219
+ num_heads=num_heads,
220
+ head_features=head_features,
221
+ context_features=context_features,
222
+ use_rel_pos=use_rel_pos,
223
+ rel_pos_num_buckets=rel_pos_num_buckets,
224
+ rel_pos_max_distance=rel_pos_max_distance,
225
+ )
226
+
227
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
228
+
229
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
230
+ x = self.attention(x, s) + x
231
+ if self.use_cross_attention:
232
+ x = self.cross_attention(x, s, context=context) + x
233
+ x = self.feed_forward(x) + x
234
+ return x
235
+
236
+ class StyleAttention(nn.Module):
237
+ def __init__(
238
+ self,
239
+ features: int,
240
+ *,
241
+ style_dim: int,
242
+ head_features: int,
243
+ num_heads: int,
244
+ context_features: Optional[int] = None,
245
+ use_rel_pos: bool,
246
+ rel_pos_num_buckets: Optional[int] = None,
247
+ rel_pos_max_distance: Optional[int] = None,
248
+ ):
249
+ super().__init__()
250
+ self.context_features = context_features
251
+ mid_features = head_features * num_heads
252
+ context_features = default(context_features, features)
253
+
254
+ self.norm = AdaLayerNorm(style_dim, features)
255
+ self.norm_context = AdaLayerNorm(style_dim, context_features)
256
+ self.to_q = nn.Linear(
257
+ in_features=features, out_features=mid_features, bias=False
258
+ )
259
+ self.to_kv = nn.Linear(
260
+ in_features=context_features, out_features=mid_features * 2, bias=False
261
+ )
262
+ self.attention = AttentionBase(
263
+ features,
264
+ num_heads=num_heads,
265
+ head_features=head_features,
266
+ use_rel_pos=use_rel_pos,
267
+ rel_pos_num_buckets=rel_pos_num_buckets,
268
+ rel_pos_max_distance=rel_pos_max_distance,
269
+ )
270
+
271
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
272
+ assert_message = "You must provide a context when using context_features"
273
+ assert not self.context_features or exists(context), assert_message
274
+ # Use context if provided
275
+ context = default(context, x)
276
+ # Normalize then compute q from input and k,v from context
277
+ x, context = self.norm(x, s), self.norm_context(context, s)
278
+
279
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
280
+ # Compute and return attention
281
+ return self.attention(q, k, v)
282
+
283
+ class Transformer1d(nn.Module):
284
+ def __init__(
285
+ self,
286
+ num_layers: int,
287
+ channels: int,
288
+ num_heads: int,
289
+ head_features: int,
290
+ multiplier: int,
291
+ use_context_time: bool = True,
292
+ use_rel_pos: bool = False,
293
+ context_features_multiplier: int = 1,
294
+ rel_pos_num_buckets: Optional[int] = None,
295
+ rel_pos_max_distance: Optional[int] = None,
296
+ context_features: Optional[int] = None,
297
+ context_embedding_features: Optional[int] = None,
298
+ embedding_max_length: int = 512,
299
+ ):
300
+ super().__init__()
301
+
302
+ self.blocks = nn.ModuleList(
303
+ [
304
+ TransformerBlock(
305
+ features=channels + context_embedding_features,
306
+ head_features=head_features,
307
+ num_heads=num_heads,
308
+ multiplier=multiplier,
309
+ use_rel_pos=use_rel_pos,
310
+ rel_pos_num_buckets=rel_pos_num_buckets,
311
+ rel_pos_max_distance=rel_pos_max_distance,
312
+ )
313
+ for i in range(num_layers)
314
+ ]
315
+ )
316
+
317
+ self.to_out = nn.Sequential(
318
+ Rearrange("b t c -> b c t"),
319
+ nn.Conv1d(
320
+ in_channels=channels + context_embedding_features,
321
+ out_channels=channels,
322
+ kernel_size=1,
323
+ ),
324
+ )
325
+
326
+ use_context_features = exists(context_features)
327
+ self.use_context_features = use_context_features
328
+ self.use_context_time = use_context_time
329
+
330
+ if use_context_time or use_context_features:
331
+ context_mapping_features = channels + context_embedding_features
332
+
333
+ self.to_mapping = nn.Sequential(
334
+ nn.Linear(context_mapping_features, context_mapping_features),
335
+ nn.GELU(),
336
+ nn.Linear(context_mapping_features, context_mapping_features),
337
+ nn.GELU(),
338
+ )
339
+
340
+ if use_context_time:
341
+ assert exists(context_mapping_features)
342
+ self.to_time = nn.Sequential(
343
+ TimePositionalEmbedding(
344
+ dim=channels, out_features=context_mapping_features
345
+ ),
346
+ nn.GELU(),
347
+ )
348
+
349
+ if use_context_features:
350
+ assert exists(context_features) and exists(context_mapping_features)
351
+ self.to_features = nn.Sequential(
352
+ nn.Linear(
353
+ in_features=context_features, out_features=context_mapping_features
354
+ ),
355
+ nn.GELU(),
356
+ )
357
+
358
+ self.fixed_embedding = FixedEmbedding(
359
+ max_length=embedding_max_length, features=context_embedding_features
360
+ )
361
+
362
+
363
+ def get_mapping(
364
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
365
+ ) -> Optional[Tensor]:
366
+ """Combines context time features and features into mapping"""
367
+ items, mapping = [], None
368
+ # Compute time features
369
+ if self.use_context_time:
370
+ assert_message = "use_context_time=True but no time features provided"
371
+ assert exists(time), assert_message
372
+ items += [self.to_time(time)]
373
+ # Compute features
374
+ if self.use_context_features:
375
+ assert_message = "context_features exists but no features provided"
376
+ assert exists(features), assert_message
377
+ items += [self.to_features(features)]
378
+
379
+ # Compute joint mapping
380
+ if self.use_context_time or self.use_context_features:
381
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
382
+ mapping = self.to_mapping(mapping)
383
+
384
+ return mapping
385
+
386
+ def run(self, x, time, embedding, features):
387
+
388
+ mapping = self.get_mapping(time, features)
389
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
390
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
391
+
392
+ for block in self.blocks:
393
+ x = x + mapping
394
+ x = block(x)
395
+
396
+ x = x.mean(axis=1).unsqueeze(1)
397
+ x = self.to_out(x)
398
+ x = x.transpose(-1, -2)
399
+
400
+ return x
401
+
402
+ def forward(self, x: Tensor,
403
+ time: Tensor,
404
+ embedding_mask_proba: float = 0.0,
405
+ embedding: Optional[Tensor] = None,
406
+ features: Optional[Tensor] = None,
407
+ embedding_scale: float = 1.0) -> Tensor:
408
+
409
+ b, device = embedding.shape[0], embedding.device
410
+ fixed_embedding = self.fixed_embedding(embedding)
411
+ if embedding_mask_proba > 0.0:
412
+ # Randomly mask embedding
413
+ batch_mask = rand_bool(
414
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
415
+ )
416
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
417
+
418
+ if embedding_scale != 1.0:
419
+ # Compute both normal and fixed embedding outputs
420
+ out = self.run(x, time, embedding=embedding, features=features)
421
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
422
+ # Scale conditional output using classifier-free guidance
423
+ return out_masked + (out - out_masked) * embedding_scale
424
+ else:
425
+ return self.run(x, time, embedding=embedding, features=features)
426
+
427
+ return x
428
+
429
+
430
+ """
431
+ Attention Components
432
+ """
433
+
434
+
435
+ class RelativePositionBias(nn.Module):
436
+ def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
437
+ super().__init__()
438
+ self.num_buckets = num_buckets
439
+ self.max_distance = max_distance
440
+ self.num_heads = num_heads
441
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
442
+
443
+ @staticmethod
444
+ def _relative_position_bucket(
445
+ relative_position: Tensor, num_buckets: int, max_distance: int
446
+ ):
447
+ num_buckets //= 2
448
+ ret = (relative_position >= 0).to(torch.long) * num_buckets
449
+ n = torch.abs(relative_position)
450
+
451
+ max_exact = num_buckets // 2
452
+ is_small = n < max_exact
453
+
454
+ val_if_large = (
455
+ max_exact
456
+ + (
457
+ torch.log(n.float() / max_exact)
458
+ / log(max_distance / max_exact)
459
+ * (num_buckets - max_exact)
460
+ ).long()
461
+ )
462
+ val_if_large = torch.min(
463
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
464
+ )
465
+
466
+ ret += torch.where(is_small, n, val_if_large)
467
+ return ret
468
+
469
+ def forward(self, num_queries: int, num_keys: int) -> Tensor:
470
+ i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
471
+ q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
472
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
473
+ rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
474
+
475
+ relative_position_bucket = self._relative_position_bucket(
476
+ rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
477
+ )
478
+
479
+ bias = self.relative_attention_bias(relative_position_bucket)
480
+ bias = rearrange(bias, "m n h -> 1 h m n")
481
+ return bias
482
+
483
+
484
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
485
+ mid_features = features * multiplier
486
+ return nn.Sequential(
487
+ nn.Linear(in_features=features, out_features=mid_features),
488
+ nn.GELU(),
489
+ nn.Linear(in_features=mid_features, out_features=features),
490
+ )
491
+
492
+
493
+ class AttentionBase(nn.Module):
494
+ def __init__(
495
+ self,
496
+ features: int,
497
+ *,
498
+ head_features: int,
499
+ num_heads: int,
500
+ use_rel_pos: bool,
501
+ out_features: Optional[int] = None,
502
+ rel_pos_num_buckets: Optional[int] = None,
503
+ rel_pos_max_distance: Optional[int] = None,
504
+ ):
505
+ super().__init__()
506
+ self.scale = head_features ** -0.5
507
+ self.num_heads = num_heads
508
+ self.use_rel_pos = use_rel_pos
509
+ mid_features = head_features * num_heads
510
+
511
+ if use_rel_pos:
512
+ assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
513
+ self.rel_pos = RelativePositionBias(
514
+ num_buckets=rel_pos_num_buckets,
515
+ max_distance=rel_pos_max_distance,
516
+ num_heads=num_heads,
517
+ )
518
+ if out_features is None:
519
+ out_features = features
520
+
521
+ self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
522
+
523
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
524
+ # Split heads
525
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
526
+ # Compute similarity matrix
527
+ sim = einsum("... n d, ... m d -> ... n m", q, k)
528
+ sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
529
+ sim = sim * self.scale
530
+ # Get attention matrix with softmax
531
+ attn = sim.softmax(dim=-1)
532
+ # Compute values
533
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
534
+ out = rearrange(out, "b h n d -> b n (h d)")
535
+ return self.to_out(out)
536
+
537
+
538
+ class Attention(nn.Module):
539
+ def __init__(
540
+ self,
541
+ features: int,
542
+ *,
543
+ head_features: int,
544
+ num_heads: int,
545
+ out_features: Optional[int] = None,
546
+ context_features: Optional[int] = None,
547
+ use_rel_pos: bool,
548
+ rel_pos_num_buckets: Optional[int] = None,
549
+ rel_pos_max_distance: Optional[int] = None,
550
+ ):
551
+ super().__init__()
552
+ self.context_features = context_features
553
+ mid_features = head_features * num_heads
554
+ context_features = default(context_features, features)
555
+
556
+ self.norm = nn.LayerNorm(features)
557
+ self.norm_context = nn.LayerNorm(context_features)
558
+ self.to_q = nn.Linear(
559
+ in_features=features, out_features=mid_features, bias=False
560
+ )
561
+ self.to_kv = nn.Linear(
562
+ in_features=context_features, out_features=mid_features * 2, bias=False
563
+ )
564
+
565
+ self.attention = AttentionBase(
566
+ features,
567
+ out_features=out_features,
568
+ num_heads=num_heads,
569
+ head_features=head_features,
570
+ use_rel_pos=use_rel_pos,
571
+ rel_pos_num_buckets=rel_pos_num_buckets,
572
+ rel_pos_max_distance=rel_pos_max_distance,
573
+ )
574
+
575
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
576
+ assert_message = "You must provide a context when using context_features"
577
+ assert not self.context_features or exists(context), assert_message
578
+ # Use context if provided
579
+ context = default(context, x)
580
+ # Normalize then compute q from input and k,v from context
581
+ x, context = self.norm(x), self.norm_context(context)
582
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
583
+ # Compute and return attention
584
+ return self.attention(q, k, v)
585
+
586
+
587
+ """
588
+ Transformer Blocks
589
+ """
590
+
591
+
592
+ class TransformerBlock(nn.Module):
593
+ def __init__(
594
+ self,
595
+ features: int,
596
+ num_heads: int,
597
+ head_features: int,
598
+ multiplier: int,
599
+ use_rel_pos: bool,
600
+ rel_pos_num_buckets: Optional[int] = None,
601
+ rel_pos_max_distance: Optional[int] = None,
602
+ context_features: Optional[int] = None,
603
+ ):
604
+ super().__init__()
605
+
606
+ self.use_cross_attention = exists(context_features) and context_features > 0
607
+
608
+ self.attention = Attention(
609
+ features=features,
610
+ num_heads=num_heads,
611
+ head_features=head_features,
612
+ use_rel_pos=use_rel_pos,
613
+ rel_pos_num_buckets=rel_pos_num_buckets,
614
+ rel_pos_max_distance=rel_pos_max_distance,
615
+ )
616
+
617
+ if self.use_cross_attention:
618
+ self.cross_attention = Attention(
619
+ features=features,
620
+ num_heads=num_heads,
621
+ head_features=head_features,
622
+ context_features=context_features,
623
+ use_rel_pos=use_rel_pos,
624
+ rel_pos_num_buckets=rel_pos_num_buckets,
625
+ rel_pos_max_distance=rel_pos_max_distance,
626
+ )
627
+
628
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
629
+
630
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
631
+ x = self.attention(x) + x
632
+ if self.use_cross_attention:
633
+ x = self.cross_attention(x, context=context) + x
634
+ x = self.feed_forward(x) + x
635
+ return x
636
+
637
+
638
+
639
+ """
640
+ Time Embeddings
641
+ """
642
+
643
+
644
+ class SinusoidalEmbedding(nn.Module):
645
+ def __init__(self, dim: int):
646
+ super().__init__()
647
+ self.dim = dim
648
+
649
+ def forward(self, x: Tensor) -> Tensor:
650
+ device, half_dim = x.device, self.dim // 2
651
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
652
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
653
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
654
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
655
+
656
+
657
+ class LearnedPositionalEmbedding(nn.Module):
658
+ """Used for continuous time"""
659
+
660
+ def __init__(self, dim: int):
661
+ super().__init__()
662
+ assert (dim % 2) == 0
663
+ half_dim = dim // 2
664
+ self.weights = nn.Parameter(torch.randn(half_dim))
665
+
666
+ def forward(self, x: Tensor) -> Tensor:
667
+ x = rearrange(x, "b -> b 1")
668
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
669
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
670
+ fouriered = torch.cat((x, fouriered), dim=-1)
671
+ return fouriered
672
+
673
+
674
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
675
+ return nn.Sequential(
676
+ LearnedPositionalEmbedding(dim),
677
+ nn.Linear(in_features=dim + 1, out_features=out_features),
678
+ )
679
+
680
+ class FixedEmbedding(nn.Module):
681
+ def __init__(self, max_length: int, features: int):
682
+ super().__init__()
683
+ self.max_length = max_length
684
+ self.embedding = nn.Embedding(max_length, features)
685
+
686
+ def forward(self, x: Tensor) -> Tensor:
687
+ batch_size, length, device = *x.shape[0:2], x.device
688
+ assert_message = "Input sequence length must be <= max_length"
689
+ assert length <= self.max_length, assert_message
690
+ position = torch.arange(length, device=device)
691
+ fixed_embedding = self.embedding(position)
692
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
693
+ return fixed_embedding
Modules/diffusion/sampler.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import atan, cos, pi, sin, sqrt
2
+ from typing import Any, Callable, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, reduce
8
+ from torch import Tensor
9
+
10
+ from .utils import *
11
+
12
+ """
13
+ Diffusion Training
14
+ """
15
+
16
+ """ Distributions """
17
+
18
+
19
+ class Distribution:
20
+ def __call__(self, num_samples: int, device: torch.device):
21
+ raise NotImplementedError()
22
+
23
+
24
+ class LogNormalDistribution(Distribution):
25
+ def __init__(self, mean: float, std: float):
26
+ self.mean = mean
27
+ self.std = std
28
+
29
+ def __call__(
30
+ self, num_samples: int, device: torch.device = torch.device("cpu")
31
+ ) -> Tensor:
32
+ normal = self.mean + self.std * torch.randn((num_samples,), device=device)
33
+ return normal.exp()
34
+
35
+
36
+ class UniformDistribution(Distribution):
37
+ def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
38
+ return torch.rand(num_samples, device=device)
39
+
40
+
41
+ class VKDistribution(Distribution):
42
+ def __init__(
43
+ self,
44
+ min_value: float = 0.0,
45
+ max_value: float = float("inf"),
46
+ sigma_data: float = 1.0,
47
+ ):
48
+ self.min_value = min_value
49
+ self.max_value = max_value
50
+ self.sigma_data = sigma_data
51
+
52
+ def __call__(
53
+ self, num_samples: int, device: torch.device = torch.device("cpu")
54
+ ) -> Tensor:
55
+ sigma_data = self.sigma_data
56
+ min_cdf = atan(self.min_value / sigma_data) * 2 / pi
57
+ max_cdf = atan(self.max_value / sigma_data) * 2 / pi
58
+ u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
59
+ return torch.tan(u * pi / 2) * sigma_data
60
+
61
+
62
+ """ Diffusion Classes """
63
+
64
+
65
+ def pad_dims(x: Tensor, ndim: int) -> Tensor:
66
+ # Pads additional ndims to the right of the tensor
67
+ return x.view(*x.shape, *((1,) * ndim))
68
+
69
+
70
+ def clip(x: Tensor, dynamic_threshold: float = 0.0):
71
+ if dynamic_threshold == 0.0:
72
+ return x.clamp(-1.0, 1.0)
73
+ else:
74
+ # Dynamic thresholding
75
+ # Find dynamic threshold quantile for each batch
76
+ x_flat = rearrange(x, "b ... -> b (...)")
77
+ scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
78
+ # Clamp to a min of 1.0
79
+ scale.clamp_(min=1.0)
80
+ # Clamp all values and scale
81
+ scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
82
+ x = x.clamp(-scale, scale) / scale
83
+ return x
84
+
85
+
86
+ def to_batch(
87
+ batch_size: int,
88
+ device: torch.device,
89
+ x: Optional[float] = None,
90
+ xs: Optional[Tensor] = None,
91
+ ) -> Tensor:
92
+ assert exists(x) ^ exists(xs), "Either x or xs must be provided"
93
+ # If x provided use the same for all batch items
94
+ if exists(x):
95
+ xs = torch.full(size=(batch_size,), fill_value=x).to(device)
96
+ assert exists(xs)
97
+ return xs
98
+
99
+
100
+ class Diffusion(nn.Module):
101
+
102
+ alias: str = ""
103
+
104
+ """Base diffusion class"""
105
+
106
+ def denoise_fn(
107
+ self,
108
+ x_noisy: Tensor,
109
+ sigmas: Optional[Tensor] = None,
110
+ sigma: Optional[float] = None,
111
+ **kwargs,
112
+ ) -> Tensor:
113
+ raise NotImplementedError("Diffusion class missing denoise_fn")
114
+
115
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
116
+ raise NotImplementedError("Diffusion class missing forward function")
117
+
118
+
119
+ class VDiffusion(Diffusion):
120
+
121
+ alias = "v"
122
+
123
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
124
+ super().__init__()
125
+ self.net = net
126
+ self.sigma_distribution = sigma_distribution
127
+
128
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
129
+ angle = sigmas * pi / 2
130
+ alpha = torch.cos(angle)
131
+ beta = torch.sin(angle)
132
+ return alpha, beta
133
+
134
+ def denoise_fn(
135
+ self,
136
+ x_noisy: Tensor,
137
+ sigmas: Optional[Tensor] = None,
138
+ sigma: Optional[float] = None,
139
+ **kwargs,
140
+ ) -> Tensor:
141
+ batch_size, device = x_noisy.shape[0], x_noisy.device
142
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
143
+ return self.net(x_noisy, sigmas, **kwargs)
144
+
145
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
146
+ batch_size, device = x.shape[0], x.device
147
+
148
+ # Sample amount of noise to add for each batch element
149
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
150
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
151
+
152
+ # Get noise
153
+ noise = default(noise, lambda: torch.randn_like(x))
154
+
155
+ # Combine input and noise weighted by half-circle
156
+ alpha, beta = self.get_alpha_beta(sigmas_padded)
157
+ x_noisy = x * alpha + noise * beta
158
+ x_target = noise * alpha - x * beta
159
+
160
+ # Denoise and return loss
161
+ x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
162
+ return F.mse_loss(x_denoised, x_target)
163
+
164
+
165
+ class KDiffusion(Diffusion):
166
+ """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
167
+
168
+ alias = "k"
169
+
170
+ def __init__(
171
+ self,
172
+ net: nn.Module,
173
+ *,
174
+ sigma_distribution: Distribution,
175
+ sigma_data: float, # data distribution standard deviation
176
+ dynamic_threshold: float = 0.0,
177
+ ):
178
+ super().__init__()
179
+ self.net = net
180
+ self.sigma_data = sigma_data
181
+ self.sigma_distribution = sigma_distribution
182
+ self.dynamic_threshold = dynamic_threshold
183
+
184
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
185
+ sigma_data = self.sigma_data
186
+ c_noise = torch.log(sigmas) * 0.25
187
+ sigmas = rearrange(sigmas, "b -> b 1 1")
188
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
189
+ c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
190
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
191
+ return c_skip, c_out, c_in, c_noise
192
+
193
+ def denoise_fn(
194
+ self,
195
+ x_noisy: Tensor,
196
+ sigmas: Optional[Tensor] = None,
197
+ sigma: Optional[float] = None,
198
+ **kwargs,
199
+ ) -> Tensor:
200
+ batch_size, device = x_noisy.shape[0], x_noisy.device
201
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
202
+
203
+ # Predict network output and add skip connection
204
+ c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
205
+ x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
206
+ x_denoised = c_skip * x_noisy + c_out * x_pred
207
+
208
+ return x_denoised
209
+
210
+ def loss_weight(self, sigmas: Tensor) -> Tensor:
211
+ # Computes weight depending on data distribution
212
+ return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
213
+
214
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
215
+ batch_size, device = x.shape[0], x.device
216
+ from einops import rearrange, reduce
217
+
218
+ # Sample amount of noise to add for each batch element
219
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
220
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
221
+
222
+ # Add noise to input
223
+ noise = default(noise, lambda: torch.randn_like(x))
224
+ x_noisy = x + sigmas_padded * noise
225
+
226
+ # Compute denoised values
227
+ x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
228
+
229
+ # Compute weighted loss
230
+ losses = F.mse_loss(x_denoised, x, reduction="none")
231
+ losses = reduce(losses, "b ... -> b", "mean")
232
+ losses = losses * self.loss_weight(sigmas)
233
+ loss = losses.mean()
234
+ return loss
235
+
236
+
237
+ class VKDiffusion(Diffusion):
238
+
239
+ alias = "vk"
240
+
241
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
242
+ super().__init__()
243
+ self.net = net
244
+ self.sigma_distribution = sigma_distribution
245
+
246
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
247
+ sigma_data = 1.0
248
+ sigmas = rearrange(sigmas, "b -> b 1 1")
249
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
250
+ c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
251
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
252
+ return c_skip, c_out, c_in
253
+
254
+ def sigma_to_t(self, sigmas: Tensor) -> Tensor:
255
+ return sigmas.atan() / pi * 2
256
+
257
+ def t_to_sigma(self, t: Tensor) -> Tensor:
258
+ return (t * pi / 2).tan()
259
+
260
+ def denoise_fn(
261
+ self,
262
+ x_noisy: Tensor,
263
+ sigmas: Optional[Tensor] = None,
264
+ sigma: Optional[float] = None,
265
+ **kwargs,
266
+ ) -> Tensor:
267
+ batch_size, device = x_noisy.shape[0], x_noisy.device
268
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
269
+
270
+ # Predict network output and add skip connection
271
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
272
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
273
+ x_denoised = c_skip * x_noisy + c_out * x_pred
274
+ return x_denoised
275
+
276
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
277
+ batch_size, device = x.shape[0], x.device
278
+
279
+ # Sample amount of noise to add for each batch element
280
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
281
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
282
+
283
+ # Add noise to input
284
+ noise = default(noise, lambda: torch.randn_like(x))
285
+ x_noisy = x + sigmas_padded * noise
286
+
287
+ # Compute model output
288
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
289
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
290
+
291
+ # Compute v-objective target
292
+ v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
293
+
294
+ # Compute loss
295
+ loss = F.mse_loss(x_pred, v_target)
296
+ return loss
297
+
298
+
299
+ """
300
+ Diffusion Sampling
301
+ """
302
+
303
+ """ Schedules """
304
+
305
+
306
+ class Schedule(nn.Module):
307
+ """Interface used by different sampling schedules"""
308
+
309
+ def forward(self, num_steps: int, device: torch.device) -> Tensor:
310
+ raise NotImplementedError()
311
+
312
+
313
+ class LinearSchedule(Schedule):
314
+ def forward(self, num_steps: int, device: Any) -> Tensor:
315
+ sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
316
+ return sigmas
317
+
318
+
319
+ class KarrasSchedule(Schedule):
320
+ """https://arxiv.org/abs/2206.00364 equation 5"""
321
+
322
+ def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
323
+ super().__init__()
324
+ self.sigma_min = sigma_min
325
+ self.sigma_max = sigma_max
326
+ self.rho = rho
327
+
328
+ def forward(self, num_steps: int, device: Any) -> Tensor:
329
+ rho_inv = 1.0 / self.rho
330
+ steps = torch.arange(num_steps, device=device, dtype=torch.float32)
331
+ sigmas = (
332
+ self.sigma_max ** rho_inv
333
+ + (steps / (num_steps - 1))
334
+ * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
335
+ ) ** self.rho
336
+ sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
337
+ return sigmas
338
+
339
+
340
+ """ Samplers """
341
+
342
+
343
+ class Sampler(nn.Module):
344
+
345
+ diffusion_types: List[Type[Diffusion]] = []
346
+
347
+ def forward(
348
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
349
+ ) -> Tensor:
350
+ raise NotImplementedError()
351
+
352
+ def inpaint(
353
+ self,
354
+ source: Tensor,
355
+ mask: Tensor,
356
+ fn: Callable,
357
+ sigmas: Tensor,
358
+ num_steps: int,
359
+ num_resamples: int,
360
+ ) -> Tensor:
361
+ raise NotImplementedError("Inpainting not available with current sampler")
362
+
363
+
364
+ class VSampler(Sampler):
365
+
366
+ diffusion_types = [VDiffusion]
367
+
368
+ def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
369
+ angle = sigma * pi / 2
370
+ alpha = cos(angle)
371
+ beta = sin(angle)
372
+ return alpha, beta
373
+
374
+ def forward(
375
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
376
+ ) -> Tensor:
377
+ x = sigmas[0] * noise
378
+ alpha, beta = self.get_alpha_beta(sigmas[0].item())
379
+
380
+ for i in range(num_steps - 1):
381
+ is_last = i == num_steps - 1
382
+
383
+ x_denoised = fn(x, sigma=sigmas[i])
384
+ x_pred = x * alpha - x_denoised * beta
385
+ x_eps = x * beta + x_denoised * alpha
386
+
387
+ if not is_last:
388
+ alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
389
+ x = x_pred * alpha + x_eps * beta
390
+
391
+ return x_pred
392
+
393
+
394
+ class KarrasSampler(Sampler):
395
+ """https://arxiv.org/abs/2206.00364 algorithm 1"""
396
+
397
+ diffusion_types = [KDiffusion, VKDiffusion]
398
+
399
+ def __init__(
400
+ self,
401
+ s_tmin: float = 0,
402
+ s_tmax: float = float("inf"),
403
+ s_churn: float = 0.0,
404
+ s_noise: float = 1.0,
405
+ ):
406
+ super().__init__()
407
+ self.s_tmin = s_tmin
408
+ self.s_tmax = s_tmax
409
+ self.s_noise = s_noise
410
+ self.s_churn = s_churn
411
+
412
+ def step(
413
+ self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
414
+ ) -> Tensor:
415
+ """Algorithm 2 (step)"""
416
+ # Select temporarily increased noise level
417
+ sigma_hat = sigma + gamma * sigma
418
+ # Add noise to move from sigma to sigma_hat
419
+ epsilon = self.s_noise * torch.randn_like(x)
420
+ x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
421
+ # Evaluate ∂x/∂sigma at sigma_hat
422
+ d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
423
+ # Take euler step from sigma_hat to sigma_next
424
+ x_next = x_hat + (sigma_next - sigma_hat) * d
425
+ # Second order correction
426
+ if sigma_next != 0:
427
+ model_out_next = fn(x_next, sigma=sigma_next)
428
+ d_prime = (x_next - model_out_next) / sigma_next
429
+ x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
430
+ return x_next
431
+
432
+ def forward(
433
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
434
+ ) -> Tensor:
435
+ x = sigmas[0] * noise
436
+ # Compute gammas
437
+ gammas = torch.where(
438
+ (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
439
+ min(self.s_churn / num_steps, sqrt(2) - 1),
440
+ 0.0,
441
+ )
442
+ # Denoise to sample
443
+ for i in range(num_steps - 1):
444
+ x = self.step(
445
+ x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
446
+ )
447
+
448
+ return x
449
+
450
+
451
+ class AEulerSampler(Sampler):
452
+
453
+ diffusion_types = [KDiffusion, VKDiffusion]
454
+
455
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
456
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
457
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
458
+ return sigma_up, sigma_down
459
+
460
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
461
+ # Sigma steps
462
+ sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
463
+ # Derivative at sigma (∂x/∂sigma)
464
+ d = (x - fn(x, sigma=sigma)) / sigma
465
+ # Euler method
466
+ x_next = x + d * (sigma_down - sigma)
467
+ # Add randomness
468
+ x_next = x_next + torch.randn_like(x) * sigma_up
469
+ return x_next
470
+
471
+ def forward(
472
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
473
+ ) -> Tensor:
474
+ x = sigmas[0] * noise
475
+ # Denoise to sample
476
+ for i in range(num_steps - 1):
477
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
478
+ return x
479
+
480
+
481
+ class ADPM2Sampler(Sampler):
482
+ """https://www.desmos.com/calculator/jbxjlqd9mb"""
483
+
484
+ diffusion_types = [KDiffusion, VKDiffusion]
485
+
486
+ def __init__(self, rho: float = 1.0):
487
+ super().__init__()
488
+ self.rho = rho
489
+
490
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
491
+ r = self.rho
492
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
493
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
494
+ sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
495
+ return sigma_up, sigma_down, sigma_mid
496
+
497
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
498
+ # Sigma steps
499
+ sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
500
+ # Derivative at sigma (∂x/∂sigma)
501
+ d = (x - fn(x, sigma=sigma)) / sigma
502
+ # Denoise to midpoint
503
+ x_mid = x + d * (sigma_mid - sigma)
504
+ # Derivative at sigma_mid (∂x_mid/∂sigma_mid)
505
+ d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
506
+ # Denoise to next
507
+ x = x + d_mid * (sigma_down - sigma)
508
+ # Add randomness
509
+ x_next = x + torch.randn_like(x) * sigma_up
510
+ return x_next
511
+
512
+ def forward(
513
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
514
+ ) -> Tensor:
515
+ x = sigmas[0] * noise
516
+ # Denoise to sample
517
+ for i in range(num_steps - 1):
518
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
519
+ return x
520
+
521
+ def inpaint(
522
+ self,
523
+ source: Tensor,
524
+ mask: Tensor,
525
+ fn: Callable,
526
+ sigmas: Tensor,
527
+ num_steps: int,
528
+ num_resamples: int,
529
+ ) -> Tensor:
530
+ x = sigmas[0] * torch.randn_like(source)
531
+
532
+ for i in range(num_steps - 1):
533
+ # Noise source to current noise level
534
+ source_noisy = source + sigmas[i] * torch.randn_like(source)
535
+ for r in range(num_resamples):
536
+ # Merge noisy source and current then denoise
537
+ x = source_noisy * mask + x * ~mask
538
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
539
+ # Renoise if not last resample step
540
+ if r < num_resamples - 1:
541
+ sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
542
+ x = x + sigma * torch.randn_like(x)
543
+
544
+ return source * mask + x * ~mask
545
+
546
+
547
+ """ Main Classes """
548
+
549
+
550
+ class DiffusionSampler(nn.Module):
551
+ def __init__(
552
+ self,
553
+ diffusion: Diffusion,
554
+ *,
555
+ sampler: Sampler,
556
+ sigma_schedule: Schedule,
557
+ num_steps: Optional[int] = None,
558
+ clamp: bool = True,
559
+ ):
560
+ super().__init__()
561
+ self.denoise_fn = diffusion.denoise_fn
562
+ self.sampler = sampler
563
+ self.sigma_schedule = sigma_schedule
564
+ self.num_steps = num_steps
565
+ self.clamp = clamp
566
+
567
+ # Check sampler is compatible with diffusion type
568
+ sampler_class = sampler.__class__.__name__
569
+ diffusion_class = diffusion.__class__.__name__
570
+ message = f"{sampler_class} incompatible with {diffusion_class}"
571
+ assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
572
+
573
+ def forward(
574
+ self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
575
+ ) -> Tensor:
576
+ device = noise.device
577
+ num_steps = default(num_steps, self.num_steps) # type: ignore
578
+ assert exists(num_steps), "Parameter `num_steps` must be provided"
579
+ # Compute sigmas using schedule
580
+ sigmas = self.sigma_schedule(num_steps, device)
581
+ # Append additional kwargs to denoise function (used e.g. for conditional unet)
582
+ fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
583
+ # Sample using sampler
584
+ x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
585
+ x = x.clamp(-1.0, 1.0) if self.clamp else x
586
+ return x
587
+
588
+
589
+ class DiffusionInpainter(nn.Module):
590
+ def __init__(
591
+ self,
592
+ diffusion: Diffusion,
593
+ *,
594
+ num_steps: int,
595
+ num_resamples: int,
596
+ sampler: Sampler,
597
+ sigma_schedule: Schedule,
598
+ ):
599
+ super().__init__()
600
+ self.denoise_fn = diffusion.denoise_fn
601
+ self.num_steps = num_steps
602
+ self.num_resamples = num_resamples
603
+ self.inpaint_fn = sampler.inpaint
604
+ self.sigma_schedule = sigma_schedule
605
+
606
+ @torch.no_grad()
607
+ def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
608
+ x = self.inpaint_fn(
609
+ source=inpaint,
610
+ mask=inpaint_mask,
611
+ fn=self.denoise_fn,
612
+ sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
613
+ num_steps=self.num_steps,
614
+ num_resamples=self.num_resamples,
615
+ )
616
+ return x
617
+
618
+
619
+ def sequential_mask(like: Tensor, start: int) -> Tensor:
620
+ length, device = like.shape[2], like.device
621
+ mask = torch.ones_like(like, dtype=torch.bool)
622
+ mask[:, :, start:] = torch.zeros((length - start,), device=device)
623
+ return mask
624
+
625
+
626
+ class SpanBySpanComposer(nn.Module):
627
+ def __init__(
628
+ self,
629
+ inpainter: DiffusionInpainter,
630
+ *,
631
+ num_spans: int,
632
+ ):
633
+ super().__init__()
634
+ self.inpainter = inpainter
635
+ self.num_spans = num_spans
636
+
637
+ def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
638
+ half_length = start.shape[2] // 2
639
+
640
+ spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
641
+ # Inpaint second half from first half
642
+ inpaint = torch.zeros_like(start)
643
+ inpaint[:, :, :half_length] = start[:, :, half_length:]
644
+ inpaint_mask = sequential_mask(like=start, start=half_length)
645
+
646
+ for i in range(self.num_spans):
647
+ # Inpaint second half
648
+ span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
649
+ # Replace first half with generated second half
650
+ second_half = span[:, :, half_length:]
651
+ inpaint[:, :, :half_length] = second_half
652
+ # Save generated span
653
+ spans.append(second_half)
654
+
655
+ return torch.cat(spans, dim=2)
656
+
657
+
658
+ class XDiffusion(nn.Module):
659
+ def __init__(self, type: str, net: nn.Module, **kwargs):
660
+ super().__init__()
661
+
662
+ diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
663
+ aliases = [t.alias for t in diffusion_classes] # type: ignore
664
+ message = f"type='{type}' must be one of {*aliases,}"
665
+ assert type in aliases, message
666
+ self.net = net
667
+
668
+ for XDiffusion in diffusion_classes:
669
+ if XDiffusion.alias == type: # type: ignore
670
+ self.diffusion = XDiffusion(net=net, **kwargs)
671
+
672
+ def forward(self, *args, **kwargs) -> Tensor:
673
+ return self.diffusion(*args, **kwargs)
674
+
675
+ def sample(
676
+ self,
677
+ noise: Tensor,
678
+ num_steps: int,
679
+ sigma_schedule: Schedule,
680
+ sampler: Sampler,
681
+ clamp: bool,
682
+ **kwargs,
683
+ ) -> Tensor:
684
+ diffusion_sampler = DiffusionSampler(
685
+ diffusion=self.diffusion,
686
+ sampler=sampler,
687
+ sigma_schedule=sigma_schedule,
688
+ num_steps=num_steps,
689
+ clamp=clamp,
690
+ )
691
+ return diffusion_sampler(noise, **kwargs)
Modules/diffusion/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from inspect import isfunction
3
+ from math import ceil, floor, log2, pi
4
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import Generator, Tensor
10
+ from typing_extensions import TypeGuard
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ def exists(val: Optional[T]) -> TypeGuard[T]:
16
+ return val is not None
17
+
18
+
19
+ def iff(condition: bool, value: T) -> Optional[T]:
20
+ return value if condition else None
21
+
22
+
23
+ def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
24
+ return isinstance(obj, list) or isinstance(obj, tuple)
25
+
26
+
27
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
28
+ if exists(val):
29
+ return val
30
+ return d() if isfunction(d) else d
31
+
32
+
33
+ def to_list(val: Union[T, Sequence[T]]) -> List[T]:
34
+ if isinstance(val, tuple):
35
+ return list(val)
36
+ if isinstance(val, list):
37
+ return val
38
+ return [val] # type: ignore
39
+
40
+
41
+ def prod(vals: Sequence[int]) -> int:
42
+ return reduce(lambda x, y: x * y, vals)
43
+
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+ def rand_bool(shape, proba, device = None):
52
+ if proba == 1:
53
+ return torch.ones(shape, device=device, dtype=torch.bool)
54
+ elif proba == 0:
55
+ return torch.zeros(shape, device=device, dtype=torch.bool)
56
+ else:
57
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
58
+
59
+
60
+ """
61
+ Kwargs Utils
62
+ """
63
+
64
+
65
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
66
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
67
+ for key in d.keys():
68
+ no_prefix = int(not key.startswith(prefix))
69
+ return_dicts[no_prefix][key] = d[key]
70
+ return return_dicts
71
+
72
+
73
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
74
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
75
+ if keep_prefix:
76
+ return kwargs_with_prefix, kwargs
77
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
78
+ return kwargs_no_prefix, kwargs
79
+
80
+
81
+ def prefix_dict(prefix: str, d: Dict) -> Dict:
82
+ return {prefix + str(k): v for k, v in d.items()}
Modules/discriminators.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, spectral_norm
6
+
7
+ from .utils import get_padding
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+ def stft(x, fft_size, hop_size, win_length, window):
12
+ """Perform STFT and convert to magnitude spectrogram.
13
+ Args:
14
+ x (Tensor): Input signal tensor (B, T).
15
+ fft_size (int): FFT size.
16
+ hop_size (int): Hop size.
17
+ win_length (int): Window length.
18
+ window (str): Window function type.
19
+ Returns:
20
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
21
+ """
22
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
23
+ return_complex=True)
24
+ real = x_stft[..., 0]
25
+ imag = x_stft[..., 1]
26
+
27
+ return torch.abs(x_stft).transpose(2, 1)
28
+
29
+ class SpecDiscriminator(nn.Module):
30
+ """docstring for Discriminator."""
31
+
32
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
33
+ super(SpecDiscriminator, self).__init__()
34
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
35
+ self.fft_size = fft_size
36
+ self.shift_size = shift_size
37
+ self.win_length = win_length
38
+ self.window = getattr(torch, window)(win_length)
39
+ self.discriminators = nn.ModuleList([
40
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
41
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
42
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
43
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
44
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
45
+ ])
46
+
47
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
48
+
49
+ def forward(self, y):
50
+
51
+ fmap = []
52
+ y = y.squeeze(1)
53
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
54
+ y = y.unsqueeze(1)
55
+ for i, d in enumerate(self.discriminators):
56
+ y = d(y)
57
+ y = F.leaky_relu(y, LRELU_SLOPE)
58
+ fmap.append(y)
59
+
60
+ y = self.out(y)
61
+ fmap.append(y)
62
+
63
+ return torch.flatten(y, 1, -1), fmap
64
+
65
+ class MultiResSpecDiscriminator(torch.nn.Module):
66
+
67
+ def __init__(self,
68
+ fft_sizes=[1024, 2048, 512],
69
+ hop_sizes=[120, 240, 50],
70
+ win_lengths=[600, 1200, 240],
71
+ window="hann_window"):
72
+
73
+ super(MultiResSpecDiscriminator, self).__init__()
74
+ self.discriminators = nn.ModuleList([
75
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
76
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
77
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
78
+ ])
79
+
80
+ def forward(self, y, y_hat):
81
+ y_d_rs = []
82
+ y_d_gs = []
83
+ fmap_rs = []
84
+ fmap_gs = []
85
+ for i, d in enumerate(self.discriminators):
86
+ y_d_r, fmap_r = d(y)
87
+ y_d_g, fmap_g = d(y_hat)
88
+ y_d_rs.append(y_d_r)
89
+ fmap_rs.append(fmap_r)
90
+ y_d_gs.append(y_d_g)
91
+ fmap_gs.append(fmap_g)
92
+
93
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
94
+
95
+
96
+ class DiscriminatorP(torch.nn.Module):
97
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
98
+ super(DiscriminatorP, self).__init__()
99
+ self.period = period
100
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
101
+ self.convs = nn.ModuleList([
102
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
103
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
104
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
105
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
106
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
107
+ ])
108
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
109
+
110
+ def forward(self, x):
111
+ fmap = []
112
+
113
+ # 1d to 2d
114
+ b, c, t = x.shape
115
+ if t % self.period != 0: # pad first
116
+ n_pad = self.period - (t % self.period)
117
+ x = F.pad(x, (0, n_pad), "reflect")
118
+ t = t + n_pad
119
+ x = x.view(b, c, t // self.period, self.period)
120
+
121
+ for l in self.convs:
122
+ x = l(x)
123
+ x = F.leaky_relu(x, LRELU_SLOPE)
124
+ fmap.append(x)
125
+ x = self.conv_post(x)
126
+ fmap.append(x)
127
+ x = torch.flatten(x, 1, -1)
128
+
129
+ return x, fmap
130
+
131
+
132
+ class MultiPeriodDiscriminator(torch.nn.Module):
133
+ def __init__(self):
134
+ super(MultiPeriodDiscriminator, self).__init__()
135
+ self.discriminators = nn.ModuleList([
136
+ DiscriminatorP(2),
137
+ DiscriminatorP(3),
138
+ DiscriminatorP(5),
139
+ DiscriminatorP(7),
140
+ DiscriminatorP(11),
141
+ ])
142
+
143
+ def forward(self, y, y_hat):
144
+ y_d_rs = []
145
+ y_d_gs = []
146
+ fmap_rs = []
147
+ fmap_gs = []
148
+ for i, d in enumerate(self.discriminators):
149
+ y_d_r, fmap_r = d(y)
150
+ y_d_g, fmap_g = d(y_hat)
151
+ y_d_rs.append(y_d_r)
152
+ fmap_rs.append(fmap_r)
153
+ y_d_gs.append(y_d_g)
154
+ fmap_gs.append(fmap_g)
155
+
156
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
157
+
158
+ class WavLMDiscriminator(nn.Module):
159
+ """docstring for Discriminator."""
160
+
161
+ def __init__(self, slm_hidden=768,
162
+ slm_layers=13,
163
+ initial_channel=64,
164
+ use_spectral_norm=False):
165
+ super(WavLMDiscriminator, self).__init__()
166
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
167
+ self.pre = norm_f(Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0))
168
+
169
+ self.convs = nn.ModuleList([
170
+ norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)),
171
+ norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)),
172
+ norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)),
173
+ ])
174
+
175
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
176
+
177
+ def forward(self, x):
178
+ x = self.pre(x)
179
+
180
+ fmap = []
181
+ for l in self.convs:
182
+ x = l(x)
183
+ x = F.leaky_relu(x, LRELU_SLOPE)
184
+ fmap.append(x)
185
+ x = self.conv_post(x)
186
+ x = torch.flatten(x, 1, -1)
187
+
188
+ return x
Modules/hifigan.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+ class AdaIN1d(nn.Module):
15
+ def __init__(self, style_dim, num_features):
16
+ super().__init__()
17
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
18
+ self.fc = nn.Linear(style_dim, num_features*2)
19
+
20
+ def forward(self, x, s):
21
+ h = self.fc(s)
22
+ h = h.view(h.size(0), h.size(1), 1)
23
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
24
+ return (1 + gamma) * self.norm(x) + beta
25
+
26
+ class AdaINResBlock1(torch.nn.Module):
27
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
28
+ super(AdaINResBlock1, self).__init__()
29
+ self.convs1 = nn.ModuleList([
30
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
31
+ padding=get_padding(kernel_size, dilation[0]))),
32
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
33
+ padding=get_padding(kernel_size, dilation[1]))),
34
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
35
+ padding=get_padding(kernel_size, dilation[2])))
36
+ ])
37
+ self.convs1.apply(init_weights)
38
+
39
+ self.convs2 = nn.ModuleList([
40
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
41
+ padding=get_padding(kernel_size, 1))),
42
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
43
+ padding=get_padding(kernel_size, 1))),
44
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
45
+ padding=get_padding(kernel_size, 1)))
46
+ ])
47
+ self.convs2.apply(init_weights)
48
+
49
+ self.adain1 = nn.ModuleList([
50
+ AdaIN1d(style_dim, channels),
51
+ AdaIN1d(style_dim, channels),
52
+ AdaIN1d(style_dim, channels),
53
+ ])
54
+
55
+ self.adain2 = nn.ModuleList([
56
+ AdaIN1d(style_dim, channels),
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ ])
60
+
61
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
62
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
63
+
64
+
65
+ def forward(self, x, s):
66
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
67
+ xt = n1(x, s)
68
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
69
+ xt = c1(xt)
70
+ xt = n2(xt, s)
71
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
72
+ xt = c2(xt)
73
+ x = xt + x
74
+ return x
75
+
76
+ def remove_weight_norm(self):
77
+ for l in self.convs1:
78
+ remove_weight_norm(l)
79
+ for l in self.convs2:
80
+ remove_weight_norm(l)
81
+
82
+ class SineGen(torch.nn.Module):
83
+ """ Definition of sine generator
84
+ SineGen(samp_rate, harmonic_num = 0,
85
+ sine_amp = 0.1, noise_std = 0.003,
86
+ voiced_threshold = 0,
87
+ flag_for_pulse=False)
88
+ samp_rate: sampling rate in Hz
89
+ harmonic_num: number of harmonic overtones (default 0)
90
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
91
+ noise_std: std of Gaussian noise (default 0.003)
92
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
93
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
94
+ Note: when flag_for_pulse is True, the first time step of a voiced
95
+ segment is always sin(np.pi) or cos(0)
96
+ """
97
+
98
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
99
+ sine_amp=0.1, noise_std=0.003,
100
+ voiced_threshold=0,
101
+ flag_for_pulse=False):
102
+ super(SineGen, self).__init__()
103
+ self.sine_amp = sine_amp
104
+ self.noise_std = noise_std
105
+ self.harmonic_num = harmonic_num
106
+ self.dim = self.harmonic_num + 1
107
+ self.sampling_rate = samp_rate
108
+ self.voiced_threshold = voiced_threshold
109
+ self.flag_for_pulse = flag_for_pulse
110
+ self.upsample_scale = upsample_scale
111
+
112
+ def _f02uv(self, f0):
113
+ # generate uv signal
114
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
115
+ return uv
116
+
117
+ def _f02sine(self, f0_values):
118
+ """ f0_values: (batchsize, length, dim)
119
+ where dim indicates fundamental tone and overtones
120
+ """
121
+ # convert to F0 in rad. The interger part n can be ignored
122
+ # because 2 * np.pi * n doesn't affect phase
123
+ rad_values = (f0_values / self.sampling_rate) % 1
124
+
125
+ # initial phase noise (no noise for fundamental component)
126
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
127
+ device=f0_values.device)
128
+ rand_ini[:, 0] = 0
129
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
130
+
131
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
132
+ if not self.flag_for_pulse:
133
+ # # for normal case
134
+
135
+ # # To prevent torch.cumsum numerical overflow,
136
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
137
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
138
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
139
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
140
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
141
+ # cumsum_shift = torch.zeros_like(rad_values)
142
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
143
+
144
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
145
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
146
+ scale_factor=1/self.upsample_scale,
147
+ mode="linear").transpose(1, 2)
148
+
149
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
150
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
151
+ # cumsum_shift = torch.zeros_like(rad_values)
152
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
153
+
154
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
155
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
156
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
157
+ sines = torch.sin(phase)
158
+
159
+ else:
160
+ # If necessary, make sure that the first time step of every
161
+ # voiced segments is sin(pi) or cos(0)
162
+ # This is used for pulse-train generation
163
+
164
+ # identify the last time step in unvoiced segments
165
+ uv = self._f02uv(f0_values)
166
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
167
+ uv_1[:, -1, :] = 1
168
+ u_loc = (uv < 1) * (uv_1 > 0)
169
+
170
+ # get the instantanouse phase
171
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
172
+ # different batch needs to be processed differently
173
+ for idx in range(f0_values.shape[0]):
174
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
175
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
176
+ # stores the accumulation of i.phase within
177
+ # each voiced segments
178
+ tmp_cumsum[idx, :, :] = 0
179
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
180
+
181
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
182
+ # within the previous voiced segment.
183
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
184
+
185
+ # get the sines
186
+ sines = torch.cos(i_phase * 2 * np.pi)
187
+ return sines
188
+
189
+ def forward(self, f0):
190
+ """ sine_tensor, uv = forward(f0)
191
+ input F0: tensor(batchsize=1, length, dim=1)
192
+ f0 for unvoiced steps should be 0
193
+ output sine_tensor: tensor(batchsize=1, length, dim)
194
+ output uv: tensor(batchsize=1, length, 1)
195
+ """
196
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
197
+ device=f0.device)
198
+ # fundamental component
199
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
200
+
201
+ # generate sine waveforms
202
+ sine_waves = self._f02sine(fn) * self.sine_amp
203
+
204
+ # generate uv signal
205
+ # uv = torch.ones(f0.shape)
206
+ # uv = uv * (f0 > self.voiced_threshold)
207
+ uv = self._f02uv(f0)
208
+
209
+ # noise: for unvoiced should be similar to sine_amp
210
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
211
+ # . for voiced regions is self.noise_std
212
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
213
+ noise = noise_amp * torch.randn_like(sine_waves)
214
+
215
+ # first: set the unvoiced part to 0 by uv
216
+ # then: additive noise
217
+ sine_waves = sine_waves * uv + noise
218
+ return sine_waves, uv, noise
219
+
220
+
221
+ class SourceModuleHnNSF(torch.nn.Module):
222
+ """ SourceModule for hn-nsf
223
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
224
+ add_noise_std=0.003, voiced_threshod=0)
225
+ sampling_rate: sampling_rate in Hz
226
+ harmonic_num: number of harmonic above F0 (default: 0)
227
+ sine_amp: amplitude of sine source signal (default: 0.1)
228
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
229
+ note that amplitude of noise in unvoiced is decided
230
+ by sine_amp
231
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
232
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
233
+ F0_sampled (batchsize, length, 1)
234
+ Sine_source (batchsize, length, 1)
235
+ noise_source (batchsize, length 1)
236
+ uv (batchsize, length, 1)
237
+ """
238
+
239
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
240
+ add_noise_std=0.003, voiced_threshod=0):
241
+ super(SourceModuleHnNSF, self).__init__()
242
+
243
+ self.sine_amp = sine_amp
244
+ self.noise_std = add_noise_std
245
+
246
+ # to produce sine waveforms
247
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
248
+ sine_amp, add_noise_std, voiced_threshod)
249
+
250
+ # to merge source harmonics into a single excitation
251
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
252
+ self.l_tanh = torch.nn.Tanh()
253
+
254
+ def forward(self, x):
255
+ """
256
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
257
+ F0_sampled (batchsize, length, 1)
258
+ Sine_source (batchsize, length, 1)
259
+ noise_source (batchsize, length 1)
260
+ """
261
+ # source for harmonic branch
262
+ with torch.no_grad():
263
+ sine_wavs, uv, _ = self.l_sin_gen(x)
264
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
265
+
266
+ # source for noise branch, in the same shape as uv
267
+ noise = torch.randn_like(uv) * self.sine_amp / 3
268
+ return sine_merge, noise, uv
269
+ def padDiff(x):
270
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
271
+
272
+ class Generator(torch.nn.Module):
273
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes):
274
+ super(Generator, self).__init__()
275
+ self.num_kernels = len(resblock_kernel_sizes)
276
+ self.num_upsamples = len(upsample_rates)
277
+ resblock = AdaINResBlock1
278
+
279
+ self.m_source = SourceModuleHnNSF(
280
+ sampling_rate=24000,
281
+ upsample_scale=np.prod(upsample_rates),
282
+ harmonic_num=8, voiced_threshod=10)
283
+
284
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
285
+ self.noise_convs = nn.ModuleList()
286
+ self.ups = nn.ModuleList()
287
+ self.noise_res = nn.ModuleList()
288
+
289
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
290
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
291
+
292
+ self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel//(2**i),
293
+ upsample_initial_channel//(2**(i+1)),
294
+ k, u, padding=(u//2 + u%2), output_padding=u%2)))
295
+
296
+ if i + 1 < len(upsample_rates): #
297
+ stride_f0 = np.prod(upsample_rates[i + 1:])
298
+ self.noise_convs.append(Conv1d(
299
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
300
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
301
+ else:
302
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
303
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
304
+
305
+ self.resblocks = nn.ModuleList()
306
+
307
+ self.alphas = nn.ParameterList()
308
+ self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1)))
309
+
310
+ for i in range(len(self.ups)):
311
+ ch = upsample_initial_channel//(2**(i+1))
312
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
313
+
314
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
315
+ self.resblocks.append(resblock(ch, k, d, style_dim))
316
+
317
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
318
+ self.ups.apply(init_weights)
319
+ self.conv_post.apply(init_weights)
320
+
321
+ def forward(self, x, s, f0):
322
+
323
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
324
+
325
+ har_source, noi_source, uv = self.m_source(f0)
326
+ har_source = har_source.transpose(1, 2)
327
+
328
+ for i in range(self.num_upsamples):
329
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
330
+ x_source = self.noise_convs[i](har_source)
331
+ x_source = self.noise_res[i](x_source, s)
332
+
333
+ x = self.ups[i](x)
334
+ x = x + x_source
335
+
336
+ xs = None
337
+ for j in range(self.num_kernels):
338
+ if xs is None:
339
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
340
+ else:
341
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
342
+ x = xs / self.num_kernels
343
+ x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2)
344
+ x = self.conv_post(x)
345
+ x = torch.tanh(x)
346
+
347
+ return x
348
+
349
+ def remove_weight_norm(self):
350
+ print('Removing weight norm...')
351
+ for l in self.ups:
352
+ remove_weight_norm(l)
353
+ for l in self.resblocks:
354
+ l.remove_weight_norm()
355
+ remove_weight_norm(self.conv_pre)
356
+ remove_weight_norm(self.conv_post)
357
+
358
+
359
+ class AdainResBlk1d(nn.Module):
360
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
361
+ upsample='none', dropout_p=0.0):
362
+ super().__init__()
363
+ self.actv = actv
364
+ self.upsample_type = upsample
365
+ self.upsample = UpSample1d(upsample)
366
+ self.learned_sc = dim_in != dim_out
367
+ self._build_weights(dim_in, dim_out, style_dim)
368
+ self.dropout = nn.Dropout(dropout_p)
369
+
370
+ if upsample == 'none':
371
+ self.pool = nn.Identity()
372
+ else:
373
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
374
+
375
+
376
+ def _build_weights(self, dim_in, dim_out, style_dim):
377
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
378
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
379
+ self.norm1 = AdaIN1d(style_dim, dim_in)
380
+ self.norm2 = AdaIN1d(style_dim, dim_out)
381
+ if self.learned_sc:
382
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
383
+
384
+ def _shortcut(self, x):
385
+ x = self.upsample(x)
386
+ if self.learned_sc:
387
+ x = self.conv1x1(x)
388
+ return x
389
+
390
+ def _residual(self, x, s):
391
+ x = self.norm1(x, s)
392
+ x = self.actv(x)
393
+ x = self.pool(x)
394
+ x = self.conv1(self.dropout(x))
395
+ x = self.norm2(x, s)
396
+ x = self.actv(x)
397
+ x = self.conv2(self.dropout(x))
398
+ return x
399
+
400
+ def forward(self, x, s):
401
+ out = self._residual(x, s)
402
+ out = (out + self._shortcut(x)) / math.sqrt(2)
403
+ return out
404
+
405
+ class UpSample1d(nn.Module):
406
+ def __init__(self, layer_type):
407
+ super().__init__()
408
+ self.layer_type = layer_type
409
+
410
+ def forward(self, x):
411
+ if self.layer_type == 'none':
412
+ return x
413
+ else:
414
+ return F.interpolate(x, scale_factor=2, mode='nearest')
415
+
416
+ class Decoder(nn.Module):
417
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
418
+ resblock_kernel_sizes = [3,7,11],
419
+ upsample_rates = [10,5,3,2],
420
+ upsample_initial_channel=512,
421
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
422
+ upsample_kernel_sizes=[20,10,6,4]):
423
+ super().__init__()
424
+
425
+ self.decode = nn.ModuleList()
426
+
427
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
428
+
429
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
430
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
431
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
432
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
433
+
434
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
435
+
436
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
437
+
438
+ self.asr_res = nn.Sequential(
439
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
440
+ )
441
+
442
+
443
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
444
+
445
+
446
+ def forward(self, asr, F0_curve, N, s):
447
+ if self.training:
448
+ downlist = [0, 3, 7]
449
+ F0_down = downlist[random.randint(0, 2)]
450
+ downlist = [0, 3, 7, 15]
451
+ N_down = downlist[random.randint(0, 3)]
452
+ if F0_down:
453
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
454
+ if N_down:
455
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
456
+
457
+
458
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
459
+ N = self.N_conv(N.unsqueeze(1))
460
+
461
+ x = torch.cat([asr, F0, N], axis=1)
462
+ x = self.encode(x, s)
463
+
464
+ asr_res = self.asr_res(asr)
465
+
466
+ res = True
467
+ for block in self.decode:
468
+ if res:
469
+ x = torch.cat([x, asr_res, F0, N], axis=1)
470
+ x = block(x, s)
471
+ if block.upsample_type != "none":
472
+ res = False
473
+
474
+ x = self.generator(x, s, F0_curve)
475
+ return x
476
+
477
+
Modules/istftnet.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+ from scipy.signal import get_window
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+ class AdaIN1d(nn.Module):
16
+ def __init__(self, style_dim, num_features):
17
+ super().__init__()
18
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
19
+ self.fc = nn.Linear(style_dim, num_features*2)
20
+
21
+ def forward(self, x, s):
22
+ h = self.fc(s)
23
+ h = h.view(h.size(0), h.size(1), 1)
24
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
25
+ return (1 + gamma) * self.norm(x) + beta
26
+
27
+ class AdaINResBlock1(torch.nn.Module):
28
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
29
+ super(AdaINResBlock1, self).__init__()
30
+ self.convs1 = nn.ModuleList([
31
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
32
+ padding=get_padding(kernel_size, dilation[0]))),
33
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
34
+ padding=get_padding(kernel_size, dilation[1]))),
35
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
36
+ padding=get_padding(kernel_size, dilation[2])))
37
+ ])
38
+ self.convs1.apply(init_weights)
39
+
40
+ self.convs2 = nn.ModuleList([
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
42
+ padding=get_padding(kernel_size, 1))),
43
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
44
+ padding=get_padding(kernel_size, 1))),
45
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
+ padding=get_padding(kernel_size, 1)))
47
+ ])
48
+ self.convs2.apply(init_weights)
49
+
50
+ self.adain1 = nn.ModuleList([
51
+ AdaIN1d(style_dim, channels),
52
+ AdaIN1d(style_dim, channels),
53
+ AdaIN1d(style_dim, channels),
54
+ ])
55
+
56
+ self.adain2 = nn.ModuleList([
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ AdaIN1d(style_dim, channels),
60
+ ])
61
+
62
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
63
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
64
+
65
+
66
+ def forward(self, x, s):
67
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
68
+ xt = n1(x, s)
69
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
70
+ xt = c1(xt)
71
+ xt = n2(xt, s)
72
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
73
+ xt = c2(xt)
74
+ x = xt + x
75
+ return x
76
+
77
+ def remove_weight_norm(self):
78
+ for l in self.convs1:
79
+ remove_weight_norm(l)
80
+ for l in self.convs2:
81
+ remove_weight_norm(l)
82
+
83
+ class TorchSTFT(torch.nn.Module):
84
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
85
+ super().__init__()
86
+ self.filter_length = filter_length
87
+ self.hop_length = hop_length
88
+ self.win_length = win_length
89
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
90
+
91
+ def transform(self, input_data):
92
+ forward_transform = torch.stft(
93
+ input_data,
94
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
95
+ return_complex=True)
96
+
97
+ return torch.abs(forward_transform), torch.angle(forward_transform)
98
+
99
+ def inverse(self, magnitude, phase):
100
+ inverse_transform = torch.istft(
101
+ magnitude * torch.exp(phase * 1j),
102
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
103
+
104
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
105
+
106
+ def forward(self, input_data):
107
+ self.magnitude, self.phase = self.transform(input_data)
108
+ reconstruction = self.inverse(self.magnitude, self.phase)
109
+ return reconstruction
110
+
111
+ class SineGen(torch.nn.Module):
112
+ """ Definition of sine generator
113
+ SineGen(samp_rate, harmonic_num = 0,
114
+ sine_amp = 0.1, noise_std = 0.003,
115
+ voiced_threshold = 0,
116
+ flag_for_pulse=False)
117
+ samp_rate: sampling rate in Hz
118
+ harmonic_num: number of harmonic overtones (default 0)
119
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
120
+ noise_std: std of Gaussian noise (default 0.003)
121
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
122
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
123
+ Note: when flag_for_pulse is True, the first time step of a voiced
124
+ segment is always sin(np.pi) or cos(0)
125
+ """
126
+
127
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
128
+ sine_amp=0.1, noise_std=0.003,
129
+ voiced_threshold=0,
130
+ flag_for_pulse=False):
131
+ super(SineGen, self).__init__()
132
+ self.sine_amp = sine_amp
133
+ self.noise_std = noise_std
134
+ self.harmonic_num = harmonic_num
135
+ self.dim = self.harmonic_num + 1
136
+ self.sampling_rate = samp_rate
137
+ self.voiced_threshold = voiced_threshold
138
+ self.flag_for_pulse = flag_for_pulse
139
+ self.upsample_scale = upsample_scale
140
+
141
+ def _f02uv(self, f0):
142
+ # generate uv signal
143
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
144
+ return uv
145
+
146
+ def _f02sine(self, f0_values):
147
+ """ f0_values: (batchsize, length, dim)
148
+ where dim indicates fundamental tone and overtones
149
+ """
150
+ # convert to F0 in rad. The interger part n can be ignored
151
+ # because 2 * np.pi * n doesn't affect phase
152
+ rad_values = (f0_values / self.sampling_rate) % 1
153
+
154
+ # initial phase noise (no noise for fundamental component)
155
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
156
+ device=f0_values.device)
157
+ rand_ini[:, 0] = 0
158
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
159
+
160
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
161
+ if not self.flag_for_pulse:
162
+ # # for normal case
163
+
164
+ # # To prevent torch.cumsum numerical overflow,
165
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
166
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
167
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
168
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
169
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
170
+ # cumsum_shift = torch.zeros_like(rad_values)
171
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
172
+
173
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
174
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
175
+ scale_factor=1/self.upsample_scale,
176
+ mode="linear").transpose(1, 2)
177
+
178
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
179
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
180
+ # cumsum_shift = torch.zeros_like(rad_values)
181
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
182
+
183
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
184
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
185
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
186
+ sines = torch.sin(phase)
187
+
188
+ else:
189
+ # If necessary, make sure that the first time step of every
190
+ # voiced segments is sin(pi) or cos(0)
191
+ # This is used for pulse-train generation
192
+
193
+ # identify the last time step in unvoiced segments
194
+ uv = self._f02uv(f0_values)
195
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
196
+ uv_1[:, -1, :] = 1
197
+ u_loc = (uv < 1) * (uv_1 > 0)
198
+
199
+ # get the instantanouse phase
200
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
201
+ # different batch needs to be processed differently
202
+ for idx in range(f0_values.shape[0]):
203
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
204
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
205
+ # stores the accumulation of i.phase within
206
+ # each voiced segments
207
+ tmp_cumsum[idx, :, :] = 0
208
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
209
+
210
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
211
+ # within the previous voiced segment.
212
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
213
+
214
+ # get the sines
215
+ sines = torch.cos(i_phase * 2 * np.pi)
216
+ return sines
217
+
218
+ def forward(self, f0):
219
+ """ sine_tensor, uv = forward(f0)
220
+ input F0: tensor(batchsize=1, length, dim=1)
221
+ f0 for unvoiced steps should be 0
222
+ output sine_tensor: tensor(batchsize=1, length, dim)
223
+ output uv: tensor(batchsize=1, length, 1)
224
+ """
225
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
226
+ device=f0.device)
227
+ # fundamental component
228
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
229
+
230
+ # generate sine waveforms
231
+ sine_waves = self._f02sine(fn) * self.sine_amp
232
+
233
+ # generate uv signal
234
+ # uv = torch.ones(f0.shape)
235
+ # uv = uv * (f0 > self.voiced_threshold)
236
+ uv = self._f02uv(f0)
237
+
238
+ # noise: for unvoiced should be similar to sine_amp
239
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
240
+ # . for voiced regions is self.noise_std
241
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
242
+ noise = noise_amp * torch.randn_like(sine_waves)
243
+
244
+ # first: set the unvoiced part to 0 by uv
245
+ # then: additive noise
246
+ sine_waves = sine_waves * uv + noise
247
+ return sine_waves, uv, noise
248
+
249
+
250
+ class SourceModuleHnNSF(torch.nn.Module):
251
+ """ SourceModule for hn-nsf
252
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
253
+ add_noise_std=0.003, voiced_threshod=0)
254
+ sampling_rate: sampling_rate in Hz
255
+ harmonic_num: number of harmonic above F0 (default: 0)
256
+ sine_amp: amplitude of sine source signal (default: 0.1)
257
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
258
+ note that amplitude of noise in unvoiced is decided
259
+ by sine_amp
260
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
261
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
262
+ F0_sampled (batchsize, length, 1)
263
+ Sine_source (batchsize, length, 1)
264
+ noise_source (batchsize, length 1)
265
+ uv (batchsize, length, 1)
266
+ """
267
+
268
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
269
+ add_noise_std=0.003, voiced_threshod=0):
270
+ super(SourceModuleHnNSF, self).__init__()
271
+
272
+ self.sine_amp = sine_amp
273
+ self.noise_std = add_noise_std
274
+
275
+ # to produce sine waveforms
276
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
277
+ sine_amp, add_noise_std, voiced_threshod)
278
+
279
+ # to merge source harmonics into a single excitation
280
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
281
+ self.l_tanh = torch.nn.Tanh()
282
+
283
+ def forward(self, x):
284
+ """
285
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
286
+ F0_sampled (batchsize, length, 1)
287
+ Sine_source (batchsize, length, 1)
288
+ noise_source (batchsize, length 1)
289
+ """
290
+ # source for harmonic branch
291
+ with torch.no_grad():
292
+ sine_wavs, uv, _ = self.l_sin_gen(x)
293
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
294
+
295
+ # source for noise branch, in the same shape as uv
296
+ noise = torch.randn_like(uv) * self.sine_amp / 3
297
+ return sine_merge, noise, uv
298
+ def padDiff(x):
299
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
300
+
301
+
302
+ class Generator(torch.nn.Module):
303
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
304
+ super(Generator, self).__init__()
305
+
306
+ self.num_kernels = len(resblock_kernel_sizes)
307
+ self.num_upsamples = len(upsample_rates)
308
+ resblock = AdaINResBlock1
309
+
310
+ self.m_source = SourceModuleHnNSF(
311
+ sampling_rate=24000,
312
+ upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
313
+ harmonic_num=8, voiced_threshod=10)
314
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
315
+ self.noise_convs = nn.ModuleList()
316
+ self.noise_res = nn.ModuleList()
317
+
318
+ self.ups = nn.ModuleList()
319
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
320
+ self.ups.append(weight_norm(
321
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
322
+ k, u, padding=(k-u)//2)))
323
+
324
+ self.resblocks = nn.ModuleList()
325
+ for i in range(len(self.ups)):
326
+ ch = upsample_initial_channel//(2**(i+1))
327
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
328
+ self.resblocks.append(resblock(ch, k, d, style_dim))
329
+
330
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
331
+
332
+ if i + 1 < len(upsample_rates): #
333
+ stride_f0 = np.prod(upsample_rates[i + 1:])
334
+ self.noise_convs.append(Conv1d(
335
+ gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
336
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
337
+ else:
338
+ self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
339
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
340
+
341
+
342
+ self.post_n_fft = gen_istft_n_fft
343
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
344
+ self.ups.apply(init_weights)
345
+ self.conv_post.apply(init_weights)
346
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
347
+ self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
348
+
349
+
350
+ def forward(self, x, s, f0):
351
+ with torch.no_grad():
352
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
353
+
354
+ har_source, noi_source, uv = self.m_source(f0)
355
+ har_source = har_source.transpose(1, 2).squeeze(1)
356
+ har_spec, har_phase = self.stft.transform(har_source)
357
+ har = torch.cat([har_spec, har_phase], dim=1)
358
+
359
+ for i in range(self.num_upsamples):
360
+ x = F.leaky_relu(x, LRELU_SLOPE)
361
+ x_source = self.noise_convs[i](har)
362
+ x_source = self.noise_res[i](x_source, s)
363
+
364
+ x = self.ups[i](x)
365
+ if i == self.num_upsamples - 1:
366
+ x = self.reflection_pad(x)
367
+
368
+ x = x + x_source
369
+ xs = None
370
+ for j in range(self.num_kernels):
371
+ if xs is None:
372
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
373
+ else:
374
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
375
+ x = xs / self.num_kernels
376
+ x = F.leaky_relu(x)
377
+ x = self.conv_post(x)
378
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
379
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
380
+ return self.stft.inverse(spec, phase)
381
+
382
+ def fw_phase(self, x, s):
383
+ for i in range(self.num_upsamples):
384
+ x = F.leaky_relu(x, LRELU_SLOPE)
385
+ x = self.ups[i](x)
386
+ xs = None
387
+ for j in range(self.num_kernels):
388
+ if xs is None:
389
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
390
+ else:
391
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
392
+ x = xs / self.num_kernels
393
+ x = F.leaky_relu(x)
394
+ x = self.reflection_pad(x)
395
+ x = self.conv_post(x)
396
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
397
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
398
+ return spec, phase
399
+
400
+ def remove_weight_norm(self):
401
+ print('Removing weight norm...')
402
+ for l in self.ups:
403
+ remove_weight_norm(l)
404
+ for l in self.resblocks:
405
+ l.remove_weight_norm()
406
+ remove_weight_norm(self.conv_pre)
407
+ remove_weight_norm(self.conv_post)
408
+
409
+
410
+ class AdainResBlk1d(nn.Module):
411
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
412
+ upsample='none', dropout_p=0.0):
413
+ super().__init__()
414
+ self.actv = actv
415
+ self.upsample_type = upsample
416
+ self.upsample = UpSample1d(upsample)
417
+ self.learned_sc = dim_in != dim_out
418
+ self._build_weights(dim_in, dim_out, style_dim)
419
+ self.dropout = nn.Dropout(dropout_p)
420
+
421
+ if upsample == 'none':
422
+ self.pool = nn.Identity()
423
+ else:
424
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
425
+
426
+
427
+ def _build_weights(self, dim_in, dim_out, style_dim):
428
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
429
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
430
+ self.norm1 = AdaIN1d(style_dim, dim_in)
431
+ self.norm2 = AdaIN1d(style_dim, dim_out)
432
+ if self.learned_sc:
433
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
434
+
435
+ def _shortcut(self, x):
436
+ x = self.upsample(x)
437
+ if self.learned_sc:
438
+ x = self.conv1x1(x)
439
+ return x
440
+
441
+ def _residual(self, x, s):
442
+ x = self.norm1(x, s)
443
+ x = self.actv(x)
444
+ x = self.pool(x)
445
+ x = self.conv1(self.dropout(x))
446
+ x = self.norm2(x, s)
447
+ x = self.actv(x)
448
+ x = self.conv2(self.dropout(x))
449
+ return x
450
+
451
+ def forward(self, x, s):
452
+ out = self._residual(x, s)
453
+ out = (out + self._shortcut(x)) / math.sqrt(2)
454
+ return out
455
+
456
+ class UpSample1d(nn.Module):
457
+ def __init__(self, layer_type):
458
+ super().__init__()
459
+ self.layer_type = layer_type
460
+
461
+ def forward(self, x):
462
+ if self.layer_type == 'none':
463
+ return x
464
+ else:
465
+ return F.interpolate(x, scale_factor=2, mode='nearest')
466
+
467
+ class Decoder(nn.Module):
468
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
469
+ resblock_kernel_sizes = [3,7,11],
470
+ upsample_rates = [10, 6],
471
+ upsample_initial_channel=512,
472
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
473
+ upsample_kernel_sizes=[20, 12],
474
+ gen_istft_n_fft=20, gen_istft_hop_size=5):
475
+ super().__init__()
476
+
477
+ self.decode = nn.ModuleList()
478
+
479
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
480
+
481
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
482
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
483
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
484
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
485
+
486
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
487
+
488
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
489
+
490
+ self.asr_res = nn.Sequential(
491
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
492
+ )
493
+
494
+
495
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
496
+ upsample_initial_channel, resblock_dilation_sizes,
497
+ upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
498
+
499
+ def forward(self, asr, F0_curve, N, s):
500
+ if self.training:
501
+ downlist = [0, 3, 7]
502
+ F0_down = downlist[random.randint(0, 2)]
503
+ downlist = [0, 3, 7, 15]
504
+ N_down = downlist[random.randint(0, 3)]
505
+ if F0_down:
506
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
507
+ if N_down:
508
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
509
+
510
+
511
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
512
+ N = self.N_conv(N.unsqueeze(1))
513
+
514
+ x = torch.cat([asr, F0, N], axis=1)
515
+ x = self.encode(x, s)
516
+
517
+ asr_res = self.asr_res(asr)
518
+
519
+ res = True
520
+ for block in self.decode:
521
+ if res:
522
+ x = torch.cat([x, asr_res, F0, N], axis=1)
523
+ x = block(x, s)
524
+ if block.upsample_type != "none":
525
+ res = False
526
+
527
+ x = self.generator(x, s, F0_curve)
528
+ return x
529
+
530
+
Modules/slmadv.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ class SLMAdversarialLoss(torch.nn.Module):
6
+
7
+ def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5):
8
+ super(SLMAdversarialLoss, self).__init__()
9
+ self.model = model
10
+ self.wl = wl
11
+ self.sampler = sampler
12
+
13
+ self.min_len = min_len
14
+ self.max_len = max_len
15
+ self.batch_percentage = batch_percentage
16
+
17
+ self.sig = sig
18
+ self.skip_update = skip_update
19
+
20
+ def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None):
21
+ text_mask = length_to_mask(ref_lengths).to(ref_text.device)
22
+ bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
23
+ d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
24
+
25
+ if use_ind and np.random.rand() < 0.5:
26
+ s_preds = s_trg
27
+ else:
28
+ num_steps = np.random.randint(3, 5)
29
+ if ref_s is not None:
30
+ s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
31
+ embedding=bert_dur,
32
+ embedding_scale=1,
33
+ features=ref_s, # reference from the same speaker as the embedding
34
+ embedding_mask_proba=0.1,
35
+ num_steps=num_steps).squeeze(1)
36
+ else:
37
+ s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
38
+ embedding=bert_dur,
39
+ embedding_scale=1,
40
+ embedding_mask_proba=0.1,
41
+ num_steps=num_steps).squeeze(1)
42
+
43
+ s_dur = s_preds[:, 128:]
44
+ s = s_preds[:, :128]
45
+
46
+ d, _ = self.model.predictor(d_en, s_dur,
47
+ ref_lengths,
48
+ torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
49
+ text_mask)
50
+
51
+ bib = 0
52
+
53
+ output_lengths = []
54
+ attn_preds = []
55
+
56
+ # differentiable duration modeling
57
+ for _s2s_pred, _text_length in zip(d, ref_lengths):
58
+
59
+ _s2s_pred_org = _s2s_pred[:_text_length, :]
60
+
61
+ _s2s_pred = torch.sigmoid(_s2s_pred_org)
62
+ _dur_pred = _s2s_pred.sum(axis=-1)
63
+
64
+ l = int(torch.round(_s2s_pred.sum()).item())
65
+ t = torch.arange(0, l).expand(l)
66
+
67
+ t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device)
68
+ loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
69
+
70
+ h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2)
71
+
72
+ out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0),
73
+ h.unsqueeze(1),
74
+ padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
75
+ attn_preds.append(F.softmax(out.squeeze(), dim=0))
76
+
77
+ output_lengths.append(l)
78
+
79
+ max_len = max(output_lengths)
80
+
81
+ with torch.no_grad():
82
+ t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
83
+
84
+ s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(ref_text.device)
85
+ for bib in range(len(output_lengths)):
86
+ s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib]
87
+
88
+ asr_pred = t_en @ s2s_attn
89
+
90
+ _, p_pred = self.model.predictor(d_en, s_dur,
91
+ ref_lengths,
92
+ s2s_attn,
93
+ text_mask)
94
+
95
+ mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
96
+ mel_len = min(mel_len, self.max_len // 2)
97
+
98
+ # get clips
99
+
100
+ en = []
101
+ p_en = []
102
+ sp = []
103
+
104
+ F0_fakes = []
105
+ N_fakes = []
106
+
107
+ wav = []
108
+
109
+ for bib in range(len(output_lengths)):
110
+ mel_length_pred = output_lengths[bib]
111
+ mel_length_gt = int(mel_input_length[bib].item() / 2)
112
+ if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
113
+ continue
114
+
115
+ sp.append(s_preds[bib])
116
+
117
+ random_start = np.random.randint(0, mel_length_pred - mel_len)
118
+ en.append(asr_pred[bib, :, random_start:random_start+mel_len])
119
+ p_en.append(p_pred[bib, :, random_start:random_start+mel_len])
120
+
121
+ # get ground truth clips
122
+ random_start = np.random.randint(0, mel_length_gt - mel_len)
123
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
124
+ wav.append(torch.from_numpy(y).to(ref_text.device))
125
+
126
+ if len(wav) >= self.batch_percentage * len(waves): # prevent OOM due to longer lengths
127
+ break
128
+
129
+ if len(sp) <= 1:
130
+ return None
131
+
132
+ sp = torch.stack(sp)
133
+ wav = torch.stack(wav).float()
134
+ en = torch.stack(en)
135
+ p_en = torch.stack(p_en)
136
+
137
+ F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
138
+ y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
139
+
140
+ # discriminator loss
141
+ if (iters + 1) % self.skip_update == 0:
142
+ if np.random.randint(0, 2) == 0:
143
+ wav = y_rec_gt_pred
144
+ use_rec = True
145
+ else:
146
+ use_rec = False
147
+
148
+ crop_size = min(wav.size(-1), y_pred.size(-1))
149
+ if use_rec: # use reconstructed (shorter lengths), do length invariant regularization
150
+ if wav.size(-1) > y_pred.size(-1):
151
+ real_GP = wav[:, : , :crop_size]
152
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
153
+ out_org = self.wl.discriminator_forward(wav.detach().squeeze())
154
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
155
+
156
+ if np.random.randint(0, 2) == 0:
157
+ d_loss = self.wl.discriminator(real_GP.detach().squeeze(), y_pred.detach().squeeze()).mean()
158
+ else:
159
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
160
+ else:
161
+ real_GP = y_pred[:, : , :crop_size]
162
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
163
+ out_org = self.wl.discriminator_forward(y_pred.detach().squeeze())
164
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
165
+
166
+ if np.random.randint(0, 2) == 0:
167
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), real_GP.detach().squeeze()).mean()
168
+ else:
169
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
170
+
171
+ # regularization (ignore length variation)
172
+ d_loss += loss_reg
173
+
174
+ out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze())
175
+ out_rec = self.wl.discriminator_forward(y_rec_gt_pred.detach().squeeze())
176
+
177
+ # regularization (ignore reconstruction artifacts)
178
+ d_loss += F.l1_loss(out_gt, out_rec)
179
+
180
+ else:
181
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
182
+ else:
183
+ d_loss = 0
184
+
185
+ # generator loss
186
+ gen_loss = self.wl.generator(y_pred.squeeze())
187
+
188
+ gen_loss = gen_loss.mean()
189
+
190
+ return d_loss, gen_loss, y_pred.detach().cpu().numpy()
191
+
192
+ def length_to_mask(lengths):
193
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
194
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
195
+ return mask
Modules/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def init_weights(m, mean=0.0, std=0.01):
2
+ classname = m.__class__.__name__
3
+ if classname.find("Conv") != -1:
4
+ m.weight.data.normal_(mean, std)
5
+
6
+
7
+ def apply_weight_norm(m):
8
+ classname = m.__class__.__name__
9
+ if classname.find("Conv") != -1:
10
+ weight_norm(m)
11
+
12
+
13
+ def get_padding(kernel_size, dilation=1):
14
+ return int((kernel_size*dilation - dilation)/2)
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Styletts2
3
+ emoji: 🦀
4
+ colorFrom: blue
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Utils_extend_v1/.ipynb_checkpoints/__init__-checkpoint.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b
3
+ size 1
Utils_extend_v1/ASR/.ipynb_checkpoints/config-checkpoint.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97bcbaad6c3198ef383a461374f4e88f495a7649d7b860b6088f09ada9e99ee8
3
+ size 481
Utils_extend_v1/ASR/.ipynb_checkpoints/layers-checkpoint.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f7335e967b23ae8571c421164681a6582284d4b3839900edc0237a408f34705
3
+ size 13454
Utils_extend_v1/ASR/.ipynb_checkpoints/model_struct-checkpoint.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e26691350743d0339c05a128805fc9792afe45504f2a4506687f52d28c31546
3
+ size 11444
Utils_extend_v1/ASR/.ipynb_checkpoints/models-checkpoint.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32fb38b3e45e3dbc25af8fdf6df45343ebdc2833aa8798049ee2c99559e8fa36
3
+ size 7272
Utils_extend_v1/ASR/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b
3
+ size 1
Utils_extend_v1/ASR/__pycache__/__init__.cpython-310.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fdb39bf0760c77cb9d7a489ce3a6034aabc6a681717b169584b23af51bb92a1
3
+ size 154
Utils_extend_v1/ASR/__pycache__/__init__.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7508768ca99aa543e7594a99b98100aca17dd69a8dab59faab23570d90066cb6
3
+ size 153
Utils_extend_v1/ASR/__pycache__/layers.cpython-310.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b54f3010d18c0c11a09e10efd112c677852350697eba8506b270302faf12ea9
3
+ size 11044
Utils_extend_v1/ASR/__pycache__/layers.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:800919299847e8bad3c9202a048791b4860226ebbeb8ada1d0c585dac485be16
3
+ size 17822
Utils_extend_v1/ASR/__pycache__/models.cpython-310.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54657f28006d723d16a19db10bb32d6f6cadeecc49210483274ded0c5a99fe2c
3
+ size 6120
Utils_extend_v1/ASR/__pycache__/models.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7ffd90ec09013a70c9f89f5ac774aa28fbfec824ff6d2758d0dcdf178da2815
3
+ size 11262
Utils_extend_v1/ASR/config.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97bcbaad6c3198ef383a461374f4e88f495a7649d7b860b6088f09ada9e99ee8
3
+ size 481
Utils_extend_v1/ASR/epoch_00080.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fedd55a1234b0c56e1e8b509c74edf3a5e2f27106a66038a4a946047a775bd6c
3
+ size 94552811
Utils_extend_v1/ASR/epoch_extend_186.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5684b106b63ab8edd2ee534dca1f2d7639bc2f250b4a80e325b55e424231b123
3
+ size 31558302
Utils_extend_v1/ASR/layers.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f7335e967b23ae8571c421164681a6582284d4b3839900edc0237a408f34705
3
+ size 13454
Utils_extend_v1/ASR/model_struct.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e26691350743d0339c05a128805fc9792afe45504f2a4506687f52d28c31546
3
+ size 11444
Utils_extend_v1/ASR/models.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32fb38b3e45e3dbc25af8fdf6df45343ebdc2833aa8798049ee2c99559e8fa36
3
+ size 7272
Utils_extend_v1/JDC/.ipynb_checkpoints/model-checkpoint.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61ae501ab5a967e2988aa8a5498a482dfd7fdd82c040a251003636ec4f08b4aa
3
+ size 7649
Utils_extend_v1/JDC/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b
3
+ size 1
Utils_extend_v1/JDC/__pycache__/__init__.cpython-310.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:233c2596619ac16af181e97a37848ede369c76c51b3b619b851192b8b077b89e
3
+ size 154
Utils_extend_v1/JDC/__pycache__/__init__.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12e5ef68039465a2ad82f93971a04ad0beef8637206a8b92a49bfa81cf3bf518
3
+ size 153
Utils_extend_v1/JDC/__pycache__/model.cpython-310.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d3a9914f96138a6850c3bca221996e154dd354b68bbe43b701ebd19b0facc83
3
+ size 4782
Utils_extend_v1/JDC/__pycache__/model.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c98923cd7c469e5a0e861baf0175e1381c7657f6054bcb21f811c2f98ca200a5
3
+ size 9454