| import os |
| import sys |
| import time |
| import glob |
| import random |
| import skimage |
| import skimage.io |
| import numpy as np |
| from skimage import io, color |
|
|
| import skimage |
| import skimage.io |
| from PIL import Image |
| import cv2 |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
|
|
| import timm |
| import torchvision |
| from torchvision.models.feature_extraction import create_feature_extractor |
|
|
| L_range = 100 |
| ab_min = -128 |
| ab_max = 127 |
| ab_range = ab_max - ab_min |
| |
| def extract_zip(input_zip): |
| input_zip=ZipFile(input_zip) |
| return {name: input_zip.read(name) for name in input_zip.namelist()} |
|
|
| def normalize_lab_channels(x): |
| |
| x[:,:,0] = x[:,:,0] / L_range |
|
|
| |
| x[:,:,1] = (x[:,:,1]-ab_min) / ab_range |
| x[:,:,2] = (x[:,:,2]-ab_min) / ab_range |
| return x |
|
|
| def normalized_lab_to_rgb(lab): |
| lab[:,:,0] = (lab[:,:,0] * L_range) |
| lab[:,:,1] = (lab[:,:,1] * ab_range) + ab_min |
| lab[:,:,2] = (lab[:,:,2] * ab_range) + ab_min |
| return color.lab2rgb(lab) |
|
|
| def torch_normalized_lab_to_rgb(lab): |
| for i in range(lab.shape[0]): |
| lab[i,0,:,:] = torch.clip(lab[i,0,:,:] * L_range, 0, L_range) |
| lab[i,1,:,:] = torch.clip((lab[i,1,:,:] * ab_range) + ab_min, ab_min, ab_max) |
| lab[i,2,:,:] = torch.clip((lab[i,2,:,:] * ab_range) + ab_min, ab_min, ab_max) |
| |
| for i in range(lab.shape[0]): |
| lab[i] = torch.from_numpy( color.lab2rgb(lab[i].permute(1,2,0).detach().cpu().numpy()) ).permute(2,0,1) |
| |
| return lab |
|
|
| class Encoder(nn.Module): |
| def __init__(self): |
| super(Encoder, self).__init__() |
| self.backend_model = timm.create_model('efficientnetv2_rw_s', pretrained=True) |
| self.backend = create_feature_extractor(self.backend_model, |
| return_nodes=['blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'act2']) |
|
|
| def forward(self, x): |
| features = self.backend(x) |
| return list(features.values()) |
|
|
| class UpSample(nn.Sequential): |
| def __init__(self, in_channels, out_channels): |
| skip_input, output_features = in_channels, out_channels |
| |
| super(UpSample, self).__init__() |
| self.convA = nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False) |
| self.leakyreluA = nn.LeakyReLU(0.2) |
| self.convB = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False) |
| self.leakyreluB = nn.LeakyReLU(0.2) |
| |
| self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
|
|
| def forward(self, x, concat_with=None): |
| up_x = self.upsample(x) |
| |
| if concat_with is not None: |
| up_x = torch.cat([up_x, concat_with], dim=1) |
| |
| return self.leakyreluB( self.convB( self.leakyreluA( self.convA( up_x ) ) ) ) |
| |
| class Decoder(nn.Module): |
| def __init__(self, num_features=1792 * 1, decoder_width=None): |
| super(Decoder, self).__init__() |
| features = int(num_features * decoder_width) |
| |
| self.conv2 = nn.Sequential( |
| nn.Conv2d(num_features, features, kernel_size=1, stride=1, padding=0, bias=False), |
| nn.LeakyReLU(0.2), |
| ) |
|
|
| self.up1 = UpSample(in_channels=features//1 + 152 - 24, out_channels=features//2) |
| self.up2 = UpSample(in_channels=features//2 + 80 - 16, out_channels=features//4) |
| self.up3 = UpSample(in_channels=features//4 + 56 - 8, out_channels=features//8) |
| self.up4 = UpSample(in_channels=features//8 + 32 - 8, out_channels=features//16) |
| |
| self.up5 = UpSample(in_channels=features//16, out_channels=features//16) |
|
|
| self.conv3 = nn.Conv2d(features//16, 2, kernel_size=1, stride=1, padding=0, bias=False) |
| |
| def forward(self, features): |
| blocks0, blocks1, blocks2, blocks3, x = features |
| |
| x = self.conv2(x) |
| |
| x = self.up1(x, blocks3) |
| x = self.up2(x, blocks2) |
| x = self.up3(x, blocks1) |
| x = self.up4(x, blocks0) |
| |
| x = self.up5(x) |
| |
| x_final = self.conv3(x) |
| |
| return x_final |
| |
| class ColorizeNet(nn.Module): |
| def __init__(self, decoder_width): |
| super(ColorizeNet, self).__init__() |
| self.encoder = Encoder() |
| self.decoder = Decoder(decoder_width=decoder_width) |
|
|
| def forward(self, x): |
| features_x = self.encoder(x) |
| return self.decoder( features_x ) |
|
|