Blue / app.py
notmax123's picture
Soften vocoder edge trim and shorter decode fades.
64122a5
"""
Gradio Space for BlueTTS — multilingual ONNX TTS.
Upstream: https://github.com/maxmelichov/BlueTTS
"""
import os
import re
import sys
import subprocess
import json
import time
import base64
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Dict
from unicodedata import normalize as uni_normalize
import numpy as np
from num2words import num2words
import gradio as gr
import onnxruntime as ort
from download_models import download_blue_models, download_renikud
# Download models if missing
if not os.path.exists("onnx_models/text_encoder.onnx") or os.path.getsize("onnx_models/text_encoder.onnx") < 1000:
print("Models missing or invalid, downloading via huggingface_hub...")
download_blue_models()
download_renikud()
# ============================================================
# Vocabulary & Normalization
# ============================================================
_PIPER_MAP: dict[str, int] = {
"_": 0, "^": 1, "$": 2, " ": 3, "!": 4, "'": 5, "(": 6, ")": 7, ",": 8, "-": 9, ".": 10,
":": 11, ";": 12, "?": 13, "a": 14, "b": 15, "c": 16, "d": 17, "e": 18, "f": 19,
"h": 20, "i": 21, "j": 22, "k": 23, "l": 24, "m": 25, "n": 26, "o": 27, "p": 28, "q": 29, "r": 30, "s": 31, "t": 32, "u": 33,
"v": 34, "w": 35, "x": 36, "y": 37, "z": 38, "æ": 39, "ç": 40, "ð": 41, "ø": 42, "ħ": 43, "ŋ": 44, "œ": 45,
"ǀ": 46, "ǁ": 47, "ǂ": 48, "ǃ": 49, "ɐ": 50, "ɑ": 51, "ɒ": 52, "ɓ": 53, "ɔ": 54, "ɕ": 55,
"ɖ": 56, "ɗ": 57, "ɘ": 58, "ə": 59, "ɚ": 60, "ɛ": 61, "ɜ": 62, "ɞ": 63, "ɟ": 64, "ɠ": 65, "ɡ": 66, "ɢ": 67,
"ɣ": 68, "ɤ": 69, "ɥ": 70, "ɦ": 71, "ɧ": 72, "ɨ": 73, "ɪ": 74, "ɫ": 75, "ɬ": 76, "ɭ": 77, "ɮ": 78, "ɯ": 79,
"ɰ": 80, "ɱ": 81, "ɲ": 82, "ɳ": 83, "ɴ": 84, "ɵ": 85, "ɶ": 86, "ɸ": 87, "ɹ": 88, "ɺ": 89, "ɻ": 90, "ɽ": 91,
"ɾ": 92, "ʀ": 93, "ʁ": 94, "ʂ": 95, "ʃ": 96, "ʄ": 97, "ʈ": 98, "ʉ": 99, "ʊ": 100, "ʋ": 101, "ʌ": 102, "ʍ": 103,
"ʎ": 104, "ʏ": 105, "ʐ": 106, "ʑ": 107, "ʒ": 108, "ʔ": 109, "ʕ": 110, "ʘ": 111, "ʙ": 112, "ʛ": 113, "ʜ": 114, "ʝ": 115,
"ʟ": 116, "ʡ": 117, "ʢ": 118, "ʲ": 119, "ˈ": 120, "ˌ": 121, "ː": 122, "ˑ": 123, "˞": 124,
"β": 125, "θ": 126, "χ": 127, "ᵻ": 128, "ⱱ": 129, "0": 130, "1": 131, "2": 132, "3": 133, "4": 134,
"5": 135, "6": 136, "7": 137, "8": 138, "9": 139, "\u0327": 140, "\u0303": 141, "\u032A": 142, "\u032F": 143, "\u0329": 144,
"ʰ": 145, "ˤ": 146, "ε": 147, "↓": 148, "#": 149, '"': 150, "↑": 151, "\u033A": 152, "\u033B": 153, "g": 154, "ʦ": 155, "X": 156,
}
_EXTENDED_MAP: dict[str, int] = {
"A": 157, "B": 158, "C": 159, "D": 160, "E": 161, "F": 162, "G": 163, "H": 164, "I": 165, "J": 166, "K": 167, "L": 168, "M": 169, "N": 170,
"O": 171, "P": 172, "Q": 173, "R": 174, "S": 175, "T": 176, "U": 177, "V": 178, "W": 179, "Y": 180, "Z": 181,
"ʤ": 182, "ɝ": 183, "ʧ": 184, "ʼ": 185, "ʴ": 186, "ʱ": 187, "ʷ": 188, "ˠ": 189, "→": 190, "↗": 191, "↘": 192,
"¡": 193, "¿": 194, "…": 195, "«": 196, "»": 197, "*": 198, "~": 199, "/": 200, "\\": 201, "&": 202,
"\u0361": 203, "\u035C": 204, "\u0325": 205, "\u032C": 206, "\u0339": 207, "\u031C": 208, "\u031D": 209, "\u031E": 210, "\u031F": 211, "\u0320": 212, "\u0330": 213, "\u0334": 214, "\u031A": 215, "\u0318": 216, "\u0319": 217, "\u0348": 218, "\u0306": 219, "\u0308": 220, "\u031B": 221, "\u0324": 222, "\u033C": 223,
"\u02C0": 224, "\u02C1": 225, "\u02BE": 226, "\u02BF": 227, "\u02BB": 228, "\u02C9": 229, "\u02CA": 230, "\u02CB": 231, "\u02C6": 232,
"\u02E5": 233, "\u02E6": 234, "\u02E7": 235, "\u02E8": 236, "\u02E9": 237, "\u0300": 238, "\u0301": 239, "\u0302": 240, "\u0304": 241, "\u030C": 242, "\u0307": 243,
}
PIPER_REGION_END = 156
LANG_REGION_START = 244
LANG_REGION_SIZE = 140
VOCAB_SIZE = LANG_REGION_START + LANG_REGION_SIZE
PAD_ID = 0
BOS_ID = 1
EOS_ID = 2
LANG_ID: dict[str, int] = {
"en": LANG_REGION_START + 1,
"he": LANG_REGION_START + 0,
"es": LANG_REGION_START + 2,
"de": LANG_REGION_START + 8,
"it": LANG_REGION_START + 9,
}
LANG_NAMES: dict[int, str] = {v: k for k, v in LANG_ID.items()}
LANG_CODE_ALIASES: dict[str, str] = {"ge": "de"}
CHAR_TO_ID: dict[str, int] = {**_PIPER_MAP, **_EXTENDED_MAP}
ID_TO_CHAR: dict[int, str] = {v: k for k, v in CHAR_TO_ID.items()}
for _lang_name, _lang_idx in LANG_ID.items():
ID_TO_CHAR[_lang_idx] = f"<{_lang_name}>"
def normalize_text(text: str, lang: str = "en") -> str:
text = text.strip()
text = uni_normalize("NFD", text)
replacements = {
"\u201c": '"', "\u201d": '"',
"\u2018": "'", "\u2019": "'",
"´": "'", "`": "'",
"–": "-", "‑": "-", "—": "-",
}
for k, v in replacements.items():
text = text.replace(k, v)
if lang == "he":
text = text.replace("r", "ʁ")
text = text.replace("g", "ɡ")
text = re.sub(r"\s+", " ", text).strip()
return text
def text_to_indices(text: str, lang: str = "en") -> list[int]:
lang = LANG_CODE_ALIASES.get(lang, lang)
if lang not in LANG_ID:
raise ValueError(f"Unknown language '{lang}'")
lang_token = LANG_ID[lang]
return [lang_token] + [CHAR_TO_ID.get(ch, PAD_ID) for ch in text]
def text_to_indices_multilang(text: str, base_lang: str = "en") -> list[int]:
base_lang = LANG_CODE_ALIASES.get(base_lang, base_lang)
if base_lang not in LANG_ID:
raise ValueError(f"Unknown language '{base_lang}'")
if "<" not in text:
return text_to_indices(text, lang=base_lang)
segments: list[tuple[str, str]] = []
last_end = 0
for m in re.finditer(r"<(\w+)>(.*?)(?:</\1>|<\1>)", text, flags=re.DOTALL):
if m.start() > last_end:
segments.append((base_lang, text[last_end:m.start()]))
tag_lang = LANG_CODE_ALIASES.get(m.group(1), m.group(1))
segments.append((tag_lang if tag_lang in LANG_ID else base_lang, m.group(2)))
last_end = m.end()
if last_end < len(text):
segments.append((base_lang, text[last_end:]))
ids: list[int] = [LANG_ID[base_lang]]
current_lang = base_lang
for lang, seg in segments:
if lang != current_lang:
ids.append(LANG_ID.get(lang, LANG_ID[base_lang]))
current_lang = lang
ids.extend(CHAR_TO_ID.get(ch, PAD_ID) for ch in seg)
return ids
# Max IPA characters per synthesis forward pass (ONNX). Independent of Renikud clause splitting.
BLUE_SYNTH_MAX_CHUNK_LEN = 150
# ============================================================
# Text Processing & Chunking
# ============================================================
@dataclass
class Style:
ttl: Any
dp: Optional[Any] = None
def _hard_split_chunk(s: str, max_len: int) -> List[str]:
"""Split ``s`` into segments of at most ``max_len`` chars (prefer last space)."""
s = s.strip()
if not s or max_len <= 0:
return [s] if s else []
if len(s) <= max_len:
return [s]
out: List[str] = []
start = 0
n = len(s)
while start < n:
end = min(start + max_len, n)
if end < n:
window = s[start:end]
cut = window.rfind(" ")
if cut > max(max_len // 4, 8):
end = start + cut
piece = s[start:end].strip()
if piece:
out.append(piece)
start = end
while start < n and s[start] == " ":
start += 1
return out
def _split_oversized_hebrew_clause(part: str, max_clause_chars: int) -> List[str]:
"""Only used when a single sentence is longer than ``max_clause_chars``."""
p = part.strip()
if not p:
return []
if len(p) <= max_clause_chars:
return [p]
if re.search(r":\s", p):
pieces = [x.strip() for x in re.split(r"(?<=:)\s+", p) if x.strip()]
if len(pieces) > 1:
out: List[str] = []
for x in pieces:
out.extend(_split_oversized_hebrew_clause(x, max_clause_chars))
return out
if re.search(r"[\u0590-\u05ff]-\s+[\u0590-\u05ff]", p):
pieces = [x.strip() for x in re.split(r"(?<=[\u0590-\u05ff])-\s+", p) if x.strip()]
if len(pieces) > 1:
out2: List[str] = []
for x in pieces:
out2.extend(_split_oversized_hebrew_clause(x, max_clause_chars))
return out2
if re.search(r",\s", p):
pieces = [x.strip() for x in re.split(r",\s+", p) if x.strip()]
if len(pieces) > 1:
out3: List[str] = []
for x in pieces:
out3.extend(_split_oversized_hebrew_clause(x, max_clause_chars))
return out3
return _hard_split_chunk(p, max_clause_chars)
def _split_hebrew_prephoneme(text: str, max_clause_chars: int = 96) -> List[str]:
"""Split raw Hebrew before Renikud G2P.
By default only sentence boundaries (``.?!``); colon / hyphen / comma splits run
only when one sentence is longer than ``max_clause_chars``.
"""
t = text.strip()
if not t:
return []
t = re.sub(r"\.+", ".", t)
t = re.sub(r"\?+", "?", t)
t = re.sub(r"!+", "!", t)
t = t.replace("…", ".")
t = re.sub(r"\s+", " ", t)
def refine_one(s: str) -> List[str]:
s = s.strip()
if not s:
return []
out: List[str] = []
for sent in re.split(r"(?<=[.!?])\s+", s):
sent = sent.strip()
if not sent:
continue
out.extend(_split_oversized_hebrew_clause(sent, max_clause_chars))
return out
clauses: List[str] = []
for block in re.split(r"\n+", t):
block = block.strip()
if block:
clauses.extend(refine_one(block))
return clauses if clauses else [t]
def chunk_text(text: str, max_len: int = 300) -> List[str]:
"""Split IPA/text into sentence-boundary chunks no longer than max_len chars."""
text = re.sub(r"([.!?])(</[a-z]{2,8}>)\s+", r"\1\2\n\n", text)
pattern = (
r"(?<!Mr\.)(?<!Mrs\.)(?<!Ms\.)(?<!Dr\.)(?<!Prof\.)(?<!Sr\.)(?<!Jr\.)"
r"(?<!Ph\.D\.)(?<!etc\.)(?<!e\.g\.)(?<!i\.e\.)(?<!vs\.)(?<!Inc\.)"
r"(?<!Ltd\.)(?<!Co\.)(?<!Corp\.)(?<!St\.)(?<!Ave\.)(?<!Blvd\.)"
r"(?<!\b[A-Z]\.)(?<=[.!?])\s+"
)
chunks: List[str] = []
for paragraph in re.split(r"\n\s*\n+", text.strip()):
paragraph = paragraph.strip()
if not paragraph:
continue
current = ""
for sentence in re.split(pattern, paragraph):
if len(current) + len(sentence) + 1 <= max_len:
current += (" " if current else "") + sentence
else:
if current:
chunks.append(current.strip())
if len(sentence) > max_len:
chunks.extend(_hard_split_chunk(sentence, max_len))
current = ""
else:
current = sentence
if current:
chunks.append(current.strip())
base = chunks if chunks else ([text.strip()] if text.strip() else [])
out: List[str] = []
for c in base:
out.extend(_hard_split_chunk(c, max_len))
fixed_out = []
active_tag = None
for c in out:
c = c.strip()
if not c:
continue
if active_tag and not c.startswith(f"<{active_tag}>"):
c = f"<{active_tag}>" + c
for m in re.finditer(r"<(/)?([a-z]{2,8})>", c):
is_close = bool(m.group(1))
tag = m.group(2)
if is_close:
if active_tag == tag:
active_tag = None
else:
active_tag = tag
if active_tag and not c.endswith(f"</{active_tag}>"):
c = c + f"</{active_tag}>"
fixed_out.append(c)
return fixed_out or ([text.strip()] if text.strip() else [])
class TextProcessor:
_ESPEAK_MAP = {
"en": "en-us", "en-us": "en-us", "de": "de", "ge": "de", "it": "it",
"es": "es",
}
_INLINE_LANG_PAIR = re.compile(r"<(\w+)>(.*?)(?:</\1>|<\1>)", re.DOTALL)
def __init__(
self,
renikud_path: Optional[str] = None,
*,
renikud_max_clause_chars: int = 96,
):
self.renikud = None
self._renikud_max_clause_chars = renikud_max_clause_chars
self._espeak_backends: Dict[str, Any] = {}
self._espeak_separator: Any = None
self._espeak_ready = False
if renikud_path is None and os.path.exists("model.onnx"):
renikud_path = "model.onnx"
self._renikud_path = renikud_path
if renikud_path and os.path.exists(renikud_path):
try:
from renikud_onnx import G2P
self.renikud = G2P(renikud_path)
print(f"[INFO] Loaded Renikud G2P from {renikud_path}")
except ImportError as e:
raise RuntimeError(
"Hebrew G2P needs the `renikud-onnx` package. Install project deps: uv sync"
) from e
self._init_espeak()
def _init_espeak(self):
"""Set up the espeak-ng library path once at startup (cross-platform via espeakng-loader)."""
if self._espeak_ready:
return
try:
import espeakng_loader
from phonemizer.backend.espeak.wrapper import EspeakWrapper
from phonemizer.separator import Separator
EspeakWrapper.set_library(espeakng_loader.get_library_path())
EspeakWrapper.set_data_path(espeakng_loader.get_data_path())
self._espeak_separator = Separator(phone="", word=" ", syllable="")
self._espeak_ready = True
print("[INFO] espeak-ng initialised via espeakng-loader")
except Exception as e:
print(f"[WARN] espeak-ng setup failed: {e}")
def _get_espeak_backend(self, espeak_lang: str) -> Any:
"""Return a cached EspeakBackend for *espeak_lang*, creating it on first use."""
if espeak_lang not in self._espeak_backends:
from phonemizer.backend import EspeakBackend
print(f"[INFO] Loading espeak backend for '{espeak_lang}'…")
self._espeak_backends[espeak_lang] = EspeakBackend(
espeak_lang, preserve_punctuation=True,
with_stress=True, language_switch="remove-flags",
)
print(f"[INFO] espeak backend for '{espeak_lang}' ready")
return self._espeak_backends[espeak_lang]
def _hebrew_requires_renikud_error(self) -> ValueError:
return ValueError(
"Hebrew text requires the Renikud ONNX weights (not bundled with the wheel). "
f"Download: https://huggingface.co/thewh1teagle/renikud/resolve/main/model.onnx\n"
"Then pass renikud_path='model.onnx' (or an absolute path) to the TTS class. "
"The `renikud-onnx` PyPI package is a project dependency."
)
def _espeak_phonemize(self, text: str, lang: str) -> str:
espeak_lang = self._ESPEAK_MAP.get(lang)
if espeak_lang is None:
return text
if not self._espeak_ready:
self._init_espeak()
if self._espeak_ready:
try:
backend = self._get_espeak_backend(espeak_lang)
raw = backend.phonemize(
[text], separator=self._espeak_separator
)[0]
return normalize_text(raw, lang=lang)
except Exception as e:
print(f"[WARN] Phonemizer backend failed for lang={lang}: {e}")
try:
result = subprocess.run(
["espeak-ng", "-q", "--ipa=1", "-v", espeak_lang, text],
check=True,
capture_output=True,
text=True,
)
raw = result.stdout.replace("\n", " ").strip()
return normalize_text(raw, lang=lang)
except Exception as e:
print(f"[WARN] espeak-ng fallback failed for lang={lang}: {e}")
return text
def _phonemize_segment(self, content: str, lang: str) -> str:
content = content.strip()
if not content:
return ""
lang = LANG_CODE_ALIASES.get(lang, lang)
if lang not in LANG_ID:
lang = "en"
has_hebrew = any("\u0590" <= c <= "\u05ff" for c in content)
if has_hebrew:
if self.renikud is None:
raise self._hebrew_requires_renikud_error()
clauses = _split_hebrew_prephoneme(content, self._renikud_max_clause_chars)
ipa_parts = [
normalize_text(self.renikud.phonemize(c), lang="he")
for c in clauses
if c.strip()
]
return re.sub(r"\s+", " ", " ".join(ipa_parts)).strip()
if lang == "he":
return normalize_text(content, lang="he")
return self._espeak_phonemize(content, lang)
def _phonemize_mixed(self, text: str, base_lang: str) -> str:
base_lang = LANG_CODE_ALIASES.get(base_lang, base_lang)
if base_lang not in LANG_ID:
raise ValueError(f"Unknown base_lang {base_lang!r}. Available: {list(LANG_ID.keys())}.")
pieces: List[str] = []
last_end = 0
for m in self._INLINE_LANG_PAIR.finditer(text):
if m.start() > last_end:
chunk = text[last_end:m.start()]
p = self._phonemize_segment(chunk, base_lang)
if p:
pieces.append(p)
open_tag = m.group(1)
seg_lang = LANG_CODE_ALIASES.get(open_tag, open_tag)
if seg_lang not in LANG_ID:
seg_lang = base_lang
inner_ipa = self._phonemize_segment(m.group(2), seg_lang)
pieces.append(f"<{open_tag}>{inner_ipa}</{open_tag}>")
last_end = m.end()
if last_end < len(text):
p = self._phonemize_segment(text[last_end:], base_lang)
if p:
pieces.append(p)
return re.sub(r"\s+", " ", " ".join(pieces)).strip()
def phonemize(self, text: str, lang: str = "en") -> str:
if self._INLINE_LANG_PAIR.search(text):
return self._phonemize_mixed(text, base_lang=lang)
is_hebrew = any("\u0590" <= c <= "\u05ff" for c in text)
if lang == "he" or is_hebrew:
if not is_hebrew:
return normalize_text(text, lang="he")
if self.renikud is not None:
clauses = _split_hebrew_prephoneme(text, self._renikud_max_clause_chars)
ipa_parts = [
normalize_text(self.renikud.phonemize(c), lang="he")
for c in clauses
if c.strip()
]
return re.sub(r"\s+", " ", " ".join(ipa_parts)).strip()
raise self._hebrew_requires_renikud_error()
return self._espeak_phonemize(text, lang)
# ============================================================
# BlueTTS Core
# ============================================================
class BlueTTS:
def __init__(
self,
onnx_dir: str,
config_path: str = "tts.json",
style_json: Optional[str] = None,
steps: int = 32,
cfg_scale: float = 3.0,
speed: float = 1.0,
seed: int = 42,
use_gpu: bool = False,
chunk_len: int = BLUE_SYNTH_MAX_CHUNK_LEN,
silence_sec: float = 0.15,
fade_duration: float = 0.02,
renikud_path: Optional[str] = None,
):
self.onnx_dir = onnx_dir
self.style_json = style_json
self.steps = steps
self.cfg_scale = cfg_scale
self.speed = speed
self.seed = seed
self.chunk_len = chunk_len
self.silence_sec = silence_sec
self.fade_duration = fade_duration
if renikud_path is None:
if os.path.exists("model.onnx"):
renikud_path = "model.onnx"
elif os.path.exists(os.path.join(onnx_dir, "model.onnx")):
renikud_path = os.path.join(onnx_dir, "model.onnx")
self._load_config(config_path)
self._init_sessions(use_gpu)
self._load_stats()
self._load_uncond()
self._load_shuffle_keys()
self._text_proc = TextProcessor(renikud_path)
def _load_config(self, config_path: str):
self.normalizer_scale = 1.0
self.latent_dim = 24
self.chunk_compress_factor = 6
self.hop_length = 512
self.sample_rate = 44100
if config_path and os.path.exists(config_path):
with open(config_path) as f:
cfg = json.load(f)
self.normalizer_scale = float(cfg.get("ttl", {}).get("normalizer", {}).get("scale", self.normalizer_scale))
self.latent_dim = int(cfg.get("ttl", {}).get("latent_dim", self.latent_dim))
self.chunk_compress_factor = int(cfg.get("ttl", {}).get("chunk_compress_factor", self.chunk_compress_factor))
self.sample_rate = int(cfg.get("ae", {}).get("sample_rate", self.sample_rate))
self.hop_length = int(cfg.get("ae", {}).get("encoder", {}).get("spec_processor", {}).get("hop_length", self.hop_length))
self.compressed_channels = self.latent_dim * self.chunk_compress_factor
def _init_sessions(self, use_gpu: bool):
available = ort.get_available_providers()
if use_gpu:
providers = [p for p in ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"] if p in available]
else:
providers = [p for p in ["OpenVINOExecutionProvider", "CPUExecutionProvider"] if p in available]
opts = ort.SessionOptions()
opts.log_severity_level = 3
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
cpu_cores = max(1, (os.cpu_count() or 4) // 4)
opts.intra_op_num_threads = int(os.environ.get("ORT_INTRA", cpu_cores))
opts.inter_op_num_threads = int(os.environ.get("ORT_INTER", 1))
self._opts = opts
self._providers = providers
self._text_enc = self._load_session("text_encoder.onnx")
self._ref_enc = self._load_session("reference_encoder.onnx", required=False)
vf_name = "backbone_keys.onnx" if os.path.exists(os.path.join(self.onnx_dir, "backbone_keys.onnx")) else "backbone.onnx"
if not os.path.exists(os.path.join(self.onnx_dir, vf_name)):
vf_name = "vector_estimator.onnx"
self._vf_model_name = vf_name.replace(".onnx", "")
self._vf = self._load_session(vf_name)
self._vocoder = self._load_session("vocoder.onnx")
dp_name = "duration_predictor.onnx" if os.path.exists(os.path.join(self.onnx_dir, "duration_predictor.onnx")) else "length_pred.onnx"
self._dp = self._load_session(dp_name, required=False)
self._dp_style = self._load_session("length_pred_style.onnx", required=False)
vf_inputs = {i.name for i in self._vf.get_inputs()}
self._vf_inputs = vf_inputs
self._vf_supports_style_keys = "style_keys" in vf_inputs
self._vf_uses_text_emb = "text_emb" in vf_inputs and "text_context" not in vf_inputs
def _load_session(self, name: str, required: bool = True) -> Optional[ort.InferenceSession]:
base = os.path.join(self.onnx_dir, name)
slim = base.replace(".onnx", ".slim.onnx")
path = slim if os.path.exists(slim) else base
if not os.path.exists(path):
if required:
raise FileNotFoundError(f"Model not found: {base}")
return None
return ort.InferenceSession(path, sess_options=self._opts, providers=self._providers)
def _load_stats(self):
stats_path = os.path.join(self.onnx_dir, "stats.npz")
self.mean = self.std = None
if os.path.exists(stats_path):
stats = np.load(stats_path)
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)
if self.mean.ndim == 1:
self.mean = self.mean.reshape(1, -1, 1)
self.std = self.std.reshape(1, -1, 1)
if self.mean.ndim == 3:
self.compressed_channels = int(self.mean.shape[1])
if "normalizer_scale" in stats.files:
self.normalizer_scale = float(stats["normalizer_scale"].item() if stats["normalizer_scale"].ndim == 0 else stats["normalizer_scale"][0])
def _load_uncond(self):
uncond_path = os.path.join(self.onnx_dir, "uncond.npz")
self._u_text = self._u_ref = self._u_keys = self._cond_keys = None
if os.path.exists(uncond_path):
u = np.load(uncond_path)
self._u_text = u["u_text"]
self._u_ref = u["u_ref"]
self._u_keys = u.get("u_keys") if "u_keys" in u.files else None
self._cond_keys = u.get("cond_keys") if "cond_keys" in u.files else None
def _load_shuffle_keys(self):
self._model_keys: dict = {}
keys_path = os.path.join(self.onnx_dir, "keys.npz")
if not os.path.exists(keys_path):
return
data = np.load(keys_path)
for k in data.files:
parts = k.split("/", 1)
if len(parts) == 2:
model, inp = parts
self._model_keys.setdefault(model, {})[inp] = data[k]
def create(self, phonemes: str, lang: str = "en") -> Tuple[np.ndarray, int]:
chunks = chunk_text(phonemes, self.chunk_len)
silence = np.zeros(int(self.silence_sec * self.sample_rate), dtype=np.float32)
parts = []
for i, chunk in enumerate(chunks):
parts.append(self._infer_chunk(chunk, lang=lang))
if i < len(chunks) - 1:
parts.append(silence)
wav = np.concatenate(parts) if parts else np.array([], dtype=np.float32)
return wav, self.sample_rate
def synthesize(self, text: str, lang: str = "en") -> Tuple[np.ndarray, int]:
phonemes = self._text_proc.phonemize(text, lang=lang)
return self.create(phonemes, lang=lang)
def _run(self, sess: ort.InferenceSession, feed: dict, model_name: str):
keys = self._model_keys.get(model_name)
if keys:
feed = {**feed, **keys}
return sess.run(None, feed)
def _load_style_json(self, path: str):
with open(path) as f:
j = json.load(f)
def _arr(key):
if key not in j:
return None
a = np.array(j[key]["data"], dtype=np.float32)
return a[None] if a.ndim == 2 else a
style_ttl = _arr("style_ttl")
style_keys = _arr("style_keys")
style_dp = _arr("style_dp")
z_ref = _arr("z_ref")
return style_ttl, style_keys, style_dp, z_ref
def _extract_style(self, z_ref_norm: np.ndarray):
if self._ref_enc is None:
raise ValueError("Reference encoder not loaded.")
TARGET = 256
B, C, T = z_ref_norm.shape
if T < TARGET:
pad = TARGET - T
z = np.pad(z_ref_norm, ((0, 0), (0, 0), (0, pad)))
mask = np.zeros((B, 1, TARGET), dtype=np.float32)
mask[:, :, :T] = 1.0
else:
z = z_ref_norm[:, :, :TARGET]
mask = np.ones((B, 1, TARGET), dtype=np.float32)
ref_names = [i.name for i in self._ref_enc.get_inputs()]
feed = {"z_ref": z}
if "mask" in ref_names:
feed["mask"] = mask
elif "ref_mask" in ref_names:
feed["ref_mask"] = mask
elif len(ref_names) >= 2:
feed[ref_names[1]] = mask
ref_values, ref_keys = self._run(self._ref_enc, feed, "reference_encoder")[:2]
return ref_values, ref_keys
def _infer_chunk(self, phonemes: str, lang: str = "en") -> np.ndarray:
if self.mean is None or self.std is None:
raise ValueError("stats.npz not loaded.")
style_ttl = style_keys = style_dp = z_ref = None
if self.style_json:
style_ttl, style_keys, style_dp, z_ref = self._load_style_json(self.style_json)
if z_ref is None and style_ttl is None:
raise ValueError("Provide style_json with z_ref or style_ttl content.")
text_plain = re.sub(r"</?[a-z]{2,8}>", "", phonemes)
indices_dp = text_to_indices(text_plain, lang=lang)
ids_dp = np.array([indices_dp], dtype=np.int64)
mask_dp = np.ones((1, 1, len(indices_dp)), dtype=np.float32)
indices_full = text_to_indices_multilang(phonemes, base_lang=lang)
text_ids = np.array([indices_full], dtype=np.int64)
text_mask = np.ones((1, 1, len(indices_full)), dtype=np.float32)
z_ref_norm = None
if z_ref is not None:
z_ref_norm = ((z_ref - self.mean) / self.std) * float(self.normalizer_scale)
T = z_ref_norm.shape[2]
tail = max(2, int(T * 0.05))
z_ref_norm = z_ref_norm[:, :, : max(1, T - tail)]
if z_ref_norm.shape[2] > 150:
z_ref_norm = z_ref_norm[:, :, :150]
if style_ttl is not None:
ref_values = style_ttl
else:
ref_values, style_keys = self._extract_style(z_ref_norm)
if ref_values.ndim == 2:
ref_values = ref_values[None]
if style_keys is not None and style_keys.ndim == 2:
style_keys = style_keys[None]
ref_keys = style_keys if style_keys is not None else ref_values
te_names = {i.name for i in self._text_enc.get_inputs()}
te_feed = {"text_ids": text_ids}
if "text_mask" in te_names:
te_feed["text_mask"] = text_mask
if "style_ttl" in te_names:
te_feed["style_ttl"] = ref_values
elif "ref_values" in te_names:
te_feed["ref_values"] = ref_values
else:
raise ValueError("Unknown text encoder input names.")
if "ref_keys" in te_names:
te_feed["ref_keys"] = ref_keys
elif "used_ref_keys" in te_names:
te_feed["used_ref_keys"] = ref_keys
text_emb = self._run(self._text_enc, te_feed, "text_encoder")[0]
T_lat = self._predict_duration(ids_dp, mask_dp, z_ref_norm, style_dp)
x = self._flow_matching(text_emb, ref_values, text_mask, T_lat)
return self._decode(x)
def _predict_duration(self, text_ids, text_mask, z_ref_norm, style_dp) -> int:
T_lat = None
if style_dp is not None and self._dp_style is not None:
out = self._run(self._dp_style, {"text_ids": text_ids, "style_dp": style_dp, "text_mask": text_mask}, "length_pred_style")
val = float(np.squeeze(out[0]))
if np.isfinite(val):
T_lat = int(np.round(val / max(self.speed, 1e-6)))
if T_lat is None and z_ref_norm is not None and self._dp is not None:
ref_len = int(z_ref_norm.shape[2])
out = self._run(self._dp, {
"text_ids": text_ids,
"z_ref": z_ref_norm.astype(np.float32),
"text_mask": text_mask,
"ref_mask": np.ones((1, 1, ref_len), dtype=np.float32),
}, "length_pred")
val = float(np.squeeze(out[0]))
if np.isfinite(val):
T_lat = int(np.round(val / max(self.speed, 1e-6)))
if T_lat is None:
T_lat = int(text_ids.shape[1] * 1.3)
txt_len = int(np.sum(text_mask))
T_cap = max(20, min(txt_len * 3 + 20, 600))
T_lat = min(max(int(T_lat), 1), T_cap, 800)
return max(10, T_lat)
def _flow_matching(self, text_emb, ref_values, text_mask, T_lat) -> np.ndarray:
rng = np.random.RandomState(self.seed)
x = rng.randn(1, self.compressed_channels, T_lat).astype(np.float32)
latent_mask = np.ones((1, 1, T_lat), dtype=np.float32)
vf_inputs = self._vf_inputs
cond_keys = None
if self._vf_supports_style_keys and self._cond_keys is not None:
cond_keys = self._cond_keys.astype(np.float32)
if cond_keys.ndim == 2:
cond_keys = cond_keys[None]
u_text = self._u_text.astype(np.float32) if self._u_text is not None else None
u_ref = self._u_ref.astype(np.float32) if self._u_ref is not None else None
u_keys = self._u_keys.astype(np.float32) if self._u_keys is not None else None
u_text_mask = np.ones((1, 1, 1), dtype=np.float32)
for i in range(self.steps):
t_val = np.array([float(i)], dtype=np.float32)
total_t = np.array([float(self.steps)], dtype=np.float32)
feed: dict = {}
if "noisy_latent" in vf_inputs:
feed["noisy_latent"] = x
if "text_emb" in vf_inputs:
feed["text_emb"] = text_emb
elif "text_context" in vf_inputs:
feed["text_context"] = text_emb
if "style_ttl" in vf_inputs:
feed["style_ttl"] = ref_values
elif "ref_values" in vf_inputs:
feed["ref_values"] = ref_values
if "latent_mask" in vf_inputs:
feed["latent_mask"] = latent_mask
if "text_mask" in vf_inputs:
feed["text_mask"] = text_mask
if "current_step" in vf_inputs:
feed["current_step"] = t_val
if "total_step" in vf_inputs:
feed["total_step"] = total_t
if "style_keys" in vf_inputs and cond_keys is not None:
feed["style_keys"] = cond_keys
if "style_mask" in vf_inputs:
feed["style_mask"] = np.ones((1, 1, ref_values.shape[1]), dtype=np.float32)
den_cond = self._run(self._vf, feed, self._vf_model_name)[0]
if self.cfg_scale != 1.0 and u_text is not None:
feed_u = dict(feed)
if "text_emb" in vf_inputs:
feed_u["text_emb"] = u_text
elif "text_context" in vf_inputs:
feed_u["text_context"] = u_text
if "style_ttl" in vf_inputs:
feed_u["style_ttl"] = u_ref
elif "ref_values" in vf_inputs:
feed_u["ref_values"] = u_ref
if "text_mask" in vf_inputs:
feed_u["text_mask"] = u_text_mask
if "style_keys" in vf_inputs:
feed_u["style_keys"] = u_keys
if "style_mask" in vf_inputs:
feed_u["style_mask"] = np.ones((1, 1, u_ref.shape[1]), dtype=np.float32)
den_uncond = self._run(self._vf, feed_u, self._vf_model_name)[0]
x = den_uncond + self.cfg_scale * (den_cond - den_uncond)
else:
x = den_cond
return x
def _apply_fade(self, wav: np.ndarray) -> np.ndarray:
fade_samples = int(self.fade_duration * self.sample_rate)
if fade_samples == 0 or len(wav) < 2 * fade_samples:
return wav
wav = wav.copy()
wav[:fade_samples] *= np.linspace(0.0, 1.0, fade_samples, dtype=np.float32)
wav[-fade_samples:] *= np.linspace(1.0, 0.0, fade_samples, dtype=np.float32)
return wav
def _decode(self, z_pred: np.ndarray) -> np.ndarray:
if float(self.normalizer_scale) not in (0.0, 1.0):
z_unnorm = (z_pred / float(self.normalizer_scale)) * self.std + self.mean
else:
z_unnorm = z_pred * self.std + self.mean
B, C, T = z_unnorm.shape
z_dec = (
z_unnorm.reshape(B, self.latent_dim, self.chunk_compress_factor, T)
.transpose(0, 1, 3, 2)
.reshape(B, self.latent_dim, T * self.chunk_compress_factor)
)
wav = self._run(self._vocoder, {"latent": z_dec}, "vocoder")[0]
frame_len = int(self.hop_length * 5)
if wav.shape[-1] > 2 * frame_len:
wav = wav[..., frame_len:-frame_len]
wav = wav.squeeze()
return self._apply_fade(wav)
def load_voice_style(style_paths: List[str]) -> Style:
B = len(style_paths)
with open(style_paths[0]) as f:
first = json.load(f)
ttl_dims = first["style_ttl"]["dims"]
ttl = np.zeros([B, ttl_dims[1], ttl_dims[2]], dtype=np.float32)
dp: Optional[np.ndarray] = None
if "style_dp" in first:
dp_dims = first["style_dp"]["dims"]
dp = np.zeros([B, dp_dims[1], dp_dims[2]], dtype=np.float32)
for i, path in enumerate(style_paths):
with open(path) as f:
d = json.load(f)
ttl[i] = np.array(d["style_ttl"]["data"], dtype=np.float32).reshape(ttl_dims[1], ttl_dims[2])
if dp is not None and "style_dp" in d:
dp[i] = np.array(d["style_dp"]["data"], dtype=np.float32).reshape(dp_dims[1], dp_dims[2])
return Style(ttl=ttl, dp=dp)
# ============================================================
# Gradio App Logic
# ============================================================
RENIKUD_PATH = "renikud.onnx"
ONNX_MODELS_DIR = "onnx_models"
VOICES = {
"Female": "voices/female1.json",
"Male": "voices/male1.json",
}
tts_models = {name: BlueTTS(ONNX_MODELS_DIR, style_json=path, renikud_path=RENIKUD_PATH) for name, path in VOICES.items()}
def expand_numbers(text: str, lang: str = "en") -> str:
try:
return re.sub(r'\d+', lambda m: num2words(int(m.group()), lang=lang), text)
except Exception:
return text
def synthesize_text(text: str, voice: str, lang: str, steps: int = 8, speed: float = 1.0):
start_t = time.time()
tts = tts_models[voice]
tts.steps, tts.speed = steps, speed
wav, sr = tts.synthesize(expand_numbers(text, lang=lang), lang=lang)
proc_time = time.time() - start_t
audio_dur = len(wav) / sr if len(wav) > 0 else 0.0
rtf = proc_time / audio_dur if audio_dur > 0 else 0
stats = _stats_html(proc_time, audio_dur, rtf)
return (sr, wav), stats
def _stats_html(proc_time, audio_dur, rtf):
return f"""
<div class="stats-bar">
<span class="stat-pill">⏱ {proc_time:.2f}s</span>
<span class="stat-pill">🔊 {audio_dur:.1f}s audio</span>
<span class="stat-pill">⚡ {rtf:.2f}x RTF</span>
</div>"""
EXAMPLES = [
["The power to change begins the moment you believe it's possible!", "Female", "en"],
["הכוח לשנות מתחיל ברגע שבו אתה מאמין שזה אפשרי!", "Male", "he"],
["¡El poder de cambiar comienza en el momento en que crees que es posible!", "Female", "es"],
["Il potere di cambiare inizia nel momento in cui credi che sia possibile!", "Male", "it"],
["Die Kraft zur Veränderung beginnt in dem Moment, in dem du glaubst, dass es möglich ist!", "Female", "de"],
]
def _load_font_face() -> str:
font_path = "fonts/EuclidCircularB.woff2"
if os.path.exists(font_path):
with open(font_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
return f"""@font-face {{
font-family: 'EuclidCircularB';
src: url(data:font/woff2;base64,{b64}) format('woff2');
font-weight: 100 900;
font-style: normal;
}}"""
return ""
css = _load_font_face() + """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500&display=swap');
* { box-sizing: border-box; }
body, .gradio-container {
background: #0a0a0f !important;
font-family: 'EuclidCircularB', sans-serif !important;
color: #e8e8f0 !important;
}
.gradio-container { max-width: 900px !important; margin: 0 auto !important; padding: 2rem 1.5rem !important; }
/* Header */
.app-header { text-align: center; margin-bottom: 2.5rem; padding: 2rem 0 1rem; }
.app-header h1 {
font-size: 2.8rem; font-weight: 600; letter-spacing: -0.03em;
background: linear-gradient(135deg, #60a5fa 0%, #a78bfa 50%, #34d399 100%);
-webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text;
margin: 0 0 0.5rem;
}
.app-header p { color: #6b7280; font-size: 1rem; margin: 0 0 1rem; }
.app-header .github-link {
display: inline-flex; align-items: center; gap: 0.4rem;
margin-top: 0.75rem; padding: 0.45rem 1rem;
font-size: 0.9rem; font-weight: 500; text-decoration: none !important;
color: #93c5fd !important; border: 1px solid #2a3f5c; border-radius: 999px;
background: rgba(96, 165, 250, 0.08); transition: background 0.15s, border-color 0.15s, color 0.15s;
}
.app-header .github-link:hover {
color: #bfdbfe !important; border-color: #60a5fa; background: rgba(96, 165, 250, 0.14);
}
/* Card */
.card {
background: #111118;
border: 1px solid #1e1e2e;
border-radius: 16px;
padding: 1.5rem;
margin-bottom: 1rem;
}
/* Textarea */
.big-input textarea {
background: #0d0d14 !important;
border: 1px solid #2a2a3e !important;
border-radius: 10px !important;
color: #e8e8f0 !important;
font-size: 1.1rem !important;
font-family: 'Inter', sans-serif !important;
line-height: 1.6 !important;
padding: 1rem !important;
resize: vertical !important;
transition: border-color 0.2s !important;
unicode-bidi: plaintext !important;
}
.big-input textarea:focus {
border-color: #60a5fa !important;
outline: none !important;
box-shadow: 0 0 0 3px rgba(96,165,250,0.1) !important;
}
/* Shared label style */
.gradio-textbox label span,
.gradio-dropdown label span,
.gradio-slider label span {
color: #9ca3af !important;
font-size: 0.75rem !important;
font-weight: 600 !important;
text-transform: uppercase !important;
letter-spacing: 0.06em !important;
}
/* ── Controls rows ─────────────────────────────────────── */
.controls-row {
margin-top: 1rem;
display: flex !important;
flex-direction: column !important;
gap: 0.75rem !important;
}
/* Row 1: Language + Voice side by side */
.ctrl-row1,
.ctrl-row2 {
display: flex !important;
flex-direction: row !important;
gap: 0.75rem !important;
align-items: flex-start !important;
width: 100% !important;
}
/* Language dropdown takes ~40%, Voice takes ~60% */
.ctrl-lang { flex: 2 !important; min-width: 0 !important; }
.ctrl-voice { flex: 3 !important; min-width: 0 !important; }
/* Quality + Speed each take 50% */
.ctrl-steps,
.ctrl-speed { flex: 1 !important; min-width: 0 !important; }
/* Dropdown styling */
.ctrl-lang .gradio-dropdown > label > div,
.ctrl-lang .gradio-dropdown > label > div > div,
.ctrl-voice .gradio-dropdown > label > div,
.ctrl-voice .gradio-dropdown > label > div > div {
background: #0d0d14 !important;
border: 1px solid #2a2a3e !important;
border-radius: 8px !important;
color: #e8e8f0 !important;
}
/* Sliders */
.ctrl-steps .gradio-slider,
.ctrl-speed .gradio-slider { width: 100% !important; }
input[type=range] { accent-color: #60a5fa !important; }
/* Generate button */
.gen-btn {
background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important;
border: none !important;
border-radius: 10px !important;
color: #fff !important;
font-size: 1rem !important;
font-weight: 600 !important;
padding: 0.75rem 2rem !important;
cursor: pointer !important;
transition: opacity 0.2s, transform 0.1s !important;
width: 100% !important;
margin-top: 1rem !important;
letter-spacing: 0.02em !important;
}
.gen-btn:hover { opacity: 0.85 !important; transform: translateY(-1px) !important; }
.gen-btn:active { transform: translateY(0) !important; }
/* Audio output */
.gradio-audio { background: #111118 !important; border: 1px solid #1e1e2e !important; border-radius: 12px !important; }
/* Stats bar */
.stats-bar {
display: flex; gap: 0.75rem; flex-wrap: wrap;
margin-top: 0.75rem; padding: 0.75rem 0;
}
.stat-pill {
background: #1a1a2e; border: 1px solid #2a2a4e;
border-radius: 20px; padding: 0.3rem 0.9rem;
font-family: 'JetBrains Mono', monospace;
font-size: 0.8rem; color: #a78bfa;
}
/* Examples */
.examples-section { margin-top: 1.5rem; }
.examples-section h3 { color: #6b7280; font-size: 0.8rem; font-weight: 500; text-transform: uppercase; letter-spacing: 0.1em; margin-bottom: 0.75rem; }
.label-wrap span { color: #6b7280 !important; font-size: 0.78rem !important; font-weight: 500 !important; text-transform: uppercase !important; letter-spacing: 0.08em !important; }
table.examples { width: 100% !important; border-collapse: separate !important; border-spacing: 0 4px !important; }
table.examples thead tr th { color: #4b5563 !important; font-size: 0.72rem !important; font-weight: 600 !important; text-transform: uppercase !important; letter-spacing: 0.08em !important; padding: 0.25rem 0.75rem !important; }
table.examples td { padding: 0.55rem 0.75rem !important; font-size: 0.9rem !important; color: #c4c4d4 !important; cursor: pointer !important; background: #111118 !important; border-top: 1px solid #1e1e2e !important; border-bottom: 1px solid #1e1e2e !important; }
table.examples td:first-child { border-left: 1px solid #1e1e2e !important; border-radius: 8px 0 0 8px !important; }
table.examples td:last-child { border-right: 1px solid #1e1e2e !important; border-radius: 0 8px 8px 0 !important; }
table.examples tr:hover td { background: #1a1a2e !important; border-color: #2a2a4e !important; color: #e8e8f0 !important; }
/* Dropdown base */
.gradio-dropdown select, .gradio-dropdown input {
background: #0d0d14 !important;
border: 1px solid #2a2a3e !important;
color: #e8e8f0 !important;
border-radius: 8px !important;
}
/* Responsive */
@media (max-width: 640px) {
.app-header h1 { font-size: 2rem; }
.gradio-container { padding: 1rem !important; }
}
"""
with gr.Blocks(title="BlueTTS — Multilingual TTS") as demo:
gr.HTML("""
<div class="app-header">
<h1>BlueTTS</h1>
<p>Lightning-fast multilingual text-to-speech · English · Hebrew · Spanish · German · Italian</p>
<a class="github-link" href="https://github.com/maxmelichov/BlueTTS" target="_blank" rel="noopener noreferrer">GitHub · maxmelichov/BlueTTS</a>
</div>
""")
with gr.Column(elem_classes="card"):
text_input = gr.Textbox(
label="Text",
placeholder="Type or paste text here…",
lines=4,
elem_classes="big-input",
value="Great ideas become real when a small team keeps building every single day.",
)
with gr.Column(elem_classes="controls-row"):
with gr.Row(elem_classes="ctrl-row1"):
with gr.Column(elem_classes="ctrl-lang"):
lang_input = gr.Dropdown(
choices=[("English 🇺🇸", "en"), ("Hebrew 🇮🇱", "he"), ("Spanish 🇪🇸", "es"), ("German 🇩🇪", "de"), ("Italian 🇮🇹", "it")],
value="en", label="Language",
)
with gr.Column(elem_classes="ctrl-voice"):
voice_input = gr.Dropdown(
choices=list(VOICES.keys()), value="Female", label="Voice",
)
with gr.Row(elem_classes="ctrl-row2"):
with gr.Column(elem_classes="ctrl-steps"):
steps_input = gr.Slider(2, 16, 8, step=1, label="Quality (steps)")
with gr.Column(elem_classes="ctrl-speed"):
speed_input = gr.Slider(0.5, 2.0, 1.0, step=0.05, label="Speed")
btn = gr.Button("⚡ Generate Speech", elem_classes="gen-btn")
audio_out = gr.Audio(label="Output", type="numpy", autoplay=True)
stats_out = gr.HTML()
gr.Examples(
examples=EXAMPLES,
inputs=[text_input, voice_input, lang_input],
label="Examples",
)
btn.click(
synthesize_text,
inputs=[text_input, voice_input, lang_input, steps_input, speed_input],
outputs=[audio_out, stats_out],
)
# Set dir="auto" on the textarea so Hebrew text is automatically RTL
gr.HTML("""
<script>
(function applyDirAuto() {
const ta = document.querySelector('.big-input textarea');
if (ta) { ta.setAttribute('dir', 'auto'); return; }
const obs = new MutationObserver(() => {
const ta = document.querySelector('.big-input textarea');
if (ta) { ta.setAttribute('dir', 'auto'); obs.disconnect(); }
});
obs.observe(document.body, { childList: true, subtree: true });
})();
</script>
""")
if __name__ == "__main__":
demo.launch(theme=gr.themes.Base(), css=css)