|
|
| import os |
| import torch |
|
|
| from model.backbone import ResEncUnet |
|
|
| from model.shader import CINN |
| from model.decoder_small import RGBADecoderNet |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| def UDPClip(x): |
| return torch.clamp(x, min=0, max=1) |
|
|
|
|
| class CoNR(): |
| def __init__(self, args): |
| self.args = args |
|
|
| self.udpparsernet = ResEncUnet( |
| backbone_name='resnet50_danbo', |
| classes=4, |
| pretrained=(args.local_rank == 0), |
| parametric_upsampling=True, |
| decoder_filters=(512, 384, 256, 128, 32), |
| map_location=device |
| ) |
| self.target_pose_encoder = ResEncUnet( |
| backbone_name='resnet18_danbo-4', |
| classes=1, |
| pretrained=(args.local_rank == 0), |
| parametric_upsampling=True, |
| decoder_filters=(512, 384, 256, 128, 32), |
| map_location=device |
| ) |
| self.DIM_SHADER_REFERENCE = 4 |
| self.shader = CINN(self.DIM_SHADER_REFERENCE) |
| self.rgbadecodernet = RGBADecoderNet( |
| ) |
| self.device() |
| self.parser_ckpt = None |
|
|
| def dist(self): |
| args = self.args |
| if args.distributed: |
| self.udpparsernet = torch.nn.parallel.DistributedDataParallel( |
| self.udpparsernet, |
| device_ids=[ |
| args.local_rank], |
| output_device=args.local_rank, |
| broadcast_buffers=False, |
| find_unused_parameters=True |
| ) |
| self.target_pose_encoder = torch.nn.parallel.DistributedDataParallel( |
| self.target_pose_encoder, |
| device_ids=[ |
| args.local_rank], |
| output_device=args.local_rank, |
| broadcast_buffers=False, |
| find_unused_parameters=True |
| ) |
| self.shader = torch.nn.parallel.DistributedDataParallel( |
| self.shader, |
| device_ids=[ |
| args.local_rank], |
| output_device=args.local_rank, |
| broadcast_buffers=True |
| ) |
|
|
| self.rgbadecodernet = torch.nn.parallel.DistributedDataParallel( |
| self.rgbadecodernet, |
| device_ids=[ |
| args.local_rank], |
| output_device=args.local_rank, |
| broadcast_buffers=True |
| ) |
|
|
| def load_model(self, path): |
| self.udpparsernet.load_state_dict( |
| torch.load('{}/udpparsernet.pth'.format(path), map_location=device)) |
| self.target_pose_encoder.load_state_dict( |
| torch.load('{}/target_pose_encoder.pth'.format(path), map_location=device)) |
| self.shader.load_state_dict( |
| torch.load('{}/shader.pth'.format(path), map_location=device)) |
| self.rgbadecodernet.load_state_dict( |
| torch.load('{}/rgbadecodernet.pth'.format(path), map_location=device)) |
|
|
| def save_model(self, ite_num): |
| self._save_pth(self.udpparsernet, |
| model_name="udpparsernet", ite_num=ite_num) |
| self._save_pth(self.target_pose_encoder, |
| model_name="target_pose_encoder", ite_num=ite_num) |
| self._save_pth(self.shader, |
| model_name="shader", ite_num=ite_num) |
| self._save_pth(self.rgbadecodernet, |
| model_name="rgbadecodernet", ite_num=ite_num) |
|
|
| def _save_pth(self, net, model_name, ite_num): |
| args = self.args |
| to_save = None |
| if args.distributed: |
| if args.local_rank == 0: |
| to_save = net.module.state_dict() |
| else: |
| to_save = net.state_dict() |
| if to_save: |
| model_dir = os.path.join( |
| os.getcwd(), 'saved_models', args.model_name + os.sep + "checkpoints" + os.sep + "itr_%d" % (ite_num)+os.sep) |
|
|
| os.makedirs(model_dir, exist_ok=True) |
| torch.save(to_save, model_dir + model_name + ".pth") |
|
|
| def train(self): |
| self.udpparsernet.train() |
| self.target_pose_encoder.train() |
| self.shader.train() |
| self.rgbadecodernet.train() |
|
|
| def eval(self): |
| self.udpparsernet.eval() |
| self.target_pose_encoder.eval() |
| self.shader.eval() |
| self.rgbadecodernet.eval() |
|
|
| def device(self): |
| self.udpparsernet.to(device) |
| self.target_pose_encoder.to(device) |
| self.shader.to(device) |
| self.rgbadecodernet.to(device) |
|
|
| def data_norm_image(self, data): |
|
|
| with torch.cuda.amp.autocast(enabled=False): |
| for name in ["character_labels", "pose_label"]: |
| if name in data: |
| data[name] = data[name].to( |
| device, non_blocking=True).float() |
| for name in ["pose_images", "pose_mask", "character_images", "character_masks"]: |
| if name in data: |
| data[name] = data[name].to( |
| device, non_blocking=True).float() / 255.0 |
| if "pose_images" in data: |
| data["num_pose_images"] = data["pose_images"].shape[1] |
| data["num_samples"] = data["pose_images"].shape[0] |
| if "character_images" in data: |
| data["num_character_images"] = data["character_images"].shape[1] |
| data["num_samples"] = data["character_images"].shape[0] |
| if "pose_images" in data and "character_images" in data: |
| assert (data["pose_images"].shape[0] == |
| data["character_images"].shape[0]) |
| return data |
|
|
| def reset_charactersheet(self): |
| self.parser_ckpt = None |
|
|
| def model_step(self, data, training=False): |
| self.eval() |
| with torch.cuda.amp.autocast(enabled=False): |
| pred = {} |
| if self.parser_ckpt: |
| pred["parser"] = self.parser_ckpt |
| else: |
| pred = self.character_parser_forward(data, pred) |
| self.parser_ckpt = pred["parser"] |
| pred = self.pose_parser_sc_forward(data, pred) |
| pred = self.shader_pose_encoder_forward(data, pred) |
| pred = self.shader_forward(data, pred) |
| return pred |
|
|
| def shader_forward(self, data, pred={}): |
| assert ("num_character_images" in data), "ERROR: No Character Sheet input." |
|
|
| character_images_rgb_nmchw, num_character_images = data[ |
| "character_images"], data["num_character_images"] |
| |
| shader_character_a_nmchw = data["character_masks"] |
| assert torch.any(torch.mean(shader_character_a_nmchw, (0, 2, 3, 4)) >= 0.95) == False, "ERROR: \ |
| No transparent area found in the image, PLEASE separate the foreground of input character sheets.\ |
| The website waifucutout.com is recommended to automatically cut out the foreground." |
| |
| if shader_character_a_nmchw is None: |
| shader_character_a_nmchw = pred["parser"]["pred"][:, :, 3:4, :, :] |
| x_reference_rgb_a = torch.cat([shader_character_a_nmchw[:, :, :, :, :] * character_images_rgb_nmchw[:, :, :, :, :], |
| shader_character_a_nmchw[:, |
| :, :, :, :], |
|
|
| ], 2) |
| assert (x_reference_rgb_a.shape[2] == self.DIM_SHADER_REFERENCE) |
| |
| x_reference_features = pred["parser"]["features"] |
| |
| retdic = self.shader( |
| pred["shader"]["target_pose_features"], x_reference_rgb_a, x_reference_features) |
| pred["shader"].update(retdic) |
|
|
| |
| if True: |
| dec_out = self.rgbadecodernet( |
| retdic["y_last_remote_features"]) |
| y_weighted_x_reference_RGB = dec_out[:, 0:3, :, :] |
| y_weighted_mask_A = dec_out[:, 3:4, :, :] |
| y_weighted_warp_decoded_rgba = torch.cat( |
| (y_weighted_x_reference_RGB*y_weighted_mask_A, y_weighted_mask_A), dim=1 |
| ) |
| assert(y_weighted_warp_decoded_rgba.shape[1] == 4) |
| assert( |
| y_weighted_warp_decoded_rgba.shape[-1] == character_images_rgb_nmchw.shape[-1]) |
| |
| pred["shader"]["y_weighted_warp_decoded_rgba"] = y_weighted_warp_decoded_rgba |
| return pred |
|
|
| def character_parser_forward(self, data, pred={}): |
| if not("num_character_images" in data and "character_images" in data): |
| return pred |
| pred["parser"] = {"pred": None} |
|
|
| inputs_rgb_nmchw, num_samples, num_character_images = data[ |
| "character_images"], data["num_samples"], data["num_character_images"] |
| inputs_rgb_fchw = inputs_rgb_nmchw.view( |
| (num_samples * num_character_images, inputs_rgb_nmchw.shape[2], inputs_rgb_nmchw.shape[3], inputs_rgb_nmchw.shape[4])) |
|
|
| encoder_out, features = self.udpparsernet( |
| (inputs_rgb_fchw-0.6)/0.2970) |
|
|
| pred["parser"]["features"] = [features_out.view( |
| (num_samples, num_character_images, features_out.shape[1], features_out.shape[2], features_out.shape[3])) for features_out in features] |
|
|
| if (encoder_out is not None): |
|
|
| pred["parser"]["pred"] = UDPClip(encoder_out.view( |
| (num_samples, num_character_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3]))) |
|
|
| return pred |
|
|
| def pose_parser_sc_forward(self, data, pred={}): |
| if not("num_pose_images" in data and "pose_images" in data): |
| return pred |
| inputs_aug_rgb_nmchw, num_samples, num_pose_images = data[ |
| "pose_images"], data["num_samples"], data["num_pose_images"] |
| inputs_aug_rgb_fchw = inputs_aug_rgb_nmchw.view( |
| (num_samples * num_pose_images, inputs_aug_rgb_nmchw.shape[2], inputs_aug_rgb_nmchw.shape[3], inputs_aug_rgb_nmchw.shape[4])) |
|
|
| encoder_out, _ = self.udpparsernet( |
| (inputs_aug_rgb_fchw-0.6)/0.2970) |
|
|
| encoder_out = encoder_out.view( |
| (num_samples, num_pose_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3])) |
|
|
| |
| pred["pose_parser"] = {"pred":UDPClip(encoder_out)[:,0,:,:,:]} |
| |
|
|
| return pred |
|
|
| def shader_pose_encoder_forward(self, data, pred={}): |
| pred["shader"] = {} |
| if "pose_images" in data: |
| pose_images_rgb_nmchw = data["pose_images"] |
| target_gt_rgb = pose_images_rgb_nmchw[:, 0, :, :, :] |
| pred["shader"]["target_gt_rgb"] = target_gt_rgb |
|
|
| shader_target_a = None |
| if "pose_mask" in data: |
| pred["shader"]["target_gt_a"] = data["pose_mask"] |
| shader_target_a = data["pose_mask"] |
|
|
| shader_target_sudp = None |
| if "pose_label" in data: |
| shader_target_sudp = data["pose_label"][:, :3, :, :] |
|
|
| if self.args.test_pose_use_parser_udp: |
| shader_target_sudp = None |
| if shader_target_sudp is None: |
| shader_target_sudp = pred["pose_parser"]["pred"][:, 0:3, :, :] |
|
|
| if shader_target_a is None: |
| shader_target_a = pred["pose_parser"]["pred"][:, 3:4, :, :] |
|
|
| |
| x_target_sudp_a = torch.cat(( |
| shader_target_sudp*shader_target_a, |
| shader_target_a |
| ), 1) |
| pred["shader"].update({ |
| "x_target_sudp_a": x_target_sudp_a |
| }) |
| _, features = self.target_pose_encoder( |
| (x_target_sudp_a-0.6)/0.2970, ret_parser_out=False) |
|
|
| pred["shader"]["target_pose_features"] = features |
| return pred |