supertonic-3-litert / inference.py
Reza2kn's picture
Add INT8 ONNX vector_estimator (65 MB), update README + tags
c47c32d verified
"""End-to-end TTS inference using the LiteRT (.tflite) + ONNX components.
Architecture:
text -> tokenize
-> duration_predictor (.tflite) -> frame count
-> text_encoder (.tflite) -> text embedding
-> sample noisy latent ~ N(0, I)
-> vector_estimator (.onnx) -> ODE step x 8
-> vocoder (.tflite) -> 44.1 kHz waveform
3 of the 4 components convert cleanly to LiteRT via onnx2tf + ai-edge-
quantizer. `vector_estimator` is kept as ONNX because its rotary
multi-head attention defeats onnx2tf's NCW-NHWC shape inference (and
litert-torch deadlocks on loaded weights with specific patterns). This
ONNX fallback runs on CPU via onnxruntime; the other three run on the
LiteRT runtime (`ai_edge_litert`) which supports true INT4 inference.
Two recommended configurations:
fp32: fp32/dp + fp32/te + vector_estimator.onnx + fp32/vocoder
(142 MB tflite + 256 MB ONNX = ~398 MB)
int4: int4/dp + int4/te + vector_estimator.onnx + int8/vocoder
(28 MB tflite + 26 MB INT8 vocoder + 256 MB ONNX = ~310 MB)
(INT4 vocoder is broken — cos ~0 — so we ship INT8 for vocoder)
Usage:
python inference.py --text "Hello, world." --voice F1 --lang en
python inference.py --text "<longer prompt>" --voice F5 --auto-pad
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
import numpy as np
import soundfile as sf
import onnxruntime as ort
HERE = Path(__file__).parent
T_BUCKET = 320
L_BUCKET = 320
SAMPLE_RATE = 44_100
LATENT_DIM = 24
CHUNK_COMPRESS_FACTOR = 6
BASE_CHUNK_SIZE = 512
DEFAULT_TOTAL_STEPS = 8
DEFAULT_SPEED = 1.05
DEFAULT_AUTO_PAD = " And with that, the gentle silence wrapped itself around the room."
def _pad(arr: np.ndarray, axis: int, target: int) -> np.ndarray:
if arr.shape[axis] >= target:
return arr
pad = [(0, 0)] * arr.ndim
pad[axis] = (0, target - arr.shape[axis])
return np.pad(arr, pad)
def _load_voice(name: str) -> tuple[np.ndarray, np.ndarray]:
j = json.loads((HERE / "voice_styles" / f"{name}.json").read_text())
def r(part): return np.array(part["data"], dtype=np.float32).reshape(*part["dims"])
return r(j["style_ttl"]), r(j["style_dp"])
def _load_tokenizer(indexer_path: Path):
try:
from supertonic.core import UnicodeProcessor
except ImportError as e:
raise RuntimeError(
"supertonic package is required for tokenization. "
"Install with: pip install supertonic"
) from e
return UnicodeProcessor(str(indexer_path))
class TFLiteRunner:
"""Convenience wrapper around ai_edge_litert.Interpreter (true LiteRT
runtime, supports INT4) — falls back to tf.lite.Interpreter for FP32
if ai_edge_litert is unavailable."""
def __init__(self, path: Path):
try:
from ai_edge_litert.interpreter import Interpreter as AILiteRT
self._interp = AILiteRT(model_path=str(path))
except ImportError:
import tensorflow as tf
self._interp = tf.lite.Interpreter(model_path=str(path))
self._interp.allocate_tensors()
self._in_details = {d["name"]: d for d in self._interp.get_input_details()}
self._in_keys = {full.split("/")[-1]: full for full in self._in_details}
self._out = self._interp.get_output_details()[0]
def predict(self, feed: dict[str, np.ndarray]) -> np.ndarray:
for short, value in feed.items():
full = self._in_keys.get(short) or next(
(k for k in self._in_details if short in k), None)
d = self._in_details[full]
v = value if value.dtype == d["dtype"] else value.astype(d["dtype"])
self._interp.set_tensor(d["index"], v)
self._interp.invoke()
return self._interp.get_tensor(self._out["index"])
def _last_loud_window(audio: np.ndarray, thresh: float = 0.025) -> int:
win = int(0.05 * SAMPLE_RATE)
n = len(audio) // win
rms = np.sqrt(np.mean(audio[: n * win].reshape(n, win) ** 2, axis=1))
loud = np.where(rms > thresh)[0]
return int(loud[-1]) if len(loud) else 0
def trim_padded(unpad: np.ndarray, padded: np.ndarray) -> np.ndarray:
win = int(0.05 * SAMPLE_RATE)
n = len(padded) // win
rms = np.sqrt(np.mean(padded[: n * win].reshape(n, win) ** 2, axis=1))
floor = _last_loud_window(unpad)
ceil_ = _last_loud_window(padded) + 1
candidates = []
j = floor
while j < ceil_ - 1:
if rms[j] < 0.025 and rms[j + 1] < 0.025:
start = j; total = 0.0; cnt = 0
while j < ceil_ and rms[j] < 0.025:
total += float(rms[j]); cnt += 1; j += 1
candidates.append((start, cnt, total / max(cnt, 1)))
else:
j += 1
if not candidates:
return padded[: ceil_ * win]
start_win, length, avg = max(candidates, key=lambda c: (c[1], -c[0]))
end_samples = start_win * win
out = padded[:end_samples].copy()
fade = min(int(0.06 * SAMPLE_RATE), len(out))
out[-fade:] *= np.linspace(1.0, 0.0, fade, dtype=np.float32)
return np.concatenate([out, np.zeros(int(0.5 * SAMPLE_RATE), dtype=np.float32)])
class Supertonic3LiteRT:
"""LiteRT TTS with ONNX vector_estimator fallback. Pass quants per
component; defaults give the recommended (int4 dp/te, int8 vocoder,
INT8 ONNX vector_estimator) configuration.
``ve_fp32=True`` swaps in the full-precision vector_estimator.onnx
(256 MB) instead of the default INT8 version (65 MB) — audio is
audibly identical, useful only as a reference."""
def __init__(self, dp_quant: str = "int4", te_quant: str = "int4",
voc_quant: str = "int8", ve_fp32: bool = False):
self.dp = TFLiteRunner(HERE / dp_quant / "duration_predictor.tflite")
self.te = TFLiteRunner(HERE / te_quant / "text_encoder.tflite")
self.voc = TFLiteRunner(HERE / voc_quant / "vocoder.tflite")
ve_name = "vector_estimator.onnx" if ve_fp32 else "vector_estimator_int8.onnx"
self.ve = ort.InferenceSession(
str(HERE / ve_name),
providers=["CPUExecutionProvider"],
)
self.tok = _load_tokenizer(HERE / "unicode_indexer.json")
def _synth(self, text: str, voice: str, lang: str, seed: int,
total_steps: int, speed: float, full_bucket: bool) -> np.ndarray:
text_ids, text_mask = self.tok([text], lang)
text_ids = text_ids.astype(np.int64); text_mask = text_mask.astype(np.float32)
style_ttl, style_dp = _load_voice(voice)
text_ids_p = _pad(text_ids, 1, T_BUCKET)
text_mask_p = _pad(text_mask, 2, T_BUCKET)
dur = float(self.dp.predict({"text_ids": text_ids_p, "style_dp": style_dp,
"text_mask": text_mask_p})[0]) / speed
text_emb_full = self.te.predict({"text_ids": text_ids_p, "style_ttl": style_ttl,
"text_mask": text_mask_p})
# ONNX VE accepts native shapes — trim text_emb back to T_real.
T_real = text_ids.shape[1]
text_emb_real = text_emb_full[:, :, :T_real]
L_real = max(1, min(L_BUCKET, (int(dur * SAMPLE_RATE) + BASE_CHUNK_SIZE * CHUNK_COMPRESS_FACTOR - 1)
// (BASE_CHUNK_SIZE * CHUNK_COMPRESS_FACTOR)))
np.random.seed(seed)
xt = (np.random.randn(1, LATENT_DIM * CHUNK_COMPRESS_FACTOR, L_real)).astype(np.float32)
latent_mask = np.ones((1, 1, L_real), dtype=np.float32)
xt = xt * latent_mask
total_step_arr = np.array([float(total_steps)], dtype=np.float32)
for step in range(total_steps):
xt = self.ve.run(None, {
"noisy_latent": xt, "text_emb": text_emb_real, "style_ttl": style_ttl,
"text_mask": text_mask, "latent_mask": latent_mask,
"current_step": np.array([float(step)], dtype=np.float32),
"total_step": total_step_arr,
})[0]
xt_padded = _pad(xt, 2, L_BUCKET)
wav = self.voc.predict({"latent": xt_padded})[0]
if full_bucket:
return wav
return wav[: L_real * CHUNK_COMPRESS_FACTOR * BASE_CHUNK_SIZE]
def synthesize(self, text: str, voice: str = "F1", lang: str = "en", seed: int = 0,
total_steps: int = DEFAULT_TOTAL_STEPS, speed: float = DEFAULT_SPEED,
auto_pad: str | None = DEFAULT_AUTO_PAD) -> np.ndarray:
if auto_pad is None:
return self._synth(text, voice, lang, seed, total_steps, speed, full_bucket=False)
unpad = self._synth(text, voice, lang, seed, total_steps, speed, full_bucket=True)
padded = self._synth(text + auto_pad, voice, lang, seed, total_steps, speed, full_bucket=True)
return trim_padded(unpad, padded)
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--text", required=True)
ap.add_argument("--voice", default="F1",
choices=[f"F{i}" for i in range(1, 6)] + [f"M{i}" for i in range(1, 6)])
ap.add_argument("--lang", default="en")
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--total-steps", type=int, default=DEFAULT_TOTAL_STEPS)
ap.add_argument("--auto-pad", nargs="?", const=DEFAULT_AUTO_PAD, default=None,
help="2-pass synthesis with filler suffix + auto-trim (recommended for long prompts).")
ap.add_argument("--dp-quant", default="int4", choices=["fp32", "int4"])
ap.add_argument("--te-quant", default="int4", choices=["fp32", "int4"])
ap.add_argument("--voc-quant", default="int8", choices=["fp32", "int8", "int4"],
help="INT4 vocoder is broken (cos ~0) — use int8 or fp32.")
ap.add_argument("--ve-fp32", action="store_true",
help="Use the full-precision vector_estimator.onnx (256 MB) "
"instead of the default INT8 ONNX (65 MB).")
ap.add_argument("--out", default="out.wav")
args = ap.parse_args()
t0 = time.time()
tts = Supertonic3LiteRT(dp_quant=args.dp_quant, te_quant=args.te_quant,
voc_quant=args.voc_quant, ve_fp32=args.ve_fp32)
ve_kind = "fp32" if args.ve_fp32 else "INT8"
print(f"Loaded models in {time.time() - t0:.2f}s "
f"(dp={args.dp_quant}, te={args.te_quant}, voc={args.voc_quant}, ve={ve_kind})")
t0 = time.time()
audio = tts.synthesize(args.text, voice=args.voice, lang=args.lang, seed=args.seed,
total_steps=args.total_steps, auto_pad=args.auto_pad)
sf.write(args.out, audio, SAMPLE_RATE)
print(f"Synthesized {len(audio)/SAMPLE_RATE:.2f}s in {time.time() - t0:.2f}s -> {args.out}")
return 0
if __name__ == "__main__":
sys.exit(main())