| import torch |
| import torch.nn as nn |
| import torch.nn.parallel |
| import torch.optim as optim |
| from torch.autograd import Variable |
|
|
|
|
| def weights_init(m): |
| classname = m.__class__.__name__ |
| |
| if classname.find('Conv') != -1: |
| m.weight.data.normal_(0.0, 0.02) |
| |
| elif classname.find('BatchNorm') != -1: |
| m.weight.data.normal_(1.0, 0.02) |
| m.bias.data.fill_(0) |
|
|
|
|
| ''' Generator network for 128x128 RGB images ''' |
| class G(nn.Module): |
| |
| def __init__(self): |
| super(G, self).__init__() |
| |
| self.main = nn.Sequential( |
| |
| nn.Conv2d(3, 16, 4, 2, 1), |
| nn.BatchNorm2d(16), |
| nn.ReLU(True), |
| nn.Conv2d(16, 32, 4, 2, 1), |
| nn.BatchNorm2d(32), |
| nn.ReLU(True), |
| nn.Conv2d(32, 64, 4, 2, 1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(True), |
| nn.Conv2d(64, 128, 4, 2, 1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(True), |
| nn.Conv2d(128, 256, 4, 2, 1), |
| nn.BatchNorm2d(256), |
| nn.ReLU(True), |
| nn.Conv2d(256, 512, 4, 2, 1), |
| nn.MaxPool2d((2,2)), |
| |
|
|
| nn.ConvTranspose2d(512, 256, 4, 1, 0, bias = False), |
| nn.BatchNorm2d(256), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False), |
| nn.BatchNorm2d(128), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False), |
| nn.BatchNorm2d(64), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(64, 32, 4, 2, 1, bias = False), |
| nn.BatchNorm2d(32), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(32, 16, 4, 2, 1, bias = False), |
| nn.BatchNorm2d(16), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(16, 3, 4, 2, 1, bias = False), |
| nn.Tanh() |
| ) |
|
|
| |
| def forward(self, input): |
| output = self.main(input) |
| return output |
|
|
|
|
| ''' Discriminator network for 128x128 RGB images ''' |
| class D(nn.Module): |
| |
| def __init__(self): |
| super(D, self).__init__() |
| self.main = nn.Sequential( |
| nn.Conv2d(3, 16, 4, 2, 1), |
| nn.LeakyReLU(0.2, inplace = True), |
| nn.Conv2d(16, 32, 4, 2, 1), |
| nn.BatchNorm2d(32), |
| nn.LeakyReLU(0.2, inplace = True), |
| nn.Conv2d(32, 64, 4, 2, 1), |
| nn.BatchNorm2d(64), |
| nn.LeakyReLU(0.2, inplace = True), |
| nn.Conv2d(64, 128, 4, 2, 1), |
| nn.BatchNorm2d(128), |
| nn.LeakyReLU(0.2, inplace = True), |
| nn.Conv2d(128, 256, 4, 2, 1), |
| nn.BatchNorm2d(256), |
| nn.LeakyReLU(0.2, inplace = True), |
| nn.Conv2d(256, 512, 4, 2, 1), |
| nn.BatchNorm2d(512), |
| nn.LeakyReLU(0.2, inplace = True), |
| nn.Conv2d(512, 1, 4, 2, 1, bias = False), |
| nn.Sigmoid() |
| ) |
| |
| |
| def forward(self, input): |
| output = self.main(input) |
| return output.view(-1) |
|
|