| import gc |
| import logging |
| from typing import Generator, Union |
|
|
| import numpy as np |
| import torch |
|
|
| from modules import config, models |
| from modules.ChatTTS import ChatTTS |
| from modules.devices import devices |
| from modules.speaker import Speaker |
| from modules.utils.cache import conditional_cache |
| from modules.utils.SeedContext import SeedContext |
|
|
| logger = logging.getLogger(__name__) |
|
|
| SAMPLE_RATE = 24000 |
|
|
|
|
| def generate_audio( |
| text: str, |
| temperature: float = 0.3, |
| top_P: float = 0.7, |
| top_K: float = 20, |
| spk: Union[int, Speaker] = -1, |
| infer_seed: int = -1, |
| use_decoder: bool = True, |
| prompt1: str = "", |
| prompt2: str = "", |
| prefix: str = "", |
| ): |
| (sample_rate, wav) = generate_audio_batch( |
| [text], |
| temperature=temperature, |
| top_P=top_P, |
| top_K=top_K, |
| spk=spk, |
| infer_seed=infer_seed, |
| use_decoder=use_decoder, |
| prompt1=prompt1, |
| prompt2=prompt2, |
| prefix=prefix, |
| )[0] |
|
|
| return (sample_rate, wav) |
|
|
|
|
| def parse_infer_params( |
| texts: list[str], |
| chat_tts: ChatTTS.Chat, |
| temperature: float = 0.3, |
| top_P: float = 0.7, |
| top_K: float = 20, |
| spk: Union[int, Speaker] = -1, |
| infer_seed: int = -1, |
| prompt1: str = "", |
| prompt2: str = "", |
| prefix: str = "", |
| ): |
| params_infer_code = { |
| "spk_emb": None, |
| "temperature": temperature, |
| "top_P": top_P, |
| "top_K": top_K, |
| "prompt1": prompt1 or "", |
| "prompt2": prompt2 or "", |
| "prefix": prefix or "", |
| "repetition_penalty": 1.0, |
| "disable_tqdm": config.runtime_env_vars.off_tqdm, |
| } |
|
|
| if isinstance(spk, int): |
| with SeedContext(spk, True): |
| params_infer_code["spk_emb"] = chat_tts.sample_random_speaker() |
| logger.debug(("spk", spk)) |
| elif isinstance(spk, Speaker): |
| if not isinstance(spk.emb, torch.Tensor): |
| raise ValueError("spk.pt is broken, please retrain the model.") |
| params_infer_code["spk_emb"] = spk.emb |
| logger.debug(("spk", spk.name)) |
| else: |
| logger.warn( |
| f"spk must be int or Speaker, but: <{type(spk)}> {spk}, wiil set to default voice" |
| ) |
| with SeedContext(2, True): |
| params_infer_code["spk_emb"] = chat_tts.sample_random_speaker() |
|
|
| logger.debug( |
| { |
| "text": texts, |
| "infer_seed": infer_seed, |
| "temperature": temperature, |
| "top_P": top_P, |
| "top_K": top_K, |
| "prompt1": prompt1 or "", |
| "prompt2": prompt2 or "", |
| "prefix": prefix or "", |
| } |
| ) |
|
|
| return params_infer_code |
|
|
|
|
| @torch.inference_mode() |
| def generate_audio_batch( |
| texts: list[str], |
| temperature: float = 0.3, |
| top_P: float = 0.7, |
| top_K: float = 20, |
| spk: Union[int, Speaker] = -1, |
| infer_seed: int = -1, |
| use_decoder: bool = True, |
| prompt1: str = "", |
| prompt2: str = "", |
| prefix: str = "", |
| ): |
| chat_tts = models.load_chat_tts() |
| params_infer_code = parse_infer_params( |
| texts=texts, |
| chat_tts=chat_tts, |
| temperature=temperature, |
| top_P=top_P, |
| top_K=top_K, |
| spk=spk, |
| infer_seed=infer_seed, |
| prompt1=prompt1, |
| prompt2=prompt2, |
| prefix=prefix, |
| ) |
|
|
| with SeedContext(infer_seed, True): |
| wavs = chat_tts.generate_audio( |
| prompt=texts, params_infer_code=params_infer_code, use_decoder=use_decoder |
| ) |
|
|
| if config.auto_gc: |
| devices.torch_gc() |
| gc.collect() |
|
|
| return [(SAMPLE_RATE, np.array(wav).flatten().astype(np.float32)) for wav in wavs] |
|
|
|
|
| |
| @torch.inference_mode() |
| def generate_audio_stream( |
| text: str, |
| temperature: float = 0.3, |
| top_P: float = 0.7, |
| top_K: float = 20, |
| spk: Union[int, Speaker] = -1, |
| infer_seed: int = -1, |
| use_decoder: bool = True, |
| prompt1: str = "", |
| prompt2: str = "", |
| prefix: str = "", |
| ) -> Generator[tuple[int, np.ndarray], None, None]: |
| chat_tts = models.load_chat_tts() |
| texts = [text] |
| params_infer_code = parse_infer_params( |
| texts=texts, |
| chat_tts=chat_tts, |
| temperature=temperature, |
| top_P=top_P, |
| top_K=top_K, |
| spk=spk, |
| infer_seed=infer_seed, |
| prompt1=prompt1, |
| prompt2=prompt2, |
| prefix=prefix, |
| ) |
|
|
| with SeedContext(infer_seed, True): |
| wavs_gen = chat_tts.generate_audio( |
| prompt=texts, |
| params_infer_code=params_infer_code, |
| use_decoder=use_decoder, |
| stream=True, |
| ) |
|
|
| for wav in wavs_gen: |
| yield [SAMPLE_RATE, np.array(wav).flatten().astype(np.float32)] |
|
|
| if config.auto_gc: |
| devices.torch_gc() |
| gc.collect() |
|
|
| return |
|
|
|
|
| lru_cache_enabled = False |
|
|
|
|
| def setup_lru_cache(): |
| global generate_audio_batch |
| global lru_cache_enabled |
|
|
| if lru_cache_enabled: |
| return |
| lru_cache_enabled = True |
|
|
| def should_cache(*args, **kwargs): |
| spk_seed = kwargs.get("spk", -1) |
| infer_seed = kwargs.get("infer_seed", -1) |
| return spk_seed != -1 and infer_seed != -1 |
|
|
| lru_size = config.runtime_env_vars.lru_size |
| if isinstance(lru_size, int): |
| generate_audio_batch = conditional_cache(lru_size, should_cache)( |
| generate_audio_batch |
| ) |
| logger.info(f"LRU cache enabled with size {lru_size}") |
| else: |
| logger.debug(f"LRU cache failed to enable, invalid size {lru_size}") |
|
|
|
|
| if __name__ == "__main__": |
| import soundfile as sf |
|
|
| |
| inputs = ["你好[lbreak]", "再见[lbreak]", "长度不同的文本片段[lbreak]"] |
| outputs = generate_audio_batch(inputs, spk=5, infer_seed=42) |
|
|
| for i, (sample_rate, wav) in enumerate(outputs): |
| print(i, sample_rate, wav.shape) |
|
|
| sf.write(f"batch_{i}.wav", wav, sample_rate, format="wav") |
|
|
| |
| for i, text in enumerate(inputs): |
| sample_rate, wav = generate_audio(text, spk=5, infer_seed=42) |
| print(i, sample_rate, wav.shape) |
|
|
| sf.write(f"one_{i}.wav", wav, sample_rate, format="wav") |
|
|