import yaml import random import inspect import numpy as np from tqdm import tqdm import typing as tp from abc import ABC import torch import torch.nn as nn import torch.nn.functional as F import torchaudio from einops import repeat from tools.torch_tools import wav_to_fbank import os import diffusers from diffusers.utils.torch_utils import randn_tensor from diffusers import DDPMScheduler from models.transformer_2d_flow import Transformer2DModel from libs.rvq.descript_quantize3 import ResidualVectorQuantize from torch.cuda.amp import autocast from muq_dev.test import load_model class SampleProcessor(torch.nn.Module): def project_sample(self, x: torch.Tensor): """Project the original sample to the 'space' where the diffusion will happen.""" return x def return_sample(self, z: torch.Tensor): """Project back from diffusion space to the actual sample space.""" return z class Feature2DProcessor(SampleProcessor): def __init__(self, dim: int = 8, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1., \ num_samples: int = 100_000): super().__init__() self.num_samples = num_samples self.dim = dim self.power_std = power_std self.register_buffer('counts', torch.zeros(1)) self.register_buffer('sum_x', torch.zeros(dim, 32)) self.register_buffer('sum_x2', torch.zeros(dim, 32)) self.register_buffer('sum_target_x2', torch.zeros(dim, 32)) self.counts: torch.Tensor self.sum_x: torch.Tensor self.sum_x2: torch.Tensor @property def mean(self): mean = self.sum_x / self.counts return mean @property def std(self): std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() return std @property def target_std(self): return 1 def project_sample(self, x: torch.Tensor): assert x.dim() == 4 if self.counts.item() < self.num_samples: self.counts += len(x) self.sum_x += x.mean(dim=(2,)).sum(dim=0) self.sum_x2 += x.pow(2).mean(dim=(2,)).sum(dim=0) rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size x = (x - self.mean.view(1, -1, 1, 32).contiguous()) * rescale.view(1, -1, 1, 32).contiguous() return x def return_sample(self, x: torch.Tensor): assert x.dim() == 4 rescale = (self.std / self.target_std) ** self.power_std x = x * rescale.view(1, -1, 1, 32).contiguous() + self.mean.view(1, -1, 1, 32).contiguous() return x class BASECFM(torch.nn.Module, ABC): def __init__( self, estimator, ): super().__init__() self.sigma_min = 1e-4 self.estimator = estimator @torch.inference_mode() def forward(self, mu, n_timesteps, temperature=1.0): """Forward diffusion Args: mu (torch.Tensor): output of encoder shape: (batch_size, n_channels, mel_timesteps, n_feats) n_timesteps (int): number of diffusion steps temperature (float, optional): temperature for scaling noise. Defaults to 1.0. Returns: sample: generated mel-spectrogram shape: (batch_size, n_channels, mel_timesteps, n_feats) """ z = torch.randn_like(mu) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) return self.solve_euler(z, t_span=t_span) def solve_euler(self, x, incontext_x, incontext_length, t_span, mu, added_cond_kwargs, guidance_scale): """ Fixed euler solver for ODEs. Args: x (torch.Tensor): random noise t_span (torch.Tensor): n_timesteps interpolated shape: (n_timesteps + 1,) mu (torch.Tensor): output of encoder shape: (batch_size, n_channels, mel_timesteps, n_feats) """ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] noise = x.clone() # I am storing this because I can later plot it by putting a debugger here and saving it to a file # Or in future might add like a return_all_steps flag sol = [] for step in tqdm(range(1, len(t_span))): x[:,:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,:,0:incontext_length,:] + t * incontext_x[:,:,0:incontext_length,:] if(guidance_scale > 1.0): dphi_dt = self.estimator( \ torch.cat([ \ torch.cat([x, x], 0), \ torch.cat([incontext_x, incontext_x], 0), \ torch.cat([torch.zeros_like(mu), mu], 0), \ ], 1), \ timestep = t.unsqueeze(-1).repeat(2), \ added_cond_kwargs={k:torch.cat([v,v],0) for k,v in added_cond_kwargs.items()}).sample dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) else: dphi_dt = self.estimator(torch.cat([x, incontext_x, mu], 1), \ timestep = t.unsqueeze(-1), added_cond_kwargs=added_cond_kwargs).sample x = x + dt * dphi_dt t = t + dt sol.append(x) if step < len(t_span) - 1: dt = t_span[step + 1] - t return sol[-1] class PromptCondAudioDiffusion(nn.Module): def __init__( self, num_channels, unet_model_name=None, unet_model_config_path=None, snr_gamma=None, uncondition=True, out_paint=False, ): super().__init__() assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" self.unet_model_name = unet_model_name self.unet_model_config_path = unet_model_config_path self.snr_gamma = snr_gamma self.uncondition = uncondition self.num_channels = num_channels # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview self.normfeat = Feature2DProcessor(dim=num_channels) self.sample_rate = 48000 self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) muencoder_dir = "muq_dev/muq_fairseq" muencoder_ckpt = "muq_dev/muq.pt" self.muencoder = load_model( model_dir=os.path.abspath(muencoder_dir), checkpoint_dir=os.path.abspath(muencoder_ckpt), ) self.rsq48tomuencoder = torchaudio.transforms.Resample(48000, 24000) for v in self.muencoder.parameters():v.requires_grad = False self.rvq_muencoder_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) self.cond_muencoder_emb = nn.Linear(1024, 16*32) self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) unet = Transformer2DModel.from_config( unet_model_config_path, ) self.set_from = "random" self.cfm_wrapper = BASECFM(unet) print("Transformer initialized from pretrain.") def compute_snr(self, timesteps): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 """ alphas_cumprod = self.noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma) ** 2 return snr def preprocess_audio(self, input_audios, threshold=0.9): assert len(input_audios.shape) == 2, input_audios.shape norm_value = torch.ones_like(input_audios[:,0]) max_volume = input_audios.abs().max(dim=-1)[0] norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold return input_audios/norm_value.unsqueeze(-1) def extract_muencoder_embeds(self, input_audio_0,input_audio_1,layer): input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 input_wav_mean = self.muencoder(self.rsq48tomuencoder(input_wav_mean), features_only = True) layer_results = input_wav_mean['layer_results'] muencoder_emb = layer_results[layer] muencoder_emb = muencoder_emb.permute(0,2,1).contiguous() return muencoder_emb def init_device_dtype(self, device, dtype): self.device = device self.dtype = dtype @torch.no_grad() def fetch_codes(self, input_audios, additional_feats,layer): input_audio_0 = input_audios[[0],:] input_audio_1 = input_audios[[1],:] input_audio_0 = self.preprocess_audio(input_audio_0) input_audio_1 = self.preprocess_audio(input_audio_1) self.muencoder.eval() muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer) muencoder_emb = muencoder_emb.detach() self.rvq_muencoder_emb.eval() quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb) spk_embeds = None return [codes_muencoder_emb], [muencoder_emb], spk_embeds @torch.no_grad() def fetch_codes_batch(self, input_audios, additional_feats,layer): input_audio_0 = input_audios[:,0,:] input_audio_1 = input_audios[:,1,:] input_audio_0 = self.preprocess_audio(input_audio_0) input_audio_1 = self.preprocess_audio(input_audio_1) self.muencoder.eval() muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer) muencoder_emb = muencoder_emb.detach() self.rvq_muencoder_emb.eval() quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb) # b,d,t spk_embeds = None return [codes_muencoder_emb], [muencoder_emb], spk_embeds @torch.no_grad() def inference_codes(self, codes, spk_embeds, true_latents, latent_length,incontext_length, additional_feats, guidance_scale=2, num_steps=20, disable_progress=True, scenario='start_seg'): classifier_free_guidance = guidance_scale > 1.0 device = self.device dtype = self.dtype codes_muencoder_emb = codes[0] batch_size = codes_muencoder_emb.shape[0] quantized_muencoder_emb,_,_=self.rvq_muencoder_emb.from_codes(codes_muencoder_emb) quantized_muencoder_emb = self.cond_muencoder_emb(quantized_muencoder_emb.permute(0,2,1)) # b t 16*32 quantized_muencoder_emb = quantized_muencoder_emb.reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2, 16, 32).reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2*16, 32).permute(0,2,1,3).contiguous() # b 32 t f num_frames = quantized_muencoder_emb.shape[-2] num_channels_latents = self.num_channels latents = self.prepare_latents(batch_size, num_frames, num_channels_latents, dtype, device) bsz, _, height, width = latents.shape resolution = torch.tensor([height, width]).repeat(bsz, 1) aspect_ratio = torch.tensor([float(height / width)]).repeat(bsz, 1) resolution = resolution.to(dtype=quantized_muencoder_emb.dtype, device=device) aspect_ratio = aspect_ratio.to(dtype=quantized_muencoder_emb.dtype, device=device) if classifier_free_guidance: resolution = torch.cat([resolution, resolution], 0) aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], 0) added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} latent_masks = torch.zeros(latents.shape[0], latents.shape[2], dtype=torch.int64, device=latents.device) latent_masks[:,0:latent_length] = 2 if(scenario=='other_seg'): latent_masks[:,0:incontext_length] = 1 quantized_muencoder_emb = (latent_masks > 0.5).unsqueeze(1).unsqueeze(-1) * quantized_muencoder_emb \ + (latent_masks < 0.5).unsqueeze(1).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,32,1,32) true_latents = self.normfeat.project_sample(true_latents) incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(1).unsqueeze(-1).float() incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] additional_model_input = torch.cat([quantized_muencoder_emb],1) temperature = 1.0 t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_muencoder_emb.device) latents = self.cfm_wrapper.solve_euler(latents * temperature, incontext_latents, incontext_length, t_span, additional_model_input, added_cond_kwargs, guidance_scale) latents[:,:,0:incontext_length,:] = incontext_latents[:,:,0:incontext_length,:] latents = self.normfeat.return_sample(latents) return latents @torch.no_grad() def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, disable_progress=True,layer=5,scenario='start_seg'): codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ guidance_scale=guidance_scale, num_steps=num_steps, \ disable_progress=disable_progress,scenario=scenario) return latents def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): divisor = 4 shape = (batch_size, num_channels_latents, num_frames, 32) if(num_frames%divisor>0): num_frames = round(num_frames/float(divisor))*divisor shape = (batch_size, num_channels_latents, num_frames, 32) latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) return latents