| from typing import Optional |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from .utils import activate_add_tanh_sigmoid_multiply |
|
|
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, channels: int, eps: float = 1e-5): |
| super(LayerNorm, self).__init__() |
| self.channels = channels |
| self.eps = eps |
|
|
| self.gamma = nn.Parameter(torch.ones(channels)) |
| self.beta = nn.Parameter(torch.zeros(channels)) |
|
|
| def forward(self, x: torch.Tensor): |
| x = x.transpose(1, -1) |
| x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) |
| return x.transpose(1, -1) |
|
|
|
|
| class WN(torch.nn.Module): |
| def __init__( |
| self, |
| hidden_channels: int, |
| kernel_size: int, |
| dilation_rate: int, |
| n_layers: int, |
| gin_channels: int = 0, |
| p_dropout: int = 0, |
| ): |
| super(WN, self).__init__() |
| assert kernel_size % 2 == 1 |
| self.hidden_channels = hidden_channels |
| self.kernel_size = (kernel_size,) |
| self.dilation_rate = dilation_rate |
| self.n_layers = n_layers |
| self.gin_channels = gin_channels |
| self.p_dropout = float(p_dropout) |
|
|
| self.in_layers = torch.nn.ModuleList() |
| self.res_skip_layers = torch.nn.ModuleList() |
| self.drop = nn.Dropout(float(p_dropout)) |
|
|
| if gin_channels != 0: |
| cond_layer = torch.nn.Conv1d( |
| gin_channels, 2 * hidden_channels * n_layers, 1 |
| ) |
| self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") |
|
|
| for i in range(n_layers): |
| dilation = dilation_rate**i |
| padding = int((kernel_size * dilation - dilation) / 2) |
| in_layer = torch.nn.Conv1d( |
| hidden_channels, |
| 2 * hidden_channels, |
| kernel_size, |
| dilation=dilation, |
| padding=padding, |
| ) |
| in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") |
| self.in_layers.append(in_layer) |
|
|
| |
| if i < n_layers - 1: |
| res_skip_channels = 2 * hidden_channels |
| else: |
| res_skip_channels = hidden_channels |
|
|
| res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) |
| res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") |
| self.res_skip_layers.append(res_skip_layer) |
|
|
| def __call__( |
| self, |
| x: torch.Tensor, |
| x_mask: torch.Tensor, |
| g: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| return super().__call__(x, x_mask, g=g) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| x_mask: torch.Tensor, |
| g: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| output = torch.zeros_like(x) |
|
|
| if g is not None: |
| g = self.cond_layer(g) |
|
|
| for i, (in_layer, res_skip_layer) in enumerate( |
| zip(self.in_layers, self.res_skip_layers) |
| ): |
| x_in: torch.Tensor = in_layer(x) |
| if g is not None: |
| cond_offset = i * 2 * self.hidden_channels |
| g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] |
| else: |
| g_l = torch.zeros_like(x_in) |
|
|
| acts = activate_add_tanh_sigmoid_multiply(x_in, g_l, self.hidden_channels) |
| acts: torch.Tensor = self.drop(acts) |
|
|
| res_skip_acts: torch.Tensor = res_skip_layer(acts) |
| if i < self.n_layers - 1: |
| res_acts = res_skip_acts[:, : self.hidden_channels, :] |
| x = (x + res_acts) * x_mask |
| output = output + res_skip_acts[:, self.hidden_channels :, :] |
| else: |
| output = output + res_skip_acts |
| return output * x_mask |
|
|
| def remove_weight_norm(self): |
| if self.gin_channels != 0: |
| torch.nn.utils.remove_weight_norm(self.cond_layer) |
| for l in self.in_layers: |
| torch.nn.utils.remove_weight_norm(l) |
| for l in self.res_skip_layers: |
| torch.nn.utils.remove_weight_norm(l) |
|
|
| def __prepare_scriptable__(self): |
| if self.gin_channels != 0: |
| for hook in self.cond_layer._forward_pre_hooks.values(): |
| if ( |
| hook.__module__ == "torch.nn.utils.weight_norm" |
| and hook.__class__.__name__ == "WeightNorm" |
| ): |
| torch.nn.utils.remove_weight_norm(self.cond_layer) |
| for l in self.in_layers: |
| for hook in l._forward_pre_hooks.values(): |
| if ( |
| hook.__module__ == "torch.nn.utils.weight_norm" |
| and hook.__class__.__name__ == "WeightNorm" |
| ): |
| torch.nn.utils.remove_weight_norm(l) |
| for l in self.res_skip_layers: |
| for hook in l._forward_pre_hooks.values(): |
| if ( |
| hook.__module__ == "torch.nn.utils.weight_norm" |
| and hook.__class__.__name__ == "WeightNorm" |
| ): |
| torch.nn.utils.remove_weight_norm(l) |
| return self |
|
|