| import os |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.tensorboard import SummaryWriter |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import time |
| from tqdm import tqdm |
| import os.path as osp |
| import re |
| import sys |
| import yaml |
| import shutil |
| from utils import * |
| from optimizers import build_optimizer |
| from model import * |
| from meldataset import build_dataloader |
| from utils import * |
| from torch.utils.tensorboard import SummaryWriter |
| import click |
|
|
| from accelerate import Accelerator |
| from accelerate.utils import LoggerType |
| from accelerate import DistributedDataParallelKwargs |
|
|
| import logging |
| from logging import StreamHandler |
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.DEBUG) |
| handler = StreamHandler() |
| handler.setLevel(logging.DEBUG) |
| logger.addHandler(handler) |
|
|
|
|
| import logging |
| from accelerate.logging import get_logger |
| logger = get_logger(__name__, log_level="DEBUG") |
|
|
| |
| torch.backends.cudnn.benchmark = True |
|
|
|
|
| def log_print(message, logger): |
| logger.info(message) |
| print(message) |
|
|
| @click.command() |
| @click.option('-p', '--config_path', default='./Configs/config.yml', type=str) |
| def main(config_path): |
|
|
| config = yaml.safe_load(open(config_path)) |
| log_dir = config['log_dir'] |
| if not osp.exists(log_dir): os.mkdir(log_dir) |
| shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path))) |
|
|
| writer = SummaryWriter(log_dir + "/tensorboard") |
| |
| ddp_kwargs = DistributedDataParallelKwargs() |
| accelerator = Accelerator(project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs]) |
| if accelerator.is_main_process: |
| writer = SummaryWriter(log_dir + "/tensorboard") |
|
|
|
|
| |
| file_handler = logging.FileHandler(osp.join(log_dir, 'train.log')) |
| file_handler.setLevel(logging.DEBUG) |
| file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s')) |
| logger.logger.addHandler(file_handler) |
|
|
| epoch = config.get('epoch', 100) |
| save_iter = 1 |
| batch_size = config.get('batch_size', 4) |
| log_interval = 10 |
| device = accelerator.device |
| train_path = config.get('train_data', None) |
| val_path = config.get('val_data', None) |
| epochs = config.get('epochs', 1000) |
|
|
| train_list, val_list = get_data_path_list(train_path, val_path) |
|
|
| train_dataloader = build_dataloader(train_list, |
| batch_size=batch_size, |
| num_workers=8, |
| dataset_config=config.get('dataset_params', {}), |
| device=device) |
|
|
| val_dataloader = build_dataloader(val_list, |
| batch_size=batch_size, |
| validation=True, |
| num_workers=2, |
| device=device, |
| dataset_config=config.get('dataset_params', {})) |
| |
|
|
|
|
| aligner = AlignerModel() |
| forward_sum_loss = ForwardSumLoss() |
| best_val_loss = float('inf') |
|
|
|
|
| scheduler_params = { |
| "max_lr": float(config['optimizer_params'].get('lr', 5e-4)), |
| "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)), |
| "epochs": epochs, |
| "steps_per_epoch": len(train_dataloader), |
| } |
|
|
|
|
| optimizer, scheduler = build_optimizer( |
| {"params": aligner.parameters(), "optimizer_params":{}, "scheduler_params": scheduler_params}) |
|
|
| |
| aligner, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( |
| aligner, optimizer, train_dataloader, val_dataloader, scheduler |
| ) |
|
|
| with accelerator.main_process_first(): |
| if config.get('pretrained_model', '') != '': |
| model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'], |
| load_only_params=config.get('load_only_params', True)) |
| else: |
| start_epoch = 0 |
| iters = 0 |
| |
| |
| |
| for epoch in range(1, epochs + 1): |
| aligner.train() |
| train_losses = [] |
| train_fwd_losses = [] |
| start_time = time.time() |
| |
| |
| |
| pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}/{epochs} [Train]") |
| for i, batch in enumerate(pbar): |
| batch = [b.to(device) for b in batch] |
|
|
| text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch |
| |
| |
| attn_soft, attn_logprob = aligner(spec=mel_input, |
| spec_len=mel_input_length, |
| text=text_input, |
| text_len=text_input_length, |
| attn_prior=attn_prior) |
| |
| |
| loss = forward_sum_loss(attn_logprob=attn_logprob, |
| in_lens=text_input_length, |
| out_lens=mel_input_length) |
| |
| |
| optimizer.zero_grad() |
| accelerator.backward(loss) |
| |
| |
| grad_norm = accelerator.clip_grad_norm_(aligner.parameters(), 5.0) |
| |
| optimizer.step() |
| iters = iters + 1 |
|
|
| if scheduler is not None: |
| scheduler.step() |
| |
|
|
| if (i+1)%log_interval == 0 and accelerator.is_main_process: |
| log_print('Epoch [%d/%d], Step [%d/%d], Forward Sum Loss: %.5f' |
| %(epoch+1, epochs, i+1, len(train_list)//batch_size, loss), logger) |
| |
| writer.add_scalar('train/Forward Sum Loss', loss, iters) |
| |
|
|
| train_losses.append(loss.item()) |
| train_fwd_losses.append(loss.item()) |
|
|
| running_loss = 0 |
| |
| accelerator.print('Time elasped:', time.time()-start_time) |
|
|
| |
| avg_train_loss = sum(train_losses) / len(train_losses) |
|
|
| |
| aligner.eval() |
| val_losses = [] |
|
|
| with torch.no_grad(): |
| for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{epochs} [Val]"): |
| batch = [b.to(device) for b in batch] |
| |
| text_input, text_input_length, mel_input, mel_input_length = batch |
| |
| |
| attn_soft, attn_logprob = aligner(spec=mel_input, |
| spec_len=mel_input_length, |
| text=text_input, |
| text_len=text_input_length, |
| attn_prior=None) |
| |
| |
| val_loss = forward_sum_loss(attn_logprob=attn_logprob, |
| in_lens=text_input_length, |
| out_lens=mel_input_length) |
| |
| val_losses.append(val_loss.item()) |
| |
| |
| avg_val_loss = sum(val_losses) / len(val_losses) |
| |
| |
| writer.add_scalar('epoch/train_loss', avg_train_loss, epoch) |
| writer.add_scalar('epoch/val_loss', avg_val_loss, epoch) |
| |
| |
| |
| if (i+1)%save_iter == 0 and accelerator.is_main_process: |
|
|
| print(f'Saving on step {epoch*len(train_dataloader)+i}...') |
| state = { |
| 'net': {key: aligner[key].state_dict() for key in aligner}, |
| 'optimizer': optimizer.state_dict(), |
| 'iters': iters, |
| 'epoch': epoch, |
| } |
| save_path = os.path.join(log_dir, 'checkpoints', f'TextAligner_checkpoint_epoch_{epoch}.pt') |
| torch.save(state, save_path) |
| |
| epoch_time = time.time() - start_time |
| accelerator.print(f"Epoch {epoch}/{epochs} completed in {epoch_time:.2f}s | " |
| f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}") |
| |
| |
| |
| |
| |
| |
| |
| writer.close() |
|
|
| if __name__=="__main__": |
| main() |