File size: 8,492 Bytes
25a89bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | """End-to-end TTS inference using the 4 CoreML components.
Pipeline (mirrors supertonic.core.Supertonic):
text -> tokenize
-> duration_predictor -> frame count
-> text_encoder -> text embedding
-> sample noisy latent ~ N(0, I)
-> vector_estimator x 8 (flow-matching ODE step, runs on ANE)
-> vocoder -> 44.1 kHz waveform
All four mlpackages are static-shape buckets at T=L=320. The driver pads
inputs to that bucket and trims outputs.
The supertonic-3 model truncates long prompts at its content limit
(~13.7s natural; CoreML's bucket-leak extends this to ~16.7s but still
short for long inputs). The `--auto-pad` mode does a two-pass synthesis
(once unpadded to find the natural endpoint, once with a long filler
sentence appended that gives the model more frames to render the full
original prompt), then trims at the silence gap between original and
appended content. Recommended for prompts longer than ~5s.
Usage:
python inference.py --text "Hello, world." --voice F1 --lang en
python inference.py --text "<longer prompt>" --voice F5 --lang en --auto-pad
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
import coremltools as ct
import numpy as np
import soundfile as sf
HERE = Path(__file__).parent
T_BUCKET = 320
L_BUCKET = 320
SAMPLE_RATE = 44_100
LATENT_DIM = 24
CHUNK_COMPRESS_FACTOR = 6
BASE_CHUNK_SIZE = 512
DEFAULT_TOTAL_STEPS = 8
DEFAULT_SPEED = 1.05
DEFAULT_AUTO_PAD = " And with that, the gentle silence wrapped itself around the room."
def _pad(arr: np.ndarray, axis: int, target: int) -> np.ndarray:
if arr.shape[axis] >= target:
return arr
pad = [(0, 0)] * arr.ndim
pad[axis] = (0, target - arr.shape[axis])
return np.pad(arr, pad)
def _load_voice(name: str) -> tuple[np.ndarray, np.ndarray]:
j = json.loads((HERE / "voice_styles" / f"{name}.json").read_text())
def r(part): return np.array(part["data"], dtype=np.float32).reshape(*part["dims"])
return r(j["style_ttl"]), r(j["style_dp"])
def _load_tokenizer(indexer_path: Path):
"""Reuse the official supertonic UnicodeProcessor (handles the 31
languages, abbreviation expansion, punctuation rules, etc.).
Install with: pip install supertonic
"""
try:
from supertonic.core import UnicodeProcessor
except ImportError as e:
raise RuntimeError(
"supertonic package is required for tokenization. "
"Install with: pip install supertonic"
) from e
return UnicodeProcessor(str(indexer_path))
def _last_loud_window(audio: np.ndarray, thresh: float = 0.025, win_s: float = 0.05) -> int:
win = int(win_s * SAMPLE_RATE)
n = len(audio) // win
rms = np.sqrt(np.mean(audio[: n * win].reshape(n, win) ** 2, axis=1))
loud = np.where(rms > thresh)[0]
return int(loud[-1]) if len(loud) else 0
def trim_padded(unpad: np.ndarray, padded: np.ndarray) -> np.ndarray:
"""Trim padded synthesis at the longest clean silence between original
prompt and appended suffix. Tail-pad with 0.5 s of true silence."""
win = int(0.05 * SAMPLE_RATE)
n = len(padded) // win
rms = np.sqrt(np.mean(padded[: n * win].reshape(n, win) ** 2, axis=1))
floor = _last_loud_window(unpad)
ceil_ = _last_loud_window(padded) + 1
candidates = []
j = floor
while j < ceil_ - 1:
if rms[j] < 0.025 and rms[j + 1] < 0.025:
start = j; total = 0.0; cnt = 0
while j < ceil_ and rms[j] < 0.025:
total += float(rms[j]); cnt += 1; j += 1
candidates.append((start, cnt, total / max(cnt, 1)))
else:
j += 1
if not candidates:
return padded[: ceil_ * win]
start_win, length, avg = max(candidates, key=lambda c: (c[1], -c[0]))
end_samples = start_win * win
out = padded[:end_samples].copy()
fade = min(int(0.06 * SAMPLE_RATE), len(out))
out[-fade:] *= np.linspace(1.0, 0.0, fade, dtype=np.float32)
return np.concatenate([out, np.zeros(int(0.5 * SAMPLE_RATE), dtype=np.float32)])
class Supertonic3CoreML:
def __init__(self, quant: str = "fp16"):
d = HERE / quant
self.dp = ct.models.MLModel(str(d / "duration_predictor.mlpackage"))
self.te = ct.models.MLModel(str(d / "text_encoder.mlpackage"))
self.ve = ct.models.MLModel(str(d / "vector_estimator.mlpackage"))
self.voc = ct.models.MLModel(str(d / "vocoder.mlpackage"))
self.tok = _load_tokenizer(HERE / "unicode_indexer.json")
def _synth(self, text: str, voice: str, lang: str, seed: int,
total_steps: int, speed: float, full_bucket: bool) -> np.ndarray:
text_ids, text_mask = self.tok([text], lang)
text_ids = text_ids.astype(np.int64); text_mask = text_mask.astype(np.float32)
style_ttl, style_dp = _load_voice(voice)
text_ids_p = _pad(text_ids.astype(np.int32), 1, T_BUCKET)
text_mask_p = _pad(text_mask, 2, T_BUCKET)
dur = float(self.dp.predict({"text_ids": text_ids_p, "style_dp": style_dp,
"text_mask": text_mask_p})["duration"][0]) / speed
text_emb = self.te.predict({"text_ids": text_ids_p, "style_ttl": style_ttl,
"text_mask": text_mask_p})["text_emb"]
L_real = max(1, min(L_BUCKET, (int(dur * SAMPLE_RATE) + BASE_CHUNK_SIZE * CHUNK_COMPRESS_FACTOR - 1) // (BASE_CHUNK_SIZE * CHUNK_COMPRESS_FACTOR)))
np.random.seed(seed)
xt = (np.random.randn(1, LATENT_DIM * CHUNK_COMPRESS_FACTOR, L_real)).astype(np.float32)
latent_mask = np.ones((1, 1, L_real), dtype=np.float32)
xt = xt * latent_mask
xt = _pad(xt, 2, L_BUCKET)
latent_mask = _pad(latent_mask, 2, L_BUCKET)
total_step_arr = np.array([float(total_steps)], dtype=np.float32)
for step in range(total_steps):
xt = self.ve.predict({
"noisy_latent": xt, "text_emb": text_emb, "style_ttl": style_ttl,
"text_mask": text_mask_p, "latent_mask": latent_mask,
"current_step": np.array([float(step)], dtype=np.float32),
"total_step": total_step_arr,
})["denoised_latent"]
wav = self.voc.predict({"latent": xt})["wav_tts"][0]
if full_bucket:
return wav
return wav[: L_real * CHUNK_COMPRESS_FACTOR * BASE_CHUNK_SIZE]
def synthesize(self, text: str, voice: str = "F1", lang: str = "en", seed: int = 0,
total_steps: int = DEFAULT_TOTAL_STEPS, speed: float = DEFAULT_SPEED,
auto_pad: str | None = DEFAULT_AUTO_PAD) -> np.ndarray:
"""Synthesize speech. With ``auto_pad`` set, runs the 2-pass auto-pad
flow for full content rendering on longer prompts."""
if auto_pad is None:
return self._synth(text, voice, lang, seed, total_steps, speed, full_bucket=False)
unpad_audio = self._synth(text, voice, lang, seed, total_steps, speed, full_bucket=True)
pad_audio = self._synth(text + auto_pad, voice, lang, seed, total_steps, speed, full_bucket=True)
return trim_padded(unpad_audio, pad_audio)
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--text", required=True, help="Text to synthesize")
ap.add_argument("--voice", default="F1", choices=[f"F{i}" for i in range(1, 6)] + [f"M{i}" for i in range(1, 6)])
ap.add_argument("--lang", default="en")
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--total-steps", type=int, default=DEFAULT_TOTAL_STEPS)
ap.add_argument("--auto-pad", nargs="?", const=DEFAULT_AUTO_PAD, default=None,
help="2-pass synthesis with filler suffix + auto-trim (recommended).")
ap.add_argument("--quant", default="fp16", choices=["fp16"])
ap.add_argument("--out", default="out.wav")
args = ap.parse_args()
t0 = time.time()
tts = Supertonic3CoreML(quant=args.quant)
print(f"Loaded models in {time.time() - t0:.2f}s")
t0 = time.time()
audio = tts.synthesize(args.text, voice=args.voice, lang=args.lang, seed=args.seed,
total_steps=args.total_steps, auto_pad=args.auto_pad)
dur = len(audio) / SAMPLE_RATE
sf.write(args.out, audio, SAMPLE_RATE)
print(f"Synthesized {dur:.2f}s of audio in {time.time() - t0:.2f}s -> {args.out}")
return 0
if __name__ == "__main__":
sys.exit(main())
|