PhuongLT commited on
Commit
81d41bd
·
1 Parent(s): 7a7d6aa

gemini version

Browse files
.gitignore CHANGED
@@ -1 +1,6 @@
1
- Models/multi_phoaudio_gemini/*.pth
 
 
 
 
 
 
1
+ Models/*/*.pth
2
+ data/
3
+ explore/
4
+ .env
5
+ __pycache__/
6
+ */__pycache__/
Models/gemini_vi/config_gemini_vi_en.yml ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/phoaudio/combine_gemini_vi_en"
2
+ first_stage_path: ""
3
+ save_freq: 1
4
+ log_interval: 50
5
+ device: "cuda"
6
+ epochs_1st: 200 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 150 # number of epochs for second stage training (joint training)
8
+ batch_size: 8
9
+ max_len: 400 # maximum number of frames
10
+ # max_len: 800
11
+ pretrained_model: "Models/phoaudio/combine_gemini_vi_en/epoch_2nd_00029.pth"
12
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
13
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
14
+
15
+ F0_path: "Utils_extend_v1/JDC/bst.t7"
16
+ ASR_config: "Utils_extend_v1/ASR/config.yml"
17
+ ASR_path: "Utils_extend_v1/ASR/epoch_extend_186.pth"
18
+ PLBERT_dir: 'Utils_extend_v1/PLBERT/'
19
+ extend_PLBERT: true # set to true if want to extend the PLBERT model
20
+
21
+ data_params:
22
+ train_data: "/home/xdep/data/jupyterhub/users/datnvt/data/custom_datasets/text_gemini_vi_en/train_list.txt"
23
+ val_data: "/home/xdep/data/jupyterhub/users/datnvt/data/custom_datasets/text_gemini_vi_en/validation_vi_list.txt"
24
+ root_path: "/home/xdep/data/jupyterhub/users/datnvt/data/custom_datasets/wavs_gemini_phoaudio_multi_speaker_small_v1"
25
+ OOD_data: "/home/xdep/data/jupyterhub/users/datnvt/data/custom_datasets/text_gemini_phoaudio_multi_speaker_small_v1/ood_multi_phoaudio.txt"
26
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
27
+
28
+ symbol: #Total 178 symbols
29
+ pad: "$"
30
+ punctuation: ';:,.!?¡¿—…"«»“” '
31
+ letters: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
32
+ letters_ipa: "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
33
+ extend: "-124567̪" #ADD MORE SYMBOLS HERE
34
+
35
+ preprocess_params:
36
+ sr: 24000
37
+ spect_params:
38
+ n_fft: 2048
39
+ win_length: 1200
40
+ hop_length: 300
41
+
42
+ model_params:
43
+ multispeaker: true
44
+
45
+ dim_in: 64
46
+ hidden_dim: 512
47
+ max_conv_dim: 512
48
+ n_layer: 3
49
+ n_mels: 80
50
+
51
+ n_token: 186 # number of phoneme tokens
52
+ max_dur: 50 # maximum duration of a single phoneme
53
+ style_dim: 128 # style vector size
54
+
55
+ dropout: 0.2
56
+
57
+ # config for decoder
58
+ decoder:
59
+ type: 'istftnet' # either hifigan or istftnet
60
+ resblock_kernel_sizes: [3,7,11]
61
+ upsample_rates : [10, 6]
62
+ upsample_initial_channel: 512
63
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
64
+ upsample_kernel_sizes: [20, 12]
65
+ gen_istft_n_fft: 20
66
+ gen_istft_hop_size: 5
67
+
68
+ # speech language model config
69
+ slm:
70
+ model: 'microsoft/wavlm-base-plus'
71
+ sr: 16000 # sampling rate of SLM
72
+ hidden: 768 # hidden size of SLM
73
+ nlayers: 13 # number of layers of SLM
74
+ initial_channel: 64 # initial channels of SLM discriminator head
75
+
76
+ # style diffusion model config
77
+ diffusion:
78
+ embedding_mask_proba: 0.1
79
+ # transformer config
80
+ transformer:
81
+ num_layers: 3
82
+ num_heads: 8
83
+ head_features: 64
84
+ multiplier: 2
85
+
86
+ # diffusion distribution config
87
+ dist:
88
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
89
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
90
+ mean: -3.0
91
+ std: 1.0
92
+
93
+ loss_params:
94
+ lambda_mel: 5. # mel reconstruction loss
95
+ lambda_gen: 1. # generator loss
96
+ lambda_slm: 1. # slm feature matching loss
97
+
98
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
99
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
100
+ TMA_epoch: 10 # TMA starting epoch (1st stage)
101
+
102
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
103
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
104
+ lambda_dur: 1. # duration loss (2nd stage)
105
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
106
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
107
+ lambda_diff: 1. # score matching loss (2nd stage)
108
+
109
+ diff_epoch: 5 # style diffusion starting epoch (2nd stage)
110
+ joint_epoch: 10 # joint training starting epoch (2nd stage)
111
+
112
+ optimizer_params:
113
+ lr: 0.0001 # general learning rate
114
+ bert_lr: 0.00001 # learning rate for PLBERT
115
+ ft_lr: 0.00001 # learning rate for acoustic modules
116
+
117
+ slmadv_params:
118
+ min_len: 400 # minimum length of samples
119
+ max_len: 500 # maximum length of samples
120
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
121
+ iter: 10 # update the discriminator every this iterations of generator update
122
+ thresh: 5 # gradient norm above which the gradient is scaled
123
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
124
+ sig: 1.5 # sigma for differentiable duration modeling
125
+
Models/styles_speaker_parallel.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e92865ecb487a924469f82fdbfdb9dad41cdbd9a58c866d71a323265686ee13
3
+ size 2091581
__pycache__/models.cpython-310.pyc CHANGED
Binary files a/__pycache__/models.cpython-310.pyc and b/__pycache__/models.cpython-310.pyc differ
 
__pycache__/text_utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/text_utils.cpython-310.pyc and b/__pycache__/text_utils.cpython-310.pyc differ
 
app.py CHANGED
@@ -1,40 +1,30 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Gradio app.py - wired to your 'inference_one' implementation
4
- - Reference voice: upload OR choose from train_ref/
5
- - Uses phonemize_text(), compute_style(), inference_one() exactly like your snippet
 
 
6
  """
7
 
8
- import os
9
- import time
10
- import yaml
11
- import numpy as np
12
- import torch
13
- import torchaudio
14
- import librosa
15
- import gradio as gr
16
-
17
  from munch import Munch
 
18
 
19
- # -----------------------------
20
- # Reproducibility
21
- # -----------------------------
22
- torch.manual_seed(0)
23
- torch.backends.cudnn.benchmark = False
24
- torch.backends.cudnn.deterministic = True
25
- np.random.seed(0)
26
-
27
- # -----------------------------
28
- # Device / sample-rate
29
- # -----------------------------
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
- SR_OUT = 24000 # target audio rate for synthesis
 
32
 
33
- # -----------------------------
34
- # External modules from the project
35
- # -----------------------------
36
- from models import * # noqa: F401,F403
37
- from utils import * # noqa: F401,F403
 
 
38
  from models import build_model
39
  from text_utils import TextCleaner
40
  from Utils_extend_v1.PLBERT.util import load_plbert
@@ -42,24 +32,21 @@ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSche
42
 
43
  textcleaner = TextCleaner()
44
 
45
- # -----------------------------
46
- # Config / model loading
47
- # -----------------------------
48
  from huggingface_hub import hf_hub_download
49
- hf_hub_download(
50
- repo_id="ltphuongunited/styletts2_vi",
51
- filename="epoch_2nd_00058.pth",
52
- local_dir="Models/multi_phoaudio_gemini",
53
- local_dir_use_symlinks=False,
54
- )
55
-
56
- CONFIG_PATH = os.getenv("MODEL_CONFIG", "Models/multi_phoaudio_gemini/config_phoaudio_gemini_small.yml")
57
- CHECKPOINT_PTH = os.getenv("MODEL_CKPT", "Models/multi_phoaudio_gemini/epoch_2nd_00058.pth")
58
-
59
- # Load config
60
  config = yaml.safe_load(open(CONFIG_PATH))
61
 
62
- # Build components
63
  ASR_config = config.get("ASR_config", False)
64
  ASR_path = config.get("ASR_path", False)
65
  F0_path = config.get("F0_path", False)
@@ -68,34 +55,24 @@ PLBERT_dir = config.get("PLBERT_dir", False)
68
  text_aligner = load_ASR_models(ASR_path, ASR_config)
69
  pitch_extractor = load_F0_models(F0_path)
70
  plbert = load_plbert(PLBERT_dir)
71
-
72
- model_params = recursive_munch(config["model_params"])
73
  model = build_model(model_params, text_aligner, pitch_extractor, plbert)
74
 
75
- # to device & eval
76
  _ = [model[k].to(DEVICE) for k in model]
77
  _ = [model[k].eval() for k in model]
78
 
79
- # Load checkpoint
80
- if not os.path.isfile(CHECKPOINT_PTH):
81
- raise FileNotFoundError(f"Checkpoint not found at '{CHECKPOINT_PTH}'")
82
- ckpt = torch.load(CHECKPOINT_PTH, map_location="cpu")
83
- params = ckpt["net"]
84
  for key in model:
85
- if key in params:
86
  try:
87
- model[key].load_state_dict(params[key])
88
  except Exception:
89
  from collections import OrderedDict
90
- state_dict = params[key]
91
  new_state = OrderedDict()
92
- for k, v in state_dict.items():
93
- name = k[7:] # strip 'module.' if present
94
- new_state[name] = v
95
  model[key].load_state_dict(new_state, strict=False)
96
- _ = [model[k].eval() for k in model]
97
 
98
- # Diffusion sampler
99
  sampler = DiffusionSampler(
100
  model.diffusion.diffusion,
101
  sampler=ADPM2Sampler(),
@@ -103,115 +80,211 @@ sampler = DiffusionSampler(
103
  clamp=False,
104
  )
105
 
106
- # -----------------------------
107
- # Audio helper: mel preprocessing
108
- # -----------------------------
109
- _to_mel = torchaudio.transforms.MelSpectrogram(
110
- n_mels=80, n_fft=2048, win_length=1200, hop_length=300
 
111
  )
112
- _MEAN, _STD = -4.0, 4.0
 
 
 
113
 
114
  def length_to_mask(lengths: torch.LongTensor) -> torch.Tensor:
115
  mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
116
  mask = torch.gt(mask + 1, lengths.unsqueeze(1))
117
  return mask
118
 
119
- def preprocess(wave: np.ndarray) -> torch.Tensor:
120
- """Same name as your snippet: np.float -> mel (normed)"""
121
- wave_tensor = torch.from_numpy(wave).float()
122
- mel_tensor = _to_mel(wave_tensor)
123
- mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - _MEAN) / _STD
124
- return mel_tensor
125
-
126
- # -----------------------------
127
- # Phonemizer (vi)
128
- # -----------------------------
129
- import phonemizer
130
- vi_phonemizer = phonemizer.backend.EspeakBackend(language="vi", preserve_punctuation=True, with_stress=True)
131
- global_phonemizer = vi_phonemizer
132
-
133
- def phonemize_text(text: str) -> str:
134
- ps = global_phonemizer.phonemize([text])[0]
135
- return ps.replace("(en)", "").replace("(vi)", "").strip()
136
-
137
- # -----------------------------
138
- # Style extractor (from file path)
139
- # -----------------------------
140
- def compute_style(model, path, device):
141
- """Compute style/prosody reference from a wav file path"""
142
- wave, sr = librosa.load(path, sr=None, mono=True)
143
- audio, _ = librosa.effects.trim(wave, top_db=30)
144
- if sr != SR_OUT:
145
- audio = librosa.resample(audio, sr, SR_OUT)
146
- mel_tensor = preprocess(audio).to(device)
147
-
148
- with torch.no_grad():
149
- ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
150
- ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
151
- return torch.cat([ref_s, ref_p], dim=1) # [1, 256]
152
-
153
- # Style extractor (from numpy array)
154
- def compute_style_from_numpy(model, arr: np.ndarray, sr: int, device):
155
- if arr.ndim > 1:
156
- arr = librosa.to_mono(arr.T)
157
- audio, _ = librosa.effects.trim(arr, top_db=30)
158
- if sr != SR_OUT:
159
- audio = librosa.resample(audio, sr, SR_OUT)
160
- mel_tensor = preprocess(audio).to(device)
161
- with torch.no_grad():
162
- ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
163
- ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
164
- return torch.cat([ref_s, ref_p], dim=1)
165
-
166
- # -----------------------------
167
- # Inference (your exact logic)
168
- # -----------------------------
169
- # Tunables (can expose to UI later)
170
- ALPHA = 0.3
171
- BETA = 0.7
172
- DIFFUSION_STEPS = 5
173
- EMBEDDING_SCALE = 1.0
174
-
175
- def inference_one(text, ref_feat, ipa_text=None,
176
- alpha=ALPHA, beta=BETA, diffusion_steps=DIFFUSION_STEPS, embedding_scale=EMBEDDING_SCALE):
177
- # text -> phonemes -> tokens
178
- ps = ipa_text if ipa_text is not None else phonemize_text(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  tokens = textcleaner(ps)
180
- tokens.insert(0, 0) # prepend BOS
181
- tokens = torch.LongTensor(tokens).to(DEVICE).unsqueeze(0) # [1, T]
 
 
182
 
183
  with torch.no_grad():
184
- input_lengths = torch.LongTensor([tokens.shape[-1]]).to(DEVICE)
185
- text_mask = length_to_mask(input_lengths).to(DEVICE)
186
-
187
- # encoders
188
  t_en = model.text_encoder(tokens, input_lengths, text_mask)
189
  bert_d = model.bert(tokens, attention_mask=(~text_mask).int())
190
  d_en = model.bert_encoder(bert_d).transpose(-1, -2)
191
 
192
- # diffusion for style latent
193
- s_pred = sampler(
194
- noise=torch.randn((1, 256)).unsqueeze(1).to(DEVICE),
195
- embedding=bert_d,
196
- embedding_scale=embedding_scale,
197
- features=ref_feat, # [1, 256]
198
- num_steps=diffusion_steps,
199
- ).squeeze(1) # [1, 256]
200
-
201
- s = s_pred[:, 128:] # prosody
202
- ref = s_pred[:, :128] # timbre
203
-
204
- # blend with real ref features
205
  ref = alpha * ref + (1 - alpha) * ref_feat[:, :128]
206
  s = beta * s + (1 - beta) * ref_feat[:, 128:]
207
 
208
- # duration prediction
 
 
 
 
 
 
209
  d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
210
  x, _ = model.predictor.lstm(d)
211
  duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1)
212
  pred_dur = torch.round(duration.squeeze()).clamp(min=1)
213
 
214
- # alignment
215
  T = int(pred_dur.sum().item())
216
  pred_aln = torch.zeros(input_lengths.item(), T, device=DEVICE)
217
  c = 0
@@ -220,119 +293,225 @@ def inference_one(text, ref_feat, ipa_text=None,
220
  pred_aln[i, c:c+span] = 1.0
221
  c += span
222
 
223
- # prosody enc
224
  en = (d.transpose(-1, -2) @ pred_aln.unsqueeze(0))
225
  if model_params.decoder.type == "hifigan":
226
- asr_new = torch.zeros_like(en); asr_new[:, :, 0] = en[:, :, 0]; asr_new[:, :, 1:] = en[:, :, 0:-1]; en = asr_new
227
 
228
  F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
229
-
230
- # content (ASR-aligned)
231
  asr = (t_en @ pred_aln.unsqueeze(0))
232
  if model_params.decoder.type == "hifigan":
233
- asr_new = torch.zeros_like(asr); asr_new[:, :, 0] = asr[:, :, 0]; asr_new[:, :, 1:] = asr[:, :, 0:-1]; asr = asr_new
234
 
235
- # decode
236
  out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
237
 
238
  wav = out.squeeze().detach().cpu().numpy()
239
  if wav.shape[-1] > 50:
240
- wav = wav[..., :-50]
241
- return wav, ps
242
-
243
- # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  # Gradio UI
245
- # -----------------------------
246
-
247
-
248
- SR_OUT = 24000
249
- ROOT_REF = "ref_voice"
250
- EXTS = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
251
-
252
- # -------- scan ref_voice/<id>_<speaker>/*.wav --------
253
- def scan_ref_voice(root=ROOT_REF):
254
- """
255
- return:
256
- speakers: list[str] # ví dụ: ["0_Fonos.vn", "1_James_A._Robinson", ...]
257
- files_by_spk: dict[str, list[str]] # speaker_dir -> [full_path,...]
258
- """
259
- speakers, files_by_spk = [], {}
260
- if not os.path.isdir(root):
261
- return speakers, files_by_spk
262
-
263
- for spk_dir in sorted(os.listdir(root)):
264
- full_dir = os.path.join(root, spk_dir)
265
- if not os.path.isdir(full_dir) or spk_dir.startswith("."):
266
- continue
267
- lst = []
268
- for fn in sorted(os.listdir(full_dir)):
269
- if os.path.splitext(fn)[1].lower() in EXTS:
270
- lst.append(os.path.join(full_dir, fn))
271
- if lst:
272
- speakers.append(spk_dir)
273
- files_by_spk[spk_dir] = lst
274
- return speakers, files_by_spk
275
-
276
- SPEAKERS, FILES_BY_SPK = scan_ref_voice()
277
-
278
- with gr.Blocks(title="StyleTTS2-vi Demo ✨") as demo:
279
- gr.Markdown("# StyleTTS2-vi Demo ✨")
280
 
281
  with gr.Row():
282
  with gr.Column():
283
- text_inp = gr.Textbox(label="Text", lines=4,
284
- value="Thời tiết hôm nay tại Hà Nội, nhiệt độ khoảng 27 độ C, có nắng nhẹ, rất hợp lý để mình đi dạo công viên nhé.")
285
-
286
- # --- 1 ô audio duy nhất (nhận filepath) ---
287
- ref_audio = gr.Audio(
288
- label="Reference Audio",
289
- type="filepath", # nhận đường dẫn file
290
- sources=["upload","microphone"], # vẫn cho upload/mic
291
- interactive=True,
292
- )
293
- ref_path = gr.Textbox(label="Đường dẫn reference", interactive=False)
294
-
295
- # --- chọn speaker -> hiện file tương ứng ---
296
- spk_dd = gr.Dropdown(
297
- label="Speaker",
298
- choices=["(None)"] + SPEAKERS,
299
- value="(None)",
300
- )
301
- file_dd = gr.Dropdown(
302
- label="Voice in speaker",
303
- choices=["(None)"],
304
- value="(None)",
305
  )
306
 
307
- # khi chọn speaker -> cập nhật danh sách file
308
- def on_pick_speaker(spk):
309
- if spk == "(None)":
310
- return gr.update(choices=["(None)"], value="(None)")
311
- files = FILES_BY_SPK.get(spk, [])
312
- # hiển thị chỉ tên file cho gọn
313
- labels = [os.path.basename(p) for p in files]
314
- # ta sẽ map label->path bằng index; set value = mục đầu tiên
315
- return gr.update(choices=labels, value=(labels[0] if labels else "(None)"))
316
-
317
- spk_dd.change(on_pick_speaker, inputs=spk_dd, outputs=file_dd)
318
-
319
- # map label (basename) -> full path theo speaker hiện tại
320
- def on_pick_file(spk, label):
321
- if spk == "(None)" or label == "(None)":
322
- return gr.update(value=None), ""
323
- files = FILES_BY_SPK.get(spk, [])
324
- # tìm đúng file theo basename
325
- for p in files:
326
- if os.path.basename(p) == label:
327
- return gr.update(value=p), p # set vào Audio + hiển thị path
328
- return gr.update(value=None), ""
329
-
330
- file_dd.change(on_pick_file, inputs=[spk_dd, file_dd], outputs=[ref_audio, ref_path])
331
-
332
- # nếu người dùng upload/mic thì hiển thị luôn đường dẫn file tạm
333
- def on_audio_changed(fp):
334
- return fp or ""
335
- ref_audio.change(on_audio_changed, inputs=ref_audio, outputs=ref_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  btn = gr.Button("Đọc 🔊🔥", variant="primary")
338
 
@@ -340,35 +519,25 @@ with gr.Blocks(title="StyleTTS2-vi Demo ✨") as demo:
340
  out_audio = gr.Audio(label="Synthesised Audio", type="numpy")
341
  metrics = gr.JSON(label="Metrics")
342
 
343
- # ---- Inference: xử từ filepath ----
344
- def _run(text, ref_fp):
345
- # ref_fp string path (do type='filepath')
346
- if isinstance(ref_fp, str) and os.path.isfile(ref_fp):
347
- wav, _ = librosa.load(ref_fp, sr=SR_OUT, mono=True)
348
- ref_feat = compute_style_from_numpy(model, wav, SR_OUT, DEVICE)
349
- ref_src = ref_fp
350
- else:
351
- ref_feat = torch.zeros(1, 256).to(DEVICE)
352
- ref_src = "(None)"
353
-
354
- t0 = time.time()
355
- wav, ps = inference_one(text, ref_feat)
356
- wav = wav.astype(np.float32)
357
- gen_time = time.time() - t0
358
- rtf = gen_time / max(1e-6, len(wav)/SR_OUT)
359
-
360
- info = {
361
- "Ref path": ref_src,
362
- "Phonemes": ps,
363
- "Sample rate": SR_OUT,
364
- "RTF": round(float(rtf), 3),
365
- "Device": DEVICE,
366
- }
367
- return (SR_OUT, wav), info
368
-
369
- btn.click(_run, inputs=[text_inp, ref_audio], outputs=[out_audio, metrics])
370
-
371
-
372
 
373
  if __name__ == "__main__":
374
  demo.launch()
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Gradio app.py - StyleTTS2-vi with precomputed style embeddings (.pth)
4
+ - UI alpha/beta/metrics
5
+ - Style Mixer: 4 slot cố định (Kore, Puck, Algenib, Leda), chỉ chỉnh weight; auto-normalize
6
+ - Luôn hiển thị 4 reference samples (accordion)
7
+ - Không còn dropdown speaker & reference sample auto
8
  """
9
 
10
+ import os, re, glob, time, yaml, torch, librosa, numpy as np, gradio as gr
 
 
 
 
 
 
 
 
11
  from munch import Munch
12
+ from soe_vinorm import SoeNormalizer
13
 
14
+ # ==============================================================
15
+ # Cấu hình cơ bản
16
+ # ==============================================================
 
 
 
 
 
 
 
 
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ SR_OUT = 24000
19
+ ALPHA, BETA, DIFFUSION_STEPS, EMBEDDING_SCALE = 0.0, 0.0, 5, 1.0
20
 
21
+ REF_DIR = "ref_voice" # thư mục chứa audio mẫu (.wav)
22
+
23
+ # ==============================================================
24
+ # Import module StyleTTS2
25
+ # ==============================================================
26
+ from models import *
27
+ from utils import *
28
  from models import build_model
29
  from text_utils import TextCleaner
30
  from Utils_extend_v1.PLBERT.util import load_plbert
 
32
 
33
  textcleaner = TextCleaner()
34
 
35
+ # ==============================================================
36
+ # Load model và checkpoint
37
+ # ==============================================================
38
  from huggingface_hub import hf_hub_download
39
+ # hf_hub_download(
40
+ # repo_id="ltphuongunited/styletts2_vi",
41
+ # filename="gemini_2nd_00045.pth",
42
+ # local_dir="Models/gemini_vi",
43
+ # local_dir_use_symlinks=False,
44
+ # )
45
+
46
+ CHECKPOINT_PTH = "Models/gemini_vi/gemini_2nd_00045.pth"
47
+ CONFIG_PATH = "Models/gemini_vi/config_gemini_vi_en.yml"
 
 
48
  config = yaml.safe_load(open(CONFIG_PATH))
49
 
 
50
  ASR_config = config.get("ASR_config", False)
51
  ASR_path = config.get("ASR_path", False)
52
  F0_path = config.get("F0_path", False)
 
55
  text_aligner = load_ASR_models(ASR_path, ASR_config)
56
  pitch_extractor = load_F0_models(F0_path)
57
  plbert = load_plbert(PLBERT_dir)
58
+ model_params = recursive_munch(config["model_params"])
 
59
  model = build_model(model_params, text_aligner, pitch_extractor, plbert)
60
 
 
61
  _ = [model[k].to(DEVICE) for k in model]
62
  _ = [model[k].eval() for k in model]
63
 
64
+ ckpt = torch.load(CHECKPOINT_PTH, map_location="cpu")["net"]
 
 
 
 
65
  for key in model:
66
+ if key in ckpt:
67
  try:
68
+ model[key].load_state_dict(ckpt[key])
69
  except Exception:
70
  from collections import OrderedDict
 
71
  new_state = OrderedDict()
72
+ for k, v in ckpt[key].items():
73
+ new_state[k[7:]] = v
 
74
  model[key].load_state_dict(new_state, strict=False)
 
75
 
 
76
  sampler = DiffusionSampler(
77
  model.diffusion.diffusion,
78
  sampler=ADPM2Sampler(),
 
80
  clamp=False,
81
  )
82
 
83
+ # ==============================================================
84
+ # Phonemizer
85
+ # ==============================================================
86
+ import phonemizer
87
+ vi_phonemizer = phonemizer.backend.EspeakBackend(
88
+ language="vi", preserve_punctuation=True, with_stress=True
89
  )
90
+
91
+ def phonemize_text(text: str) -> str:
92
+ ps = vi_phonemizer.phonemize([text])[0]
93
+ return ps.replace("(en)", "").replace("(vi)", "").strip()
94
 
95
  def length_to_mask(lengths: torch.LongTensor) -> torch.Tensor:
96
  mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
97
  mask = torch.gt(mask + 1, lengths.unsqueeze(1))
98
  return mask
99
 
100
+ # ==============================================================
101
+ # Load style embeddings đã tính sẵn
102
+ # ==============================================================
103
+ STYLE_PTH = "Models/styles_speaker_parallel.pth"
104
+ print(f"Loading precomputed styles: {STYLE_PTH}")
105
+ styles_dict = torch.load(STYLE_PTH, map_location=DEVICE)
106
+
107
+ # fallback speaker nếu mixer rỗng
108
+ SPEAKER_ORDER_PREF = ["Kore", "Puck", "Algenib", "Leda"]
109
+ DEFAULT_SPK = next((s for s in SPEAKER_ORDER_PREF if s in styles_dict), list(styles_dict.keys())[0])
110
+
111
+ def get_style_by_length(speaker: str, phoneme_len: int):
112
+ spk_tensor = styles_dict[speaker] # [510, 1, 256] hoặc [510, 256]
113
+ idx = min(max(phoneme_len, 1), spk_tensor.shape[0]) - 1
114
+ feat = spk_tensor[idx]
115
+ # ép về [1,256]
116
+ if feat.ndim == 3: # [1,1,256]
117
+ feat = feat.squeeze(0)
118
+ if feat.ndim == 2: # [1,256]
119
+ feat = feat.squeeze(0)
120
+ return feat.unsqueeze(0).to(DEVICE) # [1,256]
121
+
122
+ # ==============================================================
123
+ # Style mixing utils
124
+ # ==============================================================
125
+ def parse_mix_spec(spec: str) -> dict:
126
+ """Parse 'Kore:0.75,Puck:0.25' -> {'Kore':0.75,'Puck':0.25} (lọc lỗi, gộp trùng)."""
127
+ mix = {}
128
+ if not spec or not isinstance(spec, str):
129
+ return mix
130
+ for part in spec.split(","):
131
+ if ":" not in part:
132
+ continue
133
+ k, v = part.split(":", 1)
134
+ k = (k or "").strip()
135
+ if not k:
136
+ continue
137
+ try:
138
+ w = float((v or "").strip())
139
+ except Exception:
140
+ continue
141
+ if not np.isfinite(w) or w <= 0:
142
+ continue
143
+ mix[k] = mix.get(k, 0.0) + w
144
+ return mix
145
+
146
+ def get_style_mixed_by_length(mix_dict: dict, phoneme_len: int):
147
+ """Trộn style của nhiều speaker theo trọng số. Trả về [1,256] trên DEVICE."""
148
+ if not mix_dict:
149
+ return get_style_by_length(DEFAULT_SPK, phoneme_len)
150
+
151
+ total = sum(max(0.0, float(w)) for w in mix_dict.values())
152
+ if total <= 0:
153
+ return get_style_by_length(DEFAULT_SPK, phoneme_len)
154
+
155
+ mix_feat = None
156
+ for spk, w in mix_dict.items():
157
+ if spk not in styles_dict:
158
+ print(f"[WARN] Speaker '{spk}' không trong styles_dict, bỏ qua.")
159
+ continue
160
+ feat_i = get_style_by_length(spk, phoneme_len) # [1,256]
161
+ wi = float(w) / total
162
+ mix_feat = feat_i * wi if mix_feat is None else mix_feat + feat_i * wi
163
+
164
+ if mix_feat is None:
165
+ return get_style_by_length(DEFAULT_SPK, phoneme_len)
166
+ return mix_feat # [1,256]
167
+
168
+ # ==============================================================
169
+ # Audio postprocess (librosa): trim + denoise + remove internal silence
170
+ # ==============================================================
171
+ def _simple_spectral_denoise(y, sr, n_fft=1024, hop=256, prop_decrease=0.8):
172
+ if y.size == 0:
173
+ return y
174
+ D = librosa.stft(y, n_fft=n_fft, hop_length=hop, win_length=n_fft)
175
+ S = np.abs(D)
176
+ noise = np.median(S, axis=1, keepdims=True)
177
+ S_clean = S - prop_decrease * noise
178
+ S_clean = np.maximum(S_clean, 0.0)
179
+ gain = S_clean / (S + 1e-8)
180
+ D_denoised = D * gain
181
+ y_out = librosa.istft(D_denoised, hop_length=hop, win_length=n_fft, length=len(y))
182
+ return y_out
183
+
184
+ def _concat_with_crossfade(segments, crossfade_samples=0):
185
+ if not segments:
186
+ return np.array([], dtype=np.float32)
187
+ out = segments[0].astype(np.float32, copy=True)
188
+ for seg in segments[1:]:
189
+ seg = seg.astype(np.float32, copy=False)
190
+ if crossfade_samples > 0 and out.size > 0 and seg.size > 0:
191
+ cf = min(crossfade_samples, out.size, seg.size)
192
+ fade_out = np.linspace(1.0, 0.0, cf, dtype=np.float32)
193
+ fade_in = 1.0 - fade_out
194
+ tail = out[-cf:] * fade_out + seg[:cf] * fade_in
195
+ out = np.concatenate([out[:-cf], tail, seg[cf:]], axis=0)
196
+ else:
197
+ out = np.concatenate([out, seg], axis=0)
198
+ return out
199
+
200
+ def _reduce_internal_silence(y, sr, top_db=30, min_keep_ms=40, crossfade_ms=8):
201
+ if y.size == 0:
202
+ return y
203
+ intervals = librosa.effects.split(y, top_db=top_db)
204
+ if intervals.size == 0:
205
+ return y
206
+ min_keep = int(sr * (min_keep_ms / 1000.0))
207
+ segs = []
208
+ for s, e in intervals:
209
+ if e - s >= min_keep:
210
+ segs.append(y[s:e])
211
+ if not segs:
212
+ return y
213
+ crossfade = int(sr * (crossfade_ms / 1000.0))
214
+ y_out = _concat_with_crossfade(segs, crossfade_samples=crossfade)
215
+ return y_out
216
+
217
+ def postprocess_audio(y, sr,
218
+ trim_top_db=30,
219
+ denoise=True,
220
+ denoise_n_fft=1024,
221
+ denoise_hop=256,
222
+ denoise_strength=0.8,
223
+ remove_internal_silence=True,
224
+ split_top_db=30,
225
+ min_keep_ms=40,
226
+ crossfade_ms=8):
227
+ if y.size == 0:
228
+ return y.astype(np.float32)
229
+ y_trim, _ = librosa.effects.trim(y, top_db=trim_top_db)
230
+ if denoise:
231
+ y_trim = _simple_spectral_denoise(
232
+ y_trim, sr, n_fft=denoise_n_fft, hop=denoise_hop, prop_decrease=denoise_strength
233
+ )
234
+ if remove_internal_silence:
235
+ y_trim = _reduce_internal_silence(
236
+ y_trim, sr, top_db=split_top_db, min_keep_ms=min_keep_ms, crossfade_ms=crossfade_ms
237
+ )
238
+ y_trim = np.nan_to_num(y_trim, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
239
+ m = np.max(np.abs(y_trim)) + 1e-8
240
+ if m > 1.0:
241
+ y_trim = y_trim / m
242
+ return y_trim
243
+
244
+ # ==============================================================
245
+ # Inference core
246
+ # ==============================================================
247
+ def inference_one(text, ref_feat, alpha=ALPHA, beta=BETA,
248
+ diffusion_steps=DIFFUSION_STEPS, embedding_scale=EMBEDDING_SCALE):
249
+ ps = phonemize_text(text)
250
  tokens = textcleaner(ps)
251
+ tokens.insert(0, 0)
252
+ tokens = torch.LongTensor(tokens).unsqueeze(0).to(DEVICE)
253
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(DEVICE)
254
+ text_mask = length_to_mask(input_lengths).to(DEVICE)
255
 
256
  with torch.no_grad():
 
 
 
 
257
  t_en = model.text_encoder(tokens, input_lengths, text_mask)
258
  bert_d = model.bert(tokens, attention_mask=(~text_mask).int())
259
  d_en = model.bert_encoder(bert_d).transpose(-1, -2)
260
 
261
+ if alpha == 0 and beta == 0:
262
+ s_pred = ref_feat.clone() # [1,256]
263
+ else:
264
+ s_pred = sampler(
265
+ noise=torch.randn((1, 256)).unsqueeze(1).to(DEVICE),
266
+ embedding=bert_d,
267
+ embedding_scale=embedding_scale,
268
+ features=ref_feat, # [1,256]
269
+ num_steps=diffusion_steps,
270
+ ).squeeze(1) # [1,256]
271
+
272
+ s, ref = s_pred[:, 128:], s_pred[:, :128]
 
273
  ref = alpha * ref + (1 - alpha) * ref_feat[:, :128]
274
  s = beta * s + (1 - beta) * ref_feat[:, 128:]
275
 
276
+ # --- Metrics (cosine) ---
277
+ def cosine_sim(a, b):
278
+ return torch.nn.functional.cosine_similarity(a, b, dim=1).mean().item()
279
+ simi_timbre = cosine_sim(s_pred[:, :128], ref_feat[:, :128])
280
+ simi_prosody = cosine_sim(s_pred[:, 128:], ref_feat[:, 128:])
281
+
282
+ # --- Duration / Alignment ---
283
  d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
284
  x, _ = model.predictor.lstm(d)
285
  duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1)
286
  pred_dur = torch.round(duration.squeeze()).clamp(min=1)
287
 
 
288
  T = int(pred_dur.sum().item())
289
  pred_aln = torch.zeros(input_lengths.item(), T, device=DEVICE)
290
  c = 0
 
293
  pred_aln[i, c:c+span] = 1.0
294
  c += span
295
 
 
296
  en = (d.transpose(-1, -2) @ pred_aln.unsqueeze(0))
297
  if model_params.decoder.type == "hifigan":
298
+ en = torch.cat([en[:, :, :1], en[:, :, :-1]], dim=2)
299
 
300
  F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
 
 
301
  asr = (t_en @ pred_aln.unsqueeze(0))
302
  if model_params.decoder.type == "hifigan":
303
+ asr = torch.cat([asr[:, :, :1], asr[:, :, :-1]], dim=2)
304
 
 
305
  out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
306
 
307
  wav = out.squeeze().detach().cpu().numpy()
308
  if wav.shape[-1] > 50:
309
+ wav = wav[:-50]
310
+
311
+ # Hậu xử lý: trim + denoise + bỏ silence nội bộ
312
+ wav = postprocess_audio(
313
+ wav, SR_OUT,
314
+ trim_top_db=30,
315
+ denoise=True,
316
+ denoise_n_fft=1024, denoise_hop=256, denoise_strength=0.8,
317
+ remove_internal_silence=True,
318
+ split_top_db=30, min_keep_ms=40, crossfade_ms=8
319
+ )
320
+ return wav, ps, simi_timbre, simi_prosody
321
+
322
+ # ==============================================================
323
+ # Ref-audio mapping (quét ./ref_voice để tìm file mẫu theo speaker)
324
+ # ==============================================================
325
+ def _norm(s: str) -> str:
326
+ import unicodedata
327
+ s = unicodedata.normalize("NFKD", s)
328
+ s = "".join([c for c in s if not unicodedata.combining(c)])
329
+ s = s.lower()
330
+ s = re.sub(r"[^a-z0-9_\-\.]+", "", s)
331
+ return s
332
+
333
+ def build_ref_map(ref_dir: str) -> dict:
334
+ paths = glob.glob(os.path.join(ref_dir, "**", "*.wav"), recursive=True)
335
+ by_name = {}
336
+ for p in paths:
337
+ fname = os.path.basename(p)
338
+ by_name[_norm(fname)] = p
339
+
340
+ spk_map = {}
341
+ speakers = list(styles_dict.keys()) if isinstance(styles_dict, dict) else ["Kore","Algenib","Puck","Leda"]
342
+
343
+ for spk in speakers:
344
+ spk_n = _norm(spk)
345
+ hit = None
346
+ for k, p in by_name.items():
347
+ if f"_{spk_n}_" in k:
348
+ hit = p
349
+ break
350
+ if not hit:
351
+ for k, p in by_name.items():
352
+ if spk_n in k:
353
+ hit = p
354
+ break
355
+ if hit:
356
+ spk_map[spk] = hit
357
+ return spk_map
358
+
359
+ REF_MAP = build_ref_map(REF_DIR)
360
+
361
+ def get_ref_path_for_speaker(spk: str):
362
+ return REF_MAP.get(spk)
363
+
364
+ # ==============================================================
365
+ # Wrapper cho Gradio (nhận speaker_mix_spec là string ẩn)
366
+ # ==============================================================
367
+ def run_inference(text, alpha, beta, speaker_mix_spec):
368
+ normalizer = SoeNormalizer()
369
+ text = normalizer.normalize(text).replace(" ,", ",").replace(" .", ".")
370
+
371
+ ps = phonemize_text(text)
372
+ phoneme_len = len(ps.replace(" ", ""))
373
+
374
+ mix_dict = parse_mix_spec(speaker_mix_spec)
375
+ if len(mix_dict) > 0:
376
+ ref_feat = get_style_mixed_by_length(mix_dict, phoneme_len)
377
+ ref_idx = min(phoneme_len, 510)
378
+ total = sum(mix_dict.values())
379
+ mix_info = {k: round(float(v / total), 3) for k, v in mix_dict.items()}
380
+ chosen_speakers = list(mix_dict.keys())
381
+ else:
382
+ ref_feat = get_style_by_length(DEFAULT_SPK, phoneme_len)
383
+ ref_idx = min(phoneme_len, 510)
384
+ mix_info = {DEFAULT_SPK: 1.0}
385
+ chosen_speakers = [DEFAULT_SPK]
386
+
387
+ t0 = time.time()
388
+ wav, ps_out, simi_timbre, simi_prosody = inference_one(
389
+ text, ref_feat, alpha=float(alpha), beta=float(beta)
390
+ )
391
+ gen_time = time.time() - t0
392
+ rtf = gen_time / max(1e-6, len(wav) / SR_OUT)
393
+
394
+ info = {
395
+ "Text after soe_vinorms:": text,
396
+ "Speakers": chosen_speakers,
397
+ "Mix weights (normalized)": mix_info,
398
+ "Phonemes": ps_out,
399
+ "Phoneme length": phoneme_len,
400
+ "Ref index": ref_idx,
401
+ "simi_timbre": round(float(simi_timbre), 4),
402
+ "simi_prosody": round(float(simi_prosody), 4),
403
+ "alpha": float(alpha),
404
+ "beta": float(beta),
405
+ "RTF": round(float(rtf), 3),
406
+ "Device": DEVICE,
407
+ }
408
+ return (SR_OUT, wav.astype(np.float32)), info
409
+
410
+ # ==============================================================
411
+ # UI helper: build mix-spec CỐ ĐỊNH theo 4 speaker
412
+ # ==============================================================
413
+ def _build_mix_spec_ui_fixed(normalize, w1, w2, w3, w4, order):
414
+ pairs = [(order[0], float(w1 or 0.0)),
415
+ (order[1], float(w2 or 0.0)),
416
+ (order[2], float(w3 or 0.0)),
417
+ (order[3], float(w4 or 0.0))]
418
+ pairs = [(s, w) for s, w in pairs if w > 0]
419
+
420
+ if not pairs:
421
+ return "", {}, "**Sum:** 0.000"
422
+
423
+ total = sum(w for _, w in pairs)
424
+ if normalize and total > 0:
425
+ pairs = [(s, w/total) for s, w in pairs]
426
+
427
+ acc = {}
428
+ for s, w in pairs:
429
+ acc[s] = acc.get(s, 0.0) + w
430
+
431
+ mix_spec = ",".join([f"{s}:{w:.4f}" for s, w in acc.items()])
432
+ mix_view = {"weights": {s: round(w, 3) for s, w in acc.items()}, "normalized": bool(normalize)}
433
+ sum_md = f"**Sum:** {round(sum(acc.values()), 3)}"
434
+ return mix_spec, mix_view, sum_md
435
+
436
+
437
+
438
+ # ==============================================================
439
  # Gradio UI
440
+ # ==============================================================
441
+ with gr.Blocks(title="StyleTTS2-vi Demo") as demo:
442
+ gr.Markdown("# StyleTTS2-vi Demo")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
  with gr.Row():
445
  with gr.Column():
446
+ text_inp = gr.Textbox(
447
+ label="Text",
448
+ lines=4,
449
+ value="Trăng treo lửng trên đỉnh núi chơ vơ, ánh sáng bàng bạc phủ lên bãi đá ngổn ngang. Con dế thổn thức trong khe cỏ, tiếng gió hun hút lùa qua hốc núi trập trùng. Dưới thung lũng, đàn trâu gặm cỏ ung dung, hơi sương vẩn đục, lảng bảng giữa đồng khuya tĩnh mịch."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  )
451
 
452
+ # Danh sách speaker trong styles_dict
453
+ spk_choices = list(styles_dict.keys()) if isinstance(styles_dict, dict) else ["Kore","Algenib","Puck","Leda"]
454
+
455
+ # Thứ tự CỐ ĐỊNH cho mixer
456
+ fixed_order = [s for s in ["Kore", "Puck", "Algenib", "Leda"] if s in spk_choices]
457
+ if len(fixed_order) < 4:
458
+ for s in spk_choices:
459
+ if s not in fixed_order:
460
+ fixed_order.append(s)
461
+ if len(fixed_order) == 4:
462
+ break
463
+
464
+ # === Luôn hiển thị 4 voice sample ===
465
+ with gr.Accordion("Reference samples", open=True):
466
+ with gr.Row():
467
+ spk0 = fixed_order[0] if len(fixed_order) > 0 else "Kore"
468
+ spk1 = fixed_order[1] if len(fixed_order) > 1 else "Puck"
469
+ with gr.Column():
470
+ gr.Markdown(f"**{spk0}**")
471
+ gr.Audio(value=get_ref_path_for_speaker(spk0), label=f"{spk0} sample", type="filepath", interactive=False)
472
+ with gr.Column():
473
+ gr.Markdown(f"**{spk1}**")
474
+ gr.Audio(value=get_ref_path_for_speaker(spk1), label=f"{spk1} sample", type="filepath", interactive=False)
475
+ with gr.Row():
476
+ spk2 = fixed_order[2] if len(fixed_order) > 2 else "Algenib"
477
+ spk3 = fixed_order[3] if len(fixed_order) > 3 else "Leda"
478
+ with gr.Column():
479
+ gr.Markdown(f"**{spk2}**")
480
+ gr.Audio(value=get_ref_path_for_speaker(spk2), label=f"{spk2} sample", type="filepath", interactive=False)
481
+ with gr.Column():
482
+ gr.Markdown(f"**{spk3}**")
483
+ gr.Audio(value=get_ref_path_for_speaker(spk3), label=f"{spk3} sample", type="filepath", interactive=False)
484
+
485
+ # ---- Style Mixer cố định 4 slot ----
486
+ with gr.Accordion("Style Mixer", open=True):
487
+ normalize_ck = gr.Checkbox(value=True, label="Normalize weights to 1")
488
+
489
+ # Hàng 1: Kore & Puck
490
+ with gr.Row(equal_height=True):
491
+ with gr.Column():
492
+ gr.Markdown(f"**{fixed_order[0]}**")
493
+ w1 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Weight 1", container=False)
494
+ with gr.Column():
495
+ gr.Markdown(f"**{fixed_order[1]}**")
496
+ w2 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Weight 2", container=False)
497
+
498
+ # Hàng 2: Algenib & Leda
499
+ with gr.Row(equal_height=True):
500
+ with gr.Column():
501
+ gr.Markdown(f"**{fixed_order[2]}**")
502
+ w3 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Weight 3", container=False)
503
+ with gr.Column():
504
+ gr.Markdown(f"**{fixed_order[3]}**")
505
+ w4 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Weight 4", container=False)
506
+
507
+ mix_sum_md = gr.Markdown("**Sum:** 0.000")
508
+ mix_view_json = gr.JSON(label="Mixer weights (view)")
509
+ mix_spec_state = gr.State("") # string mix-spec cho backend
510
+ order_state = gr.State(fixed_order) # giữ thứ tự cố định
511
+
512
+ with gr.Row():
513
+ alpha_n = gr.Number(value=ALPHA, label="alpha diffusion (0-1, timbre)", precision=3)
514
+ beta_n = gr.Number(value=BETA, label="beta diffusion (0-1, prosody)", precision=3)
515
 
516
  btn = gr.Button("Đọc 🔊🔥", variant="primary")
517
 
 
519
  out_audio = gr.Audio(label="Synthesised Audio", type="numpy")
520
  metrics = gr.JSON(label="Metrics")
521
 
522
+ # Bất kỳ thay đổi weight/normalize -> build spec cố định + update tổng/json
523
+ def _ui_build_wrapper_fixed(normalize, w1, w2, w3, w4, order):
524
+ spec, view, summ = _build_mix_spec_ui_fixed(normalize, w1, w2, w3, w4, order)
525
+ return spec, view, summ
526
+
527
+ for comp in [normalize_ck, w1, w2, w3, w4]:
528
+ comp.change(
529
+ _ui_build_wrapper_fixed,
530
+ inputs=[normalize_ck, w1, w2, w3, w4, order_state],
531
+ outputs=[mix_spec_state, mix_view_json, mix_sum_md]
532
+ )
533
+
534
+
535
+ # Nút đọc: dùng mix_spec_state; nếu rỗng => fallback DEFAULT_SPK
536
+ btn.click(
537
+ run_inference,
538
+ inputs=[text_inp, alpha_n, beta_n, mix_spec_state],
539
+ outputs=[out_audio, metrics]
540
+ )
 
 
 
 
 
 
 
 
 
 
541
 
542
  if __name__ == "__main__":
543
  demo.launch()
app2.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Gradio app.py - wired to your 'inference_one' implementation
4
+ - Reference voice: upload OR choose from train_ref/
5
+ - Uses phonemize_text(), compute_style(), inference_one() exactly like your snippet
6
+ - NOW: adds UI sliders for alpha and beta and threads them into inference
7
+ """
8
+
9
+ import os
10
+ import time
11
+ import yaml
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio
15
+ import librosa
16
+ import gradio as gr
17
+
18
+ from munch import Munch
19
+
20
+ # -----------------------------
21
+ # Reproducibility
22
+ # -----------------------------
23
+ torch.manual_seed(0)
24
+ torch.backends.cudnn.benchmark = False
25
+ torch.backends.cudnn.deterministic = True
26
+ np.random.seed(0)
27
+
28
+ # -----------------------------
29
+ # Device / sample-rate
30
+ # -----------------------------
31
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+ SR_OUT = 24000 # target audio rate for synthesis
33
+
34
+ # -----------------------------
35
+ # External modules from the project
36
+ # -----------------------------
37
+ from models import * # noqa: F401,F403
38
+ from utils import * # noqa: F401,F403
39
+ from models import build_model
40
+ from text_utils import TextCleaner
41
+ from Utils_extend_v1.PLBERT.util import load_plbert
42
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
43
+
44
+ textcleaner = TextCleaner()
45
+
46
+ # -----------------------------
47
+ # Config / model loading
48
+ # -----------------------------
49
+ from huggingface_hub import hf_hub_download
50
+ hf_hub_download(
51
+ repo_id="ltphuongunited/styletts2_vi",
52
+ filename="epoch_2nd_00058.pth",
53
+ local_dir="Models/multi_phoaudio_gemini",
54
+ local_dir_use_symlinks=False,
55
+ )
56
+
57
+ # CONFIG_PATH = os.getenv("MODEL_CONFIG", "Models/multi_phoaudio_gemini/config_phoaudio_gemini_small.yml")
58
+ # CHECKPOINT_PTH = os.getenv("MODEL_CKPT", "Models/multi_phoaudio_gemini/epoch_2nd_00058.pth")
59
+
60
+ CHECKPOINT_PTH = "Models/gemini_vi/gemini_2nd_00045.pth"
61
+ CONFIG_PATH = "Models/gemini_vi/config_gemini_vi_en.yml"
62
+
63
+ # Load config
64
+ config = yaml.safe_load(open(CONFIG_PATH))
65
+
66
+ # Build components
67
+ ASR_config = config.get("ASR_config", False)
68
+ ASR_path = config.get("ASR_path", False)
69
+ F0_path = config.get("F0_path", False)
70
+ PLBERT_dir = config.get("PLBERT_dir", False)
71
+
72
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
73
+ pitch_extractor = load_F0_models(F0_path)
74
+ plbert = load_plbert(PLBERT_dir)
75
+
76
+ model_params = recursive_munch(config["model_params"])
77
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
78
+
79
+ # to device & eval
80
+ _ = [model[k].to(DEVICE) for k in model]
81
+ _ = [model[k].eval() for k in model]
82
+
83
+ # Load checkpoint
84
+ if not os.path.isfile(CHECKPOINT_PTH):
85
+ raise FileNotFoundError(f"Checkpoint not found at '{CHECKPOINT_PTH}'")
86
+ ckpt = torch.load(CHECKPOINT_PTH, map_location="cpu")
87
+ params = ckpt["net"]
88
+ for key in model:
89
+ if key in params:
90
+ try:
91
+ model[key].load_state_dict(params[key])
92
+ except Exception:
93
+ from collections import OrderedDict
94
+ state_dict = params[key]
95
+ new_state = OrderedDict()
96
+ for k, v in state_dict.items():
97
+ name = k[7:] # strip 'module.' if present
98
+ new_state[name] = v
99
+ model[key].load_state_dict(new_state, strict=False)
100
+ _ = [model[k].eval() for k in model]
101
+
102
+ # Diffusion sampler
103
+ sampler = DiffusionSampler(
104
+ model.diffusion.diffusion,
105
+ sampler=ADPM2Sampler(),
106
+ sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0),
107
+ clamp=False,
108
+ )
109
+
110
+ # -----------------------------
111
+ # Audio helper: mel preprocessing
112
+ # -----------------------------
113
+ _to_mel = torchaudio.transforms.MelSpectrogram(
114
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300
115
+ )
116
+ _MEAN, _STD = -4.0, 4.0
117
+
118
+ def length_to_mask(lengths: torch.LongTensor) -> torch.Tensor:
119
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
120
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
121
+ return mask
122
+
123
+ def preprocess(wave: np.ndarray) -> torch.Tensor:
124
+ """Same name as your snippet: np.float -> mel (normed)"""
125
+ wave_tensor = torch.from_numpy(wave).float()
126
+ mel_tensor = _to_mel(wave_tensor)
127
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - _MEAN) / _STD
128
+ return mel_tensor
129
+
130
+ # -----------------------------
131
+ # Phonemizer (vi)
132
+ # -----------------------------
133
+ import phonemizer
134
+ vi_phonemizer = phonemizer.backend.EspeakBackend(language="vi", preserve_punctuation=True, with_stress=True)
135
+ global_phonemizer = vi_phonemizer
136
+
137
+ def phonemize_text(text: str) -> str:
138
+ ps = global_phonemizer.phonemize([text])[0]
139
+ return ps.replace("(en)", "").replace("(vi)", "").strip()
140
+
141
+ # -----------------------------
142
+ # Style extractor (from file path)
143
+ # -----------------------------
144
+
145
+ def compute_style(model, path, device):
146
+ """Compute style/prosody reference from a wav file path"""
147
+ wave, sr = librosa.load(path, sr=None, mono=True)
148
+ audio, _ = librosa.effects.trim(wave, top_db=30)
149
+ if sr != SR_OUT:
150
+ audio = librosa.resample(audio, sr, SR_OUT)
151
+ mel_tensor = preprocess(audio).to(device)
152
+
153
+ with torch.no_grad():
154
+ ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
155
+ ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
156
+ return torch.cat([ref_s, ref_p], dim=1) # [1, 256]
157
+
158
+ # Style extractor (from numpy array)
159
+
160
+ def compute_style_from_numpy(model, arr: np.ndarray, sr: int, device):
161
+ if arr.ndim > 1:
162
+ arr = librosa.to_mono(arr.T)
163
+ audio, _ = librosa.effects.trim(arr, top_db=30)
164
+ if sr != SR_OUT:
165
+ audio = librosa.resample(audio, sr, SR_OUT)
166
+ mel_tensor = preprocess(audio).to(device)
167
+ with torch.no_grad():
168
+ ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
169
+ ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
170
+ return torch.cat([ref_s, ref_p], dim=1)
171
+
172
+ # -----------------------------
173
+ # Inference (your exact logic)
174
+ # -----------------------------
175
+ # Tunables (still as defaults; UI will override)
176
+ ALPHA = 0.3
177
+ BETA = 0.7
178
+ DIFFUSION_STEPS = 5
179
+ EMBEDDING_SCALE = 1.0
180
+
181
+ def inference_one(text, ref_feat, ipa_text=None,
182
+ alpha=ALPHA, beta=BETA, diffusion_steps=DIFFUSION_STEPS, embedding_scale=EMBEDDING_SCALE):
183
+ # text -> phonemes -> tokens
184
+ ps = ipa_text if ipa_text is not None else phonemize_text(text)
185
+ tokens = textcleaner(ps)
186
+ tokens.insert(0, 0) # prepend BOS
187
+ tokens = torch.LongTensor(tokens).to(DEVICE).unsqueeze(0) # [1, T]
188
+
189
+ with torch.no_grad():
190
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(DEVICE)
191
+ text_mask = length_to_mask(input_lengths).to(DEVICE)
192
+
193
+ # encoders
194
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
195
+ bert_d = model.bert(tokens, attention_mask=(~text_mask).int())
196
+ d_en = model.bert_encoder(bert_d).transpose(-1, -2)
197
+
198
+
199
+ if alpha == 0 and beta == 0:
200
+ print("Ignore Diffusion")
201
+ ref = ref_feat[:, :128]
202
+ s = ref_feat[:, 128:]
203
+ simi_timbre, simi_prosody = 1,1
204
+ else:
205
+ print("Have Diffusion")
206
+ # diffusion for style latent
207
+ s_pred = sampler(
208
+ noise=torch.randn((1, 256)).unsqueeze(1).to(DEVICE),
209
+ embedding=bert_d,
210
+ embedding_scale=embedding_scale,
211
+ features=ref_feat, # [1, 256]
212
+ num_steps=diffusion_steps,
213
+ ).squeeze(1) # [1, 256]
214
+
215
+ s = s_pred[:, 128:] # prosody
216
+ ref = s_pred[:, :128] # timbre
217
+
218
+ # blend with real ref features
219
+ ref = alpha * ref + (1 - alpha) * ref_feat[:, :128]
220
+ s = beta * s + (1 - beta) * ref_feat[:, 128:]
221
+
222
+ with torch.no_grad():
223
+ ref0 = ref_feat[:, :128] # timbre gốc
224
+ s0 = ref_feat[:, 128:] # prosody gốc
225
+
226
+ eps = 1e-8
227
+
228
+ def stats(name, new, base):
229
+ delta = new - base
230
+ l2_delta = torch.norm(delta, dim=1) # ||Δ||
231
+ l2_base = torch.norm(base, dim=1) + eps # ||x||
232
+ rel_l2 = (l2_delta / l2_base) # ||Δ|| / ||x||
233
+ mae = torch.mean(torch.abs(delta), dim=1) # MAE
234
+ cos_sim = F.cosine_similarity(new, base, dim=1) # cos(new, base)
235
+ snr_db = 20.0 * torch.log10(l2_base / (l2_delta + eps)) # SNR ~ 20*log10(||x||/||Δ||)
236
+ # # Inference batch thường =1, nhưng vẫn in theo batch để tổng quát
237
+ # for i in range(new.shape[0]):
238
+ # print(f"[{name}][sample {i}] "
239
+ # f"L2Δ={l2_delta[i]:.4f} | relL2={rel_l2[i]:.4f} | MAE={mae[i]:.6f} | "
240
+ # f"cos={cos_sim[i]:.4f} | SNR={snr_db[i]:.2f} dB")
241
+
242
+ return cos_sim
243
+
244
+
245
+ simi_timbre = stats("REF(timbre)", s_pred[:, :128], ref_feat[:, :128]).detach().cpu().squeeze().item()
246
+ simi_prosody = stats("S(prosody)", s_pred[:, 128:], ref_feat[:, 128:]).detach().cpu().squeeze().item()
247
+
248
+ # duration prediction
249
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
250
+ x, _ = model.predictor.lstm(d)
251
+ duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1)
252
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
253
+
254
+ # alignment
255
+ T = int(pred_dur.sum().item())
256
+ pred_aln = torch.zeros(input_lengths.item(), T, device=DEVICE)
257
+ c = 0
258
+ for i in range(input_lengths.item()):
259
+ span = int(pred_dur[i].item())
260
+ pred_aln[i, c:c+span] = 1.0
261
+ c += span
262
+
263
+ # prosody enc
264
+ en = (d.transpose(-1, -2) @ pred_aln.unsqueeze(0))
265
+ if model_params.decoder.type == "hifigan":
266
+ asr_new = torch.zeros_like(en); asr_new[:, :, 0] = en[:, :, 0]; asr_new[:, :, 1:] = en[:, :, 0:-1]; en = asr_new
267
+
268
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
269
+
270
+ # content (ASR-aligned)
271
+ asr = (t_en @ pred_aln.unsqueeze(0))
272
+ if model_params.decoder.type == "hifigan":
273
+ asr_new = torch.zeros_like(asr); asr_new[:, :, 0] = asr[:, :, 0]; asr_new[:, :, 1:] = asr[:, :, 0:-1]; asr = asr_new
274
+
275
+ # decode
276
+ out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
277
+
278
+ wav = out.squeeze().detach().cpu().numpy()
279
+ if wav.shape[-1] > 50:
280
+ wav = wav[..., :-50]
281
+ return wav, ps, simi_timbre, simi_prosody
282
+
283
+ # -----------------------------
284
+ # Gradio UI
285
+ # -----------------------------
286
+
287
+ SR_OUT = 24000
288
+ ROOT_REF = "ref_voice"
289
+ EXTS = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
290
+
291
+ # -------- scan ref_voice/<id>_<speaker>/*.wav --------
292
+
293
+ def scan_ref_voice(root=ROOT_REF):
294
+ """
295
+ return:
296
+ speakers: list[str] # ví dụ: ["0_Fonos.vn", "1_James_A._Robinson", ...]
297
+ files_by_spk: dict[str, list[str]] # speaker_dir -> [full_path,...]
298
+ """
299
+ speakers, files_by_spk = [], {}
300
+ if not os.path.isdir(root):
301
+ return speakers, files_by_spk
302
+
303
+ for spk_dir in sorted(os.listdir(root)):
304
+ full_dir = os.path.join(root, spk_dir)
305
+ if not os.path.isdir(full_dir) or spk_dir.startswith("."):
306
+ continue
307
+ lst = []
308
+ for fn in sorted(os.listdir(full_dir)):
309
+ if os.path.splitext(fn)[1].lower() in EXTS:
310
+ lst.append(os.path.join(full_dir, fn))
311
+ if lst:
312
+ speakers.append(spk_dir)
313
+ files_by_spk[spk_dir] = lst
314
+ return speakers, files_by_spk
315
+
316
+ SPEAKERS, FILES_BY_SPK = scan_ref_voice()
317
+
318
+ with gr.Blocks(title="StyleTTS2-vi Demo ✨") as demo:
319
+ gr.Markdown("# StyleTTS2-vi Demo ✨")
320
+
321
+ with gr.Row():
322
+ with gr.Column():
323
+ text_inp = gr.Textbox(label="Text", lines=4,
324
+ value="Thời tiết hôm nay tại Hà Nội, nhiệt độ khoảng 27 độ C, có nắng nhẹ, rất hợp lý để mình đi dạo công viên nhé.")
325
+
326
+ # --- 1 ô audio duy nhất (nhận filepath) ---
327
+ ref_audio = gr.Audio(
328
+ label="Reference Audio",
329
+ type="filepath", # nhận đường dẫn file
330
+ sources=["upload","microphone"], # vẫn cho upload/mic
331
+ interactive=True,
332
+ )
333
+ ref_path = gr.Textbox(label="Đường dẫn reference", interactive=False)
334
+
335
+ # --- chọn speaker -> hiện file tương ứng ---
336
+ spk_dd = gr.Dropdown(
337
+ label="Speaker",
338
+ choices=["(None)"] + SPEAKERS,
339
+ value="(None)",
340
+ )
341
+ file_dd = gr.Dropdown(
342
+ label="Voice in speaker",
343
+ choices=["(None)"],
344
+ value="(None)",
345
+ )
346
+
347
+ # khi chọn speaker -> cập nhật danh sách file
348
+ def on_pick_speaker(spk):
349
+ if spk == "(None)":
350
+ return gr.update(choices=["(None)"], value="(None)")
351
+ files = FILES_BY_SPK.get(spk, [])
352
+ # hiển thị chỉ tên file cho gọn
353
+ labels = [os.path.basename(p) for p in files]
354
+ # ta sẽ map label->path bằng index; set value = mục đầu tiên
355
+ return gr.update(choices=labels, value=(labels[0] if labels else "(None)"))
356
+
357
+ spk_dd.change(on_pick_speaker, inputs=spk_dd, outputs=file_dd)
358
+
359
+ # map label (basename) -> full path theo speaker hiện tại
360
+ def on_pick_file(spk, label):
361
+ if spk == "(None)" or label == "(None)":
362
+ return gr.update(value=None), ""
363
+ files = FILES_BY_SPK.get(spk, [])
364
+ # tìm đúng file theo basename
365
+ for p in files:
366
+ if os.path.basename(p) == label:
367
+ return gr.update(value=p), p # set vào Audio + hiển thị path
368
+ return gr.update(value=None), ""
369
+
370
+ file_dd.change(on_pick_file, inputs=[spk_dd, file_dd], outputs=[ref_audio, ref_path])
371
+
372
+ # nếu người dùng upload/mic thì hiển thị luôn đường dẫn file tạm
373
+ def on_audio_changed(fp):
374
+ return fp or ""
375
+ ref_audio.change(on_audio_changed, inputs=ref_audio, outputs=ref_path)
376
+
377
+ # --- NEW: alpha/beta numeric inputs ---
378
+ with gr.Row():
379
+ alpha_n = gr.Number(value=ALPHA, label="alpha (0-1, timbre)", precision=3)
380
+ beta_n = gr.Number(value=BETA, label="beta (0-1, prosody)", precision=3)
381
+
382
+ btn = gr.Button("Đọc 🔊🔥", variant="primary")
383
+
384
+ with gr.Column():
385
+ out_audio = gr.Audio(label="Synthesised Audio", type="numpy")
386
+ metrics = gr.JSON(label="Metrics")
387
+
388
+ # ---- Inference: xử lý từ filepath ----
389
+ def _run(text, ref_fp, alpha, beta):
390
+ # ref_fp là string path (do type='filepath')
391
+ if isinstance(ref_fp, str) and os.path.isfile(ref_fp):
392
+ wav, _ = librosa.load(ref_fp, sr=SR_OUT, mono=True)
393
+ ref_feat = compute_style_from_numpy(model, wav, SR_OUT, DEVICE)
394
+ ref_src = ref_fp
395
+ else:
396
+ ref_feat = torch.zeros(1, 256).to(DEVICE)
397
+ ref_src = "(None)"
398
+
399
+ t0 = time.time()
400
+ wav, ps, simi_timbre, simi_prosody = inference_one(text, ref_feat, alpha=float(alpha), beta=float(beta))
401
+ wav = wav.astype(np.float32)
402
+ gen_time = time.time() - t0
403
+ rtf = gen_time / max(1e-6, len(wav)/SR_OUT)
404
+
405
+ info = {
406
+ "simi_timbre": round(float(simi_timbre), 4) ,
407
+ "simi_prosody": round(float(simi_prosody), 4) ,
408
+ "Phonemes": ps,
409
+ "Sample rate": SR_OUT,
410
+ "RTF": round(float(rtf), 3),
411
+ "Device": DEVICE,
412
+ }
413
+ return (SR_OUT, wav), info
414
+
415
+ btn.click(_run, inputs=[text_inp, ref_audio, alpha_n, beta_n], outputs=[out_audio, metrics])
416
+
417
+
418
+ if __name__ == "__main__":
419
+ demo.launch()
ref_voice/0000000_Kore_Quân_sự.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abfd3f8771395bbcb0789f2f5e61fab93ac186672d0ff756fbe5c44854bb4cc3
3
+ size 284730
ref_voice/0000001_Algenib_Giáo_dục.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ddceacd307b7f1c0e9ba086bf3744a42c4fd128f94f1e50af9465c0d7a72eff
3
+ size 411450
ref_voice/0000002_Puck_Giáo_dục.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec31b58f8ea3606eeb7c785d9bf66db952e15ad5fe8c74875771f29a37976dff
3
+ size 760890
ref_voice/0000003_Leda_Giáo_dục.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fef8463fceb431353e8a6155cb571fe10827983cf72a85772822c6656bd72ba
3
+ size 672570
ref_voice/{5_kore_gemini-train-kore-sample_020996.wav → eng/5_kore_gemini-train-kore-sample_020996.wav} RENAMED
File without changes
ref_voice/{6_puck_gemini-train-puck-sample_017190.wav → eng/6_puck_gemini-train-puck-sample_017190.wav} RENAMED
File without changes
requirements.txt CHANGED
@@ -14,4 +14,5 @@ tqdm
14
  typing
15
  typing-extensions
16
  git+https://github.com/resemble-ai/monotonic_align.git
17
- phonemizer
 
 
14
  typing
15
  typing-extensions
16
  git+https://github.com/resemble-ai/monotonic_align.git
17
+ phonemizer
18
+ soe-vinorm
train_second.py CHANGED
@@ -349,7 +349,7 @@ def main(config_path):
349
  s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
350
 
351
  bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
352
- d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
353
 
354
  # denoiser training
355
  if epoch >= diff_epoch:
 
349
  s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
350
 
351
  bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
352
+ d_en = model.bert_encodattention_masker(bert_dur).transpose(-1, -2)
353
 
354
  # denoiser training
355
  if epoch >= diff_epoch: