| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch |
|
|
|
|
| class Discriminator(nn.Module): |
| def __init__(self, in_channels=3): |
| super(Discriminator, self).__init__() |
|
|
| def discriminator_block(in_filters, out_filters, normalization=True): |
| """Returns downsampling layers of each discriminator block""" |
| layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] |
| if normalization: |
| layers.append(nn.InstanceNorm2d(out_filters)) |
| layers.append(nn.LeakyReLU(0.2, inplace=True)) |
| return layers |
|
|
| self.model = nn.Sequential( |
| *discriminator_block(in_channels * 3, 64, normalization=False), |
| *discriminator_block(64, 128), |
| *discriminator_block(128, 256), |
| *discriminator_block(256, 512), |
| nn.ZeroPad2d((1, 0, 1, 0)), |
| nn.Conv2d(512, 1, 4, padding=1, bias=False) |
| ) |
|
|
| def forward(self, img_out, img_l, img_ref ): |
| |
| img_input = torch.cat((img_out, img_l, img_ref), 1) |
| return self.model(img_input) |
|
|
|
|