| """ |
| Contains an implementation of the U-Net architecture. |
| U-Net paper by Ronneberger et al. (2015): https://arxiv.org/abs/1505.04597 |
| |
| This implementation is based on the original U-Net architecture, with options for |
| normalization (batch normalization or layer normalization), bilinear upsampling, |
| and padding in the convolution layers. |
| |
| Author: Ole-Christian Galbo Engstrøm |
| E-mail: ocge@foss.dk |
| """ |
|
|
| from typing import Iterable |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
|
|
| def conv3x3(in_channels: int, out_channels: int, bias: bool, pad: bool) -> nn.Conv2d: |
| """ |
| Applies a convolution with a 3x3 kernel. |
| """ |
| if pad: |
| padding = 1 |
| else: |
| padding = "valid" |
| layer = nn.Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=3, |
| padding=padding, |
| bias=bias, |
| ) |
| return layer |
|
|
|
|
| def conv_block( |
| in_channels: int, |
| out_channels: int, |
| non_linearity: nn.Module, |
| normalization: None | str, |
| bias: bool, |
| pad: bool, |
| ) -> nn.Sequential: |
| """ |
| A block of two convolutional layers, each followed by a non-linearity |
| and optionally a normalization layer. |
| |
| In the U-Net architecture illustration in the U-Net paper, |
| this corresponds to two blue arrows. |
| """ |
| layers = [] |
| for _ in range(2): |
| layers.append( |
| conv3x3( |
| in_channels=in_channels, out_channels=out_channels, bias=bias, pad=pad |
| ) |
| ) |
| layers.append(non_linearity) |
| layers.append( |
| get_norm_layer(normalization=normalization, in_channels=out_channels) |
| ) |
| in_channels = out_channels |
| return nn.Sequential(*layers) |
|
|
|
|
| def batch_norm(in_channels: int) -> nn.Sequential: |
| """ |
| Apply Batch Normalization over the channel dimension. |
| Batch Normalization paper by Ioffe and Szegedy (2015): https://arxiv.org/abs/1502.03167 |
| """ |
| return nn.BatchNorm2d(in_channels, momentum=0.01) |
|
|
|
|
| class Permute(nn.Module): |
| """ |
| Permute the dimensions of a tensor. |
| """ |
|
|
| def __init__(self, dims: Iterable[int]): |
| super().__init__() |
| self.dims = dims |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x.permute(self.dims) |
|
|
| def __repr__(self): |
| return f'{self.__class__.__name__}({", ".join(map(str, self.dims))})' |
|
|
|
|
| def layer_norm(in_channels: int) -> nn.Sequential: |
| """ |
| Apply Layer Normalization over the channel dimension. |
| Layer Normalization paper by Ba et al. (2016): https://arxiv.org/abs/1607.06450 |
| """ |
| layers = [ |
| |
| Permute((0, 2, 3, 1)), |
| |
| |
| nn.LayerNorm(in_channels), |
| |
| Permute((0, 3, 1, 2)), |
| ] |
| return nn.Sequential(*layers) |
|
|
|
|
| def get_norm_layer(normalization: None | str, in_channels: int) -> nn.Module: |
| """ |
| Get the normalization layer based on the specified type. |
| Either 'bn' for batch normalization, 'ln' for layer normalization, |
| or None for no normalization layer. |
| """ |
| if normalization == "bn": |
| return batch_norm(in_channels) |
| if normalization == "ln": |
| return layer_norm(in_channels) |
| return nn.Identity() |
|
|
|
|
| def copy_and_crop(large: torch.Tensor, small: torch.Tensor) -> torch.Tensor: |
| """ |
| Implementation of a copy-and-crop block in the U-Net architecture. |
| Copy the large image and crop it to the size of the small image. |
| The large image is cropped in the middle, and then the two images are |
| concatenated along the channel dimension. |
| |
| In the U-Net architecture illustration in the U-Net paper, |
| this corresponds to a gray arrow. |
| """ |
| large_height, large_width = large.shape[-2:] |
| small_height, small_width = small.shape[-2:] |
| start_x = (large_height - small_height) // 2 |
| start_y = (large_width - small_width) // 2 |
| cropped_large = large[ |
| ..., start_x : start_x + small_height, start_y : start_y + small_width |
| ] |
| return torch.cat([cropped_large, small], dim=-3) |
|
|
|
|
| class ContractionBlock(nn.Module): |
| """ |
| Implementation of a contraction block in the U-Net architecture. |
| This block consists of a max pooling layer followed by a convolution block. |
| |
| In the U-Net architecture illustration in the U-Net paper, this corresponds to |
| one red arrow followed by the subsequent two blue arrows. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| non_linearity: nn.Module, |
| nonormalization: None | str, |
| bias: bool, |
| pad: bool, |
| ): |
| super().__init__() |
| self.max_pool = self._max_pool() |
| self.conv_block = conv_block( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| non_linearity=non_linearity, |
| normalization=nonormalization, |
| bias=bias, |
| pad=pad, |
| ) |
|
|
| def _max_pool(self) -> nn.MaxPool2d: |
| layer = nn.MaxPool2d(kernel_size=2, stride=2) |
| return layer |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.max_pool(x) |
| x = self.conv_block(x) |
| return x |
|
|
|
|
| class Upsample(nn.Module): |
| """ |
| Implementation of an upsampling block in the U-Net architecture. |
| This block consists of either a transposed convolution or bilinear upsampling, |
| followed by a convolution block. |
| |
| In the U-Net architecture illustration in the U-Net paper, this corresponds to |
| one green arrow. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| non_linearity, |
| normalization: None | str, |
| bias: bool, |
| bilinear: bool, |
| ): |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.non_linearity = non_linearity |
| self.normalization = normalization |
| self.bias = bias |
| self.bilinear = bilinear |
| self.up = self._upsample(in_channels, out_channels) |
|
|
| def _upsample(self, in_channels: int, out_channels: int) -> nn.Sequential: |
| if self.bilinear: |
| up = self._up_bilinear(in_channels, out_channels) |
| else: |
| up = self._up_trans_conv2x2(in_channels, out_channels) |
| return up |
|
|
| def _up_trans_conv2x2(self, in_channels: int, out_channels: int) -> nn.Sequential: |
| layers = [ |
| nn.ConvTranspose2d( |
| in_channels, out_channels, kernel_size=2, stride=2, bias=self.bias |
| ), |
| self.non_linearity, |
| ] |
| layers.append(get_norm_layer(self.normalization, out_channels)) |
| return nn.Sequential(*layers) |
|
|
| def _up_bilinear(self, in_channels: int, out_channels: int) -> nn.Sequential: |
| layers = [ |
| nn.Upsample(mode="bilinear", scale_factor=2, align_corners=True), |
| nn.Conv2d( |
| in_channels=in_channels, out_channels=out_channels, kernel_size=1 |
| ), |
| self.non_linearity, |
| ] |
| layers.append(get_norm_layer(self.normalization, out_channels)) |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.up(x) |
|
|
|
|
| class ExpansionBlock(nn.Module): |
| """ |
| Implementation of an expansion block in the U-Net architecture. |
| This block consists of an upsampling block followed by a copy-and-crop block and |
| a convolution block. |
| |
| In the U-Net architecture illustration in the U-Net paper, this corresponds to |
| one green arrow followed by a gray arrow and then two blue arrows. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| non_linearity: nn.Module, |
| normalization: None | str, |
| bias: bool, |
| bilinear: bool, |
| pad: bool, |
| ): |
| super().__init__() |
| self.pad = pad |
| self.upsample = Upsample( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| non_linearity=non_linearity, |
| normalization=normalization, |
| bias=bias, |
| bilinear=bilinear, |
| ) |
| self.conv_block = self.conv_block = conv_block( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| non_linearity=non_linearity, |
| normalization=normalization, |
| bias=bias, |
| pad=pad, |
| ) |
|
|
| def forward(self, large: torch.Tensor, small: torch.Tensor) -> torch.Tensor: |
| x = self.upsample(small) |
| if self.pad: |
| diff_h = large.shape[-2] - x.shape[-2] |
| diff_w = large.shape[-1] - x.shape[-1] |
| pad_left = diff_w // 2 |
| pad_right = diff_w - pad_left |
| pad_top = diff_h // 2 |
| pad_bottom = diff_h - pad_top |
| x = F.pad( |
| x, |
| (pad_left, pad_right, pad_top, pad_bottom), |
| mode="constant", |
| value=0.0, |
| ) |
| x = copy_and_crop(large, x) |
| x = self.conv_block(x) |
| return x |
|
|
|
|
| class UNet(nn.Module): |
| """ |
| in_channels : int\\ |
| Number of input channels. |
| |
| out_channels : int\\ |
| Number of output channels |
| |
| pad : bool, default=True\\ |
| If True use padding in the convolution layers, preserving the input size. |
| If False, the output size will be reduced compared to the input size. |
| |
| bilinear : bool, default=True\\ |
| If True use bilinear upsampling. |
| If False use transposed convolution. |
| |
| normalization: None | str, default=None\\ |
| If None use no normalization. |
| If 'bn' use batch normalization. |
| If 'ln' use layer normalization. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| pad: bool = True, |
| bilinear: bool = True, |
| normalization: None | str = None, |
| ): |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.pad = pad |
| self.bilinear = bilinear |
| self.normalization = normalization |
| if self.normalization not in [None, "bn", "ln"]: |
| raise ValueError( |
| "Normalization must be None, 'bn' for batch normalization," |
| "or 'ln' for layer normalization" |
| ) |
| |
| |
| self.bias_conv = normalization is None |
| self.non_linearity = nn.ReLU(inplace=True) |
| self.intermediate_channels = [64 * 2**i for i in range(5)] |
| self.first_convs = conv_block( |
| in_channels=in_channels, |
| out_channels=self.intermediate_channels[0], |
| non_linearity=self.non_linearity, |
| normalization=self.normalization, |
| bias=self.bias_conv, |
| pad=self.pad, |
| ) |
| self.last_conv = nn.Conv2d( |
| self.intermediate_channels[0], out_channels, kernel_size=1 |
| ) |
|
|
| self.contraction1 = self._get_contraction_block( |
| in_channels=self.intermediate_channels[0], |
| out_channels=self.intermediate_channels[1], |
| ) |
| self.contraction2 = self._get_contraction_block( |
| in_channels=self.intermediate_channels[1], |
| out_channels=self.intermediate_channels[2], |
| ) |
| self.contraction3 = self._get_contraction_block( |
| in_channels=self.intermediate_channels[2], |
| out_channels=self.intermediate_channels[3], |
| ) |
| self.contraction4 = self._get_contraction_block( |
| in_channels=self.intermediate_channels[3], |
| out_channels=self.intermediate_channels[4], |
| ) |
| self.expansion4 = self._get_expansion_block( |
| in_channels=self.intermediate_channels[4], |
| out_channels=self.intermediate_channels[3], |
| ) |
| self.expansion3 = self._get_expansion_block( |
| in_channels=self.intermediate_channels[3], |
| out_channels=self.intermediate_channels[2], |
| ) |
| self.expansion2 = self._get_expansion_block( |
| in_channels=self.intermediate_channels[2], |
| out_channels=self.intermediate_channels[1], |
| ) |
| self.expansion1 = self._get_expansion_block( |
| in_channels=self.intermediate_channels[1], |
| out_channels=self.intermediate_channels[0], |
| ) |
|
|
| |
| for m in self.modules(): |
| if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
| nn.init.kaiming_normal_(m.weight) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
|
|
| def _get_contraction_block( |
| self, in_channels: int, out_channels: int |
| ) -> ContractionBlock: |
| return ContractionBlock( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| non_linearity=self.non_linearity, |
| nonormalization=self.normalization, |
| bias=self.bias_conv, |
| pad=self.pad, |
| ) |
|
|
| def _get_expansion_block( |
| self, in_channels: int, out_channels: int |
| ) -> ExpansionBlock: |
| return ExpansionBlock( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| non_linearity=self.non_linearity, |
| normalization=self.normalization, |
| bias=self.bias_conv, |
| bilinear=self.bilinear, |
| pad=self.pad, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x1 = self.first_convs(x) |
| x2 = self.contraction1(x1) |
| x3 = self.contraction2(x2) |
| x4 = self.contraction3(x3) |
| x5 = self.contraction4(x4) |
| x = self.expansion4(x4, x5) |
| x = self.expansion3(x3, x) |
| x = self.expansion2(x2, x) |
| x = self.expansion1(x1, x) |
| x = self.last_conv(x) |
| return x |
|
|