ibcplateformes Claude Opus 4.6 commited on
Commit
1a64668
·
1 Parent(s): 2af6a2f

Bundle Seed-VC modules directly in repo for ZeroGPU compatibility

Browse files

ZeroGPU workers run in separate processes without access to git-cloned
repos at runtime. Solution: include Seed-VC source modules (modules/,
configs/, hf_utils.py) directly in the repo.

- Rewrite inference.py following official app_svc.py implementation
- Use transformers.WhisperModel for speech tokenization (not custom module)
- Simplify setup.py (no more git clone needed)
- Add all necessary Seed-VC modules: commons, audio, rmvpe, campplus,
bigvgan, diffusion_transformer, flow_matching, length_regulator, etc.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/astral_quantization/default_2048.yml +40 -0
  2. configs/astral_quantization/default_32.yml +40 -0
  3. configs/config.json +1 -0
  4. configs/hifigan.yml +25 -0
  5. configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml +98 -0
  6. configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml +91 -0
  7. configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml +82 -0
  8. configs/v2/vc_wrapper.yaml +105 -0
  9. hf_utils.py +12 -0
  10. modules/__init__.py +0 -0
  11. modules/astral_quantization/__init__.py +0 -0
  12. modules/astral_quantization/bsq.py +569 -0
  13. modules/astral_quantization/convnext.py +209 -0
  14. modules/astral_quantization/default_model.py +73 -0
  15. modules/astral_quantization/transformer.py +254 -0
  16. modules/audio.py +82 -0
  17. modules/bigvgan/__init__.py +0 -0
  18. modules/bigvgan/activations.py +120 -0
  19. modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  20. modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  21. modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  22. modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  23. modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  24. modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  25. modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  26. modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  27. modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  28. modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  29. modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  30. modules/bigvgan/bigvgan.py +492 -0
  31. modules/bigvgan/config.json +63 -0
  32. modules/bigvgan/env.py +18 -0
  33. modules/bigvgan/meldataset.py +354 -0
  34. modules/bigvgan/utils.py +99 -0
  35. modules/campplus/DTDNN.py +138 -0
  36. modules/campplus/__init__.py +0 -0
  37. modules/campplus/classifier.py +70 -0
  38. modules/campplus/layers.py +267 -0
  39. modules/commons.py +476 -0
  40. modules/diffusion_transformer.py +537 -0
  41. modules/encodec.py +292 -0
  42. modules/flow_matching.py +167 -0
  43. modules/hifigan/__init__.py +0 -0
  44. modules/hifigan/f0_predictor.py +55 -0
  45. modules/hifigan/generator.py +454 -0
  46. modules/length_regulator.py +141 -0
  47. modules/rmvpe.py +637 -0
  48. modules/v2/__init__.py +0 -0
  49. modules/v2/ar.py +763 -0
  50. modules/v2/cfm.py +173 -0
configs/astral_quantization/default_2048.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: modules.astral_quantization.default_model.AstralQuantizer
2
+ tokenizer_name: "openai/whisper-small"
3
+ ssl_model_name: "facebook/hubert-large-ll60k"
4
+ ssl_output_layer: 18
5
+ encoder:
6
+ _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
7
+ dim: 512
8
+ num_blocks: 12
9
+ intermediate_dim: 1536
10
+ dilation: 1
11
+ input_dim: 1024
12
+ quantizer:
13
+ _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
14
+ codebook_size: 2048 # codebook size, must be a power of 2
15
+ dim: 512
16
+ entropy_loss_weight: 0.1
17
+ diversity_gamma: 1.0
18
+ spherical: True
19
+ enable_entropy_loss: True
20
+ soft_entropy_loss: True
21
+ decoder:
22
+ _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
23
+ dim: 512
24
+ num_blocks: 12
25
+ intermediate_dim: 1536
26
+ dilation: 1
27
+ output_dim: 1024
28
+ gin_channels: 192
29
+ asr_decoder:
30
+ _target_: modules.astral_quantization.asr_decoder.ASRDecoder
31
+ hidden_dim: 768
32
+ num_heads: 12
33
+ depth: 12
34
+ block_size: 4096
35
+ in_channels: 512
36
+ n_vocab: 51866
37
+ bos_id: 50528
38
+ eos_id: 50527
39
+ dropout_rate: 0.0
40
+ attn_dropout_rate: 0.0
configs/astral_quantization/default_32.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: default_model.AstralQuantizer
2
+ tokenizer_name: "openai/whisper-small"
3
+ ssl_model_name: "facebook/hubert-large-ll60k"
4
+ ssl_output_layer: 18
5
+ encoder:
6
+ _target_: modules.convnext.ConvNeXtV2Stage
7
+ dim: 512
8
+ num_blocks: 12
9
+ intermediate_dim: 1536
10
+ dilation: 1
11
+ input_dim: 1024
12
+ quantizer:
13
+ _target_: modules.bsq.BinarySphericalQuantize
14
+ codebook_size: 32 # codebook size, must be a power of 2
15
+ dim: 512
16
+ entropy_loss_weight: 0.1
17
+ diversity_gamma: 1.0
18
+ spherical: True
19
+ enable_entropy_loss: True
20
+ soft_entropy_loss: True
21
+ decoder:
22
+ _target_: modules.convnext.ConvNeXtV2Stage
23
+ dim: 512
24
+ num_blocks: 12
25
+ intermediate_dim: 1536
26
+ dilation: 1
27
+ output_dim: 1024
28
+ gin_channels: 192
29
+ asr_decoder:
30
+ _target_: modules.asr_decoder.ASRDecoder
31
+ hidden_dim: 768
32
+ num_heads: 12
33
+ depth: 12
34
+ block_size: 4096
35
+ in_channels: 512
36
+ n_vocab: 51866
37
+ bos_id: 50528
38
+ eos_id: 50527
39
+ dropout_rate: 0.0
40
+ attn_dropout_rate: 0.0
configs/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"reference_audio_path": "D:/FAcodec/test_waves/kobe_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS 2.4", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS 2.4", "sr_type": "sr_model", "diffusion_steps": 10.0, "inference_cfg_rate": 0.0, "max_prompt_length": 3.0, "block_time": 0.7, "crossfade_length": 0.04, "extra_time": 0.5, "extra_time_right": 0.02}
configs/hifigan.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hift:
2
+ in_channels: 80
3
+ base_channels: 512
4
+ nb_harmonics: 8
5
+ sampling_rate: 22050
6
+ nsf_alpha: 0.1
7
+ nsf_sigma: 0.003
8
+ nsf_voiced_threshold: 10
9
+ upsample_rates: [8, 8]
10
+ upsample_kernel_sizes: [16, 16]
11
+ istft_params:
12
+ n_fft: 16
13
+ hop_len: 4
14
+ resblock_kernel_sizes: [3, 7, 11]
15
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
16
+ source_resblock_kernel_sizes: [7, 11]
17
+ source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
18
+ lrelu_slope: 0.1
19
+ audio_limit: 0.99
20
+ f0_predictor:
21
+ num_class: 1
22
+ in_channels: 80
23
+ cond_channels: 512
24
+
25
+ pretrained_model_path: "checkpoints/hift.pt"
configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "./runs"
2
+ save_freq: 1
3
+ log_interval: 10
4
+ save_interval: 1000
5
+ device: "cuda"
6
+ epochs: 1000 # number of epochs for first stage training (pre-training)
7
+ batch_size: 1
8
+ batch_length: 100 # maximum duration of audio in a batch (in seconds)
9
+ max_len: 80 # maximum number of frames
10
+ pretrained_model: "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth"
11
+ pretrained_encoder: ""
12
+ load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ preprocess_params:
15
+ sr: 44100
16
+ spect_params:
17
+ n_fft: 2048
18
+ win_length: 2048
19
+ hop_length: 512
20
+ n_mels: 128
21
+ fmin: 0
22
+ fmax: "None"
23
+
24
+ model_params:
25
+ dit_type: "DiT" # uDiT or DiT
26
+ reg_loss_type: "l1" # l1 or l2
27
+
28
+ timbre_shifter:
29
+ se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
30
+ ckpt_path: './modules/openvoice/checkpoints_v2/converter'
31
+
32
+ vocoder:
33
+ type: "bigvgan"
34
+ name: "nvidia/bigvgan_v2_44khz_128band_512x"
35
+
36
+ speech_tokenizer:
37
+ type: 'whisper'
38
+ name: "openai/whisper-small"
39
+
40
+ style_encoder:
41
+ dim: 192
42
+ campplus_path: "campplus_cn_common.bin"
43
+
44
+ DAC:
45
+ encoder_dim: 64
46
+ encoder_rates: [2, 5, 5, 6]
47
+ decoder_dim: 1536
48
+ decoder_rates: [ 6, 5, 5, 2 ]
49
+ sr: 24000
50
+
51
+ length_regulator:
52
+ channels: 768
53
+ is_discrete: false
54
+ in_channels: 768
55
+ content_codebook_size: 2048
56
+ sampling_ratios: [1, 1, 1, 1]
57
+ vector_quantize: false
58
+ n_codebooks: 1
59
+ quantizer_dropout: 0.0
60
+ f0_condition: true
61
+ n_f0_bins: 256
62
+
63
+ DiT:
64
+ hidden_dim: 768
65
+ num_heads: 12
66
+ depth: 17
67
+ class_dropout_prob: 0.1
68
+ block_size: 8192
69
+ in_channels: 128
70
+ style_condition: true
71
+ final_layer_type: 'mlp'
72
+ target: 'mel' # mel or codec
73
+ content_dim: 768
74
+ content_codebook_size: 1024
75
+ content_type: 'discrete'
76
+ f0_condition: true
77
+ n_f0_bins: 256
78
+ content_codebooks: 1
79
+ is_causal: false
80
+ long_skip_connection: false
81
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
82
+ time_as_token: false
83
+ style_as_token: false
84
+ uvit_skip_connection: true
85
+ add_resblock_in_transformer: false
86
+
87
+ wavenet:
88
+ hidden_dim: 768
89
+ num_layers: 8
90
+ kernel_size: 5
91
+ dilation_rate: 1
92
+ p_dropout: 0.2
93
+ style_condition: true
94
+
95
+ loss_params:
96
+ base_lr: 0.0001
97
+ lambda_mel: 45
98
+ lambda_kl: 1.0
configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "./runs"
2
+ save_freq: 1
3
+ log_interval: 10
4
+ save_interval: 1000
5
+ device: "cuda"
6
+ epochs: 1000 # number of epochs for first stage training (pre-training)
7
+ batch_size: 2
8
+ batch_length: 100 # maximum duration of audio in a batch (in seconds)
9
+ max_len: 80 # maximum number of frames
10
+ pretrained_model: "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth"
11
+ pretrained_encoder: ""
12
+ load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ preprocess_params:
15
+ sr: 22050
16
+ spect_params:
17
+ n_fft: 1024
18
+ win_length: 1024
19
+ hop_length: 256
20
+ n_mels: 80
21
+ fmin: 0
22
+ fmax: "None"
23
+
24
+ model_params:
25
+ dit_type: "DiT" # uDiT or DiT
26
+ reg_loss_type: "l1" # l1 or l2
27
+
28
+ timbre_shifter:
29
+ se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
30
+ ckpt_path: './modules/openvoice/checkpoints_v2/converter'
31
+
32
+ speech_tokenizer:
33
+ type: 'whisper'
34
+ name: "openai/whisper-small"
35
+
36
+ style_encoder:
37
+ dim: 192
38
+ campplus_path: "campplus_cn_common.bin"
39
+
40
+ vocoder:
41
+ type: "bigvgan"
42
+ name: "nvidia/bigvgan_v2_22khz_80band_256x"
43
+
44
+ length_regulator:
45
+ channels: 512
46
+ is_discrete: false
47
+ in_channels: 768
48
+ content_codebook_size: 2048
49
+ sampling_ratios: [1, 1, 1, 1]
50
+ vector_quantize: false
51
+ n_codebooks: 1
52
+ quantizer_dropout: 0.0
53
+ f0_condition: false
54
+ n_f0_bins: 512
55
+
56
+ DiT:
57
+ hidden_dim: 512
58
+ num_heads: 8
59
+ depth: 13
60
+ class_dropout_prob: 0.1
61
+ block_size: 8192
62
+ in_channels: 80
63
+ style_condition: true
64
+ final_layer_type: 'wavenet'
65
+ target: 'mel' # mel or codec
66
+ content_dim: 512
67
+ content_codebook_size: 1024
68
+ content_type: 'discrete'
69
+ f0_condition: false
70
+ n_f0_bins: 512
71
+ content_codebooks: 1
72
+ is_causal: false
73
+ long_skip_connection: true
74
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
75
+ time_as_token: false
76
+ style_as_token: false
77
+ uvit_skip_connection: true
78
+ add_resblock_in_transformer: false
79
+
80
+ wavenet:
81
+ hidden_dim: 512
82
+ num_layers: 8
83
+ kernel_size: 5
84
+ dilation_rate: 1
85
+ p_dropout: 0.2
86
+ style_condition: true
87
+
88
+ loss_params:
89
+ base_lr: 0.0001
90
+ lambda_mel: 45
91
+ lambda_kl: 1.0
configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "./runs/"
2
+ save_freq: 1
3
+ log_interval: 10
4
+ save_interval: 500
5
+ device: "cuda"
6
+ epochs: 1000 # number of epochs for first stage training (pre-training)
7
+ batch_size: 2
8
+ batch_length: 100 # maximum duration of audio in a batch (in seconds)
9
+ max_len: 80 # maximum number of frames
10
+ pretrained_model: "DiT_uvit_tat_xlsr_ema.pth"
11
+ pretrained_encoder: ""
12
+ load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ preprocess_params:
15
+ sr: 22050
16
+ spect_params:
17
+ n_fft: 1024
18
+ win_length: 1024
19
+ hop_length: 256
20
+ n_mels: 80
21
+ fmin: 0
22
+ fmax: 8000
23
+
24
+ model_params:
25
+ dit_type: "DiT" # uDiT or DiT
26
+ reg_loss_type: "l1" # l1 or l2
27
+ diffusion_type: "flow"
28
+
29
+ timbre_shifter:
30
+ se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
31
+ ckpt_path: './modules/openvoice/checkpoints_v2/converter'
32
+
33
+ vocoder:
34
+ type: "hifigan"
35
+
36
+ speech_tokenizer:
37
+ type: 'xlsr'
38
+ output_layer: 12
39
+ name: 'facebook/wav2vec2-xls-r-300m'
40
+
41
+ style_encoder:
42
+ dim: 192
43
+ campplus_path: "campplus_cn_common.bin"
44
+
45
+ length_regulator:
46
+ channels: 384
47
+ is_discrete: false
48
+ in_channels: 1024
49
+ content_codebook_size: 1024
50
+ sampling_ratios: [1, 1, 1, 1]
51
+ vector_quantize: false
52
+ n_codebooks: 2
53
+ quantizer_dropout: 0.0
54
+ f0_condition: false
55
+ n_f0_bins: 512
56
+
57
+ DiT:
58
+ hidden_dim: 384
59
+ num_heads: 6
60
+ depth: 9
61
+ class_dropout_prob: 0.1
62
+ block_size: 8192
63
+ in_channels: 80
64
+ style_condition: true
65
+ final_layer_type: 'mlp'
66
+ target: 'mel' # mel or betavae
67
+ content_dim: 384
68
+ content_codebook_size: 1024
69
+ content_type: 'discrete'
70
+ f0_condition: false
71
+ n_f0_bins: 512
72
+ content_codebooks: 1
73
+ is_causal: false
74
+ long_skip_connection: false
75
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
76
+ time_as_token: true
77
+ style_as_token: true
78
+ uvit_skip_connection: true
79
+ add_resblock_in_transformer: false
80
+
81
+ loss_params:
82
+ base_lr: 0.0001
configs/v2/vc_wrapper.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: modules.v2.vc_wrapper.VoiceConversionWrapper
2
+ sr: 22050
3
+ hop_size: 256
4
+ mel_fn:
5
+ _target_: modules.audio.mel_spectrogram
6
+ _partial_: true
7
+ n_fft: 1024
8
+ win_size: 1024
9
+ hop_size: 256
10
+ num_mels: 80
11
+ sampling_rate: 22050
12
+ fmin: 0
13
+ fmax: null
14
+ center: False
15
+ cfm:
16
+ _target_: modules.v2.cfm.CFM
17
+ estimator:
18
+ _target_: modules.v2.dit_wrapper.DiT
19
+ time_as_token: true
20
+ style_as_token: true
21
+ uvit_skip_connection: false
22
+ block_size: 8192
23
+ depth: 13
24
+ num_heads: 8
25
+ hidden_dim: 512
26
+ in_channels: 80
27
+ content_dim: 512
28
+ style_encoder_dim: 192
29
+ class_dropout_prob: 0.1
30
+ dropout_rate: 0.0
31
+ attn_dropout_rate: 0.0
32
+ cfm_length_regulator:
33
+ _target_: modules.v2.length_regulator.InterpolateRegulator
34
+ channels: 512
35
+ is_discrete: true
36
+ codebook_size: 2048
37
+ sampling_ratios: [ 1, 1, 1, 1 ]
38
+ f0_condition: false
39
+ ar:
40
+ _target_: modules.v2.ar.NaiveWrapper
41
+ model:
42
+ _target_: modules.v2.ar.NaiveTransformer
43
+ config:
44
+ _target_: modules.v2.ar.NaiveModelArgs
45
+ dropout: 0.0
46
+ rope_base: 10000.0
47
+ dim: 768
48
+ head_dim: 64
49
+ n_local_heads: 2
50
+ intermediate_size: 2304
51
+ n_head: 12
52
+ n_layer: 12
53
+ vocab_size: 2049 # 1 + 1 for eos
54
+ ar_length_regulator:
55
+ _target_: modules.v2.length_regulator.InterpolateRegulator
56
+ channels: 768
57
+ is_discrete: true
58
+ codebook_size: 32
59
+ sampling_ratios: [ ]
60
+ f0_condition: false
61
+ style_encoder:
62
+ _target_: modules.campplus.DTDNN.CAMPPlus
63
+ feat_dim: 80
64
+ embedding_size: 192
65
+ content_extractor_narrow:
66
+ _target_: modules.astral_quantization.default_model.AstralQuantizer
67
+ tokenizer_name: "openai/whisper-small"
68
+ ssl_model_name: "facebook/hubert-large-ll60k"
69
+ ssl_output_layer: 18
70
+ skip_ssl: true
71
+ encoder: &bottleneck_encoder
72
+ _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
73
+ dim: 512
74
+ num_blocks: 12
75
+ intermediate_dim: 1536
76
+ dilation: 1
77
+ input_dim: 1024
78
+ quantizer:
79
+ _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
80
+ codebook_size: 32 # codebook size, must be a power of 2
81
+ dim: 512
82
+ entropy_loss_weight: 0.1
83
+ diversity_gamma: 1.0
84
+ spherical: True
85
+ enable_entropy_loss: True
86
+ soft_entropy_loss: True
87
+ content_extractor_wide:
88
+ _target_: modules.astral_quantization.default_model.AstralQuantizer
89
+ tokenizer_name: "openai/whisper-small"
90
+ ssl_model_name: "facebook/hubert-large-ll60k"
91
+ ssl_output_layer: 18
92
+ encoder: *bottleneck_encoder
93
+ quantizer:
94
+ _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
95
+ codebook_size: 2048 # codebook size, must be a power of 2
96
+ dim: 512
97
+ entropy_loss_weight: 0.1
98
+ diversity_gamma: 1.0
99
+ spherical: True
100
+ enable_entropy_loss: True
101
+ soft_entropy_loss: True
102
+ vocoder:
103
+ _target_: modules.bigvgan.bigvgan.BigVGAN.from_pretrained
104
+ pretrained_model_name_or_path: "nvidia/bigvgan_v2_22khz_80band_256x"
105
+ use_cuda_kernel: false
hf_utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+
4
+
5
+ def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename=None):
6
+ os.makedirs("./checkpoints", exist_ok=True)
7
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
8
+ if config_filename is None:
9
+ return model_path
10
+ config_path = hf_hub_download(repo_id=repo_id, filename=config_filename, cache_dir="./checkpoints")
11
+
12
+ return model_path, config_path
modules/__init__.py ADDED
File without changes
modules/astral_quantization/__init__.py ADDED
File without changes
modules/astral_quantization/bsq.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lookup Free Quantization
3
+ Proposed in https://arxiv.org/abs/2310.05737
4
+
5
+ In the simplest setup, each dimension is quantized into {-1, 1}.
6
+ An entropy penalty is used to encourage utilization.
7
+ """
8
+
9
+ from math import log2, ceil
10
+ from functools import partial, cache
11
+ from collections import namedtuple
12
+ from contextlib import nullcontext
13
+
14
+ import torch.distributed as dist
15
+ from torch.distributed import nn as dist_nn
16
+
17
+ import torch
18
+ from torch import nn, einsum
19
+ import torch.nn.functional as F
20
+ from torch.nn import Module
21
+ from torch.amp import autocast
22
+
23
+ from einops import rearrange, reduce, pack, unpack
24
+
25
+ # constants
26
+
27
+ Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
28
+
29
+ LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
30
+
31
+ # distributed helpers
32
+
33
+ @cache
34
+ def is_distributed():
35
+ return dist.is_initialized() and dist.get_world_size() > 1
36
+
37
+ def maybe_distributed_mean(t):
38
+ if not is_distributed():
39
+ return t
40
+
41
+ dist_nn.all_reduce(t)
42
+ t = t / dist.get_world_size()
43
+ return t
44
+
45
+ # helper functions
46
+
47
+ def exists(v):
48
+ return v is not None
49
+
50
+ def identity(t):
51
+ return t
52
+
53
+ def default(*args):
54
+ for arg in args:
55
+ if exists(arg):
56
+ return arg() if callable(arg) else arg
57
+ return None
58
+
59
+ def pack_one(t, pattern):
60
+ return pack([t], pattern)
61
+
62
+ def unpack_one(t, ps, pattern):
63
+ return unpack(t, ps, pattern)[0]
64
+
65
+ def l2norm(t):
66
+ return F.normalize(t, dim = -1)
67
+
68
+ # entropy
69
+
70
+ def log(t, eps = 1e-5):
71
+ return t.clamp(min = eps).log()
72
+
73
+ def entropy(prob):
74
+ return (-prob * log(prob)).sum(dim=-1)
75
+
76
+ # cosine sim linear
77
+
78
+ class CosineSimLinear(Module):
79
+ def __init__(
80
+ self,
81
+ dim_in,
82
+ dim_out,
83
+ scale = 1.
84
+ ):
85
+ super().__init__()
86
+ self.scale = scale
87
+ self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
88
+
89
+ def forward(self, x):
90
+ x = F.normalize(x, dim = -1)
91
+ w = F.normalize(self.weight, dim = 0)
92
+ return (x @ w) * self.scale
93
+
94
+ def soft_entropy_loss(u, tau=1.0, gamma=1.0):
95
+ """
96
+ Compute the soft entropy loss for Binary Spherical Quantization (BSQ).
97
+
98
+ Args:
99
+ u (torch.Tensor): Input latent embeddings of shape (batch_size, L).
100
+ tau (float): Temperature scaling factor.
101
+ gamma (float): Weight for the second entropy term.
102
+
103
+ Returns:
104
+ torch.Tensor: Soft entropy loss.
105
+ """
106
+ # Binary quantization: Generate implicit codebook corners
107
+ L = u.size(1) # Dimensionality of codebook
108
+ corners = torch.tensor([-1.0, 1.0], device=u.device) / (L**0.5)
109
+
110
+ # Compute soft quantization probabilities for all dimensions
111
+ # q_hat(c|u) for each dimension
112
+ prob_matrix = torch.sigmoid(2 * tau * corners.unsqueeze(1) * u.unsqueeze(2)) # Shape: (batch_size, L, 2)
113
+
114
+ # Entropy of q_hat(c|u) (independent along each dimension)
115
+ entropy_per_dim = -torch.sum(prob_matrix * prob_matrix.log(), dim=-1) # Shape: (batch_size, L)
116
+ entropy_term1 = entropy_per_dim.mean()
117
+
118
+ # Expected probabilities for dataset entropy (approximation)
119
+ expected_probs = prob_matrix.mean(dim=0) # Mean across batch, shape: (L, 2)
120
+ entropy_term2 = -torch.sum(expected_probs * expected_probs.log(), dim=-1).mean()
121
+
122
+ # Final entropy loss
123
+ loss = entropy_term1 - gamma * entropy_term2
124
+ return loss
125
+
126
+ # class
127
+
128
+ class BinarySphericalQuantize(Module):
129
+ def __init__(
130
+ self,
131
+ *,
132
+ dim = None,
133
+ codebook_size = None,
134
+ entropy_loss_weight = 0.1,
135
+ commitment_loss_weight = 0.,
136
+ diversity_gamma = 1.,
137
+ straight_through_activation = nn.Identity(),
138
+ num_codebooks = 1,
139
+ keep_num_codebooks_dim = None,
140
+ codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
141
+ frac_per_sample_entropy = 0.25, # make less than 1. to only use a random fraction of the probs for per sample entropy
142
+ has_projections = None,
143
+ projection_has_bias = True,
144
+ soft_clamp_input_value = None,
145
+ cosine_sim_project_in = False,
146
+ cosine_sim_project_in_scale = None,
147
+ channel_first = None,
148
+ experimental_softplus_entropy_loss = False,
149
+ entropy_loss_offset = 5., # how much to shift the loss before softplus
150
+ spherical = True, # from https://arxiv.org/abs/2406.07548
151
+ force_quantization_f32 = True, # will force the quantization step to be full precision
152
+ enable_entropy_loss = True,
153
+ soft_entropy_loss = True,
154
+ ):
155
+ super().__init__()
156
+
157
+ # some assert validations
158
+
159
+ assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
160
+ assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
161
+
162
+ codebook_size = default(codebook_size, lambda: 2 ** dim)
163
+ self.codebook_size = codebook_size
164
+
165
+ codebook_dim = int(log2(codebook_size))
166
+ codebook_dims = codebook_dim * num_codebooks
167
+ dim = default(dim, codebook_dims)
168
+
169
+ has_projections = default(has_projections, dim != codebook_dims)
170
+
171
+ if cosine_sim_project_in:
172
+ cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
173
+ project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
174
+ else:
175
+ project_in_klass = partial(nn.Linear, bias = projection_has_bias)
176
+
177
+ self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
178
+ self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
179
+ self.has_projections = has_projections
180
+
181
+ self.dim = dim
182
+ self.codebook_dim = codebook_dim
183
+ self.num_codebooks = num_codebooks
184
+
185
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
186
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
187
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
188
+
189
+ # channel first
190
+
191
+ self.channel_first = channel_first
192
+
193
+ # straight through activation
194
+
195
+ self.activation = straight_through_activation
196
+
197
+ # whether to use BSQ (binary spherical quantization)
198
+
199
+ self.spherical = spherical
200
+ self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
201
+
202
+ # entropy aux loss related weights
203
+
204
+ assert 0 < frac_per_sample_entropy <= 1.
205
+ self.frac_per_sample_entropy = frac_per_sample_entropy
206
+
207
+ self.diversity_gamma = diversity_gamma
208
+ self.entropy_loss_weight = entropy_loss_weight
209
+
210
+ # codebook scale
211
+
212
+ self.codebook_scale = codebook_scale
213
+
214
+ # commitment loss
215
+
216
+ self.commitment_loss_weight = commitment_loss_weight
217
+
218
+ # whether to soft clamp the input value from -value to value
219
+
220
+ self.soft_clamp_input_value = soft_clamp_input_value
221
+ assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
222
+
223
+ # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
224
+
225
+ self.entropy_loss_offset = entropy_loss_offset
226
+ self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
227
+
228
+ # for no auxiliary loss, during inference
229
+
230
+ self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
231
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
232
+
233
+ # whether to force quantization step to be f32
234
+
235
+ self.force_quantization_f32 = force_quantization_f32
236
+
237
+ # codes
238
+ self.enable_entropy_loss = enable_entropy_loss
239
+ self.soft_entropy_loss = soft_entropy_loss
240
+ if codebook_size <= 100000:
241
+ all_codes = torch.arange(codebook_size)
242
+ bits = ((all_codes[..., None].int() & self.mask) != 0).float()
243
+ codebook = self.bits_to_codes(bits)
244
+
245
+ self.register_buffer('codebook', codebook.float(), persistent = False)
246
+ else:
247
+ all_codes = torch.arange(pow(2, 16))
248
+ mask = 2 ** torch.arange(16 - 1, -1, -1)
249
+ bits = ((all_codes[..., None].int() & mask) != 0).float()
250
+ codebook = self.bits_to_codes(bits)
251
+
252
+ self.register_buffer('codebook', codebook.float(), persistent = False)
253
+
254
+ def bits_to_codes(self, bits):
255
+ return bits * self.codebook_scale * 2 - self.codebook_scale
256
+
257
+ @property
258
+ def dtype(self):
259
+ return self.codebook.dtype
260
+
261
+ def indices_to_codes(
262
+ self,
263
+ indices,
264
+ project_out = True
265
+ ):
266
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
267
+ should_transpose = default(self.channel_first, is_img_or_video)
268
+
269
+ if not self.keep_num_codebooks_dim:
270
+ indices = rearrange(indices, '... -> ... 1')
271
+
272
+ # indices to codes, which are bits of either -1 or 1
273
+
274
+ bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
275
+
276
+ codes = self.bits_to_codes(bits)
277
+
278
+ codes = self.maybe_l2norm(codes)
279
+
280
+ codes = rearrange(codes, '... c d -> ... (c d)')
281
+
282
+ # whether to project codes out to original dimensions
283
+ # if the input feature dimensions were not log2(codebook size)
284
+
285
+ if project_out:
286
+ codes = self.project_out(codes)
287
+
288
+ # rearrange codes back to original shape
289
+
290
+ if should_transpose:
291
+ codes = rearrange(codes, 'b ... d -> b d ...')
292
+
293
+ return codes
294
+
295
+ def bits_to_z(self, bits):
296
+ # assert bits must contain only -1 and 1
297
+ assert torch.all(bits.abs() == 1)
298
+ quantized = bits.float()
299
+ quantized = self.maybe_l2norm(quantized)
300
+ z = self.project_out(quantized)
301
+ return z
302
+
303
+ def forward(
304
+ self,
305
+ x,
306
+ inv_temperature = 100.,
307
+ return_loss_breakdown = False,
308
+ mask = None,
309
+ return_bits = False
310
+ ):
311
+ """
312
+ einstein notation
313
+ b - batch
314
+ n - sequence (or flattened spatial dimensions)
315
+ d - feature dimension, which is also log2(codebook size)
316
+ c - number of codebook dim
317
+ """
318
+
319
+ is_img_or_video = x.ndim >= 4
320
+ should_transpose = default(self.channel_first, is_img_or_video)
321
+
322
+ # standardize image or video into (batch, seq, dimension)
323
+
324
+ if should_transpose:
325
+ x = rearrange(x, 'b d ... -> b ... d')
326
+ x, ps = pack_one(x, 'b * d')
327
+
328
+ assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
329
+
330
+ x = self.project_in(x)
331
+
332
+ # maybe soft clamp
333
+
334
+ if exists(self.soft_clamp_input_value):
335
+ clamp_value = self.soft_clamp_input_value
336
+ x = (x / clamp_value).tanh() * clamp_value
337
+
338
+ # split out number of codebooks
339
+
340
+ x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
341
+
342
+ # maybe l2norm
343
+
344
+ x = self.maybe_l2norm(x)
345
+
346
+ # whether to force quantization step to be full precision or not
347
+
348
+ force_f32 = self.force_quantization_f32
349
+
350
+ quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
351
+
352
+ with quantization_context():
353
+
354
+ if force_f32:
355
+ orig_dtype = x.dtype
356
+ x = x.float()
357
+
358
+ # quantize by eq 3.
359
+
360
+ original_input = x
361
+
362
+ codebook_value = torch.ones_like(x) * self.codebook_scale
363
+ quantized = torch.where(x > 0, codebook_value, -codebook_value)
364
+ if return_bits:
365
+ return quantized
366
+
367
+ # calculate indices
368
+
369
+ indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
370
+
371
+ # maybe l2norm
372
+
373
+ quantized = self.maybe_l2norm(quantized)
374
+
375
+ # use straight-through gradients (optionally with custom activation fn) if training
376
+
377
+ if self.training:
378
+ x = self.activation(x)
379
+ x = x + (quantized - x).detach()
380
+ else:
381
+ x = quantized
382
+
383
+ # entropy aux loss
384
+ if self.soft_entropy_loss:
385
+ entropy_aux_loss = soft_entropy_loss(x, tau=1.0, gamma=1.0)
386
+ elif self.training and self.enable_entropy_loss:
387
+
388
+ if force_f32:
389
+ codebook = self.codebook.float()
390
+
391
+ codebook = self.maybe_l2norm(codebook)
392
+
393
+ # whether to only use a fraction of probs, for reducing memory
394
+
395
+ if self.frac_per_sample_entropy < 1.:
396
+ # account for mask
397
+ if exists(mask):
398
+ original_input = original_input[mask]
399
+ original_input = rearrange(original_input, 'b n ... -> (b n) ...')
400
+
401
+ rand_mask = torch.randn(self.codebook_dim).argsort(dim = -1) < 16
402
+
403
+ sampled_input = original_input[..., rand_mask]
404
+
405
+ sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
406
+
407
+ sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
408
+
409
+ per_sample_probs = sampled_prob
410
+ else:
411
+ if exists(mask):
412
+ original_input = original_input[mask]
413
+ original_input = rearrange(original_input, 'b n ... -> (b n) ...')
414
+ # the same as euclidean distance up to a constant
415
+ distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
416
+
417
+ prob = (-distance * inv_temperature).softmax(dim = -1)
418
+
419
+ per_sample_probs = prob
420
+
421
+ # calculate per sample entropy
422
+
423
+ per_sample_entropy = entropy(per_sample_probs).mean()
424
+
425
+ # distribution over all available tokens in the batch
426
+
427
+ avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
428
+
429
+ avg_prob = maybe_distributed_mean(avg_prob)
430
+
431
+ codebook_entropy = entropy(avg_prob).mean()
432
+
433
+ # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
434
+ # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
435
+
436
+ entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
437
+ else:
438
+ # if not training, just return dummy 0
439
+ entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
440
+
441
+ # whether to make the entropy loss positive or not through a (shifted) softplus
442
+
443
+ if self.training and self.experimental_softplus_entropy_loss:
444
+ entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
445
+
446
+ # commit loss
447
+
448
+ if self.training and self.commitment_loss_weight > 0.:
449
+
450
+ commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
451
+
452
+ if exists(mask):
453
+ commit_loss = commit_loss[mask]
454
+
455
+ commit_loss = commit_loss.mean()
456
+ else:
457
+ commit_loss = self.zero
458
+
459
+ # input back to original dtype if needed
460
+
461
+ if force_f32:
462
+ x = x.type(orig_dtype)
463
+
464
+ # merge back codebook dim
465
+
466
+ x = rearrange(x, 'b n c d -> b n (c d)')
467
+
468
+ # project out to feature dimension if needed
469
+
470
+ x = self.project_out(x)
471
+
472
+ # reconstitute image or video dimensions
473
+
474
+ if should_transpose:
475
+ x = unpack_one(x, ps, 'b * d')
476
+ x = rearrange(x, 'b ... d -> b d ...')
477
+
478
+ indices = unpack_one(indices, ps, 'b * c')
479
+
480
+ # whether to remove single codebook dim
481
+
482
+ if not self.keep_num_codebooks_dim:
483
+ indices = rearrange(indices, '... 1 -> ...')
484
+
485
+ # complete aux loss
486
+
487
+ aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
488
+
489
+ # returns
490
+
491
+ ret = Return(x, indices, aux_loss)
492
+
493
+ if not return_loss_breakdown:
494
+ return ret
495
+
496
+ return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
497
+
498
+ class GroupedResidualBSQ(Module):
499
+ def __init__(
500
+ self,
501
+ *,
502
+ dim,
503
+ groups = 1,
504
+ accept_image_fmap = False,
505
+ **kwargs
506
+ ):
507
+ super().__init__()
508
+ self.dim = dim
509
+ self.groups = groups
510
+ assert (dim % groups) == 0
511
+ dim_per_group = dim // groups
512
+
513
+ self.accept_image_fmap = accept_image_fmap
514
+
515
+ self.rvqs = nn.ModuleList([])
516
+
517
+ for _ in range(groups):
518
+ self.rvqs.append(LFQ(
519
+ dim = dim_per_group,
520
+ **kwargs
521
+ ))
522
+
523
+ self.codebook_size = self.rvqs[0].codebook_size
524
+
525
+ @property
526
+ def codebooks(self):
527
+ return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
528
+
529
+ @property
530
+ def split_dim(self):
531
+ return 1 if self.accept_image_fmap else -1
532
+
533
+ def get_codes_from_indices(self, indices):
534
+ codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
535
+ return torch.stack(codes)
536
+
537
+ def get_output_from_indices(self, indices):
538
+ outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
539
+ return torch.cat(outputs, dim = self.split_dim)
540
+
541
+ def forward(
542
+ self,
543
+ x,
544
+ return_all_codes = False
545
+ ):
546
+ shape, split_dim = x.shape, self.split_dim
547
+ assert shape[split_dim] == self.dim
548
+
549
+ # split the feature dimension into groups
550
+
551
+ x = x.chunk(self.groups, dim = split_dim)
552
+
553
+ forward_kwargs = dict(
554
+ )
555
+
556
+ # invoke residual vq on each group
557
+
558
+ out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
559
+ out = tuple(zip(*out))
560
+
561
+ # otherwise, get all the zipped outputs and combine them
562
+
563
+ quantized, all_indices, *maybe_aux_loss = out
564
+
565
+ quantized = torch.cat(quantized, dim = split_dim)
566
+ all_indices = torch.stack(all_indices)
567
+
568
+ ret = (quantized, all_indices, *maybe_aux_loss)
569
+ return ret
modules/astral_quantization/convnext.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import List
5
+
6
+
7
+ class ConvNextV2LayerNorm(nn.Module):
8
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
9
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
10
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
11
+ """
12
+
13
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
14
+ super().__init__()
15
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
16
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
17
+ self.eps = eps
18
+ self.data_format = data_format
19
+ if self.data_format not in ["channels_last", "channels_first"]:
20
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
21
+ self.normalized_shape = (normalized_shape,)
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ if self.data_format == "channels_last":
25
+ x = torch.nn.functional.layer_norm(
26
+ x, self.normalized_shape, self.weight, self.bias, self.eps
27
+ )
28
+ elif self.data_format == "channels_first":
29
+ input_dtype = x.dtype
30
+ x = x.float()
31
+ u = x.mean(1, keepdim=True)
32
+ s = (x - u).pow(2).mean(1, keepdim=True)
33
+ x = (x - u) / torch.sqrt(s + self.eps)
34
+ x = x.to(dtype=input_dtype)
35
+ x = self.weight[None, :, None] * x + self.bias[None, :, None]
36
+ return x
37
+
38
+
39
+ class GRN(nn.Module):
40
+ def __init__(self, dim):
41
+ super().__init__()
42
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
43
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
44
+
45
+ def forward(self, x):
46
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
47
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
48
+ return self.gamma * (x * Nx) + self.beta + x
49
+
50
+ class InterpolationLayer(nn.Module):
51
+ def __init__(self, ): # this is a default of 1 / 50 * (44100 / 512) / 4
52
+ super().__init__()
53
+ pass
54
+
55
+ def forward(self, x: torch.Tensor, target_len: torch.Tensor, *args, **kwargs) -> torch.Tensor:
56
+ x = F.interpolate(x, size=target_len, mode='linear')
57
+ return x
58
+
59
+ class ConvNeXtV2Stage(nn.Module):
60
+ def __init__(
61
+ self,
62
+ dim: int = 512,
63
+ intermediate_dim: int = 2048,
64
+ num_blocks: int = 1,
65
+ dilation: int = 1,
66
+ downsample_layer_indices: List[int] = None,
67
+ downsample_factors: List[int] = None,
68
+ upsample_layer_indices: List[int] = None,
69
+ upsample_factors: List[int] = None,
70
+ interpolation_layer_indices: List[int] = None,
71
+ input_dim: int = None,
72
+ output_dim: int = None,
73
+ gin_channels: int = 0,
74
+ ):
75
+ super().__init__()
76
+ # maybe downsample layers
77
+ if downsample_layer_indices is not None:
78
+ assert downsample_factors is not None
79
+ self.downsample_blocks = nn.ModuleList(
80
+ [
81
+ nn.Sequential(
82
+ ConvNextV2LayerNorm(dim, data_format="channels_first"),
83
+ nn.Conv1d(
84
+ dim, dim, kernel_size=downsample_factor, stride=downsample_factor
85
+ ),
86
+ ) for _, downsample_factor in zip(downsample_layer_indices, downsample_factors)
87
+ ]
88
+ )
89
+ self.downsample_layer_indices = downsample_layer_indices
90
+ else:
91
+ self.downsample_blocks = nn.ModuleList()
92
+ self.downsample_layer_indices = []
93
+
94
+ # maybe upsample layers
95
+ if upsample_layer_indices is not None:
96
+ assert upsample_factors is not None
97
+ self.upsample_blocks = nn.ModuleList(
98
+ [
99
+ nn.Sequential(
100
+ ConvNextV2LayerNorm(dim, data_format="channels_first"),
101
+ nn.ConvTranspose1d(
102
+ dim, dim, kernel_size=upsample_factor, stride=upsample_factor
103
+ ),
104
+ ) for _, upsample_factor in zip(upsample_layer_indices, upsample_factors)
105
+ ]
106
+ )
107
+ self.upsample_layer_indices = upsample_layer_indices
108
+ else:
109
+ self.upsample_blocks = nn.ModuleList()
110
+ self.upsample_layer_indices = []
111
+
112
+ # maybe interpolation layers
113
+ if interpolation_layer_indices is not None:
114
+ self.interpolation_blocks = nn.ModuleList(
115
+ [
116
+ InterpolationLayer()
117
+ for _ in interpolation_layer_indices
118
+ ]
119
+ )
120
+ self.interpolation_layer_indices = interpolation_layer_indices
121
+ else:
122
+ self.interpolation_blocks = nn.ModuleList()
123
+ self.interpolation_layer_indices = []
124
+
125
+ # main blocks
126
+ self.blocks = nn.ModuleList(
127
+ [
128
+ ConvNeXtV2Block(
129
+ dim=dim,
130
+ intermediate_dim=intermediate_dim,
131
+ dilation=dilation,
132
+ )
133
+ for _ in range(num_blocks)
134
+ ]
135
+ )
136
+ # maybe input and output projections
137
+ if input_dim is not None and input_dim != dim:
138
+ self.input_projection = nn.Conv1d(input_dim, dim, kernel_size=1)
139
+ else:
140
+ self.input_projection = nn.Identity()
141
+ if output_dim is not None and output_dim != dim:
142
+ self.output_projection = nn.Conv1d(dim, output_dim, kernel_size=1)
143
+ else:
144
+ self.output_projection = nn.Identity()
145
+
146
+ if gin_channels > 0:
147
+ self.gin = nn.Conv1d(gin_channels, dim, kernel_size=1)
148
+
149
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
150
+ x = self.input_projection(x) # B, D, T
151
+ if hasattr(self, 'gin'):
152
+ g = kwargs['g']
153
+ x = x + self.gin(g)
154
+ # pad to a multiple of cumprod(downsample_factors)
155
+ if len(self.downsample_blocks) > 0:
156
+ downsample_factor = 1
157
+ for factor in self.downsample_blocks:
158
+ downsample_factor *= factor[1].stride[0]
159
+ pad_len = downsample_factor - x.size(-1) % downsample_factor
160
+ if pad_len > 0:
161
+ x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1)
162
+
163
+ # main blocks
164
+ for layer_idx, block in enumerate(self.blocks):
165
+ if layer_idx in self.downsample_layer_indices:
166
+ x = self.downsample_blocks[self.downsample_layer_indices.index(layer_idx)](x)
167
+ if layer_idx in self.upsample_layer_indices:
168
+ x = self.upsample_blocks[self.upsample_layer_indices.index(layer_idx)](x)
169
+ if layer_idx in self.interpolation_layer_indices:
170
+ x = self.interpolation_blocks[self.interpolation_layer_indices.index(layer_idx)](x, target_len=kwargs['target_len'])
171
+ x = block(x)
172
+ x = self.output_projection(x)
173
+ return x
174
+
175
+ def setup_caches(self, *args, **kwargs):
176
+ pass
177
+
178
+
179
+ class ConvNeXtV2Block(nn.Module):
180
+ def __init__(
181
+ self,
182
+ dim: int,
183
+ intermediate_dim: int,
184
+ dilation: int = 1,
185
+ ):
186
+ super().__init__()
187
+ padding = (dilation * (7 - 1)) // 2
188
+ self.dwconv = nn.Conv1d(
189
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
190
+ ) # depthwise conv
191
+ self.norm = ConvNextV2LayerNorm(dim, data_format="channels_first")
192
+ self.pwconv1 = nn.Linear(
193
+ dim, intermediate_dim
194
+ ) # pointwise/1x1 convs, implemented with linear layers
195
+ self.act = nn.GELU()
196
+ self.grn = GRN(intermediate_dim)
197
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
198
+
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ residual = x
201
+ x = self.dwconv(x)
202
+ x = self.norm(x)
203
+ x = x.transpose(1, 2) # b d n -> b n d
204
+ x = self.pwconv1(x)
205
+ x = self.act(x)
206
+ x = self.grn(x)
207
+ x = self.pwconv2(x)
208
+ x = x.transpose(1, 2) # b n d -> b d n
209
+ return residual + x
modules/astral_quantization/default_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor
3
+
4
+ class AstralQuantizer(torch.nn.Module):
5
+ def __init__(
6
+ self,
7
+ tokenizer_name: str,
8
+ ssl_model_name: str,
9
+ ssl_output_layer: int,
10
+ encoder: torch.nn.Module,
11
+ quantizer: torch.nn.Module,
12
+ skip_ssl: bool = False,
13
+ ):
14
+ super().__init__()
15
+ self.encoder = encoder
16
+ self.quantizer = quantizer
17
+ self.tokenizer_name = tokenizer_name
18
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
19
+
20
+ # Load SSL model from Huggingface
21
+ self.ssl_model_name = ssl_model_name
22
+ self.ssl_output_layer = ssl_output_layer
23
+ self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name)
24
+
25
+ if skip_ssl: # in case the same SSL model has been loaded somewhere else
26
+ self.ssl_model = None
27
+ else:
28
+ self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval()
29
+ self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer]
30
+ self.ssl_model.encoder.layer_norm = torch.nn.Identity()
31
+
32
+ def load_separate_checkpoint(self, checkpoint_path):
33
+ params = torch.load(checkpoint_path, map_location='cpu')['net']
34
+ for key in params.keys():
35
+ for k in list(params[key].keys()):
36
+ if k.startswith("module."):
37
+ params[key][k[len("module."):]] = params[key][k]
38
+ del params[key][k]
39
+ self.encoder.load_state_dict(params['encoder'])
40
+ self.quantizer.load_state_dict(params['vq'])
41
+ if self.decoder is not None:
42
+ self.decoder.load_state_dict(params['decoder'])
43
+ if self.asr_decoder is not None:
44
+ self.asr_decoder.load_state_dict(params['predictor'], strict=False)
45
+
46
+ def forward(self, waves_16k, wave_16k_lens, ssl_model=None):
47
+ ssl_fn = self.ssl_model if self.ssl_model else ssl_model
48
+ assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided"
49
+ waves_16k_input_list = [
50
+ waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy()
51
+ for bib in range(len(waves_16k))
52
+ ]
53
+ alt_inputs = self.ssl_feature_extractor(
54
+ waves_16k_input_list,
55
+ return_tensors='pt',
56
+ return_attention_mask=True,
57
+ padding=True,
58
+ sampling_rate=16000
59
+ ).to(waves_16k.device)
60
+ feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320 # frame rate of hubert is 50 Hz
61
+
62
+ outputs = ssl_fn(
63
+ alt_inputs.input_values,
64
+ attention_mask=alt_inputs.attention_mask,
65
+ )
66
+ last_hidden_states = outputs.last_hidden_state
67
+ last_hidden_states = last_hidden_states[:, :feature_lens.max(), :]
68
+ feature_lens = feature_lens.clamp(max=last_hidden_states.size(1))
69
+ last_hidden_states = last_hidden_states.transpose(1, 2)
70
+ x_hidden = self.encoder(last_hidden_states, feature_lens)
71
+ x_hidden = x_hidden.transpose(1, 2)
72
+ x_quantized, indices = self.quantizer(x_hidden)[:2]
73
+ return x_quantized, indices, feature_lens
modules/astral_quantization/transformer.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+ from torch.nn import functional as F
13
+ import time
14
+
15
+ def find_multiple(n: int, k: int) -> int:
16
+ if n % k == 0:
17
+ return n
18
+ return n + k - (n % k)
19
+
20
+ class AdaptiveLayerNorm(nn.Module):
21
+ r"""Adaptive Layer Normalization"""
22
+
23
+ def __init__(self, d_model, norm) -> None:
24
+ super(AdaptiveLayerNorm, self).__init__()
25
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
26
+ self.norm = norm
27
+ self.d_model = d_model
28
+ self.eps = self.norm.eps
29
+
30
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
31
+ if embedding is None:
32
+ return self.norm(input)
33
+ weight, bias = torch.split(
34
+ self.project_layer(embedding),
35
+ split_size_or_sections=self.d_model,
36
+ dim=-1,
37
+ )
38
+ return weight * self.norm(input) + bias
39
+
40
+
41
+ @dataclass
42
+ class ModelArgs:
43
+ block_size: int = 2048
44
+ vocab_size: int = 32000
45
+ n_layer: int = 32
46
+ n_head: int = 32
47
+ dim: int = 4096
48
+ intermediate_size: int = None
49
+ n_local_heads: int = -1
50
+ head_dim: int = 64
51
+ rope_base: float = 10000
52
+ norm_eps: float = 1e-5
53
+ has_cross_attention: bool = False
54
+ context_dim: int = 0
55
+ is_causal: bool = False
56
+ dropout_rate: float = 0.1
57
+ attn_dropout_rate: float = 0.1
58
+
59
+ def __post_init__(self):
60
+ if self.n_local_heads == -1:
61
+ self.n_local_heads = self.n_head
62
+ if self.intermediate_size is None:
63
+ hidden_dim = 4 * self.dim
64
+ n_hidden = int(2 * hidden_dim / 3)
65
+ self.intermediate_size = find_multiple(n_hidden, 256)
66
+ # self.head_dim = self.dim // self.n_head
67
+
68
+ class Transformer(nn.Module):
69
+ def __init__(self, config: ModelArgs) -> None:
70
+ super().__init__()
71
+ self.config = config
72
+
73
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
74
+ self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
75
+
76
+ self.max_batch_size = -1
77
+ self.max_seq_length = config.block_size
78
+ freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
79
+ self.config.rope_base)
80
+ self.register_buffer("freqs_cis", freqs_cis)
81
+
82
+ causal_mask = torch.tril(
83
+ torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
84
+ )
85
+ self.register_buffer("causal_mask", causal_mask)
86
+
87
+ def forward(self,
88
+ x: Tensor,
89
+ c: Tensor,
90
+ input_pos: Optional[Tensor] = None,
91
+ mask: Optional[Tensor] = None,
92
+ context: Optional[Tensor] = None,
93
+ context_input_pos: Optional[Tensor] = None,
94
+ cross_attention_mask: Optional[Tensor] = None,
95
+ ) -> Tensor:
96
+ if mask is None:
97
+ mask = self.causal_mask[:x.size(1), :x.size(1)]
98
+ else:
99
+ mask = mask[..., input_pos]
100
+ freqs_cis = self.freqs_cis[input_pos]
101
+ if context is not None:
102
+ context_freqs_cis = self.freqs_cis[context_input_pos]
103
+ else:
104
+ context_freqs_cis = None
105
+ skip_in_x_list = []
106
+ for i, layer in enumerate(self.layers):
107
+ x = layer(x, c, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask)
108
+ x = self.norm(x, c)
109
+ return x
110
+
111
+
112
+ class TransformerBlock(nn.Module):
113
+ def __init__(self, config: ModelArgs) -> None:
114
+ super().__init__()
115
+ self.attention = Attention(config)
116
+ self.feed_forward = FeedForward(config)
117
+ self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
118
+ self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
119
+
120
+ if config.has_cross_attention:
121
+ self.has_cross_attention = True
122
+ self.cross_attention = Attention(config, is_cross_attention=True)
123
+ self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
124
+ else:
125
+ self.has_cross_attention = False
126
+
127
+ def forward(self,
128
+ x: Tensor,
129
+ c: Tensor,
130
+ freqs_cis: Tensor,
131
+ mask: Tensor,
132
+ context: Optional[Tensor] = None,
133
+ context_freqs_cis: Optional[Tensor] = None,
134
+ cross_attention_mask: Optional[Tensor] = None,
135
+ ) -> Tensor:
136
+ #time_attn_start = time.time()
137
+ h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask)
138
+ #print(f"time take for attention of sequence length {x.shape[1]} is {time.time() - time_attn_start}")
139
+ if self.has_cross_attention:
140
+ h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, context, context_freqs_cis)
141
+ out = h + self.feed_forward(self.ffn_norm(h, c))
142
+ return out
143
+
144
+
145
+ class Attention(nn.Module):
146
+ def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
147
+ super().__init__()
148
+ assert config.dim % config.n_head == 0
149
+
150
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
151
+ # key, query, value projections for all heads, but in a batch
152
+ if is_cross_attention:
153
+ self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
154
+ self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
155
+ else:
156
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
157
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
158
+ self.kv_cache = None
159
+
160
+ self.n_head = config.n_head
161
+ self.head_dim = config.head_dim
162
+ self.n_local_heads = config.n_local_heads
163
+ self.dim = config.dim
164
+ self.attn_dropout_rate = config.attn_dropout_rate
165
+
166
+ def forward(self,
167
+ x: Tensor,
168
+ freqs_cis: Tensor,
169
+ mask: Tensor,
170
+ context: Optional[Tensor] = None,
171
+ context_freqs_cis: Optional[Tensor] = None,
172
+ ) -> Tensor:
173
+ bsz, seqlen, _ = x.shape
174
+
175
+ kv_size = self.n_local_heads * self.head_dim
176
+ if context is None:
177
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
178
+ context_seqlen = seqlen
179
+ else:
180
+ q = self.wq(x)
181
+ k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
182
+ context_seqlen = context.shape[1]
183
+
184
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
185
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
186
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
187
+
188
+ q = apply_rotary_emb(q, freqs_cis)
189
+ k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
190
+
191
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
192
+
193
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
194
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
195
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_dropout_rate if self.training else 0.0)
196
+
197
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
198
+
199
+ y = self.wo(y)
200
+ return y
201
+
202
+
203
+ class FeedForward(nn.Module):
204
+ def __init__(self, config: ModelArgs) -> None:
205
+ super().__init__()
206
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
207
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
208
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
209
+ self.dropout = nn.Dropout(config.dropout_rate)
210
+
211
+ def forward(self, x: Tensor) -> Tensor:
212
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
213
+
214
+
215
+ class RMSNorm(nn.Module):
216
+ def __init__(self, dim: int, eps: float = 1e-5):
217
+ super().__init__()
218
+ self.eps = eps
219
+ self.weight = nn.Parameter(torch.ones(dim))
220
+
221
+ def _norm(self, x):
222
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
223
+
224
+ def forward(self, x: Tensor) -> Tensor:
225
+ output = self._norm(x.float()).type_as(x)
226
+ return output * self.weight
227
+
228
+
229
+ def precompute_freqs_cis(
230
+ seq_len: int, n_elem: int, base: int = 10000,
231
+ dtype: torch.dtype = torch.bfloat16
232
+ ) -> Tensor:
233
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
234
+ t = torch.arange(seq_len, device=freqs.device)
235
+ freqs = torch.outer(t, freqs)
236
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
237
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
238
+ return cache.to(dtype=dtype)
239
+
240
+
241
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
242
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
243
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
244
+ x_out2 = torch.stack(
245
+ [
246
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
247
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
248
+ ],
249
+ -1,
250
+ )
251
+
252
+ x_out2 = x_out2.flatten(3)
253
+ return x_out2.type_as(x)
254
+
modules/audio.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46
+ if torch.min(y) < -1.0:
47
+ print("min value is ", torch.min(y))
48
+ if torch.max(y) > 1.0:
49
+ print("max value is ", torch.max(y))
50
+
51
+ global mel_basis, hann_window # pylint: disable=global-statement
52
+ if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
53
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54
+ mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55
+ hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
56
+
57
+ y = torch.nn.functional.pad(
58
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59
+ )
60
+ y = y.squeeze(1)
61
+
62
+ spec = torch.view_as_real(
63
+ torch.stft(
64
+ y,
65
+ n_fft,
66
+ hop_length=hop_size,
67
+ win_length=win_size,
68
+ window=hann_window[str(sampling_rate) + "_" + str(y.device)],
69
+ center=center,
70
+ pad_mode="reflect",
71
+ normalized=False,
72
+ onesided=True,
73
+ return_complex=True,
74
+ )
75
+ )
76
+
77
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78
+
79
+ spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
80
+ spec = spectral_normalize_torch(spec)
81
+
82
+ return spec
modules/bigvgan/__init__.py ADDED
File without changes
modules/bigvgan/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super(SnakeBeta, self).__init__()
91
+ self.in_features = in_features
92
+
93
+ # initialize alpha
94
+ self.alpha_logscale = alpha_logscale
95
+ if self.alpha_logscale: # log scale alphas initialized to zeros
96
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
98
+ else: # linear scale alphas initialized to ones
99
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
100
+ self.beta = Parameter(torch.ones(in_features) * alpha)
101
+
102
+ self.alpha.requires_grad = alpha_trainable
103
+ self.beta.requires_grad = alpha_trainable
104
+
105
+ self.no_div_by_zero = 0.000000001
106
+
107
+ def forward(self, x):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
115
+ if self.alpha_logscale:
116
+ alpha = torch.exp(alpha)
117
+ beta = torch.exp(beta)
118
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
+
120
+ return x
modules/bigvgan/alias_free_activation/cuda/__init__.py ADDED
File without changes
modules/bigvgan/alias_free_activation/cuda/activation1d.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from ..torch.resample import UpSample1d, DownSample1d
7
+
8
+ # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
9
+ from ..cuda import load
10
+
11
+ anti_alias_activation_cuda = load.load()
12
+
13
+
14
+ class FusedAntiAliasActivation(torch.autograd.Function):
15
+ """
16
+ Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
17
+ The hyperparameters are hard-coded in the kernel to maximize speed.
18
+ NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
19
+ """
20
+
21
+ @staticmethod
22
+ def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
23
+ activation_results = anti_alias_activation_cuda.forward(
24
+ inputs, up_ftr, down_ftr, alpha, beta
25
+ )
26
+
27
+ return activation_results
28
+
29
+ @staticmethod
30
+ def backward(ctx, output_grads):
31
+ raise NotImplementedError
32
+ return output_grads, None, None
33
+
34
+
35
+ class Activation1d(nn.Module):
36
+ def __init__(
37
+ self,
38
+ activation,
39
+ up_ratio: int = 2,
40
+ down_ratio: int = 2,
41
+ up_kernel_size: int = 12,
42
+ down_kernel_size: int = 12,
43
+ fused: bool = True,
44
+ ):
45
+ super().__init__()
46
+ self.up_ratio = up_ratio
47
+ self.down_ratio = down_ratio
48
+ self.act = activation
49
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
50
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
51
+
52
+ self.fused = fused # Whether to use fused CUDA kernel or not
53
+
54
+ def forward(self, x):
55
+ if not self.fused:
56
+ x = self.upsample(x)
57
+ x = self.act(x)
58
+ x = self.downsample(x)
59
+ return x
60
+ else:
61
+ if self.act.__class__.__name__ == "Snake":
62
+ beta = self.act.alpha.data # Snake uses same params for alpha and beta
63
+ else:
64
+ beta = (
65
+ self.act.beta.data
66
+ ) # Snakebeta uses different params for alpha and beta
67
+ alpha = self.act.alpha.data
68
+ if (
69
+ not self.act.alpha_logscale
70
+ ): # Exp baked into cuda kernel, cancel it out with a log
71
+ alpha = torch.log(alpha)
72
+ beta = torch.log(beta)
73
+
74
+ x = FusedAntiAliasActivation.apply(
75
+ x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
76
+ )
77
+ return x
modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <torch/extension.h>
18
+
19
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
23
+ }
modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_profiler_api.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+ #include <torch/extension.h>
24
+ #include "type_shim.h"
25
+ #include <assert.h>
26
+ #include <cfloat>
27
+ #include <limits>
28
+ #include <stdint.h>
29
+ #include <c10/macros/Macros.h>
30
+
31
+ namespace
32
+ {
33
+ // Hard-coded hyperparameters
34
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
35
+ constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
36
+ constexpr int BUFFER_SIZE = 32;
37
+ constexpr int FILTER_SIZE = 12;
38
+ constexpr int HALF_FILTER_SIZE = 6;
39
+ constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
40
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
41
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
42
+
43
+ template <typename input_t, typename output_t, typename acc_t>
44
+ __global__ void anti_alias_activation_forward(
45
+ output_t *dst,
46
+ const input_t *src,
47
+ const input_t *up_ftr,
48
+ const input_t *down_ftr,
49
+ const input_t *alpha,
50
+ const input_t *beta,
51
+ int batch_size,
52
+ int channels,
53
+ int seq_len)
54
+ {
55
+ // Up and downsample filters
56
+ input_t up_filter[FILTER_SIZE];
57
+ input_t down_filter[FILTER_SIZE];
58
+
59
+ // Load data from global memory including extra indices reserved for replication paddings
60
+ input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
61
+ input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
62
+
63
+ // Output stores downsampled output before writing to dst
64
+ output_t output[BUFFER_SIZE];
65
+
66
+ // blockDim/threadIdx = (128, 1, 1)
67
+ // gridDim/blockIdx = (seq_blocks, channels, batches)
68
+ int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
69
+ int local_offset = threadIdx.x * BUFFER_SIZE;
70
+ int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
71
+
72
+ // intermediate have double the seq_len
73
+ int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
74
+ int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
75
+
76
+ // Get values needed for replication padding before moving pointer
77
+ const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
78
+ input_t seq_left_most_value = right_most_pntr[0];
79
+ input_t seq_right_most_value = right_most_pntr[seq_len - 1];
80
+
81
+ // Move src and dst pointers
82
+ src += block_offset + local_offset;
83
+ dst += block_offset + local_offset;
84
+
85
+ // Alpha and beta values for snake activatons. Applies exp by default
86
+ alpha = alpha + blockIdx.y;
87
+ input_t alpha_val = expf(alpha[0]);
88
+ beta = beta + blockIdx.y;
89
+ input_t beta_val = expf(beta[0]);
90
+
91
+ #pragma unroll
92
+ for (int it = 0; it < FILTER_SIZE; it += 1)
93
+ {
94
+ up_filter[it] = up_ftr[it];
95
+ down_filter[it] = down_ftr[it];
96
+ }
97
+
98
+ // Apply replication padding for upsampling, matching torch impl
99
+ #pragma unroll
100
+ for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
101
+ {
102
+ int element_index = seq_offset + it; // index for element
103
+ if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
104
+ {
105
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
106
+ }
107
+ if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
108
+ {
109
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
110
+ }
111
+ if ((element_index >= 0) && (element_index < seq_len))
112
+ {
113
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
114
+ }
115
+ }
116
+
117
+ // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
118
+ #pragma unroll
119
+ for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
120
+ {
121
+ input_t acc = 0.0;
122
+ int element_index = intermediate_seq_offset + it; // index for intermediate
123
+ #pragma unroll
124
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
125
+ {
126
+ if ((element_index + f_idx) >= 0)
127
+ {
128
+ acc += up_filter[f_idx] * elements[it + f_idx];
129
+ }
130
+ }
131
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
132
+ }
133
+
134
+ // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
135
+ double no_div_by_zero = 0.000000001;
136
+ #pragma unroll
137
+ for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
138
+ {
139
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
140
+ }
141
+
142
+ // Apply replication padding before downsampling conv from intermediates
143
+ #pragma unroll
144
+ for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
145
+ {
146
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
147
+ }
148
+ #pragma unroll
149
+ for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
150
+ {
151
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
152
+ }
153
+
154
+ // Apply downsample strided convolution (assuming stride=2) from intermediates
155
+ #pragma unroll
156
+ for (int it = 0; it < BUFFER_SIZE; it += 1)
157
+ {
158
+ input_t acc = 0.0;
159
+ #pragma unroll
160
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
161
+ {
162
+ // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
163
+ acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
164
+ }
165
+ output[it] = acc;
166
+ }
167
+
168
+ // Write output to dst
169
+ #pragma unroll
170
+ for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
171
+ {
172
+ int element_index = seq_offset + it;
173
+ if (element_index < seq_len)
174
+ {
175
+ dst[it] = output[it];
176
+ }
177
+ }
178
+
179
+ }
180
+
181
+ template <typename input_t, typename output_t, typename acc_t>
182
+ void dispatch_anti_alias_activation_forward(
183
+ output_t *dst,
184
+ const input_t *src,
185
+ const input_t *up_ftr,
186
+ const input_t *down_ftr,
187
+ const input_t *alpha,
188
+ const input_t *beta,
189
+ int batch_size,
190
+ int channels,
191
+ int seq_len)
192
+ {
193
+ if (seq_len == 0)
194
+ {
195
+ return;
196
+ }
197
+ else
198
+ {
199
+ // Use 128 threads per block to maximimize gpu utilization
200
+ constexpr int threads_per_block = 128;
201
+ constexpr int seq_len_per_block = 4096;
202
+ int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
203
+ dim3 blocks(blocks_per_seq_len, channels, batch_size);
204
+ dim3 threads(threads_per_block, 1, 1);
205
+
206
+ anti_alias_activation_forward<input_t, output_t, acc_t>
207
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
208
+ }
209
+ }
210
+ }
211
+
212
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
213
+ {
214
+ // Input is a 3d tensor with dimensions [batches, channels, seq_len]
215
+ const int batches = input.size(0);
216
+ const int channels = input.size(1);
217
+ const int seq_len = input.size(2);
218
+
219
+ // Output
220
+ auto act_options = input.options().requires_grad(false);
221
+
222
+ torch::Tensor anti_alias_activation_results =
223
+ torch::empty({batches, channels, seq_len}, act_options);
224
+
225
+ void *input_ptr = static_cast<void *>(input.data_ptr());
226
+ void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
227
+ void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
228
+ void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
229
+ void *beta_ptr = static_cast<void *>(beta.data_ptr());
230
+ void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
231
+
232
+ DISPATCH_FLOAT_HALF_AND_BFLOAT(
233
+ input.scalar_type(),
234
+ "dispatch anti alias activation_forward",
235
+ dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
236
+ reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
237
+ reinterpret_cast<const scalar_t *>(input_ptr),
238
+ reinterpret_cast<const scalar_t *>(up_filter_ptr),
239
+ reinterpret_cast<const scalar_t *>(down_filter_ptr),
240
+ reinterpret_cast<const scalar_t *>(alpha_ptr),
241
+ reinterpret_cast<const scalar_t *>(beta_ptr),
242
+ batches,
243
+ channels,
244
+ seq_len););
245
+ return anti_alias_activation_results;
246
+ }
modules/bigvgan/alias_free_activation/cuda/compat.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+ #ifndef TORCH_CHECK
22
+ #define TORCH_CHECK AT_CHECK
23
+ #endif
24
+
25
+ #ifdef VERSION_GE_1_3
26
+ #define DATA_PTR data_ptr
27
+ #else
28
+ #define DATA_PTR data
29
+ #endif
modules/bigvgan/alias_free_activation/cuda/load.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+
8
+ from torch.utils import cpp_extension
9
+
10
+ """
11
+ Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
12
+ Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
13
+ """
14
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
15
+
16
+
17
+ def load():
18
+ # Check if cuda 11 is installed for compute capability 8.0
19
+ cc_flag = []
20
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
21
+ if int(bare_metal_major) >= 11:
22
+ cc_flag.append("-gencode")
23
+ cc_flag.append("arch=compute_80,code=sm_80")
24
+
25
+ # Build path
26
+ srcpath = pathlib.Path(__file__).parent.absolute()
27
+ buildpath = srcpath / "build"
28
+ _create_build_dir(buildpath)
29
+
30
+ # Helper function to build the kernels.
31
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
32
+ return cpp_extension.load(
33
+ name=name,
34
+ sources=sources,
35
+ build_directory=buildpath,
36
+ extra_cflags=[
37
+ "-O3",
38
+ ],
39
+ extra_cuda_cflags=[
40
+ "-O3",
41
+ "-gencode",
42
+ "arch=compute_70,code=sm_70",
43
+ "--use_fast_math",
44
+ ]
45
+ + extra_cuda_flags
46
+ + cc_flag,
47
+ verbose=True,
48
+ )
49
+
50
+ extra_cuda_flags = [
51
+ "-U__CUDA_NO_HALF_OPERATORS__",
52
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
53
+ "--expt-relaxed-constexpr",
54
+ "--expt-extended-lambda",
55
+ ]
56
+
57
+ sources = [
58
+ srcpath / "anti_alias_activation.cpp",
59
+ srcpath / "anti_alias_activation_cuda.cu",
60
+ ]
61
+ anti_alias_activation_cuda = _cpp_extention_load_helper(
62
+ "anti_alias_activation_cuda", sources, extra_cuda_flags
63
+ )
64
+
65
+ return anti_alias_activation_cuda
66
+
67
+
68
+ def _get_cuda_bare_metal_version(cuda_dir):
69
+ raw_output = subprocess.check_output(
70
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
71
+ )
72
+ output = raw_output.split()
73
+ release_idx = output.index("release") + 1
74
+ release = output[release_idx].split(".")
75
+ bare_metal_major = release[0]
76
+ bare_metal_minor = release[1][0]
77
+
78
+ return raw_output, bare_metal_major, bare_metal_minor
79
+
80
+
81
+ def _create_build_dir(buildpath):
82
+ try:
83
+ os.mkdir(buildpath)
84
+ except OSError:
85
+ if not os.path.isdir(buildpath):
86
+ print(f"Creation of the build directory {buildpath} failed")
modules/bigvgan/alias_free_activation/cuda/type_shim.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include "compat.h"
19
+
20
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
21
+ switch (TYPE) \
22
+ { \
23
+ case at::ScalarType::Float: \
24
+ { \
25
+ using scalar_t = float; \
26
+ __VA_ARGS__; \
27
+ break; \
28
+ } \
29
+ case at::ScalarType::Half: \
30
+ { \
31
+ using scalar_t = at::Half; \
32
+ __VA_ARGS__; \
33
+ break; \
34
+ } \
35
+ case at::ScalarType::BFloat16: \
36
+ { \
37
+ using scalar_t = at::BFloat16; \
38
+ __VA_ARGS__; \
39
+ break; \
40
+ } \
41
+ default: \
42
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
43
+ }
44
+
45
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
46
+ switch (TYPEIN) \
47
+ { \
48
+ case at::ScalarType::Float: \
49
+ { \
50
+ using scalar_t_in = float; \
51
+ switch (TYPEOUT) \
52
+ { \
53
+ case at::ScalarType::Float: \
54
+ { \
55
+ using scalar_t_out = float; \
56
+ __VA_ARGS__; \
57
+ break; \
58
+ } \
59
+ case at::ScalarType::Half: \
60
+ { \
61
+ using scalar_t_out = at::Half; \
62
+ __VA_ARGS__; \
63
+ break; \
64
+ } \
65
+ case at::ScalarType::BFloat16: \
66
+ { \
67
+ using scalar_t_out = at::BFloat16; \
68
+ __VA_ARGS__; \
69
+ break; \
70
+ } \
71
+ default: \
72
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
73
+ } \
74
+ break; \
75
+ } \
76
+ case at::ScalarType::Half: \
77
+ { \
78
+ using scalar_t_in = at::Half; \
79
+ using scalar_t_out = at::Half; \
80
+ __VA_ARGS__; \
81
+ break; \
82
+ } \
83
+ case at::ScalarType::BFloat16: \
84
+ { \
85
+ using scalar_t_in = at::BFloat16; \
86
+ using scalar_t_out = at::BFloat16; \
87
+ __VA_ARGS__; \
88
+ break; \
89
+ } \
90
+ default: \
91
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
92
+ }
modules/bigvgan/alias_free_activation/torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
modules/bigvgan/alias_free_activation/torch/act.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ activation,
12
+ up_ratio: int = 2,
13
+ down_ratio: int = 2,
14
+ up_kernel_size: int = 12,
15
+ down_kernel_size: int = 12,
16
+ ):
17
+ super().__init__()
18
+ self.up_ratio = up_ratio
19
+ self.down_ratio = down_ratio
20
+ self.act = activation
21
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
22
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
23
+
24
+ # x: [B,C,T]
25
+ def forward(self, x):
26
+ x = self.upsample(x)
27
+ x = self.act(x)
28
+ x = self.downsample(x)
29
+
30
+ return x
modules/bigvgan/alias_free_activation/torch/filter.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if "sinc" in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
+ half_size = kernel_size // 2
35
+
36
+ # For kaiser window
37
+ delta_f = 4 * half_width
38
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
+ beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
+ else:
44
+ beta = 0.0
45
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
+
47
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
+ if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
+ else:
51
+ time = torch.arange(kernel_size) - half_size
52
+ if cutoff == 0:
53
+ filter_ = torch.zeros_like(time)
54
+ else:
55
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ """
57
+ Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
58
+ """
59
+ filter_ /= filter_.sum()
60
+ filter = filter_.view(1, 1, kernel_size)
61
+
62
+ return filter
63
+
64
+
65
+ class LowPassFilter1d(nn.Module):
66
+ def __init__(
67
+ self,
68
+ cutoff=0.5,
69
+ half_width=0.6,
70
+ stride: int = 1,
71
+ padding: bool = True,
72
+ padding_mode: str = "replicate",
73
+ kernel_size: int = 12,
74
+ ):
75
+ """
76
+ kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
77
+ """
78
+ super().__init__()
79
+ if cutoff < -0.0:
80
+ raise ValueError("Minimum cutoff must be larger than zero.")
81
+ if cutoff > 0.5:
82
+ raise ValueError("A cutoff above 0.5 does not make sense.")
83
+ self.kernel_size = kernel_size
84
+ self.even = kernel_size % 2 == 0
85
+ self.pad_left = kernel_size // 2 - int(self.even)
86
+ self.pad_right = kernel_size // 2
87
+ self.stride = stride
88
+ self.padding = padding
89
+ self.padding_mode = padding_mode
90
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
91
+ self.register_buffer("filter", filter)
92
+
93
+ # Input [B, C, T]
94
+ def forward(self, x):
95
+ _, C, _ = x.shape
96
+
97
+ if self.padding:
98
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
99
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
100
+
101
+ return out
modules/bigvgan/alias_free_activation/torch/resample.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = (
15
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16
+ )
17
+ self.stride = ratio
18
+ self.pad = self.kernel_size // ratio - 1
19
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20
+ self.pad_right = (
21
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
22
+ )
23
+ filter = kaiser_sinc_filter1d(
24
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
25
+ )
26
+ self.register_buffer("filter", filter)
27
+
28
+ # x: [B, C, T]
29
+ def forward(self, x):
30
+ _, C, _ = x.shape
31
+
32
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
33
+ x = self.ratio * F.conv_transpose1d(
34
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
35
+ )
36
+ x = x[..., self.pad_left : -self.pad_right]
37
+
38
+ return x
39
+
40
+
41
+ class DownSample1d(nn.Module):
42
+ def __init__(self, ratio=2, kernel_size=None):
43
+ super().__init__()
44
+ self.ratio = ratio
45
+ self.kernel_size = (
46
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
47
+ )
48
+ self.lowpass = LowPassFilter1d(
49
+ cutoff=0.5 / ratio,
50
+ half_width=0.6 / ratio,
51
+ stride=ratio,
52
+ kernel_size=self.kernel_size,
53
+ )
54
+
55
+ def forward(self, x):
56
+ xx = self.lowpass(x)
57
+
58
+ return xx
modules/bigvgan/bigvgan.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Optional, Union, Dict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import Conv1d, ConvTranspose1d
15
+ from torch.nn.utils import weight_norm, remove_weight_norm
16
+
17
+ from . import activations
18
+ from .utils import init_weights, get_padding
19
+ from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
20
+ from .env import AttrDict
21
+
22
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
23
+
24
+
25
+ def load_hparams_from_json(path) -> AttrDict:
26
+ with open(path) as f:
27
+ data = f.read()
28
+ return AttrDict(json.loads(data))
29
+
30
+
31
+ class AMPBlock1(torch.nn.Module):
32
+ """
33
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
34
+ AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
35
+
36
+ Args:
37
+ h (AttrDict): Hyperparameters.
38
+ channels (int): Number of convolution channels.
39
+ kernel_size (int): Size of the convolution kernel. Default is 3.
40
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
41
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ h: AttrDict,
47
+ channels: int,
48
+ kernel_size: int = 3,
49
+ dilation: tuple = (1, 3, 5),
50
+ activation: str = None,
51
+ ):
52
+ super().__init__()
53
+
54
+ self.h = h
55
+
56
+ self.convs1 = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ stride=1,
64
+ dilation=d,
65
+ padding=get_padding(kernel_size, d),
66
+ )
67
+ )
68
+ for d in dilation
69
+ ]
70
+ )
71
+ self.convs1.apply(init_weights)
72
+
73
+ self.convs2 = nn.ModuleList(
74
+ [
75
+ weight_norm(
76
+ Conv1d(
77
+ channels,
78
+ channels,
79
+ kernel_size,
80
+ stride=1,
81
+ dilation=1,
82
+ padding=get_padding(kernel_size, 1),
83
+ )
84
+ )
85
+ for _ in range(len(dilation))
86
+ ]
87
+ )
88
+ self.convs2.apply(init_weights)
89
+
90
+ self.num_layers = len(self.convs1) + len(
91
+ self.convs2
92
+ ) # Total number of conv layers
93
+
94
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
95
+ if self.h.get("use_cuda_kernel", False):
96
+ from .alias_free_activation.cuda.activation1d import (
97
+ Activation1d as CudaActivation1d,
98
+ )
99
+
100
+ Activation1d = CudaActivation1d
101
+ else:
102
+ Activation1d = TorchActivation1d
103
+
104
+ # Activation functions
105
+ if activation == "snake":
106
+ self.activations = nn.ModuleList(
107
+ [
108
+ Activation1d(
109
+ activation=activations.Snake(
110
+ channels, alpha_logscale=h.snake_logscale
111
+ )
112
+ )
113
+ for _ in range(self.num_layers)
114
+ ]
115
+ )
116
+ elif activation == "snakebeta":
117
+ self.activations = nn.ModuleList(
118
+ [
119
+ Activation1d(
120
+ activation=activations.SnakeBeta(
121
+ channels, alpha_logscale=h.snake_logscale
122
+ )
123
+ )
124
+ for _ in range(self.num_layers)
125
+ ]
126
+ )
127
+ else:
128
+ raise NotImplementedError(
129
+ "activation incorrectly specified. check the config file and look for 'activation'."
130
+ )
131
+
132
+ def forward(self, x):
133
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
134
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
135
+ xt = a1(x)
136
+ xt = c1(xt)
137
+ xt = a2(xt)
138
+ xt = c2(xt)
139
+ x = xt + x
140
+
141
+ return x
142
+
143
+ def remove_weight_norm(self):
144
+ for l in self.convs1:
145
+ remove_weight_norm(l)
146
+ for l in self.convs2:
147
+ remove_weight_norm(l)
148
+
149
+
150
+ class AMPBlock2(torch.nn.Module):
151
+ """
152
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
153
+ Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
154
+
155
+ Args:
156
+ h (AttrDict): Hyperparameters.
157
+ channels (int): Number of convolution channels.
158
+ kernel_size (int): Size of the convolution kernel. Default is 3.
159
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
160
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ h: AttrDict,
166
+ channels: int,
167
+ kernel_size: int = 3,
168
+ dilation: tuple = (1, 3, 5),
169
+ activation: str = None,
170
+ ):
171
+ super().__init__()
172
+
173
+ self.h = h
174
+
175
+ self.convs = nn.ModuleList(
176
+ [
177
+ weight_norm(
178
+ Conv1d(
179
+ channels,
180
+ channels,
181
+ kernel_size,
182
+ stride=1,
183
+ dilation=d,
184
+ padding=get_padding(kernel_size, d),
185
+ )
186
+ )
187
+ for d in dilation
188
+ ]
189
+ )
190
+ self.convs.apply(init_weights)
191
+
192
+ self.num_layers = len(self.convs) # Total number of conv layers
193
+
194
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
195
+ if self.h.get("use_cuda_kernel", False):
196
+ from .alias_free_activation.cuda.activation1d import (
197
+ Activation1d as CudaActivation1d,
198
+ )
199
+
200
+ Activation1d = CudaActivation1d
201
+ else:
202
+ Activation1d = TorchActivation1d
203
+
204
+ # Activation functions
205
+ if activation == "snake":
206
+ self.activations = nn.ModuleList(
207
+ [
208
+ Activation1d(
209
+ activation=activations.Snake(
210
+ channels, alpha_logscale=h.snake_logscale
211
+ )
212
+ )
213
+ for _ in range(self.num_layers)
214
+ ]
215
+ )
216
+ elif activation == "snakebeta":
217
+ self.activations = nn.ModuleList(
218
+ [
219
+ Activation1d(
220
+ activation=activations.SnakeBeta(
221
+ channels, alpha_logscale=h.snake_logscale
222
+ )
223
+ )
224
+ for _ in range(self.num_layers)
225
+ ]
226
+ )
227
+ else:
228
+ raise NotImplementedError(
229
+ "activation incorrectly specified. check the config file and look for 'activation'."
230
+ )
231
+
232
+ def forward(self, x):
233
+ for c, a in zip(self.convs, self.activations):
234
+ xt = a(x)
235
+ xt = c(xt)
236
+ x = xt + x
237
+
238
+ def remove_weight_norm(self):
239
+ for l in self.convs:
240
+ remove_weight_norm(l)
241
+
242
+
243
+ class BigVGAN(
244
+ torch.nn.Module,
245
+ PyTorchModelHubMixin,
246
+ library_name="bigvgan",
247
+ repo_url="https://github.com/NVIDIA/BigVGAN",
248
+ docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
249
+ pipeline_tag="audio-to-audio",
250
+ license="mit",
251
+ tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
252
+ ):
253
+ """
254
+ BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
255
+ New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
256
+
257
+ Args:
258
+ h (AttrDict): Hyperparameters.
259
+ use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
260
+
261
+ Note:
262
+ - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
263
+ - Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
264
+ """
265
+
266
+ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
267
+ super().__init__()
268
+ self.h = h
269
+ self.h["use_cuda_kernel"] = use_cuda_kernel
270
+
271
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
272
+ if self.h.get("use_cuda_kernel", False):
273
+ from .alias_free_activation.cuda.activation1d import (
274
+ Activation1d as CudaActivation1d,
275
+ )
276
+
277
+ Activation1d = CudaActivation1d
278
+ else:
279
+ Activation1d = TorchActivation1d
280
+
281
+ self.num_kernels = len(h.resblock_kernel_sizes)
282
+ self.num_upsamples = len(h.upsample_rates)
283
+
284
+ # Pre-conv
285
+ self.conv_pre = weight_norm(
286
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
287
+ )
288
+
289
+ # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
290
+ if h.resblock == "1":
291
+ resblock_class = AMPBlock1
292
+ elif h.resblock == "2":
293
+ resblock_class = AMPBlock2
294
+ else:
295
+ raise ValueError(
296
+ f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
297
+ )
298
+
299
+ # Transposed conv-based upsamplers. does not apply anti-aliasing
300
+ self.ups = nn.ModuleList()
301
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
302
+ self.ups.append(
303
+ nn.ModuleList(
304
+ [
305
+ weight_norm(
306
+ ConvTranspose1d(
307
+ h.upsample_initial_channel // (2 ** i),
308
+ h.upsample_initial_channel // (2 ** (i + 1)),
309
+ k,
310
+ u,
311
+ padding=(k - u) // 2,
312
+ )
313
+ )
314
+ ]
315
+ )
316
+ )
317
+
318
+ # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
319
+ self.resblocks = nn.ModuleList()
320
+ for i in range(len(self.ups)):
321
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
322
+ for j, (k, d) in enumerate(
323
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
324
+ ):
325
+ self.resblocks.append(
326
+ resblock_class(h, ch, k, d, activation=h.activation)
327
+ )
328
+
329
+ # Post-conv
330
+ activation_post = (
331
+ activations.Snake(ch, alpha_logscale=h.snake_logscale)
332
+ if h.activation == "snake"
333
+ else (
334
+ activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
335
+ if h.activation == "snakebeta"
336
+ else None
337
+ )
338
+ )
339
+ if activation_post is None:
340
+ raise NotImplementedError(
341
+ "activation incorrectly specified. check the config file and look for 'activation'."
342
+ )
343
+
344
+ self.activation_post = Activation1d(activation=activation_post)
345
+
346
+ # Whether to use bias for the final conv_post. Default to True for backward compatibility
347
+ self.use_bias_at_final = h.get("use_bias_at_final", True)
348
+ self.conv_post = weight_norm(
349
+ Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
350
+ )
351
+
352
+ # Weight initialization
353
+ for i in range(len(self.ups)):
354
+ self.ups[i].apply(init_weights)
355
+ self.conv_post.apply(init_weights)
356
+
357
+ # Final tanh activation. Defaults to True for backward compatibility
358
+ self.use_tanh_at_final = h.get("use_tanh_at_final", True)
359
+
360
+ def forward(self, x):
361
+ # Pre-conv
362
+ x = self.conv_pre(x)
363
+
364
+ for i in range(self.num_upsamples):
365
+ # Upsampling
366
+ for i_up in range(len(self.ups[i])):
367
+ x = self.ups[i][i_up](x)
368
+ # AMP blocks
369
+ xs = None
370
+ for j in range(self.num_kernels):
371
+ if xs is None:
372
+ xs = self.resblocks[i * self.num_kernels + j](x)
373
+ else:
374
+ xs += self.resblocks[i * self.num_kernels + j](x)
375
+ x = xs / self.num_kernels
376
+
377
+ # Post-conv
378
+ x = self.activation_post(x)
379
+ x = self.conv_post(x)
380
+ # Final tanh activation
381
+ if self.use_tanh_at_final:
382
+ x = torch.tanh(x)
383
+ else:
384
+ x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
385
+
386
+ return x
387
+
388
+ def remove_weight_norm(self):
389
+ try:
390
+ print("Removing weight norm...")
391
+ for l in self.ups:
392
+ for l_i in l:
393
+ remove_weight_norm(l_i)
394
+ for l in self.resblocks:
395
+ l.remove_weight_norm()
396
+ remove_weight_norm(self.conv_pre)
397
+ remove_weight_norm(self.conv_post)
398
+ except ValueError:
399
+ print("[INFO] Model already removed weight norm. Skipping!")
400
+ pass
401
+
402
+ # Additional methods for huggingface_hub support
403
+ def _save_pretrained(self, save_directory: Path) -> None:
404
+ """Save weights and config.json from a Pytorch model to a local directory."""
405
+
406
+ model_path = save_directory / "bigvgan_generator.pt"
407
+ torch.save({"generator": self.state_dict()}, model_path)
408
+
409
+ config_path = save_directory / "config.json"
410
+ with open(config_path, "w") as config_file:
411
+ json.dump(self.h, config_file, indent=4)
412
+
413
+ @classmethod
414
+ def _from_pretrained(
415
+ cls,
416
+ *,
417
+ model_id: str,
418
+ revision: str,
419
+ cache_dir: str,
420
+ force_download: bool,
421
+ proxies: Optional[Dict],
422
+ resume_download: bool,
423
+ local_files_only: bool,
424
+ token: Union[str, bool, None],
425
+ map_location: str = "cpu", # Additional argument
426
+ strict: bool = False, # Additional argument
427
+ use_cuda_kernel: bool = False,
428
+ **model_kwargs,
429
+ ):
430
+ """Load Pytorch pretrained weights and return the loaded model."""
431
+
432
+ # Download and load hyperparameters (h) used by BigVGAN
433
+ if os.path.isdir(model_id):
434
+ print("Loading config.json from local directory")
435
+ config_file = os.path.join(model_id, "config.json")
436
+ else:
437
+ config_file = hf_hub_download(
438
+ repo_id=model_id,
439
+ filename="config.json",
440
+ revision=revision,
441
+ cache_dir=cache_dir,
442
+ force_download=force_download,
443
+ proxies=proxies,
444
+ resume_download=resume_download,
445
+ token=token,
446
+ local_files_only=local_files_only,
447
+ )
448
+ h = load_hparams_from_json(config_file)
449
+
450
+ # instantiate BigVGAN using h
451
+ if use_cuda_kernel:
452
+ print(
453
+ f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
454
+ )
455
+ print(
456
+ f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
457
+ )
458
+ print(
459
+ f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
460
+ )
461
+ model = cls(h, use_cuda_kernel=use_cuda_kernel)
462
+
463
+ # Download and load pretrained generator weight
464
+ if os.path.isdir(model_id):
465
+ print("Loading weights from local directory")
466
+ model_file = os.path.join(model_id, "bigvgan_generator.pt")
467
+ else:
468
+ print(f"Loading weights from {model_id}")
469
+ model_file = hf_hub_download(
470
+ repo_id=model_id,
471
+ filename="bigvgan_generator.pt",
472
+ revision=revision,
473
+ cache_dir=cache_dir,
474
+ force_download=force_download,
475
+ proxies=proxies,
476
+ resume_download=resume_download,
477
+ token=token,
478
+ local_files_only=local_files_only,
479
+ )
480
+
481
+ checkpoint_dict = torch.load(model_file, map_location=map_location)
482
+
483
+ try:
484
+ model.load_state_dict(checkpoint_dict["generator"])
485
+ except RuntimeError:
486
+ print(
487
+ f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
488
+ )
489
+ model.remove_weight_norm()
490
+ model.load_state_dict(checkpoint_dict["generator"])
491
+
492
+ return model
modules/bigvgan/config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 80,
43
+ "num_freq": 1025,
44
+ "n_fft": 1024,
45
+ "hop_size": 256,
46
+ "win_size": 1024,
47
+
48
+ "sampling_rate": 22050,
49
+
50
+ "fmin": 0,
51
+ "fmax": null,
52
+ "fmax_for_loss": null,
53
+
54
+ "normalize_volume": true,
55
+
56
+ "num_workers": 4,
57
+
58
+ "dist_config": {
59
+ "dist_backend": "nccl",
60
+ "dist_url": "tcp://localhost:54321",
61
+ "world_size": 1
62
+ }
63
+ }
modules/bigvgan/env.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import os
5
+ import shutil
6
+
7
+
8
+ class AttrDict(dict):
9
+ def __init__(self, *args, **kwargs):
10
+ super(AttrDict, self).__init__(*args, **kwargs)
11
+ self.__dict__ = self
12
+
13
+
14
+ def build_env(config, config_name, path):
15
+ t_path = os.path.join(path, config_name)
16
+ if config != t_path:
17
+ os.makedirs(path, exist_ok=True)
18
+ shutil.copyfile(config, os.path.join(path, config_name))
modules/bigvgan/meldataset.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import math
8
+ import os
9
+ import random
10
+ import torch
11
+ import torch.utils.data
12
+ import numpy as np
13
+ from librosa.util import normalize
14
+ from scipy.io.wavfile import read
15
+ from librosa.filters import mel as librosa_mel_fn
16
+ import pathlib
17
+ from tqdm import tqdm
18
+
19
+ MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
20
+
21
+
22
+ def load_wav(full_path, sr_target):
23
+ sampling_rate, data = read(full_path)
24
+ if sampling_rate != sr_target:
25
+ raise RuntimeError(
26
+ f"Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {sr_target} Hz"
27
+ )
28
+ return data, sampling_rate
29
+
30
+
31
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
32
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
33
+
34
+
35
+ def dynamic_range_decompression(x, C=1):
36
+ return np.exp(x) / C
37
+
38
+
39
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
40
+ return torch.log(torch.clamp(x, min=clip_val) * C)
41
+
42
+
43
+ def dynamic_range_decompression_torch(x, C=1):
44
+ return torch.exp(x) / C
45
+
46
+
47
+ def spectral_normalize_torch(magnitudes):
48
+ return dynamic_range_compression_torch(magnitudes)
49
+
50
+
51
+ def spectral_de_normalize_torch(magnitudes):
52
+ return dynamic_range_decompression_torch(magnitudes)
53
+
54
+
55
+ mel_basis_cache = {}
56
+ hann_window_cache = {}
57
+
58
+
59
+ def mel_spectrogram(
60
+ y: torch.Tensor,
61
+ n_fft: int,
62
+ num_mels: int,
63
+ sampling_rate: int,
64
+ hop_size: int,
65
+ win_size: int,
66
+ fmin: int,
67
+ fmax: int = None,
68
+ center: bool = False,
69
+ ) -> torch.Tensor:
70
+ """
71
+ Calculate the mel spectrogram of an input signal.
72
+ This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
73
+
74
+ Args:
75
+ y (torch.Tensor): Input signal.
76
+ n_fft (int): FFT size.
77
+ num_mels (int): Number of mel bins.
78
+ sampling_rate (int): Sampling rate of the input signal.
79
+ hop_size (int): Hop size for STFT.
80
+ win_size (int): Window size for STFT.
81
+ fmin (int): Minimum frequency for mel filterbank.
82
+ fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
83
+ center (bool): Whether to pad the input to center the frames. Default is False.
84
+
85
+ Returns:
86
+ torch.Tensor: Mel spectrogram.
87
+ """
88
+ if torch.min(y) < -1.0:
89
+ print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
90
+ if torch.max(y) > 1.0:
91
+ print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
92
+
93
+ device = y.device
94
+ key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
95
+
96
+ if key not in mel_basis_cache:
97
+ mel = librosa_mel_fn(
98
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
99
+ )
100
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
101
+ hann_window_cache[key] = torch.hann_window(win_size).to(device)
102
+
103
+ mel_basis = mel_basis_cache[key]
104
+ hann_window = hann_window_cache[key]
105
+
106
+ padding = (n_fft - hop_size) // 2
107
+ y = torch.nn.functional.pad(
108
+ y.unsqueeze(1), (padding, padding), mode="reflect"
109
+ ).squeeze(1)
110
+
111
+ spec = torch.stft(
112
+ y,
113
+ n_fft,
114
+ hop_length=hop_size,
115
+ win_length=win_size,
116
+ window=hann_window,
117
+ center=center,
118
+ pad_mode="reflect",
119
+ normalized=False,
120
+ onesided=True,
121
+ return_complex=True,
122
+ )
123
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
124
+
125
+ mel_spec = torch.matmul(mel_basis, spec)
126
+ mel_spec = spectral_normalize_torch(mel_spec)
127
+
128
+ return mel_spec
129
+
130
+
131
+ def get_mel_spectrogram(wav, h):
132
+ """
133
+ Generate mel spectrogram from a waveform using given hyperparameters.
134
+
135
+ Args:
136
+ wav (torch.Tensor): Input waveform.
137
+ h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
138
+
139
+ Returns:
140
+ torch.Tensor: Mel spectrogram.
141
+ """
142
+ return mel_spectrogram(
143
+ wav,
144
+ h.n_fft,
145
+ h.num_mels,
146
+ h.sampling_rate,
147
+ h.hop_size,
148
+ h.win_size,
149
+ h.fmin,
150
+ h.fmax,
151
+ )
152
+
153
+
154
+ def get_dataset_filelist(a):
155
+ training_files = []
156
+ validation_files = []
157
+ list_unseen_validation_files = []
158
+
159
+ with open(a.input_training_file, "r", encoding="utf-8") as fi:
160
+ training_files = [
161
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
162
+ for x in fi.read().split("\n")
163
+ if len(x) > 0
164
+ ]
165
+ print(f"first training file: {training_files[0]}")
166
+
167
+ with open(a.input_validation_file, "r", encoding="utf-8") as fi:
168
+ validation_files = [
169
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
170
+ for x in fi.read().split("\n")
171
+ if len(x) > 0
172
+ ]
173
+ print(f"first validation file: {validation_files[0]}")
174
+
175
+ for i in range(len(a.list_input_unseen_validation_file)):
176
+ with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi:
177
+ unseen_validation_files = [
178
+ os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav")
179
+ for x in fi.read().split("\n")
180
+ if len(x) > 0
181
+ ]
182
+ print(
183
+ f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
184
+ )
185
+ list_unseen_validation_files.append(unseen_validation_files)
186
+
187
+ return training_files, validation_files, list_unseen_validation_files
188
+
189
+
190
+ class MelDataset(torch.utils.data.Dataset):
191
+ def __init__(
192
+ self,
193
+ training_files,
194
+ hparams,
195
+ segment_size,
196
+ n_fft,
197
+ num_mels,
198
+ hop_size,
199
+ win_size,
200
+ sampling_rate,
201
+ fmin,
202
+ fmax,
203
+ split=True,
204
+ shuffle=True,
205
+ n_cache_reuse=1,
206
+ device=None,
207
+ fmax_loss=None,
208
+ fine_tuning=False,
209
+ base_mels_path=None,
210
+ is_seen=True,
211
+ ):
212
+ self.audio_files = training_files
213
+ random.seed(1234)
214
+ if shuffle:
215
+ random.shuffle(self.audio_files)
216
+ self.hparams = hparams
217
+ self.is_seen = is_seen
218
+ if self.is_seen:
219
+ self.name = pathlib.Path(self.audio_files[0]).parts[0]
220
+ else:
221
+ self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
222
+
223
+ self.segment_size = segment_size
224
+ self.sampling_rate = sampling_rate
225
+ self.split = split
226
+ self.n_fft = n_fft
227
+ self.num_mels = num_mels
228
+ self.hop_size = hop_size
229
+ self.win_size = win_size
230
+ self.fmin = fmin
231
+ self.fmax = fmax
232
+ self.fmax_loss = fmax_loss
233
+ self.cached_wav = None
234
+ self.n_cache_reuse = n_cache_reuse
235
+ self._cache_ref_count = 0
236
+ self.device = device
237
+ self.fine_tuning = fine_tuning
238
+ self.base_mels_path = base_mels_path
239
+
240
+ print("[INFO] checking dataset integrity...")
241
+ for i in tqdm(range(len(self.audio_files))):
242
+ assert os.path.exists(
243
+ self.audio_files[i]
244
+ ), f"{self.audio_files[i]} not found"
245
+
246
+ def __getitem__(self, index):
247
+ filename = self.audio_files[index]
248
+ if self._cache_ref_count == 0:
249
+ audio, sampling_rate = load_wav(filename, self.sampling_rate)
250
+ audio = audio / MAX_WAV_VALUE
251
+ if not self.fine_tuning:
252
+ audio = normalize(audio) * 0.95
253
+ self.cached_wav = audio
254
+ if sampling_rate != self.sampling_rate:
255
+ raise ValueError(
256
+ f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR"
257
+ )
258
+ self._cache_ref_count = self.n_cache_reuse
259
+ else:
260
+ audio = self.cached_wav
261
+ self._cache_ref_count -= 1
262
+
263
+ audio = torch.FloatTensor(audio)
264
+ audio = audio.unsqueeze(0)
265
+
266
+ if not self.fine_tuning:
267
+ if self.split:
268
+ if audio.size(1) >= self.segment_size:
269
+ max_audio_start = audio.size(1) - self.segment_size
270
+ audio_start = random.randint(0, max_audio_start)
271
+ audio = audio[:, audio_start : audio_start + self.segment_size]
272
+ else:
273
+ audio = torch.nn.functional.pad(
274
+ audio, (0, self.segment_size - audio.size(1)), "constant"
275
+ )
276
+
277
+ mel = mel_spectrogram(
278
+ audio,
279
+ self.n_fft,
280
+ self.num_mels,
281
+ self.sampling_rate,
282
+ self.hop_size,
283
+ self.win_size,
284
+ self.fmin,
285
+ self.fmax,
286
+ center=False,
287
+ )
288
+ else: # Validation step
289
+ # Match audio length to self.hop_size * n for evaluation
290
+ if (audio.size(1) % self.hop_size) != 0:
291
+ audio = audio[:, : -(audio.size(1) % self.hop_size)]
292
+ mel = mel_spectrogram(
293
+ audio,
294
+ self.n_fft,
295
+ self.num_mels,
296
+ self.sampling_rate,
297
+ self.hop_size,
298
+ self.win_size,
299
+ self.fmin,
300
+ self.fmax,
301
+ center=False,
302
+ )
303
+ assert (
304
+ audio.shape[1] == mel.shape[2] * self.hop_size
305
+ ), f"audio shape {audio.shape} mel shape {mel.shape}"
306
+
307
+ else:
308
+ mel = np.load(
309
+ os.path.join(
310
+ self.base_mels_path,
311
+ os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
312
+ )
313
+ )
314
+ mel = torch.from_numpy(mel)
315
+
316
+ if len(mel.shape) < 3:
317
+ mel = mel.unsqueeze(0)
318
+
319
+ if self.split:
320
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
321
+
322
+ if audio.size(1) >= self.segment_size:
323
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
324
+ mel = mel[:, :, mel_start : mel_start + frames_per_seg]
325
+ audio = audio[
326
+ :,
327
+ mel_start
328
+ * self.hop_size : (mel_start + frames_per_seg)
329
+ * self.hop_size,
330
+ ]
331
+ else:
332
+ mel = torch.nn.functional.pad(
333
+ mel, (0, frames_per_seg - mel.size(2)), "constant"
334
+ )
335
+ audio = torch.nn.functional.pad(
336
+ audio, (0, self.segment_size - audio.size(1)), "constant"
337
+ )
338
+
339
+ mel_loss = mel_spectrogram(
340
+ audio,
341
+ self.n_fft,
342
+ self.num_mels,
343
+ self.sampling_rate,
344
+ self.hop_size,
345
+ self.win_size,
346
+ self.fmin,
347
+ self.fmax_loss,
348
+ center=False,
349
+ )
350
+
351
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
352
+
353
+ def __len__(self):
354
+ return len(self.audio_files)
modules/bigvgan/utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import glob
5
+ import os
6
+ import matplotlib
7
+ import torch
8
+ from torch.nn.utils import weight_norm
9
+
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pylab as plt
12
+ from .meldataset import MAX_WAV_VALUE
13
+ from scipy.io.wavfile import write
14
+
15
+
16
+ def plot_spectrogram(spectrogram):
17
+ fig, ax = plt.subplots(figsize=(10, 2))
18
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
19
+ plt.colorbar(im, ax=ax)
20
+
21
+ fig.canvas.draw()
22
+ plt.close()
23
+
24
+ return fig
25
+
26
+
27
+ def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
28
+ fig, ax = plt.subplots(figsize=(10, 2))
29
+ im = ax.imshow(
30
+ spectrogram,
31
+ aspect="auto",
32
+ origin="lower",
33
+ interpolation="none",
34
+ vmin=1e-6,
35
+ vmax=clip_max,
36
+ )
37
+ plt.colorbar(im, ax=ax)
38
+
39
+ fig.canvas.draw()
40
+ plt.close()
41
+
42
+ return fig
43
+
44
+
45
+ def init_weights(m, mean=0.0, std=0.01):
46
+ classname = m.__class__.__name__
47
+ if classname.find("Conv") != -1:
48
+ m.weight.data.normal_(mean, std)
49
+
50
+
51
+ def apply_weight_norm(m):
52
+ classname = m.__class__.__name__
53
+ if classname.find("Conv") != -1:
54
+ weight_norm(m)
55
+
56
+
57
+ def get_padding(kernel_size, dilation=1):
58
+ return int((kernel_size * dilation - dilation) / 2)
59
+
60
+
61
+ def load_checkpoint(filepath, device):
62
+ assert os.path.isfile(filepath)
63
+ print(f"Loading '{filepath}'")
64
+ checkpoint_dict = torch.load(filepath, map_location=device)
65
+ print("Complete.")
66
+ return checkpoint_dict
67
+
68
+
69
+ def save_checkpoint(filepath, obj):
70
+ print(f"Saving checkpoint to {filepath}")
71
+ torch.save(obj, filepath)
72
+ print("Complete.")
73
+
74
+
75
+ def scan_checkpoint(cp_dir, prefix, renamed_file=None):
76
+ # Fallback to original scanning logic first
77
+ pattern = os.path.join(cp_dir, prefix + "????????")
78
+ cp_list = glob.glob(pattern)
79
+
80
+ if len(cp_list) > 0:
81
+ last_checkpoint_path = sorted(cp_list)[-1]
82
+ print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
83
+ return last_checkpoint_path
84
+
85
+ # If no pattern-based checkpoints are found, check for renamed file
86
+ if renamed_file:
87
+ renamed_path = os.path.join(cp_dir, renamed_file)
88
+ if os.path.isfile(renamed_path):
89
+ print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
90
+ return renamed_path
91
+
92
+ return None
93
+
94
+
95
+ def save_audio(audio, path, sr):
96
+ # wav: torch with 1d shape
97
+ audio = audio * MAX_WAV_VALUE
98
+ audio = audio.cpu().numpy().astype("int16")
99
+ write(path, sr, audio)
modules/campplus/DTDNN.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+
10
+ from modules.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, BasicResBlock, get_nonlinear
11
+
12
+
13
+ class FCM(nn.Module):
14
+ def __init__(self,
15
+ block=BasicResBlock,
16
+ num_blocks=[2, 2],
17
+ m_channels=32,
18
+ feat_dim=80):
19
+ super(FCM, self).__init__()
20
+ self.in_planes = m_channels
21
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
22
+ self.bn1 = nn.BatchNorm2d(m_channels)
23
+
24
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
25
+ self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
26
+
27
+ self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
28
+ self.bn2 = nn.BatchNorm2d(m_channels)
29
+ self.out_channels = m_channels * (feat_dim // 8)
30
+
31
+ def _make_layer(self, block, planes, num_blocks, stride):
32
+ strides = [stride] + [1] * (num_blocks - 1)
33
+ layers = []
34
+ for stride in strides:
35
+ layers.append(block(self.in_planes, planes, stride))
36
+ self.in_planes = planes * block.expansion
37
+ return nn.Sequential(*layers)
38
+
39
+ def forward(self, x):
40
+ x = x.unsqueeze(1)
41
+ out = F.relu(self.bn1(self.conv1(x)))
42
+ out = self.layer1(out)
43
+ out = self.layer2(out)
44
+ out = F.relu(self.bn2(self.conv2(out)))
45
+
46
+ shape = out.shape
47
+ out = out.reshape(shape[0], shape[1]*shape[2], shape[3])
48
+ return out
49
+
50
+ class CAMPPlus(nn.Module):
51
+ def __init__(self,
52
+ feat_dim=80,
53
+ embedding_size=512,
54
+ growth_rate=32,
55
+ bn_size=4,
56
+ init_channels=128,
57
+ config_str='batchnorm-relu',
58
+ memory_efficient=True):
59
+ super(CAMPPlus, self).__init__()
60
+
61
+ self.head = FCM(feat_dim=feat_dim)
62
+ channels = self.head.out_channels
63
+
64
+ self.xvector = nn.Sequential(
65
+ OrderedDict([
66
+
67
+ ('tdnn',
68
+ TDNNLayer(channels,
69
+ init_channels,
70
+ 5,
71
+ stride=2,
72
+ dilation=1,
73
+ padding=-1,
74
+ config_str=config_str)),
75
+ ]))
76
+ channels = init_channels
77
+ for i, (num_layers, kernel_size,
78
+ dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
79
+ block = CAMDenseTDNNBlock(num_layers=num_layers,
80
+ in_channels=channels,
81
+ out_channels=growth_rate,
82
+ bn_channels=bn_size * growth_rate,
83
+ kernel_size=kernel_size,
84
+ dilation=dilation,
85
+ config_str=config_str,
86
+ memory_efficient=memory_efficient)
87
+ self.xvector.add_module('block%d' % (i + 1), block)
88
+ channels = channels + num_layers * growth_rate
89
+ self.xvector.add_module(
90
+ 'transit%d' % (i + 1),
91
+ TransitLayer(channels,
92
+ channels // 2,
93
+ bias=False,
94
+ config_str=config_str))
95
+ channels //= 2
96
+
97
+ self.xvector.add_module(
98
+ 'out_nonlinear', get_nonlinear(config_str, channels))
99
+
100
+ # self.xvector.add_module('stats', StatsPool())
101
+ # self.xvector.add_module(
102
+ # 'dense',
103
+ # DenseLayer(channels * 2, embedding_size, config_str='batchnorm_'))
104
+ self.stats = StatsPool()
105
+ self.dense = DenseLayer(channels * 2, embedding_size, config_str='batchnorm_')
106
+
107
+ for m in self.modules():
108
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
109
+ nn.init.kaiming_normal_(m.weight.data)
110
+ if m.bias is not None:
111
+ nn.init.zeros_(m.bias)
112
+
113
+ def load_state_dict(self, state_dict, strict=True):
114
+ """
115
+ Custom load_state_dict that remaps keys from a previous version of the model where
116
+ stats and dense layers were part of xvector.
117
+ """
118
+ new_state_dict = {}
119
+
120
+ # Remap keys for compatibility
121
+ for key in state_dict.keys():
122
+ new_key = key
123
+ if key.startswith('xvector.stats'):
124
+ new_key = key.replace('xvector.stats', 'stats')
125
+ elif key.startswith('xvector.dense'):
126
+ new_key = key.replace('xvector.dense', 'dense')
127
+ new_state_dict[new_key] = state_dict[key]
128
+
129
+ # Call the original load_state_dict with the modified state_dict
130
+ super(CAMPPlus, self).load_state_dict(new_state_dict, strict)
131
+
132
+ def forward(self, x, x_lens=None):
133
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
134
+ x = self.head(x)
135
+ x = self.xvector(x)
136
+ x = self.stats(x, x_lens)
137
+ x = self.dense(x)
138
+ return x
modules/campplus/__init__.py ADDED
File without changes
modules/campplus/classifier.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from modules.campplus.layers import DenseLayer
9
+
10
+
11
+ class CosineClassifier(nn.Module):
12
+ def __init__(
13
+ self,
14
+ input_dim,
15
+ num_blocks=0,
16
+ inter_dim=512,
17
+ out_neurons=1000,
18
+ ):
19
+
20
+ super().__init__()
21
+ self.blocks = nn.ModuleList()
22
+
23
+ for index in range(num_blocks):
24
+ self.blocks.append(
25
+ DenseLayer(input_dim, inter_dim, config_str='batchnorm')
26
+ )
27
+ input_dim = inter_dim
28
+
29
+ self.weight = nn.Parameter(
30
+ torch.FloatTensor(out_neurons, input_dim)
31
+ )
32
+ nn.init.xavier_uniform_(self.weight)
33
+
34
+ def forward(self, x):
35
+ # x: [B, dim]
36
+ for layer in self.blocks:
37
+ x = layer(x)
38
+
39
+ # normalized
40
+ x = F.linear(F.normalize(x), F.normalize(self.weight))
41
+ return x
42
+
43
+ class LinearClassifier(nn.Module):
44
+ def __init__(
45
+ self,
46
+ input_dim,
47
+ num_blocks=0,
48
+ inter_dim=512,
49
+ out_neurons=1000,
50
+ ):
51
+
52
+ super().__init__()
53
+ self.blocks = nn.ModuleList()
54
+
55
+ self.nonlinear = nn.ReLU(inplace=True)
56
+ for index in range(num_blocks):
57
+ self.blocks.append(
58
+ DenseLayer(input_dim, inter_dim, bias=True)
59
+ )
60
+ input_dim = inter_dim
61
+
62
+ self.linear = nn.Linear(input_dim, out_neurons, bias=True)
63
+
64
+ def forward(self, x):
65
+ # x: [B, dim]
66
+ x = self.nonlinear(x)
67
+ for layer in self.blocks:
68
+ x = layer(x)
69
+ x = self.linear(x)
70
+ return x
modules/campplus/layers.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint as cp
7
+ from torch import nn
8
+
9
+
10
+ def get_nonlinear(config_str, channels):
11
+ nonlinear = nn.Sequential()
12
+ for name in config_str.split('-'):
13
+ if name == 'relu':
14
+ nonlinear.add_module('relu', nn.ReLU(inplace=True))
15
+ elif name == 'prelu':
16
+ nonlinear.add_module('prelu', nn.PReLU(channels))
17
+ elif name == 'batchnorm':
18
+ nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
19
+ elif name == 'batchnorm_':
20
+ nonlinear.add_module('batchnorm',
21
+ nn.BatchNorm1d(channels, affine=False))
22
+ else:
23
+ raise ValueError('Unexpected module ({}).'.format(name))
24
+ return nonlinear
25
+
26
+ def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
27
+ mean = x.mean(dim=dim)
28
+ std = x.std(dim=dim, unbiased=unbiased)
29
+ stats = torch.cat([mean, std], dim=-1)
30
+ if keepdim:
31
+ stats = stats.unsqueeze(dim=dim)
32
+ return stats
33
+
34
+ def masked_statistics_pooling(x, x_lens, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
35
+ stats = []
36
+ for i, x_len in enumerate(x_lens):
37
+ x_i = x[i, :, :x_len]
38
+ mean = x_i.mean(dim=dim)
39
+ std = x_i.std(dim=dim, unbiased=unbiased)
40
+ stats.append(torch.cat([mean, std], dim=-1))
41
+ stats = torch.stack(stats, dim=0)
42
+ if keepdim:
43
+ stats = stats.unsqueeze(dim=dim)
44
+ return stats
45
+
46
+
47
+ class StatsPool(nn.Module):
48
+ def forward(self, x, x_lens=None):
49
+ if x_lens is not None:
50
+ return masked_statistics_pooling(x, x_lens)
51
+ return statistics_pooling(x)
52
+
53
+
54
+ class TDNNLayer(nn.Module):
55
+ def __init__(self,
56
+ in_channels,
57
+ out_channels,
58
+ kernel_size,
59
+ stride=1,
60
+ padding=0,
61
+ dilation=1,
62
+ bias=False,
63
+ config_str='batchnorm-relu'):
64
+ super(TDNNLayer, self).__init__()
65
+ if padding < 0:
66
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
67
+ kernel_size)
68
+ padding = (kernel_size - 1) // 2 * dilation
69
+ self.linear = nn.Conv1d(in_channels,
70
+ out_channels,
71
+ kernel_size,
72
+ stride=stride,
73
+ padding=padding,
74
+ dilation=dilation,
75
+ bias=bias)
76
+ self.nonlinear = get_nonlinear(config_str, out_channels)
77
+
78
+ def forward(self, x):
79
+ x = self.linear(x)
80
+ x = self.nonlinear(x)
81
+ return x
82
+
83
+
84
+ class CAMLayer(nn.Module):
85
+ def __init__(self,
86
+ bn_channels,
87
+ out_channels,
88
+ kernel_size,
89
+ stride,
90
+ padding,
91
+ dilation,
92
+ bias,
93
+ reduction=2):
94
+ super(CAMLayer, self).__init__()
95
+ self.linear_local = nn.Conv1d(bn_channels,
96
+ out_channels,
97
+ kernel_size,
98
+ stride=stride,
99
+ padding=padding,
100
+ dilation=dilation,
101
+ bias=bias)
102
+ self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
103
+ self.relu = nn.ReLU(inplace=True)
104
+ self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
105
+ self.sigmoid = nn.Sigmoid()
106
+
107
+ def forward(self, x):
108
+ y = self.linear_local(x)
109
+ context = x.mean(-1, keepdim=True)+self.seg_pooling(x)
110
+ context = self.relu(self.linear1(context))
111
+ m = self.sigmoid(self.linear2(context))
112
+ return y*m
113
+
114
+ def seg_pooling(self, x, seg_len=100, stype='avg'):
115
+ if stype == 'avg':
116
+ seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
117
+ elif stype == 'max':
118
+ seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
119
+ else:
120
+ raise ValueError('Wrong segment pooling type.')
121
+ shape = seg.shape
122
+ seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
123
+ seg = seg[..., :x.shape[-1]]
124
+ return seg
125
+
126
+
127
+ class CAMDenseTDNNLayer(nn.Module):
128
+ def __init__(self,
129
+ in_channels,
130
+ out_channels,
131
+ bn_channels,
132
+ kernel_size,
133
+ stride=1,
134
+ dilation=1,
135
+ bias=False,
136
+ config_str='batchnorm-relu',
137
+ memory_efficient=False):
138
+ super(CAMDenseTDNNLayer, self).__init__()
139
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
140
+ kernel_size)
141
+ padding = (kernel_size - 1) // 2 * dilation
142
+ self.memory_efficient = memory_efficient
143
+ self.nonlinear1 = get_nonlinear(config_str, in_channels)
144
+ self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
145
+ self.nonlinear2 = get_nonlinear(config_str, bn_channels)
146
+ self.cam_layer = CAMLayer(bn_channels,
147
+ out_channels,
148
+ kernel_size,
149
+ stride=stride,
150
+ padding=padding,
151
+ dilation=dilation,
152
+ bias=bias)
153
+
154
+ def bn_function(self, x):
155
+ return self.linear1(self.nonlinear1(x))
156
+
157
+ def forward(self, x):
158
+ if self.training and self.memory_efficient:
159
+ x = cp.checkpoint(self.bn_function, x)
160
+ else:
161
+ x = self.bn_function(x)
162
+ x = self.cam_layer(self.nonlinear2(x))
163
+ return x
164
+
165
+
166
+ class CAMDenseTDNNBlock(nn.ModuleList):
167
+ def __init__(self,
168
+ num_layers,
169
+ in_channels,
170
+ out_channels,
171
+ bn_channels,
172
+ kernel_size,
173
+ stride=1,
174
+ dilation=1,
175
+ bias=False,
176
+ config_str='batchnorm-relu',
177
+ memory_efficient=False):
178
+ super(CAMDenseTDNNBlock, self).__init__()
179
+ for i in range(num_layers):
180
+ layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels,
181
+ out_channels=out_channels,
182
+ bn_channels=bn_channels,
183
+ kernel_size=kernel_size,
184
+ stride=stride,
185
+ dilation=dilation,
186
+ bias=bias,
187
+ config_str=config_str,
188
+ memory_efficient=memory_efficient)
189
+ self.add_module('tdnnd%d' % (i + 1), layer)
190
+
191
+ def forward(self, x):
192
+ for layer in self:
193
+ x = torch.cat([x, layer(x)], dim=1)
194
+ return x
195
+
196
+
197
+ class TransitLayer(nn.Module):
198
+ def __init__(self,
199
+ in_channels,
200
+ out_channels,
201
+ bias=True,
202
+ config_str='batchnorm-relu'):
203
+ super(TransitLayer, self).__init__()
204
+ self.nonlinear = get_nonlinear(config_str, in_channels)
205
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
206
+
207
+ def forward(self, x):
208
+ x = self.nonlinear(x)
209
+ x = self.linear(x)
210
+ return x
211
+
212
+
213
+ class DenseLayer(nn.Module):
214
+ def __init__(self,
215
+ in_channels,
216
+ out_channels,
217
+ bias=False,
218
+ config_str='batchnorm-relu'):
219
+ super(DenseLayer, self).__init__()
220
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
221
+ self.nonlinear = get_nonlinear(config_str, out_channels)
222
+
223
+ def forward(self, x):
224
+ if len(x.shape) == 2:
225
+ x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
226
+ else:
227
+ x = self.linear(x)
228
+ x = self.nonlinear(x)
229
+ return x
230
+
231
+
232
+ class BasicResBlock(nn.Module):
233
+ expansion = 1
234
+
235
+ def __init__(self, in_planes, planes, stride=1):
236
+ super(BasicResBlock, self).__init__()
237
+ self.conv1 = nn.Conv2d(in_planes,
238
+ planes,
239
+ kernel_size=3,
240
+ stride=(stride, 1),
241
+ padding=1,
242
+ bias=False)
243
+ self.bn1 = nn.BatchNorm2d(planes)
244
+ self.conv2 = nn.Conv2d(planes,
245
+ planes,
246
+ kernel_size=3,
247
+ stride=1,
248
+ padding=1,
249
+ bias=False)
250
+ self.bn2 = nn.BatchNorm2d(planes)
251
+
252
+ self.shortcut = nn.Sequential()
253
+ if stride != 1 or in_planes != self.expansion * planes:
254
+ self.shortcut = nn.Sequential(
255
+ nn.Conv2d(in_planes,
256
+ self.expansion * planes,
257
+ kernel_size=1,
258
+ stride=(stride, 1),
259
+ bias=False),
260
+ nn.BatchNorm2d(self.expansion * planes))
261
+
262
+ def forward(self, x):
263
+ out = F.relu(self.bn1(self.conv1(x)))
264
+ out = self.bn2(self.conv2(out))
265
+ out += self.shortcut(x)
266
+ out = F.relu(out)
267
+ return out
modules/commons.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from munch import Munch
7
+ import json
8
+ import argparse
9
+
10
+ def str2bool(v):
11
+ if isinstance(v, bool):
12
+ return v
13
+ if v.lower() in ("yes", "true", "t", "y", "1"):
14
+ return True
15
+ elif v.lower() in ("no", "false", "f", "n", "0"):
16
+ return False
17
+ else:
18
+ raise argparse.ArgumentTypeError("Boolean value expected.")
19
+
20
+ class AttrDict(dict):
21
+ def __init__(self, *args, **kwargs):
22
+ super(AttrDict, self).__init__(*args, **kwargs)
23
+ self.__dict__ = self
24
+
25
+
26
+ def init_weights(m, mean=0.0, std=0.01):
27
+ classname = m.__class__.__name__
28
+ if classname.find("Conv") != -1:
29
+ m.weight.data.normal_(mean, std)
30
+
31
+
32
+ def get_padding(kernel_size, dilation=1):
33
+ return int((kernel_size * dilation - dilation) / 2)
34
+
35
+
36
+ def convert_pad_shape(pad_shape):
37
+ l = pad_shape[::-1]
38
+ pad_shape = [item for sublist in l for item in sublist]
39
+ return pad_shape
40
+
41
+
42
+ def intersperse(lst, item):
43
+ result = [item] * (len(lst) * 2 + 1)
44
+ result[1::2] = lst
45
+ return result
46
+
47
+
48
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
49
+ """KL(P||Q)"""
50
+ kl = (logs_q - logs_p) - 0.5
51
+ kl += (
52
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
53
+ )
54
+ return kl
55
+
56
+
57
+ def rand_gumbel(shape):
58
+ """Sample from the Gumbel distribution, protect from overflows."""
59
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
60
+ return -torch.log(-torch.log(uniform_samples))
61
+
62
+
63
+ def rand_gumbel_like(x):
64
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
65
+ return g
66
+
67
+
68
+ def slice_segments(x, ids_str, segment_size=4):
69
+ ret = torch.zeros_like(x[:, :, :segment_size])
70
+ for i in range(x.size(0)):
71
+ idx_str = ids_str[i]
72
+ idx_end = idx_str + segment_size
73
+ ret[i] = x[i, :, idx_str:idx_end]
74
+ return ret
75
+
76
+
77
+ def slice_segments_audio(x, ids_str, segment_size=4):
78
+ ret = torch.zeros_like(x[:, :segment_size])
79
+ for i in range(x.size(0)):
80
+ idx_str = ids_str[i]
81
+ idx_end = idx_str + segment_size
82
+ ret[i] = x[i, idx_str:idx_end]
83
+ return ret
84
+
85
+
86
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
87
+ b, d, t = x.size()
88
+ if x_lengths is None:
89
+ x_lengths = t
90
+ ids_str_max = x_lengths - segment_size + 1
91
+ ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
92
+ dtype=torch.long
93
+ )
94
+ ret = slice_segments(x, ids_str, segment_size)
95
+ return ret, ids_str
96
+
97
+
98
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
99
+ position = torch.arange(length, dtype=torch.float)
100
+ num_timescales = channels // 2
101
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
102
+ num_timescales - 1
103
+ )
104
+ inv_timescales = min_timescale * torch.exp(
105
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
106
+ )
107
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
108
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
109
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
110
+ signal = signal.view(1, channels, length)
111
+ return signal
112
+
113
+
114
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
115
+ b, channels, length = x.size()
116
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
117
+ return x + signal.to(dtype=x.dtype, device=x.device)
118
+
119
+
120
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
121
+ b, channels, length = x.size()
122
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
123
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
124
+
125
+
126
+ def subsequent_mask(length):
127
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
128
+ return mask
129
+
130
+
131
+ @torch.jit.script
132
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
133
+ n_channels_int = n_channels[0]
134
+ in_act = input_a + input_b
135
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
136
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
137
+ acts = t_act * s_act
138
+ return acts
139
+
140
+
141
+ def convert_pad_shape(pad_shape):
142
+ l = pad_shape[::-1]
143
+ pad_shape = [item for sublist in l for item in sublist]
144
+ return pad_shape
145
+
146
+
147
+ def shift_1d(x):
148
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
149
+ return x
150
+
151
+
152
+ def sequence_mask(length, max_length=None):
153
+ if max_length is None:
154
+ max_length = length.max()
155
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
156
+ return x.unsqueeze(0) < length.unsqueeze(1)
157
+
158
+
159
+ def avg_with_mask(x, mask):
160
+ assert mask.dtype == torch.float, "Mask should be float"
161
+
162
+ if mask.ndim == 2:
163
+ mask = mask.unsqueeze(1)
164
+
165
+ if mask.shape[1] == 1:
166
+ mask = mask.expand_as(x)
167
+
168
+ return (x * mask).sum() / mask.sum()
169
+
170
+
171
+ def generate_path(duration, mask):
172
+ """
173
+ duration: [b, 1, t_x]
174
+ mask: [b, 1, t_y, t_x]
175
+ """
176
+ device = duration.device
177
+
178
+ b, _, t_y, t_x = mask.shape
179
+ cum_duration = torch.cumsum(duration, -1)
180
+
181
+ cum_duration_flat = cum_duration.view(b * t_x)
182
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
183
+ path = path.view(b, t_x, t_y)
184
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
185
+ path = path.unsqueeze(1).transpose(2, 3) * mask
186
+ return path
187
+
188
+
189
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
190
+ if isinstance(parameters, torch.Tensor):
191
+ parameters = [parameters]
192
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
193
+ norm_type = float(norm_type)
194
+ if clip_value is not None:
195
+ clip_value = float(clip_value)
196
+
197
+ total_norm = 0
198
+ for p in parameters:
199
+ param_norm = p.grad.data.norm(norm_type)
200
+ total_norm += param_norm.item() ** norm_type
201
+ if clip_value is not None:
202
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
203
+ total_norm = total_norm ** (1.0 / norm_type)
204
+ return total_norm
205
+
206
+
207
+ def log_norm(x, mean=-4, std=4, dim=2):
208
+ """
209
+ normalized log mel -> mel -> norm -> log(norm)
210
+ """
211
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
212
+ return x
213
+
214
+
215
+ def load_F0_models(path):
216
+ # load F0 model
217
+ from .JDC.model import JDCNet
218
+
219
+ F0_model = JDCNet(num_class=1, seq_len=192)
220
+ params = torch.load(path, map_location="cpu")["net"]
221
+ F0_model.load_state_dict(params)
222
+ _ = F0_model.train()
223
+
224
+ return F0_model
225
+
226
+
227
+ def modify_w2v_forward(self, output_layer=15):
228
+ """
229
+ change forward method of w2v encoder to get its intermediate layer output
230
+ :param self:
231
+ :param layer:
232
+ :return:
233
+ """
234
+ from transformers.modeling_outputs import BaseModelOutput
235
+
236
+ def forward(
237
+ hidden_states,
238
+ attention_mask=None,
239
+ output_attentions=False,
240
+ output_hidden_states=False,
241
+ return_dict=True,
242
+ ):
243
+ all_hidden_states = () if output_hidden_states else None
244
+ all_self_attentions = () if output_attentions else None
245
+
246
+ conv_attention_mask = attention_mask
247
+ if attention_mask is not None:
248
+ # make sure padded tokens output 0
249
+ hidden_states = hidden_states.masked_fill(
250
+ ~attention_mask.bool().unsqueeze(-1), 0.0
251
+ )
252
+
253
+ # extend attention_mask
254
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(
255
+ dtype=hidden_states.dtype
256
+ )
257
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
258
+ attention_mask = attention_mask.expand(
259
+ attention_mask.shape[0],
260
+ 1,
261
+ attention_mask.shape[-1],
262
+ attention_mask.shape[-1],
263
+ )
264
+
265
+ hidden_states = self.dropout(hidden_states)
266
+
267
+ if self.embed_positions is not None:
268
+ relative_position_embeddings = self.embed_positions(hidden_states)
269
+ else:
270
+ relative_position_embeddings = None
271
+
272
+ deepspeed_zero3_is_enabled = False
273
+
274
+ for i, layer in enumerate(self.layers):
275
+ if output_hidden_states:
276
+ all_hidden_states = all_hidden_states + (hidden_states,)
277
+
278
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
279
+ dropout_probability = torch.rand([])
280
+
281
+ skip_the_layer = (
282
+ True
283
+ if self.training and (dropout_probability < self.config.layerdrop)
284
+ else False
285
+ )
286
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
287
+ # under deepspeed zero3 all gpus must run in sync
288
+ if self.gradient_checkpointing and self.training:
289
+ layer_outputs = self._gradient_checkpointing_func(
290
+ layer.__call__,
291
+ hidden_states,
292
+ attention_mask,
293
+ relative_position_embeddings,
294
+ output_attentions,
295
+ conv_attention_mask,
296
+ )
297
+ else:
298
+ layer_outputs = layer(
299
+ hidden_states,
300
+ attention_mask=attention_mask,
301
+ relative_position_embeddings=relative_position_embeddings,
302
+ output_attentions=output_attentions,
303
+ conv_attention_mask=conv_attention_mask,
304
+ )
305
+ hidden_states = layer_outputs[0]
306
+
307
+ if skip_the_layer:
308
+ layer_outputs = (None, None)
309
+
310
+ if output_attentions:
311
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
312
+
313
+ if i == output_layer - 1:
314
+ break
315
+
316
+ if output_hidden_states:
317
+ all_hidden_states = all_hidden_states + (hidden_states,)
318
+
319
+ if not return_dict:
320
+ return tuple(
321
+ v
322
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
323
+ if v is not None
324
+ )
325
+ return BaseModelOutput(
326
+ last_hidden_state=hidden_states,
327
+ hidden_states=all_hidden_states,
328
+ attentions=all_self_attentions,
329
+ )
330
+
331
+ return forward
332
+
333
+
334
+ MATPLOTLIB_FLAG = False
335
+
336
+
337
+ def plot_spectrogram_to_numpy(spectrogram):
338
+ global MATPLOTLIB_FLAG
339
+ if not MATPLOTLIB_FLAG:
340
+ import matplotlib
341
+ import logging
342
+
343
+ matplotlib.use("Agg")
344
+ MATPLOTLIB_FLAG = True
345
+ mpl_logger = logging.getLogger("matplotlib")
346
+ mpl_logger.setLevel(logging.WARNING)
347
+ import matplotlib.pylab as plt
348
+ import numpy as np
349
+
350
+ fig, ax = plt.subplots(figsize=(10, 2))
351
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
352
+ plt.colorbar(im, ax=ax)
353
+ plt.xlabel("Frames")
354
+ plt.ylabel("Channels")
355
+ plt.tight_layout()
356
+
357
+ fig.canvas.draw()
358
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
359
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
360
+ plt.close()
361
+ return data
362
+
363
+
364
+ def normalize_f0(f0_sequence):
365
+ # Remove unvoiced frames (replace with -1)
366
+ voiced_indices = np.where(f0_sequence > 0)[0]
367
+ f0_voiced = f0_sequence[voiced_indices]
368
+
369
+ # Convert to log scale
370
+ log_f0 = np.log2(f0_voiced)
371
+
372
+ # Calculate mean and standard deviation
373
+ mean_f0 = np.mean(log_f0)
374
+ std_f0 = np.std(log_f0)
375
+
376
+ # Normalize the F0 sequence
377
+ normalized_f0 = (log_f0 - mean_f0) / std_f0
378
+
379
+ # Create the normalized F0 sequence with unvoiced frames
380
+ normalized_sequence = np.zeros_like(f0_sequence)
381
+ normalized_sequence[voiced_indices] = normalized_f0
382
+ normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
383
+
384
+ return normalized_sequence
385
+
386
+
387
+ def build_model(args, stage="DiT"):
388
+ if stage == "DiT":
389
+ from modules.flow_matching import CFM
390
+ from modules.length_regulator import InterpolateRegulator
391
+
392
+ length_regulator = InterpolateRegulator(
393
+ channels=args.length_regulator.channels,
394
+ sampling_ratios=args.length_regulator.sampling_ratios,
395
+ is_discrete=args.length_regulator.is_discrete,
396
+ in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
397
+ codebook_size=args.length_regulator.content_codebook_size,
398
+ f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
399
+ n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
400
+ )
401
+ cfm = CFM(args)
402
+ nets = Munch(
403
+ cfm=cfm,
404
+ length_regulator=length_regulator,
405
+ )
406
+ else:
407
+ raise ValueError(f"Unknown stage: {stage}")
408
+
409
+ return nets
410
+
411
+
412
+ def load_checkpoint(
413
+ model,
414
+ optimizer,
415
+ path,
416
+ load_only_params=True,
417
+ ignore_modules=[],
418
+ is_distributed=False,
419
+ load_ema=False,
420
+ ):
421
+ state = torch.load(path, map_location="cpu")
422
+ params = state["net"]
423
+ if load_ema and "ema" in state:
424
+ print("Loading EMA")
425
+ for key in model:
426
+ i = 0
427
+ for param_name in params[key]:
428
+ if "input_pos" in param_name:
429
+ continue
430
+ assert params[key][param_name].shape == state["ema"][key][0][i].shape
431
+ params[key][param_name] = state["ema"][key][0][i].clone()
432
+ i += 1
433
+ for key in model:
434
+ if key in params and key not in ignore_modules:
435
+ if not is_distributed:
436
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
437
+ for k in list(params[key].keys()):
438
+ if k.startswith("module."):
439
+ params[key][k[len("module.") :]] = params[key][k]
440
+ del params[key][k]
441
+ model_state_dict = model[key].state_dict()
442
+ # 过滤出形状匹配的键值对
443
+ filtered_state_dict = {
444
+ k: v
445
+ for k, v in params[key].items()
446
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
447
+ }
448
+ skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
449
+ if skipped_keys:
450
+ print(
451
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
452
+ )
453
+ print("%s loaded" % key)
454
+ model[key].load_state_dict(filtered_state_dict, strict=False)
455
+ _ = [model[key].eval() for key in model]
456
+
457
+ if not load_only_params:
458
+ epoch = state["epoch"] + 1
459
+ iters = state["iters"]
460
+ optimizer.load_state_dict(state["optimizer"])
461
+ optimizer.load_scheduler_state_dict(state["scheduler"])
462
+
463
+ else:
464
+ epoch = 0
465
+ iters = 0
466
+
467
+ return model, optimizer, epoch, iters
468
+
469
+
470
+ def recursive_munch(d):
471
+ if isinstance(d, dict):
472
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
473
+ elif isinstance(d, list):
474
+ return [recursive_munch(v) for v in d]
475
+ else:
476
+ return d
modules/diffusion_transformer.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import math
4
+
5
+ # from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
6
+ from modules.wavenet import WN
7
+ from modules.commons import sequence_mask
8
+
9
+ from torch.nn.utils import weight_norm
10
+
11
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
12
+ # All rights reserved.
13
+
14
+ # This source code is licensed under the license found in the
15
+ # LICENSE file in the root directory of this source tree.
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch import Tensor
22
+ from torch.nn import functional as F
23
+
24
+
25
+ def find_multiple(n: int, k: int) -> int:
26
+ if n % k == 0:
27
+ return n
28
+ return n + k - (n % k)
29
+
30
+ class AdaptiveLayerNorm(nn.Module):
31
+ r"""Adaptive Layer Normalization"""
32
+
33
+ def __init__(self, d_model, norm) -> None:
34
+ super(AdaptiveLayerNorm, self).__init__()
35
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
36
+ self.norm = norm
37
+ self.d_model = d_model
38
+ self.eps = self.norm.eps
39
+
40
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
41
+ if embedding is None:
42
+ return self.norm(input)
43
+ weight, bias = torch.split(
44
+ self.project_layer(embedding),
45
+ split_size_or_sections=self.d_model,
46
+ dim=-1,
47
+ )
48
+ return weight * self.norm(input) + bias
49
+
50
+
51
+ @dataclass
52
+ class ModelArgs:
53
+ block_size: int = 2048
54
+ vocab_size: int = 32000
55
+ n_layer: int = 32
56
+ n_head: int = 32
57
+ dim: int = 4096
58
+ intermediate_size: int = None
59
+ n_local_heads: int = -1
60
+ head_dim: int = 64
61
+ rope_base: float = 10000
62
+ norm_eps: float = 1e-5
63
+ has_cross_attention: bool = False
64
+ context_dim: int = 0
65
+ uvit_skip_connection: bool = False
66
+ time_as_token: bool = False
67
+
68
+ def __post_init__(self):
69
+ if self.n_local_heads == -1:
70
+ self.n_local_heads = self.n_head
71
+ if self.intermediate_size is None:
72
+ hidden_dim = 4 * self.dim
73
+ n_hidden = int(2 * hidden_dim / 3)
74
+ self.intermediate_size = find_multiple(n_hidden, 256)
75
+ # self.head_dim = self.dim // self.n_head
76
+
77
+ class Transformer(nn.Module):
78
+ def __init__(self, config: ModelArgs) -> None:
79
+ super().__init__()
80
+ self.config = config
81
+
82
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
83
+ self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
84
+
85
+ self.freqs_cis: Optional[Tensor] = None
86
+ self.mask_cache: Optional[Tensor] = None
87
+ self.max_batch_size = -1
88
+ self.max_seq_length = -1
89
+
90
+ def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=False):
91
+ if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
92
+ return
93
+ head_dim = self.config.dim // self.config.n_head
94
+ max_seq_length = find_multiple(max_seq_length, 8)
95
+ self.max_seq_length = max_seq_length
96
+ self.max_batch_size = max_batch_size
97
+ dtype = self.norm.project_layer.weight.dtype
98
+ device = self.norm.project_layer.weight.device
99
+
100
+ self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
101
+ self.config.rope_base, dtype).to(device)
102
+ self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
103
+ self.use_kv_cache = use_kv_cache
104
+ self.uvit_skip_connection = self.config.uvit_skip_connection
105
+ if self.uvit_skip_connection:
106
+ self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
107
+ self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
108
+ else:
109
+ self.layers_emit_skip = []
110
+ self.layers_receive_skip = []
111
+
112
+ def forward(self,
113
+ x: Tensor,
114
+ c: Tensor,
115
+ input_pos: Optional[Tensor] = None,
116
+ mask: Optional[Tensor] = None,
117
+ context: Optional[Tensor] = None,
118
+ context_input_pos: Optional[Tensor] = None,
119
+ cross_attention_mask: Optional[Tensor] = None,
120
+ ) -> Tensor:
121
+ assert self.freqs_cis is not None, "Caches must be initialized first"
122
+ if mask is None: # in case of non-causal model
123
+ if not self.training and self.use_kv_cache:
124
+ mask = self.causal_mask[None, None, input_pos]
125
+ else:
126
+ mask = self.causal_mask[None, None, input_pos]
127
+ mask = mask[..., input_pos]
128
+ freqs_cis = self.freqs_cis[input_pos]
129
+ if context is not None:
130
+ context_freqs_cis = self.freqs_cis[context_input_pos]
131
+ else:
132
+ context_freqs_cis = None
133
+ skip_in_x_list = []
134
+ for i, layer in enumerate(self.layers):
135
+ if self.uvit_skip_connection and i in self.layers_receive_skip:
136
+ skip_in_x = skip_in_x_list.pop(-1)
137
+ else:
138
+ skip_in_x = None
139
+ x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x)
140
+ if self.uvit_skip_connection and i in self.layers_emit_skip:
141
+ skip_in_x_list.append(x)
142
+ x = self.norm(x, c)
143
+ return x
144
+
145
+ @classmethod
146
+ def from_name(cls, name: str):
147
+ return cls(ModelArgs.from_name(name))
148
+
149
+
150
+ class TransformerBlock(nn.Module):
151
+ def __init__(self, config: ModelArgs) -> None:
152
+ super().__init__()
153
+ self.attention = Attention(config)
154
+ self.feed_forward = FeedForward(config)
155
+ self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
156
+ self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
157
+
158
+ if config.has_cross_attention:
159
+ self.has_cross_attention = True
160
+ self.cross_attention = Attention(config, is_cross_attention=True)
161
+ self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
162
+ else:
163
+ self.has_cross_attention = False
164
+
165
+ if config.uvit_skip_connection:
166
+ self.skip_in_linear = nn.Linear(config.dim * 2, config.dim)
167
+ self.uvit_skip_connection = True
168
+ else:
169
+ self.uvit_skip_connection = False
170
+
171
+ self.time_as_token = config.time_as_token
172
+
173
+ def forward(self,
174
+ x: Tensor,
175
+ c: Tensor,
176
+ input_pos: Tensor,
177
+ freqs_cis: Tensor,
178
+ mask: Tensor,
179
+ context: Optional[Tensor] = None,
180
+ context_freqs_cis: Optional[Tensor] = None,
181
+ cross_attention_mask: Optional[Tensor] = None,
182
+ skip_in_x: Optional[Tensor] = None,
183
+ ) -> Tensor:
184
+ c = None if self.time_as_token else c
185
+ if self.uvit_skip_connection and skip_in_x is not None:
186
+ x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1))
187
+ h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos)
188
+ if self.has_cross_attention:
189
+ h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis)
190
+ out = h + self.feed_forward(self.ffn_norm(h, c))
191
+ return out
192
+
193
+
194
+ class Attention(nn.Module):
195
+ def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
196
+ super().__init__()
197
+ assert config.dim % config.n_head == 0
198
+
199
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
200
+ # key, query, value projections for all heads, but in a batch
201
+ if is_cross_attention:
202
+ self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
203
+ self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
204
+ else:
205
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
206
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
207
+ self.kv_cache = None
208
+
209
+ self.n_head = config.n_head
210
+ self.head_dim = config.head_dim
211
+ self.n_local_heads = config.n_local_heads
212
+ self.dim = config.dim
213
+ # self._register_load_state_dict_pre_hook(self.load_hook)
214
+
215
+ # def load_hook(self, state_dict, prefix, *args):
216
+ # if prefix + "wq.weight" in state_dict:
217
+ # wq = state_dict.pop(prefix + "wq.weight")
218
+ # wk = state_dict.pop(prefix + "wk.weight")
219
+ # wv = state_dict.pop(prefix + "wv.weight")
220
+ # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
221
+
222
+ def forward(self,
223
+ x: Tensor,
224
+ freqs_cis: Tensor,
225
+ mask: Tensor,
226
+ input_pos: Optional[Tensor] = None,
227
+ context: Optional[Tensor] = None,
228
+ context_freqs_cis: Optional[Tensor] = None,
229
+ ) -> Tensor:
230
+ bsz, seqlen, _ = x.shape
231
+
232
+ kv_size = self.n_local_heads * self.head_dim
233
+ if context is None:
234
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
235
+ context_seqlen = seqlen
236
+ else:
237
+ q = self.wq(x)
238
+ k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
239
+ context_seqlen = context.shape[1]
240
+
241
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
242
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
243
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
244
+
245
+ q = apply_rotary_emb(q, freqs_cis)
246
+ k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
247
+
248
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
249
+
250
+ if self.kv_cache is not None:
251
+ k, v = self.kv_cache.update(input_pos, k, v)
252
+
253
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
254
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
255
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
256
+
257
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
258
+
259
+ y = self.wo(y)
260
+ return y
261
+
262
+
263
+ class FeedForward(nn.Module):
264
+ def __init__(self, config: ModelArgs) -> None:
265
+ super().__init__()
266
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
267
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
268
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
269
+
270
+ def forward(self, x: Tensor) -> Tensor:
271
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
272
+
273
+
274
+ class RMSNorm(nn.Module):
275
+ def __init__(self, dim: int, eps: float = 1e-5):
276
+ super().__init__()
277
+ self.eps = eps
278
+ self.weight = nn.Parameter(torch.ones(dim))
279
+
280
+ def _norm(self, x):
281
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
282
+
283
+ def forward(self, x: Tensor) -> Tensor:
284
+ output = self._norm(x.float()).type_as(x)
285
+ return output * self.weight
286
+
287
+
288
+ def precompute_freqs_cis(
289
+ seq_len: int, n_elem: int, base: int = 10000,
290
+ dtype: torch.dtype = torch.bfloat16
291
+ ) -> Tensor:
292
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
293
+ t = torch.arange(seq_len, device=freqs.device)
294
+ freqs = torch.outer(t, freqs)
295
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
296
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
297
+ return cache.to(dtype=dtype)
298
+
299
+
300
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
301
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
302
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
303
+ x_out2 = torch.stack(
304
+ [
305
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
306
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
307
+ ],
308
+ -1,
309
+ )
310
+
311
+ x_out2 = x_out2.flatten(3)
312
+ return x_out2.type_as(x)
313
+
314
+
315
+ def modulate(x, shift, scale):
316
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
317
+
318
+
319
+ #################################################################################
320
+ # Embedding Layers for Timesteps and Class Labels #
321
+ #################################################################################
322
+
323
+ class TimestepEmbedder(nn.Module):
324
+ """
325
+ Embeds scalar timesteps into vector representations.
326
+ """
327
+ def __init__(self, hidden_size, frequency_embedding_size=256):
328
+ super().__init__()
329
+ self.mlp = nn.Sequential(
330
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
331
+ nn.SiLU(),
332
+ nn.Linear(hidden_size, hidden_size, bias=True),
333
+ )
334
+ self.frequency_embedding_size = frequency_embedding_size
335
+ self.max_period = 10000
336
+ self.scale = 1000
337
+
338
+ half = frequency_embedding_size // 2
339
+ freqs = torch.exp(
340
+ -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
341
+ )
342
+ self.register_buffer("freqs", freqs)
343
+
344
+ def timestep_embedding(self, t):
345
+ """
346
+ Create sinusoidal timestep embeddings.
347
+ :param t: a 1-D Tensor of N indices, one per batch element.
348
+ These may be fractional.
349
+ :param dim: the dimension of the output.
350
+ :param max_period: controls the minimum frequency of the embeddings.
351
+ :return: an (N, D) Tensor of positional embeddings.
352
+ """
353
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
354
+
355
+ args = self.scale * t[:, None].float() * self.freqs[None]
356
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
357
+ if self.frequency_embedding_size % 2:
358
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
359
+ return embedding
360
+
361
+ def forward(self, t):
362
+ t_freq = self.timestep_embedding(t)
363
+ t_emb = self.mlp(t_freq)
364
+ return t_emb
365
+
366
+
367
+ class StyleEmbedder(nn.Module):
368
+ """
369
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
370
+ """
371
+ def __init__(self, input_size, hidden_size, dropout_prob):
372
+ super().__init__()
373
+ use_cfg_embedding = dropout_prob > 0
374
+ self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
375
+ self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
376
+ self.input_size = input_size
377
+ self.dropout_prob = dropout_prob
378
+
379
+ def forward(self, labels, train, force_drop_ids=None):
380
+ use_dropout = self.dropout_prob > 0
381
+ if (train and use_dropout) or (force_drop_ids is not None):
382
+ labels = self.token_drop(labels, force_drop_ids)
383
+ else:
384
+ labels = self.style_in(labels)
385
+ embeddings = labels
386
+ return embeddings
387
+
388
+ class FinalLayer(nn.Module):
389
+ """
390
+ The final layer of DiT.
391
+ """
392
+ def __init__(self, hidden_size, patch_size, out_channels):
393
+ super().__init__()
394
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
395
+ self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
396
+ self.adaLN_modulation = nn.Sequential(
397
+ nn.SiLU(),
398
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
399
+ )
400
+
401
+ def forward(self, x, c):
402
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
403
+ x = modulate(self.norm_final(x), shift, scale)
404
+ x = self.linear(x)
405
+ return x
406
+
407
+ class DiT(torch.nn.Module):
408
+ def __init__(
409
+ self,
410
+ args
411
+ ):
412
+ super(DiT, self).__init__()
413
+ self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
414
+ self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
415
+ self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
416
+ model_args = ModelArgs(
417
+ block_size=16384,#args.DiT.block_size,
418
+ n_layer=args.DiT.depth,
419
+ n_head=args.DiT.num_heads,
420
+ dim=args.DiT.hidden_dim,
421
+ head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
422
+ vocab_size=1024,
423
+ uvit_skip_connection=self.uvit_skip_connection,
424
+ time_as_token=self.time_as_token,
425
+ )
426
+ self.transformer = Transformer(model_args)
427
+ self.in_channels = args.DiT.in_channels
428
+ self.out_channels = args.DiT.in_channels
429
+ self.num_heads = args.DiT.num_heads
430
+
431
+ self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
432
+
433
+ self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
434
+ self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
435
+ self.content_dim = args.DiT.content_dim # for continuous content
436
+ self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
437
+ self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
438
+
439
+ self.is_causal = args.DiT.is_causal
440
+
441
+ self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
442
+
443
+ input_pos = torch.arange(16384)
444
+ self.register_buffer("input_pos", input_pos)
445
+
446
+ self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
447
+ if self.final_layer_type == 'wavenet':
448
+ self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
449
+ self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
450
+ self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
451
+ self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
452
+ kernel_size=args.wavenet.kernel_size,
453
+ dilation_rate=args.wavenet.dilation_rate,
454
+ n_layers=args.wavenet.num_layers,
455
+ gin_channels=args.wavenet.hidden_dim,
456
+ p_dropout=args.wavenet.p_dropout,
457
+ causal=False)
458
+ self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
459
+ self.res_projection = nn.Linear(args.DiT.hidden_dim,
460
+ args.wavenet.hidden_dim) # residual connection from tranformer output to final output
461
+ self.wavenet_style_condition = args.wavenet.style_condition
462
+ assert args.DiT.style_condition == args.wavenet.style_condition
463
+ else:
464
+ self.final_mlp = nn.Sequential(
465
+ nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
466
+ nn.SiLU(),
467
+ nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
468
+ )
469
+ self.transformer_style_condition = args.DiT.style_condition
470
+
471
+
472
+ self.class_dropout_prob = args.DiT.class_dropout_prob
473
+ self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
474
+
475
+ self.long_skip_connection = args.DiT.long_skip_connection
476
+ self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
477
+
478
+ self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
479
+ args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
480
+ args.DiT.hidden_dim)
481
+ if self.style_as_token:
482
+ self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
483
+
484
+ def setup_caches(self, max_batch_size, max_seq_length):
485
+ self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
486
+ def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False):
487
+ class_dropout = False
488
+ if self.training and torch.rand(1) < self.class_dropout_prob:
489
+ class_dropout = True
490
+ if not self.training and mask_content:
491
+ class_dropout = True
492
+ # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
493
+ cond_in_module = self.cond_projection
494
+
495
+ B, _, T = x.size()
496
+
497
+
498
+ t1 = self.t_embedder(t) # (N, D)
499
+
500
+ cond = cond_in_module(cond)
501
+
502
+ x = x.transpose(1, 2)
503
+ prompt_x = prompt_x.transpose(1, 2)
504
+
505
+ x_in = torch.cat([x, prompt_x, cond], dim=-1)
506
+ if self.transformer_style_condition and not self.style_as_token:
507
+ x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
508
+ if class_dropout:
509
+ x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
510
+ x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
511
+
512
+ if self.style_as_token:
513
+ style = self.style_in(style)
514
+ style = torch.zeros_like(style) if class_dropout else style
515
+ x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
516
+ if self.time_as_token:
517
+ x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
518
+ x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
519
+ input_pos = self.input_pos[:x_in.size(1)] # (T,)
520
+ x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
521
+ x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded)
522
+ x_res = x_res[:, 1:] if self.time_as_token else x_res
523
+ x_res = x_res[:, 1:] if self.style_as_token else x_res
524
+ if self.long_skip_connection:
525
+ x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
526
+ if self.final_layer_type == 'wavenet':
527
+ x = self.conv1(x_res)
528
+ x = x.transpose(1, 2)
529
+ t2 = self.t_embedder2(t)
530
+ x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
531
+ x_res) # long residual connection
532
+ x = self.final_layer(x, t1).transpose(1, 2)
533
+ x = self.conv2(x)
534
+ else:
535
+ x = self.final_mlp(x_res)
536
+ x = x.transpose(1, 2)
537
+ return x
modules/encodec.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Convolutional layers wrappers and utilities."""
8
+
9
+ import math
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import spectral_norm, weight_norm
17
+
18
+ import typing as tp
19
+
20
+ import einops
21
+
22
+
23
+ class ConvLayerNorm(nn.LayerNorm):
24
+ """
25
+ Convolution-friendly LayerNorm that moves channels to last dimensions
26
+ before running the normalization and moves them back to original position right after.
27
+ """
28
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
29
+ super().__init__(normalized_shape, **kwargs)
30
+
31
+ def forward(self, x):
32
+ x = einops.rearrange(x, 'b ... t -> b t ...')
33
+ x = super().forward(x)
34
+ x = einops.rearrange(x, 'b t ... -> b ... t')
35
+ return
36
+
37
+
38
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
39
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
40
+
41
+
42
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
43
+ assert norm in CONV_NORMALIZATIONS
44
+ if norm == 'weight_norm':
45
+ return weight_norm(module)
46
+ elif norm == 'spectral_norm':
47
+ return spectral_norm(module)
48
+ else:
49
+ # We already check was in CONV_NORMALIZATION, so any other choice
50
+ # doesn't need reparametrization.
51
+ return module
52
+
53
+
54
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
55
+ """Return the proper normalization module. If causal is True, this will ensure the returned
56
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
57
+ """
58
+ assert norm in CONV_NORMALIZATIONS
59
+ if norm == 'layer_norm':
60
+ assert isinstance(module, nn.modules.conv._ConvNd)
61
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
62
+ elif norm == 'time_group_norm':
63
+ if causal:
64
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
65
+ assert isinstance(module, nn.modules.conv._ConvNd)
66
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
67
+ else:
68
+ return nn.Identity()
69
+
70
+
71
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
72
+ padding_total: int = 0) -> int:
73
+ """See `pad_for_conv1d`.
74
+ """
75
+ length = x.shape[-1]
76
+ n_frames = (length - kernel_size + padding_total) / stride + 1
77
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
78
+ return ideal_length - length
79
+
80
+
81
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
82
+ """Pad for a convolution to make sure that the last window is full.
83
+ Extra padding is added at the end. This is required to ensure that we can rebuild
84
+ an output of the same length, as otherwise, even with padding, some time steps
85
+ might get removed.
86
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
87
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
88
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
89
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
90
+ 1 2 3 4 # once you removed padding, we are missing one time step !
91
+ """
92
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
93
+ return F.pad(x, (0, extra_padding))
94
+
95
+
96
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
97
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
98
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
99
+ """
100
+ length = x.shape[-1]
101
+ padding_left, padding_right = paddings
102
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103
+ if mode == 'reflect':
104
+ max_pad = max(padding_left, padding_right)
105
+ extra_pad = 0
106
+ if length <= max_pad:
107
+ extra_pad = max_pad - length + 1
108
+ x = F.pad(x, (0, extra_pad))
109
+ padded = F.pad(x, paddings, mode, value)
110
+ end = padded.shape[-1] - extra_pad
111
+ return padded[..., :end]
112
+ else:
113
+ return F.pad(x, paddings, mode, value)
114
+
115
+
116
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
117
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
118
+ padding_left, padding_right = paddings
119
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
120
+ assert (padding_left + padding_right) <= x.shape[-1]
121
+ end = x.shape[-1] - padding_right
122
+ return x[..., padding_left: end]
123
+
124
+
125
+ class NormConv1d(nn.Module):
126
+ """Wrapper around Conv1d and normalization applied to this conv
127
+ to provide a uniform interface across normalization approaches.
128
+ """
129
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
130
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131
+ super().__init__()
132
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
133
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
134
+ self.norm_type = norm
135
+
136
+ def forward(self, x):
137
+ x = self.conv(x)
138
+ x = self.norm(x)
139
+ return x
140
+
141
+
142
+ class NormConv2d(nn.Module):
143
+ """Wrapper around Conv2d and normalization applied to this conv
144
+ to provide a uniform interface across normalization approaches.
145
+ """
146
+ def __init__(self, *args, norm: str = 'none',
147
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148
+ super().__init__()
149
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
150
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
151
+ self.norm_type = norm
152
+
153
+ def forward(self, x):
154
+ x = self.conv(x)
155
+ x = self.norm(x)
156
+ return x
157
+
158
+
159
+ class NormConvTranspose1d(nn.Module):
160
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
161
+ to provide a uniform interface across normalization approaches.
162
+ """
163
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
164
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165
+ super().__init__()
166
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
167
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
168
+ self.norm_type = norm
169
+
170
+ def forward(self, x):
171
+ x = self.convtr(x)
172
+ x = self.norm(x)
173
+ return x
174
+
175
+
176
+ class NormConvTranspose2d(nn.Module):
177
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
178
+ to provide a uniform interface across normalization approaches.
179
+ """
180
+ def __init__(self, *args, norm: str = 'none',
181
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182
+ super().__init__()
183
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
184
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
185
+
186
+ def forward(self, x):
187
+ x = self.convtr(x)
188
+ x = self.norm(x)
189
+ return x
190
+
191
+
192
+ class SConv1d(nn.Module):
193
+ """Conv1d with some builtin handling of asymmetric or causal padding
194
+ and normalization.
195
+ """
196
+ def __init__(self, in_channels: int, out_channels: int,
197
+ kernel_size: int, stride: int = 1, dilation: int = 1,
198
+ groups: int = 1, bias: bool = True, causal: bool = False,
199
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
200
+ pad_mode: str = 'reflect', **kwargs):
201
+ super().__init__()
202
+ # warn user on unusual setup between dilation and stride
203
+ if stride > 1 and dilation > 1:
204
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
205
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
206
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
207
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
208
+ norm=norm, norm_kwargs=norm_kwargs)
209
+ self.causal = causal
210
+ self.pad_mode = pad_mode
211
+
212
+ def forward(self, x):
213
+ B, C, T = x.shape
214
+ kernel_size = self.conv.conv.kernel_size[0]
215
+ stride = self.conv.conv.stride[0]
216
+ dilation = self.conv.conv.dilation[0]
217
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
218
+ padding_total = kernel_size - stride
219
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
220
+ if self.causal:
221
+ # Left padding for causal
222
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
223
+ else:
224
+ # Asymmetric padding required for odd strides
225
+ padding_right = padding_total // 2
226
+ padding_left = padding_total - padding_right
227
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
228
+ return self.conv(x)
229
+
230
+
231
+ class SConvTranspose1d(nn.Module):
232
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
233
+ and normalization.
234
+ """
235
+ def __init__(self, in_channels: int, out_channels: int,
236
+ kernel_size: int, stride: int = 1, causal: bool = False,
237
+ norm: str = 'none', trim_right_ratio: float = 1.,
238
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
239
+ super().__init__()
240
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
241
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
242
+ self.causal = causal
243
+ self.trim_right_ratio = trim_right_ratio
244
+ assert self.causal or self.trim_right_ratio == 1., \
245
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
246
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
247
+
248
+ def forward(self, x):
249
+ kernel_size = self.convtr.convtr.kernel_size[0]
250
+ stride = self.convtr.convtr.stride[0]
251
+ padding_total = kernel_size - stride
252
+
253
+ y = self.convtr(x)
254
+
255
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
256
+ # removed at the very end, when keeping only the right length for the output,
257
+ # as removing it here would require also passing the length at the matching layer
258
+ # in the encoder.
259
+ if self.causal:
260
+ # Trim the padding on the right according to the specified ratio
261
+ # if trim_right_ratio = 1.0, trim everything from right
262
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
263
+ padding_left = padding_total - padding_right
264
+ y = unpad1d(y, (padding_left, padding_right))
265
+ else:
266
+ # Asymmetric padding required for odd strides
267
+ padding_right = padding_total // 2
268
+ padding_left = padding_total - padding_right
269
+ y = unpad1d(y, (padding_left, padding_right))
270
+ return y
271
+
272
+ class SLSTM(nn.Module):
273
+ """
274
+ LSTM without worrying about the hidden state, nor the layout of the data.
275
+ Expects input as convolutional layout.
276
+ """
277
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
278
+ super().__init__()
279
+ self.skip = skip
280
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
281
+ self.hidden = None
282
+
283
+ def forward(self, x):
284
+ x = x.permute(2, 0, 1)
285
+ if self.training:
286
+ y, _ = self.lstm(x)
287
+ else:
288
+ y, self.hidden = self.lstm(x, self.hidden)
289
+ if self.skip:
290
+ y = y + x
291
+ y = y.permute(1, 2, 0)
292
+ return y
modules/flow_matching.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from modules.diffusion_transformer import DiT
7
+ from modules.commons import sequence_mask
8
+
9
+ from tqdm import tqdm
10
+
11
+ class BASECFM(torch.nn.Module, ABC):
12
+ def __init__(
13
+ self,
14
+ args,
15
+ ):
16
+ super().__init__()
17
+ self.sigma_min = 1e-6
18
+
19
+ self.estimator = None
20
+
21
+ self.in_channels = args.DiT.in_channels
22
+
23
+ self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
24
+
25
+ if hasattr(args.DiT, 'zero_prompt_speech_token'):
26
+ self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
27
+ else:
28
+ self.zero_prompt_speech_token = False
29
+
30
+ @torch.inference_mode()
31
+ def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
32
+ """Forward diffusion
33
+
34
+ Args:
35
+ mu (torch.Tensor): output of encoder
36
+ shape: (batch_size, n_feats, mel_timesteps)
37
+ mask (torch.Tensor): output_mask
38
+ shape: (batch_size, 1, mel_timesteps)
39
+ n_timesteps (int): number of diffusion steps
40
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
41
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
42
+ shape: (batch_size, spk_emb_dim)
43
+ cond: Not used but kept for future purposes
44
+
45
+ Returns:
46
+ sample: generated mel-spectrogram
47
+ shape: (batch_size, n_feats, mel_timesteps)
48
+ """
49
+ B, T = mu.size(0), mu.size(1)
50
+ z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
51
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
52
+ # t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
53
+ return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
54
+
55
+ def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
56
+ """
57
+ Fixed euler solver for ODEs.
58
+ Args:
59
+ x (torch.Tensor): random noise
60
+ t_span (torch.Tensor): n_timesteps interpolated
61
+ shape: (n_timesteps + 1,)
62
+ mu (torch.Tensor): output of encoder
63
+ shape: (batch_size, n_feats, mel_timesteps)
64
+ mask (torch.Tensor): output_mask
65
+ shape: (batch_size, 1, mel_timesteps)
66
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
67
+ shape: (batch_size, spk_emb_dim)
68
+ cond: Not used but kept for future purposes
69
+ """
70
+ t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0]
71
+
72
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
73
+ # Or in future might add like a return_all_steps flag
74
+ sol = []
75
+ # apply prompt
76
+ prompt_len = prompt.size(-1)
77
+ prompt_x = torch.zeros_like(x)
78
+ prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
79
+ x[..., :prompt_len] = 0
80
+ if self.zero_prompt_speech_token:
81
+ mu[..., :prompt_len] = 0
82
+ for step in tqdm(range(1, len(t_span))):
83
+ dt = t_span[step] - t_span[step - 1]
84
+ if inference_cfg_rate > 0:
85
+ # Stack original and CFG (null) inputs for batched processing
86
+ stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0)
87
+ stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0)
88
+ stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0)
89
+ stacked_x = torch.cat([x, x], dim=0)
90
+ stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0)
91
+
92
+ # Perform a single forward pass for both original and CFG inputs
93
+ stacked_dphi_dt = self.estimator(
94
+ stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu,
95
+ )
96
+
97
+ # Split the output back into the original and CFG components
98
+ dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0)
99
+
100
+ # Apply CFG formula
101
+ dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt
102
+ else:
103
+ dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
104
+
105
+ x = x + dt * dphi_dt
106
+ t = t + dt
107
+ sol.append(x)
108
+ if step < len(t_span) - 1:
109
+ dt = t_span[step + 1] - t
110
+ x[:, :, :prompt_len] = 0
111
+
112
+ return sol[-1]
113
+ def forward(self, x1, x_lens, prompt_lens, mu, style):
114
+ """Computes diffusion loss
115
+
116
+ Args:
117
+ x1 (torch.Tensor): Target
118
+ shape: (batch_size, n_feats, mel_timesteps)
119
+ mask (torch.Tensor): target mask
120
+ shape: (batch_size, 1, mel_timesteps)
121
+ mu (torch.Tensor): output of encoder
122
+ shape: (batch_size, n_feats, mel_timesteps)
123
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
124
+ shape: (batch_size, spk_emb_dim)
125
+
126
+ Returns:
127
+ loss: conditional flow matching loss
128
+ y: conditional flow
129
+ shape: (batch_size, n_feats, mel_timesteps)
130
+ """
131
+ b, _, t = x1.shape
132
+
133
+ # random timestep
134
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
135
+ # sample noise p(x_0)
136
+ z = torch.randn_like(x1)
137
+
138
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
139
+ u = x1 - (1 - self.sigma_min) * z
140
+
141
+ prompt = torch.zeros_like(x1)
142
+ for bib in range(b):
143
+ prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
144
+ # range covered by prompt are set to 0
145
+ y[bib, :, :prompt_lens[bib]] = 0
146
+ if self.zero_prompt_speech_token:
147
+ mu[bib, :, :prompt_lens[bib]] = 0
148
+
149
+ estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens)
150
+ loss = 0
151
+ for bib in range(b):
152
+ loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
153
+ loss /= b
154
+
155
+ return loss, estimator_out + (1 - self.sigma_min) * z
156
+
157
+
158
+
159
+ class CFM(BASECFM):
160
+ def __init__(self, args):
161
+ super().__init__(
162
+ args
163
+ )
164
+ if args.dit_type == "DiT":
165
+ self.estimator = DiT(args)
166
+ else:
167
+ raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
modules/hifigan/__init__.py ADDED
File without changes
modules/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.utils import weight_norm
17
+
18
+
19
+ class ConvRNNF0Predictor(nn.Module):
20
+ def __init__(self,
21
+ num_class: int = 1,
22
+ in_channels: int = 80,
23
+ cond_channels: int = 512
24
+ ):
25
+ super().__init__()
26
+
27
+ self.num_class = num_class
28
+ self.condnet = nn.Sequential(
29
+ weight_norm(
30
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31
+ ),
32
+ nn.ELU(),
33
+ weight_norm(
34
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35
+ ),
36
+ nn.ELU(),
37
+ weight_norm(
38
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39
+ ),
40
+ nn.ELU(),
41
+ weight_norm(
42
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43
+ ),
44
+ nn.ELU(),
45
+ weight_norm(
46
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47
+ ),
48
+ nn.ELU(),
49
+ )
50
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = self.condnet(x)
54
+ x = x.transpose(1, 2)
55
+ return torch.abs(self.classifier(x).squeeze(-1))
modules/hifigan/generator.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ import typing as tp
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ from torch.nn.utils import weight_norm
27
+ from torch.distributions.uniform import Uniform
28
+
29
+ from torch import sin
30
+ from torch.nn.parameter import Parameter
31
+
32
+
33
+ """hifigan based generator implementation.
34
+
35
+ This code is modified from https://github.com/jik876/hifi-gan
36
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
37
+ https://github.com/NVIDIA/BigVGAN
38
+
39
+ """
40
+ class Snake(nn.Module):
41
+ '''
42
+ Implementation of a sine-based periodic activation function
43
+ Shape:
44
+ - Input: (B, C, T)
45
+ - Output: (B, C, T), same shape as the input
46
+ Parameters:
47
+ - alpha - trainable parameter
48
+ References:
49
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
50
+ https://arxiv.org/abs/2006.08195
51
+ Examples:
52
+ >>> a1 = snake(256)
53
+ >>> x = torch.randn(256)
54
+ >>> x = a1(x)
55
+ '''
56
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
57
+ '''
58
+ Initialization.
59
+ INPUT:
60
+ - in_features: shape of the input
61
+ - alpha: trainable parameter
62
+ alpha is initialized to 1 by default, higher values = higher-frequency.
63
+ alpha will be trained along with the rest of your model.
64
+ '''
65
+ super(Snake, self).__init__()
66
+ self.in_features = in_features
67
+
68
+ # initialize alpha
69
+ self.alpha_logscale = alpha_logscale
70
+ if self.alpha_logscale: # log scale alphas initialized to zeros
71
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
72
+ else: # linear scale alphas initialized to ones
73
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
74
+
75
+ self.alpha.requires_grad = alpha_trainable
76
+
77
+ self.no_div_by_zero = 0.000000001
78
+
79
+ def forward(self, x):
80
+ '''
81
+ Forward pass of the function.
82
+ Applies the function to the input elementwise.
83
+ Snake ∶= x + 1/a * sin^2 (xa)
84
+ '''
85
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
86
+ if self.alpha_logscale:
87
+ alpha = torch.exp(alpha)
88
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
89
+
90
+ return x
91
+
92
+ def get_padding(kernel_size, dilation=1):
93
+ return int((kernel_size * dilation - dilation) / 2)
94
+
95
+
96
+ def init_weights(m, mean=0.0, std=0.01):
97
+ classname = m.__class__.__name__
98
+ if classname.find("Conv") != -1:
99
+ m.weight.data.normal_(mean, std)
100
+
101
+
102
+
103
+ class ResBlock(torch.nn.Module):
104
+ """Residual block module in HiFiGAN/BigVGAN."""
105
+ def __init__(
106
+ self,
107
+ channels: int = 512,
108
+ kernel_size: int = 3,
109
+ dilations: tp.List[int] = [1, 3, 5],
110
+ ):
111
+ super(ResBlock, self).__init__()
112
+ self.convs1 = nn.ModuleList()
113
+ self.convs2 = nn.ModuleList()
114
+
115
+ for dilation in dilations:
116
+ self.convs1.append(
117
+ weight_norm(
118
+ Conv1d(
119
+ channels,
120
+ channels,
121
+ kernel_size,
122
+ 1,
123
+ dilation=dilation,
124
+ padding=get_padding(kernel_size, dilation)
125
+ )
126
+ )
127
+ )
128
+ self.convs2.append(
129
+ weight_norm(
130
+ Conv1d(
131
+ channels,
132
+ channels,
133
+ kernel_size,
134
+ 1,
135
+ dilation=1,
136
+ padding=get_padding(kernel_size, 1)
137
+ )
138
+ )
139
+ )
140
+ self.convs1.apply(init_weights)
141
+ self.convs2.apply(init_weights)
142
+ self.activations1 = nn.ModuleList([
143
+ Snake(channels, alpha_logscale=False)
144
+ for _ in range(len(self.convs1))
145
+ ])
146
+ self.activations2 = nn.ModuleList([
147
+ Snake(channels, alpha_logscale=False)
148
+ for _ in range(len(self.convs2))
149
+ ])
150
+
151
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
152
+ for idx in range(len(self.convs1)):
153
+ xt = self.activations1[idx](x)
154
+ xt = self.convs1[idx](xt)
155
+ xt = self.activations2[idx](xt)
156
+ xt = self.convs2[idx](xt)
157
+ x = xt + x
158
+ return x
159
+
160
+ def remove_weight_norm(self):
161
+ for idx in range(len(self.convs1)):
162
+ remove_weight_norm(self.convs1[idx])
163
+ remove_weight_norm(self.convs2[idx])
164
+
165
+ class SineGen(torch.nn.Module):
166
+ """ Definition of sine generator
167
+ SineGen(samp_rate, harmonic_num = 0,
168
+ sine_amp = 0.1, noise_std = 0.003,
169
+ voiced_threshold = 0,
170
+ flag_for_pulse=False)
171
+ samp_rate: sampling rate in Hz
172
+ harmonic_num: number of harmonic overtones (default 0)
173
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
174
+ noise_std: std of Gaussian noise (default 0.003)
175
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
176
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
177
+ Note: when flag_for_pulse is True, the first time step of a voiced
178
+ segment is always sin(np.pi) or cos(0)
179
+ """
180
+
181
+ def __init__(self, samp_rate, harmonic_num=0,
182
+ sine_amp=0.1, noise_std=0.003,
183
+ voiced_threshold=0):
184
+ super(SineGen, self).__init__()
185
+ self.sine_amp = sine_amp
186
+ self.noise_std = noise_std
187
+ self.harmonic_num = harmonic_num
188
+ self.sampling_rate = samp_rate
189
+ self.voiced_threshold = voiced_threshold
190
+
191
+ def _f02uv(self, f0):
192
+ # generate uv signal
193
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
194
+ return uv
195
+
196
+ @torch.no_grad()
197
+ def forward(self, f0):
198
+ """
199
+ :param f0: [B, 1, sample_len], Hz
200
+ :return: [B, 1, sample_len]
201
+ """
202
+
203
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
204
+ for i in range(self.harmonic_num + 1):
205
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
206
+
207
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
208
+ u_dist = Uniform(low=-np.pi, high=np.pi)
209
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
210
+ phase_vec[:, 0, :] = 0
211
+
212
+ # generate sine waveforms
213
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
214
+
215
+ # generate uv signal
216
+ uv = self._f02uv(f0)
217
+
218
+ # noise: for unvoiced should be similar to sine_amp
219
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
220
+ # . for voiced regions is self.noise_std
221
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
222
+ noise = noise_amp * torch.randn_like(sine_waves)
223
+
224
+ # first: set the unvoiced part to 0 by uv
225
+ # then: additive noise
226
+ sine_waves = sine_waves * uv + noise
227
+ return sine_waves, uv, noise
228
+
229
+
230
+ class SourceModuleHnNSF(torch.nn.Module):
231
+ """ SourceModule for hn-nsf
232
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
233
+ add_noise_std=0.003, voiced_threshod=0)
234
+ sampling_rate: sampling_rate in Hz
235
+ harmonic_num: number of harmonic above F0 (default: 0)
236
+ sine_amp: amplitude of sine source signal (default: 0.1)
237
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
238
+ note that amplitude of noise in unvoiced is decided
239
+ by sine_amp
240
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
241
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
242
+ F0_sampled (batchsize, length, 1)
243
+ Sine_source (batchsize, length, 1)
244
+ noise_source (batchsize, length 1)
245
+ uv (batchsize, length, 1)
246
+ """
247
+
248
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
249
+ add_noise_std=0.003, voiced_threshod=0):
250
+ super(SourceModuleHnNSF, self).__init__()
251
+
252
+ self.sine_amp = sine_amp
253
+ self.noise_std = add_noise_std
254
+
255
+ # to produce sine waveforms
256
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
257
+ sine_amp, add_noise_std, voiced_threshod)
258
+
259
+ # to merge source harmonics into a single excitation
260
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
261
+ self.l_tanh = torch.nn.Tanh()
262
+
263
+ def forward(self, x):
264
+ """
265
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
266
+ F0_sampled (batchsize, length, 1)
267
+ Sine_source (batchsize, length, 1)
268
+ noise_source (batchsize, length 1)
269
+ """
270
+ # source for harmonic branch
271
+ with torch.no_grad():
272
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
273
+ sine_wavs = sine_wavs.transpose(1, 2)
274
+ uv = uv.transpose(1, 2)
275
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
276
+
277
+ # source for noise branch, in the same shape as uv
278
+ noise = torch.randn_like(uv) * self.sine_amp / 3
279
+ return sine_merge, noise, uv
280
+
281
+
282
+ class HiFTGenerator(nn.Module):
283
+ """
284
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
285
+ https://arxiv.org/abs/2309.09493
286
+ """
287
+ def __init__(
288
+ self,
289
+ in_channels: int = 80,
290
+ base_channels: int = 512,
291
+ nb_harmonics: int = 8,
292
+ sampling_rate: int = 22050,
293
+ nsf_alpha: float = 0.1,
294
+ nsf_sigma: float = 0.003,
295
+ nsf_voiced_threshold: float = 10,
296
+ upsample_rates: tp.List[int] = [8, 8],
297
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
298
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
299
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
300
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
301
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
302
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
303
+ lrelu_slope: float = 0.1,
304
+ audio_limit: float = 0.99,
305
+ f0_predictor: torch.nn.Module = None,
306
+ ):
307
+ super(HiFTGenerator, self).__init__()
308
+
309
+ self.out_channels = 1
310
+ self.nb_harmonics = nb_harmonics
311
+ self.sampling_rate = sampling_rate
312
+ self.istft_params = istft_params
313
+ self.lrelu_slope = lrelu_slope
314
+ self.audio_limit = audio_limit
315
+
316
+ self.num_kernels = len(resblock_kernel_sizes)
317
+ self.num_upsamples = len(upsample_rates)
318
+ self.m_source = SourceModuleHnNSF(
319
+ sampling_rate=sampling_rate,
320
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
321
+ harmonic_num=nb_harmonics,
322
+ sine_amp=nsf_alpha,
323
+ add_noise_std=nsf_sigma,
324
+ voiced_threshod=nsf_voiced_threshold)
325
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
326
+
327
+ self.conv_pre = weight_norm(
328
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
329
+ )
330
+
331
+ # Up
332
+ self.ups = nn.ModuleList()
333
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
334
+ self.ups.append(
335
+ weight_norm(
336
+ ConvTranspose1d(
337
+ base_channels // (2**i),
338
+ base_channels // (2**(i + 1)),
339
+ k,
340
+ u,
341
+ padding=(k - u) // 2,
342
+ )
343
+ )
344
+ )
345
+
346
+ # Down
347
+ self.source_downs = nn.ModuleList()
348
+ self.source_resblocks = nn.ModuleList()
349
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
350
+ downsample_cum_rates = np.cumprod(downsample_rates)
351
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
352
+ source_resblock_dilation_sizes)):
353
+ if u == 1:
354
+ self.source_downs.append(
355
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
356
+ )
357
+ else:
358
+ self.source_downs.append(
359
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
360
+ )
361
+
362
+ self.source_resblocks.append(
363
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
364
+ )
365
+
366
+ self.resblocks = nn.ModuleList()
367
+ for i in range(len(self.ups)):
368
+ ch = base_channels // (2**(i + 1))
369
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
370
+ self.resblocks.append(ResBlock(ch, k, d))
371
+
372
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
373
+ self.ups.apply(init_weights)
374
+ self.conv_post.apply(init_weights)
375
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
376
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
377
+ self.f0_predictor = f0_predictor
378
+
379
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
380
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
381
+
382
+ har_source, _, _ = self.m_source(f0)
383
+ return har_source.transpose(1, 2)
384
+
385
+ def _stft(self, x):
386
+ spec = torch.stft(
387
+ x,
388
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
389
+ return_complex=True)
390
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
391
+ return spec[..., 0], spec[..., 1]
392
+
393
+ def _istft(self, magnitude, phase):
394
+ magnitude = torch.clip(magnitude, max=1e2)
395
+ real = magnitude * torch.cos(phase)
396
+ img = magnitude * torch.sin(phase)
397
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
398
+ return inverse_transform
399
+
400
+ def forward(self, x: torch.Tensor, f0=None) -> torch.Tensor:
401
+ if f0 is None:
402
+ f0 = self.f0_predictor(x)
403
+ s = self._f02source(f0)
404
+
405
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
406
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
407
+
408
+ x = self.conv_pre(x)
409
+ for i in range(self.num_upsamples):
410
+ x = F.leaky_relu(x, self.lrelu_slope)
411
+ x = self.ups[i](x)
412
+
413
+ if i == self.num_upsamples - 1:
414
+ x = self.reflection_pad(x)
415
+
416
+ # fusion
417
+ si = self.source_downs[i](s_stft)
418
+ si = self.source_resblocks[i](si)
419
+ x = x + si
420
+
421
+ xs = None
422
+ for j in range(self.num_kernels):
423
+ if xs is None:
424
+ xs = self.resblocks[i * self.num_kernels + j](x)
425
+ else:
426
+ xs += self.resblocks[i * self.num_kernels + j](x)
427
+ x = xs / self.num_kernels
428
+
429
+ x = F.leaky_relu(x)
430
+ x = self.conv_post(x)
431
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
432
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
433
+
434
+ x = self._istft(magnitude, phase)
435
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
436
+ return x
437
+
438
+ def remove_weight_norm(self):
439
+ print('Removing weight norm...')
440
+ for l in self.ups:
441
+ remove_weight_norm(l)
442
+ for l in self.resblocks:
443
+ l.remove_weight_norm()
444
+ remove_weight_norm(self.conv_pre)
445
+ remove_weight_norm(self.conv_post)
446
+ self.source_module.remove_weight_norm()
447
+ for l in self.source_downs:
448
+ remove_weight_norm(l)
449
+ for l in self.source_resblocks:
450
+ l.remove_weight_norm()
451
+
452
+ @torch.inference_mode()
453
+ def inference(self, mel: torch.Tensor, f0=None) -> torch.Tensor:
454
+ return self.forward(x=mel, f0=f0)
modules/length_regulator.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from modules.commons import sequence_mask
6
+ import numpy as np
7
+ from dac.nn.quantize import VectorQuantize
8
+
9
+ # f0_bin = 256
10
+ f0_max = 1100.0
11
+ f0_min = 50.0
12
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
13
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
14
+
15
+ def f0_to_coarse(f0, f0_bin):
16
+ f0_mel = 1127 * (1 + f0 / 700).log()
17
+ a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
18
+ b = f0_mel_min * a - 1.
19
+ f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
20
+ # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
21
+ f0_coarse = torch.round(f0_mel).long()
22
+ f0_coarse = f0_coarse * (f0_coarse > 0)
23
+ f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
24
+ f0_coarse = f0_coarse * (f0_coarse < f0_bin)
25
+ f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
26
+ return f0_coarse
27
+
28
+ class InterpolateRegulator(nn.Module):
29
+ def __init__(
30
+ self,
31
+ channels: int,
32
+ sampling_ratios: Tuple,
33
+ is_discrete: bool = False,
34
+ in_channels: int = None, # only applies to continuous input
35
+ vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
36
+ codebook_size: int = 1024, # for discrete only
37
+ out_channels: int = None,
38
+ groups: int = 1,
39
+ n_codebooks: int = 1, # number of codebooks
40
+ quantizer_dropout: float = 0.0, # dropout for quantizer
41
+ f0_condition: bool = False,
42
+ n_f0_bins: int = 512,
43
+ ):
44
+ super().__init__()
45
+ self.sampling_ratios = sampling_ratios
46
+ out_channels = out_channels or channels
47
+ model = nn.ModuleList([])
48
+ if len(sampling_ratios) > 0:
49
+ self.interpolate = True
50
+ for _ in sampling_ratios:
51
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
52
+ norm = nn.GroupNorm(groups, channels)
53
+ act = nn.Mish()
54
+ model.extend([module, norm, act])
55
+ else:
56
+ self.interpolate = False
57
+ model.append(
58
+ nn.Conv1d(channels, out_channels, 1, 1)
59
+ )
60
+ self.model = nn.Sequential(*model)
61
+ self.embedding = nn.Embedding(codebook_size, channels)
62
+ self.is_discrete = is_discrete
63
+
64
+ self.mask_token = nn.Parameter(torch.zeros(1, channels))
65
+
66
+ self.n_codebooks = n_codebooks
67
+ if n_codebooks > 1:
68
+ self.extra_codebooks = nn.ModuleList([
69
+ nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
70
+ ])
71
+ self.extra_codebook_mask_tokens = nn.ParameterList([
72
+ nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
73
+ ])
74
+ self.quantizer_dropout = quantizer_dropout
75
+
76
+ if f0_condition:
77
+ self.f0_embedding = nn.Embedding(n_f0_bins, channels)
78
+ self.f0_condition = f0_condition
79
+ self.n_f0_bins = n_f0_bins
80
+ self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
81
+ self.f0_mask = nn.Parameter(torch.zeros(1, channels))
82
+ else:
83
+ self.f0_condition = False
84
+
85
+ if not is_discrete:
86
+ self.content_in_proj = nn.Linear(in_channels, channels)
87
+ if vector_quantize:
88
+ self.vq = VectorQuantize(channels, codebook_size, 8)
89
+
90
+ def forward(self, x, ylens=None, n_quantizers=None, f0=None):
91
+ # apply token drop
92
+ if self.training:
93
+ n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
94
+ dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
95
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
96
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
97
+ n_quantizers = n_quantizers.to(x.device)
98
+ # decide whether to drop for each sample in batch
99
+ else:
100
+ n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
101
+ if self.is_discrete:
102
+ if self.n_codebooks > 1:
103
+ assert len(x.size()) == 3
104
+ x_emb = self.embedding(x[:, 0])
105
+ for i, emb in enumerate(self.extra_codebooks):
106
+ x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
107
+ # add mask token if not using this codebook
108
+ # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
109
+ x = x_emb
110
+ elif self.n_codebooks == 1:
111
+ if len(x.size()) == 2:
112
+ x = self.embedding(x)
113
+ else:
114
+ x = self.embedding(x[:, 0])
115
+ else:
116
+ x = self.content_in_proj(x)
117
+ # x in (B, T, D)
118
+ mask = sequence_mask(ylens).unsqueeze(-1)
119
+ if self.interpolate:
120
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
121
+ else:
122
+ x = x.transpose(1, 2).contiguous()
123
+ mask = mask[:, :x.size(2), :]
124
+ ylens = ylens.clamp(max=x.size(2)).long()
125
+ if self.f0_condition:
126
+ if f0 is None:
127
+ x = x + self.f0_mask.unsqueeze(-1)
128
+ else:
129
+ #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
130
+ quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
131
+ quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
132
+ f0_emb = self.f0_embedding(quantized_f0)
133
+ f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
134
+ x = x + f0_emb
135
+ out = self.model(x).transpose(1, 2).contiguous()
136
+ if hasattr(self, 'vq'):
137
+ out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
138
+ out_q = out_q.transpose(1, 2)
139
+ return out_q * mask, ylens, codes, commitment_loss, codebook_loss
140
+ olens = ylens
141
+ return out * mask, olens, None, None, None
modules/rmvpe.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import os
3
+ from typing import List, Optional, Tuple
4
+ import numpy as np
5
+ import torch
6
+
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from librosa.util import normalize, pad_center, tiny
10
+ from scipy.signal import get_window
11
+
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class STFT(torch.nn.Module):
18
+ def __init__(
19
+ self, filter_length=1024, hop_length=512, win_length=None, window="hann"
20
+ ):
21
+ """
22
+ This module implements an STFT using 1D convolution and 1D transpose convolutions.
23
+ This is a bit tricky so there are some cases that probably won't work as working
24
+ out the same sizes before and after in all overlap add setups is tough. Right now,
25
+ this code should work with hop lengths that are half the filter length (50% overlap
26
+ between frames).
27
+
28
+ Keyword Arguments:
29
+ filter_length {int} -- Length of filters used (default: {1024})
30
+ hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
31
+ win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
32
+ equals the filter length). (default: {None})
33
+ window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
34
+ (default: {'hann'})
35
+ """
36
+ super(STFT, self).__init__()
37
+ self.filter_length = filter_length
38
+ self.hop_length = hop_length
39
+ self.win_length = win_length if win_length else filter_length
40
+ self.window = window
41
+ self.forward_transform = None
42
+ self.pad_amount = int(self.filter_length / 2)
43
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
44
+
45
+ cutoff = int((self.filter_length / 2 + 1))
46
+ fourier_basis = np.vstack(
47
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
48
+ )
49
+ forward_basis = torch.FloatTensor(fourier_basis)
50
+ inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
51
+
52
+ assert filter_length >= self.win_length
53
+ # get window and zero center pad it to filter_length
54
+ fft_window = get_window(window, self.win_length, fftbins=True)
55
+ fft_window = pad_center(fft_window, size=filter_length)
56
+ fft_window = torch.from_numpy(fft_window).float()
57
+
58
+ # window the bases
59
+ forward_basis *= fft_window
60
+ inverse_basis = (inverse_basis.T * fft_window).T
61
+
62
+ self.register_buffer("forward_basis", forward_basis.float())
63
+ self.register_buffer("inverse_basis", inverse_basis.float())
64
+ self.register_buffer("fft_window", fft_window.float())
65
+
66
+ def transform(self, input_data, return_phase=False):
67
+ """Take input data (audio) to STFT domain.
68
+
69
+ Arguments:
70
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
71
+
72
+ Returns:
73
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
74
+ num_frequencies, num_frames)
75
+ phase {tensor} -- Phase of STFT with shape (num_batch,
76
+ num_frequencies, num_frames)
77
+ """
78
+ input_data = F.pad(
79
+ input_data,
80
+ (self.pad_amount, self.pad_amount),
81
+ mode="reflect",
82
+ )
83
+ forward_transform = input_data.unfold(
84
+ 1, self.filter_length, self.hop_length
85
+ ).permute(0, 2, 1)
86
+ forward_transform = torch.matmul(self.forward_basis, forward_transform)
87
+ cutoff = int((self.filter_length / 2) + 1)
88
+ real_part = forward_transform[:, :cutoff, :]
89
+ imag_part = forward_transform[:, cutoff:, :]
90
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
91
+ if return_phase:
92
+ phase = torch.atan2(imag_part.data, real_part.data)
93
+ return magnitude, phase
94
+ else:
95
+ return magnitude
96
+
97
+ def inverse(self, magnitude, phase):
98
+ """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
99
+ by the ```transform``` function.
100
+
101
+ Arguments:
102
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
103
+ num_frequencies, num_frames)
104
+ phase {tensor} -- Phase of STFT with shape (num_batch,
105
+ num_frequencies, num_frames)
106
+
107
+ Returns:
108
+ inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
109
+ shape (num_batch, num_samples)
110
+ """
111
+ cat = torch.cat(
112
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
113
+ )
114
+ fold = torch.nn.Fold(
115
+ output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length),
116
+ kernel_size=(1, self.filter_length),
117
+ stride=(1, self.hop_length),
118
+ )
119
+ inverse_transform = torch.matmul(self.inverse_basis, cat)
120
+ inverse_transform = fold(inverse_transform)[
121
+ :, 0, 0, self.pad_amount : -self.pad_amount
122
+ ]
123
+ window_square_sum = (
124
+ self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0)
125
+ )
126
+ window_square_sum = fold(window_square_sum)[
127
+ :, 0, 0, self.pad_amount : -self.pad_amount
128
+ ]
129
+ inverse_transform /= window_square_sum
130
+ return inverse_transform
131
+
132
+ def forward(self, input_data):
133
+ """Take input data (audio) to STFT domain and then back to audio.
134
+
135
+ Arguments:
136
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
137
+
138
+ Returns:
139
+ reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
140
+ shape (num_batch, num_samples)
141
+ """
142
+ self.magnitude, self.phase = self.transform(input_data, return_phase=True)
143
+ reconstruction = self.inverse(self.magnitude, self.phase)
144
+ return reconstruction
145
+
146
+
147
+ from time import time as ttime
148
+
149
+
150
+ class BiGRU(nn.Module):
151
+ def __init__(self, input_features, hidden_features, num_layers):
152
+ super(BiGRU, self).__init__()
153
+ self.gru = nn.GRU(
154
+ input_features,
155
+ hidden_features,
156
+ num_layers=num_layers,
157
+ batch_first=True,
158
+ bidirectional=True,
159
+ )
160
+
161
+ def forward(self, x):
162
+ return self.gru(x)[0]
163
+
164
+
165
+ class ConvBlockRes(nn.Module):
166
+ def __init__(self, in_channels, out_channels, momentum=0.01):
167
+ super(ConvBlockRes, self).__init__()
168
+ self.conv = nn.Sequential(
169
+ nn.Conv2d(
170
+ in_channels=in_channels,
171
+ out_channels=out_channels,
172
+ kernel_size=(3, 3),
173
+ stride=(1, 1),
174
+ padding=(1, 1),
175
+ bias=False,
176
+ ),
177
+ nn.BatchNorm2d(out_channels, momentum=momentum),
178
+ nn.ReLU(),
179
+ nn.Conv2d(
180
+ in_channels=out_channels,
181
+ out_channels=out_channels,
182
+ kernel_size=(3, 3),
183
+ stride=(1, 1),
184
+ padding=(1, 1),
185
+ bias=False,
186
+ ),
187
+ nn.BatchNorm2d(out_channels, momentum=momentum),
188
+ nn.ReLU(),
189
+ )
190
+ # self.shortcut:Optional[nn.Module] = None
191
+ if in_channels != out_channels:
192
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
193
+
194
+ def forward(self, x: torch.Tensor):
195
+ if not hasattr(self, "shortcut"):
196
+ return self.conv(x) + x
197
+ else:
198
+ return self.conv(x) + self.shortcut(x)
199
+
200
+
201
+ class Encoder(nn.Module):
202
+ def __init__(
203
+ self,
204
+ in_channels,
205
+ in_size,
206
+ n_encoders,
207
+ kernel_size,
208
+ n_blocks,
209
+ out_channels=16,
210
+ momentum=0.01,
211
+ ):
212
+ super(Encoder, self).__init__()
213
+ self.n_encoders = n_encoders
214
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
215
+ self.layers = nn.ModuleList()
216
+ self.latent_channels = []
217
+ for i in range(self.n_encoders):
218
+ self.layers.append(
219
+ ResEncoderBlock(
220
+ in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
221
+ )
222
+ )
223
+ self.latent_channels.append([out_channels, in_size])
224
+ in_channels = out_channels
225
+ out_channels *= 2
226
+ in_size //= 2
227
+ self.out_size = in_size
228
+ self.out_channel = out_channels
229
+
230
+ def forward(self, x: torch.Tensor):
231
+ concat_tensors: List[torch.Tensor] = []
232
+ x = self.bn(x)
233
+ for i, layer in enumerate(self.layers):
234
+ t, x = layer(x)
235
+ concat_tensors.append(t)
236
+ return x, concat_tensors
237
+
238
+
239
+ class ResEncoderBlock(nn.Module):
240
+ def __init__(
241
+ self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
242
+ ):
243
+ super(ResEncoderBlock, self).__init__()
244
+ self.n_blocks = n_blocks
245
+ self.conv = nn.ModuleList()
246
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
247
+ for i in range(n_blocks - 1):
248
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
249
+ self.kernel_size = kernel_size
250
+ if self.kernel_size is not None:
251
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
252
+
253
+ def forward(self, x):
254
+ for i, conv in enumerate(self.conv):
255
+ x = conv(x)
256
+ if self.kernel_size is not None:
257
+ return x, self.pool(x)
258
+ else:
259
+ return x
260
+
261
+
262
+ class Intermediate(nn.Module): #
263
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
264
+ super(Intermediate, self).__init__()
265
+ self.n_inters = n_inters
266
+ self.layers = nn.ModuleList()
267
+ self.layers.append(
268
+ ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
269
+ )
270
+ for i in range(self.n_inters - 1):
271
+ self.layers.append(
272
+ ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
273
+ )
274
+
275
+ def forward(self, x):
276
+ for i, layer in enumerate(self.layers):
277
+ x = layer(x)
278
+ return x
279
+
280
+
281
+ class ResDecoderBlock(nn.Module):
282
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
283
+ super(ResDecoderBlock, self).__init__()
284
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
285
+ self.n_blocks = n_blocks
286
+ self.conv1 = nn.Sequential(
287
+ nn.ConvTranspose2d(
288
+ in_channels=in_channels,
289
+ out_channels=out_channels,
290
+ kernel_size=(3, 3),
291
+ stride=stride,
292
+ padding=(1, 1),
293
+ output_padding=out_padding,
294
+ bias=False,
295
+ ),
296
+ nn.BatchNorm2d(out_channels, momentum=momentum),
297
+ nn.ReLU(),
298
+ )
299
+ self.conv2 = nn.ModuleList()
300
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
301
+ for i in range(n_blocks - 1):
302
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
303
+
304
+ def forward(self, x, concat_tensor):
305
+ x = self.conv1(x)
306
+ x = torch.cat((x, concat_tensor), dim=1)
307
+ for i, conv2 in enumerate(self.conv2):
308
+ x = conv2(x)
309
+ return x
310
+
311
+
312
+ class Decoder(nn.Module):
313
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
314
+ super(Decoder, self).__init__()
315
+ self.layers = nn.ModuleList()
316
+ self.n_decoders = n_decoders
317
+ for i in range(self.n_decoders):
318
+ out_channels = in_channels // 2
319
+ self.layers.append(
320
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
321
+ )
322
+ in_channels = out_channels
323
+
324
+ def forward(self, x: torch.Tensor, concat_tensors: List[torch.Tensor]):
325
+ for i, layer in enumerate(self.layers):
326
+ x = layer(x, concat_tensors[-1 - i])
327
+ return x
328
+
329
+
330
+ class DeepUnet(nn.Module):
331
+ def __init__(
332
+ self,
333
+ kernel_size,
334
+ n_blocks,
335
+ en_de_layers=5,
336
+ inter_layers=4,
337
+ in_channels=1,
338
+ en_out_channels=16,
339
+ ):
340
+ super(DeepUnet, self).__init__()
341
+ self.encoder = Encoder(
342
+ in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
343
+ )
344
+ self.intermediate = Intermediate(
345
+ self.encoder.out_channel // 2,
346
+ self.encoder.out_channel,
347
+ inter_layers,
348
+ n_blocks,
349
+ )
350
+ self.decoder = Decoder(
351
+ self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
352
+ )
353
+
354
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
355
+ x, concat_tensors = self.encoder(x)
356
+ x = self.intermediate(x)
357
+ x = self.decoder(x, concat_tensors)
358
+ return x
359
+
360
+
361
+ class E2E(nn.Module):
362
+ def __init__(
363
+ self,
364
+ n_blocks,
365
+ n_gru,
366
+ kernel_size,
367
+ en_de_layers=5,
368
+ inter_layers=4,
369
+ in_channels=1,
370
+ en_out_channels=16,
371
+ ):
372
+ super(E2E, self).__init__()
373
+ self.unet = DeepUnet(
374
+ kernel_size,
375
+ n_blocks,
376
+ en_de_layers,
377
+ inter_layers,
378
+ in_channels,
379
+ en_out_channels,
380
+ )
381
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
382
+ if n_gru:
383
+ self.fc = nn.Sequential(
384
+ BiGRU(3 * 128, 256, n_gru),
385
+ nn.Linear(512, 360),
386
+ nn.Dropout(0.25),
387
+ nn.Sigmoid(),
388
+ )
389
+ else:
390
+ self.fc = nn.Sequential(
391
+ nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
392
+ )
393
+
394
+ def forward(self, mel):
395
+ # print(mel.shape)
396
+ mel = mel.transpose(-1, -2).unsqueeze(1)
397
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
398
+ x = self.fc(x)
399
+ # print(x.shape)
400
+ return x
401
+
402
+
403
+ from librosa.filters import mel
404
+
405
+
406
+ class MelSpectrogram(torch.nn.Module):
407
+ def __init__(
408
+ self,
409
+ is_half,
410
+ n_mel_channels,
411
+ sampling_rate,
412
+ win_length,
413
+ hop_length,
414
+ n_fft=None,
415
+ mel_fmin=0,
416
+ mel_fmax=None,
417
+ clamp=1e-5,
418
+ ):
419
+ super().__init__()
420
+ n_fft = win_length if n_fft is None else n_fft
421
+ self.hann_window = {}
422
+ mel_basis = mel(
423
+ sr=sampling_rate,
424
+ n_fft=n_fft,
425
+ n_mels=n_mel_channels,
426
+ fmin=mel_fmin,
427
+ fmax=mel_fmax,
428
+ htk=True,
429
+ )
430
+ mel_basis = torch.from_numpy(mel_basis).float()
431
+ self.register_buffer("mel_basis", mel_basis)
432
+ self.n_fft = win_length if n_fft is None else n_fft
433
+ self.hop_length = hop_length
434
+ self.win_length = win_length
435
+ self.sampling_rate = sampling_rate
436
+ self.n_mel_channels = n_mel_channels
437
+ self.clamp = clamp
438
+ self.is_half = is_half
439
+
440
+ def forward(self, audio, keyshift=0, speed=1, center=True):
441
+ factor = 2 ** (keyshift / 12)
442
+ n_fft_new = int(np.round(self.n_fft * factor))
443
+ win_length_new = int(np.round(self.win_length * factor))
444
+ hop_length_new = int(np.round(self.hop_length * speed))
445
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
446
+ if keyshift_key not in self.hann_window:
447
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
448
+ audio.device
449
+ )
450
+ if "privateuseone" in str(audio.device):
451
+ if not hasattr(self, "stft"):
452
+ self.stft = STFT(
453
+ filter_length=n_fft_new,
454
+ hop_length=hop_length_new,
455
+ win_length=win_length_new,
456
+ window="hann",
457
+ ).to(audio.device)
458
+ magnitude = self.stft.transform(audio)
459
+ else:
460
+ fft = torch.stft(
461
+ audio,
462
+ n_fft=n_fft_new,
463
+ hop_length=hop_length_new,
464
+ win_length=win_length_new,
465
+ window=self.hann_window[keyshift_key],
466
+ center=center,
467
+ return_complex=True,
468
+ )
469
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
470
+ if keyshift != 0:
471
+ size = self.n_fft // 2 + 1
472
+ resize = magnitude.size(1)
473
+ if resize < size:
474
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
475
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
476
+ mel_output = torch.matmul(self.mel_basis, magnitude)
477
+ if self.is_half == True:
478
+ mel_output = mel_output.half()
479
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
480
+ return log_mel_spec
481
+
482
+
483
+ class RMVPE:
484
+ def __init__(self, model_path: str, is_half, device=None, use_jit=False):
485
+ self.resample_kernel = {}
486
+ self.resample_kernel = {}
487
+ self.is_half = is_half
488
+ if device is None:
489
+ #device = "cuda:0" if torch.cuda.is_available() else "cpu"
490
+ if torch.cuda.is_available():
491
+ device = "cuda:0"
492
+ elif torch.backends.mps.is_available():
493
+ device = "mps"
494
+ else:
495
+ device = "cpu"
496
+ self.device = device
497
+ self.mel_extractor = MelSpectrogram(
498
+ is_half, 128, 16000, 1024, 160, None, 30, 8000
499
+ ).to(device)
500
+ if "privateuseone" in str(device):
501
+ import onnxruntime as ort
502
+
503
+ ort_session = ort.InferenceSession(
504
+ "%s/rmvpe.onnx" % os.environ["rmvpe_root"],
505
+ providers=["DmlExecutionProvider"],
506
+ )
507
+ self.model = ort_session
508
+ else:
509
+ if str(self.device) == "cuda":
510
+ self.device = torch.device("cuda:0")
511
+
512
+ def get_default_model():
513
+ model = E2E(4, 1, (2, 2))
514
+ ckpt = torch.load(model_path, map_location="cpu")
515
+ model.load_state_dict(ckpt)
516
+ model.eval()
517
+ if is_half:
518
+ model = model.half()
519
+ else:
520
+ model = model.float()
521
+ return model
522
+
523
+ self.model = get_default_model()
524
+
525
+ self.model = self.model.to(device)
526
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
527
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
528
+
529
+ def mel2hidden(self, mel):
530
+ with torch.no_grad():
531
+ n_frames = mel.shape[-1]
532
+ n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
533
+ if n_pad > 0:
534
+ mel = F.pad(mel, (0, n_pad), mode="constant")
535
+ if "privateuseone" in str(self.device):
536
+ onnx_input_name = self.model.get_inputs()[0].name
537
+ onnx_outputs_names = self.model.get_outputs()[0].name
538
+ hidden = self.model.run(
539
+ [onnx_outputs_names],
540
+ input_feed={onnx_input_name: mel.cpu().numpy()},
541
+ )[0]
542
+ else:
543
+ mel = mel.half() if self.is_half else mel.float()
544
+ hidden = self.model(mel)
545
+ return hidden[:, :n_frames]
546
+
547
+ def decode(self, hidden, thred=0.03):
548
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
549
+ f0 = 10 * (2 ** (cents_pred / 1200))
550
+ f0[f0 == 10] = 0
551
+ # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
552
+ return f0
553
+
554
+ def infer_from_audio(self, audio, thred=0.03):
555
+ # torch.cuda.synchronize()
556
+ # t0 = ttime()
557
+ if not torch.is_tensor(audio):
558
+ audio = torch.from_numpy(audio)
559
+ mel = self.mel_extractor(
560
+ audio.float().to(self.device).unsqueeze(0), center=True
561
+ )
562
+ # print(123123123,mel.device.type)
563
+ # torch.cuda.synchronize()
564
+ # t1 = ttime()
565
+ hidden = self.mel2hidden(mel)
566
+ # torch.cuda.synchronize()
567
+ # t2 = ttime()
568
+ # print(234234,hidden.device.type)
569
+ if "privateuseone" not in str(self.device):
570
+ hidden = hidden.squeeze(0).cpu().numpy()
571
+ else:
572
+ hidden = hidden[0]
573
+ if self.is_half == True:
574
+ hidden = hidden.astype("float32")
575
+
576
+ f0 = self.decode(hidden, thred=thred)
577
+ # torch.cuda.synchronize()
578
+ # t3 = ttime()
579
+ # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
580
+ return f0
581
+ def infer_from_audio_batch(self, audio, thred=0.03):
582
+ # torch.cuda.synchronize()
583
+ # t0 = ttime()
584
+ if not torch.is_tensor(audio):
585
+ audio = torch.from_numpy(audio)
586
+ mel = self.mel_extractor(
587
+ audio.float().to(self.device), center=True
588
+ )
589
+ # print(123123123,mel.device.type)
590
+ # torch.cuda.synchronize()
591
+ # t1 = ttime()
592
+ hidden = self.mel2hidden(mel)
593
+ # torch.cuda.synchronize()
594
+ # t2 = ttime()
595
+ # print(234234,hidden.device.type)
596
+ if "privateuseone" not in str(self.device):
597
+ hidden = hidden.cpu().numpy()
598
+ else:
599
+ pass
600
+ if self.is_half == True:
601
+ hidden = hidden.astype("float32")
602
+
603
+ f0s = []
604
+ for bib in range(hidden.shape[0]):
605
+ f0s.append(self.decode(hidden[bib], thred=thred))
606
+ f0s = np.stack(f0s)
607
+ f0s = torch.from_numpy(f0s).to(self.device)
608
+ # torch.cuda.synchronize()
609
+ # t3 = ttime()
610
+ # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
611
+ return f0s
612
+
613
+ def to_local_average_cents(self, salience, thred=0.05):
614
+ # t0 = ttime()
615
+ center = np.argmax(salience, axis=1) # 帧长#index
616
+ salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
617
+ # t1 = ttime()
618
+ center += 4
619
+ todo_salience = []
620
+ todo_cents_mapping = []
621
+ starts = center - 4
622
+ ends = center + 5
623
+ for idx in range(salience.shape[0]):
624
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
625
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
626
+ # t2 = ttime()
627
+ todo_salience = np.array(todo_salience) # 帧长,9
628
+ todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
629
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
630
+ weight_sum = np.sum(todo_salience, 1) # 帧长
631
+ devided = product_sum / weight_sum # 帧长
632
+ # t3 = ttime()
633
+ maxx = np.max(salience, axis=1) # 帧长
634
+ devided[maxx <= thred] = 0
635
+ # t4 = ttime()
636
+ # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
637
+ return devided
modules/v2/__init__.py ADDED
File without changes
modules/v2/ar.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import json
3
+ import math
4
+ from collections import OrderedDict
5
+ from functools import partial, wraps
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, List
9
+ from tqdm import tqdm
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from einops import rearrange
14
+ from torch import Tensor
15
+ from torch.nn import functional as F
16
+ from torch.utils.checkpoint import checkpoint
17
+
18
+
19
+ def find_multiple(n: int, k: int) -> int:
20
+ if n % k == 0:
21
+ return n
22
+ return n + k - (n % k)
23
+
24
+ def l2norm(t, groups = 1):
25
+ t = rearrange(t, '... (g d) -> ... g d', g = groups)
26
+ t = F.normalize(t, p = 2, dim = -1)
27
+ return rearrange(t, '... g d -> ... (g d)')
28
+
29
+ @dataclass
30
+ class BaseModelArgs:
31
+ model_type: str = "base"
32
+
33
+ vocab_size: int = 32000
34
+ n_layer: int = 32
35
+ n_head: int = 32
36
+ dim: int = 4096
37
+ intermediate_size: int = None
38
+ n_local_heads: int = -1
39
+ head_dim: int = 64
40
+ rope_base: float = 10000
41
+ norm_eps: float = 1e-5
42
+ max_seq_len: int = 4096
43
+ dropout: float = 0.0
44
+ tie_word_embeddings: bool = True
45
+ attention_qkv_bias: bool = False
46
+
47
+ # Gradient checkpointing
48
+ use_gradient_checkpointing: bool = False
49
+
50
+ # Initialize the model
51
+ initializer_range: float = 0.02
52
+
53
+ qk_norm: bool = False
54
+ layerscale: bool = False
55
+
56
+ def __post_init__(self):
57
+ if self.n_local_heads == -1:
58
+ self.n_local_heads = self.n_head
59
+ if self.intermediate_size is None:
60
+ hidden_dim = 4 * self.dim
61
+ n_hidden = int(2 * hidden_dim / 3)
62
+ self.intermediate_size = find_multiple(n_hidden, 256)
63
+ self.head_dim = self.dim // self.n_head
64
+
65
+ def save(self, path: str):
66
+ with open(path, "w") as f:
67
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
68
+
69
+
70
+ @dataclass
71
+ class NaiveModelArgs(BaseModelArgs):
72
+ model_type: str = "naive"
73
+
74
+
75
+ class KVCache(nn.Module):
76
+ def __init__(
77
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
78
+ ):
79
+ super().__init__()
80
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
81
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
82
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
83
+
84
+ def update(self, input_pos, k_val, v_val):
85
+ # input_pos: [S], k_val: [B, H, S, D]
86
+ assert input_pos.shape[0] == k_val.shape[2]
87
+
88
+ k_out = self.k_cache
89
+ v_out = self.v_cache
90
+ k_out[:, :, input_pos] = k_val
91
+ v_out[:, :, input_pos] = v_val
92
+
93
+ return k_out, v_out
94
+
95
+
96
+ @dataclass
97
+ class TransformerForwardResult:
98
+ token_logits: Tensor
99
+ token_targets: Tensor
100
+
101
+
102
+ @dataclass
103
+ class BaseTransformerForwardResult:
104
+ logits: Tensor
105
+ hidden_states: Tensor
106
+
107
+
108
+ class BaseTransformer(nn.Module):
109
+ def __init__(
110
+ self,
111
+ config: BaseModelArgs,
112
+ init_weights: bool = True,
113
+ ) -> None:
114
+ super().__init__()
115
+ self.config = config
116
+
117
+ # Slow transformer
118
+ self.embeddings = nn.Embedding(
119
+ config.vocab_size,
120
+ config.dim,
121
+ )
122
+ self.layers = nn.ModuleList(
123
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
124
+ )
125
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
126
+
127
+ if self.config.tie_word_embeddings is False:
128
+ self.output = nn.Linear(
129
+ config.dim,
130
+ config.vocab_size,
131
+ bias=False,
132
+ )
133
+
134
+ self.register_buffer(
135
+ "freqs_cis",
136
+ precompute_freqs_cis(
137
+ config.max_seq_len,
138
+ config.dim // config.n_head,
139
+ config.rope_base,
140
+ ),
141
+ persistent=False,
142
+ )
143
+ self.register_buffer(
144
+ "causal_mask",
145
+ torch.tril(
146
+ torch.ones(
147
+ config.max_seq_len,
148
+ config.max_seq_len,
149
+ dtype=torch.bool,
150
+ )
151
+ ),
152
+ persistent=False,
153
+ )
154
+
155
+ self.output = nn.Linear(
156
+ config.dim,
157
+ config.vocab_size,
158
+ bias=False,
159
+ )
160
+
161
+ # For kv cache
162
+ self.max_batch_size = -1
163
+ self.max_seq_len = -1
164
+
165
+ if init_weights:
166
+ self.apply(self._init_weights)
167
+
168
+ def setup_caches(
169
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda"
170
+ ):
171
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
172
+ return
173
+
174
+ head_dim = self.config.dim // self.config.n_head
175
+ max_seq_len = find_multiple(max_seq_len, 8)
176
+ self.max_seq_len = max_seq_len
177
+ self.max_batch_size = max_batch_size
178
+
179
+ for b in self.layers:
180
+ b.attention.kv_cache = KVCache(
181
+ max_batch_size,
182
+ max_seq_len,
183
+ self.config.n_local_heads,
184
+ head_dim,
185
+ dtype=dtype,
186
+ ).to(device)
187
+
188
+ def embed_base(self, x: Tensor, x_lens: Tensor) -> Tensor:
189
+ for bib in range(x.size(0)):
190
+ x[bib, x_lens[bib]:] = self.config.vocab_size - 1
191
+
192
+ x_emb = self.embeddings(x)
193
+ return x, x_emb
194
+
195
+ def forward(
196
+ self,
197
+ inp: Tensor,
198
+ key_padding_mask: Optional[Tensor] = None,
199
+ input_pos: Optional[Tensor] = None,
200
+ ) -> BaseTransformerForwardResult:
201
+ seq_len = inp.size(1)
202
+
203
+ # Here we want to merge the embeddings of the codebooks
204
+ # x = self.embed(inp)
205
+ x = inp.clone()
206
+
207
+ if input_pos is None:
208
+ freqs_cis = self.freqs_cis[:seq_len].repeat(inp.size(0), 1, 1, 1)
209
+ else:
210
+ freqs_cis = self.freqs_cis[input_pos]
211
+
212
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
213
+ # That is, FALSE means masked out
214
+ # To maintain consistency, key_padding_mask use TRUE to mask out
215
+ mask = None
216
+ if key_padding_mask is not None:
217
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
218
+ mask = mask & key_padding_mask[:, None, None, :].logical_not()
219
+
220
+ for layer in self.layers:
221
+ if self.config.use_gradient_checkpointing and self.training:
222
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
223
+ else:
224
+ x = layer(x, freqs_cis, mask)
225
+
226
+ # We got slow_out here
227
+ slow_out = self.norm(x)
228
+
229
+ if self.config.tie_word_embeddings:
230
+ token_logits = F.linear(slow_out, self.embeddings.weight)
231
+ else:
232
+ token_logits = self.output(slow_out)
233
+
234
+ return BaseTransformerForwardResult(
235
+ logits=token_logits,
236
+ hidden_states=x,
237
+ )
238
+
239
+ def forward_generate(
240
+ self,
241
+ inp: Tensor,
242
+ input_pos: Optional[Tensor] = None,
243
+ kv_pos: Optional[Tensor] = None,
244
+ return_all: bool = False,
245
+ ) -> BaseTransformerForwardResult:
246
+ # This is used for generation, optimized for torch compile
247
+
248
+ x = inp
249
+ max_seq_len = self.max_seq_len
250
+
251
+ mask = self.causal_mask[None, None, kv_pos, :max_seq_len] # (B, N, Q, K)
252
+ freqs_cis = self.freqs_cis[input_pos]
253
+
254
+ for layer in self.layers:
255
+ x = layer(x, freqs_cis, mask, input_pos=kv_pos)
256
+
257
+ x = x[:, -1:]
258
+
259
+ # We got slow_out here
260
+ slow_out = self.norm(x)
261
+
262
+ token_logits = self.output(slow_out)
263
+
264
+ return BaseTransformerForwardResult(
265
+ logits=token_logits,
266
+ hidden_states=x,
267
+ )
268
+
269
+ def _init_weights(self, module):
270
+ std = self.config.initializer_range
271
+ if isinstance(module, nn.Linear):
272
+ module.weight.data.normal_(mean=0.0, std=std)
273
+ if module.bias is not None:
274
+ module.bias.data.zero_()
275
+ elif isinstance(module, nn.Embedding):
276
+ module.weight.data.normal_(mean=0.0, std=std)
277
+ if module.padding_idx is not None:
278
+ module.weight.data[module.padding_idx].zero_()
279
+
280
+ class NaiveTransformer(BaseTransformer):
281
+ def __init__(self, config: NaiveModelArgs) -> None:
282
+ super().__init__(config, init_weights=False)
283
+ self.apply(self._init_weights)
284
+
285
+ def forward(
286
+ self,
287
+ inp: Tensor,
288
+ cond_lens: Tensor,
289
+ target: Tensor,
290
+ target_lens: Tensor,
291
+ key_padding_mask: Optional[Tensor] = None,
292
+ input_pos: Optional[Tensor] = None,
293
+ ) -> TransformerForwardResult:
294
+ parent_result = super().forward(
295
+ inp=inp,
296
+ key_padding_mask=key_padding_mask,
297
+ input_pos=input_pos,
298
+ )
299
+ token_logits = parent_result.logits
300
+
301
+ # construct targets for token_logits
302
+ token_targets = torch.zeros(token_logits.size(0), token_logits.size(1), dtype=torch.long,
303
+ device=target.device) - 100
304
+ for bib in range(token_targets.size(0)):
305
+ token_targets[bib, cond_lens[bib] + 1:cond_lens[bib] + target_lens[bib] + 1] = target[bib, :target_lens[bib]]
306
+ token_targets[bib, cond_lens[bib] + target_lens[bib] + 1] = self.config.vocab_size - 1
307
+ return TransformerForwardResult(
308
+ token_logits=token_logits,
309
+ token_targets=token_targets,
310
+ )
311
+
312
+ def infer_slow(self, inp: Tensor, input_pos: Optional[Tensor] = None):
313
+ # no kv cache used
314
+ parent_result = super().forward(inp, input_pos=input_pos)
315
+ latent = parent_result.hidden_states[:, -1]
316
+ base_logits = parent_result.logits[:, -1]
317
+ base_sampled, _ = topk_sampling(base_logits, top_k=-1, top_p=1.0)
318
+ return base_sampled
319
+
320
+ def forward_generate(
321
+ self,
322
+ x: Tensor,
323
+ input_pos: Optional[Tensor] = None,
324
+ kv_pos: Optional[Tensor] = None,
325
+ vq_masks: Optional[Tensor] = None,
326
+ ) -> TransformerForwardResult:
327
+ x = super().forward_generate(x, input_pos, kv_pos, vq_masks)
328
+ return x
329
+
330
+ class NaiveWrapper(nn.Module):
331
+ def __init__(self, model: NaiveTransformer) -> None:
332
+ super().__init__()
333
+ self.model = model
334
+ self.sep_token_emb = nn.Parameter(torch.randn(model.config.dim))
335
+
336
+ def setup_caches(self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda"):
337
+ self.model.setup_caches(max_batch_size, max_seq_len, dtype, device)
338
+
339
+ def forward(self, cond: Tensor, cond_lens: Tensor, x: Tensor, x_lens: Tensor) -> torch.Tensor:
340
+ # style_emb = self.style_in(style).unsqueeze(1) # [B, 1, D]
341
+ sep_token_emb = self.sep_token_emb.expand(x.size(0), 1, -1)
342
+ _, x_emb = self.model.embed_base(x, x_lens)
343
+ emb_seq_list = []
344
+ for i in range(x.size(0)):
345
+ emb_seq = torch.cat([
346
+ sep_token_emb[i:i + 1],
347
+ cond[i:i+1, :cond_lens[i]],
348
+ sep_token_emb[i:i+1],
349
+ x_emb[i:i+1, :x_lens[i]]], dim=1)
350
+ emb_seq_list.append(emb_seq)
351
+ max_len = max([emb_seq.size(1) for emb_seq in emb_seq_list])
352
+ emb_seq = torch.cat([
353
+ F.pad(emb_seq, (0, 0, 0, max_len - emb_seq.size(1)), value=0)
354
+ for emb_seq in emb_seq_list
355
+ ], dim=0)
356
+ # input_pos = torch.arange(emb_seq.size(1), device=emb_seq.device).repeat(emb_seq.size(0), 1)
357
+ input_pos = torch.zeros(emb_seq.size(0), emb_seq.size(1), device=emb_seq.device, dtype=torch.long)
358
+ for i in range(x.size(0)):
359
+ input_pos[i, :cond_lens[i] + 1] = torch.arange(cond_lens[i] + 1, device=emb_seq.device)
360
+ input_pos[i, cond_lens[i] + 1: cond_lens[i] + x_lens[i] + 2] = torch.arange(x_lens[i] + 1, device=emb_seq.device)
361
+ out = self.model(emb_seq, cond_lens, x, x_lens, input_pos=input_pos)
362
+ loss = F.cross_entropy(out.token_logits.transpose(1, 2), out.token_targets.long(), ignore_index=-100)
363
+ return loss
364
+
365
+ @torch.no_grad()
366
+ def infer(self, cond: Tensor) -> torch.Tensor:
367
+ sep_token_emb = self.sep_token_emb.expand(1, 1, -1)
368
+ emb_seq = torch.cat([sep_token_emb, cond, sep_token_emb], dim=1)
369
+ pred_codes = []
370
+ input_pos = torch.arange(cond.size(1) + 1, device=cond.device)
371
+ for i in tqdm(range(4000)):
372
+ input_pos = torch.cat([input_pos, torch.LongTensor([i]).to(cond.device)], dim=0)
373
+ base = self.model.infer_slow(emb_seq, input_pos)
374
+ if base == self.model.config.vocab_size - 1:
375
+ break
376
+ new_emb = self.model.embed_base(base, torch.LongTensor([1]).to(base.device))[1]
377
+ emb_seq = torch.cat([emb_seq, new_emb], dim=1)
378
+ pred_codes.append(base)
379
+ return torch.cat(pred_codes, dim=-1)
380
+
381
+ @torch.no_grad()
382
+ def generate(
383
+ self,
384
+ prompt_text,
385
+ prompt_target,
386
+ compiled_decode_fn = None,
387
+ **sampling_kwargs,
388
+ ):
389
+ sep_token_emb = self.sep_token_emb.expand(1, 1, -1)
390
+ emb_seq = torch.cat([sep_token_emb, prompt_text, sep_token_emb], dim=1)
391
+ input_pos = torch.arange(prompt_text.size(1) + 1, device=emb_seq.device)
392
+ input_pos = torch.cat([input_pos, torch.LongTensor([0]).to(emb_seq.device)])
393
+ prompt_target_emb = self.model.embed_base(prompt_target,torch.LongTensor([prompt_target.size(1)]).to(prompt_target.device))[1]
394
+ emb_seq = torch.cat([emb_seq, prompt_target_emb], dim=1)
395
+ input_pos = torch.cat([input_pos, torch.arange(prompt_target_emb.size(1)).to(input_pos.device) + 1])
396
+
397
+ pred_codes = []
398
+ kv_pos = torch.arange(emb_seq.size(1), device=emb_seq.device)
399
+ next_tokens = self.decode_one_token_ar(emb_seq, input_pos, kv_pos, suppress_tokens=[self.model.config.vocab_size - 1], **sampling_kwargs)
400
+ pred_base = next_tokens[0]
401
+ pred_codes.append(pred_base)
402
+ new_emb = self.model.embed_base(pred_base.unsqueeze(0), torch.LongTensor([1]).to(pred_base.device))[1]
403
+ emb_seq = torch.cat([emb_seq, new_emb], dim=1)
404
+ for _ in tqdm(range(4000)):
405
+ suppress_eos = len(pred_codes) < 10
406
+ input_pos = input_pos[-1:] + 1
407
+ kv_pos = kv_pos[-1:] + 1
408
+ next_tokens = self.decode_one_token_ar(
409
+ emb_seq[:, -1:].reshape(1, 1, -1),
410
+ input_pos.reshape(1),
411
+ kv_pos.reshape(1),
412
+ previous_tokens=torch.cat(pred_codes),
413
+ suppress_tokens=[self.model.config.vocab_size - 1] if suppress_eos else None,
414
+ compiled_decode_fn=compiled_decode_fn,
415
+ **sampling_kwargs)
416
+ pred_base = next_tokens[0]
417
+ if pred_base == self.model.config.vocab_size - 1:
418
+ break
419
+ pred_codes.append(pred_base.clone())
420
+ new_emb = self.model.embed_base(pred_base.unsqueeze(0), torch.LongTensor([1]).to(pred_base.device))[1]
421
+ emb_seq = torch.cat([emb_seq, new_emb], dim=1)
422
+ return torch.stack(pred_codes, dim=-1)
423
+
424
+ def decode_one_token_ar(
425
+ self,
426
+ x: torch.Tensor,
427
+ input_pos: torch.Tensor,
428
+ kv_pos: torch.Tensor,
429
+ previous_tokens: torch.Tensor = None,
430
+ compiled_decode_fn = None,
431
+ **sampling_kwargs,
432
+ ) -> torch.Tensor:
433
+ if compiled_decode_fn is not None:
434
+ x = compiled_decode_fn(x, input_pos, kv_pos)
435
+ else:
436
+ x = self.model.forward_generate(x, input_pos, kv_pos)
437
+
438
+ sampling_kwargs_main = sampling_kwargs.copy()
439
+ codebooks = [
440
+ sample(
441
+ x.logits,
442
+ previous_tokens=(
443
+ previous_tokens[0] if previous_tokens is not None else None
444
+ ),
445
+ **sampling_kwargs_main,
446
+ )[0]
447
+ ]
448
+ codebooks = torch.stack(codebooks, dim=0)
449
+ return codebooks
450
+
451
+ class TransformerBlock(nn.Module):
452
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
453
+ super().__init__()
454
+ self.attention = Attention(config, use_sdpa=use_sdpa)
455
+ self.feed_forward = FeedForward(config)
456
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
457
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
458
+
459
+ def forward(
460
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
461
+ ) -> Tensor:
462
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
463
+ out = h + self.feed_forward(self.ffn_norm(h))
464
+ return out
465
+
466
+
467
+ class Attention(nn.Module):
468
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
469
+ super().__init__()
470
+ assert config.dim % config.n_head == 0
471
+
472
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
473
+ # key, query, value projections for all heads, but in a batch
474
+ self.wqkv = nn.Linear(
475
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
476
+ )
477
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
478
+ self.kv_cache = None
479
+
480
+ self.dropout = config.dropout
481
+ self.n_head = config.n_head
482
+ self.head_dim = config.head_dim
483
+ self.n_local_heads = config.n_local_heads
484
+ self.dim = config.dim
485
+ self.use_sdpa = use_sdpa
486
+ self._register_load_state_dict_pre_hook(self.load_hook)
487
+ self.qk_norm = config.qk_norm
488
+ self.qk_norm_groups = 1
489
+ self.qk_norm_scale = 10
490
+ self.qk_norm_dim_scale = False
491
+ self.qk_norm_q_scale = self.qk_norm_k_scale = 1
492
+
493
+ if self.qk_norm and self.qk_norm_dim_scale:
494
+ self.qk_norm_q_scale = nn.Parameter(torch.ones(self.n_head, 1, self.head_dim))
495
+ self.qk_norm_k_scale = nn.Parameter(torch.ones(self.n_head, 1, self.head_dim))
496
+ def load_hook(self, state_dict, prefix, *args):
497
+ if prefix + "wq.weight" in state_dict:
498
+ wq = state_dict.pop(prefix + "wq.weight")
499
+ wk = state_dict.pop(prefix + "wk.weight")
500
+ wv = state_dict.pop(prefix + "wv.weight")
501
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
502
+
503
+ def forward(
504
+ self,
505
+ x: Tensor,
506
+ freqs_cis: Tensor,
507
+ mask: Tensor,
508
+ input_pos: Optional[Tensor] = None,
509
+ ) -> Tensor:
510
+ bsz, seqlen, _ = x.shape
511
+
512
+ kv_size = self.n_local_heads * self.head_dim
513
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
514
+
515
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
516
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
517
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
518
+
519
+ if self.qk_norm:
520
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
521
+ q, k = map(qk_l2norm, (q, k))
522
+ scale = self.qk_norm_scale
523
+
524
+ q = q * self.qk_norm_q_scale
525
+ k = k * self.qk_norm_k_scale
526
+
527
+ q = apply_rotary_emb(q, freqs_cis)
528
+ k = apply_rotary_emb(k, freqs_cis)
529
+
530
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
531
+
532
+ if self.kv_cache is not None:
533
+ k, v = self.kv_cache.update(input_pos, k, v)
534
+
535
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
536
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
537
+
538
+ if self.use_sdpa:
539
+ if mask is None:
540
+ y = F.scaled_dot_product_attention(
541
+ q,
542
+ k,
543
+ v,
544
+ dropout_p=self.dropout if self.training else 0.0,
545
+ is_causal=True,
546
+ # No third party attn_mask here to use flash_attention
547
+ )
548
+ else:
549
+ y = F.scaled_dot_product_attention(
550
+ q,
551
+ k,
552
+ v,
553
+ attn_mask=mask,
554
+ dropout_p=self.dropout if self.training else 0.0,
555
+ )
556
+ else:
557
+ y = self.eq_scaled_dot_product_attention(
558
+ q,
559
+ k,
560
+ v,
561
+ attn_mask=mask,
562
+ dropout_p=self.dropout if self.training else 0.0,
563
+ )
564
+
565
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
566
+
567
+ return self.wo(y)
568
+
569
+ def eq_scaled_dot_product_attention(
570
+ self,
571
+ query,
572
+ key,
573
+ value,
574
+ attn_mask=None,
575
+ dropout_p=0.0,
576
+ ) -> torch.Tensor:
577
+ # This is a standard scaled dot product attention
578
+ # It's low efficient, but it doesn't raise cuda error
579
+
580
+ L, S = query.size(-2), key.size(-2)
581
+ scale_factor = 1 / math.sqrt(query.size(-1))
582
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
583
+
584
+ if attn_mask is not None:
585
+ if attn_mask.dtype == torch.bool:
586
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
587
+ else:
588
+ attn_bias += attn_mask
589
+
590
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
591
+ attn_weight += attn_bias
592
+ attn_weight = torch.softmax(attn_weight, dim=-1)
593
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
594
+
595
+ return attn_weight @ value
596
+
597
+
598
+ class FeedForward(nn.Module):
599
+ def __init__(self, config: BaseModelArgs) -> None:
600
+ super().__init__()
601
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
602
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
603
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
604
+ self.dropout = nn.Dropout(p=config.dropout)
605
+
606
+ def forward(self, x: Tensor) -> Tensor:
607
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
608
+
609
+
610
+ class RMSNorm(nn.Module):
611
+ def __init__(self, dim: int, eps: float = 1e-5):
612
+ super().__init__()
613
+ self.eps = eps
614
+ self.weight = nn.Parameter(torch.ones(dim))
615
+
616
+ def _norm(self, x):
617
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
618
+
619
+ def forward(self, x: Tensor) -> Tensor:
620
+ output = self._norm(x.float()).type_as(x)
621
+ return output * self.weight
622
+
623
+
624
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
625
+ freqs = 1.0 / (
626
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
627
+ )
628
+ t = torch.arange(seq_len, device=freqs.device)
629
+ freqs = torch.outer(t, freqs)
630
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
631
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
632
+ return cache.to(dtype=torch.bfloat16)
633
+
634
+
635
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
636
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
637
+ freqs_cis = freqs_cis.view(x.size(0), xshaped.size(1), 1, xshaped.size(3), 2)
638
+ x_out2 = torch.stack(
639
+ [
640
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
641
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
642
+ ],
643
+ -1,
644
+ )
645
+
646
+ x_out2 = x_out2.flatten(3)
647
+ return x_out2.type_as(x)
648
+
649
+ def top_k_top_p_filtering(
650
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
651
+ ):
652
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
653
+ Args:
654
+ logits: logits distribution shape (batch size, vocabulary size)
655
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
656
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
657
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
658
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
659
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
660
+ """
661
+ if top_k > 0:
662
+ top_k = min(
663
+ max(top_k, min_tokens_to_keep), logits.size(-1)
664
+ ) # Safety check
665
+ # Remove all tokens with a probability less than the last token of the top-k
666
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
667
+ logits[indices_to_remove] = filter_value
668
+
669
+ if top_p < 1.0:
670
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
671
+ cumulative_probs = torch.cumsum(
672
+ F.softmax(sorted_logits, dim=-1), dim=-1
673
+ )
674
+
675
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
676
+ sorted_indices_to_remove = cumulative_probs > top_p
677
+ if min_tokens_to_keep > 1:
678
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
679
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
680
+ # Shift the indices to the right to keep also the first token above the threshold
681
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
682
+ ..., :-1
683
+ ].clone()
684
+ sorted_indices_to_remove[..., 0] = 0
685
+
686
+ # scatter sorted tensors to original indexing
687
+ indices_to_remove = sorted_indices_to_remove.scatter(
688
+ 1, sorted_indices, sorted_indices_to_remove
689
+ )
690
+ logits[indices_to_remove] = filter_value
691
+ return logits
692
+
693
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
694
+ # temperature: (`optional`) float
695
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
696
+ # top_k: (`optional`) int
697
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
698
+ # top_p: (`optional`) float
699
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
700
+
701
+ # Temperature (higher temperature => more likely to sample low probability tokens)
702
+ if temperature != 1.0:
703
+ logits = logits / temperature
704
+ # Top-p/top-k filtering
705
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
706
+ # Sample
707
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
708
+ logprobs = F.log_softmax(logits.float(), dim=-1)
709
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
710
+ return token, current_logprobs
711
+
712
+ def sample(
713
+ logits,
714
+ previous_tokens: Optional[torch.Tensor] = None,
715
+ **sampling_kwargs,
716
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
717
+ probs = logits_to_probs(
718
+ logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
719
+ )
720
+ idx_next = multinomial_sample_one_no_sync(probs)
721
+ return idx_next, probs
722
+
723
+ def multinomial_sample_one_no_sync(
724
+ probs_sort,
725
+ ): # Does multinomial sampling without a cuda synchronization
726
+ q = torch.empty_like(probs_sort).exponential_(1)
727
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
728
+
729
+
730
+ def logits_to_probs(
731
+ logits,
732
+ previous_tokens: Optional[torch.Tensor] = None,
733
+ suppress_tokens: Optional[List[int]] = None,
734
+ temperature: torch.Tensor = 0.7,
735
+ top_p: torch.Tensor = 0.7,
736
+ repetition_penalty: torch.Tensor = 1.5,
737
+ ) -> torch.Tensor:
738
+ # Apply repetition penalty
739
+ if previous_tokens is not None:
740
+ previous_tokens = previous_tokens.long()
741
+ score = torch.gather(logits, dim=0, index=previous_tokens)
742
+ score = torch.where(
743
+ score < 0, score * repetition_penalty, score / repetition_penalty
744
+ )
745
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
746
+ if suppress_tokens is not None:
747
+ for token in suppress_tokens:
748
+ logits[token] = -float("Inf")
749
+
750
+ # Apply top-p sampling
751
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
752
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
753
+ sorted_indices_to_remove = cum_probs > top_p
754
+ sorted_indices_to_remove[0] = False # keep at least one option
755
+ indices_to_remove = sorted_indices_to_remove.scatter(
756
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
757
+ )
758
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
759
+
760
+ logits = logits / max(temperature, 1e-5)
761
+
762
+ probs = torch.nn.functional.softmax(logits, dim=-1)
763
+ return probs
modules/v2/cfm.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+
4
+ class CFM(torch.nn.Module):
5
+ def __init__(
6
+ self,
7
+ estimator: torch.nn.Module,
8
+ ):
9
+ super().__init__()
10
+ self.sigma_min = 1e-6
11
+ self.estimator = estimator
12
+ self.in_channels = estimator.in_channels
13
+ self.criterion = torch.nn.L1Loss()
14
+
15
+ @torch.inference_mode()
16
+ def inference(self,
17
+ mu: torch.Tensor,
18
+ x_lens: torch.Tensor,
19
+ prompt: torch.Tensor,
20
+ style: torch.Tensor,
21
+ n_timesteps=10,
22
+ temperature=1.0,
23
+ inference_cfg_rate=[0.5, 0.5],
24
+ random_voice=False,
25
+ ):
26
+ """Forward diffusion
27
+
28
+ Args:
29
+ mu (torch.Tensor): output of encoder
30
+ shape: (batch_size, n_feats, mel_timesteps)
31
+ x_lens (torch.Tensor): length of each mel-spectrogram
32
+ shape: (batch_size,)
33
+ prompt (torch.Tensor): prompt
34
+ shape: (batch_size, n_feats, prompt_len)
35
+ style (torch.Tensor): style
36
+ shape: (batch_size, style_dim)
37
+ n_timesteps (int): number of diffusion steps
38
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
39
+ inference_cfg_rate (float, optional): Classifier-Free Guidance inference introduced in VoiceBox. Defaults to 0.5.
40
+
41
+ Returns:
42
+ sample: generated mel-spectrogram
43
+ shape: (batch_size, n_feats, mel_timesteps)
44
+ """
45
+ B, T = mu.size(0), mu.size(1)
46
+ z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
47
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
48
+ t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
49
+ return self.solve_euler(z, x_lens, prompt, mu, style, t_span, inference_cfg_rate, random_voice)
50
+ def solve_euler(self, x, x_lens, prompt, mu, style, t_span, inference_cfg_rate=[0.5, 0.5], random_voice=False,):
51
+ """
52
+ Fixed euler solver for ODEs.
53
+ Args:
54
+ x (torch.Tensor): random noise
55
+ t_span (torch.Tensor): n_timesteps interpolated
56
+ shape: (n_timesteps + 1,)
57
+ mu (torch.Tensor): output of encoder
58
+ shape: (batch_size, n_feats, mel_timesteps)
59
+ x_lens (torch.Tensor): length of each mel-spectrogram
60
+ shape: (batch_size,)
61
+ prompt (torch.Tensor): prompt
62
+ shape: (batch_size, n_feats, prompt_len)
63
+ style (torch.Tensor): style
64
+ shape: (batch_size, style_dim)
65
+ inference_cfg_rate (float, optional): Classifier-Free Guidance inference introduced in VoiceBox. Defaults to 0.5.
66
+ sway_sampling (bool, optional): Sway sampling. Defaults to False.
67
+ amo_sampling (bool, optional): AMO sampling. Defaults to False.
68
+ """
69
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
70
+
71
+ # apply prompt
72
+ prompt_len = prompt.size(-1)
73
+ prompt_x = torch.zeros_like(x)
74
+ prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
75
+ x[..., :prompt_len] = 0
76
+ for step in tqdm(range(1, len(t_span))):
77
+ if random_voice:
78
+ cfg_dphi_dt = self.estimator(
79
+ torch.cat([x, x], dim=0),
80
+ torch.cat([torch.zeros_like(prompt_x), torch.zeros_like(prompt_x)], dim=0),
81
+ torch.cat([x_lens, x_lens], dim=0),
82
+ torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0),
83
+ torch.cat([torch.zeros_like(style), torch.zeros_like(style)], dim=0),
84
+ torch.cat([mu, torch.zeros_like(mu)], dim=0),
85
+ )
86
+ cond_txt, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2]
87
+ dphi_dt = ((1.0 + inference_cfg_rate[0]) * cond_txt - inference_cfg_rate[0] * uncond)
88
+ elif all(i == 0 for i in inference_cfg_rate):
89
+ dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
90
+ elif inference_cfg_rate[0] == 0:
91
+ # Classifier-Free Guidance inference introduced in VoiceBox
92
+ cfg_dphi_dt = self.estimator(
93
+ torch.cat([x, x], dim=0),
94
+ torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0),
95
+ torch.cat([x_lens, x_lens], dim=0),
96
+ torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0),
97
+ torch.cat([style, torch.zeros_like(style)], dim=0),
98
+ torch.cat([mu, mu], dim=0),
99
+ )
100
+ cond_txt_spk, cond_txt = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2]
101
+ dphi_dt = ((1.0 + inference_cfg_rate[1]) * cond_txt_spk - inference_cfg_rate[1] * cond_txt)
102
+ elif inference_cfg_rate[1] == 0:
103
+ cfg_dphi_dt = self.estimator(
104
+ torch.cat([x, x], dim=0),
105
+ torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0),
106
+ torch.cat([x_lens, x_lens], dim=0),
107
+ torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0),
108
+ torch.cat([style, torch.zeros_like(style)], dim=0),
109
+ torch.cat([mu, torch.zeros_like(mu)], dim=0),
110
+ )
111
+ cond_txt_spk, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2]
112
+ dphi_dt = ((1.0 + inference_cfg_rate[0]) * cond_txt_spk - inference_cfg_rate[0] * uncond)
113
+ else:
114
+ # Multi-condition Classifier-Free Guidance inference introduced in MegaTTS3
115
+ cfg_dphi_dt = self.estimator(
116
+ torch.cat([x, x, x], dim=0),
117
+ torch.cat([prompt_x, torch.zeros_like(prompt_x), torch.zeros_like(prompt_x)], dim=0),
118
+ torch.cat([x_lens, x_lens, x_lens], dim=0),
119
+ torch.cat([t.unsqueeze(0), t.unsqueeze(0), t.unsqueeze(0)], dim=0),
120
+ torch.cat([style, torch.zeros_like(style), torch.zeros_like(style)], dim=0),
121
+ torch.cat([mu, mu, torch.zeros_like(mu)], dim=0),
122
+ )
123
+ cond_txt_spk, cond_txt, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2], cfg_dphi_dt[2:3]
124
+ dphi_dt = (1.0 + inference_cfg_rate[0] + inference_cfg_rate[1]) * cond_txt_spk - \
125
+ inference_cfg_rate[0] * uncond - inference_cfg_rate[1] * cond_txt
126
+ x = x + dt * dphi_dt
127
+ t = t + dt
128
+ if step < len(t_span) - 1:
129
+ dt = t_span[step + 1] - t
130
+ x[:, :, :prompt_len] = 0
131
+
132
+ return x
133
+
134
+ def forward(self, x1, x_lens, prompt_lens, mu, style):
135
+ """Computes diffusion loss
136
+
137
+ Args:
138
+ x1 (torch.Tensor): Target
139
+ shape: (batch_size, n_feats, mel_timesteps)
140
+ mask (torch.Tensor): target mask
141
+ shape: (batch_size, 1, mel_timesteps)
142
+ mu (torch.Tensor): output of encoder
143
+ shape: (batch_size, n_feats, mel_timesteps)
144
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
145
+ shape: (batch_size, spk_emb_dim)
146
+
147
+ Returns:
148
+ loss: conditional flow matching loss
149
+ y: conditional flow
150
+ shape: (batch_size, n_feats, mel_timesteps)
151
+ """
152
+ b, _, t = x1.shape
153
+
154
+ # random timestep
155
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
156
+ # sample noise p(x_0)
157
+ z = torch.randn_like(x1)
158
+
159
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
160
+ u = x1 - (1 - self.sigma_min) * z
161
+ prompt = torch.zeros_like(x1)
162
+ for bib in range(b):
163
+ prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
164
+ # range covered by prompt are set to 0
165
+ y[bib, :, :prompt_lens[bib]] = 0
166
+
167
+ estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(), style, mu)
168
+ loss = 0
169
+ for bib in range(b):
170
+ loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
171
+ loss /= b
172
+
173
+ return loss