NewGame commited on
Commit
c07485b
·
1 Parent(s): e9e1122

add the second accent feature

Browse files
Files changed (1) hide show
  1. app.py +88 -43
app.py CHANGED
@@ -8,6 +8,7 @@ lifetime of the Space instance.
8
  """
9
 
10
  import os
 
11
  import tempfile
12
 
13
  import gradio as gr
@@ -15,6 +16,7 @@ import torch
15
  from huggingface_hub import snapshot_download
16
 
17
  from accent_task_vectors.inference import load_xtts_model, attach_lora_adapter
 
18
 
19
  # ---------------------------------------------------------------------------
20
  # Model registry (mirrors download_checkpoints.py)
@@ -54,10 +56,9 @@ ACCENTS_BY_LANGUAGE = {
54
  # Paths
55
  # ---------------------------------------------------------------------------
56
 
57
- CACHE_DIR = os.environ.get("MODEL_CACHE_DIR", "model_cache")
58
  PRETRAINED_DIR = os.path.join(CACHE_DIR, "pretrained")
59
 
60
- # Keys in config.json that hold pretrained model paths
61
  _PRETRAINED_PATH_FIELDS = {
62
  "mel_norm_file": "mel_stats.pth",
63
  "dvae_checkpoint": "dvae.pth",
@@ -66,17 +67,17 @@ _PRETRAINED_PATH_FIELDS = {
66
  }
67
 
68
  # ---------------------------------------------------------------------------
69
- # In-memory model cache {(language, accent): tts}
 
 
70
  # ---------------------------------------------------------------------------
71
 
72
- _model_cache: dict = {}
 
73
  _device = "cuda" if torch.cuda.is_available() else "cpu"
74
 
75
 
76
  def _patch_config(config_path: str, pretrained_dir: str) -> None:
77
- """Rewrite pretrained model paths in config.json to point to local dir."""
78
- import json
79
-
80
  with open(config_path) as f:
81
  config = json.load(f)
82
 
@@ -103,7 +104,6 @@ def _patch_config(config_path: str, pretrained_dir: str) -> None:
103
 
104
 
105
  def _ensure_pretrained() -> None:
106
- """Download the base pretrained XTTS model if not already cached."""
107
  if not os.path.isdir(PRETRAINED_DIR):
108
  print(f"Downloading pretrained model from {PRETRAINED_REPO} …")
109
  snapshot_download(
@@ -113,18 +113,11 @@ def _ensure_pretrained() -> None:
113
  )
114
 
115
 
116
- def _load_model(language: str, accent: str) -> object:
117
- """Return a cached (or freshly loaded) TTS model for the given combination."""
118
- key = (language, accent)
119
- if key in _model_cache:
120
- return _model_cache[key]
121
-
122
- _ensure_pretrained()
123
-
124
- repo_id = MODELS[key]
125
  lora_dir = os.path.join(CACHE_DIR, f"{accent.lower()}-accent-{language.lower()}")
126
-
127
  if not os.path.isdir(lora_dir):
 
128
  print(f"Downloading LoRA adapter from {repo_id} …")
129
  snapshot_download(
130
  repo_id=repo_id,
@@ -133,15 +126,33 @@ def _load_model(language: str, accent: str) -> object:
133
  allow_patterns=["config.json", "lora/best_model/**"],
134
  )
135
  _patch_config(os.path.join(lora_dir, "config.json"), PRETRAINED_DIR)
 
 
 
 
 
 
 
 
136
 
 
 
 
137
  checkpoint_path = os.path.join(PRETRAINED_DIR, "checkpoint_0.pth")
138
- config_path = os.path.join(lora_dir, "config.json")
139
- lora_path = os.path.join(lora_dir, "lora", "best_model")
140
 
141
  tts = load_xtts_model(checkpoint_path, config_path, device=_device)
142
- tts = attach_lora_adapter(tts, lora_path=lora_path)
 
 
 
 
 
 
143
 
144
- _model_cache[key] = tts
 
145
  return tts
146
 
147
 
@@ -149,22 +160,37 @@ def _load_model(language: str, accent: str) -> object:
149
  # Inference function called by Gradio
150
  # ---------------------------------------------------------------------------
151
 
152
- def synthesise(text: str, speaker_audio: str, language: str, accent: str, lora_coeff: float):
 
 
 
 
 
 
 
 
 
153
  if not text.strip():
154
  raise gr.Error("Please enter some text to synthesise.")
155
  if speaker_audio is None:
156
  raise gr.Error("Please upload a reference speaker audio file.")
157
- if (language, accent) not in MODELS:
158
- raise gr.Error(f"Unsupported combination: language={language}, accent={accent}.")
159
 
160
- tts = _load_model(language, accent)
161
 
162
- # Scale LoRA if needed
163
- if lora_coeff != 1.0:
164
- from accent_task_vectors.inference.inference import _scale_lora
165
- # Reset to 1.0 first, then apply desired coefficient
166
- _scale_lora(tts, lora_coeff / getattr(tts, "_last_lora_coeff", 1.0))
167
- tts._last_lora_coeff = lora_coeff
 
 
 
 
 
 
168
 
169
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
170
  output_path = tmp.name
@@ -211,34 +237,52 @@ the speaker's **accent**, upload a short reference audio clip, and type your tex
211
  label="Reference speaker audio (3–10 s)",
212
  type="filepath",
213
  )
 
214
  with gr.Row():
215
  language_dd = gr.Dropdown(
216
  label="Output language",
217
  choices=list(ACCENTS_BY_LANGUAGE.keys()),
218
  value="English",
219
  )
220
- accent_dd = gr.Dropdown(
221
  label="Speaker accent",
222
  choices=ACCENTS_BY_LANGUAGE["English"],
223
  value="English",
224
  )
225
- lora_coeff = gr.Slider(
226
- label="Accent strength (LoRA coefficient)",
227
- minimum=0.0,
228
- maximum=2.0,
229
- step=0.05,
230
- value=1.0,
231
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  generate_btn = gr.Button("Generate", variant="primary")
233
 
234
  with gr.Column():
235
  audio_output = gr.Audio(label="Generated speech", type="filepath")
236
 
237
- language_dd.change(fn=update_accent_choices, inputs=language_dd, outputs=accent_dd)
 
 
238
 
239
  generate_btn.click(
240
  fn=synthesise,
241
- inputs=[text_input, speaker_audio, language_dd, accent_dd, lora_coeff],
 
 
 
 
242
  outputs=audio_output,
243
  )
244
 
@@ -250,8 +294,9 @@ the speaker's **accent**, upload a short reference audio clip, and type your tex
250
  2. **Speaker accent** — the L1 accent of the target speaker style.
251
  3. **Reference audio** — a clean 3–10 second clip of any speaker; the model
252
  clones the voice while applying the chosen accent.
253
- 4. **Accent strength** — scale the LoRA adapter contribution (1.0 = default,
254
- 0 = no accent modification, >1 = stronger accent).
 
255
 
256
  Models are downloaded automatically on first use.
257
  """
 
8
  """
9
 
10
  import os
11
+ import json
12
  import tempfile
13
 
14
  import gradio as gr
 
16
  from huggingface_hub import snapshot_download
17
 
18
  from accent_task_vectors.inference import load_xtts_model, attach_lora_adapter
19
+ from accent_task_vectors.inference.inference import _scale_lora
20
 
21
  # ---------------------------------------------------------------------------
22
  # Model registry (mirrors download_checkpoints.py)
 
56
  # Paths
57
  # ---------------------------------------------------------------------------
58
 
59
+ CACHE_DIR = os.environ.get("MODEL_CACHE_DIR", "model_cache")
60
  PRETRAINED_DIR = os.path.join(CACHE_DIR, "pretrained")
61
 
 
62
  _PRETRAINED_PATH_FIELDS = {
63
  "mel_norm_file": "mel_stats.pth",
64
  "dvae_checkpoint": "dvae.pth",
 
67
  }
68
 
69
  # ---------------------------------------------------------------------------
70
+ # In-memory model cache
71
+ # _model_cache: (language, accent1, accent2|None) -> tts
72
+ # _current_coeffs: same key -> (coeff1, coeff2)
73
  # ---------------------------------------------------------------------------
74
 
75
+ _model_cache: dict = {}
76
+ _current_coeffs: dict = {}
77
  _device = "cuda" if torch.cuda.is_available() else "cpu"
78
 
79
 
80
  def _patch_config(config_path: str, pretrained_dir: str) -> None:
 
 
 
81
  with open(config_path) as f:
82
  config = json.load(f)
83
 
 
104
 
105
 
106
  def _ensure_pretrained() -> None:
 
107
  if not os.path.isdir(PRETRAINED_DIR):
108
  print(f"Downloading pretrained model from {PRETRAINED_REPO} …")
109
  snapshot_download(
 
113
  )
114
 
115
 
116
+ def _download_lora(language: str, accent: str) -> str:
117
+ """Download a LoRA adapter if needed; return its local directory."""
 
 
 
 
 
 
 
118
  lora_dir = os.path.join(CACHE_DIR, f"{accent.lower()}-accent-{language.lower()}")
 
119
  if not os.path.isdir(lora_dir):
120
+ repo_id = MODELS[(language, accent)]
121
  print(f"Downloading LoRA adapter from {repo_id} …")
122
  snapshot_download(
123
  repo_id=repo_id,
 
126
  allow_patterns=["config.json", "lora/best_model/**"],
127
  )
128
  _patch_config(os.path.join(lora_dir, "config.json"), PRETRAINED_DIR)
129
+ return lora_dir
130
+
131
+
132
+ def _load_model(language: str, accent1: str, accent2: str | None):
133
+ """Return a cached TTS model with adapter(s) loaded at coeff=1.0."""
134
+ key = (language, accent1, accent2)
135
+ if key in _model_cache:
136
+ return _model_cache[key]
137
 
138
+ _ensure_pretrained()
139
+
140
+ lora_dir1 = _download_lora(language, accent1)
141
  checkpoint_path = os.path.join(PRETRAINED_DIR, "checkpoint_0.pth")
142
+ config_path = os.path.join(lora_dir1, "config.json")
143
+ lora_path1 = os.path.join(lora_dir1, "lora", "best_model")
144
 
145
  tts = load_xtts_model(checkpoint_path, config_path, device=_device)
146
+ tts = attach_lora_adapter(tts, lora_path=lora_path1, adapter_name="default", scaling_coef=1.0)
147
+
148
+ if accent2 is not None:
149
+ lora_dir2 = _download_lora(language, accent2)
150
+ lora_path2 = os.path.join(lora_dir2, "lora", "best_model")
151
+ tts = attach_lora_adapter(tts, lora_path=lora_path2, adapter_name="other", scaling_coef=1.0)
152
+ tts.synthesizer.tts_model.set_adapter(["default", "other"])
153
 
154
+ _model_cache[key] = tts
155
+ _current_coeffs[key] = (1.0, 1.0)
156
  return tts
157
 
158
 
 
160
  # Inference function called by Gradio
161
  # ---------------------------------------------------------------------------
162
 
163
+ def synthesise(
164
+ text: str,
165
+ speaker_audio: str,
166
+ language: str,
167
+ accent1: str,
168
+ coeff1: float,
169
+ enable_second: bool,
170
+ accent2: str,
171
+ coeff2: float,
172
+ ):
173
  if not text.strip():
174
  raise gr.Error("Please enter some text to synthesise.")
175
  if speaker_audio is None:
176
  raise gr.Error("Please upload a reference speaker audio file.")
177
+ if (language, accent1) not in MODELS:
178
+ raise gr.Error(f"Unsupported combination: language={language}, accent={accent1}.")
179
 
180
+ accent2_key = accent2 if enable_second else None
181
 
182
+ if enable_second and (language, accent2) not in MODELS:
183
+ raise gr.Error(f"Unsupported combination: language={language}, accent={accent2}.")
184
+
185
+ tts = _load_model(language, accent1, accent2_key)
186
+ key = (language, accent1, accent2_key)
187
+
188
+ # Rescale adapters from their current cached coefficients to the desired ones
189
+ prev_coeff1, prev_coeff2 = _current_coeffs[key]
190
+ _scale_lora(tts, coeff1 / prev_coeff1, adapter_name="default")
191
+ if accent2_key is not None:
192
+ _scale_lora(tts, coeff2 / prev_coeff2, adapter_name="other")
193
+ _current_coeffs[key] = (coeff1, coeff2 if accent2_key else 1.0)
194
 
195
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
196
  output_path = tmp.name
 
237
  label="Reference speaker audio (3–10 s)",
238
  type="filepath",
239
  )
240
+
241
  with gr.Row():
242
  language_dd = gr.Dropdown(
243
  label="Output language",
244
  choices=list(ACCENTS_BY_LANGUAGE.keys()),
245
  value="English",
246
  )
247
+ accent1_dd = gr.Dropdown(
248
  label="Speaker accent",
249
  choices=ACCENTS_BY_LANGUAGE["English"],
250
  value="English",
251
  )
252
+ coeff1_slider = gr.Slider(
253
+ label="Accent strength",
254
+ minimum=0.0, maximum=1.0, step=0.05, value=1.0,
 
 
 
255
  )
256
+
257
+ with gr.Accordion("Mix a second accent (optional)", open=False):
258
+ enable_second = gr.Checkbox(label="Enable second accent", value=False)
259
+ accent2_dd = gr.Dropdown(
260
+ label="Second accent",
261
+ choices=ACCENTS_BY_LANGUAGE["English"],
262
+ value="Hindi",
263
+ interactive=True,
264
+ )
265
+ coeff2_slider = gr.Slider(
266
+ label="Second accent strength",
267
+ minimum=0.0, maximum=1.0, step=0.05, value=0.5,
268
+ )
269
+
270
  generate_btn = gr.Button("Generate", variant="primary")
271
 
272
  with gr.Column():
273
  audio_output = gr.Audio(label="Generated speech", type="filepath")
274
 
275
+ # Update both accent dropdowns when language changes
276
+ language_dd.change(fn=update_accent_choices, inputs=language_dd, outputs=accent1_dd)
277
+ language_dd.change(fn=update_accent_choices, inputs=language_dd, outputs=accent2_dd)
278
 
279
  generate_btn.click(
280
  fn=synthesise,
281
+ inputs=[
282
+ text_input, speaker_audio,
283
+ language_dd, accent1_dd, coeff1_slider,
284
+ enable_second, accent2_dd, coeff2_slider,
285
+ ],
286
  outputs=audio_output,
287
  )
288
 
 
294
  2. **Speaker accent** — the L1 accent of the target speaker style.
295
  3. **Reference audio** — a clean 3–10 second clip of any speaker; the model
296
  clones the voice while applying the chosen accent.
297
+ 4. **Accent strength** — LoRA adapter contribution (0 = no accent effect, 1 = full).
298
+ 5. **Mix a second accent** optionally blend two accents together by enabling
299
+ a second adapter and setting its strength independently.
300
 
301
  Models are downloaded automatically on first use.
302
  """