| import os |
| import inspect |
| from trainer import Trainer, TrainerArgs |
| from TTS.tts.configs.glow_tts_config import GlowTTSConfig |
| from TTS.tts.models.glow_tts import GlowTTS |
| from TTS.tts.configs.shared_configs import BaseDatasetConfig |
| from TTS.tts.datasets import load_tts_samples |
| from TTS.tts.utils.text.tokenizer import TTSTokenizer |
| from TTS.utils.audio import AudioProcessor |
|
|
| def main(): |
| output_path = os.path.dirname(os.path.abspath(__file__)) |
|
|
| dataset_config = BaseDatasetConfig( |
| formatter="ljspeech", |
| meta_file_train="metadata.csv", |
| path=os.path.join(output_path, "LJSpeech-1.1/") |
| ) |
|
|
| config = GlowTTSConfig( |
| batch_size=256, |
| eval_batch_size=128, |
| num_loader_workers=4, |
| num_eval_loader_workers=2, |
| run_eval=True, |
| test_delay_epochs=-1, |
| epochs=600, |
| text_cleaner="phoneme_cleaners", |
| use_phonemes=True, |
| phoneme_language="en-us", |
| phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), |
| print_step=25, |
| print_eval=False, |
| mixed_precision=True, |
| output_path=output_path, |
| datasets=[dataset_config], |
| max_audio_len=22050 * 10, |
| min_audio_len=22050 * 1, |
| ) |
|
|
| ap = AudioProcessor(config=config.audio) |
|
|
| tokenizer, config = TTSTokenizer.init_from_config(config) |
|
|
| train_samples, eval_samples = load_tts_samples( |
| config, |
| eval_split=True, |
| eval_split_max_size=20, |
| ) |
|
|
| model = GlowTTS(config, ap, tokenizer=tokenizer, speaker_manager=None) |
|
|
| trainer = Trainer( |
| TrainerArgs(), |
| config, |
| output_path, |
| model=model, |
| train_samples=train_samples, |
| eval_samples=eval_samples, |
| training_assets={'audio_processor': ap}, |
| ) |
|
|
| if getattr(trainer, "best_loss", None) is None: |
| trainer.best_loss = {"train_loss": float("inf")} |
| elif isinstance(trainer.best_loss, dict) and trainer.best_loss.get("train_loss") is None: |
| trainer.best_loss["train_loss"] = float("inf") |
|
|
| trainer.fit() |
|
|
| if __name__ == "__main__": |
| main() |
|
|