from typing import Any, List, Callable import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.utils.spectral_norm as SpectralNorm import threading from torchvision.ops import roi_align from math import sqrt from torchvision.transforms.functional import normalize from roop.typing import Face, Frame, FaceSet THREAD_LOCK_DMDNET = threading.Lock() class Enhance_DMDNet: plugin_options: dict = None model_dmdnet = None torchdevice = None processorname = "dmdnet" type = "enhance" def Initialize(self, plugin_options: dict): if self.plugin_options is not None: if self.plugin_options["devicename"] != plugin_options["devicename"]: self.Release() self.plugin_options = plugin_options if self.model_dmdnet is None: self.model_dmdnet = self.create(self.plugin_options["devicename"]) # temp_frame already cropped+aligned, bbox not def Run( self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame ) -> Frame: input_size = temp_frame.shape[1] result = self.enhance_face(source_faceset, temp_frame, target_face) scale_factor = int(result.shape[1] / input_size) return result.astype(np.uint8), scale_factor def Release(self): self.model_dmdnet = None # https://stackoverflow.com/a/67174339 def landmarks106_to_68(self, pt106): map106to68 = [ 1, 10, 12, 14, 16, 3, 5, 7, 0, 23, 21, 19, 32, 30, 28, 26, 17, 43, 48, 49, 51, 50, 102, 103, 104, 105, 101, 72, 73, 74, 86, 78, 79, 80, 85, 84, 35, 41, 42, 39, 37, 36, 89, 95, 96, 93, 91, 90, 52, 64, 63, 71, 67, 68, 61, 58, 59, 53, 56, 55, 65, 66, 62, 70, 69, 57, 60, 54, ] pt68 = [] for i in range(68): index = map106to68[i] pt68.append(pt106[index]) return pt68 def check_bbox(self, imgs, boxes): boxes = boxes.view(-1, 4, 4) colors = [(0, 255, 0), (0, 255, 0), (255, 255, 0), (255, 0, 0)] i = 0 for img, box in zip(imgs, boxes): img = (img + 1) / 2 * 255 img2 = img.permute(1, 2, 0).float().cpu().flip(2).numpy().copy() for idx, point in enumerate(box): cv2.rectangle( img2, (int(point[0]), int(point[1])), (int(point[2]), int(point[3])), color=colors[idx], thickness=2, ) cv2.imwrite("dmdnet_{:02d}.png".format(i), img2) i += 1 def trans_points2d(self, pts, M): new_pts = np.zeros(shape=pts.shape, dtype=np.float32) for i in range(pts.shape[0]): pt = pts[i] new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) new_pt = np.dot(M, new_pt) new_pts[i] = new_pt[0:2] return new_pts def enhance_face(self, ref_faceset: FaceSet, temp_frame, face: Face): # preprocess start_x, start_y, end_x, end_y = map(int, face["bbox"]) lm106 = face.landmark_2d_106 lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) if temp_frame.shape[0] != 512 or temp_frame.shape[1] != 512: # scale to 512x512 scale_factor = 512 / temp_frame.shape[1] M = face.matrix * scale_factor lq_landmarks = self.trans_points2d(lq_landmarks, M) temp_frame = cv2.resize( temp_frame, (512, 512), interpolation=cv2.INTER_AREA ) if temp_frame.ndim == 2: temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG # else: # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB lq = read_img_tensor(temp_frame) LQLocs = get_component_location(lq_landmarks) # self.check_bbox(lq, LQLocs.unsqueeze(0)) # specific, change 1000 to 1 to activate if len(ref_faceset.faces) > 1: SpecificImgs = [] SpecificLocs = [] for i, face in enumerate(ref_faceset.faces): lm106 = face.landmark_2d_106 lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) ref_image = ref_faceset.ref_images[i] if ref_image.shape[0] != 512 or ref_image.shape[1] != 512: # scale to 512x512 scale_factor = 512 / ref_image.shape[1] M = face.matrix * scale_factor lq_landmarks = self.trans_points2d(lq_landmarks, M) ref_image = cv2.resize( ref_image, (512, 512), interpolation=cv2.INTER_AREA ) if ref_image.ndim == 2: temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG # else: # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB ref_tensor = read_img_tensor(ref_image) ref_locs = get_component_location(lq_landmarks) # self.check_bbox(ref_tensor, ref_locs.unsqueeze(0)) SpecificImgs.append(ref_tensor) SpecificLocs.append(ref_locs.unsqueeze(0)) SpecificImgs = torch.cat(SpecificImgs, dim=0) SpecificLocs = torch.cat(SpecificLocs, dim=0) # check_bbox(SpecificImgs, SpecificLocs) SpMem256, SpMem128, SpMem64 = ( self.model_dmdnet.generate_specific_dictionary( sp_imgs=SpecificImgs.to(self.torchdevice), sp_locs=SpecificLocs ) ) SpMem256Para = {} SpMem128Para = {} SpMem64Para = {} for k, v in SpMem256.items(): SpMem256Para[k] = v for k, v in SpMem128.items(): SpMem128Para[k] = v for k, v in SpMem64.items(): SpMem64Para[k] = v else: # generic SpMem256Para, SpMem128Para, SpMem64Para = None, None, None with torch.no_grad(): with THREAD_LOCK_DMDNET: try: GenericResult, SpecificResult = self.model_dmdnet( lq=lq.to(self.torchdevice), loc=LQLocs.unsqueeze(0), sp_256=SpMem256Para, sp_128=SpMem128Para, sp_64=SpMem64Para, ) except Exception as e: print( f"Error {e} there may be something wrong with the detected component locations." ) return temp_frame if SpecificResult is not None: save_specific = SpecificResult * 0.5 + 0.5 save_specific = ( save_specific.squeeze(0).permute(1, 2, 0).flip(2) ) # RGB->BGR save_specific = np.clip(save_specific.float().cpu().numpy(), 0, 1) * 255.0 temp_frame = save_specific.astype("uint8") if False: save_generic = GenericResult * 0.5 + 0.5 save_generic = ( save_generic.squeeze(0).permute(1, 2, 0).flip(2) ) # RGB->BGR save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 check_lq = lq * 0.5 + 0.5 check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0 cv2.imwrite( "dmdnet_comparison.png", cv2.cvtColor( np.hstack((check_lq, save_generic, save_specific)), cv2.COLOR_RGB2BGR, ), ) else: save_generic = GenericResult * 0.5 + 0.5 save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 temp_frame = save_generic.astype("uint8") temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_RGB2BGR) # RGB return temp_frame def create(self, devicename): self.torchdevice = torch.device(devicename) model_dmdnet = DMDNet().to(self.torchdevice) weights = torch.load("./models/DMDNet.pth", map_location=self.torchdevice) model_dmdnet.load_state_dict(weights, strict=False) model_dmdnet.eval() num_params = 0 for param in model_dmdnet.parameters(): num_params += param.numel() return model_dmdnet # print('{:>8s} : {}'.format('Using device', device)) # print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6)) def read_img_tensor(Img=None): # rgb -1~1 Img = Img.transpose((2, 0, 1)) / 255.0 Img = torch.from_numpy(Img).float() normalize(Img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True) ImgTensor = Img.unsqueeze(0) return ImgTensor def get_component_location(Landmarks, re_read=False): if re_read: ReadLandmark = [] with open(Landmarks, "r") as f: for line in f: tmp = [float(i) for i in line.split(" ") if i != "\n"] ReadLandmark.append(tmp) ReadLandmark = np.array(ReadLandmark) # Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2 Map_LE_B = list(np.hstack((range(17, 22), range(36, 42)))) Map_RE_B = list(np.hstack((range(22, 27), range(42, 48)))) Map_LE = list(range(36, 42)) Map_RE = list(range(42, 48)) Map_NO = list(range(29, 36)) Map_MO = list(range(48, 68)) Landmarks[Landmarks > 504] = 504 Landmarks[Landmarks < 8] = 8 # left eye Mean_LE = np.mean(Landmarks[Map_LE], 0) L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B, 1]) L_LE1 = L_LE1 * 1.3 L_LE2 = L_LE1 / 1.9 L_LE_xy = L_LE1 + L_LE2 L_LE_lt = [L_LE_xy / 2, L_LE1] L_LE_rb = [L_LE_xy / 2, L_LE2] Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int) # right eye Mean_RE = np.mean(Landmarks[Map_RE], 0) L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B, 1]) L_RE1 = L_RE1 * 1.3 L_RE2 = L_RE1 / 1.9 L_RE_xy = L_RE1 + L_RE2 L_RE_lt = [L_RE_xy / 2, L_RE1] L_RE_rb = [L_RE_xy / 2, L_RE2] Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int) # nose Mean_NO = np.mean(Landmarks[Map_NO], 0) L_NO1 = ( np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]]) ) * 1.25 L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1 L_NO_xy = L_NO1 * 2 L_NO_lt = [L_NO_xy / 2, L_NO_xy - L_NO2] L_NO_rb = [L_NO_xy / 2, L_NO2] Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int) # mouth Mean_MO = np.mean(Landmarks[Map_MO], 0) L_MO = ( np.max( ( np.max(np.max(Landmarks[Map_MO], 0) - np.min(Landmarks[Map_MO], 0)) / 2, 16, ) ) * 1.1 ) MO_O = Mean_MO - L_MO + 1 MO_T = Mean_MO + L_MO MO_T[MO_T > 510] = 510 Location_MO = np.hstack((MO_O, MO_T)).astype(int) return torch.cat( [ torch.FloatTensor(Location_LE).unsqueeze(0), torch.FloatTensor(Location_RE).unsqueeze(0), torch.FloatTensor(Location_NO).unsqueeze(0), torch.FloatTensor(Location_MO).unsqueeze(0), ], dim=0, ) def calc_mean_std_4D(feat, eps=1e-5): # eps is a small value added to the variance to avoid divide-by-zero. size = feat.size() assert len(size) == 4 N, C = size[:2] feat_var = feat.view(N, C, -1).var(dim=2) + eps feat_std = feat_var.sqrt().view(N, C, 1, 1) feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) return feat_mean, feat_std def adaptive_instance_normalization_4D( content_feat, style_feat ): # content_feat is ref feature, style is degradate feature size = content_feat.size() style_mean, style_std = calc_mean_std_4D(style_feat) content_mean, content_std = calc_mean_std_4D(content_feat) normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand( size ) return normalized_feat * style_std.expand(size) + style_mean.expand(size) def convU( in_channels, out_channels, conv_layer, norm_layer, kernel_size=3, stride=1, dilation=1, bias=True, ): return nn.Sequential( SpectralNorm( conv_layer( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size - 1) // 2) * dilation, bias=bias, ) ), nn.LeakyReLU(0.2), SpectralNorm( conv_layer( out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size - 1) // 2) * dilation, bias=bias, ) ), ) class MSDilateBlock(nn.Module): def __init__( self, in_channels, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1, 1, 1, 1], bias=True, ): super(MSDilateBlock, self).__init__() self.conv1 = convU( in_channels, in_channels, conv_layer, norm_layer, kernel_size, dilation=dilation[0], bias=bias, ) self.conv2 = convU( in_channels, in_channels, conv_layer, norm_layer, kernel_size, dilation=dilation[1], bias=bias, ) self.conv3 = convU( in_channels, in_channels, conv_layer, norm_layer, kernel_size, dilation=dilation[2], bias=bias, ) self.conv4 = convU( in_channels, in_channels, conv_layer, norm_layer, kernel_size, dilation=dilation[3], bias=bias, ) self.convi = SpectralNorm( conv_layer( in_channels * 4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias, ) ) def forward(self, x): conv1 = self.conv1(x) conv2 = self.conv2(x) conv3 = self.conv3(x) conv4 = self.conv4(x) cat = torch.cat([conv1, conv2, conv3, conv4], 1) out = self.convi(cat) + x return out class AdaptiveInstanceNorm(nn.Module): def __init__(self, in_channel): super().__init__() self.norm = nn.InstanceNorm2d(in_channel) def forward(self, input, style): style_mean, style_std = calc_mean_std_4D(style) out = self.norm(input) size = input.size() out = style_std.expand(size) * out + style_mean.expand(size) return out class NoiseInjection(nn.Module): def __init__(self, channel): super().__init__() self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) def forward(self, image, noise): if noise is None: b, c, h, w = image.shape noise = image.new_empty(b, 1, h, w).normal_() return image + self.weight * noise class StyledUpBlock(nn.Module): def __init__( self, in_channel, out_channel, kernel_size=3, padding=1, upsample=False, noise_inject=False, ): super().__init__() self.noise_inject = noise_inject if upsample: self.conv1 = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), SpectralNorm( nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding) ), nn.LeakyReLU(0.2), ) else: self.conv1 = nn.Sequential( SpectralNorm( nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding) ), nn.LeakyReLU(0.2), SpectralNorm( nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding) ), ) self.convup = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), SpectralNorm( nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding) ), nn.LeakyReLU(0.2), SpectralNorm( nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding) ), ) if self.noise_inject: self.noise1 = NoiseInjection(out_channel) self.lrelu1 = nn.LeakyReLU(0.2) self.ScaleModel1 = nn.Sequential( SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), ) self.ShiftModel1 = nn.Sequential( SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), ) def forward(self, input, style): out = self.conv1(input) out = self.lrelu1(out) Shift1 = self.ShiftModel1(style) Scale1 = self.ScaleModel1(style) out = out * Scale1 + Shift1 if self.noise_inject: out = self.noise1(out, noise=None) outup = self.convup(out) return outup #################################################################### ###############Face Dictionary Generator #################################################################### def AttentionBlock(in_channel): return nn.Sequential( SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), ) class DilateResBlock(nn.Module): def __init__(self, dim, dilation=[5, 3]): super(DilateResBlock, self).__init__() self.Res = nn.Sequential( SpectralNorm( nn.Conv2d(dim, dim, 3, 1, ((3 - 1) // 2) * dilation[0], dilation[0]) ), nn.LeakyReLU(0.2), SpectralNorm( nn.Conv2d(dim, dim, 3, 1, ((3 - 1) // 2) * dilation[1], dilation[1]) ), ) def forward(self, x): out = x + self.Res(x) return out class KeyValue(nn.Module): def __init__(self, indim, keydim, valdim): super(KeyValue, self).__init__() self.Key = nn.Sequential( SpectralNorm( nn.Conv2d(indim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1) ), nn.LeakyReLU(0.2), SpectralNorm( nn.Conv2d(keydim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1) ), ) self.Value = nn.Sequential( SpectralNorm( nn.Conv2d(indim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1) ), nn.LeakyReLU(0.2), SpectralNorm( nn.Conv2d(valdim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1) ), ) def forward(self, x): return self.Key(x), self.Value(x) class MaskAttention(nn.Module): def __init__(self, indim): super(MaskAttention, self).__init__() self.conv1 = nn.Sequential( SpectralNorm( nn.Conv2d( indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 ) ), nn.LeakyReLU(0.2), SpectralNorm( nn.Conv2d( indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 ) ), ) self.conv2 = nn.Sequential( SpectralNorm( nn.Conv2d( indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 ) ), nn.LeakyReLU(0.2), SpectralNorm( nn.Conv2d( indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 ) ), ) self.conv3 = nn.Sequential( SpectralNorm( nn.Conv2d( indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 ) ), nn.LeakyReLU(0.2), SpectralNorm( nn.Conv2d( indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 ) ), ) self.convCat = nn.Sequential( SpectralNorm( nn.Conv2d( indim // 3 * 3, indim, kernel_size=(3, 3), padding=(1, 1), stride=1 ) ), nn.LeakyReLU(0.2), SpectralNorm( nn.Conv2d(indim, indim, kernel_size=(3, 3), padding=(1, 1), stride=1) ), ) def forward(self, x, y, z): c1 = self.conv1(x) c2 = self.conv2(y) c3 = self.conv3(z) return self.convCat(torch.cat([c1, c2, c3], dim=1)) class Query(nn.Module): def __init__(self, indim, quedim): super(Query, self).__init__() self.Query = nn.Sequential( SpectralNorm( nn.Conv2d(indim, quedim, kernel_size=(3, 3), padding=(1, 1), stride=1) ), nn.LeakyReLU(0.2), SpectralNorm( nn.Conv2d(quedim, quedim, kernel_size=(3, 3), padding=(1, 1), stride=1) ), ) def forward(self, x): return self.Query(x) def roi_align_self(input, location, target_size): test = (target_size.item(), target_size.item()) return torch.cat( [ F.interpolate( input[ i : i + 1, :, location[i, 1] : location[i, 3], location[i, 0] : location[i, 2], ], test, mode="bilinear", align_corners=False, ) for i in range(input.size(0)) ], 0, ) class FeatureExtractor(nn.Module): def __init__(self, ngf=64, key_scale=4): # super().__init__() self.key_scale = 4 self.part_sizes = np.array([80, 80, 50, 110]) # self.feature_sizes = np.array([256, 128, 64]) # self.conv1 = nn.Sequential( SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), ) self.conv2 = nn.Sequential( SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), ) self.res1 = DilateResBlock(ngf, [5, 3]) self.res2 = DilateResBlock(ngf, [5, 3]) self.conv3 = nn.Sequential( SpectralNorm(nn.Conv2d(ngf, ngf * 2, 3, 2, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)), ) self.conv4 = nn.Sequential( SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)), ) self.res3 = DilateResBlock(ngf * 2, [3, 1]) self.res4 = DilateResBlock(ngf * 2, [3, 1]) self.conv5 = nn.Sequential( SpectralNorm(nn.Conv2d(ngf * 2, ngf * 4, 3, 2, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)), ) self.conv6 = nn.Sequential( SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)), ) self.res5 = DilateResBlock(ngf * 4, [1, 1]) self.res6 = DilateResBlock(ngf * 4, [1, 1]) self.LE_256_Q = Query(ngf, ngf // self.key_scale) self.RE_256_Q = Query(ngf, ngf // self.key_scale) self.MO_256_Q = Query(ngf, ngf // self.key_scale) self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) def forward(self, img, locs): le_location = locs[:, 0, :].int().cpu().numpy() re_location = locs[:, 1, :].int().cpu().numpy() no_location = locs[:, 2, :].int().cpu().numpy() mo_location = locs[:, 3, :].int().cpu().numpy() f1_0 = self.conv1(img) f1_1 = self.res1(f1_0) f2_0 = self.conv2(f1_1) f2_1 = self.res2(f2_0) f3_0 = self.conv3(f2_1) f3_1 = self.res3(f3_0) f4_0 = self.conv4(f3_1) f4_1 = self.res4(f4_0) f5_0 = self.conv5(f4_1) f5_1 = self.res5(f5_0) f6_0 = self.conv6(f5_1) f6_1 = self.res6(f6_0) ####ROI Align le_part_256 = roi_align_self( f2_1.clone(), le_location // 2, self.part_sizes[0] // 2 ) re_part_256 = roi_align_self( f2_1.clone(), re_location // 2, self.part_sizes[1] // 2 ) mo_part_256 = roi_align_self( f2_1.clone(), mo_location // 2, self.part_sizes[3] // 2 ) le_part_128 = roi_align_self( f4_1.clone(), le_location // 4, self.part_sizes[0] // 4 ) re_part_128 = roi_align_self( f4_1.clone(), re_location // 4, self.part_sizes[1] // 4 ) mo_part_128 = roi_align_self( f4_1.clone(), mo_location // 4, self.part_sizes[3] // 4 ) le_part_64 = roi_align_self( f6_1.clone(), le_location // 8, self.part_sizes[0] // 8 ) re_part_64 = roi_align_self( f6_1.clone(), re_location // 8, self.part_sizes[1] // 8 ) mo_part_64 = roi_align_self( f6_1.clone(), mo_location // 8, self.part_sizes[3] // 8 ) le_256_q = self.LE_256_Q(le_part_256) re_256_q = self.RE_256_Q(re_part_256) mo_256_q = self.MO_256_Q(mo_part_256) le_128_q = self.LE_128_Q(le_part_128) re_128_q = self.RE_128_Q(re_part_128) mo_128_q = self.MO_128_Q(mo_part_128) le_64_q = self.LE_64_Q(le_part_64) re_64_q = self.RE_64_Q(re_part_64) mo_64_q = self.MO_64_Q(mo_part_64) return { "f256": f2_1, "f128": f4_1, "f64": f6_1, "le256": le_part_256, "re256": re_part_256, "mo256": mo_part_256, "le128": le_part_128, "re128": re_part_128, "mo128": mo_part_128, "le64": le_part_64, "re64": re_part_64, "mo64": mo_part_64, "le_256_q": le_256_q, "re_256_q": re_256_q, "mo_256_q": mo_256_q, "le_128_q": le_128_q, "re_128_q": re_128_q, "mo_128_q": mo_128_q, "le_64_q": le_64_q, "re_64_q": re_64_q, "mo_64_q": mo_64_q, } class DMDNet(nn.Module): def __init__(self, ngf=64, banks_num=128): super().__init__() self.part_sizes = np.array([80, 80, 50, 110]) # size for 512 self.feature_sizes = np.array([256, 128, 64]) # size for 512 self.banks_num = banks_num self.key_scale = 4 self.E_lq = FeatureExtractor(key_scale=self.key_scale) self.E_hq = FeatureExtractor(key_scale=self.key_scale) self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) self.LE_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2) self.RE_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2) self.MO_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2) self.LE_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4) self.RE_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4) self.MO_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4) self.LE_256_Attention = AttentionBlock(64) self.RE_256_Attention = AttentionBlock(64) self.MO_256_Attention = AttentionBlock(64) self.LE_128_Attention = AttentionBlock(128) self.RE_128_Attention = AttentionBlock(128) self.MO_128_Attention = AttentionBlock(128) self.LE_64_Attention = AttentionBlock(256) self.RE_64_Attention = AttentionBlock(256) self.MO_64_Attention = AttentionBlock(256) self.LE_256_Mask = MaskAttention(64) self.RE_256_Mask = MaskAttention(64) self.MO_256_Mask = MaskAttention(64) self.LE_128_Mask = MaskAttention(128) self.RE_128_Mask = MaskAttention(128) self.MO_128_Mask = MaskAttention(128) self.LE_64_Mask = MaskAttention(256) self.RE_64_Mask = MaskAttention(256) self.MO_64_Mask = MaskAttention(256) self.MSDilate = MSDilateBlock(ngf * 4, dilation=[4, 3, 2, 1]) self.up1 = StyledUpBlock(ngf * 4, ngf * 2, noise_inject=False) # self.up2 = StyledUpBlock(ngf * 2, ngf, noise_inject=False) # self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) # self.up4 = nn.Sequential( SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), nn.LeakyReLU(0.2), UpResBlock(ngf), UpResBlock(ngf), SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)), nn.Tanh(), ) # define generic memory, revise register_buffer to register_parameter for backward update self.register_buffer("le_256_mem_key", torch.randn(128, 16, 40, 40)) self.register_buffer("re_256_mem_key", torch.randn(128, 16, 40, 40)) self.register_buffer("mo_256_mem_key", torch.randn(128, 16, 55, 55)) self.register_buffer("le_256_mem_value", torch.randn(128, 64, 40, 40)) self.register_buffer("re_256_mem_value", torch.randn(128, 64, 40, 40)) self.register_buffer("mo_256_mem_value", torch.randn(128, 64, 55, 55)) self.register_buffer("le_128_mem_key", torch.randn(128, 32, 20, 20)) self.register_buffer("re_128_mem_key", torch.randn(128, 32, 20, 20)) self.register_buffer("mo_128_mem_key", torch.randn(128, 32, 27, 27)) self.register_buffer("le_128_mem_value", torch.randn(128, 128, 20, 20)) self.register_buffer("re_128_mem_value", torch.randn(128, 128, 20, 20)) self.register_buffer("mo_128_mem_value", torch.randn(128, 128, 27, 27)) self.register_buffer("le_64_mem_key", torch.randn(128, 64, 10, 10)) self.register_buffer("re_64_mem_key", torch.randn(128, 64, 10, 10)) self.register_buffer("mo_64_mem_key", torch.randn(128, 64, 13, 13)) self.register_buffer("le_64_mem_value", torch.randn(128, 256, 10, 10)) self.register_buffer("re_64_mem_value", torch.randn(128, 256, 10, 10)) self.register_buffer("mo_64_mem_value", torch.randn(128, 256, 13, 13)) def readMem(self, k, v, q): sim = F.conv2d(q, k) score = F.softmax(sim / sqrt(sim.size(1)), dim=1) # B * S * 1 * 1 6*128 sb, sn, sw, sh = score.size() s_m = score.view(sb, -1).unsqueeze(1) # 2*1*M vb, vn, vw, vh = v.size() v_in = v.view(vb, -1).repeat(sb, 1, 1) # 2*M*(c*w*h) mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw, vh) max_inds = torch.argmax(score, dim=1).squeeze() return mem_out, max_inds def memorize(self, img, locs): fs = self.E_hq(img, locs) LE256_key, LE256_value = self.LE_256_KV(fs["le256"]) RE256_key, RE256_value = self.RE_256_KV(fs["re256"]) MO256_key, MO256_value = self.MO_256_KV(fs["mo256"]) LE128_key, LE128_value = self.LE_128_KV(fs["le128"]) RE128_key, RE128_value = self.RE_128_KV(fs["re128"]) MO128_key, MO128_value = self.MO_128_KV(fs["mo128"]) LE64_key, LE64_value = self.LE_64_KV(fs["le64"]) RE64_key, RE64_value = self.RE_64_KV(fs["re64"]) MO64_key, MO64_value = self.MO_64_KV(fs["mo64"]) Mem256 = { "LE256Key": LE256_key, "LE256Value": LE256_value, "RE256Key": RE256_key, "RE256Value": RE256_value, "MO256Key": MO256_key, "MO256Value": MO256_value, } Mem128 = { "LE128Key": LE128_key, "LE128Value": LE128_value, "RE128Key": RE128_key, "RE128Value": RE128_value, "MO128Key": MO128_key, "MO128Value": MO128_value, } Mem64 = { "LE64Key": LE64_key, "LE64Value": LE64_value, "RE64Key": RE64_key, "RE64Value": RE64_value, "MO64Key": MO64_key, "MO64Value": MO64_value, } FS256 = {"LE256F": fs["le256"], "RE256F": fs["re256"], "MO256F": fs["mo256"]} FS128 = {"LE128F": fs["le128"], "RE128F": fs["re128"], "MO128F": fs["mo128"]} FS64 = {"LE64F": fs["le64"], "RE64F": fs["re64"], "MO64F": fs["mo64"]} return Mem256, Mem128, Mem64 def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None): le_256_q = fs_in["le_256_q"] re_256_q = fs_in["re_256_q"] mo_256_q = fs_in["mo_256_q"] le_128_q = fs_in["le_128_q"] re_128_q = fs_in["re_128_q"] mo_128_q = fs_in["mo_128_q"] le_64_q = fs_in["le_64_q"] re_64_q = fs_in["re_64_q"] mo_64_q = fs_in["mo_64_q"] ####for 256 le_256_mem_g, le_256_inds = self.readMem( self.le_256_mem_key, self.le_256_mem_value, le_256_q ) re_256_mem_g, re_256_inds = self.readMem( self.re_256_mem_key, self.re_256_mem_value, re_256_q ) mo_256_mem_g, mo_256_inds = self.readMem( self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q ) le_128_mem_g, le_128_inds = self.readMem( self.le_128_mem_key, self.le_128_mem_value, le_128_q ) re_128_mem_g, re_128_inds = self.readMem( self.re_128_mem_key, self.re_128_mem_value, re_128_q ) mo_128_mem_g, mo_128_inds = self.readMem( self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q ) le_64_mem_g, le_64_inds = self.readMem( self.le_64_mem_key, self.le_64_mem_value, le_64_q ) re_64_mem_g, re_64_inds = self.readMem( self.re_64_mem_key, self.re_64_mem_value, re_64_q ) mo_64_mem_g, mo_64_inds = self.readMem( self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q ) if sp_256 is not None and sp_128 is not None and sp_64 is not None: le_256_mem_s, _ = self.readMem( sp_256["LE256Key"], sp_256["LE256Value"], le_256_q ) re_256_mem_s, _ = self.readMem( sp_256["RE256Key"], sp_256["RE256Value"], re_256_q ) mo_256_mem_s, _ = self.readMem( sp_256["MO256Key"], sp_256["MO256Value"], mo_256_q ) le_256_mask = self.LE_256_Mask(fs_in["le256"], le_256_mem_s, le_256_mem_g) le_256_mem = le_256_mask * le_256_mem_s + (1 - le_256_mask) * le_256_mem_g re_256_mask = self.RE_256_Mask(fs_in["re256"], re_256_mem_s, re_256_mem_g) re_256_mem = re_256_mask * re_256_mem_s + (1 - re_256_mask) * re_256_mem_g mo_256_mask = self.MO_256_Mask(fs_in["mo256"], mo_256_mem_s, mo_256_mem_g) mo_256_mem = mo_256_mask * mo_256_mem_s + (1 - mo_256_mask) * mo_256_mem_g le_128_mem_s, _ = self.readMem( sp_128["LE128Key"], sp_128["LE128Value"], le_128_q ) re_128_mem_s, _ = self.readMem( sp_128["RE128Key"], sp_128["RE128Value"], re_128_q ) mo_128_mem_s, _ = self.readMem( sp_128["MO128Key"], sp_128["MO128Value"], mo_128_q ) le_128_mask = self.LE_128_Mask(fs_in["le128"], le_128_mem_s, le_128_mem_g) le_128_mem = le_128_mask * le_128_mem_s + (1 - le_128_mask) * le_128_mem_g re_128_mask = self.RE_128_Mask(fs_in["re128"], re_128_mem_s, re_128_mem_g) re_128_mem = re_128_mask * re_128_mem_s + (1 - re_128_mask) * re_128_mem_g mo_128_mask = self.MO_128_Mask(fs_in["mo128"], mo_128_mem_s, mo_128_mem_g) mo_128_mem = mo_128_mask * mo_128_mem_s + (1 - mo_128_mask) * mo_128_mem_g le_64_mem_s, _ = self.readMem(sp_64["LE64Key"], sp_64["LE64Value"], le_64_q) re_64_mem_s, _ = self.readMem(sp_64["RE64Key"], sp_64["RE64Value"], re_64_q) mo_64_mem_s, _ = self.readMem(sp_64["MO64Key"], sp_64["MO64Value"], mo_64_q) le_64_mask = self.LE_64_Mask(fs_in["le64"], le_64_mem_s, le_64_mem_g) le_64_mem = le_64_mask * le_64_mem_s + (1 - le_64_mask) * le_64_mem_g re_64_mask = self.RE_64_Mask(fs_in["re64"], re_64_mem_s, re_64_mem_g) re_64_mem = re_64_mask * re_64_mem_s + (1 - re_64_mask) * re_64_mem_g mo_64_mask = self.MO_64_Mask(fs_in["mo64"], mo_64_mem_s, mo_64_mem_g) mo_64_mem = mo_64_mask * mo_64_mem_s + (1 - mo_64_mask) * mo_64_mem_g else: le_256_mem = le_256_mem_g re_256_mem = re_256_mem_g mo_256_mem = mo_256_mem_g le_128_mem = le_128_mem_g re_128_mem = re_128_mem_g mo_128_mem = mo_128_mem_g le_64_mem = le_64_mem_g re_64_mem = re_64_mem_g mo_64_mem = mo_64_mem_g le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in["le256"]) re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in["re256"]) mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in["mo256"]) ####for 128 le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in["le128"]) re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in["re128"]) mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in["mo128"]) ####for 64 le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in["le64"]) re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in["re64"]) mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in["mo64"]) EnMem256 = { "LE256Norm": le_256_mem_norm, "RE256Norm": re_256_mem_norm, "MO256Norm": mo_256_mem_norm, } EnMem128 = { "LE128Norm": le_128_mem_norm, "RE128Norm": re_128_mem_norm, "MO128Norm": mo_128_mem_norm, } EnMem64 = { "LE64Norm": le_64_mem_norm, "RE64Norm": re_64_mem_norm, "MO64Norm": mo_64_mem_norm, } Ind256 = {"LE": le_256_inds, "RE": re_256_inds, "MO": mo_256_inds} Ind128 = {"LE": le_128_inds, "RE": re_128_inds, "MO": mo_128_inds} Ind64 = {"LE": le_64_inds, "RE": re_64_inds, "MO": mo_64_inds} return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64 def reconstruct(self, fs_in, locs, memstar): le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = ( memstar[0]["LE256Norm"], memstar[0]["RE256Norm"], memstar[0]["MO256Norm"], ) le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = ( memstar[1]["LE128Norm"], memstar[1]["RE128Norm"], memstar[1]["MO128Norm"], ) le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = ( memstar[2]["LE64Norm"], memstar[2]["RE64Norm"], memstar[2]["MO64Norm"], ) le_256_final = ( self.LE_256_Attention(le_256_mem_norm - fs_in["le256"]) * le_256_mem_norm + fs_in["le256"] ) re_256_final = ( self.RE_256_Attention(re_256_mem_norm - fs_in["re256"]) * re_256_mem_norm + fs_in["re256"] ) mo_256_final = ( self.MO_256_Attention(mo_256_mem_norm - fs_in["mo256"]) * mo_256_mem_norm + fs_in["mo256"] ) le_128_final = ( self.LE_128_Attention(le_128_mem_norm - fs_in["le128"]) * le_128_mem_norm + fs_in["le128"] ) re_128_final = ( self.RE_128_Attention(re_128_mem_norm - fs_in["re128"]) * re_128_mem_norm + fs_in["re128"] ) mo_128_final = ( self.MO_128_Attention(mo_128_mem_norm - fs_in["mo128"]) * mo_128_mem_norm + fs_in["mo128"] ) le_64_final = ( self.LE_64_Attention(le_64_mem_norm - fs_in["le64"]) * le_64_mem_norm + fs_in["le64"] ) re_64_final = ( self.RE_64_Attention(re_64_mem_norm - fs_in["re64"]) * re_64_mem_norm + fs_in["re64"] ) mo_64_final = ( self.MO_64_Attention(mo_64_mem_norm - fs_in["mo64"]) * mo_64_mem_norm + fs_in["mo64"] ) le_location = locs[:, 0, :] re_location = locs[:, 1, :] mo_location = locs[:, 3, :] # Somehow with latest Torch it doesn't like numpy wrappers anymore # le_location = le_location.cpu().int().numpy() # re_location = re_location.cpu().int().numpy() # mo_location = mo_location.cpu().int().numpy() le_location = le_location.cpu().int() re_location = re_location.cpu().int() mo_location = mo_location.cpu().int() up_in_256 = fs_in["f256"].clone() # * 0 up_in_128 = fs_in["f128"].clone() # * 0 up_in_64 = fs_in["f64"].clone() # * 0 for i in range(fs_in["f256"].size(0)): up_in_256[ i : i + 1, :, le_location[i, 1] // 2 : le_location[i, 3] // 2, le_location[i, 0] // 2 : le_location[i, 2] // 2, ] = F.interpolate( le_256_final[i : i + 1, :, :, :].clone(), ( le_location[i, 3] // 2 - le_location[i, 1] // 2, le_location[i, 2] // 2 - le_location[i, 0] // 2, ), mode="bilinear", align_corners=False, ) up_in_256[ i : i + 1, :, re_location[i, 1] // 2 : re_location[i, 3] // 2, re_location[i, 0] // 2 : re_location[i, 2] // 2, ] = F.interpolate( re_256_final[i : i + 1, :, :, :].clone(), ( re_location[i, 3] // 2 - re_location[i, 1] // 2, re_location[i, 2] // 2 - re_location[i, 0] // 2, ), mode="bilinear", align_corners=False, ) up_in_256[ i : i + 1, :, mo_location[i, 1] // 2 : mo_location[i, 3] // 2, mo_location[i, 0] // 2 : mo_location[i, 2] // 2, ] = F.interpolate( mo_256_final[i : i + 1, :, :, :].clone(), ( mo_location[i, 3] // 2 - mo_location[i, 1] // 2, mo_location[i, 2] // 2 - mo_location[i, 0] // 2, ), mode="bilinear", align_corners=False, ) up_in_128[ i : i + 1, :, le_location[i, 1] // 4 : le_location[i, 3] // 4, le_location[i, 0] // 4 : le_location[i, 2] // 4, ] = F.interpolate( le_128_final[i : i + 1, :, :, :].clone(), ( le_location[i, 3] // 4 - le_location[i, 1] // 4, le_location[i, 2] // 4 - le_location[i, 0] // 4, ), mode="bilinear", align_corners=False, ) up_in_128[ i : i + 1, :, re_location[i, 1] // 4 : re_location[i, 3] // 4, re_location[i, 0] // 4 : re_location[i, 2] // 4, ] = F.interpolate( re_128_final[i : i + 1, :, :, :].clone(), ( re_location[i, 3] // 4 - re_location[i, 1] // 4, re_location[i, 2] // 4 - re_location[i, 0] // 4, ), mode="bilinear", align_corners=False, ) up_in_128[ i : i + 1, :, mo_location[i, 1] // 4 : mo_location[i, 3] // 4, mo_location[i, 0] // 4 : mo_location[i, 2] // 4, ] = F.interpolate( mo_128_final[i : i + 1, :, :, :].clone(), ( mo_location[i, 3] // 4 - mo_location[i, 1] // 4, mo_location[i, 2] // 4 - mo_location[i, 0] // 4, ), mode="bilinear", align_corners=False, ) up_in_64[ i : i + 1, :, le_location[i, 1] // 8 : le_location[i, 3] // 8, le_location[i, 0] // 8 : le_location[i, 2] // 8, ] = F.interpolate( le_64_final[i : i + 1, :, :, :].clone(), ( le_location[i, 3] // 8 - le_location[i, 1] // 8, le_location[i, 2] // 8 - le_location[i, 0] // 8, ), mode="bilinear", align_corners=False, ) up_in_64[ i : i + 1, :, re_location[i, 1] // 8 : re_location[i, 3] // 8, re_location[i, 0] // 8 : re_location[i, 2] // 8, ] = F.interpolate( re_64_final[i : i + 1, :, :, :].clone(), ( re_location[i, 3] // 8 - re_location[i, 1] // 8, re_location[i, 2] // 8 - re_location[i, 0] // 8, ), mode="bilinear", align_corners=False, ) up_in_64[ i : i + 1, :, mo_location[i, 1] // 8 : mo_location[i, 3] // 8, mo_location[i, 0] // 8 : mo_location[i, 2] // 8, ] = F.interpolate( mo_64_final[i : i + 1, :, :, :].clone(), ( mo_location[i, 3] // 8 - mo_location[i, 1] // 8, mo_location[i, 2] // 8 - mo_location[i, 0] // 8, ), mode="bilinear", align_corners=False, ) ms_in_64 = self.MSDilate(fs_in["f64"].clone()) fea_up1 = self.up1(ms_in_64, up_in_64) fea_up2 = self.up2(fea_up1, up_in_128) # fea_up3 = self.up3(fea_up2, up_in_256) # output = self.up4(fea_up3) # return output def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None): return self.memorize(sp_imgs, sp_locs) def forward(self, lq=None, loc=None, sp_256=None, sp_128=None, sp_64=None): try: fs_in = self.E_lq(lq, loc) # low quality images except Exception as e: print(e) GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer( fs_in ) GeOut = self.reconstruct( fs_in, loc, memstar=[GeMemNorm256, GeMemNorm128, GeMemNorm64] ) if sp_256 is not None and sp_128 is not None and sp_64 is not None: GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer( fs_in, sp_256, sp_128, sp_64 ) GSOut = self.reconstruct( fs_in, loc, memstar=[GSMemNorm256, GSMemNorm128, GSMemNorm64] ) else: GSOut = None return GeOut, GSOut class UpResBlock(nn.Module): def __init__(self, dim, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d): super(UpResBlock, self).__init__() self.Model = nn.Sequential( SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), ) def forward(self, x): out = x + self.Model(x) return out