| ''' |
| Codes are from: |
| https://github.com/jaxony/unet-pytorch/blob/master/model.py |
| ''' |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.autograd import Variable |
| from collections import OrderedDict |
| from torch.nn import init |
| import numpy as np |
|
|
|
|
| def conv3x3(in_channels, out_channels, stride=1, |
| padding=1, bias=True, groups=1): |
| return nn.Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=3, |
| stride=stride, |
| padding=padding, |
| bias=bias, |
| groups=groups) |
|
|
|
|
| def upconv2x2(in_channels, out_channels, mode='transpose'): |
| if mode == 'transpose': |
| return nn.ConvTranspose2d( |
| in_channels, |
| out_channels, |
| kernel_size=2, |
| stride=2) |
| else: |
| |
| |
| return nn.Sequential( |
| nn.Upsample(mode='bilinear', scale_factor=2), |
| conv1x1(in_channels, out_channels)) |
|
|
|
|
| def conv1x1(in_channels, out_channels, groups=1): |
| return nn.Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=1, |
| groups=groups, |
| stride=1) |
|
|
| class RollOut_Conv(nn.Module): |
| def __init__(self,in_channels,out_channels): |
| super(RollOut_Conv,self).__init__() |
| |
| self.in_channels=in_channels |
| self.out_channels=out_channels |
| self.conv = conv3x3(self.in_channels*3, self.out_channels) |
|
|
| def forward(self,row_features): |
| H,W=row_features.shape[2],row_features.shape[3] |
| H_per=H//3 |
| xz_feature,xy_feature,yz_feature=torch.split(row_features,dim=2,split_size_or_sections=H_per) |
| xy_row_pool=torch.mean(xy_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) |
| yz_col_pool=torch.mean(yz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) |
| cat_xz_feat=torch.cat([xz_feature,xy_row_pool,yz_col_pool],dim=1) |
|
|
| xz_row_pool=torch.mean(xz_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) |
| zy_feature=yz_feature.transpose(2,3) |
| zy_col_pool=torch.mean(zy_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) |
| cat_xy_feat=torch.cat([xy_feature,xz_row_pool,zy_col_pool],dim=1) |
|
|
| xz_col_pool=torch.mean(xz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) |
| yx_feature=xy_feature.transpose(2,3) |
| yx_row_pool=torch.mean(yx_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) |
| cat_yz_feat=torch.cat([yz_feature,yx_row_pool,xz_col_pool],dim=1) |
|
|
| fuse_row_feat=torch.cat([cat_xz_feat,cat_xy_feat,cat_yz_feat],dim=2) |
|
|
| x = self.conv(fuse_row_feat) |
|
|
| return x |
|
|
|
|
| class DownConv(nn.Module): |
| """ |
| A helper Module that performs 2 convolutions and 1 MaxPool. |
| A ReLU activation follows each convolution. |
| """ |
|
|
| def __init__(self, in_channels, out_channels, pooling=True): |
| super(DownConv, self).__init__() |
|
|
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.pooling = pooling |
|
|
| self.conv1 = conv3x3(self.in_channels, self.out_channels) |
| self.Rollout_conv=RollOut_Conv(self.out_channels,self.out_channels) |
| self.conv2 = conv3x3(self.out_channels, self.out_channels) |
|
|
| if self.pooling: |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
| def forward(self, x): |
| x = F.relu(self.conv1(x)) |
| x = F.relu(self.Rollout_conv(x)) |
| x = F.relu(self.conv2(x)) |
| before_pool = x |
| if self.pooling: |
| x = self.pool(x) |
| return x, before_pool |
|
|
|
|
| class UpConv(nn.Module): |
| """ |
| A helper Module that performs 2 convolutions and 1 UpConvolution. |
| A ReLU activation follows each convolution. |
| """ |
|
|
| def __init__(self, in_channels, out_channels, |
| merge_mode='concat', up_mode='transpose'): |
| super(UpConv, self).__init__() |
|
|
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.merge_mode = merge_mode |
| self.up_mode = up_mode |
|
|
| self.upconv = upconv2x2(self.in_channels, self.out_channels, |
| mode=self.up_mode) |
|
|
| if self.merge_mode == 'concat': |
| self.conv1 = conv3x3( |
| 2 * self.out_channels, self.out_channels) |
| else: |
| |
| self.conv1 = conv3x3(self.out_channels, self.out_channels) |
| self.Rollout_conv = RollOut_Conv(self.out_channels, self.out_channels) |
| self.conv2 = conv3x3(self.out_channels, self.out_channels) |
|
|
| def forward(self, from_down, from_up): |
| """ Forward pass |
| Arguments: |
| from_down: tensor from the encoder pathway |
| from_up: upconv'd tensor from the decoder pathway |
| """ |
| from_up = self.upconv(from_up) |
| if self.merge_mode == 'concat': |
| x = torch.cat((from_up, from_down), 1) |
| else: |
| x = from_up + from_down |
| x = F.relu(self.conv1(x)) |
| x = F.relu(self.Rollout_conv(x)) |
| x = F.relu(self.conv2(x)) |
| return x |
|
|
|
|
| class UNet(nn.Module): |
| """ `UNet` class is based on https://arxiv.org/abs/1505.04597 |
| |
| The U-Net is a convolutional encoder-decoder neural network. |
| Contextual spatial information (from the decoding, |
| expansive pathway) about an input tensor is merged with |
| information representing the localization of details |
| (from the encoding, compressive pathway). |
| |
| Modifications to the original paper: |
| (1) padding is used in 3x3 convolutions to prevent loss |
| of border pixels |
| (2) merging outputs does not require cropping due to (1) |
| (3) residual connections can be used by specifying |
| UNet(merge_mode='add') |
| (4) if non-parametric upsampling is used in the decoder |
| pathway (specified by upmode='upsample'), then an |
| additional 1x1 2d convolution occurs after upsampling |
| to reduce channel dimensionality by a factor of 2. |
| This channel halving happens with the convolution in |
| the tranpose convolution (specified by upmode='transpose') |
| """ |
|
|
| def __init__(self, num_classes, in_channels=3, depth=5, |
| start_filts=64, up_mode='transpose', |
| merge_mode='concat', **kwargs): |
| """ |
| Arguments: |
| in_channels: int, number of channels in the input tensor. |
| Default is 3 for RGB images. |
| depth: int, number of MaxPools in the U-Net. |
| start_filts: int, number of convolutional filters for the |
| first conv. |
| up_mode: string, type of upconvolution. Choices: 'transpose' |
| for transpose convolution or 'upsample' for nearest neighbour |
| upsampling. |
| """ |
| super(UNet, self).__init__() |
|
|
| if up_mode in ('transpose', 'upsample'): |
| self.up_mode = up_mode |
| else: |
| raise ValueError("\"{}\" is not a valid mode for " |
| "upsampling. Only \"transpose\" and " |
| "\"upsample\" are allowed.".format(up_mode)) |
|
|
| if merge_mode in ('concat', 'add'): |
| self.merge_mode = merge_mode |
| else: |
| raise ValueError("\"{}\" is not a valid mode for" |
| "merging up and down paths. " |
| "Only \"concat\" and " |
| "\"add\" are allowed.".format(up_mode)) |
|
|
| |
| if self.up_mode == 'upsample' and self.merge_mode == 'add': |
| raise ValueError("up_mode \"upsample\" is incompatible " |
| "with merge_mode \"add\" at the moment " |
| "because it doesn't make sense to use " |
| "nearest neighbour to reduce " |
| "depth channels (by half).") |
|
|
| self.num_classes = num_classes |
| self.in_channels = in_channels |
| self.start_filts = start_filts |
| self.depth = depth |
|
|
| self.down_convs = [] |
| self.up_convs = [] |
|
|
| |
| for i in range(depth): |
| ins = self.in_channels if i == 0 else outs |
| outs = self.start_filts * (2 ** i) |
| pooling = True if i < depth - 1 else False |
|
|
| down_conv = DownConv(ins, outs, pooling=pooling) |
| self.down_convs.append(down_conv) |
|
|
| |
| |
| for i in range(depth - 1): |
| ins = outs |
| outs = ins // 2 |
| up_conv = UpConv(ins, outs, up_mode=up_mode, |
| merge_mode=merge_mode) |
| self.up_convs.append(up_conv) |
|
|
| |
| self.down_convs = nn.ModuleList(self.down_convs) |
| self.up_convs = nn.ModuleList(self.up_convs) |
| self.conv_final = conv1x1(outs, self.num_classes) |
|
|
| self.reset_params() |
|
|
| @staticmethod |
| def weight_init(m): |
| if isinstance(m, nn.Conv2d): |
| init.xavier_normal_(m.weight) |
| init.constant_(m.bias, 0) |
|
|
| def reset_params(self): |
| for i, m in enumerate(self.modules()): |
| self.weight_init(m) |
|
|
| def forward(self, feature_plane): |
| |
| x=feature_plane |
| encoder_outs = [] |
| |
| for i, module in enumerate(self.down_convs): |
| x, before_pool = module(x) |
| encoder_outs.append(before_pool) |
| for i, module in enumerate(self.up_convs): |
| before_pool = encoder_outs[-(i + 2)] |
| x = module(before_pool, x) |
|
|
| |
| |
| |
| x = self.conv_final(x) |
| return x |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| model=UNet(32, depth=5, merge_mode='concat', in_channels=32, start_filts=32).cuda().float() |
| row_feature=torch.randn((10,32,128*3,128)).cuda().float() |
| output=model(row_feature) |
| |
| |