| import numpy as np |
| import torch.nn as nn |
|
|
| from models.model_blocks import ResBlock |
|
|
|
|
| class Discriminator(nn.Module): |
| def __init__(self, input_nc, ndf=64, n_layers=6): |
| super(Discriminator, self).__init__() |
| sequence = [nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1)] |
| for i in range(n_layers): |
| if i >= 3: |
| sequence += [ResBlock(512, 512, down_sample=True, norm=False)] |
| else: |
| mult = 2**i |
| sequence += [ResBlock(ndf * mult, ndf * mult * 2, down_sample=True, norm=False)] |
| sequence += [ |
| nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=0), |
| nn.LeakyReLU(0.2, inplace=True), |
| nn.Conv2d(512, 2, kernel_size=1, stride=1, padding=0), |
| nn.LeakyReLU(0.2, inplace=True), |
| ] |
| self.sequence = nn.Sequential(*sequence) |
|
|
| def forward(self, input): |
| return self.sequence(input) |
|
|