| from typing import Optional |
|
|
| import cv2 |
| import numpy as np |
| import torch |
|
|
| from configs.train_config import TrainConfig |
| from models.model import HifiFace |
|
|
|
|
| def inference(source_face: str, target_face: str, model_path: str, model_idx: Optional[int], swapped_face: str): |
| opt = TrainConfig() |
| opt.use_ddp = False |
|
|
| device = "cpu" |
| checkpoint = (model_path, model_idx) |
| model = HifiFace(opt.identity_extractor_config, is_training=False, device=device, load_checkpoint=checkpoint) |
| model.eval() |
|
|
| src = cv2.cvtColor(cv2.imread(source_face), cv2.COLOR_BGR2RGB) |
| src = cv2.resize(src, (256, 256)) |
| src = src.transpose(2, 0, 1) |
| src = torch.from_numpy(src).unsqueeze(0).to(device).float() |
| src = src / 255.0 |
|
|
| tgt = cv2.cvtColor(cv2.imread(target_face), cv2.COLOR_BGR2RGB) |
| tgt = cv2.resize(tgt, (256, 256)) |
| tgt = tgt.transpose(2, 0, 1) |
| tgt = torch.from_numpy(tgt).unsqueeze(0).to(device).float() |
| tgt = tgt / 255.0 |
|
|
| with torch.no_grad(): |
| result_face = model.forward(src, tgt).cpu() |
| result_face = torch.clamp(result_face, 0, 1) * 255 |
| result_face = result_face.numpy()[0].astype(np.uint8) |
| result_face = result_face.transpose(1, 2, 0) |
|
|
| result_face = cv2.cvtColor(result_face, cv2.COLOR_BGR2RGB) |
| cv2.imwrite(swapped_face, result_face) |
|
|
|
|
| if __name__ == "__main__": |
| source_face = "/home/xuehongyang/data/female_1.jpg" |
| target_face = "/home/xuehongyang/data/female_2.jpg" |
| model_path = "/data/checkpoints/hififace/baseline_1k_ddp_with_cyc_1681278017147" |
| model_idx = 80000 |
| swapped_face = "/home/xuehongyang/data/male_1_to_male_2.jpg" |
| inference(source_face, target_face, model_path, model_idx, swapped_face) |
|
|