| import argparse |
| import filecmp |
| import multiprocessing |
| import os |
| import subprocess |
| import librosa |
| from functools import partial |
| from multiprocessing import Pool, Process |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| from torch.optim import AdamW |
|
|
| from modules.vocoder.commons.stft_loss import MultiResolutionSTFTLoss |
| from modules.vocoder.hifigan.hifigan import MultiPeriodDiscriminator, MultiScaleDiscriminator, \ |
| generator_loss, feature_loss, discriminator_loss |
| from modules.vocoder.hifigan.mel_utils import mel_spectrogram |
| from modules.vocoder.univnet.mrd import MultiResolutionDiscriminator |
| from modules.tts.wavvae.decoder.wavvae_v3 import WavVAE_V3 |
| from tasks.tts.utils.audio import torch_wav2spec |
| from tasks.tts.utils.audio.align import mel2token_to_dur |
| from utils.commons.ckpt_utils import load_ckpt |
| from utils.commons.hparams import hparams |
|
|
| from attrdict import AttrDict |
| from tasks.tts.dataset_mixin import TTSDatasetMixin |
| from utils.commons.base_task import BaseTask |
| from utils.commons.import_utils import import_module_bystr |
| from utils.nn.schedulers import WarmupSchedule, CosineSchedule |
|
|
|
|
| class WavVAETask(TTSDatasetMixin, BaseTask): |
| def __init__(self): |
| super().__init__() |
| self.dataset_cls = import_module_bystr(hparams['dataset_cls']) |
| self.val_dataset_cls = import_module_bystr(hparams['val_dataset_cls']) |
| self.processer_fn = import_module_bystr(hparams['processer_fn']) |
| self.build_fast_dataloader = import_module_bystr(hparams['build_fast_dataloader']) |
| self.hparams = hparams |
| self.config = AttrDict(hparams) |
|
|
| |
| sample_rate = hparams["audio_sample_rate"] |
| fft_size = hparams["win_size"] |
| win_size = hparams["win_size"] |
| hop_size = hparams["hop_size"] |
| num_mels = hparams["audio_num_mel_bins"] |
| fmin = hparams["fmin"] |
| fmax = hparams["fmax"] |
| mel_basis = librosa.filters.mel( |
| sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax |
| ) |
| self.torch_wav2spec_ = partial( |
| torch_wav2spec, mel_basis=mel_basis, fft_size=fft_size, hop_size=hop_size, win_length=win_size, |
| ) |
|
|
| def build_model(self): |
| self.model_gen = WavVAE_V3(hparams=hparams) |
|
|
| self.model_disc = torch.nn.ModuleDict() |
| self.model_disc['mpd'] = MultiPeriodDiscriminator(hparams['mpd'], use_cond=hparams['use_cond_disc']) |
| self.model_disc['msd'] = MultiScaleDiscriminator(use_cond=hparams['use_cond_disc']) |
| if hparams['use_mrd']: |
| self.model_disc['mrd'] = MultiResolutionDiscriminator(hparams) |
| self.stft_loss = MultiResolutionSTFTLoss() |
|
|
| |
| |
| |
| return {'trainable': [self.model_gen, self.model_disc['mpd'], self.model_disc['msd'], self.model_disc['mrd']], 'others': []} |
|
|
| def load_model(self): |
| if hparams.get('load_ckpt', '') != '': |
| load_ckpt(self.model, hparams['load_ckpt'], 'model', strict=False) |
|
|
| def build_optimizer(self): |
| optimizer_gen = torch.optim.AdamW(self.model_gen.parameters(), lr=hparams['lr'], |
| betas=[hparams['adam_b1'], hparams['adam_b2']]) |
| optimizer_disc = torch.optim.AdamW(self.model_disc.parameters(), |
| lr=hparams.get('disc_lr', hparams['lr']), |
| betas=[hparams['adam_b1'], hparams['adam_b2']]) |
| return [optimizer_gen, optimizer_disc] |
|
|
| def build_scheduler(self, optimizer): |
| return None |
|
|
| def _training_step(self, sample, batch_idx, optimizer_idx): |
| log_outputs = {} |
| loss_weights = {} |
| sample['wavs'] = sample['wavs'].float() |
| |
|
|
| if self.global_step % 100 == 0: |
| devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",") |
| for d in devices: |
| os.system(f'pkill -f "voidgpu{d}"') |
|
|
| y = sample['wavs'] |
| loss_output = {} |
| if optimizer_idx == 0: |
| |
| |
| |
| y_, posterior = self.model_gen(y) |
| y = y.unsqueeze(1) |
| y_mel = mel_spectrogram(y.squeeze(1), hparams).transpose(1, 2) |
| y_hat_mel = mel_spectrogram(y_.squeeze(1), hparams).transpose(1, 2) |
| loss_output['mel'] = F.l1_loss(y_hat_mel, y_mel) * hparams['lambda_mel'] |
| if self.training: |
| _, y_p_hat_g, fmap_f_r, fmap_f_g = self.model_disc['mpd'](y, y_, None) |
| _, y_s_hat_g, fmap_s_r, fmap_s_g = self.model_disc['msd'](y, y_, None) |
| loss_output['a_p'] = generator_loss(y_p_hat_g) * hparams['lambda_adv'] * hparams.get('lambda_mpd', 1.0) |
| loss_output['a_s'] = generator_loss(y_s_hat_g) * hparams['lambda_adv'] * hparams.get('lambda_msd', 1.0) |
| if hparams['use_mrd']: |
| y_r_hat_g = [x[1] for x in self.model_disc['mrd'](y_)] |
| loss_output['a_r'] = generator_loss(y_r_hat_g) \ |
| * hparams['lambda_adv'] * hparams.get('lambda_mrd', 1.0) |
| if hparams['use_ms_stft']: |
| loss_output['sc'], loss_output['mag'] = self.stft_loss(y.squeeze(1), y_.squeeze(1)) |
| loss_output['kl_loss'] = posterior.kl().mean() * hparams.get('lambda_kl', 1.0) |
| self.y_ = y_.detach() |
| else: |
| |
| |
| |
| if not self.training: |
| return None |
| y = y.unsqueeze(1) |
| y_ = self.y_ |
| |
| y_p_hat_r, y_p_hat_g, _, _ = self.model_disc['mpd'](y, y_.detach(), None) |
| loss_output['r_p'], loss_output['f_p'] = discriminator_loss(y_p_hat_r, y_p_hat_g) |
| |
| y_s_hat_r, y_s_hat_g, _, _ = self.model_disc['msd'](y, y_.detach(), None) |
| loss_output['r_s'], loss_output['f_s'] = discriminator_loss(y_s_hat_r, y_s_hat_g) |
| |
| if hparams['use_mrd']: |
| y_r_hat_r = [x[1] for x in self.model_disc['mrd'](y)] |
| y_r_hat_g = [x[1] for x in self.model_disc['mrd'](y_.detach())] |
| loss_output['r_r'], loss_output['f_r'] = discriminator_loss(y_r_hat_r, y_r_hat_g) |
| total_loss = sum(loss_output.values()) |
| loss_output['bs'] = sample['wavs'].shape[0] |
| return total_loss, loss_output |
|
|
| def save_valid_result(self, sample, batch_idx, model_out): |
| sr = hparams['audio_sample_rate'] |
| mel_out = model_out.get('mel_out') |
| f0 = sample.get('f0') |
| f0_gt = sample.get('f0') |
| if f0 is not None: |
| f0_gt = f0_gt.cpu()[-1] |
| if mel_out is not None: |
| f0_pred = self.predict_f0(sample['mels']) |
| self.plot_mel(batch_idx, sample['mels'], mel_out, f0s={'f0': f0_pred, 'f0g': f0_gt}) |
| |
| if self.global_step <= hparams['valid_infer_interval']: |
| mel_gt = sample['mels'][-1].cpu() |
| f0 = self.predict_f0(sample['mels'][-1:]) |
| wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0) |
| self.logger.add_audio(f'wav_gt_{batch_idx}', wav_gt, self.global_step, sr) |
|
|
| if self.global_step >= 0: |
| |
| model_out = self.run_model(sample, infer=True, infer_use_gt_dur=True) |
| |
| |
| dur_info = None |
|
|
| f0 = self.predict_f0(model_out['mel_out']) |
| wav_pred = self.vocoder.spec2wav(model_out['mel_out'][-1].cpu(), f0=f0) |
| self.logger.add_audio(f'wav_gdur_{batch_idx}', wav_pred, self.global_step, sr) |
| self.plot_mel(batch_idx, sample['mels'][-1:], model_out['mel_out'][-1], f'mel_gdur_{batch_idx}', |
| dur_info=dur_info, f0s={'f0': f0, 'f0g': f0_gt}) |
|
|
| |
| if not hparams['use_gt_dur'] and not hparams['use_gt_latent']: |
| model_out = self.run_model(sample, infer=True, infer_use_gt_dur=False) |
| |
| dur_info = None |
| f0 = self.predict_f0(model_out['mel_out']) |
| self.plot_mel( |
| batch_idx, sample['mels'], model_out['mel_out'][-1], f'mel_pdur_{batch_idx}', |
| dur_info=dur_info, f0s={'f0': f0, 'f0g': f0_gt}) |
| wav_pred = self.vocoder.spec2wav(model_out['mel_out'][-1].cpu(), f0=f0) |
| self.logger.add_audio(f'wav_pdur_{batch_idx}', wav_pred, self.global_step, sr) |
|
|
| def get_plot_dur_info(self, sample, model_out): |
| T_txt = sample['txt_tokens'].shape[1] |
| dur_gt = mel2token_to_dur(sample['mel2ph'], T_txt)[-1] |
| dur_pred = model_out['dur'] if 'dur' in model_out else dur_gt |
| txt = self.token_encoder.decode(sample['txt_tokens'][-1].cpu().numpy()) |
| txt = txt.split(" ") |
| return {'dur_gt': dur_gt, 'dur_pred': dur_pred, 'txt': txt} |
|
|
| def on_before_optimization(self, opt_idx): |
| if opt_idx == 0: |
| nn.utils.clip_grad_norm_(self.model_gen.parameters(), hparams['generator_grad_norm']) |
| else: |
| nn.utils.clip_grad_norm_(self.model_disc.parameters(), hparams["discriminator_grad_norm"]) |
|
|
| def to(self, device=None, dtype=None): |
| super().to(device=device, dtype=dtype) |
| |
| if hparams.get('use_ema', False): |
| self.ema.to(device=device, dtype=dtype) |
|
|
| def cuda(self,device): |
| super().cuda(device) |
| if hparams.get('use_ema', False): |
| self.ema.to(device=device) |
| |
| @torch.no_grad() |
| def validation_step(self, sample, batch_idx): |
| infer_steps = self.hparams.get('infer_steps', 12) |
| outputs = self._validation_step(sample, batch_idx, infer_steps) |
| return outputs |
|
|
| def _validation_step(self, sample, batch_idx, infer_steps): |
| outputs = {} |
| if self.trainer.proc_rank == 0: |
| |
| |
| |
| |
| |
| |
| pass |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return outputs |
|
|
| @torch.no_grad() |
| def test_step(self, sample, batch_idx): |
| infer_steps = hparams['infer_steps'] |
| return self._validation_step(sample, batch_idx, infer_steps) |