Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Variable | |
| import numpy as np | |
| import pdb | |
| from torch.nn import functional as F | |
| from torch.nn import init | |
| ''' | |
| ''' | |
| class Concat_embed4(nn.Module): | |
| def __init__(self, embed_dim, projected_embed_dim): | |
| super(Concat_embed4, self).__init__() | |
| self.projection = nn.Sequential( | |
| nn.Linear(in_features=embed_dim, out_features=embed_dim), | |
| nn.BatchNorm1d(num_features=embed_dim), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| nn.Linear(in_features=embed_dim, out_features=embed_dim), | |
| nn.BatchNorm1d(num_features=embed_dim), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| nn.Linear(in_features=embed_dim, out_features=projected_embed_dim), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| ) | |
| def forward(self, inp, embed): | |
| projected_embed = self.projection(embed) | |
| replicated_embed = projected_embed.repeat(4, 4, 1, 1).permute(2, 3, 0, 1) | |
| hidden_concat = torch.cat([inp, replicated_embed], 1) | |
| return hidden_concat | |
| class generator(nn.Module): | |
| def __init__(self): | |
| super(generator, self).__init__() | |
| self.image_size = 64 | |
| self.num_channels = 3 | |
| self.noise_dim = 100 | |
| self.embed_dim = 768 | |
| self.projected_embed_dim = 128 | |
| self.latent_dim = self.noise_dim + self.projected_embed_dim | |
| self.ngf = 64 | |
| self.projection = nn.Sequential( | |
| nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim), | |
| nn.BatchNorm1d(num_features=self.embed_dim), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim), | |
| nn.BatchNorm1d(num_features=self.embed_dim), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| nn.Linear(in_features=self.embed_dim, out_features=self.projected_embed_dim), | |
| nn.BatchNorm1d(num_features=self.projected_embed_dim), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
| ) | |
| self.netG = nn.ModuleList([ | |
| nn.ConvTranspose2d(self.latent_dim, self.ngf * 8, 4, 1, 0, bias=False), | |
| nn.BatchNorm2d(self.ngf * 8), | |
| nn.ReLU(True), | |
| # state size. (ngf*8) x 4 x 4 | |
| nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(self.ngf * 4), | |
| nn.ReLU(True), | |
| # state size. (ngf*4) x 8 x 8 | |
| nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(self.ngf * 2), | |
| nn.ReLU(True), | |
| # state size. (ngf*2) x 16 x 16 | |
| nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(self.ngf), | |
| nn.ReLU(True), | |
| # state size. (ngf) x 32 x 32 | |
| nn.ConvTranspose2d(self.ngf, self.num_channels, 4, 2, 1, bias=False), | |
| nn.Tanh() | |
| # state size. (num_channels) x 64 x 64 | |
| ]) | |
| def forward(self, embed_vector, z): | |
| projected_embed = self.projection(embed_vector) | |
| out = torch.cat([projected_embed.unsqueeze(2).unsqueeze(3), z], 1) | |
| for m in self.netG: | |
| out = m(out) | |
| return out | |
| class discriminator(nn.Module): | |
| def __init__(self): | |
| super(discriminator, self).__init__() | |
| self.image_size = 64 | |
| self.num_channels = 3 | |
| self.embed_dim = 768 | |
| self.projected_embed_dim = 128 | |
| self.ndf = 64 | |
| self.B_dim = 128 | |
| self.C_dim = 16 | |
| self.netD_1 = nn.Sequential( | |
| # input is (nc) x 64 x 64 | |
| nn.Conv2d(self.num_channels, self.ndf, 4, 2, 1, bias=False), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| # state size. (ndf) x 32 x 32 | |
| # SelfAttention(self.ndf), | |
| nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(self.ndf * 2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| # state size. (ndf*2) x 16 x 16 | |
| nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(self.ndf * 4), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| # state size. (ndf*4) x 8 x 8 | |
| nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(self.ndf * 8), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| self.projector = Concat_embed4(self.embed_dim, self.projected_embed_dim) | |
| self.netD_2 = nn.Sequential( | |
| # state size. (ndf*8) x 4 x 4 | |
| nn.Conv2d(self.ndf * 8 + self.projected_embed_dim, | |
| self.ndf * 8, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(self.ndf * 8), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, inp, embed): | |
| x_intermediate = self.netD_1(inp) | |
| x = self.projector(x_intermediate, embed) | |
| x = self.netD_2(x) | |
| return x.view(-1, 1).squeeze(1), x_intermediate | |