File size: 2,765 Bytes
d7bdcbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Run the merged Kokoro decoder/vocoder LiteRT artifact with custom op.

The input NPZ must contain these arrays:

  asr, f0_curve, noise, style, valid_frames, initial_phase, sine_noise

This example is intentionally decoder-only. It does not perform Kokoro text
normalization, phonemization, duration prediction, or frontend inference.
"""

from __future__ import annotations

import argparse
import ctypes
from pathlib import Path

import numpy as np
from ai_edge_litert import interpreter as litert_interpreter

INPUT_NAMES = (
    "asr",
    "f0_curve",
    "noise",
    "style",
    "valid_frames",
    "initial_phase",
    "sine_noise",
)


def register_kokoro_source_stft(shared_object: Path):
    library = ctypes.CDLL(str(shared_object), mode=ctypes.RTLD_GLOBAL)
    register_native = library.RegisterKokoroSourceStft
    register_native.argtypes = [ctypes.c_uint64]
    register_native.restype = None

    def registerer(resolver_pointer: int) -> None:
        register_native(int(resolver_pointer))

    return registerer


def parse_args() -> argparse.Namespace:
    artifact_root = Path(__file__).resolve().parents[1]
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=Path,
        default=artifact_root / "kokoro_decoder_source_stft_merged.tflite",
    )
    parser.add_argument(
        "--custom-op",
        type=Path,
        default=artifact_root
        / "custom_ops"
        / "linux-x86_64"
        / "kokoro_source_stft_custom_op_native.so",
    )
    parser.add_argument("--inputs", type=Path, required=True)
    parser.add_argument("--output", type=Path, default=Path("waveform.npy"))
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    inputs = np.load(args.inputs)
    missing = [name for name in INPUT_NAMES if name not in inputs]
    if missing:
        raise KeyError(f"input NPZ is missing required arrays: {missing}")

    interpreter = litert_interpreter.InterpreterWithCustomOps(
        model_path=str(args.model),
        custom_op_registerers=[register_kokoro_source_stft(args.custom_op)],
    )
    interpreter.allocate_tensors()

    for detail, name in zip(interpreter.get_input_details(), INPUT_NAMES, strict=True):
        interpreter.set_tensor(detail["index"], inputs[name])

    interpreter.invoke()
    outputs = interpreter.get_output_details()
    waveform = interpreter.get_tensor(outputs[0]["index"])
    valid_samples = int(interpreter.get_tensor(outputs[1]["index"])[0])
    args.output.parent.mkdir(parents=True, exist_ok=True)
    np.save(args.output, waveform[..., :valid_samples])
    print(f"wrote {args.output} with {valid_samples} valid samples at 24000 Hz")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())