Spaces:
Runtime error
Runtime error
Upload 54 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- FaceSet.py +21 -0
- ProcessEntry.py +7 -0
- ProcessMgr.py +1058 -0
- ProcessOptions.py +35 -0
- StreamWriter.py +60 -0
- __init__.py +0 -0
- __pycache__/FaceSet.cpython-310.pyc +0 -0
- __pycache__/ProcessEntry.cpython-310.pyc +0 -0
- __pycache__/ProcessMgr.cpython-310.pyc +0 -0
- __pycache__/ProcessOptions.cpython-310.pyc +0 -0
- __pycache__/StreamWriter.cpython-310.pyc +0 -0
- __pycache__/__init__.cpython-310.pyc +0 -0
- __pycache__/capturer.cpython-310.pyc +0 -0
- __pycache__/core.cpython-310.pyc +0 -0
- __pycache__/face_util.cpython-310.pyc +0 -0
- __pycache__/ffmpeg_writer.cpython-310.pyc +0 -0
- __pycache__/globals.cpython-310.pyc +0 -0
- __pycache__/metadata.cpython-310.pyc +0 -0
- __pycache__/template_parser.cpython-310.pyc +0 -0
- __pycache__/typing.cpython-310.pyc +0 -0
- __pycache__/util_ffmpeg.cpython-310.pyc +0 -0
- __pycache__/utilities.cpython-310.pyc +0 -0
- __pycache__/vr_util.cpython-310.pyc +0 -0
- capturer.py +50 -0
- core.py +605 -0
- face_util.py +352 -0
- ffmpeg_writer.py +240 -0
- globals.py +54 -0
- metadata.py +2 -0
- processors/Enhance_CodeFormer.py +76 -0
- processors/Enhance_DMDNet.py +1425 -0
- processors/Enhance_GFPGAN.py +65 -0
- processors/Enhance_GPEN.py +65 -0
- processors/Enhance_RestoreFormerPPlus.py +68 -0
- processors/FaceSwapInsightFace.py +56 -0
- processors/Frame_Colorizer.py +83 -0
- processors/Frame_Filter.py +118 -0
- processors/Frame_Masking.py +74 -0
- processors/Frame_Upscale.py +151 -0
- processors/Mask_Clip2Seg.py +110 -0
- processors/Mask_XSeg.py +54 -0
- processors/__init__.py +0 -0
- processors/__pycache__/Enhance_CodeFormer.cpython-310.pyc +0 -0
- processors/__pycache__/Enhance_DMDNet.cpython-310.pyc +0 -0
- processors/__pycache__/FaceSwapInsightFace.cpython-310.pyc +0 -0
- processors/__pycache__/__init__.cpython-310.pyc +0 -0
- requirements.txt +20 -0
- run.py +11 -0
- template_parser.py +23 -0
- typing.py +9 -0
FaceSet.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class FaceSet:
|
| 5 |
+
faces = []
|
| 6 |
+
ref_images = []
|
| 7 |
+
embedding_average = "None"
|
| 8 |
+
embeddings_backup = None
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.faces = []
|
| 12 |
+
self.ref_images = []
|
| 13 |
+
self.embeddings_backup = None
|
| 14 |
+
|
| 15 |
+
def AverageEmbeddings(self):
|
| 16 |
+
if len(self.faces) > 1 and self.embeddings_backup is None:
|
| 17 |
+
self.embeddings_backup = self.faces[0]["embedding"]
|
| 18 |
+
embeddings = [face.embedding for face in self.faces]
|
| 19 |
+
|
| 20 |
+
self.faces[0]["embedding"] = np.mean(embeddings, axis=0)
|
| 21 |
+
# try median too?
|
ProcessEntry.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ProcessEntry:
|
| 2 |
+
def __init__(self, filename: str, start: int, end: int, fps: float):
|
| 3 |
+
self.filename = filename
|
| 4 |
+
self.finalname = None
|
| 5 |
+
self.startframe = start
|
| 6 |
+
self.endframe = end
|
| 7 |
+
self.fps = fps
|
ProcessMgr.py
ADDED
|
@@ -0,0 +1,1058 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import psutil
|
| 5 |
+
|
| 6 |
+
from roop.ProcessOptions import ProcessOptions
|
| 7 |
+
|
| 8 |
+
from roop.face_util import (
|
| 9 |
+
get_first_face,
|
| 10 |
+
get_all_faces,
|
| 11 |
+
rotate_anticlockwise,
|
| 12 |
+
rotate_clockwise,
|
| 13 |
+
clamp_cut_values,
|
| 14 |
+
)
|
| 15 |
+
from roop.utilities import (
|
| 16 |
+
compute_cosine_distance,
|
| 17 |
+
get_device,
|
| 18 |
+
str_to_class,
|
| 19 |
+
shuffle_array,
|
| 20 |
+
)
|
| 21 |
+
import roop.vr_util as vr
|
| 22 |
+
|
| 23 |
+
from typing import Any, List, Callable
|
| 24 |
+
from roop.typing import Frame, Face
|
| 25 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 26 |
+
from threading import Thread, Lock
|
| 27 |
+
from queue import Queue
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
from roop.ffmpeg_writer import FFMPEG_VideoWriter
|
| 30 |
+
from roop.StreamWriter import StreamWriter
|
| 31 |
+
import roop.globals
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Poor man's enum to be able to compare to int
|
| 35 |
+
class eNoFaceAction:
|
| 36 |
+
USE_ORIGINAL_FRAME = 0
|
| 37 |
+
RETRY_ROTATED = 1
|
| 38 |
+
SKIP_FRAME = 2
|
| 39 |
+
SKIP_FRAME_IF_DISSIMILAR = (3,)
|
| 40 |
+
USE_LAST_SWAPPED = 4
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_queue(temp_frame_paths: List[str]) -> Queue[str]:
|
| 44 |
+
queue: Queue[str] = Queue()
|
| 45 |
+
for frame_path in temp_frame_paths:
|
| 46 |
+
queue.put(frame_path)
|
| 47 |
+
return queue
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def pick_queue(queue: Queue[str], queue_per_future: int) -> List[str]:
|
| 51 |
+
queues = []
|
| 52 |
+
for _ in range(queue_per_future):
|
| 53 |
+
if not queue.empty():
|
| 54 |
+
queues.append(queue.get())
|
| 55 |
+
return queues
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ProcessMgr:
|
| 59 |
+
input_face_datas = []
|
| 60 |
+
target_face_datas = []
|
| 61 |
+
|
| 62 |
+
imagemask = None
|
| 63 |
+
|
| 64 |
+
processors = []
|
| 65 |
+
options: ProcessOptions = None
|
| 66 |
+
|
| 67 |
+
num_threads = 1
|
| 68 |
+
current_index = 0
|
| 69 |
+
processing_threads = 1
|
| 70 |
+
buffer_wait_time = 0.1
|
| 71 |
+
|
| 72 |
+
lock = Lock()
|
| 73 |
+
|
| 74 |
+
frames_queue = None
|
| 75 |
+
processed_queue = None
|
| 76 |
+
|
| 77 |
+
videowriter = None
|
| 78 |
+
streamwriter = None
|
| 79 |
+
|
| 80 |
+
progress_gradio = None
|
| 81 |
+
total_frames = 0
|
| 82 |
+
|
| 83 |
+
num_frames_no_face = 0
|
| 84 |
+
last_swapped_frame = None
|
| 85 |
+
|
| 86 |
+
output_to_file = None
|
| 87 |
+
output_to_cam = None
|
| 88 |
+
|
| 89 |
+
plugins = {
|
| 90 |
+
"faceswap": "FaceSwapInsightFace",
|
| 91 |
+
"mask_clip2seg": "Mask_Clip2Seg",
|
| 92 |
+
"mask_xseg": "Mask_XSeg",
|
| 93 |
+
"codeformer": "Enhance_CodeFormer",
|
| 94 |
+
"gfpgan": "Enhance_GFPGAN",
|
| 95 |
+
"dmdnet": "Enhance_DMDNet",
|
| 96 |
+
"gpen": "Enhance_GPEN",
|
| 97 |
+
"restoreformer++": "Enhance_RestoreFormerPPlus",
|
| 98 |
+
"colorizer": "Frame_Colorizer",
|
| 99 |
+
"filter_generic": "Frame_Filter",
|
| 100 |
+
"removebg": "Frame_Masking",
|
| 101 |
+
"upscale": "Frame_Upscale",
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
def __init__(self, progress):
|
| 105 |
+
if progress is not None:
|
| 106 |
+
self.progress_gradio = progress
|
| 107 |
+
|
| 108 |
+
def reuseOldProcessor(self, name: str):
|
| 109 |
+
for p in self.processors:
|
| 110 |
+
if p.processorname == name:
|
| 111 |
+
return p
|
| 112 |
+
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
def initialize(self, input_faces, target_faces, options):
|
| 116 |
+
self.input_face_datas = input_faces
|
| 117 |
+
self.target_face_datas = target_faces
|
| 118 |
+
self.num_frames_no_face = 0
|
| 119 |
+
self.last_swapped_frame = None
|
| 120 |
+
self.options = options
|
| 121 |
+
devicename = get_device()
|
| 122 |
+
|
| 123 |
+
roop.globals.g_desired_face_analysis = [
|
| 124 |
+
"landmark_3d_68",
|
| 125 |
+
"landmark_2d_106",
|
| 126 |
+
"detection",
|
| 127 |
+
"recognition",
|
| 128 |
+
]
|
| 129 |
+
if options.swap_mode == "all_female" or options.swap_mode == "all_male":
|
| 130 |
+
roop.globals.g_desired_face_analysis.append("genderage")
|
| 131 |
+
elif options.swap_mode == "all_random":
|
| 132 |
+
# don't modify original list
|
| 133 |
+
self.input_face_datas = input_faces.copy()
|
| 134 |
+
shuffle_array(self.input_face_datas)
|
| 135 |
+
|
| 136 |
+
for p in self.processors:
|
| 137 |
+
newp = next(
|
| 138 |
+
(x for x in options.processors.keys() if x == p.processorname), None
|
| 139 |
+
)
|
| 140 |
+
if newp is None:
|
| 141 |
+
p.Release()
|
| 142 |
+
del p
|
| 143 |
+
|
| 144 |
+
newprocessors = []
|
| 145 |
+
for key, extoption in options.processors.items():
|
| 146 |
+
p = self.reuseOldProcessor(key)
|
| 147 |
+
if p is None:
|
| 148 |
+
classname = self.plugins[key]
|
| 149 |
+
module = "roop.processors." + classname
|
| 150 |
+
p = str_to_class(module, classname)
|
| 151 |
+
if p is not None:
|
| 152 |
+
extoption.update({"devicename": devicename})
|
| 153 |
+
if p.type == "swap":
|
| 154 |
+
if self.options.swap_modelname == "InSwapper 128":
|
| 155 |
+
extoption.update({"modelname": "inswapper_128.onnx"})
|
| 156 |
+
elif self.options.swap_modelname == "ReSwapper 128":
|
| 157 |
+
extoption.update({"modelname": "reswapper_128.onnx"})
|
| 158 |
+
elif self.options.swap_modelname == "ReSwapper 256":
|
| 159 |
+
extoption.update({"modelname": "reswapper_256.onnx"})
|
| 160 |
+
|
| 161 |
+
p.Initialize(extoption)
|
| 162 |
+
newprocessors.append(p)
|
| 163 |
+
else:
|
| 164 |
+
print(f"Not using {module}")
|
| 165 |
+
self.processors = newprocessors
|
| 166 |
+
|
| 167 |
+
if (
|
| 168 |
+
isinstance(self.options.imagemask, dict)
|
| 169 |
+
and self.options.imagemask.get("layers")
|
| 170 |
+
and len(self.options.imagemask["layers"]) > 0
|
| 171 |
+
):
|
| 172 |
+
self.options.imagemask = self.options.imagemask.get("layers")[0]
|
| 173 |
+
# Get rid of alpha
|
| 174 |
+
self.options.imagemask = cv2.cvtColor(
|
| 175 |
+
self.options.imagemask, cv2.COLOR_RGBA2GRAY
|
| 176 |
+
)
|
| 177 |
+
if np.any(self.options.imagemask):
|
| 178 |
+
mo = self.input_face_datas[0].faces[0].mask_offsets
|
| 179 |
+
self.options.imagemask = self.blur_area(
|
| 180 |
+
self.options.imagemask, mo[4], mo[5]
|
| 181 |
+
)
|
| 182 |
+
self.options.imagemask = self.options.imagemask.astype(np.float32) / 255
|
| 183 |
+
self.options.imagemask = cv2.cvtColor(
|
| 184 |
+
self.options.imagemask, cv2.COLOR_GRAY2RGB
|
| 185 |
+
)
|
| 186 |
+
else:
|
| 187 |
+
self.options.imagemask = None
|
| 188 |
+
|
| 189 |
+
self.options.frame_processing = False
|
| 190 |
+
for p in self.processors:
|
| 191 |
+
if p.type.startswith("frame_"):
|
| 192 |
+
self.options.frame_processing = True
|
| 193 |
+
|
| 194 |
+
def run_batch(self, source_files, target_files, threads: int = 1):
|
| 195 |
+
progress_bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
|
| 196 |
+
self.total_frames = len(source_files)
|
| 197 |
+
self.num_threads = threads
|
| 198 |
+
with tqdm(
|
| 199 |
+
total=self.total_frames,
|
| 200 |
+
desc="Processing",
|
| 201 |
+
unit="frame",
|
| 202 |
+
dynamic_ncols=True,
|
| 203 |
+
bar_format=progress_bar_format,
|
| 204 |
+
) as progress:
|
| 205 |
+
with ThreadPoolExecutor(max_workers=threads) as executor:
|
| 206 |
+
futures = []
|
| 207 |
+
queue = create_queue(source_files)
|
| 208 |
+
queue_per_future = max(len(source_files) // threads, 1)
|
| 209 |
+
while not queue.empty():
|
| 210 |
+
future = executor.submit(
|
| 211 |
+
self.process_frames,
|
| 212 |
+
source_files,
|
| 213 |
+
target_files,
|
| 214 |
+
pick_queue(queue, queue_per_future),
|
| 215 |
+
lambda: self.update_progress(progress),
|
| 216 |
+
)
|
| 217 |
+
futures.append(future)
|
| 218 |
+
for future in as_completed(futures):
|
| 219 |
+
future.result()
|
| 220 |
+
|
| 221 |
+
def process_frames(
|
| 222 |
+
self,
|
| 223 |
+
source_files: List[str],
|
| 224 |
+
target_files: List[str],
|
| 225 |
+
current_files,
|
| 226 |
+
update: Callable[[], None],
|
| 227 |
+
) -> None:
|
| 228 |
+
for f in current_files:
|
| 229 |
+
if not roop.globals.processing:
|
| 230 |
+
return
|
| 231 |
+
|
| 232 |
+
# Decode the byte array into an OpenCV image
|
| 233 |
+
temp_frame = cv2.imdecode(np.fromfile(f, dtype=np.uint8), cv2.IMREAD_COLOR)
|
| 234 |
+
if temp_frame is not None:
|
| 235 |
+
if self.options.frame_processing:
|
| 236 |
+
for p in self.processors:
|
| 237 |
+
frame = p.Run(temp_frame)
|
| 238 |
+
resimg = frame
|
| 239 |
+
else:
|
| 240 |
+
resimg = self.process_frame(temp_frame)
|
| 241 |
+
if resimg is not None:
|
| 242 |
+
i = source_files.index(f)
|
| 243 |
+
# Also let numpy write the file to support utf-8/16 filenames
|
| 244 |
+
cv2.imencode(f".{roop.globals.CFG.output_image_format}", resimg)[
|
| 245 |
+
1
|
| 246 |
+
].tofile(target_files[i])
|
| 247 |
+
if update:
|
| 248 |
+
update()
|
| 249 |
+
|
| 250 |
+
def read_frames_thread(self, cap, frame_start, frame_end, num_threads):
|
| 251 |
+
num_frame = 0
|
| 252 |
+
total_num = frame_end - frame_start
|
| 253 |
+
if frame_start > 0:
|
| 254 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_start)
|
| 255 |
+
|
| 256 |
+
while True and roop.globals.processing:
|
| 257 |
+
ret, frame = cap.read()
|
| 258 |
+
if not ret:
|
| 259 |
+
break
|
| 260 |
+
|
| 261 |
+
self.frames_queue[num_frame % num_threads].put(frame, block=True)
|
| 262 |
+
num_frame += 1
|
| 263 |
+
if num_frame == total_num:
|
| 264 |
+
break
|
| 265 |
+
|
| 266 |
+
for i in range(num_threads):
|
| 267 |
+
self.frames_queue[i].put(None)
|
| 268 |
+
|
| 269 |
+
def process_videoframes(self, threadindex, progress) -> None:
|
| 270 |
+
while True:
|
| 271 |
+
frame = self.frames_queue[threadindex].get()
|
| 272 |
+
if frame is None:
|
| 273 |
+
self.processing_threads -= 1
|
| 274 |
+
self.processed_queue[threadindex].put((False, None))
|
| 275 |
+
return
|
| 276 |
+
else:
|
| 277 |
+
if self.options.frame_processing:
|
| 278 |
+
for p in self.processors:
|
| 279 |
+
frame = p.Run(frame)
|
| 280 |
+
resimg = frame
|
| 281 |
+
else:
|
| 282 |
+
resimg = self.process_frame(frame)
|
| 283 |
+
self.processed_queue[threadindex].put((True, resimg))
|
| 284 |
+
del frame
|
| 285 |
+
progress()
|
| 286 |
+
|
| 287 |
+
def write_frames_thread(self):
|
| 288 |
+
nextindex = 0
|
| 289 |
+
num_producers = self.num_threads
|
| 290 |
+
|
| 291 |
+
while True:
|
| 292 |
+
process, frame = self.processed_queue[nextindex % self.num_threads].get()
|
| 293 |
+
nextindex += 1
|
| 294 |
+
if frame is not None:
|
| 295 |
+
if self.output_to_file:
|
| 296 |
+
self.videowriter.write_frame(frame)
|
| 297 |
+
if self.output_to_cam:
|
| 298 |
+
self.streamwriter.WriteToStream(frame)
|
| 299 |
+
del frame
|
| 300 |
+
elif process == False:
|
| 301 |
+
num_producers -= 1
|
| 302 |
+
if num_producers < 1:
|
| 303 |
+
return
|
| 304 |
+
|
| 305 |
+
def run_batch_inmem(
|
| 306 |
+
self,
|
| 307 |
+
output_method,
|
| 308 |
+
source_video,
|
| 309 |
+
target_video,
|
| 310 |
+
frame_start,
|
| 311 |
+
frame_end,
|
| 312 |
+
fps,
|
| 313 |
+
threads: int = 1,
|
| 314 |
+
):
|
| 315 |
+
if len(self.processors) < 1:
|
| 316 |
+
print("No processor defined!")
|
| 317 |
+
return
|
| 318 |
+
|
| 319 |
+
cap = cv2.VideoCapture(source_video)
|
| 320 |
+
# frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 321 |
+
frame_count = (frame_end - frame_start) + 1
|
| 322 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 323 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 324 |
+
|
| 325 |
+
processed_resolution = None
|
| 326 |
+
for p in self.processors:
|
| 327 |
+
if hasattr(p, "getProcessedResolution"):
|
| 328 |
+
processed_resolution = p.getProcessedResolution(width, height)
|
| 329 |
+
print(f"Processed resolution: {processed_resolution}")
|
| 330 |
+
if processed_resolution is not None:
|
| 331 |
+
width = processed_resolution[0]
|
| 332 |
+
height = processed_resolution[1]
|
| 333 |
+
|
| 334 |
+
self.total_frames = frame_count
|
| 335 |
+
self.num_threads = threads
|
| 336 |
+
|
| 337 |
+
self.processing_threads = self.num_threads
|
| 338 |
+
self.frames_queue = []
|
| 339 |
+
self.processed_queue = []
|
| 340 |
+
for _ in range(threads):
|
| 341 |
+
self.frames_queue.append(Queue(1))
|
| 342 |
+
self.processed_queue.append(Queue(1))
|
| 343 |
+
|
| 344 |
+
self.output_to_file = output_method != "Virtual Camera"
|
| 345 |
+
self.output_to_cam = (
|
| 346 |
+
output_method == "Virtual Camera" or output_method == "Both"
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if self.output_to_file:
|
| 350 |
+
self.videowriter = FFMPEG_VideoWriter(
|
| 351 |
+
target_video,
|
| 352 |
+
(width, height),
|
| 353 |
+
fps,
|
| 354 |
+
codec=roop.globals.video_encoder,
|
| 355 |
+
crf=roop.globals.video_quality,
|
| 356 |
+
audiofile=None,
|
| 357 |
+
)
|
| 358 |
+
if self.output_to_cam:
|
| 359 |
+
self.streamwriter = StreamWriter((width, height), int(fps))
|
| 360 |
+
|
| 361 |
+
readthread = Thread(
|
| 362 |
+
target=self.read_frames_thread, args=(cap, frame_start, frame_end, threads)
|
| 363 |
+
)
|
| 364 |
+
readthread.start()
|
| 365 |
+
|
| 366 |
+
writethread = Thread(target=self.write_frames_thread)
|
| 367 |
+
writethread.start()
|
| 368 |
+
|
| 369 |
+
progress_bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
|
| 370 |
+
with tqdm(
|
| 371 |
+
total=self.total_frames,
|
| 372 |
+
desc="Processing",
|
| 373 |
+
unit="frames",
|
| 374 |
+
dynamic_ncols=True,
|
| 375 |
+
bar_format=progress_bar_format,
|
| 376 |
+
) as progress:
|
| 377 |
+
with ThreadPoolExecutor(
|
| 378 |
+
thread_name_prefix="swap_proc", max_workers=self.num_threads
|
| 379 |
+
) as executor:
|
| 380 |
+
futures = []
|
| 381 |
+
|
| 382 |
+
for threadindex in range(threads):
|
| 383 |
+
future = executor.submit(
|
| 384 |
+
self.process_videoframes,
|
| 385 |
+
threadindex,
|
| 386 |
+
lambda: self.update_progress(progress),
|
| 387 |
+
)
|
| 388 |
+
futures.append(future)
|
| 389 |
+
|
| 390 |
+
for future in as_completed(futures):
|
| 391 |
+
future.result()
|
| 392 |
+
# wait for the task to complete
|
| 393 |
+
readthread.join()
|
| 394 |
+
writethread.join()
|
| 395 |
+
cap.release()
|
| 396 |
+
if self.output_to_file:
|
| 397 |
+
self.videowriter.close()
|
| 398 |
+
if self.output_to_cam:
|
| 399 |
+
self.streamwriter.Close()
|
| 400 |
+
|
| 401 |
+
self.frames_queue.clear()
|
| 402 |
+
self.processed_queue.clear()
|
| 403 |
+
|
| 404 |
+
def update_progress(self, progress: Any = None) -> None:
|
| 405 |
+
process = psutil.Process(os.getpid())
|
| 406 |
+
memory_usage = process.memory_info().rss / 1024 / 1024 / 1024
|
| 407 |
+
progress.set_postfix(
|
| 408 |
+
{
|
| 409 |
+
"memory_usage": "{:.2f}".format(memory_usage).zfill(5) + "GB",
|
| 410 |
+
"execution_threads": self.num_threads,
|
| 411 |
+
}
|
| 412 |
+
)
|
| 413 |
+
progress.update(1)
|
| 414 |
+
if self.progress_gradio is not None:
|
| 415 |
+
self.progress_gradio(
|
| 416 |
+
(progress.n, self.total_frames),
|
| 417 |
+
desc="Processing",
|
| 418 |
+
total=self.total_frames,
|
| 419 |
+
unit="frames",
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
def process_frame(self, frame: Frame):
|
| 423 |
+
if len(self.input_face_datas) < 1 and not self.options.show_face_masking:
|
| 424 |
+
return frame
|
| 425 |
+
temp_frame = frame.copy()
|
| 426 |
+
num_swapped, temp_frame = self.swap_faces(frame, temp_frame)
|
| 427 |
+
if num_swapped > 0:
|
| 428 |
+
if roop.globals.no_face_action == eNoFaceAction.SKIP_FRAME_IF_DISSIMILAR:
|
| 429 |
+
if len(self.input_face_datas) > num_swapped:
|
| 430 |
+
return None
|
| 431 |
+
self.num_frames_no_face = 0
|
| 432 |
+
self.last_swapped_frame = temp_frame.copy()
|
| 433 |
+
return temp_frame
|
| 434 |
+
if roop.globals.no_face_action == eNoFaceAction.USE_LAST_SWAPPED:
|
| 435 |
+
if (
|
| 436 |
+
self.last_swapped_frame is not None
|
| 437 |
+
and self.num_frames_no_face < self.options.max_num_reuse_frame
|
| 438 |
+
):
|
| 439 |
+
self.num_frames_no_face += 1
|
| 440 |
+
return self.last_swapped_frame.copy()
|
| 441 |
+
return frame
|
| 442 |
+
|
| 443 |
+
elif roop.globals.no_face_action == eNoFaceAction.USE_ORIGINAL_FRAME:
|
| 444 |
+
return frame
|
| 445 |
+
if roop.globals.no_face_action == eNoFaceAction.SKIP_FRAME:
|
| 446 |
+
# This only works with in-mem processing, as it simply skips the frame.
|
| 447 |
+
# For 'extract frames' it simply leaves the unprocessed frame unprocessed and it gets used in the final output by ffmpeg.
|
| 448 |
+
# If we could delete that frame here, that'd work but that might cause ffmpeg to fail unless the frames are renamed, and I don't think we have the info on what frame it actually is?????
|
| 449 |
+
# alternatively, it could mark all the necessary frames for deletion, delete them at the end, then rename the remaining frames that might work?
|
| 450 |
+
return None
|
| 451 |
+
else:
|
| 452 |
+
return self.retry_rotated(frame)
|
| 453 |
+
|
| 454 |
+
def retry_rotated(self, frame):
|
| 455 |
+
copyframe = frame.copy()
|
| 456 |
+
copyframe = rotate_clockwise(copyframe)
|
| 457 |
+
temp_frame = copyframe.copy()
|
| 458 |
+
num_swapped, temp_frame = self.swap_faces(copyframe, temp_frame)
|
| 459 |
+
if num_swapped > 0:
|
| 460 |
+
return rotate_anticlockwise(temp_frame)
|
| 461 |
+
|
| 462 |
+
copyframe = frame.copy()
|
| 463 |
+
copyframe = rotate_anticlockwise(copyframe)
|
| 464 |
+
temp_frame = copyframe.copy()
|
| 465 |
+
num_swapped, temp_frame = self.swap_faces(copyframe, temp_frame)
|
| 466 |
+
if num_swapped > 0:
|
| 467 |
+
return rotate_clockwise(temp_frame)
|
| 468 |
+
del copyframe
|
| 469 |
+
return frame
|
| 470 |
+
|
| 471 |
+
def swap_faces(self, frame, temp_frame):
|
| 472 |
+
num_faces_found = 0
|
| 473 |
+
|
| 474 |
+
if self.options.swap_mode == "first":
|
| 475 |
+
face = get_first_face(frame)
|
| 476 |
+
|
| 477 |
+
if face is None:
|
| 478 |
+
return num_faces_found, frame
|
| 479 |
+
|
| 480 |
+
num_faces_found += 1
|
| 481 |
+
temp_frame = self.process_face(
|
| 482 |
+
self.options.selected_index, face, temp_frame
|
| 483 |
+
)
|
| 484 |
+
del face
|
| 485 |
+
|
| 486 |
+
else:
|
| 487 |
+
faces = get_all_faces(frame)
|
| 488 |
+
if faces is None:
|
| 489 |
+
return num_faces_found, frame
|
| 490 |
+
|
| 491 |
+
if self.options.swap_mode == "all":
|
| 492 |
+
for face in faces:
|
| 493 |
+
num_faces_found += 1
|
| 494 |
+
temp_frame = self.process_face(
|
| 495 |
+
self.options.selected_index, face, temp_frame
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
elif (
|
| 499 |
+
self.options.swap_mode == "all_input"
|
| 500 |
+
or self.options.swap_mode == "all_random"
|
| 501 |
+
):
|
| 502 |
+
for i, face in enumerate(faces):
|
| 503 |
+
num_faces_found += 1
|
| 504 |
+
if i < len(self.input_face_datas):
|
| 505 |
+
temp_frame = self.process_face(i, face, temp_frame)
|
| 506 |
+
else:
|
| 507 |
+
break
|
| 508 |
+
|
| 509 |
+
elif self.options.swap_mode == "selected":
|
| 510 |
+
num_targetfaces = len(self.target_face_datas)
|
| 511 |
+
use_index = num_targetfaces == 1
|
| 512 |
+
for i, tf in enumerate(self.target_face_datas):
|
| 513 |
+
for face in faces:
|
| 514 |
+
if (
|
| 515 |
+
compute_cosine_distance(tf.embedding, face.embedding)
|
| 516 |
+
<= self.options.face_distance_threshold
|
| 517 |
+
):
|
| 518 |
+
if i < len(self.input_face_datas):
|
| 519 |
+
if use_index:
|
| 520 |
+
temp_frame = self.process_face(
|
| 521 |
+
self.options.selected_index, face, temp_frame
|
| 522 |
+
)
|
| 523 |
+
else:
|
| 524 |
+
temp_frame = self.process_face(i, face, temp_frame)
|
| 525 |
+
num_faces_found += 1
|
| 526 |
+
if (
|
| 527 |
+
not roop.globals.vr_mode
|
| 528 |
+
and num_faces_found == num_targetfaces
|
| 529 |
+
):
|
| 530 |
+
break
|
| 531 |
+
elif (
|
| 532 |
+
self.options.swap_mode == "all_female"
|
| 533 |
+
or self.options.swap_mode == "all_male"
|
| 534 |
+
):
|
| 535 |
+
gender = "F" if self.options.swap_mode == "all_female" else "M"
|
| 536 |
+
for face in faces:
|
| 537 |
+
if face.sex == gender:
|
| 538 |
+
num_faces_found += 1
|
| 539 |
+
temp_frame = self.process_face(
|
| 540 |
+
self.options.selected_index, face, temp_frame
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# might be slower but way more clean to release everything here
|
| 544 |
+
for face in faces:
|
| 545 |
+
del face
|
| 546 |
+
faces.clear()
|
| 547 |
+
|
| 548 |
+
if roop.globals.vr_mode and num_faces_found % 2 > 0:
|
| 549 |
+
# stereo image, there has to be an even number of faces
|
| 550 |
+
num_faces_found = 0
|
| 551 |
+
return num_faces_found, frame
|
| 552 |
+
if num_faces_found == 0:
|
| 553 |
+
return num_faces_found, frame
|
| 554 |
+
|
| 555 |
+
# maskprocessor = next((x for x in self.processors if x.type == 'mask'), None)
|
| 556 |
+
|
| 557 |
+
if (
|
| 558 |
+
self.options.imagemask is not None
|
| 559 |
+
and self.options.imagemask.shape == frame.shape
|
| 560 |
+
):
|
| 561 |
+
temp_frame = self.simple_blend_with_mask(
|
| 562 |
+
temp_frame, frame, self.options.imagemask
|
| 563 |
+
)
|
| 564 |
+
return num_faces_found, temp_frame
|
| 565 |
+
|
| 566 |
+
def rotation_action(self, original_face: Face, frame: Frame):
|
| 567 |
+
(height, width) = frame.shape[:2]
|
| 568 |
+
|
| 569 |
+
bounding_box_width = original_face.bbox[2] - original_face.bbox[0]
|
| 570 |
+
bounding_box_height = original_face.bbox[3] - original_face.bbox[1]
|
| 571 |
+
horizontal_face = bounding_box_width > bounding_box_height
|
| 572 |
+
|
| 573 |
+
center_x = width // 2.0
|
| 574 |
+
start_x = original_face.bbox[0]
|
| 575 |
+
end_x = original_face.bbox[2]
|
| 576 |
+
bbox_center_x = start_x + (bounding_box_width // 2.0)
|
| 577 |
+
|
| 578 |
+
# need to leverage the array of landmarks as decribed here:
|
| 579 |
+
# https://github.com/deepinsight/insightface/tree/master/alignment/coordinate_reg
|
| 580 |
+
# basically, we should be able to check for the relative position of eyes and nose
|
| 581 |
+
# then use that to determine which way the face is actually facing when in a horizontal position
|
| 582 |
+
# and use that to determine the correct rotation_action
|
| 583 |
+
|
| 584 |
+
forehead_x = original_face.landmark_2d_106[72][0]
|
| 585 |
+
chin_x = original_face.landmark_2d_106[0][0]
|
| 586 |
+
|
| 587 |
+
if horizontal_face:
|
| 588 |
+
if chin_x < forehead_x:
|
| 589 |
+
# this is someone lying down with their face like this (:
|
| 590 |
+
return "rotate_anticlockwise"
|
| 591 |
+
elif forehead_x < chin_x:
|
| 592 |
+
# this is someone lying down with their face like this :)
|
| 593 |
+
return "rotate_clockwise"
|
| 594 |
+
if bbox_center_x >= center_x:
|
| 595 |
+
# this is someone lying down with their face in the right hand side of the frame
|
| 596 |
+
return "rotate_anticlockwise"
|
| 597 |
+
if bbox_center_x < center_x:
|
| 598 |
+
# this is someone lying down with their face in the left hand side of the frame
|
| 599 |
+
return "rotate_clockwise"
|
| 600 |
+
|
| 601 |
+
return None
|
| 602 |
+
|
| 603 |
+
def auto_rotate_frame(self, original_face, frame: Frame):
|
| 604 |
+
target_face = original_face
|
| 605 |
+
original_frame = frame
|
| 606 |
+
|
| 607 |
+
rotation_action = self.rotation_action(original_face, frame)
|
| 608 |
+
|
| 609 |
+
if rotation_action == "rotate_anticlockwise":
|
| 610 |
+
# face is horizontal, rotating frame anti-clockwise and getting face bounding box from rotated frame
|
| 611 |
+
frame = rotate_anticlockwise(frame)
|
| 612 |
+
elif rotation_action == "rotate_clockwise":
|
| 613 |
+
# face is horizontal, rotating frame clockwise and getting face bounding box from rotated frame
|
| 614 |
+
frame = rotate_clockwise(frame)
|
| 615 |
+
|
| 616 |
+
return target_face, frame, rotation_action
|
| 617 |
+
|
| 618 |
+
def auto_unrotate_frame(self, frame: Frame, rotation_action):
|
| 619 |
+
if rotation_action == "rotate_anticlockwise":
|
| 620 |
+
return rotate_clockwise(frame)
|
| 621 |
+
elif rotation_action == "rotate_clockwise":
|
| 622 |
+
return rotate_anticlockwise(frame)
|
| 623 |
+
|
| 624 |
+
return frame
|
| 625 |
+
|
| 626 |
+
def process_face(self, face_index, target_face: Face, frame: Frame):
|
| 627 |
+
from roop.face_util import align_crop
|
| 628 |
+
|
| 629 |
+
enhanced_frame = None
|
| 630 |
+
if len(self.input_face_datas) > 0:
|
| 631 |
+
inputface = self.input_face_datas[face_index].faces[0]
|
| 632 |
+
else:
|
| 633 |
+
inputface = None
|
| 634 |
+
|
| 635 |
+
rotation_action = None
|
| 636 |
+
if roop.globals.autorotate_faces:
|
| 637 |
+
# check for sideways rotation of face
|
| 638 |
+
rotation_action = self.rotation_action(target_face, frame)
|
| 639 |
+
if rotation_action is not None:
|
| 640 |
+
(startX, startY, endX, endY) = target_face["bbox"].astype("int")
|
| 641 |
+
width = endX - startX
|
| 642 |
+
height = endY - startY
|
| 643 |
+
offs = int(max(width, height) * 0.25)
|
| 644 |
+
rotcutframe, startX, startY, endX, endY = self.cutout(
|
| 645 |
+
frame, startX - offs, startY - offs, endX + offs, endY + offs
|
| 646 |
+
)
|
| 647 |
+
if rotation_action == "rotate_anticlockwise":
|
| 648 |
+
rotcutframe = rotate_anticlockwise(rotcutframe)
|
| 649 |
+
elif rotation_action == "rotate_clockwise":
|
| 650 |
+
rotcutframe = rotate_clockwise(rotcutframe)
|
| 651 |
+
# rotate image and re-detect face to correct wonky landmarks
|
| 652 |
+
rotface = get_first_face(rotcutframe)
|
| 653 |
+
if rotface is None:
|
| 654 |
+
rotation_action = None
|
| 655 |
+
else:
|
| 656 |
+
saved_frame = frame.copy()
|
| 657 |
+
frame = rotcutframe
|
| 658 |
+
target_face = rotface
|
| 659 |
+
|
| 660 |
+
# if roop.globals.vr_mode:
|
| 661 |
+
# bbox = target_face.bbox
|
| 662 |
+
# [orig_width, orig_height, _] = frame.shape
|
| 663 |
+
|
| 664 |
+
# # Convert bounding box to ints
|
| 665 |
+
# x1, y1, x2, y2 = map(int, bbox)
|
| 666 |
+
|
| 667 |
+
# # Determine the center of the bounding box
|
| 668 |
+
# x_center = (x1 + x2) / 2
|
| 669 |
+
# y_center = (y1 + y2) / 2
|
| 670 |
+
|
| 671 |
+
# # Normalize coordinates to range [-1, 1]
|
| 672 |
+
# x_center_normalized = x_center / (orig_width / 2) - 1
|
| 673 |
+
# y_center_normalized = y_center / (orig_width / 2) - 1
|
| 674 |
+
|
| 675 |
+
# # Convert normalized coordinates to spherical (theta, phi)
|
| 676 |
+
# theta = x_center_normalized * 180 # Theta ranges from -180 to 180 degrees
|
| 677 |
+
# phi = -y_center_normalized * 90 # Phi ranges from -90 to 90 degrees
|
| 678 |
+
|
| 679 |
+
# img = vr.GetPerspective(frame, 90, theta, phi, 1280, 1280) # Generate perspective image
|
| 680 |
+
|
| 681 |
+
""" Code ported/adapted from Facefusion which borrowed the idea from Rope:
|
| 682 |
+
Kind of subsampling the cutout and aligned face image and faceswapping slices of it up to
|
| 683 |
+
the desired output resolution. This works around the current resolution limitations without using enhancers.
|
| 684 |
+
"""
|
| 685 |
+
model_output_size = self.options.swap_output_size
|
| 686 |
+
subsample_size = max(self.options.subsample_size, model_output_size)
|
| 687 |
+
subsample_total = subsample_size // model_output_size
|
| 688 |
+
aligned_img, M = align_crop(frame, target_face.kps, subsample_size)
|
| 689 |
+
|
| 690 |
+
fake_frame = aligned_img
|
| 691 |
+
target_face.matrix = M
|
| 692 |
+
|
| 693 |
+
for p in self.processors:
|
| 694 |
+
if p.type == "swap":
|
| 695 |
+
swap_result_frames = []
|
| 696 |
+
subsample_frames = self.implode_pixel_boost(
|
| 697 |
+
aligned_img, model_output_size, subsample_total
|
| 698 |
+
)
|
| 699 |
+
for sliced_frame in subsample_frames:
|
| 700 |
+
for _ in range(0, self.options.num_swap_steps):
|
| 701 |
+
sliced_frame = self.prepare_crop_frame(sliced_frame)
|
| 702 |
+
sliced_frame = p.Run(inputface, target_face, sliced_frame)
|
| 703 |
+
sliced_frame = self.normalize_swap_frame(sliced_frame)
|
| 704 |
+
swap_result_frames.append(sliced_frame)
|
| 705 |
+
fake_frame = self.explode_pixel_boost(
|
| 706 |
+
swap_result_frames,
|
| 707 |
+
model_output_size,
|
| 708 |
+
subsample_total,
|
| 709 |
+
subsample_size,
|
| 710 |
+
)
|
| 711 |
+
fake_frame = fake_frame.astype(np.uint8)
|
| 712 |
+
scale_factor = 0.0
|
| 713 |
+
elif p.type == "mask":
|
| 714 |
+
fake_frame = self.process_mask(p, aligned_img, fake_frame)
|
| 715 |
+
else:
|
| 716 |
+
enhanced_frame, scale_factor = p.Run(
|
| 717 |
+
self.input_face_datas[face_index], target_face, fake_frame
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
upscale = 512
|
| 721 |
+
orig_width = fake_frame.shape[1]
|
| 722 |
+
if orig_width != upscale:
|
| 723 |
+
fake_frame = cv2.resize(fake_frame, (upscale, upscale), cv2.INTER_CUBIC)
|
| 724 |
+
mask_offsets = (
|
| 725 |
+
(0, 0, 0, 0, 1, 20) if inputface is None else inputface.mask_offsets
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
if enhanced_frame is None:
|
| 729 |
+
scale_factor = int(upscale / orig_width)
|
| 730 |
+
result = self.paste_upscale(
|
| 731 |
+
fake_frame,
|
| 732 |
+
fake_frame,
|
| 733 |
+
target_face.matrix,
|
| 734 |
+
frame,
|
| 735 |
+
scale_factor,
|
| 736 |
+
mask_offsets,
|
| 737 |
+
)
|
| 738 |
+
else:
|
| 739 |
+
result = self.paste_upscale(
|
| 740 |
+
fake_frame,
|
| 741 |
+
enhanced_frame,
|
| 742 |
+
target_face.matrix,
|
| 743 |
+
frame,
|
| 744 |
+
scale_factor,
|
| 745 |
+
mask_offsets,
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# Restore mouth before unrotating
|
| 749 |
+
if self.options.restore_original_mouth:
|
| 750 |
+
mouth_cutout, mouth_bb = self.create_mouth_mask(target_face, frame)
|
| 751 |
+
result = self.apply_mouth_area(result, mouth_cutout, mouth_bb)
|
| 752 |
+
|
| 753 |
+
if rotation_action is not None:
|
| 754 |
+
fake_frame = self.auto_unrotate_frame(result, rotation_action)
|
| 755 |
+
result = self.paste_simple(fake_frame, saved_frame, startX, startY)
|
| 756 |
+
|
| 757 |
+
return result
|
| 758 |
+
|
| 759 |
+
def cutout(self, frame: Frame, start_x, start_y, end_x, end_y):
|
| 760 |
+
if start_x < 0:
|
| 761 |
+
start_x = 0
|
| 762 |
+
if start_y < 0:
|
| 763 |
+
start_y = 0
|
| 764 |
+
if end_x > frame.shape[1]:
|
| 765 |
+
end_x = frame.shape[1]
|
| 766 |
+
if end_y > frame.shape[0]:
|
| 767 |
+
end_y = frame.shape[0]
|
| 768 |
+
return frame[start_y:end_y, start_x:end_x], start_x, start_y, end_x, end_y
|
| 769 |
+
|
| 770 |
+
def paste_simple(self, src: Frame, dest: Frame, start_x, start_y):
|
| 771 |
+
end_x = start_x + src.shape[1]
|
| 772 |
+
end_y = start_y + src.shape[0]
|
| 773 |
+
|
| 774 |
+
start_x, end_x, start_y, end_y = clamp_cut_values(
|
| 775 |
+
start_x, end_x, start_y, end_y, dest
|
| 776 |
+
)
|
| 777 |
+
dest[start_y:end_y, start_x:end_x] = src
|
| 778 |
+
return dest
|
| 779 |
+
|
| 780 |
+
def simple_blend_with_mask(self, image1, image2, mask):
|
| 781 |
+
# Blend the images
|
| 782 |
+
blended_image = (
|
| 783 |
+
image1.astype(np.float32) * (1.0 - mask) + image2.astype(np.float32) * mask
|
| 784 |
+
)
|
| 785 |
+
return blended_image.astype(np.uint8)
|
| 786 |
+
|
| 787 |
+
def paste_upscale(
|
| 788 |
+
self, fake_face, upsk_face, M, target_img, scale_factor, mask_offsets
|
| 789 |
+
):
|
| 790 |
+
M_scale = M * scale_factor
|
| 791 |
+
IM = cv2.invertAffineTransform(M_scale)
|
| 792 |
+
|
| 793 |
+
face_matte = np.full(
|
| 794 |
+
(target_img.shape[0], target_img.shape[1]), 255, dtype=np.uint8
|
| 795 |
+
)
|
| 796 |
+
# Generate white square sized as a upsk_face
|
| 797 |
+
img_matte = np.zeros((upsk_face.shape[0], upsk_face.shape[1]), dtype=np.uint8)
|
| 798 |
+
|
| 799 |
+
w = img_matte.shape[1]
|
| 800 |
+
h = img_matte.shape[0]
|
| 801 |
+
|
| 802 |
+
top = int(mask_offsets[0] * h)
|
| 803 |
+
bottom = int(h - (mask_offsets[1] * h))
|
| 804 |
+
left = int(mask_offsets[2] * w)
|
| 805 |
+
right = int(w - (mask_offsets[3] * w))
|
| 806 |
+
img_matte[top:bottom, left:right] = 255
|
| 807 |
+
|
| 808 |
+
# Transform white square back to target_img
|
| 809 |
+
img_matte = cv2.warpAffine(
|
| 810 |
+
img_matte,
|
| 811 |
+
IM,
|
| 812 |
+
(target_img.shape[1], target_img.shape[0]),
|
| 813 |
+
flags=cv2.INTER_NEAREST,
|
| 814 |
+
borderValue=0.0,
|
| 815 |
+
)
|
| 816 |
+
##Blacken the edges of face_matte by 1 pixels (so the mask in not expanded on the image edges)
|
| 817 |
+
img_matte[:1, :] = img_matte[-1:, :] = img_matte[:, :1] = img_matte[:, -1:] = 0
|
| 818 |
+
|
| 819 |
+
img_matte = self.blur_area(img_matte, mask_offsets[4], mask_offsets[5])
|
| 820 |
+
# Normalize images to float values and reshape
|
| 821 |
+
img_matte = img_matte.astype(np.float32) / 255
|
| 822 |
+
face_matte = face_matte.astype(np.float32) / 255
|
| 823 |
+
img_matte = np.minimum(face_matte, img_matte)
|
| 824 |
+
if self.options.show_face_area_overlay:
|
| 825 |
+
# Additional steps for green overlay
|
| 826 |
+
green_overlay = np.zeros_like(target_img)
|
| 827 |
+
green_color = [0, 255, 0] # RGB for green
|
| 828 |
+
for i in range(3): # Apply green color where img_matte is not zero
|
| 829 |
+
green_overlay[:, :, i] = np.where(
|
| 830 |
+
img_matte > 0, green_color[i], 0
|
| 831 |
+
) ##Transform upcaled face back to target_img
|
| 832 |
+
img_matte = np.reshape(img_matte, [img_matte.shape[0], img_matte.shape[1], 1])
|
| 833 |
+
paste_face = cv2.warpAffine(
|
| 834 |
+
upsk_face,
|
| 835 |
+
IM,
|
| 836 |
+
(target_img.shape[1], target_img.shape[0]),
|
| 837 |
+
borderMode=cv2.BORDER_REPLICATE,
|
| 838 |
+
)
|
| 839 |
+
if upsk_face is not fake_face:
|
| 840 |
+
fake_face = cv2.warpAffine(
|
| 841 |
+
fake_face,
|
| 842 |
+
IM,
|
| 843 |
+
(target_img.shape[1], target_img.shape[0]),
|
| 844 |
+
borderMode=cv2.BORDER_REPLICATE,
|
| 845 |
+
)
|
| 846 |
+
paste_face = cv2.addWeighted(
|
| 847 |
+
paste_face,
|
| 848 |
+
self.options.blend_ratio,
|
| 849 |
+
fake_face,
|
| 850 |
+
1.0 - self.options.blend_ratio,
|
| 851 |
+
0,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
# Re-assemble image
|
| 855 |
+
paste_face = img_matte * paste_face
|
| 856 |
+
paste_face = paste_face + (1 - img_matte) * target_img.astype(np.float32)
|
| 857 |
+
if self.options.show_face_area_overlay:
|
| 858 |
+
# Overlay the green overlay on the final image
|
| 859 |
+
paste_face = cv2.addWeighted(
|
| 860 |
+
paste_face.astype(np.uint8), 1 - 0.5, green_overlay, 0.5, 0
|
| 861 |
+
)
|
| 862 |
+
return paste_face.astype(np.uint8)
|
| 863 |
+
|
| 864 |
+
def blur_area(self, img_matte, num_erosion_iterations, blur_amount):
|
| 865 |
+
# Detect the affine transformed white area
|
| 866 |
+
mask_h_inds, mask_w_inds = np.where(img_matte == 255)
|
| 867 |
+
# Calculate the size (and diagonal size) of transformed white area width and height boundaries
|
| 868 |
+
mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
|
| 869 |
+
mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
|
| 870 |
+
mask_size = int(np.sqrt(mask_h * mask_w))
|
| 871 |
+
# Calculate the kernel size for eroding img_matte by kernel (insightface empirical guess for best size was max(mask_size//10,10))
|
| 872 |
+
# k = max(mask_size//12, 8)
|
| 873 |
+
k = max(mask_size // (blur_amount // 2), blur_amount // 2)
|
| 874 |
+
kernel = np.ones((k, k), np.uint8)
|
| 875 |
+
img_matte = cv2.erode(img_matte, kernel, iterations=num_erosion_iterations)
|
| 876 |
+
# Calculate the kernel size for blurring img_matte by blur_size (insightface empirical guess for best size was max(mask_size//20, 5))
|
| 877 |
+
# k = max(mask_size//24, 4)
|
| 878 |
+
k = max(mask_size // blur_amount, blur_amount // 5)
|
| 879 |
+
kernel_size = (k, k)
|
| 880 |
+
blur_size = tuple(2 * i + 1 for i in kernel_size)
|
| 881 |
+
return cv2.GaussianBlur(img_matte, blur_size, 0)
|
| 882 |
+
|
| 883 |
+
def prepare_crop_frame(self, swap_frame):
|
| 884 |
+
model_type = "inswapper"
|
| 885 |
+
model_mean = [0.0, 0.0, 0.0]
|
| 886 |
+
model_standard_deviation = [1.0, 1.0, 1.0]
|
| 887 |
+
|
| 888 |
+
if model_type == "ghost":
|
| 889 |
+
swap_frame = swap_frame[:, :, ::-1] / 127.5 - 1
|
| 890 |
+
else:
|
| 891 |
+
swap_frame = swap_frame[:, :, ::-1] / 255.0
|
| 892 |
+
swap_frame = (swap_frame - model_mean) / model_standard_deviation
|
| 893 |
+
swap_frame = swap_frame.transpose(2, 0, 1)
|
| 894 |
+
swap_frame = np.expand_dims(swap_frame, axis=0).astype(np.float32)
|
| 895 |
+
return swap_frame
|
| 896 |
+
|
| 897 |
+
def normalize_swap_frame(self, swap_frame):
|
| 898 |
+
model_type = "inswapper"
|
| 899 |
+
swap_frame = swap_frame.transpose(1, 2, 0)
|
| 900 |
+
|
| 901 |
+
if model_type == "ghost":
|
| 902 |
+
swap_frame = (swap_frame * 127.5 + 127.5).round()
|
| 903 |
+
else:
|
| 904 |
+
swap_frame = (swap_frame * 255.0).round()
|
| 905 |
+
swap_frame = swap_frame[:, :, ::-1]
|
| 906 |
+
return swap_frame
|
| 907 |
+
|
| 908 |
+
def implode_pixel_boost(
|
| 909 |
+
self, aligned_face_frame, model_size, pixel_boost_total: int
|
| 910 |
+
):
|
| 911 |
+
subsample_frame = aligned_face_frame.reshape(
|
| 912 |
+
model_size, pixel_boost_total, model_size, pixel_boost_total, 3
|
| 913 |
+
)
|
| 914 |
+
subsample_frame = subsample_frame.transpose(1, 3, 0, 2, 4).reshape(
|
| 915 |
+
pixel_boost_total**2, model_size, model_size, 3
|
| 916 |
+
)
|
| 917 |
+
return subsample_frame
|
| 918 |
+
|
| 919 |
+
def explode_pixel_boost(
|
| 920 |
+
self, subsample_frame, model_size, pixel_boost_total, pixel_boost_size
|
| 921 |
+
):
|
| 922 |
+
final_frame = np.stack(subsample_frame, axis=0).reshape(
|
| 923 |
+
pixel_boost_total, pixel_boost_total, model_size, model_size, 3
|
| 924 |
+
)
|
| 925 |
+
final_frame = final_frame.transpose(2, 0, 3, 1, 4).reshape(
|
| 926 |
+
pixel_boost_size, pixel_boost_size, 3
|
| 927 |
+
)
|
| 928 |
+
return final_frame
|
| 929 |
+
|
| 930 |
+
def process_mask(self, processor, frame: Frame, target: Frame):
|
| 931 |
+
img_mask = processor.Run(frame, self.options.masking_text)
|
| 932 |
+
img_mask = cv2.resize(img_mask, (target.shape[1], target.shape[0]))
|
| 933 |
+
img_mask = np.reshape(img_mask, [img_mask.shape[0], img_mask.shape[1], 1])
|
| 934 |
+
|
| 935 |
+
if self.options.show_face_masking:
|
| 936 |
+
result = (1 - img_mask) * frame.astype(np.float32)
|
| 937 |
+
return np.uint8(result)
|
| 938 |
+
|
| 939 |
+
target = target.astype(np.float32)
|
| 940 |
+
result = (1 - img_mask) * target
|
| 941 |
+
result += img_mask * frame.astype(np.float32)
|
| 942 |
+
return np.uint8(result)
|
| 943 |
+
|
| 944 |
+
# Code for mouth restoration adapted from https://github.com/iVideoGameBoss/iRoopDeepFaceCam
|
| 945 |
+
|
| 946 |
+
def create_mouth_mask(self, face: Face, frame: Frame):
|
| 947 |
+
mouth_cutout = None
|
| 948 |
+
|
| 949 |
+
landmarks = face.landmark_2d_106
|
| 950 |
+
if landmarks is not None:
|
| 951 |
+
# Get mouth landmarks (indices 52 to 71 typically represent the outer mouth)
|
| 952 |
+
mouth_points = landmarks[52:71].astype(np.int32)
|
| 953 |
+
|
| 954 |
+
# Add padding to mouth area
|
| 955 |
+
min_x, min_y = np.min(mouth_points, axis=0)
|
| 956 |
+
max_x, max_y = np.max(mouth_points, axis=0)
|
| 957 |
+
min_x = max(0, min_x - (15 * 6))
|
| 958 |
+
min_y = max(0, min_y - 22)
|
| 959 |
+
max_x = min(frame.shape[1], max_x + (15 * 6))
|
| 960 |
+
max_y = min(frame.shape[0], max_y + (90 * 6))
|
| 961 |
+
|
| 962 |
+
# Extract the mouth area from the frame using the calculated bounding box
|
| 963 |
+
mouth_cutout = frame[min_y:max_y, min_x:max_x].copy()
|
| 964 |
+
|
| 965 |
+
return mouth_cutout, (min_x, min_y, max_x, max_y)
|
| 966 |
+
|
| 967 |
+
def create_feathered_mask(self, shape, feather_amount=30):
|
| 968 |
+
mask = np.zeros(shape[:2], dtype=np.float32)
|
| 969 |
+
center = (shape[1] // 2, shape[0] // 2)
|
| 970 |
+
cv2.ellipse(
|
| 971 |
+
mask,
|
| 972 |
+
center,
|
| 973 |
+
(shape[1] // 2 - feather_amount, shape[0] // 2 - feather_amount),
|
| 974 |
+
0,
|
| 975 |
+
0,
|
| 976 |
+
360,
|
| 977 |
+
1,
|
| 978 |
+
-1,
|
| 979 |
+
)
|
| 980 |
+
mask = cv2.GaussianBlur(
|
| 981 |
+
mask, (feather_amount * 2 + 1, feather_amount * 2 + 1), 0
|
| 982 |
+
)
|
| 983 |
+
return mask / np.max(mask)
|
| 984 |
+
|
| 985 |
+
def apply_mouth_area(
|
| 986 |
+
self, frame: np.ndarray, mouth_cutout: np.ndarray, mouth_box: tuple
|
| 987 |
+
) -> np.ndarray:
|
| 988 |
+
min_x, min_y, max_x, max_y = mouth_box
|
| 989 |
+
box_width = max_x - min_x
|
| 990 |
+
box_height = max_y - min_y
|
| 991 |
+
|
| 992 |
+
# Resize the mouth cutout to match the mouth box size
|
| 993 |
+
if mouth_cutout is None or box_width is None or box_height is None:
|
| 994 |
+
return frame
|
| 995 |
+
try:
|
| 996 |
+
resized_mouth_cutout = cv2.resize(mouth_cutout, (box_width, box_height))
|
| 997 |
+
|
| 998 |
+
# Extract the region of interest (ROI) from the target frame
|
| 999 |
+
roi = frame[min_y:max_y, min_x:max_x]
|
| 1000 |
+
|
| 1001 |
+
# Ensure the ROI and resized_mouth_cutout have the same shape
|
| 1002 |
+
if roi.shape != resized_mouth_cutout.shape:
|
| 1003 |
+
resized_mouth_cutout = cv2.resize(
|
| 1004 |
+
resized_mouth_cutout, (roi.shape[1], roi.shape[0])
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
# Apply color transfer from ROI to mouth cutout
|
| 1008 |
+
color_corrected_mouth = self.apply_color_transfer(resized_mouth_cutout, roi)
|
| 1009 |
+
|
| 1010 |
+
# Create a feathered mask with increased feather amount
|
| 1011 |
+
feather_amount = min(30, box_width // 15, box_height // 15)
|
| 1012 |
+
mask = self.create_feathered_mask(
|
| 1013 |
+
resized_mouth_cutout.shape, feather_amount
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
# Blend the color-corrected mouth cutout with the ROI using the feathered mask
|
| 1017 |
+
mask = mask[:, :, np.newaxis] # Add channel dimension to mask
|
| 1018 |
+
blended = (color_corrected_mouth * mask + roi * (1 - mask)).astype(np.uint8)
|
| 1019 |
+
|
| 1020 |
+
# Place the blended result back into the frame
|
| 1021 |
+
frame[min_y:max_y, min_x:max_x] = blended
|
| 1022 |
+
except Exception as e:
|
| 1023 |
+
print(f"Error {e}")
|
| 1024 |
+
pass
|
| 1025 |
+
|
| 1026 |
+
return frame
|
| 1027 |
+
|
| 1028 |
+
def apply_color_transfer(self, source, target):
|
| 1029 |
+
"""
|
| 1030 |
+
Apply color transfer from target to source image
|
| 1031 |
+
"""
|
| 1032 |
+
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype("float32")
|
| 1033 |
+
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype("float32")
|
| 1034 |
+
|
| 1035 |
+
source_mean, source_std = cv2.meanStdDev(source)
|
| 1036 |
+
target_mean, target_std = cv2.meanStdDev(target)
|
| 1037 |
+
|
| 1038 |
+
# Reshape mean and std to be broadcastable
|
| 1039 |
+
source_mean = source_mean.reshape(1, 1, 3)
|
| 1040 |
+
source_std = source_std.reshape(1, 1, 3)
|
| 1041 |
+
target_mean = target_mean.reshape(1, 1, 3)
|
| 1042 |
+
target_std = target_std.reshape(1, 1, 3)
|
| 1043 |
+
|
| 1044 |
+
# Perform the color transfer
|
| 1045 |
+
source = (source - source_mean) * (target_std / source_std) + target_mean
|
| 1046 |
+
return cv2.cvtColor(np.clip(source, 0, 255).astype("uint8"), cv2.COLOR_LAB2BGR)
|
| 1047 |
+
|
| 1048 |
+
def unload_models():
|
| 1049 |
+
pass
|
| 1050 |
+
|
| 1051 |
+
def release_resources(self):
|
| 1052 |
+
for p in self.processors:
|
| 1053 |
+
p.Release()
|
| 1054 |
+
self.processors.clear()
|
| 1055 |
+
if self.videowriter is not None:
|
| 1056 |
+
self.videowriter.close()
|
| 1057 |
+
if self.streamwriter is not None:
|
| 1058 |
+
self.streamwriter.Close()
|
ProcessOptions.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ProcessOptions:
|
| 2 |
+
def __init__(
|
| 3 |
+
self,
|
| 4 |
+
swap_model,
|
| 5 |
+
processordefines: dict,
|
| 6 |
+
face_distance,
|
| 7 |
+
blend_ratio,
|
| 8 |
+
swap_mode,
|
| 9 |
+
selected_index,
|
| 10 |
+
masking_text,
|
| 11 |
+
imagemask,
|
| 12 |
+
num_steps,
|
| 13 |
+
subsample_size,
|
| 14 |
+
show_face_area,
|
| 15 |
+
restore_original_mouth,
|
| 16 |
+
show_mask=False,
|
| 17 |
+
):
|
| 18 |
+
if swap_model is not None:
|
| 19 |
+
self.swap_modelname = swap_model
|
| 20 |
+
self.swap_output_size = int(swap_model.split()[-1])
|
| 21 |
+
else:
|
| 22 |
+
self.swap_output_size = 128
|
| 23 |
+
self.processors = processordefines
|
| 24 |
+
self.face_distance_threshold = face_distance
|
| 25 |
+
self.blend_ratio = blend_ratio
|
| 26 |
+
self.swap_mode = swap_mode
|
| 27 |
+
self.selected_index = selected_index
|
| 28 |
+
self.masking_text = masking_text
|
| 29 |
+
self.imagemask = imagemask
|
| 30 |
+
self.num_swap_steps = num_steps
|
| 31 |
+
self.show_face_area_overlay = show_face_area
|
| 32 |
+
self.show_face_masking = show_mask
|
| 33 |
+
self.subsample_size = subsample_size
|
| 34 |
+
self.restore_original_mouth = restore_original_mouth
|
| 35 |
+
self.max_num_reuse_frame = 15
|
StreamWriter.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import time
|
| 3 |
+
import pyvirtualcam
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class StreamWriter:
|
| 7 |
+
FPS = 30
|
| 8 |
+
VCam = None
|
| 9 |
+
Active = False
|
| 10 |
+
THREAD_LOCK_STREAM = threading.Lock()
|
| 11 |
+
time_last_process = None
|
| 12 |
+
timespan_min = 0.0
|
| 13 |
+
|
| 14 |
+
def __enter__(self):
|
| 15 |
+
return self
|
| 16 |
+
|
| 17 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 18 |
+
self.Close()
|
| 19 |
+
|
| 20 |
+
def __init__(self, size, fps):
|
| 21 |
+
self.time_last_process = time.perf_counter()
|
| 22 |
+
self.FPS = fps
|
| 23 |
+
self.timespan_min = 1.0 / fps
|
| 24 |
+
print("Detecting virtual cam devices")
|
| 25 |
+
self.VCam = pyvirtualcam.Camera(
|
| 26 |
+
width=size[0],
|
| 27 |
+
height=size[1],
|
| 28 |
+
fps=fps,
|
| 29 |
+
fmt=pyvirtualcam.PixelFormat.BGR,
|
| 30 |
+
print_fps=False,
|
| 31 |
+
)
|
| 32 |
+
if self.VCam is None:
|
| 33 |
+
print("No virtual camera found!")
|
| 34 |
+
return
|
| 35 |
+
print(f"Using virtual camera: {self.VCam.device}")
|
| 36 |
+
print(f"Using {self.VCam.native_fmt}")
|
| 37 |
+
self.Active = True
|
| 38 |
+
|
| 39 |
+
def LimitFrames(self):
|
| 40 |
+
while True:
|
| 41 |
+
current_time = time.perf_counter()
|
| 42 |
+
time_passed = current_time - self.time_last_process
|
| 43 |
+
if time_passed >= self.timespan_min:
|
| 44 |
+
break
|
| 45 |
+
|
| 46 |
+
# First version used a queue and threading. Surprisingly this
|
| 47 |
+
# totally simple, blocking version is 10 times faster!
|
| 48 |
+
def WriteToStream(self, frame):
|
| 49 |
+
if self.VCam is None:
|
| 50 |
+
return
|
| 51 |
+
with self.THREAD_LOCK_STREAM:
|
| 52 |
+
self.LimitFrames()
|
| 53 |
+
self.VCam.send(frame)
|
| 54 |
+
self.time_last_process = time.perf_counter()
|
| 55 |
+
|
| 56 |
+
def Close(self):
|
| 57 |
+
self.Active = False
|
| 58 |
+
if self.VCam is None:
|
| 59 |
+
self.VCam.close()
|
| 60 |
+
self.VCam = None
|
__init__.py
ADDED
|
File without changes
|
__pycache__/FaceSet.cpython-310.pyc
ADDED
|
Binary file (1.09 kB). View file
|
|
|
__pycache__/ProcessEntry.cpython-310.pyc
ADDED
|
Binary file (662 Bytes). View file
|
|
|
__pycache__/ProcessMgr.cpython-310.pyc
ADDED
|
Binary file (22.7 kB). View file
|
|
|
__pycache__/ProcessOptions.cpython-310.pyc
ADDED
|
Binary file (1.09 kB). View file
|
|
|
__pycache__/StreamWriter.cpython-310.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (220 Bytes). View file
|
|
|
__pycache__/capturer.cpython-310.pyc
ADDED
|
Binary file (1.48 kB). View file
|
|
|
__pycache__/core.cpython-310.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
__pycache__/face_util.cpython-310.pyc
ADDED
|
Binary file (8.15 kB). View file
|
|
|
__pycache__/ffmpeg_writer.cpython-310.pyc
ADDED
|
Binary file (5.76 kB). View file
|
|
|
__pycache__/globals.cpython-310.pyc
ADDED
|
Binary file (1.33 kB). View file
|
|
|
__pycache__/metadata.cpython-310.pyc
ADDED
|
Binary file (265 Bytes). View file
|
|
|
__pycache__/template_parser.cpython-310.pyc
ADDED
|
Binary file (1.18 kB). View file
|
|
|
__pycache__/typing.cpython-310.pyc
ADDED
|
Binary file (408 Bytes). View file
|
|
|
__pycache__/util_ffmpeg.cpython-310.pyc
ADDED
|
Binary file (5.05 kB). View file
|
|
|
__pycache__/utilities.cpython-310.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
__pycache__/vr_util.cpython-310.pyc
ADDED
|
Binary file (1.51 kB). View file
|
|
|
capturer.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from roop.typing import Frame
|
| 6 |
+
|
| 7 |
+
current_video_path = None
|
| 8 |
+
current_frame_total = 0
|
| 9 |
+
current_capture = None
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_image_frame(filename: str):
|
| 13 |
+
try:
|
| 14 |
+
return cv2.imdecode(np.fromfile(filename, dtype=np.uint8), cv2.IMREAD_COLOR)
|
| 15 |
+
except:
|
| 16 |
+
print(f"Exception reading {filename}")
|
| 17 |
+
return None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Frame]:
|
| 21 |
+
global current_video_path, current_capture, current_frame_total
|
| 22 |
+
|
| 23 |
+
if video_path != current_video_path:
|
| 24 |
+
release_video()
|
| 25 |
+
current_capture = cv2.VideoCapture(video_path)
|
| 26 |
+
current_video_path = video_path
|
| 27 |
+
current_frame_total = current_capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
| 28 |
+
|
| 29 |
+
current_capture.set(
|
| 30 |
+
cv2.CAP_PROP_POS_FRAMES, min(current_frame_total, frame_number - 1)
|
| 31 |
+
)
|
| 32 |
+
has_frame, frame = current_capture.read()
|
| 33 |
+
if has_frame:
|
| 34 |
+
return frame
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def release_video():
|
| 39 |
+
global current_capture
|
| 40 |
+
|
| 41 |
+
if current_capture is not None:
|
| 42 |
+
current_capture.release()
|
| 43 |
+
current_capture = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_video_frame_total(video_path: str) -> int:
|
| 47 |
+
capture = cv2.VideoCapture(video_path)
|
| 48 |
+
video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 49 |
+
capture.release()
|
| 50 |
+
return video_frame_total
|
core.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import shutil
|
| 6 |
+
|
| 7 |
+
# single thread doubles cuda performance - needs to be set before torch import
|
| 8 |
+
if any(arg.startswith("--execution-provider") for arg in sys.argv):
|
| 9 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 10 |
+
|
| 11 |
+
import warnings
|
| 12 |
+
from typing import List
|
| 13 |
+
import platform
|
| 14 |
+
import signal
|
| 15 |
+
import torch
|
| 16 |
+
import onnxruntime
|
| 17 |
+
import pathlib
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
from time import time
|
| 21 |
+
|
| 22 |
+
import roop.globals
|
| 23 |
+
import roop.metadata
|
| 24 |
+
import roop.utilities as util
|
| 25 |
+
import roop.util_ffmpeg as ffmpeg
|
| 26 |
+
import ui.main as main
|
| 27 |
+
from settings import Settings
|
| 28 |
+
from roop.face_util import extract_face_images
|
| 29 |
+
from roop.ProcessEntry import ProcessEntry
|
| 30 |
+
from roop.ProcessMgr import ProcessMgr
|
| 31 |
+
from roop.ProcessOptions import ProcessOptions
|
| 32 |
+
from roop.capturer import get_video_frame_total, release_video
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
clip_text = None
|
| 36 |
+
|
| 37 |
+
call_display_ui = None
|
| 38 |
+
|
| 39 |
+
process_mgr = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if "ROCMExecutionProvider" in roop.globals.execution_providers:
|
| 43 |
+
del torch
|
| 44 |
+
|
| 45 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="insightface")
|
| 46 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def parse_args() -> None:
|
| 50 |
+
signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
|
| 51 |
+
roop.globals.headless = False
|
| 52 |
+
|
| 53 |
+
program = argparse.ArgumentParser(
|
| 54 |
+
formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100)
|
| 55 |
+
)
|
| 56 |
+
program.add_argument(
|
| 57 |
+
"--server_share",
|
| 58 |
+
help="Public server",
|
| 59 |
+
dest="server_share",
|
| 60 |
+
action="store_true",
|
| 61 |
+
default=False,
|
| 62 |
+
)
|
| 63 |
+
program.add_argument(
|
| 64 |
+
"--cuda_device_id",
|
| 65 |
+
help="Index of the cuda gpu to use",
|
| 66 |
+
dest="cuda_device_id",
|
| 67 |
+
type=int,
|
| 68 |
+
default=0,
|
| 69 |
+
)
|
| 70 |
+
roop.globals.startup_args = program.parse_args()
|
| 71 |
+
# Always enable all processors when using GUI
|
| 72 |
+
roop.globals.frame_processors = ["face_swapper", "face_enhancer"]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
|
| 76 |
+
return [
|
| 77 |
+
execution_provider.replace("ExecutionProvider", "").lower()
|
| 78 |
+
for execution_provider in execution_providers
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def decode_execution_providers(execution_providers: List[str]) -> List[str]:
|
| 83 |
+
list_providers = [
|
| 84 |
+
provider
|
| 85 |
+
for provider, encoded_execution_provider in zip(
|
| 86 |
+
onnxruntime.get_available_providers(),
|
| 87 |
+
encode_execution_providers(onnxruntime.get_available_providers()),
|
| 88 |
+
)
|
| 89 |
+
if any(
|
| 90 |
+
execution_provider in encoded_execution_provider
|
| 91 |
+
for execution_provider in execution_providers
|
| 92 |
+
)
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
for i in range(len(list_providers)):
|
| 97 |
+
if list_providers[i] == "CUDAExecutionProvider":
|
| 98 |
+
list_providers[i] = (
|
| 99 |
+
"CUDAExecutionProvider",
|
| 100 |
+
{"device_id": roop.globals.cuda_device_id},
|
| 101 |
+
)
|
| 102 |
+
torch.cuda.set_device(roop.globals.cuda_device_id)
|
| 103 |
+
break
|
| 104 |
+
except:
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
return list_providers
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def suggest_max_memory() -> int:
|
| 111 |
+
if platform.system().lower() == "darwin":
|
| 112 |
+
return 4
|
| 113 |
+
return 16
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def suggest_execution_providers() -> List[str]:
|
| 117 |
+
return encode_execution_providers(onnxruntime.get_available_providers())
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def suggest_execution_threads() -> int:
|
| 121 |
+
if "DmlExecutionProvider" in roop.globals.execution_providers:
|
| 122 |
+
return 1
|
| 123 |
+
if "ROCMExecutionProvider" in roop.globals.execution_providers:
|
| 124 |
+
return 1
|
| 125 |
+
return 8
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def limit_resources() -> None:
|
| 129 |
+
# limit memory usage
|
| 130 |
+
if roop.globals.max_memory:
|
| 131 |
+
memory = roop.globals.max_memory * 1024**3
|
| 132 |
+
if platform.system().lower() == "darwin":
|
| 133 |
+
memory = roop.globals.max_memory * 1024**6
|
| 134 |
+
if platform.system().lower() == "windows":
|
| 135 |
+
import ctypes
|
| 136 |
+
|
| 137 |
+
kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
|
| 138 |
+
kernel32.SetProcessWorkingSetSize(
|
| 139 |
+
-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory)
|
| 140 |
+
)
|
| 141 |
+
else:
|
| 142 |
+
import resource
|
| 143 |
+
|
| 144 |
+
resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def release_resources() -> None:
|
| 148 |
+
import gc
|
| 149 |
+
|
| 150 |
+
global process_mgr
|
| 151 |
+
|
| 152 |
+
if process_mgr is not None:
|
| 153 |
+
process_mgr.release_resources()
|
| 154 |
+
process_mgr = None
|
| 155 |
+
|
| 156 |
+
gc.collect()
|
| 157 |
+
if (
|
| 158 |
+
"CUDAExecutionProvider" in roop.globals.execution_providers
|
| 159 |
+
and torch.cuda.is_available()
|
| 160 |
+
):
|
| 161 |
+
with torch.cuda.device("cuda"):
|
| 162 |
+
torch.cuda.empty_cache()
|
| 163 |
+
torch.cuda.ipc_collect()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def pre_check() -> bool:
|
| 167 |
+
if sys.version_info < (3, 9):
|
| 168 |
+
update_status(
|
| 169 |
+
"Python version is not supported - please upgrade to 3.9 or higher."
|
| 170 |
+
)
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
download_directory_path = util.resolve_relative_path("../models")
|
| 174 |
+
util.conditional_download(
|
| 175 |
+
download_directory_path,
|
| 176 |
+
[
|
| 177 |
+
[
|
| 178 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/inswapper_128.onnx",
|
| 179 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/InSwapper/inswapper_128.onnx",
|
| 180 |
+
],
|
| 181 |
+
[
|
| 182 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/reswapper_128.onnx",
|
| 183 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/ReSwapper/reswapper_128.onnx",
|
| 184 |
+
],
|
| 185 |
+
[
|
| 186 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/reswapper_256.onnx",
|
| 187 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/ReSwapper/reswapper_256.onnx",
|
| 188 |
+
],
|
| 189 |
+
[
|
| 190 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/GFPGANv1.4.onnx",
|
| 191 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/GFPGAN/GFPGANv1.4.onnx",
|
| 192 |
+
],
|
| 193 |
+
[
|
| 194 |
+
"https://github.com/csxmli2016/DMDNet/releases/download/v1/DMDNet.pth",
|
| 195 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/DMDNet/DMDNet.pth",
|
| 196 |
+
],
|
| 197 |
+
[
|
| 198 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/GPEN-BFR-512.onnx",
|
| 199 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/GPEN/GPEN-BFR-512.onnx",
|
| 200 |
+
],
|
| 201 |
+
[
|
| 202 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/restoreformer_plus_plus.onnx",
|
| 203 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/RestoreFormer/restoreformer_plus_plus.onnx",
|
| 204 |
+
],
|
| 205 |
+
[
|
| 206 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/xseg.onnx",
|
| 207 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/xseg.onnx",
|
| 208 |
+
],
|
| 209 |
+
],
|
| 210 |
+
)
|
| 211 |
+
download_directory_path = util.resolve_relative_path("../models/CLIP")
|
| 212 |
+
util.conditional_download(
|
| 213 |
+
download_directory_path,
|
| 214 |
+
[
|
| 215 |
+
[
|
| 216 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/rd64-uni-refined.pth",
|
| 217 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/rd64-uni-refined.pth",
|
| 218 |
+
]
|
| 219 |
+
],
|
| 220 |
+
)
|
| 221 |
+
download_directory_path = util.resolve_relative_path("../models/buffalo_l")
|
| 222 |
+
util.conditional_download(
|
| 223 |
+
download_directory_path,
|
| 224 |
+
[
|
| 225 |
+
[
|
| 226 |
+
"https://huggingface.co/halllooo/buffalo_l/resolve/main/1k3d68.onnx",
|
| 227 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/buffalo_l/1k3d68.onnx",
|
| 228 |
+
],
|
| 229 |
+
[
|
| 230 |
+
"https://huggingface.co/halllooo/buffalo_l/resolve/main/2d106det.onnx",
|
| 231 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/buffalo_l/2d106det.onnx",
|
| 232 |
+
],
|
| 233 |
+
[
|
| 234 |
+
"https://huggingface.co/halllooo/buffalo_l/resolve/main/det_10g.onnx",
|
| 235 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/buffalo_l/det_10g.onnx",
|
| 236 |
+
],
|
| 237 |
+
[
|
| 238 |
+
"https://huggingface.co/halllooo/buffalo_l/resolve/main/genderage.onnx",
|
| 239 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/buffalo_l/genderage.onnx",
|
| 240 |
+
],
|
| 241 |
+
[
|
| 242 |
+
"https://huggingface.co/halllooo/buffalo_l/resolve/main/w600k_r50.onnx",
|
| 243 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/buffalo_l/w600k_r50.onnx",
|
| 244 |
+
],
|
| 245 |
+
],
|
| 246 |
+
)
|
| 247 |
+
download_directory_path = util.resolve_relative_path("../models/CodeFormer")
|
| 248 |
+
util.conditional_download(
|
| 249 |
+
download_directory_path,
|
| 250 |
+
[
|
| 251 |
+
[
|
| 252 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/CodeFormerv0.1.onnx",
|
| 253 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/CodeFormer/CodeFormerv0.1.onnx",
|
| 254 |
+
]
|
| 255 |
+
],
|
| 256 |
+
)
|
| 257 |
+
download_directory_path = util.resolve_relative_path("../models/Frame")
|
| 258 |
+
util.conditional_download(
|
| 259 |
+
download_directory_path,
|
| 260 |
+
[
|
| 261 |
+
[
|
| 262 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/deoldify_artistic.onnx",
|
| 263 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/DeOldify/deoldify_artistic.onnx",
|
| 264 |
+
],
|
| 265 |
+
[
|
| 266 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/deoldify_stable.onnx",
|
| 267 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/DeOldify/deoldify_stable.onnx",
|
| 268 |
+
],
|
| 269 |
+
[
|
| 270 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/isnet-general-use.onnx",
|
| 271 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/isnet-general-use.onnx",
|
| 272 |
+
],
|
| 273 |
+
[
|
| 274 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/real_esrgan_x4.onnx",
|
| 275 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/real_esrgan_x4.onnx",
|
| 276 |
+
],
|
| 277 |
+
[
|
| 278 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/real_esrgan_x2.onnx",
|
| 279 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/real_esrgan_x2.onnx",
|
| 280 |
+
],
|
| 281 |
+
[
|
| 282 |
+
"https://huggingface.co/countfloyd/deepfake/resolve/main/lsdir_x4.onnx",
|
| 283 |
+
"https://codeberg.org/roop-unleashed/models/media/branch/main/lsdir_x4.onnx",
|
| 284 |
+
],
|
| 285 |
+
],
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if not shutil.which("ffmpeg"):
|
| 289 |
+
update_status("ffmpeg is not installed.")
|
| 290 |
+
return True
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def set_display_ui(function):
|
| 294 |
+
global call_display_ui
|
| 295 |
+
|
| 296 |
+
call_display_ui = function
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def update_status(message: str) -> None:
|
| 300 |
+
global call_display_ui
|
| 301 |
+
|
| 302 |
+
print(message)
|
| 303 |
+
if call_display_ui is not None:
|
| 304 |
+
call_display_ui(message)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def start() -> None:
|
| 308 |
+
if roop.globals.headless:
|
| 309 |
+
print("Headless mode currently unsupported - starting UI!")
|
| 310 |
+
# faces = extract_face_images(roop.globals.source_path, (False, 0))
|
| 311 |
+
# roop.globals.INPUT_FACES.append(faces[roop.globals.source_face_index])
|
| 312 |
+
# faces = extract_face_images(roop.globals.target_path, (False, util.has_image_extension(roop.globals.target_path)))
|
| 313 |
+
# roop.globals.TARGET_FACES.append(faces[roop.globals.target_face_index])
|
| 314 |
+
# if 'face_enhancer' in roop.globals.frame_processors:
|
| 315 |
+
# roop.globals.selected_enhancer = 'GFPGAN'
|
| 316 |
+
|
| 317 |
+
batch_process_regular(None, False, None)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def get_processing_plugins(masking_engine):
|
| 321 |
+
processors = {"faceswap": {}}
|
| 322 |
+
if masking_engine is not None:
|
| 323 |
+
processors.update({masking_engine: {}})
|
| 324 |
+
|
| 325 |
+
if roop.globals.selected_enhancer == "GFPGAN":
|
| 326 |
+
processors.update({"gfpgan": {}})
|
| 327 |
+
elif roop.globals.selected_enhancer == "Codeformer":
|
| 328 |
+
processors.update({"codeformer": {}})
|
| 329 |
+
elif roop.globals.selected_enhancer == "DMDNet":
|
| 330 |
+
processors.update({"dmdnet": {}})
|
| 331 |
+
elif roop.globals.selected_enhancer == "GPEN":
|
| 332 |
+
processors.update({"gpen": {}})
|
| 333 |
+
elif roop.globals.selected_enhancer == "Restoreformer++":
|
| 334 |
+
processors.update({"restoreformer++": {}})
|
| 335 |
+
return processors
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def live_swap(frame, options):
|
| 339 |
+
global process_mgr
|
| 340 |
+
|
| 341 |
+
if frame is None:
|
| 342 |
+
return frame
|
| 343 |
+
|
| 344 |
+
if process_mgr is None:
|
| 345 |
+
process_mgr = ProcessMgr(None)
|
| 346 |
+
|
| 347 |
+
# if len(roop.globals.INPUT_FACESETS) <= selected_index:
|
| 348 |
+
# selected_index = 0
|
| 349 |
+
process_mgr.initialize(
|
| 350 |
+
roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options
|
| 351 |
+
)
|
| 352 |
+
newframe = process_mgr.process_frame(frame)
|
| 353 |
+
if newframe is None:
|
| 354 |
+
return frame
|
| 355 |
+
return newframe
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def batch_process_regular(
|
| 359 |
+
swap_model,
|
| 360 |
+
output_method,
|
| 361 |
+
files: list[ProcessEntry],
|
| 362 |
+
masking_engine: str,
|
| 363 |
+
new_clip_text: str,
|
| 364 |
+
use_new_method,
|
| 365 |
+
imagemask,
|
| 366 |
+
restore_original_mouth,
|
| 367 |
+
num_swap_steps,
|
| 368 |
+
progress,
|
| 369 |
+
selected_index=0,
|
| 370 |
+
) -> None:
|
| 371 |
+
global clip_text, process_mgr
|
| 372 |
+
|
| 373 |
+
release_resources()
|
| 374 |
+
limit_resources()
|
| 375 |
+
if process_mgr is None:
|
| 376 |
+
process_mgr = ProcessMgr(progress)
|
| 377 |
+
mask = imagemask["layers"][0] if imagemask is not None else None
|
| 378 |
+
if len(roop.globals.INPUT_FACESETS) <= selected_index:
|
| 379 |
+
selected_index = 0
|
| 380 |
+
options = ProcessOptions(
|
| 381 |
+
swap_model,
|
| 382 |
+
get_processing_plugins(masking_engine),
|
| 383 |
+
roop.globals.distance_threshold,
|
| 384 |
+
roop.globals.blend_ratio,
|
| 385 |
+
roop.globals.face_swap_mode,
|
| 386 |
+
selected_index,
|
| 387 |
+
new_clip_text,
|
| 388 |
+
mask,
|
| 389 |
+
num_swap_steps,
|
| 390 |
+
roop.globals.subsample_size,
|
| 391 |
+
False,
|
| 392 |
+
restore_original_mouth,
|
| 393 |
+
)
|
| 394 |
+
process_mgr.initialize(
|
| 395 |
+
roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options
|
| 396 |
+
)
|
| 397 |
+
batch_process(output_method, files, use_new_method)
|
| 398 |
+
return
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def batch_process_with_options(files: list[ProcessEntry], options, progress):
|
| 402 |
+
global clip_text, process_mgr
|
| 403 |
+
|
| 404 |
+
release_resources()
|
| 405 |
+
limit_resources()
|
| 406 |
+
if process_mgr is None:
|
| 407 |
+
process_mgr = ProcessMgr(progress)
|
| 408 |
+
process_mgr.initialize(
|
| 409 |
+
roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options
|
| 410 |
+
)
|
| 411 |
+
roop.globals.keep_frames = False
|
| 412 |
+
roop.globals.wait_after_extraction = False
|
| 413 |
+
roop.globals.skip_audio = False
|
| 414 |
+
batch_process("Files", files, True)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def batch_process(output_method, files: list[ProcessEntry], use_new_method) -> None:
|
| 418 |
+
global clip_text, process_mgr
|
| 419 |
+
|
| 420 |
+
roop.globals.processing = True
|
| 421 |
+
|
| 422 |
+
# limit threads for some providers
|
| 423 |
+
max_threads = suggest_execution_threads()
|
| 424 |
+
if max_threads == 1:
|
| 425 |
+
roop.globals.execution_threads = 1
|
| 426 |
+
|
| 427 |
+
imagefiles: list[ProcessEntry] = []
|
| 428 |
+
videofiles: list[ProcessEntry] = []
|
| 429 |
+
|
| 430 |
+
update_status("Sorting videos/images")
|
| 431 |
+
|
| 432 |
+
for index, f in enumerate(files):
|
| 433 |
+
fullname = f.filename
|
| 434 |
+
if util.has_image_extension(fullname):
|
| 435 |
+
destination = util.get_destfilename_from_path(
|
| 436 |
+
fullname,
|
| 437 |
+
roop.globals.output_path,
|
| 438 |
+
f".{roop.globals.CFG.output_image_format}",
|
| 439 |
+
)
|
| 440 |
+
destination = util.replace_template(destination, index=index)
|
| 441 |
+
pathlib.Path(os.path.dirname(destination)).mkdir(
|
| 442 |
+
parents=True, exist_ok=True
|
| 443 |
+
)
|
| 444 |
+
f.finalname = destination
|
| 445 |
+
imagefiles.append(f)
|
| 446 |
+
|
| 447 |
+
elif util.is_video(fullname) or util.has_extension(fullname, ["gif"]):
|
| 448 |
+
destination = util.get_destfilename_from_path(
|
| 449 |
+
fullname,
|
| 450 |
+
roop.globals.output_path,
|
| 451 |
+
f"__temp.{roop.globals.CFG.output_video_format}",
|
| 452 |
+
)
|
| 453 |
+
f.finalname = destination
|
| 454 |
+
videofiles.append(f)
|
| 455 |
+
|
| 456 |
+
if len(imagefiles) > 0:
|
| 457 |
+
update_status("Processing image(s)")
|
| 458 |
+
origimages = []
|
| 459 |
+
fakeimages = []
|
| 460 |
+
for f in imagefiles:
|
| 461 |
+
origimages.append(f.filename)
|
| 462 |
+
fakeimages.append(f.finalname)
|
| 463 |
+
|
| 464 |
+
process_mgr.run_batch(origimages, fakeimages, roop.globals.execution_threads)
|
| 465 |
+
origimages.clear()
|
| 466 |
+
fakeimages.clear()
|
| 467 |
+
|
| 468 |
+
if len(videofiles) > 0:
|
| 469 |
+
for index, v in enumerate(videofiles):
|
| 470 |
+
if not roop.globals.processing:
|
| 471 |
+
end_processing("Processing stopped!")
|
| 472 |
+
return
|
| 473 |
+
fps = v.fps if v.fps > 0 else util.detect_fps(v.filename)
|
| 474 |
+
if v.endframe == 0:
|
| 475 |
+
v.endframe = get_video_frame_total(v.filename)
|
| 476 |
+
|
| 477 |
+
is_streaming_only = output_method == "Virtual Camera"
|
| 478 |
+
if is_streaming_only == False:
|
| 479 |
+
update_status(
|
| 480 |
+
f"Creating {os.path.basename(v.finalname)} with {fps} FPS..."
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
start_processing = time()
|
| 484 |
+
if (
|
| 485 |
+
is_streaming_only == False
|
| 486 |
+
and roop.globals.keep_frames
|
| 487 |
+
or not use_new_method
|
| 488 |
+
):
|
| 489 |
+
util.create_temp(v.filename)
|
| 490 |
+
update_status("Extracting frames...")
|
| 491 |
+
ffmpeg.extract_frames(v.filename, v.startframe, v.endframe, fps)
|
| 492 |
+
if not roop.globals.processing:
|
| 493 |
+
end_processing("Processing stopped!")
|
| 494 |
+
return
|
| 495 |
+
|
| 496 |
+
temp_frame_paths = util.get_temp_frame_paths(v.filename)
|
| 497 |
+
process_mgr.run_batch(
|
| 498 |
+
temp_frame_paths, temp_frame_paths, roop.globals.execution_threads
|
| 499 |
+
)
|
| 500 |
+
if not roop.globals.processing:
|
| 501 |
+
end_processing("Processing stopped!")
|
| 502 |
+
return
|
| 503 |
+
if roop.globals.wait_after_extraction:
|
| 504 |
+
extract_path = os.path.dirname(temp_frame_paths[0])
|
| 505 |
+
util.open_folder(extract_path)
|
| 506 |
+
input("Press any key to continue...")
|
| 507 |
+
print("Resorting frames to create video")
|
| 508 |
+
util.sort_rename_frames(extract_path)
|
| 509 |
+
|
| 510 |
+
ffmpeg.create_video(v.filename, v.finalname, fps)
|
| 511 |
+
if not roop.globals.keep_frames:
|
| 512 |
+
util.delete_temp_frames(temp_frame_paths[0])
|
| 513 |
+
else:
|
| 514 |
+
if util.has_extension(v.filename, ["gif"]):
|
| 515 |
+
skip_audio = True
|
| 516 |
+
else:
|
| 517 |
+
skip_audio = roop.globals.skip_audio
|
| 518 |
+
process_mgr.run_batch_inmem(
|
| 519 |
+
output_method,
|
| 520 |
+
v.filename,
|
| 521 |
+
v.finalname,
|
| 522 |
+
v.startframe,
|
| 523 |
+
v.endframe,
|
| 524 |
+
fps,
|
| 525 |
+
roop.globals.execution_threads,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
if not roop.globals.processing:
|
| 529 |
+
end_processing("Processing stopped!")
|
| 530 |
+
return
|
| 531 |
+
|
| 532 |
+
video_file_name = v.finalname
|
| 533 |
+
if os.path.isfile(video_file_name):
|
| 534 |
+
destination = ""
|
| 535 |
+
if util.has_extension(v.filename, ["gif"]):
|
| 536 |
+
gifname = util.get_destfilename_from_path(
|
| 537 |
+
v.filename, roop.globals.output_path, ".gif"
|
| 538 |
+
)
|
| 539 |
+
destination = util.replace_template(gifname, index=index)
|
| 540 |
+
pathlib.Path(os.path.dirname(destination)).mkdir(
|
| 541 |
+
parents=True, exist_ok=True
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
update_status("Creating final GIF")
|
| 545 |
+
ffmpeg.create_gif_from_video(video_file_name, destination)
|
| 546 |
+
if os.path.isfile(destination):
|
| 547 |
+
os.remove(video_file_name)
|
| 548 |
+
else:
|
| 549 |
+
skip_audio = roop.globals.skip_audio
|
| 550 |
+
destination = util.replace_template(video_file_name, index=index)
|
| 551 |
+
pathlib.Path(os.path.dirname(destination)).mkdir(
|
| 552 |
+
parents=True, exist_ok=True
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
if not skip_audio:
|
| 556 |
+
ffmpeg.restore_audio(
|
| 557 |
+
video_file_name,
|
| 558 |
+
v.filename,
|
| 559 |
+
v.startframe,
|
| 560 |
+
v.endframe,
|
| 561 |
+
destination,
|
| 562 |
+
)
|
| 563 |
+
if os.path.isfile(destination):
|
| 564 |
+
os.remove(video_file_name)
|
| 565 |
+
else:
|
| 566 |
+
shutil.move(video_file_name, destination)
|
| 567 |
+
|
| 568 |
+
elif is_streaming_only == False:
|
| 569 |
+
update_status(f"Failed processing {os.path.basename(v.finalname)}!")
|
| 570 |
+
elapsed_time = time() - start_processing
|
| 571 |
+
average_fps = (v.endframe - v.startframe) / elapsed_time
|
| 572 |
+
update_status(
|
| 573 |
+
f"\nProcessing {os.path.basename(destination)} took {elapsed_time:.2f} secs, {average_fps:.2f} frames/s"
|
| 574 |
+
)
|
| 575 |
+
end_processing("Finished")
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def end_processing(msg: str):
|
| 579 |
+
update_status(msg)
|
| 580 |
+
roop.globals.target_folder_path = None
|
| 581 |
+
release_resources()
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def destroy() -> None:
|
| 585 |
+
if roop.globals.target_path:
|
| 586 |
+
util.clean_temp(roop.globals.target_path)
|
| 587 |
+
release_resources()
|
| 588 |
+
sys.exit()
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def run() -> None:
|
| 592 |
+
parse_args()
|
| 593 |
+
if not pre_check():
|
| 594 |
+
return
|
| 595 |
+
roop.globals.CFG = Settings("config.yaml")
|
| 596 |
+
roop.globals.cuda_device_id = roop.globals.startup_args.cuda_device_id
|
| 597 |
+
roop.globals.execution_threads = roop.globals.CFG.max_threads
|
| 598 |
+
roop.globals.video_encoder = roop.globals.CFG.output_video_codec
|
| 599 |
+
roop.globals.video_quality = roop.globals.CFG.video_quality
|
| 600 |
+
roop.globals.max_memory = (
|
| 601 |
+
roop.globals.CFG.memory_limit if roop.globals.CFG.memory_limit > 0 else None
|
| 602 |
+
)
|
| 603 |
+
if roop.globals.startup_args.server_share:
|
| 604 |
+
roop.globals.CFG.server_share = True
|
| 605 |
+
main.run()
|
face_util.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
from typing import Any
|
| 3 |
+
import insightface
|
| 4 |
+
|
| 5 |
+
import roop.globals
|
| 6 |
+
from roop.typing import Frame, Face
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from skimage import transform as trans
|
| 11 |
+
from roop.capturer import get_video_frame
|
| 12 |
+
from roop.utilities import resolve_relative_path, conditional_thread_semaphore
|
| 13 |
+
|
| 14 |
+
FACE_ANALYSER = None
|
| 15 |
+
# THREAD_LOCK_ANALYSER = threading.Lock()
|
| 16 |
+
# THREAD_LOCK_SWAPPER = threading.Lock()
|
| 17 |
+
FACE_SWAPPER = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_face_analyser() -> Any:
|
| 21 |
+
global FACE_ANALYSER
|
| 22 |
+
|
| 23 |
+
with conditional_thread_semaphore():
|
| 24 |
+
if (
|
| 25 |
+
FACE_ANALYSER is None
|
| 26 |
+
or roop.globals.g_current_face_analysis
|
| 27 |
+
!= roop.globals.g_desired_face_analysis
|
| 28 |
+
):
|
| 29 |
+
model_path = resolve_relative_path("..")
|
| 30 |
+
# removed genderage
|
| 31 |
+
allowed_modules = roop.globals.g_desired_face_analysis
|
| 32 |
+
roop.globals.g_current_face_analysis = roop.globals.g_desired_face_analysis
|
| 33 |
+
if roop.globals.CFG.force_cpu:
|
| 34 |
+
print("Forcing CPU for Face Analysis")
|
| 35 |
+
FACE_ANALYSER = insightface.app.FaceAnalysis(
|
| 36 |
+
name="buffalo_l",
|
| 37 |
+
root=model_path,
|
| 38 |
+
providers=["CPUExecutionProvider"],
|
| 39 |
+
allowed_modules=allowed_modules,
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
FACE_ANALYSER = insightface.app.FaceAnalysis(
|
| 43 |
+
name="buffalo_l",
|
| 44 |
+
root=model_path,
|
| 45 |
+
providers=roop.globals.execution_providers,
|
| 46 |
+
allowed_modules=allowed_modules,
|
| 47 |
+
)
|
| 48 |
+
FACE_ANALYSER.prepare(
|
| 49 |
+
ctx_id=0,
|
| 50 |
+
det_size=(640, 640) if roop.globals.default_det_size else (320, 320),
|
| 51 |
+
)
|
| 52 |
+
return FACE_ANALYSER
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_first_face(frame: Frame) -> Any:
|
| 56 |
+
try:
|
| 57 |
+
faces = get_face_analyser().get(frame)
|
| 58 |
+
return min(faces, key=lambda x: x.bbox[0])
|
| 59 |
+
# return sorted(faces, reverse=True, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[0]
|
| 60 |
+
except:
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_all_faces(frame: Frame) -> Any:
|
| 65 |
+
try:
|
| 66 |
+
faces = get_face_analyser().get(frame)
|
| 67 |
+
return sorted(faces, key=lambda x: x.bbox[0])
|
| 68 |
+
except:
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def extract_face_images(source_filename, video_info, extra_padding=-1.0):
|
| 73 |
+
face_data = []
|
| 74 |
+
source_image = None
|
| 75 |
+
|
| 76 |
+
if video_info[0]:
|
| 77 |
+
frame = get_video_frame(source_filename, video_info[1])
|
| 78 |
+
if frame is not None:
|
| 79 |
+
source_image = frame
|
| 80 |
+
else:
|
| 81 |
+
return face_data
|
| 82 |
+
else:
|
| 83 |
+
source_image = cv2.imdecode(
|
| 84 |
+
np.fromfile(source_filename, dtype=np.uint8), cv2.IMREAD_COLOR
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
faces = get_all_faces(source_image)
|
| 88 |
+
if faces is None:
|
| 89 |
+
return face_data
|
| 90 |
+
|
| 91 |
+
i = 0
|
| 92 |
+
for face in faces:
|
| 93 |
+
(startX, startY, endX, endY) = face["bbox"].astype("int")
|
| 94 |
+
startX, endX, startY, endY = clamp_cut_values(
|
| 95 |
+
startX, endX, startY, endY, source_image
|
| 96 |
+
)
|
| 97 |
+
if extra_padding > 0.0:
|
| 98 |
+
if source_image.shape[:2] == (512, 512):
|
| 99 |
+
i += 1
|
| 100 |
+
face_data.append([face, source_image])
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
found = False
|
| 104 |
+
for i in range(1, 3):
|
| 105 |
+
(startX, startY, endX, endY) = face["bbox"].astype("int")
|
| 106 |
+
startX, endX, startY, endY = clamp_cut_values(
|
| 107 |
+
startX, endX, startY, endY, source_image
|
| 108 |
+
)
|
| 109 |
+
cutout_padding = extra_padding
|
| 110 |
+
# top needs extra room for detection
|
| 111 |
+
padding = int((endY - startY) * cutout_padding)
|
| 112 |
+
oldY = startY
|
| 113 |
+
startY -= padding
|
| 114 |
+
|
| 115 |
+
factor = 0.25 if i == 1 else 0.5
|
| 116 |
+
cutout_padding = factor
|
| 117 |
+
padding = int((endY - oldY) * cutout_padding)
|
| 118 |
+
endY += padding
|
| 119 |
+
padding = int((endX - startX) * cutout_padding)
|
| 120 |
+
startX -= padding
|
| 121 |
+
endX += padding
|
| 122 |
+
startX, endX, startY, endY = clamp_cut_values(
|
| 123 |
+
startX, endX, startY, endY, source_image
|
| 124 |
+
)
|
| 125 |
+
face_temp = source_image[startY:endY, startX:endX]
|
| 126 |
+
face_temp = resize_image_keep_content(face_temp)
|
| 127 |
+
testfaces = get_all_faces(face_temp)
|
| 128 |
+
if testfaces is not None and len(testfaces) > 0:
|
| 129 |
+
i += 1
|
| 130 |
+
face_data.append([testfaces[0], face_temp])
|
| 131 |
+
found = True
|
| 132 |
+
break
|
| 133 |
+
|
| 134 |
+
if not found:
|
| 135 |
+
print("No face found after resizing, this shouldn't happen!")
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
face_temp = source_image[startY:endY, startX:endX]
|
| 139 |
+
if face_temp.size < 1:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
i += 1
|
| 143 |
+
face_data.append([face, face_temp])
|
| 144 |
+
return face_data
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def clamp_cut_values(startX, endX, startY, endY, image):
|
| 148 |
+
if startX < 0:
|
| 149 |
+
startX = 0
|
| 150 |
+
if endX > image.shape[1]:
|
| 151 |
+
endX = image.shape[1]
|
| 152 |
+
if startY < 0:
|
| 153 |
+
startY = 0
|
| 154 |
+
if endY > image.shape[0]:
|
| 155 |
+
endY = image.shape[0]
|
| 156 |
+
return startX, endX, startY, endY
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def face_offset_top(face: Face, offset):
|
| 160 |
+
face["bbox"][1] += offset
|
| 161 |
+
face["bbox"][3] += offset
|
| 162 |
+
lm106 = face.landmark_2d_106
|
| 163 |
+
add = np.full_like(lm106, [0, offset])
|
| 164 |
+
face["landmark_2d_106"] = lm106 + add
|
| 165 |
+
return face
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def resize_image_keep_content(image, new_width=512, new_height=512):
|
| 169 |
+
dim = None
|
| 170 |
+
(h, w) = image.shape[:2]
|
| 171 |
+
if h > w:
|
| 172 |
+
r = new_height / float(h)
|
| 173 |
+
dim = (int(w * r), new_height)
|
| 174 |
+
else:
|
| 175 |
+
# Calculate the ratio of the width and construct the dimensions
|
| 176 |
+
r = new_width / float(w)
|
| 177 |
+
dim = (new_width, int(h * r))
|
| 178 |
+
image = cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
|
| 179 |
+
(h, w) = image.shape[:2]
|
| 180 |
+
if h == new_height and w == new_width:
|
| 181 |
+
return image
|
| 182 |
+
resize_img = np.zeros(shape=(new_height, new_width, 3), dtype=image.dtype)
|
| 183 |
+
offs = (new_width - w) if h == new_height else (new_height - h)
|
| 184 |
+
startoffs = int(offs // 2) if offs % 2 == 0 else int(offs // 2) + 1
|
| 185 |
+
offs = int(offs // 2)
|
| 186 |
+
|
| 187 |
+
if h == new_height:
|
| 188 |
+
resize_img[0:new_height, startoffs : new_width - offs] = image
|
| 189 |
+
else:
|
| 190 |
+
resize_img[startoffs : new_height - offs, 0:new_width] = image
|
| 191 |
+
return resize_img
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def rotate_image_90(image, rotate=True):
|
| 195 |
+
if rotate:
|
| 196 |
+
return np.rot90(image)
|
| 197 |
+
else:
|
| 198 |
+
return np.rot90(image, 1, (1, 0))
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def rotate_anticlockwise(frame):
|
| 202 |
+
return rotate_image_90(frame)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def rotate_clockwise(frame):
|
| 206 |
+
return rotate_image_90(frame, False)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def rotate_image_180(image):
|
| 210 |
+
return np.flip(image, 0)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# alignment code from insightface https://github.com/deepinsight/insightface/blob/master/python-package/insightface/utils/face_align.py
|
| 214 |
+
|
| 215 |
+
arcface_dst = np.array(
|
| 216 |
+
[
|
| 217 |
+
[38.2946, 51.6963],
|
| 218 |
+
[73.5318, 51.5014],
|
| 219 |
+
[56.0252, 71.7366],
|
| 220 |
+
[41.5493, 92.3655],
|
| 221 |
+
[70.7299, 92.2041],
|
| 222 |
+
],
|
| 223 |
+
dtype=np.float32,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
""" def estimate_norm(lmk, image_size=112):
|
| 228 |
+
assert lmk.shape == (5, 2)
|
| 229 |
+
if image_size % 112 == 0:
|
| 230 |
+
ratio = float(image_size) / 112.0
|
| 231 |
+
diff_x = 0
|
| 232 |
+
elif image_size % 128 == 0:
|
| 233 |
+
ratio = float(image_size) / 128.0
|
| 234 |
+
diff_x = 8.0 * ratio
|
| 235 |
+
elif image_size % 512 == 0:
|
| 236 |
+
ratio = float(image_size) / 512.0
|
| 237 |
+
diff_x = 32.0 * ratio
|
| 238 |
+
|
| 239 |
+
dst = arcface_dst * ratio
|
| 240 |
+
dst[:, 0] += diff_x
|
| 241 |
+
tform = trans.SimilarityTransform()
|
| 242 |
+
tform.estimate(lmk, dst)
|
| 243 |
+
M = tform.params[0:2, :]
|
| 244 |
+
return M
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def estimate_norm(lmk, image_size=112):
|
| 249 |
+
if image_size % 112 == 0:
|
| 250 |
+
ratio = float(image_size) / 112.0
|
| 251 |
+
diff_x = 0
|
| 252 |
+
else:
|
| 253 |
+
ratio = float(image_size) / 128.0
|
| 254 |
+
diff_x = 8.0 * ratio
|
| 255 |
+
dst = arcface_dst * ratio
|
| 256 |
+
dst[:, 0] += diff_x
|
| 257 |
+
|
| 258 |
+
if image_size == 160:
|
| 259 |
+
dst[:, 0] += 0.1
|
| 260 |
+
dst[:, 1] += 0.1
|
| 261 |
+
elif image_size == 256:
|
| 262 |
+
dst[:, 0] += 0.5
|
| 263 |
+
dst[:, 1] += 0.5
|
| 264 |
+
elif image_size == 320:
|
| 265 |
+
dst[:, 0] += 0.75
|
| 266 |
+
dst[:, 1] += 0.75
|
| 267 |
+
elif image_size == 512:
|
| 268 |
+
dst[:, 0] += 1.5
|
| 269 |
+
dst[:, 1] += 1.5
|
| 270 |
+
|
| 271 |
+
tform = trans.SimilarityTransform()
|
| 272 |
+
tform.estimate(lmk, dst)
|
| 273 |
+
M = tform.params[0:2, :]
|
| 274 |
+
return M
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# aligned, M = norm_crop2(f[1], face.kps, 512)
|
| 278 |
+
def align_crop(img, landmark, image_size=112, mode="arcface"):
|
| 279 |
+
M = estimate_norm(landmark, image_size)
|
| 280 |
+
warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
|
| 281 |
+
return warped, M
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def square_crop(im, S):
|
| 285 |
+
if im.shape[0] > im.shape[1]:
|
| 286 |
+
height = S
|
| 287 |
+
width = int(float(im.shape[1]) / im.shape[0] * S)
|
| 288 |
+
scale = float(S) / im.shape[0]
|
| 289 |
+
else:
|
| 290 |
+
width = S
|
| 291 |
+
height = int(float(im.shape[0]) / im.shape[1] * S)
|
| 292 |
+
scale = float(S) / im.shape[1]
|
| 293 |
+
resized_im = cv2.resize(im, (width, height))
|
| 294 |
+
det_im = np.zeros((S, S, 3), dtype=np.uint8)
|
| 295 |
+
det_im[: resized_im.shape[0], : resized_im.shape[1], :] = resized_im
|
| 296 |
+
return det_im, scale
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def transform(data, center, output_size, scale, rotation):
|
| 300 |
+
scale_ratio = scale
|
| 301 |
+
rot = float(rotation) * np.pi / 180.0
|
| 302 |
+
# translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
|
| 303 |
+
t1 = trans.SimilarityTransform(scale=scale_ratio)
|
| 304 |
+
cx = center[0] * scale_ratio
|
| 305 |
+
cy = center[1] * scale_ratio
|
| 306 |
+
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
|
| 307 |
+
t3 = trans.SimilarityTransform(rotation=rot)
|
| 308 |
+
t4 = trans.SimilarityTransform(translation=(output_size / 2, output_size / 2))
|
| 309 |
+
t = t1 + t2 + t3 + t4
|
| 310 |
+
M = t.params[0:2]
|
| 311 |
+
cropped = cv2.warpAffine(data, M, (output_size, output_size), borderValue=0.0)
|
| 312 |
+
return cropped, M
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def trans_points2d(pts, M):
|
| 316 |
+
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
| 317 |
+
for i in range(pts.shape[0]):
|
| 318 |
+
pt = pts[i]
|
| 319 |
+
new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
|
| 320 |
+
new_pt = np.dot(M, new_pt)
|
| 321 |
+
# print('new_pt', new_pt.shape, new_pt)
|
| 322 |
+
new_pts[i] = new_pt[0:2]
|
| 323 |
+
|
| 324 |
+
return new_pts
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def trans_points3d(pts, M):
|
| 328 |
+
scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
|
| 329 |
+
# print(scale)
|
| 330 |
+
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
| 331 |
+
for i in range(pts.shape[0]):
|
| 332 |
+
pt = pts[i]
|
| 333 |
+
new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
|
| 334 |
+
new_pt = np.dot(M, new_pt)
|
| 335 |
+
# print('new_pt', new_pt.shape, new_pt)
|
| 336 |
+
new_pts[i][0:2] = new_pt[0:2]
|
| 337 |
+
new_pts[i][2] = pts[i][2] * scale
|
| 338 |
+
|
| 339 |
+
return new_pts
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def trans_points(pts, M):
|
| 343 |
+
if pts.shape[1] == 2:
|
| 344 |
+
return trans_points2d(pts, M)
|
| 345 |
+
else:
|
| 346 |
+
return trans_points3d(pts, M)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def create_blank_image(width, height):
|
| 350 |
+
img = np.zeros((height, width, 4), dtype=np.uint8)
|
| 351 |
+
img[:] = [0, 0, 0, 0]
|
| 352 |
+
return img
|
ffmpeg_writer.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FFMPEG_Writer - write set of frames to video file
|
| 3 |
+
|
| 4 |
+
original from
|
| 5 |
+
https://github.com/Zulko/moviepy/blob/master/moviepy/video/io/ffmpeg_writer.py
|
| 6 |
+
|
| 7 |
+
removed unnecessary dependencies
|
| 8 |
+
|
| 9 |
+
The MIT License (MIT)
|
| 10 |
+
|
| 11 |
+
Copyright (c) 2015 Zulko
|
| 12 |
+
Copyright (c) 2023 Janvarev Vladislav
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import subprocess as sp
|
| 17 |
+
|
| 18 |
+
PIPE = -1
|
| 19 |
+
STDOUT = -2
|
| 20 |
+
DEVNULL = -3
|
| 21 |
+
|
| 22 |
+
FFMPEG_BINARY = "ffmpeg"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class FFMPEG_VideoWriter:
|
| 26 |
+
"""A class for FFMPEG-based video writing.
|
| 27 |
+
|
| 28 |
+
A class to write videos using ffmpeg. ffmpeg will write in a large
|
| 29 |
+
choice of formats.
|
| 30 |
+
|
| 31 |
+
Parameters
|
| 32 |
+
-----------
|
| 33 |
+
|
| 34 |
+
filename
|
| 35 |
+
Any filename like 'video.mp4' etc. but if you want to avoid
|
| 36 |
+
complications it is recommended to use the generic extension
|
| 37 |
+
'.avi' for all your videos.
|
| 38 |
+
|
| 39 |
+
size
|
| 40 |
+
Size (width,height) of the output video in pixels.
|
| 41 |
+
|
| 42 |
+
fps
|
| 43 |
+
Frames per second in the output video file.
|
| 44 |
+
|
| 45 |
+
codec
|
| 46 |
+
FFMPEG codec. It seems that in terms of quality the hierarchy is
|
| 47 |
+
'rawvideo' = 'png' > 'mpeg4' > 'libx264'
|
| 48 |
+
'png' manages the same lossless quality as 'rawvideo' but yields
|
| 49 |
+
smaller files. Type ``ffmpeg -codecs`` in a terminal to get a list
|
| 50 |
+
of accepted codecs.
|
| 51 |
+
|
| 52 |
+
Note for default 'libx264': by default the pixel format yuv420p
|
| 53 |
+
is used. If the video dimensions are not both even (e.g. 720x405)
|
| 54 |
+
another pixel format is used, and this can cause problem in some
|
| 55 |
+
video readers.
|
| 56 |
+
|
| 57 |
+
audiofile
|
| 58 |
+
Optional: The name of an audio file that will be incorporated
|
| 59 |
+
to the video.
|
| 60 |
+
|
| 61 |
+
preset
|
| 62 |
+
Sets the time that FFMPEG will take to compress the video. The slower,
|
| 63 |
+
the better the compression rate. Possibilities are: ultrafast,superfast,
|
| 64 |
+
veryfast, faster, fast, medium (default), slow, slower, veryslow,
|
| 65 |
+
placebo.
|
| 66 |
+
|
| 67 |
+
bitrate
|
| 68 |
+
Only relevant for codecs which accept a bitrate. "5000k" offers
|
| 69 |
+
nice results in general.
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
filename,
|
| 76 |
+
size,
|
| 77 |
+
fps,
|
| 78 |
+
codec="libx265",
|
| 79 |
+
crf=14,
|
| 80 |
+
audiofile=None,
|
| 81 |
+
preset="medium",
|
| 82 |
+
bitrate=None,
|
| 83 |
+
logfile=None,
|
| 84 |
+
threads=None,
|
| 85 |
+
ffmpeg_params=None,
|
| 86 |
+
):
|
| 87 |
+
if logfile is None:
|
| 88 |
+
logfile = sp.PIPE
|
| 89 |
+
|
| 90 |
+
self.filename = filename
|
| 91 |
+
self.codec = codec
|
| 92 |
+
self.ext = self.filename.split(".")[-1]
|
| 93 |
+
w = size[0] - 1 if size[0] % 2 != 0 else size[0]
|
| 94 |
+
h = size[1] - 1 if size[1] % 2 != 0 else size[1]
|
| 95 |
+
|
| 96 |
+
# order is important
|
| 97 |
+
cmd = [
|
| 98 |
+
FFMPEG_BINARY,
|
| 99 |
+
"-hide_banner",
|
| 100 |
+
"-hwaccel",
|
| 101 |
+
"auto",
|
| 102 |
+
"-y",
|
| 103 |
+
"-loglevel",
|
| 104 |
+
"error" if logfile == sp.PIPE else "info",
|
| 105 |
+
"-f",
|
| 106 |
+
"rawvideo",
|
| 107 |
+
"-vcodec",
|
| 108 |
+
"rawvideo",
|
| 109 |
+
"-s",
|
| 110 |
+
"%dx%d" % (size[0], size[1]),
|
| 111 |
+
#'-pix_fmt', 'rgba' if withmask else 'rgb24',
|
| 112 |
+
"-pix_fmt",
|
| 113 |
+
"bgr24",
|
| 114 |
+
"-r",
|
| 115 |
+
str(fps),
|
| 116 |
+
"-an",
|
| 117 |
+
"-i",
|
| 118 |
+
"-",
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
if audiofile is not None:
|
| 122 |
+
cmd.extend(["-i", audiofile, "-acodec", "copy"])
|
| 123 |
+
|
| 124 |
+
cmd.extend(
|
| 125 |
+
[
|
| 126 |
+
"-vcodec",
|
| 127 |
+
codec,
|
| 128 |
+
"-crf",
|
| 129 |
+
str(crf),
|
| 130 |
+
#'-preset', preset,
|
| 131 |
+
]
|
| 132 |
+
)
|
| 133 |
+
if ffmpeg_params is not None:
|
| 134 |
+
cmd.extend(ffmpeg_params)
|
| 135 |
+
if bitrate is not None:
|
| 136 |
+
cmd.extend(["-b", bitrate])
|
| 137 |
+
|
| 138 |
+
# scale to a resolution divisible by 2 if not even
|
| 139 |
+
cmd.extend(
|
| 140 |
+
[
|
| 141 |
+
"-vf",
|
| 142 |
+
f"scale={w}:{h}"
|
| 143 |
+
if w != size[0] or h != size[1]
|
| 144 |
+
else "colorspace=bt709:iall=bt601-6-625:fast=1",
|
| 145 |
+
]
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
if threads is not None:
|
| 149 |
+
cmd.extend(["-threads", str(threads)])
|
| 150 |
+
|
| 151 |
+
cmd.extend(
|
| 152 |
+
[
|
| 153 |
+
"-pix_fmt",
|
| 154 |
+
"yuv420p",
|
| 155 |
+
]
|
| 156 |
+
)
|
| 157 |
+
cmd.extend([filename])
|
| 158 |
+
|
| 159 |
+
test = str(cmd)
|
| 160 |
+
print(test)
|
| 161 |
+
|
| 162 |
+
popen_params = {"stdout": DEVNULL, "stderr": logfile, "stdin": sp.PIPE}
|
| 163 |
+
|
| 164 |
+
# This was added so that no extra unwanted window opens on windows
|
| 165 |
+
# when the child process is created
|
| 166 |
+
if os.name == "nt":
|
| 167 |
+
popen_params["creationflags"] = 0x08000000 # CREATE_NO_WINDOW
|
| 168 |
+
|
| 169 |
+
self.proc = sp.Popen(cmd, **popen_params)
|
| 170 |
+
|
| 171 |
+
def write_frame(self, img_array):
|
| 172 |
+
"""Writes one frame in the file."""
|
| 173 |
+
try:
|
| 174 |
+
# if PY3:
|
| 175 |
+
self.proc.stdin.write(img_array.tobytes())
|
| 176 |
+
# else:
|
| 177 |
+
# self.proc.stdin.write(img_array.tostring())
|
| 178 |
+
except IOError as err:
|
| 179 |
+
_, ffmpeg_error = self.proc.communicate()
|
| 180 |
+
error = str(err) + (
|
| 181 |
+
"\n\nroop unleashed error: FFMPEG encountered "
|
| 182 |
+
"the following error while writing file %s:"
|
| 183 |
+
"\n\n %s" % (self.filename, str(ffmpeg_error))
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if b"Unknown encoder" in ffmpeg_error:
|
| 187 |
+
error = error + (
|
| 188 |
+
"\n\nThe video export "
|
| 189 |
+
"failed because FFMPEG didn't find the specified "
|
| 190 |
+
"codec for video encoding (%s). Please install "
|
| 191 |
+
"this codec or change the codec when calling "
|
| 192 |
+
"write_videofile. For instance:\n"
|
| 193 |
+
" >>> clip.write_videofile('myvid.webm', codec='libvpx')"
|
| 194 |
+
) % (self.codec)
|
| 195 |
+
|
| 196 |
+
elif b"incorrect codec parameters ?" in ffmpeg_error:
|
| 197 |
+
error = error + (
|
| 198 |
+
"\n\nThe video export "
|
| 199 |
+
"failed, possibly because the codec specified for "
|
| 200 |
+
"the video (%s) is not compatible with the given "
|
| 201 |
+
"extension (%s). Please specify a valid 'codec' "
|
| 202 |
+
"argument in write_videofile. This would be 'libx264' "
|
| 203 |
+
"or 'mpeg4' for mp4, 'libtheora' for ogv, 'libvpx for webm. "
|
| 204 |
+
"Another possible reason is that the audio codec was not "
|
| 205 |
+
"compatible with the video codec. For instance the video "
|
| 206 |
+
"extensions 'ogv' and 'webm' only allow 'libvorbis' (default) as a"
|
| 207 |
+
"video codec."
|
| 208 |
+
) % (self.codec, self.ext)
|
| 209 |
+
|
| 210 |
+
elif b"encoder setup failed" in ffmpeg_error:
|
| 211 |
+
error = error + (
|
| 212 |
+
"\n\nThe video export "
|
| 213 |
+
"failed, possibly because the bitrate you specified "
|
| 214 |
+
"was too high or too low for the video codec."
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
elif b"Invalid encoder type" in ffmpeg_error:
|
| 218 |
+
error = error + (
|
| 219 |
+
"\n\nThe video export failed because the codec "
|
| 220 |
+
"or file extension you provided is not a video"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
raise IOError(error)
|
| 224 |
+
|
| 225 |
+
def close(self):
|
| 226 |
+
if self.proc:
|
| 227 |
+
self.proc.stdin.close()
|
| 228 |
+
if self.proc.stderr is not None:
|
| 229 |
+
self.proc.stderr.close()
|
| 230 |
+
self.proc.wait()
|
| 231 |
+
|
| 232 |
+
self.proc = None
|
| 233 |
+
|
| 234 |
+
# Support the Context Manager protocol, to ensure that resources are cleaned up.
|
| 235 |
+
|
| 236 |
+
def __enter__(self):
|
| 237 |
+
return self
|
| 238 |
+
|
| 239 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 240 |
+
self.close()
|
globals.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from settings import Settings
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
source_path = None
|
| 5 |
+
target_path = None
|
| 6 |
+
output_path = None
|
| 7 |
+
target_folder_path = None
|
| 8 |
+
startup_args = None
|
| 9 |
+
|
| 10 |
+
cuda_device_id = 0
|
| 11 |
+
frame_processors: List[str] = []
|
| 12 |
+
keep_fps = None
|
| 13 |
+
keep_frames = None
|
| 14 |
+
autorotate_faces = None
|
| 15 |
+
vr_mode = None
|
| 16 |
+
skip_audio = None
|
| 17 |
+
wait_after_extraction = None
|
| 18 |
+
many_faces = None
|
| 19 |
+
use_batch = None
|
| 20 |
+
source_face_index = 0
|
| 21 |
+
target_face_index = 0
|
| 22 |
+
face_position = None
|
| 23 |
+
video_encoder = None
|
| 24 |
+
video_quality = None
|
| 25 |
+
max_memory = None
|
| 26 |
+
execution_providers: List[str] = []
|
| 27 |
+
execution_threads = None
|
| 28 |
+
headless = None
|
| 29 |
+
log_level = "error"
|
| 30 |
+
selected_enhancer = None
|
| 31 |
+
subsample_size = 128
|
| 32 |
+
face_swap_mode = None
|
| 33 |
+
blend_ratio = 0.5
|
| 34 |
+
distance_threshold = 0.65
|
| 35 |
+
default_det_size = True
|
| 36 |
+
|
| 37 |
+
no_face_action = 0
|
| 38 |
+
|
| 39 |
+
processing = False
|
| 40 |
+
|
| 41 |
+
g_current_face_analysis = None
|
| 42 |
+
g_desired_face_analysis = None
|
| 43 |
+
|
| 44 |
+
FACE_ENHANCER = None
|
| 45 |
+
|
| 46 |
+
INPUT_FACESETS = []
|
| 47 |
+
TARGET_FACES = []
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
IMAGE_CHAIN_PROCESSOR = None
|
| 51 |
+
VIDEO_CHAIN_PROCESSOR = None
|
| 52 |
+
BATCH_IMAGE_CHAIN_PROCESSOR = None
|
| 53 |
+
|
| 54 |
+
CFG: Settings = None
|
metadata.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name = "roop unleashed"
|
| 2 |
+
version = "4.4.1"
|
processors/Enhance_CodeFormer.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List, Callable
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import onnxruntime
|
| 5 |
+
import roop.globals
|
| 6 |
+
|
| 7 |
+
from roop.typing import Face, Frame, FaceSet
|
| 8 |
+
from roop.utilities import resolve_relative_path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Enhance_CodeFormer:
|
| 12 |
+
model_codeformer = None
|
| 13 |
+
|
| 14 |
+
plugin_options: dict = None
|
| 15 |
+
|
| 16 |
+
processorname = "codeformer"
|
| 17 |
+
type = "enhance"
|
| 18 |
+
|
| 19 |
+
def Initialize(self, plugin_options: dict):
|
| 20 |
+
if self.plugin_options is not None:
|
| 21 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
| 22 |
+
self.Release()
|
| 23 |
+
|
| 24 |
+
self.plugin_options = plugin_options
|
| 25 |
+
if self.model_codeformer is None:
|
| 26 |
+
# replace Mac mps with cpu for the moment
|
| 27 |
+
self.devicename = self.plugin_options["devicename"].replace("mps", "cpu")
|
| 28 |
+
model_path = resolve_relative_path(
|
| 29 |
+
"../models/CodeFormer/CodeFormerv0.1.onnx"
|
| 30 |
+
)
|
| 31 |
+
self.model_codeformer = onnxruntime.InferenceSession(
|
| 32 |
+
model_path, None, providers=roop.globals.execution_providers
|
| 33 |
+
)
|
| 34 |
+
self.model_inputs = self.model_codeformer.get_inputs()
|
| 35 |
+
model_outputs = self.model_codeformer.get_outputs()
|
| 36 |
+
self.io_binding = self.model_codeformer.io_binding()
|
| 37 |
+
self.io_binding.bind_cpu_input(self.model_inputs[1].name, np.array([0.5]))
|
| 38 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
| 39 |
+
|
| 40 |
+
def Run(
|
| 41 |
+
self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame
|
| 42 |
+
) -> Frame:
|
| 43 |
+
input_size = temp_frame.shape[1]
|
| 44 |
+
# preprocess
|
| 45 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
| 46 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
| 47 |
+
temp_frame = temp_frame.astype("float32") / 255.0
|
| 48 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
| 49 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
| 50 |
+
|
| 51 |
+
self.io_binding.bind_cpu_input(
|
| 52 |
+
self.model_inputs[0].name, temp_frame.astype(np.float32)
|
| 53 |
+
)
|
| 54 |
+
self.model_codeformer.run_with_iobinding(self.io_binding)
|
| 55 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
| 56 |
+
result = ort_outs[0][0]
|
| 57 |
+
del ort_outs
|
| 58 |
+
|
| 59 |
+
# post-process
|
| 60 |
+
result = result.transpose((1, 2, 0))
|
| 61 |
+
|
| 62 |
+
un_min = -1.0
|
| 63 |
+
un_max = 1.0
|
| 64 |
+
result = np.clip(result, un_min, un_max)
|
| 65 |
+
result = (result - un_min) / (un_max - un_min)
|
| 66 |
+
|
| 67 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
| 68 |
+
result = (result * 255.0).round()
|
| 69 |
+
scale_factor = int(result.shape[1] / input_size)
|
| 70 |
+
return result.astype(np.uint8), scale_factor
|
| 71 |
+
|
| 72 |
+
def Release(self):
|
| 73 |
+
del self.model_codeformer
|
| 74 |
+
self.model_codeformer = None
|
| 75 |
+
del self.io_binding
|
| 76 |
+
self.io_binding = None
|
processors/Enhance_DMDNet.py
ADDED
|
@@ -0,0 +1,1425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List, Callable
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.nn.utils.spectral_norm as SpectralNorm
|
| 8 |
+
import threading
|
| 9 |
+
from torchvision.ops import roi_align
|
| 10 |
+
|
| 11 |
+
from math import sqrt
|
| 12 |
+
|
| 13 |
+
from torchvision.transforms.functional import normalize
|
| 14 |
+
|
| 15 |
+
from roop.typing import Face, Frame, FaceSet
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
THREAD_LOCK_DMDNET = threading.Lock()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Enhance_DMDNet:
|
| 22 |
+
plugin_options: dict = None
|
| 23 |
+
model_dmdnet = None
|
| 24 |
+
torchdevice = None
|
| 25 |
+
|
| 26 |
+
processorname = "dmdnet"
|
| 27 |
+
type = "enhance"
|
| 28 |
+
|
| 29 |
+
def Initialize(self, plugin_options: dict):
|
| 30 |
+
if self.plugin_options is not None:
|
| 31 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
| 32 |
+
self.Release()
|
| 33 |
+
|
| 34 |
+
self.plugin_options = plugin_options
|
| 35 |
+
if self.model_dmdnet is None:
|
| 36 |
+
self.model_dmdnet = self.create(self.plugin_options["devicename"])
|
| 37 |
+
|
| 38 |
+
# temp_frame already cropped+aligned, bbox not
|
| 39 |
+
def Run(
|
| 40 |
+
self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame
|
| 41 |
+
) -> Frame:
|
| 42 |
+
input_size = temp_frame.shape[1]
|
| 43 |
+
|
| 44 |
+
result = self.enhance_face(source_faceset, temp_frame, target_face)
|
| 45 |
+
scale_factor = int(result.shape[1] / input_size)
|
| 46 |
+
return result.astype(np.uint8), scale_factor
|
| 47 |
+
|
| 48 |
+
def Release(self):
|
| 49 |
+
self.model_dmdnet = None
|
| 50 |
+
|
| 51 |
+
# https://stackoverflow.com/a/67174339
|
| 52 |
+
def landmarks106_to_68(self, pt106):
|
| 53 |
+
map106to68 = [
|
| 54 |
+
1,
|
| 55 |
+
10,
|
| 56 |
+
12,
|
| 57 |
+
14,
|
| 58 |
+
16,
|
| 59 |
+
3,
|
| 60 |
+
5,
|
| 61 |
+
7,
|
| 62 |
+
0,
|
| 63 |
+
23,
|
| 64 |
+
21,
|
| 65 |
+
19,
|
| 66 |
+
32,
|
| 67 |
+
30,
|
| 68 |
+
28,
|
| 69 |
+
26,
|
| 70 |
+
17,
|
| 71 |
+
43,
|
| 72 |
+
48,
|
| 73 |
+
49,
|
| 74 |
+
51,
|
| 75 |
+
50,
|
| 76 |
+
102,
|
| 77 |
+
103,
|
| 78 |
+
104,
|
| 79 |
+
105,
|
| 80 |
+
101,
|
| 81 |
+
72,
|
| 82 |
+
73,
|
| 83 |
+
74,
|
| 84 |
+
86,
|
| 85 |
+
78,
|
| 86 |
+
79,
|
| 87 |
+
80,
|
| 88 |
+
85,
|
| 89 |
+
84,
|
| 90 |
+
35,
|
| 91 |
+
41,
|
| 92 |
+
42,
|
| 93 |
+
39,
|
| 94 |
+
37,
|
| 95 |
+
36,
|
| 96 |
+
89,
|
| 97 |
+
95,
|
| 98 |
+
96,
|
| 99 |
+
93,
|
| 100 |
+
91,
|
| 101 |
+
90,
|
| 102 |
+
52,
|
| 103 |
+
64,
|
| 104 |
+
63,
|
| 105 |
+
71,
|
| 106 |
+
67,
|
| 107 |
+
68,
|
| 108 |
+
61,
|
| 109 |
+
58,
|
| 110 |
+
59,
|
| 111 |
+
53,
|
| 112 |
+
56,
|
| 113 |
+
55,
|
| 114 |
+
65,
|
| 115 |
+
66,
|
| 116 |
+
62,
|
| 117 |
+
70,
|
| 118 |
+
69,
|
| 119 |
+
57,
|
| 120 |
+
60,
|
| 121 |
+
54,
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
pt68 = []
|
| 125 |
+
for i in range(68):
|
| 126 |
+
index = map106to68[i]
|
| 127 |
+
pt68.append(pt106[index])
|
| 128 |
+
return pt68
|
| 129 |
+
|
| 130 |
+
def check_bbox(self, imgs, boxes):
|
| 131 |
+
boxes = boxes.view(-1, 4, 4)
|
| 132 |
+
colors = [(0, 255, 0), (0, 255, 0), (255, 255, 0), (255, 0, 0)]
|
| 133 |
+
i = 0
|
| 134 |
+
for img, box in zip(imgs, boxes):
|
| 135 |
+
img = (img + 1) / 2 * 255
|
| 136 |
+
img2 = img.permute(1, 2, 0).float().cpu().flip(2).numpy().copy()
|
| 137 |
+
for idx, point in enumerate(box):
|
| 138 |
+
cv2.rectangle(
|
| 139 |
+
img2,
|
| 140 |
+
(int(point[0]), int(point[1])),
|
| 141 |
+
(int(point[2]), int(point[3])),
|
| 142 |
+
color=colors[idx],
|
| 143 |
+
thickness=2,
|
| 144 |
+
)
|
| 145 |
+
cv2.imwrite("dmdnet_{:02d}.png".format(i), img2)
|
| 146 |
+
i += 1
|
| 147 |
+
|
| 148 |
+
def trans_points2d(self, pts, M):
|
| 149 |
+
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
| 150 |
+
for i in range(pts.shape[0]):
|
| 151 |
+
pt = pts[i]
|
| 152 |
+
new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
|
| 153 |
+
new_pt = np.dot(M, new_pt)
|
| 154 |
+
new_pts[i] = new_pt[0:2]
|
| 155 |
+
|
| 156 |
+
return new_pts
|
| 157 |
+
|
| 158 |
+
def enhance_face(self, ref_faceset: FaceSet, temp_frame, face: Face):
|
| 159 |
+
# preprocess
|
| 160 |
+
start_x, start_y, end_x, end_y = map(int, face["bbox"])
|
| 161 |
+
lm106 = face.landmark_2d_106
|
| 162 |
+
lq_landmarks = np.asarray(self.landmarks106_to_68(lm106))
|
| 163 |
+
|
| 164 |
+
if temp_frame.shape[0] != 512 or temp_frame.shape[1] != 512:
|
| 165 |
+
# scale to 512x512
|
| 166 |
+
scale_factor = 512 / temp_frame.shape[1]
|
| 167 |
+
|
| 168 |
+
M = face.matrix * scale_factor
|
| 169 |
+
|
| 170 |
+
lq_landmarks = self.trans_points2d(lq_landmarks, M)
|
| 171 |
+
temp_frame = cv2.resize(
|
| 172 |
+
temp_frame, (512, 512), interpolation=cv2.INTER_AREA
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if temp_frame.ndim == 2:
|
| 176 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG
|
| 177 |
+
# else:
|
| 178 |
+
# temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB
|
| 179 |
+
|
| 180 |
+
lq = read_img_tensor(temp_frame)
|
| 181 |
+
|
| 182 |
+
LQLocs = get_component_location(lq_landmarks)
|
| 183 |
+
# self.check_bbox(lq, LQLocs.unsqueeze(0))
|
| 184 |
+
|
| 185 |
+
# specific, change 1000 to 1 to activate
|
| 186 |
+
if len(ref_faceset.faces) > 1:
|
| 187 |
+
SpecificImgs = []
|
| 188 |
+
SpecificLocs = []
|
| 189 |
+
for i, face in enumerate(ref_faceset.faces):
|
| 190 |
+
lm106 = face.landmark_2d_106
|
| 191 |
+
lq_landmarks = np.asarray(self.landmarks106_to_68(lm106))
|
| 192 |
+
ref_image = ref_faceset.ref_images[i]
|
| 193 |
+
if ref_image.shape[0] != 512 or ref_image.shape[1] != 512:
|
| 194 |
+
# scale to 512x512
|
| 195 |
+
scale_factor = 512 / ref_image.shape[1]
|
| 196 |
+
|
| 197 |
+
M = face.matrix * scale_factor
|
| 198 |
+
|
| 199 |
+
lq_landmarks = self.trans_points2d(lq_landmarks, M)
|
| 200 |
+
ref_image = cv2.resize(
|
| 201 |
+
ref_image, (512, 512), interpolation=cv2.INTER_AREA
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if ref_image.ndim == 2:
|
| 205 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG
|
| 206 |
+
# else:
|
| 207 |
+
# temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB
|
| 208 |
+
|
| 209 |
+
ref_tensor = read_img_tensor(ref_image)
|
| 210 |
+
ref_locs = get_component_location(lq_landmarks)
|
| 211 |
+
# self.check_bbox(ref_tensor, ref_locs.unsqueeze(0))
|
| 212 |
+
|
| 213 |
+
SpecificImgs.append(ref_tensor)
|
| 214 |
+
SpecificLocs.append(ref_locs.unsqueeze(0))
|
| 215 |
+
|
| 216 |
+
SpecificImgs = torch.cat(SpecificImgs, dim=0)
|
| 217 |
+
SpecificLocs = torch.cat(SpecificLocs, dim=0)
|
| 218 |
+
# check_bbox(SpecificImgs, SpecificLocs)
|
| 219 |
+
SpMem256, SpMem128, SpMem64 = (
|
| 220 |
+
self.model_dmdnet.generate_specific_dictionary(
|
| 221 |
+
sp_imgs=SpecificImgs.to(self.torchdevice), sp_locs=SpecificLocs
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
SpMem256Para = {}
|
| 225 |
+
SpMem128Para = {}
|
| 226 |
+
SpMem64Para = {}
|
| 227 |
+
for k, v in SpMem256.items():
|
| 228 |
+
SpMem256Para[k] = v
|
| 229 |
+
for k, v in SpMem128.items():
|
| 230 |
+
SpMem128Para[k] = v
|
| 231 |
+
for k, v in SpMem64.items():
|
| 232 |
+
SpMem64Para[k] = v
|
| 233 |
+
else:
|
| 234 |
+
# generic
|
| 235 |
+
SpMem256Para, SpMem128Para, SpMem64Para = None, None, None
|
| 236 |
+
|
| 237 |
+
with torch.no_grad():
|
| 238 |
+
with THREAD_LOCK_DMDNET:
|
| 239 |
+
try:
|
| 240 |
+
GenericResult, SpecificResult = self.model_dmdnet(
|
| 241 |
+
lq=lq.to(self.torchdevice),
|
| 242 |
+
loc=LQLocs.unsqueeze(0),
|
| 243 |
+
sp_256=SpMem256Para,
|
| 244 |
+
sp_128=SpMem128Para,
|
| 245 |
+
sp_64=SpMem64Para,
|
| 246 |
+
)
|
| 247 |
+
except Exception as e:
|
| 248 |
+
print(
|
| 249 |
+
f"Error {e} there may be something wrong with the detected component locations."
|
| 250 |
+
)
|
| 251 |
+
return temp_frame
|
| 252 |
+
|
| 253 |
+
if SpecificResult is not None:
|
| 254 |
+
save_specific = SpecificResult * 0.5 + 0.5
|
| 255 |
+
save_specific = (
|
| 256 |
+
save_specific.squeeze(0).permute(1, 2, 0).flip(2)
|
| 257 |
+
) # RGB->BGR
|
| 258 |
+
save_specific = np.clip(save_specific.float().cpu().numpy(), 0, 1) * 255.0
|
| 259 |
+
temp_frame = save_specific.astype("uint8")
|
| 260 |
+
if False:
|
| 261 |
+
save_generic = GenericResult * 0.5 + 0.5
|
| 262 |
+
save_generic = (
|
| 263 |
+
save_generic.squeeze(0).permute(1, 2, 0).flip(2)
|
| 264 |
+
) # RGB->BGR
|
| 265 |
+
save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
|
| 266 |
+
check_lq = lq * 0.5 + 0.5
|
| 267 |
+
check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
| 268 |
+
check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0
|
| 269 |
+
cv2.imwrite(
|
| 270 |
+
"dmdnet_comparison.png",
|
| 271 |
+
cv2.cvtColor(
|
| 272 |
+
np.hstack((check_lq, save_generic, save_specific)),
|
| 273 |
+
cv2.COLOR_RGB2BGR,
|
| 274 |
+
),
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
save_generic = GenericResult * 0.5 + 0.5
|
| 278 |
+
save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
| 279 |
+
save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
|
| 280 |
+
temp_frame = save_generic.astype("uint8")
|
| 281 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_RGB2BGR) # RGB
|
| 282 |
+
return temp_frame
|
| 283 |
+
|
| 284 |
+
def create(self, devicename):
|
| 285 |
+
self.torchdevice = torch.device(devicename)
|
| 286 |
+
model_dmdnet = DMDNet().to(self.torchdevice)
|
| 287 |
+
weights = torch.load("./models/DMDNet.pth", map_location=self.torchdevice)
|
| 288 |
+
model_dmdnet.load_state_dict(weights, strict=False)
|
| 289 |
+
|
| 290 |
+
model_dmdnet.eval()
|
| 291 |
+
num_params = 0
|
| 292 |
+
for param in model_dmdnet.parameters():
|
| 293 |
+
num_params += param.numel()
|
| 294 |
+
return model_dmdnet
|
| 295 |
+
|
| 296 |
+
# print('{:>8s} : {}'.format('Using device', device))
|
| 297 |
+
# print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6))
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def read_img_tensor(Img=None): # rgb -1~1
|
| 301 |
+
Img = Img.transpose((2, 0, 1)) / 255.0
|
| 302 |
+
Img = torch.from_numpy(Img).float()
|
| 303 |
+
normalize(Img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
|
| 304 |
+
ImgTensor = Img.unsqueeze(0)
|
| 305 |
+
return ImgTensor
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def get_component_location(Landmarks, re_read=False):
|
| 309 |
+
if re_read:
|
| 310 |
+
ReadLandmark = []
|
| 311 |
+
with open(Landmarks, "r") as f:
|
| 312 |
+
for line in f:
|
| 313 |
+
tmp = [float(i) for i in line.split(" ") if i != "\n"]
|
| 314 |
+
ReadLandmark.append(tmp)
|
| 315 |
+
ReadLandmark = np.array(ReadLandmark) #
|
| 316 |
+
Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2
|
| 317 |
+
Map_LE_B = list(np.hstack((range(17, 22), range(36, 42))))
|
| 318 |
+
Map_RE_B = list(np.hstack((range(22, 27), range(42, 48))))
|
| 319 |
+
Map_LE = list(range(36, 42))
|
| 320 |
+
Map_RE = list(range(42, 48))
|
| 321 |
+
Map_NO = list(range(29, 36))
|
| 322 |
+
Map_MO = list(range(48, 68))
|
| 323 |
+
|
| 324 |
+
Landmarks[Landmarks > 504] = 504
|
| 325 |
+
Landmarks[Landmarks < 8] = 8
|
| 326 |
+
|
| 327 |
+
# left eye
|
| 328 |
+
Mean_LE = np.mean(Landmarks[Map_LE], 0)
|
| 329 |
+
L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B, 1])
|
| 330 |
+
L_LE1 = L_LE1 * 1.3
|
| 331 |
+
L_LE2 = L_LE1 / 1.9
|
| 332 |
+
L_LE_xy = L_LE1 + L_LE2
|
| 333 |
+
L_LE_lt = [L_LE_xy / 2, L_LE1]
|
| 334 |
+
L_LE_rb = [L_LE_xy / 2, L_LE2]
|
| 335 |
+
Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int)
|
| 336 |
+
|
| 337 |
+
# right eye
|
| 338 |
+
Mean_RE = np.mean(Landmarks[Map_RE], 0)
|
| 339 |
+
L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B, 1])
|
| 340 |
+
L_RE1 = L_RE1 * 1.3
|
| 341 |
+
L_RE2 = L_RE1 / 1.9
|
| 342 |
+
L_RE_xy = L_RE1 + L_RE2
|
| 343 |
+
L_RE_lt = [L_RE_xy / 2, L_RE1]
|
| 344 |
+
L_RE_rb = [L_RE_xy / 2, L_RE2]
|
| 345 |
+
Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int)
|
| 346 |
+
|
| 347 |
+
# nose
|
| 348 |
+
Mean_NO = np.mean(Landmarks[Map_NO], 0)
|
| 349 |
+
L_NO1 = (
|
| 350 |
+
np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]])
|
| 351 |
+
) * 1.25
|
| 352 |
+
L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1
|
| 353 |
+
L_NO_xy = L_NO1 * 2
|
| 354 |
+
L_NO_lt = [L_NO_xy / 2, L_NO_xy - L_NO2]
|
| 355 |
+
L_NO_rb = [L_NO_xy / 2, L_NO2]
|
| 356 |
+
Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int)
|
| 357 |
+
|
| 358 |
+
# mouth
|
| 359 |
+
Mean_MO = np.mean(Landmarks[Map_MO], 0)
|
| 360 |
+
L_MO = (
|
| 361 |
+
np.max(
|
| 362 |
+
(
|
| 363 |
+
np.max(np.max(Landmarks[Map_MO], 0) - np.min(Landmarks[Map_MO], 0)) / 2,
|
| 364 |
+
16,
|
| 365 |
+
)
|
| 366 |
+
)
|
| 367 |
+
* 1.1
|
| 368 |
+
)
|
| 369 |
+
MO_O = Mean_MO - L_MO + 1
|
| 370 |
+
MO_T = Mean_MO + L_MO
|
| 371 |
+
MO_T[MO_T > 510] = 510
|
| 372 |
+
Location_MO = np.hstack((MO_O, MO_T)).astype(int)
|
| 373 |
+
return torch.cat(
|
| 374 |
+
[
|
| 375 |
+
torch.FloatTensor(Location_LE).unsqueeze(0),
|
| 376 |
+
torch.FloatTensor(Location_RE).unsqueeze(0),
|
| 377 |
+
torch.FloatTensor(Location_NO).unsqueeze(0),
|
| 378 |
+
torch.FloatTensor(Location_MO).unsqueeze(0),
|
| 379 |
+
],
|
| 380 |
+
dim=0,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def calc_mean_std_4D(feat, eps=1e-5):
|
| 385 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
| 386 |
+
size = feat.size()
|
| 387 |
+
assert len(size) == 4
|
| 388 |
+
N, C = size[:2]
|
| 389 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
| 390 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
| 391 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
| 392 |
+
return feat_mean, feat_std
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def adaptive_instance_normalization_4D(
|
| 396 |
+
content_feat, style_feat
|
| 397 |
+
): # content_feat is ref feature, style is degradate feature
|
| 398 |
+
size = content_feat.size()
|
| 399 |
+
style_mean, style_std = calc_mean_std_4D(style_feat)
|
| 400 |
+
content_mean, content_std = calc_mean_std_4D(content_feat)
|
| 401 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
|
| 402 |
+
size
|
| 403 |
+
)
|
| 404 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def convU(
|
| 408 |
+
in_channels,
|
| 409 |
+
out_channels,
|
| 410 |
+
conv_layer,
|
| 411 |
+
norm_layer,
|
| 412 |
+
kernel_size=3,
|
| 413 |
+
stride=1,
|
| 414 |
+
dilation=1,
|
| 415 |
+
bias=True,
|
| 416 |
+
):
|
| 417 |
+
return nn.Sequential(
|
| 418 |
+
SpectralNorm(
|
| 419 |
+
conv_layer(
|
| 420 |
+
in_channels,
|
| 421 |
+
out_channels,
|
| 422 |
+
kernel_size=kernel_size,
|
| 423 |
+
stride=stride,
|
| 424 |
+
dilation=dilation,
|
| 425 |
+
padding=((kernel_size - 1) // 2) * dilation,
|
| 426 |
+
bias=bias,
|
| 427 |
+
)
|
| 428 |
+
),
|
| 429 |
+
nn.LeakyReLU(0.2),
|
| 430 |
+
SpectralNorm(
|
| 431 |
+
conv_layer(
|
| 432 |
+
out_channels,
|
| 433 |
+
out_channels,
|
| 434 |
+
kernel_size=kernel_size,
|
| 435 |
+
stride=stride,
|
| 436 |
+
dilation=dilation,
|
| 437 |
+
padding=((kernel_size - 1) // 2) * dilation,
|
| 438 |
+
bias=bias,
|
| 439 |
+
)
|
| 440 |
+
),
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class MSDilateBlock(nn.Module):
|
| 445 |
+
def __init__(
|
| 446 |
+
self,
|
| 447 |
+
in_channels,
|
| 448 |
+
conv_layer=nn.Conv2d,
|
| 449 |
+
norm_layer=nn.BatchNorm2d,
|
| 450 |
+
kernel_size=3,
|
| 451 |
+
dilation=[1, 1, 1, 1],
|
| 452 |
+
bias=True,
|
| 453 |
+
):
|
| 454 |
+
super(MSDilateBlock, self).__init__()
|
| 455 |
+
self.conv1 = convU(
|
| 456 |
+
in_channels,
|
| 457 |
+
in_channels,
|
| 458 |
+
conv_layer,
|
| 459 |
+
norm_layer,
|
| 460 |
+
kernel_size,
|
| 461 |
+
dilation=dilation[0],
|
| 462 |
+
bias=bias,
|
| 463 |
+
)
|
| 464 |
+
self.conv2 = convU(
|
| 465 |
+
in_channels,
|
| 466 |
+
in_channels,
|
| 467 |
+
conv_layer,
|
| 468 |
+
norm_layer,
|
| 469 |
+
kernel_size,
|
| 470 |
+
dilation=dilation[1],
|
| 471 |
+
bias=bias,
|
| 472 |
+
)
|
| 473 |
+
self.conv3 = convU(
|
| 474 |
+
in_channels,
|
| 475 |
+
in_channels,
|
| 476 |
+
conv_layer,
|
| 477 |
+
norm_layer,
|
| 478 |
+
kernel_size,
|
| 479 |
+
dilation=dilation[2],
|
| 480 |
+
bias=bias,
|
| 481 |
+
)
|
| 482 |
+
self.conv4 = convU(
|
| 483 |
+
in_channels,
|
| 484 |
+
in_channels,
|
| 485 |
+
conv_layer,
|
| 486 |
+
norm_layer,
|
| 487 |
+
kernel_size,
|
| 488 |
+
dilation=dilation[3],
|
| 489 |
+
bias=bias,
|
| 490 |
+
)
|
| 491 |
+
self.convi = SpectralNorm(
|
| 492 |
+
conv_layer(
|
| 493 |
+
in_channels * 4,
|
| 494 |
+
in_channels,
|
| 495 |
+
kernel_size=kernel_size,
|
| 496 |
+
stride=1,
|
| 497 |
+
padding=(kernel_size - 1) // 2,
|
| 498 |
+
bias=bias,
|
| 499 |
+
)
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
def forward(self, x):
|
| 503 |
+
conv1 = self.conv1(x)
|
| 504 |
+
conv2 = self.conv2(x)
|
| 505 |
+
conv3 = self.conv3(x)
|
| 506 |
+
conv4 = self.conv4(x)
|
| 507 |
+
cat = torch.cat([conv1, conv2, conv3, conv4], 1)
|
| 508 |
+
out = self.convi(cat) + x
|
| 509 |
+
return out
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class AdaptiveInstanceNorm(nn.Module):
|
| 513 |
+
def __init__(self, in_channel):
|
| 514 |
+
super().__init__()
|
| 515 |
+
self.norm = nn.InstanceNorm2d(in_channel)
|
| 516 |
+
|
| 517 |
+
def forward(self, input, style):
|
| 518 |
+
style_mean, style_std = calc_mean_std_4D(style)
|
| 519 |
+
out = self.norm(input)
|
| 520 |
+
size = input.size()
|
| 521 |
+
out = style_std.expand(size) * out + style_mean.expand(size)
|
| 522 |
+
return out
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
class NoiseInjection(nn.Module):
|
| 526 |
+
def __init__(self, channel):
|
| 527 |
+
super().__init__()
|
| 528 |
+
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
| 529 |
+
|
| 530 |
+
def forward(self, image, noise):
|
| 531 |
+
if noise is None:
|
| 532 |
+
b, c, h, w = image.shape
|
| 533 |
+
noise = image.new_empty(b, 1, h, w).normal_()
|
| 534 |
+
return image + self.weight * noise
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
class StyledUpBlock(nn.Module):
|
| 538 |
+
def __init__(
|
| 539 |
+
self,
|
| 540 |
+
in_channel,
|
| 541 |
+
out_channel,
|
| 542 |
+
kernel_size=3,
|
| 543 |
+
padding=1,
|
| 544 |
+
upsample=False,
|
| 545 |
+
noise_inject=False,
|
| 546 |
+
):
|
| 547 |
+
super().__init__()
|
| 548 |
+
|
| 549 |
+
self.noise_inject = noise_inject
|
| 550 |
+
if upsample:
|
| 551 |
+
self.conv1 = nn.Sequential(
|
| 552 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
| 553 |
+
SpectralNorm(
|
| 554 |
+
nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)
|
| 555 |
+
),
|
| 556 |
+
nn.LeakyReLU(0.2),
|
| 557 |
+
)
|
| 558 |
+
else:
|
| 559 |
+
self.conv1 = nn.Sequential(
|
| 560 |
+
SpectralNorm(
|
| 561 |
+
nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)
|
| 562 |
+
),
|
| 563 |
+
nn.LeakyReLU(0.2),
|
| 564 |
+
SpectralNorm(
|
| 565 |
+
nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)
|
| 566 |
+
),
|
| 567 |
+
)
|
| 568 |
+
self.convup = nn.Sequential(
|
| 569 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
| 570 |
+
SpectralNorm(
|
| 571 |
+
nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)
|
| 572 |
+
),
|
| 573 |
+
nn.LeakyReLU(0.2),
|
| 574 |
+
SpectralNorm(
|
| 575 |
+
nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)
|
| 576 |
+
),
|
| 577 |
+
)
|
| 578 |
+
if self.noise_inject:
|
| 579 |
+
self.noise1 = NoiseInjection(out_channel)
|
| 580 |
+
|
| 581 |
+
self.lrelu1 = nn.LeakyReLU(0.2)
|
| 582 |
+
|
| 583 |
+
self.ScaleModel1 = nn.Sequential(
|
| 584 |
+
SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)),
|
| 585 |
+
nn.LeakyReLU(0.2),
|
| 586 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
|
| 587 |
+
)
|
| 588 |
+
self.ShiftModel1 = nn.Sequential(
|
| 589 |
+
SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)),
|
| 590 |
+
nn.LeakyReLU(0.2),
|
| 591 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
def forward(self, input, style):
|
| 595 |
+
out = self.conv1(input)
|
| 596 |
+
out = self.lrelu1(out)
|
| 597 |
+
Shift1 = self.ShiftModel1(style)
|
| 598 |
+
Scale1 = self.ScaleModel1(style)
|
| 599 |
+
out = out * Scale1 + Shift1
|
| 600 |
+
if self.noise_inject:
|
| 601 |
+
out = self.noise1(out, noise=None)
|
| 602 |
+
outup = self.convup(out)
|
| 603 |
+
return outup
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
####################################################################
|
| 607 |
+
###############Face Dictionary Generator
|
| 608 |
+
####################################################################
|
| 609 |
+
def AttentionBlock(in_channel):
|
| 610 |
+
return nn.Sequential(
|
| 611 |
+
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
|
| 612 |
+
nn.LeakyReLU(0.2),
|
| 613 |
+
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
class DilateResBlock(nn.Module):
|
| 618 |
+
def __init__(self, dim, dilation=[5, 3]):
|
| 619 |
+
super(DilateResBlock, self).__init__()
|
| 620 |
+
self.Res = nn.Sequential(
|
| 621 |
+
SpectralNorm(
|
| 622 |
+
nn.Conv2d(dim, dim, 3, 1, ((3 - 1) // 2) * dilation[0], dilation[0])
|
| 623 |
+
),
|
| 624 |
+
nn.LeakyReLU(0.2),
|
| 625 |
+
SpectralNorm(
|
| 626 |
+
nn.Conv2d(dim, dim, 3, 1, ((3 - 1) // 2) * dilation[1], dilation[1])
|
| 627 |
+
),
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
def forward(self, x):
|
| 631 |
+
out = x + self.Res(x)
|
| 632 |
+
return out
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class KeyValue(nn.Module):
|
| 636 |
+
def __init__(self, indim, keydim, valdim):
|
| 637 |
+
super(KeyValue, self).__init__()
|
| 638 |
+
self.Key = nn.Sequential(
|
| 639 |
+
SpectralNorm(
|
| 640 |
+
nn.Conv2d(indim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1)
|
| 641 |
+
),
|
| 642 |
+
nn.LeakyReLU(0.2),
|
| 643 |
+
SpectralNorm(
|
| 644 |
+
nn.Conv2d(keydim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1)
|
| 645 |
+
),
|
| 646 |
+
)
|
| 647 |
+
self.Value = nn.Sequential(
|
| 648 |
+
SpectralNorm(
|
| 649 |
+
nn.Conv2d(indim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1)
|
| 650 |
+
),
|
| 651 |
+
nn.LeakyReLU(0.2),
|
| 652 |
+
SpectralNorm(
|
| 653 |
+
nn.Conv2d(valdim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1)
|
| 654 |
+
),
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
def forward(self, x):
|
| 658 |
+
return self.Key(x), self.Value(x)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
class MaskAttention(nn.Module):
|
| 662 |
+
def __init__(self, indim):
|
| 663 |
+
super(MaskAttention, self).__init__()
|
| 664 |
+
self.conv1 = nn.Sequential(
|
| 665 |
+
SpectralNorm(
|
| 666 |
+
nn.Conv2d(
|
| 667 |
+
indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1
|
| 668 |
+
)
|
| 669 |
+
),
|
| 670 |
+
nn.LeakyReLU(0.2),
|
| 671 |
+
SpectralNorm(
|
| 672 |
+
nn.Conv2d(
|
| 673 |
+
indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1
|
| 674 |
+
)
|
| 675 |
+
),
|
| 676 |
+
)
|
| 677 |
+
self.conv2 = nn.Sequential(
|
| 678 |
+
SpectralNorm(
|
| 679 |
+
nn.Conv2d(
|
| 680 |
+
indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1
|
| 681 |
+
)
|
| 682 |
+
),
|
| 683 |
+
nn.LeakyReLU(0.2),
|
| 684 |
+
SpectralNorm(
|
| 685 |
+
nn.Conv2d(
|
| 686 |
+
indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1
|
| 687 |
+
)
|
| 688 |
+
),
|
| 689 |
+
)
|
| 690 |
+
self.conv3 = nn.Sequential(
|
| 691 |
+
SpectralNorm(
|
| 692 |
+
nn.Conv2d(
|
| 693 |
+
indim, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1
|
| 694 |
+
)
|
| 695 |
+
),
|
| 696 |
+
nn.LeakyReLU(0.2),
|
| 697 |
+
SpectralNorm(
|
| 698 |
+
nn.Conv2d(
|
| 699 |
+
indim // 3, indim // 3, kernel_size=(3, 3), padding=(1, 1), stride=1
|
| 700 |
+
)
|
| 701 |
+
),
|
| 702 |
+
)
|
| 703 |
+
self.convCat = nn.Sequential(
|
| 704 |
+
SpectralNorm(
|
| 705 |
+
nn.Conv2d(
|
| 706 |
+
indim // 3 * 3, indim, kernel_size=(3, 3), padding=(1, 1), stride=1
|
| 707 |
+
)
|
| 708 |
+
),
|
| 709 |
+
nn.LeakyReLU(0.2),
|
| 710 |
+
SpectralNorm(
|
| 711 |
+
nn.Conv2d(indim, indim, kernel_size=(3, 3), padding=(1, 1), stride=1)
|
| 712 |
+
),
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
def forward(self, x, y, z):
|
| 716 |
+
c1 = self.conv1(x)
|
| 717 |
+
c2 = self.conv2(y)
|
| 718 |
+
c3 = self.conv3(z)
|
| 719 |
+
return self.convCat(torch.cat([c1, c2, c3], dim=1))
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
class Query(nn.Module):
|
| 723 |
+
def __init__(self, indim, quedim):
|
| 724 |
+
super(Query, self).__init__()
|
| 725 |
+
self.Query = nn.Sequential(
|
| 726 |
+
SpectralNorm(
|
| 727 |
+
nn.Conv2d(indim, quedim, kernel_size=(3, 3), padding=(1, 1), stride=1)
|
| 728 |
+
),
|
| 729 |
+
nn.LeakyReLU(0.2),
|
| 730 |
+
SpectralNorm(
|
| 731 |
+
nn.Conv2d(quedim, quedim, kernel_size=(3, 3), padding=(1, 1), stride=1)
|
| 732 |
+
),
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
def forward(self, x):
|
| 736 |
+
return self.Query(x)
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def roi_align_self(input, location, target_size):
|
| 740 |
+
test = (target_size.item(), target_size.item())
|
| 741 |
+
return torch.cat(
|
| 742 |
+
[
|
| 743 |
+
F.interpolate(
|
| 744 |
+
input[
|
| 745 |
+
i : i + 1,
|
| 746 |
+
:,
|
| 747 |
+
location[i, 1] : location[i, 3],
|
| 748 |
+
location[i, 0] : location[i, 2],
|
| 749 |
+
],
|
| 750 |
+
test,
|
| 751 |
+
mode="bilinear",
|
| 752 |
+
align_corners=False,
|
| 753 |
+
)
|
| 754 |
+
for i in range(input.size(0))
|
| 755 |
+
],
|
| 756 |
+
0,
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
class FeatureExtractor(nn.Module):
|
| 761 |
+
def __init__(self, ngf=64, key_scale=4): #
|
| 762 |
+
super().__init__()
|
| 763 |
+
|
| 764 |
+
self.key_scale = 4
|
| 765 |
+
self.part_sizes = np.array([80, 80, 50, 110]) #
|
| 766 |
+
self.feature_sizes = np.array([256, 128, 64]) #
|
| 767 |
+
|
| 768 |
+
self.conv1 = nn.Sequential(
|
| 769 |
+
SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)),
|
| 770 |
+
nn.LeakyReLU(0.2),
|
| 771 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
| 772 |
+
)
|
| 773 |
+
self.conv2 = nn.Sequential(
|
| 774 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
| 775 |
+
nn.LeakyReLU(0.2),
|
| 776 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
| 777 |
+
)
|
| 778 |
+
self.res1 = DilateResBlock(ngf, [5, 3])
|
| 779 |
+
self.res2 = DilateResBlock(ngf, [5, 3])
|
| 780 |
+
|
| 781 |
+
self.conv3 = nn.Sequential(
|
| 782 |
+
SpectralNorm(nn.Conv2d(ngf, ngf * 2, 3, 2, 1)),
|
| 783 |
+
nn.LeakyReLU(0.2),
|
| 784 |
+
SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)),
|
| 785 |
+
)
|
| 786 |
+
self.conv4 = nn.Sequential(
|
| 787 |
+
SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)),
|
| 788 |
+
nn.LeakyReLU(0.2),
|
| 789 |
+
SpectralNorm(nn.Conv2d(ngf * 2, ngf * 2, 3, 1, 1)),
|
| 790 |
+
)
|
| 791 |
+
self.res3 = DilateResBlock(ngf * 2, [3, 1])
|
| 792 |
+
self.res4 = DilateResBlock(ngf * 2, [3, 1])
|
| 793 |
+
|
| 794 |
+
self.conv5 = nn.Sequential(
|
| 795 |
+
SpectralNorm(nn.Conv2d(ngf * 2, ngf * 4, 3, 2, 1)),
|
| 796 |
+
nn.LeakyReLU(0.2),
|
| 797 |
+
SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)),
|
| 798 |
+
)
|
| 799 |
+
self.conv6 = nn.Sequential(
|
| 800 |
+
SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)),
|
| 801 |
+
nn.LeakyReLU(0.2),
|
| 802 |
+
SpectralNorm(nn.Conv2d(ngf * 4, ngf * 4, 3, 1, 1)),
|
| 803 |
+
)
|
| 804 |
+
self.res5 = DilateResBlock(ngf * 4, [1, 1])
|
| 805 |
+
self.res6 = DilateResBlock(ngf * 4, [1, 1])
|
| 806 |
+
|
| 807 |
+
self.LE_256_Q = Query(ngf, ngf // self.key_scale)
|
| 808 |
+
self.RE_256_Q = Query(ngf, ngf // self.key_scale)
|
| 809 |
+
self.MO_256_Q = Query(ngf, ngf // self.key_scale)
|
| 810 |
+
self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
| 811 |
+
self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
| 812 |
+
self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
| 813 |
+
self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
| 814 |
+
self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
| 815 |
+
self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
| 816 |
+
|
| 817 |
+
def forward(self, img, locs):
|
| 818 |
+
le_location = locs[:, 0, :].int().cpu().numpy()
|
| 819 |
+
re_location = locs[:, 1, :].int().cpu().numpy()
|
| 820 |
+
no_location = locs[:, 2, :].int().cpu().numpy()
|
| 821 |
+
mo_location = locs[:, 3, :].int().cpu().numpy()
|
| 822 |
+
|
| 823 |
+
f1_0 = self.conv1(img)
|
| 824 |
+
f1_1 = self.res1(f1_0)
|
| 825 |
+
f2_0 = self.conv2(f1_1)
|
| 826 |
+
f2_1 = self.res2(f2_0)
|
| 827 |
+
|
| 828 |
+
f3_0 = self.conv3(f2_1)
|
| 829 |
+
f3_1 = self.res3(f3_0)
|
| 830 |
+
f4_0 = self.conv4(f3_1)
|
| 831 |
+
f4_1 = self.res4(f4_0)
|
| 832 |
+
|
| 833 |
+
f5_0 = self.conv5(f4_1)
|
| 834 |
+
f5_1 = self.res5(f5_0)
|
| 835 |
+
f6_0 = self.conv6(f5_1)
|
| 836 |
+
f6_1 = self.res6(f6_0)
|
| 837 |
+
|
| 838 |
+
####ROI Align
|
| 839 |
+
le_part_256 = roi_align_self(
|
| 840 |
+
f2_1.clone(), le_location // 2, self.part_sizes[0] // 2
|
| 841 |
+
)
|
| 842 |
+
re_part_256 = roi_align_self(
|
| 843 |
+
f2_1.clone(), re_location // 2, self.part_sizes[1] // 2
|
| 844 |
+
)
|
| 845 |
+
mo_part_256 = roi_align_self(
|
| 846 |
+
f2_1.clone(), mo_location // 2, self.part_sizes[3] // 2
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
le_part_128 = roi_align_self(
|
| 850 |
+
f4_1.clone(), le_location // 4, self.part_sizes[0] // 4
|
| 851 |
+
)
|
| 852 |
+
re_part_128 = roi_align_self(
|
| 853 |
+
f4_1.clone(), re_location // 4, self.part_sizes[1] // 4
|
| 854 |
+
)
|
| 855 |
+
mo_part_128 = roi_align_self(
|
| 856 |
+
f4_1.clone(), mo_location // 4, self.part_sizes[3] // 4
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
le_part_64 = roi_align_self(
|
| 860 |
+
f6_1.clone(), le_location // 8, self.part_sizes[0] // 8
|
| 861 |
+
)
|
| 862 |
+
re_part_64 = roi_align_self(
|
| 863 |
+
f6_1.clone(), re_location // 8, self.part_sizes[1] // 8
|
| 864 |
+
)
|
| 865 |
+
mo_part_64 = roi_align_self(
|
| 866 |
+
f6_1.clone(), mo_location // 8, self.part_sizes[3] // 8
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
le_256_q = self.LE_256_Q(le_part_256)
|
| 870 |
+
re_256_q = self.RE_256_Q(re_part_256)
|
| 871 |
+
mo_256_q = self.MO_256_Q(mo_part_256)
|
| 872 |
+
|
| 873 |
+
le_128_q = self.LE_128_Q(le_part_128)
|
| 874 |
+
re_128_q = self.RE_128_Q(re_part_128)
|
| 875 |
+
mo_128_q = self.MO_128_Q(mo_part_128)
|
| 876 |
+
|
| 877 |
+
le_64_q = self.LE_64_Q(le_part_64)
|
| 878 |
+
re_64_q = self.RE_64_Q(re_part_64)
|
| 879 |
+
mo_64_q = self.MO_64_Q(mo_part_64)
|
| 880 |
+
|
| 881 |
+
return {
|
| 882 |
+
"f256": f2_1,
|
| 883 |
+
"f128": f4_1,
|
| 884 |
+
"f64": f6_1,
|
| 885 |
+
"le256": le_part_256,
|
| 886 |
+
"re256": re_part_256,
|
| 887 |
+
"mo256": mo_part_256,
|
| 888 |
+
"le128": le_part_128,
|
| 889 |
+
"re128": re_part_128,
|
| 890 |
+
"mo128": mo_part_128,
|
| 891 |
+
"le64": le_part_64,
|
| 892 |
+
"re64": re_part_64,
|
| 893 |
+
"mo64": mo_part_64,
|
| 894 |
+
"le_256_q": le_256_q,
|
| 895 |
+
"re_256_q": re_256_q,
|
| 896 |
+
"mo_256_q": mo_256_q,
|
| 897 |
+
"le_128_q": le_128_q,
|
| 898 |
+
"re_128_q": re_128_q,
|
| 899 |
+
"mo_128_q": mo_128_q,
|
| 900 |
+
"le_64_q": le_64_q,
|
| 901 |
+
"re_64_q": re_64_q,
|
| 902 |
+
"mo_64_q": mo_64_q,
|
| 903 |
+
}
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
class DMDNet(nn.Module):
|
| 907 |
+
def __init__(self, ngf=64, banks_num=128):
|
| 908 |
+
super().__init__()
|
| 909 |
+
self.part_sizes = np.array([80, 80, 50, 110]) # size for 512
|
| 910 |
+
self.feature_sizes = np.array([256, 128, 64]) # size for 512
|
| 911 |
+
|
| 912 |
+
self.banks_num = banks_num
|
| 913 |
+
self.key_scale = 4
|
| 914 |
+
|
| 915 |
+
self.E_lq = FeatureExtractor(key_scale=self.key_scale)
|
| 916 |
+
self.E_hq = FeatureExtractor(key_scale=self.key_scale)
|
| 917 |
+
|
| 918 |
+
self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
| 919 |
+
self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
| 920 |
+
self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
| 921 |
+
|
| 922 |
+
self.LE_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2)
|
| 923 |
+
self.RE_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2)
|
| 924 |
+
self.MO_128_KV = KeyValue(ngf * 2, ngf * 2 // self.key_scale, ngf * 2)
|
| 925 |
+
|
| 926 |
+
self.LE_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4)
|
| 927 |
+
self.RE_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4)
|
| 928 |
+
self.MO_64_KV = KeyValue(ngf * 4, ngf * 4 // self.key_scale, ngf * 4)
|
| 929 |
+
|
| 930 |
+
self.LE_256_Attention = AttentionBlock(64)
|
| 931 |
+
self.RE_256_Attention = AttentionBlock(64)
|
| 932 |
+
self.MO_256_Attention = AttentionBlock(64)
|
| 933 |
+
|
| 934 |
+
self.LE_128_Attention = AttentionBlock(128)
|
| 935 |
+
self.RE_128_Attention = AttentionBlock(128)
|
| 936 |
+
self.MO_128_Attention = AttentionBlock(128)
|
| 937 |
+
|
| 938 |
+
self.LE_64_Attention = AttentionBlock(256)
|
| 939 |
+
self.RE_64_Attention = AttentionBlock(256)
|
| 940 |
+
self.MO_64_Attention = AttentionBlock(256)
|
| 941 |
+
|
| 942 |
+
self.LE_256_Mask = MaskAttention(64)
|
| 943 |
+
self.RE_256_Mask = MaskAttention(64)
|
| 944 |
+
self.MO_256_Mask = MaskAttention(64)
|
| 945 |
+
|
| 946 |
+
self.LE_128_Mask = MaskAttention(128)
|
| 947 |
+
self.RE_128_Mask = MaskAttention(128)
|
| 948 |
+
self.MO_128_Mask = MaskAttention(128)
|
| 949 |
+
|
| 950 |
+
self.LE_64_Mask = MaskAttention(256)
|
| 951 |
+
self.RE_64_Mask = MaskAttention(256)
|
| 952 |
+
self.MO_64_Mask = MaskAttention(256)
|
| 953 |
+
|
| 954 |
+
self.MSDilate = MSDilateBlock(ngf * 4, dilation=[4, 3, 2, 1])
|
| 955 |
+
|
| 956 |
+
self.up1 = StyledUpBlock(ngf * 4, ngf * 2, noise_inject=False) #
|
| 957 |
+
self.up2 = StyledUpBlock(ngf * 2, ngf, noise_inject=False) #
|
| 958 |
+
self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) #
|
| 959 |
+
self.up4 = nn.Sequential(
|
| 960 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
| 961 |
+
nn.LeakyReLU(0.2),
|
| 962 |
+
UpResBlock(ngf),
|
| 963 |
+
UpResBlock(ngf),
|
| 964 |
+
SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
|
| 965 |
+
nn.Tanh(),
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
# define generic memory, revise register_buffer to register_parameter for backward update
|
| 969 |
+
self.register_buffer("le_256_mem_key", torch.randn(128, 16, 40, 40))
|
| 970 |
+
self.register_buffer("re_256_mem_key", torch.randn(128, 16, 40, 40))
|
| 971 |
+
self.register_buffer("mo_256_mem_key", torch.randn(128, 16, 55, 55))
|
| 972 |
+
self.register_buffer("le_256_mem_value", torch.randn(128, 64, 40, 40))
|
| 973 |
+
self.register_buffer("re_256_mem_value", torch.randn(128, 64, 40, 40))
|
| 974 |
+
self.register_buffer("mo_256_mem_value", torch.randn(128, 64, 55, 55))
|
| 975 |
+
|
| 976 |
+
self.register_buffer("le_128_mem_key", torch.randn(128, 32, 20, 20))
|
| 977 |
+
self.register_buffer("re_128_mem_key", torch.randn(128, 32, 20, 20))
|
| 978 |
+
self.register_buffer("mo_128_mem_key", torch.randn(128, 32, 27, 27))
|
| 979 |
+
self.register_buffer("le_128_mem_value", torch.randn(128, 128, 20, 20))
|
| 980 |
+
self.register_buffer("re_128_mem_value", torch.randn(128, 128, 20, 20))
|
| 981 |
+
self.register_buffer("mo_128_mem_value", torch.randn(128, 128, 27, 27))
|
| 982 |
+
|
| 983 |
+
self.register_buffer("le_64_mem_key", torch.randn(128, 64, 10, 10))
|
| 984 |
+
self.register_buffer("re_64_mem_key", torch.randn(128, 64, 10, 10))
|
| 985 |
+
self.register_buffer("mo_64_mem_key", torch.randn(128, 64, 13, 13))
|
| 986 |
+
self.register_buffer("le_64_mem_value", torch.randn(128, 256, 10, 10))
|
| 987 |
+
self.register_buffer("re_64_mem_value", torch.randn(128, 256, 10, 10))
|
| 988 |
+
self.register_buffer("mo_64_mem_value", torch.randn(128, 256, 13, 13))
|
| 989 |
+
|
| 990 |
+
def readMem(self, k, v, q):
|
| 991 |
+
sim = F.conv2d(q, k)
|
| 992 |
+
score = F.softmax(sim / sqrt(sim.size(1)), dim=1) # B * S * 1 * 1 6*128
|
| 993 |
+
sb, sn, sw, sh = score.size()
|
| 994 |
+
s_m = score.view(sb, -1).unsqueeze(1) # 2*1*M
|
| 995 |
+
vb, vn, vw, vh = v.size()
|
| 996 |
+
v_in = v.view(vb, -1).repeat(sb, 1, 1) # 2*M*(c*w*h)
|
| 997 |
+
mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw, vh)
|
| 998 |
+
max_inds = torch.argmax(score, dim=1).squeeze()
|
| 999 |
+
return mem_out, max_inds
|
| 1000 |
+
|
| 1001 |
+
def memorize(self, img, locs):
|
| 1002 |
+
fs = self.E_hq(img, locs)
|
| 1003 |
+
LE256_key, LE256_value = self.LE_256_KV(fs["le256"])
|
| 1004 |
+
RE256_key, RE256_value = self.RE_256_KV(fs["re256"])
|
| 1005 |
+
MO256_key, MO256_value = self.MO_256_KV(fs["mo256"])
|
| 1006 |
+
|
| 1007 |
+
LE128_key, LE128_value = self.LE_128_KV(fs["le128"])
|
| 1008 |
+
RE128_key, RE128_value = self.RE_128_KV(fs["re128"])
|
| 1009 |
+
MO128_key, MO128_value = self.MO_128_KV(fs["mo128"])
|
| 1010 |
+
|
| 1011 |
+
LE64_key, LE64_value = self.LE_64_KV(fs["le64"])
|
| 1012 |
+
RE64_key, RE64_value = self.RE_64_KV(fs["re64"])
|
| 1013 |
+
MO64_key, MO64_value = self.MO_64_KV(fs["mo64"])
|
| 1014 |
+
|
| 1015 |
+
Mem256 = {
|
| 1016 |
+
"LE256Key": LE256_key,
|
| 1017 |
+
"LE256Value": LE256_value,
|
| 1018 |
+
"RE256Key": RE256_key,
|
| 1019 |
+
"RE256Value": RE256_value,
|
| 1020 |
+
"MO256Key": MO256_key,
|
| 1021 |
+
"MO256Value": MO256_value,
|
| 1022 |
+
}
|
| 1023 |
+
Mem128 = {
|
| 1024 |
+
"LE128Key": LE128_key,
|
| 1025 |
+
"LE128Value": LE128_value,
|
| 1026 |
+
"RE128Key": RE128_key,
|
| 1027 |
+
"RE128Value": RE128_value,
|
| 1028 |
+
"MO128Key": MO128_key,
|
| 1029 |
+
"MO128Value": MO128_value,
|
| 1030 |
+
}
|
| 1031 |
+
Mem64 = {
|
| 1032 |
+
"LE64Key": LE64_key,
|
| 1033 |
+
"LE64Value": LE64_value,
|
| 1034 |
+
"RE64Key": RE64_key,
|
| 1035 |
+
"RE64Value": RE64_value,
|
| 1036 |
+
"MO64Key": MO64_key,
|
| 1037 |
+
"MO64Value": MO64_value,
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
FS256 = {"LE256F": fs["le256"], "RE256F": fs["re256"], "MO256F": fs["mo256"]}
|
| 1041 |
+
FS128 = {"LE128F": fs["le128"], "RE128F": fs["re128"], "MO128F": fs["mo128"]}
|
| 1042 |
+
FS64 = {"LE64F": fs["le64"], "RE64F": fs["re64"], "MO64F": fs["mo64"]}
|
| 1043 |
+
|
| 1044 |
+
return Mem256, Mem128, Mem64
|
| 1045 |
+
|
| 1046 |
+
def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None):
|
| 1047 |
+
le_256_q = fs_in["le_256_q"]
|
| 1048 |
+
re_256_q = fs_in["re_256_q"]
|
| 1049 |
+
mo_256_q = fs_in["mo_256_q"]
|
| 1050 |
+
|
| 1051 |
+
le_128_q = fs_in["le_128_q"]
|
| 1052 |
+
re_128_q = fs_in["re_128_q"]
|
| 1053 |
+
mo_128_q = fs_in["mo_128_q"]
|
| 1054 |
+
|
| 1055 |
+
le_64_q = fs_in["le_64_q"]
|
| 1056 |
+
re_64_q = fs_in["re_64_q"]
|
| 1057 |
+
mo_64_q = fs_in["mo_64_q"]
|
| 1058 |
+
|
| 1059 |
+
####for 256
|
| 1060 |
+
le_256_mem_g, le_256_inds = self.readMem(
|
| 1061 |
+
self.le_256_mem_key, self.le_256_mem_value, le_256_q
|
| 1062 |
+
)
|
| 1063 |
+
re_256_mem_g, re_256_inds = self.readMem(
|
| 1064 |
+
self.re_256_mem_key, self.re_256_mem_value, re_256_q
|
| 1065 |
+
)
|
| 1066 |
+
mo_256_mem_g, mo_256_inds = self.readMem(
|
| 1067 |
+
self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
le_128_mem_g, le_128_inds = self.readMem(
|
| 1071 |
+
self.le_128_mem_key, self.le_128_mem_value, le_128_q
|
| 1072 |
+
)
|
| 1073 |
+
re_128_mem_g, re_128_inds = self.readMem(
|
| 1074 |
+
self.re_128_mem_key, self.re_128_mem_value, re_128_q
|
| 1075 |
+
)
|
| 1076 |
+
mo_128_mem_g, mo_128_inds = self.readMem(
|
| 1077 |
+
self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
le_64_mem_g, le_64_inds = self.readMem(
|
| 1081 |
+
self.le_64_mem_key, self.le_64_mem_value, le_64_q
|
| 1082 |
+
)
|
| 1083 |
+
re_64_mem_g, re_64_inds = self.readMem(
|
| 1084 |
+
self.re_64_mem_key, self.re_64_mem_value, re_64_q
|
| 1085 |
+
)
|
| 1086 |
+
mo_64_mem_g, mo_64_inds = self.readMem(
|
| 1087 |
+
self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
if sp_256 is not None and sp_128 is not None and sp_64 is not None:
|
| 1091 |
+
le_256_mem_s, _ = self.readMem(
|
| 1092 |
+
sp_256["LE256Key"], sp_256["LE256Value"], le_256_q
|
| 1093 |
+
)
|
| 1094 |
+
re_256_mem_s, _ = self.readMem(
|
| 1095 |
+
sp_256["RE256Key"], sp_256["RE256Value"], re_256_q
|
| 1096 |
+
)
|
| 1097 |
+
mo_256_mem_s, _ = self.readMem(
|
| 1098 |
+
sp_256["MO256Key"], sp_256["MO256Value"], mo_256_q
|
| 1099 |
+
)
|
| 1100 |
+
le_256_mask = self.LE_256_Mask(fs_in["le256"], le_256_mem_s, le_256_mem_g)
|
| 1101 |
+
le_256_mem = le_256_mask * le_256_mem_s + (1 - le_256_mask) * le_256_mem_g
|
| 1102 |
+
re_256_mask = self.RE_256_Mask(fs_in["re256"], re_256_mem_s, re_256_mem_g)
|
| 1103 |
+
re_256_mem = re_256_mask * re_256_mem_s + (1 - re_256_mask) * re_256_mem_g
|
| 1104 |
+
mo_256_mask = self.MO_256_Mask(fs_in["mo256"], mo_256_mem_s, mo_256_mem_g)
|
| 1105 |
+
mo_256_mem = mo_256_mask * mo_256_mem_s + (1 - mo_256_mask) * mo_256_mem_g
|
| 1106 |
+
|
| 1107 |
+
le_128_mem_s, _ = self.readMem(
|
| 1108 |
+
sp_128["LE128Key"], sp_128["LE128Value"], le_128_q
|
| 1109 |
+
)
|
| 1110 |
+
re_128_mem_s, _ = self.readMem(
|
| 1111 |
+
sp_128["RE128Key"], sp_128["RE128Value"], re_128_q
|
| 1112 |
+
)
|
| 1113 |
+
mo_128_mem_s, _ = self.readMem(
|
| 1114 |
+
sp_128["MO128Key"], sp_128["MO128Value"], mo_128_q
|
| 1115 |
+
)
|
| 1116 |
+
le_128_mask = self.LE_128_Mask(fs_in["le128"], le_128_mem_s, le_128_mem_g)
|
| 1117 |
+
le_128_mem = le_128_mask * le_128_mem_s + (1 - le_128_mask) * le_128_mem_g
|
| 1118 |
+
re_128_mask = self.RE_128_Mask(fs_in["re128"], re_128_mem_s, re_128_mem_g)
|
| 1119 |
+
re_128_mem = re_128_mask * re_128_mem_s + (1 - re_128_mask) * re_128_mem_g
|
| 1120 |
+
mo_128_mask = self.MO_128_Mask(fs_in["mo128"], mo_128_mem_s, mo_128_mem_g)
|
| 1121 |
+
mo_128_mem = mo_128_mask * mo_128_mem_s + (1 - mo_128_mask) * mo_128_mem_g
|
| 1122 |
+
|
| 1123 |
+
le_64_mem_s, _ = self.readMem(sp_64["LE64Key"], sp_64["LE64Value"], le_64_q)
|
| 1124 |
+
re_64_mem_s, _ = self.readMem(sp_64["RE64Key"], sp_64["RE64Value"], re_64_q)
|
| 1125 |
+
mo_64_mem_s, _ = self.readMem(sp_64["MO64Key"], sp_64["MO64Value"], mo_64_q)
|
| 1126 |
+
le_64_mask = self.LE_64_Mask(fs_in["le64"], le_64_mem_s, le_64_mem_g)
|
| 1127 |
+
le_64_mem = le_64_mask * le_64_mem_s + (1 - le_64_mask) * le_64_mem_g
|
| 1128 |
+
re_64_mask = self.RE_64_Mask(fs_in["re64"], re_64_mem_s, re_64_mem_g)
|
| 1129 |
+
re_64_mem = re_64_mask * re_64_mem_s + (1 - re_64_mask) * re_64_mem_g
|
| 1130 |
+
mo_64_mask = self.MO_64_Mask(fs_in["mo64"], mo_64_mem_s, mo_64_mem_g)
|
| 1131 |
+
mo_64_mem = mo_64_mask * mo_64_mem_s + (1 - mo_64_mask) * mo_64_mem_g
|
| 1132 |
+
else:
|
| 1133 |
+
le_256_mem = le_256_mem_g
|
| 1134 |
+
re_256_mem = re_256_mem_g
|
| 1135 |
+
mo_256_mem = mo_256_mem_g
|
| 1136 |
+
le_128_mem = le_128_mem_g
|
| 1137 |
+
re_128_mem = re_128_mem_g
|
| 1138 |
+
mo_128_mem = mo_128_mem_g
|
| 1139 |
+
le_64_mem = le_64_mem_g
|
| 1140 |
+
re_64_mem = re_64_mem_g
|
| 1141 |
+
mo_64_mem = mo_64_mem_g
|
| 1142 |
+
|
| 1143 |
+
le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in["le256"])
|
| 1144 |
+
re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in["re256"])
|
| 1145 |
+
mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in["mo256"])
|
| 1146 |
+
|
| 1147 |
+
####for 128
|
| 1148 |
+
le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in["le128"])
|
| 1149 |
+
re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in["re128"])
|
| 1150 |
+
mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in["mo128"])
|
| 1151 |
+
|
| 1152 |
+
####for 64
|
| 1153 |
+
le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in["le64"])
|
| 1154 |
+
re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in["re64"])
|
| 1155 |
+
mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in["mo64"])
|
| 1156 |
+
|
| 1157 |
+
EnMem256 = {
|
| 1158 |
+
"LE256Norm": le_256_mem_norm,
|
| 1159 |
+
"RE256Norm": re_256_mem_norm,
|
| 1160 |
+
"MO256Norm": mo_256_mem_norm,
|
| 1161 |
+
}
|
| 1162 |
+
EnMem128 = {
|
| 1163 |
+
"LE128Norm": le_128_mem_norm,
|
| 1164 |
+
"RE128Norm": re_128_mem_norm,
|
| 1165 |
+
"MO128Norm": mo_128_mem_norm,
|
| 1166 |
+
}
|
| 1167 |
+
EnMem64 = {
|
| 1168 |
+
"LE64Norm": le_64_mem_norm,
|
| 1169 |
+
"RE64Norm": re_64_mem_norm,
|
| 1170 |
+
"MO64Norm": mo_64_mem_norm,
|
| 1171 |
+
}
|
| 1172 |
+
Ind256 = {"LE": le_256_inds, "RE": re_256_inds, "MO": mo_256_inds}
|
| 1173 |
+
Ind128 = {"LE": le_128_inds, "RE": re_128_inds, "MO": mo_128_inds}
|
| 1174 |
+
Ind64 = {"LE": le_64_inds, "RE": re_64_inds, "MO": mo_64_inds}
|
| 1175 |
+
return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64
|
| 1176 |
+
|
| 1177 |
+
def reconstruct(self, fs_in, locs, memstar):
|
| 1178 |
+
le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = (
|
| 1179 |
+
memstar[0]["LE256Norm"],
|
| 1180 |
+
memstar[0]["RE256Norm"],
|
| 1181 |
+
memstar[0]["MO256Norm"],
|
| 1182 |
+
)
|
| 1183 |
+
le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = (
|
| 1184 |
+
memstar[1]["LE128Norm"],
|
| 1185 |
+
memstar[1]["RE128Norm"],
|
| 1186 |
+
memstar[1]["MO128Norm"],
|
| 1187 |
+
)
|
| 1188 |
+
le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = (
|
| 1189 |
+
memstar[2]["LE64Norm"],
|
| 1190 |
+
memstar[2]["RE64Norm"],
|
| 1191 |
+
memstar[2]["MO64Norm"],
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
le_256_final = (
|
| 1195 |
+
self.LE_256_Attention(le_256_mem_norm - fs_in["le256"]) * le_256_mem_norm
|
| 1196 |
+
+ fs_in["le256"]
|
| 1197 |
+
)
|
| 1198 |
+
re_256_final = (
|
| 1199 |
+
self.RE_256_Attention(re_256_mem_norm - fs_in["re256"]) * re_256_mem_norm
|
| 1200 |
+
+ fs_in["re256"]
|
| 1201 |
+
)
|
| 1202 |
+
mo_256_final = (
|
| 1203 |
+
self.MO_256_Attention(mo_256_mem_norm - fs_in["mo256"]) * mo_256_mem_norm
|
| 1204 |
+
+ fs_in["mo256"]
|
| 1205 |
+
)
|
| 1206 |
+
|
| 1207 |
+
le_128_final = (
|
| 1208 |
+
self.LE_128_Attention(le_128_mem_norm - fs_in["le128"]) * le_128_mem_norm
|
| 1209 |
+
+ fs_in["le128"]
|
| 1210 |
+
)
|
| 1211 |
+
re_128_final = (
|
| 1212 |
+
self.RE_128_Attention(re_128_mem_norm - fs_in["re128"]) * re_128_mem_norm
|
| 1213 |
+
+ fs_in["re128"]
|
| 1214 |
+
)
|
| 1215 |
+
mo_128_final = (
|
| 1216 |
+
self.MO_128_Attention(mo_128_mem_norm - fs_in["mo128"]) * mo_128_mem_norm
|
| 1217 |
+
+ fs_in["mo128"]
|
| 1218 |
+
)
|
| 1219 |
+
|
| 1220 |
+
le_64_final = (
|
| 1221 |
+
self.LE_64_Attention(le_64_mem_norm - fs_in["le64"]) * le_64_mem_norm
|
| 1222 |
+
+ fs_in["le64"]
|
| 1223 |
+
)
|
| 1224 |
+
re_64_final = (
|
| 1225 |
+
self.RE_64_Attention(re_64_mem_norm - fs_in["re64"]) * re_64_mem_norm
|
| 1226 |
+
+ fs_in["re64"]
|
| 1227 |
+
)
|
| 1228 |
+
mo_64_final = (
|
| 1229 |
+
self.MO_64_Attention(mo_64_mem_norm - fs_in["mo64"]) * mo_64_mem_norm
|
| 1230 |
+
+ fs_in["mo64"]
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
le_location = locs[:, 0, :]
|
| 1234 |
+
re_location = locs[:, 1, :]
|
| 1235 |
+
mo_location = locs[:, 3, :]
|
| 1236 |
+
|
| 1237 |
+
# Somehow with latest Torch it doesn't like numpy wrappers anymore
|
| 1238 |
+
|
| 1239 |
+
# le_location = le_location.cpu().int().numpy()
|
| 1240 |
+
# re_location = re_location.cpu().int().numpy()
|
| 1241 |
+
# mo_location = mo_location.cpu().int().numpy()
|
| 1242 |
+
le_location = le_location.cpu().int()
|
| 1243 |
+
re_location = re_location.cpu().int()
|
| 1244 |
+
mo_location = mo_location.cpu().int()
|
| 1245 |
+
|
| 1246 |
+
up_in_256 = fs_in["f256"].clone() # * 0
|
| 1247 |
+
up_in_128 = fs_in["f128"].clone() # * 0
|
| 1248 |
+
up_in_64 = fs_in["f64"].clone() # * 0
|
| 1249 |
+
|
| 1250 |
+
for i in range(fs_in["f256"].size(0)):
|
| 1251 |
+
up_in_256[
|
| 1252 |
+
i : i + 1,
|
| 1253 |
+
:,
|
| 1254 |
+
le_location[i, 1] // 2 : le_location[i, 3] // 2,
|
| 1255 |
+
le_location[i, 0] // 2 : le_location[i, 2] // 2,
|
| 1256 |
+
] = F.interpolate(
|
| 1257 |
+
le_256_final[i : i + 1, :, :, :].clone(),
|
| 1258 |
+
(
|
| 1259 |
+
le_location[i, 3] // 2 - le_location[i, 1] // 2,
|
| 1260 |
+
le_location[i, 2] // 2 - le_location[i, 0] // 2,
|
| 1261 |
+
),
|
| 1262 |
+
mode="bilinear",
|
| 1263 |
+
align_corners=False,
|
| 1264 |
+
)
|
| 1265 |
+
up_in_256[
|
| 1266 |
+
i : i + 1,
|
| 1267 |
+
:,
|
| 1268 |
+
re_location[i, 1] // 2 : re_location[i, 3] // 2,
|
| 1269 |
+
re_location[i, 0] // 2 : re_location[i, 2] // 2,
|
| 1270 |
+
] = F.interpolate(
|
| 1271 |
+
re_256_final[i : i + 1, :, :, :].clone(),
|
| 1272 |
+
(
|
| 1273 |
+
re_location[i, 3] // 2 - re_location[i, 1] // 2,
|
| 1274 |
+
re_location[i, 2] // 2 - re_location[i, 0] // 2,
|
| 1275 |
+
),
|
| 1276 |
+
mode="bilinear",
|
| 1277 |
+
align_corners=False,
|
| 1278 |
+
)
|
| 1279 |
+
up_in_256[
|
| 1280 |
+
i : i + 1,
|
| 1281 |
+
:,
|
| 1282 |
+
mo_location[i, 1] // 2 : mo_location[i, 3] // 2,
|
| 1283 |
+
mo_location[i, 0] // 2 : mo_location[i, 2] // 2,
|
| 1284 |
+
] = F.interpolate(
|
| 1285 |
+
mo_256_final[i : i + 1, :, :, :].clone(),
|
| 1286 |
+
(
|
| 1287 |
+
mo_location[i, 3] // 2 - mo_location[i, 1] // 2,
|
| 1288 |
+
mo_location[i, 2] // 2 - mo_location[i, 0] // 2,
|
| 1289 |
+
),
|
| 1290 |
+
mode="bilinear",
|
| 1291 |
+
align_corners=False,
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
up_in_128[
|
| 1295 |
+
i : i + 1,
|
| 1296 |
+
:,
|
| 1297 |
+
le_location[i, 1] // 4 : le_location[i, 3] // 4,
|
| 1298 |
+
le_location[i, 0] // 4 : le_location[i, 2] // 4,
|
| 1299 |
+
] = F.interpolate(
|
| 1300 |
+
le_128_final[i : i + 1, :, :, :].clone(),
|
| 1301 |
+
(
|
| 1302 |
+
le_location[i, 3] // 4 - le_location[i, 1] // 4,
|
| 1303 |
+
le_location[i, 2] // 4 - le_location[i, 0] // 4,
|
| 1304 |
+
),
|
| 1305 |
+
mode="bilinear",
|
| 1306 |
+
align_corners=False,
|
| 1307 |
+
)
|
| 1308 |
+
up_in_128[
|
| 1309 |
+
i : i + 1,
|
| 1310 |
+
:,
|
| 1311 |
+
re_location[i, 1] // 4 : re_location[i, 3] // 4,
|
| 1312 |
+
re_location[i, 0] // 4 : re_location[i, 2] // 4,
|
| 1313 |
+
] = F.interpolate(
|
| 1314 |
+
re_128_final[i : i + 1, :, :, :].clone(),
|
| 1315 |
+
(
|
| 1316 |
+
re_location[i, 3] // 4 - re_location[i, 1] // 4,
|
| 1317 |
+
re_location[i, 2] // 4 - re_location[i, 0] // 4,
|
| 1318 |
+
),
|
| 1319 |
+
mode="bilinear",
|
| 1320 |
+
align_corners=False,
|
| 1321 |
+
)
|
| 1322 |
+
up_in_128[
|
| 1323 |
+
i : i + 1,
|
| 1324 |
+
:,
|
| 1325 |
+
mo_location[i, 1] // 4 : mo_location[i, 3] // 4,
|
| 1326 |
+
mo_location[i, 0] // 4 : mo_location[i, 2] // 4,
|
| 1327 |
+
] = F.interpolate(
|
| 1328 |
+
mo_128_final[i : i + 1, :, :, :].clone(),
|
| 1329 |
+
(
|
| 1330 |
+
mo_location[i, 3] // 4 - mo_location[i, 1] // 4,
|
| 1331 |
+
mo_location[i, 2] // 4 - mo_location[i, 0] // 4,
|
| 1332 |
+
),
|
| 1333 |
+
mode="bilinear",
|
| 1334 |
+
align_corners=False,
|
| 1335 |
+
)
|
| 1336 |
+
|
| 1337 |
+
up_in_64[
|
| 1338 |
+
i : i + 1,
|
| 1339 |
+
:,
|
| 1340 |
+
le_location[i, 1] // 8 : le_location[i, 3] // 8,
|
| 1341 |
+
le_location[i, 0] // 8 : le_location[i, 2] // 8,
|
| 1342 |
+
] = F.interpolate(
|
| 1343 |
+
le_64_final[i : i + 1, :, :, :].clone(),
|
| 1344 |
+
(
|
| 1345 |
+
le_location[i, 3] // 8 - le_location[i, 1] // 8,
|
| 1346 |
+
le_location[i, 2] // 8 - le_location[i, 0] // 8,
|
| 1347 |
+
),
|
| 1348 |
+
mode="bilinear",
|
| 1349 |
+
align_corners=False,
|
| 1350 |
+
)
|
| 1351 |
+
up_in_64[
|
| 1352 |
+
i : i + 1,
|
| 1353 |
+
:,
|
| 1354 |
+
re_location[i, 1] // 8 : re_location[i, 3] // 8,
|
| 1355 |
+
re_location[i, 0] // 8 : re_location[i, 2] // 8,
|
| 1356 |
+
] = F.interpolate(
|
| 1357 |
+
re_64_final[i : i + 1, :, :, :].clone(),
|
| 1358 |
+
(
|
| 1359 |
+
re_location[i, 3] // 8 - re_location[i, 1] // 8,
|
| 1360 |
+
re_location[i, 2] // 8 - re_location[i, 0] // 8,
|
| 1361 |
+
),
|
| 1362 |
+
mode="bilinear",
|
| 1363 |
+
align_corners=False,
|
| 1364 |
+
)
|
| 1365 |
+
up_in_64[
|
| 1366 |
+
i : i + 1,
|
| 1367 |
+
:,
|
| 1368 |
+
mo_location[i, 1] // 8 : mo_location[i, 3] // 8,
|
| 1369 |
+
mo_location[i, 0] // 8 : mo_location[i, 2] // 8,
|
| 1370 |
+
] = F.interpolate(
|
| 1371 |
+
mo_64_final[i : i + 1, :, :, :].clone(),
|
| 1372 |
+
(
|
| 1373 |
+
mo_location[i, 3] // 8 - mo_location[i, 1] // 8,
|
| 1374 |
+
mo_location[i, 2] // 8 - mo_location[i, 0] // 8,
|
| 1375 |
+
),
|
| 1376 |
+
mode="bilinear",
|
| 1377 |
+
align_corners=False,
|
| 1378 |
+
)
|
| 1379 |
+
|
| 1380 |
+
ms_in_64 = self.MSDilate(fs_in["f64"].clone())
|
| 1381 |
+
fea_up1 = self.up1(ms_in_64, up_in_64)
|
| 1382 |
+
fea_up2 = self.up2(fea_up1, up_in_128) #
|
| 1383 |
+
fea_up3 = self.up3(fea_up2, up_in_256) #
|
| 1384 |
+
output = self.up4(fea_up3) #
|
| 1385 |
+
return output
|
| 1386 |
+
|
| 1387 |
+
def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None):
|
| 1388 |
+
return self.memorize(sp_imgs, sp_locs)
|
| 1389 |
+
|
| 1390 |
+
def forward(self, lq=None, loc=None, sp_256=None, sp_128=None, sp_64=None):
|
| 1391 |
+
try:
|
| 1392 |
+
fs_in = self.E_lq(lq, loc) # low quality images
|
| 1393 |
+
except Exception as e:
|
| 1394 |
+
print(e)
|
| 1395 |
+
|
| 1396 |
+
GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer(
|
| 1397 |
+
fs_in
|
| 1398 |
+
)
|
| 1399 |
+
GeOut = self.reconstruct(
|
| 1400 |
+
fs_in, loc, memstar=[GeMemNorm256, GeMemNorm128, GeMemNorm64]
|
| 1401 |
+
)
|
| 1402 |
+
if sp_256 is not None and sp_128 is not None and sp_64 is not None:
|
| 1403 |
+
GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer(
|
| 1404 |
+
fs_in, sp_256, sp_128, sp_64
|
| 1405 |
+
)
|
| 1406 |
+
GSOut = self.reconstruct(
|
| 1407 |
+
fs_in, loc, memstar=[GSMemNorm256, GSMemNorm128, GSMemNorm64]
|
| 1408 |
+
)
|
| 1409 |
+
else:
|
| 1410 |
+
GSOut = None
|
| 1411 |
+
return GeOut, GSOut
|
| 1412 |
+
|
| 1413 |
+
|
| 1414 |
+
class UpResBlock(nn.Module):
|
| 1415 |
+
def __init__(self, dim, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d):
|
| 1416 |
+
super(UpResBlock, self).__init__()
|
| 1417 |
+
self.Model = nn.Sequential(
|
| 1418 |
+
SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
|
| 1419 |
+
nn.LeakyReLU(0.2),
|
| 1420 |
+
SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
|
| 1421 |
+
)
|
| 1422 |
+
|
| 1423 |
+
def forward(self, x):
|
| 1424 |
+
out = x + self.Model(x)
|
| 1425 |
+
return out
|
processors/Enhance_GFPGAN.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List, Callable
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import onnxruntime
|
| 5 |
+
import roop.globals
|
| 6 |
+
|
| 7 |
+
from roop.typing import Face, Frame, FaceSet
|
| 8 |
+
from roop.utilities import resolve_relative_path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Enhance_GFPGAN:
|
| 12 |
+
plugin_options: dict = None
|
| 13 |
+
|
| 14 |
+
model_gfpgan = None
|
| 15 |
+
name = None
|
| 16 |
+
devicename = None
|
| 17 |
+
|
| 18 |
+
processorname = "gfpgan"
|
| 19 |
+
type = "enhance"
|
| 20 |
+
|
| 21 |
+
def Initialize(self, plugin_options: dict):
|
| 22 |
+
if self.plugin_options is not None:
|
| 23 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
| 24 |
+
self.Release()
|
| 25 |
+
|
| 26 |
+
self.plugin_options = plugin_options
|
| 27 |
+
if self.model_gfpgan is None:
|
| 28 |
+
model_path = resolve_relative_path("../models/GFPGANv1.4.onnx")
|
| 29 |
+
self.model_gfpgan = onnxruntime.InferenceSession(
|
| 30 |
+
model_path, None, providers=roop.globals.execution_providers
|
| 31 |
+
)
|
| 32 |
+
# replace Mac mps with cpu for the moment
|
| 33 |
+
self.devicename = self.plugin_options["devicename"].replace("mps", "cpu")
|
| 34 |
+
|
| 35 |
+
self.name = self.model_gfpgan.get_inputs()[0].name
|
| 36 |
+
|
| 37 |
+
def Run(
|
| 38 |
+
self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame
|
| 39 |
+
) -> Frame:
|
| 40 |
+
# preprocess
|
| 41 |
+
input_size = temp_frame.shape[1]
|
| 42 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
| 43 |
+
|
| 44 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
| 45 |
+
temp_frame = temp_frame.astype("float32") / 255.0
|
| 46 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
| 47 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
| 48 |
+
|
| 49 |
+
io_binding = self.model_gfpgan.io_binding()
|
| 50 |
+
io_binding.bind_cpu_input("input", temp_frame)
|
| 51 |
+
io_binding.bind_output("1288", self.devicename)
|
| 52 |
+
self.model_gfpgan.run_with_iobinding(io_binding)
|
| 53 |
+
ort_outs = io_binding.copy_outputs_to_cpu()
|
| 54 |
+
result = ort_outs[0][0]
|
| 55 |
+
|
| 56 |
+
# post-process
|
| 57 |
+
result = np.clip(result, -1, 1)
|
| 58 |
+
result = (result + 1) / 2
|
| 59 |
+
result = result.transpose(1, 2, 0) * 255.0
|
| 60 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
| 61 |
+
scale_factor = int(result.shape[1] / input_size)
|
| 62 |
+
return result.astype(np.uint8), scale_factor
|
| 63 |
+
|
| 64 |
+
def Release(self):
|
| 65 |
+
self.model_gfpgan = None
|
processors/Enhance_GPEN.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List, Callable
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import onnxruntime
|
| 5 |
+
import roop.globals
|
| 6 |
+
|
| 7 |
+
from roop.typing import Face, Frame, FaceSet
|
| 8 |
+
from roop.utilities import resolve_relative_path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Enhance_GPEN:
|
| 12 |
+
plugin_options: dict = None
|
| 13 |
+
|
| 14 |
+
model_gpen = None
|
| 15 |
+
name = None
|
| 16 |
+
devicename = None
|
| 17 |
+
|
| 18 |
+
processorname = "gpen"
|
| 19 |
+
type = "enhance"
|
| 20 |
+
|
| 21 |
+
def Initialize(self, plugin_options: dict):
|
| 22 |
+
if self.plugin_options is not None:
|
| 23 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
| 24 |
+
self.Release()
|
| 25 |
+
|
| 26 |
+
self.plugin_options = plugin_options
|
| 27 |
+
if self.model_gpen is None:
|
| 28 |
+
model_path = resolve_relative_path("../models/GPEN-BFR-512.onnx")
|
| 29 |
+
self.model_gpen = onnxruntime.InferenceSession(
|
| 30 |
+
model_path, None, providers=roop.globals.execution_providers
|
| 31 |
+
)
|
| 32 |
+
# replace Mac mps with cpu for the moment
|
| 33 |
+
self.devicename = self.plugin_options["devicename"].replace("mps", "cpu")
|
| 34 |
+
|
| 35 |
+
self.name = self.model_gpen.get_inputs()[0].name
|
| 36 |
+
|
| 37 |
+
def Run(
|
| 38 |
+
self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame
|
| 39 |
+
) -> Frame:
|
| 40 |
+
# preprocess
|
| 41 |
+
input_size = temp_frame.shape[1]
|
| 42 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
| 43 |
+
|
| 44 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
| 45 |
+
temp_frame = temp_frame.astype("float32") / 255.0
|
| 46 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
| 47 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
| 48 |
+
|
| 49 |
+
io_binding = self.model_gpen.io_binding()
|
| 50 |
+
io_binding.bind_cpu_input("input", temp_frame)
|
| 51 |
+
io_binding.bind_output("output", self.devicename)
|
| 52 |
+
self.model_gpen.run_with_iobinding(io_binding)
|
| 53 |
+
ort_outs = io_binding.copy_outputs_to_cpu()
|
| 54 |
+
result = ort_outs[0][0]
|
| 55 |
+
|
| 56 |
+
# post-process
|
| 57 |
+
result = np.clip(result, -1, 1)
|
| 58 |
+
result = (result + 1) / 2
|
| 59 |
+
result = result.transpose(1, 2, 0) * 255.0
|
| 60 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
| 61 |
+
scale_factor = int(result.shape[1] / input_size)
|
| 62 |
+
return result.astype(np.uint8), scale_factor
|
| 63 |
+
|
| 64 |
+
def Release(self):
|
| 65 |
+
self.model_gpen = None
|
processors/Enhance_RestoreFormerPPlus.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List, Callable
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import onnxruntime
|
| 5 |
+
import roop.globals
|
| 6 |
+
|
| 7 |
+
from roop.typing import Face, Frame, FaceSet
|
| 8 |
+
from roop.utilities import resolve_relative_path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Enhance_RestoreFormerPPlus:
|
| 12 |
+
plugin_options: dict = None
|
| 13 |
+
model_restoreformerpplus = None
|
| 14 |
+
devicename = None
|
| 15 |
+
name = None
|
| 16 |
+
|
| 17 |
+
processorname = "restoreformer++"
|
| 18 |
+
type = "enhance"
|
| 19 |
+
|
| 20 |
+
def Initialize(self, plugin_options: dict):
|
| 21 |
+
if self.plugin_options is not None:
|
| 22 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
| 23 |
+
self.Release()
|
| 24 |
+
|
| 25 |
+
self.plugin_options = plugin_options
|
| 26 |
+
if self.model_restoreformerpplus is None:
|
| 27 |
+
# replace Mac mps with cpu for the moment
|
| 28 |
+
self.devicename = self.plugin_options["devicename"].replace("mps", "cpu")
|
| 29 |
+
model_path = resolve_relative_path("../models/restoreformer_plus_plus.onnx")
|
| 30 |
+
self.model_restoreformerpplus = onnxruntime.InferenceSession(
|
| 31 |
+
model_path, None, providers=roop.globals.execution_providers
|
| 32 |
+
)
|
| 33 |
+
self.model_inputs = self.model_restoreformerpplus.get_inputs()
|
| 34 |
+
model_outputs = self.model_restoreformerpplus.get_outputs()
|
| 35 |
+
self.io_binding = self.model_restoreformerpplus.io_binding()
|
| 36 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
| 37 |
+
|
| 38 |
+
def Run(
|
| 39 |
+
self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame
|
| 40 |
+
) -> Frame:
|
| 41 |
+
# preprocess
|
| 42 |
+
input_size = temp_frame.shape[1]
|
| 43 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
| 44 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
| 45 |
+
temp_frame = temp_frame.astype("float32") / 255.0
|
| 46 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
| 47 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
| 48 |
+
|
| 49 |
+
self.io_binding.bind_cpu_input(
|
| 50 |
+
self.model_inputs[0].name, temp_frame
|
| 51 |
+
) # .astype(np.float32)
|
| 52 |
+
self.model_restoreformerpplus.run_with_iobinding(self.io_binding)
|
| 53 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
| 54 |
+
result = ort_outs[0][0]
|
| 55 |
+
del ort_outs
|
| 56 |
+
|
| 57 |
+
result = np.clip(result, -1, 1)
|
| 58 |
+
result = (result + 1) / 2
|
| 59 |
+
result = result.transpose(1, 2, 0) * 255.0
|
| 60 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
| 61 |
+
scale_factor = int(result.shape[1] / input_size)
|
| 62 |
+
return result.astype(np.uint8), scale_factor
|
| 63 |
+
|
| 64 |
+
def Release(self):
|
| 65 |
+
del self.model_restoreformerpplus
|
| 66 |
+
self.model_restoreformerpplus = None
|
| 67 |
+
del self.io_binding
|
| 68 |
+
self.io_binding = None
|
processors/FaceSwapInsightFace.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import roop.globals
|
| 2 |
+
import numpy as np
|
| 3 |
+
import onnx
|
| 4 |
+
import onnxruntime
|
| 5 |
+
|
| 6 |
+
from roop.typing import Face, Frame
|
| 7 |
+
from roop.utilities import resolve_relative_path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FaceSwapInsightFace:
|
| 11 |
+
plugin_options: dict = None
|
| 12 |
+
model_swap_insightface = None
|
| 13 |
+
|
| 14 |
+
processorname = "faceswap"
|
| 15 |
+
type = "swap"
|
| 16 |
+
|
| 17 |
+
def Initialize(self, plugin_options: dict):
|
| 18 |
+
if self.plugin_options is not None:
|
| 19 |
+
if (
|
| 20 |
+
self.plugin_options["devicename"] != plugin_options["devicename"]
|
| 21 |
+
or self.plugin_options["modelname"] != plugin_options["modelname"]
|
| 22 |
+
):
|
| 23 |
+
self.Release()
|
| 24 |
+
|
| 25 |
+
self.plugin_options = plugin_options
|
| 26 |
+
if self.model_swap_insightface is None:
|
| 27 |
+
model_path = resolve_relative_path(
|
| 28 |
+
"../models/" + self.plugin_options["modelname"]
|
| 29 |
+
)
|
| 30 |
+
graph = onnx.load(model_path).graph
|
| 31 |
+
self.emap = onnx.numpy_helper.to_array(graph.initializer[-1])
|
| 32 |
+
self.devicename = self.plugin_options["devicename"].replace("mps", "cpu")
|
| 33 |
+
self.input_mean = 0.0
|
| 34 |
+
self.input_std = 255.0
|
| 35 |
+
# cuda_options = {"arena_extend_strategy": "kSameAsRequested", 'cudnn_conv_algo_search': 'DEFAULT'}
|
| 36 |
+
sess_options = onnxruntime.SessionOptions()
|
| 37 |
+
sess_options.enable_cpu_mem_arena = False
|
| 38 |
+
self.model_swap_insightface = onnxruntime.InferenceSession(
|
| 39 |
+
model_path, sess_options, providers=roop.globals.execution_providers
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def Run(self, source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
|
| 43 |
+
latent = source_face.normed_embedding.reshape((1, -1))
|
| 44 |
+
latent = np.dot(latent, self.emap)
|
| 45 |
+
latent /= np.linalg.norm(latent)
|
| 46 |
+
io_binding = self.model_swap_insightface.io_binding()
|
| 47 |
+
io_binding.bind_cpu_input("target", temp_frame)
|
| 48 |
+
io_binding.bind_cpu_input("source", latent)
|
| 49 |
+
io_binding.bind_output("output", self.devicename)
|
| 50 |
+
self.model_swap_insightface.run_with_iobinding(io_binding)
|
| 51 |
+
ort_outs = io_binding.copy_outputs_to_cpu()[0]
|
| 52 |
+
return ort_outs[0]
|
| 53 |
+
|
| 54 |
+
def Release(self):
|
| 55 |
+
del self.model_swap_insightface
|
| 56 |
+
self.model_swap_insightface = None
|
processors/Frame_Colorizer.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import onnxruntime
|
| 4 |
+
import roop.globals
|
| 5 |
+
|
| 6 |
+
from roop.utilities import resolve_relative_path
|
| 7 |
+
from roop.typing import Frame
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Frame_Colorizer:
|
| 11 |
+
plugin_options: dict = None
|
| 12 |
+
model_colorizer = None
|
| 13 |
+
devicename = None
|
| 14 |
+
prev_type = None
|
| 15 |
+
|
| 16 |
+
processorname = "deoldify"
|
| 17 |
+
type = "frame_colorizer"
|
| 18 |
+
|
| 19 |
+
def Initialize(self, plugin_options: dict):
|
| 20 |
+
if self.plugin_options is not None:
|
| 21 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
| 22 |
+
self.Release()
|
| 23 |
+
|
| 24 |
+
self.plugin_options = plugin_options
|
| 25 |
+
if (
|
| 26 |
+
self.prev_type is not None
|
| 27 |
+
and self.prev_type != self.plugin_options["subtype"]
|
| 28 |
+
):
|
| 29 |
+
self.Release()
|
| 30 |
+
self.prev_type = self.plugin_options["subtype"]
|
| 31 |
+
if self.model_colorizer is None:
|
| 32 |
+
# replace Mac mps with cpu for the moment
|
| 33 |
+
self.devicename = self.plugin_options["devicename"].replace("mps", "cpu")
|
| 34 |
+
if self.prev_type == "deoldify_artistic":
|
| 35 |
+
model_path = resolve_relative_path(
|
| 36 |
+
"../models/Frame/deoldify_artistic.onnx"
|
| 37 |
+
)
|
| 38 |
+
elif self.prev_type == "deoldify_stable":
|
| 39 |
+
model_path = resolve_relative_path(
|
| 40 |
+
"../models/Frame/deoldify_stable.onnx"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
onnxruntime.set_default_logger_severity(3)
|
| 44 |
+
self.model_colorizer = onnxruntime.InferenceSession(
|
| 45 |
+
model_path, None, providers=roop.globals.execution_providers
|
| 46 |
+
)
|
| 47 |
+
self.model_inputs = self.model_colorizer.get_inputs()
|
| 48 |
+
model_outputs = self.model_colorizer.get_outputs()
|
| 49 |
+
self.io_binding = self.model_colorizer.io_binding()
|
| 50 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
| 51 |
+
|
| 52 |
+
def Run(self, input_frame: Frame) -> Frame:
|
| 53 |
+
temp_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2GRAY)
|
| 54 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB)
|
| 55 |
+
temp_frame = cv2.resize(temp_frame, (256, 256))
|
| 56 |
+
temp_frame = temp_frame.transpose((2, 0, 1))
|
| 57 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).astype(np.float32)
|
| 58 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame)
|
| 59 |
+
self.model_colorizer.run_with_iobinding(self.io_binding)
|
| 60 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
| 61 |
+
result = ort_outs[0][0]
|
| 62 |
+
del ort_outs
|
| 63 |
+
colorized_frame = result.transpose(1, 2, 0)
|
| 64 |
+
colorized_frame = cv2.resize(
|
| 65 |
+
colorized_frame, (input_frame.shape[1], input_frame.shape[0])
|
| 66 |
+
)
|
| 67 |
+
temp_blue_channel, _, _ = cv2.split(input_frame)
|
| 68 |
+
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2RGB).astype(
|
| 69 |
+
np.uint8
|
| 70 |
+
)
|
| 71 |
+
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2LAB)
|
| 72 |
+
_, color_green_channel, color_red_channel = cv2.split(colorized_frame)
|
| 73 |
+
colorized_frame = cv2.merge(
|
| 74 |
+
(temp_blue_channel, color_green_channel, color_red_channel)
|
| 75 |
+
)
|
| 76 |
+
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_LAB2BGR)
|
| 77 |
+
return colorized_frame.astype(np.uint8)
|
| 78 |
+
|
| 79 |
+
def Release(self):
|
| 80 |
+
del self.model_colorizer
|
| 81 |
+
self.model_colorizer = None
|
| 82 |
+
del self.io_binding
|
| 83 |
+
self.io_binding = None
|
processors/Frame_Filter.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from roop.typing import Frame
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Frame_Filter:
|
| 8 |
+
processorname = "generic_filter"
|
| 9 |
+
type = "frame_processor"
|
| 10 |
+
|
| 11 |
+
plugin_options: dict = None
|
| 12 |
+
|
| 13 |
+
c64_palette = np.array(
|
| 14 |
+
[
|
| 15 |
+
[0, 0, 0],
|
| 16 |
+
[255, 255, 255],
|
| 17 |
+
[0x81, 0x33, 0x38],
|
| 18 |
+
[0x75, 0xCE, 0xC8],
|
| 19 |
+
[0x8E, 0x3C, 0x97],
|
| 20 |
+
[0x56, 0xAC, 0x4D],
|
| 21 |
+
[0x2E, 0x2C, 0x9B],
|
| 22 |
+
[0xED, 0xF1, 0x71],
|
| 23 |
+
[0x8E, 0x50, 0x29],
|
| 24 |
+
[0x55, 0x38, 0x00],
|
| 25 |
+
[0xC4, 0x6C, 0x71],
|
| 26 |
+
[0x4A, 0x4A, 0x4A],
|
| 27 |
+
[0x7B, 0x7B, 0x7B],
|
| 28 |
+
[0xA9, 0xFF, 0x9F],
|
| 29 |
+
[0x70, 0x6D, 0xEB],
|
| 30 |
+
[0xB2, 0xB2, 0xB2],
|
| 31 |
+
]
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def RenderC64Screen(self, image):
|
| 35 |
+
# Simply round the color values to the nearest color in the palette
|
| 36 |
+
image = cv2.resize(image, (320, 200))
|
| 37 |
+
palette = self.c64_palette / 255.0 # Normalize palette
|
| 38 |
+
img_normalized = image / 255.0 # Normalize image
|
| 39 |
+
|
| 40 |
+
# Calculate the index in the palette that is closest to each pixel in the image
|
| 41 |
+
indices = np.sqrt(
|
| 42 |
+
((img_normalized[:, :, None, :] - palette[None, None, :, :]) ** 2).sum(
|
| 43 |
+
axis=3
|
| 44 |
+
)
|
| 45 |
+
).argmin(axis=2)
|
| 46 |
+
# Map the image to the palette colors
|
| 47 |
+
mapped_image = palette[indices]
|
| 48 |
+
return (mapped_image * 255).astype(np.uint8) # Denormalize and return the image
|
| 49 |
+
|
| 50 |
+
def RenderDetailEnhance(self, image):
|
| 51 |
+
return cv2.detailEnhance(image)
|
| 52 |
+
|
| 53 |
+
def RenderStylize(self, image):
|
| 54 |
+
return cv2.stylization(image)
|
| 55 |
+
|
| 56 |
+
def RenderPencilSketch(self, image):
|
| 57 |
+
imgray, imout = cv2.pencilSketch(
|
| 58 |
+
image, sigma_s=60, sigma_r=0.07, shade_factor=0.05
|
| 59 |
+
)
|
| 60 |
+
return imout
|
| 61 |
+
|
| 62 |
+
def RenderCartoon(self, image):
|
| 63 |
+
numDownSamples = 2 # number of downscaling steps
|
| 64 |
+
numBilateralFilters = 7 # number of bilateral filtering steps
|
| 65 |
+
|
| 66 |
+
img_color = image
|
| 67 |
+
for _ in range(numDownSamples):
|
| 68 |
+
img_color = cv2.pyrDown(img_color)
|
| 69 |
+
for _ in range(numBilateralFilters):
|
| 70 |
+
img_color = cv2.bilateralFilter(img_color, 9, 9, 7)
|
| 71 |
+
for _ in range(numDownSamples):
|
| 72 |
+
img_color = cv2.pyrUp(img_color)
|
| 73 |
+
img_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 74 |
+
img_blur = cv2.medianBlur(img_gray, 7)
|
| 75 |
+
img_edge = cv2.adaptiveThreshold(
|
| 76 |
+
img_blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 2
|
| 77 |
+
)
|
| 78 |
+
img_edge = cv2.cvtColor(img_edge, cv2.COLOR_GRAY2RGB)
|
| 79 |
+
if img_color.shape != image.shape:
|
| 80 |
+
img_color = cv2.resize(
|
| 81 |
+
img_color,
|
| 82 |
+
(image.shape[1], image.shape[0]),
|
| 83 |
+
interpolation=cv2.INTER_LINEAR,
|
| 84 |
+
)
|
| 85 |
+
if img_color.shape != img_edge.shape:
|
| 86 |
+
img_edge = cv2.resize(
|
| 87 |
+
img_edge,
|
| 88 |
+
(img_color.shape[1], img_color.shape[0]),
|
| 89 |
+
interpolation=cv2.INTER_LINEAR,
|
| 90 |
+
)
|
| 91 |
+
return cv2.bitwise_and(img_color, img_edge)
|
| 92 |
+
|
| 93 |
+
def Initialize(self, plugin_options: dict):
|
| 94 |
+
if self.plugin_options is not None:
|
| 95 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
| 96 |
+
self.Release()
|
| 97 |
+
self.plugin_options = plugin_options
|
| 98 |
+
|
| 99 |
+
def Run(self, temp_frame: Frame) -> Frame:
|
| 100 |
+
subtype = self.plugin_options["subtype"]
|
| 101 |
+
if subtype == "stylize":
|
| 102 |
+
return self.RenderStylize(temp_frame).astype(np.uint8)
|
| 103 |
+
if subtype == "detailenhance":
|
| 104 |
+
return self.RenderDetailEnhance(temp_frame).astype(np.uint8)
|
| 105 |
+
if subtype == "pencil":
|
| 106 |
+
return self.RenderPencilSketch(temp_frame).astype(np.uint8)
|
| 107 |
+
if subtype == "cartoon":
|
| 108 |
+
return self.RenderCartoon(temp_frame).astype(np.uint8)
|
| 109 |
+
if subtype == "C64":
|
| 110 |
+
return self.RenderC64Screen(temp_frame).astype(np.uint8)
|
| 111 |
+
|
| 112 |
+
def Release(self):
|
| 113 |
+
pass
|
| 114 |
+
|
| 115 |
+
def getProcessedResolution(self, width, height):
|
| 116 |
+
if self.plugin_options["subtype"] == "C64":
|
| 117 |
+
return (320, 200)
|
| 118 |
+
return None
|
processors/Frame_Masking.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import onnxruntime
|
| 4 |
+
import roop.globals
|
| 5 |
+
|
| 6 |
+
from roop.utilities import resolve_relative_path
|
| 7 |
+
from roop.typing import Frame
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Frame_Masking:
|
| 11 |
+
plugin_options: dict = None
|
| 12 |
+
model_masking = None
|
| 13 |
+
devicename = None
|
| 14 |
+
name = None
|
| 15 |
+
|
| 16 |
+
processorname = "removebg"
|
| 17 |
+
type = "frame_masking"
|
| 18 |
+
|
| 19 |
+
def Initialize(self, plugin_options: dict):
|
| 20 |
+
if self.plugin_options is not None:
|
| 21 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
| 22 |
+
self.Release()
|
| 23 |
+
|
| 24 |
+
self.plugin_options = plugin_options
|
| 25 |
+
if self.model_masking is None:
|
| 26 |
+
# replace Mac mps with cpu for the moment
|
| 27 |
+
self.devicename = self.plugin_options["devicename"]
|
| 28 |
+
self.devicename = self.devicename.replace("mps", "cpu")
|
| 29 |
+
model_path = resolve_relative_path("../models/Frame/isnet-general-use.onnx")
|
| 30 |
+
self.model_masking = onnxruntime.InferenceSession(
|
| 31 |
+
model_path, None, providers=roop.globals.execution_providers
|
| 32 |
+
)
|
| 33 |
+
self.model_inputs = self.model_masking.get_inputs()
|
| 34 |
+
model_outputs = self.model_masking.get_outputs()
|
| 35 |
+
self.io_binding = self.model_masking.io_binding()
|
| 36 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
| 37 |
+
|
| 38 |
+
def Run(self, temp_frame: Frame) -> Frame:
|
| 39 |
+
# Pre process:Resize, BGR->RGB, float32 cast
|
| 40 |
+
input_image = cv2.resize(temp_frame, (1024, 1024))
|
| 41 |
+
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
|
| 42 |
+
mean = [0.5, 0.5, 0.5]
|
| 43 |
+
std = [1.0, 1.0, 1.0]
|
| 44 |
+
input_image = (input_image / 255.0 - mean) / std
|
| 45 |
+
input_image = input_image.transpose(2, 0, 1)
|
| 46 |
+
input_image = np.expand_dims(input_image, axis=0)
|
| 47 |
+
input_image = input_image.astype("float32")
|
| 48 |
+
|
| 49 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, input_image)
|
| 50 |
+
self.model_masking.run_with_iobinding(self.io_binding)
|
| 51 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
| 52 |
+
result = ort_outs[0][0]
|
| 53 |
+
del ort_outs
|
| 54 |
+
# Post process:squeeze, Sigmoid, Normarize, uint8 cast
|
| 55 |
+
mask = np.squeeze(result[0])
|
| 56 |
+
min_value = np.min(mask)
|
| 57 |
+
max_value = np.max(mask)
|
| 58 |
+
mask = (mask - min_value) / (max_value - min_value)
|
| 59 |
+
# mask = np.where(mask < score_th, 0, 1)
|
| 60 |
+
# mask *= 255
|
| 61 |
+
mask = cv2.resize(
|
| 62 |
+
mask,
|
| 63 |
+
(temp_frame.shape[1], temp_frame.shape[0]),
|
| 64 |
+
interpolation=cv2.INTER_LINEAR,
|
| 65 |
+
)
|
| 66 |
+
mask = np.reshape(mask, [mask.shape[0], mask.shape[1], 1])
|
| 67 |
+
result = mask * temp_frame.astype(np.float32)
|
| 68 |
+
return result.astype(np.uint8)
|
| 69 |
+
|
| 70 |
+
def Release(self):
|
| 71 |
+
del self.model_masking
|
| 72 |
+
self.model_masking = None
|
| 73 |
+
del self.io_binding
|
| 74 |
+
self.io_binding = None
|
processors/Frame_Upscale.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import onnxruntime
|
| 4 |
+
import roop.globals
|
| 5 |
+
|
| 6 |
+
from roop.utilities import resolve_relative_path, conditional_thread_semaphore
|
| 7 |
+
from roop.typing import Frame
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Frame_Upscale:
|
| 11 |
+
plugin_options: dict = None
|
| 12 |
+
model_upscale = None
|
| 13 |
+
devicename = None
|
| 14 |
+
prev_type = None
|
| 15 |
+
|
| 16 |
+
processorname = "upscale"
|
| 17 |
+
type = "frame_enhancer"
|
| 18 |
+
|
| 19 |
+
def Initialize(self, plugin_options: dict):
|
| 20 |
+
if self.plugin_options is not None:
|
| 21 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
| 22 |
+
self.Release()
|
| 23 |
+
|
| 24 |
+
self.plugin_options = plugin_options
|
| 25 |
+
if (
|
| 26 |
+
self.prev_type is not None
|
| 27 |
+
and self.prev_type != self.plugin_options["subtype"]
|
| 28 |
+
):
|
| 29 |
+
self.Release()
|
| 30 |
+
self.prev_type = self.plugin_options["subtype"]
|
| 31 |
+
if self.model_upscale is None:
|
| 32 |
+
# replace Mac mps with cpu for the moment
|
| 33 |
+
self.devicename = self.plugin_options["devicename"].replace("mps", "cpu")
|
| 34 |
+
if self.prev_type == "esrganx4":
|
| 35 |
+
model_path = resolve_relative_path(
|
| 36 |
+
"../models/Frame/real_esrgan_x4.onnx"
|
| 37 |
+
)
|
| 38 |
+
self.scale = 4
|
| 39 |
+
elif self.prev_type == "esrganx2":
|
| 40 |
+
model_path = resolve_relative_path(
|
| 41 |
+
"../models/Frame/real_esrgan_x2.onnx"
|
| 42 |
+
)
|
| 43 |
+
self.scale = 2
|
| 44 |
+
elif self.prev_type == "lsdirx4":
|
| 45 |
+
model_path = resolve_relative_path("../models/Frame/lsdir_x4.onnx")
|
| 46 |
+
self.scale = 4
|
| 47 |
+
onnxruntime.set_default_logger_severity(3)
|
| 48 |
+
self.model_upscale = onnxruntime.InferenceSession(
|
| 49 |
+
model_path, None, providers=roop.globals.execution_providers
|
| 50 |
+
)
|
| 51 |
+
self.model_inputs = self.model_upscale.get_inputs()
|
| 52 |
+
model_outputs = self.model_upscale.get_outputs()
|
| 53 |
+
self.io_binding = self.model_upscale.io_binding()
|
| 54 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
| 55 |
+
|
| 56 |
+
def getProcessedResolution(self, width, height):
|
| 57 |
+
return (width * self.scale, height * self.scale)
|
| 58 |
+
|
| 59 |
+
# borrowed from facefusion -> https://github.com/facefusion/facefusion
|
| 60 |
+
def prepare_tile_frame(self, tile_frame: Frame) -> Frame:
|
| 61 |
+
tile_frame = np.expand_dims(tile_frame[:, :, ::-1], axis=0)
|
| 62 |
+
tile_frame = tile_frame.transpose(0, 3, 1, 2)
|
| 63 |
+
tile_frame = tile_frame.astype(np.float32) / 255
|
| 64 |
+
return tile_frame
|
| 65 |
+
|
| 66 |
+
def normalize_tile_frame(self, tile_frame: Frame) -> Frame:
|
| 67 |
+
tile_frame = tile_frame.transpose(0, 2, 3, 1).squeeze(0) * 255
|
| 68 |
+
tile_frame = tile_frame.clip(0, 255).astype(np.uint8)[:, :, ::-1]
|
| 69 |
+
return tile_frame
|
| 70 |
+
|
| 71 |
+
def create_tile_frames(self, input_frame: Frame, size):
|
| 72 |
+
input_frame = np.pad(
|
| 73 |
+
input_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0))
|
| 74 |
+
)
|
| 75 |
+
tile_width = size[0] - 2 * size[2]
|
| 76 |
+
pad_size_bottom = size[2] + tile_width - input_frame.shape[0] % tile_width
|
| 77 |
+
pad_size_right = size[2] + tile_width - input_frame.shape[1] % tile_width
|
| 78 |
+
pad_vision_frame = np.pad(
|
| 79 |
+
input_frame, ((size[2], pad_size_bottom), (size[2], pad_size_right), (0, 0))
|
| 80 |
+
)
|
| 81 |
+
pad_height, pad_width = pad_vision_frame.shape[:2]
|
| 82 |
+
row_range = range(size[2], pad_height - size[2], tile_width)
|
| 83 |
+
col_range = range(size[2], pad_width - size[2], tile_width)
|
| 84 |
+
tile_frames = []
|
| 85 |
+
|
| 86 |
+
for row_frame in row_range:
|
| 87 |
+
top = row_frame - size[2]
|
| 88 |
+
bottom = row_frame + size[2] + tile_width
|
| 89 |
+
for column_vision_frame in col_range:
|
| 90 |
+
left = column_vision_frame - size[2]
|
| 91 |
+
right = column_vision_frame + size[2] + tile_width
|
| 92 |
+
tile_frames.append(pad_vision_frame[top:bottom, left:right, :])
|
| 93 |
+
return tile_frames, pad_width, pad_height
|
| 94 |
+
|
| 95 |
+
def merge_tile_frames(
|
| 96 |
+
self,
|
| 97 |
+
tile_frames,
|
| 98 |
+
temp_width: int,
|
| 99 |
+
temp_height: int,
|
| 100 |
+
pad_width: int,
|
| 101 |
+
pad_height: int,
|
| 102 |
+
size,
|
| 103 |
+
) -> Frame:
|
| 104 |
+
merge_frame = np.zeros((pad_height, pad_width, 3)).astype(np.uint8)
|
| 105 |
+
tile_width = tile_frames[0].shape[1] - 2 * size[2]
|
| 106 |
+
tiles_per_row = min(pad_width // tile_width, len(tile_frames))
|
| 107 |
+
|
| 108 |
+
for index, tile_frame in enumerate(tile_frames):
|
| 109 |
+
tile_frame = tile_frame[size[2] : -size[2], size[2] : -size[2]]
|
| 110 |
+
row_index = index // tiles_per_row
|
| 111 |
+
col_index = index % tiles_per_row
|
| 112 |
+
top = row_index * tile_frame.shape[0]
|
| 113 |
+
bottom = top + tile_frame.shape[0]
|
| 114 |
+
left = col_index * tile_frame.shape[1]
|
| 115 |
+
right = left + tile_frame.shape[1]
|
| 116 |
+
merge_frame[top:bottom, left:right, :] = tile_frame
|
| 117 |
+
merge_frame = merge_frame[
|
| 118 |
+
size[1] : size[1] + temp_height, size[1] : size[1] + temp_width, :
|
| 119 |
+
]
|
| 120 |
+
return merge_frame
|
| 121 |
+
|
| 122 |
+
def Run(self, temp_frame: Frame) -> Frame:
|
| 123 |
+
size = (128, 8, 2)
|
| 124 |
+
temp_height, temp_width = temp_frame.shape[:2]
|
| 125 |
+
upscale_tile_frames, pad_width, pad_height = self.create_tile_frames(
|
| 126 |
+
temp_frame, size
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
for index, tile_frame in enumerate(upscale_tile_frames):
|
| 130 |
+
tile_frame = self.prepare_tile_frame(tile_frame)
|
| 131 |
+
with conditional_thread_semaphore():
|
| 132 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, tile_frame)
|
| 133 |
+
self.model_upscale.run_with_iobinding(self.io_binding)
|
| 134 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
| 135 |
+
result = ort_outs[0]
|
| 136 |
+
upscale_tile_frames[index] = self.normalize_tile_frame(result)
|
| 137 |
+
final_frame = self.merge_tile_frames(
|
| 138 |
+
upscale_tile_frames,
|
| 139 |
+
temp_width * self.scale,
|
| 140 |
+
temp_height * self.scale,
|
| 141 |
+
pad_width * self.scale,
|
| 142 |
+
pad_height * self.scale,
|
| 143 |
+
(size[0] * self.scale, size[1] * self.scale, size[2] * self.scale),
|
| 144 |
+
)
|
| 145 |
+
return final_frame.astype(np.uint8)
|
| 146 |
+
|
| 147 |
+
def Release(self):
|
| 148 |
+
del self.model_upscale
|
| 149 |
+
self.model_upscale = None
|
| 150 |
+
del self.io_binding
|
| 151 |
+
self.io_binding = None
|
processors/Mask_Clip2Seg.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import threading
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from clip.clipseg import CLIPDensePredT
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from roop.typing import Frame
|
| 10 |
+
|
| 11 |
+
THREAD_LOCK_CLIP = threading.Lock()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Mask_Clip2Seg:
|
| 15 |
+
plugin_options: dict = None
|
| 16 |
+
model_clip = None
|
| 17 |
+
|
| 18 |
+
processorname = "clip2seg"
|
| 19 |
+
type = "mask"
|
| 20 |
+
|
| 21 |
+
def Initialize(self, plugin_options: dict):
|
| 22 |
+
if self.plugin_options is not None:
|
| 23 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
| 24 |
+
self.Release()
|
| 25 |
+
|
| 26 |
+
self.plugin_options = plugin_options
|
| 27 |
+
if self.model_clip is None:
|
| 28 |
+
self.model_clip = CLIPDensePredT(
|
| 29 |
+
version="ViT-B/16", reduce_dim=64, complex_trans_conv=True
|
| 30 |
+
)
|
| 31 |
+
self.model_clip.eval()
|
| 32 |
+
self.model_clip.load_state_dict(
|
| 33 |
+
torch.load(
|
| 34 |
+
"models/CLIP/rd64-uni-refined.pth", map_location=torch.device("cpu")
|
| 35 |
+
),
|
| 36 |
+
strict=False,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
device = torch.device(self.plugin_options["devicename"])
|
| 40 |
+
self.model_clip.to(device)
|
| 41 |
+
|
| 42 |
+
def Run(self, img1, keywords: str) -> Frame:
|
| 43 |
+
if keywords is None or len(keywords) < 1 or img1 is None:
|
| 44 |
+
return img1
|
| 45 |
+
|
| 46 |
+
source_image_small = cv2.resize(img1, (256, 256))
|
| 47 |
+
|
| 48 |
+
img_mask = np.full(
|
| 49 |
+
(source_image_small.shape[0], source_image_small.shape[1]),
|
| 50 |
+
0,
|
| 51 |
+
dtype=np.float32,
|
| 52 |
+
)
|
| 53 |
+
mask_border = 1
|
| 54 |
+
l = 0
|
| 55 |
+
t = 0
|
| 56 |
+
r = 1
|
| 57 |
+
b = 1
|
| 58 |
+
|
| 59 |
+
mask_blur = 5
|
| 60 |
+
clip_blur = 5
|
| 61 |
+
|
| 62 |
+
img_mask = cv2.rectangle(
|
| 63 |
+
img_mask,
|
| 64 |
+
(mask_border + int(l), mask_border + int(t)),
|
| 65 |
+
(256 - mask_border - int(r), 256 - mask_border - int(b)),
|
| 66 |
+
(255, 255, 255),
|
| 67 |
+
-1,
|
| 68 |
+
)
|
| 69 |
+
img_mask = cv2.GaussianBlur(img_mask, (mask_blur * 2 + 1, mask_blur * 2 + 1), 0)
|
| 70 |
+
img_mask /= 255
|
| 71 |
+
|
| 72 |
+
input_image = source_image_small
|
| 73 |
+
|
| 74 |
+
transform = transforms.Compose(
|
| 75 |
+
[
|
| 76 |
+
transforms.ToTensor(),
|
| 77 |
+
transforms.Normalize(
|
| 78 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 79 |
+
),
|
| 80 |
+
transforms.Resize((256, 256)),
|
| 81 |
+
]
|
| 82 |
+
)
|
| 83 |
+
img = transform(input_image).unsqueeze(0)
|
| 84 |
+
|
| 85 |
+
thresh = 0.5
|
| 86 |
+
prompts = keywords.split(",")
|
| 87 |
+
with THREAD_LOCK_CLIP:
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
preds = self.model_clip(img.repeat(len(prompts), 1, 1, 1), prompts)[0]
|
| 90 |
+
clip_mask = torch.sigmoid(preds[0][0])
|
| 91 |
+
for i in range(len(prompts) - 1):
|
| 92 |
+
clip_mask += torch.sigmoid(preds[i + 1][0])
|
| 93 |
+
|
| 94 |
+
clip_mask = clip_mask.data.cpu().numpy()
|
| 95 |
+
np.clip(clip_mask, 0, 1)
|
| 96 |
+
|
| 97 |
+
clip_mask[clip_mask > thresh] = 1.0
|
| 98 |
+
clip_mask[clip_mask <= thresh] = 0.0
|
| 99 |
+
kernel = np.ones((5, 5), np.float32)
|
| 100 |
+
clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
|
| 101 |
+
clip_mask = cv2.GaussianBlur(
|
| 102 |
+
clip_mask, (clip_blur * 2 + 1, clip_blur * 2 + 1), 0
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
img_mask *= clip_mask
|
| 106 |
+
img_mask[img_mask < 0.0] = 0.0
|
| 107 |
+
return img_mask
|
| 108 |
+
|
| 109 |
+
def Release(self):
|
| 110 |
+
self.model_clip = None
|
processors/Mask_XSeg.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import onnxruntime
|
| 4 |
+
import roop.globals
|
| 5 |
+
|
| 6 |
+
from roop.typing import Frame
|
| 7 |
+
from roop.utilities import resolve_relative_path, conditional_thread_semaphore
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Mask_XSeg:
|
| 11 |
+
plugin_options: dict = None
|
| 12 |
+
|
| 13 |
+
model_xseg = None
|
| 14 |
+
|
| 15 |
+
processorname = "mask_xseg"
|
| 16 |
+
type = "mask"
|
| 17 |
+
|
| 18 |
+
def Initialize(self, plugin_options: dict):
|
| 19 |
+
if self.plugin_options is not None:
|
| 20 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
| 21 |
+
self.Release()
|
| 22 |
+
|
| 23 |
+
self.plugin_options = plugin_options
|
| 24 |
+
if self.model_xseg is None:
|
| 25 |
+
model_path = resolve_relative_path("../models/xseg.onnx")
|
| 26 |
+
onnxruntime.set_default_logger_severity(3)
|
| 27 |
+
self.model_xseg = onnxruntime.InferenceSession(
|
| 28 |
+
model_path, None, providers=roop.globals.execution_providers
|
| 29 |
+
)
|
| 30 |
+
self.model_inputs = self.model_xseg.get_inputs()
|
| 31 |
+
self.model_outputs = self.model_xseg.get_outputs()
|
| 32 |
+
|
| 33 |
+
# replace Mac mps with cpu for the moment
|
| 34 |
+
self.devicename = self.plugin_options["devicename"].replace("mps", "cpu")
|
| 35 |
+
|
| 36 |
+
def Run(self, img1, keywords: str) -> Frame:
|
| 37 |
+
temp_frame = cv2.resize(img1, (256, 256), cv2.INTER_CUBIC)
|
| 38 |
+
temp_frame = temp_frame.astype("float32") / 255.0
|
| 39 |
+
temp_frame = temp_frame[None, ...]
|
| 40 |
+
io_binding = self.model_xseg.io_binding()
|
| 41 |
+
io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame)
|
| 42 |
+
io_binding.bind_output(self.model_outputs[0].name, self.devicename)
|
| 43 |
+
self.model_xseg.run_with_iobinding(io_binding)
|
| 44 |
+
ort_outs = io_binding.copy_outputs_to_cpu()
|
| 45 |
+
result = ort_outs[0][0]
|
| 46 |
+
result = np.clip(result, 0, 1.0)
|
| 47 |
+
result[result < 0.1] = 0
|
| 48 |
+
# invert values to mask areas to keep
|
| 49 |
+
result = 1.0 - result
|
| 50 |
+
return result
|
| 51 |
+
|
| 52 |
+
def Release(self):
|
| 53 |
+
del self.model_xseg
|
| 54 |
+
self.model_xseg = None
|
processors/__init__.py
ADDED
|
File without changes
|
processors/__pycache__/Enhance_CodeFormer.cpython-310.pyc
ADDED
|
Binary file (2.58 kB). View file
|
|
|
processors/__pycache__/Enhance_DMDNet.cpython-310.pyc
ADDED
|
Binary file (30.8 kB). View file
|
|
|
processors/__pycache__/FaceSwapInsightFace.cpython-310.pyc
ADDED
|
Binary file (2.15 kB). View file
|
|
|
processors/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (231 Bytes). View file
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu124
|
| 2 |
+
numpy==1.26.4
|
| 3 |
+
gradio==5.9.1
|
| 4 |
+
opencv-python-headless==4.10.0.84
|
| 5 |
+
onnx==1.16.1
|
| 6 |
+
insightface==0.7.3
|
| 7 |
+
albucore==0.0.16
|
| 8 |
+
psutil==5.9.6
|
| 9 |
+
torch==2.5.1+cu124; sys_platform != 'darwin'
|
| 10 |
+
torch==2.5.1; sys_platform == 'darwin'
|
| 11 |
+
torchvision==0.20.1+cu124; sys_platform != 'darwin'
|
| 12 |
+
torchvision==0.20.1; sys_platform == 'darwin'
|
| 13 |
+
onnxruntime==1.20.1; sys_platform == 'darwin' and platform_machine != 'arm64'
|
| 14 |
+
onnxruntime-silicon==1.20.1; sys_platform == 'darwin' and platform_machine == 'arm64'
|
| 15 |
+
onnxruntime-gpu==1.20.1; sys_platform != 'darwin'
|
| 16 |
+
tqdm==4.66.4
|
| 17 |
+
ftfy
|
| 18 |
+
regex
|
| 19 |
+
pyvirtualcam
|
| 20 |
+
pydantic==2.10.4
|
run.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
from roop import core
|
| 4 |
+
share=True
|
| 5 |
+
def run():
|
| 6 |
+
args = parse_args()
|
| 7 |
+
roop.globals.CFG.server_share = args.share # <-- toggle share here
|
| 8 |
+
...
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
core.run()
|
| 11 |
+
|
template_parser.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
|
| 4 |
+
template_functions = {
|
| 5 |
+
"timestamp": lambda data: str(int(datetime.now().timestamp())),
|
| 6 |
+
"i": lambda data: data.get("index", False),
|
| 7 |
+
"file": lambda data: data.get("file", False),
|
| 8 |
+
"date": lambda data: datetime.now().strftime("%Y-%m-%d"),
|
| 9 |
+
"time": lambda data: datetime.now().strftime("%H-%M-%S"),
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse(text: str, data: dict):
|
| 14 |
+
pattern = r"\{([^}]+)\}"
|
| 15 |
+
|
| 16 |
+
matches = re.findall(pattern, text)
|
| 17 |
+
|
| 18 |
+
for match in matches:
|
| 19 |
+
replacement = template_functions[match](data)
|
| 20 |
+
if replacement is not False:
|
| 21 |
+
text = text.replace(f"{{{match}}}", replacement)
|
| 22 |
+
|
| 23 |
+
return text
|
typing.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from insightface.app.common import Face
|
| 4 |
+
from roop.FaceSet import FaceSet
|
| 5 |
+
import numpy
|
| 6 |
+
|
| 7 |
+
Face = Face
|
| 8 |
+
FaceSet = FaceSet
|
| 9 |
+
Frame = numpy.ndarray[Any, Any]
|