| import argparse |
| import os |
| from typing import List |
| from typing import Optional |
|
|
| import cv2 |
| import numpy as np |
| import torch |
|
|
| from configs.train_config import TrainConfig |
| from models.model import HifiFace |
|
|
|
|
| def test( |
| data_root: str, |
| result_path: str, |
| source_face: List[str], |
| target_face: List[str], |
| model_path: str, |
| model_idx: Optional[int], |
| ): |
| opt = TrainConfig() |
| opt.use_ddp = False |
|
|
| device = "cpu" |
| checkpoint = (model_path, model_idx) |
| model = HifiFace(opt.identity_extractor_config, is_training=False, device=device, load_checkpoint=checkpoint) |
| model.eval() |
|
|
| results = [] |
| for source, target in zip(source_face, target_face): |
| source = os.path.join(data_root, source) |
| target = os.path.join(data_root, target) |
|
|
| src_img = cv2.imread(source) |
| src_img = cv2.resize(src_img, (256, 256)) |
| src = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) |
| src = src.transpose(2, 0, 1) |
| src = torch.from_numpy(src).unsqueeze(0).to(device).float() |
| src = src / 255.0 |
|
|
| tgt_img = cv2.imread(target) |
| tgt_img = cv2.resize(tgt_img, (256, 256)) |
| tgt = cv2.cvtColor(tgt_img, cv2.COLOR_BGR2RGB) |
| tgt = tgt.transpose(2, 0, 1) |
| tgt = torch.from_numpy(tgt).unsqueeze(0).to(device).float() |
| tgt = tgt / 255.0 |
|
|
| with torch.no_grad(): |
| result_face = model.forward(src, tgt).cpu() |
| result_face = torch.clamp(result_face, 0, 1) * 255 |
| result_face = result_face.numpy()[0].astype(np.uint8) |
| result_face = result_face.transpose(1, 2, 0) |
|
|
| result_face = cv2.cvtColor(result_face, cv2.COLOR_BGR2RGB) |
| one_result = np.concatenate((src_img, tgt_img, result_face), axis=0) |
| results.append(one_result) |
| result = np.concatenate(results, axis=1) |
| swapped_face = os.path.join(data_root, result_path) |
| cv2.imwrite(swapped_face, result) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| prog="benchmark", description="What the program does", epilog="Text at the bottom of help" |
| ) |
| parser.add_argument("-m", "--model_name") |
| parser.add_argument("-i", "--model_index") |
| args = parser.parse_args() |
| data_root = "/home/xuehongyang/data/face_swap_test" |
|
|
| model_path = os.path.join("/data/checkpoints/hififace/", args.model_name) |
| model_idx = int(args.model_index) |
|
|
| name = f"{args.model_name}_{args.model_index}" |
| source = [ |
| "male_1.jpg", |
| "male_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "male_1.jpg", |
| "male_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "test1.jpg", |
| "test1.jpg", |
| "test1.jpg", |
| ] |
| target = [ |
| "male_2.jpg", |
| "male_1.jpg", |
| "female_2.jpg", |
| "female_1.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "male_2.jpg", |
| "male_1.jpg", |
| "male_1.jpg", |
| "male_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "male_1.jpg", |
| ] |
|
|
| target_src = os.path.join(data_root, f"../{name}_1.jpg") |
| test(data_root, target_src, source, target, model_path, model_idx) |
|
|
| source = [ |
| "male_2.jpg", |
| "male_1.jpg", |
| "male_1.jpg", |
| "male_2.jpg", |
| "male_1.jpg", |
| "male_2.jpg", |
| "male_1.jpg", |
| "male_2.jpg", |
| "male_1.jpg", |
| "male_2.jpg", |
| "female_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "female_1.jpg", |
| ] |
| target = [ |
| "male_1.jpg", |
| "male_2.jpg", |
| "minlu_1.jpg", |
| "minlu_2.jpg", |
| "shizong_1.jpg", |
| "shizong_2.jpg", |
| "tianxin_1.jpg", |
| "tianxin_2.jpg", |
| "xiaohui_1.jpg", |
| "xiaohui_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "female_3.jpg", |
| "female_4.jpg", |
| "female_5.jpg", |
| "female_6.jpg", |
| "lixia_1.jpg", |
| "lixia_2.jpg", |
| "qq_1.jpg", |
| "qq_2.jpg", |
| "pink_1.jpg", |
| "pink_2.jpg", |
| "xulie_1.jpg", |
| "xulie_2.jpg", |
| ] |
|
|
| target_src = os.path.join(data_root, f"../{name}_2.jpg") |
| test(data_root, target_src, source, target, model_path, model_idx) |
|
|
| source = [ |
| "male_2.jpg", |
| "male_1.jpg", |
| "shizong_1.jpg", |
| "shizong_2.jpg", |
| "minlu_1.jpg", |
| "minlu_2.jpg", |
| "xiaohui_1.jpg", |
| "xiaohui_2.jpg", |
| "tianxin_1.jpg", |
| "tianxin_2.jpg", |
| "female_2.jpg", |
| "female_1.jpg", |
| "female_5.jpg", |
| "female_6.jpg", |
| "female_3.jpg", |
| "female_4.jpg", |
| "qq_1.jpg", |
| "qq_2.jpg", |
| "pink_1.jpg", |
| "pink_2.jpg", |
| "xulie_1.jpg", |
| "xulie_2.jpg", |
| "lixia_1.jpg", |
| "lixia_2.jpg", |
| ] |
| target = [ |
| "male_2.jpg", |
| "male_1.jpg", |
| "minlu_1.jpg", |
| "minlu_2.jpg", |
| "shizong_1.jpg", |
| "shizong_2.jpg", |
| "tianxin_1.jpg", |
| "tianxin_2.jpg", |
| "xiaohui_1.jpg", |
| "xiaohui_2.jpg", |
| "female_1.jpg", |
| "female_2.jpg", |
| "female_3.jpg", |
| "female_4.jpg", |
| "female_5.jpg", |
| "female_6.jpg", |
| "lixia_1.jpg", |
| "lixia_2.jpg", |
| "qq_1.jpg", |
| "qq_2.jpg", |
| "pink_1.jpg", |
| "pink_2.jpg", |
| "xulie_1.jpg", |
| "xulie_2.jpg", |
| ] |
|
|
| target_src = os.path.join(data_root, f"../{name}_3.jpg") |
| test(data_root, target_src, source, target, model_path, model_idx) |
|
|