| import torch |
| import numpy as np |
| import torch.nn.functional as F |
| import torch.nn as nn |
|
|
| from torch import Tensor, nn, no_grad |
| from .autoencoders import OobleckDecoder, OobleckEncoder |
|
|
| from .transformer import ContinuousTransformer |
| LRELU_SLOPE = 0.1 |
| padding_mode = "zeros" |
| sample_eps = 1e-6 |
|
|
| def vae_sample(mean, scale): |
| stdev = nn.functional.softplus(scale) |
| var = stdev * stdev + sample_eps |
| logvar = torch.log(var) |
| latents = torch.randn_like(mean) * stdev + mean |
|
|
| kl = (mean * mean + var - logvar - 1).sum(1).mean() |
| |
| return latents, kl |
|
|
|
|
| class EAR_VAE(nn.Module): |
|
|
| def __init__(self, model_config: dict = None): |
| super().__init__() |
|
|
| if model_config is None: |
| model_config = { |
| "encoder": { |
| "config": { |
| "in_channels": 2, |
| "channels": 128, |
| "c_mults": [1, 2, 4, 8, 16], |
| "strides": [2, 4, 4, 4, 8], |
| "latent_dim": 128, |
| "use_snake": True |
| } |
| }, |
| "decoder": { |
| "config": { |
| "out_channels": 2, |
| "channels": 128, |
| "c_mults": [1, 2, 4, 8, 16], |
| "strides": [2, 4, 4, 4, 8], |
| "latent_dim": 64, |
| "use_nearest_upsample": False, |
| "use_snake": True, |
| "final_tanh": False, |
| }, |
| }, |
| "latent_dim": 64, |
| "downsampling_ratio": 1024, |
| "io_channels": 2, |
| } |
| else: |
| model_config = model_config |
|
|
| if model_config.get("transformer") is not None: |
| self.transformers = ContinuousTransformer( |
| dim=model_config["decoder"]["config"]["latent_dim"], |
| depth=model_config["transformer"]["depth"], |
| **model_config["transformer"].get("config", {}), |
| ) |
| else: |
| self.transformers = None |
|
|
| self.encoder = OobleckEncoder(**model_config["encoder"]["config"]) |
| self.decoder = OobleckDecoder(**model_config["decoder"]["config"]) |
|
|
| def forward(self, audio) -> Tensor: |
| """ |
| audio: Input audio tensor [B,C,T] |
| """ |
| status = self.encoder(audio) |
| mean, scale = status.chunk(2, dim=1) |
| z, kl = vae_sample(mean, scale) |
| |
| if self.transformers is not None: |
| z = z.permute(0, 2, 1) |
| z = self.transformers(z) |
| z = z.permute(0, 2, 1) |
|
|
| x = self.decoder(z) |
| return x, kl |
|
|
| def encode(self, audio, use_sample=True): |
| x = self.encoder(audio) |
| mean, scale = x.chunk(2, dim=1) |
| if use_sample: |
| z, _ = vae_sample(mean, scale) |
| else: |
| z = mean |
| return z |
|
|
| def decode(self, z): |
| |
| if self.transformers is not None: |
| z = z.permute(0, 2, 1) |
| z = self.transformers(z) |
| z = z.permute(0, 2, 1) |
| |
| x = self.decoder(z) |
| return x |
|
|
| @no_grad() |
| def inference(self, audio): |
| z = self.encode(audio) |
| recon_audio = self.decode(z) |
| return recon_audio |
|
|