Spaces:
Runtime error
Runtime error
| from typing import Any, List, Callable | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.nn.utils.spectral_norm as SpectralNorm | |
| import threading | |
| from torchvision.ops import roi_align | |
| from math import sqrt | |
| from torchvision.transforms.functional import normalize | |
| from roop.typing import Face, Frame, FaceSet | |
| THREAD_LOCK_DMDNET = threading.Lock() | |
| class Enhance_DMDNet: | |
| plugin_options: dict = None | |
| model_dmdnet = None | |
| torchdevice = None | |
| processorname = "dmdnet" | |
| type = "enhance" | |
| def Initialize(self, plugin_options: dict): | |
| if self.plugin_options is not None: | |
| if self.plugin_options["devicename"] != plugin_options["devicename"]: | |
| self.Release() | |
| self.plugin_options = plugin_options | |
| if self.model_dmdnet is None: | |
| self.model_dmdnet = self.create(self.plugin_options["devicename"]) | |
| # temp_frame already cropped+aligned, bbox not | |
| def Run( | |
| self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame | |
| ) -> Frame: | |
| input_size = temp_frame.shape[1] | |
| result = self.enhance_face(source_faceset, temp_frame, target_face) | |
| scale_factor = int(result.shape[1] / input_size) | |
| return result.astype(np.uint8), scale_factor | |
| def Release(self): | |
| self.model_dmdnet = None | |
| # https://stackoverflow.com/a/67174339 | |
| def landmarks106_to_68(self, pt106): | |
| map106to68 = [ | |
| 1, | |
| 10, | |
| 12, | |
| 14, | |
| 16, | |
| 3, | |
| 5, | |
| 7, | |
| 0, | |
| 23, | |
| 21, | |
| 19, | |
| 32, | |
| 30, | |
| 28, | |
| 26, | |
| 17, | |
| 43, | |
| 48, | |
| 49, | |
| 51, | |
| 50, | |
| 102, | |
| 103, | |
| 104, | |
| 105, | |
| 101, | |
| 72, | |
| 73, | |
| 74, | |
| 86, | |
| 78, | |
| 79, | |
| 80, | |
| 85, | |
| 84, | |
| 35, | |
| 41, | |
| 42, | |
| 39, | |
| 37, | |
| 36, | |
| 89, | |
| 95, | |
| 96, | |
| 93, | |
| 91, | |
| 90, | |
| 52, | |
| 64, | |
| 63, | |
| 71, | |
| 67, | |
| 68, | |
| 61, | |
| 58, | |
| 59, | |
| 53, | |
| 56, | |
| 55, | |
| 65, | |
| 66, | |
| 62, | |
| 70, | |
| 69, | |
| 57, | |
| 60, | |
| 54, | |
| ] | |
| pt68 = [] | |
| for i in range(68): | |
| index = map106to68[i] | |
| pt68.append(pt106[index]) | |
| return pt68 | |
| def check_bbox(self, imgs, boxes): | |
| boxes = boxes.view(-1, 4, 4) | |
| colors = [(0, 255, 0), (0, 255, 0), (255, 255, 0), (255, 0, 0)] | |
| i = 0 | |
| for img, box in zip(imgs, boxes): | |
| img = (img + 1) / 2 * 255 | |
| img2 = img.permute(1, 2, 0).float().cpu().flip(2).numpy().copy() | |
| for idx, point in enumerate(box): | |
| cv2.rectangle( | |
| img2, | |
| (int(point[0]), int(point[1])), | |
| (int(point[2]), int(point[3])), | |
| color=colors[idx], | |
| thickness=2, | |
| ) | |
| cv2.imwrite("dmdnet_{:02d}.png".format(i), img2) | |
| i += 1 | |
| def trans_points2d(self, pts, M): | |
| new_pts = np.zeros(shape=pts.shape, dtype=np.float32) | |
| for i in range(pts.shape[0]): | |
| pt = pts[i] | |
| new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) | |
| new_pt = np.dot(M, new_pt) | |
| new_pts[i] = new_pt[0:2] | |
| return new_pts | |
| def enhance_face(self, ref_faceset: FaceSet, temp_frame, face: Face): | |
| # preprocess | |
| start_x, start_y, end_x, end_y = map(int, face["bbox"]) | |
| lm106 = face.landmark_2d_106 | |
| lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) | |
| if temp_frame.shape[0] != 512 or temp_frame.shape[1] != 512: | |
| # scale to 512x512 | |
| scale_factor = 512 / temp_frame.shape[1] | |
| M = face.matrix * scale_factor | |
| lq_landmarks = self.trans_points2d(lq_landmarks, M) | |
| temp_frame = cv2.resize( | |
| temp_frame, (512, 512), interpolation=cv2.INTER_AREA | |
| ) | |
| if temp_frame.ndim == 2: | |
| temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG | |
| # else: | |
| # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB | |
| lq = read_img_tensor(temp_frame) | |
| LQLocs = get_component_location(lq_landmarks) | |
| # self.check_bbox(lq, LQLocs.unsqueeze(0)) | |
| # specific, change 1000 to 1 to activate | |
| if len(ref_faceset.faces) > 1: | |
| SpecificImgs = [] | |
| SpecificLocs = [] | |
| for i, face in enumerate(ref_faceset.faces): | |
| lm106 = face.landmark_2d_106 | |
| lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) | |
| ref_image = ref_faceset.ref_images[i] | |
| if ref_image.shape[0] != 512 or ref_image.shape[1] != 512: | |
| # scale to 512x512 | |
| scale_factor = 512 / ref_image.shape[1] | |
| M = face.matrix * scale_factor | |
| lq_landmarks = self.trans_points2d(lq_landmarks, M) | |
| ref_image = cv2.resize( | |
| ref_image, (512, 512), interpolation=cv2.INTER_AREA | |
| ) | |
| if ref_image.ndim == 2: | |
| temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG | |
| # else: | |
| # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB | |
| ref_tensor = read_img_tensor(ref_image) | |
| ref_locs = get_component_location(lq_landmarks) | |
| # self.check_bbox(ref_tensor, ref_locs.unsqueeze(0)) | |
| SpecificImgs.append(ref_tensor) | |
| SpecificLocs.append(ref_locs.unsqueeze(0)) | |
| SpecificImgs = torch.cat(SpecificImgs, dim=0) | |
| SpecificLocs = torch.cat(SpecificLocs, dim=0) | |
| # check_bbox(SpecificImgs, SpecificLocs) | |
| SpMem256, SpMem128, SpMem64 = ( | |
| self.model_dmdnet.generate_specific_dictionary( | |
| sp_imgs=SpecificImgs.to(self.torchdevice), sp_locs=SpecificLocs | |
| ) | |
| ) | |
| SpMem256Para = {} | |
| SpMem128Para = {} | |
| SpMem64Para = {} | |
| for k, v in SpMem256.items(): | |
| SpMem256Para[k] = v | |
| for k, v in SpMem128.items(): | |
| SpMem128Para[k] = v | |
| for k, v in SpMem64.items(): | |
| SpMem64Para[k] = v | |
| else: | |
| # generic | |
| SpMem256Para, SpMem128Para, SpMem64Para = None, None, None | |
| with torch.no_grad(): | |
| with THREAD_LOCK_DMDNET: | |
| try: | |
| GenericResult, SpecificResult = self.model_dmdnet( | |
| lq=lq.to(self.torchdevice), | |
| loc=LQLocs.unsqueeze(0), | |
| sp_256=SpMem256Para, | |
| sp_128=SpMem128Para, | |
| sp_64=SpMem64Para, | |
| ) | |
| except Exception as e: | |
| print( | |
| f"Error {e} there may be something wrong with the detected component locations." | |
| ) | |
| return temp_frame | |
| if SpecificResult is not None: | |
| save_specific = SpecificResult * 0.5 + 0.5 | |
| save_specific = ( | |
| save_specific.squeeze(0).permute(1, 2, 0).flip(2) | |
| ) # RGB->BGR | |
| save_specific = np.clip(save_specific.float().cpu().numpy(), 0, 1) * 255.0 | |
| temp_frame = save_specific.astype("uint8") | |
| if False: | |
| save_generic = GenericResult * 0.5 + 0.5 | |
| save_generic = ( | |
| save_generic.squeeze(0).permute(1, 2, 0).flip(2) | |
| ) # RGB->BGR | |
| save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 | |
| check_lq = lq * 0.5 + 0.5 | |
| check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR | |
| check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0 | |
| cv2.imwrite( | |
| "dmdnet_comparison.png", | |
| cv2.cvtColor( | |
| np.hstack((check_lq, save_generic, save_specific)), | |
| cv2.COLOR_RGB2BGR, | |
| ), | |
| ) | |
| else: | |
| save_generic = GenericResult * 0.5 + 0.5 | |
| save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR | |
| save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 | |
| temp_frame = save_generic.astype("uint8") | |
| temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_RGB2BGR) # RGB | |
| return temp_frame | |
| def create(self, devicename): | |
| self.torchdevice = torch.device(devicename) | |
| model_dmdnet = DMDNet().to(self.torchdevice) | |
| weights = torch.load("./models/DMDNet.pth", map_location=self.torchdevice) | |
| model_dmdnet.load_state_dict(weights, strict=False) | |
| model_dmdnet.eval() | |
| num_params = 0 | |
| for param in model_dmdnet.parameters(): | |
| num_params += param.numel() | |
| return model_dmdnet | |
| # print('{:>8s} : {}'.format('Using device', device)) | |
| # print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6)) | |
| def read_img_tensor(Img=None): # rgb -1~1 | |
| Img = Img.transpose((2, 0, 1)) / 255.0 | |
| Img = torch.from_numpy(Img).float() | |
| normalize(Img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True) | |
| ImgTensor = Img.unsqueeze(0) | |
| return ImgTensor | |
| def get_component_location(Landmarks, re_read=False): | |
| if re_read: | |
| ReadLandmark = [] | |
| with open(Landmarks, "r") as f: | |
| for line in f: | |
| tmp = [float(i) for i in line.split(" ") if i != "\n"] | |
| ReadLandmark.append(tmp) | |
| ReadLandmark = np.array(ReadLandmark) # | |
| Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2 | |
| Map_LE_B = list(np.hstack((range(17, 22), range(36, 42)))) | |
| Map_RE_B = list(np.hstack((range(22, 27), range(42, 48)))) | |
| Map_LE = list(range(36, 42)) | |
| Map_RE = list(range(42, 48)) | |
| Map_NO = list(range(29, 36)) | |
| Map_MO = list(range(48, 68)) | |
| Landmarks[Landmarks > 504] = 504 | |
| Landmarks[Landmarks < 8] = 8 | |
| # left eye | |
| Mean_LE = np.mean(Landmarks[Map_LE], 0) | |
| L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B, 1]) | |
| L_LE1 = L_LE1 * 1.3 | |
| L_LE2 = L_LE1 / 1.9 | |
| L_LE_xy = L_LE1 + L_LE2 | |
| L_LE_lt = [L_LE_xy / 2, L_LE1] | |
| L_LE_rb = [L_LE_xy / 2, L_LE2] | |
| Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int) | |
| # right eye | |
| Mean_RE = np.mean(Landmarks[Map_RE], 0) | |
| L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B, 1]) | |
| L_RE1 = L_RE1 * 1.3 | |
| L_RE2 = L_RE1 / 1.9 | |
| L_RE_xy = L_RE1 + L_RE2 | |
| L_RE_lt = [L_RE_xy / 2, L_RE1] | |
| L_RE_rb = [L_RE_xy / 2, L_RE2] | |
| Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int) | |
| # nose | |
| Mean_NO = np.mean(Landmarks[Map_NO], 0) | |
| L_NO1 = ( | |
| np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]]) | |
| ) * 1.25 | |
| L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1 | |
| L_NO_xy = L_NO1 * 2 | |
| L_NO_lt = [L_NO_xy / 2, L_NO_xy - L_NO2] | |
| L_NO_rb = [L_NO_xy / 2, L_NO2] | |
| Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int) | |
| # mouth | |
| Mean_MO = np.mean(Landmarks[Map_MO], 0) | |
| L_MO = ( | |
| np.max( | |
| ( | |
| np.max(np.max(Landmarks[Map_MO], 0) - np.min(Landmarks[Map_MO], 0)) / 2, | |
| 16, | |
| ) | |
| ) | |
| * 1.1 | |
| ) | |
| MO_O = Mean_MO - L_MO + 1 | |
| MO_T = Mean_MO + L_MO | |
| MO_T[MO_T > 510] = 510 | |
| Location_MO = np.hstack((MO_O, MO_T)).astype(int) | |
| return torch.cat( | |
| [ | |
| torch.FloatTensor(Location_LE).unsqueeze(0), | |
| torch.FloatTensor(Location_RE).unsqueeze(0), | |
| torch.FloatTensor(Location_NO).unsqueeze(0), | |
| torch.FloatTensor(Location_MO).unsqueeze(0), | |
| ], | |
| dim=0, | |
| ) | |
| def calc_mean_std_4D(feat, eps=1e-5): | |
| # eps is a small value added to the variance to avoid divide-by-zero. | |
| size = feat.size() | |
| assert len(size) == 4 | |
| N, C = size[:2] | |
| feat_var = feat.view(N, C, -1).var(dim=2) + eps | |
| feat_std = feat_var.sqrt().view(N, C, 1, 1) | |
| feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) | |
| return feat_mean, feat_std | |
| def adaptive_instance_normalization_4D( | |
| content_feat, style_feat | |
| ): # content_feat is ref feature, style is degradate feature | |
| size = content_feat.size() | |
| style_mean, style_std = calc_mean_std_4D(style_feat) | |
| content_mean, content_std = calc_mean_std_4D(content_feat) | |
| normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand( | |
| size | |
| ) | |
| return normalized_feat * style_std.expand(size) + style_mean.expand(size) | |
| def convU( | |
| in_channels, | |
| out_channels, | |
| conv_layer, | |
| norm_layer, | |
| kernel_size=3, | |
| stride=1, | |
| dilation=1, | |
| bias=True, | |
| ): | |
| return nn.Sequential( | |
| SpectralNorm( | |
| conv_layer( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| dilation=dilation, | |
| padding=((kernel_size - 1) // 2) * dilation, | |
| bias=bias, | |
| ) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm( | |
| conv_layer( | |
| out_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| dilation=dilation, | |
| padding=((kernel_size - 1) // 2) * dilation, | |
| bias=bias, | |
| ) | |
| ), | |
| ) | |
| class MSDilateBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| conv_layer=nn.Conv2d, | |
| norm_layer=nn.BatchNorm2d, | |
| kernel_size=3, | |
| dilation=[1, 1, 1, 1], | |
| bias=True, | |
| ): | |
| super(MSDilateBlock, self).__init__() | |
| self.conv1 = convU( | |
| in_channels, | |
| in_channels, | |
| conv_layer, | |
| norm_layer, | |
| kernel_size, | |
| dilation=dilation[0], | |
| bias=bias, | |
| ) | |
| self.conv2 = convU( | |
| in_channels, | |
| in_channels, | |
| conv_layer, | |
| norm_layer, | |
| kernel_size, | |
| dilation=dilation[1], | |
| bias=bias, | |
| ) | |
| self.conv3 = convU( | |
| in_channels, | |
| in_channels, | |
| conv_layer, | |
| norm_layer, | |
| kernel_size, | |
| dilation=dilation[2], | |
| bias=bias, | |
| ) | |
| self.conv4 = convU( | |
| in_channels, | |
| in_channels, | |
| conv_layer, | |
| norm_layer, | |
| kernel_size, | |
| dilation=dilation[3], | |
| bias=bias, | |
| ) | |
| self.convi = SpectralNorm( | |
| conv_layer( | |
| in_channels * 4, | |
| in_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=(kernel_size - 1) // 2, | |
| bias=bias, | |
| ) | |
| ) | |
| def forward(self, x): | |
| conv1 = self.conv1(x) | |
| conv2 = self.conv2(x) | |
| conv3 = self.conv3(x) | |
| conv4 = self.conv4(x) | |
| cat = torch.cat([conv1, conv2, conv3, conv4], 1) | |
| out = self.convi(cat) + x | |
| return out | |
| class AdaptiveInstanceNorm(nn.Module): | |
| def __init__(self, in_channel): | |
| super().__init__() | |
| self.norm = nn.InstanceNorm2d(in_channel) | |
| def forward(self, input, style): | |
| style_mean, style_std = calc_mean_std_4D(style) | |
| out = self.norm(input) | |
| size = input.size() | |
| out = style_std.expand(size) * out + style_mean.expand(size) | |
| return out | |
| class NoiseInjection(nn.Module): | |
| def __init__(self, channel): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) | |
| def forward(self, image, noise): | |
| if noise is None: | |
| b, c, h, w = image.shape | |
| noise = image.new_empty(b, 1, h, w).normal_() | |
| return image + self.weight * noise | |
| class StyledUpBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channel, | |
| out_channel, | |
| kernel_size=3, | |
| padding=1, | |
| upsample=False, | |
| noise_inject=False, | |
| ): | |
| super().__init__() | |
| self.noise_inject = noise_inject | |
| if upsample: | |
| self.conv1 = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), | |
| SpectralNorm( | |
| nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| ) | |
| else: | |
| self.conv1 = nn.Sequential( | |
| SpectralNorm( | |
| nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm( | |
| nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding) | |
| ), | |
| ) | |
| self.convup = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), | |
| SpectralNorm( | |
| nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm( | |
| nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding) | |
| ), | |
| ) | |
| if self.noise_inject: | |
| self.noise1 = NoiseInjection(out_channel) | |
| self.lrelu1 = nn.LeakyReLU(0.2) | |
| self.ScaleModel1 = nn.Sequential( | |
| SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), | |
| ) | |
| self.ShiftModel1 = nn.Sequential( | |
| SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), | |
| ) | |
| def forward(self, input, style): | |
| out = self.conv1(input) | |
| out = self.lrelu1(out) | |
| Shift1 = self.ShiftModel1(style) | |
| Scale1 = self.ScaleModel1(style) | |
| out = out * Scale1 + Shift1 | |
| if self.noise_inject: | |
| out = self.noise1(out, noise=None) | |
| outup = self.convup(out) | |
| return outup | |
| #################################################################### | |
| ###############Face Dictionary Generator | |
| #################################################################### | |
| def AttentionBlock(in_channel): | |
| return nn.Sequential( | |
| SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), | |
| ) | |
| class DilateResBlock(nn.Module): | |
| def __init__(self, dim, dilation=[5, 3]): | |
| super(DilateResBlock, self).__init__() | |
| self.Res = nn.Sequential( | |
| SpectralNorm( | |
| nn.Conv2d(dim, dim, 3, 1, ((3 - 1) // 2) * dilation[0], dilation[0]) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm( | |
| nn.Conv2d(dim, dim, 3, 1, ((3 - 1) // 2) * dilation[1], dilation[1]) | |
| ), | |
| ) | |
| def forward(self, x): | |
| out = x + self.Res(x) | |
| return out | |
| class KeyValue(nn.Module): | |
| def __init__(self, indim, keydim, valdim): | |
| super(KeyValue, self).__init__() | |
| self.Key = nn.Sequential( | |
| SpectralNorm( | |
| nn.Conv2d(indim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm( | |
| nn.Conv2d(keydim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1) | |
| ), | |
| ) | |
| self.Value = nn.Sequential( | |
| SpectralNorm( | |
| nn.Conv2d(indim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm( | |
| nn.Conv2d(valdim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1) | |
| ), | |
| ) | |
| def forward(self, x): | |
| return self.Key(x), self.Value(x) | |
| class MaskAttention(nn.Module): | |
| def __init__(self, indim): | |
| super(MaskAttention, self).__init__() | |
| self.conv1 = nn.Sequential( | |
| SpectralNorm( | |
| nn.Conv2d( | |
| indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 | |
| ) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm( | |
| nn.Conv2d( | |
| indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 | |
| ) | |
| ), | |
| ) | |
| self.conv2 = nn.Sequential( | |
| SpectralNorm( | |
| nn.Conv2d( | |
| indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 | |
| ) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm( | |
| nn.Conv2d( | |
| indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 | |
| ) | |
| ), | |
| ) | |
| self.conv3 = nn.Sequential( | |
| SpectralNorm( | |
| nn.Conv2d( | |
| indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 | |
| ) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm( | |
| nn.Conv2d( | |
| indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 | |
| ) | |
| ), | |
| ) | |
| self.convCat = nn.Sequential( | |
| SpectralNorm( | |
| nn.Conv2d( | |
| indim // 3 * 3, indim, kernel_size=(3, 3), padding=(1, 1), stride=1 | |
| ) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm( | |
| nn.Conv2d(indim, indim, kernel_size=(3, 3), padding=(1, 1), stride=1) | |
| ), | |
| ) | |
| def forward(self, x, y, z): | |
| c1 = self.conv1(x) | |
| c2 = self.conv2(y) | |
| c3 = self.conv3(z) | |
| return self.convCat(torch.cat([c1, c2, c3], dim=1)) | |
| class Query(nn.Module): | |
| def __init__(self, indim, quedim): | |
| super(Query, self).__init__() | |
| self.Query = nn.Sequential( | |
| SpectralNorm( | |
| nn.Conv2d(indim, quedim, kernel_size=(3, 3), padding=(1, 1), stride=1) | |
| ), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm( | |
| nn.Conv2d(quedim, quedim, kernel_size=(3, 3), padding=(1, 1), stride=1) | |
| ), | |
| ) | |
| def forward(self, x): | |
| return self.Query(x) | |
| def roi_align_self(input, location, target_size): | |
| test = (target_size.item(), target_size.item()) | |
| return torch.cat( | |
| [ | |
| F.interpolate( | |
| input[ | |
| i : i + 1, | |
| :, | |
| location[i, 1] : location[i, 3], | |
| location[i, 0] : location[i, 2], | |
| ], | |
| test, | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| for i in range(input.size(0)) | |
| ], | |
| 0, | |
| ) | |
| class FeatureExtractor(nn.Module): | |
| def __init__(self, ngf=64, key_scale=4): # | |
| super().__init__() | |
| self.key_scale = 4 | |
| self.part_sizes = np.array([80, 80, 50, 110]) # | |
| self.feature_sizes = np.array([256, 128, 64]) # | |
| self.conv1 = nn.Sequential( | |
| SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), | |
| ) | |
| self.conv2 = nn.Sequential( | |
| SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), | |
| ) | |
| self.res1 = DilateResBlock(ngf, [5, 3]) | |
| self.res2 = DilateResBlock(ngf, [5, 3]) | |
| self.conv3 = nn.Sequential( | |
| SpectralNorm(nn.Conv2d(ngf, ngf * 2, 3, 2, 1)), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)), | |
| ) | |
| self.conv4 = nn.Sequential( | |
| SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)), | |
| ) | |
| self.res3 = DilateResBlock(ngf * 2, [3, 1]) | |
| self.res4 = DilateResBlock(ngf * 2, [3, 1]) | |
| self.conv5 = nn.Sequential( | |
| SpectralNorm(nn.Conv2d(ngf * 2, ngf * 4, 3, 2, 1)), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)), | |
| ) | |
| self.conv6 = nn.Sequential( | |
| SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)), | |
| ) | |
| self.res5 = DilateResBlock(ngf * 4, [1, 1]) | |
| self.res6 = DilateResBlock(ngf * 4, [1, 1]) | |
| self.LE_256_Q = Query(ngf, ngf // self.key_scale) | |
| self.RE_256_Q = Query(ngf, ngf // self.key_scale) | |
| self.MO_256_Q = Query(ngf, ngf // self.key_scale) | |
| self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) | |
| self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) | |
| self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) | |
| self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) | |
| self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) | |
| self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) | |
| def forward(self, img, locs): | |
| le_location = locs[:, 0, :].int().cpu().numpy() | |
| re_location = locs[:, 1, :].int().cpu().numpy() | |
| no_location = locs[:, 2, :].int().cpu().numpy() | |
| mo_location = locs[:, 3, :].int().cpu().numpy() | |
| f1_0 = self.conv1(img) | |
| f1_1 = self.res1(f1_0) | |
| f2_0 = self.conv2(f1_1) | |
| f2_1 = self.res2(f2_0) | |
| f3_0 = self.conv3(f2_1) | |
| f3_1 = self.res3(f3_0) | |
| f4_0 = self.conv4(f3_1) | |
| f4_1 = self.res4(f4_0) | |
| f5_0 = self.conv5(f4_1) | |
| f5_1 = self.res5(f5_0) | |
| f6_0 = self.conv6(f5_1) | |
| f6_1 = self.res6(f6_0) | |
| ####ROI Align | |
| le_part_256 = roi_align_self( | |
| f2_1.clone(), le_location // 2, self.part_sizes[0] // 2 | |
| ) | |
| re_part_256 = roi_align_self( | |
| f2_1.clone(), re_location // 2, self.part_sizes[1] // 2 | |
| ) | |
| mo_part_256 = roi_align_self( | |
| f2_1.clone(), mo_location // 2, self.part_sizes[3] // 2 | |
| ) | |
| le_part_128 = roi_align_self( | |
| f4_1.clone(), le_location // 4, self.part_sizes[0] // 4 | |
| ) | |
| re_part_128 = roi_align_self( | |
| f4_1.clone(), re_location // 4, self.part_sizes[1] // 4 | |
| ) | |
| mo_part_128 = roi_align_self( | |
| f4_1.clone(), mo_location // 4, self.part_sizes[3] // 4 | |
| ) | |
| le_part_64 = roi_align_self( | |
| f6_1.clone(), le_location // 8, self.part_sizes[0] // 8 | |
| ) | |
| re_part_64 = roi_align_self( | |
| f6_1.clone(), re_location // 8, self.part_sizes[1] // 8 | |
| ) | |
| mo_part_64 = roi_align_self( | |
| f6_1.clone(), mo_location // 8, self.part_sizes[3] // 8 | |
| ) | |
| le_256_q = self.LE_256_Q(le_part_256) | |
| re_256_q = self.RE_256_Q(re_part_256) | |
| mo_256_q = self.MO_256_Q(mo_part_256) | |
| le_128_q = self.LE_128_Q(le_part_128) | |
| re_128_q = self.RE_128_Q(re_part_128) | |
| mo_128_q = self.MO_128_Q(mo_part_128) | |
| le_64_q = self.LE_64_Q(le_part_64) | |
| re_64_q = self.RE_64_Q(re_part_64) | |
| mo_64_q = self.MO_64_Q(mo_part_64) | |
| return { | |
| "f256": f2_1, | |
| "f128": f4_1, | |
| "f64": f6_1, | |
| "le256": le_part_256, | |
| "re256": re_part_256, | |
| "mo256": mo_part_256, | |
| "le128": le_part_128, | |
| "re128": re_part_128, | |
| "mo128": mo_part_128, | |
| "le64": le_part_64, | |
| "re64": re_part_64, | |
| "mo64": mo_part_64, | |
| "le_256_q": le_256_q, | |
| "re_256_q": re_256_q, | |
| "mo_256_q": mo_256_q, | |
| "le_128_q": le_128_q, | |
| "re_128_q": re_128_q, | |
| "mo_128_q": mo_128_q, | |
| "le_64_q": le_64_q, | |
| "re_64_q": re_64_q, | |
| "mo_64_q": mo_64_q, | |
| } | |
| class DMDNet(nn.Module): | |
| def __init__(self, ngf=64, banks_num=128): | |
| super().__init__() | |
| self.part_sizes = np.array([80, 80, 50, 110]) # size for 512 | |
| self.feature_sizes = np.array([256, 128, 64]) # size for 512 | |
| self.banks_num = banks_num | |
| self.key_scale = 4 | |
| self.E_lq = FeatureExtractor(key_scale=self.key_scale) | |
| self.E_hq = FeatureExtractor(key_scale=self.key_scale) | |
| self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) | |
| self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) | |
| self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) | |
| self.LE_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2) | |
| self.RE_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2) | |
| self.MO_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2) | |
| self.LE_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4) | |
| self.RE_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4) | |
| self.MO_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4) | |
| self.LE_256_Attention = AttentionBlock(64) | |
| self.RE_256_Attention = AttentionBlock(64) | |
| self.MO_256_Attention = AttentionBlock(64) | |
| self.LE_128_Attention = AttentionBlock(128) | |
| self.RE_128_Attention = AttentionBlock(128) | |
| self.MO_128_Attention = AttentionBlock(128) | |
| self.LE_64_Attention = AttentionBlock(256) | |
| self.RE_64_Attention = AttentionBlock(256) | |
| self.MO_64_Attention = AttentionBlock(256) | |
| self.LE_256_Mask = MaskAttention(64) | |
| self.RE_256_Mask = MaskAttention(64) | |
| self.MO_256_Mask = MaskAttention(64) | |
| self.LE_128_Mask = MaskAttention(128) | |
| self.RE_128_Mask = MaskAttention(128) | |
| self.MO_128_Mask = MaskAttention(128) | |
| self.LE_64_Mask = MaskAttention(256) | |
| self.RE_64_Mask = MaskAttention(256) | |
| self.MO_64_Mask = MaskAttention(256) | |
| self.MSDilate = MSDilateBlock(ngf * 4, dilation=[4, 3, 2, 1]) | |
| self.up1 = StyledUpBlock(ngf * 4, ngf * 2, noise_inject=False) # | |
| self.up2 = StyledUpBlock(ngf * 2, ngf, noise_inject=False) # | |
| self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) # | |
| self.up4 = nn.Sequential( | |
| SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), | |
| nn.LeakyReLU(0.2), | |
| UpResBlock(ngf), | |
| UpResBlock(ngf), | |
| SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)), | |
| nn.Tanh(), | |
| ) | |
| # define generic memory, revise register_buffer to register_parameter for backward update | |
| self.register_buffer("le_256_mem_key", torch.randn(128, 16, 40, 40)) | |
| self.register_buffer("re_256_mem_key", torch.randn(128, 16, 40, 40)) | |
| self.register_buffer("mo_256_mem_key", torch.randn(128, 16, 55, 55)) | |
| self.register_buffer("le_256_mem_value", torch.randn(128, 64, 40, 40)) | |
| self.register_buffer("re_256_mem_value", torch.randn(128, 64, 40, 40)) | |
| self.register_buffer("mo_256_mem_value", torch.randn(128, 64, 55, 55)) | |
| self.register_buffer("le_128_mem_key", torch.randn(128, 32, 20, 20)) | |
| self.register_buffer("re_128_mem_key", torch.randn(128, 32, 20, 20)) | |
| self.register_buffer("mo_128_mem_key", torch.randn(128, 32, 27, 27)) | |
| self.register_buffer("le_128_mem_value", torch.randn(128, 128, 20, 20)) | |
| self.register_buffer("re_128_mem_value", torch.randn(128, 128, 20, 20)) | |
| self.register_buffer("mo_128_mem_value", torch.randn(128, 128, 27, 27)) | |
| self.register_buffer("le_64_mem_key", torch.randn(128, 64, 10, 10)) | |
| self.register_buffer("re_64_mem_key", torch.randn(128, 64, 10, 10)) | |
| self.register_buffer("mo_64_mem_key", torch.randn(128, 64, 13, 13)) | |
| self.register_buffer("le_64_mem_value", torch.randn(128, 256, 10, 10)) | |
| self.register_buffer("re_64_mem_value", torch.randn(128, 256, 10, 10)) | |
| self.register_buffer("mo_64_mem_value", torch.randn(128, 256, 13, 13)) | |
| def readMem(self, k, v, q): | |
| sim = F.conv2d(q, k) | |
| score = F.softmax(sim / sqrt(sim.size(1)), dim=1) # B * S * 1 * 1 6*128 | |
| sb, sn, sw, sh = score.size() | |
| s_m = score.view(sb, -1).unsqueeze(1) # 2*1*M | |
| vb, vn, vw, vh = v.size() | |
| v_in = v.view(vb, -1).repeat(sb, 1, 1) # 2*M*(c*w*h) | |
| mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw, vh) | |
| max_inds = torch.argmax(score, dim=1).squeeze() | |
| return mem_out, max_inds | |
| def memorize(self, img, locs): | |
| fs = self.E_hq(img, locs) | |
| LE256_key, LE256_value = self.LE_256_KV(fs["le256"]) | |
| RE256_key, RE256_value = self.RE_256_KV(fs["re256"]) | |
| MO256_key, MO256_value = self.MO_256_KV(fs["mo256"]) | |
| LE128_key, LE128_value = self.LE_128_KV(fs["le128"]) | |
| RE128_key, RE128_value = self.RE_128_KV(fs["re128"]) | |
| MO128_key, MO128_value = self.MO_128_KV(fs["mo128"]) | |
| LE64_key, LE64_value = self.LE_64_KV(fs["le64"]) | |
| RE64_key, RE64_value = self.RE_64_KV(fs["re64"]) | |
| MO64_key, MO64_value = self.MO_64_KV(fs["mo64"]) | |
| Mem256 = { | |
| "LE256Key": LE256_key, | |
| "LE256Value": LE256_value, | |
| "RE256Key": RE256_key, | |
| "RE256Value": RE256_value, | |
| "MO256Key": MO256_key, | |
| "MO256Value": MO256_value, | |
| } | |
| Mem128 = { | |
| "LE128Key": LE128_key, | |
| "LE128Value": LE128_value, | |
| "RE128Key": RE128_key, | |
| "RE128Value": RE128_value, | |
| "MO128Key": MO128_key, | |
| "MO128Value": MO128_value, | |
| } | |
| Mem64 = { | |
| "LE64Key": LE64_key, | |
| "LE64Value": LE64_value, | |
| "RE64Key": RE64_key, | |
| "RE64Value": RE64_value, | |
| "MO64Key": MO64_key, | |
| "MO64Value": MO64_value, | |
| } | |
| FS256 = {"LE256F": fs["le256"], "RE256F": fs["re256"], "MO256F": fs["mo256"]} | |
| FS128 = {"LE128F": fs["le128"], "RE128F": fs["re128"], "MO128F": fs["mo128"]} | |
| FS64 = {"LE64F": fs["le64"], "RE64F": fs["re64"], "MO64F": fs["mo64"]} | |
| return Mem256, Mem128, Mem64 | |
| def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None): | |
| le_256_q = fs_in["le_256_q"] | |
| re_256_q = fs_in["re_256_q"] | |
| mo_256_q = fs_in["mo_256_q"] | |
| le_128_q = fs_in["le_128_q"] | |
| re_128_q = fs_in["re_128_q"] | |
| mo_128_q = fs_in["mo_128_q"] | |
| le_64_q = fs_in["le_64_q"] | |
| re_64_q = fs_in["re_64_q"] | |
| mo_64_q = fs_in["mo_64_q"] | |
| ####for 256 | |
| le_256_mem_g, le_256_inds = self.readMem( | |
| self.le_256_mem_key, self.le_256_mem_value, le_256_q | |
| ) | |
| re_256_mem_g, re_256_inds = self.readMem( | |
| self.re_256_mem_key, self.re_256_mem_value, re_256_q | |
| ) | |
| mo_256_mem_g, mo_256_inds = self.readMem( | |
| self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q | |
| ) | |
| le_128_mem_g, le_128_inds = self.readMem( | |
| self.le_128_mem_key, self.le_128_mem_value, le_128_q | |
| ) | |
| re_128_mem_g, re_128_inds = self.readMem( | |
| self.re_128_mem_key, self.re_128_mem_value, re_128_q | |
| ) | |
| mo_128_mem_g, mo_128_inds = self.readMem( | |
| self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q | |
| ) | |
| le_64_mem_g, le_64_inds = self.readMem( | |
| self.le_64_mem_key, self.le_64_mem_value, le_64_q | |
| ) | |
| re_64_mem_g, re_64_inds = self.readMem( | |
| self.re_64_mem_key, self.re_64_mem_value, re_64_q | |
| ) | |
| mo_64_mem_g, mo_64_inds = self.readMem( | |
| self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q | |
| ) | |
| if sp_256 is not None and sp_128 is not None and sp_64 is not None: | |
| le_256_mem_s, _ = self.readMem( | |
| sp_256["LE256Key"], sp_256["LE256Value"], le_256_q | |
| ) | |
| re_256_mem_s, _ = self.readMem( | |
| sp_256["RE256Key"], sp_256["RE256Value"], re_256_q | |
| ) | |
| mo_256_mem_s, _ = self.readMem( | |
| sp_256["MO256Key"], sp_256["MO256Value"], mo_256_q | |
| ) | |
| le_256_mask = self.LE_256_Mask(fs_in["le256"], le_256_mem_s, le_256_mem_g) | |
| le_256_mem = le_256_mask * le_256_mem_s + (1 - le_256_mask) * le_256_mem_g | |
| re_256_mask = self.RE_256_Mask(fs_in["re256"], re_256_mem_s, re_256_mem_g) | |
| re_256_mem = re_256_mask * re_256_mem_s + (1 - re_256_mask) * re_256_mem_g | |
| mo_256_mask = self.MO_256_Mask(fs_in["mo256"], mo_256_mem_s, mo_256_mem_g) | |
| mo_256_mem = mo_256_mask * mo_256_mem_s + (1 - mo_256_mask) * mo_256_mem_g | |
| le_128_mem_s, _ = self.readMem( | |
| sp_128["LE128Key"], sp_128["LE128Value"], le_128_q | |
| ) | |
| re_128_mem_s, _ = self.readMem( | |
| sp_128["RE128Key"], sp_128["RE128Value"], re_128_q | |
| ) | |
| mo_128_mem_s, _ = self.readMem( | |
| sp_128["MO128Key"], sp_128["MO128Value"], mo_128_q | |
| ) | |
| le_128_mask = self.LE_128_Mask(fs_in["le128"], le_128_mem_s, le_128_mem_g) | |
| le_128_mem = le_128_mask * le_128_mem_s + (1 - le_128_mask) * le_128_mem_g | |
| re_128_mask = self.RE_128_Mask(fs_in["re128"], re_128_mem_s, re_128_mem_g) | |
| re_128_mem = re_128_mask * re_128_mem_s + (1 - re_128_mask) * re_128_mem_g | |
| mo_128_mask = self.MO_128_Mask(fs_in["mo128"], mo_128_mem_s, mo_128_mem_g) | |
| mo_128_mem = mo_128_mask * mo_128_mem_s + (1 - mo_128_mask) * mo_128_mem_g | |
| le_64_mem_s, _ = self.readMem(sp_64["LE64Key"], sp_64["LE64Value"], le_64_q) | |
| re_64_mem_s, _ = self.readMem(sp_64["RE64Key"], sp_64["RE64Value"], re_64_q) | |
| mo_64_mem_s, _ = self.readMem(sp_64["MO64Key"], sp_64["MO64Value"], mo_64_q) | |
| le_64_mask = self.LE_64_Mask(fs_in["le64"], le_64_mem_s, le_64_mem_g) | |
| le_64_mem = le_64_mask * le_64_mem_s + (1 - le_64_mask) * le_64_mem_g | |
| re_64_mask = self.RE_64_Mask(fs_in["re64"], re_64_mem_s, re_64_mem_g) | |
| re_64_mem = re_64_mask * re_64_mem_s + (1 - re_64_mask) * re_64_mem_g | |
| mo_64_mask = self.MO_64_Mask(fs_in["mo64"], mo_64_mem_s, mo_64_mem_g) | |
| mo_64_mem = mo_64_mask * mo_64_mem_s + (1 - mo_64_mask) * mo_64_mem_g | |
| else: | |
| le_256_mem = le_256_mem_g | |
| re_256_mem = re_256_mem_g | |
| mo_256_mem = mo_256_mem_g | |
| le_128_mem = le_128_mem_g | |
| re_128_mem = re_128_mem_g | |
| mo_128_mem = mo_128_mem_g | |
| le_64_mem = le_64_mem_g | |
| re_64_mem = re_64_mem_g | |
| mo_64_mem = mo_64_mem_g | |
| le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in["le256"]) | |
| re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in["re256"]) | |
| mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in["mo256"]) | |
| ####for 128 | |
| le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in["le128"]) | |
| re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in["re128"]) | |
| mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in["mo128"]) | |
| ####for 64 | |
| le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in["le64"]) | |
| re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in["re64"]) | |
| mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in["mo64"]) | |
| EnMem256 = { | |
| "LE256Norm": le_256_mem_norm, | |
| "RE256Norm": re_256_mem_norm, | |
| "MO256Norm": mo_256_mem_norm, | |
| } | |
| EnMem128 = { | |
| "LE128Norm": le_128_mem_norm, | |
| "RE128Norm": re_128_mem_norm, | |
| "MO128Norm": mo_128_mem_norm, | |
| } | |
| EnMem64 = { | |
| "LE64Norm": le_64_mem_norm, | |
| "RE64Norm": re_64_mem_norm, | |
| "MO64Norm": mo_64_mem_norm, | |
| } | |
| Ind256 = {"LE": le_256_inds, "RE": re_256_inds, "MO": mo_256_inds} | |
| Ind128 = {"LE": le_128_inds, "RE": re_128_inds, "MO": mo_128_inds} | |
| Ind64 = {"LE": le_64_inds, "RE": re_64_inds, "MO": mo_64_inds} | |
| return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64 | |
| def reconstruct(self, fs_in, locs, memstar): | |
| le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = ( | |
| memstar[0]["LE256Norm"], | |
| memstar[0]["RE256Norm"], | |
| memstar[0]["MO256Norm"], | |
| ) | |
| le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = ( | |
| memstar[1]["LE128Norm"], | |
| memstar[1]["RE128Norm"], | |
| memstar[1]["MO128Norm"], | |
| ) | |
| le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = ( | |
| memstar[2]["LE64Norm"], | |
| memstar[2]["RE64Norm"], | |
| memstar[2]["MO64Norm"], | |
| ) | |
| le_256_final = ( | |
| self.LE_256_Attention(le_256_mem_norm - fs_in["le256"]) * le_256_mem_norm | |
| + fs_in["le256"] | |
| ) | |
| re_256_final = ( | |
| self.RE_256_Attention(re_256_mem_norm - fs_in["re256"]) * re_256_mem_norm | |
| + fs_in["re256"] | |
| ) | |
| mo_256_final = ( | |
| self.MO_256_Attention(mo_256_mem_norm - fs_in["mo256"]) * mo_256_mem_norm | |
| + fs_in["mo256"] | |
| ) | |
| le_128_final = ( | |
| self.LE_128_Attention(le_128_mem_norm - fs_in["le128"]) * le_128_mem_norm | |
| + fs_in["le128"] | |
| ) | |
| re_128_final = ( | |
| self.RE_128_Attention(re_128_mem_norm - fs_in["re128"]) * re_128_mem_norm | |
| + fs_in["re128"] | |
| ) | |
| mo_128_final = ( | |
| self.MO_128_Attention(mo_128_mem_norm - fs_in["mo128"]) * mo_128_mem_norm | |
| + fs_in["mo128"] | |
| ) | |
| le_64_final = ( | |
| self.LE_64_Attention(le_64_mem_norm - fs_in["le64"]) * le_64_mem_norm | |
| + fs_in["le64"] | |
| ) | |
| re_64_final = ( | |
| self.RE_64_Attention(re_64_mem_norm - fs_in["re64"]) * re_64_mem_norm | |
| + fs_in["re64"] | |
| ) | |
| mo_64_final = ( | |
| self.MO_64_Attention(mo_64_mem_norm - fs_in["mo64"]) * mo_64_mem_norm | |
| + fs_in["mo64"] | |
| ) | |
| le_location = locs[:, 0, :] | |
| re_location = locs[:, 1, :] | |
| mo_location = locs[:, 3, :] | |
| # Somehow with latest Torch it doesn't like numpy wrappers anymore | |
| # le_location = le_location.cpu().int().numpy() | |
| # re_location = re_location.cpu().int().numpy() | |
| # mo_location = mo_location.cpu().int().numpy() | |
| le_location = le_location.cpu().int() | |
| re_location = re_location.cpu().int() | |
| mo_location = mo_location.cpu().int() | |
| up_in_256 = fs_in["f256"].clone() # * 0 | |
| up_in_128 = fs_in["f128"].clone() # * 0 | |
| up_in_64 = fs_in["f64"].clone() # * 0 | |
| for i in range(fs_in["f256"].size(0)): | |
| up_in_256[ | |
| i : i + 1, | |
| :, | |
| le_location[i, 1] // 2 : le_location[i, 3] // 2, | |
| le_location[i, 0] // 2 : le_location[i, 2] // 2, | |
| ] = F.interpolate( | |
| le_256_final[i : i + 1, :, :, :].clone(), | |
| ( | |
| le_location[i, 3] // 2 - le_location[i, 1] // 2, | |
| le_location[i, 2] // 2 - le_location[i, 0] // 2, | |
| ), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| up_in_256[ | |
| i : i + 1, | |
| :, | |
| re_location[i, 1] // 2 : re_location[i, 3] // 2, | |
| re_location[i, 0] // 2 : re_location[i, 2] // 2, | |
| ] = F.interpolate( | |
| re_256_final[i : i + 1, :, :, :].clone(), | |
| ( | |
| re_location[i, 3] // 2 - re_location[i, 1] // 2, | |
| re_location[i, 2] // 2 - re_location[i, 0] // 2, | |
| ), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| up_in_256[ | |
| i : i + 1, | |
| :, | |
| mo_location[i, 1] // 2 : mo_location[i, 3] // 2, | |
| mo_location[i, 0] // 2 : mo_location[i, 2] // 2, | |
| ] = F.interpolate( | |
| mo_256_final[i : i + 1, :, :, :].clone(), | |
| ( | |
| mo_location[i, 3] // 2 - mo_location[i, 1] // 2, | |
| mo_location[i, 2] // 2 - mo_location[i, 0] // 2, | |
| ), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| up_in_128[ | |
| i : i + 1, | |
| :, | |
| le_location[i, 1] // 4 : le_location[i, 3] // 4, | |
| le_location[i, 0] // 4 : le_location[i, 2] // 4, | |
| ] = F.interpolate( | |
| le_128_final[i : i + 1, :, :, :].clone(), | |
| ( | |
| le_location[i, 3] // 4 - le_location[i, 1] // 4, | |
| le_location[i, 2] // 4 - le_location[i, 0] // 4, | |
| ), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| up_in_128[ | |
| i : i + 1, | |
| :, | |
| re_location[i, 1] // 4 : re_location[i, 3] // 4, | |
| re_location[i, 0] // 4 : re_location[i, 2] // 4, | |
| ] = F.interpolate( | |
| re_128_final[i : i + 1, :, :, :].clone(), | |
| ( | |
| re_location[i, 3] // 4 - re_location[i, 1] // 4, | |
| re_location[i, 2] // 4 - re_location[i, 0] // 4, | |
| ), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| up_in_128[ | |
| i : i + 1, | |
| :, | |
| mo_location[i, 1] // 4 : mo_location[i, 3] // 4, | |
| mo_location[i, 0] // 4 : mo_location[i, 2] // 4, | |
| ] = F.interpolate( | |
| mo_128_final[i : i + 1, :, :, :].clone(), | |
| ( | |
| mo_location[i, 3] // 4 - mo_location[i, 1] // 4, | |
| mo_location[i, 2] // 4 - mo_location[i, 0] // 4, | |
| ), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| up_in_64[ | |
| i : i + 1, | |
| :, | |
| le_location[i, 1] // 8 : le_location[i, 3] // 8, | |
| le_location[i, 0] // 8 : le_location[i, 2] // 8, | |
| ] = F.interpolate( | |
| le_64_final[i : i + 1, :, :, :].clone(), | |
| ( | |
| le_location[i, 3] // 8 - le_location[i, 1] // 8, | |
| le_location[i, 2] // 8 - le_location[i, 0] // 8, | |
| ), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| up_in_64[ | |
| i : i + 1, | |
| :, | |
| re_location[i, 1] // 8 : re_location[i, 3] // 8, | |
| re_location[i, 0] // 8 : re_location[i, 2] // 8, | |
| ] = F.interpolate( | |
| re_64_final[i : i + 1, :, :, :].clone(), | |
| ( | |
| re_location[i, 3] // 8 - re_location[i, 1] // 8, | |
| re_location[i, 2] // 8 - re_location[i, 0] // 8, | |
| ), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| up_in_64[ | |
| i : i + 1, | |
| :, | |
| mo_location[i, 1] // 8 : mo_location[i, 3] // 8, | |
| mo_location[i, 0] // 8 : mo_location[i, 2] // 8, | |
| ] = F.interpolate( | |
| mo_64_final[i : i + 1, :, :, :].clone(), | |
| ( | |
| mo_location[i, 3] // 8 - mo_location[i, 1] // 8, | |
| mo_location[i, 2] // 8 - mo_location[i, 0] // 8, | |
| ), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| ms_in_64 = self.MSDilate(fs_in["f64"].clone()) | |
| fea_up1 = self.up1(ms_in_64, up_in_64) | |
| fea_up2 = self.up2(fea_up1, up_in_128) # | |
| fea_up3 = self.up3(fea_up2, up_in_256) # | |
| output = self.up4(fea_up3) # | |
| return output | |
| def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None): | |
| return self.memorize(sp_imgs, sp_locs) | |
| def forward(self, lq=None, loc=None, sp_256=None, sp_128=None, sp_64=None): | |
| try: | |
| fs_in = self.E_lq(lq, loc) # low quality images | |
| except Exception as e: | |
| print(e) | |
| GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer( | |
| fs_in | |
| ) | |
| GeOut = self.reconstruct( | |
| fs_in, loc, memstar=[GeMemNorm256, GeMemNorm128, GeMemNorm64] | |
| ) | |
| if sp_256 is not None and sp_128 is not None and sp_64 is not None: | |
| GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer( | |
| fs_in, sp_256, sp_128, sp_64 | |
| ) | |
| GSOut = self.reconstruct( | |
| fs_in, loc, memstar=[GSMemNorm256, GSMemNorm128, GSMemNorm64] | |
| ) | |
| else: | |
| GSOut = None | |
| return GeOut, GSOut | |
| class UpResBlock(nn.Module): | |
| def __init__(self, dim, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d): | |
| super(UpResBlock, self).__init__() | |
| self.Model = nn.Sequential( | |
| SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), | |
| nn.LeakyReLU(0.2), | |
| SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), | |
| ) | |
| def forward(self, x): | |
| out = x + self.Model(x) | |
| return out | |