| from typing import Tuple
|
| import torch.nn as nn
|
| import torch
|
| from torch.nn import functional as F
|
| from VietTTS.utils.mask import make_pad_mask
|
|
|
|
|
| class InterpolateRegulator(nn.Module):
|
| def __init__(
|
| self,
|
| channels: int,
|
| sampling_ratios: Tuple,
|
| out_channels: int = None,
|
| groups: int = 1,
|
| ):
|
| super().__init__()
|
| self.sampling_ratios = sampling_ratios
|
| out_channels = out_channels or channels
|
| model = nn.ModuleList([])
|
| if len(sampling_ratios) > 0:
|
| for _ in sampling_ratios:
|
| module = nn.Conv1d(channels, channels, 3, 1, 1)
|
| norm = nn.GroupNorm(groups, channels)
|
| act = nn.Mish()
|
| model.extend([module, norm, act])
|
| model.append(
|
| nn.Conv1d(channels, out_channels, 1, 1)
|
| )
|
| self.model = nn.Sequential(*model)
|
|
|
| def forward(self, x, ylens=None):
|
|
|
| mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
| x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
| out = self.model(x).transpose(1, 2).contiguous()
|
| olens = ylens
|
| return out * mask, olens
|
|
|
| def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
|
|
|
|
| if x2.shape[1] > 40:
|
| x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
| x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
| mode='linear')
|
| x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
| x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
| else:
|
| x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
| if x1.shape[1] != 0:
|
| x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
| x = torch.concat([x1, x2], dim=2)
|
| else:
|
| x = x2
|
| out = self.model(x).transpose(1, 2).contiguous()
|
| return out, mel_len1 + mel_len2
|
|
|