AccentVector / app.py
NewGame's picture
update deatils
7f59768
"""Gradio demo for Accent Vectors.
Lets users synthesise speech with a controllable accent directly in the
browser — no local setup required.
Models are downloaded from Hugging Face on first use and cached for the
lifetime of the Space instance.
"""
import os
import json
import tempfile
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from accent_task_vectors.inference import load_xtts_model, attach_lora_adapter
from accent_task_vectors.inference.inference import _scale_lora
# ---------------------------------------------------------------------------
# Model registry (mirrors download_checkpoints.py)
# ---------------------------------------------------------------------------
PRETRAINED_REPO = "NewGame/pretrained-xtts"
MODELS = {
("English", "English"): "NewGame/english-accent-english-xtts",
("English", "Hindi"): "NewGame/hindi-accent-english-xtts",
("English", "German"): "NewGame/german-accent-english-xtts",
("English", "French"): "NewGame/french-accent-english-xtts",
("English", "Spanish"): "NewGame/spanish-accent-english-xtts",
("English", "Mandarin"): "NewGame/mandarin-accent-english-xtts",
("Spanish", "English"): "NewGame/english-accent-spanish-xtts",
("German", "English"): "NewGame/english-accent-german-xtts",
("Mandarin", "English"): "NewGame/english-accent-mandarin-xtts",
}
# Language code passed to the TTS model
LANGUAGE_CODES = {
"English": "en",
"Spanish": "es",
"German": "de",
"Mandarin": "zh-cn",
}
# Accents available for each output language
ACCENTS_BY_LANGUAGE = {
"English": ["English", "Hindi", "German", "French", "Spanish", "Mandarin"],
"Spanish": ["English"],
"German": ["English"],
"Mandarin": ["English"],
}
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
CACHE_DIR = os.environ.get("MODEL_CACHE_DIR", "model_cache")
PRETRAINED_DIR = os.path.join(CACHE_DIR, "pretrained")
_PRETRAINED_PATH_FIELDS = {
"mel_norm_file": "mel_stats.pth",
"dvae_checkpoint": "dvae.pth",
"xtts_checkpoint": "model.pth",
"tokenizer_file": "vocab.json",
}
# ---------------------------------------------------------------------------
# In-memory model cache
# _model_cache: (language, accent1, accent2|None) -> tts
# _current_coeffs: same key -> (coeff1, coeff2)
# ---------------------------------------------------------------------------
_model_cache: dict = {}
_current_coeffs: dict = {}
_device = "cuda" if torch.cuda.is_available() else "cpu"
def _patch_config(config_path: str, pretrained_dir: str) -> None:
with open(config_path) as f:
config = json.load(f)
abs_pretrained = os.path.abspath(pretrained_dir)
changed = False
def _patch(obj):
nonlocal changed
if isinstance(obj, dict):
for key, filename in _PRETRAINED_PATH_FIELDS.items():
if key in obj:
new_val = os.path.join(abs_pretrained, filename)
if obj[key] != new_val:
obj[key] = new_val
changed = True
for v in obj.values():
_patch(v)
_patch(config)
if changed:
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
def _ensure_pretrained() -> None:
if not os.path.isdir(PRETRAINED_DIR):
print(f"Downloading pretrained model from {PRETRAINED_REPO} …")
snapshot_download(
repo_id=PRETRAINED_REPO,
repo_type="model",
local_dir=PRETRAINED_DIR,
)
def _download_lora(language: str, accent: str) -> str:
"""Download a LoRA adapter if needed; return its local directory."""
lora_dir = os.path.join(CACHE_DIR, f"{accent.lower()}-accent-{language.lower()}")
if not os.path.isdir(lora_dir):
repo_id = MODELS[(language, accent)]
print(f"Downloading LoRA adapter from {repo_id} …")
snapshot_download(
repo_id=repo_id,
repo_type="model",
local_dir=lora_dir,
allow_patterns=["config.json", "lora/best_model/**"],
)
_patch_config(os.path.join(lora_dir, "config.json"), PRETRAINED_DIR)
return lora_dir
def _load_model(language: str, accent1: str, accent2: str | None):
"""Return a cached TTS model with adapter(s) loaded at coeff=1.0."""
key = (language, accent1, accent2)
if key in _model_cache:
return _model_cache[key]
_ensure_pretrained()
lora_dir1 = _download_lora(language, accent1)
checkpoint_path = os.path.join(PRETRAINED_DIR, "checkpoint_0.pth")
config_path = os.path.join(lora_dir1, "config.json")
lora_path1 = os.path.join(lora_dir1, "lora", "best_model")
tts = load_xtts_model(checkpoint_path, config_path, device=_device)
tts = attach_lora_adapter(tts, lora_path=lora_path1, adapter_name="default", scaling_coef=1.0)
if accent2 is not None:
lora_dir2 = _download_lora(language, accent2)
lora_path2 = os.path.join(lora_dir2, "lora", "best_model")
tts = attach_lora_adapter(tts, lora_path=lora_path2, adapter_name="other", scaling_coef=1.0)
tts.synthesizer.tts_model.set_adapter(["default", "other"])
_model_cache[key] = tts
_current_coeffs[key] = (1.0, 1.0)
return tts
# ---------------------------------------------------------------------------
# Inference function called by Gradio
# ---------------------------------------------------------------------------
def synthesise(
text: str,
speaker_audio: str,
language: str,
accent1: str,
coeff1: float,
enable_second: bool,
accent2: str,
coeff2: float,
):
if not text.strip():
raise gr.Error("Please enter some text to synthesise.")
if speaker_audio is None:
raise gr.Error("Please upload a reference speaker audio file.")
if (language, accent1) not in MODELS:
raise gr.Error(f"Unsupported combination: language={language}, accent={accent1}.")
accent2_key = accent2 if enable_second else None
if enable_second and (language, accent2) not in MODELS:
raise gr.Error(f"Unsupported combination: language={language}, accent={accent2}.")
tts = _load_model(language, accent1, accent2_key)
key = (language, accent1, accent2_key)
# Rescale adapters from their current cached coefficients to the desired ones
prev_coeff1, prev_coeff2 = _current_coeffs[key]
_scale_lora(tts, coeff1 / prev_coeff1, adapter_name="default")
if accent2_key is not None:
_scale_lora(tts, coeff2 / prev_coeff2, adapter_name="other")
_current_coeffs[key] = (coeff1, coeff2 if accent2_key else 1.0)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
output_path = tmp.name
tts.tts_to_file(
text=text,
speaker_wav=speaker_audio,
language=LANGUAGE_CODES[language],
file_path=output_path,
)
return output_path
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def update_accent_choices(language: str):
accents = ACCENTS_BY_LANGUAGE.get(language, [])
return gr.update(choices=accents, value=accents[0])
with gr.Blocks(title="Accent Vectors") as demo:
gr.Markdown(
"""
# Accent Vectors
Synthesise speech with a controllable accent — pick the output **language**,
the speaker's **accent**, upload a short reference audio clip, and type your text.
> **Paper:** *Accent Vector: Controllable Accent Manipulation for Multilingual TTS
> Without Accented Data* (submitted to Interspeech 2026)
"""
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Text to synthesise",
placeholder="Type something here…",
lines=3,
)
speaker_audio = gr.Audio(
label="Reference speaker audio (3–10 s)",
type="filepath",
)
with gr.Row():
language_dd = gr.Dropdown(
label="Output language",
choices=list(ACCENTS_BY_LANGUAGE.keys()),
value="English",
)
accent1_dd = gr.Dropdown(
label="Speaker accent",
choices=ACCENTS_BY_LANGUAGE["English"],
value="English",
)
coeff1_slider = gr.Slider(
label="Accent strength",
minimum=0.0, maximum=1.0, step=0.05, value=1.0,
)
with gr.Accordion("Mix a second accent (optional)", open=False):
enable_second = gr.Checkbox(label="Enable second accent", value=False)
accent2_dd = gr.Dropdown(
label="Second accent",
choices=ACCENTS_BY_LANGUAGE["English"],
value="Hindi",
interactive=True,
)
coeff2_slider = gr.Slider(
label="Second accent strength",
minimum=0.0, maximum=1.0, step=0.05, value=0.5,
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Generated speech", type="filepath")
# Update both accent dropdowns when language changes
language_dd.change(fn=update_accent_choices, inputs=language_dd, outputs=accent1_dd)
language_dd.change(fn=update_accent_choices, inputs=language_dd, outputs=accent2_dd)
generate_btn.click(
fn=synthesise,
inputs=[
text_input, speaker_audio,
language_dd, accent1_dd, coeff1_slider,
enable_second, accent2_dd, coeff2_slider,
],
outputs=audio_output,
)
gr.Markdown(
"""
---
### How to use
1. **Output language** — the language the model will speak in.
2. **Speaker accent** — the L1 accent of the target speaker style.
3. **Reference audio** — a clean 3–10 second clip of any speaker; the model
clones the voice while applying the chosen accent.
4. **Accent strength** — LoRA adapter contribution (0 = no accent effect, 1 = full).
5. **Mix a second accent** — optionally blend two accents together by enabling
a second adapter and setting its strength independently.
Models are downloaded automatically on first use.
"""
)
if __name__ == "__main__":
demo.launch()