diff --git a/FaceSet.py b/FaceSet.py new file mode 100644 index 0000000000000000000000000000000000000000..46e8bec4319a3bc2edab0d05cfda5520dd09d5ce --- /dev/null +++ b/FaceSet.py @@ -0,0 +1,21 @@ +import numpy as np + + +class FaceSet: + faces = [] + ref_images = [] + embedding_average = "None" + embeddings_backup = None + + def __init__(self): + self.faces = [] + self.ref_images = [] + self.embeddings_backup = None + + def AverageEmbeddings(self): + if len(self.faces) > 1 and self.embeddings_backup is None: + self.embeddings_backup = self.faces[0]["embedding"] + embeddings = [face.embedding for face in self.faces] + + self.faces[0]["embedding"] = np.mean(embeddings, axis=0) + # try median too? diff --git a/ProcessEntry.py b/ProcessEntry.py new file mode 100644 index 0000000000000000000000000000000000000000..880d29b6483e54e23c9382864f9b2983668861a7 --- /dev/null +++ b/ProcessEntry.py @@ -0,0 +1,7 @@ +class ProcessEntry: + def __init__(self, filename: str, start: int, end: int, fps: float): + self.filename = filename + self.finalname = None + self.startframe = start + self.endframe = end + self.fps = fps diff --git a/ProcessMgr.py b/ProcessMgr.py new file mode 100644 index 0000000000000000000000000000000000000000..9e0405459cd36338a64071093215b08a27a072f3 --- /dev/null +++ b/ProcessMgr.py @@ -0,0 +1,1058 @@ +import os +import cv2 +import numpy as np +import psutil + +from roop.ProcessOptions import ProcessOptions + +from roop.face_util import ( + get_first_face, + get_all_faces, + rotate_anticlockwise, + rotate_clockwise, + clamp_cut_values, +) +from roop.utilities import ( + compute_cosine_distance, + get_device, + str_to_class, + shuffle_array, +) +import roop.vr_util as vr + +from typing import Any, List, Callable +from roop.typing import Frame, Face +from concurrent.futures import ThreadPoolExecutor, as_completed +from threading import Thread, Lock +from queue import Queue +from tqdm import tqdm +from roop.ffmpeg_writer import FFMPEG_VideoWriter +from roop.StreamWriter import StreamWriter +import roop.globals + + +# Poor man's enum to be able to compare to int +class eNoFaceAction: + USE_ORIGINAL_FRAME = 0 + RETRY_ROTATED = 1 + SKIP_FRAME = 2 + SKIP_FRAME_IF_DISSIMILAR = (3,) + USE_LAST_SWAPPED = 4 + + +def create_queue(temp_frame_paths: List[str]) -> Queue[str]: + queue: Queue[str] = Queue() + for frame_path in temp_frame_paths: + queue.put(frame_path) + return queue + + +def pick_queue(queue: Queue[str], queue_per_future: int) -> List[str]: + queues = [] + for _ in range(queue_per_future): + if not queue.empty(): + queues.append(queue.get()) + return queues + + +class ProcessMgr: + input_face_datas = [] + target_face_datas = [] + + imagemask = None + + processors = [] + options: ProcessOptions = None + + num_threads = 1 + current_index = 0 + processing_threads = 1 + buffer_wait_time = 0.1 + + lock = Lock() + + frames_queue = None + processed_queue = None + + videowriter = None + streamwriter = None + + progress_gradio = None + total_frames = 0 + + num_frames_no_face = 0 + last_swapped_frame = None + + output_to_file = None + output_to_cam = None + + plugins = { + "faceswap": "FaceSwapInsightFace", + "mask_clip2seg": "Mask_Clip2Seg", + "mask_xseg": "Mask_XSeg", + "codeformer": "Enhance_CodeFormer", + "gfpgan": "Enhance_GFPGAN", + "dmdnet": "Enhance_DMDNet", + "gpen": "Enhance_GPEN", + "restoreformer++": "Enhance_RestoreFormerPPlus", + "colorizer": "Frame_Colorizer", + "filter_generic": "Frame_Filter", + "removebg": "Frame_Masking", + "upscale": "Frame_Upscale", + } + + def __init__(self, progress): + if progress is not None: + self.progress_gradio = progress + + def reuseOldProcessor(self, name: str): + for p in self.processors: + if p.processorname == name: + return p + + return None + + def initialize(self, input_faces, target_faces, options): + self.input_face_datas = input_faces + self.target_face_datas = target_faces + self.num_frames_no_face = 0 + self.last_swapped_frame = None + self.options = options + devicename = get_device() + + roop.globals.g_desired_face_analysis = [ + "landmark_3d_68", + "landmark_2d_106", + "detection", + "recognition", + ] + if options.swap_mode == "all_female" or options.swap_mode == "all_male": + roop.globals.g_desired_face_analysis.append("genderage") + elif options.swap_mode == "all_random": + # don't modify original list + self.input_face_datas = input_faces.copy() + shuffle_array(self.input_face_datas) + + for p in self.processors: + newp = next( + (x for x in options.processors.keys() if x == p.processorname), None + ) + if newp is None: + p.Release() + del p + + newprocessors = [] + for key, extoption in options.processors.items(): + p = self.reuseOldProcessor(key) + if p is None: + classname = self.plugins[key] + module = "roop.processors." + classname + p = str_to_class(module, classname) + if p is not None: + extoption.update({"devicename": devicename}) + if p.type == "swap": + if self.options.swap_modelname == "InSwapper 128": + extoption.update({"modelname": "inswapper_128.onnx"}) + elif self.options.swap_modelname == "ReSwapper 128": + extoption.update({"modelname": "reswapper_128.onnx"}) + elif self.options.swap_modelname == "ReSwapper 256": + extoption.update({"modelname": "reswapper_256.onnx"}) + + p.Initialize(extoption) + newprocessors.append(p) + else: + print(f"Not using {module}") + self.processors = newprocessors + + if ( + isinstance(self.options.imagemask, dict) + and self.options.imagemask.get("layers") + and len(self.options.imagemask["layers"]) > 0 + ): + self.options.imagemask = self.options.imagemask.get("layers")[0] + # Get rid of alpha + self.options.imagemask = cv2.cvtColor( + self.options.imagemask, cv2.COLOR_RGBA2GRAY + ) + if np.any(self.options.imagemask): + mo = self.input_face_datas[0].faces[0].mask_offsets + self.options.imagemask = self.blur_area( + self.options.imagemask, mo[4], mo[5] + ) + self.options.imagemask = self.options.imagemask.astype(np.float32) / 255 + self.options.imagemask = cv2.cvtColor( + self.options.imagemask, cv2.COLOR_GRAY2RGB + ) + else: + self.options.imagemask = None + + self.options.frame_processing = False + for p in self.processors: + if p.type.startswith("frame_"): + self.options.frame_processing = True + + def run_batch(self, source_files, target_files, threads: int = 1): + progress_bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]" + self.total_frames = len(source_files) + self.num_threads = threads + with tqdm( + total=self.total_frames, + desc="Processing", + unit="frame", + dynamic_ncols=True, + bar_format=progress_bar_format, + ) as progress: + with ThreadPoolExecutor(max_workers=threads) as executor: + futures = [] + queue = create_queue(source_files) + queue_per_future = max(len(source_files) // threads, 1) + while not queue.empty(): + future = executor.submit( + self.process_frames, + source_files, + target_files, + pick_queue(queue, queue_per_future), + lambda: self.update_progress(progress), + ) + futures.append(future) + for future in as_completed(futures): + future.result() + + def process_frames( + self, + source_files: List[str], + target_files: List[str], + current_files, + update: Callable[[], None], + ) -> None: + for f in current_files: + if not roop.globals.processing: + return + + # Decode the byte array into an OpenCV image + temp_frame = cv2.imdecode(np.fromfile(f, dtype=np.uint8), cv2.IMREAD_COLOR) + if temp_frame is not None: + if self.options.frame_processing: + for p in self.processors: + frame = p.Run(temp_frame) + resimg = frame + else: + resimg = self.process_frame(temp_frame) + if resimg is not None: + i = source_files.index(f) + # Also let numpy write the file to support utf-8/16 filenames + cv2.imencode(f".{roop.globals.CFG.output_image_format}", resimg)[ + 1 + ].tofile(target_files[i]) + if update: + update() + + def read_frames_thread(self, cap, frame_start, frame_end, num_threads): + num_frame = 0 + total_num = frame_end - frame_start + if frame_start > 0: + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_start) + + while True and roop.globals.processing: + ret, frame = cap.read() + if not ret: + break + + self.frames_queue[num_frame % num_threads].put(frame, block=True) + num_frame += 1 + if num_frame == total_num: + break + + for i in range(num_threads): + self.frames_queue[i].put(None) + + def process_videoframes(self, threadindex, progress) -> None: + while True: + frame = self.frames_queue[threadindex].get() + if frame is None: + self.processing_threads -= 1 + self.processed_queue[threadindex].put((False, None)) + return + else: + if self.options.frame_processing: + for p in self.processors: + frame = p.Run(frame) + resimg = frame + else: + resimg = self.process_frame(frame) + self.processed_queue[threadindex].put((True, resimg)) + del frame + progress() + + def write_frames_thread(self): + nextindex = 0 + num_producers = self.num_threads + + while True: + process, frame = self.processed_queue[nextindex % self.num_threads].get() + nextindex += 1 + if frame is not None: + if self.output_to_file: + self.videowriter.write_frame(frame) + if self.output_to_cam: + self.streamwriter.WriteToStream(frame) + del frame + elif process == False: + num_producers -= 1 + if num_producers < 1: + return + + def run_batch_inmem( + self, + output_method, + source_video, + target_video, + frame_start, + frame_end, + fps, + threads: int = 1, + ): + if len(self.processors) < 1: + print("No processor defined!") + return + + cap = cv2.VideoCapture(source_video) + # frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_count = (frame_end - frame_start) + 1 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + processed_resolution = None + for p in self.processors: + if hasattr(p, "getProcessedResolution"): + processed_resolution = p.getProcessedResolution(width, height) + print(f"Processed resolution: {processed_resolution}") + if processed_resolution is not None: + width = processed_resolution[0] + height = processed_resolution[1] + + self.total_frames = frame_count + self.num_threads = threads + + self.processing_threads = self.num_threads + self.frames_queue = [] + self.processed_queue = [] + for _ in range(threads): + self.frames_queue.append(Queue(1)) + self.processed_queue.append(Queue(1)) + + self.output_to_file = output_method != "Virtual Camera" + self.output_to_cam = ( + output_method == "Virtual Camera" or output_method == "Both" + ) + + if self.output_to_file: + self.videowriter = FFMPEG_VideoWriter( + target_video, + (width, height), + fps, + codec=roop.globals.video_encoder, + crf=roop.globals.video_quality, + audiofile=None, + ) + if self.output_to_cam: + self.streamwriter = StreamWriter((width, height), int(fps)) + + readthread = Thread( + target=self.read_frames_thread, args=(cap, frame_start, frame_end, threads) + ) + readthread.start() + + writethread = Thread(target=self.write_frames_thread) + writethread.start() + + progress_bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]" + with tqdm( + total=self.total_frames, + desc="Processing", + unit="frames", + dynamic_ncols=True, + bar_format=progress_bar_format, + ) as progress: + with ThreadPoolExecutor( + thread_name_prefix="swap_proc", max_workers=self.num_threads + ) as executor: + futures = [] + + for threadindex in range(threads): + future = executor.submit( + self.process_videoframes, + threadindex, + lambda: self.update_progress(progress), + ) + futures.append(future) + + for future in as_completed(futures): + future.result() + # wait for the task to complete + readthread.join() + writethread.join() + cap.release() + if self.output_to_file: + self.videowriter.close() + if self.output_to_cam: + self.streamwriter.Close() + + self.frames_queue.clear() + self.processed_queue.clear() + + def update_progress(self, progress: Any = None) -> None: + process = psutil.Process(os.getpid()) + memory_usage = process.memory_info().rss / 1024 / 1024 / 1024 + progress.set_postfix( + { + "memory_usage": "{:.2f}".format(memory_usage).zfill(5) + "GB", + "execution_threads": self.num_threads, + } + ) + progress.update(1) + if self.progress_gradio is not None: + self.progress_gradio( + (progress.n, self.total_frames), + desc="Processing", + total=self.total_frames, + unit="frames", + ) + + def process_frame(self, frame: Frame): + if len(self.input_face_datas) < 1 and not self.options.show_face_masking: + return frame + temp_frame = frame.copy() + num_swapped, temp_frame = self.swap_faces(frame, temp_frame) + if num_swapped > 0: + if roop.globals.no_face_action == eNoFaceAction.SKIP_FRAME_IF_DISSIMILAR: + if len(self.input_face_datas) > num_swapped: + return None + self.num_frames_no_face = 0 + self.last_swapped_frame = temp_frame.copy() + return temp_frame + if roop.globals.no_face_action == eNoFaceAction.USE_LAST_SWAPPED: + if ( + self.last_swapped_frame is not None + and self.num_frames_no_face < self.options.max_num_reuse_frame + ): + self.num_frames_no_face += 1 + return self.last_swapped_frame.copy() + return frame + + elif roop.globals.no_face_action == eNoFaceAction.USE_ORIGINAL_FRAME: + return frame + if roop.globals.no_face_action == eNoFaceAction.SKIP_FRAME: + # This only works with in-mem processing, as it simply skips the frame. + # For 'extract frames' it simply leaves the unprocessed frame unprocessed and it gets used in the final output by ffmpeg. + # If we could delete that frame here, that'd work but that might cause ffmpeg to fail unless the frames are renamed, and I don't think we have the info on what frame it actually is????? + # alternatively, it could mark all the necessary frames for deletion, delete them at the end, then rename the remaining frames that might work? + return None + else: + return self.retry_rotated(frame) + + def retry_rotated(self, frame): + copyframe = frame.copy() + copyframe = rotate_clockwise(copyframe) + temp_frame = copyframe.copy() + num_swapped, temp_frame = self.swap_faces(copyframe, temp_frame) + if num_swapped > 0: + return rotate_anticlockwise(temp_frame) + + copyframe = frame.copy() + copyframe = rotate_anticlockwise(copyframe) + temp_frame = copyframe.copy() + num_swapped, temp_frame = self.swap_faces(copyframe, temp_frame) + if num_swapped > 0: + return rotate_clockwise(temp_frame) + del copyframe + return frame + + def swap_faces(self, frame, temp_frame): + num_faces_found = 0 + + if self.options.swap_mode == "first": + face = get_first_face(frame) + + if face is None: + return num_faces_found, frame + + num_faces_found += 1 + temp_frame = self.process_face( + self.options.selected_index, face, temp_frame + ) + del face + + else: + faces = get_all_faces(frame) + if faces is None: + return num_faces_found, frame + + if self.options.swap_mode == "all": + for face in faces: + num_faces_found += 1 + temp_frame = self.process_face( + self.options.selected_index, face, temp_frame + ) + + elif ( + self.options.swap_mode == "all_input" + or self.options.swap_mode == "all_random" + ): + for i, face in enumerate(faces): + num_faces_found += 1 + if i < len(self.input_face_datas): + temp_frame = self.process_face(i, face, temp_frame) + else: + break + + elif self.options.swap_mode == "selected": + num_targetfaces = len(self.target_face_datas) + use_index = num_targetfaces == 1 + for i, tf in enumerate(self.target_face_datas): + for face in faces: + if ( + compute_cosine_distance(tf.embedding, face.embedding) + <= self.options.face_distance_threshold + ): + if i < len(self.input_face_datas): + if use_index: + temp_frame = self.process_face( + self.options.selected_index, face, temp_frame + ) + else: + temp_frame = self.process_face(i, face, temp_frame) + num_faces_found += 1 + if ( + not roop.globals.vr_mode + and num_faces_found == num_targetfaces + ): + break + elif ( + self.options.swap_mode == "all_female" + or self.options.swap_mode == "all_male" + ): + gender = "F" if self.options.swap_mode == "all_female" else "M" + for face in faces: + if face.sex == gender: + num_faces_found += 1 + temp_frame = self.process_face( + self.options.selected_index, face, temp_frame + ) + + # might be slower but way more clean to release everything here + for face in faces: + del face + faces.clear() + + if roop.globals.vr_mode and num_faces_found % 2 > 0: + # stereo image, there has to be an even number of faces + num_faces_found = 0 + return num_faces_found, frame + if num_faces_found == 0: + return num_faces_found, frame + + # maskprocessor = next((x for x in self.processors if x.type == 'mask'), None) + + if ( + self.options.imagemask is not None + and self.options.imagemask.shape == frame.shape + ): + temp_frame = self.simple_blend_with_mask( + temp_frame, frame, self.options.imagemask + ) + return num_faces_found, temp_frame + + def rotation_action(self, original_face: Face, frame: Frame): + (height, width) = frame.shape[:2] + + bounding_box_width = original_face.bbox[2] - original_face.bbox[0] + bounding_box_height = original_face.bbox[3] - original_face.bbox[1] + horizontal_face = bounding_box_width > bounding_box_height + + center_x = width // 2.0 + start_x = original_face.bbox[0] + end_x = original_face.bbox[2] + bbox_center_x = start_x + (bounding_box_width // 2.0) + + # need to leverage the array of landmarks as decribed here: + # https://github.com/deepinsight/insightface/tree/master/alignment/coordinate_reg + # basically, we should be able to check for the relative position of eyes and nose + # then use that to determine which way the face is actually facing when in a horizontal position + # and use that to determine the correct rotation_action + + forehead_x = original_face.landmark_2d_106[72][0] + chin_x = original_face.landmark_2d_106[0][0] + + if horizontal_face: + if chin_x < forehead_x: + # this is someone lying down with their face like this (: + return "rotate_anticlockwise" + elif forehead_x < chin_x: + # this is someone lying down with their face like this :) + return "rotate_clockwise" + if bbox_center_x >= center_x: + # this is someone lying down with their face in the right hand side of the frame + return "rotate_anticlockwise" + if bbox_center_x < center_x: + # this is someone lying down with their face in the left hand side of the frame + return "rotate_clockwise" + + return None + + def auto_rotate_frame(self, original_face, frame: Frame): + target_face = original_face + original_frame = frame + + rotation_action = self.rotation_action(original_face, frame) + + if rotation_action == "rotate_anticlockwise": + # face is horizontal, rotating frame anti-clockwise and getting face bounding box from rotated frame + frame = rotate_anticlockwise(frame) + elif rotation_action == "rotate_clockwise": + # face is horizontal, rotating frame clockwise and getting face bounding box from rotated frame + frame = rotate_clockwise(frame) + + return target_face, frame, rotation_action + + def auto_unrotate_frame(self, frame: Frame, rotation_action): + if rotation_action == "rotate_anticlockwise": + return rotate_clockwise(frame) + elif rotation_action == "rotate_clockwise": + return rotate_anticlockwise(frame) + + return frame + + def process_face(self, face_index, target_face: Face, frame: Frame): + from roop.face_util import align_crop + + enhanced_frame = None + if len(self.input_face_datas) > 0: + inputface = self.input_face_datas[face_index].faces[0] + else: + inputface = None + + rotation_action = None + if roop.globals.autorotate_faces: + # check for sideways rotation of face + rotation_action = self.rotation_action(target_face, frame) + if rotation_action is not None: + (startX, startY, endX, endY) = target_face["bbox"].astype("int") + width = endX - startX + height = endY - startY + offs = int(max(width, height) * 0.25) + rotcutframe, startX, startY, endX, endY = self.cutout( + frame, startX - offs, startY - offs, endX + offs, endY + offs + ) + if rotation_action == "rotate_anticlockwise": + rotcutframe = rotate_anticlockwise(rotcutframe) + elif rotation_action == "rotate_clockwise": + rotcutframe = rotate_clockwise(rotcutframe) + # rotate image and re-detect face to correct wonky landmarks + rotface = get_first_face(rotcutframe) + if rotface is None: + rotation_action = None + else: + saved_frame = frame.copy() + frame = rotcutframe + target_face = rotface + + # if roop.globals.vr_mode: + # bbox = target_face.bbox + # [orig_width, orig_height, _] = frame.shape + + # # Convert bounding box to ints + # x1, y1, x2, y2 = map(int, bbox) + + # # Determine the center of the bounding box + # x_center = (x1 + x2) / 2 + # y_center = (y1 + y2) / 2 + + # # Normalize coordinates to range [-1, 1] + # x_center_normalized = x_center / (orig_width / 2) - 1 + # y_center_normalized = y_center / (orig_width / 2) - 1 + + # # Convert normalized coordinates to spherical (theta, phi) + # theta = x_center_normalized * 180 # Theta ranges from -180 to 180 degrees + # phi = -y_center_normalized * 90 # Phi ranges from -90 to 90 degrees + + # img = vr.GetPerspective(frame, 90, theta, phi, 1280, 1280) # Generate perspective image + + """ Code ported/adapted from Facefusion which borrowed the idea from Rope: + Kind of subsampling the cutout and aligned face image and faceswapping slices of it up to + the desired output resolution. This works around the current resolution limitations without using enhancers. + """ + model_output_size = self.options.swap_output_size + subsample_size = max(self.options.subsample_size, model_output_size) + subsample_total = subsample_size // model_output_size + aligned_img, M = align_crop(frame, target_face.kps, subsample_size) + + fake_frame = aligned_img + target_face.matrix = M + + for p in self.processors: + if p.type == "swap": + swap_result_frames = [] + subsample_frames = self.implode_pixel_boost( + aligned_img, model_output_size, subsample_total + ) + for sliced_frame in subsample_frames: + for _ in range(0, self.options.num_swap_steps): + sliced_frame = self.prepare_crop_frame(sliced_frame) + sliced_frame = p.Run(inputface, target_face, sliced_frame) + sliced_frame = self.normalize_swap_frame(sliced_frame) + swap_result_frames.append(sliced_frame) + fake_frame = self.explode_pixel_boost( + swap_result_frames, + model_output_size, + subsample_total, + subsample_size, + ) + fake_frame = fake_frame.astype(np.uint8) + scale_factor = 0.0 + elif p.type == "mask": + fake_frame = self.process_mask(p, aligned_img, fake_frame) + else: + enhanced_frame, scale_factor = p.Run( + self.input_face_datas[face_index], target_face, fake_frame + ) + + upscale = 512 + orig_width = fake_frame.shape[1] + if orig_width != upscale: + fake_frame = cv2.resize(fake_frame, (upscale, upscale), cv2.INTER_CUBIC) + mask_offsets = ( + (0, 0, 0, 0, 1, 20) if inputface is None else inputface.mask_offsets + ) + + if enhanced_frame is None: + scale_factor = int(upscale / orig_width) + result = self.paste_upscale( + fake_frame, + fake_frame, + target_face.matrix, + frame, + scale_factor, + mask_offsets, + ) + else: + result = self.paste_upscale( + fake_frame, + enhanced_frame, + target_face.matrix, + frame, + scale_factor, + mask_offsets, + ) + + # Restore mouth before unrotating + if self.options.restore_original_mouth: + mouth_cutout, mouth_bb = self.create_mouth_mask(target_face, frame) + result = self.apply_mouth_area(result, mouth_cutout, mouth_bb) + + if rotation_action is not None: + fake_frame = self.auto_unrotate_frame(result, rotation_action) + result = self.paste_simple(fake_frame, saved_frame, startX, startY) + + return result + + def cutout(self, frame: Frame, start_x, start_y, end_x, end_y): + if start_x < 0: + start_x = 0 + if start_y < 0: + start_y = 0 + if end_x > frame.shape[1]: + end_x = frame.shape[1] + if end_y > frame.shape[0]: + end_y = frame.shape[0] + return frame[start_y:end_y, start_x:end_x], start_x, start_y, end_x, end_y + + def paste_simple(self, src: Frame, dest: Frame, start_x, start_y): + end_x = start_x + src.shape[1] + end_y = start_y + src.shape[0] + + start_x, end_x, start_y, end_y = clamp_cut_values( + start_x, end_x, start_y, end_y, dest + ) + dest[start_y:end_y, start_x:end_x] = src + return dest + + def simple_blend_with_mask(self, image1, image2, mask): + # Blend the images + blended_image = ( + image1.astype(np.float32) * (1.0 - mask) + image2.astype(np.float32) * mask + ) + return blended_image.astype(np.uint8) + + def paste_upscale( + self, fake_face, upsk_face, M, target_img, scale_factor, mask_offsets + ): + M_scale = M * scale_factor + IM = cv2.invertAffineTransform(M_scale) + + face_matte = np.full( + (target_img.shape[0], target_img.shape[1]), 255, dtype=np.uint8 + ) + # Generate white square sized as a upsk_face + img_matte = np.zeros((upsk_face.shape[0], upsk_face.shape[1]), dtype=np.uint8) + + w = img_matte.shape[1] + h = img_matte.shape[0] + + top = int(mask_offsets[0] * h) + bottom = int(h - (mask_offsets[1] * h)) + left = int(mask_offsets[2] * w) + right = int(w - (mask_offsets[3] * w)) + img_matte[top:bottom, left:right] = 255 + + # Transform white square back to target_img + img_matte = cv2.warpAffine( + img_matte, + IM, + (target_img.shape[1], target_img.shape[0]), + flags=cv2.INTER_NEAREST, + borderValue=0.0, + ) + ##Blacken the edges of face_matte by 1 pixels (so the mask in not expanded on the image edges) + img_matte[:1, :] = img_matte[-1:, :] = img_matte[:, :1] = img_matte[:, -1:] = 0 + + img_matte = self.blur_area(img_matte, mask_offsets[4], mask_offsets[5]) + # Normalize images to float values and reshape + img_matte = img_matte.astype(np.float32) / 255 + face_matte = face_matte.astype(np.float32) / 255 + img_matte = np.minimum(face_matte, img_matte) + if self.options.show_face_area_overlay: + # Additional steps for green overlay + green_overlay = np.zeros_like(target_img) + green_color = [0, 255, 0] # RGB for green + for i in range(3): # Apply green color where img_matte is not zero + green_overlay[:, :, i] = np.where( + img_matte > 0, green_color[i], 0 + ) ##Transform upcaled face back to target_img + img_matte = np.reshape(img_matte, [img_matte.shape[0], img_matte.shape[1], 1]) + paste_face = cv2.warpAffine( + upsk_face, + IM, + (target_img.shape[1], target_img.shape[0]), + borderMode=cv2.BORDER_REPLICATE, + ) + if upsk_face is not fake_face: + fake_face = cv2.warpAffine( + fake_face, + IM, + (target_img.shape[1], target_img.shape[0]), + borderMode=cv2.BORDER_REPLICATE, + ) + paste_face = cv2.addWeighted( + paste_face, + self.options.blend_ratio, + fake_face, + 1.0 - self.options.blend_ratio, + 0, + ) + + # Re-assemble image + paste_face = img_matte * paste_face + paste_face = paste_face + (1 - img_matte) * target_img.astype(np.float32) + if self.options.show_face_area_overlay: + # Overlay the green overlay on the final image + paste_face = cv2.addWeighted( + paste_face.astype(np.uint8), 1 - 0.5, green_overlay, 0.5, 0 + ) + return paste_face.astype(np.uint8) + + def blur_area(self, img_matte, num_erosion_iterations, blur_amount): + # Detect the affine transformed white area + mask_h_inds, mask_w_inds = np.where(img_matte == 255) + # Calculate the size (and diagonal size) of transformed white area width and height boundaries + mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) + mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) + mask_size = int(np.sqrt(mask_h * mask_w)) + # Calculate the kernel size for eroding img_matte by kernel (insightface empirical guess for best size was max(mask_size//10,10)) + # k = max(mask_size//12, 8) + k = max(mask_size // (blur_amount // 2), blur_amount // 2) + kernel = np.ones((k, k), np.uint8) + img_matte = cv2.erode(img_matte, kernel, iterations=num_erosion_iterations) + # Calculate the kernel size for blurring img_matte by blur_size (insightface empirical guess for best size was max(mask_size//20, 5)) + # k = max(mask_size//24, 4) + k = max(mask_size // blur_amount, blur_amount // 5) + kernel_size = (k, k) + blur_size = tuple(2 * i + 1 for i in kernel_size) + return cv2.GaussianBlur(img_matte, blur_size, 0) + + def prepare_crop_frame(self, swap_frame): + model_type = "inswapper" + model_mean = [0.0, 0.0, 0.0] + model_standard_deviation = [1.0, 1.0, 1.0] + + if model_type == "ghost": + swap_frame = swap_frame[:, :, ::-1] / 127.5 - 1 + else: + swap_frame = swap_frame[:, :, ::-1] / 255.0 + swap_frame = (swap_frame - model_mean) / model_standard_deviation + swap_frame = swap_frame.transpose(2, 0, 1) + swap_frame = np.expand_dims(swap_frame, axis=0).astype(np.float32) + return swap_frame + + def normalize_swap_frame(self, swap_frame): + model_type = "inswapper" + swap_frame = swap_frame.transpose(1, 2, 0) + + if model_type == "ghost": + swap_frame = (swap_frame * 127.5 + 127.5).round() + else: + swap_frame = (swap_frame * 255.0).round() + swap_frame = swap_frame[:, :, ::-1] + return swap_frame + + def implode_pixel_boost( + self, aligned_face_frame, model_size, pixel_boost_total: int + ): + subsample_frame = aligned_face_frame.reshape( + model_size, pixel_boost_total, model_size, pixel_boost_total, 3 + ) + subsample_frame = subsample_frame.transpose(1, 3, 0, 2, 4).reshape( + pixel_boost_total**2, model_size, model_size, 3 + ) + return subsample_frame + + def explode_pixel_boost( + self, subsample_frame, model_size, pixel_boost_total, pixel_boost_size + ): + final_frame = np.stack(subsample_frame, axis=0).reshape( + pixel_boost_total, pixel_boost_total, model_size, model_size, 3 + ) + final_frame = final_frame.transpose(2, 0, 3, 1, 4).reshape( + pixel_boost_size, pixel_boost_size, 3 + ) + return final_frame + + def process_mask(self, processor, frame: Frame, target: Frame): + img_mask = processor.Run(frame, self.options.masking_text) + img_mask = cv2.resize(img_mask, (target.shape[1], target.shape[0])) + img_mask = np.reshape(img_mask, [img_mask.shape[0], img_mask.shape[1], 1]) + + if self.options.show_face_masking: + result = (1 - img_mask) * frame.astype(np.float32) + return np.uint8(result) + + target = target.astype(np.float32) + result = (1 - img_mask) * target + result += img_mask * frame.astype(np.float32) + return np.uint8(result) + + # Code for mouth restoration adapted from https://github.com/iVideoGameBoss/iRoopDeepFaceCam + + def create_mouth_mask(self, face: Face, frame: Frame): + mouth_cutout = None + + landmarks = face.landmark_2d_106 + if landmarks is not None: + # Get mouth landmarks (indices 52 to 71 typically represent the outer mouth) + mouth_points = landmarks[52:71].astype(np.int32) + + # Add padding to mouth area + min_x, min_y = np.min(mouth_points, axis=0) + max_x, max_y = np.max(mouth_points, axis=0) + min_x = max(0, min_x - (15 * 6)) + min_y = max(0, min_y - 22) + max_x = min(frame.shape[1], max_x + (15 * 6)) + max_y = min(frame.shape[0], max_y + (90 * 6)) + + # Extract the mouth area from the frame using the calculated bounding box + mouth_cutout = frame[min_y:max_y, min_x:max_x].copy() + + return mouth_cutout, (min_x, min_y, max_x, max_y) + + def create_feathered_mask(self, shape, feather_amount=30): + mask = np.zeros(shape[:2], dtype=np.float32) + center = (shape[1] // 2, shape[0] // 2) + cv2.ellipse( + mask, + center, + (shape[1] // 2 - feather_amount, shape[0] // 2 - feather_amount), + 0, + 0, + 360, + 1, + -1, + ) + mask = cv2.GaussianBlur( + mask, (feather_amount * 2 + 1, feather_amount * 2 + 1), 0 + ) + return mask / np.max(mask) + + def apply_mouth_area( + self, frame: np.ndarray, mouth_cutout: np.ndarray, mouth_box: tuple + ) -> np.ndarray: + min_x, min_y, max_x, max_y = mouth_box + box_width = max_x - min_x + box_height = max_y - min_y + + # Resize the mouth cutout to match the mouth box size + if mouth_cutout is None or box_width is None or box_height is None: + return frame + try: + resized_mouth_cutout = cv2.resize(mouth_cutout, (box_width, box_height)) + + # Extract the region of interest (ROI) from the target frame + roi = frame[min_y:max_y, min_x:max_x] + + # Ensure the ROI and resized_mouth_cutout have the same shape + if roi.shape != resized_mouth_cutout.shape: + resized_mouth_cutout = cv2.resize( + resized_mouth_cutout, (roi.shape[1], roi.shape[0]) + ) + + # Apply color transfer from ROI to mouth cutout + color_corrected_mouth = self.apply_color_transfer(resized_mouth_cutout, roi) + + # Create a feathered mask with increased feather amount + feather_amount = min(30, box_width // 15, box_height // 15) + mask = self.create_feathered_mask( + resized_mouth_cutout.shape, feather_amount + ) + + # Blend the color-corrected mouth cutout with the ROI using the feathered mask + mask = mask[:, :, np.newaxis] # Add channel dimension to mask + blended = (color_corrected_mouth * mask + roi * (1 - mask)).astype(np.uint8) + + # Place the blended result back into the frame + frame[min_y:max_y, min_x:max_x] = blended + except Exception as e: + print(f"Error {e}") + pass + + return frame + + def apply_color_transfer(self, source, target): + """ + Apply color transfer from target to source image + """ + source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype("float32") + target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype("float32") + + source_mean, source_std = cv2.meanStdDev(source) + target_mean, target_std = cv2.meanStdDev(target) + + # Reshape mean and std to be broadcastable + source_mean = source_mean.reshape(1, 1, 3) + source_std = source_std.reshape(1, 1, 3) + target_mean = target_mean.reshape(1, 1, 3) + target_std = target_std.reshape(1, 1, 3) + + # Perform the color transfer + source = (source - source_mean) * (target_std / source_std) + target_mean + return cv2.cvtColor(np.clip(source, 0, 255).astype("uint8"), cv2.COLOR_LAB2BGR) + + def unload_models(): + pass + + def release_resources(self): + for p in self.processors: + p.Release() + self.processors.clear() + if self.videowriter is not None: + self.videowriter.close() + if self.streamwriter is not None: + self.streamwriter.Close() diff --git a/ProcessOptions.py b/ProcessOptions.py new file mode 100644 index 0000000000000000000000000000000000000000..b308a40570d6b31e7874d4f6966cb2df972c9d38 --- /dev/null +++ b/ProcessOptions.py @@ -0,0 +1,35 @@ +class ProcessOptions: + def __init__( + self, + swap_model, + processordefines: dict, + face_distance, + blend_ratio, + swap_mode, + selected_index, + masking_text, + imagemask, + num_steps, + subsample_size, + show_face_area, + restore_original_mouth, + show_mask=False, + ): + if swap_model is not None: + self.swap_modelname = swap_model + self.swap_output_size = int(swap_model.split()[-1]) + else: + self.swap_output_size = 128 + self.processors = processordefines + self.face_distance_threshold = face_distance + self.blend_ratio = blend_ratio + self.swap_mode = swap_mode + self.selected_index = selected_index + self.masking_text = masking_text + self.imagemask = imagemask + self.num_swap_steps = num_steps + self.show_face_area_overlay = show_face_area + self.show_face_masking = show_mask + self.subsample_size = subsample_size + self.restore_original_mouth = restore_original_mouth + self.max_num_reuse_frame = 15 diff --git a/StreamWriter.py b/StreamWriter.py new file mode 100644 index 0000000000000000000000000000000000000000..60571f0998bf19f780633c14f206c88a08c00e13 --- /dev/null +++ b/StreamWriter.py @@ -0,0 +1,60 @@ +import threading +import time +import pyvirtualcam + + +class StreamWriter: + FPS = 30 + VCam = None + Active = False + THREAD_LOCK_STREAM = threading.Lock() + time_last_process = None + timespan_min = 0.0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.Close() + + def __init__(self, size, fps): + self.time_last_process = time.perf_counter() + self.FPS = fps + self.timespan_min = 1.0 / fps + print("Detecting virtual cam devices") + self.VCam = pyvirtualcam.Camera( + width=size[0], + height=size[1], + fps=fps, + fmt=pyvirtualcam.PixelFormat.BGR, + print_fps=False, + ) + if self.VCam is None: + print("No virtual camera found!") + return + print(f"Using virtual camera: {self.VCam.device}") + print(f"Using {self.VCam.native_fmt}") + self.Active = True + + def LimitFrames(self): + while True: + current_time = time.perf_counter() + time_passed = current_time - self.time_last_process + if time_passed >= self.timespan_min: + break + + # First version used a queue and threading. Surprisingly this + # totally simple, blocking version is 10 times faster! + def WriteToStream(self, frame): + if self.VCam is None: + return + with self.THREAD_LOCK_STREAM: + self.LimitFrames() + self.VCam.send(frame) + self.time_last_process = time.perf_counter() + + def Close(self): + self.Active = False + if self.VCam is None: + self.VCam.close() + self.VCam = None diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/__pycache__/FaceSet.cpython-310.pyc b/__pycache__/FaceSet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dc947521b5a5ec2a1a220634c32afc77fbf3c1b Binary files /dev/null and b/__pycache__/FaceSet.cpython-310.pyc differ diff --git a/__pycache__/ProcessEntry.cpython-310.pyc b/__pycache__/ProcessEntry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..641442a29605e6385dbad620f5f9c1cbdb450567 Binary files /dev/null and b/__pycache__/ProcessEntry.cpython-310.pyc differ diff --git a/__pycache__/ProcessMgr.cpython-310.pyc b/__pycache__/ProcessMgr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18b573b3ef13dabd9ddfa069a691048236cd86a5 Binary files /dev/null and b/__pycache__/ProcessMgr.cpython-310.pyc differ diff --git a/__pycache__/ProcessOptions.cpython-310.pyc b/__pycache__/ProcessOptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..994dc1c039dcc91802ab2936eab57ad627b792e3 Binary files /dev/null and b/__pycache__/ProcessOptions.cpython-310.pyc differ diff --git a/__pycache__/StreamWriter.cpython-310.pyc b/__pycache__/StreamWriter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2abba07260dd1eb89219e442381beebbd6ccee51 Binary files /dev/null and b/__pycache__/StreamWriter.cpython-310.pyc differ diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6df0e2d48ac8e78003c12b7947f52cf207182ad Binary files /dev/null and b/__pycache__/__init__.cpython-310.pyc differ diff --git a/__pycache__/capturer.cpython-310.pyc b/__pycache__/capturer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1abf17794ec3bf9059244e66b047e310a8b64126 Binary files /dev/null and b/__pycache__/capturer.cpython-310.pyc differ diff --git a/__pycache__/core.cpython-310.pyc b/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c7072f714817dba30ab97749fbbdb2fa8e88b4a Binary files /dev/null and b/__pycache__/core.cpython-310.pyc differ diff --git a/__pycache__/face_util.cpython-310.pyc b/__pycache__/face_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5975e6d5bc76a2d550894d3720dcfb63f5789e1 Binary files /dev/null and b/__pycache__/face_util.cpython-310.pyc differ diff --git a/__pycache__/ffmpeg_writer.cpython-310.pyc b/__pycache__/ffmpeg_writer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..752014ee5b1ce676a5211f89a648dea41e6a0aae Binary files /dev/null and b/__pycache__/ffmpeg_writer.cpython-310.pyc differ diff --git a/__pycache__/globals.cpython-310.pyc b/__pycache__/globals.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b760c63e5cf3e331747ce863552f1d82f1303f34 Binary files /dev/null and b/__pycache__/globals.cpython-310.pyc differ diff --git a/__pycache__/metadata.cpython-310.pyc b/__pycache__/metadata.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac3cc1f41b1efc963dfd22ba2b776a3b807708ef Binary files /dev/null and b/__pycache__/metadata.cpython-310.pyc differ diff --git a/__pycache__/template_parser.cpython-310.pyc b/__pycache__/template_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43efe6cb65d7029ca85be2bad87ad809416216c5 Binary files /dev/null and b/__pycache__/template_parser.cpython-310.pyc differ diff --git a/__pycache__/typing.cpython-310.pyc b/__pycache__/typing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d29fb561f4a899ad20e53d79326435c8e619e87 Binary files /dev/null and b/__pycache__/typing.cpython-310.pyc differ diff --git a/__pycache__/util_ffmpeg.cpython-310.pyc b/__pycache__/util_ffmpeg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d435b4f639b9c9466719c0302367fc8bf9d7f6cd Binary files /dev/null and b/__pycache__/util_ffmpeg.cpython-310.pyc differ diff --git a/__pycache__/utilities.cpython-310.pyc b/__pycache__/utilities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f5ec2b88487090b327d6a27abc24bc0d078e96b Binary files /dev/null and b/__pycache__/utilities.cpython-310.pyc differ diff --git a/__pycache__/vr_util.cpython-310.pyc b/__pycache__/vr_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..565549bf2cddd35d3bcf80ce0c4a1fb593e4de15 Binary files /dev/null and b/__pycache__/vr_util.cpython-310.pyc differ diff --git a/capturer.py b/capturer.py new file mode 100644 index 0000000000000000000000000000000000000000..774e1d30cb0a000cebfc547273df96c5a70aac02 --- /dev/null +++ b/capturer.py @@ -0,0 +1,50 @@ +from typing import Optional +import cv2 +import numpy as np + +from roop.typing import Frame + +current_video_path = None +current_frame_total = 0 +current_capture = None + + +def get_image_frame(filename: str): + try: + return cv2.imdecode(np.fromfile(filename, dtype=np.uint8), cv2.IMREAD_COLOR) + except: + print(f"Exception reading {filename}") + return None + + +def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Frame]: + global current_video_path, current_capture, current_frame_total + + if video_path != current_video_path: + release_video() + current_capture = cv2.VideoCapture(video_path) + current_video_path = video_path + current_frame_total = current_capture.get(cv2.CAP_PROP_FRAME_COUNT) + + current_capture.set( + cv2.CAP_PROP_POS_FRAMES, min(current_frame_total, frame_number - 1) + ) + has_frame, frame = current_capture.read() + if has_frame: + return frame + return None + + +def release_video(): + global current_capture + + if current_capture is not None: + current_capture.release() + current_capture = None + + +def get_video_frame_total(video_path: str) -> int: + capture = cv2.VideoCapture(video_path) + video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) + capture.release() + return video_frame_total diff --git a/core.py b/core.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3cceed31ccc405b4ebac8a2f0c6b271d37d5fc --- /dev/null +++ b/core.py @@ -0,0 +1,605 @@ +#!/usr/bin/env python3 + +import os +import sys +import shutil + +# single thread doubles cuda performance - needs to be set before torch import +if any(arg.startswith("--execution-provider") for arg in sys.argv): + os.environ["OMP_NUM_THREADS"] = "1" + +import warnings +from typing import List +import platform +import signal +import torch +import onnxruntime +import pathlib +import argparse + +from time import time + +import roop.globals +import roop.metadata +import roop.utilities as util +import roop.util_ffmpeg as ffmpeg +import ui.main as main +from settings import Settings +from roop.face_util import extract_face_images +from roop.ProcessEntry import ProcessEntry +from roop.ProcessMgr import ProcessMgr +from roop.ProcessOptions import ProcessOptions +from roop.capturer import get_video_frame_total, release_video + + +clip_text = None + +call_display_ui = None + +process_mgr = None + + +if "ROCMExecutionProvider" in roop.globals.execution_providers: + del torch + +warnings.filterwarnings("ignore", category=FutureWarning, module="insightface") +warnings.filterwarnings("ignore", category=UserWarning, module="torchvision") + + +def parse_args() -> None: + signal.signal(signal.SIGINT, lambda signal_number, frame: destroy()) + roop.globals.headless = False + + program = argparse.ArgumentParser( + formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100) + ) + program.add_argument( + "--server_share", + help="Public server", + dest="server_share", + action="store_true", + default=False, + ) + program.add_argument( + "--cuda_device_id", + help="Index of the cuda gpu to use", + dest="cuda_device_id", + type=int, + default=0, + ) + roop.globals.startup_args = program.parse_args() + # Always enable all processors when using GUI + roop.globals.frame_processors = ["face_swapper", "face_enhancer"] + + +def encode_execution_providers(execution_providers: List[str]) -> List[str]: + return [ + execution_provider.replace("ExecutionProvider", "").lower() + for execution_provider in execution_providers + ] + + +def decode_execution_providers(execution_providers: List[str]) -> List[str]: + list_providers = [ + provider + for provider, encoded_execution_provider in zip( + onnxruntime.get_available_providers(), + encode_execution_providers(onnxruntime.get_available_providers()), + ) + if any( + execution_provider in encoded_execution_provider + for execution_provider in execution_providers + ) + ] + + try: + for i in range(len(list_providers)): + if list_providers[i] == "CUDAExecutionProvider": + list_providers[i] = ( + "CUDAExecutionProvider", + {"device_id": roop.globals.cuda_device_id}, + ) + torch.cuda.set_device(roop.globals.cuda_device_id) + break + except: + pass + + return list_providers + + +def suggest_max_memory() -> int: + if platform.system().lower() == "darwin": + return 4 + return 16 + + +def suggest_execution_providers() -> List[str]: + return encode_execution_providers(onnxruntime.get_available_providers()) + + +def suggest_execution_threads() -> int: + if "DmlExecutionProvider" in roop.globals.execution_providers: + return 1 + if "ROCMExecutionProvider" in roop.globals.execution_providers: + return 1 + return 8 + + +def limit_resources() -> None: + # limit memory usage + if roop.globals.max_memory: + memory = roop.globals.max_memory * 1024**3 + if platform.system().lower() == "darwin": + memory = roop.globals.max_memory * 1024**6 + if platform.system().lower() == "windows": + import ctypes + + kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined] + kernel32.SetProcessWorkingSetSize( + -1, ctypes.c_size_t(memory), ctypes.c_size_t(memory) + ) + else: + import resource + + resource.setrlimit(resource.RLIMIT_DATA, (memory, memory)) + + +def release_resources() -> None: + import gc + + global process_mgr + + if process_mgr is not None: + process_mgr.release_resources() + process_mgr = None + + gc.collect() + if ( + "CUDAExecutionProvider" in roop.globals.execution_providers + and torch.cuda.is_available() + ): + with torch.cuda.device("cuda"): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +def pre_check() -> bool: + if sys.version_info < (3, 9): + update_status( + "Python version is not supported - please upgrade to 3.9 or higher." + ) + return False + + download_directory_path = util.resolve_relative_path("../models") + util.conditional_download( + download_directory_path, + [ + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/inswapper_128.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/InSwapper/inswapper_128.onnx", + ], + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/reswapper_128.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/ReSwapper/reswapper_128.onnx", + ], + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/reswapper_256.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/ReSwapper/reswapper_256.onnx", + ], + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/GFPGANv1.4.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/GFPGAN/GFPGANv1.4.onnx", + ], + [ + "https://github.com/csxmli2016/DMDNet/releases/download/v1/DMDNet.pth", + "https://codeberg.org/roop-unleashed/models/media/branch/main/DMDNet/DMDNet.pth", + ], + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/GPEN-BFR-512.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/GPEN/GPEN-BFR-512.onnx", + ], + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/restoreformer_plus_plus.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/RestoreFormer/restoreformer_plus_plus.onnx", + ], + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/xseg.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/xseg.onnx", + ], + ], + ) + download_directory_path = util.resolve_relative_path("../models/CLIP") + util.conditional_download( + download_directory_path, + [ + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/rd64-uni-refined.pth", + "https://codeberg.org/roop-unleashed/models/media/branch/main/rd64-uni-refined.pth", + ] + ], + ) + download_directory_path = util.resolve_relative_path("../models/buffalo_l") + util.conditional_download( + download_directory_path, + [ + [ + "https://huggingface.co/halllooo/buffalo_l/resolve/main/1k3d68.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/buffalo_l/1k3d68.onnx", + ], + [ + "https://huggingface.co/halllooo/buffalo_l/resolve/main/2d106det.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/buffalo_l/2d106det.onnx", + ], + [ + "https://huggingface.co/halllooo/buffalo_l/resolve/main/det_10g.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/buffalo_l/det_10g.onnx", + ], + [ + "https://huggingface.co/halllooo/buffalo_l/resolve/main/genderage.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/buffalo_l/genderage.onnx", + ], + [ + "https://huggingface.co/halllooo/buffalo_l/resolve/main/w600k_r50.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/buffalo_l/w600k_r50.onnx", + ], + ], + ) + download_directory_path = util.resolve_relative_path("../models/CodeFormer") + util.conditional_download( + download_directory_path, + [ + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/CodeFormerv0.1.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/CodeFormer/CodeFormerv0.1.onnx", + ] + ], + ) + download_directory_path = util.resolve_relative_path("../models/Frame") + util.conditional_download( + download_directory_path, + [ + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/deoldify_artistic.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/DeOldify/deoldify_artistic.onnx", + ], + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/deoldify_stable.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/DeOldify/deoldify_stable.onnx", + ], + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/isnet-general-use.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/isnet-general-use.onnx", + ], + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/real_esrgan_x4.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/real_esrgan_x4.onnx", + ], + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/real_esrgan_x2.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/real_esrgan_x2.onnx", + ], + [ + "https://huggingface.co/countfloyd/deepfake/resolve/main/lsdir_x4.onnx", + "https://codeberg.org/roop-unleashed/models/media/branch/main/lsdir_x4.onnx", + ], + ], + ) + + if not shutil.which("ffmpeg"): + update_status("ffmpeg is not installed.") + return True + + +def set_display_ui(function): + global call_display_ui + + call_display_ui = function + + +def update_status(message: str) -> None: + global call_display_ui + + print(message) + if call_display_ui is not None: + call_display_ui(message) + + +def start() -> None: + if roop.globals.headless: + print("Headless mode currently unsupported - starting UI!") + # faces = extract_face_images(roop.globals.source_path, (False, 0)) + # roop.globals.INPUT_FACES.append(faces[roop.globals.source_face_index]) + # faces = extract_face_images(roop.globals.target_path, (False, util.has_image_extension(roop.globals.target_path))) + # roop.globals.TARGET_FACES.append(faces[roop.globals.target_face_index]) + # if 'face_enhancer' in roop.globals.frame_processors: + # roop.globals.selected_enhancer = 'GFPGAN' + + batch_process_regular(None, False, None) + + +def get_processing_plugins(masking_engine): + processors = {"faceswap": {}} + if masking_engine is not None: + processors.update({masking_engine: {}}) + + if roop.globals.selected_enhancer == "GFPGAN": + processors.update({"gfpgan": {}}) + elif roop.globals.selected_enhancer == "Codeformer": + processors.update({"codeformer": {}}) + elif roop.globals.selected_enhancer == "DMDNet": + processors.update({"dmdnet": {}}) + elif roop.globals.selected_enhancer == "GPEN": + processors.update({"gpen": {}}) + elif roop.globals.selected_enhancer == "Restoreformer++": + processors.update({"restoreformer++": {}}) + return processors + + +def live_swap(frame, options): + global process_mgr + + if frame is None: + return frame + + if process_mgr is None: + process_mgr = ProcessMgr(None) + + # if len(roop.globals.INPUT_FACESETS) <= selected_index: + # selected_index = 0 + process_mgr.initialize( + roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options + ) + newframe = process_mgr.process_frame(frame) + if newframe is None: + return frame + return newframe + + +def batch_process_regular( + swap_model, + output_method, + files: list[ProcessEntry], + masking_engine: str, + new_clip_text: str, + use_new_method, + imagemask, + restore_original_mouth, + num_swap_steps, + progress, + selected_index=0, +) -> None: + global clip_text, process_mgr + + release_resources() + limit_resources() + if process_mgr is None: + process_mgr = ProcessMgr(progress) + mask = imagemask["layers"][0] if imagemask is not None else None + if len(roop.globals.INPUT_FACESETS) <= selected_index: + selected_index = 0 + options = ProcessOptions( + swap_model, + get_processing_plugins(masking_engine), + roop.globals.distance_threshold, + roop.globals.blend_ratio, + roop.globals.face_swap_mode, + selected_index, + new_clip_text, + mask, + num_swap_steps, + roop.globals.subsample_size, + False, + restore_original_mouth, + ) + process_mgr.initialize( + roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options + ) + batch_process(output_method, files, use_new_method) + return + + +def batch_process_with_options(files: list[ProcessEntry], options, progress): + global clip_text, process_mgr + + release_resources() + limit_resources() + if process_mgr is None: + process_mgr = ProcessMgr(progress) + process_mgr.initialize( + roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options + ) + roop.globals.keep_frames = False + roop.globals.wait_after_extraction = False + roop.globals.skip_audio = False + batch_process("Files", files, True) + + +def batch_process(output_method, files: list[ProcessEntry], use_new_method) -> None: + global clip_text, process_mgr + + roop.globals.processing = True + + # limit threads for some providers + max_threads = suggest_execution_threads() + if max_threads == 1: + roop.globals.execution_threads = 1 + + imagefiles: list[ProcessEntry] = [] + videofiles: list[ProcessEntry] = [] + + update_status("Sorting videos/images") + + for index, f in enumerate(files): + fullname = f.filename + if util.has_image_extension(fullname): + destination = util.get_destfilename_from_path( + fullname, + roop.globals.output_path, + f".{roop.globals.CFG.output_image_format}", + ) + destination = util.replace_template(destination, index=index) + pathlib.Path(os.path.dirname(destination)).mkdir( + parents=True, exist_ok=True + ) + f.finalname = destination + imagefiles.append(f) + + elif util.is_video(fullname) or util.has_extension(fullname, ["gif"]): + destination = util.get_destfilename_from_path( + fullname, + roop.globals.output_path, + f"__temp.{roop.globals.CFG.output_video_format}", + ) + f.finalname = destination + videofiles.append(f) + + if len(imagefiles) > 0: + update_status("Processing image(s)") + origimages = [] + fakeimages = [] + for f in imagefiles: + origimages.append(f.filename) + fakeimages.append(f.finalname) + + process_mgr.run_batch(origimages, fakeimages, roop.globals.execution_threads) + origimages.clear() + fakeimages.clear() + + if len(videofiles) > 0: + for index, v in enumerate(videofiles): + if not roop.globals.processing: + end_processing("Processing stopped!") + return + fps = v.fps if v.fps > 0 else util.detect_fps(v.filename) + if v.endframe == 0: + v.endframe = get_video_frame_total(v.filename) + + is_streaming_only = output_method == "Virtual Camera" + if is_streaming_only == False: + update_status( + f"Creating {os.path.basename(v.finalname)} with {fps} FPS..." + ) + + start_processing = time() + if ( + is_streaming_only == False + and roop.globals.keep_frames + or not use_new_method + ): + util.create_temp(v.filename) + update_status("Extracting frames...") + ffmpeg.extract_frames(v.filename, v.startframe, v.endframe, fps) + if not roop.globals.processing: + end_processing("Processing stopped!") + return + + temp_frame_paths = util.get_temp_frame_paths(v.filename) + process_mgr.run_batch( + temp_frame_paths, temp_frame_paths, roop.globals.execution_threads + ) + if not roop.globals.processing: + end_processing("Processing stopped!") + return + if roop.globals.wait_after_extraction: + extract_path = os.path.dirname(temp_frame_paths[0]) + util.open_folder(extract_path) + input("Press any key to continue...") + print("Resorting frames to create video") + util.sort_rename_frames(extract_path) + + ffmpeg.create_video(v.filename, v.finalname, fps) + if not roop.globals.keep_frames: + util.delete_temp_frames(temp_frame_paths[0]) + else: + if util.has_extension(v.filename, ["gif"]): + skip_audio = True + else: + skip_audio = roop.globals.skip_audio + process_mgr.run_batch_inmem( + output_method, + v.filename, + v.finalname, + v.startframe, + v.endframe, + fps, + roop.globals.execution_threads, + ) + + if not roop.globals.processing: + end_processing("Processing stopped!") + return + + video_file_name = v.finalname + if os.path.isfile(video_file_name): + destination = "" + if util.has_extension(v.filename, ["gif"]): + gifname = util.get_destfilename_from_path( + v.filename, roop.globals.output_path, ".gif" + ) + destination = util.replace_template(gifname, index=index) + pathlib.Path(os.path.dirname(destination)).mkdir( + parents=True, exist_ok=True + ) + + update_status("Creating final GIF") + ffmpeg.create_gif_from_video(video_file_name, destination) + if os.path.isfile(destination): + os.remove(video_file_name) + else: + skip_audio = roop.globals.skip_audio + destination = util.replace_template(video_file_name, index=index) + pathlib.Path(os.path.dirname(destination)).mkdir( + parents=True, exist_ok=True + ) + + if not skip_audio: + ffmpeg.restore_audio( + video_file_name, + v.filename, + v.startframe, + v.endframe, + destination, + ) + if os.path.isfile(destination): + os.remove(video_file_name) + else: + shutil.move(video_file_name, destination) + + elif is_streaming_only == False: + update_status(f"Failed processing {os.path.basename(v.finalname)}!") + elapsed_time = time() - start_processing + average_fps = (v.endframe - v.startframe) / elapsed_time + update_status( + f"\nProcessing {os.path.basename(destination)} took {elapsed_time:.2f} secs, {average_fps:.2f} frames/s" + ) + end_processing("Finished") + + +def end_processing(msg: str): + update_status(msg) + roop.globals.target_folder_path = None + release_resources() + + +def destroy() -> None: + if roop.globals.target_path: + util.clean_temp(roop.globals.target_path) + release_resources() + sys.exit() + + +def run() -> None: + parse_args() + if not pre_check(): + return + roop.globals.CFG = Settings("config.yaml") + roop.globals.cuda_device_id = roop.globals.startup_args.cuda_device_id + roop.globals.execution_threads = roop.globals.CFG.max_threads + roop.globals.video_encoder = roop.globals.CFG.output_video_codec + roop.globals.video_quality = roop.globals.CFG.video_quality + roop.globals.max_memory = ( + roop.globals.CFG.memory_limit if roop.globals.CFG.memory_limit > 0 else None + ) + if roop.globals.startup_args.server_share: + roop.globals.CFG.server_share = True + main.run() diff --git a/face_util.py b/face_util.py new file mode 100644 index 0000000000000000000000000000000000000000..38ca259f1abb4f292d7f15992960dd7cacac261a --- /dev/null +++ b/face_util.py @@ -0,0 +1,352 @@ +import threading +from typing import Any +import insightface + +import roop.globals +from roop.typing import Frame, Face + +import cv2 +import numpy as np +from skimage import transform as trans +from roop.capturer import get_video_frame +from roop.utilities import resolve_relative_path, conditional_thread_semaphore + +FACE_ANALYSER = None +# THREAD_LOCK_ANALYSER = threading.Lock() +# THREAD_LOCK_SWAPPER = threading.Lock() +FACE_SWAPPER = None + + +def get_face_analyser() -> Any: + global FACE_ANALYSER + + with conditional_thread_semaphore(): + if ( + FACE_ANALYSER is None + or roop.globals.g_current_face_analysis + != roop.globals.g_desired_face_analysis + ): + model_path = resolve_relative_path("..") + # removed genderage + allowed_modules = roop.globals.g_desired_face_analysis + roop.globals.g_current_face_analysis = roop.globals.g_desired_face_analysis + if roop.globals.CFG.force_cpu: + print("Forcing CPU for Face Analysis") + FACE_ANALYSER = insightface.app.FaceAnalysis( + name="buffalo_l", + root=model_path, + providers=["CPUExecutionProvider"], + allowed_modules=allowed_modules, + ) + else: + FACE_ANALYSER = insightface.app.FaceAnalysis( + name="buffalo_l", + root=model_path, + providers=roop.globals.execution_providers, + allowed_modules=allowed_modules, + ) + FACE_ANALYSER.prepare( + ctx_id=0, + det_size=(640, 640) if roop.globals.default_det_size else (320, 320), + ) + return FACE_ANALYSER + + +def get_first_face(frame: Frame) -> Any: + try: + faces = get_face_analyser().get(frame) + return min(faces, key=lambda x: x.bbox[0]) + # return sorted(faces, reverse=True, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[0] + except: + return None + + +def get_all_faces(frame: Frame) -> Any: + try: + faces = get_face_analyser().get(frame) + return sorted(faces, key=lambda x: x.bbox[0]) + except: + return None + + +def extract_face_images(source_filename, video_info, extra_padding=-1.0): + face_data = [] + source_image = None + + if video_info[0]: + frame = get_video_frame(source_filename, video_info[1]) + if frame is not None: + source_image = frame + else: + return face_data + else: + source_image = cv2.imdecode( + np.fromfile(source_filename, dtype=np.uint8), cv2.IMREAD_COLOR + ) + + faces = get_all_faces(source_image) + if faces is None: + return face_data + + i = 0 + for face in faces: + (startX, startY, endX, endY) = face["bbox"].astype("int") + startX, endX, startY, endY = clamp_cut_values( + startX, endX, startY, endY, source_image + ) + if extra_padding > 0.0: + if source_image.shape[:2] == (512, 512): + i += 1 + face_data.append([face, source_image]) + continue + + found = False + for i in range(1, 3): + (startX, startY, endX, endY) = face["bbox"].astype("int") + startX, endX, startY, endY = clamp_cut_values( + startX, endX, startY, endY, source_image + ) + cutout_padding = extra_padding + # top needs extra room for detection + padding = int((endY - startY) * cutout_padding) + oldY = startY + startY -= padding + + factor = 0.25 if i == 1 else 0.5 + cutout_padding = factor + padding = int((endY - oldY) * cutout_padding) + endY += padding + padding = int((endX - startX) * cutout_padding) + startX -= padding + endX += padding + startX, endX, startY, endY = clamp_cut_values( + startX, endX, startY, endY, source_image + ) + face_temp = source_image[startY:endY, startX:endX] + face_temp = resize_image_keep_content(face_temp) + testfaces = get_all_faces(face_temp) + if testfaces is not None and len(testfaces) > 0: + i += 1 + face_data.append([testfaces[0], face_temp]) + found = True + break + + if not found: + print("No face found after resizing, this shouldn't happen!") + continue + + face_temp = source_image[startY:endY, startX:endX] + if face_temp.size < 1: + continue + + i += 1 + face_data.append([face, face_temp]) + return face_data + + +def clamp_cut_values(startX, endX, startY, endY, image): + if startX < 0: + startX = 0 + if endX > image.shape[1]: + endX = image.shape[1] + if startY < 0: + startY = 0 + if endY > image.shape[0]: + endY = image.shape[0] + return startX, endX, startY, endY + + +def face_offset_top(face: Face, offset): + face["bbox"][1] += offset + face["bbox"][3] += offset + lm106 = face.landmark_2d_106 + add = np.full_like(lm106, [0, offset]) + face["landmark_2d_106"] = lm106 + add + return face + + +def resize_image_keep_content(image, new_width=512, new_height=512): + dim = None + (h, w) = image.shape[:2] + if h > w: + r = new_height / float(h) + dim = (int(w * r), new_height) + else: + # Calculate the ratio of the width and construct the dimensions + r = new_width / float(w) + dim = (new_width, int(h * r)) + image = cv2.resize(image, dim, interpolation=cv2.INTER_AREA) + (h, w) = image.shape[:2] + if h == new_height and w == new_width: + return image + resize_img = np.zeros(shape=(new_height, new_width, 3), dtype=image.dtype) + offs = (new_width - w) if h == new_height else (new_height - h) + startoffs = int(offs // 2) if offs % 2 == 0 else int(offs // 2) + 1 + offs = int(offs // 2) + + if h == new_height: + resize_img[0:new_height, startoffs : new_width - offs] = image + else: + resize_img[startoffs : new_height - offs, 0:new_width] = image + return resize_img + + +def rotate_image_90(image, rotate=True): + if rotate: + return np.rot90(image) + else: + return np.rot90(image, 1, (1, 0)) + + +def rotate_anticlockwise(frame): + return rotate_image_90(frame) + + +def rotate_clockwise(frame): + return rotate_image_90(frame, False) + + +def rotate_image_180(image): + return np.flip(image, 0) + + +# alignment code from insightface https://github.com/deepinsight/insightface/blob/master/python-package/insightface/utils/face_align.py + +arcface_dst = np.array( + [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041], + ], + dtype=np.float32, +) + + +""" def estimate_norm(lmk, image_size=112): + assert lmk.shape == (5, 2) + if image_size % 112 == 0: + ratio = float(image_size) / 112.0 + diff_x = 0 + elif image_size % 128 == 0: + ratio = float(image_size) / 128.0 + diff_x = 8.0 * ratio + elif image_size % 512 == 0: + ratio = float(image_size) / 512.0 + diff_x = 32.0 * ratio + + dst = arcface_dst * ratio + dst[:, 0] += diff_x + tform = trans.SimilarityTransform() + tform.estimate(lmk, dst) + M = tform.params[0:2, :] + return M + """ + + +def estimate_norm(lmk, image_size=112): + if image_size % 112 == 0: + ratio = float(image_size) / 112.0 + diff_x = 0 + else: + ratio = float(image_size) / 128.0 + diff_x = 8.0 * ratio + dst = arcface_dst * ratio + dst[:, 0] += diff_x + + if image_size == 160: + dst[:, 0] += 0.1 + dst[:, 1] += 0.1 + elif image_size == 256: + dst[:, 0] += 0.5 + dst[:, 1] += 0.5 + elif image_size == 320: + dst[:, 0] += 0.75 + dst[:, 1] += 0.75 + elif image_size == 512: + dst[:, 0] += 1.5 + dst[:, 1] += 1.5 + + tform = trans.SimilarityTransform() + tform.estimate(lmk, dst) + M = tform.params[0:2, :] + return M + + +# aligned, M = norm_crop2(f[1], face.kps, 512) +def align_crop(img, landmark, image_size=112, mode="arcface"): + M = estimate_norm(landmark, image_size) + warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) + return warped, M + + +def square_crop(im, S): + if im.shape[0] > im.shape[1]: + height = S + width = int(float(im.shape[1]) / im.shape[0] * S) + scale = float(S) / im.shape[0] + else: + width = S + height = int(float(im.shape[0]) / im.shape[1] * S) + scale = float(S) / im.shape[1] + resized_im = cv2.resize(im, (width, height)) + det_im = np.zeros((S, S, 3), dtype=np.uint8) + det_im[: resized_im.shape[0], : resized_im.shape[1], :] = resized_im + return det_im, scale + + +def transform(data, center, output_size, scale, rotation): + scale_ratio = scale + rot = float(rotation) * np.pi / 180.0 + # translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) + t1 = trans.SimilarityTransform(scale=scale_ratio) + cx = center[0] * scale_ratio + cy = center[1] * scale_ratio + t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) + t3 = trans.SimilarityTransform(rotation=rot) + t4 = trans.SimilarityTransform(translation=(output_size / 2, output_size / 2)) + t = t1 + t2 + t3 + t4 + M = t.params[0:2] + cropped = cv2.warpAffine(data, M, (output_size, output_size), borderValue=0.0) + return cropped, M + + +def trans_points2d(pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) + new_pt = np.dot(M, new_pt) + # print('new_pt', new_pt.shape, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + +def trans_points3d(pts, M): + scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) + # print(scale) + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) + new_pt = np.dot(M, new_pt) + # print('new_pt', new_pt.shape, new_pt) + new_pts[i][0:2] = new_pt[0:2] + new_pts[i][2] = pts[i][2] * scale + + return new_pts + + +def trans_points(pts, M): + if pts.shape[1] == 2: + return trans_points2d(pts, M) + else: + return trans_points3d(pts, M) + + +def create_blank_image(width, height): + img = np.zeros((height, width, 4), dtype=np.uint8) + img[:] = [0, 0, 0, 0] + return img diff --git a/ffmpeg_writer.py b/ffmpeg_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..b0fdb9d0a78daf6b78ea7baeb60f2252f539b159 --- /dev/null +++ b/ffmpeg_writer.py @@ -0,0 +1,240 @@ +""" +FFMPEG_Writer - write set of frames to video file + +original from +https://github.com/Zulko/moviepy/blob/master/moviepy/video/io/ffmpeg_writer.py + +removed unnecessary dependencies + +The MIT License (MIT) + +Copyright (c) 2015 Zulko +Copyright (c) 2023 Janvarev Vladislav +""" + +import os +import subprocess as sp + +PIPE = -1 +STDOUT = -2 +DEVNULL = -3 + +FFMPEG_BINARY = "ffmpeg" + + +class FFMPEG_VideoWriter: + """A class for FFMPEG-based video writing. + + A class to write videos using ffmpeg. ffmpeg will write in a large + choice of formats. + + Parameters + ----------- + + filename + Any filename like 'video.mp4' etc. but if you want to avoid + complications it is recommended to use the generic extension + '.avi' for all your videos. + + size + Size (width,height) of the output video in pixels. + + fps + Frames per second in the output video file. + + codec + FFMPEG codec. It seems that in terms of quality the hierarchy is + 'rawvideo' = 'png' > 'mpeg4' > 'libx264' + 'png' manages the same lossless quality as 'rawvideo' but yields + smaller files. Type ``ffmpeg -codecs`` in a terminal to get a list + of accepted codecs. + + Note for default 'libx264': by default the pixel format yuv420p + is used. If the video dimensions are not both even (e.g. 720x405) + another pixel format is used, and this can cause problem in some + video readers. + + audiofile + Optional: The name of an audio file that will be incorporated + to the video. + + preset + Sets the time that FFMPEG will take to compress the video. The slower, + the better the compression rate. Possibilities are: ultrafast,superfast, + veryfast, faster, fast, medium (default), slow, slower, veryslow, + placebo. + + bitrate + Only relevant for codecs which accept a bitrate. "5000k" offers + nice results in general. + + """ + + def __init__( + self, + filename, + size, + fps, + codec="libx265", + crf=14, + audiofile=None, + preset="medium", + bitrate=None, + logfile=None, + threads=None, + ffmpeg_params=None, + ): + if logfile is None: + logfile = sp.PIPE + + self.filename = filename + self.codec = codec + self.ext = self.filename.split(".")[-1] + w = size[0] - 1 if size[0] % 2 != 0 else size[0] + h = size[1] - 1 if size[1] % 2 != 0 else size[1] + + # order is important + cmd = [ + FFMPEG_BINARY, + "-hide_banner", + "-hwaccel", + "auto", + "-y", + "-loglevel", + "error" if logfile == sp.PIPE else "info", + "-f", + "rawvideo", + "-vcodec", + "rawvideo", + "-s", + "%dx%d" % (size[0], size[1]), + #'-pix_fmt', 'rgba' if withmask else 'rgb24', + "-pix_fmt", + "bgr24", + "-r", + str(fps), + "-an", + "-i", + "-", + ] + + if audiofile is not None: + cmd.extend(["-i", audiofile, "-acodec", "copy"]) + + cmd.extend( + [ + "-vcodec", + codec, + "-crf", + str(crf), + #'-preset', preset, + ] + ) + if ffmpeg_params is not None: + cmd.extend(ffmpeg_params) + if bitrate is not None: + cmd.extend(["-b", bitrate]) + + # scale to a resolution divisible by 2 if not even + cmd.extend( + [ + "-vf", + f"scale={w}:{h}" + if w != size[0] or h != size[1] + else "colorspace=bt709:iall=bt601-6-625:fast=1", + ] + ) + + if threads is not None: + cmd.extend(["-threads", str(threads)]) + + cmd.extend( + [ + "-pix_fmt", + "yuv420p", + ] + ) + cmd.extend([filename]) + + test = str(cmd) + print(test) + + popen_params = {"stdout": DEVNULL, "stderr": logfile, "stdin": sp.PIPE} + + # This was added so that no extra unwanted window opens on windows + # when the child process is created + if os.name == "nt": + popen_params["creationflags"] = 0x08000000 # CREATE_NO_WINDOW + + self.proc = sp.Popen(cmd, **popen_params) + + def write_frame(self, img_array): + """Writes one frame in the file.""" + try: + # if PY3: + self.proc.stdin.write(img_array.tobytes()) + # else: + # self.proc.stdin.write(img_array.tostring()) + except IOError as err: + _, ffmpeg_error = self.proc.communicate() + error = str(err) + ( + "\n\nroop unleashed error: FFMPEG encountered " + "the following error while writing file %s:" + "\n\n %s" % (self.filename, str(ffmpeg_error)) + ) + + if b"Unknown encoder" in ffmpeg_error: + error = error + ( + "\n\nThe video export " + "failed because FFMPEG didn't find the specified " + "codec for video encoding (%s). Please install " + "this codec or change the codec when calling " + "write_videofile. For instance:\n" + " >>> clip.write_videofile('myvid.webm', codec='libvpx')" + ) % (self.codec) + + elif b"incorrect codec parameters ?" in ffmpeg_error: + error = error + ( + "\n\nThe video export " + "failed, possibly because the codec specified for " + "the video (%s) is not compatible with the given " + "extension (%s). Please specify a valid 'codec' " + "argument in write_videofile. This would be 'libx264' " + "or 'mpeg4' for mp4, 'libtheora' for ogv, 'libvpx for webm. " + "Another possible reason is that the audio codec was not " + "compatible with the video codec. For instance the video " + "extensions 'ogv' and 'webm' only allow 'libvorbis' (default) as a" + "video codec." + ) % (self.codec, self.ext) + + elif b"encoder setup failed" in ffmpeg_error: + error = error + ( + "\n\nThe video export " + "failed, possibly because the bitrate you specified " + "was too high or too low for the video codec." + ) + + elif b"Invalid encoder type" in ffmpeg_error: + error = error + ( + "\n\nThe video export failed because the codec " + "or file extension you provided is not a video" + ) + + raise IOError(error) + + def close(self): + if self.proc: + self.proc.stdin.close() + if self.proc.stderr is not None: + self.proc.stderr.close() + self.proc.wait() + + self.proc = None + + # Support the Context Manager protocol, to ensure that resources are cleaned up. + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() diff --git a/globals.py b/globals.py new file mode 100644 index 0000000000000000000000000000000000000000..ca0bd50e4a8ea0951e3b2b368ed228f284aae2b3 --- /dev/null +++ b/globals.py @@ -0,0 +1,54 @@ +from settings import Settings +from typing import List + +source_path = None +target_path = None +output_path = None +target_folder_path = None +startup_args = None + +cuda_device_id = 0 +frame_processors: List[str] = [] +keep_fps = None +keep_frames = None +autorotate_faces = None +vr_mode = None +skip_audio = None +wait_after_extraction = None +many_faces = None +use_batch = None +source_face_index = 0 +target_face_index = 0 +face_position = None +video_encoder = None +video_quality = None +max_memory = None +execution_providers: List[str] = [] +execution_threads = None +headless = None +log_level = "error" +selected_enhancer = None +subsample_size = 128 +face_swap_mode = None +blend_ratio = 0.5 +distance_threshold = 0.65 +default_det_size = True + +no_face_action = 0 + +processing = False + +g_current_face_analysis = None +g_desired_face_analysis = None + +FACE_ENHANCER = None + +INPUT_FACESETS = [] +TARGET_FACES = [] + + +IMAGE_CHAIN_PROCESSOR = None +VIDEO_CHAIN_PROCESSOR = None +BATCH_IMAGE_CHAIN_PROCESSOR = None + +CFG: Settings = None diff --git a/metadata.py b/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..f334619ca2952c648bbe3b50571588974c63c3b0 --- /dev/null +++ b/metadata.py @@ -0,0 +1,2 @@ +name = "roop unleashed" +version = "4.4.1" diff --git a/processors/Enhance_CodeFormer.py b/processors/Enhance_CodeFormer.py new file mode 100644 index 0000000000000000000000000000000000000000..985272e14535368347c06da7db5e1c30010d4c82 --- /dev/null +++ b/processors/Enhance_CodeFormer.py @@ -0,0 +1,76 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + + +class Enhance_CodeFormer: + model_codeformer = None + + plugin_options: dict = None + + processorname = "codeformer" + type = "enhance" + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_codeformer is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace("mps", "cpu") + model_path = resolve_relative_path( + "../models/CodeFormer/CodeFormerv0.1.onnx" + ) + self.model_codeformer = onnxruntime.InferenceSession( + model_path, None, providers=roop.globals.execution_providers + ) + self.model_inputs = self.model_codeformer.get_inputs() + model_outputs = self.model_codeformer.get_outputs() + self.io_binding = self.model_codeformer.io_binding() + self.io_binding.bind_cpu_input(self.model_inputs[1].name, np.array([0.5])) + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def Run( + self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame + ) -> Frame: + input_size = temp_frame.shape[1] + # preprocess + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype("float32") / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + self.io_binding.bind_cpu_input( + self.model_inputs[0].name, temp_frame.astype(np.float32) + ) + self.model_codeformer.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + + # post-process + result = result.transpose((1, 2, 0)) + + un_min = -1.0 + un_max = 1.0 + result = np.clip(result, un_min, un_max) + result = (result - un_min) / (un_max - un_min) + + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + result = (result * 255.0).round() + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + def Release(self): + del self.model_codeformer + self.model_codeformer = None + del self.io_binding + self.io_binding = None diff --git a/processors/Enhance_DMDNet.py b/processors/Enhance_DMDNet.py new file mode 100644 index 0000000000000000000000000000000000000000..bfacae0c1ec81580887f85144f713512fbd92006 --- /dev/null +++ b/processors/Enhance_DMDNet.py @@ -0,0 +1,1425 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.spectral_norm as SpectralNorm +import threading +from torchvision.ops import roi_align + +from math import sqrt + +from torchvision.transforms.functional import normalize + +from roop.typing import Face, Frame, FaceSet + + +THREAD_LOCK_DMDNET = threading.Lock() + + +class Enhance_DMDNet: + plugin_options: dict = None + model_dmdnet = None + torchdevice = None + + processorname = "dmdnet" + type = "enhance" + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_dmdnet is None: + self.model_dmdnet = self.create(self.plugin_options["devicename"]) + + # temp_frame already cropped+aligned, bbox not + def Run( + self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame + ) -> Frame: + input_size = temp_frame.shape[1] + + result = self.enhance_face(source_faceset, temp_frame, target_face) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + def Release(self): + self.model_dmdnet = None + + # https://stackoverflow.com/a/67174339 + def landmarks106_to_68(self, pt106): + map106to68 = [ + 1, + 10, + 12, + 14, + 16, + 3, + 5, + 7, + 0, + 23, + 21, + 19, + 32, + 30, + 28, + 26, + 17, + 43, + 48, + 49, + 51, + 50, + 102, + 103, + 104, + 105, + 101, + 72, + 73, + 74, + 86, + 78, + 79, + 80, + 85, + 84, + 35, + 41, + 42, + 39, + 37, + 36, + 89, + 95, + 96, + 93, + 91, + 90, + 52, + 64, + 63, + 71, + 67, + 68, + 61, + 58, + 59, + 53, + 56, + 55, + 65, + 66, + 62, + 70, + 69, + 57, + 60, + 54, + ] + + pt68 = [] + for i in range(68): + index = map106to68[i] + pt68.append(pt106[index]) + return pt68 + + def check_bbox(self, imgs, boxes): + boxes = boxes.view(-1, 4, 4) + colors = [(0, 255, 0), (0, 255, 0), (255, 255, 0), (255, 0, 0)] + i = 0 + for img, box in zip(imgs, boxes): + img = (img + 1) / 2 * 255 + img2 = img.permute(1, 2, 0).float().cpu().flip(2).numpy().copy() + for idx, point in enumerate(box): + cv2.rectangle( + img2, + (int(point[0]), int(point[1])), + (int(point[2]), int(point[3])), + color=colors[idx], + thickness=2, + ) + cv2.imwrite("dmdnet_{:02d}.png".format(i), img2) + i += 1 + + def trans_points2d(self, pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) + new_pt = np.dot(M, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + def enhance_face(self, ref_faceset: FaceSet, temp_frame, face: Face): + # preprocess + start_x, start_y, end_x, end_y = map(int, face["bbox"]) + lm106 = face.landmark_2d_106 + lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) + + if temp_frame.shape[0] != 512 or temp_frame.shape[1] != 512: + # scale to 512x512 + scale_factor = 512 / temp_frame.shape[1] + + M = face.matrix * scale_factor + + lq_landmarks = self.trans_points2d(lq_landmarks, M) + temp_frame = cv2.resize( + temp_frame, (512, 512), interpolation=cv2.INTER_AREA + ) + + if temp_frame.ndim == 2: + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG + # else: + # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB + + lq = read_img_tensor(temp_frame) + + LQLocs = get_component_location(lq_landmarks) + # self.check_bbox(lq, LQLocs.unsqueeze(0)) + + # specific, change 1000 to 1 to activate + if len(ref_faceset.faces) > 1: + SpecificImgs = [] + SpecificLocs = [] + for i, face in enumerate(ref_faceset.faces): + lm106 = face.landmark_2d_106 + lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) + ref_image = ref_faceset.ref_images[i] + if ref_image.shape[0] != 512 or ref_image.shape[1] != 512: + # scale to 512x512 + scale_factor = 512 / ref_image.shape[1] + + M = face.matrix * scale_factor + + lq_landmarks = self.trans_points2d(lq_landmarks, M) + ref_image = cv2.resize( + ref_image, (512, 512), interpolation=cv2.INTER_AREA + ) + + if ref_image.ndim == 2: + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG + # else: + # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB + + ref_tensor = read_img_tensor(ref_image) + ref_locs = get_component_location(lq_landmarks) + # self.check_bbox(ref_tensor, ref_locs.unsqueeze(0)) + + SpecificImgs.append(ref_tensor) + SpecificLocs.append(ref_locs.unsqueeze(0)) + + SpecificImgs = torch.cat(SpecificImgs, dim=0) + SpecificLocs = torch.cat(SpecificLocs, dim=0) + # check_bbox(SpecificImgs, SpecificLocs) + SpMem256, SpMem128, SpMem64 = ( + self.model_dmdnet.generate_specific_dictionary( + sp_imgs=SpecificImgs.to(self.torchdevice), sp_locs=SpecificLocs + ) + ) + SpMem256Para = {} + SpMem128Para = {} + SpMem64Para = {} + for k, v in SpMem256.items(): + SpMem256Para[k] = v + for k, v in SpMem128.items(): + SpMem128Para[k] = v + for k, v in SpMem64.items(): + SpMem64Para[k] = v + else: + # generic + SpMem256Para, SpMem128Para, SpMem64Para = None, None, None + + with torch.no_grad(): + with THREAD_LOCK_DMDNET: + try: + GenericResult, SpecificResult = self.model_dmdnet( + lq=lq.to(self.torchdevice), + loc=LQLocs.unsqueeze(0), + sp_256=SpMem256Para, + sp_128=SpMem128Para, + sp_64=SpMem64Para, + ) + except Exception as e: + print( + f"Error {e} there may be something wrong with the detected component locations." + ) + return temp_frame + + if SpecificResult is not None: + save_specific = SpecificResult * 0.5 + 0.5 + save_specific = ( + save_specific.squeeze(0).permute(1, 2, 0).flip(2) + ) # RGB->BGR + save_specific = np.clip(save_specific.float().cpu().numpy(), 0, 1) * 255.0 + temp_frame = save_specific.astype("uint8") + if False: + save_generic = GenericResult * 0.5 + 0.5 + save_generic = ( + save_generic.squeeze(0).permute(1, 2, 0).flip(2) + ) # RGB->BGR + save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 + check_lq = lq * 0.5 + 0.5 + check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0 + cv2.imwrite( + "dmdnet_comparison.png", + cv2.cvtColor( + np.hstack((check_lq, save_generic, save_specific)), + cv2.COLOR_RGB2BGR, + ), + ) + else: + save_generic = GenericResult * 0.5 + 0.5 + save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 + temp_frame = save_generic.astype("uint8") + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_RGB2BGR) # RGB + return temp_frame + + def create(self, devicename): + self.torchdevice = torch.device(devicename) + model_dmdnet = DMDNet().to(self.torchdevice) + weights = torch.load("./models/DMDNet.pth", map_location=self.torchdevice) + model_dmdnet.load_state_dict(weights, strict=False) + + model_dmdnet.eval() + num_params = 0 + for param in model_dmdnet.parameters(): + num_params += param.numel() + return model_dmdnet + + # print('{:>8s} : {}'.format('Using device', device)) + # print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6)) + + +def read_img_tensor(Img=None): # rgb -1~1 + Img = Img.transpose((2, 0, 1)) / 255.0 + Img = torch.from_numpy(Img).float() + normalize(Img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True) + ImgTensor = Img.unsqueeze(0) + return ImgTensor + + +def get_component_location(Landmarks, re_read=False): + if re_read: + ReadLandmark = [] + with open(Landmarks, "r") as f: + for line in f: + tmp = [float(i) for i in line.split(" ") if i != "\n"] + ReadLandmark.append(tmp) + ReadLandmark = np.array(ReadLandmark) # + Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2 + Map_LE_B = list(np.hstack((range(17, 22), range(36, 42)))) + Map_RE_B = list(np.hstack((range(22, 27), range(42, 48)))) + Map_LE = list(range(36, 42)) + Map_RE = list(range(42, 48)) + Map_NO = list(range(29, 36)) + Map_MO = list(range(48, 68)) + + Landmarks[Landmarks > 504] = 504 + Landmarks[Landmarks < 8] = 8 + + # left eye + Mean_LE = np.mean(Landmarks[Map_LE], 0) + L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B, 1]) + L_LE1 = L_LE1 * 1.3 + L_LE2 = L_LE1 / 1.9 + L_LE_xy = L_LE1 + L_LE2 + L_LE_lt = [L_LE_xy / 2, L_LE1] + L_LE_rb = [L_LE_xy / 2, L_LE2] + Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int) + + # right eye + Mean_RE = np.mean(Landmarks[Map_RE], 0) + L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B, 1]) + L_RE1 = L_RE1 * 1.3 + L_RE2 = L_RE1 / 1.9 + L_RE_xy = L_RE1 + L_RE2 + L_RE_lt = [L_RE_xy / 2, L_RE1] + L_RE_rb = [L_RE_xy / 2, L_RE2] + Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int) + + # nose + Mean_NO = np.mean(Landmarks[Map_NO], 0) + L_NO1 = ( + np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]]) + ) * 1.25 + L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1 + L_NO_xy = L_NO1 * 2 + L_NO_lt = [L_NO_xy / 2, L_NO_xy - L_NO2] + L_NO_rb = [L_NO_xy / 2, L_NO2] + Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int) + + # mouth + Mean_MO = np.mean(Landmarks[Map_MO], 0) + L_MO = ( + np.max( + ( + np.max(np.max(Landmarks[Map_MO], 0) - np.min(Landmarks[Map_MO], 0)) / 2, + 16, + ) + ) + * 1.1 + ) + MO_O = Mean_MO - L_MO + 1 + MO_T = Mean_MO + L_MO + MO_T[MO_T > 510] = 510 + Location_MO = np.hstack((MO_O, MO_T)).astype(int) + return torch.cat( + [ + torch.FloatTensor(Location_LE).unsqueeze(0), + torch.FloatTensor(Location_RE).unsqueeze(0), + torch.FloatTensor(Location_NO).unsqueeze(0), + torch.FloatTensor(Location_MO).unsqueeze(0), + ], + dim=0, + ) + + +def calc_mean_std_4D(feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + assert len(size) == 4 + N, C = size[:2] + feat_var = feat.view(N, C, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(N, C, 1, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization_4D( + content_feat, style_feat +): # content_feat is ref feature, style is degradate feature + size = content_feat.size() + style_mean, style_std = calc_mean_std_4D(style_feat) + content_mean, content_std = calc_mean_std_4D(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand( + size + ) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +def convU( + in_channels, + out_channels, + conv_layer, + norm_layer, + kernel_size=3, + stride=1, + dilation=1, + bias=True, +): + return nn.Sequential( + SpectralNorm( + conv_layer( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=((kernel_size - 1) // 2) * dilation, + bias=bias, + ) + ), + nn.LeakyReLU(0.2), + SpectralNorm( + conv_layer( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=((kernel_size - 1) // 2) * dilation, + bias=bias, + ) + ), + ) + + +class MSDilateBlock(nn.Module): + def __init__( + self, + in_channels, + conv_layer=nn.Conv2d, + norm_layer=nn.BatchNorm2d, + kernel_size=3, + dilation=[1, 1, 1, 1], + bias=True, + ): + super(MSDilateBlock, self).__init__() + self.conv1 = convU( + in_channels, + in_channels, + conv_layer, + norm_layer, + kernel_size, + dilation=dilation[0], + bias=bias, + ) + self.conv2 = convU( + in_channels, + in_channels, + conv_layer, + norm_layer, + kernel_size, + dilation=dilation[1], + bias=bias, + ) + self.conv3 = convU( + in_channels, + in_channels, + conv_layer, + norm_layer, + kernel_size, + dilation=dilation[2], + bias=bias, + ) + self.conv4 = convU( + in_channels, + in_channels, + conv_layer, + norm_layer, + kernel_size, + dilation=dilation[3], + bias=bias, + ) + self.convi = SpectralNorm( + conv_layer( + in_channels * 4, + in_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=bias, + ) + ) + + def forward(self, x): + conv1 = self.conv1(x) + conv2 = self.conv2(x) + conv3 = self.conv3(x) + conv4 = self.conv4(x) + cat = torch.cat([conv1, conv2, conv3, conv4], 1) + out = self.convi(cat) + x + return out + + +class AdaptiveInstanceNorm(nn.Module): + def __init__(self, in_channel): + super().__init__() + self.norm = nn.InstanceNorm2d(in_channel) + + def forward(self, input, style): + style_mean, style_std = calc_mean_std_4D(style) + out = self.norm(input) + size = input.size() + out = style_std.expand(size) * out + style_mean.expand(size) + return out + + +class NoiseInjection(nn.Module): + def __init__(self, channel): + super().__init__() + self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) + + def forward(self, image, noise): + if noise is None: + b, c, h, w = image.shape + noise = image.new_empty(b, 1, h, w).normal_() + return image + self.weight * noise + + +class StyledUpBlock(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size=3, + padding=1, + upsample=False, + noise_inject=False, + ): + super().__init__() + + self.noise_inject = noise_inject + if upsample: + self.conv1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + SpectralNorm( + nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding) + ), + nn.LeakyReLU(0.2), + ) + else: + self.conv1 = nn.Sequential( + SpectralNorm( + nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding) + ), + nn.LeakyReLU(0.2), + SpectralNorm( + nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding) + ), + ) + self.convup = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + SpectralNorm( + nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding) + ), + nn.LeakyReLU(0.2), + SpectralNorm( + nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding) + ), + ) + if self.noise_inject: + self.noise1 = NoiseInjection(out_channel) + + self.lrelu1 = nn.LeakyReLU(0.2) + + self.ScaleModel1 = nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), + ) + self.ShiftModel1 = nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), + ) + + def forward(self, input, style): + out = self.conv1(input) + out = self.lrelu1(out) + Shift1 = self.ShiftModel1(style) + Scale1 = self.ScaleModel1(style) + out = out * Scale1 + Shift1 + if self.noise_inject: + out = self.noise1(out, noise=None) + outup = self.convup(out) + return outup + + +#################################################################### +###############Face Dictionary Generator +#################################################################### +def AttentionBlock(in_channel): + return nn.Sequential( + SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), + ) + + +class DilateResBlock(nn.Module): + def __init__(self, dim, dilation=[5, 3]): + super(DilateResBlock, self).__init__() + self.Res = nn.Sequential( + SpectralNorm( + nn.Conv2d(dim, dim, 3, 1, ((3 - 1) // 2) * dilation[0], dilation[0]) + ), + nn.LeakyReLU(0.2), + SpectralNorm( + nn.Conv2d(dim, dim, 3, 1, ((3 - 1) // 2) * dilation[1], dilation[1]) + ), + ) + + def forward(self, x): + out = x + self.Res(x) + return out + + +class KeyValue(nn.Module): + def __init__(self, indim, keydim, valdim): + super(KeyValue, self).__init__() + self.Key = nn.Sequential( + SpectralNorm( + nn.Conv2d(indim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1) + ), + nn.LeakyReLU(0.2), + SpectralNorm( + nn.Conv2d(keydim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1) + ), + ) + self.Value = nn.Sequential( + SpectralNorm( + nn.Conv2d(indim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1) + ), + nn.LeakyReLU(0.2), + SpectralNorm( + nn.Conv2d(valdim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1) + ), + ) + + def forward(self, x): + return self.Key(x), self.Value(x) + + +class MaskAttention(nn.Module): + def __init__(self, indim): + super(MaskAttention, self).__init__() + self.conv1 = nn.Sequential( + SpectralNorm( + nn.Conv2d( + indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 + ) + ), + nn.LeakyReLU(0.2), + SpectralNorm( + nn.Conv2d( + indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 + ) + ), + ) + self.conv2 = nn.Sequential( + SpectralNorm( + nn.Conv2d( + indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 + ) + ), + nn.LeakyReLU(0.2), + SpectralNorm( + nn.Conv2d( + indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 + ) + ), + ) + self.conv3 = nn.Sequential( + SpectralNorm( + nn.Conv2d( + indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 + ) + ), + nn.LeakyReLU(0.2), + SpectralNorm( + nn.Conv2d( + indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1 + ) + ), + ) + self.convCat = nn.Sequential( + SpectralNorm( + nn.Conv2d( + indim // 3 * 3, indim, kernel_size=(3, 3), padding=(1, 1), stride=1 + ) + ), + nn.LeakyReLU(0.2), + SpectralNorm( + nn.Conv2d(indim, indim, kernel_size=(3, 3), padding=(1, 1), stride=1) + ), + ) + + def forward(self, x, y, z): + c1 = self.conv1(x) + c2 = self.conv2(y) + c3 = self.conv3(z) + return self.convCat(torch.cat([c1, c2, c3], dim=1)) + + +class Query(nn.Module): + def __init__(self, indim, quedim): + super(Query, self).__init__() + self.Query = nn.Sequential( + SpectralNorm( + nn.Conv2d(indim, quedim, kernel_size=(3, 3), padding=(1, 1), stride=1) + ), + nn.LeakyReLU(0.2), + SpectralNorm( + nn.Conv2d(quedim, quedim, kernel_size=(3, 3), padding=(1, 1), stride=1) + ), + ) + + def forward(self, x): + return self.Query(x) + + +def roi_align_self(input, location, target_size): + test = (target_size.item(), target_size.item()) + return torch.cat( + [ + F.interpolate( + input[ + i : i + 1, + :, + location[i, 1] : location[i, 3], + location[i, 0] : location[i, 2], + ], + test, + mode="bilinear", + align_corners=False, + ) + for i in range(input.size(0)) + ], + 0, + ) + + +class FeatureExtractor(nn.Module): + def __init__(self, ngf=64, key_scale=4): # + super().__init__() + + self.key_scale = 4 + self.part_sizes = np.array([80, 80, 50, 110]) # + self.feature_sizes = np.array([256, 128, 64]) # + + self.conv1 = nn.Sequential( + SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), + ) + self.conv2 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), + ) + self.res1 = DilateResBlock(ngf, [5, 3]) + self.res2 = DilateResBlock(ngf, [5, 3]) + + self.conv3 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf, ngf * 2, 3, 2, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)), + ) + self.conv4 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)), + ) + self.res3 = DilateResBlock(ngf * 2, [3, 1]) + self.res4 = DilateResBlock(ngf * 2, [3, 1]) + + self.conv5 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf * 2, ngf * 4, 3, 2, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)), + ) + self.conv6 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)), + ) + self.res5 = DilateResBlock(ngf * 4, [1, 1]) + self.res6 = DilateResBlock(ngf * 4, [1, 1]) + + self.LE_256_Q = Query(ngf, ngf // self.key_scale) + self.RE_256_Q = Query(ngf, ngf // self.key_scale) + self.MO_256_Q = Query(ngf, ngf // self.key_scale) + self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) + self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) + self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) + self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) + self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) + self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) + + def forward(self, img, locs): + le_location = locs[:, 0, :].int().cpu().numpy() + re_location = locs[:, 1, :].int().cpu().numpy() + no_location = locs[:, 2, :].int().cpu().numpy() + mo_location = locs[:, 3, :].int().cpu().numpy() + + f1_0 = self.conv1(img) + f1_1 = self.res1(f1_0) + f2_0 = self.conv2(f1_1) + f2_1 = self.res2(f2_0) + + f3_0 = self.conv3(f2_1) + f3_1 = self.res3(f3_0) + f4_0 = self.conv4(f3_1) + f4_1 = self.res4(f4_0) + + f5_0 = self.conv5(f4_1) + f5_1 = self.res5(f5_0) + f6_0 = self.conv6(f5_1) + f6_1 = self.res6(f6_0) + + ####ROI Align + le_part_256 = roi_align_self( + f2_1.clone(), le_location // 2, self.part_sizes[0] // 2 + ) + re_part_256 = roi_align_self( + f2_1.clone(), re_location // 2, self.part_sizes[1] // 2 + ) + mo_part_256 = roi_align_self( + f2_1.clone(), mo_location // 2, self.part_sizes[3] // 2 + ) + + le_part_128 = roi_align_self( + f4_1.clone(), le_location // 4, self.part_sizes[0] // 4 + ) + re_part_128 = roi_align_self( + f4_1.clone(), re_location // 4, self.part_sizes[1] // 4 + ) + mo_part_128 = roi_align_self( + f4_1.clone(), mo_location // 4, self.part_sizes[3] // 4 + ) + + le_part_64 = roi_align_self( + f6_1.clone(), le_location // 8, self.part_sizes[0] // 8 + ) + re_part_64 = roi_align_self( + f6_1.clone(), re_location // 8, self.part_sizes[1] // 8 + ) + mo_part_64 = roi_align_self( + f6_1.clone(), mo_location // 8, self.part_sizes[3] // 8 + ) + + le_256_q = self.LE_256_Q(le_part_256) + re_256_q = self.RE_256_Q(re_part_256) + mo_256_q = self.MO_256_Q(mo_part_256) + + le_128_q = self.LE_128_Q(le_part_128) + re_128_q = self.RE_128_Q(re_part_128) + mo_128_q = self.MO_128_Q(mo_part_128) + + le_64_q = self.LE_64_Q(le_part_64) + re_64_q = self.RE_64_Q(re_part_64) + mo_64_q = self.MO_64_Q(mo_part_64) + + return { + "f256": f2_1, + "f128": f4_1, + "f64": f6_1, + "le256": le_part_256, + "re256": re_part_256, + "mo256": mo_part_256, + "le128": le_part_128, + "re128": re_part_128, + "mo128": mo_part_128, + "le64": le_part_64, + "re64": re_part_64, + "mo64": mo_part_64, + "le_256_q": le_256_q, + "re_256_q": re_256_q, + "mo_256_q": mo_256_q, + "le_128_q": le_128_q, + "re_128_q": re_128_q, + "mo_128_q": mo_128_q, + "le_64_q": le_64_q, + "re_64_q": re_64_q, + "mo_64_q": mo_64_q, + } + + +class DMDNet(nn.Module): + def __init__(self, ngf=64, banks_num=128): + super().__init__() + self.part_sizes = np.array([80, 80, 50, 110]) # size for 512 + self.feature_sizes = np.array([256, 128, 64]) # size for 512 + + self.banks_num = banks_num + self.key_scale = 4 + + self.E_lq = FeatureExtractor(key_scale=self.key_scale) + self.E_hq = FeatureExtractor(key_scale=self.key_scale) + + self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) + self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) + self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) + + self.LE_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2) + self.RE_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2) + self.MO_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2) + + self.LE_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4) + self.RE_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4) + self.MO_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4) + + self.LE_256_Attention = AttentionBlock(64) + self.RE_256_Attention = AttentionBlock(64) + self.MO_256_Attention = AttentionBlock(64) + + self.LE_128_Attention = AttentionBlock(128) + self.RE_128_Attention = AttentionBlock(128) + self.MO_128_Attention = AttentionBlock(128) + + self.LE_64_Attention = AttentionBlock(256) + self.RE_64_Attention = AttentionBlock(256) + self.MO_64_Attention = AttentionBlock(256) + + self.LE_256_Mask = MaskAttention(64) + self.RE_256_Mask = MaskAttention(64) + self.MO_256_Mask = MaskAttention(64) + + self.LE_128_Mask = MaskAttention(128) + self.RE_128_Mask = MaskAttention(128) + self.MO_128_Mask = MaskAttention(128) + + self.LE_64_Mask = MaskAttention(256) + self.RE_64_Mask = MaskAttention(256) + self.MO_64_Mask = MaskAttention(256) + + self.MSDilate = MSDilateBlock(ngf * 4, dilation=[4, 3, 2, 1]) + + self.up1 = StyledUpBlock(ngf * 4, ngf * 2, noise_inject=False) # + self.up2 = StyledUpBlock(ngf * 2, ngf, noise_inject=False) # + self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) # + self.up4 = nn.Sequential( + SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), + nn.LeakyReLU(0.2), + UpResBlock(ngf), + UpResBlock(ngf), + SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)), + nn.Tanh(), + ) + + # define generic memory, revise register_buffer to register_parameter for backward update + self.register_buffer("le_256_mem_key", torch.randn(128, 16, 40, 40)) + self.register_buffer("re_256_mem_key", torch.randn(128, 16, 40, 40)) + self.register_buffer("mo_256_mem_key", torch.randn(128, 16, 55, 55)) + self.register_buffer("le_256_mem_value", torch.randn(128, 64, 40, 40)) + self.register_buffer("re_256_mem_value", torch.randn(128, 64, 40, 40)) + self.register_buffer("mo_256_mem_value", torch.randn(128, 64, 55, 55)) + + self.register_buffer("le_128_mem_key", torch.randn(128, 32, 20, 20)) + self.register_buffer("re_128_mem_key", torch.randn(128, 32, 20, 20)) + self.register_buffer("mo_128_mem_key", torch.randn(128, 32, 27, 27)) + self.register_buffer("le_128_mem_value", torch.randn(128, 128, 20, 20)) + self.register_buffer("re_128_mem_value", torch.randn(128, 128, 20, 20)) + self.register_buffer("mo_128_mem_value", torch.randn(128, 128, 27, 27)) + + self.register_buffer("le_64_mem_key", torch.randn(128, 64, 10, 10)) + self.register_buffer("re_64_mem_key", torch.randn(128, 64, 10, 10)) + self.register_buffer("mo_64_mem_key", torch.randn(128, 64, 13, 13)) + self.register_buffer("le_64_mem_value", torch.randn(128, 256, 10, 10)) + self.register_buffer("re_64_mem_value", torch.randn(128, 256, 10, 10)) + self.register_buffer("mo_64_mem_value", torch.randn(128, 256, 13, 13)) + + def readMem(self, k, v, q): + sim = F.conv2d(q, k) + score = F.softmax(sim / sqrt(sim.size(1)), dim=1) # B * S * 1 * 1 6*128 + sb, sn, sw, sh = score.size() + s_m = score.view(sb, -1).unsqueeze(1) # 2*1*M + vb, vn, vw, vh = v.size() + v_in = v.view(vb, -1).repeat(sb, 1, 1) # 2*M*(c*w*h) + mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw, vh) + max_inds = torch.argmax(score, dim=1).squeeze() + return mem_out, max_inds + + def memorize(self, img, locs): + fs = self.E_hq(img, locs) + LE256_key, LE256_value = self.LE_256_KV(fs["le256"]) + RE256_key, RE256_value = self.RE_256_KV(fs["re256"]) + MO256_key, MO256_value = self.MO_256_KV(fs["mo256"]) + + LE128_key, LE128_value = self.LE_128_KV(fs["le128"]) + RE128_key, RE128_value = self.RE_128_KV(fs["re128"]) + MO128_key, MO128_value = self.MO_128_KV(fs["mo128"]) + + LE64_key, LE64_value = self.LE_64_KV(fs["le64"]) + RE64_key, RE64_value = self.RE_64_KV(fs["re64"]) + MO64_key, MO64_value = self.MO_64_KV(fs["mo64"]) + + Mem256 = { + "LE256Key": LE256_key, + "LE256Value": LE256_value, + "RE256Key": RE256_key, + "RE256Value": RE256_value, + "MO256Key": MO256_key, + "MO256Value": MO256_value, + } + Mem128 = { + "LE128Key": LE128_key, + "LE128Value": LE128_value, + "RE128Key": RE128_key, + "RE128Value": RE128_value, + "MO128Key": MO128_key, + "MO128Value": MO128_value, + } + Mem64 = { + "LE64Key": LE64_key, + "LE64Value": LE64_value, + "RE64Key": RE64_key, + "RE64Value": RE64_value, + "MO64Key": MO64_key, + "MO64Value": MO64_value, + } + + FS256 = {"LE256F": fs["le256"], "RE256F": fs["re256"], "MO256F": fs["mo256"]} + FS128 = {"LE128F": fs["le128"], "RE128F": fs["re128"], "MO128F": fs["mo128"]} + FS64 = {"LE64F": fs["le64"], "RE64F": fs["re64"], "MO64F": fs["mo64"]} + + return Mem256, Mem128, Mem64 + + def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None): + le_256_q = fs_in["le_256_q"] + re_256_q = fs_in["re_256_q"] + mo_256_q = fs_in["mo_256_q"] + + le_128_q = fs_in["le_128_q"] + re_128_q = fs_in["re_128_q"] + mo_128_q = fs_in["mo_128_q"] + + le_64_q = fs_in["le_64_q"] + re_64_q = fs_in["re_64_q"] + mo_64_q = fs_in["mo_64_q"] + + ####for 256 + le_256_mem_g, le_256_inds = self.readMem( + self.le_256_mem_key, self.le_256_mem_value, le_256_q + ) + re_256_mem_g, re_256_inds = self.readMem( + self.re_256_mem_key, self.re_256_mem_value, re_256_q + ) + mo_256_mem_g, mo_256_inds = self.readMem( + self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q + ) + + le_128_mem_g, le_128_inds = self.readMem( + self.le_128_mem_key, self.le_128_mem_value, le_128_q + ) + re_128_mem_g, re_128_inds = self.readMem( + self.re_128_mem_key, self.re_128_mem_value, re_128_q + ) + mo_128_mem_g, mo_128_inds = self.readMem( + self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q + ) + + le_64_mem_g, le_64_inds = self.readMem( + self.le_64_mem_key, self.le_64_mem_value, le_64_q + ) + re_64_mem_g, re_64_inds = self.readMem( + self.re_64_mem_key, self.re_64_mem_value, re_64_q + ) + mo_64_mem_g, mo_64_inds = self.readMem( + self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q + ) + + if sp_256 is not None and sp_128 is not None and sp_64 is not None: + le_256_mem_s, _ = self.readMem( + sp_256["LE256Key"], sp_256["LE256Value"], le_256_q + ) + re_256_mem_s, _ = self.readMem( + sp_256["RE256Key"], sp_256["RE256Value"], re_256_q + ) + mo_256_mem_s, _ = self.readMem( + sp_256["MO256Key"], sp_256["MO256Value"], mo_256_q + ) + le_256_mask = self.LE_256_Mask(fs_in["le256"], le_256_mem_s, le_256_mem_g) + le_256_mem = le_256_mask * le_256_mem_s + (1 - le_256_mask) * le_256_mem_g + re_256_mask = self.RE_256_Mask(fs_in["re256"], re_256_mem_s, re_256_mem_g) + re_256_mem = re_256_mask * re_256_mem_s + (1 - re_256_mask) * re_256_mem_g + mo_256_mask = self.MO_256_Mask(fs_in["mo256"], mo_256_mem_s, mo_256_mem_g) + mo_256_mem = mo_256_mask * mo_256_mem_s + (1 - mo_256_mask) * mo_256_mem_g + + le_128_mem_s, _ = self.readMem( + sp_128["LE128Key"], sp_128["LE128Value"], le_128_q + ) + re_128_mem_s, _ = self.readMem( + sp_128["RE128Key"], sp_128["RE128Value"], re_128_q + ) + mo_128_mem_s, _ = self.readMem( + sp_128["MO128Key"], sp_128["MO128Value"], mo_128_q + ) + le_128_mask = self.LE_128_Mask(fs_in["le128"], le_128_mem_s, le_128_mem_g) + le_128_mem = le_128_mask * le_128_mem_s + (1 - le_128_mask) * le_128_mem_g + re_128_mask = self.RE_128_Mask(fs_in["re128"], re_128_mem_s, re_128_mem_g) + re_128_mem = re_128_mask * re_128_mem_s + (1 - re_128_mask) * re_128_mem_g + mo_128_mask = self.MO_128_Mask(fs_in["mo128"], mo_128_mem_s, mo_128_mem_g) + mo_128_mem = mo_128_mask * mo_128_mem_s + (1 - mo_128_mask) * mo_128_mem_g + + le_64_mem_s, _ = self.readMem(sp_64["LE64Key"], sp_64["LE64Value"], le_64_q) + re_64_mem_s, _ = self.readMem(sp_64["RE64Key"], sp_64["RE64Value"], re_64_q) + mo_64_mem_s, _ = self.readMem(sp_64["MO64Key"], sp_64["MO64Value"], mo_64_q) + le_64_mask = self.LE_64_Mask(fs_in["le64"], le_64_mem_s, le_64_mem_g) + le_64_mem = le_64_mask * le_64_mem_s + (1 - le_64_mask) * le_64_mem_g + re_64_mask = self.RE_64_Mask(fs_in["re64"], re_64_mem_s, re_64_mem_g) + re_64_mem = re_64_mask * re_64_mem_s + (1 - re_64_mask) * re_64_mem_g + mo_64_mask = self.MO_64_Mask(fs_in["mo64"], mo_64_mem_s, mo_64_mem_g) + mo_64_mem = mo_64_mask * mo_64_mem_s + (1 - mo_64_mask) * mo_64_mem_g + else: + le_256_mem = le_256_mem_g + re_256_mem = re_256_mem_g + mo_256_mem = mo_256_mem_g + le_128_mem = le_128_mem_g + re_128_mem = re_128_mem_g + mo_128_mem = mo_128_mem_g + le_64_mem = le_64_mem_g + re_64_mem = re_64_mem_g + mo_64_mem = mo_64_mem_g + + le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in["le256"]) + re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in["re256"]) + mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in["mo256"]) + + ####for 128 + le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in["le128"]) + re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in["re128"]) + mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in["mo128"]) + + ####for 64 + le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in["le64"]) + re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in["re64"]) + mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in["mo64"]) + + EnMem256 = { + "LE256Norm": le_256_mem_norm, + "RE256Norm": re_256_mem_norm, + "MO256Norm": mo_256_mem_norm, + } + EnMem128 = { + "LE128Norm": le_128_mem_norm, + "RE128Norm": re_128_mem_norm, + "MO128Norm": mo_128_mem_norm, + } + EnMem64 = { + "LE64Norm": le_64_mem_norm, + "RE64Norm": re_64_mem_norm, + "MO64Norm": mo_64_mem_norm, + } + Ind256 = {"LE": le_256_inds, "RE": re_256_inds, "MO": mo_256_inds} + Ind128 = {"LE": le_128_inds, "RE": re_128_inds, "MO": mo_128_inds} + Ind64 = {"LE": le_64_inds, "RE": re_64_inds, "MO": mo_64_inds} + return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64 + + def reconstruct(self, fs_in, locs, memstar): + le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = ( + memstar[0]["LE256Norm"], + memstar[0]["RE256Norm"], + memstar[0]["MO256Norm"], + ) + le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = ( + memstar[1]["LE128Norm"], + memstar[1]["RE128Norm"], + memstar[1]["MO128Norm"], + ) + le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = ( + memstar[2]["LE64Norm"], + memstar[2]["RE64Norm"], + memstar[2]["MO64Norm"], + ) + + le_256_final = ( + self.LE_256_Attention(le_256_mem_norm - fs_in["le256"]) * le_256_mem_norm + + fs_in["le256"] + ) + re_256_final = ( + self.RE_256_Attention(re_256_mem_norm - fs_in["re256"]) * re_256_mem_norm + + fs_in["re256"] + ) + mo_256_final = ( + self.MO_256_Attention(mo_256_mem_norm - fs_in["mo256"]) * mo_256_mem_norm + + fs_in["mo256"] + ) + + le_128_final = ( + self.LE_128_Attention(le_128_mem_norm - fs_in["le128"]) * le_128_mem_norm + + fs_in["le128"] + ) + re_128_final = ( + self.RE_128_Attention(re_128_mem_norm - fs_in["re128"]) * re_128_mem_norm + + fs_in["re128"] + ) + mo_128_final = ( + self.MO_128_Attention(mo_128_mem_norm - fs_in["mo128"]) * mo_128_mem_norm + + fs_in["mo128"] + ) + + le_64_final = ( + self.LE_64_Attention(le_64_mem_norm - fs_in["le64"]) * le_64_mem_norm + + fs_in["le64"] + ) + re_64_final = ( + self.RE_64_Attention(re_64_mem_norm - fs_in["re64"]) * re_64_mem_norm + + fs_in["re64"] + ) + mo_64_final = ( + self.MO_64_Attention(mo_64_mem_norm - fs_in["mo64"]) * mo_64_mem_norm + + fs_in["mo64"] + ) + + le_location = locs[:, 0, :] + re_location = locs[:, 1, :] + mo_location = locs[:, 3, :] + + # Somehow with latest Torch it doesn't like numpy wrappers anymore + + # le_location = le_location.cpu().int().numpy() + # re_location = re_location.cpu().int().numpy() + # mo_location = mo_location.cpu().int().numpy() + le_location = le_location.cpu().int() + re_location = re_location.cpu().int() + mo_location = mo_location.cpu().int() + + up_in_256 = fs_in["f256"].clone() # * 0 + up_in_128 = fs_in["f128"].clone() # * 0 + up_in_64 = fs_in["f64"].clone() # * 0 + + for i in range(fs_in["f256"].size(0)): + up_in_256[ + i : i + 1, + :, + le_location[i, 1] // 2 : le_location[i, 3] // 2, + le_location[i, 0] // 2 : le_location[i, 2] // 2, + ] = F.interpolate( + le_256_final[i : i + 1, :, :, :].clone(), + ( + le_location[i, 3] // 2 - le_location[i, 1] // 2, + le_location[i, 2] // 2 - le_location[i, 0] // 2, + ), + mode="bilinear", + align_corners=False, + ) + up_in_256[ + i : i + 1, + :, + re_location[i, 1] // 2 : re_location[i, 3] // 2, + re_location[i, 0] // 2 : re_location[i, 2] // 2, + ] = F.interpolate( + re_256_final[i : i + 1, :, :, :].clone(), + ( + re_location[i, 3] // 2 - re_location[i, 1] // 2, + re_location[i, 2] // 2 - re_location[i, 0] // 2, + ), + mode="bilinear", + align_corners=False, + ) + up_in_256[ + i : i + 1, + :, + mo_location[i, 1] // 2 : mo_location[i, 3] // 2, + mo_location[i, 0] // 2 : mo_location[i, 2] // 2, + ] = F.interpolate( + mo_256_final[i : i + 1, :, :, :].clone(), + ( + mo_location[i, 3] // 2 - mo_location[i, 1] // 2, + mo_location[i, 2] // 2 - mo_location[i, 0] // 2, + ), + mode="bilinear", + align_corners=False, + ) + + up_in_128[ + i : i + 1, + :, + le_location[i, 1] // 4 : le_location[i, 3] // 4, + le_location[i, 0] // 4 : le_location[i, 2] // 4, + ] = F.interpolate( + le_128_final[i : i + 1, :, :, :].clone(), + ( + le_location[i, 3] // 4 - le_location[i, 1] // 4, + le_location[i, 2] // 4 - le_location[i, 0] // 4, + ), + mode="bilinear", + align_corners=False, + ) + up_in_128[ + i : i + 1, + :, + re_location[i, 1] // 4 : re_location[i, 3] // 4, + re_location[i, 0] // 4 : re_location[i, 2] // 4, + ] = F.interpolate( + re_128_final[i : i + 1, :, :, :].clone(), + ( + re_location[i, 3] // 4 - re_location[i, 1] // 4, + re_location[i, 2] // 4 - re_location[i, 0] // 4, + ), + mode="bilinear", + align_corners=False, + ) + up_in_128[ + i : i + 1, + :, + mo_location[i, 1] // 4 : mo_location[i, 3] // 4, + mo_location[i, 0] // 4 : mo_location[i, 2] // 4, + ] = F.interpolate( + mo_128_final[i : i + 1, :, :, :].clone(), + ( + mo_location[i, 3] // 4 - mo_location[i, 1] // 4, + mo_location[i, 2] // 4 - mo_location[i, 0] // 4, + ), + mode="bilinear", + align_corners=False, + ) + + up_in_64[ + i : i + 1, + :, + le_location[i, 1] // 8 : le_location[i, 3] // 8, + le_location[i, 0] // 8 : le_location[i, 2] // 8, + ] = F.interpolate( + le_64_final[i : i + 1, :, :, :].clone(), + ( + le_location[i, 3] // 8 - le_location[i, 1] // 8, + le_location[i, 2] // 8 - le_location[i, 0] // 8, + ), + mode="bilinear", + align_corners=False, + ) + up_in_64[ + i : i + 1, + :, + re_location[i, 1] // 8 : re_location[i, 3] // 8, + re_location[i, 0] // 8 : re_location[i, 2] // 8, + ] = F.interpolate( + re_64_final[i : i + 1, :, :, :].clone(), + ( + re_location[i, 3] // 8 - re_location[i, 1] // 8, + re_location[i, 2] // 8 - re_location[i, 0] // 8, + ), + mode="bilinear", + align_corners=False, + ) + up_in_64[ + i : i + 1, + :, + mo_location[i, 1] // 8 : mo_location[i, 3] // 8, + mo_location[i, 0] // 8 : mo_location[i, 2] // 8, + ] = F.interpolate( + mo_64_final[i : i + 1, :, :, :].clone(), + ( + mo_location[i, 3] // 8 - mo_location[i, 1] // 8, + mo_location[i, 2] // 8 - mo_location[i, 0] // 8, + ), + mode="bilinear", + align_corners=False, + ) + + ms_in_64 = self.MSDilate(fs_in["f64"].clone()) + fea_up1 = self.up1(ms_in_64, up_in_64) + fea_up2 = self.up2(fea_up1, up_in_128) # + fea_up3 = self.up3(fea_up2, up_in_256) # + output = self.up4(fea_up3) # + return output + + def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None): + return self.memorize(sp_imgs, sp_locs) + + def forward(self, lq=None, loc=None, sp_256=None, sp_128=None, sp_64=None): + try: + fs_in = self.E_lq(lq, loc) # low quality images + except Exception as e: + print(e) + + GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer( + fs_in + ) + GeOut = self.reconstruct( + fs_in, loc, memstar=[GeMemNorm256, GeMemNorm128, GeMemNorm64] + ) + if sp_256 is not None and sp_128 is not None and sp_64 is not None: + GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer( + fs_in, sp_256, sp_128, sp_64 + ) + GSOut = self.reconstruct( + fs_in, loc, memstar=[GSMemNorm256, GSMemNorm128, GSMemNorm64] + ) + else: + GSOut = None + return GeOut, GSOut + + +class UpResBlock(nn.Module): + def __init__(self, dim, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d): + super(UpResBlock, self).__init__() + self.Model = nn.Sequential( + SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), + nn.LeakyReLU(0.2), + SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), + ) + + def forward(self, x): + out = x + self.Model(x) + return out diff --git a/processors/Enhance_GFPGAN.py b/processors/Enhance_GFPGAN.py new file mode 100644 index 0000000000000000000000000000000000000000..4e6f4bab8b692eef31dc62b91b04d665f3d019e6 --- /dev/null +++ b/processors/Enhance_GFPGAN.py @@ -0,0 +1,65 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + + +class Enhance_GFPGAN: + plugin_options: dict = None + + model_gfpgan = None + name = None + devicename = None + + processorname = "gfpgan" + type = "enhance" + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_gfpgan is None: + model_path = resolve_relative_path("../models/GFPGANv1.4.onnx") + self.model_gfpgan = onnxruntime.InferenceSession( + model_path, None, providers=roop.globals.execution_providers + ) + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace("mps", "cpu") + + self.name = self.model_gfpgan.get_inputs()[0].name + + def Run( + self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame + ) -> Frame: + # preprocess + input_size = temp_frame.shape[1] + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype("float32") / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + io_binding = self.model_gfpgan.io_binding() + io_binding.bind_cpu_input("input", temp_frame) + io_binding.bind_output("1288", self.devicename) + self.model_gfpgan.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + + # post-process + result = np.clip(result, -1, 1) + result = (result + 1) / 2 + result = result.transpose(1, 2, 0) * 255.0 + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + def Release(self): + self.model_gfpgan = None diff --git a/processors/Enhance_GPEN.py b/processors/Enhance_GPEN.py new file mode 100644 index 0000000000000000000000000000000000000000..e2853a851c8b6c31f20f349585bd23ba97e30198 --- /dev/null +++ b/processors/Enhance_GPEN.py @@ -0,0 +1,65 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + + +class Enhance_GPEN: + plugin_options: dict = None + + model_gpen = None + name = None + devicename = None + + processorname = "gpen" + type = "enhance" + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_gpen is None: + model_path = resolve_relative_path("../models/GPEN-BFR-512.onnx") + self.model_gpen = onnxruntime.InferenceSession( + model_path, None, providers=roop.globals.execution_providers + ) + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace("mps", "cpu") + + self.name = self.model_gpen.get_inputs()[0].name + + def Run( + self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame + ) -> Frame: + # preprocess + input_size = temp_frame.shape[1] + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype("float32") / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + io_binding = self.model_gpen.io_binding() + io_binding.bind_cpu_input("input", temp_frame) + io_binding.bind_output("output", self.devicename) + self.model_gpen.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + + # post-process + result = np.clip(result, -1, 1) + result = (result + 1) / 2 + result = result.transpose(1, 2, 0) * 255.0 + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + def Release(self): + self.model_gpen = None diff --git a/processors/Enhance_RestoreFormerPPlus.py b/processors/Enhance_RestoreFormerPPlus.py new file mode 100644 index 0000000000000000000000000000000000000000..0478c96f1f32326e2b5c3c87ada83e619f881502 --- /dev/null +++ b/processors/Enhance_RestoreFormerPPlus.py @@ -0,0 +1,68 @@ +from typing import Any, List, Callable +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.typing import Face, Frame, FaceSet +from roop.utilities import resolve_relative_path + + +class Enhance_RestoreFormerPPlus: + plugin_options: dict = None + model_restoreformerpplus = None + devicename = None + name = None + + processorname = "restoreformer++" + type = "enhance" + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_restoreformerpplus is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace("mps", "cpu") + model_path = resolve_relative_path("../models/restoreformer_plus_plus.onnx") + self.model_restoreformerpplus = onnxruntime.InferenceSession( + model_path, None, providers=roop.globals.execution_providers + ) + self.model_inputs = self.model_restoreformerpplus.get_inputs() + model_outputs = self.model_restoreformerpplus.get_outputs() + self.io_binding = self.model_restoreformerpplus.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def Run( + self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame + ) -> Frame: + # preprocess + input_size = temp_frame.shape[1] + temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC) + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) + temp_frame = temp_frame.astype("float32") / 255.0 + temp_frame = (temp_frame - 0.5) / 0.5 + temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2) + + self.io_binding.bind_cpu_input( + self.model_inputs[0].name, temp_frame + ) # .astype(np.float32) + self.model_restoreformerpplus.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + + result = np.clip(result, -1, 1) + result = (result + 1) / 2 + result = result.transpose(1, 2, 0) * 255.0 + result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + scale_factor = int(result.shape[1] / input_size) + return result.astype(np.uint8), scale_factor + + def Release(self): + del self.model_restoreformerpplus + self.model_restoreformerpplus = None + del self.io_binding + self.io_binding = None diff --git a/processors/FaceSwapInsightFace.py b/processors/FaceSwapInsightFace.py new file mode 100644 index 0000000000000000000000000000000000000000..4232c9277b272dad76d996950e197dbd3879858c --- /dev/null +++ b/processors/FaceSwapInsightFace.py @@ -0,0 +1,56 @@ +import roop.globals +import numpy as np +import onnx +import onnxruntime + +from roop.typing import Face, Frame +from roop.utilities import resolve_relative_path + + +class FaceSwapInsightFace: + plugin_options: dict = None + model_swap_insightface = None + + processorname = "faceswap" + type = "swap" + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if ( + self.plugin_options["devicename"] != plugin_options["devicename"] + or self.plugin_options["modelname"] != plugin_options["modelname"] + ): + self.Release() + + self.plugin_options = plugin_options + if self.model_swap_insightface is None: + model_path = resolve_relative_path( + "../models/" + self.plugin_options["modelname"] + ) + graph = onnx.load(model_path).graph + self.emap = onnx.numpy_helper.to_array(graph.initializer[-1]) + self.devicename = self.plugin_options["devicename"].replace("mps", "cpu") + self.input_mean = 0.0 + self.input_std = 255.0 + # cuda_options = {"arena_extend_strategy": "kSameAsRequested", 'cudnn_conv_algo_search': 'DEFAULT'} + sess_options = onnxruntime.SessionOptions() + sess_options.enable_cpu_mem_arena = False + self.model_swap_insightface = onnxruntime.InferenceSession( + model_path, sess_options, providers=roop.globals.execution_providers + ) + + def Run(self, source_face: Face, target_face: Face, temp_frame: Frame) -> Frame: + latent = source_face.normed_embedding.reshape((1, -1)) + latent = np.dot(latent, self.emap) + latent /= np.linalg.norm(latent) + io_binding = self.model_swap_insightface.io_binding() + io_binding.bind_cpu_input("target", temp_frame) + io_binding.bind_cpu_input("source", latent) + io_binding.bind_output("output", self.devicename) + self.model_swap_insightface.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu()[0] + return ort_outs[0] + + def Release(self): + del self.model_swap_insightface + self.model_swap_insightface = None diff --git a/processors/Frame_Colorizer.py b/processors/Frame_Colorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..20848b7dc402728888408e4263a99639b87695ab --- /dev/null +++ b/processors/Frame_Colorizer.py @@ -0,0 +1,83 @@ +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.utilities import resolve_relative_path +from roop.typing import Frame + + +class Frame_Colorizer: + plugin_options: dict = None + model_colorizer = None + devicename = None + prev_type = None + + processorname = "deoldify" + type = "frame_colorizer" + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if ( + self.prev_type is not None + and self.prev_type != self.plugin_options["subtype"] + ): + self.Release() + self.prev_type = self.plugin_options["subtype"] + if self.model_colorizer is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace("mps", "cpu") + if self.prev_type == "deoldify_artistic": + model_path = resolve_relative_path( + "../models/Frame/deoldify_artistic.onnx" + ) + elif self.prev_type == "deoldify_stable": + model_path = resolve_relative_path( + "../models/Frame/deoldify_stable.onnx" + ) + + onnxruntime.set_default_logger_severity(3) + self.model_colorizer = onnxruntime.InferenceSession( + model_path, None, providers=roop.globals.execution_providers + ) + self.model_inputs = self.model_colorizer.get_inputs() + model_outputs = self.model_colorizer.get_outputs() + self.io_binding = self.model_colorizer.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def Run(self, input_frame: Frame) -> Frame: + temp_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2GRAY) + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) + temp_frame = cv2.resize(temp_frame, (256, 256)) + temp_frame = temp_frame.transpose((2, 0, 1)) + temp_frame = np.expand_dims(temp_frame, axis=0).astype(np.float32) + self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) + self.model_colorizer.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + colorized_frame = result.transpose(1, 2, 0) + colorized_frame = cv2.resize( + colorized_frame, (input_frame.shape[1], input_frame.shape[0]) + ) + temp_blue_channel, _, _ = cv2.split(input_frame) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2RGB).astype( + np.uint8 + ) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2LAB) + _, color_green_channel, color_red_channel = cv2.split(colorized_frame) + colorized_frame = cv2.merge( + (temp_blue_channel, color_green_channel, color_red_channel) + ) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_LAB2BGR) + return colorized_frame.astype(np.uint8) + + def Release(self): + del self.model_colorizer + self.model_colorizer = None + del self.io_binding + self.io_binding = None diff --git a/processors/Frame_Filter.py b/processors/Frame_Filter.py new file mode 100644 index 0000000000000000000000000000000000000000..1e507a117a28e37d2c7f1f14c66e39f4ab4c378a --- /dev/null +++ b/processors/Frame_Filter.py @@ -0,0 +1,118 @@ +import cv2 +import numpy as np + +from roop.typing import Frame + + +class Frame_Filter: + processorname = "generic_filter" + type = "frame_processor" + + plugin_options: dict = None + + c64_palette = np.array( + [ + [0, 0, 0], + [255, 255, 255], + [0x81, 0x33, 0x38], + [0x75, 0xCE, 0xC8], + [0x8E, 0x3C, 0x97], + [0x56, 0xAC, 0x4D], + [0x2E, 0x2C, 0x9B], + [0xED, 0xF1, 0x71], + [0x8E, 0x50, 0x29], + [0x55, 0x38, 0x00], + [0xC4, 0x6C, 0x71], + [0x4A, 0x4A, 0x4A], + [0x7B, 0x7B, 0x7B], + [0xA9, 0xFF, 0x9F], + [0x70, 0x6D, 0xEB], + [0xB2, 0xB2, 0xB2], + ] + ) + + def RenderC64Screen(self, image): + # Simply round the color values to the nearest color in the palette + image = cv2.resize(image, (320, 200)) + palette = self.c64_palette / 255.0 # Normalize palette + img_normalized = image / 255.0 # Normalize image + + # Calculate the index in the palette that is closest to each pixel in the image + indices = np.sqrt( + ((img_normalized[:, :, None, :] - palette[None, None, :, :]) ** 2).sum( + axis=3 + ) + ).argmin(axis=2) + # Map the image to the palette colors + mapped_image = palette[indices] + return (mapped_image * 255).astype(np.uint8) # Denormalize and return the image + + def RenderDetailEnhance(self, image): + return cv2.detailEnhance(image) + + def RenderStylize(self, image): + return cv2.stylization(image) + + def RenderPencilSketch(self, image): + imgray, imout = cv2.pencilSketch( + image, sigma_s=60, sigma_r=0.07, shade_factor=0.05 + ) + return imout + + def RenderCartoon(self, image): + numDownSamples = 2 # number of downscaling steps + numBilateralFilters = 7 # number of bilateral filtering steps + + img_color = image + for _ in range(numDownSamples): + img_color = cv2.pyrDown(img_color) + for _ in range(numBilateralFilters): + img_color = cv2.bilateralFilter(img_color, 9, 9, 7) + for _ in range(numDownSamples): + img_color = cv2.pyrUp(img_color) + img_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + img_blur = cv2.medianBlur(img_gray, 7) + img_edge = cv2.adaptiveThreshold( + img_blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 2 + ) + img_edge = cv2.cvtColor(img_edge, cv2.COLOR_GRAY2RGB) + if img_color.shape != image.shape: + img_color = cv2.resize( + img_color, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_LINEAR, + ) + if img_color.shape != img_edge.shape: + img_edge = cv2.resize( + img_edge, + (img_color.shape[1], img_color.shape[0]), + interpolation=cv2.INTER_LINEAR, + ) + return cv2.bitwise_and(img_color, img_edge) + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + self.plugin_options = plugin_options + + def Run(self, temp_frame: Frame) -> Frame: + subtype = self.plugin_options["subtype"] + if subtype == "stylize": + return self.RenderStylize(temp_frame).astype(np.uint8) + if subtype == "detailenhance": + return self.RenderDetailEnhance(temp_frame).astype(np.uint8) + if subtype == "pencil": + return self.RenderPencilSketch(temp_frame).astype(np.uint8) + if subtype == "cartoon": + return self.RenderCartoon(temp_frame).astype(np.uint8) + if subtype == "C64": + return self.RenderC64Screen(temp_frame).astype(np.uint8) + + def Release(self): + pass + + def getProcessedResolution(self, width, height): + if self.plugin_options["subtype"] == "C64": + return (320, 200) + return None diff --git a/processors/Frame_Masking.py b/processors/Frame_Masking.py new file mode 100644 index 0000000000000000000000000000000000000000..35bb3b9fe50bf9fdfe84b4ba14f5561ff21b3b40 --- /dev/null +++ b/processors/Frame_Masking.py @@ -0,0 +1,74 @@ +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.utilities import resolve_relative_path +from roop.typing import Frame + + +class Frame_Masking: + plugin_options: dict = None + model_masking = None + devicename = None + name = None + + processorname = "removebg" + type = "frame_masking" + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_masking is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"] + self.devicename = self.devicename.replace("mps", "cpu") + model_path = resolve_relative_path("../models/Frame/isnet-general-use.onnx") + self.model_masking = onnxruntime.InferenceSession( + model_path, None, providers=roop.globals.execution_providers + ) + self.model_inputs = self.model_masking.get_inputs() + model_outputs = self.model_masking.get_outputs() + self.io_binding = self.model_masking.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def Run(self, temp_frame: Frame) -> Frame: + # Pre process:Resize, BGR->RGB, float32 cast + input_image = cv2.resize(temp_frame, (1024, 1024)) + input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) + mean = [0.5, 0.5, 0.5] + std = [1.0, 1.0, 1.0] + input_image = (input_image / 255.0 - mean) / std + input_image = input_image.transpose(2, 0, 1) + input_image = np.expand_dims(input_image, axis=0) + input_image = input_image.astype("float32") + + self.io_binding.bind_cpu_input(self.model_inputs[0].name, input_image) + self.model_masking.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + del ort_outs + # Post process:squeeze, Sigmoid, Normarize, uint8 cast + mask = np.squeeze(result[0]) + min_value = np.min(mask) + max_value = np.max(mask) + mask = (mask - min_value) / (max_value - min_value) + # mask = np.where(mask < score_th, 0, 1) + # mask *= 255 + mask = cv2.resize( + mask, + (temp_frame.shape[1], temp_frame.shape[0]), + interpolation=cv2.INTER_LINEAR, + ) + mask = np.reshape(mask, [mask.shape[0], mask.shape[1], 1]) + result = mask * temp_frame.astype(np.float32) + return result.astype(np.uint8) + + def Release(self): + del self.model_masking + self.model_masking = None + del self.io_binding + self.io_binding = None diff --git a/processors/Frame_Upscale.py b/processors/Frame_Upscale.py new file mode 100644 index 0000000000000000000000000000000000000000..f89430b72a5fa8bcfbb90c7a5716147406176e21 --- /dev/null +++ b/processors/Frame_Upscale.py @@ -0,0 +1,151 @@ +import cv2 +import numpy as np +import onnxruntime +import roop.globals + +from roop.utilities import resolve_relative_path, conditional_thread_semaphore +from roop.typing import Frame + + +class Frame_Upscale: + plugin_options: dict = None + model_upscale = None + devicename = None + prev_type = None + + processorname = "upscale" + type = "frame_enhancer" + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if ( + self.prev_type is not None + and self.prev_type != self.plugin_options["subtype"] + ): + self.Release() + self.prev_type = self.plugin_options["subtype"] + if self.model_upscale is None: + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace("mps", "cpu") + if self.prev_type == "esrganx4": + model_path = resolve_relative_path( + "../models/Frame/real_esrgan_x4.onnx" + ) + self.scale = 4 + elif self.prev_type == "esrganx2": + model_path = resolve_relative_path( + "../models/Frame/real_esrgan_x2.onnx" + ) + self.scale = 2 + elif self.prev_type == "lsdirx4": + model_path = resolve_relative_path("../models/Frame/lsdir_x4.onnx") + self.scale = 4 + onnxruntime.set_default_logger_severity(3) + self.model_upscale = onnxruntime.InferenceSession( + model_path, None, providers=roop.globals.execution_providers + ) + self.model_inputs = self.model_upscale.get_inputs() + model_outputs = self.model_upscale.get_outputs() + self.io_binding = self.model_upscale.io_binding() + self.io_binding.bind_output(model_outputs[0].name, self.devicename) + + def getProcessedResolution(self, width, height): + return (width * self.scale, height * self.scale) + + # borrowed from facefusion -> https://github.com/facefusion/facefusion + def prepare_tile_frame(self, tile_frame: Frame) -> Frame: + tile_frame = np.expand_dims(tile_frame[:, :, ::-1], axis=0) + tile_frame = tile_frame.transpose(0, 3, 1, 2) + tile_frame = tile_frame.astype(np.float32) / 255 + return tile_frame + + def normalize_tile_frame(self, tile_frame: Frame) -> Frame: + tile_frame = tile_frame.transpose(0, 2, 3, 1).squeeze(0) * 255 + tile_frame = tile_frame.clip(0, 255).astype(np.uint8)[:, :, ::-1] + return tile_frame + + def create_tile_frames(self, input_frame: Frame, size): + input_frame = np.pad( + input_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0)) + ) + tile_width = size[0] - 2 * size[2] + pad_size_bottom = size[2] + tile_width - input_frame.shape[0] % tile_width + pad_size_right = size[2] + tile_width - input_frame.shape[1] % tile_width + pad_vision_frame = np.pad( + input_frame, ((size[2], pad_size_bottom), (size[2], pad_size_right), (0, 0)) + ) + pad_height, pad_width = pad_vision_frame.shape[:2] + row_range = range(size[2], pad_height - size[2], tile_width) + col_range = range(size[2], pad_width - size[2], tile_width) + tile_frames = [] + + for row_frame in row_range: + top = row_frame - size[2] + bottom = row_frame + size[2] + tile_width + for column_vision_frame in col_range: + left = column_vision_frame - size[2] + right = column_vision_frame + size[2] + tile_width + tile_frames.append(pad_vision_frame[top:bottom, left:right, :]) + return tile_frames, pad_width, pad_height + + def merge_tile_frames( + self, + tile_frames, + temp_width: int, + temp_height: int, + pad_width: int, + pad_height: int, + size, + ) -> Frame: + merge_frame = np.zeros((pad_height, pad_width, 3)).astype(np.uint8) + tile_width = tile_frames[0].shape[1] - 2 * size[2] + tiles_per_row = min(pad_width // tile_width, len(tile_frames)) + + for index, tile_frame in enumerate(tile_frames): + tile_frame = tile_frame[size[2] : -size[2], size[2] : -size[2]] + row_index = index // tiles_per_row + col_index = index % tiles_per_row + top = row_index * tile_frame.shape[0] + bottom = top + tile_frame.shape[0] + left = col_index * tile_frame.shape[1] + right = left + tile_frame.shape[1] + merge_frame[top:bottom, left:right, :] = tile_frame + merge_frame = merge_frame[ + size[1] : size[1] + temp_height, size[1] : size[1] + temp_width, : + ] + return merge_frame + + def Run(self, temp_frame: Frame) -> Frame: + size = (128, 8, 2) + temp_height, temp_width = temp_frame.shape[:2] + upscale_tile_frames, pad_width, pad_height = self.create_tile_frames( + temp_frame, size + ) + + for index, tile_frame in enumerate(upscale_tile_frames): + tile_frame = self.prepare_tile_frame(tile_frame) + with conditional_thread_semaphore(): + self.io_binding.bind_cpu_input(self.model_inputs[0].name, tile_frame) + self.model_upscale.run_with_iobinding(self.io_binding) + ort_outs = self.io_binding.copy_outputs_to_cpu() + result = ort_outs[0] + upscale_tile_frames[index] = self.normalize_tile_frame(result) + final_frame = self.merge_tile_frames( + upscale_tile_frames, + temp_width * self.scale, + temp_height * self.scale, + pad_width * self.scale, + pad_height * self.scale, + (size[0] * self.scale, size[1] * self.scale, size[2] * self.scale), + ) + return final_frame.astype(np.uint8) + + def Release(self): + del self.model_upscale + self.model_upscale = None + del self.io_binding + self.io_binding = None diff --git a/processors/Mask_Clip2Seg.py b/processors/Mask_Clip2Seg.py new file mode 100644 index 0000000000000000000000000000000000000000..0319d5b3187eefeda74e67eea00af3d67f4fff73 --- /dev/null +++ b/processors/Mask_Clip2Seg.py @@ -0,0 +1,110 @@ +import cv2 +import numpy as np +import torch +import threading +from torchvision import transforms +from clip.clipseg import CLIPDensePredT +import numpy as np + +from roop.typing import Frame + +THREAD_LOCK_CLIP = threading.Lock() + + +class Mask_Clip2Seg: + plugin_options: dict = None + model_clip = None + + processorname = "clip2seg" + type = "mask" + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_clip is None: + self.model_clip = CLIPDensePredT( + version="ViT-B/16", reduce_dim=64, complex_trans_conv=True + ) + self.model_clip.eval() + self.model_clip.load_state_dict( + torch.load( + "models/CLIP/rd64-uni-refined.pth", map_location=torch.device("cpu") + ), + strict=False, + ) + + device = torch.device(self.plugin_options["devicename"]) + self.model_clip.to(device) + + def Run(self, img1, keywords: str) -> Frame: + if keywords is None or len(keywords) < 1 or img1 is None: + return img1 + + source_image_small = cv2.resize(img1, (256, 256)) + + img_mask = np.full( + (source_image_small.shape[0], source_image_small.shape[1]), + 0, + dtype=np.float32, + ) + mask_border = 1 + l = 0 + t = 0 + r = 1 + b = 1 + + mask_blur = 5 + clip_blur = 5 + + img_mask = cv2.rectangle( + img_mask, + (mask_border + int(l), mask_border + int(t)), + (256 - mask_border - int(r), 256 - mask_border - int(b)), + (255, 255, 255), + -1, + ) + img_mask = cv2.GaussianBlur(img_mask, (mask_blur * 2 + 1, mask_blur * 2 + 1), 0) + img_mask /= 255 + + input_image = source_image_small + + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + transforms.Resize((256, 256)), + ] + ) + img = transform(input_image).unsqueeze(0) + + thresh = 0.5 + prompts = keywords.split(",") + with THREAD_LOCK_CLIP: + with torch.no_grad(): + preds = self.model_clip(img.repeat(len(prompts), 1, 1, 1), prompts)[0] + clip_mask = torch.sigmoid(preds[0][0]) + for i in range(len(prompts) - 1): + clip_mask += torch.sigmoid(preds[i + 1][0]) + + clip_mask = clip_mask.data.cpu().numpy() + np.clip(clip_mask, 0, 1) + + clip_mask[clip_mask > thresh] = 1.0 + clip_mask[clip_mask <= thresh] = 0.0 + kernel = np.ones((5, 5), np.float32) + clip_mask = cv2.dilate(clip_mask, kernel, iterations=1) + clip_mask = cv2.GaussianBlur( + clip_mask, (clip_blur * 2 + 1, clip_blur * 2 + 1), 0 + ) + + img_mask *= clip_mask + img_mask[img_mask < 0.0] = 0.0 + return img_mask + + def Release(self): + self.model_clip = None diff --git a/processors/Mask_XSeg.py b/processors/Mask_XSeg.py new file mode 100644 index 0000000000000000000000000000000000000000..e0dfeb48e7631aa3126cea233ee4929ad16c004d --- /dev/null +++ b/processors/Mask_XSeg.py @@ -0,0 +1,54 @@ +import numpy as np +import cv2 +import onnxruntime +import roop.globals + +from roop.typing import Frame +from roop.utilities import resolve_relative_path, conditional_thread_semaphore + + +class Mask_XSeg: + plugin_options: dict = None + + model_xseg = None + + processorname = "mask_xseg" + type = "mask" + + def Initialize(self, plugin_options: dict): + if self.plugin_options is not None: + if self.plugin_options["devicename"] != plugin_options["devicename"]: + self.Release() + + self.plugin_options = plugin_options + if self.model_xseg is None: + model_path = resolve_relative_path("../models/xseg.onnx") + onnxruntime.set_default_logger_severity(3) + self.model_xseg = onnxruntime.InferenceSession( + model_path, None, providers=roop.globals.execution_providers + ) + self.model_inputs = self.model_xseg.get_inputs() + self.model_outputs = self.model_xseg.get_outputs() + + # replace Mac mps with cpu for the moment + self.devicename = self.plugin_options["devicename"].replace("mps", "cpu") + + def Run(self, img1, keywords: str) -> Frame: + temp_frame = cv2.resize(img1, (256, 256), cv2.INTER_CUBIC) + temp_frame = temp_frame.astype("float32") / 255.0 + temp_frame = temp_frame[None, ...] + io_binding = self.model_xseg.io_binding() + io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) + io_binding.bind_output(self.model_outputs[0].name, self.devicename) + self.model_xseg.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu() + result = ort_outs[0][0] + result = np.clip(result, 0, 1.0) + result[result < 0.1] = 0 + # invert values to mask areas to keep + result = 1.0 - result + return result + + def Release(self): + del self.model_xseg + self.model_xseg = None diff --git a/processors/__init__.py b/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/processors/__pycache__/Enhance_CodeFormer.cpython-310.pyc b/processors/__pycache__/Enhance_CodeFormer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efb0082a88cf6b0ee1d06b575d4208c4404e2181 Binary files /dev/null and b/processors/__pycache__/Enhance_CodeFormer.cpython-310.pyc differ diff --git a/processors/__pycache__/Enhance_DMDNet.cpython-310.pyc b/processors/__pycache__/Enhance_DMDNet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb7328bff84992948542f60e3ba198785f7262bb Binary files /dev/null and b/processors/__pycache__/Enhance_DMDNet.cpython-310.pyc differ diff --git a/processors/__pycache__/FaceSwapInsightFace.cpython-310.pyc b/processors/__pycache__/FaceSwapInsightFace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b67b4ef2ee1cfb93b4a846edf3b1d3656be2db9 Binary files /dev/null and b/processors/__pycache__/FaceSwapInsightFace.cpython-310.pyc differ diff --git a/processors/__pycache__/__init__.cpython-310.pyc b/processors/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f32a76c2270399ea82f92053209bd59007adc5e0 Binary files /dev/null and b/processors/__pycache__/__init__.cpython-310.pyc differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..48a71f59c60ca1b86d58ebc09556bf85fde3e266 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +--extra-index-url https://download.pytorch.org/whl/cu124 +numpy==1.26.4 +gradio==5.9.1 +opencv-python-headless==4.10.0.84 +onnx==1.16.1 +insightface==0.7.3 +albucore==0.0.16 +psutil==5.9.6 +torch==2.5.1+cu124; sys_platform != 'darwin' +torch==2.5.1; sys_platform == 'darwin' +torchvision==0.20.1+cu124; sys_platform != 'darwin' +torchvision==0.20.1; sys_platform == 'darwin' +onnxruntime==1.20.1; sys_platform == 'darwin' and platform_machine != 'arm64' +onnxruntime-silicon==1.20.1; sys_platform == 'darwin' and platform_machine == 'arm64' +onnxruntime-gpu==1.20.1; sys_platform != 'darwin' +tqdm==4.66.4 +ftfy +regex +pyvirtualcam +pydantic==2.10.4 diff --git a/run.py b/run.py new file mode 100644 index 0000000000000000000000000000000000000000..604d9966627567f47f0969e4a7eba3d2e8cac88b --- /dev/null +++ b/run.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 + +from roop import core +share=True +def run(): + args = parse_args() + roop.globals.CFG.server_share = args.share # <-- toggle share here + ... +if __name__ == "__main__": + core.run() + diff --git a/template_parser.py b/template_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..a23f322779a8e7139c538948d2fc9c6fdd409919 --- /dev/null +++ b/template_parser.py @@ -0,0 +1,23 @@ +import re +from datetime import datetime + +template_functions = { + "timestamp": lambda data: str(int(datetime.now().timestamp())), + "i": lambda data: data.get("index", False), + "file": lambda data: data.get("file", False), + "date": lambda data: datetime.now().strftime("%Y-%m-%d"), + "time": lambda data: datetime.now().strftime("%H-%M-%S"), +} + + +def parse(text: str, data: dict): + pattern = r"\{([^}]+)\}" + + matches = re.findall(pattern, text) + + for match in matches: + replacement = template_functions[match](data) + if replacement is not False: + text = text.replace(f"{{{match}}}", replacement) + + return text diff --git a/typing.py b/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..1fe4183d325e2c2de827eb872ee60311723ad7ee --- /dev/null +++ b/typing.py @@ -0,0 +1,9 @@ +from typing import Any + +from insightface.app.common import Face +from roop.FaceSet import FaceSet +import numpy + +Face = Face +FaceSet = FaceSet +Frame = numpy.ndarray[Any, Any] diff --git a/util_ffmpeg.py b/util_ffmpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..502cf0e1f5d2fc34098126149f81fc85928db00c --- /dev/null +++ b/util_ffmpeg.py @@ -0,0 +1,281 @@ +import os +import subprocess +import roop.globals +import roop.utilities as util + +from typing import List, Any + + +def run_ffmpeg(args: List[str]) -> bool: + commands = [ + "ffmpeg", + "-hide_banner", + "-hwaccel", + "auto", + "-y", + "-loglevel", + roop.globals.log_level, + ] + commands.extend(args) + print("Running ffmpeg") + try: + subprocess.check_output(commands, stderr=subprocess.STDOUT) + return True + except Exception as e: + print("Running ffmpeg failed! Commandline:") + print(" ".join(commands)) + return False + + +def cut_video( + original_video: str, + cut_video: str, + start_frame: int, + end_frame: int, + reencode: bool, +): + fps = util.detect_fps(original_video) + start_time = start_frame / fps + num_frames = end_frame - start_frame + + if reencode: + run_ffmpeg( + [ + "-ss", + format(start_time, ".2f"), + "-i", + original_video, + "-c:v", + roop.globals.video_encoder, + "-c:a", + "aac", + "-frames:v", + str(num_frames), + cut_video, + ] + ) + else: + run_ffmpeg( + [ + "-ss", + format(start_time, ".2f"), + "-i", + original_video, + "-frames:v", + str(num_frames), + "-c:v", + "copy", + "-c:a", + "copy", + cut_video, + ] + ) + + +def join_videos(videos: List[str], dest_filename: str, simple: bool): + if simple: + txtfilename = util.resolve_relative_path("../temp") + txtfilename = os.path.join(txtfilename, "joinvids.txt") + with open(txtfilename, "w", encoding="utf-8") as f: + for v in videos: + v = v.replace("\\", "/") + f.write(f"file {v}\n") + commands = [ + "-f", + "concat", + "-safe", + "0", + "-i", + f"{txtfilename}", + "-vcodec", + "copy", + f"{dest_filename}", + ] + run_ffmpeg(commands) + + else: + inputs = [] + filter = "" + for i, v in enumerate(videos): + inputs.append("-i") + inputs.append(v) + filter += f"[{i}:v:0][{i}:a:0]" + run_ffmpeg( + [ + " ".join(inputs), + "-filter_complex", + f'"{filter}concat=n={len(videos)}:v=1:a=1[outv][outa]"', + "-map", + '"[outv]"', + "-map", + '"[outa]"', + dest_filename, + ] + ) + + # filter += f'[{i}:v:0][{i}:a:0]' + # run_ffmpeg([" ".join(inputs), '-filter_complex', f'"{filter}concat=n={len(videos)}:v=1:a=1[outv][outa]"', '-map', '"[outv]"', '-map', '"[outa]"', dest_filename]) + + +def extract_frames( + target_path: str, trim_frame_start, trim_frame_end, fps: float +) -> bool: + util.create_temp(target_path) + temp_directory_path = util.get_temp_directory_path(target_path) + commands = [ + "-i", + target_path, + "-q:v", + "1", + "-pix_fmt", + "rgb24", + ] + if trim_frame_start is not None and trim_frame_end is not None: + commands.extend( + [ + "-vf", + "trim=start_frame=" + + str(trim_frame_start) + + ":end_frame=" + + str(trim_frame_end) + + ",fps=" + + str(fps), + ] + ) + commands.extend( + [ + "-vsync", + "0", + os.path.join( + temp_directory_path, "%06d." + roop.globals.CFG.output_image_format + ), + ] + ) + return run_ffmpeg(commands) + + +def create_video( + target_path: str, + dest_filename: str, + fps: float = 24.0, + temp_directory_path: str = None, +) -> None: + if temp_directory_path is None: + temp_directory_path = util.get_temp_directory_path(target_path) + run_ffmpeg( + [ + "-r", + str(fps), + "-i", + os.path.join( + temp_directory_path, f"%06d.{roop.globals.CFG.output_image_format}" + ), + "-c:v", + roop.globals.video_encoder, + "-crf", + str(roop.globals.video_quality), + "-pix_fmt", + "yuv420p", + "-vf", + "colorspace=bt709:iall=bt601-6-625:fast=1", + "-y", + dest_filename, + ] + ) + return dest_filename + + +def create_gif_from_video(video_path: str, gif_path): + from roop.capturer import get_video_frame, release_video + + fps = util.detect_fps(video_path) + frame = get_video_frame(video_path) + release_video() + + scalex = frame.shape[0] + scaley = frame.shape[1] + + if scalex >= scaley: + scaley = -1 + else: + scalex = -1 + + run_ffmpeg( + [ + "-i", + video_path, + "-vf", + f"fps={fps},scale={int(scalex)}:{int(scaley)}:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse", + "-loop", + "0", + gif_path, + ] + ) + + +def create_video_from_gif(gif_path: str, output_path): + fps = util.detect_fps(gif_path) + filter = """scale='trunc(in_w/2)*2':'trunc(in_h/2)*2',format=yuv420p,fps=10""" + run_ffmpeg( + [ + "-i", + gif_path, + "-vf", + f'"{filter}"', + "-movflags", + "+faststart", + "-shortest", + output_path, + ] + ) + + +def repair_video(original_video: str, final_video: str): + run_ffmpeg( + [ + "-i", + original_video, + "-movflags", + "faststart", + "-acodec", + "copy", + "-vcodec", + "copy", + final_video, + ] + ) + + +def restore_audio( + intermediate_video: str, + original_video: str, + trim_frame_start, + trim_frame_end, + final_video: str, +) -> None: + fps = util.detect_fps(original_video) + commands = ["-i", intermediate_video] + if trim_frame_start is None and trim_frame_end is None: + commands.extend(["-c:a", "copy"]) + else: + # if trim_frame_start is not None: + # start_time = trim_frame_start / fps + # commands.extend([ '-ss', format(start_time, ".2f")]) + # else: + # commands.extend([ '-ss', '0' ]) + # if trim_frame_end is not None: + # end_time = trim_frame_end / fps + # commands.extend([ '-to', format(end_time, ".2f")]) + # commands.extend([ '-c:a', 'aac' ]) + if trim_frame_start is not None: + start_time = trim_frame_start / fps + commands.extend(["-ss", format(start_time, ".2f")]) + else: + commands.extend(["-ss", "0"]) + if trim_frame_end is not None: + end_time = trim_frame_end / fps + commands.extend(["-to", format(end_time, ".2f")]) + commands.extend(["-i", original_video, "-c", "copy"]) + + commands.extend(["-map", "0:v:0", "-map", "1:a:0?", "-shortest", final_video]) + run_ffmpeg(commands) diff --git a/utilities.py b/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..effbce0f3e747db8d9f0f2b4f689bc4206636195 --- /dev/null +++ b/utilities.py @@ -0,0 +1,442 @@ +import glob +import mimetypes +import os +import platform +import shutil +import ssl +import subprocess +import sys +import urllib +import torch +import gradio +import tempfile +import cv2 +import zipfile +import traceback +import threading +import threading +import random + +from typing import Union, Any +from contextlib import nullcontext + +from pathlib import Path +from typing import List, Any +from tqdm import tqdm +from scipy.spatial import distance +from urllib.parse import urlparse + +import roop.template_parser as template_parser + +import roop.globals + +TEMP_FILE = "temp.mp4" +TEMP_DIRECTORY = "temp" + +THREAD_SEMAPHORE = threading.Semaphore() +NULL_CONTEXT = nullcontext() + + +# monkey patch ssl for mac +if platform.system().lower() == "darwin": + ssl._create_default_https_context = ssl._create_unverified_context + + +# https://github.com/facefusion/facefusion/blob/master/facefusion +def detect_fps(target_path: str) -> float: + fps = 24.0 + cap = cv2.VideoCapture(target_path) + if cap.isOpened(): + fps = cap.get(cv2.CAP_PROP_FPS) + cap.release() + return fps + + +# Gradio wants Images in RGB +def convert_to_gradio(image): + if image is None: + return None + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + +def sort_filenames_ignore_path(filenames): + """Sorts a list of filenames containing a complete path by their filename, + while retaining their original path. + + Args: + filenames: A list of filenames containing a complete path. + + Returns: + A sorted list of filenames containing a complete path. + """ + filename_path_tuples = [ + (os.path.split(filename)[1], filename) for filename in filenames + ] + sorted_filename_path_tuples = sorted(filename_path_tuples, key=lambda x: x[0]) + return [ + filename_path_tuple[1] for filename_path_tuple in sorted_filename_path_tuples + ] + + +def sort_rename_frames(path: str): + filenames = os.listdir(path) + filenames.sort() + for i in range(len(filenames)): + of = os.path.join(path, filenames[i]) + newidx = i + 1 + new_filename = os.path.join( + path, f"{newidx:06d}." + roop.globals.CFG.output_image_format + ) + os.rename(of, new_filename) + + +def get_temp_frame_paths(target_path: str) -> List[str]: + temp_directory_path = get_temp_directory_path(target_path) + return glob.glob( + ( + os.path.join( + glob.escape(temp_directory_path), + f"*.{roop.globals.CFG.output_image_format}", + ) + ) + ) + + +def get_temp_directory_path(target_path: str) -> str: + target_name, _ = os.path.splitext(os.path.basename(target_path)) + target_directory_path = os.path.dirname(target_path) + return os.path.join(target_directory_path, TEMP_DIRECTORY, target_name) + + +def get_temp_output_path(target_path: str) -> str: + temp_directory_path = get_temp_directory_path(target_path) + return os.path.join(temp_directory_path, TEMP_FILE) + + +def normalize_output_path(source_path: str, target_path: str, output_path: str) -> Any: + if source_path and target_path: + source_name, _ = os.path.splitext(os.path.basename(source_path)) + target_name, target_extension = os.path.splitext(os.path.basename(target_path)) + if os.path.isdir(output_path): + return os.path.join( + output_path, source_name + "-" + target_name + target_extension + ) + return output_path + + +def get_destfilename_from_path( + srcfilepath: str, destfilepath: str, extension: str +) -> str: + fn, ext = os.path.splitext(os.path.basename(srcfilepath)) + if "." in extension: + return os.path.join(destfilepath, f"{fn}{extension}") + return os.path.join(destfilepath, f"{fn}{extension}{ext}") + + +def replace_template(file_path: str, index: int = 0) -> str: + fn, ext = os.path.splitext(os.path.basename(file_path)) + + # Remove the "__temp" placeholder that was used as a temporary filename + fn = fn.replace("__temp", "") + + template = roop.globals.CFG.output_template + replaced_filename = template_parser.parse( + template, {"index": str(index), "file": fn} + ) + + return os.path.join(roop.globals.output_path, f"{replaced_filename}{ext}") + + +def create_temp(target_path: str) -> None: + temp_directory_path = get_temp_directory_path(target_path) + Path(temp_directory_path).mkdir(parents=True, exist_ok=True) + + +def move_temp(target_path: str, output_path: str) -> None: + temp_output_path = get_temp_output_path(target_path) + if os.path.isfile(temp_output_path): + if os.path.isfile(output_path): + os.remove(output_path) + shutil.move(temp_output_path, output_path) + + +def clean_temp(target_path: str) -> None: + temp_directory_path = get_temp_directory_path(target_path) + parent_directory_path = os.path.dirname(temp_directory_path) + if not roop.globals.keep_frames and os.path.isdir(temp_directory_path): + shutil.rmtree(temp_directory_path) + if os.path.exists(parent_directory_path) and not os.listdir(parent_directory_path): + os.rmdir(parent_directory_path) + + +def delete_temp_frames(filename: str) -> None: + dir = os.path.dirname(os.path.dirname(filename)) + shutil.rmtree(dir) + + +def has_image_extension(image_path: str) -> bool: + return image_path.lower().endswith(("png", "jpg", "jpeg", "webp")) + + +def has_extension(filepath: str, extensions: List[str]) -> bool: + return filepath.lower().endswith(tuple(extensions)) + + +def is_image(image_path: str) -> bool: + if image_path and os.path.isfile(image_path): + if image_path.endswith(".webp"): + return True + mimetype, _ = mimetypes.guess_type(image_path) + return bool(mimetype and mimetype.startswith("image/")) + return False + + +def is_video(video_path: str) -> bool: + if video_path and os.path.isfile(video_path): + mimetype, _ = mimetypes.guess_type(video_path) + return bool(mimetype and mimetype.startswith("video/")) + return False + + +def conditional_download( + download_directory_path: str, urls: List[str | List[str]] +) -> None: + if not os.path.exists(download_directory_path): + os.makedirs(download_directory_path) + for url in urls: + if isinstance(url, str): + download_file_path = os.path.join( + download_directory_path, os.path.basename(url) + ) + if os.path.exists(download_file_path): + continue + request = urllib.request.urlopen(url) # type: ignore[attr-defined] + total = int(request.headers.get("Content-Length", 0)) + with tqdm( + total=total, + desc=f"Downloading {os.path.basename(urlparse(url).path)}", + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as progress: + urllib.request.urlretrieve( + url, + download_file_path, + reporthook=lambda count, block_size, total_size: progress.update( + block_size + ), + ) # type: ignore[attr-defined] + elif isinstance(url, list): + for _url in url: + download_file_path = os.path.join( + download_directory_path, os.path.basename(_url) + ) + if os.path.exists(download_file_path): + break + request = urllib.request.urlopen(_url) # type: ignore[attr-defined] + if not request.status == 200: + continue + total = int(request.headers.get("Content-Length", 0)) + with tqdm( + total=total, + desc=f"Downloading {os.path.basename(urlparse(_url).path)}", + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as progress: + urllib.request.urlretrieve( + _url, + download_file_path, + reporthook=lambda count, + block_size, + total_size: progress.update(block_size), + ) # type: ignore[attr-defined] + break + + +def get_local_files_from_folder(folder: str) -> List[str]: + if not os.path.exists(folder) or not os.path.isdir(folder): + return None + files = [ + os.path.join(folder, f) + for f in os.listdir(folder) + if os.path.isfile(os.path.join(folder, f)) + ] + return files + + +def resolve_relative_path(path: str) -> str: + return os.path.abspath(os.path.join(os.path.dirname(__file__), path)) + + +def get_device() -> str: + if len(roop.globals.execution_providers) < 1: + roop.globals.execution_providers = ["CPUExecutionProvider"] + + prov = roop.globals.execution_providers[0] + if "CoreMLExecutionProvider" in prov: + return "mps" + if "CUDAExecutionProvider" in prov or "ROCMExecutionProvider" in prov: + return "cuda" + if "OpenVINOExecutionProvider" in prov: + return "mkl" + return "cpu" + + +def str_to_class(module_name, class_name) -> Any: + from importlib import import_module + + class_ = None + try: + module_ = import_module(module_name) + try: + class_ = getattr(module_, class_name)() + except AttributeError: + print(f"Class {class_name} does not exist") + except ImportError: + print(f"Module {module_name} does not exist") + return class_ + + +def is_installed(name: str) -> bool: + return shutil.which(name) + + +# Taken from https://stackoverflow.com/a/68842705 +def get_platform() -> str: + if sys.platform == "linux": + try: + proc_version = open("/proc/version").read() + if "Microsoft" in proc_version: + return "wsl" + except: + pass + return sys.platform + + +def open_with_default_app(filename: str): + if filename == None: + return + platform = get_platform() + if platform == "darwin": + subprocess.call(("open", filename)) + elif platform in ["win64", "win32"]: + os.startfile(filename.replace("/", "\\")) + elif platform == "wsl": + subprocess.call("cmd.exe /C start".split() + [filename]) + else: # linux variants + subprocess.call("xdg-open", filename) + + +def prepare_for_batch(target_files) -> str: + print("Preparing temp files") + tempfolder = os.path.join(tempfile.gettempdir(), "rooptmp") + if os.path.exists(tempfolder): + shutil.rmtree(tempfolder) + Path(tempfolder).mkdir(parents=True, exist_ok=True) + for f in target_files: + newname = os.path.basename(f.name) + shutil.move(f.name, os.path.join(tempfolder, newname)) + return tempfolder + + +def zip(files, zipname): + with zipfile.ZipFile(zipname, "w") as zip_file: + for f in files: + zip_file.write(f, os.path.basename(f)) + + +def unzip(zipfilename: str, target_path: str): + with zipfile.ZipFile(zipfilename, "r") as zip_file: + zip_file.extractall(target_path) + + +def mkdir_with_umask(directory): + oldmask = os.umask(0) + # mode needs octal + os.makedirs(directory, mode=0o775, exist_ok=True) + os.umask(oldmask) + + +def open_folder(path: str): + platform = get_platform() + try: + if platform == "darwin": + subprocess.call(("open", path)) + elif platform in ["win64", "win32"]: + open_with_default_app(path) + elif platform == "wsl": + subprocess.call("cmd.exe /C start".split() + [path]) + else: # linux variants + subprocess.Popen(["xdg-open", path]) + except Exception as e: + traceback.print_exc() + pass + # import webbrowser + # webbrowser.open(url) + + +def create_version_html() -> str: + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + versions_html = f""" +python: {python_version} +• +torch: {getattr(torch, "__long_version__", torch.__version__)} +• +gradio: {gradio.__version__} +""" + return versions_html + + +def compute_cosine_distance(emb1, emb2) -> float: + return distance.cosine(emb1, emb2) + + +def has_cuda_device(): + return torch.cuda is not None and torch.cuda.is_available() + + +def print_cuda_info(): + try: + print( + f"Number of CUDA devices: {torch.cuda.device_count()} Currently used Id: {torch.cuda.current_device()} Device Name: {torch.cuda.get_device_name(torch.cuda.current_device())}" + ) + except: + print("No CUDA device found!") + + +def clean_dir(path: str): + contents = os.listdir(path) + for item in contents: + item_path = os.path.join(path, item) + try: + if os.path.isfile(item_path): + os.remove(item_path) + elif os.path.isdir(item_path): + shutil.rmtree(item_path) + except Exception as e: + print(e) + + +def conditional_thread_semaphore() -> Union[Any, Any]: + if ( + "DmlExecutionProvider" in roop.globals.execution_providers + or "ROCMExecutionProvider" in roop.globals.execution_providers + ): + return THREAD_SEMAPHORE + return NULL_CONTEXT + + +def shuffle_array(arr): + """ + Shuffles the given array in place using the Fisher-Yates shuffle algorithm. + + Args: + arr: The array to be shuffled. + + Returns: + None. The array is shuffled in place. + """ + for i in range(len(arr) - 1, 0, -1): + j = random.randint(0, i) + arr[i], arr[j] = arr[j], arr[i] diff --git a/virtualcam.py b/virtualcam.py new file mode 100644 index 0000000000000000000000000000000000000000..89de40df8bc7df55b9f220c4cae67615e06ce1a4 --- /dev/null +++ b/virtualcam.py @@ -0,0 +1,120 @@ +import cv2 +import roop.globals +import ui.globals +import pyvirtualcam +import threading +import platform + + +cam_active = False +cam_thread = None +vcam = None + + +def virtualcamera( + swap_model, streamobs, use_xseg, use_mouthrestore, cam_num, width, height +): + from roop.ProcessOptions import ProcessOptions + from roop.core import live_swap, get_processing_plugins + + global cam_active + + # time.sleep(2) + print("Starting capture") + cap = cv2.VideoCapture( + cam_num, + cv2.CAP_DSHOW if platform.system() != "Darwin" else cv2.CAP_AVFOUNDATION, + ) + if not cap.isOpened(): + print("Cannot open camera") + cap.release() + del cap + return + + pref_width = width + pref_height = height + pref_fps_in = 30 + cap.set(cv2.CAP_PROP_FRAME_WIDTH, pref_width) + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, pref_height) + cap.set(cv2.CAP_PROP_FPS, pref_fps_in) + cam_active = True + + # native format UYVY + + cam = None + if streamobs: + print("Detecting virtual cam devices") + cam = pyvirtualcam.Camera( + width=pref_width, + height=pref_height, + fps=pref_fps_in, + fmt=pyvirtualcam.PixelFormat.BGR, + print_fps=False, + ) + if cam: + print(f"Using virtual camera: {cam.device}") + print(f"Using {cam.native_fmt}") + else: + print(f"Not streaming to virtual camera!") + subsample_size = roop.globals.subsample_size + + options = ProcessOptions( + swap_model, + get_processing_plugins("mask_xseg" if use_xseg else None), + roop.globals.distance_threshold, + roop.globals.blend_ratio, + "all", + 0, + None, + None, + 1, + subsample_size, + False, + use_mouthrestore, + ) + while cam_active: + ret, frame = cap.read() + if not ret: + break + + if len(roop.globals.INPUT_FACESETS) > 0: + frame = live_swap(frame, options) + if cam: + cam.send(frame) + cam.sleep_until_next_frame() + ui.globals.ui_camera_frame = frame + + if cam: + cam.close() + cap.release() + print("Camera stopped") + + +def start_virtual_cam( + swap_model, streamobs, use_xseg, use_mouthrestore, cam_number, resolution +): + global cam_thread, cam_active + + if not cam_active: + width, height = map(int, resolution.split("x")) + cam_thread = threading.Thread( + target=virtualcamera, + args=[ + swap_model, + streamobs, + use_xseg, + use_mouthrestore, + cam_number, + width, + height, + ], + ) + cam_thread.start() + + +def stop_virtual_cam(): + global cam_active, cam_thread + + if cam_active: + cam_active = False + cam_thread.join() diff --git a/vr_util.py b/vr_util.py new file mode 100644 index 0000000000000000000000000000000000000000..24d8c1e0c0f61837203559b86f2f47735b97646f --- /dev/null +++ b/vr_util.py @@ -0,0 +1,57 @@ +import cv2 +import numpy as np + +# VR Lense Distortion +# Taken from https://github.com/g0kuvonlange/vrswap + + +def get_perspective(img, FOV, THETA, PHI, height, width): + # + # THETA is left/right angle, PHI is up/down angle, both in degree + # + [orig_width, orig_height, _] = img.shape + equ_h = orig_height + equ_w = orig_width + equ_cx = (equ_w - 1) / 2.0 + equ_cy = (equ_h - 1) / 2.0 + + wFOV = FOV + hFOV = float(height) / width * wFOV + + w_len = np.tan(np.radians(wFOV / 2.0)) + h_len = np.tan(np.radians(hFOV / 2.0)) + + x_map = np.ones([height, width], np.float32) + y_map = np.tile(np.linspace(-w_len, w_len, width), [height, 1]) + z_map = -np.tile(np.linspace(-h_len, h_len, height), [width, 1]).T + + D = np.sqrt(x_map**2 + y_map**2 + z_map**2) + xyz = np.stack((x_map, y_map, z_map), axis=2) / np.repeat( + D[:, :, np.newaxis], 3, axis=2 + ) + + y_axis = np.array([0.0, 1.0, 0.0], np.float32) + z_axis = np.array([0.0, 0.0, 1.0], np.float32) + [R1, _] = cv2.Rodrigues(z_axis * np.radians(THETA)) + [R2, _] = cv2.Rodrigues(np.dot(R1, y_axis) * np.radians(-PHI)) + + xyz = xyz.reshape([height * width, 3]).T + xyz = np.dot(R1, xyz) + xyz = np.dot(R2, xyz).T + lat = np.arcsin(xyz[:, 2]) + lon = np.arctan2(xyz[:, 1], xyz[:, 0]) + + lon = lon.reshape([height, width]) / np.pi * 180 + lat = -lat.reshape([height, width]) / np.pi * 180 + + lon = lon / 180 * equ_cx + equ_cx + lat = lat / 90 * equ_cy + equ_cy + + persp = cv2.remap( + img, + lon.astype(np.float32), + lat.astype(np.float32), + cv2.INTER_CUBIC, + borderMode=cv2.BORDER_WRAP, + ) + return persp