| """ |
| Definition of the FFDNet model and its custom layers |
| |
| Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr> |
| |
| This program is free software: you can use, modify and/or |
| redistribute it under the terms of the GNU General Public |
| License as published by the Free Software Foundation, either |
| version 3 of the License, or (at your option) any later |
| version. You should have received a copy of this license along |
| this program. If not, see <http://www.gnu.org/licenses/>. |
| """ |
| import torch.nn as nn |
| from torch.autograd import Variable |
| import denoising.functions as functions |
| |
| class UpSampleFeatures(nn.Module): |
| r"""Implements the last layer of FFDNet |
| """ |
| def __init__(self): |
| super(UpSampleFeatures, self).__init__() |
| def forward(self, x): |
| return functions.upsamplefeatures(x) |
|
|
| class IntermediateDnCNN(nn.Module): |
| r"""Implements the middel part of the FFDNet architecture, which |
| is basically a DnCNN net |
| """ |
| def __init__(self, input_features, middle_features, num_conv_layers): |
| super(IntermediateDnCNN, self).__init__() |
| self.kernel_size = 3 |
| self.padding = 1 |
| self.input_features = input_features |
| self.num_conv_layers = num_conv_layers |
| self.middle_features = middle_features |
| if self.input_features == 5: |
| self.output_features = 4 |
| elif self.input_features == 15: |
| self.output_features = 12 |
| else: |
| raise Exception('Invalid number of input features') |
|
|
| layers = [] |
| layers.append(nn.Conv2d(in_channels=self.input_features,\ |
| out_channels=self.middle_features,\ |
| kernel_size=self.kernel_size,\ |
| padding=self.padding,\ |
| bias=False)) |
| layers.append(nn.ReLU(inplace=True)) |
| for _ in range(self.num_conv_layers-2): |
| layers.append(nn.Conv2d(in_channels=self.middle_features,\ |
| out_channels=self.middle_features,\ |
| kernel_size=self.kernel_size,\ |
| padding=self.padding,\ |
| bias=False)) |
| layers.append(nn.BatchNorm2d(self.middle_features)) |
| layers.append(nn.ReLU(inplace=True)) |
| layers.append(nn.Conv2d(in_channels=self.middle_features,\ |
| out_channels=self.output_features,\ |
| kernel_size=self.kernel_size,\ |
| padding=self.padding,\ |
| bias=False)) |
| self.itermediate_dncnn = nn.Sequential(*layers) |
| def forward(self, x): |
| out = self.itermediate_dncnn(x) |
| return out |
|
|
| class FFDNet(nn.Module): |
| r"""Implements the FFDNet architecture |
| """ |
| def __init__(self, num_input_channels): |
| super(FFDNet, self).__init__() |
| self.num_input_channels = num_input_channels |
| if self.num_input_channels == 1: |
| |
| self.num_feature_maps = 64 |
| self.num_conv_layers = 15 |
| self.downsampled_channels = 5 |
| self.output_features = 4 |
| elif self.num_input_channels == 3: |
| |
| self.num_feature_maps = 96 |
| self.num_conv_layers = 12 |
| self.downsampled_channels = 15 |
| self.output_features = 12 |
| else: |
| raise Exception('Invalid number of input features') |
|
|
| self.intermediate_dncnn = IntermediateDnCNN(\ |
| input_features=self.downsampled_channels,\ |
| middle_features=self.num_feature_maps,\ |
| num_conv_layers=self.num_conv_layers) |
| self.upsamplefeatures = UpSampleFeatures() |
|
|
| def forward(self, x, noise_sigma): |
| concat_noise_x = functions.concatenate_input_noise_map(x.data, noise_sigma.data) |
| concat_noise_x = Variable(concat_noise_x) |
| h_dncnn = self.intermediate_dncnn(concat_noise_x) |
| pred_noise = self.upsamplefeatures(h_dncnn) |
| return pred_noise |
|
|