import os import inspect from trainer import Trainer, TrainerArgs from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor def main(): output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig( formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "LJSpeech-1.1/") ) config = GlowTTSConfig( batch_size=256, eval_batch_size=128, num_loader_workers=4, num_eval_loader_workers=2, run_eval=True, test_delay_epochs=-1, epochs=600, text_cleaner="phoneme_cleaners", use_phonemes=True, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), print_step=25, print_eval=False, mixed_precision=True, output_path=output_path, datasets=[dataset_config], max_audio_len=22050 * 10, min_audio_len=22050 * 1, ) ap = AudioProcessor(config=config.audio) tokenizer, config = TTSTokenizer.init_from_config(config) train_samples, eval_samples = load_tts_samples( config, eval_split=True, eval_split_max_size=20, ) model = GlowTTS(config, ap, tokenizer=tokenizer, speaker_manager=None) trainer = Trainer( TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples, training_assets={'audio_processor': ap}, ) if getattr(trainer, "best_loss", None) is None: trainer.best_loss = {"train_loss": float("inf")} elif isinstance(trainer.best_loss, dict) and trainer.best_loss.get("train_loss") is None: trainer.best_loss["train_loss"] = float("inf") trainer.fit() if __name__ == "__main__": main()