"""Run the merged Kokoro decoder/vocoder LiteRT artifact with custom op. The input NPZ must contain these arrays: asr, f0_curve, noise, style, valid_frames, initial_phase, sine_noise This example is intentionally decoder-only. It does not perform Kokoro text normalization, phonemization, duration prediction, or frontend inference. """ from __future__ import annotations import argparse import ctypes from pathlib import Path import numpy as np from ai_edge_litert import interpreter as litert_interpreter INPUT_NAMES = ( "asr", "f0_curve", "noise", "style", "valid_frames", "initial_phase", "sine_noise", ) def register_kokoro_source_stft(shared_object: Path): library = ctypes.CDLL(str(shared_object), mode=ctypes.RTLD_GLOBAL) register_native = library.RegisterKokoroSourceStft register_native.argtypes = [ctypes.c_uint64] register_native.restype = None def registerer(resolver_pointer: int) -> None: register_native(int(resolver_pointer)) return registerer def parse_args() -> argparse.Namespace: artifact_root = Path(__file__).resolve().parents[1] parser = argparse.ArgumentParser() parser.add_argument( "--model", type=Path, default=artifact_root / "kokoro_decoder_source_stft_merged.tflite", ) parser.add_argument( "--custom-op", type=Path, default=artifact_root / "custom_ops" / "linux-x86_64" / "kokoro_source_stft_custom_op_native.so", ) parser.add_argument("--inputs", type=Path, required=True) parser.add_argument("--output", type=Path, default=Path("waveform.npy")) return parser.parse_args() def main() -> int: args = parse_args() inputs = np.load(args.inputs) missing = [name for name in INPUT_NAMES if name not in inputs] if missing: raise KeyError(f"input NPZ is missing required arrays: {missing}") interpreter = litert_interpreter.InterpreterWithCustomOps( model_path=str(args.model), custom_op_registerers=[register_kokoro_source_stft(args.custom_op)], ) interpreter.allocate_tensors() for detail, name in zip(interpreter.get_input_details(), INPUT_NAMES, strict=True): interpreter.set_tensor(detail["index"], inputs[name]) interpreter.invoke() outputs = interpreter.get_output_details() waveform = interpreter.get_tensor(outputs[0]["index"]) valid_samples = int(interpreter.get_tensor(outputs[1]["index"])[0]) args.output.parent.mkdir(parents=True, exist_ok=True) np.save(args.output, waveform[..., :valid_samples]) print(f"wrote {args.output} with {valid_samples} valid samples at 24000 Hz") return 0 if __name__ == "__main__": raise SystemExit(main())