| import os |
| |
| import random |
| import hydra |
| import numpy as np |
| import librosa |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| import pytorch_lightning as pl |
| from vq import CodecEncoder, CodecDecoderVocos |
| from module import HiFiGANMultiPeriodDiscriminator, SpecDiscriminator |
| from criterions import GANLoss, MultiResolutionMelSpectrogramLoss, MultiResolutionSTFTLoss |
| from common.schedulers import WarmupLR |
| from transformers import AutoModel |
| from vq.module import SemanticDecoder,SemanticEncoder |
| from transformers import AutoFeatureExtractor, Wav2Vec2BertModel |
| import sys |
| |
| |
| |
|
|
|
|
| from transformers import AutoModel, AutoFeatureExtractor |
|
|
|
|
| class CodecLightningModule(pl.LightningModule): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.ocwd = hydra.utils.get_original_cwd() |
| self.construct_model() |
| self.construct_criteria() |
| self.save_hyperparameters() |
| self.automatic_optimization = False |
|
|
| def construct_model(self): |
| |
| |
| enccfg = self.cfg.model.codec_encoder |
|
|
| |
| self.CodecEnc = CodecEncoder( |
| |
| ngf=enccfg.ngf, |
| up_ratios=enccfg.up_ratios, |
| dilations=enccfg.dilations, |
| hidden_dim=enccfg['hidden_dim'], |
| depth=enccfg['depth'], |
| heads=enccfg['heads'], |
| pos_meb_dim=enccfg['pos_meb_dim'], |
| ) |
|
|
| |
| deccfg = self.cfg.model.codec_decoder |
|
|
| self.generator = CodecDecoderVocos( |
| hidden_dim=deccfg.hidden_dim, |
| depth=deccfg.depth, |
| heads=deccfg.heads, |
| pos_meb_dim=deccfg.pos_meb_dim, |
| hop_length=960, |
| vq_num_quantizers=deccfg.vq_num_quantizers, |
| vq_dim=deccfg.vq_dim, |
| vq_commit_weight=deccfg.vq_commit_weight, |
| vq_weight_init=deccfg.vq_weight_init, |
| vq_full_commit_loss=deccfg.vq_full_commit_loss, |
| codebook_size=deccfg.codebook_size, |
| codebook_dim=deccfg.codebook_dim , |
| |
| ) |
| |
| |
|
|
| |
| mpdcfg = self.cfg.model.mpd |
| self.discriminator = HiFiGANMultiPeriodDiscriminator( |
| periods=mpdcfg.periods, |
| max_downsample_channels=mpdcfg.max_downsample_channels, |
| channels=mpdcfg.channels, |
| channel_increasing_factor=mpdcfg.channel_increasing_factor, |
| ) |
|
|
| |
| mstftcfg = self.cfg.model.mstft |
| self.spec_discriminator = SpecDiscriminator( |
| stft_params=mstftcfg.stft_params, |
| in_channels=mstftcfg.in_channels, |
| out_channels=mstftcfg.out_channels, |
| kernel_sizes=mstftcfg.kernel_sizes, |
| channels=mstftcfg.channels, |
| max_downsample_channels=mstftcfg.max_downsample_channels, |
| downsample_scales=mstftcfg.downsample_scales, |
| use_weight_norm=mstftcfg.use_weight_norm, |
| ) |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| self.speaker_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus-sv") |
| self.speaker_feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv") |
| self.speaker_model.eval() |
| self.speaker_model.requires_grad_(False) |
|
|
| self.fc_prior = nn.Linear(1024 + 1024, deccfg.vq_dim, ) |
| self.fc_post_a = nn.Linear(deccfg.vq_dim, deccfg.hidden_dim ) |
| self.fc_post_s = nn.Linear(deccfg.vq_dim, 1024) |
|
|
| self.SemanticDecoder_module = SemanticDecoder(1024, 1024, 1024) |
| self.SemanticEncoder_module = SemanticEncoder(1024, 1024, 1024) |
| self.semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", output_hidden_states=True) |
| self.semantic_model.eval() |
| self.semantic_model.requires_grad_(False) |
| |
|
|
| |
| |
| |
|
|
| def construct_criteria(self): |
| cfg = self.cfg.train |
| self.criteria = nn.ModuleDict() |
| if cfg.use_mel_loss: |
| self.criteria['mel_loss'] = MultiResolutionMelSpectrogramLoss(sample_rate=self.cfg.preprocess.audio.sr) |
| if cfg.use_stft_loss: |
| self.criteria['stft_loss'] = MultiResolutionSTFTLoss( |
| fft_sizes=cfg.stft_loss_params.fft_sizes, |
| hop_sizes=cfg.stft_loss_params.hop_sizes, |
| win_sizes=cfg.stft_loss_params.win_lengths |
| ) |
| if cfg.use_feat_match_loss: |
| self.criteria['fm_loss'] = nn.L1Loss() |
| self.criteria['gan_loss'] = GANLoss() |
| self.criteria['l1_loss'] = nn.L1Loss() |
| self.criteria['l2_loss'] = nn.MSELoss() |
| print(self.criteria) |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def forward(self, batch): |
| wav = batch['wav'] |
| feats = batch['feats'] |
| |
| vq_emb = self.CodecEnc(wav.unsqueeze(1)) |
|
|
| with torch.no_grad(): |
| semantic_target = self.semantic_model(feats) |
| semantic_target = semantic_target.hidden_states[16].detach() |
|
|
| T_codec = vq_emb.shape[1] |
| T_semantic = semantic_target.shape[1] |
| |
|
|
| semantic_target_for_loss = semantic_target.clone() |
| |
| if T_codec != T_semantic: |
| semantic_target = F.interpolate( |
| semantic_target.transpose(1, 2), |
| size=T_codec, |
| mode='linear', |
| align_corners=False |
| ).transpose(1, 2) |
|
|
| semantic_target_transposed = semantic_target.transpose(1, 2) |
| semantic_target_processed = self.SemanticEncoder_module(semantic_target_transposed) |
| semantic_target_processed = semantic_target_processed.transpose(1, 2) |
| |
| vq_emb = torch.cat([semantic_target_processed, vq_emb], dim=2) |
| vq_emb = self.fc_prior(vq_emb) |
| |
| vq_emb = vq_emb.transpose(1, 2) |
| vq_post_emb, vq_code, vq_loss = self.generator(vq_emb, vq=True) |
| |
| vq_post_emb_t = vq_post_emb.transpose(1, 2) |
| |
| semantic_recon = self.fc_post_s(vq_post_emb_t) |
| |
| semantic_recon_transposed = semantic_recon.transpose(1, 2) |
| semantic_recon = self.SemanticDecoder_module(semantic_recon_transposed) |
| semantic_recon = semantic_recon.transpose(1, 2) |
| |
| |
| if T_codec != T_semantic: |
| semantic_recon_for_loss = F.interpolate( |
| semantic_recon.transpose(1, 2), |
| size=T_semantic, |
| mode='linear', |
| align_corners=False |
| ).transpose(1, 2) |
| else: |
| semantic_recon_for_loss = semantic_recon |
|
|
| |
| gen_input = self.fc_post_a(vq_post_emb_t) |
| y_, _ = self.generator(gen_input.transpose(1, 2), vq=False) |
| y = wav.unsqueeze(1) |
|
|
| output = { |
| 'gt_wav': y, |
| 'gen_wav': y_, |
| 'vq_loss': vq_loss, |
| 'vq_code': vq_code, |
| 'semantic_recon_loss': F.mse_loss(semantic_recon_for_loss, semantic_target_for_loss), |
| } |
| return output |
| |
| @torch.inference_mode() |
| def inference(self, wav): |
| vq_emb = self.CodecEnc(wav.unsqueeze(1)) |
| vq_post_emb, vq_code, vq_loss = self.generator(vq_emb, vq=True) |
| y_ = self.generator(vq_post_emb, vq=False).squeeze(1) |
| return y_ |
|
|
| def compute_disc_loss(self, batch, output): |
| y, y_ = output['gt_wav'], output['gen_wav'] |
| y_ = y_.detach() |
| p = self.discriminator(y) |
| p_ = self.discriminator(y_) |
|
|
| real_loss_list, fake_loss_list = [], [] |
| for i in range(len(p)): |
| real_loss, fake_loss = self.criteria['gan_loss'].disc_loss(p[i][-1], p_[i][-1]) |
| real_loss_list.append(real_loss) |
| fake_loss_list.append(fake_loss) |
|
|
| if hasattr(self, 'spec_discriminator'): |
| sd_p = self.spec_discriminator(y) |
| sd_p_ = self.spec_discriminator(y_) |
|
|
| for i in range(len(sd_p)): |
| real_loss, fake_loss = self.criteria['gan_loss'].disc_loss(sd_p[i][-1], sd_p_[i][-1]) |
| real_loss_list.append(real_loss) |
| fake_loss_list.append(fake_loss) |
|
|
| real_loss = sum(real_loss_list) |
| fake_loss = sum(fake_loss_list) |
|
|
| disc_loss = real_loss + fake_loss |
| disc_loss = self.cfg.train.lambdas.lambda_disc * disc_loss |
|
|
| output = { |
| 'real_loss': real_loss, |
| 'fake_loss': fake_loss, |
| 'disc_loss': disc_loss, |
| } |
| return output |
|
|
| def compute_gen_loss(self, batch, output): |
| y, y_ = output['gt_wav'], output['gen_wav'] |
| vq_loss, vq_code = output['vq_loss'], output['vq_code'] |
| semantic_recon_loss = output['semantic_recon_loss'] |
| |
| |
| gen_loss = 0.0 |
| self.set_discriminator_gradients(False) |
| output_dict = {} |
| cfg = self.cfg.train |
|
|
| |
| if cfg.use_mel_loss: |
| mel_loss = self.criteria['mel_loss'](y_.squeeze(1), y.squeeze(1)) |
| gen_loss += mel_loss * cfg.lambdas.lambda_mel_loss |
| output_dict['mel_loss'] = mel_loss |
|
|
| |
| p_ = self.discriminator(y_) |
| adv_loss_list = [] |
| for i in range(len(p_)): |
| adv_loss_list.append(self.criteria['gan_loss'].gen_loss(p_[i][-1])) |
| if hasattr(self, 'spec_discriminator'): |
| sd_p_ = self.spec_discriminator(y_) |
| for i in range(len(sd_p_)): |
| adv_loss_list.append(self.criteria['gan_loss'].gen_loss(sd_p_[i][-1])) |
| adv_loss = sum(adv_loss_list) |
| gen_loss += adv_loss * cfg.lambdas.lambda_adv |
| output_dict['adv_loss'] = adv_loss |
|
|
| |
| if cfg.use_feat_match_loss: |
| fm_loss = 0.0 |
| with torch.no_grad(): |
| p = self.discriminator(y) |
| for i in range(len(p_)): |
| for j in range(len(p_[i]) - 1): |
| fm_loss += self.criteria['fm_loss'](p_[i][j], p[i][j].detach()) |
| gen_loss += fm_loss * cfg.lambdas.lambda_feat_match_loss |
| output_dict['fm_loss'] = fm_loss |
| if hasattr(self, 'spec_discriminator'): |
| spec_fm_loss = 0.0 |
| with torch.no_grad(): |
| sd_p = self.spec_discriminator(y) |
| for i in range(len(sd_p_)): |
| for j in range(len(sd_p_[i]) - 1): |
| spec_fm_loss += self.criteria['fm_loss'](sd_p_[i][j], sd_p[i][j].detach()) |
| gen_loss += spec_fm_loss * cfg.lambdas.lambda_feat_match_loss |
| output_dict['spec_fm_loss'] = spec_fm_loss |
|
|
| |
| if vq_loss is not None: |
| vq_loss = sum(vq_loss) |
| gen_loss += vq_loss |
| output_dict['vq_loss'] = vq_loss |
|
|
| |
| output_dict['semantic_recon_loss'] = semantic_recon_loss |
| gen_loss += output_dict['semantic_recon_loss'] * cfg.lambdas.lambda_semantic_loss |
|
|
| |
| |
| |
| |
| self.set_discriminator_gradients(True) |
| output_dict['gen_loss'] = gen_loss |
| return output_dict |
|
|
| def training_step(self, batch, batch_idx): |
| output = self(batch) |
|
|
| gen_opt, disc_opt = self.optimizers() |
| gen_sche, disc_sche = self.lr_schedulers() |
|
|
| |
| disc_losses = self.compute_disc_loss(batch, output) |
| disc_loss = disc_losses['disc_loss'] |
| disc_opt.zero_grad() |
| self.manual_backward(disc_loss) |
| self.clip_gradients( |
| disc_opt, |
| gradient_clip_val=self.cfg.train.disc_grad_clip, |
| gradient_clip_algorithm='norm' |
| ) |
| disc_opt.step() |
| disc_sche.step() |
|
|
| |
| gen_losses = self.compute_gen_loss(batch, output) |
| gen_loss = gen_losses['gen_loss'] |
| gen_opt.zero_grad() |
| self.manual_backward(gen_loss) |
| self.clip_gradients( |
| gen_opt, |
| gradient_clip_val=self.cfg.train.gen_grad_clip, |
| gradient_clip_algorithm='norm' |
| ) |
| gen_opt.step() |
| gen_sche.step() |
|
|
| |
| self.log_dict( |
| disc_losses, |
| on_step=True, |
| on_epoch=True, |
| prog_bar=True, |
| logger=True, |
| batch_size=self.cfg.dataset.train.batch_size, |
| sync_dist=True |
| ) |
| self.log_dict( |
| gen_losses, |
| on_step=True, |
| on_epoch=True, |
| prog_bar=True, |
| logger=True, |
| batch_size=self.cfg.dataset.train.batch_size, |
| sync_dist=True |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| def validation_step(self, batch, batch_idx): |
| output = self(batch) |
| y = output['gt_wav'] |
| y_ = output['gen_wav'] |
| |
| |
| y_audio = y.squeeze(1).cpu().numpy() |
| y_recon_audio = y_.squeeze(1).cpu().numpy() |
| |
| embeddings1_list = [] |
| embeddings2_list = [] |
| |
| |
| for i in range(y_audio.shape[0]): |
| |
| y_16k = librosa.resample(y_audio[i], orig_sr=self.cfg.preprocess.audio.sr, target_sr=16000) |
| y_recon_16k = librosa.resample(y_recon_audio[i], orig_sr=self.cfg.preprocess.audio.sr, target_sr=16000) |
| |
| |
| inputs1 = self.speaker_feature_extractor( |
| y_16k, |
| sampling_rate=16000, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| inputs2 = self.speaker_feature_extractor( |
| y_recon_16k, |
| sampling_rate=16000, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| |
| with torch.no_grad(): |
| outputs1 = self.speaker_model(**inputs1) |
| outputs2 = self.speaker_model(**inputs2) |
| |
| |
| embedding1 = torch.mean(outputs1.last_hidden_state, dim=1) |
| embedding2 = torch.mean(outputs2.last_hidden_state, dim=1) |
| |
| |
| embedding1 = F.normalize(embedding1, p=2, dim=1) |
| embedding2 = F.normalize(embedding2, p=2, dim=1) |
| |
| embeddings1_list.append(embedding1) |
| embeddings2_list.append(embedding2) |
| |
| |
| embeddings1 = torch.cat(embeddings1_list, dim=0) |
| embeddings2 = torch.cat(embeddings2_list, dim=0) |
| |
| |
| sim = F.cosine_similarity(embeddings1, embeddings2) |
| sim = sim.mean() |
| |
| self.log('val/sim', sim, on_step=False, on_epoch=True, prog_bar=True, logger=True) |
| |
| return {'sim': sim} |
|
|
| |
|
|
| def test_step(self, batch, batch_idx): |
| |
| pass |
|
|
| def configure_optimizers(self): |
| from itertools import chain |
|
|
| |
| disc_params = self.discriminator.parameters() |
| |
| disc_params = chain(disc_params, self.spec_discriminator.parameters()) |
|
|
| |
| gen_params = chain( |
| self.CodecEnc.parameters(), |
| self.generator.parameters(), |
| |
| self.fc_prior.parameters(), |
| self.fc_post_a.parameters(), |
| self.fc_post_s.parameters(), |
| self.SemanticDecoder_module.parameters(), |
| self.SemanticEncoder_module.parameters() |
| ) |
|
|
| |
| gen_opt = optim.AdamW(gen_params, **self.cfg.train.gen_optim_params) |
| disc_opt = optim.AdamW(disc_params, **self.cfg.train.disc_optim_params) |
|
|
| |
| gen_sche = WarmupLR(gen_opt, **self.cfg.train.gen_schedule_params) |
| disc_sche = WarmupLR(disc_opt, **self.cfg.train.disc_schedule_params) |
|
|
| print(f'Generator optim: {gen_opt}') |
| print(f'Discriminator optim: {disc_opt}') |
|
|
| return [gen_opt, disc_opt], [gen_sche, disc_sche] |
|
|
| def set_discriminator_gradients(self, flag=True): |
| for p in self.discriminator.parameters(): |
| p.requires_grad = flag |
|
|
| if hasattr(self, 'spec_discriminator'): |
| for p in self.spec_discriminator.parameters(): |
| p.requires_grad = flag |
|
|