| import os |
| |
| import pytorch_lightning as pl |
| import hydra |
| import torch |
| import random |
| import time |
| from os.path import join, basename, exists |
| from pytorch_lightning import seed_everything |
| from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
| from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
| from pytorch_lightning.strategies import DDPStrategy, FSDPStrategy |
| from torch.utils.data import DataLoader |
| from data_module import DataModule |
| from lightning_module import CodecLightningModule |
| from pytorch_lightning.loggers import TensorBoardLogger |
| from omegaconf import OmegaConf |
| |
| seed = 1024 |
| seed_everything(seed) |
| |
| @hydra.main(config_path='config', config_name='default', version_base=None) |
| def train(cfg): |
| checkpoint_callback = ModelCheckpoint( |
| dirpath=cfg.log_dir, |
| save_top_k=5, |
| save_last=True, |
| every_n_train_steps=5000, |
| monitor='mel_loss', |
| mode='min' |
| ) |
| lr_monitor = LearningRateMonitor(logging_interval='step') |
| callbacks = [checkpoint_callback, lr_monitor] |
| |
| datamodule = DataModule(cfg) |
| lightning_module = CodecLightningModule(cfg) |
| |
| log_dir_name = os.path.basename(os.path.normpath(cfg.log_dir)) |
| |
|
|
| tensorboard_logger = TensorBoardLogger( |
| save_dir=cfg.log_dir, |
| name="", |
| version="", |
| log_graph=False, |
| default_hp_metric=True |
| ) |
| |
| ckpt_path = None |
| last_ckpt = os.path.join(cfg.log_dir, 'last.ckpt') |
| if os.path.exists(last_ckpt): |
| ckpt_path = last_ckpt |
| print(f"Resuming from checkpoint: {ckpt_path}") |
| else: |
| print("No checkpoint found, starting training from scratch.") |
| |
| trainer = pl.Trainer( |
| **cfg.train.trainer, |
| strategy=DDPStrategy(find_unused_parameters=True), |
| callbacks=callbacks, |
| logger=tensorboard_logger, |
| profiler="simple", |
| limit_train_batches=1.0 if not cfg.debug else 0.001 |
| ) |
| |
| torch.backends.cudnn.benchmark = True |
| |
| trainer.fit(lightning_module, datamodule=datamodule, ckpt_path=ckpt_path) |
| |
| print(f'Training ends, best score: {checkpoint_callback.best_model_score}, ckpt path: {checkpoint_callback.best_model_path}') |
|
|
| if __name__ == '__main__': |
| torch.multiprocessing.set_start_method('spawn', force=True) |
| train() |