| import os |
| import sys |
| import cv2 |
| import yaml |
| import imageio |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import subprocess, platform |
| from mutagen.wave import WAVE |
| from datetime import timedelta |
|
|
| from face_vid2vid.sync_batchnorm.replicate import DataParallelWithCallback |
| from face_vid2vid.modules.generator import OcclusionAwareSPADEGenerator |
| from face_vid2vid.modules.keypoint_detector import KPDetector, HEEstimator |
| from face_vid2vid.animate import normalize_kp |
| from batch_face import RetinaFace |
|
|
|
|
| if sys.version_info[0] < 3: |
| raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") |
|
|
|
|
| def load_checkpoints(config_path, checkpoint_path): |
| with open(config_path) as f: |
| config = yaml.load(f, Loader=yaml.FullLoader) |
|
|
| generator = OcclusionAwareSPADEGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"]) |
| |
| generator.cuda().half() |
|
|
| kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"]) |
| |
| kp_detector.cuda() |
|
|
| he_estimator = HEEstimator(**config["model_params"]["he_estimator_params"], **config["model_params"]["common_params"]) |
| |
| he_estimator.cuda() |
|
|
| print("Loading checkpoints") |
| checkpoint = torch.load(checkpoint_path) |
|
|
| generator.load_state_dict(checkpoint["generator"]) |
| kp_detector.load_state_dict(checkpoint["kp_detector"]) |
| he_estimator.load_state_dict(checkpoint["he_estimator"]) |
|
|
| generator = DataParallelWithCallback(generator) |
| kp_detector = DataParallelWithCallback(kp_detector) |
| he_estimator = DataParallelWithCallback(he_estimator) |
|
|
| generator.eval() |
| kp_detector.eval() |
| he_estimator.eval() |
| print("Model successfully loaded!") |
|
|
| return generator, kp_detector, he_estimator |
|
|
|
|
| def headpose_pred_to_degree(pred): |
| device = pred.device |
| idx_tensor = [idx for idx in range(66)] |
| idx_tensor = torch.FloatTensor(idx_tensor).to(device) |
| pred = F.softmax(pred, dim=1) |
| degree = torch.sum(pred * idx_tensor, axis=1) * 3 - 99 |
|
|
| return degree |
|
|
|
|
| def get_rotation_matrix(yaw, pitch, roll): |
| yaw = yaw / 180 * 3.14 |
| pitch = pitch / 180 * 3.14 |
| roll = roll / 180 * 3.14 |
|
|
| roll = roll.unsqueeze(1) |
| pitch = pitch.unsqueeze(1) |
| yaw = yaw.unsqueeze(1) |
|
|
| pitch_mat = torch.cat( |
| [ |
| torch.ones_like(pitch), |
| torch.zeros_like(pitch), |
| torch.zeros_like(pitch), |
| torch.zeros_like(pitch), |
| torch.cos(pitch), |
| -torch.sin(pitch), |
| torch.zeros_like(pitch), |
| torch.sin(pitch), |
| torch.cos(pitch), |
| ], |
| dim=1, |
| ) |
| pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) |
|
|
| yaw_mat = torch.cat( |
| [ |
| torch.cos(yaw), |
| torch.zeros_like(yaw), |
| torch.sin(yaw), |
| torch.zeros_like(yaw), |
| torch.ones_like(yaw), |
| torch.zeros_like(yaw), |
| -torch.sin(yaw), |
| torch.zeros_like(yaw), |
| torch.cos(yaw), |
| ], |
| dim=1, |
| ) |
| yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) |
|
|
| roll_mat = torch.cat( |
| [ |
| torch.cos(roll), |
| -torch.sin(roll), |
| torch.zeros_like(roll), |
| torch.sin(roll), |
| torch.cos(roll), |
| torch.zeros_like(roll), |
| torch.zeros_like(roll), |
| torch.zeros_like(roll), |
| torch.ones_like(roll), |
| ], |
| dim=1, |
| ) |
| roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) |
|
|
| rot_mat = torch.einsum("bij,bjk,bkm->bim", pitch_mat, yaw_mat, roll_mat) |
|
|
| return rot_mat |
|
|
|
|
| def keypoint_transformation(kp_canonical, he, estimate_jacobian=False, free_view=False, yaw=0, pitch=0, roll=0, output_coord=False): |
| kp = kp_canonical["value"] |
| if not free_view: |
| yaw, pitch, roll = he["yaw"], he["pitch"], he["roll"] |
| yaw = headpose_pred_to_degree(yaw) |
| pitch = headpose_pred_to_degree(pitch) |
| roll = headpose_pred_to_degree(roll) |
| else: |
| if yaw is not None: |
| yaw = torch.tensor([yaw]).cuda() |
| else: |
| yaw = he["yaw"] |
| yaw = headpose_pred_to_degree(yaw) |
| if pitch is not None: |
| pitch = torch.tensor([pitch]).cuda() |
| else: |
| pitch = he["pitch"] |
| pitch = headpose_pred_to_degree(pitch) |
| if roll is not None: |
| roll = torch.tensor([roll]).cuda() |
| else: |
| roll = he["roll"] |
| roll = headpose_pred_to_degree(roll) |
|
|
| t, exp = he["t"], he["exp"] |
|
|
| rot_mat = get_rotation_matrix(yaw, pitch, roll) |
|
|
| |
| kp_rotated = torch.einsum("bmp,bkp->bkm", rot_mat, kp) |
|
|
| |
| t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) |
| kp_t = kp_rotated + t |
|
|
| |
| exp = exp.view(exp.shape[0], -1, 3) |
| kp_transformed = kp_t + exp |
|
|
| if estimate_jacobian: |
| jacobian = kp_canonical["jacobian"] |
| jacobian_transformed = torch.einsum("bmp,bkps->bkms", rot_mat, jacobian) |
| else: |
| jacobian_transformed = None |
|
|
| if output_coord: |
| return {"value": kp_transformed, "jacobian": jacobian_transformed}, { |
| "yaw": float(yaw.cpu().numpy()), |
| "pitch": float(pitch.cpu().numpy()), |
| "roll": float(roll.cpu().numpy()), |
| } |
|
|
| return {"value": kp_transformed, "jacobian": jacobian_transformed} |
|
|
|
|
| def get_square_face(coords, image): |
| x1, y1, x2, y2 = coords |
| |
| length = max(x2 - x1, y2 - y1) // 2 |
| x1 = x1 - length * 0.5 |
| x2 = x2 + length * 0.5 |
| y1 = y1 - length * 0.5 |
| y2 = y2 + length * 0.5 |
|
|
| |
| center = (x1 + x2) // 2, (y1 + y2) // 2 |
| length = max(x2 - x1, y2 - y1) // 2 |
| x1 = max(int(round(center[0] - length)), 0) |
| x2 = min(int(round(center[0] + length)), image.shape[1]) |
| y1 = max(int(round(center[1] - length)), 0) |
| y2 = min(int(round(center[1] + length)), image.shape[0]) |
| return image[y1:y2, x1:x2] |
|
|
|
|
| def smooth_coord(last_coord, current_coord, smooth_factor=0.2): |
| change = np.array(current_coord) - np.array(last_coord) |
| |
| change = change * smooth_factor |
| return (np.array(last_coord) + np.array(change)).astype(int).tolist() |
|
|
|
|
| class FaceAnimationClass: |
| def __init__(self, source_image_path=None, use_sr=False): |
| assert source_image_path is not None, "source_image_path is None, please set source_image_path" |
| config_path = os.path.join(os.path.dirname(__file__), "face_vid2vid/config/vox-256-spade.yaml") |
| |
| checkpoint_path = os.path.join(os.path.expanduser("~"), ".cache/torch/hub/checkpoints/FaceMapping.pth.tar") |
| if not os.path.exists(checkpoint_path): |
| os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) |
| from gdown import download |
| file_id = "11ZgyjKI5OcB7klcsIdPpCCX38AIX8Soc" |
| download(id=file_id, output=checkpoint_path, quiet=False) |
| if use_sr: |
| from face_vid2vid.GPEN.face_enhancement import FaceEnhancement |
|
|
| self.faceenhancer = FaceEnhancement( |
| size=256, model="GPEN-BFR-256", use_sr=False, sr_model="realesrnet_x2", channel_multiplier=1, narrow=0.5, use_facegan=True |
| ) |
|
|
| |
| self.generator, self.kp_detector, self.he_estimator = load_checkpoints(config_path=config_path, checkpoint_path=checkpoint_path) |
| source_image = cv2.cvtColor(cv2.imread(source_image_path), cv2.COLOR_RGB2BGR).astype(np.float32) / 255. |
| source_image = cv2.resize(source_image, (256, 256), interpolation=cv2.INTER_AREA) |
| source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) |
| self.source = source.cuda() |
|
|
| |
| self.face_detector = RetinaFace() |
| self.detect_interval = 8 |
| self.smooth_factor = 0.2 |
|
|
| |
| self.base_frame = cv2.imread(source_image_path) if not use_sr else self.faceenhancer.process(cv2.imread(source_image_path))[0] |
| self.base_frame = cv2.resize(self.base_frame, (256, 256)) |
| self.blank_frame = np.ones(self.base_frame.shape, dtype=np.uint8) * 255 |
| cv2.putText(self.blank_frame, "Face not", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) |
| cv2.putText(self.blank_frame, "detected!", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) |
|
|
| |
| self.n_frame = 0 |
|
|
| |
| self.first_frame = True |
| self.last_coords = None |
| self.coords = None |
| self.use_sr = use_sr |
| self.kp_source = None |
| self.kp_driving_initial = None |
|
|
|
|
| def _conver_input_frame(self, frame): |
| frame = cv2.resize(frame, (256, 256), interpolation=cv2.INTER_NEAREST).astype(np.float32) / 255.0 |
| return torch.tensor(frame[np.newaxis]).permute(0, 3, 1, 2).cuda() |
|
|
| def _process_first_frame(self, frame): |
| print("Processing first frame") |
| |
| faces = self.face_detector(frame, cv=True) |
| if len(faces) == 0: |
| raise ValueError("Face is not detected") |
| else: |
| self.coords = faces[0][0] |
| face = get_square_face(self.coords, frame) |
| self.last_coords = self.coords |
|
|
| |
| with torch.no_grad(): |
| self.kp_canonical = self.kp_detector(self.source) |
| self.he_source = self.he_estimator(self.source) |
|
|
| face_input = self._conver_input_frame(face) |
| he_driving_initial = self.he_estimator(face_input) |
| self.kp_driving_initial, coordinates = keypoint_transformation(self.kp_canonical, he_driving_initial, output_coord=True) |
| self.kp_source = keypoint_transformation( |
| self.kp_canonical, self.he_source, free_view=True, yaw=coordinates["yaw"], pitch=coordinates["pitch"], roll=coordinates["roll"] |
| ) |
|
|
| def _inference(self, frame): |
| |
| with torch.no_grad(): |
| self.n_frame += 1 |
| if self.first_frame: |
| self._process_first_frame(frame) |
| self.first_frame = False |
| else: |
| pass |
| if self.n_frame % self.detect_interval == 0: |
| faces = self.face_detector(frame, cv=True) |
| if len(faces) == 0: |
| raise ValueError("Face is not detected") |
| else: |
| self.coords = faces[0][0] |
| self.coords = smooth_coord(self.last_coords, self.coords, self.smooth_factor) |
| face = get_square_face(self.coords, frame) |
| self.last_coords = self.coords |
| face_input = self._conver_input_frame(face) |
|
|
| he_driving = self.he_estimator(face_input) |
| kp_driving = keypoint_transformation(self.kp_canonical, he_driving) |
| kp_norm = normalize_kp( |
| kp_source=self.kp_source, |
| kp_driving=kp_driving, |
| kp_driving_initial=self.kp_driving_initial, |
| use_relative_movement=True, |
| adapt_movement_scale=True, |
| ) |
|
|
| out = self.generator(self.source, kp_source=self.kp_source, kp_driving=kp_norm, fp16=True) |
| image = np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0] |
| image = (np.array(image).astype(np.float32) * 255).astype(np.uint8) |
| result = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
| return face, result |
|
|
| def inference(self, frame): |
| |
| try: |
| if frame is not None: |
| face, result = self._inference(frame) |
| if self.use_sr: |
| result, _, _ = self.faceenhancer.process(result) |
| result = cv2.resize(result, (256, 256)) |
| return face, result |
| except Exception as e: |
| print(e) |
| self.first_frame = True |
| self.n_frame = 0 |
| return self.blank_frame, self.base_frame |
|
|
|
|
| def get_audio_duration(audioPath): |
| audio = WAVE(audioPath) |
| duration = audio.info.length |
| return duration |
|
|
| def seconds_to_hms(seconds): |
| seconds = int(seconds) + 1 |
| hms = str(timedelta(seconds=seconds)) |
| hms = hms.split(":") |
| hms = [f"0{h}" if len(h) == 1 else h for h in hms] |
| return ":".join(hms) |
|
|
| def animate_face(path_id, audiofile, driverfile, imgfile, animatedfile): |
| from tqdm import tqdm |
| import time |
| faceanimation = FaceAnimationClass(source_image_path=os.path.join("temp", path_id, imgfile), use_sr=False) |
|
|
| tmpfile = f"temp/{path_id}/tmp.mp4" |
| duration = get_audio_duration(os.path.join("temp", path_id, audiofile)) |
| print("duration of audio:", duration) |
| hms = seconds_to_hms(duration) |
| print("converted into hms:", hms) |
| command = f"ffmpeg -ss 00:00:00 -i {driverfile} -to {hms} -c copy {tmpfile}" |
| subprocess.call(command, shell=platform.system() != 'Windows') |
|
|
| capture = cv2.VideoCapture(tmpfile) |
| fps = capture.get(cv2.CAP_PROP_FPS) |
| frames = [] |
| _, frame = capture.read() |
| while frame is not None: |
| frames.append(frame) |
| _, frame = capture.read() |
| capture.release() |
|
|
| output_frames = [] |
| time_start = time.time() |
| for frame in tqdm(frames): |
| face, result = faceanimation.inference(frame) |
| |
| output_frames.append(result) |
| time_end = time.time() |
| print("Time cost: %.2f" % (time_end - time_start), "FPS: %.2f" % (len(frames) / (time_end - time_start))) |
| writer = imageio.get_writer(os.path.join("temp", path_id, animatedfile), fps=fps, quality=9, macro_block_size=1, |
| codec="libx264", pixelformat="yuv420p") |
| for frame in output_frames: |
| writer.append_data(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| |
| writer.close() |
| |
|
|
|
|