| import torch
|
| from torch import nn
|
| from torch.nn import functional as F
|
|
|
| from math import sqrt
|
|
|
|
|
| class EqualLR:
|
| def __init__(self, name):
|
| self.name = name
|
|
|
| def compute_weight(self, module):
|
| weight = getattr(module, self.name + '_orig')
|
| fan_in = weight.data.size(1) * weight.data[0][0].numel()
|
|
|
| return weight * sqrt(2 / fan_in)
|
|
|
| @staticmethod
|
| def apply(module, name):
|
| fn = EqualLR(name)
|
|
|
| weight = getattr(module, name)
|
| del module._parameters[name]
|
| module.register_parameter(name + '_orig', nn.Parameter(weight.data))
|
| module.register_forward_pre_hook(fn)
|
|
|
| return fn
|
|
|
| def __call__(self, module, input):
|
| weight = self.compute_weight(module)
|
| setattr(module, self.name, weight)
|
|
|
|
|
| def equal_lr(module, name='weight'):
|
| EqualLR.apply(module, name)
|
|
|
| return module
|
|
|
|
|
| class PixelNorm(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
|
|
| def forward(self, input):
|
| return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True)
|
| + 1e-8)
|
|
|
|
|
| class EqualConv2d(nn.Module):
|
| def __init__(self, *args, **kwargs):
|
| super().__init__()
|
|
|
| conv = nn.Conv2d(*args, **kwargs)
|
| conv.weight.data.normal_()
|
| conv.bias.data.zero_()
|
| self.conv = equal_lr(conv)
|
|
|
| def forward(self, input):
|
| return self.conv(input)
|
|
|
|
|
| class EqualConvTranspose2d(nn.Module):
|
|
|
| def __init__(self, *args, **kwargs):
|
| super().__init__()
|
|
|
| conv = nn.ConvTranspose2d(*args, **kwargs)
|
| conv.weight.data.normal_()
|
| conv.bias.data.zero_()
|
| self.conv = equal_lr(conv)
|
|
|
| def forward(self, input):
|
| return self.conv(input)
|
|
|
| class EqualLinear(nn.Module):
|
| def __init__(self, in_dim, out_dim):
|
| super().__init__()
|
|
|
| linear = nn.Linear(in_dim, out_dim)
|
| linear.weight.data.normal_()
|
| linear.bias.data.zero_()
|
|
|
| self.linear = equal_lr(linear)
|
|
|
| def forward(self, input):
|
| return self.linear(input)
|
|
|
|
|
| class ConvBlock(nn.Module):
|
| def __init__(self, in_channel, out_channel, kernel_size, padding, kernel_size2=None, padding2=None, pixel_norm=True):
|
| super().__init__()
|
|
|
| pad1 = padding
|
| pad2 = padding
|
| if padding2 is not None:
|
| pad2 = padding2
|
|
|
| kernel1 = kernel_size
|
| kernel2 = kernel_size
|
| if kernel_size2 is not None:
|
| kernel2 = kernel_size2
|
|
|
| convs = [EqualConv2d(in_channel, out_channel, kernel1, padding=pad1)]
|
| if pixel_norm:
|
| convs.append(PixelNorm())
|
| convs.append(nn.LeakyReLU(0.1))
|
| convs.append(EqualConv2d(out_channel, out_channel, kernel2, padding=pad2))
|
| if pixel_norm:
|
| convs.append(PixelNorm())
|
| convs.append(nn.LeakyReLU(0.1))
|
|
|
| self.conv = nn.Sequential(*convs)
|
|
|
| def forward(self, input):
|
| out = self.conv(input)
|
| return out
|
|
|
|
|
| def upscale(feat):
|
| return F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False)
|
|
|
| class Generator(nn.Module):
|
| def __init__(self, input_code_dim=128, in_channel=128, pixel_norm=True, tanh=True):
|
| super().__init__()
|
| self.input_dim = input_code_dim
|
| self.tanh = tanh
|
| self.input_layer = nn.Sequential(
|
| EqualConvTranspose2d(input_code_dim, in_channel, 4, 1, 0),
|
| PixelNorm(),
|
| nn.LeakyReLU(0.1))
|
|
|
| self.progression_4 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
|
| self.progression_8 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
|
| self.progression_16 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
|
| self.progression_32 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
|
| self.progression_64 = ConvBlock(in_channel, in_channel//2, 3, 1, pixel_norm=pixel_norm)
|
| self.progression_128 = ConvBlock(in_channel//2, in_channel//4, 3, 1, pixel_norm=pixel_norm)
|
| self.progression_256 = ConvBlock(in_channel//4, in_channel//4, 3, 1, pixel_norm=pixel_norm)
|
|
|
| self.to_rgb_8 = EqualConv2d(in_channel, 3, 1)
|
| self.to_rgb_16 = EqualConv2d(in_channel, 3, 1)
|
| self.to_rgb_32 = EqualConv2d(in_channel, 3, 1)
|
| self.to_rgb_64 = EqualConv2d(in_channel//2, 3, 1)
|
| self.to_rgb_128 = EqualConv2d(in_channel//4, 3, 1)
|
| self.to_rgb_256 = EqualConv2d(in_channel//4, 3, 1)
|
|
|
| self.max_step = 6
|
|
|
| def progress(self, feat, module):
|
| out = F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False)
|
| out = module(out)
|
| return out
|
|
|
| def output(self, feat1, feat2, module1, module2, alpha):
|
| if 0 <= alpha < 1:
|
| skip_rgb = upscale(module1(feat1))
|
| out = (1-alpha)*skip_rgb + alpha*module2(feat2)
|
| else:
|
| out = module2(feat2)
|
| if self.tanh:
|
| return torch.tanh(out)
|
| return out
|
|
|
| def forward(self, input, step=0, alpha=-1):
|
| if step > self.max_step:
|
| step = self.max_step
|
|
|
| out_4 = self.input_layer(input.view(-1, self.input_dim, 1, 1))
|
| out_4 = self.progression_4(out_4)
|
| out_8 = self.progress(out_4, self.progression_8)
|
| if step==1:
|
| if self.tanh:
|
| return torch.tanh(self.to_rgb_8(out_8))
|
| return self.to_rgb_8(out_8)
|
|
|
| out_16 = self.progress(out_8, self.progression_16)
|
| if step==2:
|
| return self.output( out_8, out_16, self.to_rgb_8, self.to_rgb_16, alpha )
|
|
|
| out_32 = self.progress(out_16, self.progression_32)
|
| if step==3:
|
| return self.output( out_16, out_32, self.to_rgb_16, self.to_rgb_32, alpha )
|
|
|
| out_64 = self.progress(out_32, self.progression_64)
|
| if step==4:
|
| return self.output( out_32, out_64, self.to_rgb_32, self.to_rgb_64, alpha )
|
|
|
| out_128 = self.progress(out_64, self.progression_128)
|
| if step==5:
|
| return self.output( out_64, out_128, self.to_rgb_64, self.to_rgb_128, alpha )
|
|
|
| out_256 = self.progress(out_128, self.progression_256)
|
| if step==6:
|
| return self.output( out_128, out_256, self.to_rgb_128, self.to_rgb_256, alpha )
|
|
|
|
|
| class Discriminator(nn.Module):
|
| def __init__(self, feat_dim=128):
|
| super().__init__()
|
|
|
| self.progression = nn.ModuleList([ConvBlock(feat_dim//4, feat_dim//4, 3, 1),
|
| ConvBlock(feat_dim//4, feat_dim//2, 3, 1),
|
| ConvBlock(feat_dim//2, feat_dim, 3, 1),
|
| ConvBlock(feat_dim, feat_dim, 3, 1),
|
| ConvBlock(feat_dim, feat_dim, 3, 1),
|
| ConvBlock(feat_dim, feat_dim, 3, 1),
|
| ConvBlock(feat_dim+1, feat_dim, 3, 1, 4, 0)])
|
|
|
| self.from_rgb = nn.ModuleList([EqualConv2d(3, feat_dim//4, 1),
|
| EqualConv2d(3, feat_dim//4, 1),
|
| EqualConv2d(3, feat_dim//2, 1),
|
| EqualConv2d(3, feat_dim, 1),
|
| EqualConv2d(3, feat_dim, 1),
|
| EqualConv2d(3, feat_dim, 1),
|
| EqualConv2d(3, feat_dim, 1)])
|
|
|
| self.n_layer = len(self.progression)
|
|
|
| self.linear = EqualLinear(feat_dim, 1)
|
|
|
| def forward(self, input, step=0, alpha=-1):
|
| for i in range(step, -1, -1):
|
| index = self.n_layer - i - 1
|
|
|
| if i == step:
|
| out = self.from_rgb[index](input)
|
|
|
| if i == 0:
|
| out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
|
| mean_std = out_std.mean()
|
| mean_std = mean_std.expand(out.size(0), 1, 4, 4)
|
| out = torch.cat([out, mean_std], 1)
|
|
|
| out = self.progression[index](out)
|
|
|
| if i > 0:
|
|
|
| out = F.interpolate(out, scale_factor=0.5, mode='bilinear', align_corners=False)
|
|
|
| if i == step and 0 <= alpha < 1:
|
|
|
| skip_rgb = F.interpolate(input, scale_factor=0.5, mode='bilinear', align_corners=False)
|
| skip_rgb = self.from_rgb[index + 1](skip_rgb)
|
| out = (1 - alpha) * skip_rgb + alpha * out
|
|
|
| out = out.squeeze(2).squeeze(2)
|
|
|
| out = self.linear(out)
|
|
|
| return out |