| import os |
| from pathlib import Path |
| from kokoro_onnx import Kokoro |
| from misaki import espeak, en, zh |
| from misaki.espeak import EspeakG2P |
| from logging import getLogger |
| import onnxruntime |
|
|
| from lib.utils import Timer, write_audio |
|
|
|
|
| logger = getLogger(__name__) |
| providers = onnxruntime.get_available_providers() |
| MODEL_DIR = Path("//Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models/kokoro") |
|
|
| def create_session(model_path): |
| |
| providers = onnxruntime.get_available_providers() |
| print(f"Available onnx runtime providers: {providers}") |
|
|
| |
| sess_options = onnxruntime.SessionOptions() |
| cpu_count = os.cpu_count() // 2 |
| print(f"Setting threads to CPU cores count: {cpu_count}") |
| sess_options.intra_op_num_threads = cpu_count |
| session = onnxruntime.InferenceSession( |
| model_path, providers=["CPUExecutionProvider"], sess_options=sess_options |
| ) |
| return session |
|
|
|
|
| class KokoroTTS: |
| language_voice_mapping = { |
| "JP": "jf_alpha", |
| "JA": "jf_alpha", |
| "ZH": "zf_xiaoyi", |
| "EN": "af_heart", |
| "FR": "ff_siwis", |
| "IT": "im_nicola", |
| "HI": "hf_alpha", |
| "PT": "im_nicola", |
| "ES": "im_nicola" |
| } |
| language_word_mapping = { |
| "ZH": "你好", |
| "EN": "hello", |
| "FR": "Bonjour", |
| "IT": "Ciao", |
| "HI": "हेलो", |
| "PT": "Olá", |
| "ES": "Hola" |
| } |
|
|
| def __init__(self, model_path: str, voice_model_path: str, vocab_config=None, gcp=None, voice=None): |
| self._session = create_session(model_path) |
| self.model = Kokoro.from_session(self._session, voice_model_path, vocab_config=vocab_config) |
| self.g2p = gcp |
| self.voice = voice |
|
|
| @classmethod |
| def from_language(cls, language: str, model_dir: Path=MODEL_DIR): |
| model_path: str = str(model_dir / "kokoro-quant.onnx") |
| voice_model_path: str = str(model_dir / "voices-v1.0.bin") |
| voice = cls.language_voice_mapping.get(language.upper()) |
| warm_up_text = cls.language_word_mapping.get(language.upper()) |
| logger.info(f"[TTS] language: {language}") |
| if not voice: |
| raise ValueError(f"Unsupported language: {language}, voice: {voice}") |
| vocab_config = None |
| if language.upper() == "ZH": |
| g2p = zh.ZHG2P() |
| vocab_config = model_dir / "zh_config.json" |
| elif language.upper() == 'EN': |
| fallback = espeak.EspeakFallback(british=False) |
| g2p = en.G2P(trf=False, british=False, fallback=fallback) |
| elif language.upper() == "HI": |
| g2p = EspeakG2P(language="hi") |
| elif language.upper() == "IT": |
| g2p = EspeakG2P(language="it") |
| elif language.upper() == "PT": |
| g2p = EspeakG2P(language="pt-br") |
| elif language.upper() == "ES": |
| g2p = EspeakG2P(language="es") |
| elif language.upper() == "FR": |
| g2p = EspeakG2P(language="fr-fr") |
| else: |
| g2p = EspeakG2P(language.lower()) |
| with Timer("load tts"): |
| tts = cls(model_path, voice_model_path,vocab_config=vocab_config, gcp=g2p, voice=voice) |
| tts.generate(warm_up_text) |
| return tts |
|
|
| def generate(self, text, speed=1.2): |
| with Timer("tts inference") as t: |
| phonemes, _ = self.g2p(text) |
| samples, sample_rate = self.model.create(phonemes, self.voice, is_phonemes=True, speed=speed) |
| return samples, sample_rate, t.duration |
| |
|
|
| async def stream(self, text, speed=1.2): |
| phonemes, _ = self.g2p(text) |
| stream = self.model.create_stream(phonemes, self.voice, is_phonemes=True, speed=speed) |
| async for samples, sample_rate in stream: |
| yield samples, sample_rate |
|
|
|
|
| if __name__ == '__main__': |
| tts = KokoroTTS.from_language(language="ZH") |
| samples, sr, time_cost = tts.generate("今天天气怎么样?") |
| write_audio("tts_out.wav", samples, sr) |
| print(time_cost) |
|
|