| import os |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.tensorboard import SummaryWriter |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import time |
| from tqdm import tqdm |
| from optimizers import build_optimizer |
|
|
| def train_aligner(config, accelerator, train_dataloader, val_dataloader, device, log_dir, epochs=100): |
| |
| aligner = AlignerModel().to(device) |
| |
| |
| forward_sum_loss = ForwardSumLoss() |
| |
| |
| scheduler_params = { |
| "max_lr": float(config['optimizer_params'].get('lr', 5e-4)), |
| "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)), |
| "epochs": epochs, |
| "steps_per_epoch": len(train_dataloader), |
| } |
|
|
| optimizer, scheduler = build_optimizer( |
| {"params": aligner.parameters(), "optimizer_params":{}, "scheduler_params": scheduler_params}) |
| |
| |
| writer = SummaryWriter(log_dir=log_dir) |
| |
| |
| os.makedirs(os.path.join(log_dir, 'checkpoints'), exist_ok=True) |
| |
| |
| best_val_loss = float('inf') |
| |
| |
| fwd_sum_loss_weight = config.get('fwd_sum_loss_weight', 1.0) |
| |
| |
| for epoch in range(1, epochs + 1): |
| aligner.train() |
| train_losses = [] |
| train_fwd_losses = [] |
| start_time = time.time() |
| |
| |
| pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}/{epochs} [Train]") |
| for i, batch in enumerate(pbar): |
| batch = [b.to(device) for b in batch] |
|
|
| text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch |
| |
| |
| attn_soft, attn_logprob = aligner(spec=mel_input, |
| spec_len=mel_input_length, |
| text=text_input, |
| text_len=text_input_length, |
| attn_prior=attn_prior) |
| |
| |
| loss = forward_sum_loss(attn_logprob=attn_logprob, |
| in_lens=text_input_length, |
| out_lens=mel_input_length) |
| |
| |
| optimizer.zero_grad() |
| loss.backward() |
| |
| |
| grad_norm = nn.utils.clip_grad_norm_(aligner.parameters(), config.get('grad_clip', 5.0)) |
| |
| optimizer.step() |
| if scheduler is not None: |
| scheduler.step() |
| |
| |
| global_step = (epoch - 1) * len(train_dataloader) + i |
| writer.add_scalar('train/total_loss', loss.item(), global_step) |
| writer.add_scalar('train/grad_norm', grad_norm, global_step) |
| |
| |
| train_losses.append(loss.item()) |
| train_fwd_losses.append(loss.item()) |
| |
| |
| pbar.set_description(f"Epoch {epoch}/{epochs} [Train] Loss: {loss.item():.4f}") |
| |
| |
| avg_train_loss = sum(train_losses) / len(train_losses) |
| |
| |
| aligner.eval() |
| val_losses = [] |
| |
| with torch.no_grad(): |
| for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{epochs} [Val]"): |
| batch = [b.to(device) for b in batch] |
| |
| text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch |
| |
| |
| attn_soft, attn_logprob = aligner(spec=mel_input, |
| spec_len=mel_input_length, |
| text=text_input, |
| text_len=text_input_length, |
| attn_prior=attn_prior) |
| |
| |
| val_loss = forward_sum_loss(attn_logprob=attn_logprob, |
| in_lens=text_input_length, |
| out_lens=mel_input_length) |
| |
| val_losses.append(val_loss.item()) |
| |
| |
| avg_val_loss = sum(val_losses) / len(val_losses) |
| |
| |
| writer.add_scalar('epoch/train_loss', avg_train_loss, epoch) |
| writer.add_scalar('epoch/val_loss', avg_val_loss, epoch) |
| |
| |
| if avg_val_loss < best_val_loss: |
| best_val_loss = avg_val_loss |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': aligner.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'train_loss': avg_train_loss, |
| 'val_loss': avg_val_loss, |
| }, os.path.join(log_dir, 'checkpoints', 'best_model.pt')) |
| |
| |
| if epoch % config.get('save_every', 10) == 0: |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': aligner.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'train_loss': avg_train_loss, |
| 'val_loss': avg_val_loss, |
| }, os.path.join(log_dir, 'checkpoints', f'checkpoint_epoch_{epoch}.pt')) |
| |
| |
| epoch_time = time.time() - start_time |
| print(f"Epoch {epoch}/{epochs} completed in {epoch_time:.2f}s | " |
| f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}") |
| |
| |
| if epoch % config.get('plot_every', 10) == 0: |
| plot_attention_matrices(aligner, val_dataloader, device, |
| os.path.join(log_dir, 'attention_plots', f'epoch_{epoch}'), |
| num_samples=4) |
| |
| writer.close() |
| print(f"Training completed. Best validation loss: {best_val_loss:.4f}") |
| return aligner |
|
|
| |
| if __name__ == "__main__": |
|
|
| def length_to_mask(lengths): |
| mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) |
| mask = torch.gt(mask+1, lengths.unsqueeze(1)) |
| return mask |
| |
| |
| train_aligner( |
| config=config, |
| train_dataloader=train_dataloader, |
| val_dataloader=val_dataloader, |
| device=device, |
| log_dir=log_dir, |
| epochs=epoch |
| ) |