|
|
| import soundfile as sf |
| import os |
| from librosa.filters import mel as librosa_mel_fn |
| import sys |
| sys.path.append(os.path.join(os.path.dirname(__file__), "..")) |
| import tools.torch_tools as torch_tools |
| import torch.nn as nn |
| import torch |
| import numpy as np |
| from einops import rearrange |
| from scipy.signal import get_window |
| from librosa.util import pad_center, tiny |
| import librosa.util as librosa_util |
|
|
| class AttrDict(dict): |
| def __init__(self, *args, **kwargs): |
| super(AttrDict, self).__init__(*args, **kwargs) |
| self.__dict__ = self |
|
|
| def init_weights(m, mean=0.0, std=0.01): |
| classname = m.__class__.__name__ |
| if classname.find("Conv") != -1: |
| m.weight.data.normal_(mean, std) |
|
|
|
|
| def get_padding(kernel_size, dilation=1): |
| return int((kernel_size * dilation - dilation) / 2) |
|
|
| LRELU_SLOPE = 0.1 |
|
|
| class ResBlock(torch.nn.Module): |
| def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): |
| super(ResBlock, self).__init__() |
| self.h = h |
| self.convs1 = nn.ModuleList( |
| [ |
| torch.nn.utils.weight_norm( |
| nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=dilation[0], |
| padding=get_padding(kernel_size, dilation[0]), |
| ) |
| ), |
| torch.nn.utils.weight_norm( |
| nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=dilation[1], |
| padding=get_padding(kernel_size, dilation[1]), |
| ) |
| ), |
| torch.nn.utils.weight_norm( |
| nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=dilation[2], |
| padding=get_padding(kernel_size, dilation[2]), |
| ) |
| ), |
| ] |
| ) |
| self.convs1.apply(init_weights) |
|
|
| self.convs2 = nn.ModuleList( |
| [ |
| torch.nn.utils.weight_norm( |
| nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=1, |
| padding=get_padding(kernel_size, 1), |
| ) |
| ), |
| torch.nn.utils.weight_norm( |
| nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=1, |
| padding=get_padding(kernel_size, 1), |
| ) |
| ), |
| torch.nn.utils.weight_norm( |
| nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=1, |
| padding=get_padding(kernel_size, 1), |
| ) |
| ), |
| ] |
| ) |
| self.convs2.apply(init_weights) |
|
|
| def forward(self, x): |
| for c1, c2 in zip(self.convs1, self.convs2): |
| xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) |
| xt = c1(xt) |
| xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) |
| xt = c2(xt) |
| x = xt + x |
| return x |
|
|
| def remove_weight_norm(self): |
| for l in self.convs1: |
| torch.nn.utils.remove_weight_norm(l) |
| for l in self.convs2: |
| torch.nn.utils.remove_weight_norm(l) |
|
|
|
|
| class Generator_old(torch.nn.Module): |
| def __init__(self, h): |
| super(Generator_old, self).__init__() |
| self.h = h |
| self.num_kernels = len(h.resblock_kernel_sizes) |
| self.num_upsamples = len(h.upsample_rates) |
| self.conv_pre = torch.nn.utils.weight_norm( |
| nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) |
| ) |
| resblock = ResBlock |
|
|
| self.ups = nn.ModuleList() |
| for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): |
| self.ups.append( |
| torch.nn.utils.weight_norm( |
| nn.ConvTranspose1d( |
| h.upsample_initial_channel // (2**i), |
| h.upsample_initial_channel // (2 ** (i + 1)), |
| k, |
| u, |
| padding=(k - u) // 2, |
| ) |
| ) |
| ) |
|
|
| self.resblocks = nn.ModuleList() |
| for i in range(len(self.ups)): |
| ch = h.upsample_initial_channel // (2 ** (i + 1)) |
| for j, (k, d) in enumerate( |
| zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) |
| ): |
| self.resblocks.append(resblock(h, ch, k, d)) |
|
|
| self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3)) |
| self.ups.apply(init_weights) |
| self.conv_post.apply(init_weights) |
|
|
| def forward(self, x): |
| x = self.conv_pre(x) |
| for i in range(self.num_upsamples): |
| x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) |
| x = self.ups[i](x) |
| xs = None |
| for j in range(self.num_kernels): |
| if xs is None: |
| xs = self.resblocks[i * self.num_kernels + j](x) |
| else: |
| xs += self.resblocks[i * self.num_kernels + j](x) |
| x = xs / self.num_kernels |
| x = torch.nn.functional.leaky_relu(x) |
| x = self.conv_post(x) |
| x = torch.tanh(x) |
|
|
| return x |
|
|
| def remove_weight_norm(self): |
| |
| for l in self.ups: |
| torch.nn.utils.remove_weight_norm(l) |
| for l in self.resblocks: |
| l.remove_weight_norm() |
| torch.nn.utils.remove_weight_norm(self.conv_pre) |
| torch.nn.utils.remove_weight_norm(self.conv_post) |
|
|
|
|
|
|
| def nonlinearity(x): |
| |
| return x * torch.sigmoid(x) |
|
|
|
|
| def Normalize(in_channels, num_groups=32): |
| return torch.nn.GroupNorm( |
| num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True |
| ) |
|
|
| class Downsample(nn.Module): |
| def __init__(self, in_channels, with_conv): |
| super().__init__() |
| self.with_conv = with_conv |
| if self.with_conv: |
| |
| |
| self.conv = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=3, stride=2, padding=0 |
| ) |
|
|
| def forward(self, x): |
| if self.with_conv: |
| pad = (0, 1, 0, 1) |
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
| x = self.conv(x) |
| else: |
| x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) |
| return x |
|
|
|
|
| class DownsampleTimeStride4(nn.Module): |
| def __init__(self, in_channels, with_conv): |
| super().__init__() |
| self.with_conv = with_conv |
| if self.with_conv: |
| |
| |
| self.conv = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 |
| ) |
|
|
| def forward(self, x): |
| if self.with_conv: |
| pad = (0, 1, 0, 1) |
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
| x = self.conv(x) |
| else: |
| x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) |
| return x |
| |
| class Upsample(nn.Module): |
| def __init__(self, in_channels, with_conv): |
| super().__init__() |
| self.with_conv = with_conv |
| if self.with_conv: |
| self.conv = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=3, stride=1, padding=1 |
| ) |
|
|
| def forward(self, x): |
| x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
| if self.with_conv: |
| x = self.conv(x) |
| return x |
|
|
|
|
| class UpsampleTimeStride4(nn.Module): |
| def __init__(self, in_channels, with_conv): |
| super().__init__() |
| self.with_conv = with_conv |
| if self.with_conv: |
| self.conv = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=5, stride=1, padding=2 |
| ) |
|
|
| def forward(self, x): |
| x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") |
| if self.with_conv: |
| x = self.conv(x) |
| return x |
|
|
| class AttnBlock(nn.Module): |
| def __init__(self, in_channels): |
| super().__init__() |
| self.in_channels = in_channels |
|
|
| self.norm = Normalize(in_channels) |
| self.q = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
| self.k = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
| self.v = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
| self.proj_out = torch.nn.Conv2d( |
| in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| ) |
|
|
| def forward(self, x): |
| h_ = x |
| h_ = self.norm(h_) |
| q = self.q(h_) |
| k = self.k(h_) |
| v = self.v(h_) |
|
|
| |
| b, c, h, w = q.shape |
| q = q.reshape(b, c, h * w).contiguous() |
| q = q.permute(0, 2, 1).contiguous() |
| k = k.reshape(b, c, h * w).contiguous() |
| w_ = torch.bmm(q, k).contiguous() |
| w_ = w_ * (int(c) ** (-0.5)) |
| w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
| |
| v = v.reshape(b, c, h * w).contiguous() |
| w_ = w_.permute(0, 2, 1).contiguous() |
| h_ = torch.bmm( |
| v, w_ |
| ).contiguous() |
| h_ = h_.reshape(b, c, h, w).contiguous() |
|
|
| h_ = self.proj_out(h_) |
|
|
| return x + h_ |
|
|
|
|
| def make_attn(in_channels, attn_type="vanilla"): |
| assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" |
| |
| if attn_type == "vanilla": |
| return AttnBlock(in_channels) |
| elif attn_type == "none": |
| return nn.Identity(in_channels) |
| else: |
| raise ValueError(attn_type) |
|
|
|
|
| class ResnetBlock(nn.Module): |
| def __init__( |
| self, |
| *, |
| in_channels, |
| out_channels=None, |
| conv_shortcut=False, |
| dropout, |
| temb_channels=512, |
| ): |
| super().__init__() |
| self.in_channels = in_channels |
| out_channels = in_channels if out_channels is None else out_channels |
| self.out_channels = out_channels |
| self.use_conv_shortcut = conv_shortcut |
|
|
| self.norm1 = Normalize(in_channels) |
| self.conv1 = torch.nn.Conv2d( |
| in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
| ) |
| if temb_channels > 0: |
| self.temb_proj = torch.nn.Linear(temb_channels, out_channels) |
| self.norm2 = Normalize(out_channels) |
| self.dropout = torch.nn.Dropout(dropout) |
| self.conv2 = torch.nn.Conv2d( |
| out_channels, out_channels, kernel_size=3, stride=1, padding=1 |
| ) |
| if self.in_channels != self.out_channels: |
| if self.use_conv_shortcut: |
| self.conv_shortcut = torch.nn.Conv2d( |
| in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
| ) |
| else: |
| self.nin_shortcut = torch.nn.Conv2d( |
| in_channels, out_channels, kernel_size=1, stride=1, padding=0 |
| ) |
|
|
| def forward(self, x, temb): |
| h = x |
| h = self.norm1(h) |
| h = nonlinearity(h) |
| h = self.conv1(h) |
|
|
| if temb is not None: |
| h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] |
|
|
| h = self.norm2(h) |
| h = nonlinearity(h) |
| h = self.dropout(h) |
| h = self.conv2(h) |
|
|
| if self.in_channels != self.out_channels: |
| if self.use_conv_shortcut: |
| x = self.conv_shortcut(x) |
| else: |
| x = self.nin_shortcut(x) |
|
|
| return x + h |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__( |
| self, |
| *, |
| ch, |
| out_ch, |
| ch_mult=(1, 2, 4, 8), |
| num_res_blocks, |
| attn_resolutions, |
| dropout=0.0, |
| resamp_with_conv=True, |
| in_channels, |
| resolution, |
| z_channels, |
| double_z=True, |
| use_linear_attn=False, |
| attn_type="vanilla", |
| downsample_time_stride4_levels=[], |
| **ignore_kwargs, |
| ): |
| super().__init__() |
| if use_linear_attn: |
| attn_type = "linear" |
| self.ch = ch |
| self.temb_ch = 0 |
| self.num_resolutions = len(ch_mult) |
| self.num_res_blocks = num_res_blocks |
| self.resolution = resolution |
| self.in_channels = in_channels |
| self.downsample_time_stride4_levels = downsample_time_stride4_levels |
|
|
| if len(self.downsample_time_stride4_levels) > 0: |
| assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( |
| "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" |
| % str(self.num_resolutions) |
| ) |
|
|
| |
| self.conv_in = torch.nn.Conv2d( |
| in_channels, self.ch, kernel_size=3, stride=1, padding=1 |
| ) |
|
|
| curr_res = resolution |
| in_ch_mult = (1,) + tuple(ch_mult) |
| self.in_ch_mult = in_ch_mult |
| self.down = nn.ModuleList() |
| for i_level in range(self.num_resolutions): |
| block = nn.ModuleList() |
| attn = nn.ModuleList() |
| block_in = ch * in_ch_mult[i_level] |
| block_out = ch * ch_mult[i_level] |
| for i_block in range(self.num_res_blocks): |
| block.append( |
| ResnetBlock( |
| in_channels=block_in, |
| out_channels=block_out, |
| temb_channels=self.temb_ch, |
| dropout=dropout, |
| ) |
| ) |
| block_in = block_out |
| if curr_res in attn_resolutions: |
| attn.append(make_attn(block_in, attn_type=attn_type)) |
| down = nn.Module() |
| down.block = block |
| down.attn = attn |
| if i_level != self.num_resolutions - 1: |
| if i_level in self.downsample_time_stride4_levels: |
| down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) |
| else: |
| down.downsample = Downsample(block_in, resamp_with_conv) |
| curr_res = curr_res // 2 |
| self.down.append(down) |
|
|
| |
| self.mid = nn.Module() |
| self.mid.block_1 = ResnetBlock( |
| in_channels=block_in, |
| out_channels=block_in, |
| temb_channels=self.temb_ch, |
| dropout=dropout, |
| ) |
| self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) |
| self.mid.block_2 = ResnetBlock( |
| in_channels=block_in, |
| out_channels=block_in, |
| temb_channels=self.temb_ch, |
| dropout=dropout, |
| ) |
|
|
| |
| self.norm_out = Normalize(block_in) |
| self.conv_out = torch.nn.Conv2d( |
| block_in, |
| 2 * z_channels if double_z else z_channels, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ) |
|
|
| def forward(self, x): |
| |
| temb = None |
| |
| hs = [self.conv_in(x)] |
| for i_level in range(self.num_resolutions): |
| for i_block in range(self.num_res_blocks): |
| h = self.down[i_level].block[i_block](hs[-1], temb) |
| if len(self.down[i_level].attn) > 0: |
| h = self.down[i_level].attn[i_block](h) |
| hs.append(h) |
| if i_level != self.num_resolutions - 1: |
| hs.append(self.down[i_level].downsample(hs[-1])) |
|
|
| |
| h = hs[-1] |
| h = self.mid.block_1(h, temb) |
| h = self.mid.attn_1(h) |
| h = self.mid.block_2(h, temb) |
|
|
| |
| h = self.norm_out(h) |
| h = nonlinearity(h) |
| h = self.conv_out(h) |
| return h |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__( |
| self, |
| *, |
| ch, |
| out_ch, |
| ch_mult=(1, 2, 4, 8), |
| num_res_blocks, |
| attn_resolutions, |
| dropout=0.0, |
| resamp_with_conv=True, |
| in_channels, |
| resolution, |
| z_channels, |
| give_pre_end=False, |
| tanh_out=False, |
| use_linear_attn=False, |
| downsample_time_stride4_levels=[], |
| attn_type="vanilla", |
| **ignorekwargs, |
| ): |
| super().__init__() |
| if use_linear_attn: |
| attn_type = "linear" |
| self.ch = ch |
| self.temb_ch = 0 |
| self.num_resolutions = len(ch_mult) |
| self.num_res_blocks = num_res_blocks |
| self.resolution = resolution |
| self.in_channels = in_channels |
| self.give_pre_end = give_pre_end |
| self.tanh_out = tanh_out |
| self.downsample_time_stride4_levels = downsample_time_stride4_levels |
|
|
| if len(self.downsample_time_stride4_levels) > 0: |
| assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( |
| "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" |
| % str(self.num_resolutions) |
| ) |
|
|
| |
| (1,) + tuple(ch_mult) |
| block_in = ch * ch_mult[self.num_resolutions - 1] |
| curr_res = resolution // 2 ** (self.num_resolutions - 1) |
| self.z_shape = (1, z_channels, curr_res, curr_res) |
| |
| |
| |
| |
| |
|
|
| |
| self.conv_in = torch.nn.Conv2d( |
| z_channels, block_in, kernel_size=3, stride=1, padding=1 |
| ) |
|
|
| |
| self.mid = nn.Module() |
| self.mid.block_1 = ResnetBlock( |
| in_channels=block_in, |
| out_channels=block_in, |
| temb_channels=self.temb_ch, |
| dropout=dropout, |
| ) |
| self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) |
| self.mid.block_2 = ResnetBlock( |
| in_channels=block_in, |
| out_channels=block_in, |
| temb_channels=self.temb_ch, |
| dropout=dropout, |
| ) |
|
|
| |
| self.up = nn.ModuleList() |
| for i_level in reversed(range(self.num_resolutions)): |
| block = nn.ModuleList() |
| attn = nn.ModuleList() |
| block_out = ch * ch_mult[i_level] |
| for i_block in range(self.num_res_blocks + 1): |
| block.append( |
| ResnetBlock( |
| in_channels=block_in, |
| out_channels=block_out, |
| temb_channels=self.temb_ch, |
| dropout=dropout, |
| ) |
| ) |
| block_in = block_out |
| if curr_res in attn_resolutions: |
| attn.append(make_attn(block_in, attn_type=attn_type)) |
| up = nn.Module() |
| up.block = block |
| up.attn = attn |
| if i_level != 0: |
| if i_level - 1 in self.downsample_time_stride4_levels: |
| up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) |
| else: |
| up.upsample = Upsample(block_in, resamp_with_conv) |
| curr_res = curr_res * 2 |
| self.up.insert(0, up) |
|
|
| |
| self.norm_out = Normalize(block_in) |
| self.conv_out = torch.nn.Conv2d( |
| block_in, out_ch, kernel_size=3, stride=1, padding=1 |
| ) |
|
|
| def forward(self, z): |
| |
| self.last_z_shape = z.shape |
|
|
| |
| temb = None |
|
|
| |
| h = self.conv_in(z) |
|
|
| |
| h = self.mid.block_1(h, temb) |
| h = self.mid.attn_1(h) |
| h = self.mid.block_2(h, temb) |
|
|
| |
| for i_level in reversed(range(self.num_resolutions)): |
| for i_block in range(self.num_res_blocks + 1): |
| h = self.up[i_level].block[i_block](h, temb) |
| if len(self.up[i_level].attn) > 0: |
| h = self.up[i_level].attn[i_block](h) |
| if i_level != 0: |
| h = self.up[i_level].upsample(h) |
|
|
| |
| if self.give_pre_end: |
| return h |
|
|
| h = self.norm_out(h) |
| h = nonlinearity(h) |
| h = self.conv_out(h) |
| if self.tanh_out: |
| h = torch.tanh(h) |
| return h |
|
|
|
|
| class DiagonalGaussianDistribution(object): |
| def __init__(self, parameters, deterministic=False): |
| self.parameters = parameters |
| self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) |
| self.logvar = torch.clamp(self.logvar, -30.0, 20.0) |
| self.deterministic = deterministic |
| self.std = torch.exp(0.5 * self.logvar) |
| self.var = torch.exp(self.logvar) |
| if self.deterministic: |
| self.var = self.std = torch.zeros_like(self.mean).to( |
| device=self.parameters.device |
| ) |
|
|
| def sample(self): |
| x = self.mean + self.std * torch.randn(self.mean.shape).to( |
| device=self.parameters.device |
| ) |
| return x |
|
|
| def kl(self, other=None): |
| if self.deterministic: |
| return torch.Tensor([0.0]) |
| else: |
| if other is None: |
| return 0.5 * torch.mean( |
| torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, |
| dim=[1, 2, 3], |
| ) |
| else: |
| return 0.5 * torch.mean( |
| torch.pow(self.mean - other.mean, 2) / other.var |
| + self.var / other.var |
| - 1.0 |
| - self.logvar |
| + other.logvar, |
| dim=[1, 2, 3], |
| ) |
|
|
| def nll(self, sample, dims=[1, 2, 3]): |
| if self.deterministic: |
| return torch.Tensor([0.0]) |
| logtwopi = np.log(2.0 * np.pi) |
| return 0.5 * torch.sum( |
| logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, |
| dim=dims, |
| ) |
|
|
| def mode(self): |
| return self.mean |
|
|
| def get_vocoder_config_48k(): |
| return { |
| "resblock": "1", |
| "num_gpus": 8, |
| "batch_size": 128, |
| "learning_rate": 0.0001, |
| "adam_b1": 0.8, |
| "adam_b2": 0.99, |
| "lr_decay": 0.999, |
| "seed": 1234, |
|
|
| "upsample_rates": [6,5,4,2,2], |
| "upsample_kernel_sizes": [12,10,8,4,4], |
| "upsample_initial_channel": 1536, |
| "resblock_kernel_sizes": [3,7,11,15], |
| "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]], |
|
|
| "segment_size": 15360, |
| "num_mels": 256, |
| "n_fft": 2048, |
| "hop_size": 480, |
| "win_size": 2048, |
|
|
| "sampling_rate": 48000, |
|
|
| "fmin": 20, |
| "fmax": 24000, |
| "fmax_for_loss": None, |
|
|
| "num_workers": 8, |
|
|
| "dist_config": { |
| "dist_backend": "nccl", |
| "dist_url": "tcp://localhost:18273", |
| "world_size": 1 |
| } |
| } |
|
|
| def get_vocoder(config, device, mel_bins): |
| name = "HiFi-GAN" |
| speaker = "" |
| if name == "MelGAN": |
| if speaker == "LJSpeech": |
| vocoder = torch.hub.load( |
| "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" |
| ) |
| elif speaker == "universal": |
| vocoder = torch.hub.load( |
| "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" |
| ) |
| vocoder.mel2wav.eval() |
| vocoder.mel2wav.to(device) |
| elif name == "HiFi-GAN": |
| if(mel_bins == 256): |
| config = get_vocoder_config_48k() |
| config = AttrDict(config) |
| vocoder = Generator_old(config) |
| |
| |
| |
| |
| |
| vocoder.eval() |
| vocoder.remove_weight_norm() |
| vocoder.to(device) |
| else: |
| raise ValueError(mel_bins) |
| return vocoder |
|
|
| def vocoder_infer(mels, vocoder, lengths=None): |
| with torch.no_grad(): |
| wavs = vocoder(mels).squeeze(1) |
|
|
| |
| wavs = (wavs.cpu().numpy()) |
|
|
| if lengths is not None: |
| wavs = wavs[:, :lengths] |
|
|
| |
|
|
| |
| |
| |
|
|
| return wavs |
|
|
| @torch.no_grad() |
| def vocoder_chunk_infer(mels, vocoder, lengths=None): |
| chunk_size = 256*4 |
| shift_size = 256*1 |
| ov_size = chunk_size-shift_size |
| |
|
|
| for cinx in range(0, mels.shape[2], shift_size): |
| if(cinx==0): |
| wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).cpu() |
| num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size |
| wavs = wavs[:,0:num_samples] |
| ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size) |
| ov_win = torch.from_numpy(np.linspace(0,1,ov_sample)[None,:]) |
| ov_win = torch.cat([ov_win,1-ov_win],-1) |
| if(cinx+chunk_size>=mels.shape[2]): |
| break |
| else: |
| cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).cpu()[:,0:num_samples] |
| wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample] |
| |
| wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1) |
| if(cinx+chunk_size>=mels.shape[2]): |
| break |
| |
|
|
| wavs = (wavs.cpu().numpy()) |
|
|
| if lengths is not None: |
| wavs = wavs[:, :lengths] |
| |
| return wavs |
|
|
| def synth_one_sample(mel_input, mel_prediction, labels, vocoder): |
| if vocoder is not None: |
|
|
| wav_reconstruction = vocoder_infer( |
| mel_input.permute(0, 2, 1), |
| vocoder, |
| ) |
| wav_prediction = vocoder_infer( |
| mel_prediction.permute(0, 2, 1), |
| vocoder, |
| ) |
| else: |
| wav_reconstruction = wav_prediction = None |
|
|
| return wav_reconstruction, wav_prediction |
|
|
|
|
| class AutoencoderKL(nn.Module): |
| def __init__( |
| self, |
| ddconfig=None, |
| lossconfig=None, |
| batchsize=None, |
| embed_dim=None, |
| time_shuffle=1, |
| subband=1, |
| sampling_rate=16000, |
| ckpt_path=None, |
| reload_from_ckpt=None, |
| ignore_keys=[], |
| image_key="fbank", |
| colorize_nlabels=None, |
| monitor=None, |
| base_learning_rate=1e-5, |
| scale_factor=1 |
| ): |
| super().__init__() |
| self.automatic_optimization = False |
| assert ( |
| "mel_bins" in ddconfig.keys() |
| ), "mel_bins is not specified in the Autoencoder config" |
| num_mel = ddconfig["mel_bins"] |
| self.image_key = image_key |
| self.sampling_rate = sampling_rate |
| self.encoder = Encoder(**ddconfig) |
| self.decoder = Decoder(**ddconfig) |
|
|
| self.loss = None |
| self.subband = int(subband) |
|
|
| if self.subband > 1: |
| print("Use subband decomposition %s" % self.subband) |
|
|
| assert ddconfig["double_z"] |
| self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) |
| self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) |
|
|
| if self.image_key == "fbank": |
| self.vocoder = get_vocoder(None, "cpu", num_mel) |
| self.embed_dim = embed_dim |
| if colorize_nlabels is not None: |
| assert type(colorize_nlabels) == int |
| self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) |
| if monitor is not None: |
| self.monitor = monitor |
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
| self.learning_rate = float(base_learning_rate) |
| |
|
|
| self.time_shuffle = time_shuffle |
| self.reload_from_ckpt = reload_from_ckpt |
| self.reloaded = False |
| self.mean, self.std = None, None |
|
|
| self.feature_cache = None |
| self.flag_first_run = True |
| self.train_step = 0 |
|
|
| self.logger_save_dir = None |
| self.logger_exp_name = None |
| self.scale_factor = scale_factor |
|
|
| print("Num parameters:") |
| print("Encoder : ", sum(p.numel() for p in self.encoder.parameters())) |
| print("Decoder : ", sum(p.numel() for p in self.decoder.parameters())) |
| print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters())) |
|
|
| def get_log_dir(self): |
| if self.logger_save_dir is None and self.logger_exp_name is None: |
| return os.path.join(self.logger.save_dir, self.logger._project) |
| else: |
| return os.path.join(self.logger_save_dir, self.logger_exp_name) |
|
|
| def set_log_dir(self, save_dir, exp_name): |
| self.logger_save_dir = save_dir |
| self.logger_exp_name = exp_name |
|
|
| def init_from_ckpt(self, path, ignore_keys=list()): |
| sd = torch.load(path, map_location="cpu")["state_dict"] |
| keys = list(sd.keys()) |
| for k in keys: |
| for ik in ignore_keys: |
| if k.startswith(ik): |
| print("Deleting key {} from state_dict.".format(k)) |
| del sd[k] |
| self.load_state_dict(sd, strict=False) |
| print(f"Restored from {path}") |
|
|
| def encode(self, x): |
| |
| |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
|
|
| def decode(self, z): |
| z = self.post_quant_conv(z) |
| dec = self.decoder(z) |
| |
| |
| |
| return dec |
|
|
| def decode_to_waveform(self, dec): |
|
|
| if self.image_key == "fbank": |
| dec = dec.squeeze(1).permute(0, 2, 1) |
| wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder) |
| elif self.image_key == "stft": |
| dec = dec.squeeze(1).permute(0, 2, 1) |
| wav_reconstruction = self.wave_decoder(dec) |
| return wav_reconstruction |
|
|
| def mel_spectrogram_to_waveform( |
| self, mel, savepath=".", bs=None, name="outwav", save=True |
| ): |
| |
| if len(mel.size()) == 4: |
| mel = mel.squeeze(1) |
| mel = mel.permute(0, 2, 1) |
| waveform = self.vocoder(mel) |
| waveform = waveform.cpu().detach().numpy() |
| |
| |
| return waveform |
|
|
| @torch.no_grad() |
| def encode_first_stage(self, x): |
| return self.encode(x) |
| |
| @torch.no_grad() |
| def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): |
| if predict_cids: |
| if z.dim() == 4: |
| z = torch.argmax(z.exp(), dim=1).long() |
| z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) |
| z = rearrange(z, "b h w c -> b c h w").contiguous() |
|
|
| z = 1.0 / self.scale_factor * z |
| return self.decode(z) |
|
|
| def decode_first_stage_withgrad(self, z): |
| z = 1.0 / self.scale_factor * z |
| return self.decode(z) |
|
|
| def get_first_stage_encoding(self, encoder_posterior, use_mode=False): |
| if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode: |
| z = encoder_posterior.sample() |
| elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode: |
| z = encoder_posterior.mode() |
| elif isinstance(encoder_posterior, torch.Tensor): |
| z = encoder_posterior |
| else: |
| raise NotImplementedError( |
| f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" |
| ) |
| return self.scale_factor * z |
|
|
| def visualize_latent(self, input): |
| import matplotlib.pyplot as plt |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| np.save("input.npy", input.cpu().detach().numpy()) |
| |
| time_input = input.clone() |
| time_input[:, :, :, :32] *= 0 |
| time_input[:, :, :, :32] -= 11.59 |
|
|
| np.save("time_input.npy", time_input.cpu().detach().numpy()) |
|
|
| posterior = self.encode(time_input) |
| latent = posterior.sample() |
| np.save("time_latent.npy", latent.cpu().detach().numpy()) |
| avg_latent = torch.mean(latent, dim=1) |
| for i in range(avg_latent.size(0)): |
| plt.imshow(avg_latent[i].cpu().detach().numpy().T) |
| plt.savefig("freq_%s.png" % i) |
| plt.close() |
|
|
| freq_input = input.clone() |
| freq_input[:, :, :512, :] *= 0 |
| freq_input[:, :, :512, :] -= 11.59 |
|
|
| np.save("freq_input.npy", freq_input.cpu().detach().numpy()) |
|
|
| posterior = self.encode(freq_input) |
| latent = posterior.sample() |
| np.save("freq_latent.npy", latent.cpu().detach().numpy()) |
| avg_latent = torch.mean(latent, dim=1) |
| for i in range(avg_latent.size(0)): |
| plt.imshow(avg_latent[i].cpu().detach().numpy().T) |
| plt.savefig("time_%s.png" % i) |
| plt.close() |
|
|
| def get_input(self, batch): |
| fname, text, label_indices, waveform, stft, fbank = ( |
| batch["fname"], |
| batch["text"], |
| batch["label_vector"], |
| batch["waveform"], |
| batch["stft"], |
| batch["log_mel_spec"], |
| ) |
| |
| |
| |
| |
|
|
| ret = {} |
|
|
| ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( |
| fbank.unsqueeze(1), |
| stft.unsqueeze(1), |
| fname, |
| waveform.unsqueeze(1), |
| ) |
|
|
| return ret |
|
|
| def save_wave(self, batch_wav, fname, save_dir): |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| for wav, name in zip(batch_wav, fname): |
| name = os.path.basename(name) |
|
|
| sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) |
|
|
| def get_last_layer(self): |
| return self.decoder.conv_out.weight |
|
|
| @torch.no_grad() |
| def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): |
| log = dict() |
| x = batch.to(self.device) |
| if not only_inputs: |
| xrec, posterior = self(x) |
| log["samples"] = self.decode(posterior.sample()) |
| log["reconstructions"] = xrec |
|
|
| log["inputs"] = x |
| wavs = self._log_img(log, train=train, index=0, waveform=waveform) |
| return wavs |
|
|
| def _log_img(self, log, train=True, index=0, waveform=None): |
| images_input = self.tensor2numpy(log["inputs"][index, 0]).T |
| images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T |
| images_samples = self.tensor2numpy(log["samples"][index, 0]).T |
|
|
| if train: |
| name = "train" |
| else: |
| name = "val" |
|
|
| if self.logger is not None: |
| self.logger.log_image( |
| "img_%s" % name, |
| [images_input, images_reconstruct, images_samples], |
| caption=["input", "reconstruct", "samples"], |
| ) |
|
|
| inputs, reconstructions, samples = ( |
| log["inputs"], |
| log["reconstructions"], |
| log["samples"], |
| ) |
|
|
| if self.image_key == "fbank": |
| wav_original, wav_prediction = synth_one_sample( |
| inputs[index], |
| reconstructions[index], |
| labels="validation", |
| vocoder=self.vocoder, |
| ) |
| wav_original, wav_samples = synth_one_sample( |
| inputs[index], samples[index], labels="validation", vocoder=self.vocoder |
| ) |
| wav_original, wav_samples, wav_prediction = ( |
| wav_original[0], |
| wav_samples[0], |
| wav_prediction[0], |
| ) |
| elif self.image_key == "stft": |
| wav_prediction = ( |
| self.decode_to_waveform(reconstructions)[index, 0] |
| .cpu() |
| .detach() |
| .numpy() |
| ) |
| wav_samples = ( |
| self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() |
| ) |
| wav_original = waveform[index, 0].cpu().detach().numpy() |
|
|
| if self.logger is not None: |
| self.logger.experiment.log( |
| { |
| "original_%s" |
| % name: wandb.Audio( |
| wav_original, caption="original", sample_rate=self.sampling_rate |
| ), |
| "reconstruct_%s" |
| % name: wandb.Audio( |
| wav_prediction, |
| caption="reconstruct", |
| sample_rate=self.sampling_rate, |
| ), |
| "samples_%s" |
| % name: wandb.Audio( |
| wav_samples, caption="samples", sample_rate=self.sampling_rate |
| ), |
| } |
| ) |
|
|
| return wav_original, wav_prediction, wav_samples |
|
|
| def tensor2numpy(self, tensor): |
| return tensor.cpu().detach().numpy() |
|
|
| def to_rgb(self, x): |
| assert self.image_key == "segmentation" |
| if not hasattr(self, "colorize"): |
| self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize) |
| x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 |
| return x |
|
|
|
|
| class IdentityFirstStage(torch.nn.Module): |
| def __init__(self, *args, vq_interface=False, **kwargs): |
| self.vq_interface = vq_interface |
| super().__init__() |
|
|
| def encode(self, x, *args, **kwargs): |
| return x |
|
|
| def decode(self, x, *args, **kwargs): |
| return x |
|
|
| def quantize(self, x, *args, **kwargs): |
| if self.vq_interface: |
| return x, None, [None, None, None] |
| return x |
|
|
| def forward(self, x, *args, **kwargs): |
| return x |
|
|
|
|
| def window_sumsquare( |
| window, |
| n_frames, |
| hop_length, |
| win_length, |
| n_fft, |
| dtype=np.float32, |
| norm=None, |
| ): |
| """ |
| # from librosa 0.6 |
| Compute the sum-square envelope of a window function at a given hop length. |
| |
| This is used to estimate modulation effects induced by windowing |
| observations in short-time fourier transforms. |
| |
| Parameters |
| ---------- |
| window : string, tuple, number, callable, or list-like |
| Window specification, as in `get_window` |
| |
| n_frames : int > 0 |
| The number of analysis frames |
| |
| hop_length : int > 0 |
| The number of samples to advance between frames |
| |
| win_length : [optional] |
| The length of the window function. By default, this matches `n_fft`. |
| |
| n_fft : int > 0 |
| The length of each analysis frame. |
| |
| dtype : np.dtype |
| The data type of the output |
| |
| Returns |
| ------- |
| wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` |
| The sum-squared envelope of the window function |
| """ |
| if win_length is None: |
| win_length = n_fft |
|
|
| n = n_fft + hop_length * (n_frames - 1) |
| x = np.zeros(n, dtype=dtype) |
|
|
| |
| win_sq = get_window(window, win_length, fftbins=True) |
| win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 |
| win_sq = librosa_util.pad_center(win_sq, n_fft) |
|
|
| |
| for i in range(n_frames): |
| sample = i * hop_length |
| x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] |
| return x |
|
|
| def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): |
| """ |
| PARAMS |
| ------ |
| C: compression factor |
| """ |
| return normalize_fun(torch.clamp(x, min=clip_val) * C) |
|
|
|
|
| def dynamic_range_decompression(x, C=1): |
| """ |
| PARAMS |
| ------ |
| C: compression factor used to compress |
| """ |
| return torch.exp(x) / C |
|
|
|
|
| class STFT(torch.nn.Module): |
| """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" |
|
|
| def __init__(self, filter_length, hop_length, win_length, window="hann"): |
| super(STFT, self).__init__() |
| self.filter_length = filter_length |
| self.hop_length = hop_length |
| self.win_length = win_length |
| self.window = window |
| self.forward_transform = None |
| scale = self.filter_length / self.hop_length |
| fourier_basis = np.fft.fft(np.eye(self.filter_length)) |
|
|
| cutoff = int((self.filter_length / 2 + 1)) |
| fourier_basis = np.vstack( |
| [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] |
| ) |
|
|
| forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) |
| inverse_basis = torch.FloatTensor( |
| np.linalg.pinv(scale * fourier_basis).T[:, None, :] |
| ) |
|
|
| if window is not None: |
| assert filter_length >= win_length |
| |
| fft_window = get_window(window, win_length, fftbins=True) |
| fft_window = pad_center(fft_window, size=filter_length) |
| fft_window = torch.from_numpy(fft_window).float() |
|
|
| |
| forward_basis *= fft_window |
| inverse_basis *= fft_window |
|
|
| self.register_buffer("forward_basis", forward_basis.float()) |
| self.register_buffer("inverse_basis", inverse_basis.float()) |
|
|
| def transform(self, input_data): |
|
|
| device = self.forward_basis.device |
| input_data = input_data.to(device) |
|
|
| num_batches = input_data.size(0) |
| num_samples = input_data.size(1) |
|
|
| self.num_samples = num_samples |
|
|
| |
| input_data = input_data.view(num_batches, 1, num_samples) |
| input_data = torch.nn.functional.pad( |
| input_data.unsqueeze(1), |
| (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), |
| mode="reflect", |
| ) |
| input_data = input_data.squeeze(1) |
|
|
| forward_transform = torch.nn.functional.conv1d( |
| input_data, |
| torch.autograd.Variable(self.forward_basis, requires_grad=False), |
| stride=self.hop_length, |
| padding=0, |
| ) |
|
|
| cutoff = int((self.filter_length / 2) + 1) |
| real_part = forward_transform[:, :cutoff, :] |
| imag_part = forward_transform[:, cutoff:, :] |
|
|
| magnitude = torch.sqrt(real_part**2 + imag_part**2) |
| phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) |
|
|
| return magnitude, phase |
|
|
| def inverse(self, magnitude, phase): |
|
|
| device = self.forward_basis.device |
| magnitude, phase = magnitude.to(device), phase.to(device) |
|
|
| recombine_magnitude_phase = torch.cat( |
| [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 |
| ) |
|
|
| inverse_transform = torch.nn.functional.conv_transpose1d( |
| recombine_magnitude_phase, |
| torch.autograd.Variable(self.inverse_basis, requires_grad=False), |
| stride=self.hop_length, |
| padding=0, |
| ) |
|
|
| if self.window is not None: |
| window_sum = window_sumsquare( |
| self.window, |
| magnitude.size(-1), |
| hop_length=self.hop_length, |
| win_length=self.win_length, |
| n_fft=self.filter_length, |
| dtype=np.float32, |
| ) |
| |
| approx_nonzero_indices = torch.from_numpy( |
| np.where(window_sum > tiny(window_sum))[0] |
| ) |
| window_sum = torch.autograd.Variable( |
| torch.from_numpy(window_sum), requires_grad=False |
| ) |
| window_sum = window_sum |
| inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ |
| approx_nonzero_indices |
| ] |
|
|
| |
| inverse_transform *= float(self.filter_length) / self.hop_length |
|
|
| inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] |
| inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] |
|
|
| return inverse_transform |
|
|
| def forward(self, input_data): |
| self.magnitude, self.phase = self.transform(input_data) |
| reconstruction = self.inverse(self.magnitude, self.phase) |
| return reconstruction |
|
|
|
|
| class TacotronSTFT(torch.nn.Module): |
| def __init__( |
| self, |
| filter_length, |
| hop_length, |
| win_length, |
| n_mel_channels, |
| sampling_rate, |
| mel_fmin, |
| mel_fmax, |
| ): |
| super(TacotronSTFT, self).__init__() |
| self.n_mel_channels = n_mel_channels |
| self.sampling_rate = sampling_rate |
| self.stft_fn = STFT(filter_length, hop_length, win_length) |
| mel_basis = librosa_mel_fn( |
| sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax |
| ) |
| mel_basis = torch.from_numpy(mel_basis).float() |
| self.register_buffer("mel_basis", mel_basis) |
|
|
| def spectral_normalize(self, magnitudes, normalize_fun): |
| output = dynamic_range_compression(magnitudes, normalize_fun) |
| return output |
|
|
| def spectral_de_normalize(self, magnitudes): |
| output = dynamic_range_decompression(magnitudes) |
| return output |
|
|
| def mel_spectrogram(self, y, normalize_fun=torch.log): |
| """Computes mel-spectrograms from a batch of waves |
| PARAMS |
| ------ |
| y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] |
| |
| RETURNS |
| ------- |
| mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) |
| """ |
| assert torch.min(y.data) >= -1, torch.min(y.data) |
| assert torch.max(y.data) <= 1, torch.max(y.data) |
|
|
| magnitudes, phases = self.stft_fn.transform(y) |
| magnitudes = magnitudes.data |
| mel_output = torch.matmul(self.mel_basis, magnitudes) |
| mel_output = self.spectral_normalize(mel_output, normalize_fun) |
| energy = torch.norm(magnitudes, dim=1) |
|
|
| log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) |
|
|
| return mel_output, log_magnitudes, energy |
|
|
|
|
| def build_pretrained_models(ckpt): |
| checkpoint = torch.load(ckpt, map_location="cpu") |
| scale_factor = checkpoint["state_dict"]["scale_factor"].item() |
| print("scale_factor: ", scale_factor) |
|
|
| vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k} |
|
|
| config = { |
| "preprocessing": { |
| "audio": { |
| "sampling_rate": 48000, |
| "max_wav_value": 32768, |
| "duration": 10.24 |
| }, |
| "stft": { |
| "filter_length": 2048, |
| "hop_length": 480, |
| "win_length": 2048 |
| }, |
| "mel": { |
| "n_mel_channels": 256, |
| "mel_fmin": 20, |
| "mel_fmax": 24000 |
| } |
| }, |
| "model": { |
| "params": { |
| "first_stage_config": { |
| "params": { |
| "sampling_rate": 48000, |
| "batchsize": 4, |
| "monitor": "val/rec_loss", |
| "image_key": "fbank", |
| "subband": 1, |
| "embed_dim": 16, |
| "time_shuffle": 1, |
| "lossconfig": { |
| "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", |
| "params": { |
| "disc_start": 50001, |
| "kl_weight": 1000, |
| "disc_weight": 0.5, |
| "disc_in_channels": 1 |
| } |
| }, |
| "ddconfig": { |
| "double_z": True, |
| "mel_bins": 256, |
| "z_channels": 16, |
| "resolution": 256, |
| "downsample_time": False, |
| "in_channels": 1, |
| "out_ch": 1, |
| "ch": 128, |
| "ch_mult": [ |
| 1, |
| 2, |
| 4, |
| 8 |
| ], |
| "num_res_blocks": 2, |
| "attn_resolutions": [], |
| "dropout": 0 |
| } |
| } |
| }, |
| } |
| } |
| } |
| vae_config = config["model"]["params"]["first_stage_config"]["params"] |
| vae_config["scale_factor"] = scale_factor |
|
|
| vae = AutoencoderKL(**vae_config) |
| vae.load_state_dict(vae_state_dict) |
|
|
| fn_STFT = TacotronSTFT( |
| config["preprocessing"]["stft"]["filter_length"], |
| config["preprocessing"]["stft"]["hop_length"], |
| config["preprocessing"]["stft"]["win_length"], |
| config["preprocessing"]["mel"]["n_mel_channels"], |
| config["preprocessing"]["audio"]["sampling_rate"], |
| config["preprocessing"]["mel"]["mel_fmin"], |
| config["preprocessing"]["mel"]["mel_fmax"], |
| ) |
|
|
| vae.eval() |
| fn_STFT.eval() |
| return vae, fn_STFT |
|
|
|
|
|
|