| import torch |
| from torch import nn, optim |
|
|
| |
|
|
| |
| class Unet(nn.Module): |
| def __init__(self, input_c=1, output_c=2, num_filters=128): |
| super().__init__() |
| self.model = nn.Sequential( |
| nn.Conv2d(input_c,64,kernel_size=4,stride = 1,padding="same"), |
| nn.BatchNorm2d(64), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1), |
| nn.BatchNorm2d(128), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1), |
| nn.BatchNorm2d(256), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(256,256,kernel_size=4,stride=2,padding=1), |
| nn.BatchNorm2d(256), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1), |
| nn.BatchNorm2d(512), |
| nn.LeakyReLU(0.2, True), |
| nn.Conv2d(512,512,kernel_size=4,stride=2,padding=1), |
| nn.BatchNorm2d(512), |
| nn.LeakyReLU(0.2, True), |
| |
| nn.ConvTranspose2d(512,512,kernel_size=4,stride=2,padding=1), |
| nn.BatchNorm2d(512), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(512,256,kernel_size=4,stride=2,padding=1), |
| nn.BatchNorm2d(256), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(256,256,kernel_size=4,stride=2,padding=1), |
| nn.BatchNorm2d(256), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(256,128,kernel_size=4,stride=2,padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(True), |
| nn.Conv2d(64,output_c, kernel_size=1,stride=1), |
| nn.Tanh() |
| ) |
| |
| def forward(self, x): |
| return self.model(x) |
| class PatchDiscriminator(nn.Module): |
| def __init__(self, input_c, num_filters=64, n_down=3): |
| super().__init__() |
| model = [self.get_layers(input_c, num_filters, norm=False)] |
| model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) |
| for i in range(n_down)] |
| |
| model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] |
| |
| self.model = nn.Sequential(*model) |
| |
| def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): |
| layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)] |
| if norm: layers += [nn.BatchNorm2d(nf)] |
| if act: layers += [nn.LeakyReLU(0.2, True)] |
| return nn.Sequential(*layers) |
| |
| def forward(self, x): |
| return self.model(x) |
|
|
| class GANLoss(nn.Module): |
| def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0): |
| super().__init__() |
| self.register_buffer('real_label', torch.tensor(real_label)) |
| self.register_buffer('fake_label', torch.tensor(fake_label)) |
| if gan_mode == 'vanilla': |
| self.loss = nn.BCEWithLogitsLoss() |
| elif gan_mode == 'lsgan': |
| self.loss = nn.MSELoss() |
| |
| def get_labels(self, preds, target_is_real): |
| if target_is_real: |
| labels = self.real_label |
| else: |
| labels = self.fake_label |
| return labels.expand_as(preds) |
| |
| def __call__(self, preds, target_is_real): |
| labels = self.get_labels(preds, target_is_real) |
| loss = self.loss(preds, labels) |
| return loss |
|
|
| def init_weights(net, init='norm', gain=0.02): |
| |
| def init_func(m): |
| classname = m.__class__.__name__ |
| if hasattr(m, 'weight') and 'Conv' in classname: |
| if init == 'norm': |
| nn.init.normal_(m.weight.data, mean=0.0, std=gain) |
| elif init == 'xavier': |
| nn.init.xavier_normal_(m.weight.data, gain=gain) |
| elif init == 'kaiming': |
| nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
| |
| if hasattr(m, 'bias') and m.bias is not None: |
| nn.init.constant_(m.bias.data, 0.0) |
| elif 'BatchNorm2d' in classname: |
| nn.init.normal_(m.weight.data, 1., gain) |
| nn.init.constant_(m.bias.data, 0.) |
| |
| net.apply(init_func) |
| print(f"model initialized with {init} initialization") |
| return net |
|
|
| def init_model(model, device): |
| model = model.to(device) |
| model = init_weights(model) |
| return model |
|
|
| class MainModel(nn.Module): |
| def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, |
| beta1=0.5, beta2=0.999, lambda_L1=100.): |
| super().__init__() |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.lambda_L1 = lambda_L1 |
| |
| if net_G is None: |
| self.net_G = init_model(Unet(input_c=1, output_c=2, num_filters=64), self.device) |
| else: |
| self.net_G = net_G.to(self.device) |
| self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device) |
| self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device) |
| self.L1criterion = nn.L1Loss() |
| self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2)) |
| self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2)) |
| |
| def set_requires_grad(self, model, requires_grad=True): |
| for p in model.parameters(): |
| p.requires_grad = requires_grad |
| |
| def setup_input(self, data): |
| self.L = data['L'].to(self.device) |
| self.ab = data['ab'].to(self.device) |
| |
| def forward(self): |
| self.fake_color = self.net_G(self.L) |
| |
| def backward_D(self,epoch): |
| fake_image = torch.cat([self.L, self.fake_color], dim=1) |
| fake_preds = self.net_D(fake_image.detach()) |
| self.loss_D_fake = self.GANcriterion(fake_preds, False) |
| real_image = torch.cat([self.L, self.ab], dim=1) |
| real_preds = self.net_D(real_image) |
| self.loss_D_real = self.GANcriterion(real_preds, True) |
| self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 |
| |
| if epoch % 2 ==0: |
| self.loss_D.backward() |
| |
| def backward_G(self): |
| fake_image = torch.cat([self.L, self.fake_color], dim=1) |
| fake_preds = self.net_D(fake_image) |
| self.loss_G_GAN = self.GANcriterion(fake_preds, True) |
| self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1 |
| self.loss_G = self.loss_G_GAN + self.loss_G_L1 |
| self.loss_G.backward() |
| |
| def optimize(self, epoch): |
| self.forward() |
| self.net_D.train() |
| self.set_requires_grad(self.net_D, True) |
| self.opt_D.zero_grad() |
| self.backward_D(epoch) |
| if epoch % 2 ==0: |
| self.opt_D.step() |
| |
| self.net_G.train() |
| self.set_requires_grad(self.net_D, False) |
| self.opt_G.zero_grad() |
| self.backward_G() |
| self.opt_G.step() |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|