wdga's picture
Upload Kokoro LiteRT runtime preview
d7bdcbf verified
"""Run the merged Kokoro decoder/vocoder LiteRT artifact with custom op.
The input NPZ must contain these arrays:
asr, f0_curve, noise, style, valid_frames, initial_phase, sine_noise
This example is intentionally decoder-only. It does not perform Kokoro text
normalization, phonemization, duration prediction, or frontend inference.
"""
from __future__ import annotations
import argparse
import ctypes
from pathlib import Path
import numpy as np
from ai_edge_litert import interpreter as litert_interpreter
INPUT_NAMES = (
"asr",
"f0_curve",
"noise",
"style",
"valid_frames",
"initial_phase",
"sine_noise",
)
def register_kokoro_source_stft(shared_object: Path):
library = ctypes.CDLL(str(shared_object), mode=ctypes.RTLD_GLOBAL)
register_native = library.RegisterKokoroSourceStft
register_native.argtypes = [ctypes.c_uint64]
register_native.restype = None
def registerer(resolver_pointer: int) -> None:
register_native(int(resolver_pointer))
return registerer
def parse_args() -> argparse.Namespace:
artifact_root = Path(__file__).resolve().parents[1]
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=Path,
default=artifact_root / "kokoro_decoder_source_stft_merged.tflite",
)
parser.add_argument(
"--custom-op",
type=Path,
default=artifact_root
/ "custom_ops"
/ "linux-x86_64"
/ "kokoro_source_stft_custom_op_native.so",
)
parser.add_argument("--inputs", type=Path, required=True)
parser.add_argument("--output", type=Path, default=Path("waveform.npy"))
return parser.parse_args()
def main() -> int:
args = parse_args()
inputs = np.load(args.inputs)
missing = [name for name in INPUT_NAMES if name not in inputs]
if missing:
raise KeyError(f"input NPZ is missing required arrays: {missing}")
interpreter = litert_interpreter.InterpreterWithCustomOps(
model_path=str(args.model),
custom_op_registerers=[register_kokoro_source_stft(args.custom_op)],
)
interpreter.allocate_tensors()
for detail, name in zip(interpreter.get_input_details(), INPUT_NAMES, strict=True):
interpreter.set_tensor(detail["index"], inputs[name])
interpreter.invoke()
outputs = interpreter.get_output_details()
waveform = interpreter.get_tensor(outputs[0]["index"])
valid_samples = int(interpreter.get_tensor(outputs[1]["index"])[0])
args.output.parent.mkdir(parents=True, exist_ok=True)
np.save(args.output, waveform[..., :valid_samples])
print(f"wrote {args.output} with {valid_samples} valid samples at 24000 Hz")
return 0
if __name__ == "__main__":
raise SystemExit(main())