| |
| |
| |
| |
| |
|
|
|
|
|
|
| |
|
|
|
|
|
|
| import torch |
| import torch.nn.parallel |
| import numpy as np |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class Downsample(nn.Module): |
| def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): |
| super(Downsample, self).__init__() |
| self.filt_size = filt_size |
| self.pad_off = pad_off |
| self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] |
| self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] |
| self.stride = stride |
| self.off = int((self.stride-1)/2.) |
| self.channels = channels |
|
|
| |
| if(self.filt_size==1): |
| a = np.array([1.,]) |
| elif(self.filt_size==2): |
| a = np.array([1., 1.]) |
| elif(self.filt_size==3): |
| a = np.array([1., 2., 1.]) |
| elif(self.filt_size==4): |
| a = np.array([1., 3., 3., 1.]) |
| elif(self.filt_size==5): |
| a = np.array([1., 4., 6., 4., 1.]) |
| elif(self.filt_size==6): |
| a = np.array([1., 5., 10., 10., 5., 1.]) |
| elif(self.filt_size==7): |
| a = np.array([1., 6., 15., 20., 15., 6., 1.]) |
|
|
| filt = torch.Tensor(a[:,None]*a[None,:]) |
| filt = filt/torch.sum(filt) |
| self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) |
|
|
| self.pad = get_pad_layer(pad_type)(self.pad_sizes) |
|
|
| def forward(self, inp): |
| if(self.filt_size==1): |
| if(self.pad_off==0): |
| return inp[:,:,::self.stride,::self.stride] |
| else: |
| return self.pad(inp)[:,:,::self.stride,::self.stride] |
| else: |
| return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) |
|
|
| def get_pad_layer(pad_type): |
| if(pad_type in ['refl','reflect']): |
| PadLayer = nn.ReflectionPad2d |
| elif(pad_type in ['repl','replicate']): |
| PadLayer = nn.ReplicationPad2d |
| elif(pad_type=='zero'): |
| PadLayer = nn.ZeroPad2d |
| else: |
| print('Pad type [%s] not recognized'%pad_type) |
| return PadLayer |
|
|
|
|
| class Downsample1D(nn.Module): |
| def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): |
| super(Downsample1D, self).__init__() |
| self.filt_size = filt_size |
| self.pad_off = pad_off |
| self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] |
| self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] |
| self.stride = stride |
| self.off = int((self.stride - 1) / 2.) |
| self.channels = channels |
|
|
| |
| if(self.filt_size == 1): |
| a = np.array([1., ]) |
| elif(self.filt_size == 2): |
| a = np.array([1., 1.]) |
| elif(self.filt_size == 3): |
| a = np.array([1., 2., 1.]) |
| elif(self.filt_size == 4): |
| a = np.array([1., 3., 3., 1.]) |
| elif(self.filt_size == 5): |
| a = np.array([1., 4., 6., 4., 1.]) |
| elif(self.filt_size == 6): |
| a = np.array([1., 5., 10., 10., 5., 1.]) |
| elif(self.filt_size == 7): |
| a = np.array([1., 6., 15., 20., 15., 6., 1.]) |
|
|
| filt = torch.Tensor(a) |
| filt = filt / torch.sum(filt) |
| self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) |
|
|
| self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) |
|
|
| def forward(self, inp): |
| if(self.filt_size == 1): |
| if(self.pad_off == 0): |
| return inp[:, :, ::self.stride] |
| else: |
| return self.pad(inp)[:, :, ::self.stride] |
| else: |
| return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) |
|
|
|
|
| def get_pad_layer_1d(pad_type): |
| if(pad_type in ['refl', 'reflect']): |
| PadLayer = nn.ReflectionPad1d |
| elif(pad_type in ['repl', 'replicate']): |
| PadLayer = nn.ReplicationPad1d |
| elif(pad_type == 'zero'): |
| PadLayer = nn.ZeroPad1d |
| else: |
| print('Pad type [%s] not recognized' % pad_type) |
| return PadLayer |