| from typing import Callable, Sequence, Type, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]] |
|
|
|
|
| class FeedForwardModule(nn.Module): |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| self.net = None |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| class Residual(nn.Module): |
|
|
| def __init__(self, module: nn.Module) -> None: |
| super().__init__() |
| self.module = module |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.module(x) + x |
|
|
|
|
| class DilatedConvolutionalUnit(FeedForwardModule): |
|
|
| def __init__( |
| self, |
| hidden_dim: int, |
| dilation: int, |
| kernel_size: int, |
| activation: ModuleFactory, |
| normalization: Callable[[nn.Module], |
| nn.Module] = lambda x: x) -> None: |
| super().__init__() |
| self.net = nn.Sequential( |
| activation(), |
| normalization( |
| nn.Conv1d( |
| in_channels=hidden_dim, |
| out_channels=hidden_dim, |
| kernel_size=kernel_size, |
| dilation=dilation, |
| padding=((kernel_size - 1) * dilation) // 2, |
| )), |
| activation(), |
| nn.Conv1d(in_channels=hidden_dim, |
| out_channels=hidden_dim, |
| kernel_size=1), |
| ) |
|
|
|
|
| class UpsamplingUnit(FeedForwardModule): |
|
|
| def __init__( |
| self, |
| input_dim: int, |
| output_dim: int, |
| stride: int, |
| activation: ModuleFactory, |
| normalization: Callable[[nn.Module], |
| nn.Module] = lambda x: x) -> None: |
| super().__init__() |
| self.net = nn.Sequential( |
| activation(), |
| normalization( |
| nn.ConvTranspose1d( |
| in_channels=input_dim, |
| out_channels=output_dim, |
| kernel_size=2 * stride, |
| stride=stride, |
| padding=stride // 2+ stride % 2, |
| output_padding=1 if stride % 2 != 0 else 0 |
| ))) |
|
|
|
|
| class DownsamplingUnit(FeedForwardModule): |
|
|
| def __init__( |
| self, |
| input_dim: int, |
| output_dim: int, |
| stride: int, |
| activation: ModuleFactory, |
| normalization: Callable[[nn.Module], |
| nn.Module] = lambda x: x) -> None: |
| super().__init__() |
| self.net = nn.Sequential( |
| activation(), |
| normalization( |
| nn.Conv1d( |
| in_channels=input_dim, |
| out_channels=output_dim, |
| kernel_size=2 * stride, |
| stride=stride, |
| padding= stride // 2+ stride % 2, |
| |
| ))) |
|
|
|
|
| class DilatedResidualEncoder(FeedForwardModule): |
|
|
| def __init__( |
| self, |
| capacity: int, |
| dilated_unit: Type[DilatedConvolutionalUnit], |
| downsampling_unit: Type[DownsamplingUnit], |
| ratios: Sequence[int], |
| dilations: Union[Sequence[int], Sequence[Sequence[int]]], |
| pre_network_conv: Type[nn.Conv1d], |
| post_network_conv: Type[nn.Conv1d], |
| normalization: Callable[[nn.Module], |
| nn.Module] = lambda x: x) -> None: |
| super().__init__() |
| channels = capacity * 2**np.arange(len(ratios) + 1) |
|
|
| dilations_list = self.normalize_dilations(dilations, ratios) |
|
|
| net = [normalization(pre_network_conv(out_channels=channels[0]))] |
|
|
| for ratio, dilations, input_dim, output_dim in zip( |
| ratios, dilations_list, channels[:-1], channels[1:]): |
| for dilation in dilations: |
| net.append(Residual(dilated_unit(input_dim, dilation))) |
| net.append(downsampling_unit(input_dim, output_dim, ratio)) |
|
|
| net.append(post_network_conv(in_channels=output_dim)) |
|
|
| self.net = nn.Sequential(*net) |
|
|
| @staticmethod |
| def normalize_dilations(dilations: Union[Sequence[int], |
| Sequence[Sequence[int]]], |
| ratios: Sequence[int]): |
| if isinstance(dilations[0], int): |
| dilations = [dilations for _ in ratios] |
| return dilations |
|
|
|
|
| class DilatedResidualDecoder(FeedForwardModule): |
|
|
| def __init__( |
| self, |
| capacity: int, |
| dilated_unit: Type[DilatedConvolutionalUnit], |
| upsampling_unit: Type[UpsamplingUnit], |
| ratios: Sequence[int], |
| dilations: Union[Sequence[int], Sequence[Sequence[int]]], |
| pre_network_conv: Type[nn.Conv1d], |
| post_network_conv: Type[nn.Conv1d], |
| normalization: Callable[[nn.Module], |
| nn.Module] = lambda x: x) -> None: |
| super().__init__() |
| channels = capacity * 2**np.arange(len(ratios) + 1) |
| channels = channels[::-1] |
|
|
| dilations_list = self.normalize_dilations(dilations, ratios) |
| dilations_list = dilations_list[::-1] |
|
|
| net = [pre_network_conv(out_channels=channels[0])] |
|
|
| for ratio, dilations, input_dim, output_dim in zip( |
| ratios, dilations_list, channels[:-1], channels[1:]): |
| net.append(upsampling_unit(input_dim, output_dim, ratio)) |
| for dilation in dilations: |
| net.append(Residual(dilated_unit(output_dim, dilation))) |
|
|
| net.append(normalization(post_network_conv(in_channels=output_dim))) |
|
|
| self.net = nn.Sequential(*net) |
|
|
| @staticmethod |
| def normalize_dilations(dilations: Union[Sequence[int], |
| Sequence[Sequence[int]]], |
| ratios: Sequence[int]): |
| if isinstance(dilations[0], int): |
| dilations = [dilations for _ in ratios] |
| return dilations |