cond_gen / MuCodec /model.py
Leon299's picture
Add files using upload-large-folder tool
8337fa0 verified
import yaml
import random
import inspect
import numpy as np
from tqdm import tqdm
import typing as tp
from abc import ABC
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from einops import repeat
from tools.torch_tools import wav_to_fbank
import os
import diffusers
from diffusers.utils.torch_utils import randn_tensor
from diffusers import DDPMScheduler
from models.transformer_2d_flow import Transformer2DModel
from libs.rvq.descript_quantize3 import ResidualVectorQuantize
from torch.cuda.amp import autocast
from muq_dev.test import load_model
class SampleProcessor(torch.nn.Module):
def project_sample(self, x: torch.Tensor):
"""Project the original sample to the 'space' where the diffusion will happen."""
return x
def return_sample(self, z: torch.Tensor):
"""Project back from diffusion space to the actual sample space."""
return z
class Feature2DProcessor(SampleProcessor):
def __init__(self, dim: int = 8, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1., \
num_samples: int = 100_000):
super().__init__()
self.num_samples = num_samples
self.dim = dim
self.power_std = power_std
self.register_buffer('counts', torch.zeros(1))
self.register_buffer('sum_x', torch.zeros(dim, 32))
self.register_buffer('sum_x2', torch.zeros(dim, 32))
self.register_buffer('sum_target_x2', torch.zeros(dim, 32))
self.counts: torch.Tensor
self.sum_x: torch.Tensor
self.sum_x2: torch.Tensor
@property
def mean(self):
mean = self.sum_x / self.counts
return mean
@property
def std(self):
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
return std
@property
def target_std(self):
return 1
def project_sample(self, x: torch.Tensor):
assert x.dim() == 4
if self.counts.item() < self.num_samples:
self.counts += len(x)
self.sum_x += x.mean(dim=(2,)).sum(dim=0)
self.sum_x2 += x.pow(2).mean(dim=(2,)).sum(dim=0)
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
x = (x - self.mean.view(1, -1, 1, 32).contiguous()) * rescale.view(1, -1, 1, 32).contiguous()
return x
def return_sample(self, x: torch.Tensor):
assert x.dim() == 4
rescale = (self.std / self.target_std) ** self.power_std
x = x * rescale.view(1, -1, 1, 32).contiguous() + self.mean.view(1, -1, 1, 32).contiguous()
return x
class BASECFM(torch.nn.Module, ABC):
def __init__(
self,
estimator,
):
super().__init__()
self.sigma_min = 1e-4
self.estimator = estimator
@torch.inference_mode()
def forward(self, mu, n_timesteps, temperature=1.0):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_channels, mel_timesteps, n_feats)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_channels, mel_timesteps, n_feats)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
return self.solve_euler(z, t_span=t_span)
def solve_euler(self, x, incontext_x, incontext_length, t_span, mu, added_cond_kwargs, guidance_scale):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_channels, mel_timesteps, n_feats)
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
noise = x.clone()
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in tqdm(range(1, len(t_span))):
x[:,:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,:,0:incontext_length,:] + t * incontext_x[:,:,0:incontext_length,:]
if(guidance_scale > 1.0):
dphi_dt = self.estimator( \
torch.cat([ \
torch.cat([x, x], 0), \
torch.cat([incontext_x, incontext_x], 0), \
torch.cat([torch.zeros_like(mu), mu], 0), \
], 1), \
timestep = t.unsqueeze(-1).repeat(2), \
added_cond_kwargs={k:torch.cat([v,v],0) for k,v in added_cond_kwargs.items()}).sample
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
else:
dphi_dt = self.estimator(torch.cat([x, incontext_x, mu], 1), \
timestep = t.unsqueeze(-1),
added_cond_kwargs=added_cond_kwargs).sample
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
class PromptCondAudioDiffusion(nn.Module):
def __init__(
self,
num_channels,
unet_model_name=None,
unet_model_config_path=None,
snr_gamma=None,
uncondition=True,
out_paint=False,
):
super().__init__()
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
self.unet_model_name = unet_model_name
self.unet_model_config_path = unet_model_config_path
self.snr_gamma = snr_gamma
self.uncondition = uncondition
self.num_channels = num_channels
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
self.normfeat = Feature2DProcessor(dim=num_channels)
self.sample_rate = 48000
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
muencoder_dir = "muq_dev/muq_fairseq"
muencoder_ckpt = "muq_dev/muq.pt"
self.muencoder = load_model(
model_dir=os.path.abspath(muencoder_dir),
checkpoint_dir=os.path.abspath(muencoder_ckpt),
)
self.rsq48tomuencoder = torchaudio.transforms.Resample(48000, 24000)
for v in self.muencoder.parameters():v.requires_grad = False
self.rvq_muencoder_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
self.cond_muencoder_emb = nn.Linear(1024, 16*32)
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
unet = Transformer2DModel.from_config(
unet_model_config_path,
)
self.set_from = "random"
self.cfm_wrapper = BASECFM(unet)
print("Transformer initialized from pretrain.")
def compute_snr(self, timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = self.noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
def preprocess_audio(self, input_audios, threshold=0.9):
assert len(input_audios.shape) == 2, input_audios.shape
norm_value = torch.ones_like(input_audios[:,0])
max_volume = input_audios.abs().max(dim=-1)[0]
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
return input_audios/norm_value.unsqueeze(-1)
def extract_muencoder_embeds(self, input_audio_0,input_audio_1,layer):
input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
input_wav_mean = self.muencoder(self.rsq48tomuencoder(input_wav_mean), features_only = True)
layer_results = input_wav_mean['layer_results']
muencoder_emb = layer_results[layer]
muencoder_emb = muencoder_emb.permute(0,2,1).contiguous()
return muencoder_emb
def init_device_dtype(self, device, dtype):
self.device = device
self.dtype = dtype
@torch.no_grad()
def fetch_codes(self, input_audios, additional_feats,layer):
input_audio_0 = input_audios[[0],:]
input_audio_1 = input_audios[[1],:]
input_audio_0 = self.preprocess_audio(input_audio_0)
input_audio_1 = self.preprocess_audio(input_audio_1)
self.muencoder.eval()
muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer)
muencoder_emb = muencoder_emb.detach()
self.rvq_muencoder_emb.eval()
quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb)
spk_embeds = None
return [codes_muencoder_emb], [muencoder_emb], spk_embeds
@torch.no_grad()
def fetch_codes_batch(self, input_audios, additional_feats,layer):
input_audio_0 = input_audios[:,0,:]
input_audio_1 = input_audios[:,1,:]
input_audio_0 = self.preprocess_audio(input_audio_0)
input_audio_1 = self.preprocess_audio(input_audio_1)
self.muencoder.eval()
muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer)
muencoder_emb = muencoder_emb.detach()
self.rvq_muencoder_emb.eval()
quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb) # b,d,t
spk_embeds = None
return [codes_muencoder_emb], [muencoder_emb], spk_embeds
@torch.no_grad()
def inference_codes(self, codes, spk_embeds, true_latents, latent_length,incontext_length, additional_feats,
guidance_scale=2, num_steps=20,
disable_progress=True, scenario='start_seg'):
classifier_free_guidance = guidance_scale > 1.0
device = self.device
dtype = self.dtype
codes_muencoder_emb = codes[0]
batch_size = codes_muencoder_emb.shape[0]
quantized_muencoder_emb,_,_=self.rvq_muencoder_emb.from_codes(codes_muencoder_emb)
quantized_muencoder_emb = self.cond_muencoder_emb(quantized_muencoder_emb.permute(0,2,1)) # b t 16*32
quantized_muencoder_emb = quantized_muencoder_emb.reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2, 16, 32).reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2*16, 32).permute(0,2,1,3).contiguous() # b 32 t f
num_frames = quantized_muencoder_emb.shape[-2]
num_channels_latents = self.num_channels
latents = self.prepare_latents(batch_size, num_frames, num_channels_latents, dtype, device)
bsz, _, height, width = latents.shape
resolution = torch.tensor([height, width]).repeat(bsz, 1)
aspect_ratio = torch.tensor([float(height / width)]).repeat(bsz, 1)
resolution = resolution.to(dtype=quantized_muencoder_emb.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=quantized_muencoder_emb.dtype, device=device)
if classifier_free_guidance:
resolution = torch.cat([resolution, resolution], 0)
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], 0)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
latent_masks = torch.zeros(latents.shape[0], latents.shape[2], dtype=torch.int64, device=latents.device)
latent_masks[:,0:latent_length] = 2
if(scenario=='other_seg'):
latent_masks[:,0:incontext_length] = 1
quantized_muencoder_emb = (latent_masks > 0.5).unsqueeze(1).unsqueeze(-1) * quantized_muencoder_emb \
+ (latent_masks < 0.5).unsqueeze(1).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,32,1,32)
true_latents = self.normfeat.project_sample(true_latents)
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(1).unsqueeze(-1).float()
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
additional_model_input = torch.cat([quantized_muencoder_emb],1)
temperature = 1.0
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_muencoder_emb.device)
latents = self.cfm_wrapper.solve_euler(latents * temperature, incontext_latents, incontext_length, t_span, additional_model_input, added_cond_kwargs, guidance_scale)
latents[:,:,0:incontext_length,:] = incontext_latents[:,:,0:incontext_length,:]
latents = self.normfeat.return_sample(latents)
return latents
@torch.no_grad()
def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
disable_progress=True,layer=5,scenario='start_seg'):
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
guidance_scale=guidance_scale, num_steps=num_steps, \
disable_progress=disable_progress,scenario=scenario)
return latents
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
divisor = 4
shape = (batch_size, num_channels_latents, num_frames, 32)
if(num_frames%divisor>0):
num_frames = round(num_frames/float(divisor))*divisor
shape = (batch_size, num_channels_latents, num_frames, 32)
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
return latents