| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import numpy as np |
| import math |
| import functools |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import init |
| import torch.optim as optim |
| import torch.nn.functional as F |
| from torch.nn import Parameter as P |
|
|
| import layers |
| from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d |
|
|
| |
|
|
|
|
| |
| |
| |
|
|
| |
| class GBlock(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| which_conv=nn.Conv2d, |
| which_bn=layers.bn, |
| activation=None, |
| upsample=None, |
| channel_ratio=4, |
| ): |
| super(GBlock, self).__init__() |
|
|
| self.in_channels, self.out_channels = in_channels, out_channels |
| self.hidden_channels = self.in_channels // channel_ratio |
| self.which_conv, self.which_bn = which_conv, which_bn |
| self.activation = activation |
| |
| self.conv1 = self.which_conv( |
| self.in_channels, self.hidden_channels, kernel_size=1, padding=0 |
| ) |
| self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels) |
| self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels) |
| self.conv4 = self.which_conv( |
| self.hidden_channels, self.out_channels, kernel_size=1, padding=0 |
| ) |
| |
| self.bn1 = self.which_bn(self.in_channels) |
| self.bn2 = self.which_bn(self.hidden_channels) |
| self.bn3 = self.which_bn(self.hidden_channels) |
| self.bn4 = self.which_bn(self.hidden_channels) |
| |
| self.upsample = upsample |
|
|
| def forward(self, x, y): |
| |
| h = self.conv1(self.activation(self.bn1(x, y))) |
| |
| h = self.activation(self.bn2(h, y)) |
| |
| if self.in_channels != self.out_channels: |
| x = x[:, : self.out_channels] |
| |
| if self.upsample: |
| h = self.upsample(h) |
| x = self.upsample(x) |
| |
| h = self.conv2(h) |
| h = self.conv3(self.activation(self.bn3(h, y))) |
| |
| h = self.conv4(self.activation(self.bn4(h, y))) |
| return h + x |
|
|
|
|
| def G_arch(ch=64, attention="64", ksize="333333", dilation="111111"): |
| arch = {} |
| arch[256] = { |
| "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]], |
| "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]], |
| "upsample": [True] * 6, |
| "resolution": [8, 16, 32, 64, 128, 256], |
| "attention": { |
| 2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
| for i in range(3, 9) |
| }, |
| } |
| arch[128] = { |
| "in_channels": [ch * item for item in [16, 16, 8, 4, 2]], |
| "out_channels": [ch * item for item in [16, 8, 4, 2, 1]], |
| "upsample": [True] * 5, |
| "resolution": [8, 16, 32, 64, 128], |
| "attention": { |
| 2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
| for i in range(3, 8) |
| }, |
| } |
| arch[64] = { |
| "in_channels": [ch * item for item in [16, 16, 8, 4]], |
| "out_channels": [ch * item for item in [16, 8, 4, 2]], |
| "upsample": [True] * 4, |
| "resolution": [8, 16, 32, 64], |
| "attention": { |
| 2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
| for i in range(3, 7) |
| }, |
| } |
| arch[32] = { |
| "in_channels": [ch * item for item in [4, 4, 4]], |
| "out_channels": [ch * item for item in [4, 4, 4]], |
| "upsample": [True] * 3, |
| "resolution": [8, 16, 32], |
| "attention": { |
| 2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
| for i in range(3, 6) |
| }, |
| } |
|
|
| return arch |
|
|
|
|
| class Generator(nn.Module): |
| def __init__( |
| self, |
| G_ch=64, |
| G_depth=2, |
| dim_z=128, |
| bottom_width=4, |
| resolution=128, |
| G_kernel_size=3, |
| G_attn="64", |
| n_classes=1000, |
| num_G_SVs=1, |
| num_G_SV_itrs=1, |
| G_shared=True, |
| shared_dim=0, |
| hier=False, |
| cross_replica=False, |
| mybn=False, |
| G_activation=nn.ReLU(inplace=False), |
| G_lr=5e-5, |
| G_B1=0.0, |
| G_B2=0.999, |
| adam_eps=1e-8, |
| BN_eps=1e-5, |
| SN_eps=1e-12, |
| G_mixed_precision=False, |
| G_fp16=False, |
| G_init="ortho", |
| skip_init=False, |
| no_optim=False, |
| G_param="SN", |
| norm_style="bn", |
| **kwargs |
| ): |
| super(Generator, self).__init__() |
| |
| self.ch = G_ch |
| |
| self.G_depth = G_depth |
| |
| self.dim_z = dim_z |
| |
| self.bottom_width = bottom_width |
| |
| self.resolution = resolution |
| |
| self.kernel_size = G_kernel_size |
| |
| self.attention = G_attn |
| |
| self.n_classes = n_classes |
| |
| self.G_shared = G_shared |
| |
| self.shared_dim = shared_dim if shared_dim > 0 else dim_z |
| |
| self.hier = hier |
| |
| self.cross_replica = cross_replica |
| |
| self.mybn = mybn |
| |
| self.activation = G_activation |
| |
| self.init = G_init |
| |
| self.G_param = G_param |
| |
| self.norm_style = norm_style |
| |
| self.BN_eps = BN_eps |
| |
| self.SN_eps = SN_eps |
| |
| self.fp16 = G_fp16 |
| |
| self.arch = G_arch(self.ch, self.attention)[resolution] |
|
|
| |
| if self.G_param == "SN": |
| self.which_conv = functools.partial( |
| layers.SNConv2d, |
| kernel_size=3, |
| padding=1, |
| num_svs=num_G_SVs, |
| num_itrs=num_G_SV_itrs, |
| eps=self.SN_eps, |
| ) |
| self.which_linear = functools.partial( |
| layers.SNLinear, |
| num_svs=num_G_SVs, |
| num_itrs=num_G_SV_itrs, |
| eps=self.SN_eps, |
| ) |
| else: |
| self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) |
| self.which_linear = nn.Linear |
|
|
| |
| |
| self.which_embedding = nn.Embedding |
| bn_linear = ( |
| functools.partial(self.which_linear, bias=False) |
| if self.G_shared |
| else self.which_embedding |
| ) |
| self.which_bn = functools.partial( |
| layers.ccbn, |
| which_linear=bn_linear, |
| cross_replica=self.cross_replica, |
| mybn=self.mybn, |
| input_size=( |
| self.shared_dim + self.dim_z if self.G_shared else self.n_classes |
| ), |
| norm_style=self.norm_style, |
| eps=self.BN_eps, |
| ) |
|
|
| |
| |
| self.shared = ( |
| self.which_embedding(n_classes, self.shared_dim) |
| if G_shared |
| else layers.identity() |
| ) |
| |
| self.linear = self.which_linear( |
| self.dim_z + self.shared_dim, |
| self.arch["in_channels"][0] * (self.bottom_width ** 2), |
| ) |
|
|
| |
| |
| |
| self.blocks = [] |
| for index in range(len(self.arch["out_channels"])): |
| self.blocks += [ |
| [ |
| GBlock( |
| in_channels=self.arch["in_channels"][index], |
| out_channels=self.arch["in_channels"][index] |
| if g_index == 0 |
| else self.arch["out_channels"][index], |
| which_conv=self.which_conv, |
| which_bn=self.which_bn, |
| activation=self.activation, |
| upsample=( |
| functools.partial(F.interpolate, scale_factor=2) |
| if self.arch["upsample"][index] |
| and g_index == (self.G_depth - 1) |
| else None |
| ), |
| ) |
| ] |
| for g_index in range(self.G_depth) |
| ] |
|
|
| |
| if self.arch["attention"][self.arch["resolution"][index]]: |
| print( |
| "Adding attention layer in G at resolution %d" |
| % self.arch["resolution"][index] |
| ) |
| self.blocks[-1] += [ |
| layers.Attention(self.arch["out_channels"][index], self.which_conv) |
| ] |
|
|
| |
| self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) |
|
|
| |
| |
| self.output_layer = nn.Sequential( |
| layers.bn( |
| self.arch["out_channels"][-1], |
| cross_replica=self.cross_replica, |
| mybn=self.mybn, |
| ), |
| self.activation, |
| self.which_conv(self.arch["out_channels"][-1], 3), |
| ) |
|
|
| |
| if not skip_init: |
| self.init_weights() |
|
|
| |
| |
| if no_optim: |
| return |
| self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps |
| if G_mixed_precision: |
| print("Using fp16 adam in G...") |
| import utils |
|
|
| self.optim = utils.Adam16( |
| params=self.parameters(), |
| lr=self.lr, |
| betas=(self.B1, self.B2), |
| weight_decay=0, |
| eps=self.adam_eps, |
| ) |
| else: |
| self.optim = optim.Adam( |
| params=self.parameters(), |
| lr=self.lr, |
| betas=(self.B1, self.B2), |
| weight_decay=0, |
| eps=self.adam_eps, |
| ) |
|
|
| |
| |
| |
|
|
| |
| def init_weights(self): |
| self.param_count = 0 |
| for module in self.modules(): |
| if ( |
| isinstance(module, nn.Conv2d) |
| or isinstance(module, nn.Linear) |
| or isinstance(module, nn.Embedding) |
| ): |
| if self.init == "ortho": |
| init.orthogonal_(module.weight) |
| elif self.init == "N02": |
| init.normal_(module.weight, 0, 0.02) |
| elif self.init in ["glorot", "xavier"]: |
| init.xavier_uniform_(module.weight) |
| else: |
| print("Init style not recognized...") |
| self.param_count += sum( |
| [p.data.nelement() for p in module.parameters()] |
| ) |
| print("Param count for G" "s initialized parameters: %d" % self.param_count) |
|
|
| |
| |
| |
| |
| |
| def forward(self, z, y): |
| |
| if self.hier: |
| z = torch.cat([y, z], 1) |
| y = z |
| |
| h = self.linear(z) |
| |
| h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) |
| |
| for index, blocklist in enumerate(self.blocks): |
| |
| for block in blocklist: |
| h = block(h, y) |
|
|
| |
| return torch.tanh(self.output_layer(h)) |
|
|
|
|
| class DBlock(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| which_conv=layers.SNConv2d, |
| wide=True, |
| preactivation=True, |
| activation=None, |
| downsample=None, |
| channel_ratio=4, |
| ): |
| super(DBlock, self).__init__() |
| self.in_channels, self.out_channels = in_channels, out_channels |
| |
| self.hidden_channels = self.out_channels // channel_ratio |
| self.which_conv = which_conv |
| self.preactivation = preactivation |
| self.activation = activation |
| self.downsample = downsample |
|
|
| |
| self.conv1 = self.which_conv( |
| self.in_channels, self.hidden_channels, kernel_size=1, padding=0 |
| ) |
| self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels) |
| self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels) |
| self.conv4 = self.which_conv( |
| self.hidden_channels, self.out_channels, kernel_size=1, padding=0 |
| ) |
|
|
| self.learnable_sc = True if (in_channels != out_channels) else False |
| if self.learnable_sc: |
| self.conv_sc = self.which_conv( |
| in_channels, out_channels - in_channels, kernel_size=1, padding=0 |
| ) |
|
|
| def shortcut(self, x): |
| if self.downsample: |
| x = self.downsample(x) |
| if self.learnable_sc: |
| x = torch.cat([x, self.conv_sc(x)], 1) |
| return x |
|
|
| def forward(self, x): |
| |
| h = self.conv1(F.relu(x)) |
| |
| h = self.conv2(self.activation(h)) |
| h = self.conv3(self.activation(h)) |
| |
| h = self.activation(h) |
| |
| if self.downsample: |
| h = self.downsample(h) |
| |
| h = self.conv4(h) |
| return h + self.shortcut(x) |
|
|
|
|
| |
| def D_arch(ch=64, attention="64", ksize="333333", dilation="111111"): |
| arch = {} |
| arch[256] = { |
| "in_channels": [item * ch for item in [1, 2, 4, 8, 8, 16]], |
| "out_channels": [item * ch for item in [2, 4, 8, 8, 16, 16]], |
| "downsample": [True] * 6 + [False], |
| "resolution": [128, 64, 32, 16, 8, 4, 4], |
| "attention": { |
| 2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
| for i in range(2, 8) |
| }, |
| } |
| arch[128] = { |
| "in_channels": [item * ch for item in [1, 2, 4, 8, 16]], |
| "out_channels": [item * ch for item in [2, 4, 8, 16, 16]], |
| "downsample": [True] * 5 + [False], |
| "resolution": [64, 32, 16, 8, 4, 4], |
| "attention": { |
| 2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
| for i in range(2, 8) |
| }, |
| } |
| arch[64] = { |
| "in_channels": [item * ch for item in [1, 2, 4, 8]], |
| "out_channels": [item * ch for item in [2, 4, 8, 16]], |
| "downsample": [True] * 4 + [False], |
| "resolution": [32, 16, 8, 4, 4], |
| "attention": { |
| 2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
| for i in range(2, 7) |
| }, |
| } |
| arch[32] = { |
| "in_channels": [item * ch for item in [4, 4, 4]], |
| "out_channels": [item * ch for item in [4, 4, 4]], |
| "downsample": [True, True, False, False], |
| "resolution": [16, 16, 16, 16], |
| "attention": { |
| 2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
| for i in range(2, 6) |
| }, |
| } |
| return arch |
|
|
|
|
| class Discriminator(nn.Module): |
| def __init__( |
| self, |
| D_ch=64, |
| D_wide=True, |
| D_depth=2, |
| resolution=128, |
| D_kernel_size=3, |
| D_attn="64", |
| n_classes=1000, |
| num_D_SVs=1, |
| num_D_SV_itrs=1, |
| D_activation=nn.ReLU(inplace=False), |
| D_lr=2e-4, |
| D_B1=0.0, |
| D_B2=0.999, |
| adam_eps=1e-8, |
| SN_eps=1e-12, |
| output_dim=1, |
| D_mixed_precision=False, |
| D_fp16=False, |
| D_init="ortho", |
| skip_init=False, |
| D_param="SN", |
| **kwargs |
| ): |
| super(Discriminator, self).__init__() |
| |
| self.ch = D_ch |
| |
| self.D_wide = D_wide |
| |
| self.D_depth = D_depth |
| |
| self.resolution = resolution |
| |
| self.kernel_size = D_kernel_size |
| |
| self.attention = D_attn |
| |
| self.n_classes = n_classes |
| |
| self.activation = D_activation |
| |
| self.init = D_init |
| |
| self.D_param = D_param |
| |
| self.SN_eps = SN_eps |
| |
| self.fp16 = D_fp16 |
| |
| self.arch = D_arch(self.ch, self.attention)[resolution] |
|
|
| |
| |
| if self.D_param == "SN": |
| self.which_conv = functools.partial( |
| layers.SNConv2d, |
| kernel_size=3, |
| padding=1, |
| num_svs=num_D_SVs, |
| num_itrs=num_D_SV_itrs, |
| eps=self.SN_eps, |
| ) |
| self.which_linear = functools.partial( |
| layers.SNLinear, |
| num_svs=num_D_SVs, |
| num_itrs=num_D_SV_itrs, |
| eps=self.SN_eps, |
| ) |
| self.which_embedding = functools.partial( |
| layers.SNEmbedding, |
| num_svs=num_D_SVs, |
| num_itrs=num_D_SV_itrs, |
| eps=self.SN_eps, |
| ) |
|
|
| |
| |
| self.input_conv = self.which_conv(3, self.arch["in_channels"][0]) |
| |
| |
| self.blocks = [] |
| for index in range(len(self.arch["out_channels"])): |
| self.blocks += [ |
| [ |
| DBlock( |
| in_channels=self.arch["in_channels"][index] |
| if d_index == 0 |
| else self.arch["out_channels"][index], |
| out_channels=self.arch["out_channels"][index], |
| which_conv=self.which_conv, |
| wide=self.D_wide, |
| activation=self.activation, |
| preactivation=True, |
| downsample=( |
| nn.AvgPool2d(2) |
| if self.arch["downsample"][index] and d_index == 0 |
| else None |
| ), |
| ) |
| for d_index in range(self.D_depth) |
| ] |
| ] |
| |
| if self.arch["attention"][self.arch["resolution"][index]]: |
| print( |
| "Adding attention layer in D at resolution %d" |
| % self.arch["resolution"][index] |
| ) |
| self.blocks[-1] += [ |
| layers.Attention(self.arch["out_channels"][index], self.which_conv) |
| ] |
| |
| self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) |
| |
| |
| self.linear = self.which_linear(self.arch["out_channels"][-1], output_dim) |
| |
| self.embed = self.which_embedding(self.n_classes, self.arch["out_channels"][-1]) |
|
|
| |
| if not skip_init: |
| self.init_weights() |
|
|
| |
| self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps |
| if D_mixed_precision: |
| print("Using fp16 adam in D...") |
| import utils |
|
|
| self.optim = utils.Adam16( |
| params=self.parameters(), |
| lr=self.lr, |
| betas=(self.B1, self.B2), |
| weight_decay=0, |
| eps=self.adam_eps, |
| ) |
| else: |
| self.optim = optim.Adam( |
| params=self.parameters(), |
| lr=self.lr, |
| betas=(self.B1, self.B2), |
| weight_decay=0, |
| eps=self.adam_eps, |
| ) |
| |
| |
| |
|
|
| |
| def init_weights(self): |
| self.param_count = 0 |
| for module in self.modules(): |
| if ( |
| isinstance(module, nn.Conv2d) |
| or isinstance(module, nn.Linear) |
| or isinstance(module, nn.Embedding) |
| ): |
| if self.init == "ortho": |
| init.orthogonal_(module.weight) |
| elif self.init == "N02": |
| init.normal_(module.weight, 0, 0.02) |
| elif self.init in ["glorot", "xavier"]: |
| init.xavier_uniform_(module.weight) |
| else: |
| print("Init style not recognized...") |
| self.param_count += sum( |
| [p.data.nelement() for p in module.parameters()] |
| ) |
| print("Param count for D" "s initialized parameters: %d" % self.param_count) |
|
|
| def forward(self, x, y=None): |
| |
| h = self.input_conv(x) |
| |
| for index, blocklist in enumerate(self.blocks): |
| for block in blocklist: |
| h = block(h) |
| |
| h = torch.sum(self.activation(h), [2, 3]) |
| |
| out = self.linear(h) |
| |
| out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) |
| return out |
|
|
|
|
| |
| |
| class G_D(nn.Module): |
| def __init__(self, G, D): |
| super(G_D, self).__init__() |
| self.G = G |
| self.D = D |
|
|
| def forward( |
| self, z, gy, x=None, dy=None, train_G=False, return_G_z=False, split_D=False |
| ): |
| |
| with torch.set_grad_enabled(train_G): |
| |
| G_z = self.G(z, self.G.shared(gy)) |
| |
| if self.G.fp16 and not self.D.fp16: |
| G_z = G_z.float() |
| if self.D.fp16 and not self.G.fp16: |
| G_z = G_z.half() |
| |
| |
| if split_D: |
| D_fake = self.D(G_z, gy) |
| if x is not None: |
| D_real = self.D(x, dy) |
| return D_fake, D_real |
| else: |
| if return_G_z: |
| return D_fake, G_z |
| else: |
| return D_fake |
| |
| |
| else: |
| D_input = torch.cat([G_z, x], 0) if x is not None else G_z |
| D_class = torch.cat([gy, dy], 0) if dy is not None else gy |
| |
| D_out = self.D(D_input, D_class) |
| if x is not None: |
| return torch.split(D_out, [G_z.shape[0], x.shape[0]]) |
| else: |
| if return_G_z: |
| return D_out, G_z |
| else: |
| return D_out |
|
|