| import torch |
|
|
| class CNN2D(torch.nn.Module): |
| |
| def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15): |
| assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding) |
| super(CNN2D, self).__init__() |
| |
| |
| self.conv_blocks = torch.nn.ModuleList() |
| prev_channel = 1 |
| |
| for i in range(len(channels)): |
| |
| block = [] |
| for j, conv_channel in enumerate(channels[i]): |
| block.append(torch.nn.Conv2d(in_channels=prev_channel, out_channels=conv_channel, kernel_size=conv_kernels[i], stride=conv_strides[i], padding=conv_padding[i])) |
| prev_channel = conv_channel |
| |
| block.append(torch.nn.BatchNorm2d(prev_channel)) |
| |
| block.append(torch.nn.ReLU()) |
| self.conv_blocks.append(torch.nn.Sequential(*block)) |
|
|
| |
| self.pool_blocks = torch.nn.ModuleList() |
| for i in range(len(pool_padding)): |
| |
| self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i])) |
|
|
| |
| self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) |
| self.linear = torch.nn.Linear(prev_channel, num_classes) |
|
|
| def forward(self, inwav): |
| for i in range(len(self.conv_blocks)): |
| |
| inwav = self.conv_blocks[i](inwav) |
| |
| if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav) |
| |
| out = self.global_pool(inwav).squeeze() |
| out = self.linear(out) |
| return out |
| |
| class ResBlock2D(torch.nn.Module): |
| |
| def __init__(self, prev_channel, channel, conv_kernel, conv_stride, conv_pad): |
| super(ResBlock2D, self).__init__() |
| self.res = torch.nn.Sequential( |
| torch.nn.Conv2d(in_channels=prev_channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad), |
| torch.nn.BatchNorm2d(channel), |
| torch.nn.ReLU(), |
| torch.nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad), |
| torch.nn.BatchNorm2d(channel), |
| ) |
| self.bn = torch.nn.BatchNorm2d(channel) |
| self.relu = torch.nn.ReLU() |
|
|
| def forward(self, x): |
| identity = x |
| x = self.res(x) |
| if x.shape[1] == identity.shape[1]: |
| x += identity |
| elif x.shape[1] > identity.shape[1]: |
| if x.shape[1] % identity.shape[1] == 0: |
| x += identity.repeat(1, x.shape[1]//identity.shape[1], 1, 1) |
| else: |
| raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!") |
| else: |
| if identity.shape[1] % x.shape[1] == 0: |
| identity += x.repeat(1, identity.shape[1]//x.shape[1], 1, 1) |
| else: |
| raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!") |
| x = identity |
| x = self.bn(x) |
| x = self.relu(x) |
| return x |
| |
| class CNNRes2D(torch.nn.Module): |
| |
| def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15): |
| assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding) |
| super(CNNRes2D, self).__init__() |
| |
| |
| prev_channel = 1 |
| self.conv_block = torch.nn.Sequential( |
| torch.nn.Conv2d(in_channels=prev_channel, out_channels=channels[0][0], kernel_size=conv_kernels[0], stride=conv_strides[0], padding=conv_padding[0]), |
| torch.nn.BatchNorm2d(channels[0][0]), |
| torch.nn.ReLU(), |
| torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[0]), |
| ) |
| |
| |
| prev_channel = channels[0][0] |
| self.res_blocks = torch.nn.ModuleList() |
| for i in range(1, len(channels)): |
| block = [] |
| for j, conv_channel in enumerate(channels[i]): |
| block.append(ResBlock2D(prev_channel, conv_channel, conv_kernels[i], conv_strides[i], conv_padding[i])) |
| prev_channel = conv_channel |
| self.res_blocks.append(torch.nn.Sequential(*block)) |
|
|
| |
| self.pool_blocks = torch.nn.ModuleList() |
| for i in range(1, len(pool_padding)): |
| self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i])) |
|
|
| |
| self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) |
| self.linear = torch.nn.Linear(prev_channel, num_classes) |
|
|
| def forward(self, inwav): |
| inwav = self.conv_block(inwav) |
| for i in range(len(self.res_blocks)): |
| inwav = self.res_blocks[i](inwav) |
| if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav) |
| out = self.global_pool(inwav).squeeze() |
| out = self.linear(out) |
| return out |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|