import torch print(torch.__version__) from torchvision.transforms import ToTensor import numpy as np from networks.models import Colorizer from denoising.denoiser import FFDNetDenoiser from utils.utils import resize_pad class MangaColorizator: def __init__(self, device, generator_path = 'networks/generator.zip', extractor_path = 'networks/extractor.pth'): self.colorizer = Colorizer().to(device) self.colorizer.generator.load_state_dict(torch.load(generator_path, map_location = device)) self.colorizer = self.colorizer.eval() self.denoiser = FFDNetDenoiser(device) self.current_image = None self.current_hint = None self.current_pad = None self.device = device def set_image(self, image, size = 576, apply_denoise = True, denoise_sigma = 25, transform = ToTensor()): if (size % 32 != 0): raise RuntimeError("size is not divisible by 32") if apply_denoise: image = self.denoiser.get_denoised_image(image, sigma = denoise_sigma) image, self.current_pad = resize_pad(image, size) self.current_image = transform(image).unsqueeze(0).to(self.device) self.current_hint = torch.zeros(1, 4, self.current_image.shape[2], self.current_image.shape[3]).float().to(self.device) def update_hint(self, hint, mask): ''' Args: hint: numpy.ndarray with shape (self.current_image.shape[2], self.current_image.shape[3], 3) mask: numpy.ndarray with shape (self.current_image.shape[2], self.current_image.shape[3]) ''' if issubclass(hint.dtype.type, np.integer): hint = hint.astype('float32') / 255 hint = (hint - 0.5) / 0.5 hint = torch.FloatTensor(hint).permute(2, 0, 1) mask = torch.FloatTensor(np.expand_dims(mask, 0)) self.current_hint = torch.cat([hint * mask, mask], 0).unsqueeze(0).to(self.device) def colorize(self): with torch.no_grad(): fake_color, _ = self.colorizer(torch.cat([self.current_image, self.current_hint], 1)) fake_color = fake_color.detach() result = fake_color[0].detach().cpu().permute(1, 2, 0) * 0.5 + 0.5 if self.current_pad[0] != 0: result = result[:-self.current_pad[0]] if self.current_pad[1] != 0: result = result[:, :-self.current_pad[1]] return result.numpy()