| import os |
| from abc import abstractmethod |
| from typing import Dict |
| from typing import Tuple |
|
|
| import kornia |
| import lpips |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from loguru import logger |
|
|
| from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel |
| from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper |
| from HRNet.hrnet import HighResolutionNet |
| from arcface_torch.backbones.iresnet import iresnet100 |
| from models.discriminator import Discriminator |
| from models.gan_loss import GANLoss |
| from models.generator import Generator |
| from models.init_weight import init_net |
|
|
| class HifiFace: |
| def __init__( |
| self, |
| identity_extractor_config, |
| generator_path, |
| is_training=False, |
| device="cpu" |
| ): |
| super(HifiFace, self).__init__() |
| self.d_optimizer = None |
| self.g_optimizer = None |
| self.generator = Generator(identity_extractor_config) |
| self.is_training = is_training |
| self.device = device |
| self.generator_path = generator_path |
|
|
| if self.is_training: |
| self.lr = TrainConfig().lr |
| self.use_ddp = TrainConfig().use_ddp |
| self.grad_clip = TrainConfig().grad_clip if TrainConfig().grad_clip is not None else 100.0 |
|
|
| self.discriminator = init_net(Discriminator(3)) |
|
|
| self.l1_loss = nn.L1Loss() |
| if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: |
| self.mse_loss = nn.MSELoss() |
| self.loss_fn_vgg = lpips.LPIPS(net="vgg") |
| self.adv_loss = GANLoss() |
|
|
| |
| self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False) |
| self.f_3d.load_state_dict( |
| torch.load(identity_extractor_config["f_3d_checkpoint_path"], map_location="cpu")["net_recon"] |
| ) |
| self.f_3d.eval() |
| self.face_model = ParametricFaceModel(bfm_folder=identity_extractor_config["bfm_folder"]) |
| self.face_model.to("cpu") |
|
|
| |
| self.f_id = iresnet100(pretrained=False, fp16=False) |
| self.f_id.load_state_dict(torch.load(identity_extractor_config["f_id_checkpoint_path"], map_location="cpu")) |
| self.f_id.eval() |
|
|
| |
| if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: |
| self.model_mouth = HighResolutionNet() |
| checkpoint = torch.load(identity_extractor_config["hrnet_path"], map_location="cpu") |
| self.model_mouth.load_state_dict(checkpoint) |
| self.model_mouth.eval() |
|
|
| self.lambda_adv = 1 |
| self.lambda_seg = 100 |
| self.lambda_rec = 20 |
| self.lambda_cyc = 1 |
| self.lambda_lpips = 5 |
|
|
| self.lambda_shape = 0.5 |
| self.lambda_id = 5 |
| self.lambda_eye_hm = 10000.0 |
| self.lambda_mouth_hm = 10000.0 |
|
|
| self.dilation_kernel = torch.ones(5, 5) |
|
|
| self.load_checkpoint() |
|
|
| self.setup(self.device) |
|
|
| def save(self, path, idx=None): |
| os.makedirs(path, exist_ok=True) |
| if idx is None: |
| g_path = os.path.join(path, "generator.pth") |
| d_path = os.path.join(path, "discriminator.pth") |
| else: |
| g_path = os.path.join(path, f"generator_{idx}.pth") |
| d_path = os.path.join(path, f"discriminator_{idx}.pth") |
| if self.use_ddp: |
| torch.save(self.generator.module.state_dict(), g_path) |
| torch.save(self.discriminator.module.state_dict(), d_path) |
| else: |
| torch.save(self.generator.state_dict(), g_path) |
| torch.save(self.discriminator.state_dict(), d_path) |
|
|
| @abstractmethod |
| def load_checkpoint(self): |
| pass |
|
|
| def setup(self, device): |
| self.generator.to(device) |
|
|
| if self.is_training: |
| self.discriminator.to(device) |
| self.l1_loss.to(device) |
| if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: |
| self.mse_loss.to(device) |
| self.f_3d.to(device) |
| self.f_id.to(device) |
|
|
| self.loss_fn_vgg.to(device) |
| self.face_model.to(device) |
| self.adv_loss.to(device) |
|
|
| if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: |
| self.model_mouth.to(device) |
| self.f_3d.requires_grad_(False) |
| self.f_id.requires_grad_(False) |
| self.loss_fn_vgg.requires_grad_(False) |
| if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: |
| self.model_mouth.requires_grad_(False) |
| self.dilation_kernel = self.dilation_kernel.to(device) |
| if self.use_ddp: |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| import torch.distributed as dist |
|
|
| self.generator = DDP(self.generator, device_ids=[device]) |
| self.discriminator = DDP(self.discriminator, device_ids=[device]) |
|
|
| if dist.get_rank() == 0: |
| torch.save(self.generator.state_dict(), "/tmp/generator.pth") |
| torch.save(self.discriminator.state_dict(), "/tmp/discriminator.pth") |
|
|
| dist.barrier() |
| self.generator.load_state_dict(torch.load("/tmp/generator.pth", map_location=device)) |
| self.discriminator.load_state_dict(torch.load("/tmp/discriminator.pth", map_location=device)) |
|
|
| self.g_optimizer = torch.optim.AdamW(self.generator.parameters(), lr=self.lr, betas=[0, 0.999]) |
| self.d_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr=self.lr, betas=[0, 0.999]) |
|
|
| def train(self): |
| self.generator.train() |
| self.discriminator.train() |
| |
| if self.use_ddp: |
| self.generator.module.id_extractor.eval() |
| else: |
| self.generator.id_extractor.eval() |
|
|
| def eval(self): |
| self.generator.eval() |
| if self.is_training: |
| self.discriminator.eval() |
|
|
| def train_forward_generator(self, source_img, target_img, target_mask, same_id_mask): |
| """ |
| 训练时候 Generator的loss计算 |
| Parameters: |
| ----------- |
| source_img: torch.Tensor |
| target_img: torch.Tensor |
| target_mask: torch.Tensor, [B, 1, H, W] |
| same_id_mask: torch.Tensor, [B, 1] |
| |
| Returns: |
| -------- |
| source_img: torch.Tensor |
| target_img: torch.Tensor |
| i_cycle: torch.Tensor, cycle image |
| i_r: torch.Tensor |
| m_r: torch.Tensor |
| loss: Dict[torch.Tensor], contain pairs of loss name and loss values |
| """ |
| same = same_id_mask.unsqueeze(-1).unsqueeze(-1) |
| i_r, i_low, m_r, m_low = self.generator(source_img, target_img, need_id_grad=False) |
| i_cycle, _, _, _ = self.generator(target_img, i_r, need_id_grad=True) |
| d_r = self.discriminator(i_r) |
|
|
| |
|
|
| with torch.no_grad(): |
| c_s = self.f_3d(F.interpolate(source_img, size=224, mode="bilinear")) |
| c_t = self.f_3d(F.interpolate(target_img, size=224, mode="bilinear")) |
| c_r = self.f_3d(F.interpolate(i_r, size=224, mode="bilinear")) |
| c_low = self.f_3d(F.interpolate(i_low, size=224, mode="bilinear")) |
| with torch.no_grad(): |
| c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1) |
| _, _, _, q_fuse = self.face_model.compute_for_render(c_fuse) |
| _, _, _, q_r = self.face_model.compute_for_render(c_r) |
| _, _, _, q_low = self.face_model.compute_for_render(c_low) |
| with torch.no_grad(): |
| v_id_i_s = F.normalize( |
| self.f_id(F.interpolate((source_img - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2 |
| ) |
|
|
| v_id_i_r = F.normalize(self.f_id(F.interpolate((i_r - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) |
| v_id_i_low = F.normalize(self.f_id(F.interpolate((i_low - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) |
| loss_shape = self.l1_loss(q_fuse, q_r) + self.l1_loss(q_fuse, q_low) |
| loss_shape = torch.clamp(loss_shape, min=0.0, max=10.0) |
|
|
| inner_product_r = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_r.unsqueeze(2)).squeeze() |
| inner_product_low = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_low.unsqueeze(2)).squeeze() |
| loss_id = self.l1_loss(torch.ones_like(inner_product_r), inner_product_r) + self.l1_loss( |
| torch.ones_like(inner_product_low), inner_product_low |
| ) |
| loss_sid = self.lambda_shape * loss_shape + self.lambda_id * loss_id |
|
|
| |
|
|
| loss_cycle = self.l1_loss(target_img, i_cycle) |
|
|
| |
| target_mask = kornia.morphology.dilation(target_mask, self.dilation_kernel) |
|
|
| loss_segmentation = self.l1_loss( |
| F.interpolate(target_mask, scale_factor=0.25, mode="bilinear"), m_low |
| ) + self.l1_loss(target_mask, m_r) |
|
|
| loss_reconstruction = self.l1_loss(i_r * same, target_img * same) + self.l1_loss( |
| i_low * same, F.interpolate(target_img, scale_factor=0.25, mode="bilinear") * same |
| ) |
|
|
| loss_perceptual = self.loss_fn_vgg(target_img * same, i_r * same).mean() |
|
|
| loss_adversarial = self.adv_loss(d_r, True, for_discriminator=False) |
|
|
| loss_realism = ( |
| self.lambda_adv * loss_adversarial |
| + self.lambda_seg * loss_segmentation |
| + self.lambda_rec * loss_reconstruction |
| + self.lambda_cyc * loss_cycle |
| + self.lambda_lpips * loss_perceptual |
| ) |
|
|
| |
| loss_eye_hm = 0 |
| |
| loss_mouth_hm = 0 |
| if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: |
| target_hm = self.model_mouth(target_img) |
| r_hm = self.model_mouth(i_r) |
|
|
| if TrainConfig().eye_hm_loss: |
| target_eye_hm = target_hm[:, 96:98, :, :] |
| r_eye_hm = r_hm[:, 96:98, :, :] |
| loss_eye_hm = self.mse_loss(r_eye_hm, target_eye_hm) |
| loss_realism = loss_realism + self.lambda_eye_hm * loss_eye_hm |
|
|
| if TrainConfig().mouth_hm_loss: |
| target_mouth_hm = target_hm[:, 76:96, :, :] |
| r_mouth_hm = r_hm[:, 76:96, :, :] |
| loss_mouth_hm = self.mse_loss(r_mouth_hm, target_mouth_hm) |
| loss_realism = loss_realism + self.lambda_mouth_hm * loss_mouth_hm |
|
|
| loss_generator = loss_sid + loss_realism |
|
|
| loss_dict = { |
| "loss_shape": loss_shape, |
| "loss_id": loss_id, |
| "loss_sid": loss_sid, |
| "loss_cycle": loss_cycle, |
| "loss_segmentation": loss_segmentation, |
| "loss_reconstruction": loss_reconstruction, |
| "loss_perceptual": loss_perceptual, |
| "loss_adversarial": loss_adversarial, |
| "loss_realism": loss_realism, |
| "loss_generator": loss_generator, |
| } |
| if TrainConfig().eye_hm_loss: |
| loss_dict.update({"loss_eye_hm": loss_eye_hm}) |
| if TrainConfig().mouth_hm_loss: |
| loss_dict.update({"loss_mouth_hm": loss_mouth_hm}) |
| return ( |
| source_img, |
| target_img, |
| i_cycle.detach(), |
| i_r.detach(), |
| m_r.detach(), |
| loss_dict, |
| ) |
|
|
| def train_forward_discriminator(self, target_img, i_r): |
| """ |
| 训练时候 Discriminator的loss计算 |
| Parameters: |
| ----------- |
| target_img: torch.Tensor, 目标脸图片 |
| i_r: torch.Tensor, 换脸结果 |
| |
| Returns: |
| -------- |
| Dict[str]: contains pair of loss name and loss values |
| """ |
| d_gt = self.discriminator(target_img) |
| d_fake = self.discriminator(i_r.detach()) |
| loss_real = self.adv_loss(d_gt, True) |
| loss_fake = self.adv_loss(d_fake, False) |
|
|
| |
| |
| |
| |
|
|
| loss_discriminator = loss_real + loss_fake |
| return { |
| "loss_real": loss_real, |
| "loss_fake": loss_fake, |
| |
| "loss_discriminator": loss_discriminator, |
| } |
|
|
| def forward( |
| self, source_img: torch.Tensor, target_img: torch.Tensor, shape_rate=None, id_rate=None |
| ) -> torch.Tensor: |
| """ |
| Parameters: |
| ----------- |
| source_img: torch.Tensor, source face 图像 |
| target_img: torch.Tensor, target face 图像 |
| *_rate: 插值系数 |
| Returns: |
| -------- |
| i_r: torch.Tensor, swapped result |
| """ |
| if shape_rate is None and id_rate is None: |
| i_r, _, m_r, _ = self.generator(source_img, target_img) |
| else: |
| if shape_rate is None: |
| shape_rate = 1.0 |
| if id_rate is None: |
| id_rate = 1.0 |
| i_r, _, m_r, _ = self.generator.interp(source_img, target_img, shape_rate, id_rate) |
| return i_r, m_r |
|
|
| def optimize( |
| self, |
| source_img: torch.Tensor, |
| target_img: torch.Tensor, |
| target_mask: torch.Tensor, |
| same_id_mask: torch.Tensor, |
| ) -> Tuple[Dict, Dict[str, torch.Tensor]]: |
| """ |
| 模型的optimize |
| 训练模式下执行一次训练,并返回loss信息和结果 |
| Parameters: |
| ----------- |
| source_img: torch.Tensor, source face 图像 |
| target_img: torch.Tensor, target face 图像 |
| target_mask: torch.Tensor, target face mask |
| same_id_mask: torch.Tensor, same id mask, 标识source 和 target是否是同个人 |
| |
| Returns: |
| -------- |
| Tuple[Dict, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| loss_dict, source_img, target_img, m_r(预测的mask), i_r(换脸结果) |
| """ |
| src_img, tgt_img, i_cycle, i_r, m_r, loss_G_dict = self.train_forward_generator( |
| source_img, target_img, target_mask, same_id_mask |
| ) |
| loss_G = loss_G_dict["loss_generator"] |
| self.g_optimizer.zero_grad() |
| loss_G.backward() |
| global_norm_G = torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.grad_clip) |
| self.g_optimizer.step() |
|
|
| loss_D_dict = self.train_forward_discriminator(tgt_img, i_r) |
| loss_D = loss_D_dict["loss_discriminator"] |
| self.d_optimizer.zero_grad() |
| loss_D.backward() |
| global_norm_D = torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_clip) |
| self.d_optimizer.step() |
|
|
| total_loss_dict = {"global_norm_G": global_norm_G, "global_norm_D": global_norm_D} |
| total_loss_dict.update(loss_G_dict) |
| total_loss_dict.update(loss_D_dict) |
|
|
| return total_loss_dict, { |
| "source face": src_img, |
| "target face": tgt_img, |
| "swapped face": torch.clamp(i_r, min=0.0, max=1.0), |
| "pred face mask": m_r, |
| "cycle face": i_cycle, |
| } |
|
|
|
|
| class HifiFaceST(HifiFace): |
| def __init__(self, identity_extractor_config, device, generator_path): |
| super().__init__(identity_extractor_config, device=device, generator_path=generator_path) |
|
|
| def load_checkpoint(self): |
| self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device)) |
| logger.info(f"Loading generator from {self.generator_path}") |
|
|
| class HifiFaceWGM(HifiFace): |
| def __init__(self, identity_extractor_config, device, generator_path): |
| super().__init__(identity_extractor_config, device=device, generator_path=generator_path) |
|
|
| def load_checkpoint(self): |
| self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device)) |
| logger.info(f"Loading generator from {self.generator_path}") |