Spaces:
Running
Running
dreamlessx commited on
Commit ·
433e26f
1
Parent(s): c951530
Add error handling, sync package from public repo, auto-trigger on upload
Browse files- Wrap all processing functions in try/except with informative error messages
- Add show_error=True for visible error reporting
- Auto-trigger processing when image is uploaded (not just on button click)
- Sync bundled landmarkdiff package to match public repo v0.2.0
- Add .gitignore for __pycache__
- .gitignore +1 -0
- app.py +99 -70
- landmarkdiff/__init__.py +1 -1
- landmarkdiff/__main__.py +111 -40
- landmarkdiff/api_client.py +15 -14
- landmarkdiff/arcface_torch.py +40 -22
- landmarkdiff/audit.py +5 -9
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +17 -12
- landmarkdiff/checkpoint_manager.py +10 -14
- landmarkdiff/cli.py +19 -13
- landmarkdiff/clinical.py +3 -3
- landmarkdiff/conditioning.py +115 -9
- landmarkdiff/config.py +24 -11
- landmarkdiff/curriculum.py +6 -9
- landmarkdiff/data.py +11 -13
- landmarkdiff/data_version.py +5 -8
- landmarkdiff/displacement_model.py +47 -43
- landmarkdiff/ensemble.py +35 -23
- landmarkdiff/evaluation.py +12 -11
- landmarkdiff/experiment_tracker.py +215 -0
- landmarkdiff/face_verifier.py +51 -39
- landmarkdiff/fid.py +85 -79
- landmarkdiff/hyperparam.py +36 -19
- landmarkdiff/inference.py +90 -38
- landmarkdiff/landmarks.py +131 -23
- landmarkdiff/log.py +6 -4
- landmarkdiff/losses.py +17 -6
- landmarkdiff/manipulation.py +335 -138
- landmarkdiff/metrics_agg.py +19 -14
- landmarkdiff/metrics_viz.py +62 -37
- landmarkdiff/model_registry.py +9 -16
- landmarkdiff/postprocess.py +27 -29
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +7 -12
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +3 -3
- landmarkdiff/synthetic/pair_generator.py +9 -13
- landmarkdiff/synthetic/tps_warp.py +4 -5
- landmarkdiff/validation.py +14 -8
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pyc
|
app.py
CHANGED
|
@@ -2,6 +2,9 @@
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
import cv2
|
| 6 |
import gradio as gr
|
| 7 |
import numpy as np
|
|
@@ -11,6 +14,9 @@ from landmarkdiff.landmarks import extract_landmarks
|
|
| 11 |
from landmarkdiff.manipulation import PROCEDURE_LANDMARKS, apply_procedure_preset
|
| 12 |
from landmarkdiff.masking import generate_surgical_mask
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
VERSION = "v0.2.1"
|
| 15 |
|
| 16 |
GITHUB_URL = "https://github.com/dreamlessx/LandmarkDiff-public"
|
|
@@ -44,17 +50,31 @@ def mask_composite(warped, original, mask):
|
|
| 44 |
PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
|
| 45 |
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
def process_image(image_rgb, procedure, intensity):
|
| 48 |
"""Process a single image through the TPS pipeline."""
|
| 49 |
if image_rgb is None:
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
| 54 |
-
image_bgr = cv2.resize(image_bgr, (512, 512))
|
| 55 |
-
image_rgb_512 = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
| 56 |
-
|
| 57 |
-
face = extract_landmarks(image_bgr)
|
| 58 |
if face is None:
|
| 59 |
return (
|
| 60 |
image_rgb_512,
|
|
@@ -64,38 +84,38 @@ def process_image(image_rgb, procedure, intensity):
|
|
| 64 |
"No face detected. Try a clearer photo with good lighting.",
|
| 65 |
)
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
wireframe_rgb = cv2.cvtColor(wireframe, cv2.COLOR_GRAY2RGB)
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
mask_vis = (mask * 255).astype(np.uint8)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
composited_rgb = cv2.cvtColor(composited, cv2.COLOR_BGR2RGB)
|
| 82 |
|
| 83 |
-
|
| 84 |
-
side_by_side = np.hstack([image_rgb_512, composited_rgb])
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
def compare_procedures(image_rgb, intensity):
|
|
@@ -104,23 +124,28 @@ def compare_procedures(image_rgb, intensity):
|
|
| 104 |
blank = np.zeros((512, 512, 3), dtype=np.uint8)
|
| 105 |
return [blank] * len(PROCEDURES)
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
def intensity_sweep(image_rgb, procedure):
|
|
@@ -128,27 +153,31 @@ def intensity_sweep(image_rgb, procedure):
|
|
| 128 |
if image_rgb is None:
|
| 129 |
return []
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
return []
|
| 137 |
|
| 138 |
-
steps = [0, 20, 40, 60, 80, 100]
|
| 139 |
-
results = []
|
| 140 |
-
for val in steps:
|
| 141 |
-
if val == 0:
|
| 142 |
-
results.append((cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB), "0%"))
|
| 143 |
-
continue
|
| 144 |
-
manip = apply_procedure_preset(face, procedure, float(val), image_size=512)
|
| 145 |
-
mask = generate_surgical_mask(face, procedure, 512, 512)
|
| 146 |
-
warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
|
| 147 |
-
comp = mask_composite(warped, image_bgr, mask)
|
| 148 |
-
results.append((cv2.cvtColor(comp, cv2.COLOR_BGR2RGB), f"{val}%"))
|
| 149 |
-
|
| 150 |
-
return results
|
| 151 |
-
|
| 152 |
|
| 153 |
# -- Build the procedure table for the description --
|
| 154 |
_proc_rows = "\n".join(
|
|
@@ -248,7 +277,7 @@ with gr.Blocks(
|
|
| 248 |
inputs=[input_image, procedure, intensity],
|
| 249 |
outputs=[out_wireframe, out_mask, out_result, out_sidebyside, info_box],
|
| 250 |
)
|
| 251 |
-
for trigger in [procedure, intensity]:
|
| 252 |
trigger.change(
|
| 253 |
fn=process_image,
|
| 254 |
inputs=[input_image, procedure, intensity],
|
|
@@ -310,4 +339,4 @@ with gr.Blocks(
|
|
| 310 |
gr.Markdown(FOOTER_MD)
|
| 311 |
|
| 312 |
if __name__ == "__main__":
|
| 313 |
-
demo.launch()
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import logging
|
| 6 |
+
import traceback
|
| 7 |
+
|
| 8 |
import cv2
|
| 9 |
import gradio as gr
|
| 10 |
import numpy as np
|
|
|
|
| 14 |
from landmarkdiff.manipulation import PROCEDURE_LANDMARKS, apply_procedure_preset
|
| 15 |
from landmarkdiff.masking import generate_surgical_mask
|
| 16 |
|
| 17 |
+
logging.basicConfig(level=logging.INFO)
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
VERSION = "v0.2.1"
|
| 21 |
|
| 22 |
GITHUB_URL = "https://github.com/dreamlessx/LandmarkDiff-public"
|
|
|
|
| 50 |
PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
|
| 51 |
|
| 52 |
|
| 53 |
+
def _error_result(msg):
|
| 54 |
+
"""Return a 5-tuple of blanks + error message for the UI."""
|
| 55 |
+
blank = np.zeros((512, 512, 3), dtype=np.uint8)
|
| 56 |
+
return blank, blank, blank, blank, msg
|
| 57 |
+
|
| 58 |
+
|
| 59 |
def process_image(image_rgb, procedure, intensity):
|
| 60 |
"""Process a single image through the TPS pipeline."""
|
| 61 |
if image_rgb is None:
|
| 62 |
+
return _error_result("Upload a face photo to begin.")
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
image_bgr = cv2.cvtColor(np.asarray(image_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
| 66 |
+
image_bgr = cv2.resize(image_bgr, (512, 512))
|
| 67 |
+
image_rgb_512 = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
| 68 |
+
except Exception as exc:
|
| 69 |
+
logger.error("Image conversion failed: %s", exc)
|
| 70 |
+
return _error_result(f"Image conversion failed: {exc}")
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
face = extract_landmarks(image_bgr)
|
| 74 |
+
except Exception as exc:
|
| 75 |
+
logger.error("Landmark extraction failed: %s\n%s", exc, traceback.format_exc())
|
| 76 |
+
return _error_result(f"Landmark extraction error: {exc}")
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if face is None:
|
| 79 |
return (
|
| 80 |
image_rgb_512,
|
|
|
|
| 84 |
"No face detected. Try a clearer photo with good lighting.",
|
| 85 |
)
|
| 86 |
|
| 87 |
+
try:
|
| 88 |
+
manipulated = apply_procedure_preset(face, procedure, float(intensity), image_size=512)
|
| 89 |
|
| 90 |
+
wireframe = render_wireframe(manipulated, width=512, height=512)
|
| 91 |
+
wireframe_rgb = cv2.cvtColor(wireframe, cv2.COLOR_GRAY2RGB)
|
|
|
|
| 92 |
|
| 93 |
+
mask = generate_surgical_mask(face, procedure, 512, 512)
|
| 94 |
+
mask_vis = (mask * 255).astype(np.uint8)
|
|
|
|
| 95 |
|
| 96 |
+
warped = warp_image_tps(image_bgr, face.pixel_coords, manipulated.pixel_coords)
|
| 97 |
+
composited = mask_composite(warped, image_bgr, mask)
|
| 98 |
+
composited_rgb = cv2.cvtColor(composited, cv2.COLOR_BGR2RGB)
|
|
|
|
| 99 |
|
| 100 |
+
side_by_side = np.hstack([image_rgb_512, composited_rgb])
|
|
|
|
| 101 |
|
| 102 |
+
displacement = np.mean(
|
| 103 |
+
np.linalg.norm(manipulated.pixel_coords - face.pixel_coords, axis=1)
|
| 104 |
+
)
|
| 105 |
|
| 106 |
+
info = (
|
| 107 |
+
f"Procedure: {procedure}\n"
|
| 108 |
+
f"Intensity: {intensity:.0f}%\n"
|
| 109 |
+
f"Landmarks: {len(face.landmarks)}\n"
|
| 110 |
+
f"Avg displacement: {displacement:.1f} px\n"
|
| 111 |
+
f"Confidence: {face.confidence:.2f}\n"
|
| 112 |
+
f"Mode: TPS (CPU)"
|
| 113 |
+
)
|
| 114 |
+
return wireframe_rgb, mask_vis, composited_rgb, side_by_side, info
|
| 115 |
|
| 116 |
+
except Exception as exc:
|
| 117 |
+
logger.error("Processing failed: %s\n%s", exc, traceback.format_exc())
|
| 118 |
+
return _error_result(f"Processing error: {exc}")
|
| 119 |
|
| 120 |
|
| 121 |
def compare_procedures(image_rgb, intensity):
|
|
|
|
| 124 |
blank = np.zeros((512, 512, 3), dtype=np.uint8)
|
| 125 |
return [blank] * len(PROCEDURES)
|
| 126 |
|
| 127 |
+
try:
|
| 128 |
+
image_bgr = cv2.cvtColor(np.asarray(image_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
| 129 |
+
image_bgr = cv2.resize(image_bgr, (512, 512))
|
| 130 |
+
|
| 131 |
+
face = extract_landmarks(image_bgr)
|
| 132 |
+
if face is None:
|
| 133 |
+
rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
| 134 |
+
return [rgb] * len(PROCEDURES)
|
| 135 |
+
|
| 136 |
+
results = []
|
| 137 |
+
for proc in PROCEDURES:
|
| 138 |
+
manip = apply_procedure_preset(face, proc, float(intensity), image_size=512)
|
| 139 |
+
mask = generate_surgical_mask(face, proc, 512, 512)
|
| 140 |
+
warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
|
| 141 |
+
comp = mask_composite(warped, image_bgr, mask)
|
| 142 |
+
results.append(cv2.cvtColor(comp, cv2.COLOR_BGR2RGB))
|
| 143 |
+
|
| 144 |
+
return results
|
| 145 |
+
except Exception as exc:
|
| 146 |
+
logger.error("Compare procedures failed: %s\n%s", exc, traceback.format_exc())
|
| 147 |
+
blank = np.zeros((512, 512, 3), dtype=np.uint8)
|
| 148 |
+
return [blank] * len(PROCEDURES)
|
| 149 |
|
| 150 |
|
| 151 |
def intensity_sweep(image_rgb, procedure):
|
|
|
|
| 153 |
if image_rgb is None:
|
| 154 |
return []
|
| 155 |
|
| 156 |
+
try:
|
| 157 |
+
image_bgr = cv2.cvtColor(np.asarray(image_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
| 158 |
+
image_bgr = cv2.resize(image_bgr, (512, 512))
|
| 159 |
+
|
| 160 |
+
face = extract_landmarks(image_bgr)
|
| 161 |
+
if face is None:
|
| 162 |
+
return []
|
| 163 |
+
|
| 164 |
+
steps = [0, 20, 40, 60, 80, 100]
|
| 165 |
+
results = []
|
| 166 |
+
for val in steps:
|
| 167 |
+
if val == 0:
|
| 168 |
+
results.append((cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB), "0%"))
|
| 169 |
+
continue
|
| 170 |
+
manip = apply_procedure_preset(face, procedure, float(val), image_size=512)
|
| 171 |
+
mask = generate_surgical_mask(face, procedure, 512, 512)
|
| 172 |
+
warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
|
| 173 |
+
comp = mask_composite(warped, image_bgr, mask)
|
| 174 |
+
results.append((cv2.cvtColor(comp, cv2.COLOR_BGR2RGB), f"{val}%"))
|
| 175 |
+
|
| 176 |
+
return results
|
| 177 |
+
except Exception as exc:
|
| 178 |
+
logger.error("Intensity sweep failed: %s\n%s", exc, traceback.format_exc())
|
| 179 |
return []
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
# -- Build the procedure table for the description --
|
| 183 |
_proc_rows = "\n".join(
|
|
|
|
| 277 |
inputs=[input_image, procedure, intensity],
|
| 278 |
outputs=[out_wireframe, out_mask, out_result, out_sidebyside, info_box],
|
| 279 |
)
|
| 280 |
+
for trigger in [input_image, procedure, intensity]:
|
| 281 |
trigger.change(
|
| 282 |
fn=process_image,
|
| 283 |
inputs=[input_image, procedure, intensity],
|
|
|
|
| 339 |
gr.Markdown(FOOTER_MD)
|
| 340 |
|
| 341 |
if __name__ == "__main__":
|
| 342 |
+
demo.launch(show_error=True)
|
landmarkdiff/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""LandmarkDiff: Anatomically-conditioned latent diffusion for facial surgery simulation."""
|
| 2 |
|
| 3 |
-
__version__ = "0.
|
| 4 |
|
| 5 |
__all__ = [
|
| 6 |
"api_client",
|
|
|
|
| 1 |
"""LandmarkDiff: Anatomically-conditioned latent diffusion for facial surgery simulation."""
|
| 2 |
|
| 3 |
+
__version__ = "0.2.0"
|
| 4 |
|
| 5 |
__all__ = [
|
| 6 |
"api_client",
|
landmarkdiff/__main__.py
CHANGED
|
@@ -1,80 +1,148 @@
|
|
| 1 |
"""CLI entry point for python -m landmarkdiff."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import argparse
|
| 4 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
def main():
|
| 8 |
parser = argparse.ArgumentParser(
|
| 9 |
prog="landmarkdiff",
|
| 10 |
description="Facial surgery outcome prediction from clinical photography",
|
| 11 |
)
|
| 12 |
-
parser.add_argument("--version", action="
|
| 13 |
|
| 14 |
subparsers = parser.add_subparsers(dest="command")
|
| 15 |
|
| 16 |
# inference
|
| 17 |
infer = subparsers.add_parser("infer", help="Run inference on an image")
|
| 18 |
infer.add_argument("image", type=str, help="Path to input face image")
|
| 19 |
-
infer.add_argument(
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# landmarks
|
| 30 |
lm = subparsers.add_parser("landmarks", help="Extract and visualize landmarks")
|
| 31 |
lm.add_argument("image", type=str, help="Path to input face image")
|
| 32 |
-
lm.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# demo
|
| 35 |
subparsers.add_parser("demo", help="Launch Gradio web demo")
|
| 36 |
|
| 37 |
args = parser.parse_args()
|
| 38 |
|
| 39 |
-
if args.version:
|
| 40 |
-
from landmarkdiff import __version__
|
| 41 |
-
print(f"landmarkdiff {__version__}")
|
| 42 |
-
return
|
| 43 |
-
|
| 44 |
if args.command is None:
|
| 45 |
parser.print_help()
|
| 46 |
return
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
import numpy as np
|
| 59 |
from PIL import Image
|
|
|
|
| 60 |
from landmarkdiff.landmarks import extract_landmarks
|
| 61 |
from landmarkdiff.manipulation import apply_procedure_preset
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
output_dir = Path(args.output)
|
| 64 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 65 |
|
| 66 |
-
img = Image.open(
|
| 67 |
img_array = np.array(img)
|
| 68 |
|
| 69 |
landmarks = extract_landmarks(img_array)
|
| 70 |
if landmarks is None:
|
| 71 |
-
|
| 72 |
-
sys.exit(1)
|
| 73 |
|
| 74 |
deformed = apply_procedure_preset(landmarks, args.procedure, intensity=args.intensity)
|
| 75 |
|
| 76 |
if args.mode == "tps":
|
| 77 |
from landmarkdiff.synthetic.tps_warp import warp_image_tps
|
|
|
|
| 78 |
src = landmarks.pixel_coords[:, :2].copy()
|
| 79 |
dst = deformed.pixel_coords[:, :2].copy()
|
| 80 |
src[:, 0] *= 512 / landmarks.image_width
|
|
@@ -85,8 +153,11 @@ def _run_inference(args):
|
|
| 85 |
Image.fromarray(warped).save(str(output_dir / "prediction.png"))
|
| 86 |
print(f"saved tps result to {output_dir / 'prediction.png'}")
|
| 87 |
else:
|
|
|
|
|
|
|
| 88 |
from landmarkdiff.inference import LandmarkDiffPipeline
|
| 89 |
-
|
|
|
|
| 90 |
pipeline.load()
|
| 91 |
result = pipeline.generate(
|
| 92 |
img_array,
|
|
@@ -99,37 +170,37 @@ def _run_inference(args):
|
|
| 99 |
print(f"saved result to {output_dir / 'prediction.png'}")
|
| 100 |
|
| 101 |
|
| 102 |
-
def _run_landmarks(args):
|
| 103 |
-
from pathlib import Path
|
| 104 |
import numpy as np
|
| 105 |
from PIL import Image
|
|
|
|
| 106 |
from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
|
| 107 |
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
landmarks = extract_landmarks(img)
|
| 110 |
if landmarks is None:
|
| 111 |
-
|
| 112 |
-
sys.exit(1)
|
| 113 |
|
| 114 |
mesh = render_landmark_image(landmarks, 512, 512)
|
| 115 |
|
| 116 |
output_path = Path(args.output)
|
| 117 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 118 |
|
| 119 |
-
from PIL import Image
|
| 120 |
Image.fromarray(mesh).save(str(output_path))
|
| 121 |
print(f"saved landmark mesh to {output_path}")
|
| 122 |
print(f"detected {len(landmarks.landmarks)} landmarks, confidence {landmarks.confidence:.2f}")
|
| 123 |
|
| 124 |
|
| 125 |
-
def _run_demo():
|
| 126 |
try:
|
| 127 |
from scripts.app import build_app
|
|
|
|
| 128 |
demo = build_app()
|
| 129 |
demo.launch()
|
| 130 |
except ImportError:
|
| 131 |
-
|
| 132 |
-
sys.exit(1)
|
| 133 |
|
| 134 |
|
| 135 |
if __name__ == "__main__":
|
|
|
|
| 1 |
"""CLI entry point for python -m landmarkdiff."""
|
| 2 |
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
import argparse
|
| 6 |
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import NoReturn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _error(msg: str) -> NoReturn:
|
| 12 |
+
"""Print error to stderr and exit."""
|
| 13 |
+
print(f"error: {msg}", file=sys.stderr)
|
| 14 |
+
sys.exit(1)
|
| 15 |
+
|
| 16 |
|
| 17 |
+
def _validate_image_path(path_str: str) -> Path:
|
| 18 |
+
"""Validate that the image path exists and looks like an image file."""
|
| 19 |
+
p = Path(path_str)
|
| 20 |
+
if not p.exists():
|
| 21 |
+
_error(f"file not found: {path_str}")
|
| 22 |
+
if not p.is_file():
|
| 23 |
+
_error(f"not a file: {path_str}")
|
| 24 |
+
return p
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main() -> None:
|
| 28 |
+
from landmarkdiff import __version__
|
| 29 |
|
|
|
|
| 30 |
parser = argparse.ArgumentParser(
|
| 31 |
prog="landmarkdiff",
|
| 32 |
description="Facial surgery outcome prediction from clinical photography",
|
| 33 |
)
|
| 34 |
+
parser.add_argument("--version", action="version", version=f"landmarkdiff {__version__}")
|
| 35 |
|
| 36 |
subparsers = parser.add_subparsers(dest="command")
|
| 37 |
|
| 38 |
# inference
|
| 39 |
infer = subparsers.add_parser("infer", help="Run inference on an image")
|
| 40 |
infer.add_argument("image", type=str, help="Path to input face image")
|
| 41 |
+
infer.add_argument(
|
| 42 |
+
"--procedure",
|
| 43 |
+
type=str,
|
| 44 |
+
default="rhinoplasty",
|
| 45 |
+
choices=[
|
| 46 |
+
"rhinoplasty",
|
| 47 |
+
"blepharoplasty",
|
| 48 |
+
"rhytidectomy",
|
| 49 |
+
"orthognathic",
|
| 50 |
+
"brow_lift",
|
| 51 |
+
"mentoplasty",
|
| 52 |
+
],
|
| 53 |
+
help="Surgical procedure to simulate (default: rhinoplasty)",
|
| 54 |
+
)
|
| 55 |
+
infer.add_argument(
|
| 56 |
+
"--intensity",
|
| 57 |
+
type=float,
|
| 58 |
+
default=60.0,
|
| 59 |
+
help="Deformation intensity, 0-100 (default: 60)",
|
| 60 |
+
)
|
| 61 |
+
infer.add_argument(
|
| 62 |
+
"--mode",
|
| 63 |
+
type=str,
|
| 64 |
+
default="tps",
|
| 65 |
+
choices=["tps", "controlnet", "img2img", "controlnet_ip"],
|
| 66 |
+
help="Inference mode (default: tps, others require GPU)",
|
| 67 |
+
)
|
| 68 |
+
infer.add_argument(
|
| 69 |
+
"--output",
|
| 70 |
+
type=str,
|
| 71 |
+
default="output/",
|
| 72 |
+
help="Output directory (default: output/)",
|
| 73 |
+
)
|
| 74 |
+
infer.add_argument(
|
| 75 |
+
"--steps",
|
| 76 |
+
type=int,
|
| 77 |
+
default=30,
|
| 78 |
+
help="Number of diffusion steps (default: 30)",
|
| 79 |
+
)
|
| 80 |
+
infer.add_argument(
|
| 81 |
+
"--seed",
|
| 82 |
+
type=int,
|
| 83 |
+
default=None,
|
| 84 |
+
help="Random seed for reproducibility",
|
| 85 |
+
)
|
| 86 |
|
| 87 |
# landmarks
|
| 88 |
lm = subparsers.add_parser("landmarks", help="Extract and visualize landmarks")
|
| 89 |
lm.add_argument("image", type=str, help="Path to input face image")
|
| 90 |
+
lm.add_argument(
|
| 91 |
+
"--output",
|
| 92 |
+
type=str,
|
| 93 |
+
default="output/landmarks.png",
|
| 94 |
+
help="Output path for landmark visualization (default: output/landmarks.png)",
|
| 95 |
+
)
|
| 96 |
|
| 97 |
# demo
|
| 98 |
subparsers.add_parser("demo", help="Launch Gradio web demo")
|
| 99 |
|
| 100 |
args = parser.parse_args()
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
if args.command is None:
|
| 103 |
parser.print_help()
|
| 104 |
return
|
| 105 |
|
| 106 |
+
try:
|
| 107 |
+
if args.command == "infer":
|
| 108 |
+
_run_inference(args)
|
| 109 |
+
elif args.command == "landmarks":
|
| 110 |
+
_run_landmarks(args)
|
| 111 |
+
elif args.command == "demo":
|
| 112 |
+
_run_demo()
|
| 113 |
+
except KeyboardInterrupt:
|
| 114 |
+
sys.exit(130)
|
| 115 |
+
except Exception as exc:
|
| 116 |
+
_error(str(exc))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _run_inference(args: argparse.Namespace) -> None:
|
| 120 |
import numpy as np
|
| 121 |
from PIL import Image
|
| 122 |
+
|
| 123 |
from landmarkdiff.landmarks import extract_landmarks
|
| 124 |
from landmarkdiff.manipulation import apply_procedure_preset
|
| 125 |
|
| 126 |
+
if not (0 <= args.intensity <= 100):
|
| 127 |
+
_error(f"intensity must be between 0 and 100, got {args.intensity}")
|
| 128 |
+
|
| 129 |
+
image_path = _validate_image_path(args.image)
|
| 130 |
+
|
| 131 |
output_dir = Path(args.output)
|
| 132 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 133 |
|
| 134 |
+
img = Image.open(image_path).convert("RGB").resize((512, 512))
|
| 135 |
img_array = np.array(img)
|
| 136 |
|
| 137 |
landmarks = extract_landmarks(img_array)
|
| 138 |
if landmarks is None:
|
| 139 |
+
_error("no face detected in image")
|
|
|
|
| 140 |
|
| 141 |
deformed = apply_procedure_preset(landmarks, args.procedure, intensity=args.intensity)
|
| 142 |
|
| 143 |
if args.mode == "tps":
|
| 144 |
from landmarkdiff.synthetic.tps_warp import warp_image_tps
|
| 145 |
+
|
| 146 |
src = landmarks.pixel_coords[:, :2].copy()
|
| 147 |
dst = deformed.pixel_coords[:, :2].copy()
|
| 148 |
src[:, 0] *= 512 / landmarks.image_width
|
|
|
|
| 153 |
Image.fromarray(warped).save(str(output_dir / "prediction.png"))
|
| 154 |
print(f"saved tps result to {output_dir / 'prediction.png'}")
|
| 155 |
else:
|
| 156 |
+
import torch
|
| 157 |
+
|
| 158 |
from landmarkdiff.inference import LandmarkDiffPipeline
|
| 159 |
+
|
| 160 |
+
pipeline = LandmarkDiffPipeline(mode=args.mode, device=torch.device("cuda"))
|
| 161 |
pipeline.load()
|
| 162 |
result = pipeline.generate(
|
| 163 |
img_array,
|
|
|
|
| 170 |
print(f"saved result to {output_dir / 'prediction.png'}")
|
| 171 |
|
| 172 |
|
| 173 |
+
def _run_landmarks(args: argparse.Namespace) -> None:
|
|
|
|
| 174 |
import numpy as np
|
| 175 |
from PIL import Image
|
| 176 |
+
|
| 177 |
from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
|
| 178 |
|
| 179 |
+
image_path = _validate_image_path(args.image)
|
| 180 |
+
|
| 181 |
+
img = np.array(Image.open(image_path).convert("RGB").resize((512, 512)))
|
| 182 |
landmarks = extract_landmarks(img)
|
| 183 |
if landmarks is None:
|
| 184 |
+
_error("no face detected in image")
|
|
|
|
| 185 |
|
| 186 |
mesh = render_landmark_image(landmarks, 512, 512)
|
| 187 |
|
| 188 |
output_path = Path(args.output)
|
| 189 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 190 |
|
|
|
|
| 191 |
Image.fromarray(mesh).save(str(output_path))
|
| 192 |
print(f"saved landmark mesh to {output_path}")
|
| 193 |
print(f"detected {len(landmarks.landmarks)} landmarks, confidence {landmarks.confidence:.2f}")
|
| 194 |
|
| 195 |
|
| 196 |
+
def _run_demo() -> None:
|
| 197 |
try:
|
| 198 |
from scripts.app import build_app
|
| 199 |
+
|
| 200 |
demo = build_app()
|
| 201 |
demo.launch()
|
| 202 |
except ImportError:
|
| 203 |
+
_error("gradio not installed - run: pip install landmarkdiff[app]")
|
|
|
|
| 204 |
|
| 205 |
|
| 206 |
if __name__ == "__main__":
|
landmarkdiff/api_client.py
CHANGED
|
@@ -26,7 +26,6 @@ Usage:
|
|
| 26 |
from __future__ import annotations
|
| 27 |
|
| 28 |
import base64
|
| 29 |
-
import io
|
| 30 |
from dataclasses import dataclass, field
|
| 31 |
from pathlib import Path
|
| 32 |
from typing import Any
|
|
@@ -43,8 +42,8 @@ class PredictionResult:
|
|
| 43 |
procedure: str
|
| 44 |
intensity: float
|
| 45 |
confidence: float = 0.0
|
| 46 |
-
landmarks_before: list | None = None
|
| 47 |
-
landmarks_after: list | None = None
|
| 48 |
metrics: dict[str, float] = field(default_factory=dict)
|
| 49 |
metadata: dict[str, Any] = field(default_factory=dict)
|
| 50 |
|
|
@@ -70,15 +69,15 @@ class LandmarkDiffClient:
|
|
| 70 |
def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 60.0) -> None:
|
| 71 |
self.base_url = base_url.rstrip("/")
|
| 72 |
self.timeout = timeout
|
| 73 |
-
self._session = None
|
| 74 |
|
| 75 |
-
def _get_session(self):
|
| 76 |
"""Lazy-initialize requests session."""
|
| 77 |
if self._session is None:
|
| 78 |
try:
|
| 79 |
import requests
|
| 80 |
except ImportError:
|
| 81 |
-
raise ImportError("requests required. Install with: pip install requests")
|
| 82 |
self._session = requests.Session()
|
| 83 |
self._session.timeout = self.timeout
|
| 84 |
return self._session
|
|
@@ -218,12 +217,14 @@ class LandmarkDiffClient:
|
|
| 218 |
results.append(result)
|
| 219 |
except Exception as e:
|
| 220 |
# Create a failed result
|
| 221 |
-
results.append(
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
return results
|
| 228 |
|
| 229 |
def close(self) -> None:
|
|
@@ -232,10 +233,10 @@ class LandmarkDiffClient:
|
|
| 232 |
self._session.close()
|
| 233 |
self._session = None
|
| 234 |
|
| 235 |
-
def __enter__(self):
|
| 236 |
return self
|
| 237 |
|
| 238 |
-
def __exit__(self, *args):
|
| 239 |
self.close()
|
| 240 |
|
| 241 |
def __repr__(self) -> str:
|
|
|
|
| 26 |
from __future__ import annotations
|
| 27 |
|
| 28 |
import base64
|
|
|
|
| 29 |
from dataclasses import dataclass, field
|
| 30 |
from pathlib import Path
|
| 31 |
from typing import Any
|
|
|
|
| 42 |
procedure: str
|
| 43 |
intensity: float
|
| 44 |
confidence: float = 0.0
|
| 45 |
+
landmarks_before: list[Any] | None = None
|
| 46 |
+
landmarks_after: list[Any] | None = None
|
| 47 |
metrics: dict[str, float] = field(default_factory=dict)
|
| 48 |
metadata: dict[str, Any] = field(default_factory=dict)
|
| 49 |
|
|
|
|
| 69 |
def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 60.0) -> None:
|
| 70 |
self.base_url = base_url.rstrip("/")
|
| 71 |
self.timeout = timeout
|
| 72 |
+
self._session: Any = None
|
| 73 |
|
| 74 |
+
def _get_session(self) -> Any:
|
| 75 |
"""Lazy-initialize requests session."""
|
| 76 |
if self._session is None:
|
| 77 |
try:
|
| 78 |
import requests
|
| 79 |
except ImportError:
|
| 80 |
+
raise ImportError("requests required. Install with: pip install requests") from None
|
| 81 |
self._session = requests.Session()
|
| 82 |
self._session.timeout = self.timeout
|
| 83 |
return self._session
|
|
|
|
| 217 |
results.append(result)
|
| 218 |
except Exception as e:
|
| 219 |
# Create a failed result
|
| 220 |
+
results.append(
|
| 221 |
+
PredictionResult(
|
| 222 |
+
output_image=np.zeros((512, 512, 3), dtype=np.uint8),
|
| 223 |
+
procedure=procedure,
|
| 224 |
+
intensity=intensity,
|
| 225 |
+
metadata={"error": str(e), "path": str(path)},
|
| 226 |
+
)
|
| 227 |
+
)
|
| 228 |
return results
|
| 229 |
|
| 230 |
def close(self) -> None:
|
|
|
|
| 233 |
self._session.close()
|
| 234 |
self._session = None
|
| 235 |
|
| 236 |
+
def __enter__(self) -> LandmarkDiffClient:
|
| 237 |
return self
|
| 238 |
|
| 239 |
+
def __exit__(self, *args: Any) -> None:
|
| 240 |
self.close()
|
| 241 |
|
| 242 |
def __repr__(self) -> str:
|
landmarkdiff/arcface_torch.py
CHANGED
|
@@ -29,7 +29,6 @@ from __future__ import annotations
|
|
| 29 |
import logging
|
| 30 |
import warnings
|
| 31 |
from pathlib import Path
|
| 32 |
-
from typing import Optional
|
| 33 |
|
| 34 |
import torch
|
| 35 |
import torch.nn as nn
|
|
@@ -42,6 +41,7 @@ logger = logging.getLogger(__name__)
|
|
| 42 |
# Building blocks
|
| 43 |
# ---------------------------------------------------------------------------
|
| 44 |
|
|
|
|
| 45 |
class SEModule(nn.Module):
|
| 46 |
"""Squeeze-and-Excitation channel attention (Hu et al., 2018).
|
| 47 |
|
|
@@ -79,18 +79,28 @@ class IBasicBlock(nn.Module):
|
|
| 79 |
inplanes: int,
|
| 80 |
planes: int,
|
| 81 |
stride: int = 1,
|
| 82 |
-
downsample:
|
| 83 |
use_se: bool = True,
|
| 84 |
):
|
| 85 |
super().__init__()
|
| 86 |
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-5)
|
| 87 |
self.conv1 = nn.Conv2d(
|
| 88 |
-
inplanes,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
)
|
| 90 |
self.bn2 = nn.BatchNorm2d(planes, eps=1e-5)
|
| 91 |
self.prelu = nn.PReLU(planes)
|
| 92 |
self.conv2 = nn.Conv2d(
|
| 93 |
-
planes,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
self.bn3 = nn.BatchNorm2d(planes, eps=1e-5)
|
| 96 |
|
|
@@ -120,6 +130,7 @@ class IBasicBlock(nn.Module):
|
|
| 120 |
# Backbone
|
| 121 |
# ---------------------------------------------------------------------------
|
| 122 |
|
|
|
|
| 123 |
class ArcFaceBackbone(nn.Module):
|
| 124 |
"""IResNet-50 backbone for ArcFace identity embeddings.
|
| 125 |
|
|
@@ -257,7 +268,7 @@ _WEIGHT_URL = (
|
|
| 257 |
)
|
| 258 |
|
| 259 |
|
| 260 |
-
def _find_pretrained_weights() ->
|
| 261 |
"""Search known locations for pretrained IResNet-50 weights."""
|
| 262 |
for p in _KNOWN_WEIGHT_PATHS:
|
| 263 |
if p.exists() and p.suffix == ".pth":
|
|
@@ -269,6 +280,7 @@ def _try_download_weights(dest: Path) -> bool:
|
|
| 269 |
"""Attempt to download pretrained weights from the InsightFace release."""
|
| 270 |
try:
|
| 271 |
import urllib.request
|
|
|
|
| 272 |
dest.parent.mkdir(parents=True, exist_ok=True)
|
| 273 |
logger.info("Downloading ArcFace IResNet-50 weights from %s ...", _WEIGHT_URL)
|
| 274 |
urllib.request.urlretrieve(_WEIGHT_URL, str(dest))
|
|
@@ -281,7 +293,7 @@ def _try_download_weights(dest: Path) -> bool:
|
|
| 281 |
|
| 282 |
def load_pretrained_weights(
|
| 283 |
model: ArcFaceBackbone,
|
| 284 |
-
weights_path:
|
| 285 |
download: bool = True,
|
| 286 |
) -> bool:
|
| 287 |
"""Load pretrained InsightFace IResNet-50 weights into the model.
|
|
@@ -300,7 +312,7 @@ def load_pretrained_weights(
|
|
| 300 |
``True`` if weights were loaded successfully, ``False`` otherwise
|
| 301 |
(model keeps random initialization).
|
| 302 |
"""
|
| 303 |
-
path:
|
| 304 |
|
| 305 |
if weights_path is not None:
|
| 306 |
path = Path(weights_path)
|
|
@@ -368,8 +380,7 @@ def load_pretrained_weights(
|
|
| 368 |
return True
|
| 369 |
except Exception as e:
|
| 370 |
warnings.warn(
|
| 371 |
-
f"Failed to load ArcFace weights from {path}: {e}. "
|
| 372 |
-
"Using random initialization.",
|
| 373 |
UserWarning,
|
| 374 |
stacklevel=2,
|
| 375 |
)
|
|
@@ -380,6 +391,7 @@ def load_pretrained_weights(
|
|
| 380 |
# Differentiable face alignment
|
| 381 |
# ---------------------------------------------------------------------------
|
| 382 |
|
|
|
|
| 383 |
def align_face(
|
| 384 |
images: torch.Tensor,
|
| 385 |
size: int = 112,
|
|
@@ -402,7 +414,7 @@ def align_face(
|
|
| 402 |
"""
|
| 403 |
B, C, H, W = images.shape
|
| 404 |
|
| 405 |
-
if
|
| 406 |
return images
|
| 407 |
|
| 408 |
# Crop fraction: keep central 80% to remove background padding
|
|
@@ -414,13 +426,17 @@ def align_face(
|
|
| 414 |
# grid_sample expects coordinates in [-1, 1] where -1 is top-left, +1 is bottom-right
|
| 415 |
# Center crop: map [-1, 1] output range to [-crop_frac, +crop_frac] input range
|
| 416 |
theta = torch.zeros(B, 2, 3, device=images.device, dtype=images.dtype)
|
| 417 |
-
theta[:, 0, 0] = half_crop
|
| 418 |
-
theta[:, 1, 1] = half_crop
|
| 419 |
# translation stays 0 (centered)
|
| 420 |
|
| 421 |
grid = F.affine_grid(theta, [B, C, size, size], align_corners=False)
|
| 422 |
aligned = F.grid_sample(
|
| 423 |
-
images,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
)
|
| 425 |
return aligned
|
| 426 |
|
|
@@ -444,7 +460,10 @@ def align_face_no_crop(
|
|
| 444 |
if images.shape[-2] == size and images.shape[-1] == size:
|
| 445 |
return images
|
| 446 |
return F.interpolate(
|
| 447 |
-
images,
|
|
|
|
|
|
|
|
|
|
| 448 |
)
|
| 449 |
|
| 450 |
|
|
@@ -452,6 +471,7 @@ def align_face_no_crop(
|
|
| 452 |
# ArcFaceLoss: differentiable identity preservation loss
|
| 453 |
# ---------------------------------------------------------------------------
|
| 454 |
|
|
|
|
| 455 |
class ArcFaceLoss(nn.Module):
|
| 456 |
"""Differentiable identity loss using PyTorch-native ArcFace.
|
| 457 |
|
|
@@ -474,8 +494,8 @@ class ArcFaceLoss(nn.Module):
|
|
| 474 |
|
| 475 |
def __init__(
|
| 476 |
self,
|
| 477 |
-
device:
|
| 478 |
-
weights_path:
|
| 479 |
crop_face: bool = True,
|
| 480 |
):
|
| 481 |
"""
|
|
@@ -527,10 +547,7 @@ class ArcFaceLoss(nn.Module):
|
|
| 527 |
Returns:
|
| 528 |
(B, 3, 112, 112) in [-1, 1].
|
| 529 |
"""
|
| 530 |
-
if self.crop_face
|
| 531 |
-
x = align_face(images, size=112)
|
| 532 |
-
else:
|
| 533 |
-
x = align_face_no_crop(images, size=112)
|
| 534 |
|
| 535 |
# Normalize from [0, 1] to [-1, 1]
|
| 536 |
x = x * 2.0 - 1.0
|
|
@@ -662,9 +679,10 @@ class ArcFaceLoss(nn.Module):
|
|
| 662 |
# Convenience: create a pre-configured loss instance
|
| 663 |
# ---------------------------------------------------------------------------
|
| 664 |
|
|
|
|
| 665 |
def create_arcface_loss(
|
| 666 |
-
device:
|
| 667 |
-
weights_path:
|
| 668 |
) -> ArcFaceLoss:
|
| 669 |
"""Factory function for creating an ArcFaceLoss with sensible defaults.
|
| 670 |
|
|
|
|
| 29 |
import logging
|
| 30 |
import warnings
|
| 31 |
from pathlib import Path
|
|
|
|
| 32 |
|
| 33 |
import torch
|
| 34 |
import torch.nn as nn
|
|
|
|
| 41 |
# Building blocks
|
| 42 |
# ---------------------------------------------------------------------------
|
| 43 |
|
| 44 |
+
|
| 45 |
class SEModule(nn.Module):
|
| 46 |
"""Squeeze-and-Excitation channel attention (Hu et al., 2018).
|
| 47 |
|
|
|
|
| 79 |
inplanes: int,
|
| 80 |
planes: int,
|
| 81 |
stride: int = 1,
|
| 82 |
+
downsample: nn.Module | None = None,
|
| 83 |
use_se: bool = True,
|
| 84 |
):
|
| 85 |
super().__init__()
|
| 86 |
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-5)
|
| 87 |
self.conv1 = nn.Conv2d(
|
| 88 |
+
inplanes,
|
| 89 |
+
planes,
|
| 90 |
+
kernel_size=3,
|
| 91 |
+
stride=1,
|
| 92 |
+
padding=1,
|
| 93 |
+
bias=False,
|
| 94 |
)
|
| 95 |
self.bn2 = nn.BatchNorm2d(planes, eps=1e-5)
|
| 96 |
self.prelu = nn.PReLU(planes)
|
| 97 |
self.conv2 = nn.Conv2d(
|
| 98 |
+
planes,
|
| 99 |
+
planes,
|
| 100 |
+
kernel_size=3,
|
| 101 |
+
stride=stride,
|
| 102 |
+
padding=1,
|
| 103 |
+
bias=False,
|
| 104 |
)
|
| 105 |
self.bn3 = nn.BatchNorm2d(planes, eps=1e-5)
|
| 106 |
|
|
|
|
| 130 |
# Backbone
|
| 131 |
# ---------------------------------------------------------------------------
|
| 132 |
|
| 133 |
+
|
| 134 |
class ArcFaceBackbone(nn.Module):
|
| 135 |
"""IResNet-50 backbone for ArcFace identity embeddings.
|
| 136 |
|
|
|
|
| 268 |
)
|
| 269 |
|
| 270 |
|
| 271 |
+
def _find_pretrained_weights() -> Path | None:
|
| 272 |
"""Search known locations for pretrained IResNet-50 weights."""
|
| 273 |
for p in _KNOWN_WEIGHT_PATHS:
|
| 274 |
if p.exists() and p.suffix == ".pth":
|
|
|
|
| 280 |
"""Attempt to download pretrained weights from the InsightFace release."""
|
| 281 |
try:
|
| 282 |
import urllib.request
|
| 283 |
+
|
| 284 |
dest.parent.mkdir(parents=True, exist_ok=True)
|
| 285 |
logger.info("Downloading ArcFace IResNet-50 weights from %s ...", _WEIGHT_URL)
|
| 286 |
urllib.request.urlretrieve(_WEIGHT_URL, str(dest))
|
|
|
|
| 293 |
|
| 294 |
def load_pretrained_weights(
|
| 295 |
model: ArcFaceBackbone,
|
| 296 |
+
weights_path: str | None = None,
|
| 297 |
download: bool = True,
|
| 298 |
) -> bool:
|
| 299 |
"""Load pretrained InsightFace IResNet-50 weights into the model.
|
|
|
|
| 312 |
``True`` if weights were loaded successfully, ``False`` otherwise
|
| 313 |
(model keeps random initialization).
|
| 314 |
"""
|
| 315 |
+
path: Path | None = None
|
| 316 |
|
| 317 |
if weights_path is not None:
|
| 318 |
path = Path(weights_path)
|
|
|
|
| 380 |
return True
|
| 381 |
except Exception as e:
|
| 382 |
warnings.warn(
|
| 383 |
+
f"Failed to load ArcFace weights from {path}: {e}. Using random initialization.",
|
|
|
|
| 384 |
UserWarning,
|
| 385 |
stacklevel=2,
|
| 386 |
)
|
|
|
|
| 391 |
# Differentiable face alignment
|
| 392 |
# ---------------------------------------------------------------------------
|
| 393 |
|
| 394 |
+
|
| 395 |
def align_face(
|
| 396 |
images: torch.Tensor,
|
| 397 |
size: int = 112,
|
|
|
|
| 414 |
"""
|
| 415 |
B, C, H, W = images.shape
|
| 416 |
|
| 417 |
+
if size == H and size == W:
|
| 418 |
return images
|
| 419 |
|
| 420 |
# Crop fraction: keep central 80% to remove background padding
|
|
|
|
| 426 |
# grid_sample expects coordinates in [-1, 1] where -1 is top-left, +1 is bottom-right
|
| 427 |
# Center crop: map [-1, 1] output range to [-crop_frac, +crop_frac] input range
|
| 428 |
theta = torch.zeros(B, 2, 3, device=images.device, dtype=images.dtype)
|
| 429 |
+
theta[:, 0, 0] = half_crop # x scale
|
| 430 |
+
theta[:, 1, 1] = half_crop # y scale
|
| 431 |
# translation stays 0 (centered)
|
| 432 |
|
| 433 |
grid = F.affine_grid(theta, [B, C, size, size], align_corners=False)
|
| 434 |
aligned = F.grid_sample(
|
| 435 |
+
images,
|
| 436 |
+
grid,
|
| 437 |
+
mode="bilinear",
|
| 438 |
+
padding_mode="border",
|
| 439 |
+
align_corners=False,
|
| 440 |
)
|
| 441 |
return aligned
|
| 442 |
|
|
|
|
| 460 |
if images.shape[-2] == size and images.shape[-1] == size:
|
| 461 |
return images
|
| 462 |
return F.interpolate(
|
| 463 |
+
images,
|
| 464 |
+
size=(size, size),
|
| 465 |
+
mode="bilinear",
|
| 466 |
+
align_corners=False,
|
| 467 |
)
|
| 468 |
|
| 469 |
|
|
|
|
| 471 |
# ArcFaceLoss: differentiable identity preservation loss
|
| 472 |
# ---------------------------------------------------------------------------
|
| 473 |
|
| 474 |
+
|
| 475 |
class ArcFaceLoss(nn.Module):
|
| 476 |
"""Differentiable identity loss using PyTorch-native ArcFace.
|
| 477 |
|
|
|
|
| 494 |
|
| 495 |
def __init__(
|
| 496 |
self,
|
| 497 |
+
device: torch.device | None = None,
|
| 498 |
+
weights_path: str | None = None,
|
| 499 |
crop_face: bool = True,
|
| 500 |
):
|
| 501 |
"""
|
|
|
|
| 547 |
Returns:
|
| 548 |
(B, 3, 112, 112) in [-1, 1].
|
| 549 |
"""
|
| 550 |
+
x = align_face(images, size=112) if self.crop_face else align_face_no_crop(images, size=112)
|
|
|
|
|
|
|
|
|
|
| 551 |
|
| 552 |
# Normalize from [0, 1] to [-1, 1]
|
| 553 |
x = x * 2.0 - 1.0
|
|
|
|
| 679 |
# Convenience: create a pre-configured loss instance
|
| 680 |
# ---------------------------------------------------------------------------
|
| 681 |
|
| 682 |
+
|
| 683 |
def create_arcface_loss(
|
| 684 |
+
device: torch.device | None = None,
|
| 685 |
+
weights_path: str | None = None,
|
| 686 |
) -> ArcFaceLoss:
|
| 687 |
"""Factory function for creating an ArcFaceLoss with sensible defaults.
|
| 688 |
|
landmarkdiff/audit.py
CHANGED
|
@@ -116,12 +116,10 @@ class AuditReporter:
|
|
| 116 |
if case.identity_sim > 0:
|
| 117 |
by_proc[proc]["id_sims"].append(case.identity_sim)
|
| 118 |
|
| 119 |
-
for
|
| 120 |
stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
|
| 121 |
stats["mean_identity_sim"] = (
|
| 122 |
-
sum(stats["id_sims"]) / len(stats["id_sims"])
|
| 123 |
-
if stats["id_sims"]
|
| 124 |
-
else 0.0
|
| 125 |
)
|
| 126 |
del stats["id_sims"]
|
| 127 |
|
|
@@ -137,12 +135,10 @@ class AuditReporter:
|
|
| 137 |
if case.identity_sim > 0:
|
| 138 |
by_fitz[ft]["id_sims"].append(case.identity_sim)
|
| 139 |
|
| 140 |
-
for
|
| 141 |
stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
|
| 142 |
stats["mean_identity_sim"] = (
|
| 143 |
-
sum(stats["id_sims"]) / len(stats["id_sims"])
|
| 144 |
-
if stats["id_sims"]
|
| 145 |
-
else 0.0
|
| 146 |
)
|
| 147 |
del stats["id_sims"]
|
| 148 |
|
|
@@ -268,7 +264,7 @@ class AuditReporter:
|
|
| 268 |
f"<td>{c.procedure.title()}</td>"
|
| 269 |
f"<td>{c.fitzpatrick_type}</td>"
|
| 270 |
f"<td>{c.identity_sim:.4f}</td>"
|
| 271 |
-
f
|
| 272 |
f"<td>{issues}</td>"
|
| 273 |
f"</tr>\n"
|
| 274 |
)
|
|
|
|
| 116 |
if case.identity_sim > 0:
|
| 117 |
by_proc[proc]["id_sims"].append(case.identity_sim)
|
| 118 |
|
| 119 |
+
for _proc, stats in by_proc.items():
|
| 120 |
stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
|
| 121 |
stats["mean_identity_sim"] = (
|
| 122 |
+
sum(stats["id_sims"]) / len(stats["id_sims"]) if stats["id_sims"] else 0.0
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
del stats["id_sims"]
|
| 125 |
|
|
|
|
| 135 |
if case.identity_sim > 0:
|
| 136 |
by_fitz[ft]["id_sims"].append(case.identity_sim)
|
| 137 |
|
| 138 |
+
for _ft, stats in by_fitz.items():
|
| 139 |
stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
|
| 140 |
stats["mean_identity_sim"] = (
|
| 141 |
+
sum(stats["id_sims"]) / len(stats["id_sims"]) if stats["id_sims"] else 0.0
|
|
|
|
|
|
|
| 142 |
)
|
| 143 |
del stats["id_sims"]
|
| 144 |
|
|
|
|
| 264 |
f"<td>{c.procedure.title()}</td>"
|
| 265 |
f"<td>{c.fitzpatrick_type}</td>"
|
| 266 |
f"<td>{c.identity_sim:.4f}</td>"
|
| 267 |
+
f"<td>{'WARN' if c.safety_passed else 'FAIL'}</td>"
|
| 268 |
f"<td>{issues}</td>"
|
| 269 |
f"</tr>\n"
|
| 270 |
)
|
landmarkdiff/augmentation.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training data augmentation pipeline for LandmarkDiff.
|
| 2 |
+
|
| 3 |
+
Provides domain-specific augmentations that maintain landmark consistency:
|
| 4 |
+
- Geometric: flip, rotation, affine (landmarks co-transformed)
|
| 5 |
+
- Photometric: color jitter, brightness, contrast (applied to images only)
|
| 6 |
+
- Skin-tone augmentation: ITA-space perturbation for Fitzpatrick balance
|
| 7 |
+
- Conditioning augmentation: noise injection, dropout for robustness
|
| 8 |
+
|
| 9 |
+
All augmentations preserve the correspondence between:
|
| 10 |
+
input_image ↔ conditioning_image ↔ target_image ↔ mask
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
|
| 17 |
+
import cv2
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class AugmentationConfig:
|
| 23 |
+
"""Augmentation parameters."""
|
| 24 |
+
|
| 25 |
+
# Geometric
|
| 26 |
+
random_flip: bool = True
|
| 27 |
+
random_rotation_deg: float = 5.0
|
| 28 |
+
random_scale: tuple[float, float] = (0.95, 1.05)
|
| 29 |
+
random_translate: float = 0.02 # fraction of image size
|
| 30 |
+
|
| 31 |
+
# Photometric (images only, not conditioning)
|
| 32 |
+
brightness_range: tuple[float, float] = (0.9, 1.1)
|
| 33 |
+
contrast_range: tuple[float, float] = (0.9, 1.1)
|
| 34 |
+
saturation_range: tuple[float, float] = (0.9, 1.1)
|
| 35 |
+
hue_shift_range: float = 5.0 # degrees
|
| 36 |
+
|
| 37 |
+
# Conditioning augmentation
|
| 38 |
+
conditioning_dropout_prob: float = 0.1
|
| 39 |
+
conditioning_noise_std: float = 0.02
|
| 40 |
+
|
| 41 |
+
# Skin-tone augmentation
|
| 42 |
+
ita_perturbation_std: float = 3.0 # ITA angle noise
|
| 43 |
+
|
| 44 |
+
seed: int | None = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def augment_training_sample(
|
| 48 |
+
input_image: np.ndarray,
|
| 49 |
+
target_image: np.ndarray,
|
| 50 |
+
conditioning: np.ndarray,
|
| 51 |
+
mask: np.ndarray,
|
| 52 |
+
landmarks_src: np.ndarray | None = None,
|
| 53 |
+
landmarks_dst: np.ndarray | None = None,
|
| 54 |
+
config: AugmentationConfig | None = None,
|
| 55 |
+
rng: np.random.Generator | None = None,
|
| 56 |
+
) -> dict[str, np.ndarray]:
|
| 57 |
+
"""Apply consistent augmentations to a training sample.
|
| 58 |
+
|
| 59 |
+
All spatial transforms are applied to images AND landmarks together
|
| 60 |
+
so correspondence is preserved.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
input_image: (H, W, 3) original face image (uint8 BGR).
|
| 64 |
+
target_image: (H, W, 3) target face image (uint8 BGR).
|
| 65 |
+
conditioning: (H, W, 3) conditioning image (uint8).
|
| 66 |
+
mask: (H, W) or (H, W, 1) float32 mask.
|
| 67 |
+
landmarks_src: (N, 2) normalized [0,1] source landmark coords.
|
| 68 |
+
landmarks_dst: (N, 2) normalized [0,1] target landmark coords.
|
| 69 |
+
config: Augmentation parameters.
|
| 70 |
+
rng: Random generator for reproducibility.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Dict with augmented versions of all inputs.
|
| 74 |
+
"""
|
| 75 |
+
if config is None:
|
| 76 |
+
config = AugmentationConfig()
|
| 77 |
+
if rng is None:
|
| 78 |
+
rng = np.random.default_rng(config.seed)
|
| 79 |
+
|
| 80 |
+
h, w = input_image.shape[:2]
|
| 81 |
+
out_input = input_image.copy()
|
| 82 |
+
out_target = target_image.copy()
|
| 83 |
+
out_cond = conditioning.copy()
|
| 84 |
+
out_mask = mask.copy()
|
| 85 |
+
out_lm_src = landmarks_src.copy() if landmarks_src is not None else None
|
| 86 |
+
out_lm_dst = landmarks_dst.copy() if landmarks_dst is not None else None
|
| 87 |
+
|
| 88 |
+
# --- Geometric augmentations (applied to all) ---
|
| 89 |
+
|
| 90 |
+
# Random horizontal flip
|
| 91 |
+
if config.random_flip and rng.random() < 0.5:
|
| 92 |
+
out_input = np.ascontiguousarray(out_input[:, ::-1])
|
| 93 |
+
out_target = np.ascontiguousarray(out_target[:, ::-1])
|
| 94 |
+
out_cond = np.ascontiguousarray(out_cond[:, ::-1])
|
| 95 |
+
out_mask = np.ascontiguousarray(
|
| 96 |
+
out_mask[:, ::-1] if out_mask.ndim == 2 else out_mask[:, ::-1, :]
|
| 97 |
+
)
|
| 98 |
+
if out_lm_src is not None:
|
| 99 |
+
out_lm_src[:, 0] = 1.0 - out_lm_src[:, 0]
|
| 100 |
+
if out_lm_dst is not None:
|
| 101 |
+
out_lm_dst[:, 0] = 1.0 - out_lm_dst[:, 0]
|
| 102 |
+
|
| 103 |
+
# Random rotation + scale + translate
|
| 104 |
+
if config.random_rotation_deg > 0 or config.random_scale != (1.0, 1.0):
|
| 105 |
+
angle = rng.uniform(-config.random_rotation_deg, config.random_rotation_deg)
|
| 106 |
+
scale = rng.uniform(config.random_scale[0], config.random_scale[1])
|
| 107 |
+
tx = rng.uniform(-config.random_translate, config.random_translate) * w
|
| 108 |
+
ty = rng.uniform(-config.random_translate, config.random_translate) * h
|
| 109 |
+
|
| 110 |
+
center = (w / 2, h / 2)
|
| 111 |
+
M = cv2.getRotationMatrix2D(center, angle, scale)
|
| 112 |
+
M[0, 2] += tx
|
| 113 |
+
M[1, 2] += ty
|
| 114 |
+
|
| 115 |
+
out_input = cv2.warpAffine(out_input, M, (w, h), borderMode=cv2.BORDER_REFLECT_101)
|
| 116 |
+
out_target = cv2.warpAffine(out_target, M, (w, h), borderMode=cv2.BORDER_REFLECT_101)
|
| 117 |
+
out_cond = cv2.warpAffine(
|
| 118 |
+
out_cond, M, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=0
|
| 119 |
+
)
|
| 120 |
+
mask_2d = out_mask if out_mask.ndim == 2 else out_mask[:, :, 0]
|
| 121 |
+
mask_2d = cv2.warpAffine(mask_2d, M, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=0)
|
| 122 |
+
out_mask = mask_2d if out_mask.ndim == 2 else mask_2d[:, :, np.newaxis]
|
| 123 |
+
|
| 124 |
+
# Transform landmarks
|
| 125 |
+
if out_lm_src is not None:
|
| 126 |
+
out_lm_src = _transform_landmarks(out_lm_src, M, w, h)
|
| 127 |
+
if out_lm_dst is not None:
|
| 128 |
+
out_lm_dst = _transform_landmarks(out_lm_dst, M, w, h)
|
| 129 |
+
|
| 130 |
+
# --- Photometric augmentations (images only, not conditioning/mask) ---
|
| 131 |
+
|
| 132 |
+
# Brightness
|
| 133 |
+
b_factor = rng.uniform(config.brightness_range[0], config.brightness_range[1])
|
| 134 |
+
out_input = np.clip(out_input.astype(np.float32) * b_factor, 0, 255).astype(np.uint8)
|
| 135 |
+
out_target = np.clip(out_target.astype(np.float32) * b_factor, 0, 255).astype(np.uint8)
|
| 136 |
+
|
| 137 |
+
# Contrast
|
| 138 |
+
c_factor = rng.uniform(config.contrast_range[0], config.contrast_range[1])
|
| 139 |
+
mean_in = out_input.mean()
|
| 140 |
+
mean_tgt = out_target.mean()
|
| 141 |
+
out_input = np.clip(
|
| 142 |
+
(out_input.astype(np.float32) - mean_in) * c_factor + mean_in, 0, 255
|
| 143 |
+
).astype(np.uint8)
|
| 144 |
+
out_target = np.clip(
|
| 145 |
+
(out_target.astype(np.float32) - mean_tgt) * c_factor + mean_tgt, 0, 255
|
| 146 |
+
).astype(np.uint8)
|
| 147 |
+
|
| 148 |
+
# Saturation (in HSV space)
|
| 149 |
+
s_factor = rng.uniform(config.saturation_range[0], config.saturation_range[1])
|
| 150 |
+
if abs(s_factor - 1.0) > 1e-4:
|
| 151 |
+
out_input = _adjust_saturation(out_input, s_factor)
|
| 152 |
+
out_target = _adjust_saturation(out_target, s_factor)
|
| 153 |
+
|
| 154 |
+
# Hue shift
|
| 155 |
+
if config.hue_shift_range > 0:
|
| 156 |
+
hue_delta = rng.uniform(-config.hue_shift_range, config.hue_shift_range)
|
| 157 |
+
if abs(hue_delta) > 0.1:
|
| 158 |
+
out_input = _shift_hue(out_input, hue_delta)
|
| 159 |
+
out_target = _shift_hue(out_target, hue_delta)
|
| 160 |
+
|
| 161 |
+
# --- Conditioning augmentation ---
|
| 162 |
+
|
| 163 |
+
# Conditioning dropout (replace with zeros to learn unconditional)
|
| 164 |
+
if config.conditioning_dropout_prob > 0 and rng.random() < config.conditioning_dropout_prob:
|
| 165 |
+
out_cond = np.zeros_like(out_cond)
|
| 166 |
+
|
| 167 |
+
# Conditioning noise
|
| 168 |
+
if config.conditioning_noise_std > 0:
|
| 169 |
+
noise = rng.normal(0, config.conditioning_noise_std * 255, out_cond.shape)
|
| 170 |
+
out_cond = np.clip(out_cond.astype(np.float32) + noise, 0, 255).astype(np.uint8)
|
| 171 |
+
|
| 172 |
+
result = {
|
| 173 |
+
"input_image": out_input,
|
| 174 |
+
"target_image": out_target,
|
| 175 |
+
"conditioning": out_cond,
|
| 176 |
+
"mask": out_mask,
|
| 177 |
+
}
|
| 178 |
+
if out_lm_src is not None:
|
| 179 |
+
result["landmarks_src"] = out_lm_src
|
| 180 |
+
if out_lm_dst is not None:
|
| 181 |
+
result["landmarks_dst"] = out_lm_dst
|
| 182 |
+
|
| 183 |
+
return result
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _transform_landmarks(landmarks: np.ndarray, M: np.ndarray, w: int, h: int) -> np.ndarray:
|
| 187 |
+
"""Transform normalized landmarks with an affine matrix."""
|
| 188 |
+
# Convert to pixel coords
|
| 189 |
+
px = landmarks.copy()
|
| 190 |
+
px[:, 0] *= w
|
| 191 |
+
px[:, 1] *= h
|
| 192 |
+
|
| 193 |
+
# Apply affine transform
|
| 194 |
+
ones = np.ones((px.shape[0], 1))
|
| 195 |
+
px_h = np.hstack([px, ones]) # (N, 3)
|
| 196 |
+
transformed = (M @ px_h.T).T # (N, 2)
|
| 197 |
+
|
| 198 |
+
# Back to normalized
|
| 199 |
+
transformed[:, 0] /= w
|
| 200 |
+
transformed[:, 1] /= h
|
| 201 |
+
return np.clip(transformed, 0.0, 1.0)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _adjust_saturation(img: np.ndarray, factor: float) -> np.ndarray:
|
| 205 |
+
"""Adjust saturation of a BGR image."""
|
| 206 |
+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
|
| 207 |
+
hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255)
|
| 208 |
+
return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _shift_hue(img: np.ndarray, delta_deg: float) -> np.ndarray:
|
| 212 |
+
"""Shift hue of a BGR image by delta degrees."""
|
| 213 |
+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
|
| 214 |
+
# OpenCV hue range is [0, 180]
|
| 215 |
+
hsv[:, :, 0] = (hsv[:, :, 0] + delta_deg / 2) % 180
|
| 216 |
+
return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def augment_skin_tone(
|
| 220 |
+
image: np.ndarray,
|
| 221 |
+
ita_delta: float = 0.0,
|
| 222 |
+
) -> np.ndarray:
|
| 223 |
+
"""Augment skin tone by shifting in L*a*b* space.
|
| 224 |
+
|
| 225 |
+
This helps balance Fitzpatrick representation in training by
|
| 226 |
+
simulating different skin tones from existing samples.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
image: (H, W, 3) BGR uint8 image.
|
| 230 |
+
ita_delta: ITA angle shift (positive = lighter, negative = darker).
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Augmented image with shifted skin tone.
|
| 234 |
+
"""
|
| 235 |
+
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 236 |
+
|
| 237 |
+
# Shift L channel (lightness) based on ITA delta
|
| 238 |
+
# ITA = arctan((L-50)/b), so shifting ITA shifts L
|
| 239 |
+
l_shift = ita_delta * 0.5 # approximate mapping
|
| 240 |
+
lab[:, :, 0] = np.clip(lab[:, :, 0] + l_shift, 0, 255)
|
| 241 |
+
|
| 242 |
+
# Slightly shift b channel too for more natural tone changes
|
| 243 |
+
b_shift = -ita_delta * 0.15
|
| 244 |
+
lab[:, :, 2] = np.clip(lab[:, :, 2] + b_shift, 0, 255)
|
| 245 |
+
|
| 246 |
+
return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class FitzpatrickBalancer:
|
| 250 |
+
"""Oversample underrepresented Fitzpatrick types during training.
|
| 251 |
+
|
| 252 |
+
Maintains per-type counts and generates sampling weights to ensure
|
| 253 |
+
equitable training across all skin types.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(self, target_distribution: dict[str, float] | None = None):
|
| 257 |
+
"""Initialize balancer.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
target_distribution: Target fraction per type. Defaults to uniform.
|
| 261 |
+
"""
|
| 262 |
+
self.target = target_distribution or {
|
| 263 |
+
"I": 1 / 6,
|
| 264 |
+
"II": 1 / 6,
|
| 265 |
+
"III": 1 / 6,
|
| 266 |
+
"IV": 1 / 6,
|
| 267 |
+
"V": 1 / 6,
|
| 268 |
+
"VI": 1 / 6,
|
| 269 |
+
}
|
| 270 |
+
self._counts: dict[str, int] = {}
|
| 271 |
+
|
| 272 |
+
def register_sample(self, fitz_type: str) -> None:
|
| 273 |
+
"""Register a sample's Fitzpatrick type."""
|
| 274 |
+
self._counts[fitz_type] = self._counts.get(fitz_type, 0) + 1
|
| 275 |
+
|
| 276 |
+
def get_sampling_weights(self, fitz_types: list[str]) -> np.ndarray:
|
| 277 |
+
"""Compute sampling weights for a list of samples.
|
| 278 |
+
|
| 279 |
+
Returns weights inversely proportional to type frequency,
|
| 280 |
+
so underrepresented types get upsampled.
|
| 281 |
+
"""
|
| 282 |
+
total = sum(self._counts.values()) or 1
|
| 283 |
+
weights = []
|
| 284 |
+
for ft in fitz_types:
|
| 285 |
+
count = self._counts.get(ft, 1)
|
| 286 |
+
freq = count / total
|
| 287 |
+
target_freq = self.target.get(ft, 1 / 6)
|
| 288 |
+
# Weight = target / actual (capped for stability)
|
| 289 |
+
w = min(target_freq / max(freq, 1e-6), 5.0)
|
| 290 |
+
weights.append(w)
|
| 291 |
+
|
| 292 |
+
w = np.array(weights, dtype=np.float64)
|
| 293 |
+
return w / w.sum() # normalize to probability distribution
|
landmarkdiff/benchmark.py
CHANGED
|
@@ -63,17 +63,19 @@ class InferenceBenchmark:
|
|
| 63 |
if throughput_fps == 0.0 and latency_ms > 0:
|
| 64 |
throughput_fps = 1000.0 / latency_ms * batch_size
|
| 65 |
|
| 66 |
-
self.results.append(
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
|
| 78 |
def mean_latency(self, config_name: str | None = None) -> float:
|
| 79 |
"""Mean latency in ms, optionally filtered by config."""
|
|
@@ -124,7 +126,10 @@ class InferenceBenchmark:
|
|
| 124 |
if not configs:
|
| 125 |
return "No benchmark results."
|
| 126 |
|
| 127 |
-
header =
|
|
|
|
|
|
|
|
|
|
| 128 |
lines = [
|
| 129 |
f"Inference Benchmark: {self.model_name}",
|
| 130 |
header,
|
|
|
|
| 63 |
if throughput_fps == 0.0 and latency_ms > 0:
|
| 64 |
throughput_fps = 1000.0 / latency_ms * batch_size
|
| 65 |
|
| 66 |
+
self.results.append(
|
| 67 |
+
BenchmarkResult(
|
| 68 |
+
config_name=config_name,
|
| 69 |
+
latency_ms=latency_ms,
|
| 70 |
+
throughput_fps=throughput_fps,
|
| 71 |
+
vram_gb=vram_gb,
|
| 72 |
+
batch_size=batch_size,
|
| 73 |
+
resolution=resolution,
|
| 74 |
+
num_inference_steps=num_inference_steps,
|
| 75 |
+
device=device,
|
| 76 |
+
metadata=metadata,
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
|
| 80 |
def mean_latency(self, config_name: str | None = None) -> float:
|
| 81 |
"""Mean latency in ms, optionally filtered by config."""
|
|
|
|
| 126 |
if not configs:
|
| 127 |
return "No benchmark results."
|
| 128 |
|
| 129 |
+
header = (
|
| 130 |
+
f"{'Config':>20s} | {'Mean(ms)':>10s} | {'P99(ms)':>10s}"
|
| 131 |
+
f" | {'FPS':>8s} | {'VRAM(GB)':>8s} | {'N':>4s}"
|
| 132 |
+
)
|
| 133 |
lines = [
|
| 134 |
f"Inference Benchmark: {self.model_name}",
|
| 135 |
header,
|
landmarkdiff/checkpoint_manager.py
CHANGED
|
@@ -166,9 +166,7 @@ class CheckpointManager:
|
|
| 166 |
torch.save(state, ckpt_dir / "training_state.pt")
|
| 167 |
|
| 168 |
# Compute checkpoint size
|
| 169 |
-
size_mb = sum(
|
| 170 |
-
f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()
|
| 171 |
-
) / (1024 * 1024)
|
| 172 |
|
| 173 |
# Create metadata
|
| 174 |
meta = CheckpointMetadata(
|
|
@@ -216,7 +214,7 @@ class CheckpointManager:
|
|
| 216 |
entries.sort(key=lambda x: x[1], reverse=not self.lower_is_better)
|
| 217 |
|
| 218 |
# Mark best
|
| 219 |
-
best_names = {e[0] for e in entries[:self.keep_best]}
|
| 220 |
for name, meta in self._index["checkpoints"].items():
|
| 221 |
meta["is_best"] = name in best_names
|
| 222 |
|
|
@@ -245,11 +243,11 @@ class CheckpointManager:
|
|
| 245 |
val = meta.get("metrics", {}).get(self.metric)
|
| 246 |
if val is None:
|
| 247 |
continue
|
| 248 |
-
if
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
best, best_val = name, val
|
| 254 |
return best
|
| 255 |
|
|
@@ -280,7 +278,7 @@ class CheckpointManager:
|
|
| 280 |
keep = set()
|
| 281 |
|
| 282 |
# Keep latest
|
| 283 |
-
for name in all_names[-self.keep_latest:]:
|
| 284 |
keep.add(name)
|
| 285 |
|
| 286 |
# Keep best
|
|
@@ -323,10 +321,7 @@ class CheckpointManager:
|
|
| 323 |
|
| 324 |
def total_size_mb(self) -> float:
|
| 325 |
"""Return total disk size of all tracked checkpoints."""
|
| 326 |
-
return sum(
|
| 327 |
-
meta.get("size_mb", 0.0)
|
| 328 |
-
for meta in self._index["checkpoints"].values()
|
| 329 |
-
)
|
| 330 |
|
| 331 |
def summary(self) -> str:
|
| 332 |
"""Return a human-readable summary of checkpoint state."""
|
|
@@ -351,6 +346,7 @@ class CheckpointManager:
|
|
| 351 |
# Helpers
|
| 352 |
# ------------------------------------------------------------------
|
| 353 |
|
|
|
|
| 354 |
def _get_state_dict(module: torch.nn.Module) -> dict:
|
| 355 |
"""Extract state dict, handling DDP wrapper."""
|
| 356 |
if hasattr(module, "module"):
|
|
|
|
| 166 |
torch.save(state, ckpt_dir / "training_state.pt")
|
| 167 |
|
| 168 |
# Compute checkpoint size
|
| 169 |
+
size_mb = sum(f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()) / (1024 * 1024)
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# Create metadata
|
| 172 |
meta = CheckpointMetadata(
|
|
|
|
| 214 |
entries.sort(key=lambda x: x[1], reverse=not self.lower_is_better)
|
| 215 |
|
| 216 |
# Mark best
|
| 217 |
+
best_names = {e[0] for e in entries[: self.keep_best]}
|
| 218 |
for name, meta in self._index["checkpoints"].items():
|
| 219 |
meta["is_best"] = name in best_names
|
| 220 |
|
|
|
|
| 243 |
val = meta.get("metrics", {}).get(self.metric)
|
| 244 |
if val is None:
|
| 245 |
continue
|
| 246 |
+
if (
|
| 247 |
+
best_val is None
|
| 248 |
+
or (self.lower_is_better and val < best_val)
|
| 249 |
+
or (not self.lower_is_better and val > best_val)
|
| 250 |
+
):
|
| 251 |
best, best_val = name, val
|
| 252 |
return best
|
| 253 |
|
|
|
|
| 278 |
keep = set()
|
| 279 |
|
| 280 |
# Keep latest
|
| 281 |
+
for name in all_names[-self.keep_latest :]:
|
| 282 |
keep.add(name)
|
| 283 |
|
| 284 |
# Keep best
|
|
|
|
| 321 |
|
| 322 |
def total_size_mb(self) -> float:
|
| 323 |
"""Return total disk size of all tracked checkpoints."""
|
| 324 |
+
return sum(meta.get("size_mb", 0.0) for meta in self._index["checkpoints"].values())
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
def summary(self) -> str:
|
| 327 |
"""Return a human-readable summary of checkpoint state."""
|
|
|
|
| 346 |
# Helpers
|
| 347 |
# ------------------------------------------------------------------
|
| 348 |
|
| 349 |
+
|
| 350 |
def _get_state_dict(module: torch.nn.Module) -> dict:
|
| 351 |
"""Extract state dict, handling DDP wrapper."""
|
| 352 |
if hasattr(module, "module"):
|
landmarkdiff/cli.py
CHANGED
|
@@ -17,10 +17,10 @@ import sys
|
|
| 17 |
|
| 18 |
def cmd_infer(args: argparse.Namespace) -> None:
|
| 19 |
"""Run single-image inference."""
|
| 20 |
-
import cv2
|
| 21 |
-
import numpy as np
|
| 22 |
from pathlib import Path
|
| 23 |
|
|
|
|
|
|
|
| 24 |
from landmarkdiff.inference import LandmarkDiffPipeline
|
| 25 |
|
| 26 |
image = cv2.imread(args.image)
|
|
@@ -51,6 +51,7 @@ def cmd_infer(args: argparse.Namespace) -> None:
|
|
| 51 |
|
| 52 |
if args.watermark:
|
| 53 |
from landmarkdiff.safety import SafetyValidator
|
|
|
|
| 54 |
validator = SafetyValidator()
|
| 55 |
watermarked = validator.apply_watermark(result["output"])
|
| 56 |
wm_path = out_path.with_stem(out_path.stem + "_watermarked")
|
|
@@ -87,9 +88,7 @@ def cmd_evaluate(args: argparse.Namespace) -> None:
|
|
| 87 |
run_evaluation(
|
| 88 |
test_dir=args.test_dir,
|
| 89 |
output_dir=args.output,
|
| 90 |
-
mode=args.mode,
|
| 91 |
checkpoint=args.checkpoint,
|
| 92 |
-
displacement_model=args.displacement_model,
|
| 93 |
max_samples=args.max_samples,
|
| 94 |
)
|
| 95 |
|
|
@@ -98,10 +97,7 @@ def cmd_config(args: argparse.Namespace) -> None:
|
|
| 98 |
"""Show or validate configuration."""
|
| 99 |
from landmarkdiff.config import ExperimentConfig, load_config, validate_config
|
| 100 |
|
| 101 |
-
if args.file
|
| 102 |
-
config = load_config(args.file)
|
| 103 |
-
else:
|
| 104 |
-
config = ExperimentConfig()
|
| 105 |
|
| 106 |
if args.validate:
|
| 107 |
warnings = validate_config(config)
|
|
@@ -112,14 +108,17 @@ def cmd_config(args: argparse.Namespace) -> None:
|
|
| 112 |
else:
|
| 113 |
print("Configuration valid (no warnings).")
|
| 114 |
else:
|
| 115 |
-
import yaml
|
| 116 |
from dataclasses import asdict
|
|
|
|
|
|
|
|
|
|
| 117 |
print(yaml.dump(asdict(config), default_flow_style=False, sort_keys=False))
|
| 118 |
|
| 119 |
|
| 120 |
def cmd_validate(args: argparse.Namespace) -> None:
|
| 121 |
"""Run safety validation on an output image."""
|
| 122 |
import cv2
|
|
|
|
| 123 |
from landmarkdiff.safety import SafetyValidator
|
| 124 |
|
| 125 |
input_img = cv2.imread(args.input)
|
|
@@ -148,6 +147,7 @@ def cmd_validate(args: argparse.Namespace) -> None:
|
|
| 148 |
def cmd_version(args: argparse.Namespace) -> None:
|
| 149 |
"""Print version info."""
|
| 150 |
from landmarkdiff import __version__
|
|
|
|
| 151 |
print(f"LandmarkDiff v{__version__}")
|
| 152 |
|
| 153 |
|
|
@@ -162,8 +162,11 @@ def main(argv: list[str] | None = None) -> None:
|
|
| 162 |
# --- infer ---
|
| 163 |
p_infer = subparsers.add_parser("infer", help="Run single-image inference")
|
| 164 |
p_infer.add_argument("image", help="Input face image path")
|
| 165 |
-
p_infer.add_argument(
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
| 167 |
p_infer.add_argument("--intensity", type=float, default=65.0)
|
| 168 |
p_infer.add_argument("--output", default="output.png")
|
| 169 |
p_infer.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
|
|
@@ -180,8 +183,11 @@ def main(argv: list[str] | None = None) -> None:
|
|
| 180 |
p_ensemble.add_argument("--intensity", type=float, default=65.0)
|
| 181 |
p_ensemble.add_argument("--output", default="ensemble_output")
|
| 182 |
p_ensemble.add_argument("--n-samples", type=int, default=5)
|
| 183 |
-
p_ensemble.add_argument(
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
| 185 |
p_ensemble.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
|
| 186 |
p_ensemble.add_argument("--checkpoint", default=None)
|
| 187 |
p_ensemble.add_argument("--displacement-model", default=None)
|
|
|
|
| 17 |
|
| 18 |
def cmd_infer(args: argparse.Namespace) -> None:
|
| 19 |
"""Run single-image inference."""
|
|
|
|
|
|
|
| 20 |
from pathlib import Path
|
| 21 |
|
| 22 |
+
import cv2
|
| 23 |
+
|
| 24 |
from landmarkdiff.inference import LandmarkDiffPipeline
|
| 25 |
|
| 26 |
image = cv2.imread(args.image)
|
|
|
|
| 51 |
|
| 52 |
if args.watermark:
|
| 53 |
from landmarkdiff.safety import SafetyValidator
|
| 54 |
+
|
| 55 |
validator = SafetyValidator()
|
| 56 |
watermarked = validator.apply_watermark(result["output"])
|
| 57 |
wm_path = out_path.with_stem(out_path.stem + "_watermarked")
|
|
|
|
| 88 |
run_evaluation(
|
| 89 |
test_dir=args.test_dir,
|
| 90 |
output_dir=args.output,
|
|
|
|
| 91 |
checkpoint=args.checkpoint,
|
|
|
|
| 92 |
max_samples=args.max_samples,
|
| 93 |
)
|
| 94 |
|
|
|
|
| 97 |
"""Show or validate configuration."""
|
| 98 |
from landmarkdiff.config import ExperimentConfig, load_config, validate_config
|
| 99 |
|
| 100 |
+
config = load_config(args.file) if args.file else ExperimentConfig()
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
if args.validate:
|
| 103 |
warnings = validate_config(config)
|
|
|
|
| 108 |
else:
|
| 109 |
print("Configuration valid (no warnings).")
|
| 110 |
else:
|
|
|
|
| 111 |
from dataclasses import asdict
|
| 112 |
+
|
| 113 |
+
import yaml
|
| 114 |
+
|
| 115 |
print(yaml.dump(asdict(config), default_flow_style=False, sort_keys=False))
|
| 116 |
|
| 117 |
|
| 118 |
def cmd_validate(args: argparse.Namespace) -> None:
|
| 119 |
"""Run safety validation on an output image."""
|
| 120 |
import cv2
|
| 121 |
+
|
| 122 |
from landmarkdiff.safety import SafetyValidator
|
| 123 |
|
| 124 |
input_img = cv2.imread(args.input)
|
|
|
|
| 147 |
def cmd_version(args: argparse.Namespace) -> None:
|
| 148 |
"""Print version info."""
|
| 149 |
from landmarkdiff import __version__
|
| 150 |
+
|
| 151 |
print(f"LandmarkDiff v{__version__}")
|
| 152 |
|
| 153 |
|
|
|
|
| 162 |
# --- infer ---
|
| 163 |
p_infer = subparsers.add_parser("infer", help="Run single-image inference")
|
| 164 |
p_infer.add_argument("image", help="Input face image path")
|
| 165 |
+
p_infer.add_argument(
|
| 166 |
+
"--procedure",
|
| 167 |
+
default="rhinoplasty",
|
| 168 |
+
choices=["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic"],
|
| 169 |
+
)
|
| 170 |
p_infer.add_argument("--intensity", type=float, default=65.0)
|
| 171 |
p_infer.add_argument("--output", default="output.png")
|
| 172 |
p_infer.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
|
|
|
|
| 183 |
p_ensemble.add_argument("--intensity", type=float, default=65.0)
|
| 184 |
p_ensemble.add_argument("--output", default="ensemble_output")
|
| 185 |
p_ensemble.add_argument("--n-samples", type=int, default=5)
|
| 186 |
+
p_ensemble.add_argument(
|
| 187 |
+
"--strategy",
|
| 188 |
+
default="best_of_n",
|
| 189 |
+
choices=["pixel_average", "weighted_average", "best_of_n", "median"],
|
| 190 |
+
)
|
| 191 |
p_ensemble.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
|
| 192 |
p_ensemble.add_argument("--checkpoint", default=None)
|
| 193 |
p_ensemble.add_argument("--displacement-model", default=None)
|
landmarkdiff/clinical.py
CHANGED
|
@@ -80,9 +80,9 @@ def detect_vitiligo_patches(
|
|
| 80 |
# Also check for low saturation (a,b channels close to 128)
|
| 81 |
a_channel = lab[:, :, 1]
|
| 82 |
b_channel = lab[:, :, 2]
|
| 83 |
-
low_sat = (
|
| 84 |
-
|
| 85 |
-
)
|
| 86 |
|
| 87 |
# Combined: bright AND low-saturation within face
|
| 88 |
vitiligo_raw = cv2.bitwise_and(bright_mask, low_sat)
|
|
|
|
| 80 |
# Also check for low saturation (a,b channels close to 128)
|
| 81 |
a_channel = lab[:, :, 1]
|
| 82 |
b_channel = lab[:, :, 2]
|
| 83 |
+
low_sat = ((np.abs(a_channel - 128) < 15) & (np.abs(b_channel - 128) < 15)).astype(
|
| 84 |
+
np.uint8
|
| 85 |
+
) * 255
|
| 86 |
|
| 87 |
# Combined: bright AND low-saturation within face
|
| 88 |
vitiligo_raw = cv2.bitwise_and(bright_mask, low_sat)
|
landmarkdiff/conditioning.py
CHANGED
|
@@ -18,17 +18,83 @@ from landmarkdiff.landmarks import FaceLandmarks
|
|
| 18 |
# This is invariant to landmark displacement (unlike Delaunay).
|
| 19 |
|
| 20 |
JAWLINE_CONTOUR = [
|
| 21 |
-
10,
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
]
|
| 25 |
|
| 26 |
LEFT_EYE_CONTOUR = [
|
| 27 |
-
33,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
]
|
| 29 |
|
| 30 |
RIGHT_EYE_CONTOUR = [
|
| 31 |
-
362,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
]
|
| 33 |
|
| 34 |
LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
|
|
@@ -39,13 +105,53 @@ NOSE_TIP = [94, 2, 326, 327, 294, 278, 279, 275, 274, 460, 456, 363, 370]
|
|
| 39 |
NOSE_BOTTOM = [19, 1, 274, 275, 440, 344, 278, 294, 460, 305, 289, 392]
|
| 40 |
|
| 41 |
OUTER_LIPS = [
|
| 42 |
-
61,
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
]
|
| 45 |
|
| 46 |
INNER_LIPS = [
|
| 47 |
-
78,
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
]
|
| 50 |
|
| 51 |
ALL_CONTOURS = [
|
|
|
|
| 18 |
# This is invariant to landmark displacement (unlike Delaunay).
|
| 19 |
|
| 20 |
JAWLINE_CONTOUR = [
|
| 21 |
+
10,
|
| 22 |
+
338,
|
| 23 |
+
297,
|
| 24 |
+
332,
|
| 25 |
+
284,
|
| 26 |
+
251,
|
| 27 |
+
389,
|
| 28 |
+
356,
|
| 29 |
+
454,
|
| 30 |
+
323,
|
| 31 |
+
361,
|
| 32 |
+
288,
|
| 33 |
+
397,
|
| 34 |
+
365,
|
| 35 |
+
379,
|
| 36 |
+
378,
|
| 37 |
+
400,
|
| 38 |
+
377,
|
| 39 |
+
152,
|
| 40 |
+
148,
|
| 41 |
+
176,
|
| 42 |
+
149,
|
| 43 |
+
150,
|
| 44 |
+
136,
|
| 45 |
+
172,
|
| 46 |
+
58,
|
| 47 |
+
132,
|
| 48 |
+
93,
|
| 49 |
+
234,
|
| 50 |
+
127,
|
| 51 |
+
162,
|
| 52 |
+
21,
|
| 53 |
+
54,
|
| 54 |
+
103,
|
| 55 |
+
67,
|
| 56 |
+
109,
|
| 57 |
+
10,
|
| 58 |
]
|
| 59 |
|
| 60 |
LEFT_EYE_CONTOUR = [
|
| 61 |
+
33,
|
| 62 |
+
7,
|
| 63 |
+
163,
|
| 64 |
+
144,
|
| 65 |
+
145,
|
| 66 |
+
153,
|
| 67 |
+
154,
|
| 68 |
+
155,
|
| 69 |
+
133,
|
| 70 |
+
173,
|
| 71 |
+
157,
|
| 72 |
+
158,
|
| 73 |
+
159,
|
| 74 |
+
160,
|
| 75 |
+
161,
|
| 76 |
+
246,
|
| 77 |
+
33,
|
| 78 |
]
|
| 79 |
|
| 80 |
RIGHT_EYE_CONTOUR = [
|
| 81 |
+
362,
|
| 82 |
+
382,
|
| 83 |
+
381,
|
| 84 |
+
380,
|
| 85 |
+
374,
|
| 86 |
+
373,
|
| 87 |
+
390,
|
| 88 |
+
249,
|
| 89 |
+
263,
|
| 90 |
+
466,
|
| 91 |
+
388,
|
| 92 |
+
387,
|
| 93 |
+
386,
|
| 94 |
+
385,
|
| 95 |
+
384,
|
| 96 |
+
398,
|
| 97 |
+
362,
|
| 98 |
]
|
| 99 |
|
| 100 |
LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
|
|
|
|
| 105 |
NOSE_BOTTOM = [19, 1, 274, 275, 440, 344, 278, 294, 460, 305, 289, 392]
|
| 106 |
|
| 107 |
OUTER_LIPS = [
|
| 108 |
+
61,
|
| 109 |
+
146,
|
| 110 |
+
91,
|
| 111 |
+
181,
|
| 112 |
+
84,
|
| 113 |
+
17,
|
| 114 |
+
314,
|
| 115 |
+
405,
|
| 116 |
+
321,
|
| 117 |
+
375,
|
| 118 |
+
291,
|
| 119 |
+
308,
|
| 120 |
+
324,
|
| 121 |
+
318,
|
| 122 |
+
402,
|
| 123 |
+
317,
|
| 124 |
+
14,
|
| 125 |
+
87,
|
| 126 |
+
178,
|
| 127 |
+
88,
|
| 128 |
+
95,
|
| 129 |
+
78,
|
| 130 |
+
61,
|
| 131 |
]
|
| 132 |
|
| 133 |
INNER_LIPS = [
|
| 134 |
+
78,
|
| 135 |
+
191,
|
| 136 |
+
80,
|
| 137 |
+
81,
|
| 138 |
+
82,
|
| 139 |
+
13,
|
| 140 |
+
312,
|
| 141 |
+
311,
|
| 142 |
+
310,
|
| 143 |
+
415,
|
| 144 |
+
308,
|
| 145 |
+
324,
|
| 146 |
+
318,
|
| 147 |
+
402,
|
| 148 |
+
317,
|
| 149 |
+
14,
|
| 150 |
+
87,
|
| 151 |
+
178,
|
| 152 |
+
88,
|
| 153 |
+
95,
|
| 154 |
+
78,
|
| 155 |
]
|
| 156 |
|
| 157 |
ALL_CONTOURS = [
|
landmarkdiff/config.py
CHANGED
|
@@ -18,7 +18,8 @@ Usage:
|
|
| 18 |
|
| 19 |
from __future__ import annotations
|
| 20 |
|
| 21 |
-
|
|
|
|
| 22 |
from pathlib import Path
|
| 23 |
from typing import Any
|
| 24 |
|
|
@@ -28,6 +29,7 @@ import yaml
|
|
| 28 |
@dataclass
|
| 29 |
class ModelConfig:
|
| 30 |
"""ControlNet and base model configuration."""
|
|
|
|
| 31 |
base_model: str = "runwayml/stable-diffusion-v1-5"
|
| 32 |
controlnet_conditioning_channels: int = 3
|
| 33 |
controlnet_conditioning_scale: float = 1.0
|
|
@@ -39,6 +41,7 @@ class ModelConfig:
|
|
| 39 |
@dataclass
|
| 40 |
class TrainingConfig:
|
| 41 |
"""Training hyperparameters."""
|
|
|
|
| 42 |
phase: str = "A" # "A" or "B"
|
| 43 |
learning_rate: float = 1e-5
|
| 44 |
batch_size: int = 4
|
|
@@ -77,6 +80,7 @@ class TrainingConfig:
|
|
| 77 |
@dataclass
|
| 78 |
class DataConfig:
|
| 79 |
"""Dataset configuration."""
|
|
|
|
| 80 |
train_dir: str = "data/training"
|
| 81 |
val_dir: str = "data/validation"
|
| 82 |
test_dir: str = "data/test"
|
|
@@ -90,9 +94,14 @@ class DataConfig:
|
|
| 90 |
color_jitter: float = 0.1
|
| 91 |
|
| 92 |
# Procedure filtering
|
| 93 |
-
procedures: list[str] = field(
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
intensity_range: tuple[float, float] = (30.0, 100.0)
|
| 97 |
|
| 98 |
# Data-driven displacement
|
|
@@ -103,6 +112,7 @@ class DataConfig:
|
|
| 103 |
@dataclass
|
| 104 |
class InferenceConfig:
|
| 105 |
"""Inference / generation configuration."""
|
|
|
|
| 106 |
num_inference_steps: int = 30
|
| 107 |
guidance_scale: float = 7.5
|
| 108 |
scheduler: str = "dpmsolver++" # "ddpm", "ddim", "dpmsolver++"
|
|
@@ -124,6 +134,7 @@ class InferenceConfig:
|
|
| 124 |
@dataclass
|
| 125 |
class EvaluationConfig:
|
| 126 |
"""Evaluation configuration."""
|
|
|
|
| 127 |
compute_fid: bool = True
|
| 128 |
compute_lpips: bool = True
|
| 129 |
compute_nme: bool = True
|
|
@@ -137,6 +148,7 @@ class EvaluationConfig:
|
|
| 137 |
@dataclass
|
| 138 |
class WandbConfig:
|
| 139 |
"""Weights & Biases logging configuration."""
|
|
|
|
| 140 |
enabled: bool = True
|
| 141 |
project: str = "landmarkdiff"
|
| 142 |
entity: str | None = None
|
|
@@ -147,8 +159,9 @@ class WandbConfig:
|
|
| 147 |
@dataclass
|
| 148 |
class SlurmConfig:
|
| 149 |
"""SLURM job submission parameters."""
|
|
|
|
| 150 |
partition: str = "batch_gpu"
|
| 151 |
-
account: str = "
|
| 152 |
gpu_type: str = "nvidia_rtx_a6000"
|
| 153 |
num_gpus: int = 1
|
| 154 |
mem: str = "48G"
|
|
@@ -160,6 +173,7 @@ class SlurmConfig:
|
|
| 160 |
@dataclass
|
| 161 |
class SafetyConfig:
|
| 162 |
"""Clinical safety and responsible AI parameters."""
|
|
|
|
| 163 |
identity_threshold: float = 0.6
|
| 164 |
max_displacement_fraction: float = 0.05
|
| 165 |
watermark_enabled: bool = True
|
|
@@ -173,6 +187,7 @@ class SafetyConfig:
|
|
| 173 |
@dataclass
|
| 174 |
class ExperimentConfig:
|
| 175 |
"""Top-level experiment configuration."""
|
|
|
|
| 176 |
experiment_name: str = "default"
|
| 177 |
description: str = ""
|
| 178 |
version: str = "0.3.0"
|
|
@@ -227,9 +242,10 @@ class ExperimentConfig:
|
|
| 227 |
return asdict(self)
|
| 228 |
|
| 229 |
|
| 230 |
-
def _from_dict(cls, d: dict):
|
| 231 |
"""Create a dataclass from a dict, ignoring unknown keys."""
|
| 232 |
import dataclasses
|
|
|
|
| 233 |
field_map = {f.name: f for f in dataclasses.fields(cls)}
|
| 234 |
filtered = {}
|
| 235 |
for k, v in d.items():
|
|
@@ -243,7 +259,7 @@ def _from_dict(cls, d: dict):
|
|
| 243 |
return cls(**filtered)
|
| 244 |
|
| 245 |
|
| 246 |
-
def _convert_tuples(obj):
|
| 247 |
"""Recursively convert tuples to lists for YAML serialization."""
|
| 248 |
if isinstance(obj, dict):
|
| 249 |
return {k: _convert_tuples(v) for k, v in obj.items()}
|
|
@@ -266,10 +282,7 @@ def load_config(
|
|
| 266 |
Returns:
|
| 267 |
ExperimentConfig with overrides applied.
|
| 268 |
"""
|
| 269 |
-
if config_path
|
| 270 |
-
config = ExperimentConfig.from_yaml(config_path)
|
| 271 |
-
else:
|
| 272 |
-
config = ExperimentConfig()
|
| 273 |
|
| 274 |
if overrides:
|
| 275 |
for key, value in overrides.items():
|
|
|
|
| 18 |
|
| 19 |
from __future__ import annotations
|
| 20 |
|
| 21 |
+
import os
|
| 22 |
+
from dataclasses import asdict, dataclass, field
|
| 23 |
from pathlib import Path
|
| 24 |
from typing import Any
|
| 25 |
|
|
|
|
| 29 |
@dataclass
|
| 30 |
class ModelConfig:
|
| 31 |
"""ControlNet and base model configuration."""
|
| 32 |
+
|
| 33 |
base_model: str = "runwayml/stable-diffusion-v1-5"
|
| 34 |
controlnet_conditioning_channels: int = 3
|
| 35 |
controlnet_conditioning_scale: float = 1.0
|
|
|
|
| 41 |
@dataclass
|
| 42 |
class TrainingConfig:
|
| 43 |
"""Training hyperparameters."""
|
| 44 |
+
|
| 45 |
phase: str = "A" # "A" or "B"
|
| 46 |
learning_rate: float = 1e-5
|
| 47 |
batch_size: int = 4
|
|
|
|
| 80 |
@dataclass
|
| 81 |
class DataConfig:
|
| 82 |
"""Dataset configuration."""
|
| 83 |
+
|
| 84 |
train_dir: str = "data/training"
|
| 85 |
val_dir: str = "data/validation"
|
| 86 |
test_dir: str = "data/test"
|
|
|
|
| 94 |
color_jitter: float = 0.1
|
| 95 |
|
| 96 |
# Procedure filtering
|
| 97 |
+
procedures: list[str] = field(
|
| 98 |
+
default_factory=lambda: [
|
| 99 |
+
"rhinoplasty",
|
| 100 |
+
"blepharoplasty",
|
| 101 |
+
"rhytidectomy",
|
| 102 |
+
"orthognathic",
|
| 103 |
+
]
|
| 104 |
+
)
|
| 105 |
intensity_range: tuple[float, float] = (30.0, 100.0)
|
| 106 |
|
| 107 |
# Data-driven displacement
|
|
|
|
| 112 |
@dataclass
|
| 113 |
class InferenceConfig:
|
| 114 |
"""Inference / generation configuration."""
|
| 115 |
+
|
| 116 |
num_inference_steps: int = 30
|
| 117 |
guidance_scale: float = 7.5
|
| 118 |
scheduler: str = "dpmsolver++" # "ddpm", "ddim", "dpmsolver++"
|
|
|
|
| 134 |
@dataclass
|
| 135 |
class EvaluationConfig:
|
| 136 |
"""Evaluation configuration."""
|
| 137 |
+
|
| 138 |
compute_fid: bool = True
|
| 139 |
compute_lpips: bool = True
|
| 140 |
compute_nme: bool = True
|
|
|
|
| 148 |
@dataclass
|
| 149 |
class WandbConfig:
|
| 150 |
"""Weights & Biases logging configuration."""
|
| 151 |
+
|
| 152 |
enabled: bool = True
|
| 153 |
project: str = "landmarkdiff"
|
| 154 |
entity: str | None = None
|
|
|
|
| 159 |
@dataclass
|
| 160 |
class SlurmConfig:
|
| 161 |
"""SLURM job submission parameters."""
|
| 162 |
+
|
| 163 |
partition: str = "batch_gpu"
|
| 164 |
+
account: str = os.environ.get("SLURM_ACCOUNT", "default_gpu")
|
| 165 |
gpu_type: str = "nvidia_rtx_a6000"
|
| 166 |
num_gpus: int = 1
|
| 167 |
mem: str = "48G"
|
|
|
|
| 173 |
@dataclass
|
| 174 |
class SafetyConfig:
|
| 175 |
"""Clinical safety and responsible AI parameters."""
|
| 176 |
+
|
| 177 |
identity_threshold: float = 0.6
|
| 178 |
max_displacement_fraction: float = 0.05
|
| 179 |
watermark_enabled: bool = True
|
|
|
|
| 187 |
@dataclass
|
| 188 |
class ExperimentConfig:
|
| 189 |
"""Top-level experiment configuration."""
|
| 190 |
+
|
| 191 |
experiment_name: str = "default"
|
| 192 |
description: str = ""
|
| 193 |
version: str = "0.3.0"
|
|
|
|
| 242 |
return asdict(self)
|
| 243 |
|
| 244 |
|
| 245 |
+
def _from_dict(cls: type, d: dict) -> Any:
|
| 246 |
"""Create a dataclass from a dict, ignoring unknown keys."""
|
| 247 |
import dataclasses
|
| 248 |
+
|
| 249 |
field_map = {f.name: f for f in dataclasses.fields(cls)}
|
| 250 |
filtered = {}
|
| 251 |
for k, v in d.items():
|
|
|
|
| 259 |
return cls(**filtered)
|
| 260 |
|
| 261 |
|
| 262 |
+
def _convert_tuples(obj: Any) -> Any:
|
| 263 |
"""Recursively convert tuples to lists for YAML serialization."""
|
| 264 |
if isinstance(obj, dict):
|
| 265 |
return {k: _convert_tuples(v) for k, v in obj.items()}
|
|
|
|
| 282 |
Returns:
|
| 283 |
ExperimentConfig with overrides applied.
|
| 284 |
"""
|
| 285 |
+
config = ExperimentConfig.from_yaml(config_path) if config_path else ExperimentConfig()
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
if overrides:
|
| 288 |
for key, value in overrides.items():
|
landmarkdiff/curriculum.py
CHANGED
|
@@ -104,10 +104,10 @@ class ProcedureCurriculum:
|
|
| 104 |
|
| 105 |
# Difficulty ranking (0=easiest, 1=hardest)
|
| 106 |
DEFAULT_PROCEDURE_DIFFICULTY = {
|
| 107 |
-
"blepharoplasty": 0.3,
|
| 108 |
-
"rhinoplasty": 0.5,
|
| 109 |
-
"rhytidectomy": 0.7,
|
| 110 |
-
"orthognathic": 0.9,
|
| 111 |
}
|
| 112 |
|
| 113 |
def __init__(
|
|
@@ -137,10 +137,7 @@ class ProcedureCurriculum:
|
|
| 137 |
|
| 138 |
def get_procedure_weights(self, step: int) -> dict[str, float]:
|
| 139 |
"""Get all procedure weights at the given step."""
|
| 140 |
-
return {
|
| 141 |
-
proc: self.get_weight(step, proc)
|
| 142 |
-
for proc in self.proc_difficulty
|
| 143 |
-
}
|
| 144 |
|
| 145 |
|
| 146 |
def compute_sample_difficulty(
|
|
@@ -174,7 +171,7 @@ def compute_sample_difficulty(
|
|
| 174 |
source_bonus = {
|
| 175 |
"synthetic": 0.0,
|
| 176 |
"synthetic_v3": 0.1, # realistic displacements slightly harder
|
| 177 |
-
"real": 0.2,
|
| 178 |
"augmented": 0.0,
|
| 179 |
}
|
| 180 |
|
|
|
|
| 104 |
|
| 105 |
# Difficulty ranking (0=easiest, 1=hardest)
|
| 106 |
DEFAULT_PROCEDURE_DIFFICULTY = {
|
| 107 |
+
"blepharoplasty": 0.3, # small, localized changes
|
| 108 |
+
"rhinoplasty": 0.5, # moderate, central face
|
| 109 |
+
"rhytidectomy": 0.7, # large, affects face shape
|
| 110 |
+
"orthognathic": 0.9, # largest deformations
|
| 111 |
}
|
| 112 |
|
| 113 |
def __init__(
|
|
|
|
| 137 |
|
| 138 |
def get_procedure_weights(self, step: int) -> dict[str, float]:
|
| 139 |
"""Get all procedure weights at the given step."""
|
| 140 |
+
return {proc: self.get_weight(step, proc) for proc in self.proc_difficulty}
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
def compute_sample_difficulty(
|
|
|
|
| 171 |
source_bonus = {
|
| 172 |
"synthetic": 0.0,
|
| 173 |
"synthetic_v3": 0.1, # realistic displacements slightly harder
|
| 174 |
+
"real": 0.2, # real data hardest
|
| 175 |
"augmented": 0.0,
|
| 176 |
}
|
| 177 |
|
landmarkdiff/data.py
CHANGED
|
@@ -23,8 +23,8 @@ from __future__ import annotations
|
|
| 23 |
import csv
|
| 24 |
import json
|
| 25 |
import logging
|
|
|
|
| 26 |
from pathlib import Path
|
| 27 |
-
from typing import Callable
|
| 28 |
|
| 29 |
import cv2
|
| 30 |
import numpy as np
|
|
@@ -38,6 +38,7 @@ logger = logging.getLogger(__name__)
|
|
| 38 |
# Core dataset
|
| 39 |
# ---------------------------------------------------------------------------
|
| 40 |
|
|
|
|
| 41 |
class SurgicalPairDataset(Dataset):
|
| 42 |
"""Dataset for loading surgical before/after training pairs.
|
| 43 |
|
|
@@ -162,9 +163,7 @@ class SurgicalPairDataset(Dataset):
|
|
| 162 |
img = cv2.imread(str(path))
|
| 163 |
if img is None:
|
| 164 |
logger.warning("Failed to load %s, using blank", path)
|
| 165 |
-
return np.zeros(
|
| 166 |
-
(self.resolution, self.resolution, 3), dtype=np.uint8
|
| 167 |
-
)
|
| 168 |
if img.shape[:2] != (self.resolution, self.resolution):
|
| 169 |
img = cv2.resize(img, (self.resolution, self.resolution))
|
| 170 |
return img
|
|
@@ -173,14 +172,10 @@ class SurgicalPairDataset(Dataset):
|
|
| 173 |
"""Load a mask as float32 [0,1], resized to resolution."""
|
| 174 |
path = self.data_dir / filename
|
| 175 |
if not path.exists():
|
| 176 |
-
return np.ones(
|
| 177 |
-
(self.resolution, self.resolution), dtype=np.float32
|
| 178 |
-
)
|
| 179 |
mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
|
| 180 |
if mask is None:
|
| 181 |
-
return np.ones(
|
| 182 |
-
(self.resolution, self.resolution), dtype=np.float32
|
| 183 |
-
)
|
| 184 |
mask = cv2.resize(mask, (self.resolution, self.resolution))
|
| 185 |
return mask.astype(np.float32) / 255.0
|
| 186 |
|
|
@@ -189,6 +184,7 @@ class SurgicalPairDataset(Dataset):
|
|
| 189 |
# Evaluation dataset (input + ground truth)
|
| 190 |
# ---------------------------------------------------------------------------
|
| 191 |
|
|
|
|
| 192 |
class EvalPairDataset(Dataset):
|
| 193 |
"""Dataset for evaluation: loads input/target pairs with procedure labels.
|
| 194 |
|
|
@@ -235,9 +231,7 @@ class EvalPairDataset(Dataset):
|
|
| 235 |
path = self.data_dir / filename
|
| 236 |
img = cv2.imread(str(path))
|
| 237 |
if img is None:
|
| 238 |
-
return np.zeros(
|
| 239 |
-
(self.resolution, self.resolution, 3), dtype=np.uint8
|
| 240 |
-
)
|
| 241 |
if img.shape[:2] != (self.resolution, self.resolution):
|
| 242 |
img = cv2.resize(img, (self.resolution, self.resolution))
|
| 243 |
return img
|
|
@@ -247,6 +241,7 @@ class EvalPairDataset(Dataset):
|
|
| 247 |
# Conversion utilities
|
| 248 |
# ---------------------------------------------------------------------------
|
| 249 |
|
|
|
|
| 250 |
def bgr_to_tensor(bgr: np.ndarray) -> torch.Tensor:
|
| 251 |
"""Convert BGR uint8 image to RGB [0,1] tensor (C, H, W)."""
|
| 252 |
rgb = bgr[:, :, ::-1].astype(np.float32) / 255.0
|
|
@@ -271,6 +266,7 @@ def mask_to_tensor(mask: np.ndarray) -> torch.Tensor:
|
|
| 271 |
# Samplers
|
| 272 |
# ---------------------------------------------------------------------------
|
| 273 |
|
|
|
|
| 274 |
def create_procedure_sampler(
|
| 275 |
dataset: SurgicalPairDataset,
|
| 276 |
balance_procedures: bool = True,
|
|
@@ -309,6 +305,7 @@ def create_procedure_sampler(
|
|
| 309 |
# DataLoader factory
|
| 310 |
# ---------------------------------------------------------------------------
|
| 311 |
|
|
|
|
| 312 |
def create_dataloader(
|
| 313 |
dataset: Dataset,
|
| 314 |
batch_size: int = 4,
|
|
@@ -353,6 +350,7 @@ def create_dataloader(
|
|
| 353 |
# Multi-directory dataset
|
| 354 |
# ---------------------------------------------------------------------------
|
| 355 |
|
|
|
|
| 356 |
class CombinedDataset(Dataset):
|
| 357 |
"""Combine multiple SurgicalPairDatasets into one.
|
| 358 |
|
|
|
|
| 23 |
import csv
|
| 24 |
import json
|
| 25 |
import logging
|
| 26 |
+
from collections.abc import Callable
|
| 27 |
from pathlib import Path
|
|
|
|
| 28 |
|
| 29 |
import cv2
|
| 30 |
import numpy as np
|
|
|
|
| 38 |
# Core dataset
|
| 39 |
# ---------------------------------------------------------------------------
|
| 40 |
|
| 41 |
+
|
| 42 |
class SurgicalPairDataset(Dataset):
|
| 43 |
"""Dataset for loading surgical before/after training pairs.
|
| 44 |
|
|
|
|
| 163 |
img = cv2.imread(str(path))
|
| 164 |
if img is None:
|
| 165 |
logger.warning("Failed to load %s, using blank", path)
|
| 166 |
+
return np.zeros((self.resolution, self.resolution, 3), dtype=np.uint8)
|
|
|
|
|
|
|
| 167 |
if img.shape[:2] != (self.resolution, self.resolution):
|
| 168 |
img = cv2.resize(img, (self.resolution, self.resolution))
|
| 169 |
return img
|
|
|
|
| 172 |
"""Load a mask as float32 [0,1], resized to resolution."""
|
| 173 |
path = self.data_dir / filename
|
| 174 |
if not path.exists():
|
| 175 |
+
return np.ones((self.resolution, self.resolution), dtype=np.float32)
|
|
|
|
|
|
|
| 176 |
mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
|
| 177 |
if mask is None:
|
| 178 |
+
return np.ones((self.resolution, self.resolution), dtype=np.float32)
|
|
|
|
|
|
|
| 179 |
mask = cv2.resize(mask, (self.resolution, self.resolution))
|
| 180 |
return mask.astype(np.float32) / 255.0
|
| 181 |
|
|
|
|
| 184 |
# Evaluation dataset (input + ground truth)
|
| 185 |
# ---------------------------------------------------------------------------
|
| 186 |
|
| 187 |
+
|
| 188 |
class EvalPairDataset(Dataset):
|
| 189 |
"""Dataset for evaluation: loads input/target pairs with procedure labels.
|
| 190 |
|
|
|
|
| 231 |
path = self.data_dir / filename
|
| 232 |
img = cv2.imread(str(path))
|
| 233 |
if img is None:
|
| 234 |
+
return np.zeros((self.resolution, self.resolution, 3), dtype=np.uint8)
|
|
|
|
|
|
|
| 235 |
if img.shape[:2] != (self.resolution, self.resolution):
|
| 236 |
img = cv2.resize(img, (self.resolution, self.resolution))
|
| 237 |
return img
|
|
|
|
| 241 |
# Conversion utilities
|
| 242 |
# ---------------------------------------------------------------------------
|
| 243 |
|
| 244 |
+
|
| 245 |
def bgr_to_tensor(bgr: np.ndarray) -> torch.Tensor:
|
| 246 |
"""Convert BGR uint8 image to RGB [0,1] tensor (C, H, W)."""
|
| 247 |
rgb = bgr[:, :, ::-1].astype(np.float32) / 255.0
|
|
|
|
| 266 |
# Samplers
|
| 267 |
# ---------------------------------------------------------------------------
|
| 268 |
|
| 269 |
+
|
| 270 |
def create_procedure_sampler(
|
| 271 |
dataset: SurgicalPairDataset,
|
| 272 |
balance_procedures: bool = True,
|
|
|
|
| 305 |
# DataLoader factory
|
| 306 |
# ---------------------------------------------------------------------------
|
| 307 |
|
| 308 |
+
|
| 309 |
def create_dataloader(
|
| 310 |
dataset: Dataset,
|
| 311 |
batch_size: int = 4,
|
|
|
|
| 350 |
# Multi-directory dataset
|
| 351 |
# ---------------------------------------------------------------------------
|
| 352 |
|
| 353 |
+
|
| 354 |
class CombinedDataset(Dataset):
|
| 355 |
"""Combine multiple SurgicalPairDatasets into one.
|
| 356 |
|
landmarkdiff/data_version.py
CHANGED
|
@@ -22,7 +22,7 @@ import json
|
|
| 22 |
from dataclasses import dataclass, field
|
| 23 |
from datetime import datetime, timezone
|
| 24 |
from pathlib import Path
|
| 25 |
-
from typing import Any
|
| 26 |
|
| 27 |
|
| 28 |
@dataclass
|
|
@@ -68,9 +68,7 @@ class DataManifest:
|
|
| 68 |
"""
|
| 69 |
|
| 70 |
version: str = "1.0"
|
| 71 |
-
created_at: str = field(
|
| 72 |
-
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
| 73 |
-
)
|
| 74 |
root_dir: str = ""
|
| 75 |
files: list[FileEntry] = field(default_factory=list)
|
| 76 |
metadata: dict[str, Any] = field(default_factory=dict)
|
|
@@ -237,8 +235,7 @@ class DataManifest:
|
|
| 237 |
actual_size = fp.stat().st_size
|
| 238 |
if actual_size != entry.size_bytes:
|
| 239 |
issues.append(
|
| 240 |
-
f"Size mismatch: {entry.path} "
|
| 241 |
-
f"(expected {entry.size_bytes}, got {actual_size})"
|
| 242 |
)
|
| 243 |
|
| 244 |
# Check checksum
|
|
@@ -265,8 +262,7 @@ class DataManifest:
|
|
| 265 |
added = sorted(other_paths - self_paths)
|
| 266 |
removed = sorted(self_paths - other_paths)
|
| 267 |
modified = sorted(
|
| 268 |
-
p for p in self_paths & other_paths
|
| 269 |
-
if self_files[p].checksum != other_files[p].checksum
|
| 270 |
)
|
| 271 |
|
| 272 |
return {"added": added, "removed": removed, "modified": modified}
|
|
@@ -292,6 +288,7 @@ def _get_hostname() -> str:
|
|
| 292 |
"""Get hostname safely."""
|
| 293 |
try:
|
| 294 |
import socket
|
|
|
|
| 295 |
return socket.gethostname()
|
| 296 |
except Exception:
|
| 297 |
return "unknown"
|
|
|
|
| 22 |
from dataclasses import dataclass, field
|
| 23 |
from datetime import datetime, timezone
|
| 24 |
from pathlib import Path
|
| 25 |
+
from typing import Any
|
| 26 |
|
| 27 |
|
| 28 |
@dataclass
|
|
|
|
| 68 |
"""
|
| 69 |
|
| 70 |
version: str = "1.0"
|
| 71 |
+
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
|
|
|
|
|
| 72 |
root_dir: str = ""
|
| 73 |
files: list[FileEntry] = field(default_factory=list)
|
| 74 |
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
| 235 |
actual_size = fp.stat().st_size
|
| 236 |
if actual_size != entry.size_bytes:
|
| 237 |
issues.append(
|
| 238 |
+
f"Size mismatch: {entry.path} (expected {entry.size_bytes}, got {actual_size})"
|
|
|
|
| 239 |
)
|
| 240 |
|
| 241 |
# Check checksum
|
|
|
|
| 262 |
added = sorted(other_paths - self_paths)
|
| 263 |
removed = sorted(self_paths - other_paths)
|
| 264 |
modified = sorted(
|
| 265 |
+
p for p in self_paths & other_paths if self_files[p].checksum != other_files[p].checksum
|
|
|
|
| 266 |
)
|
| 267 |
|
| 268 |
return {"added": added, "removed": removed, "modified": modified}
|
|
|
|
| 288 |
"""Get hostname safely."""
|
| 289 |
try:
|
| 290 |
import socket
|
| 291 |
+
|
| 292 |
return socket.gethostname()
|
| 293 |
except Exception:
|
| 294 |
return "unknown"
|
landmarkdiff/displacement_model.py
CHANGED
|
@@ -33,12 +33,11 @@ from __future__ import annotations
|
|
| 33 |
import json
|
| 34 |
import logging
|
| 35 |
from pathlib import Path
|
| 36 |
-
from typing import Optional, Union
|
| 37 |
|
| 38 |
import cv2
|
| 39 |
import numpy as np
|
| 40 |
|
| 41 |
-
from landmarkdiff.landmarks import
|
| 42 |
from landmarkdiff.manipulation import PROCEDURE_LANDMARKS
|
| 43 |
|
| 44 |
logger = logging.getLogger(__name__)
|
|
@@ -54,6 +53,7 @@ PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
|
|
| 54 |
# Helpers
|
| 55 |
# ---------------------------------------------------------------------------
|
| 56 |
|
|
|
|
| 57 |
def _normalized_coords_2d(face: FaceLandmarks) -> np.ndarray:
|
| 58 |
"""Extract (478, 2) normalized [0, 1] coordinates from a FaceLandmarks object.
|
| 59 |
|
|
@@ -84,9 +84,24 @@ def _compute_alignment_quality(
|
|
| 84 |
# Stable landmarks: forehead, temple region, outer face oval
|
| 85 |
# These should exhibit near-zero displacement after surgery.
|
| 86 |
stable_indices = [
|
| 87 |
-
10,
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
]
|
| 91 |
stable_indices = [i for i in stable_indices if i < NUM_LANDMARKS]
|
| 92 |
|
|
@@ -95,7 +110,7 @@ def _compute_alignment_quality(
|
|
| 95 |
|
| 96 |
# RMS displacement on stable points
|
| 97 |
diffs = after_stable - before_stable
|
| 98 |
-
rms = np.sqrt(np.mean(np.sum(diffs
|
| 99 |
|
| 100 |
# Map RMS to quality: 0 displacement -> 1.0, rms >= 0.05 (5% of image) -> 0.0
|
| 101 |
quality = float(np.clip(1.0 - rms / 0.05, 0.0, 1.0))
|
|
@@ -106,6 +121,7 @@ def _compute_alignment_quality(
|
|
| 106 |
# Procedure classification
|
| 107 |
# ---------------------------------------------------------------------------
|
| 108 |
|
|
|
|
| 109 |
def classify_procedure(displacements: np.ndarray) -> str:
|
| 110 |
"""Classify which surgical procedure was performed from displacement vectors.
|
| 111 |
|
|
@@ -143,8 +159,7 @@ def classify_procedure(displacements: np.ndarray) -> str:
|
|
| 143 |
# Threshold: mean displacement < 0.002 (~1 pixel at 512x512)
|
| 144 |
if best_score < 0.002:
|
| 145 |
logger.debug(
|
| 146 |
-
"No significant displacement detected (best=%.5f). "
|
| 147 |
-
"Classified as 'unknown'.",
|
| 148 |
best_score,
|
| 149 |
)
|
| 150 |
return "unknown"
|
|
@@ -156,11 +171,12 @@ def classify_procedure(displacements: np.ndarray) -> str:
|
|
| 156 |
# Single-pair extraction
|
| 157 |
# ---------------------------------------------------------------------------
|
| 158 |
|
|
|
|
| 159 |
def extract_displacements(
|
| 160 |
before_img: np.ndarray,
|
| 161 |
after_img: np.ndarray,
|
| 162 |
min_detection_confidence: float = 0.5,
|
| 163 |
-
) ->
|
| 164 |
"""Extract landmark displacements from a before/after surgery image pair.
|
| 165 |
|
| 166 |
Runs MediaPipe Face Mesh on both images, computes per-landmark
|
|
@@ -185,16 +201,12 @@ def extract_displacements(
|
|
| 185 |
Returns ``None`` if face detection fails on either image.
|
| 186 |
"""
|
| 187 |
# Extract landmarks from both images
|
| 188 |
-
face_before = extract_landmarks(
|
| 189 |
-
before_img, min_detection_confidence=min_detection_confidence
|
| 190 |
-
)
|
| 191 |
if face_before is None:
|
| 192 |
logger.warning("Face detection failed on before image.")
|
| 193 |
return None
|
| 194 |
|
| 195 |
-
face_after = extract_landmarks(
|
| 196 |
-
after_img, min_detection_confidence=min_detection_confidence
|
| 197 |
-
)
|
| 198 |
if face_after is None:
|
| 199 |
logger.warning("Face detection failed on after image.")
|
| 200 |
return None
|
|
@@ -227,8 +239,9 @@ def extract_displacements(
|
|
| 227 |
# Batch extraction from directory
|
| 228 |
# ---------------------------------------------------------------------------
|
| 229 |
|
|
|
|
| 230 |
def extract_from_directory(
|
| 231 |
-
pairs_dir:
|
| 232 |
min_detection_confidence: float = 0.5,
|
| 233 |
min_quality: float = 0.0,
|
| 234 |
) -> list[dict]:
|
|
@@ -338,6 +351,7 @@ def extract_from_directory(
|
|
| 338 |
# Displacement model
|
| 339 |
# ---------------------------------------------------------------------------
|
| 340 |
|
|
|
|
| 341 |
class DisplacementModel:
|
| 342 |
"""Statistical model of per-procedure surgical displacements.
|
| 343 |
|
|
@@ -418,12 +432,12 @@ class DisplacementModel:
|
|
| 418 |
n = stacked.shape[0]
|
| 419 |
|
| 420 |
self.stats[proc] = {
|
| 421 |
-
"mean": np.mean(stacked, axis=0),
|
| 422 |
-
"std": np.std(stacked, axis=0),
|
| 423 |
-
"min": np.min(stacked, axis=0),
|
| 424 |
-
"max": np.max(stacked, axis=0),
|
| 425 |
-
"median": np.median(stacked, axis=0),
|
| 426 |
-
"mean_magnitude": np.mean(
|
| 427 |
np.linalg.norm(stacked, axis=2), axis=0
|
| 428 |
),
|
| 429 |
}
|
|
@@ -442,7 +456,7 @@ class DisplacementModel:
|
|
| 442 |
procedure: str,
|
| 443 |
intensity: float = 1.0,
|
| 444 |
noise_scale: float = 0.0,
|
| 445 |
-
rng:
|
| 446 |
) -> np.ndarray:
|
| 447 |
"""Generate a displacement field for a given procedure and intensity.
|
| 448 |
|
|
@@ -470,10 +484,7 @@ class DisplacementModel:
|
|
| 470 |
|
| 471 |
if procedure not in self.stats:
|
| 472 |
available = ", ".join(self.procedures)
|
| 473 |
-
raise KeyError(
|
| 474 |
-
f"Procedure '{procedure}' not in model. "
|
| 475 |
-
f"Available: {available}"
|
| 476 |
-
)
|
| 477 |
|
| 478 |
proc_stats = self.stats[procedure]
|
| 479 |
field = proc_stats["mean"].copy() * intensity
|
|
@@ -489,7 +500,7 @@ class DisplacementModel:
|
|
| 489 |
|
| 490 |
return field.astype(np.float32)
|
| 491 |
|
| 492 |
-
def get_summary(self, procedure:
|
| 493 |
"""Get a human-readable summary of the model statistics.
|
| 494 |
|
| 495 |
Args:
|
|
@@ -518,7 +529,7 @@ class DisplacementModel:
|
|
| 518 |
|
| 519 |
return summary
|
| 520 |
|
| 521 |
-
def save(self, path:
|
| 522 |
"""Save the fitted model to disk as a ``.npz`` file.
|
| 523 |
|
| 524 |
The file contains:
|
|
@@ -550,15 +561,13 @@ class DisplacementModel:
|
|
| 550 |
"n_samples": self.n_samples,
|
| 551 |
"num_landmarks": NUM_LANDMARKS,
|
| 552 |
}
|
| 553 |
-
arrays["__metadata__"] = np.frombuffer(
|
| 554 |
-
json.dumps(metadata).encode("utf-8"), dtype=np.uint8
|
| 555 |
-
)
|
| 556 |
|
| 557 |
np.savez_compressed(str(path), **arrays)
|
| 558 |
logger.info("Saved displacement model to %s", path)
|
| 559 |
|
| 560 |
@classmethod
|
| 561 |
-
def load(cls, path:
|
| 562 |
"""Load a fitted model from a ``.npz`` file.
|
| 563 |
|
| 564 |
Supports two formats:
|
|
@@ -592,7 +601,7 @@ class DisplacementModel:
|
|
| 592 |
model.stats[proc] = {}
|
| 593 |
for key in data.files:
|
| 594 |
if key.startswith(f"{proc}__"):
|
| 595 |
-
stat_name = key[len(f"{proc}__"):]
|
| 596 |
model.stats[proc][stat_name] = data[key]
|
| 597 |
|
| 598 |
# Format 2: extract_displacements.py format with procedures array
|
|
@@ -625,10 +634,7 @@ class DisplacementModel:
|
|
| 625 |
model.n_samples[proc] = 0
|
| 626 |
|
| 627 |
else:
|
| 628 |
-
raise ValueError(
|
| 629 |
-
f"Unrecognized displacement model format. "
|
| 630 |
-
f"Keys: {data.files[:10]}"
|
| 631 |
-
)
|
| 632 |
|
| 633 |
model._fitted = True
|
| 634 |
logger.info(
|
|
@@ -644,6 +650,7 @@ class DisplacementModel:
|
|
| 644 |
# Utilities
|
| 645 |
# ---------------------------------------------------------------------------
|
| 646 |
|
|
|
|
| 647 |
def _top_k_landmarks(
|
| 648 |
magnitudes: np.ndarray,
|
| 649 |
k: int = 10,
|
|
@@ -659,10 +666,7 @@ def _top_k_landmarks(
|
|
| 659 |
descending by magnitude.
|
| 660 |
"""
|
| 661 |
top_indices = np.argsort(magnitudes)[::-1][:k]
|
| 662 |
-
return [
|
| 663 |
-
{"index": int(idx), "magnitude": float(magnitudes[idx])}
|
| 664 |
-
for idx in top_indices
|
| 665 |
-
]
|
| 666 |
|
| 667 |
|
| 668 |
def visualize_displacements(
|
|
@@ -698,7 +702,7 @@ def visualize_displacements(
|
|
| 698 |
dy = int(displacements[i, 1] * h * scale)
|
| 699 |
|
| 700 |
# Only draw if displacement is above noise floor
|
| 701 |
-
mag = np.sqrt(dx
|
| 702 |
if mag < 1.0:
|
| 703 |
continue
|
| 704 |
|
|
|
|
| 33 |
import json
|
| 34 |
import logging
|
| 35 |
from pathlib import Path
|
|
|
|
| 36 |
|
| 37 |
import cv2
|
| 38 |
import numpy as np
|
| 39 |
|
| 40 |
+
from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks
|
| 41 |
from landmarkdiff.manipulation import PROCEDURE_LANDMARKS
|
| 42 |
|
| 43 |
logger = logging.getLogger(__name__)
|
|
|
|
| 53 |
# Helpers
|
| 54 |
# ---------------------------------------------------------------------------
|
| 55 |
|
| 56 |
+
|
| 57 |
def _normalized_coords_2d(face: FaceLandmarks) -> np.ndarray:
|
| 58 |
"""Extract (478, 2) normalized [0, 1] coordinates from a FaceLandmarks object.
|
| 59 |
|
|
|
|
| 84 |
# Stable landmarks: forehead, temple region, outer face oval
|
| 85 |
# These should exhibit near-zero displacement after surgery.
|
| 86 |
stable_indices = [
|
| 87 |
+
10,
|
| 88 |
+
109,
|
| 89 |
+
67,
|
| 90 |
+
103,
|
| 91 |
+
54,
|
| 92 |
+
21,
|
| 93 |
+
162,
|
| 94 |
+
127, # left forehead/temple
|
| 95 |
+
338,
|
| 96 |
+
297,
|
| 97 |
+
332,
|
| 98 |
+
284,
|
| 99 |
+
251,
|
| 100 |
+
389,
|
| 101 |
+
356,
|
| 102 |
+
454, # right forehead/temple
|
| 103 |
+
234,
|
| 104 |
+
93, # outer cheek anchors
|
| 105 |
]
|
| 106 |
stable_indices = [i for i in stable_indices if i < NUM_LANDMARKS]
|
| 107 |
|
|
|
|
| 110 |
|
| 111 |
# RMS displacement on stable points
|
| 112 |
diffs = after_stable - before_stable
|
| 113 |
+
rms = np.sqrt(np.mean(np.sum(diffs**2, axis=1)))
|
| 114 |
|
| 115 |
# Map RMS to quality: 0 displacement -> 1.0, rms >= 0.05 (5% of image) -> 0.0
|
| 116 |
quality = float(np.clip(1.0 - rms / 0.05, 0.0, 1.0))
|
|
|
|
| 121 |
# Procedure classification
|
| 122 |
# ---------------------------------------------------------------------------
|
| 123 |
|
| 124 |
+
|
| 125 |
def classify_procedure(displacements: np.ndarray) -> str:
|
| 126 |
"""Classify which surgical procedure was performed from displacement vectors.
|
| 127 |
|
|
|
|
| 159 |
# Threshold: mean displacement < 0.002 (~1 pixel at 512x512)
|
| 160 |
if best_score < 0.002:
|
| 161 |
logger.debug(
|
| 162 |
+
"No significant displacement detected (best=%.5f). Classified as 'unknown'.",
|
|
|
|
| 163 |
best_score,
|
| 164 |
)
|
| 165 |
return "unknown"
|
|
|
|
| 171 |
# Single-pair extraction
|
| 172 |
# ---------------------------------------------------------------------------
|
| 173 |
|
| 174 |
+
|
| 175 |
def extract_displacements(
|
| 176 |
before_img: np.ndarray,
|
| 177 |
after_img: np.ndarray,
|
| 178 |
min_detection_confidence: float = 0.5,
|
| 179 |
+
) -> dict | None:
|
| 180 |
"""Extract landmark displacements from a before/after surgery image pair.
|
| 181 |
|
| 182 |
Runs MediaPipe Face Mesh on both images, computes per-landmark
|
|
|
|
| 201 |
Returns ``None`` if face detection fails on either image.
|
| 202 |
"""
|
| 203 |
# Extract landmarks from both images
|
| 204 |
+
face_before = extract_landmarks(before_img, min_detection_confidence=min_detection_confidence)
|
|
|
|
|
|
|
| 205 |
if face_before is None:
|
| 206 |
logger.warning("Face detection failed on before image.")
|
| 207 |
return None
|
| 208 |
|
| 209 |
+
face_after = extract_landmarks(after_img, min_detection_confidence=min_detection_confidence)
|
|
|
|
|
|
|
| 210 |
if face_after is None:
|
| 211 |
logger.warning("Face detection failed on after image.")
|
| 212 |
return None
|
|
|
|
| 239 |
# Batch extraction from directory
|
| 240 |
# ---------------------------------------------------------------------------
|
| 241 |
|
| 242 |
+
|
| 243 |
def extract_from_directory(
|
| 244 |
+
pairs_dir: str | Path,
|
| 245 |
min_detection_confidence: float = 0.5,
|
| 246 |
min_quality: float = 0.0,
|
| 247 |
) -> list[dict]:
|
|
|
|
| 351 |
# Displacement model
|
| 352 |
# ---------------------------------------------------------------------------
|
| 353 |
|
| 354 |
+
|
| 355 |
class DisplacementModel:
|
| 356 |
"""Statistical model of per-procedure surgical displacements.
|
| 357 |
|
|
|
|
| 432 |
n = stacked.shape[0]
|
| 433 |
|
| 434 |
self.stats[proc] = {
|
| 435 |
+
"mean": np.mean(stacked, axis=0), # (478, 2)
|
| 436 |
+
"std": np.std(stacked, axis=0), # (478, 2)
|
| 437 |
+
"min": np.min(stacked, axis=0), # (478, 2)
|
| 438 |
+
"max": np.max(stacked, axis=0), # (478, 2)
|
| 439 |
+
"median": np.median(stacked, axis=0), # (478, 2)
|
| 440 |
+
"mean_magnitude": np.mean( # (478,)
|
| 441 |
np.linalg.norm(stacked, axis=2), axis=0
|
| 442 |
),
|
| 443 |
}
|
|
|
|
| 456 |
procedure: str,
|
| 457 |
intensity: float = 1.0,
|
| 458 |
noise_scale: float = 0.0,
|
| 459 |
+
rng: np.random.Generator | None = None,
|
| 460 |
) -> np.ndarray:
|
| 461 |
"""Generate a displacement field for a given procedure and intensity.
|
| 462 |
|
|
|
|
| 484 |
|
| 485 |
if procedure not in self.stats:
|
| 486 |
available = ", ".join(self.procedures)
|
| 487 |
+
raise KeyError(f"Procedure '{procedure}' not in model. Available: {available}")
|
|
|
|
|
|
|
|
|
|
| 488 |
|
| 489 |
proc_stats = self.stats[procedure]
|
| 490 |
field = proc_stats["mean"].copy() * intensity
|
|
|
|
| 500 |
|
| 501 |
return field.astype(np.float32)
|
| 502 |
|
| 503 |
+
def get_summary(self, procedure: str | None = None) -> dict:
|
| 504 |
"""Get a human-readable summary of the model statistics.
|
| 505 |
|
| 506 |
Args:
|
|
|
|
| 529 |
|
| 530 |
return summary
|
| 531 |
|
| 532 |
+
def save(self, path: str | Path) -> None:
|
| 533 |
"""Save the fitted model to disk as a ``.npz`` file.
|
| 534 |
|
| 535 |
The file contains:
|
|
|
|
| 561 |
"n_samples": self.n_samples,
|
| 562 |
"num_landmarks": NUM_LANDMARKS,
|
| 563 |
}
|
| 564 |
+
arrays["__metadata__"] = np.frombuffer(json.dumps(metadata).encode("utf-8"), dtype=np.uint8)
|
|
|
|
|
|
|
| 565 |
|
| 566 |
np.savez_compressed(str(path), **arrays)
|
| 567 |
logger.info("Saved displacement model to %s", path)
|
| 568 |
|
| 569 |
@classmethod
|
| 570 |
+
def load(cls, path: str | Path) -> DisplacementModel:
|
| 571 |
"""Load a fitted model from a ``.npz`` file.
|
| 572 |
|
| 573 |
Supports two formats:
|
|
|
|
| 601 |
model.stats[proc] = {}
|
| 602 |
for key in data.files:
|
| 603 |
if key.startswith(f"{proc}__"):
|
| 604 |
+
stat_name = key[len(f"{proc}__") :]
|
| 605 |
model.stats[proc][stat_name] = data[key]
|
| 606 |
|
| 607 |
# Format 2: extract_displacements.py format with procedures array
|
|
|
|
| 634 |
model.n_samples[proc] = 0
|
| 635 |
|
| 636 |
else:
|
| 637 |
+
raise ValueError(f"Unrecognized displacement model format. Keys: {data.files[:10]}")
|
|
|
|
|
|
|
|
|
|
| 638 |
|
| 639 |
model._fitted = True
|
| 640 |
logger.info(
|
|
|
|
| 650 |
# Utilities
|
| 651 |
# ---------------------------------------------------------------------------
|
| 652 |
|
| 653 |
+
|
| 654 |
def _top_k_landmarks(
|
| 655 |
magnitudes: np.ndarray,
|
| 656 |
k: int = 10,
|
|
|
|
| 666 |
descending by magnitude.
|
| 667 |
"""
|
| 668 |
top_indices = np.argsort(magnitudes)[::-1][:k]
|
| 669 |
+
return [{"index": int(idx), "magnitude": float(magnitudes[idx])} for idx in top_indices]
|
|
|
|
|
|
|
|
|
|
| 670 |
|
| 671 |
|
| 672 |
def visualize_displacements(
|
|
|
|
| 702 |
dy = int(displacements[i, 1] * h * scale)
|
| 703 |
|
| 704 |
# Only draw if displacement is above noise floor
|
| 705 |
+
mag = np.sqrt(dx**2 + dy**2)
|
| 706 |
if mag < 1.0:
|
| 707 |
continue
|
| 708 |
|
landmarkdiff/ensemble.py
CHANGED
|
@@ -21,8 +21,6 @@ Usage:
|
|
| 21 |
|
| 22 |
from __future__ import annotations
|
| 23 |
|
| 24 |
-
from typing import Optional
|
| 25 |
-
|
| 26 |
import cv2
|
| 27 |
import numpy as np
|
| 28 |
|
|
@@ -93,7 +91,7 @@ class EnsembleInference:
|
|
| 93 |
guidance_scale: float = 9.0,
|
| 94 |
controlnet_conditioning_scale: float = 0.9,
|
| 95 |
strength: float = 0.5,
|
| 96 |
-
seed:
|
| 97 |
**kwargs,
|
| 98 |
) -> dict:
|
| 99 |
"""Generate ensemble output.
|
|
@@ -155,14 +153,16 @@ class EnsembleInference:
|
|
| 155 |
# Copy metadata from best result
|
| 156 |
best_idx = selected_idx if selected_idx >= 0 else 0
|
| 157 |
ensemble_result = dict(results[best_idx])
|
| 158 |
-
ensemble_result.update(
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
| 166 |
|
| 167 |
return ensemble_result
|
| 168 |
|
|
@@ -196,7 +196,7 @@ class EnsembleInference:
|
|
| 196 |
|
| 197 |
# Weighted average
|
| 198 |
result = np.zeros_like(outputs[0], dtype=np.float32)
|
| 199 |
-
for output, weight in zip(outputs, weights):
|
| 200 |
result += output.astype(np.float32) * weight
|
| 201 |
|
| 202 |
return np.clip(result, 0, 255).astype(np.uint8), scores
|
|
@@ -269,8 +269,10 @@ def ensemble_inference(
|
|
| 269 |
for i, output in enumerate(result["outputs"]):
|
| 270 |
cv2.imwrite(str(out / f"sample_{i:02d}.png"), output)
|
| 271 |
score = result["scores"][i]
|
| 272 |
-
print(
|
| 273 |
-
|
|
|
|
|
|
|
| 274 |
|
| 275 |
# Comparison grid
|
| 276 |
panels = [image] + result["outputs"] + [result["output"]]
|
|
@@ -281,8 +283,10 @@ def ensemble_inference(
|
|
| 281 |
|
| 282 |
print(f"\nEnsemble output saved: {out / 'ensemble_output.png'}")
|
| 283 |
if result.get("selected_idx", -1) >= 0:
|
| 284 |
-
print(
|
| 285 |
-
|
|
|
|
|
|
|
| 286 |
|
| 287 |
|
| 288 |
if __name__ == "__main__":
|
|
@@ -294,18 +298,26 @@ if __name__ == "__main__":
|
|
| 294 |
parser.add_argument("--intensity", type=float, default=65.0)
|
| 295 |
parser.add_argument("--output", default="ensemble_output")
|
| 296 |
parser.add_argument("--n_samples", type=int, default=5)
|
| 297 |
-
parser.add_argument(
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
| 301 |
parser.add_argument("--checkpoint", default=None)
|
| 302 |
parser.add_argument("--displacement-model", default=None)
|
| 303 |
parser.add_argument("--seed", type=int, default=42)
|
| 304 |
args = parser.parse_args()
|
| 305 |
|
| 306 |
ensemble_inference(
|
| 307 |
-
args.image,
|
| 308 |
-
args.
|
| 309 |
-
args.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
args.seed,
|
| 311 |
)
|
|
|
|
| 21 |
|
| 22 |
from __future__ import annotations
|
| 23 |
|
|
|
|
|
|
|
| 24 |
import cv2
|
| 25 |
import numpy as np
|
| 26 |
|
|
|
|
| 91 |
guidance_scale: float = 9.0,
|
| 92 |
controlnet_conditioning_scale: float = 0.9,
|
| 93 |
strength: float = 0.5,
|
| 94 |
+
seed: int | None = None,
|
| 95 |
**kwargs,
|
| 96 |
) -> dict:
|
| 97 |
"""Generate ensemble output.
|
|
|
|
| 153 |
# Copy metadata from best result
|
| 154 |
best_idx = selected_idx if selected_idx >= 0 else 0
|
| 155 |
ensemble_result = dict(results[best_idx])
|
| 156 |
+
ensemble_result.update(
|
| 157 |
+
{
|
| 158 |
+
"output": final,
|
| 159 |
+
"outputs": outputs,
|
| 160 |
+
"scores": scores,
|
| 161 |
+
"selected_idx": selected_idx,
|
| 162 |
+
"strategy": self.strategy,
|
| 163 |
+
"n_samples": self.n_samples,
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
|
| 167 |
return ensemble_result
|
| 168 |
|
|
|
|
| 196 |
|
| 197 |
# Weighted average
|
| 198 |
result = np.zeros_like(outputs[0], dtype=np.float32)
|
| 199 |
+
for output, weight in zip(outputs, weights, strict=False):
|
| 200 |
result += output.astype(np.float32) * weight
|
| 201 |
|
| 202 |
return np.clip(result, 0, 255).astype(np.uint8), scores
|
|
|
|
| 269 |
for i, output in enumerate(result["outputs"]):
|
| 270 |
cv2.imwrite(str(out / f"sample_{i:02d}.png"), output)
|
| 271 |
score = result["scores"][i]
|
| 272 |
+
print(
|
| 273 |
+
f" Sample {i}: score={score:.4f}"
|
| 274 |
+
+ (" <-- selected" if i == result.get("selected_idx") else "")
|
| 275 |
+
)
|
| 276 |
|
| 277 |
# Comparison grid
|
| 278 |
panels = [image] + result["outputs"] + [result["output"]]
|
|
|
|
| 283 |
|
| 284 |
print(f"\nEnsemble output saved: {out / 'ensemble_output.png'}")
|
| 285 |
if result.get("selected_idx", -1) >= 0:
|
| 286 |
+
print(
|
| 287 |
+
f"Selected sample: {result['selected_idx']} "
|
| 288 |
+
f"(score={result['scores'][result['selected_idx']]:.4f})"
|
| 289 |
+
)
|
| 290 |
|
| 291 |
|
| 292 |
if __name__ == "__main__":
|
|
|
|
| 298 |
parser.add_argument("--intensity", type=float, default=65.0)
|
| 299 |
parser.add_argument("--output", default="ensemble_output")
|
| 300 |
parser.add_argument("--n_samples", type=int, default=5)
|
| 301 |
+
parser.add_argument(
|
| 302 |
+
"--strategy",
|
| 303 |
+
default="best_of_n",
|
| 304 |
+
choices=["pixel_average", "weighted_average", "best_of_n", "median"],
|
| 305 |
+
)
|
| 306 |
+
parser.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
|
| 307 |
parser.add_argument("--checkpoint", default=None)
|
| 308 |
parser.add_argument("--displacement-model", default=None)
|
| 309 |
parser.add_argument("--seed", type=int, default=42)
|
| 310 |
args = parser.parse_args()
|
| 311 |
|
| 312 |
ensemble_inference(
|
| 313 |
+
args.image,
|
| 314 |
+
args.procedure,
|
| 315 |
+
args.intensity,
|
| 316 |
+
args.output,
|
| 317 |
+
args.n_samples,
|
| 318 |
+
args.strategy,
|
| 319 |
+
args.mode,
|
| 320 |
+
args.checkpoint,
|
| 321 |
+
args.displacement_model,
|
| 322 |
args.seed,
|
| 323 |
)
|
landmarkdiff/evaluation.py
CHANGED
|
@@ -8,6 +8,7 @@ Secondary: SSIM (relaxed target >0.80).
|
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
from dataclasses import dataclass, field
|
|
|
|
| 11 |
|
| 12 |
import numpy as np
|
| 13 |
|
|
@@ -23,7 +24,7 @@ class EvalMetrics:
|
|
| 23 |
|
| 24 |
fid: float = 0.0
|
| 25 |
lpips: float = 0.0
|
| 26 |
-
nme: float = 0.0
|
| 27 |
identity_sim: float = 0.0 # ArcFace cosine similarity
|
| 28 |
ssim: float = 0.0
|
| 29 |
|
|
@@ -154,9 +155,7 @@ def compute_nme(
|
|
| 154 |
Returns:
|
| 155 |
NME value (lower is better).
|
| 156 |
"""
|
| 157 |
-
iod = np.linalg.norm(
|
| 158 |
-
target_landmarks[left_eye_idx] - target_landmarks[right_eye_idx]
|
| 159 |
-
)
|
| 160 |
if iod < 1.0:
|
| 161 |
iod = 1.0
|
| 162 |
|
|
@@ -175,6 +174,7 @@ def compute_ssim(
|
|
| 175 |
"""
|
| 176 |
try:
|
| 177 |
from skimage.metrics import structural_similarity
|
|
|
|
| 178 |
# Convert to grayscale if color, or compute per-channel
|
| 179 |
if pred.ndim == 3 and pred.shape[2] == 3:
|
| 180 |
return float(structural_similarity(pred, target, channel_axis=2, data_range=255))
|
|
@@ -194,10 +194,8 @@ def compute_ssim(
|
|
| 194 |
C1 = (0.01 * 255) ** 2
|
| 195 |
C2 = (0.03 * 255) ** 2
|
| 196 |
|
| 197 |
-
ssim_val = (
|
| 198 |
-
(2
|
| 199 |
-
) / (
|
| 200 |
-
(mu_p ** 2 + mu_t ** 2 + C1) * (sigma_p ** 2 + sigma_t ** 2 + C2)
|
| 201 |
)
|
| 202 |
return float(ssim_val)
|
| 203 |
|
|
@@ -206,11 +204,12 @@ _LPIPS_FN = None
|
|
| 206 |
_ARCFACE_APP = None
|
| 207 |
|
| 208 |
|
| 209 |
-
def _get_lpips_fn():
|
| 210 |
"""Get or create singleton LPIPS model."""
|
| 211 |
global _LPIPS_FN
|
| 212 |
if _LPIPS_FN is None:
|
| 213 |
import lpips
|
|
|
|
| 214 |
_LPIPS_FN = lpips.LPIPS(net="alex", verbose=False)
|
| 215 |
_LPIPS_FN.eval()
|
| 216 |
return _LPIPS_FN
|
|
@@ -225,7 +224,7 @@ def compute_lpips(
|
|
| 225 |
Returns LPIPS score (lower = more similar).
|
| 226 |
"""
|
| 227 |
try:
|
| 228 |
-
import lpips
|
| 229 |
import torch
|
| 230 |
except ImportError:
|
| 231 |
return float("nan")
|
|
@@ -261,9 +260,10 @@ def compute_fid(
|
|
| 261 |
except ImportError:
|
| 262 |
raise ImportError(
|
| 263 |
"torch-fidelity is required for FID. Install with: pip install torch-fidelity"
|
| 264 |
-
)
|
| 265 |
|
| 266 |
import torch
|
|
|
|
| 267 |
metrics = calculate_metrics(
|
| 268 |
input1=generated_dir,
|
| 269 |
input2=real_dir,
|
|
@@ -285,6 +285,7 @@ def compute_identity_similarity(
|
|
| 285 |
"""
|
| 286 |
try:
|
| 287 |
from insightface.app import FaceAnalysis
|
|
|
|
| 288 |
global _ARCFACE_APP
|
| 289 |
if _ARCFACE_APP is None:
|
| 290 |
_ARCFACE_APP = FaceAnalysis(
|
|
|
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
from dataclasses import dataclass, field
|
| 11 |
+
from typing import Any
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
|
|
|
|
| 24 |
|
| 25 |
fid: float = 0.0
|
| 26 |
lpips: float = 0.0
|
| 27 |
+
nme: float = 0.0 # Normalized Mean landmark Error
|
| 28 |
identity_sim: float = 0.0 # ArcFace cosine similarity
|
| 29 |
ssim: float = 0.0
|
| 30 |
|
|
|
|
| 155 |
Returns:
|
| 156 |
NME value (lower is better).
|
| 157 |
"""
|
| 158 |
+
iod = np.linalg.norm(target_landmarks[left_eye_idx] - target_landmarks[right_eye_idx])
|
|
|
|
|
|
|
| 159 |
if iod < 1.0:
|
| 160 |
iod = 1.0
|
| 161 |
|
|
|
|
| 174 |
"""
|
| 175 |
try:
|
| 176 |
from skimage.metrics import structural_similarity
|
| 177 |
+
|
| 178 |
# Convert to grayscale if color, or compute per-channel
|
| 179 |
if pred.ndim == 3 and pred.shape[2] == 3:
|
| 180 |
return float(structural_similarity(pred, target, channel_axis=2, data_range=255))
|
|
|
|
| 194 |
C1 = (0.01 * 255) ** 2
|
| 195 |
C2 = (0.03 * 255) ** 2
|
| 196 |
|
| 197 |
+
ssim_val = ((2 * mu_p * mu_t + C1) * (2 * sigma_pt + C2)) / (
|
| 198 |
+
(mu_p**2 + mu_t**2 + C1) * (sigma_p**2 + sigma_t**2 + C2)
|
|
|
|
|
|
|
| 199 |
)
|
| 200 |
return float(ssim_val)
|
| 201 |
|
|
|
|
| 204 |
_ARCFACE_APP = None
|
| 205 |
|
| 206 |
|
| 207 |
+
def _get_lpips_fn() -> Any:
|
| 208 |
"""Get or create singleton LPIPS model."""
|
| 209 |
global _LPIPS_FN
|
| 210 |
if _LPIPS_FN is None:
|
| 211 |
import lpips
|
| 212 |
+
|
| 213 |
_LPIPS_FN = lpips.LPIPS(net="alex", verbose=False)
|
| 214 |
_LPIPS_FN.eval()
|
| 215 |
return _LPIPS_FN
|
|
|
|
| 224 |
Returns LPIPS score (lower = more similar).
|
| 225 |
"""
|
| 226 |
try:
|
| 227 |
+
import lpips # noqa: F401
|
| 228 |
import torch
|
| 229 |
except ImportError:
|
| 230 |
return float("nan")
|
|
|
|
| 260 |
except ImportError:
|
| 261 |
raise ImportError(
|
| 262 |
"torch-fidelity is required for FID. Install with: pip install torch-fidelity"
|
| 263 |
+
) from None
|
| 264 |
|
| 265 |
import torch
|
| 266 |
+
|
| 267 |
metrics = calculate_metrics(
|
| 268 |
input1=generated_dir,
|
| 269 |
input2=real_dir,
|
|
|
|
| 285 |
"""
|
| 286 |
try:
|
| 287 |
from insightface.app import FaceAnalysis
|
| 288 |
+
|
| 289 |
global _ARCFACE_APP
|
| 290 |
if _ARCFACE_APP is None:
|
| 291 |
_ARCFACE_APP = FaceAnalysis(
|
landmarkdiff/experiment_tracker.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Local experiment tracker for training reproducibility.
|
| 2 |
+
|
| 3 |
+
Tracks all training runs with their configs, metrics, and results.
|
| 4 |
+
Each experiment gets a unique ID and timestamp.
|
| 5 |
+
|
| 6 |
+
Usage::
|
| 7 |
+
|
| 8 |
+
tracker = ExperimentTracker("experiments/")
|
| 9 |
+
|
| 10 |
+
# Start a new experiment
|
| 11 |
+
exp_id = tracker.start(
|
| 12 |
+
name="phaseA_v2",
|
| 13 |
+
config={
|
| 14 |
+
"phase": "A", "lr": 1e-5, "batch": 4,
|
| 15 |
+
"steps": 100000, "data": "training_combined",
|
| 16 |
+
},
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Log metrics during training
|
| 20 |
+
tracker.log_metric(exp_id, step=1000, loss=0.045, ssim=0.82)
|
| 21 |
+
|
| 22 |
+
# Record final results
|
| 23 |
+
tracker.finish(exp_id, results={"fid": 42.3, "ssim": 0.87})
|
| 24 |
+
|
| 25 |
+
# List all experiments
|
| 26 |
+
tracker.list_experiments()
|
| 27 |
+
|
| 28 |
+
# Compare experiments
|
| 29 |
+
tracker.compare(["exp_001", "exp_002"])
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import json
|
| 35 |
+
import os
|
| 36 |
+
import socket
|
| 37 |
+
import time
|
| 38 |
+
from datetime import datetime
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ExperimentTracker:
|
| 43 |
+
"""Simple file-based experiment tracker."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, experiments_dir: str = "experiments"):
|
| 46 |
+
self.dir = Path(experiments_dir)
|
| 47 |
+
self.dir.mkdir(parents=True, exist_ok=True)
|
| 48 |
+
self._index_path = self.dir / "index.json"
|
| 49 |
+
self._index = self._load_index()
|
| 50 |
+
|
| 51 |
+
def _load_index(self) -> dict:
|
| 52 |
+
if self._index_path.exists():
|
| 53 |
+
with open(self._index_path) as f:
|
| 54 |
+
return json.load(f)
|
| 55 |
+
return {"experiments": {}, "counter": 0}
|
| 56 |
+
|
| 57 |
+
def _save_index(self) -> None:
|
| 58 |
+
with open(self._index_path, "w") as f:
|
| 59 |
+
json.dump(self._index, f, indent=2)
|
| 60 |
+
|
| 61 |
+
def start(
|
| 62 |
+
self,
|
| 63 |
+
name: str,
|
| 64 |
+
config: dict,
|
| 65 |
+
tags: list[str] | None = None,
|
| 66 |
+
) -> str:
|
| 67 |
+
"""Start a new experiment. Returns experiment ID."""
|
| 68 |
+
self._index["counter"] += 1
|
| 69 |
+
exp_id = f"exp_{self._index['counter']:03d}"
|
| 70 |
+
|
| 71 |
+
exp = {
|
| 72 |
+
"id": exp_id,
|
| 73 |
+
"name": name,
|
| 74 |
+
"config": config,
|
| 75 |
+
"tags": tags or [],
|
| 76 |
+
"status": "running",
|
| 77 |
+
"started_at": datetime.now().isoformat(),
|
| 78 |
+
"finished_at": None,
|
| 79 |
+
"hostname": socket.gethostname(),
|
| 80 |
+
"slurm_job_id": os.environ.get("SLURM_JOB_ID"),
|
| 81 |
+
"gpu": os.environ.get("CUDA_VISIBLE_DEVICES"),
|
| 82 |
+
"results": {},
|
| 83 |
+
"metrics_file": f"{exp_id}_metrics.jsonl",
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
self._index["experiments"][exp_id] = exp
|
| 87 |
+
self._save_index()
|
| 88 |
+
|
| 89 |
+
# Create metrics log file
|
| 90 |
+
metrics_path = self.dir / str(exp["metrics_file"])
|
| 91 |
+
metrics_path.touch()
|
| 92 |
+
|
| 93 |
+
print(f"Experiment started: {exp_id} ({name})")
|
| 94 |
+
return exp_id
|
| 95 |
+
|
| 96 |
+
def log_metric(self, exp_id: str, step: int | None = None, **metrics) -> None:
|
| 97 |
+
"""Log metrics for a training step."""
|
| 98 |
+
exp = self._index["experiments"].get(exp_id)
|
| 99 |
+
if not exp:
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
entry = {
|
| 103 |
+
"timestamp": time.time(),
|
| 104 |
+
"step": step,
|
| 105 |
+
**metrics,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
metrics_path = self.dir / str(exp["metrics_file"])
|
| 109 |
+
with open(metrics_path, "a") as f:
|
| 110 |
+
f.write(json.dumps(entry) + "\n")
|
| 111 |
+
|
| 112 |
+
def finish(
|
| 113 |
+
self,
|
| 114 |
+
exp_id: str,
|
| 115 |
+
results: dict | None = None,
|
| 116 |
+
status: str = "completed",
|
| 117 |
+
) -> None:
|
| 118 |
+
"""Mark experiment as finished."""
|
| 119 |
+
exp = self._index["experiments"].get(exp_id)
|
| 120 |
+
if not exp:
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
exp["status"] = status
|
| 124 |
+
exp["finished_at"] = datetime.now().isoformat()
|
| 125 |
+
if results:
|
| 126 |
+
exp["results"] = results
|
| 127 |
+
|
| 128 |
+
self._save_index()
|
| 129 |
+
print(f"Experiment {exp_id} {status}")
|
| 130 |
+
|
| 131 |
+
def get_metrics(self, exp_id: str) -> list[dict]:
|
| 132 |
+
"""Load all logged metrics for an experiment."""
|
| 133 |
+
exp = self._index["experiments"].get(exp_id)
|
| 134 |
+
if not exp:
|
| 135 |
+
return []
|
| 136 |
+
|
| 137 |
+
metrics_path = self.dir / str(exp["metrics_file"])
|
| 138 |
+
if not metrics_path.exists():
|
| 139 |
+
return []
|
| 140 |
+
|
| 141 |
+
entries = []
|
| 142 |
+
with open(metrics_path) as f:
|
| 143 |
+
for line in f:
|
| 144 |
+
line = line.strip()
|
| 145 |
+
if line:
|
| 146 |
+
entries.append(json.loads(line))
|
| 147 |
+
return entries
|
| 148 |
+
|
| 149 |
+
def list_experiments(self) -> list[dict]:
|
| 150 |
+
"""List all experiments with summary info."""
|
| 151 |
+
experiments = []
|
| 152 |
+
for exp_id, exp in sorted(self._index["experiments"].items()):
|
| 153 |
+
summary = {
|
| 154 |
+
"id": exp_id,
|
| 155 |
+
"name": exp["name"],
|
| 156 |
+
"status": exp["status"],
|
| 157 |
+
"started": exp["started_at"][:19],
|
| 158 |
+
"tags": exp.get("tags", []),
|
| 159 |
+
}
|
| 160 |
+
if exp["results"]:
|
| 161 |
+
for key in ["fid", "ssim", "lpips", "nme"]:
|
| 162 |
+
if key in exp["results"]:
|
| 163 |
+
summary[key] = exp["results"][key]
|
| 164 |
+
experiments.append(summary)
|
| 165 |
+
return experiments
|
| 166 |
+
|
| 167 |
+
def compare(self, exp_ids: list[str]) -> dict:
|
| 168 |
+
"""Compare multiple experiments by their results."""
|
| 169 |
+
comparison = {}
|
| 170 |
+
for exp_id in exp_ids:
|
| 171 |
+
exp = self._index["experiments"].get(exp_id)
|
| 172 |
+
if exp:
|
| 173 |
+
comparison[exp_id] = {
|
| 174 |
+
"name": exp["name"],
|
| 175 |
+
"config": exp["config"],
|
| 176 |
+
"results": exp["results"],
|
| 177 |
+
}
|
| 178 |
+
return comparison
|
| 179 |
+
|
| 180 |
+
def print_summary(self) -> None:
|
| 181 |
+
"""Print a summary table of all experiments."""
|
| 182 |
+
experiments = self.list_experiments()
|
| 183 |
+
if not experiments:
|
| 184 |
+
print("No experiments found.")
|
| 185 |
+
return
|
| 186 |
+
|
| 187 |
+
# Header
|
| 188 |
+
print(f"{'ID':<10} {'Name':<20} {'Status':<12} {'FID':>6} {'SSIM':>6} {'LPIPS':>6}")
|
| 189 |
+
print("-" * 70)
|
| 190 |
+
|
| 191 |
+
for exp in experiments:
|
| 192 |
+
fid = f"{exp.get('fid', '')}" if "fid" in exp else "--"
|
| 193 |
+
ssim = f"{exp.get('ssim', ''):.4f}" if "ssim" in exp else "--"
|
| 194 |
+
lpips = f"{exp.get('lpips', ''):.4f}" if "lpips" in exp else "--"
|
| 195 |
+
print(
|
| 196 |
+
f"{exp['id']:<10} {exp['name']:<20}"
|
| 197 |
+
f" {exp['status']:<12} {fid:>6} {ssim:>6} {lpips:>6}"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def get_best(self, metric: str = "fid", lower_is_better: bool = True) -> str | None:
|
| 201 |
+
"""Get the experiment ID with the best value for a given metric."""
|
| 202 |
+
best_id = None
|
| 203 |
+
best_val = float("inf") if lower_is_better else float("-inf")
|
| 204 |
+
|
| 205 |
+
for exp_id, exp in self._index["experiments"].items():
|
| 206 |
+
if exp["status"] != "completed":
|
| 207 |
+
continue
|
| 208 |
+
val = exp["results"].get(metric)
|
| 209 |
+
if val is None:
|
| 210 |
+
continue
|
| 211 |
+
if (lower_is_better and val < best_val) or (not lower_is_better and val > best_val):
|
| 212 |
+
best_val = val
|
| 213 |
+
best_id = exp_id
|
| 214 |
+
|
| 215 |
+
return best_id
|
landmarkdiff/face_verifier.py
CHANGED
|
@@ -15,19 +15,18 @@ Designed for:
|
|
| 15 |
|
| 16 |
from __future__ import annotations
|
| 17 |
|
| 18 |
-
import os
|
| 19 |
from dataclasses import dataclass, field
|
| 20 |
from pathlib import Path
|
| 21 |
-
from typing import
|
| 22 |
|
| 23 |
import cv2
|
| 24 |
import numpy as np
|
| 25 |
|
| 26 |
-
|
| 27 |
# ---------------------------------------------------------------------------
|
| 28 |
# Data structures
|
| 29 |
# ---------------------------------------------------------------------------
|
| 30 |
|
|
|
|
| 31 |
@dataclass
|
| 32 |
class DistortionReport:
|
| 33 |
"""Analysis of detected distortions in a face image."""
|
|
@@ -36,13 +35,13 @@ class DistortionReport:
|
|
| 36 |
quality_score: float = 0.0
|
| 37 |
|
| 38 |
# Individual distortion scores (0-1, higher = more distorted)
|
| 39 |
-
blur_score: float = 0.0
|
| 40 |
-
noise_score: float = 0.0
|
| 41 |
-
compression_score: float = 0.0
|
| 42 |
-
oversmooth_score: float = 0.0
|
| 43 |
-
color_cast_score: float = 0.0
|
| 44 |
-
geometric_distort: float = 0.0
|
| 45 |
-
lighting_score: float = 0.0
|
| 46 |
|
| 47 |
# Classification
|
| 48 |
primary_distortion: str = "none"
|
|
@@ -74,14 +73,14 @@ class DistortionReport:
|
|
| 74 |
class RestorationResult:
|
| 75 |
"""Result of neural face restoration pipeline."""
|
| 76 |
|
| 77 |
-
restored: np.ndarray
|
| 78 |
-
original: np.ndarray
|
| 79 |
-
distortion_report: DistortionReport
|
| 80 |
-
post_quality_score: float = 0.0
|
| 81 |
-
identity_similarity: float = 0.0
|
| 82 |
-
identity_preserved: bool = True
|
| 83 |
restoration_stages: list[str] = field(default_factory=list) # Which nets ran
|
| 84 |
-
improvement: float = 0.0
|
| 85 |
|
| 86 |
def summary(self) -> str:
|
| 87 |
lines = [
|
|
@@ -100,9 +99,9 @@ class BatchVerificationReport:
|
|
| 100 |
"""Summary of batch face verification/restoration."""
|
| 101 |
|
| 102 |
total: int = 0
|
| 103 |
-
passed: int = 0
|
| 104 |
-
restored: int = 0
|
| 105 |
-
rejected: int = 0
|
| 106 |
identity_failures: int = 0 # Restoration changed identity
|
| 107 |
avg_quality_before: float = 0.0
|
| 108 |
avg_quality_after: float = 0.0
|
|
@@ -123,7 +122,8 @@ class BatchVerificationReport:
|
|
| 123 |
"Distortion Breakdown:",
|
| 124 |
]
|
| 125 |
for dist_type, count in sorted(
|
| 126 |
-
self.distortion_counts.items(),
|
|
|
|
| 127 |
):
|
| 128 |
lines.append(f" {dist_type}: {count}")
|
| 129 |
return "\n".join(lines)
|
|
@@ -133,6 +133,7 @@ class BatchVerificationReport:
|
|
| 133 |
# Distortion Detection (classical + neural)
|
| 134 |
# ---------------------------------------------------------------------------
|
| 135 |
|
|
|
|
| 136 |
def detect_blur(image: np.ndarray) -> float:
|
| 137 |
"""Detect blur using Laplacian variance.
|
| 138 |
|
|
@@ -147,7 +148,7 @@ def detect_blur(image: np.ndarray) -> float:
|
|
| 147 |
# Gradient magnitude (secondary)
|
| 148 |
gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
|
| 149 |
gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 150 |
-
grad_mag = np.sqrt(gx
|
| 151 |
|
| 152 |
# Normalize: typical sharp face has lap_var > 500, grad_mag > 30
|
| 153 |
blur_lap = 1.0 - min(lap_var / 800.0, 1.0)
|
|
@@ -221,7 +222,7 @@ def detect_oversmoothing(image: np.ndarray) -> float:
|
|
| 221 |
# Focus on face center region (avoid background)
|
| 222 |
if h < 8 or w < 8:
|
| 223 |
return 0.0 # Too small to analyze
|
| 224 |
-
roi = gray[h // 4:3 * h // 4, w // 4:3 * w // 4]
|
| 225 |
|
| 226 |
# Texture energy: variance of high-pass filtered image
|
| 227 |
blurred = cv2.GaussianBlur(roi.astype(np.float64), (0, 0), 2.0)
|
|
@@ -254,7 +255,7 @@ def detect_color_cast(image: np.ndarray) -> float:
|
|
| 254 |
h, w = image.shape[:2]
|
| 255 |
|
| 256 |
# Sample face center region
|
| 257 |
-
roi = lab[h // 4:3 * h // 4, w // 4:3 * w // 4]
|
| 258 |
|
| 259 |
# A channel: green-red axis (neutral ~128)
|
| 260 |
# B channel: blue-yellow axis (neutral ~128)
|
|
@@ -337,7 +338,7 @@ def detect_lighting_issues(image: np.ndarray) -> float:
|
|
| 337 |
|
| 338 |
# Check for clipping
|
| 339 |
overexposed = np.mean(l_channel > 245) * 5 # Fraction near white
|
| 340 |
-
underexposed = np.mean(l_channel < 10) * 5
|
| 341 |
|
| 342 |
# Check for bimodal distribution (harsh shadows)
|
| 343 |
hist = cv2.calcHist([l_channel], [0], None, [256], [0, 256]).flatten()
|
|
@@ -429,7 +430,7 @@ def analyze_distortions(image: np.ndarray) -> DistortionReport:
|
|
| 429 |
_FACE_QUALITY_NET = None
|
| 430 |
|
| 431 |
|
| 432 |
-
def _get_face_quality_scorer():
|
| 433 |
"""Get or create singleton face quality assessment model.
|
| 434 |
|
| 435 |
Uses FaceXLib's quality scorer or falls back to BRISQUE-style features.
|
|
@@ -440,6 +441,7 @@ def _get_face_quality_scorer():
|
|
| 440 |
|
| 441 |
try:
|
| 442 |
from facexlib.assessment import init_assessment_model
|
|
|
|
| 443 |
_FACE_QUALITY_NET = init_assessment_model("hypernet")
|
| 444 |
return _FACE_QUALITY_NET
|
| 445 |
except Exception:
|
|
@@ -461,6 +463,7 @@ def neural_quality_score(image: np.ndarray) -> float:
|
|
| 461 |
try:
|
| 462 |
import torch
|
| 463 |
from facexlib.utils import img2tensor
|
|
|
|
| 464 |
img_t = img2tensor(image / 255.0, bgr2rgb=True, float32=True)
|
| 465 |
img_t = img_t.unsqueeze(0)
|
| 466 |
if torch.cuda.is_available():
|
|
@@ -481,6 +484,7 @@ def neural_quality_score(image: np.ndarray) -> float:
|
|
| 481 |
# Neural Face Restoration (cascaded)
|
| 482 |
# ---------------------------------------------------------------------------
|
| 483 |
|
|
|
|
| 484 |
def restore_face(
|
| 485 |
image: np.ndarray,
|
| 486 |
distortion: DistortionReport | None = None,
|
|
@@ -563,6 +567,7 @@ def restore_face(
|
|
| 563 |
post_blur = detect_blur(result)
|
| 564 |
if post_blur > 0.3:
|
| 565 |
from landmarkdiff.postprocess import frequency_aware_sharpen
|
|
|
|
| 566 |
result = frequency_aware_sharpen(result, strength=0.3)
|
| 567 |
stages.append("sharpen")
|
| 568 |
|
|
@@ -573,6 +578,7 @@ def _try_codeformer(image: np.ndarray, fidelity: float = 0.7) -> np.ndarray | No
|
|
| 573 |
"""Try CodeFormer restoration. Returns None if unavailable."""
|
| 574 |
try:
|
| 575 |
from landmarkdiff.postprocess import restore_face_codeformer
|
|
|
|
| 576 |
restored = restore_face_codeformer(image, fidelity=fidelity)
|
| 577 |
if restored is not image:
|
| 578 |
return restored
|
|
@@ -585,6 +591,7 @@ def _try_gfpgan(image: np.ndarray) -> np.ndarray | None:
|
|
| 585 |
"""Try GFPGAN restoration. Returns None if unavailable."""
|
| 586 |
try:
|
| 587 |
from landmarkdiff.postprocess import restore_face_gfpgan
|
|
|
|
| 588 |
restored = restore_face_gfpgan(image)
|
| 589 |
if restored is not image:
|
| 590 |
return restored
|
|
@@ -599,15 +606,19 @@ _FV_REALESRGAN = None
|
|
| 599 |
def _try_realesrgan(image: np.ndarray) -> np.ndarray | None:
|
| 600 |
"""Try Real-ESRGAN 2x upscale + downsample. Returns None if unavailable."""
|
| 601 |
try:
|
| 602 |
-
from realesrgan import RealESRGANer
|
| 603 |
-
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 604 |
import torch
|
|
|
|
|
|
|
| 605 |
|
| 606 |
global _FV_REALESRGAN
|
| 607 |
if _FV_REALESRGAN is None:
|
| 608 |
model = RRDBNet(
|
| 609 |
-
num_in_ch=3,
|
| 610 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
)
|
| 612 |
_FV_REALESRGAN = RealESRGANer(
|
| 613 |
scale=4,
|
|
@@ -661,15 +672,15 @@ def _fix_lighting(image: np.ndarray) -> np.ndarray:
|
|
| 661 |
_ARCFACE_APP = None
|
| 662 |
|
| 663 |
|
| 664 |
-
def _get_arcface():
|
| 665 |
"""Get or create singleton ArcFace model."""
|
| 666 |
global _ARCFACE_APP
|
| 667 |
if _ARCFACE_APP is not None:
|
| 668 |
return _ARCFACE_APP
|
| 669 |
|
| 670 |
try:
|
| 671 |
-
from insightface.app import FaceAnalysis
|
| 672 |
import torch
|
|
|
|
| 673 |
|
| 674 |
app = FaceAnalysis(
|
| 675 |
name="buffalo_l",
|
|
@@ -717,9 +728,9 @@ def verify_identity(
|
|
| 717 |
if emb_orig is None or emb_rest is None:
|
| 718 |
return -1.0, True # Can't verify — assume OK
|
| 719 |
|
| 720 |
-
sim = float(
|
| 721 |
-
np.linalg.norm(emb_orig) * np.linalg.norm(emb_rest) + 1e-8
|
| 722 |
-
)
|
| 723 |
sim = float(np.clip(sim, -1, 1))
|
| 724 |
return sim, sim >= threshold
|
| 725 |
|
|
@@ -728,6 +739,7 @@ def verify_identity(
|
|
| 728 |
# Full Verification + Restoration Pipeline
|
| 729 |
# ---------------------------------------------------------------------------
|
| 730 |
|
|
|
|
| 731 |
def verify_and_restore(
|
| 732 |
image: np.ndarray,
|
| 733 |
quality_threshold: float = 60.0,
|
|
@@ -813,6 +825,7 @@ def verify_and_restore(
|
|
| 813 |
# Batch Processing
|
| 814 |
# ---------------------------------------------------------------------------
|
| 815 |
|
|
|
|
| 816 |
def verify_batch(
|
| 817 |
image_dir: str,
|
| 818 |
output_dir: str | None = None,
|
|
@@ -858,10 +871,9 @@ def verify_batch(
|
|
| 858 |
rejected_dir.mkdir(parents=True, exist_ok=True)
|
| 859 |
|
| 860 |
# Find all images
|
| 861 |
-
image_files = sorted(
|
| 862 |
-
f for f in image_path.iterdir()
|
| 863 |
-
|
| 864 |
-
])
|
| 865 |
|
| 866 |
report = BatchVerificationReport(total=len(image_files))
|
| 867 |
quality_before = []
|
|
|
|
| 15 |
|
| 16 |
from __future__ import annotations
|
| 17 |
|
|
|
|
| 18 |
from dataclasses import dataclass, field
|
| 19 |
from pathlib import Path
|
| 20 |
+
from typing import Any
|
| 21 |
|
| 22 |
import cv2
|
| 23 |
import numpy as np
|
| 24 |
|
|
|
|
| 25 |
# ---------------------------------------------------------------------------
|
| 26 |
# Data structures
|
| 27 |
# ---------------------------------------------------------------------------
|
| 28 |
|
| 29 |
+
|
| 30 |
@dataclass
|
| 31 |
class DistortionReport:
|
| 32 |
"""Analysis of detected distortions in a face image."""
|
|
|
|
| 35 |
quality_score: float = 0.0
|
| 36 |
|
| 37 |
# Individual distortion scores (0-1, higher = more distorted)
|
| 38 |
+
blur_score: float = 0.0 # Laplacian variance-based
|
| 39 |
+
noise_score: float = 0.0 # High-freq energy ratio
|
| 40 |
+
compression_score: float = 0.0 # JPEG block artifact detection
|
| 41 |
+
oversmooth_score: float = 0.0 # Beauty filter / airbrushed detection
|
| 42 |
+
color_cast_score: float = 0.0 # Unnatural color shift
|
| 43 |
+
geometric_distort: float = 0.0 # Face proportion anomalies
|
| 44 |
+
lighting_score: float = 0.0 # Over/under exposure
|
| 45 |
|
| 46 |
# Classification
|
| 47 |
primary_distortion: str = "none"
|
|
|
|
| 73 |
class RestorationResult:
|
| 74 |
"""Result of neural face restoration pipeline."""
|
| 75 |
|
| 76 |
+
restored: np.ndarray # Restored BGR image
|
| 77 |
+
original: np.ndarray # Original BGR image
|
| 78 |
+
distortion_report: DistortionReport # Pre-restoration analysis
|
| 79 |
+
post_quality_score: float = 0.0 # Quality after restoration
|
| 80 |
+
identity_similarity: float = 0.0 # ArcFace cosine sim (original vs restored)
|
| 81 |
+
identity_preserved: bool = True # Whether identity check passed
|
| 82 |
restoration_stages: list[str] = field(default_factory=list) # Which nets ran
|
| 83 |
+
improvement: float = 0.0 # quality_after - quality_before
|
| 84 |
|
| 85 |
def summary(self) -> str:
|
| 86 |
lines = [
|
|
|
|
| 99 |
"""Summary of batch face verification/restoration."""
|
| 100 |
|
| 101 |
total: int = 0
|
| 102 |
+
passed: int = 0 # Good quality, no fix needed
|
| 103 |
+
restored: int = 0 # Fixed and now usable
|
| 104 |
+
rejected: int = 0 # Too distorted to salvage
|
| 105 |
identity_failures: int = 0 # Restoration changed identity
|
| 106 |
avg_quality_before: float = 0.0
|
| 107 |
avg_quality_after: float = 0.0
|
|
|
|
| 122 |
"Distortion Breakdown:",
|
| 123 |
]
|
| 124 |
for dist_type, count in sorted(
|
| 125 |
+
self.distortion_counts.items(),
|
| 126 |
+
key=lambda x: -x[1],
|
| 127 |
):
|
| 128 |
lines.append(f" {dist_type}: {count}")
|
| 129 |
return "\n".join(lines)
|
|
|
|
| 133 |
# Distortion Detection (classical + neural)
|
| 134 |
# ---------------------------------------------------------------------------
|
| 135 |
|
| 136 |
+
|
| 137 |
def detect_blur(image: np.ndarray) -> float:
|
| 138 |
"""Detect blur using Laplacian variance.
|
| 139 |
|
|
|
|
| 148 |
# Gradient magnitude (secondary)
|
| 149 |
gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
|
| 150 |
gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 151 |
+
grad_mag = np.sqrt(gx**2 + gy**2).mean()
|
| 152 |
|
| 153 |
# Normalize: typical sharp face has lap_var > 500, grad_mag > 30
|
| 154 |
blur_lap = 1.0 - min(lap_var / 800.0, 1.0)
|
|
|
|
| 222 |
# Focus on face center region (avoid background)
|
| 223 |
if h < 8 or w < 8:
|
| 224 |
return 0.0 # Too small to analyze
|
| 225 |
+
roi = gray[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
|
| 226 |
|
| 227 |
# Texture energy: variance of high-pass filtered image
|
| 228 |
blurred = cv2.GaussianBlur(roi.astype(np.float64), (0, 0), 2.0)
|
|
|
|
| 255 |
h, w = image.shape[:2]
|
| 256 |
|
| 257 |
# Sample face center region
|
| 258 |
+
roi = lab[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
|
| 259 |
|
| 260 |
# A channel: green-red axis (neutral ~128)
|
| 261 |
# B channel: blue-yellow axis (neutral ~128)
|
|
|
|
| 338 |
|
| 339 |
# Check for clipping
|
| 340 |
overexposed = np.mean(l_channel > 245) * 5 # Fraction near white
|
| 341 |
+
underexposed = np.mean(l_channel < 10) * 5 # Fraction near black
|
| 342 |
|
| 343 |
# Check for bimodal distribution (harsh shadows)
|
| 344 |
hist = cv2.calcHist([l_channel], [0], None, [256], [0, 256]).flatten()
|
|
|
|
| 430 |
_FACE_QUALITY_NET = None
|
| 431 |
|
| 432 |
|
| 433 |
+
def _get_face_quality_scorer() -> Any:
|
| 434 |
"""Get or create singleton face quality assessment model.
|
| 435 |
|
| 436 |
Uses FaceXLib's quality scorer or falls back to BRISQUE-style features.
|
|
|
|
| 441 |
|
| 442 |
try:
|
| 443 |
from facexlib.assessment import init_assessment_model
|
| 444 |
+
|
| 445 |
_FACE_QUALITY_NET = init_assessment_model("hypernet")
|
| 446 |
return _FACE_QUALITY_NET
|
| 447 |
except Exception:
|
|
|
|
| 463 |
try:
|
| 464 |
import torch
|
| 465 |
from facexlib.utils import img2tensor
|
| 466 |
+
|
| 467 |
img_t = img2tensor(image / 255.0, bgr2rgb=True, float32=True)
|
| 468 |
img_t = img_t.unsqueeze(0)
|
| 469 |
if torch.cuda.is_available():
|
|
|
|
| 484 |
# Neural Face Restoration (cascaded)
|
| 485 |
# ---------------------------------------------------------------------------
|
| 486 |
|
| 487 |
+
|
| 488 |
def restore_face(
|
| 489 |
image: np.ndarray,
|
| 490 |
distortion: DistortionReport | None = None,
|
|
|
|
| 567 |
post_blur = detect_blur(result)
|
| 568 |
if post_blur > 0.3:
|
| 569 |
from landmarkdiff.postprocess import frequency_aware_sharpen
|
| 570 |
+
|
| 571 |
result = frequency_aware_sharpen(result, strength=0.3)
|
| 572 |
stages.append("sharpen")
|
| 573 |
|
|
|
|
| 578 |
"""Try CodeFormer restoration. Returns None if unavailable."""
|
| 579 |
try:
|
| 580 |
from landmarkdiff.postprocess import restore_face_codeformer
|
| 581 |
+
|
| 582 |
restored = restore_face_codeformer(image, fidelity=fidelity)
|
| 583 |
if restored is not image:
|
| 584 |
return restored
|
|
|
|
| 591 |
"""Try GFPGAN restoration. Returns None if unavailable."""
|
| 592 |
try:
|
| 593 |
from landmarkdiff.postprocess import restore_face_gfpgan
|
| 594 |
+
|
| 595 |
restored = restore_face_gfpgan(image)
|
| 596 |
if restored is not image:
|
| 597 |
return restored
|
|
|
|
| 606 |
def _try_realesrgan(image: np.ndarray) -> np.ndarray | None:
|
| 607 |
"""Try Real-ESRGAN 2x upscale + downsample. Returns None if unavailable."""
|
| 608 |
try:
|
|
|
|
|
|
|
| 609 |
import torch
|
| 610 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 611 |
+
from realesrgan import RealESRGANer
|
| 612 |
|
| 613 |
global _FV_REALESRGAN
|
| 614 |
if _FV_REALESRGAN is None:
|
| 615 |
model = RRDBNet(
|
| 616 |
+
num_in_ch=3,
|
| 617 |
+
num_out_ch=3,
|
| 618 |
+
num_feat=64,
|
| 619 |
+
num_block=23,
|
| 620 |
+
num_grow_ch=32,
|
| 621 |
+
scale=4,
|
| 622 |
)
|
| 623 |
_FV_REALESRGAN = RealESRGANer(
|
| 624 |
scale=4,
|
|
|
|
| 672 |
_ARCFACE_APP = None
|
| 673 |
|
| 674 |
|
| 675 |
+
def _get_arcface() -> Any:
|
| 676 |
"""Get or create singleton ArcFace model."""
|
| 677 |
global _ARCFACE_APP
|
| 678 |
if _ARCFACE_APP is not None:
|
| 679 |
return _ARCFACE_APP
|
| 680 |
|
| 681 |
try:
|
|
|
|
| 682 |
import torch
|
| 683 |
+
from insightface.app import FaceAnalysis
|
| 684 |
|
| 685 |
app = FaceAnalysis(
|
| 686 |
name="buffalo_l",
|
|
|
|
| 728 |
if emb_orig is None or emb_rest is None:
|
| 729 |
return -1.0, True # Can't verify — assume OK
|
| 730 |
|
| 731 |
+
sim = float(
|
| 732 |
+
np.dot(emb_orig, emb_rest) / (np.linalg.norm(emb_orig) * np.linalg.norm(emb_rest) + 1e-8)
|
| 733 |
+
)
|
| 734 |
sim = float(np.clip(sim, -1, 1))
|
| 735 |
return sim, sim >= threshold
|
| 736 |
|
|
|
|
| 739 |
# Full Verification + Restoration Pipeline
|
| 740 |
# ---------------------------------------------------------------------------
|
| 741 |
|
| 742 |
+
|
| 743 |
def verify_and_restore(
|
| 744 |
image: np.ndarray,
|
| 745 |
quality_threshold: float = 60.0,
|
|
|
|
| 825 |
# Batch Processing
|
| 826 |
# ---------------------------------------------------------------------------
|
| 827 |
|
| 828 |
+
|
| 829 |
def verify_batch(
|
| 830 |
image_dir: str,
|
| 831 |
output_dir: str | None = None,
|
|
|
|
| 871 |
rejected_dir.mkdir(parents=True, exist_ok=True)
|
| 872 |
|
| 873 |
# Find all images
|
| 874 |
+
image_files = sorted(
|
| 875 |
+
[f for f in image_path.iterdir() if f.suffix.lower() in extensions and f.is_file()]
|
| 876 |
+
)
|
|
|
|
| 877 |
|
| 878 |
report = BatchVerificationReport(total=len(image_files))
|
| 879 |
quality_before = []
|
landmarkdiff/fid.py
CHANGED
|
@@ -16,6 +16,7 @@ Usage:
|
|
| 16 |
from __future__ import annotations
|
| 17 |
|
| 18 |
from pathlib import Path
|
|
|
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
|
|
@@ -23,14 +24,15 @@ try:
|
|
| 23 |
import torch
|
| 24 |
import torch.nn as nn
|
| 25 |
from torch.utils.data import DataLoader, Dataset
|
|
|
|
| 26 |
HAS_TORCH = True
|
| 27 |
except ImportError:
|
| 28 |
HAS_TORCH = False
|
| 29 |
|
| 30 |
|
| 31 |
-
def _load_inception_v3():
|
| 32 |
"""Load InceptionV3 with pool3 features (2048-dim)."""
|
| 33 |
-
from torchvision.models import
|
| 34 |
|
| 35 |
model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
|
| 36 |
# We want features from the avg pool layer (2048-dim)
|
|
@@ -40,79 +42,81 @@ def _load_inception_v3():
|
|
| 40 |
return model
|
| 41 |
|
| 42 |
|
| 43 |
-
class
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
t = _imagenet_normalize(t)
|
| 69 |
-
return t
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class NumpyArrayDataset(Dataset):
|
| 73 |
-
"""Dataset wrapping a list of numpy arrays."""
|
| 74 |
-
|
| 75 |
-
def __init__(self, images: list[np.ndarray], image_size: int = 299):
|
| 76 |
-
self.images = images
|
| 77 |
-
self.image_size = image_size
|
| 78 |
-
|
| 79 |
-
def __len__(self):
|
| 80 |
-
return len(self.images)
|
| 81 |
-
|
| 82 |
-
def __getitem__(self, idx):
|
| 83 |
-
import cv2
|
| 84 |
-
img = self.images[idx]
|
| 85 |
-
if img.shape[:2] != (self.image_size, self.image_size):
|
| 86 |
img = cv2.resize(img, (self.image_size, self.image_size))
|
| 87 |
-
if img.shape[2] == 3:
|
| 88 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
@@ -123,8 +127,10 @@ def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
| 123 |
|
| 124 |
|
| 125 |
def _calculate_fid(
|
| 126 |
-
mu1: np.ndarray,
|
| 127 |
-
|
|
|
|
|
|
|
| 128 |
) -> float:
|
| 129 |
"""Calculate FID given two sets of statistics.
|
| 130 |
|
|
@@ -177,10 +183,10 @@ def compute_fid_from_dirs(
|
|
| 177 |
if len(real_ds) == 0 or len(gen_ds) == 0:
|
| 178 |
raise ValueError("Need at least 1 image in each directory")
|
| 179 |
|
| 180 |
-
real_loader = DataLoader(
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
|
| 185 |
real_features = _extract_features(model, real_loader, dev)
|
| 186 |
gen_features = _extract_features(model, gen_loader, dev)
|
|
|
|
| 16 |
from __future__ import annotations
|
| 17 |
|
| 18 |
from pathlib import Path
|
| 19 |
+
from typing import Any
|
| 20 |
|
| 21 |
import numpy as np
|
| 22 |
|
|
|
|
| 24 |
import torch
|
| 25 |
import torch.nn as nn
|
| 26 |
from torch.utils.data import DataLoader, Dataset
|
| 27 |
+
|
| 28 |
HAS_TORCH = True
|
| 29 |
except ImportError:
|
| 30 |
HAS_TORCH = False
|
| 31 |
|
| 32 |
|
| 33 |
+
def _load_inception_v3() -> Any:
|
| 34 |
"""Load InceptionV3 with pool3 features (2048-dim)."""
|
| 35 |
+
from torchvision.models import Inception_V3_Weights, inception_v3
|
| 36 |
|
| 37 |
model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
|
| 38 |
# We want features from the avg pool layer (2048-dim)
|
|
|
|
| 42 |
return model
|
| 43 |
|
| 44 |
|
| 45 |
+
# Guard torch-dependent class and function definitions so the module
|
| 46 |
+
# can be imported safely when torch is not installed.
|
| 47 |
+
if HAS_TORCH:
|
| 48 |
+
|
| 49 |
+
class ImageFolderDataset(Dataset): # type: ignore[misc]
|
| 50 |
+
"""Simple dataset that loads images from a directory."""
|
| 51 |
+
|
| 52 |
+
def __init__(self, directory: str | Path, image_size: int = 299):
|
| 53 |
+
self.directory = Path(directory)
|
| 54 |
+
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
|
| 55 |
+
self.files = sorted(
|
| 56 |
+
f for f in self.directory.iterdir() if f.suffix.lower() in exts and f.is_file()
|
| 57 |
+
)
|
| 58 |
+
self.image_size = image_size
|
| 59 |
+
|
| 60 |
+
def __len__(self) -> int:
|
| 61 |
+
return len(self.files)
|
| 62 |
+
|
| 63 |
+
def __getitem__(self, idx: int) -> Any:
|
| 64 |
+
import cv2
|
| 65 |
+
|
| 66 |
+
img = cv2.imread(str(self.files[idx]))
|
| 67 |
+
if img is None:
|
| 68 |
+
# Return zeros if image can't be loaded
|
| 69 |
+
return torch.zeros(3, self.image_size, self.image_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
img = cv2.resize(img, (self.image_size, self.image_size))
|
|
|
|
| 71 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 72 |
+
# Normalize to [0, 1] then ImageNet normalize
|
| 73 |
+
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
|
| 74 |
+
t = _imagenet_normalize(t)
|
| 75 |
+
return t
|
| 76 |
+
|
| 77 |
+
class NumpyArrayDataset(Dataset): # type: ignore[misc]
|
| 78 |
+
"""Dataset wrapping a list of numpy arrays."""
|
| 79 |
+
|
| 80 |
+
def __init__(self, images: list[np.ndarray], image_size: int = 299):
|
| 81 |
+
self.images = images
|
| 82 |
+
self.image_size = image_size
|
| 83 |
+
|
| 84 |
+
def __len__(self) -> int:
|
| 85 |
+
return len(self.images)
|
| 86 |
+
|
| 87 |
+
def __getitem__(self, idx: int) -> Any:
|
| 88 |
+
import cv2
|
| 89 |
+
|
| 90 |
+
img = self.images[idx]
|
| 91 |
+
if img.shape[:2] != (self.image_size, self.image_size):
|
| 92 |
+
img = cv2.resize(img, (self.image_size, self.image_size))
|
| 93 |
+
if img.shape[2] == 3:
|
| 94 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 95 |
+
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
|
| 96 |
+
t = _imagenet_normalize(t)
|
| 97 |
+
return t
|
| 98 |
+
|
| 99 |
+
def _imagenet_normalize(t: Any) -> Any:
|
| 100 |
+
"""Apply ImageNet normalization."""
|
| 101 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 102 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 103 |
+
return (t - mean) / std
|
| 104 |
+
|
| 105 |
+
@torch.no_grad()
|
| 106 |
+
def _extract_features(
|
| 107 |
+
model: Any,
|
| 108 |
+
dataloader: Any,
|
| 109 |
+
device: Any,
|
| 110 |
+
) -> np.ndarray:
|
| 111 |
+
"""Extract InceptionV3 pool3 features from a dataloader."""
|
| 112 |
+
features = []
|
| 113 |
+
for batch in dataloader:
|
| 114 |
+
batch = batch.to(device)
|
| 115 |
+
feat = model(batch)
|
| 116 |
+
if isinstance(feat, tuple):
|
| 117 |
+
feat = feat[0]
|
| 118 |
+
features.append(feat.cpu().numpy())
|
| 119 |
+
return np.concatenate(features, axis=0)
|
| 120 |
|
| 121 |
|
| 122 |
def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def _calculate_fid(
|
| 130 |
+
mu1: np.ndarray,
|
| 131 |
+
sigma1: np.ndarray,
|
| 132 |
+
mu2: np.ndarray,
|
| 133 |
+
sigma2: np.ndarray,
|
| 134 |
) -> float:
|
| 135 |
"""Calculate FID given two sets of statistics.
|
| 136 |
|
|
|
|
| 183 |
if len(real_ds) == 0 or len(gen_ds) == 0:
|
| 184 |
raise ValueError("Need at least 1 image in each directory")
|
| 185 |
|
| 186 |
+
real_loader = DataLoader(
|
| 187 |
+
real_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True
|
| 188 |
+
)
|
| 189 |
+
gen_loader = DataLoader(gen_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
|
| 190 |
|
| 191 |
real_features = _extract_features(model, real_loader, dev)
|
| 192 |
gen_features = _extract_features(model, gen_loader, dev)
|
landmarkdiff/hyperparam.py
CHANGED
|
@@ -24,7 +24,7 @@ import json
|
|
| 24 |
import math
|
| 25 |
from dataclasses import dataclass, field
|
| 26 |
from pathlib import Path
|
| 27 |
-
from typing import Any
|
| 28 |
|
| 29 |
|
| 30 |
def _to_native(val: Any) -> Any:
|
|
@@ -99,27 +99,45 @@ class SearchSpace:
|
|
| 99 |
self.params: dict[str, ParamSpec] = {}
|
| 100 |
|
| 101 |
def add_float(
|
| 102 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
) -> SearchSpace:
|
| 104 |
"""Add a continuous float parameter."""
|
| 105 |
self.params[name] = ParamSpec(
|
| 106 |
-
name=name,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
)
|
| 108 |
return self
|
| 109 |
|
| 110 |
def add_int(
|
| 111 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
) -> SearchSpace:
|
| 113 |
"""Add an integer parameter."""
|
| 114 |
self.params[name] = ParamSpec(
|
| 115 |
-
name=name,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
)
|
| 117 |
return self
|
| 118 |
|
| 119 |
def add_choice(self, name: str, choices: list[Any]) -> SearchSpace:
|
| 120 |
"""Add a categorical parameter."""
|
| 121 |
self.params[name] = ParamSpec(
|
| 122 |
-
name=name,
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
return self
|
| 125 |
|
|
@@ -204,10 +222,7 @@ class HyperparamSearch:
|
|
| 204 |
attempts = 0
|
| 205 |
while len(trials) < n_trials and attempts < max_attempts:
|
| 206 |
attempts += 1
|
| 207 |
-
config = {
|
| 208 |
-
name: spec.sample(rng)
|
| 209 |
-
for name, spec in self.space.params.items()
|
| 210 |
-
}
|
| 211 |
trial = Trial(
|
| 212 |
trial_id=f"trial_{len(trials):04d}",
|
| 213 |
config=config,
|
|
@@ -223,14 +238,11 @@ class HyperparamSearch:
|
|
| 223 |
import itertools
|
| 224 |
|
| 225 |
param_names = list(self.space.params.keys())
|
| 226 |
-
param_values = [
|
| 227 |
-
self.space.params[name].grid_values(grid_points)
|
| 228 |
-
for name in param_names
|
| 229 |
-
]
|
| 230 |
|
| 231 |
trials = []
|
| 232 |
for combo in itertools.product(*param_values):
|
| 233 |
-
config = dict(zip(param_names, combo))
|
| 234 |
trial = Trial(
|
| 235 |
trial_id=f"trial_{len(trials):04d}",
|
| 236 |
config=config,
|
|
@@ -240,7 +252,9 @@ class HyperparamSearch:
|
|
| 240 |
return trials
|
| 241 |
|
| 242 |
def record_result(
|
| 243 |
-
self,
|
|
|
|
|
|
|
| 244 |
) -> None:
|
| 245 |
"""Record results for a trial."""
|
| 246 |
for trial in self.trials:
|
|
@@ -251,7 +265,9 @@ class HyperparamSearch:
|
|
| 251 |
raise KeyError(f"Trial {trial_id} not found")
|
| 252 |
|
| 253 |
def best_trial(
|
| 254 |
-
self,
|
|
|
|
|
|
|
| 255 |
) -> Trial | None:
|
| 256 |
"""Get the best completed trial by a metric."""
|
| 257 |
completed = [t for t in self.trials if t.status == "completed" and metric in t.result]
|
|
@@ -275,7 +291,8 @@ class HyperparamSearch:
|
|
| 275 |
with open(cfg_path, "w") as f:
|
| 276 |
yaml.safe_dump(
|
| 277 |
{"trial_id": trial.trial_id, **native_config},
|
| 278 |
-
f,
|
|
|
|
| 279 |
)
|
| 280 |
|
| 281 |
# Save summary index
|
|
@@ -321,7 +338,7 @@ class HyperparamSearch:
|
|
| 321 |
if isinstance(val, float):
|
| 322 |
parts.append(f"{val:>12.6f}")
|
| 323 |
else:
|
| 324 |
-
parts.append(f"{
|
| 325 |
for m in metric_names:
|
| 326 |
val = trial.result.get(m, float("nan"))
|
| 327 |
parts.append(f"{val:>12.4f}")
|
|
|
|
| 24 |
import math
|
| 25 |
from dataclasses import dataclass, field
|
| 26 |
from pathlib import Path
|
| 27 |
+
from typing import Any
|
| 28 |
|
| 29 |
|
| 30 |
def _to_native(val: Any) -> Any:
|
|
|
|
| 99 |
self.params: dict[str, ParamSpec] = {}
|
| 100 |
|
| 101 |
def add_float(
|
| 102 |
+
self,
|
| 103 |
+
name: str,
|
| 104 |
+
low: float,
|
| 105 |
+
high: float,
|
| 106 |
+
log_scale: bool = False,
|
| 107 |
) -> SearchSpace:
|
| 108 |
"""Add a continuous float parameter."""
|
| 109 |
self.params[name] = ParamSpec(
|
| 110 |
+
name=name,
|
| 111 |
+
param_type="float",
|
| 112 |
+
low=low,
|
| 113 |
+
high=high,
|
| 114 |
+
log_scale=log_scale,
|
| 115 |
)
|
| 116 |
return self
|
| 117 |
|
| 118 |
def add_int(
|
| 119 |
+
self,
|
| 120 |
+
name: str,
|
| 121 |
+
low: int,
|
| 122 |
+
high: int,
|
| 123 |
+
step: int = 1,
|
| 124 |
) -> SearchSpace:
|
| 125 |
"""Add an integer parameter."""
|
| 126 |
self.params[name] = ParamSpec(
|
| 127 |
+
name=name,
|
| 128 |
+
param_type="int",
|
| 129 |
+
low=low,
|
| 130 |
+
high=high,
|
| 131 |
+
step=step,
|
| 132 |
)
|
| 133 |
return self
|
| 134 |
|
| 135 |
def add_choice(self, name: str, choices: list[Any]) -> SearchSpace:
|
| 136 |
"""Add a categorical parameter."""
|
| 137 |
self.params[name] = ParamSpec(
|
| 138 |
+
name=name,
|
| 139 |
+
param_type="choice",
|
| 140 |
+
choices=choices,
|
| 141 |
)
|
| 142 |
return self
|
| 143 |
|
|
|
|
| 222 |
attempts = 0
|
| 223 |
while len(trials) < n_trials and attempts < max_attempts:
|
| 224 |
attempts += 1
|
| 225 |
+
config = {name: spec.sample(rng) for name, spec in self.space.params.items()}
|
|
|
|
|
|
|
|
|
|
| 226 |
trial = Trial(
|
| 227 |
trial_id=f"trial_{len(trials):04d}",
|
| 228 |
config=config,
|
|
|
|
| 238 |
import itertools
|
| 239 |
|
| 240 |
param_names = list(self.space.params.keys())
|
| 241 |
+
param_values = [self.space.params[name].grid_values(grid_points) for name in param_names]
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
trials = []
|
| 244 |
for combo in itertools.product(*param_values):
|
| 245 |
+
config = dict(zip(param_names, combo, strict=False))
|
| 246 |
trial = Trial(
|
| 247 |
trial_id=f"trial_{len(trials):04d}",
|
| 248 |
config=config,
|
|
|
|
| 252 |
return trials
|
| 253 |
|
| 254 |
def record_result(
|
| 255 |
+
self,
|
| 256 |
+
trial_id: str,
|
| 257 |
+
metrics: dict[str, float],
|
| 258 |
) -> None:
|
| 259 |
"""Record results for a trial."""
|
| 260 |
for trial in self.trials:
|
|
|
|
| 265 |
raise KeyError(f"Trial {trial_id} not found")
|
| 266 |
|
| 267 |
def best_trial(
|
| 268 |
+
self,
|
| 269 |
+
metric: str = "loss",
|
| 270 |
+
lower_is_better: bool = True,
|
| 271 |
) -> Trial | None:
|
| 272 |
"""Get the best completed trial by a metric."""
|
| 273 |
completed = [t for t in self.trials if t.status == "completed" and metric in t.result]
|
|
|
|
| 291 |
with open(cfg_path, "w") as f:
|
| 292 |
yaml.safe_dump(
|
| 293 |
{"trial_id": trial.trial_id, **native_config},
|
| 294 |
+
f,
|
| 295 |
+
default_flow_style=False,
|
| 296 |
)
|
| 297 |
|
| 298 |
# Save summary index
|
|
|
|
| 338 |
if isinstance(val, float):
|
| 339 |
parts.append(f"{val:>12.6f}")
|
| 340 |
else:
|
| 341 |
+
parts.append(f"{val!s:>12s}")
|
| 342 |
for m in metric_names:
|
| 343 |
val = trial.result.get(m, float("nan"))
|
| 344 |
parts.append(f"{val:>12.4f}")
|
landmarkdiff/inference.py
CHANGED
|
@@ -13,7 +13,7 @@ from __future__ import annotations
|
|
| 13 |
|
| 14 |
import sys
|
| 15 |
from pathlib import Path
|
| 16 |
-
from typing import
|
| 17 |
|
| 18 |
import cv2
|
| 19 |
import numpy as np
|
|
@@ -21,11 +21,13 @@ import torch
|
|
| 21 |
from PIL import Image
|
| 22 |
|
| 23 |
from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks, render_landmark_image
|
| 24 |
-
from landmarkdiff.conditioning import generate_conditioning
|
| 25 |
from landmarkdiff.manipulation import apply_procedure_preset
|
| 26 |
from landmarkdiff.masking import generate_surgical_mask, mask_to_3channel
|
| 27 |
from landmarkdiff.synthetic.tps_warp import warp_image_tps
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def get_device() -> torch.device:
|
| 31 |
if torch.backends.mps.is_available():
|
|
@@ -102,6 +104,7 @@ def mask_composite(
|
|
| 102 |
if use_laplacian:
|
| 103 |
try:
|
| 104 |
from landmarkdiff.postprocess import laplacian_pyramid_blend
|
|
|
|
| 105 |
return laplacian_pyramid_blend(corrected, original, mask_f)
|
| 106 |
except Exception:
|
| 107 |
pass
|
|
@@ -109,8 +112,7 @@ def mask_composite(
|
|
| 109 |
# Fallback: simple alpha blend
|
| 110 |
mask_3ch = mask_to_3channel(mask_f)
|
| 111 |
result = (
|
| 112 |
-
corrected.astype(np.float32) * mask_3ch
|
| 113 |
-
+ original.astype(np.float32) * (1.0 - mask_3ch)
|
| 114 |
).astype(np.uint8)
|
| 115 |
|
| 116 |
return result
|
|
@@ -170,10 +172,10 @@ class LandmarkDiffPipeline:
|
|
| 170 |
controlnet_id: str = "CrucibleAI/ControlNetMediaPipeFace",
|
| 171 |
controlnet_checkpoint: str | None = None,
|
| 172 |
base_model_id: str | None = None,
|
| 173 |
-
device:
|
| 174 |
-
dtype:
|
| 175 |
ip_adapter_scale: float = 0.6,
|
| 176 |
-
clinical_flags:
|
| 177 |
displacement_model_path: str | None = None,
|
| 178 |
):
|
| 179 |
self.mode = mode
|
|
@@ -187,6 +189,7 @@ class LandmarkDiffPipeline:
|
|
| 187 |
if displacement_model_path:
|
| 188 |
try:
|
| 189 |
from landmarkdiff.displacement_model import DisplacementModel
|
|
|
|
| 190 |
self._displacement_model = DisplacementModel.load(displacement_model_path)
|
| 191 |
print(f"Displacement model loaded: {self._displacement_model.procedures}")
|
| 192 |
except Exception as e:
|
|
@@ -224,8 +227,8 @@ class LandmarkDiffPipeline:
|
|
| 224 |
def _load_controlnet(self) -> None:
|
| 225 |
from diffusers import (
|
| 226 |
ControlNetModel,
|
| 227 |
-
StableDiffusionControlNetPipeline,
|
| 228 |
DPMSolverMultistepScheduler,
|
|
|
|
| 229 |
)
|
| 230 |
|
| 231 |
if self.controlnet_checkpoint:
|
|
@@ -236,12 +239,15 @@ class LandmarkDiffPipeline:
|
|
| 236 |
ckpt_path = ckpt_path / "controlnet_ema"
|
| 237 |
print(f"Loading fine-tuned ControlNet from {ckpt_path}...")
|
| 238 |
controlnet = ControlNetModel.from_pretrained(
|
| 239 |
-
str(ckpt_path),
|
|
|
|
| 240 |
)
|
| 241 |
else:
|
| 242 |
print(f"Loading ControlNet from {self.controlnet_id}...")
|
| 243 |
controlnet = ControlNetModel.from_pretrained(
|
| 244 |
-
self.controlnet_id,
|
|
|
|
|
|
|
| 245 |
)
|
| 246 |
print(f"Loading base model from {self.base_model_id}...")
|
| 247 |
self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
|
@@ -287,8 +293,8 @@ class LandmarkDiffPipeline:
|
|
| 287 |
|
| 288 |
def _load_img2img(self) -> None:
|
| 289 |
from diffusers import (
|
| 290 |
-
StableDiffusionImg2ImgPipeline,
|
| 291 |
DPMSolverMultistepScheduler,
|
|
|
|
| 292 |
)
|
| 293 |
|
| 294 |
print(f"Loading SD1.5 img2img from {self.base_model_id}...")
|
|
@@ -298,9 +304,7 @@ class LandmarkDiffPipeline:
|
|
| 298 |
safety_checker=None,
|
| 299 |
requires_safety_checker=False,
|
| 300 |
)
|
| 301 |
-
self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 302 |
-
self._pipe.scheduler.config
|
| 303 |
-
)
|
| 304 |
self._apply_device_optimizations()
|
| 305 |
|
| 306 |
def _apply_device_optimizations(self) -> None:
|
|
@@ -329,8 +333,8 @@ class LandmarkDiffPipeline:
|
|
| 329 |
guidance_scale: float = 9.0,
|
| 330 |
controlnet_conditioning_scale: float = 0.9,
|
| 331 |
strength: float = 0.5,
|
| 332 |
-
seed:
|
| 333 |
-
clinical_flags:
|
| 334 |
postprocess: bool = True,
|
| 335 |
use_gfpgan: bool = False,
|
| 336 |
) -> dict:
|
|
@@ -351,12 +355,14 @@ class LandmarkDiffPipeline:
|
|
| 351 |
manipulation_mode = "preset"
|
| 352 |
if self._displacement_model and procedure in self._displacement_model.procedures:
|
| 353 |
try:
|
| 354 |
-
from landmarkdiff.displacement_model import DisplacementModel
|
| 355 |
rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
|
| 356 |
# Map UI intensity (0-100) to displacement model intensity (0-2)
|
| 357 |
dm_intensity = intensity / 50.0 # 50 -> 1.0x mean displacement
|
| 358 |
displacement = self._displacement_model.get_displacement_field(
|
| 359 |
-
procedure,
|
|
|
|
|
|
|
|
|
|
| 360 |
)
|
| 361 |
# Apply displacement to landmarks
|
| 362 |
new_lm = face.landmarks.copy()
|
|
@@ -367,21 +373,34 @@ class LandmarkDiffPipeline:
|
|
| 367 |
new_lm[:, 1] = np.clip(new_lm[:, 1], 0.01, 0.99)
|
| 368 |
manipulated = FaceLandmarks(
|
| 369 |
landmarks=new_lm,
|
| 370 |
-
image_width=512,
|
|
|
|
| 371 |
confidence=face.confidence,
|
| 372 |
)
|
| 373 |
manipulation_mode = "displacement_model"
|
| 374 |
except Exception:
|
| 375 |
manipulated = apply_procedure_preset(
|
| 376 |
-
face,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
)
|
| 378 |
else:
|
| 379 |
manipulated = apply_procedure_preset(
|
| 380 |
-
face,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
)
|
| 382 |
landmark_img = render_landmark_image(manipulated, 512, 512)
|
| 383 |
mask = generate_surgical_mask(
|
| 384 |
-
face,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
)
|
| 386 |
|
| 387 |
generator = None
|
|
@@ -398,14 +417,24 @@ class LandmarkDiffPipeline:
|
|
| 398 |
elif self.mode in ("controlnet", "controlnet_ip"):
|
| 399 |
ip_image = numpy_to_pil(image_512) if self._ip_adapter_loaded else None
|
| 400 |
raw_output = self._generate_controlnet(
|
| 401 |
-
image_512,
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
ip_adapter_image=ip_image,
|
| 404 |
)
|
| 405 |
else:
|
| 406 |
raw_output = self._generate_img2img(
|
| 407 |
-
tps_warped,
|
| 408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
)
|
| 410 |
|
| 411 |
# Step 2: Post-processing for photorealism (neural + classical pipeline)
|
|
@@ -413,6 +442,7 @@ class LandmarkDiffPipeline:
|
|
| 413 |
restore_used = "none"
|
| 414 |
if postprocess and self.mode != "tps":
|
| 415 |
from landmarkdiff.postprocess import full_postprocess
|
|
|
|
| 416 |
pp_result = full_postprocess(
|
| 417 |
generated=raw_output,
|
| 418 |
original=image_512,
|
|
@@ -450,8 +480,13 @@ class LandmarkDiffPipeline:
|
|
| 450 |
}
|
| 451 |
|
| 452 |
def _generate_controlnet(
|
| 453 |
-
self,
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
generator: torch.Generator | None,
|
| 456 |
ip_adapter_image: Image.Image | None = None,
|
| 457 |
) -> np.ndarray:
|
|
@@ -470,8 +505,13 @@ class LandmarkDiffPipeline:
|
|
| 470 |
return pil_to_numpy(result.images[0])
|
| 471 |
|
| 472 |
def _generate_img2img(
|
| 473 |
-
self,
|
| 474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
generator: torch.Generator | None,
|
| 476 |
) -> np.ndarray:
|
| 477 |
result = self._pipe(
|
|
@@ -558,7 +598,8 @@ def run_inference(
|
|
| 558 |
sys.exit(1)
|
| 559 |
|
| 560 |
pipe = LandmarkDiffPipeline(
|
| 561 |
-
mode=mode,
|
|
|
|
| 562 |
controlnet_checkpoint=controlnet_checkpoint,
|
| 563 |
displacement_model_path=displacement_model_path,
|
| 564 |
)
|
|
@@ -594,18 +635,29 @@ if __name__ == "__main__":
|
|
| 594 |
parser.add_argument("--output", default="scripts/inference_output")
|
| 595 |
parser.add_argument("--seed", type=int, default=42)
|
| 596 |
parser.add_argument(
|
| 597 |
-
"--mode",
|
|
|
|
| 598 |
choices=["img2img", "controlnet", "controlnet_ip", "tps"],
|
| 599 |
)
|
| 600 |
parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
|
| 601 |
-
parser.add_argument(
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
args = parser.parse_args()
|
| 606 |
|
| 607 |
run_inference(
|
| 608 |
-
args.image,
|
| 609 |
-
args.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
args.displacement_model,
|
| 611 |
)
|
|
|
|
| 13 |
|
| 14 |
import sys
|
| 15 |
from pathlib import Path
|
| 16 |
+
from typing import TYPE_CHECKING
|
| 17 |
|
| 18 |
import cv2
|
| 19 |
import numpy as np
|
|
|
|
| 21 |
from PIL import Image
|
| 22 |
|
| 23 |
from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks, render_landmark_image
|
|
|
|
| 24 |
from landmarkdiff.manipulation import apply_procedure_preset
|
| 25 |
from landmarkdiff.masking import generate_surgical_mask, mask_to_3channel
|
| 26 |
from landmarkdiff.synthetic.tps_warp import warp_image_tps
|
| 27 |
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from landmarkdiff.clinical import ClinicalFlags
|
| 30 |
+
|
| 31 |
|
| 32 |
def get_device() -> torch.device:
|
| 33 |
if torch.backends.mps.is_available():
|
|
|
|
| 104 |
if use_laplacian:
|
| 105 |
try:
|
| 106 |
from landmarkdiff.postprocess import laplacian_pyramid_blend
|
| 107 |
+
|
| 108 |
return laplacian_pyramid_blend(corrected, original, mask_f)
|
| 109 |
except Exception:
|
| 110 |
pass
|
|
|
|
| 112 |
# Fallback: simple alpha blend
|
| 113 |
mask_3ch = mask_to_3channel(mask_f)
|
| 114 |
result = (
|
| 115 |
+
corrected.astype(np.float32) * mask_3ch + original.astype(np.float32) * (1.0 - mask_3ch)
|
|
|
|
| 116 |
).astype(np.uint8)
|
| 117 |
|
| 118 |
return result
|
|
|
|
| 172 |
controlnet_id: str = "CrucibleAI/ControlNetMediaPipeFace",
|
| 173 |
controlnet_checkpoint: str | None = None,
|
| 174 |
base_model_id: str | None = None,
|
| 175 |
+
device: torch.device | None = None,
|
| 176 |
+
dtype: torch.dtype | None = None,
|
| 177 |
ip_adapter_scale: float = 0.6,
|
| 178 |
+
clinical_flags: ClinicalFlags | None = None,
|
| 179 |
displacement_model_path: str | None = None,
|
| 180 |
):
|
| 181 |
self.mode = mode
|
|
|
|
| 189 |
if displacement_model_path:
|
| 190 |
try:
|
| 191 |
from landmarkdiff.displacement_model import DisplacementModel
|
| 192 |
+
|
| 193 |
self._displacement_model = DisplacementModel.load(displacement_model_path)
|
| 194 |
print(f"Displacement model loaded: {self._displacement_model.procedures}")
|
| 195 |
except Exception as e:
|
|
|
|
| 227 |
def _load_controlnet(self) -> None:
|
| 228 |
from diffusers import (
|
| 229 |
ControlNetModel,
|
|
|
|
| 230 |
DPMSolverMultistepScheduler,
|
| 231 |
+
StableDiffusionControlNetPipeline,
|
| 232 |
)
|
| 233 |
|
| 234 |
if self.controlnet_checkpoint:
|
|
|
|
| 239 |
ckpt_path = ckpt_path / "controlnet_ema"
|
| 240 |
print(f"Loading fine-tuned ControlNet from {ckpt_path}...")
|
| 241 |
controlnet = ControlNetModel.from_pretrained(
|
| 242 |
+
str(ckpt_path),
|
| 243 |
+
torch_dtype=self.dtype,
|
| 244 |
)
|
| 245 |
else:
|
| 246 |
print(f"Loading ControlNet from {self.controlnet_id}...")
|
| 247 |
controlnet = ControlNetModel.from_pretrained(
|
| 248 |
+
self.controlnet_id,
|
| 249 |
+
subfolder="diffusion_sd15",
|
| 250 |
+
torch_dtype=self.dtype,
|
| 251 |
)
|
| 252 |
print(f"Loading base model from {self.base_model_id}...")
|
| 253 |
self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
|
|
|
| 293 |
|
| 294 |
def _load_img2img(self) -> None:
|
| 295 |
from diffusers import (
|
|
|
|
| 296 |
DPMSolverMultistepScheduler,
|
| 297 |
+
StableDiffusionImg2ImgPipeline,
|
| 298 |
)
|
| 299 |
|
| 300 |
print(f"Loading SD1.5 img2img from {self.base_model_id}...")
|
|
|
|
| 304 |
safety_checker=None,
|
| 305 |
requires_safety_checker=False,
|
| 306 |
)
|
| 307 |
+
self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(self._pipe.scheduler.config)
|
|
|
|
|
|
|
| 308 |
self._apply_device_optimizations()
|
| 309 |
|
| 310 |
def _apply_device_optimizations(self) -> None:
|
|
|
|
| 333 |
guidance_scale: float = 9.0,
|
| 334 |
controlnet_conditioning_scale: float = 0.9,
|
| 335 |
strength: float = 0.5,
|
| 336 |
+
seed: int | None = None,
|
| 337 |
+
clinical_flags: ClinicalFlags | None = None,
|
| 338 |
postprocess: bool = True,
|
| 339 |
use_gfpgan: bool = False,
|
| 340 |
) -> dict:
|
|
|
|
| 355 |
manipulation_mode = "preset"
|
| 356 |
if self._displacement_model and procedure in self._displacement_model.procedures:
|
| 357 |
try:
|
|
|
|
| 358 |
rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
|
| 359 |
# Map UI intensity (0-100) to displacement model intensity (0-2)
|
| 360 |
dm_intensity = intensity / 50.0 # 50 -> 1.0x mean displacement
|
| 361 |
displacement = self._displacement_model.get_displacement_field(
|
| 362 |
+
procedure,
|
| 363 |
+
intensity=dm_intensity,
|
| 364 |
+
noise_scale=0.3,
|
| 365 |
+
rng=rng,
|
| 366 |
)
|
| 367 |
# Apply displacement to landmarks
|
| 368 |
new_lm = face.landmarks.copy()
|
|
|
|
| 373 |
new_lm[:, 1] = np.clip(new_lm[:, 1], 0.01, 0.99)
|
| 374 |
manipulated = FaceLandmarks(
|
| 375 |
landmarks=new_lm,
|
| 376 |
+
image_width=512,
|
| 377 |
+
image_height=512,
|
| 378 |
confidence=face.confidence,
|
| 379 |
)
|
| 380 |
manipulation_mode = "displacement_model"
|
| 381 |
except Exception:
|
| 382 |
manipulated = apply_procedure_preset(
|
| 383 |
+
face,
|
| 384 |
+
procedure,
|
| 385 |
+
intensity,
|
| 386 |
+
image_size=512,
|
| 387 |
+
clinical_flags=flags,
|
| 388 |
)
|
| 389 |
else:
|
| 390 |
manipulated = apply_procedure_preset(
|
| 391 |
+
face,
|
| 392 |
+
procedure,
|
| 393 |
+
intensity,
|
| 394 |
+
image_size=512,
|
| 395 |
+
clinical_flags=flags,
|
| 396 |
)
|
| 397 |
landmark_img = render_landmark_image(manipulated, 512, 512)
|
| 398 |
mask = generate_surgical_mask(
|
| 399 |
+
face,
|
| 400 |
+
procedure,
|
| 401 |
+
512,
|
| 402 |
+
512,
|
| 403 |
+
clinical_flags=flags,
|
| 404 |
)
|
| 405 |
|
| 406 |
generator = None
|
|
|
|
| 417 |
elif self.mode in ("controlnet", "controlnet_ip"):
|
| 418 |
ip_image = numpy_to_pil(image_512) if self._ip_adapter_loaded else None
|
| 419 |
raw_output = self._generate_controlnet(
|
| 420 |
+
image_512,
|
| 421 |
+
landmark_img,
|
| 422 |
+
prompt,
|
| 423 |
+
num_inference_steps,
|
| 424 |
+
guidance_scale,
|
| 425 |
+
controlnet_conditioning_scale,
|
| 426 |
+
generator,
|
| 427 |
ip_adapter_image=ip_image,
|
| 428 |
)
|
| 429 |
else:
|
| 430 |
raw_output = self._generate_img2img(
|
| 431 |
+
tps_warped,
|
| 432 |
+
mask,
|
| 433 |
+
prompt,
|
| 434 |
+
num_inference_steps,
|
| 435 |
+
guidance_scale,
|
| 436 |
+
strength,
|
| 437 |
+
generator,
|
| 438 |
)
|
| 439 |
|
| 440 |
# Step 2: Post-processing for photorealism (neural + classical pipeline)
|
|
|
|
| 442 |
restore_used = "none"
|
| 443 |
if postprocess and self.mode != "tps":
|
| 444 |
from landmarkdiff.postprocess import full_postprocess
|
| 445 |
+
|
| 446 |
pp_result = full_postprocess(
|
| 447 |
generated=raw_output,
|
| 448 |
original=image_512,
|
|
|
|
| 480 |
}
|
| 481 |
|
| 482 |
def _generate_controlnet(
|
| 483 |
+
self,
|
| 484 |
+
image: np.ndarray,
|
| 485 |
+
conditioning: np.ndarray,
|
| 486 |
+
prompt: str,
|
| 487 |
+
steps: int,
|
| 488 |
+
cfg: float,
|
| 489 |
+
cn_scale: float,
|
| 490 |
generator: torch.Generator | None,
|
| 491 |
ip_adapter_image: Image.Image | None = None,
|
| 492 |
) -> np.ndarray:
|
|
|
|
| 505 |
return pil_to_numpy(result.images[0])
|
| 506 |
|
| 507 |
def _generate_img2img(
|
| 508 |
+
self,
|
| 509 |
+
image: np.ndarray,
|
| 510 |
+
mask: np.ndarray,
|
| 511 |
+
prompt: str,
|
| 512 |
+
steps: int,
|
| 513 |
+
cfg: float,
|
| 514 |
+
strength: float,
|
| 515 |
generator: torch.Generator | None,
|
| 516 |
) -> np.ndarray:
|
| 517 |
result = self._pipe(
|
|
|
|
| 598 |
sys.exit(1)
|
| 599 |
|
| 600 |
pipe = LandmarkDiffPipeline(
|
| 601 |
+
mode=mode,
|
| 602 |
+
ip_adapter_scale=ip_adapter_scale,
|
| 603 |
controlnet_checkpoint=controlnet_checkpoint,
|
| 604 |
displacement_model_path=displacement_model_path,
|
| 605 |
)
|
|
|
|
| 635 |
parser.add_argument("--output", default="scripts/inference_output")
|
| 636 |
parser.add_argument("--seed", type=int, default=42)
|
| 637 |
parser.add_argument(
|
| 638 |
+
"--mode",
|
| 639 |
+
default="img2img",
|
| 640 |
choices=["img2img", "controlnet", "controlnet_ip", "tps"],
|
| 641 |
)
|
| 642 |
parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
|
| 643 |
+
parser.add_argument(
|
| 644 |
+
"--checkpoint", default=None, help="Path to fine-tuned ControlNet checkpoint"
|
| 645 |
+
)
|
| 646 |
+
parser.add_argument(
|
| 647 |
+
"--displacement-model",
|
| 648 |
+
default=None,
|
| 649 |
+
help="Path to displacement_model.npz for data-driven manipulation",
|
| 650 |
+
)
|
| 651 |
args = parser.parse_args()
|
| 652 |
|
| 653 |
run_inference(
|
| 654 |
+
args.image,
|
| 655 |
+
args.procedure,
|
| 656 |
+
args.intensity,
|
| 657 |
+
args.output,
|
| 658 |
+
args.seed,
|
| 659 |
+
args.mode,
|
| 660 |
+
args.ip_adapter_scale,
|
| 661 |
+
args.checkpoint,
|
| 662 |
args.displacement_model,
|
| 663 |
)
|
landmarkdiff/landmarks.py
CHANGED
|
@@ -4,7 +4,6 @@ from __future__ import annotations
|
|
| 4 |
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from pathlib import Path
|
| 7 |
-
from typing import Optional
|
| 8 |
|
| 9 |
import cv2
|
| 10 |
import mediapipe as mp
|
|
@@ -12,39 +11,145 @@ import numpy as np
|
|
| 12 |
|
| 13 |
# Region color map for visualization (BGR)
|
| 14 |
REGION_COLORS: dict[str, tuple[int, int, int]] = {
|
| 15 |
-
"jawline": (255, 255, 255),
|
| 16 |
-
"eyebrow_left": (0, 255, 0),
|
| 17 |
"eyebrow_right": (0, 255, 0),
|
| 18 |
-
"eye_left": (255, 255, 0),
|
| 19 |
"eye_right": (255, 255, 0),
|
| 20 |
-
"nose": (0, 255, 255),
|
| 21 |
-
"lips": (0, 0, 255),
|
| 22 |
-
"iris_left": (255, 0, 255),
|
| 23 |
"iris_right": (255, 0, 255),
|
| 24 |
}
|
| 25 |
|
| 26 |
# MediaPipe landmark index groups by anatomical region
|
| 27 |
LANDMARK_REGIONS: dict[str, list[int]] = {
|
| 28 |
"jawline": [
|
| 29 |
-
10,
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
],
|
| 33 |
"eye_left": [
|
| 34 |
-
33,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
],
|
| 36 |
"eye_right": [
|
| 37 |
-
362,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
],
|
| 39 |
"eyebrow_left": [70, 63, 105, 66, 107, 55, 65, 52, 53, 46],
|
| 40 |
"eyebrow_right": [300, 293, 334, 296, 336, 285, 295, 282, 283, 276],
|
| 41 |
"nose": [
|
| 42 |
-
1,
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
],
|
| 45 |
"lips": [
|
| 46 |
-
61,
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
],
|
| 49 |
"iris_left": [468, 469, 470, 471, 472],
|
| 50 |
"iris_right": [473, 474, 475, 476, 477],
|
|
@@ -78,7 +183,7 @@ def extract_landmarks(
|
|
| 78 |
image: np.ndarray,
|
| 79 |
min_detection_confidence: float = 0.5,
|
| 80 |
min_tracking_confidence: float = 0.5,
|
| 81 |
-
) ->
|
| 82 |
"""Extract 478 facial landmarks from an image using MediaPipe Face Mesh.
|
| 83 |
|
| 84 |
Args:
|
|
@@ -97,7 +202,9 @@ def extract_landmarks(
|
|
| 97 |
landmarks, confidence = _extract_tasks_api(rgb, min_detection_confidence)
|
| 98 |
except Exception:
|
| 99 |
try:
|
| 100 |
-
landmarks, confidence = _extract_solutions_api(
|
|
|
|
|
|
|
| 101 |
except Exception:
|
| 102 |
return None
|
| 103 |
|
|
@@ -115,14 +222,14 @@ def extract_landmarks(
|
|
| 115 |
def _extract_tasks_api(
|
| 116 |
rgb: np.ndarray,
|
| 117 |
min_confidence: float,
|
| 118 |
-
) -> tuple[
|
| 119 |
"""Extract landmarks using MediaPipe Tasks API (>= 0.10.20)."""
|
| 120 |
FaceLandmarker = mp.tasks.vision.FaceLandmarker
|
| 121 |
FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
|
| 122 |
RunningMode = mp.tasks.vision.RunningMode
|
| 123 |
BaseOptions = mp.tasks.BaseOptions
|
| 124 |
-
import urllib.request
|
| 125 |
import tempfile
|
|
|
|
| 126 |
|
| 127 |
# Download model if not cached
|
| 128 |
model_path = Path(tempfile.gettempdir()) / "face_landmarker_v2_with_blendshapes.task"
|
|
@@ -161,7 +268,7 @@ def _extract_solutions_api(
|
|
| 161 |
rgb: np.ndarray,
|
| 162 |
min_detection_confidence: float,
|
| 163 |
min_tracking_confidence: float,
|
| 164 |
-
) -> tuple[
|
| 165 |
"""Extract landmarks using legacy MediaPipe Solutions API."""
|
| 166 |
with mp.solutions.face_mesh.FaceMesh(
|
| 167 |
static_image_mode=True,
|
|
@@ -224,8 +331,8 @@ def visualize_landmarks(
|
|
| 224 |
|
| 225 |
def render_landmark_image(
|
| 226 |
face: FaceLandmarks,
|
| 227 |
-
width:
|
| 228 |
-
height:
|
| 229 |
radius: int = 2,
|
| 230 |
) -> np.ndarray:
|
| 231 |
"""Render MediaPipe face mesh tessellation on black canvas.
|
|
@@ -257,6 +364,7 @@ def render_landmark_image(
|
|
| 257 |
# Draw tessellation mesh (what CrucibleAI ControlNet expects)
|
| 258 |
try:
|
| 259 |
from mediapipe.tasks.python.vision.face_landmarker import FaceLandmarksConnections
|
|
|
|
| 260 |
tessellation = FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION
|
| 261 |
contours = FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS
|
| 262 |
|
|
|
|
| 4 |
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from pathlib import Path
|
|
|
|
| 7 |
|
| 8 |
import cv2
|
| 9 |
import mediapipe as mp
|
|
|
|
| 11 |
|
| 12 |
# Region color map for visualization (BGR)
|
| 13 |
REGION_COLORS: dict[str, tuple[int, int, int]] = {
|
| 14 |
+
"jawline": (255, 255, 255), # white
|
| 15 |
+
"eyebrow_left": (0, 255, 0), # green
|
| 16 |
"eyebrow_right": (0, 255, 0),
|
| 17 |
+
"eye_left": (255, 255, 0), # cyan
|
| 18 |
"eye_right": (255, 255, 0),
|
| 19 |
+
"nose": (0, 255, 255), # yellow
|
| 20 |
+
"lips": (0, 0, 255), # red
|
| 21 |
+
"iris_left": (255, 0, 255), # magenta
|
| 22 |
"iris_right": (255, 0, 255),
|
| 23 |
}
|
| 24 |
|
| 25 |
# MediaPipe landmark index groups by anatomical region
|
| 26 |
LANDMARK_REGIONS: dict[str, list[int]] = {
|
| 27 |
"jawline": [
|
| 28 |
+
10,
|
| 29 |
+
338,
|
| 30 |
+
297,
|
| 31 |
+
332,
|
| 32 |
+
284,
|
| 33 |
+
251,
|
| 34 |
+
389,
|
| 35 |
+
356,
|
| 36 |
+
454,
|
| 37 |
+
323,
|
| 38 |
+
361,
|
| 39 |
+
288,
|
| 40 |
+
397,
|
| 41 |
+
365,
|
| 42 |
+
379,
|
| 43 |
+
378,
|
| 44 |
+
400,
|
| 45 |
+
377,
|
| 46 |
+
152,
|
| 47 |
+
148,
|
| 48 |
+
176,
|
| 49 |
+
149,
|
| 50 |
+
150,
|
| 51 |
+
136,
|
| 52 |
+
172,
|
| 53 |
+
58,
|
| 54 |
+
132,
|
| 55 |
+
93,
|
| 56 |
+
234,
|
| 57 |
+
127,
|
| 58 |
+
162,
|
| 59 |
+
21,
|
| 60 |
+
54,
|
| 61 |
+
103,
|
| 62 |
+
67,
|
| 63 |
+
109,
|
| 64 |
],
|
| 65 |
"eye_left": [
|
| 66 |
+
33,
|
| 67 |
+
7,
|
| 68 |
+
163,
|
| 69 |
+
144,
|
| 70 |
+
145,
|
| 71 |
+
153,
|
| 72 |
+
154,
|
| 73 |
+
155,
|
| 74 |
+
133,
|
| 75 |
+
173,
|
| 76 |
+
157,
|
| 77 |
+
158,
|
| 78 |
+
159,
|
| 79 |
+
160,
|
| 80 |
+
161,
|
| 81 |
+
246,
|
| 82 |
],
|
| 83 |
"eye_right": [
|
| 84 |
+
362,
|
| 85 |
+
382,
|
| 86 |
+
381,
|
| 87 |
+
380,
|
| 88 |
+
374,
|
| 89 |
+
373,
|
| 90 |
+
390,
|
| 91 |
+
249,
|
| 92 |
+
263,
|
| 93 |
+
466,
|
| 94 |
+
388,
|
| 95 |
+
387,
|
| 96 |
+
386,
|
| 97 |
+
385,
|
| 98 |
+
384,
|
| 99 |
+
398,
|
| 100 |
],
|
| 101 |
"eyebrow_left": [70, 63, 105, 66, 107, 55, 65, 52, 53, 46],
|
| 102 |
"eyebrow_right": [300, 293, 334, 296, 336, 285, 295, 282, 283, 276],
|
| 103 |
"nose": [
|
| 104 |
+
1,
|
| 105 |
+
2,
|
| 106 |
+
4,
|
| 107 |
+
5,
|
| 108 |
+
6,
|
| 109 |
+
19,
|
| 110 |
+
94,
|
| 111 |
+
141,
|
| 112 |
+
168,
|
| 113 |
+
195,
|
| 114 |
+
197,
|
| 115 |
+
236,
|
| 116 |
+
240,
|
| 117 |
+
274,
|
| 118 |
+
275,
|
| 119 |
+
278,
|
| 120 |
+
279,
|
| 121 |
+
294,
|
| 122 |
+
326,
|
| 123 |
+
327,
|
| 124 |
+
360,
|
| 125 |
+
363,
|
| 126 |
+
370,
|
| 127 |
+
456,
|
| 128 |
+
460,
|
| 129 |
],
|
| 130 |
"lips": [
|
| 131 |
+
61,
|
| 132 |
+
146,
|
| 133 |
+
91,
|
| 134 |
+
181,
|
| 135 |
+
84,
|
| 136 |
+
17,
|
| 137 |
+
314,
|
| 138 |
+
405,
|
| 139 |
+
321,
|
| 140 |
+
375,
|
| 141 |
+
291,
|
| 142 |
+
308,
|
| 143 |
+
324,
|
| 144 |
+
318,
|
| 145 |
+
402,
|
| 146 |
+
317,
|
| 147 |
+
14,
|
| 148 |
+
87,
|
| 149 |
+
178,
|
| 150 |
+
88,
|
| 151 |
+
95,
|
| 152 |
+
78,
|
| 153 |
],
|
| 154 |
"iris_left": [468, 469, 470, 471, 472],
|
| 155 |
"iris_right": [473, 474, 475, 476, 477],
|
|
|
|
| 183 |
image: np.ndarray,
|
| 184 |
min_detection_confidence: float = 0.5,
|
| 185 |
min_tracking_confidence: float = 0.5,
|
| 186 |
+
) -> FaceLandmarks | None:
|
| 187 |
"""Extract 478 facial landmarks from an image using MediaPipe Face Mesh.
|
| 188 |
|
| 189 |
Args:
|
|
|
|
| 202 |
landmarks, confidence = _extract_tasks_api(rgb, min_detection_confidence)
|
| 203 |
except Exception:
|
| 204 |
try:
|
| 205 |
+
landmarks, confidence = _extract_solutions_api(
|
| 206 |
+
rgb, min_detection_confidence, min_tracking_confidence
|
| 207 |
+
)
|
| 208 |
except Exception:
|
| 209 |
return None
|
| 210 |
|
|
|
|
| 222 |
def _extract_tasks_api(
|
| 223 |
rgb: np.ndarray,
|
| 224 |
min_confidence: float,
|
| 225 |
+
) -> tuple[np.ndarray | None, float]:
|
| 226 |
"""Extract landmarks using MediaPipe Tasks API (>= 0.10.20)."""
|
| 227 |
FaceLandmarker = mp.tasks.vision.FaceLandmarker
|
| 228 |
FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
|
| 229 |
RunningMode = mp.tasks.vision.RunningMode
|
| 230 |
BaseOptions = mp.tasks.BaseOptions
|
|
|
|
| 231 |
import tempfile
|
| 232 |
+
import urllib.request
|
| 233 |
|
| 234 |
# Download model if not cached
|
| 235 |
model_path = Path(tempfile.gettempdir()) / "face_landmarker_v2_with_blendshapes.task"
|
|
|
|
| 268 |
rgb: np.ndarray,
|
| 269 |
min_detection_confidence: float,
|
| 270 |
min_tracking_confidence: float,
|
| 271 |
+
) -> tuple[np.ndarray | None, float]:
|
| 272 |
"""Extract landmarks using legacy MediaPipe Solutions API."""
|
| 273 |
with mp.solutions.face_mesh.FaceMesh(
|
| 274 |
static_image_mode=True,
|
|
|
|
| 331 |
|
| 332 |
def render_landmark_image(
|
| 333 |
face: FaceLandmarks,
|
| 334 |
+
width: int | None = None,
|
| 335 |
+
height: int | None = None,
|
| 336 |
radius: int = 2,
|
| 337 |
) -> np.ndarray:
|
| 338 |
"""Render MediaPipe face mesh tessellation on black canvas.
|
|
|
|
| 364 |
# Draw tessellation mesh (what CrucibleAI ControlNet expects)
|
| 365 |
try:
|
| 366 |
from mediapipe.tasks.python.vision.face_landmarker import FaceLandmarksConnections
|
| 367 |
+
|
| 368 |
tessellation = FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION
|
| 369 |
contours = FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS
|
| 370 |
|
landmarkdiff/log.py
CHANGED
|
@@ -46,10 +46,12 @@ def setup_logging(
|
|
| 46 |
|
| 47 |
if not _CONFIGURED:
|
| 48 |
handler = logging.StreamHandler(stream or sys.stderr)
|
| 49 |
-
handler.setFormatter(
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
root_logger.addHandler(handler)
|
| 54 |
# Prevent propagation to root logger to avoid duplicate messages
|
| 55 |
root_logger.propagate = False
|
|
|
|
| 46 |
|
| 47 |
if not _CONFIGURED:
|
| 48 |
handler = logging.StreamHandler(stream or sys.stderr)
|
| 49 |
+
handler.setFormatter(
|
| 50 |
+
logging.Formatter(
|
| 51 |
+
fmt or LOG_FORMAT,
|
| 52 |
+
datefmt=LOG_DATE_FORMAT,
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
root_logger.addHandler(handler)
|
| 56 |
# Prevent propagation to root logger to avoid duplicate messages
|
| 57 |
root_logger.propagate = False
|
landmarkdiff/losses.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""4-term loss function module for ControlNet fine-tuning.
|
| 2 |
|
| 3 |
-
L_total = L_diffusion + w_landmark * L_landmark
|
|
|
|
| 4 |
|
| 5 |
Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
|
| 6 |
rubbery TPS warps — it would penalize realism.
|
|
@@ -92,11 +93,16 @@ class IdentityLoss:
|
|
| 92 |
return
|
| 93 |
try:
|
| 94 |
from insightface.app import FaceAnalysis
|
|
|
|
| 95 |
self._app = FaceAnalysis(
|
| 96 |
name="buffalo_l",
|
| 97 |
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
| 98 |
)
|
| 99 |
-
ctx_id =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
self._app.prepare(ctx_id=ctx_id, det_size=(320, 320))
|
| 101 |
self._has_arcface = True
|
| 102 |
except Exception:
|
|
@@ -114,6 +120,7 @@ class IdentityLoss:
|
|
| 114 |
"""
|
| 115 |
if self._has_arcface:
|
| 116 |
import numpy as np
|
|
|
|
| 117 |
embeddings = []
|
| 118 |
valid_mask = []
|
| 119 |
for i in range(image_tensor.shape[0]):
|
|
@@ -152,7 +159,9 @@ class IdentityLoss:
|
|
| 152 |
|
| 153 |
# Resize to 112x112 for ArcFace
|
| 154 |
pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False)
|
| 155 |
-
target_112 = F.interpolate(
|
|
|
|
|
|
|
| 156 |
|
| 157 |
# Normalize to [-1, 1]
|
| 158 |
pred_norm = pred_112 * 2 - 1
|
|
@@ -163,7 +172,7 @@ class IdentityLoss:
|
|
| 163 |
target_emb, target_valid = self._extract_embedding(target_norm)
|
| 164 |
|
| 165 |
# Only compute loss for samples where both faces were detected
|
| 166 |
-
valid = [p and t for p, t in zip(pred_valid, target_valid)]
|
| 167 |
if not any(valid):
|
| 168 |
return torch.tensor(0.0, device=pred_image.device)
|
| 169 |
|
|
@@ -216,6 +225,7 @@ class PerceptualLoss:
|
|
| 216 |
if self._lpips is None:
|
| 217 |
try:
|
| 218 |
import lpips
|
|
|
|
| 219 |
self._lpips = lpips.LPIPS(net="alex").to(device)
|
| 220 |
self._lpips.eval()
|
| 221 |
for p in self._lpips.parameters():
|
|
@@ -225,9 +235,9 @@ class PerceptualLoss:
|
|
| 225 |
|
| 226 |
def __call__(
|
| 227 |
self,
|
| 228 |
-
pred: torch.Tensor,
|
| 229 |
target: torch.Tensor,
|
| 230 |
-
mask: torch.Tensor,
|
| 231 |
) -> torch.Tensor:
|
| 232 |
self._ensure_loaded(pred.device)
|
| 233 |
|
|
@@ -289,6 +299,7 @@ class CombinedLoss:
|
|
| 289 |
# or ONNX-based fallback
|
| 290 |
if use_differentiable_arcface:
|
| 291 |
from landmarkdiff.arcface_torch import ArcFaceLoss
|
|
|
|
| 292 |
self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
|
| 293 |
else:
|
| 294 |
self.identity_loss = IdentityLoss()
|
|
|
|
| 1 |
"""4-term loss function module for ControlNet fine-tuning.
|
| 2 |
|
| 3 |
+
L_total = L_diffusion + w_landmark * L_landmark
|
| 4 |
+
+ w_identity * L_identity + w_perceptual * L_perceptual
|
| 5 |
|
| 6 |
Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
|
| 7 |
rubbery TPS warps — it would penalize realism.
|
|
|
|
| 93 |
return
|
| 94 |
try:
|
| 95 |
from insightface.app import FaceAnalysis
|
| 96 |
+
|
| 97 |
self._app = FaceAnalysis(
|
| 98 |
name="buffalo_l",
|
| 99 |
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
| 100 |
)
|
| 101 |
+
ctx_id = (
|
| 102 |
+
device.index
|
| 103 |
+
if device.type == "cuda" and device.index is not None
|
| 104 |
+
else (0 if device.type == "cuda" else -1)
|
| 105 |
+
)
|
| 106 |
self._app.prepare(ctx_id=ctx_id, det_size=(320, 320))
|
| 107 |
self._has_arcface = True
|
| 108 |
except Exception:
|
|
|
|
| 120 |
"""
|
| 121 |
if self._has_arcface:
|
| 122 |
import numpy as np
|
| 123 |
+
|
| 124 |
embeddings = []
|
| 125 |
valid_mask = []
|
| 126 |
for i in range(image_tensor.shape[0]):
|
|
|
|
| 159 |
|
| 160 |
# Resize to 112x112 for ArcFace
|
| 161 |
pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False)
|
| 162 |
+
target_112 = F.interpolate(
|
| 163 |
+
target_crop, size=(112, 112), mode="bilinear", align_corners=False
|
| 164 |
+
)
|
| 165 |
|
| 166 |
# Normalize to [-1, 1]
|
| 167 |
pred_norm = pred_112 * 2 - 1
|
|
|
|
| 172 |
target_emb, target_valid = self._extract_embedding(target_norm)
|
| 173 |
|
| 174 |
# Only compute loss for samples where both faces were detected
|
| 175 |
+
valid = [p and t for p, t in zip(pred_valid, target_valid, strict=False)]
|
| 176 |
if not any(valid):
|
| 177 |
return torch.tensor(0.0, device=pred_image.device)
|
| 178 |
|
|
|
|
| 225 |
if self._lpips is None:
|
| 226 |
try:
|
| 227 |
import lpips
|
| 228 |
+
|
| 229 |
self._lpips = lpips.LPIPS(net="alex").to(device)
|
| 230 |
self._lpips.eval()
|
| 231 |
for p in self._lpips.parameters():
|
|
|
|
| 235 |
|
| 236 |
def __call__(
|
| 237 |
self,
|
| 238 |
+
pred: torch.Tensor, # (B, 3, H, W) in [0, 1]
|
| 239 |
target: torch.Tensor,
|
| 240 |
+
mask: torch.Tensor, # (B, 1, H, W) surgical mask [0, 1]
|
| 241 |
) -> torch.Tensor:
|
| 242 |
self._ensure_loaded(pred.device)
|
| 243 |
|
|
|
|
| 299 |
# or ONNX-based fallback
|
| 300 |
if use_differentiable_arcface:
|
| 301 |
from landmarkdiff.arcface_torch import ArcFaceLoss
|
| 302 |
+
|
| 303 |
self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
|
| 304 |
else:
|
| 305 |
self.identity_loss = IdentityLoss()
|
landmarkdiff/manipulation.py
CHANGED
|
@@ -7,11 +7,11 @@ mm inputs only in v3+ with FLAME calibrated metric space.
|
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
from dataclasses import dataclass
|
| 10 |
-
from typing import
|
| 11 |
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
-
from landmarkdiff.landmarks import FaceLandmarks
|
| 15 |
|
| 16 |
if TYPE_CHECKING:
|
| 17 |
from landmarkdiff.clinical import ClinicalFlags
|
|
@@ -23,38 +23,184 @@ class DeformationHandle:
|
|
| 23 |
|
| 24 |
landmark_index: int
|
| 25 |
displacement: np.ndarray # (2,) or (3,) pixel displacement
|
| 26 |
-
influence_radius: float
|
| 27 |
|
| 28 |
|
| 29 |
# Procedure-specific landmark indices from the technical specification
|
| 30 |
PROCEDURE_LANDMARKS: dict[str, list[int]] = {
|
| 31 |
"rhinoplasty": [
|
| 32 |
-
1,
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
],
|
| 35 |
"blepharoplasty": [
|
| 36 |
-
33,
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
],
|
| 40 |
"rhytidectomy": [
|
| 41 |
-
10,
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
],
|
| 45 |
"orthognathic": [
|
| 46 |
-
0,
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
],
|
| 51 |
"brow_lift": [
|
| 52 |
-
70,
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
],
|
| 56 |
"mentoplasty": [
|
| 57 |
-
148,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
],
|
| 59 |
}
|
| 60 |
# Default influence radii per procedure (in pixels at 512x512)
|
|
@@ -78,7 +224,7 @@ def gaussian_rbf_deform(
|
|
| 78 |
displacement = handle.displacement[:2]
|
| 79 |
|
| 80 |
distances_sq = np.sum((landmarks[:, :2] - center) ** 2, axis=1)
|
| 81 |
-
weights = np.exp(-distances_sq / (2.0 * handle.influence_radius
|
| 82 |
|
| 83 |
result[:, 0] += displacement[0] * weights
|
| 84 |
result[:, 1] += displacement[1] * weights
|
|
@@ -94,8 +240,8 @@ def apply_procedure_preset(
|
|
| 94 |
procedure: str,
|
| 95 |
intensity: float = 50.0,
|
| 96 |
image_size: int = 512,
|
| 97 |
-
clinical_flags:
|
| 98 |
-
displacement_model_path:
|
| 99 |
noise_scale: float = 0.0,
|
| 100 |
) -> FaceLandmarks:
|
| 101 |
"""Apply a surgical procedure preset to landmarks.
|
|
@@ -123,7 +269,11 @@ def apply_procedure_preset(
|
|
| 123 |
# Data-driven displacement mode
|
| 124 |
if displacement_model_path is not None:
|
| 125 |
return _apply_data_driven(
|
| 126 |
-
face,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
)
|
| 128 |
|
| 129 |
indices = PROCEDURE_LANDMARKS[procedure]
|
|
@@ -140,6 +290,7 @@ def apply_procedure_preset(
|
|
| 140 |
# Bell's palsy: remove handles on the affected (paralyzed) side
|
| 141 |
if clinical_flags and clinical_flags.bells_palsy:
|
| 142 |
from landmarkdiff.clinical import get_bells_palsy_side_indices
|
|
|
|
| 143 |
affected = get_bells_palsy_side_indices(clinical_flags.bells_palsy_side)
|
| 144 |
affected_indices = set()
|
| 145 |
for region_indices in affected.values():
|
|
@@ -219,48 +370,58 @@ def _get_procedure_handles(
|
|
| 219 |
left_alar = [240, 236, 141, 363, 370]
|
| 220 |
for idx in left_alar:
|
| 221 |
if idx in indices:
|
| 222 |
-
handles.append(
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
# right nostril -> move LEFT (-X)
|
| 228 |
right_alar = [460, 456, 274, 275, 278, 279]
|
| 229 |
for idx in right_alar:
|
| 230 |
if idx in indices:
|
| 231 |
-
handles.append(
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
| 236 |
|
| 237 |
# --- Tip refinement: subtle upward rotation + narrowing ---
|
| 238 |
tip_indices = [1, 2, 94, 19]
|
| 239 |
for idx in tip_indices:
|
| 240 |
if idx in indices:
|
| 241 |
-
handles.append(
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
| 246 |
|
| 247 |
# --- Dorsum narrowing: bilateral squeeze of nasal bridge ---
|
| 248 |
dorsum_left = [195, 197, 236]
|
| 249 |
for idx in dorsum_left:
|
| 250 |
if idx in indices:
|
| 251 |
-
handles.append(
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
| 256 |
dorsum_right = [326, 327, 456]
|
| 257 |
for idx in dorsum_right:
|
| 258 |
if idx in indices:
|
| 259 |
-
handles.append(
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
| 264 |
|
| 265 |
elif procedure == "blepharoplasty":
|
| 266 |
# --- Upper lid elevation (primary effect) ---
|
|
@@ -268,31 +429,37 @@ def _get_procedure_handles(
|
|
| 268 |
upper_lid_right = [386, 385, 384]
|
| 269 |
for idx in upper_lid_left + upper_lid_right:
|
| 270 |
if idx in indices:
|
| 271 |
-
handles.append(
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
|
|
|
|
|
|
| 276 |
# --- Medial/lateral lid corners: less displacement (tapered) ---
|
| 277 |
corner_left = [158, 157, 133, 33]
|
| 278 |
corner_right = [387, 388, 362, 263]
|
| 279 |
for idx in corner_left + corner_right:
|
| 280 |
if idx in indices:
|
| 281 |
-
handles.append(
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
| 286 |
# --- Subtle lower lid tightening ---
|
| 287 |
lower_lid_left = [145, 153, 154]
|
| 288 |
lower_lid_right = [374, 380, 381]
|
| 289 |
for idx in lower_lid_left + lower_lid_right:
|
| 290 |
if idx in indices:
|
| 291 |
-
handles.append(
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
| 296 |
|
| 297 |
elif procedure == "rhytidectomy":
|
| 298 |
# Different displacement vectors by anatomical sub-region.
|
|
@@ -300,82 +467,100 @@ def _get_procedure_handles(
|
|
| 300 |
jowl_left = [132, 136, 172, 58, 150, 176]
|
| 301 |
for idx in jowl_left:
|
| 302 |
if idx in indices:
|
| 303 |
-
handles.append(
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
| 308 |
jowl_right = [361, 365, 397, 288, 379, 400]
|
| 309 |
for idx in jowl_right:
|
| 310 |
if idx in indices:
|
| 311 |
-
handles.append(
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
| 316 |
# Chin/submental: upward only (no lateral)
|
| 317 |
chin = [152, 148, 377, 378]
|
| 318 |
for idx in chin:
|
| 319 |
if idx in indices:
|
| 320 |
-
handles.append(
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
| 325 |
# Temple/upper face: very mild lift
|
| 326 |
temple_left = [10, 21, 54, 67, 103, 109, 162, 127]
|
| 327 |
temple_right = [284, 297, 332, 338, 323, 356, 389, 454]
|
| 328 |
for idx in temple_left:
|
| 329 |
if idx in indices:
|
| 330 |
-
handles.append(
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
| 335 |
for idx in temple_right:
|
| 336 |
if idx in indices:
|
| 337 |
-
handles.append(
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
|
|
|
|
|
|
| 342 |
|
| 343 |
elif procedure == "orthognathic":
|
| 344 |
# --- Mandible repositioning: move jaw up and forward (visible as upward in 2D) ---
|
| 345 |
lower_jaw = [17, 18, 200, 201, 202, 204, 208, 211, 212, 214]
|
| 346 |
for idx in lower_jaw:
|
| 347 |
if idx in indices:
|
| 348 |
-
handles.append(
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
| 353 |
# --- Chin projection: move chin point forward/upward ---
|
| 354 |
chin_pts = [175, 170, 169, 167, 396]
|
| 355 |
for idx in chin_pts:
|
| 356 |
if idx in indices:
|
| 357 |
-
handles.append(
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
|
|
|
|
|
|
| 362 |
# --- Lateral jaw: bilateral symmetric inward pull for narrowing ---
|
| 363 |
jaw_left = [57, 61, 78, 91, 95, 146, 181]
|
| 364 |
for idx in jaw_left:
|
| 365 |
if idx in indices:
|
| 366 |
-
handles.append(
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
|
|
|
|
|
|
| 371 |
jaw_right = [291, 311, 312, 321, 324, 325, 375, 405]
|
| 372 |
for idx in jaw_right:
|
| 373 |
if idx in indices:
|
| 374 |
-
handles.append(
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
| 379 |
|
| 380 |
elif procedure == "brow_lift":
|
| 381 |
# --- Brow elevation ---
|
|
@@ -386,56 +571,68 @@ def _get_procedure_handles(
|
|
| 386 |
left_weights = [0.7, 0.8, 0.9, 1.0, 1.1]
|
| 387 |
for i, idx in enumerate(brow_left):
|
| 388 |
if idx in indices:
|
| 389 |
-
handles.append(
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
|
|
|
|
|
|
| 394 |
|
| 395 |
right_weights = [0.7, 0.8, 0.9, 1.0, 1.1]
|
| 396 |
for i, idx in enumerate(brow_right):
|
| 397 |
if idx in indices:
|
| 398 |
-
handles.append(
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
| 403 |
|
| 404 |
# --- Forehead smoothing / subtle lift ---
|
| 405 |
forehead = [9, 8, 10, 109, 67, 103, 338, 297, 332]
|
| 406 |
for idx in forehead:
|
| 407 |
if idx in indices:
|
| 408 |
-
handles.append(
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
|
|
|
|
|
|
| 413 |
elif procedure == "mentoplasty":
|
| 414 |
# --- Chin tip advancement: move chin forward (upward in 2D) ---
|
| 415 |
chin_tip = [152, 175]
|
| 416 |
for idx in chin_tip:
|
| 417 |
if idx in indices:
|
| 418 |
-
handles.append(
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
|
|
|
|
|
|
| 423 |
# --- Lower chin contour: follow tip with softer displacement ---
|
| 424 |
lower_contour = [148, 149, 150, 176, 377]
|
| 425 |
for idx in lower_contour:
|
| 426 |
if idx in indices:
|
| 427 |
-
handles.append(
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
|
|
|
|
|
|
| 432 |
# --- Jaw angles: minimal upward pull for natural transition ---
|
| 433 |
jaw_angles = [171, 396]
|
| 434 |
for idx in jaw_angles:
|
| 435 |
if idx in indices:
|
| 436 |
-
handles.append(
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
|
|
|
|
|
|
| 441 |
return handles
|
|
|
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
from dataclasses import dataclass
|
| 10 |
+
from typing import TYPE_CHECKING
|
| 11 |
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
+
from landmarkdiff.landmarks import FaceLandmarks
|
| 15 |
|
| 16 |
if TYPE_CHECKING:
|
| 17 |
from landmarkdiff.clinical import ClinicalFlags
|
|
|
|
| 23 |
|
| 24 |
landmark_index: int
|
| 25 |
displacement: np.ndarray # (2,) or (3,) pixel displacement
|
| 26 |
+
influence_radius: float # Gaussian RBF radius in pixels
|
| 27 |
|
| 28 |
|
| 29 |
# Procedure-specific landmark indices from the technical specification
|
| 30 |
PROCEDURE_LANDMARKS: dict[str, list[int]] = {
|
| 31 |
"rhinoplasty": [
|
| 32 |
+
1,
|
| 33 |
+
2,
|
| 34 |
+
4,
|
| 35 |
+
5,
|
| 36 |
+
6,
|
| 37 |
+
19,
|
| 38 |
+
94,
|
| 39 |
+
141,
|
| 40 |
+
168,
|
| 41 |
+
195,
|
| 42 |
+
197,
|
| 43 |
+
236,
|
| 44 |
+
240,
|
| 45 |
+
274,
|
| 46 |
+
275,
|
| 47 |
+
278,
|
| 48 |
+
279,
|
| 49 |
+
294,
|
| 50 |
+
326,
|
| 51 |
+
327,
|
| 52 |
+
360,
|
| 53 |
+
363,
|
| 54 |
+
370,
|
| 55 |
+
456,
|
| 56 |
+
460,
|
| 57 |
],
|
| 58 |
"blepharoplasty": [
|
| 59 |
+
33,
|
| 60 |
+
7,
|
| 61 |
+
163,
|
| 62 |
+
144,
|
| 63 |
+
145,
|
| 64 |
+
153,
|
| 65 |
+
154,
|
| 66 |
+
155,
|
| 67 |
+
157,
|
| 68 |
+
158,
|
| 69 |
+
159,
|
| 70 |
+
160,
|
| 71 |
+
161,
|
| 72 |
+
246,
|
| 73 |
+
362,
|
| 74 |
+
382,
|
| 75 |
+
381,
|
| 76 |
+
380,
|
| 77 |
+
374,
|
| 78 |
+
373,
|
| 79 |
+
390,
|
| 80 |
+
249,
|
| 81 |
+
263,
|
| 82 |
+
466,
|
| 83 |
+
388,
|
| 84 |
+
387,
|
| 85 |
+
386,
|
| 86 |
+
385,
|
| 87 |
+
384,
|
| 88 |
+
398,
|
| 89 |
],
|
| 90 |
"rhytidectomy": [
|
| 91 |
+
10,
|
| 92 |
+
21,
|
| 93 |
+
54,
|
| 94 |
+
58,
|
| 95 |
+
67,
|
| 96 |
+
93,
|
| 97 |
+
103,
|
| 98 |
+
109,
|
| 99 |
+
127,
|
| 100 |
+
132,
|
| 101 |
+
136,
|
| 102 |
+
150,
|
| 103 |
+
162,
|
| 104 |
+
172,
|
| 105 |
+
176,
|
| 106 |
+
187,
|
| 107 |
+
207,
|
| 108 |
+
213,
|
| 109 |
+
234,
|
| 110 |
+
284,
|
| 111 |
+
297,
|
| 112 |
+
323,
|
| 113 |
+
332,
|
| 114 |
+
338,
|
| 115 |
+
356,
|
| 116 |
+
361,
|
| 117 |
+
365,
|
| 118 |
+
379,
|
| 119 |
+
389,
|
| 120 |
+
397,
|
| 121 |
+
400,
|
| 122 |
+
427,
|
| 123 |
+
454,
|
| 124 |
],
|
| 125 |
"orthognathic": [
|
| 126 |
+
0,
|
| 127 |
+
17,
|
| 128 |
+
18,
|
| 129 |
+
36,
|
| 130 |
+
37,
|
| 131 |
+
39,
|
| 132 |
+
40,
|
| 133 |
+
57,
|
| 134 |
+
61,
|
| 135 |
+
78,
|
| 136 |
+
80,
|
| 137 |
+
81,
|
| 138 |
+
82,
|
| 139 |
+
84,
|
| 140 |
+
87,
|
| 141 |
+
88,
|
| 142 |
+
91,
|
| 143 |
+
95,
|
| 144 |
+
146,
|
| 145 |
+
167,
|
| 146 |
+
169,
|
| 147 |
+
170,
|
| 148 |
+
175,
|
| 149 |
+
181,
|
| 150 |
+
191,
|
| 151 |
+
200,
|
| 152 |
+
201,
|
| 153 |
+
202,
|
| 154 |
+
204,
|
| 155 |
+
208,
|
| 156 |
+
211,
|
| 157 |
+
212,
|
| 158 |
+
214,
|
| 159 |
+
269,
|
| 160 |
+
270,
|
| 161 |
+
291,
|
| 162 |
+
311,
|
| 163 |
+
312,
|
| 164 |
+
317,
|
| 165 |
+
321,
|
| 166 |
+
324,
|
| 167 |
+
325,
|
| 168 |
+
375,
|
| 169 |
+
396,
|
| 170 |
+
405,
|
| 171 |
+
407,
|
| 172 |
+
415,
|
| 173 |
],
|
| 174 |
"brow_lift": [
|
| 175 |
+
70,
|
| 176 |
+
63,
|
| 177 |
+
105,
|
| 178 |
+
66,
|
| 179 |
+
107, # left brow
|
| 180 |
+
300,
|
| 181 |
+
293,
|
| 182 |
+
334,
|
| 183 |
+
296,
|
| 184 |
+
336, # right brow
|
| 185 |
+
9,
|
| 186 |
+
8,
|
| 187 |
+
10,
|
| 188 |
+
109,
|
| 189 |
+
67,
|
| 190 |
+
103,
|
| 191 |
+
338,
|
| 192 |
+
297,
|
| 193 |
+
332, # forehead/upper face
|
| 194 |
],
|
| 195 |
"mentoplasty": [
|
| 196 |
+
148,
|
| 197 |
+
149,
|
| 198 |
+
150,
|
| 199 |
+
152,
|
| 200 |
+
171,
|
| 201 |
+
175,
|
| 202 |
+
176,
|
| 203 |
+
377,
|
| 204 |
],
|
| 205 |
}
|
| 206 |
# Default influence radii per procedure (in pixels at 512x512)
|
|
|
|
| 224 |
displacement = handle.displacement[:2]
|
| 225 |
|
| 226 |
distances_sq = np.sum((landmarks[:, :2] - center) ** 2, axis=1)
|
| 227 |
+
weights = np.exp(-distances_sq / (2.0 * handle.influence_radius**2))
|
| 228 |
|
| 229 |
result[:, 0] += displacement[0] * weights
|
| 230 |
result[:, 1] += displacement[1] * weights
|
|
|
|
| 240 |
procedure: str,
|
| 241 |
intensity: float = 50.0,
|
| 242 |
image_size: int = 512,
|
| 243 |
+
clinical_flags: ClinicalFlags | None = None,
|
| 244 |
+
displacement_model_path: str | None = None,
|
| 245 |
noise_scale: float = 0.0,
|
| 246 |
) -> FaceLandmarks:
|
| 247 |
"""Apply a surgical procedure preset to landmarks.
|
|
|
|
| 269 |
# Data-driven displacement mode
|
| 270 |
if displacement_model_path is not None:
|
| 271 |
return _apply_data_driven(
|
| 272 |
+
face,
|
| 273 |
+
procedure,
|
| 274 |
+
scale,
|
| 275 |
+
displacement_model_path,
|
| 276 |
+
noise_scale,
|
| 277 |
)
|
| 278 |
|
| 279 |
indices = PROCEDURE_LANDMARKS[procedure]
|
|
|
|
| 290 |
# Bell's palsy: remove handles on the affected (paralyzed) side
|
| 291 |
if clinical_flags and clinical_flags.bells_palsy:
|
| 292 |
from landmarkdiff.clinical import get_bells_palsy_side_indices
|
| 293 |
+
|
| 294 |
affected = get_bells_palsy_side_indices(clinical_flags.bells_palsy_side)
|
| 295 |
affected_indices = set()
|
| 296 |
for region_indices in affected.values():
|
|
|
|
| 370 |
left_alar = [240, 236, 141, 363, 370]
|
| 371 |
for idx in left_alar:
|
| 372 |
if idx in indices:
|
| 373 |
+
handles.append(
|
| 374 |
+
DeformationHandle(
|
| 375 |
+
landmark_index=idx,
|
| 376 |
+
displacement=np.array([2.5 * scale, 0.0]),
|
| 377 |
+
influence_radius=radius * 0.6,
|
| 378 |
+
)
|
| 379 |
+
)
|
| 380 |
# right nostril -> move LEFT (-X)
|
| 381 |
right_alar = [460, 456, 274, 275, 278, 279]
|
| 382 |
for idx in right_alar:
|
| 383 |
if idx in indices:
|
| 384 |
+
handles.append(
|
| 385 |
+
DeformationHandle(
|
| 386 |
+
landmark_index=idx,
|
| 387 |
+
displacement=np.array([-2.5 * scale, 0.0]),
|
| 388 |
+
influence_radius=radius * 0.6,
|
| 389 |
+
)
|
| 390 |
+
)
|
| 391 |
|
| 392 |
# --- Tip refinement: subtle upward rotation + narrowing ---
|
| 393 |
tip_indices = [1, 2, 94, 19]
|
| 394 |
for idx in tip_indices:
|
| 395 |
if idx in indices:
|
| 396 |
+
handles.append(
|
| 397 |
+
DeformationHandle(
|
| 398 |
+
landmark_index=idx,
|
| 399 |
+
displacement=np.array([0.0, -2.0 * scale]),
|
| 400 |
+
influence_radius=radius * 0.5,
|
| 401 |
+
)
|
| 402 |
+
)
|
| 403 |
|
| 404 |
# --- Dorsum narrowing: bilateral squeeze of nasal bridge ---
|
| 405 |
dorsum_left = [195, 197, 236]
|
| 406 |
for idx in dorsum_left:
|
| 407 |
if idx in indices:
|
| 408 |
+
handles.append(
|
| 409 |
+
DeformationHandle(
|
| 410 |
+
landmark_index=idx,
|
| 411 |
+
displacement=np.array([1.5 * scale, 0.0]),
|
| 412 |
+
influence_radius=radius * 0.5,
|
| 413 |
+
)
|
| 414 |
+
)
|
| 415 |
dorsum_right = [326, 327, 456]
|
| 416 |
for idx in dorsum_right:
|
| 417 |
if idx in indices:
|
| 418 |
+
handles.append(
|
| 419 |
+
DeformationHandle(
|
| 420 |
+
landmark_index=idx,
|
| 421 |
+
displacement=np.array([-1.5 * scale, 0.0]),
|
| 422 |
+
influence_radius=radius * 0.5,
|
| 423 |
+
)
|
| 424 |
+
)
|
| 425 |
|
| 426 |
elif procedure == "blepharoplasty":
|
| 427 |
# --- Upper lid elevation (primary effect) ---
|
|
|
|
| 429 |
upper_lid_right = [386, 385, 384]
|
| 430 |
for idx in upper_lid_left + upper_lid_right:
|
| 431 |
if idx in indices:
|
| 432 |
+
handles.append(
|
| 433 |
+
DeformationHandle(
|
| 434 |
+
landmark_index=idx,
|
| 435 |
+
displacement=np.array([0.0, -2.0 * scale]),
|
| 436 |
+
influence_radius=radius,
|
| 437 |
+
)
|
| 438 |
+
)
|
| 439 |
# --- Medial/lateral lid corners: less displacement (tapered) ---
|
| 440 |
corner_left = [158, 157, 133, 33]
|
| 441 |
corner_right = [387, 388, 362, 263]
|
| 442 |
for idx in corner_left + corner_right:
|
| 443 |
if idx in indices:
|
| 444 |
+
handles.append(
|
| 445 |
+
DeformationHandle(
|
| 446 |
+
landmark_index=idx,
|
| 447 |
+
displacement=np.array([0.0, -0.8 * scale]),
|
| 448 |
+
influence_radius=radius * 0.7,
|
| 449 |
+
)
|
| 450 |
+
)
|
| 451 |
# --- Subtle lower lid tightening ---
|
| 452 |
lower_lid_left = [145, 153, 154]
|
| 453 |
lower_lid_right = [374, 380, 381]
|
| 454 |
for idx in lower_lid_left + lower_lid_right:
|
| 455 |
if idx in indices:
|
| 456 |
+
handles.append(
|
| 457 |
+
DeformationHandle(
|
| 458 |
+
landmark_index=idx,
|
| 459 |
+
displacement=np.array([0.0, 0.5 * scale]),
|
| 460 |
+
influence_radius=radius * 0.5,
|
| 461 |
+
)
|
| 462 |
+
)
|
| 463 |
|
| 464 |
elif procedure == "rhytidectomy":
|
| 465 |
# Different displacement vectors by anatomical sub-region.
|
|
|
|
| 467 |
jowl_left = [132, 136, 172, 58, 150, 176]
|
| 468 |
for idx in jowl_left:
|
| 469 |
if idx in indices:
|
| 470 |
+
handles.append(
|
| 471 |
+
DeformationHandle(
|
| 472 |
+
landmark_index=idx,
|
| 473 |
+
displacement=np.array([-2.5 * scale, -3.0 * scale]),
|
| 474 |
+
influence_radius=radius,
|
| 475 |
+
)
|
| 476 |
+
)
|
| 477 |
jowl_right = [361, 365, 397, 288, 379, 400]
|
| 478 |
for idx in jowl_right:
|
| 479 |
if idx in indices:
|
| 480 |
+
handles.append(
|
| 481 |
+
DeformationHandle(
|
| 482 |
+
landmark_index=idx,
|
| 483 |
+
displacement=np.array([2.5 * scale, -3.0 * scale]),
|
| 484 |
+
influence_radius=radius,
|
| 485 |
+
)
|
| 486 |
+
)
|
| 487 |
# Chin/submental: upward only (no lateral)
|
| 488 |
chin = [152, 148, 377, 378]
|
| 489 |
for idx in chin:
|
| 490 |
if idx in indices:
|
| 491 |
+
handles.append(
|
| 492 |
+
DeformationHandle(
|
| 493 |
+
landmark_index=idx,
|
| 494 |
+
displacement=np.array([0.0, -2.0 * scale]),
|
| 495 |
+
influence_radius=radius * 0.8,
|
| 496 |
+
)
|
| 497 |
+
)
|
| 498 |
# Temple/upper face: very mild lift
|
| 499 |
temple_left = [10, 21, 54, 67, 103, 109, 162, 127]
|
| 500 |
temple_right = [284, 297, 332, 338, 323, 356, 389, 454]
|
| 501 |
for idx in temple_left:
|
| 502 |
if idx in indices:
|
| 503 |
+
handles.append(
|
| 504 |
+
DeformationHandle(
|
| 505 |
+
landmark_index=idx,
|
| 506 |
+
displacement=np.array([-0.5 * scale, -1.0 * scale]),
|
| 507 |
+
influence_radius=radius * 0.6,
|
| 508 |
+
)
|
| 509 |
+
)
|
| 510 |
for idx in temple_right:
|
| 511 |
if idx in indices:
|
| 512 |
+
handles.append(
|
| 513 |
+
DeformationHandle(
|
| 514 |
+
landmark_index=idx,
|
| 515 |
+
displacement=np.array([0.5 * scale, -1.0 * scale]),
|
| 516 |
+
influence_radius=radius * 0.6,
|
| 517 |
+
)
|
| 518 |
+
)
|
| 519 |
|
| 520 |
elif procedure == "orthognathic":
|
| 521 |
# --- Mandible repositioning: move jaw up and forward (visible as upward in 2D) ---
|
| 522 |
lower_jaw = [17, 18, 200, 201, 202, 204, 208, 211, 212, 214]
|
| 523 |
for idx in lower_jaw:
|
| 524 |
if idx in indices:
|
| 525 |
+
handles.append(
|
| 526 |
+
DeformationHandle(
|
| 527 |
+
landmark_index=idx,
|
| 528 |
+
displacement=np.array([0.0, -3.0 * scale]),
|
| 529 |
+
influence_radius=radius,
|
| 530 |
+
)
|
| 531 |
+
)
|
| 532 |
# --- Chin projection: move chin point forward/upward ---
|
| 533 |
chin_pts = [175, 170, 169, 167, 396]
|
| 534 |
for idx in chin_pts:
|
| 535 |
if idx in indices:
|
| 536 |
+
handles.append(
|
| 537 |
+
DeformationHandle(
|
| 538 |
+
landmark_index=idx,
|
| 539 |
+
displacement=np.array([0.0, -2.0 * scale]),
|
| 540 |
+
influence_radius=radius * 0.7,
|
| 541 |
+
)
|
| 542 |
+
)
|
| 543 |
# --- Lateral jaw: bilateral symmetric inward pull for narrowing ---
|
| 544 |
jaw_left = [57, 61, 78, 91, 95, 146, 181]
|
| 545 |
for idx in jaw_left:
|
| 546 |
if idx in indices:
|
| 547 |
+
handles.append(
|
| 548 |
+
DeformationHandle(
|
| 549 |
+
landmark_index=idx,
|
| 550 |
+
displacement=np.array([1.5 * scale, -1.0 * scale]),
|
| 551 |
+
influence_radius=radius * 0.8,
|
| 552 |
+
)
|
| 553 |
+
)
|
| 554 |
jaw_right = [291, 311, 312, 321, 324, 325, 375, 405]
|
| 555 |
for idx in jaw_right:
|
| 556 |
if idx in indices:
|
| 557 |
+
handles.append(
|
| 558 |
+
DeformationHandle(
|
| 559 |
+
landmark_index=idx,
|
| 560 |
+
displacement=np.array([-1.5 * scale, -1.0 * scale]),
|
| 561 |
+
influence_radius=radius * 0.8,
|
| 562 |
+
)
|
| 563 |
+
)
|
| 564 |
|
| 565 |
elif procedure == "brow_lift":
|
| 566 |
# --- Brow elevation ---
|
|
|
|
| 571 |
left_weights = [0.7, 0.8, 0.9, 1.0, 1.1]
|
| 572 |
for i, idx in enumerate(brow_left):
|
| 573 |
if idx in indices:
|
| 574 |
+
handles.append(
|
| 575 |
+
DeformationHandle(
|
| 576 |
+
landmark_index=idx,
|
| 577 |
+
displacement=np.array([0.0, -4.0 * left_weights[i] * scale]),
|
| 578 |
+
influence_radius=radius,
|
| 579 |
+
)
|
| 580 |
+
)
|
| 581 |
|
| 582 |
right_weights = [0.7, 0.8, 0.9, 1.0, 1.1]
|
| 583 |
for i, idx in enumerate(brow_right):
|
| 584 |
if idx in indices:
|
| 585 |
+
handles.append(
|
| 586 |
+
DeformationHandle(
|
| 587 |
+
landmark_index=idx,
|
| 588 |
+
displacement=np.array([0.0, -4.0 * right_weights[i] * scale]),
|
| 589 |
+
influence_radius=radius,
|
| 590 |
+
)
|
| 591 |
+
)
|
| 592 |
|
| 593 |
# --- Forehead smoothing / subtle lift ---
|
| 594 |
forehead = [9, 8, 10, 109, 67, 103, 338, 297, 332]
|
| 595 |
for idx in forehead:
|
| 596 |
if idx in indices:
|
| 597 |
+
handles.append(
|
| 598 |
+
DeformationHandle(
|
| 599 |
+
landmark_index=idx,
|
| 600 |
+
displacement=np.array([0.0, -1.5 * scale]),
|
| 601 |
+
influence_radius=radius * 1.2,
|
| 602 |
+
)
|
| 603 |
+
)
|
| 604 |
elif procedure == "mentoplasty":
|
| 605 |
# --- Chin tip advancement: move chin forward (upward in 2D) ---
|
| 606 |
chin_tip = [152, 175]
|
| 607 |
for idx in chin_tip:
|
| 608 |
if idx in indices:
|
| 609 |
+
handles.append(
|
| 610 |
+
DeformationHandle(
|
| 611 |
+
landmark_index=idx,
|
| 612 |
+
displacement=np.array([0.0, -4.0 * scale]),
|
| 613 |
+
influence_radius=radius,
|
| 614 |
+
)
|
| 615 |
+
)
|
| 616 |
# --- Lower chin contour: follow tip with softer displacement ---
|
| 617 |
lower_contour = [148, 149, 150, 176, 377]
|
| 618 |
for idx in lower_contour:
|
| 619 |
if idx in indices:
|
| 620 |
+
handles.append(
|
| 621 |
+
DeformationHandle(
|
| 622 |
+
landmark_index=idx,
|
| 623 |
+
displacement=np.array([0.0, -2.5 * scale]),
|
| 624 |
+
influence_radius=radius * 0.8,
|
| 625 |
+
)
|
| 626 |
+
)
|
| 627 |
# --- Jaw angles: minimal upward pull for natural transition ---
|
| 628 |
jaw_angles = [171, 396]
|
| 629 |
for idx in jaw_angles:
|
| 630 |
if idx in indices:
|
| 631 |
+
handles.append(
|
| 632 |
+
DeformationHandle(
|
| 633 |
+
landmark_index=idx,
|
| 634 |
+
displacement=np.array([0.0, -1.0 * scale]),
|
| 635 |
+
influence_radius=radius * 0.6,
|
| 636 |
+
)
|
| 637 |
+
)
|
| 638 |
return handles
|
landmarkdiff/metrics_agg.py
CHANGED
|
@@ -41,8 +41,12 @@ class MetricsAggregator:
|
|
| 41 |
"""
|
| 42 |
|
| 43 |
HIGHER_BETTER = {
|
| 44 |
-
"ssim": True,
|
| 45 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
}
|
| 47 |
|
| 48 |
def __init__(self) -> None:
|
|
@@ -57,13 +61,15 @@ class MetricsAggregator:
|
|
| 57 |
**metadata: Any,
|
| 58 |
) -> None:
|
| 59 |
"""Add a single evaluation record."""
|
| 60 |
-
self.records.append(
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
|
| 68 |
def add_batch(
|
| 69 |
self,
|
|
@@ -76,7 +82,9 @@ class MetricsAggregator:
|
|
| 76 |
"""
|
| 77 |
for rec in records:
|
| 78 |
proc = rec.get("procedure", "all")
|
| 79 |
-
metrics = {
|
|
|
|
|
|
|
| 80 |
self.add(experiment, proc, metrics)
|
| 81 |
|
| 82 |
@property
|
|
@@ -211,10 +219,7 @@ class MetricsAggregator:
|
|
| 211 |
val = self.mean(exp, metric, procedure)
|
| 212 |
if math.isnan(val):
|
| 213 |
continue
|
| 214 |
-
if higher_better and val > best_val:
|
| 215 |
-
best_val = val
|
| 216 |
-
best_exp = exp
|
| 217 |
-
elif not higher_better and val < best_val:
|
| 218 |
best_val = val
|
| 219 |
best_exp = exp
|
| 220 |
|
|
|
|
| 41 |
"""
|
| 42 |
|
| 43 |
HIGHER_BETTER = {
|
| 44 |
+
"ssim": True,
|
| 45 |
+
"psnr": True,
|
| 46 |
+
"identity_sim": True,
|
| 47 |
+
"lpips": False,
|
| 48 |
+
"fid": False,
|
| 49 |
+
"nme": False,
|
| 50 |
}
|
| 51 |
|
| 52 |
def __init__(self) -> None:
|
|
|
|
| 61 |
**metadata: Any,
|
| 62 |
) -> None:
|
| 63 |
"""Add a single evaluation record."""
|
| 64 |
+
self.records.append(
|
| 65 |
+
MetricRecord(
|
| 66 |
+
experiment=experiment,
|
| 67 |
+
procedure=procedure,
|
| 68 |
+
metrics=metrics,
|
| 69 |
+
checkpoint_step=checkpoint_step,
|
| 70 |
+
metadata=metadata,
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
|
| 74 |
def add_batch(
|
| 75 |
self,
|
|
|
|
| 82 |
"""
|
| 83 |
for rec in records:
|
| 84 |
proc = rec.get("procedure", "all")
|
| 85 |
+
metrics = {
|
| 86 |
+
k: v for k, v in rec.items() if k != "procedure" and isinstance(v, (int, float))
|
| 87 |
+
}
|
| 88 |
self.add(experiment, proc, metrics)
|
| 89 |
|
| 90 |
@property
|
|
|
|
| 219 |
val = self.mean(exp, metric, procedure)
|
| 220 |
if math.isnan(val):
|
| 221 |
continue
|
| 222 |
+
if (higher_better and val > best_val) or (not higher_better and val < best_val):
|
|
|
|
|
|
|
|
|
|
| 223 |
best_val = val
|
| 224 |
best_exp = exp
|
| 225 |
|
landmarkdiff/metrics_viz.py
CHANGED
|
@@ -24,7 +24,6 @@ Usage:
|
|
| 24 |
|
| 25 |
from __future__ import annotations
|
| 26 |
|
| 27 |
-
import json
|
| 28 |
from pathlib import Path
|
| 29 |
from typing import Any
|
| 30 |
|
|
@@ -79,25 +78,29 @@ class MetricsVisualizer:
|
|
| 79 |
self.dpi = dpi
|
| 80 |
self.style = style
|
| 81 |
|
| 82 |
-
def _get_plt(self):
|
| 83 |
"""Import matplotlib with configuration."""
|
| 84 |
import matplotlib
|
|
|
|
| 85 |
matplotlib.use("Agg")
|
| 86 |
import matplotlib.pyplot as plt
|
|
|
|
| 87 |
try:
|
| 88 |
plt.style.use(self.style)
|
| 89 |
except OSError:
|
| 90 |
plt.style.use("seaborn-v0_8")
|
| 91 |
# Publication font sizes
|
| 92 |
-
plt.rcParams.update(
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
return plt
|
| 102 |
|
| 103 |
# ------------------------------------------------------------------
|
|
@@ -138,7 +141,7 @@ class MetricsVisualizer:
|
|
| 138 |
if n_metrics == 1:
|
| 139 |
axes = [axes]
|
| 140 |
|
| 141 |
-
for ax, metric in zip(axes, metrics):
|
| 142 |
values = [metrics_by_procedure[p].get(metric, 0) for p in procedures]
|
| 143 |
colors = [self.COLORS.get(p, "#999999") for p in procedures]
|
| 144 |
|
|
@@ -146,16 +149,21 @@ class MetricsVisualizer:
|
|
| 146 |
ax.set_xticks(range(n_procs))
|
| 147 |
ax.set_xticklabels(
|
| 148 |
[p[:5].title() for p in procedures],
|
| 149 |
-
rotation=30,
|
|
|
|
| 150 |
)
|
| 151 |
ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
|
| 152 |
ax.set_title(self.METRIC_LABELS.get(metric, metric))
|
| 153 |
|
| 154 |
# Add value labels on bars
|
| 155 |
-
for bar, val in zip(bars, values):
|
| 156 |
ax.text(
|
| 157 |
-
bar.get_x() + bar.get_width() / 2,
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
)
|
| 160 |
|
| 161 |
fig.suptitle(title, fontweight="bold")
|
|
@@ -192,9 +200,8 @@ class MetricsVisualizer:
|
|
| 192 |
|
| 193 |
if metrics is None:
|
| 194 |
metrics = sorted(
|
| 195 |
-
set.intersection(
|
| 196 |
-
|
| 197 |
-
) & set(self.METRIC_LABELS.keys())
|
| 198 |
)
|
| 199 |
|
| 200 |
n_metrics = len(metrics)
|
|
@@ -258,9 +265,7 @@ class MetricsVisualizer:
|
|
| 258 |
plt = self._get_plt()
|
| 259 |
|
| 260 |
fitz_types = sorted(metrics_by_type.keys())
|
| 261 |
-
procedures = sorted(
|
| 262 |
-
set.union(*(set(v.keys()) for v in metrics_by_type.values()))
|
| 263 |
-
)
|
| 264 |
|
| 265 |
# Build matrix
|
| 266 |
matrix = np.zeros((len(fitz_types), len(procedures)))
|
|
@@ -268,7 +273,9 @@ class MetricsVisualizer:
|
|
| 268 |
for j, proc in enumerate(procedures):
|
| 269 |
matrix[i, j] = metrics_by_type[ft].get(proc, 0)
|
| 270 |
|
| 271 |
-
fig, ax = plt.subplots(
|
|
|
|
|
|
|
| 272 |
|
| 273 |
cmap = "RdYlGn" if self.METRIC_HIGHER_BETTER.get(metric, True) else "RdYlGn_r"
|
| 274 |
im = ax.imshow(matrix, cmap=cmap, aspect="auto")
|
|
@@ -282,9 +289,15 @@ class MetricsVisualizer:
|
|
| 282 |
# Annotate cells
|
| 283 |
for i in range(len(fitz_types)):
|
| 284 |
for j in range(len(procedures)):
|
| 285 |
-
ax.text(
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
fig.colorbar(im, ax=ax, label=self.METRIC_LABELS.get(metric, metric))
|
| 290 |
|
|
@@ -328,18 +341,21 @@ class MetricsVisualizer:
|
|
| 328 |
fig, ax = plt.subplots(figsize=(max(6, len(groups) * 1.2), 5))
|
| 329 |
|
| 330 |
bp = ax.boxplot(
|
| 331 |
-
data,
|
|
|
|
|
|
|
| 332 |
medianprops={"color": "black", "linewidth": 1.5},
|
| 333 |
)
|
| 334 |
|
| 335 |
colors = [self.COLORS.get(g, "#4C72B0") for g in groups]
|
| 336 |
-
for patch, color in zip(bp["boxes"], colors):
|
| 337 |
patch.set_facecolor(color)
|
| 338 |
patch.set_alpha(0.7)
|
| 339 |
|
| 340 |
ax.set_xticklabels(
|
| 341 |
[g.title() for g in groups],
|
| 342 |
-
rotation=30,
|
|
|
|
| 343 |
)
|
| 344 |
ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
|
| 345 |
|
|
@@ -348,9 +364,16 @@ class MetricsVisualizer:
|
|
| 348 |
ax.set_title(title, fontweight="bold")
|
| 349 |
|
| 350 |
# Add sample count annotations
|
| 351 |
-
for i, (
|
| 352 |
-
ax.text(
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
fig.tight_layout()
|
| 356 |
out_path = self.output_dir / filename
|
|
@@ -430,10 +453,12 @@ class MetricsVisualizer:
|
|
| 430 |
parts.append(val_str)
|
| 431 |
lines.append(" & ".join(parts) + " \\\\")
|
| 432 |
|
| 433 |
-
lines.extend(
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
|
|
|
|
|
|
| 438 |
|
| 439 |
return "\n".join(lines)
|
|
|
|
| 24 |
|
| 25 |
from __future__ import annotations
|
| 26 |
|
|
|
|
| 27 |
from pathlib import Path
|
| 28 |
from typing import Any
|
| 29 |
|
|
|
|
| 78 |
self.dpi = dpi
|
| 79 |
self.style = style
|
| 80 |
|
| 81 |
+
def _get_plt(self) -> Any:
|
| 82 |
"""Import matplotlib with configuration."""
|
| 83 |
import matplotlib
|
| 84 |
+
|
| 85 |
matplotlib.use("Agg")
|
| 86 |
import matplotlib.pyplot as plt
|
| 87 |
+
|
| 88 |
try:
|
| 89 |
plt.style.use(self.style)
|
| 90 |
except OSError:
|
| 91 |
plt.style.use("seaborn-v0_8")
|
| 92 |
# Publication font sizes
|
| 93 |
+
plt.rcParams.update(
|
| 94 |
+
{
|
| 95 |
+
"font.size": 10,
|
| 96 |
+
"axes.titlesize": 12,
|
| 97 |
+
"axes.labelsize": 11,
|
| 98 |
+
"xtick.labelsize": 9,
|
| 99 |
+
"ytick.labelsize": 9,
|
| 100 |
+
"legend.fontsize": 9,
|
| 101 |
+
"figure.titlesize": 13,
|
| 102 |
+
}
|
| 103 |
+
)
|
| 104 |
return plt
|
| 105 |
|
| 106 |
# ------------------------------------------------------------------
|
|
|
|
| 141 |
if n_metrics == 1:
|
| 142 |
axes = [axes]
|
| 143 |
|
| 144 |
+
for ax, metric in zip(axes, metrics, strict=False):
|
| 145 |
values = [metrics_by_procedure[p].get(metric, 0) for p in procedures]
|
| 146 |
colors = [self.COLORS.get(p, "#999999") for p in procedures]
|
| 147 |
|
|
|
|
| 149 |
ax.set_xticks(range(n_procs))
|
| 150 |
ax.set_xticklabels(
|
| 151 |
[p[:5].title() for p in procedures],
|
| 152 |
+
rotation=30,
|
| 153 |
+
ha="right",
|
| 154 |
)
|
| 155 |
ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
|
| 156 |
ax.set_title(self.METRIC_LABELS.get(metric, metric))
|
| 157 |
|
| 158 |
# Add value labels on bars
|
| 159 |
+
for bar, val in zip(bars, values, strict=False):
|
| 160 |
ax.text(
|
| 161 |
+
bar.get_x() + bar.get_width() / 2,
|
| 162 |
+
bar.get_height(),
|
| 163 |
+
f"{val:.3f}",
|
| 164 |
+
ha="center",
|
| 165 |
+
va="bottom",
|
| 166 |
+
fontsize=8,
|
| 167 |
)
|
| 168 |
|
| 169 |
fig.suptitle(title, fontweight="bold")
|
|
|
|
| 200 |
|
| 201 |
if metrics is None:
|
| 202 |
metrics = sorted(
|
| 203 |
+
set.intersection(*(set(v.keys()) for v in experiments.values()))
|
| 204 |
+
& set(self.METRIC_LABELS.keys())
|
|
|
|
| 205 |
)
|
| 206 |
|
| 207 |
n_metrics = len(metrics)
|
|
|
|
| 265 |
plt = self._get_plt()
|
| 266 |
|
| 267 |
fitz_types = sorted(metrics_by_type.keys())
|
| 268 |
+
procedures = sorted(set.union(*(set(v.keys()) for v in metrics_by_type.values())))
|
|
|
|
|
|
|
| 269 |
|
| 270 |
# Build matrix
|
| 271 |
matrix = np.zeros((len(fitz_types), len(procedures)))
|
|
|
|
| 273 |
for j, proc in enumerate(procedures):
|
| 274 |
matrix[i, j] = metrics_by_type[ft].get(proc, 0)
|
| 275 |
|
| 276 |
+
fig, ax = plt.subplots(
|
| 277 |
+
figsize=(max(6, len(procedures) * 1.5), max(4, len(fitz_types) * 0.8))
|
| 278 |
+
)
|
| 279 |
|
| 280 |
cmap = "RdYlGn" if self.METRIC_HIGHER_BETTER.get(metric, True) else "RdYlGn_r"
|
| 281 |
im = ax.imshow(matrix, cmap=cmap, aspect="auto")
|
|
|
|
| 289 |
# Annotate cells
|
| 290 |
for i in range(len(fitz_types)):
|
| 291 |
for j in range(len(procedures)):
|
| 292 |
+
ax.text(
|
| 293 |
+
j,
|
| 294 |
+
i,
|
| 295 |
+
f"{matrix[i, j]:.3f}",
|
| 296 |
+
ha="center",
|
| 297 |
+
va="center",
|
| 298 |
+
fontsize=9,
|
| 299 |
+
color="white" if matrix[i, j] < np.median(matrix) else "black",
|
| 300 |
+
)
|
| 301 |
|
| 302 |
fig.colorbar(im, ax=ax, label=self.METRIC_LABELS.get(metric, metric))
|
| 303 |
|
|
|
|
| 341 |
fig, ax = plt.subplots(figsize=(max(6, len(groups) * 1.2), 5))
|
| 342 |
|
| 343 |
bp = ax.boxplot(
|
| 344 |
+
data,
|
| 345 |
+
patch_artist=True,
|
| 346 |
+
widths=0.6,
|
| 347 |
medianprops={"color": "black", "linewidth": 1.5},
|
| 348 |
)
|
| 349 |
|
| 350 |
colors = [self.COLORS.get(g, "#4C72B0") for g in groups]
|
| 351 |
+
for patch, color in zip(bp["boxes"], colors, strict=False):
|
| 352 |
patch.set_facecolor(color)
|
| 353 |
patch.set_alpha(0.7)
|
| 354 |
|
| 355 |
ax.set_xticklabels(
|
| 356 |
[g.title() for g in groups],
|
| 357 |
+
rotation=30,
|
| 358 |
+
ha="right",
|
| 359 |
)
|
| 360 |
ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
|
| 361 |
|
|
|
|
| 364 |
ax.set_title(title, fontweight="bold")
|
| 365 |
|
| 366 |
# Add sample count annotations
|
| 367 |
+
for i, (_g, vals) in enumerate(zip(groups, data, strict=False)):
|
| 368 |
+
ax.text(
|
| 369 |
+
i + 1,
|
| 370 |
+
ax.get_ylim()[0],
|
| 371 |
+
f"n={len(vals)}",
|
| 372 |
+
ha="center",
|
| 373 |
+
va="bottom",
|
| 374 |
+
fontsize=8,
|
| 375 |
+
color="gray",
|
| 376 |
+
)
|
| 377 |
|
| 378 |
fig.tight_layout()
|
| 379 |
out_path = self.output_dir / filename
|
|
|
|
| 453 |
parts.append(val_str)
|
| 454 |
lines.append(" & ".join(parts) + " \\\\")
|
| 455 |
|
| 456 |
+
lines.extend(
|
| 457 |
+
[
|
| 458 |
+
"\\bottomrule",
|
| 459 |
+
"\\end{tabular}",
|
| 460 |
+
"\\end{table}",
|
| 461 |
+
]
|
| 462 |
+
)
|
| 463 |
|
| 464 |
return "\n".join(lines)
|
landmarkdiff/model_registry.py
CHANGED
|
@@ -139,9 +139,7 @@ class ModelRegistry:
|
|
| 139 |
step = int(parts[-1])
|
| 140 |
|
| 141 |
# Compute size
|
| 142 |
-
size_mb = sum(
|
| 143 |
-
f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()
|
| 144 |
-
) / (1024 * 1024)
|
| 145 |
|
| 146 |
return ModelEntry(
|
| 147 |
name=ckpt_dir.name,
|
|
@@ -195,16 +193,15 @@ class ModelRegistry:
|
|
| 195 |
Returns:
|
| 196 |
Best ModelEntry, or None if no models have the metric.
|
| 197 |
"""
|
| 198 |
-
candidates = [
|
| 199 |
-
m for m in self._models.values()
|
| 200 |
-
if metric in m.metrics
|
| 201 |
-
]
|
| 202 |
if not candidates:
|
| 203 |
return None
|
| 204 |
|
| 205 |
-
return
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
| 208 |
|
| 209 |
def get_by_step(self, step: int) -> ModelEntry | None:
|
| 210 |
"""Get a model by its training step."""
|
|
@@ -266,9 +263,7 @@ class ModelRegistry:
|
|
| 266 |
raise KeyError(f"Checkpoint '{name}' not found in registry")
|
| 267 |
|
| 268 |
if use_ema and entry.has_ema:
|
| 269 |
-
return ControlNetModel.from_pretrained(
|
| 270 |
-
str(entry.path / "controlnet_ema")
|
| 271 |
-
)
|
| 272 |
|
| 273 |
# Fallback: load from training state
|
| 274 |
state = self.load(name)
|
|
@@ -356,9 +351,7 @@ class ModelRegistry:
|
|
| 356 |
for metric in sorted(all_metrics):
|
| 357 |
values = [m.metrics[metric] for m in models if metric in m.metrics]
|
| 358 |
if values:
|
| 359 |
-
lines.append(
|
| 360 |
-
f" {metric}: {min(values):.4f} — {max(values):.4f}"
|
| 361 |
-
)
|
| 362 |
|
| 363 |
return "\n".join(lines)
|
| 364 |
|
|
|
|
| 139 |
step = int(parts[-1])
|
| 140 |
|
| 141 |
# Compute size
|
| 142 |
+
size_mb = sum(f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()) / (1024 * 1024)
|
|
|
|
|
|
|
| 143 |
|
| 144 |
return ModelEntry(
|
| 145 |
name=ckpt_dir.name,
|
|
|
|
| 193 |
Returns:
|
| 194 |
Best ModelEntry, or None if no models have the metric.
|
| 195 |
"""
|
| 196 |
+
candidates = [m for m in self._models.values() if metric in m.metrics]
|
|
|
|
|
|
|
|
|
|
| 197 |
if not candidates:
|
| 198 |
return None
|
| 199 |
|
| 200 |
+
return (
|
| 201 |
+
min(candidates, key=lambda m: m.metrics[metric])
|
| 202 |
+
if lower_is_better
|
| 203 |
+
else max(candidates, key=lambda m: m.metrics[metric])
|
| 204 |
+
)
|
| 205 |
|
| 206 |
def get_by_step(self, step: int) -> ModelEntry | None:
|
| 207 |
"""Get a model by its training step."""
|
|
|
|
| 263 |
raise KeyError(f"Checkpoint '{name}' not found in registry")
|
| 264 |
|
| 265 |
if use_ema and entry.has_ema:
|
| 266 |
+
return ControlNetModel.from_pretrained(str(entry.path / "controlnet_ema"))
|
|
|
|
|
|
|
| 267 |
|
| 268 |
# Fallback: load from training state
|
| 269 |
state = self.load(name)
|
|
|
|
| 351 |
for metric in sorted(all_metrics):
|
| 352 |
values = [m.metrics[metric] for m in models if metric in m.metrics]
|
| 353 |
if values:
|
| 354 |
+
lines.append(f" {metric}: {min(values):.4f} — {max(values):.4f}")
|
|
|
|
|
|
|
| 355 |
|
| 356 |
return "\n".join(lines)
|
| 357 |
|
landmarkdiff/postprocess.py
CHANGED
|
@@ -54,13 +54,10 @@ def laplacian_pyramid_blend(
|
|
| 54 |
mask_f = mask.astype(np.float32)
|
| 55 |
if mask_f.max() > 1.0:
|
| 56 |
mask_f = mask_f / 255.0
|
| 57 |
-
if mask_f.ndim == 2
|
| 58 |
-
mask_3ch = np.stack([mask_f] * 3, axis=-1)
|
| 59 |
-
else:
|
| 60 |
-
mask_3ch = mask_f
|
| 61 |
|
| 62 |
# Make dimensions divisible by 2^levels
|
| 63 |
-
factor = 2
|
| 64 |
new_h = (h + factor - 1) // factor * factor
|
| 65 |
new_w = (w + factor - 1) // factor * factor
|
| 66 |
|
|
@@ -232,24 +229,27 @@ def restore_face_codeformer(
|
|
| 232 |
Restored BGR image, or original if CodeFormer unavailable.
|
| 233 |
"""
|
| 234 |
try:
|
|
|
|
| 235 |
from codeformer.basicsr.utils import img2tensor, tensor2img
|
| 236 |
-
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
|
| 237 |
from codeformer.basicsr.utils.download_util import load_file_from_url
|
| 238 |
-
import
|
| 239 |
from torchvision.transforms.functional import normalize as tv_normalize
|
| 240 |
except ImportError:
|
| 241 |
return image
|
| 242 |
|
| 243 |
try:
|
| 244 |
global _CODEFORMER_MODEL, _CODEFORMER_HELPER
|
| 245 |
-
from codeformer.inference_codeformer import set_realesrgan as _unused # noqa: F401
|
| 246 |
from codeformer.basicsr.archs.codeformer_arch import CodeFormer as CodeFormerArch
|
|
|
|
| 247 |
|
| 248 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 249 |
|
| 250 |
if _CODEFORMER_MODEL is None:
|
| 251 |
model = CodeFormerArch(
|
| 252 |
-
dim_embd=512,
|
|
|
|
|
|
|
|
|
|
| 253 |
connect_list=["32", "64", "128", "256"],
|
| 254 |
).to(device)
|
| 255 |
|
|
@@ -316,16 +316,18 @@ def enhance_background_realesrgan(
|
|
| 316 |
Enhanced BGR image at original resolution.
|
| 317 |
"""
|
| 318 |
try:
|
| 319 |
-
from realesrgan import RealESRGANer
|
| 320 |
-
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 321 |
import torch
|
|
|
|
|
|
|
| 322 |
except ImportError:
|
| 323 |
return image
|
| 324 |
|
| 325 |
try:
|
| 326 |
global _REALESRGAN_UPSAMPLER
|
| 327 |
if _REALESRGAN_UPSAMPLER is None:
|
| 328 |
-
model = RRDBNet(
|
|
|
|
|
|
|
| 329 |
_REALESRGAN_UPSAMPLER = RealESRGANer(
|
| 330 |
scale=4,
|
| 331 |
model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
|
@@ -345,15 +347,11 @@ def enhance_background_realesrgan(
|
|
| 345 |
mask_f = mask.astype(np.float32)
|
| 346 |
if mask_f.max() > 1.0:
|
| 347 |
mask_f /= 255.0
|
| 348 |
-
if mask_f.ndim == 2
|
| 349 |
-
mask_3ch = np.stack([mask_f] * 3, axis=-1)
|
| 350 |
-
else:
|
| 351 |
-
mask_3ch = mask_f
|
| 352 |
|
| 353 |
# Keep face region from original, use enhanced for background
|
| 354 |
result = (
|
| 355 |
-
image.astype(np.float32) * mask_3ch
|
| 356 |
-
+ enhanced.astype(np.float32) * (1.0 - mask_3ch)
|
| 357 |
).astype(np.uint8)
|
| 358 |
return result
|
| 359 |
except Exception:
|
|
@@ -414,9 +412,10 @@ def verify_identity_arcface(
|
|
| 414 |
orig_emb = orig_faces[0].embedding
|
| 415 |
result_emb = result_faces[0].embedding
|
| 416 |
|
| 417 |
-
sim = float(
|
| 418 |
-
np.
|
| 419 |
-
|
|
|
|
| 420 |
sim = float(np.clip(sim, 0, 1))
|
| 421 |
|
| 422 |
passed = sim >= threshold
|
|
@@ -437,6 +436,7 @@ def verify_identity_arcface(
|
|
| 437 |
def _has_cuda() -> bool:
|
| 438 |
try:
|
| 439 |
import torch
|
|
|
|
| 440 |
return torch.cuda.is_available()
|
| 441 |
except ImportError:
|
| 442 |
return False
|
|
@@ -465,7 +465,7 @@ def histogram_match_skin(
|
|
| 465 |
if not np.any(mask_bool):
|
| 466 |
return source
|
| 467 |
|
| 468 |
-
|
| 469 |
src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 470 |
ref_lab = cv2.cvtColor(reference, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 471 |
|
|
@@ -574,20 +574,18 @@ def full_postprocess(
|
|
| 574 |
mask_f = mask.astype(np.float32)
|
| 575 |
if mask_f.max() > 1.0:
|
| 576 |
mask_f /= 255.0
|
| 577 |
-
if mask_f.ndim == 2
|
| 578 |
-
mask_3ch = np.stack([mask_f] * 3, axis=-1)
|
| 579 |
-
else:
|
| 580 |
-
mask_3ch = mask_f
|
| 581 |
composited = (
|
| 582 |
-
result.astype(np.float32) * mask_3ch
|
| 583 |
-
+ original.astype(np.float32) * (1.0 - mask_3ch)
|
| 584 |
).astype(np.uint8)
|
| 585 |
|
| 586 |
# Step 6: Neural identity verification
|
| 587 |
identity_check = {"similarity": -1.0, "passed": True, "message": "skipped"}
|
| 588 |
if verify_identity:
|
| 589 |
identity_check = verify_identity_arcface(
|
| 590 |
-
original,
|
|
|
|
|
|
|
| 591 |
)
|
| 592 |
|
| 593 |
return {
|
|
|
|
| 54 |
mask_f = mask.astype(np.float32)
|
| 55 |
if mask_f.max() > 1.0:
|
| 56 |
mask_f = mask_f / 255.0
|
| 57 |
+
mask_3ch = np.stack([mask_f] * 3, axis=-1) if mask_f.ndim == 2 else mask_f
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
# Make dimensions divisible by 2^levels
|
| 60 |
+
factor = 2**levels
|
| 61 |
new_h = (h + factor - 1) // factor * factor
|
| 62 |
new_w = (w + factor - 1) // factor * factor
|
| 63 |
|
|
|
|
| 229 |
Restored BGR image, or original if CodeFormer unavailable.
|
| 230 |
"""
|
| 231 |
try:
|
| 232 |
+
import torch
|
| 233 |
from codeformer.basicsr.utils import img2tensor, tensor2img
|
|
|
|
| 234 |
from codeformer.basicsr.utils.download_util import load_file_from_url
|
| 235 |
+
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
|
| 236 |
from torchvision.transforms.functional import normalize as tv_normalize
|
| 237 |
except ImportError:
|
| 238 |
return image
|
| 239 |
|
| 240 |
try:
|
| 241 |
global _CODEFORMER_MODEL, _CODEFORMER_HELPER
|
|
|
|
| 242 |
from codeformer.basicsr.archs.codeformer_arch import CodeFormer as CodeFormerArch
|
| 243 |
+
from codeformer.inference_codeformer import set_realesrgan as _unused # noqa: F401
|
| 244 |
|
| 245 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 246 |
|
| 247 |
if _CODEFORMER_MODEL is None:
|
| 248 |
model = CodeFormerArch(
|
| 249 |
+
dim_embd=512,
|
| 250 |
+
codebook_size=1024,
|
| 251 |
+
n_head=8,
|
| 252 |
+
n_layers=9,
|
| 253 |
connect_list=["32", "64", "128", "256"],
|
| 254 |
).to(device)
|
| 255 |
|
|
|
|
| 316 |
Enhanced BGR image at original resolution.
|
| 317 |
"""
|
| 318 |
try:
|
|
|
|
|
|
|
| 319 |
import torch
|
| 320 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 321 |
+
from realesrgan import RealESRGANer
|
| 322 |
except ImportError:
|
| 323 |
return image
|
| 324 |
|
| 325 |
try:
|
| 326 |
global _REALESRGAN_UPSAMPLER
|
| 327 |
if _REALESRGAN_UPSAMPLER is None:
|
| 328 |
+
model = RRDBNet(
|
| 329 |
+
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
|
| 330 |
+
)
|
| 331 |
_REALESRGAN_UPSAMPLER = RealESRGANer(
|
| 332 |
scale=4,
|
| 333 |
model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
|
|
|
| 347 |
mask_f = mask.astype(np.float32)
|
| 348 |
if mask_f.max() > 1.0:
|
| 349 |
mask_f /= 255.0
|
| 350 |
+
mask_3ch = np.stack([mask_f] * 3, axis=-1) if mask_f.ndim == 2 else mask_f
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
# Keep face region from original, use enhanced for background
|
| 353 |
result = (
|
| 354 |
+
image.astype(np.float32) * mask_3ch + enhanced.astype(np.float32) * (1.0 - mask_3ch)
|
|
|
|
| 355 |
).astype(np.uint8)
|
| 356 |
return result
|
| 357 |
except Exception:
|
|
|
|
| 412 |
orig_emb = orig_faces[0].embedding
|
| 413 |
result_emb = result_faces[0].embedding
|
| 414 |
|
| 415 |
+
sim = float(
|
| 416 |
+
np.dot(orig_emb, result_emb)
|
| 417 |
+
/ (np.linalg.norm(orig_emb) * np.linalg.norm(result_emb) + 1e-8)
|
| 418 |
+
)
|
| 419 |
sim = float(np.clip(sim, 0, 1))
|
| 420 |
|
| 421 |
passed = sim >= threshold
|
|
|
|
| 436 |
def _has_cuda() -> bool:
|
| 437 |
try:
|
| 438 |
import torch
|
| 439 |
+
|
| 440 |
return torch.cuda.is_available()
|
| 441 |
except ImportError:
|
| 442 |
return False
|
|
|
|
| 465 |
if not np.any(mask_bool):
|
| 466 |
return source
|
| 467 |
|
| 468 |
+
source.copy()
|
| 469 |
src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 470 |
ref_lab = cv2.cvtColor(reference, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 471 |
|
|
|
|
| 574 |
mask_f = mask.astype(np.float32)
|
| 575 |
if mask_f.max() > 1.0:
|
| 576 |
mask_f /= 255.0
|
| 577 |
+
mask_3ch = np.stack([mask_f] * 3, axis=-1) if mask_f.ndim == 2 else mask_f
|
|
|
|
|
|
|
|
|
|
| 578 |
composited = (
|
| 579 |
+
result.astype(np.float32) * mask_3ch + original.astype(np.float32) * (1.0 - mask_3ch)
|
|
|
|
| 580 |
).astype(np.uint8)
|
| 581 |
|
| 582 |
# Step 6: Neural identity verification
|
| 583 |
identity_check = {"similarity": -1.0, "passed": True, "message": "skipped"}
|
| 584 |
if verify_identity:
|
| 585 |
identity_check = verify_identity_arcface(
|
| 586 |
+
original,
|
| 587 |
+
composited,
|
| 588 |
+
threshold=identity_threshold,
|
| 589 |
)
|
| 590 |
|
| 591 |
return {
|
landmarkdiff/py.typed
ADDED
|
File without changes
|
landmarkdiff/safety.py
CHANGED
|
@@ -26,7 +26,6 @@ Usage:
|
|
| 26 |
from __future__ import annotations
|
| 27 |
|
| 28 |
from dataclasses import dataclass, field
|
| 29 |
-
from typing import Optional
|
| 30 |
|
| 31 |
import cv2
|
| 32 |
import numpy as np
|
|
@@ -35,6 +34,7 @@ import numpy as np
|
|
| 35 |
@dataclass
|
| 36 |
class SafetyResult:
|
| 37 |
"""Result of safety validation checks."""
|
|
|
|
| 38 |
passed: bool = True
|
| 39 |
failures: list[str] = field(default_factory=list)
|
| 40 |
warnings: list[str] = field(default_factory=list)
|
|
@@ -124,9 +124,7 @@ class SafetyValidator:
|
|
| 124 |
|
| 125 |
return result
|
| 126 |
|
| 127 |
-
def _check_face_confidence(
|
| 128 |
-
self, result: SafetyResult, confidence: float
|
| 129 |
-
) -> None:
|
| 130 |
"""Check face detection confidence."""
|
| 131 |
if confidence < self.min_face_confidence:
|
| 132 |
result.add_failure(
|
|
@@ -147,14 +145,14 @@ class SafetyValidator:
|
|
| 147 |
"""Check identity preservation using ArcFace similarity."""
|
| 148 |
try:
|
| 149 |
from landmarkdiff.evaluation import compute_identity_similarity
|
|
|
|
| 150 |
sim = compute_identity_similarity(output_image, input_image)
|
| 151 |
result.details["identity_similarity"] = float(sim)
|
| 152 |
|
| 153 |
if sim < self.identity_threshold:
|
| 154 |
result.add_failure(
|
| 155 |
"identity",
|
| 156 |
-
f"Identity similarity {sim:.3f} below threshold "
|
| 157 |
-
f"{self.identity_threshold}",
|
| 158 |
)
|
| 159 |
else:
|
| 160 |
result.add_pass("identity")
|
|
@@ -257,9 +255,7 @@ class SafetyValidator:
|
|
| 257 |
else:
|
| 258 |
result.add_pass("procedure_region")
|
| 259 |
|
| 260 |
-
def _check_output_quality(
|
| 261 |
-
self, result: SafetyResult, output: np.ndarray
|
| 262 |
-
) -> None:
|
| 263 |
"""Check output image quality (not blank, not corrupted)."""
|
| 264 |
if output is None or output.size == 0:
|
| 265 |
result.add_failure("output_quality", "Output image is empty")
|
|
@@ -346,8 +342,7 @@ class SafetyValidator:
|
|
| 346 |
cv2.addWeighted(overlay, opacity, result, 1 - opacity, 0, result)
|
| 347 |
|
| 348 |
# White text
|
| 349 |
-
cv2.putText(result, text, (x, y), font, font_scale,
|
| 350 |
-
(255, 255, 255), thickness, cv2.LINE_AA)
|
| 351 |
|
| 352 |
return result
|
| 353 |
|
|
@@ -371,7 +366,7 @@ class SafetyValidator:
|
|
| 371 |
"procedure": procedure,
|
| 372 |
"intensity": intensity,
|
| 373 |
"disclaimer": "AI-generated surgical prediction for visualization only. "
|
| 374 |
-
|
| 375 |
}
|
| 376 |
|
| 377 |
# Save as sidecar JSON (PNG doesn't have easy EXIF support)
|
|
|
|
| 26 |
from __future__ import annotations
|
| 27 |
|
| 28 |
from dataclasses import dataclass, field
|
|
|
|
| 29 |
|
| 30 |
import cv2
|
| 31 |
import numpy as np
|
|
|
|
| 34 |
@dataclass
|
| 35 |
class SafetyResult:
|
| 36 |
"""Result of safety validation checks."""
|
| 37 |
+
|
| 38 |
passed: bool = True
|
| 39 |
failures: list[str] = field(default_factory=list)
|
| 40 |
warnings: list[str] = field(default_factory=list)
|
|
|
|
| 124 |
|
| 125 |
return result
|
| 126 |
|
| 127 |
+
def _check_face_confidence(self, result: SafetyResult, confidence: float) -> None:
|
|
|
|
|
|
|
| 128 |
"""Check face detection confidence."""
|
| 129 |
if confidence < self.min_face_confidence:
|
| 130 |
result.add_failure(
|
|
|
|
| 145 |
"""Check identity preservation using ArcFace similarity."""
|
| 146 |
try:
|
| 147 |
from landmarkdiff.evaluation import compute_identity_similarity
|
| 148 |
+
|
| 149 |
sim = compute_identity_similarity(output_image, input_image)
|
| 150 |
result.details["identity_similarity"] = float(sim)
|
| 151 |
|
| 152 |
if sim < self.identity_threshold:
|
| 153 |
result.add_failure(
|
| 154 |
"identity",
|
| 155 |
+
f"Identity similarity {sim:.3f} below threshold {self.identity_threshold}",
|
|
|
|
| 156 |
)
|
| 157 |
else:
|
| 158 |
result.add_pass("identity")
|
|
|
|
| 255 |
else:
|
| 256 |
result.add_pass("procedure_region")
|
| 257 |
|
| 258 |
+
def _check_output_quality(self, result: SafetyResult, output: np.ndarray) -> None:
|
|
|
|
|
|
|
| 259 |
"""Check output image quality (not blank, not corrupted)."""
|
| 260 |
if output is None or output.size == 0:
|
| 261 |
result.add_failure("output_quality", "Output image is empty")
|
|
|
|
| 342 |
cv2.addWeighted(overlay, opacity, result, 1 - opacity, 0, result)
|
| 343 |
|
| 344 |
# White text
|
| 345 |
+
cv2.putText(result, text, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
|
|
|
| 346 |
|
| 347 |
return result
|
| 348 |
|
|
|
|
| 366 |
"procedure": procedure,
|
| 367 |
"intensity": intensity,
|
| 368 |
"disclaimer": "AI-generated surgical prediction for visualization only. "
|
| 369 |
+
"Not a guarantee of surgical outcome.",
|
| 370 |
}
|
| 371 |
|
| 372 |
# Save as sidecar JSON (PNG doesn't have easy EXIF support)
|
landmarkdiff/synthetic/__init__.py
CHANGED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Synthetic data generation for ControlNet fine-tuning.
|
| 2 |
+
|
| 3 |
+
Modules:
|
| 4 |
+
- pair_generator: Generate training pairs from face images
|
| 5 |
+
- augmentation: Clinical degradation augmentations
|
| 6 |
+
- tps_warp: TPS warping with rigid region preservation
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation
|
| 10 |
+
from landmarkdiff.synthetic.pair_generator import (
|
| 11 |
+
TrainingPair,
|
| 12 |
+
generate_pair,
|
| 13 |
+
generate_pairs_from_directory,
|
| 14 |
+
)
|
| 15 |
+
from landmarkdiff.synthetic.tps_warp import warp_image_tps
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"TrainingPair",
|
| 19 |
+
"apply_clinical_augmentation",
|
| 20 |
+
"generate_pair",
|
| 21 |
+
"generate_pairs_from_directory",
|
| 22 |
+
"warp_image_tps",
|
| 23 |
+
]
|
landmarkdiff/synthetic/augmentation.py
CHANGED
|
@@ -7,8 +7,8 @@ Applied from day 1 - domain gap prevention, not afterthought.
|
|
| 7 |
|
| 8 |
from __future__ import annotations
|
| 9 |
|
|
|
|
| 10 |
from dataclasses import dataclass
|
| 11 |
-
from typing import Callable
|
| 12 |
|
| 13 |
import cv2
|
| 14 |
import numpy as np
|
|
@@ -35,7 +35,7 @@ def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.nda
|
|
| 35 |
# Distance-based falloff
|
| 36 |
y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
|
| 37 |
dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
|
| 38 |
-
max_dist = np.sqrt(w
|
| 39 |
light_map = 1.0 - (dist / max_dist) * intensity
|
| 40 |
|
| 41 |
light_map = np.clip(light_map, 0.3, 1.0)
|
|
@@ -132,7 +132,7 @@ def vignette(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
| 132 |
y, x = np.mgrid[0:h, 0:w].astype(np.float32)
|
| 133 |
cx, cy = w / 2, h / 2
|
| 134 |
dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
|
| 135 |
-
max_dist = np.sqrt(cx
|
| 136 |
|
| 137 |
mask = 1 - strength * (dist / max_dist) ** 2
|
| 138 |
mask = np.clip(mask, 0.3, 1.0)
|
|
|
|
| 7 |
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
+
from collections.abc import Callable
|
| 11 |
from dataclasses import dataclass
|
|
|
|
| 12 |
|
| 13 |
import cv2
|
| 14 |
import numpy as np
|
|
|
|
| 35 |
# Distance-based falloff
|
| 36 |
y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
|
| 37 |
dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
|
| 38 |
+
max_dist = np.sqrt(w**2 + h**2)
|
| 39 |
light_map = 1.0 - (dist / max_dist) * intensity
|
| 40 |
|
| 41 |
light_map = np.clip(light_map, 0.3, 1.0)
|
|
|
|
| 132 |
y, x = np.mgrid[0:h, 0:w].astype(np.float32)
|
| 133 |
cx, cy = w / 2, h / 2
|
| 134 |
dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
|
| 135 |
+
max_dist = np.sqrt(cx**2 + cy**2)
|
| 136 |
|
| 137 |
mask = 1 - strength * (dist / max_dist) ** 2
|
| 138 |
mask = np.clip(mask, 0.3, 1.0)
|
landmarkdiff/synthetic/pair_generator.py
CHANGED
|
@@ -6,33 +6,32 @@ Augmentations on INPUT only, never target.
|
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
|
|
|
| 9 |
from dataclasses import dataclass
|
| 10 |
from pathlib import Path
|
| 11 |
-
from typing import Iterator
|
| 12 |
|
| 13 |
import cv2
|
| 14 |
import numpy as np
|
| 15 |
|
| 16 |
-
from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks, render_landmark_image
|
| 17 |
from landmarkdiff.conditioning import generate_conditioning
|
|
|
|
| 18 |
from landmarkdiff.manipulation import (
|
| 19 |
-
PROCEDURE_LANDMARKS,
|
| 20 |
apply_procedure_preset,
|
| 21 |
)
|
| 22 |
from landmarkdiff.masking import generate_surgical_mask
|
| 23 |
from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation
|
| 24 |
-
from landmarkdiff.synthetic.tps_warp import warp_image_tps
|
| 25 |
|
| 26 |
|
| 27 |
@dataclass(frozen=True)
|
| 28 |
class TrainingPair:
|
| 29 |
"""A single training sample for ControlNet fine-tuning."""
|
| 30 |
|
| 31 |
-
input_image: np.ndarray
|
| 32 |
-
target_image: np.ndarray
|
| 33 |
-
conditioning: np.ndarray
|
| 34 |
-
canny: np.ndarray
|
| 35 |
-
mask: np.ndarray
|
| 36 |
procedure: str
|
| 37 |
intensity: float
|
| 38 |
|
|
@@ -104,10 +103,7 @@ def generate_pairs_from_directory(
|
|
| 104 |
image_dir = Path(image_dir)
|
| 105 |
|
| 106 |
extensions = {".jpg", ".jpeg", ".png", ".webp"}
|
| 107 |
-
image_files = sorted(
|
| 108 |
-
f for f in image_dir.iterdir()
|
| 109 |
-
if f.suffix.lower() in extensions
|
| 110 |
-
)
|
| 111 |
|
| 112 |
if not image_files:
|
| 113 |
raise FileNotFoundError(f"No images found in {image_dir}")
|
|
|
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
+
from collections.abc import Iterator
|
| 10 |
from dataclasses import dataclass
|
| 11 |
from pathlib import Path
|
|
|
|
| 12 |
|
| 13 |
import cv2
|
| 14 |
import numpy as np
|
| 15 |
|
|
|
|
| 16 |
from landmarkdiff.conditioning import generate_conditioning
|
| 17 |
+
from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
|
| 18 |
from landmarkdiff.manipulation import (
|
|
|
|
| 19 |
apply_procedure_preset,
|
| 20 |
)
|
| 21 |
from landmarkdiff.masking import generate_surgical_mask
|
| 22 |
from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation
|
| 23 |
+
from landmarkdiff.synthetic.tps_warp import warp_image_tps
|
| 24 |
|
| 25 |
|
| 26 |
@dataclass(frozen=True)
|
| 27 |
class TrainingPair:
|
| 28 |
"""A single training sample for ControlNet fine-tuning."""
|
| 29 |
|
| 30 |
+
input_image: np.ndarray # augmented input (512x512 BGR)
|
| 31 |
+
target_image: np.ndarray # clean target (512x512 BGR) - TPS-warped original
|
| 32 |
+
conditioning: np.ndarray # landmark rendering (512x512 BGR)
|
| 33 |
+
canny: np.ndarray # canny edge map (512x512 grayscale)
|
| 34 |
+
mask: np.ndarray # feathered surgical mask (512x512 float32)
|
| 35 |
procedure: str
|
| 36 |
intensity: float
|
| 37 |
|
|
|
|
| 103 |
image_dir = Path(image_dir)
|
| 104 |
|
| 105 |
extensions = {".jpg", ".jpeg", ".png", ".webp"}
|
| 106 |
+
image_files = sorted(f for f in image_dir.iterdir() if f.suffix.lower() in extensions)
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
if not image_files:
|
| 109 |
raise FileNotFoundError(f"No images found in {image_dir}")
|
landmarkdiff/synthetic/tps_warp.py
CHANGED
|
@@ -156,7 +156,7 @@ def _solve_tps_weights(
|
|
| 156 |
|
| 157 |
# Build kernel matrix K (vectorized)
|
| 158 |
diff = control_pts[:, np.newaxis, :] - control_pts[np.newaxis, :, :] # (n, n, 2)
|
| 159 |
-
r_mat = np.sqrt((diff
|
| 160 |
K = np.zeros((n, n))
|
| 161 |
nz = r_mat > 0
|
| 162 |
K[nz] = r_mat[nz] ** 2 * np.log(r_mat[nz])
|
|
@@ -205,7 +205,7 @@ def _evaluate_tps(
|
|
| 205 |
# Compute all distances at once: (M, n)
|
| 206 |
dx = batch[:, 0:1] - control_pts[:, 0] # (M, n) via broadcasting
|
| 207 |
dy = batch[:, 1:2] - control_pts[:, 1] # (M, n)
|
| 208 |
-
r = np.sqrt(dx
|
| 209 |
|
| 210 |
# TPS kernel: r^2 * log(r), with r=0 -> 0
|
| 211 |
kernel = np.zeros_like(r)
|
|
@@ -230,9 +230,8 @@ def _compute_rigid_translation(
|
|
| 230 |
inside = []
|
| 231 |
for i, (x, y) in enumerate(src):
|
| 232 |
ix, iy = int(x), int(y)
|
| 233 |
-
if 0 <= ix < width and 0 <= iy < height:
|
| 234 |
-
|
| 235 |
-
inside.append(i)
|
| 236 |
|
| 237 |
if not inside:
|
| 238 |
return np.array([0.0, 0.0])
|
|
|
|
| 156 |
|
| 157 |
# Build kernel matrix K (vectorized)
|
| 158 |
diff = control_pts[:, np.newaxis, :] - control_pts[np.newaxis, :, :] # (n, n, 2)
|
| 159 |
+
r_mat = np.sqrt((diff**2).sum(axis=2)) # (n, n)
|
| 160 |
K = np.zeros((n, n))
|
| 161 |
nz = r_mat > 0
|
| 162 |
K[nz] = r_mat[nz] ** 2 * np.log(r_mat[nz])
|
|
|
|
| 205 |
# Compute all distances at once: (M, n)
|
| 206 |
dx = batch[:, 0:1] - control_pts[:, 0] # (M, n) via broadcasting
|
| 207 |
dy = batch[:, 1:2] - control_pts[:, 1] # (M, n)
|
| 208 |
+
r = np.sqrt(dx**2 + dy**2)
|
| 209 |
|
| 210 |
# TPS kernel: r^2 * log(r), with r=0 -> 0
|
| 211 |
kernel = np.zeros_like(r)
|
|
|
|
| 230 |
inside = []
|
| 231 |
for i, (x, y) in enumerate(src):
|
| 232 |
ix, iy = int(x), int(y)
|
| 233 |
+
if 0 <= ix < width and 0 <= iy < height and mask[iy, ix] > 0:
|
| 234 |
+
inside.append(i)
|
|
|
|
| 235 |
|
| 236 |
if not inside:
|
| 237 |
return np.array([0.0, 0.0])
|
landmarkdiff/validation.py
CHANGED
|
@@ -14,13 +14,11 @@ import json
|
|
| 14 |
import time
|
| 15 |
from pathlib import Path
|
| 16 |
|
| 17 |
-
import cv2
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
| 20 |
-
import torch.nn.functional as F
|
| 21 |
from PIL import Image
|
| 22 |
|
| 23 |
-
from landmarkdiff.evaluation import
|
| 24 |
|
| 25 |
|
| 26 |
class ValidationCallback:
|
|
@@ -116,13 +114,18 @@ class ValidationCallback:
|
|
| 116 |
|
| 117 |
# ControlNet
|
| 118 |
down_samples, mid_sample = controlnet(
|
| 119 |
-
scaled,
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
| 121 |
)
|
| 122 |
|
| 123 |
# UNet with ControlNet residuals
|
| 124 |
noise_pred = unet(
|
| 125 |
-
scaled,
|
|
|
|
|
|
|
| 126 |
down_block_additional_residuals=down_samples,
|
| 127 |
mid_block_additional_residual=mid_sample,
|
| 128 |
).sample
|
|
@@ -136,7 +139,9 @@ class ValidationCallback:
|
|
| 136 |
# Convert to numpy for metrics
|
| 137 |
gen_np = (decoded[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 138 |
tgt_np = (target[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 139 |
-
cond_np = (conditioning[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(
|
|
|
|
|
|
|
| 140 |
|
| 141 |
# BGR for metrics (our metrics expect BGR)
|
| 142 |
gen_bgr = gen_np[:, :, ::-1].copy()
|
|
@@ -177,7 +182,7 @@ class ValidationCallback:
|
|
| 177 |
if generated_images:
|
| 178 |
grid_rows = []
|
| 179 |
for i in range(0, len(generated_images), 4):
|
| 180 |
-
row_imgs = generated_images[i:i+4]
|
| 181 |
while len(row_imgs) < 4:
|
| 182 |
row_imgs.append(np.zeros_like(generated_images[0]))
|
| 183 |
grid_rows.append(np.hstack(row_imgs))
|
|
@@ -202,6 +207,7 @@ class ValidationCallback:
|
|
| 202 |
|
| 203 |
try:
|
| 204 |
import matplotlib
|
|
|
|
| 205 |
matplotlib.use("Agg")
|
| 206 |
import matplotlib.pyplot as plt
|
| 207 |
except ImportError:
|
|
|
|
| 14 |
import time
|
| 15 |
from pathlib import Path
|
| 16 |
|
|
|
|
| 17 |
import numpy as np
|
| 18 |
import torch
|
|
|
|
| 19 |
from PIL import Image
|
| 20 |
|
| 21 |
+
from landmarkdiff.evaluation import compute_lpips, compute_ssim
|
| 22 |
|
| 23 |
|
| 24 |
class ValidationCallback:
|
|
|
|
| 114 |
|
| 115 |
# ControlNet
|
| 116 |
down_samples, mid_sample = controlnet(
|
| 117 |
+
scaled,
|
| 118 |
+
t,
|
| 119 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 120 |
+
controlnet_cond=conditioning,
|
| 121 |
+
return_dict=False,
|
| 122 |
)
|
| 123 |
|
| 124 |
# UNet with ControlNet residuals
|
| 125 |
noise_pred = unet(
|
| 126 |
+
scaled,
|
| 127 |
+
t,
|
| 128 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 129 |
down_block_additional_residuals=down_samples,
|
| 130 |
mid_block_additional_residual=mid_sample,
|
| 131 |
).sample
|
|
|
|
| 139 |
# Convert to numpy for metrics
|
| 140 |
gen_np = (decoded[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 141 |
tgt_np = (target[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 142 |
+
cond_np = (conditioning[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(
|
| 143 |
+
np.uint8
|
| 144 |
+
)
|
| 145 |
|
| 146 |
# BGR for metrics (our metrics expect BGR)
|
| 147 |
gen_bgr = gen_np[:, :, ::-1].copy()
|
|
|
|
| 182 |
if generated_images:
|
| 183 |
grid_rows = []
|
| 184 |
for i in range(0, len(generated_images), 4):
|
| 185 |
+
row_imgs = generated_images[i : i + 4]
|
| 186 |
while len(row_imgs) < 4:
|
| 187 |
row_imgs.append(np.zeros_like(generated_images[0]))
|
| 188 |
grid_rows.append(np.hstack(row_imgs))
|
|
|
|
| 207 |
|
| 208 |
try:
|
| 209 |
import matplotlib
|
| 210 |
+
|
| 211 |
matplotlib.use("Agg")
|
| 212 |
import matplotlib.pyplot as plt
|
| 213 |
except ImportError:
|