"""End-to-end TTS inference using the LiteRT (.tflite) + ONNX components. Architecture: text -> tokenize -> duration_predictor (.tflite) -> frame count -> text_encoder (.tflite) -> text embedding -> sample noisy latent ~ N(0, I) -> vector_estimator (.onnx) -> ODE step x 8 -> vocoder (.tflite) -> 44.1 kHz waveform 3 of the 4 components convert cleanly to LiteRT via onnx2tf + ai-edge- quantizer. `vector_estimator` is kept as ONNX because its rotary multi-head attention defeats onnx2tf's NCW-NHWC shape inference (and litert-torch deadlocks on loaded weights with specific patterns). This ONNX fallback runs on CPU via onnxruntime; the other three run on the LiteRT runtime (`ai_edge_litert`) which supports true INT4 inference. Two recommended configurations: fp32: fp32/dp + fp32/te + vector_estimator.onnx + fp32/vocoder (142 MB tflite + 256 MB ONNX = ~398 MB) int4: int4/dp + int4/te + vector_estimator.onnx + int8/vocoder (28 MB tflite + 26 MB INT8 vocoder + 256 MB ONNX = ~310 MB) (INT4 vocoder is broken — cos ~0 — so we ship INT8 for vocoder) Usage: python inference.py --text "Hello, world." --voice F1 --lang en python inference.py --text "" --voice F5 --auto-pad """ from __future__ import annotations import argparse import json import sys import time from pathlib import Path import numpy as np import soundfile as sf import onnxruntime as ort HERE = Path(__file__).parent T_BUCKET = 320 L_BUCKET = 320 SAMPLE_RATE = 44_100 LATENT_DIM = 24 CHUNK_COMPRESS_FACTOR = 6 BASE_CHUNK_SIZE = 512 DEFAULT_TOTAL_STEPS = 8 DEFAULT_SPEED = 1.05 DEFAULT_AUTO_PAD = " And with that, the gentle silence wrapped itself around the room." def _pad(arr: np.ndarray, axis: int, target: int) -> np.ndarray: if arr.shape[axis] >= target: return arr pad = [(0, 0)] * arr.ndim pad[axis] = (0, target - arr.shape[axis]) return np.pad(arr, pad) def _load_voice(name: str) -> tuple[np.ndarray, np.ndarray]: j = json.loads((HERE / "voice_styles" / f"{name}.json").read_text()) def r(part): return np.array(part["data"], dtype=np.float32).reshape(*part["dims"]) return r(j["style_ttl"]), r(j["style_dp"]) def _load_tokenizer(indexer_path: Path): try: from supertonic.core import UnicodeProcessor except ImportError as e: raise RuntimeError( "supertonic package is required for tokenization. " "Install with: pip install supertonic" ) from e return UnicodeProcessor(str(indexer_path)) class TFLiteRunner: """Convenience wrapper around ai_edge_litert.Interpreter (true LiteRT runtime, supports INT4) — falls back to tf.lite.Interpreter for FP32 if ai_edge_litert is unavailable.""" def __init__(self, path: Path): try: from ai_edge_litert.interpreter import Interpreter as AILiteRT self._interp = AILiteRT(model_path=str(path)) except ImportError: import tensorflow as tf self._interp = tf.lite.Interpreter(model_path=str(path)) self._interp.allocate_tensors() self._in_details = {d["name"]: d for d in self._interp.get_input_details()} self._in_keys = {full.split("/")[-1]: full for full in self._in_details} self._out = self._interp.get_output_details()[0] def predict(self, feed: dict[str, np.ndarray]) -> np.ndarray: for short, value in feed.items(): full = self._in_keys.get(short) or next( (k for k in self._in_details if short in k), None) d = self._in_details[full] v = value if value.dtype == d["dtype"] else value.astype(d["dtype"]) self._interp.set_tensor(d["index"], v) self._interp.invoke() return self._interp.get_tensor(self._out["index"]) def _last_loud_window(audio: np.ndarray, thresh: float = 0.025) -> int: win = int(0.05 * SAMPLE_RATE) n = len(audio) // win rms = np.sqrt(np.mean(audio[: n * win].reshape(n, win) ** 2, axis=1)) loud = np.where(rms > thresh)[0] return int(loud[-1]) if len(loud) else 0 def trim_padded(unpad: np.ndarray, padded: np.ndarray) -> np.ndarray: win = int(0.05 * SAMPLE_RATE) n = len(padded) // win rms = np.sqrt(np.mean(padded[: n * win].reshape(n, win) ** 2, axis=1)) floor = _last_loud_window(unpad) ceil_ = _last_loud_window(padded) + 1 candidates = [] j = floor while j < ceil_ - 1: if rms[j] < 0.025 and rms[j + 1] < 0.025: start = j; total = 0.0; cnt = 0 while j < ceil_ and rms[j] < 0.025: total += float(rms[j]); cnt += 1; j += 1 candidates.append((start, cnt, total / max(cnt, 1))) else: j += 1 if not candidates: return padded[: ceil_ * win] start_win, length, avg = max(candidates, key=lambda c: (c[1], -c[0])) end_samples = start_win * win out = padded[:end_samples].copy() fade = min(int(0.06 * SAMPLE_RATE), len(out)) out[-fade:] *= np.linspace(1.0, 0.0, fade, dtype=np.float32) return np.concatenate([out, np.zeros(int(0.5 * SAMPLE_RATE), dtype=np.float32)]) class Supertonic3LiteRT: """LiteRT TTS with ONNX vector_estimator fallback. Pass quants per component; defaults give the recommended (int4 dp/te, int8 vocoder, INT8 ONNX vector_estimator) configuration. ``ve_fp32=True`` swaps in the full-precision vector_estimator.onnx (256 MB) instead of the default INT8 version (65 MB) — audio is audibly identical, useful only as a reference.""" def __init__(self, dp_quant: str = "int4", te_quant: str = "int4", voc_quant: str = "int8", ve_fp32: bool = False): self.dp = TFLiteRunner(HERE / dp_quant / "duration_predictor.tflite") self.te = TFLiteRunner(HERE / te_quant / "text_encoder.tflite") self.voc = TFLiteRunner(HERE / voc_quant / "vocoder.tflite") ve_name = "vector_estimator.onnx" if ve_fp32 else "vector_estimator_int8.onnx" self.ve = ort.InferenceSession( str(HERE / ve_name), providers=["CPUExecutionProvider"], ) self.tok = _load_tokenizer(HERE / "unicode_indexer.json") def _synth(self, text: str, voice: str, lang: str, seed: int, total_steps: int, speed: float, full_bucket: bool) -> np.ndarray: text_ids, text_mask = self.tok([text], lang) text_ids = text_ids.astype(np.int64); text_mask = text_mask.astype(np.float32) style_ttl, style_dp = _load_voice(voice) text_ids_p = _pad(text_ids, 1, T_BUCKET) text_mask_p = _pad(text_mask, 2, T_BUCKET) dur = float(self.dp.predict({"text_ids": text_ids_p, "style_dp": style_dp, "text_mask": text_mask_p})[0]) / speed text_emb_full = self.te.predict({"text_ids": text_ids_p, "style_ttl": style_ttl, "text_mask": text_mask_p}) # ONNX VE accepts native shapes — trim text_emb back to T_real. T_real = text_ids.shape[1] text_emb_real = text_emb_full[:, :, :T_real] L_real = max(1, min(L_BUCKET, (int(dur * SAMPLE_RATE) + BASE_CHUNK_SIZE * CHUNK_COMPRESS_FACTOR - 1) // (BASE_CHUNK_SIZE * CHUNK_COMPRESS_FACTOR))) np.random.seed(seed) xt = (np.random.randn(1, LATENT_DIM * CHUNK_COMPRESS_FACTOR, L_real)).astype(np.float32) latent_mask = np.ones((1, 1, L_real), dtype=np.float32) xt = xt * latent_mask total_step_arr = np.array([float(total_steps)], dtype=np.float32) for step in range(total_steps): xt = self.ve.run(None, { "noisy_latent": xt, "text_emb": text_emb_real, "style_ttl": style_ttl, "text_mask": text_mask, "latent_mask": latent_mask, "current_step": np.array([float(step)], dtype=np.float32), "total_step": total_step_arr, })[0] xt_padded = _pad(xt, 2, L_BUCKET) wav = self.voc.predict({"latent": xt_padded})[0] if full_bucket: return wav return wav[: L_real * CHUNK_COMPRESS_FACTOR * BASE_CHUNK_SIZE] def synthesize(self, text: str, voice: str = "F1", lang: str = "en", seed: int = 0, total_steps: int = DEFAULT_TOTAL_STEPS, speed: float = DEFAULT_SPEED, auto_pad: str | None = DEFAULT_AUTO_PAD) -> np.ndarray: if auto_pad is None: return self._synth(text, voice, lang, seed, total_steps, speed, full_bucket=False) unpad = self._synth(text, voice, lang, seed, total_steps, speed, full_bucket=True) padded = self._synth(text + auto_pad, voice, lang, seed, total_steps, speed, full_bucket=True) return trim_padded(unpad, padded) def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--text", required=True) ap.add_argument("--voice", default="F1", choices=[f"F{i}" for i in range(1, 6)] + [f"M{i}" for i in range(1, 6)]) ap.add_argument("--lang", default="en") ap.add_argument("--seed", type=int, default=0) ap.add_argument("--total-steps", type=int, default=DEFAULT_TOTAL_STEPS) ap.add_argument("--auto-pad", nargs="?", const=DEFAULT_AUTO_PAD, default=None, help="2-pass synthesis with filler suffix + auto-trim (recommended for long prompts).") ap.add_argument("--dp-quant", default="int4", choices=["fp32", "int4"]) ap.add_argument("--te-quant", default="int4", choices=["fp32", "int4"]) ap.add_argument("--voc-quant", default="int8", choices=["fp32", "int8", "int4"], help="INT4 vocoder is broken (cos ~0) — use int8 or fp32.") ap.add_argument("--ve-fp32", action="store_true", help="Use the full-precision vector_estimator.onnx (256 MB) " "instead of the default INT8 ONNX (65 MB).") ap.add_argument("--out", default="out.wav") args = ap.parse_args() t0 = time.time() tts = Supertonic3LiteRT(dp_quant=args.dp_quant, te_quant=args.te_quant, voc_quant=args.voc_quant, ve_fp32=args.ve_fp32) ve_kind = "fp32" if args.ve_fp32 else "INT8" print(f"Loaded models in {time.time() - t0:.2f}s " f"(dp={args.dp_quant}, te={args.te_quant}, voc={args.voc_quant}, ve={ve_kind})") t0 = time.time() audio = tts.synthesize(args.text, voice=args.voice, lang=args.lang, seed=args.seed, total_steps=args.total_steps, auto_pad=args.auto_pad) sf.write(args.out, audio, SAMPLE_RATE) print(f"Synthesized {len(audio)/SAMPLE_RATE:.2f}s in {time.time() - t0:.2f}s -> {args.out}") return 0 if __name__ == "__main__": sys.exit(main())