| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from arcface_torch.backbones.iresnet import iresnet100 |
| from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper |
|
|
|
|
| class ShapeAwareIdentityExtractor(nn.Module): |
| def __init__(self, identity_extractor_config): |
| """ |
| Shape Aware Identity Extractor |
| Parameters: |
| ---------- |
| identity_extractor_config: Dict[str, str] |
| 必须包含以下内容: |
| f_3d_checkpoint_path: str |
| 3D人脸重建模型路径,如"model/Deep3DFaceRecon_pytorch/checkpoints/epoch_20.pth" |
| f_id_checkpoint_path: str |
| arcface人脸识别模型路径 |
| 非官方实现用的是https://onedrive.live.com/?authkey=%21AFZjr283nwZHqbA&id=4A83B6B633B029CC%215585&cid=4A83B6B633B029CC/backbone.pth |
| """ |
| super(ShapeAwareIdentityExtractor, self).__init__() |
| f_3d_checkpoint_path = identity_extractor_config["f_3d_checkpoint_path"] |
| f_id_checkpoint_path = identity_extractor_config["f_id_checkpoint_path"] |
| |
| self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False) |
| self.f_3d.load_state_dict(torch.load(f_3d_checkpoint_path, map_location="cpu")["net_recon"]) |
| self.f_3d.eval() |
|
|
| |
| self.f_id = iresnet100(pretrained=False, fp16=False) |
| self.f_id.load_state_dict(torch.load(f_id_checkpoint_path, map_location="cpu")) |
| self.f_id.eval() |
|
|
| @torch.no_grad() |
| def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0): |
| """ |
| 插值shape和id信息 |
| """ |
| c_s = self.f_3d(i_source) |
| c_t = self.f_3d(i_target) |
| c_interp = shape_rate * c_s + (1 - shape_rate) * c_t |
| c_fuse = torch.cat((c_interp[:, :80], c_t[:, 80:]), dim=1) |
| |
| v_s = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) |
| v_t = F.normalize(self.f_id(F.interpolate((i_target - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) |
| v_id = id_rate * v_s + (1 - id_rate) * v_t |
| |
| v_sid = torch.cat((c_fuse, v_id), dim=1) |
| return v_sid |
|
|
| def forward(self, i_source, i_target): |
| """ |
| Parameters: |
| ----------- |
| i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image |
| i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image |
| |
| Returns: |
| -------- |
| v_sid: torch.Tensor, fused shape and id features |
| """ |
| |
| c_s = self.f_3d(i_source) |
| c_t = self.f_3d(i_target) |
|
|
| |
| |
| c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1) |
|
|
| |
| v_id = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) |
|
|
| |
| v_sid = torch.cat((c_fuse, v_id), dim=1) |
| return v_sid |
|
|