dreamlessx commited on
Commit
433e26f
·
1 Parent(s): c951530

Add error handling, sync package from public repo, auto-trigger on upload

Browse files

- Wrap all processing functions in try/except with informative error messages
- Add show_error=True for visible error reporting
- Auto-trigger processing when image is uploaded (not just on button click)
- Sync bundled landmarkdiff package to match public repo v0.2.0
- Add .gitignore for __pycache__

.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
app.py CHANGED
@@ -2,6 +2,9 @@
2
 
3
  from __future__ import annotations
4
 
 
 
 
5
  import cv2
6
  import gradio as gr
7
  import numpy as np
@@ -11,6 +14,9 @@ from landmarkdiff.landmarks import extract_landmarks
11
  from landmarkdiff.manipulation import PROCEDURE_LANDMARKS, apply_procedure_preset
12
  from landmarkdiff.masking import generate_surgical_mask
13
 
 
 
 
14
  VERSION = "v0.2.1"
15
 
16
  GITHUB_URL = "https://github.com/dreamlessx/LandmarkDiff-public"
@@ -44,17 +50,31 @@ def mask_composite(warped, original, mask):
44
  PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
45
 
46
 
 
 
 
 
 
 
47
  def process_image(image_rgb, procedure, intensity):
48
  """Process a single image through the TPS pipeline."""
49
  if image_rgb is None:
50
- blank = np.zeros((512, 512, 3), dtype=np.uint8)
51
- return blank, blank, blank, blank, "Upload a face photo to begin."
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
54
- image_bgr = cv2.resize(image_bgr, (512, 512))
55
- image_rgb_512 = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
56
-
57
- face = extract_landmarks(image_bgr)
58
  if face is None:
59
  return (
60
  image_rgb_512,
@@ -64,38 +84,38 @@ def process_image(image_rgb, procedure, intensity):
64
  "No face detected. Try a clearer photo with good lighting.",
65
  )
66
 
67
- # Manipulate landmarks
68
- manipulated = apply_procedure_preset(face, procedure, float(intensity), image_size=512)
69
 
70
- # Generate wireframe (pass width and height as separate keyword args)
71
- wireframe = render_wireframe(manipulated, width=512, height=512)
72
- wireframe_rgb = cv2.cvtColor(wireframe, cv2.COLOR_GRAY2RGB)
73
 
74
- # Generate mask
75
- mask = generate_surgical_mask(face, procedure, 512, 512)
76
- mask_vis = (mask * 255).astype(np.uint8)
77
 
78
- # TPS warp + composite
79
- warped = warp_image_tps(image_bgr, face.pixel_coords, manipulated.pixel_coords)
80
- composited = mask_composite(warped, image_bgr, mask)
81
- composited_rgb = cv2.cvtColor(composited, cv2.COLOR_BGR2RGB)
82
 
83
- # Side by side
84
- side_by_side = np.hstack([image_rgb_512, composited_rgb])
85
 
86
- # Displacement stats
87
- displacement = np.mean(np.linalg.norm(manipulated.pixel_coords - face.pixel_coords, axis=1))
 
88
 
89
- info = (
90
- f"Procedure: {procedure}\n"
91
- f"Intensity: {intensity:.0f}%\n"
92
- f"Landmarks: {len(face.landmarks)}\n"
93
- f"Avg displacement: {displacement:.1f} px\n"
94
- f"Confidence: {face.confidence:.2f}\n"
95
- f"Mode: TPS (CPU)"
96
- )
 
97
 
98
- return wireframe_rgb, mask_vis, composited_rgb, side_by_side, info
 
 
99
 
100
 
101
  def compare_procedures(image_rgb, intensity):
@@ -104,23 +124,28 @@ def compare_procedures(image_rgb, intensity):
104
  blank = np.zeros((512, 512, 3), dtype=np.uint8)
105
  return [blank] * len(PROCEDURES)
106
 
107
- image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
108
- image_bgr = cv2.resize(image_bgr, (512, 512))
109
-
110
- face = extract_landmarks(image_bgr)
111
- if face is None:
112
- rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
113
- return [rgb] * len(PROCEDURES)
114
-
115
- results = []
116
- for proc in PROCEDURES:
117
- manip = apply_procedure_preset(face, proc, float(intensity), image_size=512)
118
- mask = generate_surgical_mask(face, proc, 512, 512)
119
- warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
120
- comp = mask_composite(warped, image_bgr, mask)
121
- results.append(cv2.cvtColor(comp, cv2.COLOR_BGR2RGB))
122
-
123
- return results
 
 
 
 
 
124
 
125
 
126
  def intensity_sweep(image_rgb, procedure):
@@ -128,27 +153,31 @@ def intensity_sweep(image_rgb, procedure):
128
  if image_rgb is None:
129
  return []
130
 
131
- image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
132
- image_bgr = cv2.resize(image_bgr, (512, 512))
133
-
134
- face = extract_landmarks(image_bgr)
135
- if face is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  return []
137
 
138
- steps = [0, 20, 40, 60, 80, 100]
139
- results = []
140
- for val in steps:
141
- if val == 0:
142
- results.append((cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB), "0%"))
143
- continue
144
- manip = apply_procedure_preset(face, procedure, float(val), image_size=512)
145
- mask = generate_surgical_mask(face, procedure, 512, 512)
146
- warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
147
- comp = mask_composite(warped, image_bgr, mask)
148
- results.append((cv2.cvtColor(comp, cv2.COLOR_BGR2RGB), f"{val}%"))
149
-
150
- return results
151
-
152
 
153
  # -- Build the procedure table for the description --
154
  _proc_rows = "\n".join(
@@ -248,7 +277,7 @@ with gr.Blocks(
248
  inputs=[input_image, procedure, intensity],
249
  outputs=[out_wireframe, out_mask, out_result, out_sidebyside, info_box],
250
  )
251
- for trigger in [procedure, intensity]:
252
  trigger.change(
253
  fn=process_image,
254
  inputs=[input_image, procedure, intensity],
@@ -310,4 +339,4 @@ with gr.Blocks(
310
  gr.Markdown(FOOTER_MD)
311
 
312
  if __name__ == "__main__":
313
- demo.launch()
 
2
 
3
  from __future__ import annotations
4
 
5
+ import logging
6
+ import traceback
7
+
8
  import cv2
9
  import gradio as gr
10
  import numpy as np
 
14
  from landmarkdiff.manipulation import PROCEDURE_LANDMARKS, apply_procedure_preset
15
  from landmarkdiff.masking import generate_surgical_mask
16
 
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
  VERSION = "v0.2.1"
21
 
22
  GITHUB_URL = "https://github.com/dreamlessx/LandmarkDiff-public"
 
50
  PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
51
 
52
 
53
+ def _error_result(msg):
54
+ """Return a 5-tuple of blanks + error message for the UI."""
55
+ blank = np.zeros((512, 512, 3), dtype=np.uint8)
56
+ return blank, blank, blank, blank, msg
57
+
58
+
59
  def process_image(image_rgb, procedure, intensity):
60
  """Process a single image through the TPS pipeline."""
61
  if image_rgb is None:
62
+ return _error_result("Upload a face photo to begin.")
63
+
64
+ try:
65
+ image_bgr = cv2.cvtColor(np.asarray(image_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
66
+ image_bgr = cv2.resize(image_bgr, (512, 512))
67
+ image_rgb_512 = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
68
+ except Exception as exc:
69
+ logger.error("Image conversion failed: %s", exc)
70
+ return _error_result(f"Image conversion failed: {exc}")
71
+
72
+ try:
73
+ face = extract_landmarks(image_bgr)
74
+ except Exception as exc:
75
+ logger.error("Landmark extraction failed: %s\n%s", exc, traceback.format_exc())
76
+ return _error_result(f"Landmark extraction error: {exc}")
77
 
 
 
 
 
 
78
  if face is None:
79
  return (
80
  image_rgb_512,
 
84
  "No face detected. Try a clearer photo with good lighting.",
85
  )
86
 
87
+ try:
88
+ manipulated = apply_procedure_preset(face, procedure, float(intensity), image_size=512)
89
 
90
+ wireframe = render_wireframe(manipulated, width=512, height=512)
91
+ wireframe_rgb = cv2.cvtColor(wireframe, cv2.COLOR_GRAY2RGB)
 
92
 
93
+ mask = generate_surgical_mask(face, procedure, 512, 512)
94
+ mask_vis = (mask * 255).astype(np.uint8)
 
95
 
96
+ warped = warp_image_tps(image_bgr, face.pixel_coords, manipulated.pixel_coords)
97
+ composited = mask_composite(warped, image_bgr, mask)
98
+ composited_rgb = cv2.cvtColor(composited, cv2.COLOR_BGR2RGB)
 
99
 
100
+ side_by_side = np.hstack([image_rgb_512, composited_rgb])
 
101
 
102
+ displacement = np.mean(
103
+ np.linalg.norm(manipulated.pixel_coords - face.pixel_coords, axis=1)
104
+ )
105
 
106
+ info = (
107
+ f"Procedure: {procedure}\n"
108
+ f"Intensity: {intensity:.0f}%\n"
109
+ f"Landmarks: {len(face.landmarks)}\n"
110
+ f"Avg displacement: {displacement:.1f} px\n"
111
+ f"Confidence: {face.confidence:.2f}\n"
112
+ f"Mode: TPS (CPU)"
113
+ )
114
+ return wireframe_rgb, mask_vis, composited_rgb, side_by_side, info
115
 
116
+ except Exception as exc:
117
+ logger.error("Processing failed: %s\n%s", exc, traceback.format_exc())
118
+ return _error_result(f"Processing error: {exc}")
119
 
120
 
121
  def compare_procedures(image_rgb, intensity):
 
124
  blank = np.zeros((512, 512, 3), dtype=np.uint8)
125
  return [blank] * len(PROCEDURES)
126
 
127
+ try:
128
+ image_bgr = cv2.cvtColor(np.asarray(image_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
129
+ image_bgr = cv2.resize(image_bgr, (512, 512))
130
+
131
+ face = extract_landmarks(image_bgr)
132
+ if face is None:
133
+ rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
134
+ return [rgb] * len(PROCEDURES)
135
+
136
+ results = []
137
+ for proc in PROCEDURES:
138
+ manip = apply_procedure_preset(face, proc, float(intensity), image_size=512)
139
+ mask = generate_surgical_mask(face, proc, 512, 512)
140
+ warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
141
+ comp = mask_composite(warped, image_bgr, mask)
142
+ results.append(cv2.cvtColor(comp, cv2.COLOR_BGR2RGB))
143
+
144
+ return results
145
+ except Exception as exc:
146
+ logger.error("Compare procedures failed: %s\n%s", exc, traceback.format_exc())
147
+ blank = np.zeros((512, 512, 3), dtype=np.uint8)
148
+ return [blank] * len(PROCEDURES)
149
 
150
 
151
  def intensity_sweep(image_rgb, procedure):
 
153
  if image_rgb is None:
154
  return []
155
 
156
+ try:
157
+ image_bgr = cv2.cvtColor(np.asarray(image_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
158
+ image_bgr = cv2.resize(image_bgr, (512, 512))
159
+
160
+ face = extract_landmarks(image_bgr)
161
+ if face is None:
162
+ return []
163
+
164
+ steps = [0, 20, 40, 60, 80, 100]
165
+ results = []
166
+ for val in steps:
167
+ if val == 0:
168
+ results.append((cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB), "0%"))
169
+ continue
170
+ manip = apply_procedure_preset(face, procedure, float(val), image_size=512)
171
+ mask = generate_surgical_mask(face, procedure, 512, 512)
172
+ warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
173
+ comp = mask_composite(warped, image_bgr, mask)
174
+ results.append((cv2.cvtColor(comp, cv2.COLOR_BGR2RGB), f"{val}%"))
175
+
176
+ return results
177
+ except Exception as exc:
178
+ logger.error("Intensity sweep failed: %s\n%s", exc, traceback.format_exc())
179
  return []
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  # -- Build the procedure table for the description --
183
  _proc_rows = "\n".join(
 
277
  inputs=[input_image, procedure, intensity],
278
  outputs=[out_wireframe, out_mask, out_result, out_sidebyside, info_box],
279
  )
280
+ for trigger in [input_image, procedure, intensity]:
281
  trigger.change(
282
  fn=process_image,
283
  inputs=[input_image, procedure, intensity],
 
339
  gr.Markdown(FOOTER_MD)
340
 
341
  if __name__ == "__main__":
342
+ demo.launch(show_error=True)
landmarkdiff/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  """LandmarkDiff: Anatomically-conditioned latent diffusion for facial surgery simulation."""
2
 
3
- __version__ = "0.3.0"
4
 
5
  __all__ = [
6
  "api_client",
 
1
  """LandmarkDiff: Anatomically-conditioned latent diffusion for facial surgery simulation."""
2
 
3
+ __version__ = "0.2.0"
4
 
5
  __all__ = [
6
  "api_client",
landmarkdiff/__main__.py CHANGED
@@ -1,80 +1,148 @@
1
  """CLI entry point for python -m landmarkdiff."""
2
 
 
 
3
  import argparse
4
  import sys
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- def main():
8
  parser = argparse.ArgumentParser(
9
  prog="landmarkdiff",
10
  description="Facial surgery outcome prediction from clinical photography",
11
  )
12
- parser.add_argument("--version", action="store_true", help="Print version and exit")
13
 
14
  subparsers = parser.add_subparsers(dest="command")
15
 
16
  # inference
17
  infer = subparsers.add_parser("infer", help="Run inference on an image")
18
  infer.add_argument("image", type=str, help="Path to input face image")
19
- infer.add_argument("--procedure", type=str, default="rhinoplasty",
20
- choices=["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic", "brow_lift", "mentoplasty"])
21
- infer.add_argument("--intensity", type=float, default=60.0,
22
- help="Deformation intensity (0-100)")
23
- infer.add_argument("--mode", type=str, default="tps",
24
- choices=["tps", "controlnet", "img2img", "controlnet_ip"])
25
- infer.add_argument("--output", type=str, default="output/")
26
- infer.add_argument("--steps", type=int, default=30)
27
- infer.add_argument("--seed", type=int, default=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # landmarks
30
  lm = subparsers.add_parser("landmarks", help="Extract and visualize landmarks")
31
  lm.add_argument("image", type=str, help="Path to input face image")
32
- lm.add_argument("--output", type=str, default="output/landmarks.png")
 
 
 
 
 
33
 
34
  # demo
35
  subparsers.add_parser("demo", help="Launch Gradio web demo")
36
 
37
  args = parser.parse_args()
38
 
39
- if args.version:
40
- from landmarkdiff import __version__
41
- print(f"landmarkdiff {__version__}")
42
- return
43
-
44
  if args.command is None:
45
  parser.print_help()
46
  return
47
 
48
- if args.command == "infer":
49
- _run_inference(args)
50
- elif args.command == "landmarks":
51
- _run_landmarks(args)
52
- elif args.command == "demo":
53
- _run_demo()
54
-
55
-
56
- def _run_inference(args):
57
- from pathlib import Path
 
 
 
 
58
  import numpy as np
59
  from PIL import Image
 
60
  from landmarkdiff.landmarks import extract_landmarks
61
  from landmarkdiff.manipulation import apply_procedure_preset
62
 
 
 
 
 
 
63
  output_dir = Path(args.output)
64
  output_dir.mkdir(parents=True, exist_ok=True)
65
 
66
- img = Image.open(args.image).convert("RGB").resize((512, 512))
67
  img_array = np.array(img)
68
 
69
  landmarks = extract_landmarks(img_array)
70
  if landmarks is None:
71
- print("no face detected")
72
- sys.exit(1)
73
 
74
  deformed = apply_procedure_preset(landmarks, args.procedure, intensity=args.intensity)
75
 
76
  if args.mode == "tps":
77
  from landmarkdiff.synthetic.tps_warp import warp_image_tps
 
78
  src = landmarks.pixel_coords[:, :2].copy()
79
  dst = deformed.pixel_coords[:, :2].copy()
80
  src[:, 0] *= 512 / landmarks.image_width
@@ -85,8 +153,11 @@ def _run_inference(args):
85
  Image.fromarray(warped).save(str(output_dir / "prediction.png"))
86
  print(f"saved tps result to {output_dir / 'prediction.png'}")
87
  else:
 
 
88
  from landmarkdiff.inference import LandmarkDiffPipeline
89
- pipeline = LandmarkDiffPipeline(mode=args.mode, device="cuda")
 
90
  pipeline.load()
91
  result = pipeline.generate(
92
  img_array,
@@ -99,37 +170,37 @@ def _run_inference(args):
99
  print(f"saved result to {output_dir / 'prediction.png'}")
100
 
101
 
102
- def _run_landmarks(args):
103
- from pathlib import Path
104
  import numpy as np
105
  from PIL import Image
 
106
  from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
107
 
108
- img = np.array(Image.open(args.image).convert("RGB").resize((512, 512)))
 
 
109
  landmarks = extract_landmarks(img)
110
  if landmarks is None:
111
- print("no face detected")
112
- sys.exit(1)
113
 
114
  mesh = render_landmark_image(landmarks, 512, 512)
115
 
116
  output_path = Path(args.output)
117
  output_path.parent.mkdir(parents=True, exist_ok=True)
118
 
119
- from PIL import Image
120
  Image.fromarray(mesh).save(str(output_path))
121
  print(f"saved landmark mesh to {output_path}")
122
  print(f"detected {len(landmarks.landmarks)} landmarks, confidence {landmarks.confidence:.2f}")
123
 
124
 
125
- def _run_demo():
126
  try:
127
  from scripts.app import build_app
 
128
  demo = build_app()
129
  demo.launch()
130
  except ImportError:
131
- print("gradio not installed - run: pip install landmarkdiff[app]")
132
- sys.exit(1)
133
 
134
 
135
  if __name__ == "__main__":
 
1
  """CLI entry point for python -m landmarkdiff."""
2
 
3
+ from __future__ import annotations
4
+
5
  import argparse
6
  import sys
7
+ from pathlib import Path
8
+ from typing import NoReturn
9
+
10
+
11
+ def _error(msg: str) -> NoReturn:
12
+ """Print error to stderr and exit."""
13
+ print(f"error: {msg}", file=sys.stderr)
14
+ sys.exit(1)
15
+
16
 
17
+ def _validate_image_path(path_str: str) -> Path:
18
+ """Validate that the image path exists and looks like an image file."""
19
+ p = Path(path_str)
20
+ if not p.exists():
21
+ _error(f"file not found: {path_str}")
22
+ if not p.is_file():
23
+ _error(f"not a file: {path_str}")
24
+ return p
25
+
26
+
27
+ def main() -> None:
28
+ from landmarkdiff import __version__
29
 
 
30
  parser = argparse.ArgumentParser(
31
  prog="landmarkdiff",
32
  description="Facial surgery outcome prediction from clinical photography",
33
  )
34
+ parser.add_argument("--version", action="version", version=f"landmarkdiff {__version__}")
35
 
36
  subparsers = parser.add_subparsers(dest="command")
37
 
38
  # inference
39
  infer = subparsers.add_parser("infer", help="Run inference on an image")
40
  infer.add_argument("image", type=str, help="Path to input face image")
41
+ infer.add_argument(
42
+ "--procedure",
43
+ type=str,
44
+ default="rhinoplasty",
45
+ choices=[
46
+ "rhinoplasty",
47
+ "blepharoplasty",
48
+ "rhytidectomy",
49
+ "orthognathic",
50
+ "brow_lift",
51
+ "mentoplasty",
52
+ ],
53
+ help="Surgical procedure to simulate (default: rhinoplasty)",
54
+ )
55
+ infer.add_argument(
56
+ "--intensity",
57
+ type=float,
58
+ default=60.0,
59
+ help="Deformation intensity, 0-100 (default: 60)",
60
+ )
61
+ infer.add_argument(
62
+ "--mode",
63
+ type=str,
64
+ default="tps",
65
+ choices=["tps", "controlnet", "img2img", "controlnet_ip"],
66
+ help="Inference mode (default: tps, others require GPU)",
67
+ )
68
+ infer.add_argument(
69
+ "--output",
70
+ type=str,
71
+ default="output/",
72
+ help="Output directory (default: output/)",
73
+ )
74
+ infer.add_argument(
75
+ "--steps",
76
+ type=int,
77
+ default=30,
78
+ help="Number of diffusion steps (default: 30)",
79
+ )
80
+ infer.add_argument(
81
+ "--seed",
82
+ type=int,
83
+ default=None,
84
+ help="Random seed for reproducibility",
85
+ )
86
 
87
  # landmarks
88
  lm = subparsers.add_parser("landmarks", help="Extract and visualize landmarks")
89
  lm.add_argument("image", type=str, help="Path to input face image")
90
+ lm.add_argument(
91
+ "--output",
92
+ type=str,
93
+ default="output/landmarks.png",
94
+ help="Output path for landmark visualization (default: output/landmarks.png)",
95
+ )
96
 
97
  # demo
98
  subparsers.add_parser("demo", help="Launch Gradio web demo")
99
 
100
  args = parser.parse_args()
101
 
 
 
 
 
 
102
  if args.command is None:
103
  parser.print_help()
104
  return
105
 
106
+ try:
107
+ if args.command == "infer":
108
+ _run_inference(args)
109
+ elif args.command == "landmarks":
110
+ _run_landmarks(args)
111
+ elif args.command == "demo":
112
+ _run_demo()
113
+ except KeyboardInterrupt:
114
+ sys.exit(130)
115
+ except Exception as exc:
116
+ _error(str(exc))
117
+
118
+
119
+ def _run_inference(args: argparse.Namespace) -> None:
120
  import numpy as np
121
  from PIL import Image
122
+
123
  from landmarkdiff.landmarks import extract_landmarks
124
  from landmarkdiff.manipulation import apply_procedure_preset
125
 
126
+ if not (0 <= args.intensity <= 100):
127
+ _error(f"intensity must be between 0 and 100, got {args.intensity}")
128
+
129
+ image_path = _validate_image_path(args.image)
130
+
131
  output_dir = Path(args.output)
132
  output_dir.mkdir(parents=True, exist_ok=True)
133
 
134
+ img = Image.open(image_path).convert("RGB").resize((512, 512))
135
  img_array = np.array(img)
136
 
137
  landmarks = extract_landmarks(img_array)
138
  if landmarks is None:
139
+ _error("no face detected in image")
 
140
 
141
  deformed = apply_procedure_preset(landmarks, args.procedure, intensity=args.intensity)
142
 
143
  if args.mode == "tps":
144
  from landmarkdiff.synthetic.tps_warp import warp_image_tps
145
+
146
  src = landmarks.pixel_coords[:, :2].copy()
147
  dst = deformed.pixel_coords[:, :2].copy()
148
  src[:, 0] *= 512 / landmarks.image_width
 
153
  Image.fromarray(warped).save(str(output_dir / "prediction.png"))
154
  print(f"saved tps result to {output_dir / 'prediction.png'}")
155
  else:
156
+ import torch
157
+
158
  from landmarkdiff.inference import LandmarkDiffPipeline
159
+
160
+ pipeline = LandmarkDiffPipeline(mode=args.mode, device=torch.device("cuda"))
161
  pipeline.load()
162
  result = pipeline.generate(
163
  img_array,
 
170
  print(f"saved result to {output_dir / 'prediction.png'}")
171
 
172
 
173
+ def _run_landmarks(args: argparse.Namespace) -> None:
 
174
  import numpy as np
175
  from PIL import Image
176
+
177
  from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
178
 
179
+ image_path = _validate_image_path(args.image)
180
+
181
+ img = np.array(Image.open(image_path).convert("RGB").resize((512, 512)))
182
  landmarks = extract_landmarks(img)
183
  if landmarks is None:
184
+ _error("no face detected in image")
 
185
 
186
  mesh = render_landmark_image(landmarks, 512, 512)
187
 
188
  output_path = Path(args.output)
189
  output_path.parent.mkdir(parents=True, exist_ok=True)
190
 
 
191
  Image.fromarray(mesh).save(str(output_path))
192
  print(f"saved landmark mesh to {output_path}")
193
  print(f"detected {len(landmarks.landmarks)} landmarks, confidence {landmarks.confidence:.2f}")
194
 
195
 
196
+ def _run_demo() -> None:
197
  try:
198
  from scripts.app import build_app
199
+
200
  demo = build_app()
201
  demo.launch()
202
  except ImportError:
203
+ _error("gradio not installed - run: pip install landmarkdiff[app]")
 
204
 
205
 
206
  if __name__ == "__main__":
landmarkdiff/api_client.py CHANGED
@@ -26,7 +26,6 @@ Usage:
26
  from __future__ import annotations
27
 
28
  import base64
29
- import io
30
  from dataclasses import dataclass, field
31
  from pathlib import Path
32
  from typing import Any
@@ -43,8 +42,8 @@ class PredictionResult:
43
  procedure: str
44
  intensity: float
45
  confidence: float = 0.0
46
- landmarks_before: list | None = None
47
- landmarks_after: list | None = None
48
  metrics: dict[str, float] = field(default_factory=dict)
49
  metadata: dict[str, Any] = field(default_factory=dict)
50
 
@@ -70,15 +69,15 @@ class LandmarkDiffClient:
70
  def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 60.0) -> None:
71
  self.base_url = base_url.rstrip("/")
72
  self.timeout = timeout
73
- self._session = None
74
 
75
- def _get_session(self):
76
  """Lazy-initialize requests session."""
77
  if self._session is None:
78
  try:
79
  import requests
80
  except ImportError:
81
- raise ImportError("requests required. Install with: pip install requests")
82
  self._session = requests.Session()
83
  self._session.timeout = self.timeout
84
  return self._session
@@ -218,12 +217,14 @@ class LandmarkDiffClient:
218
  results.append(result)
219
  except Exception as e:
220
  # Create a failed result
221
- results.append(PredictionResult(
222
- output_image=np.zeros((512, 512, 3), dtype=np.uint8),
223
- procedure=procedure,
224
- intensity=intensity,
225
- metadata={"error": str(e), "path": str(path)},
226
- ))
 
 
227
  return results
228
 
229
  def close(self) -> None:
@@ -232,10 +233,10 @@ class LandmarkDiffClient:
232
  self._session.close()
233
  self._session = None
234
 
235
- def __enter__(self):
236
  return self
237
 
238
- def __exit__(self, *args):
239
  self.close()
240
 
241
  def __repr__(self) -> str:
 
26
  from __future__ import annotations
27
 
28
  import base64
 
29
  from dataclasses import dataclass, field
30
  from pathlib import Path
31
  from typing import Any
 
42
  procedure: str
43
  intensity: float
44
  confidence: float = 0.0
45
+ landmarks_before: list[Any] | None = None
46
+ landmarks_after: list[Any] | None = None
47
  metrics: dict[str, float] = field(default_factory=dict)
48
  metadata: dict[str, Any] = field(default_factory=dict)
49
 
 
69
  def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 60.0) -> None:
70
  self.base_url = base_url.rstrip("/")
71
  self.timeout = timeout
72
+ self._session: Any = None
73
 
74
+ def _get_session(self) -> Any:
75
  """Lazy-initialize requests session."""
76
  if self._session is None:
77
  try:
78
  import requests
79
  except ImportError:
80
+ raise ImportError("requests required. Install with: pip install requests") from None
81
  self._session = requests.Session()
82
  self._session.timeout = self.timeout
83
  return self._session
 
217
  results.append(result)
218
  except Exception as e:
219
  # Create a failed result
220
+ results.append(
221
+ PredictionResult(
222
+ output_image=np.zeros((512, 512, 3), dtype=np.uint8),
223
+ procedure=procedure,
224
+ intensity=intensity,
225
+ metadata={"error": str(e), "path": str(path)},
226
+ )
227
+ )
228
  return results
229
 
230
  def close(self) -> None:
 
233
  self._session.close()
234
  self._session = None
235
 
236
+ def __enter__(self) -> LandmarkDiffClient:
237
  return self
238
 
239
+ def __exit__(self, *args: Any) -> None:
240
  self.close()
241
 
242
  def __repr__(self) -> str:
landmarkdiff/arcface_torch.py CHANGED
@@ -29,7 +29,6 @@ from __future__ import annotations
29
  import logging
30
  import warnings
31
  from pathlib import Path
32
- from typing import Optional
33
 
34
  import torch
35
  import torch.nn as nn
@@ -42,6 +41,7 @@ logger = logging.getLogger(__name__)
42
  # Building blocks
43
  # ---------------------------------------------------------------------------
44
 
 
45
  class SEModule(nn.Module):
46
  """Squeeze-and-Excitation channel attention (Hu et al., 2018).
47
 
@@ -79,18 +79,28 @@ class IBasicBlock(nn.Module):
79
  inplanes: int,
80
  planes: int,
81
  stride: int = 1,
82
- downsample: Optional[nn.Module] = None,
83
  use_se: bool = True,
84
  ):
85
  super().__init__()
86
  self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-5)
87
  self.conv1 = nn.Conv2d(
88
- inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False,
 
 
 
 
 
89
  )
90
  self.bn2 = nn.BatchNorm2d(planes, eps=1e-5)
91
  self.prelu = nn.PReLU(planes)
92
  self.conv2 = nn.Conv2d(
93
- planes, planes, kernel_size=3, stride=stride, padding=1, bias=False,
 
 
 
 
 
94
  )
95
  self.bn3 = nn.BatchNorm2d(planes, eps=1e-5)
96
 
@@ -120,6 +130,7 @@ class IBasicBlock(nn.Module):
120
  # Backbone
121
  # ---------------------------------------------------------------------------
122
 
 
123
  class ArcFaceBackbone(nn.Module):
124
  """IResNet-50 backbone for ArcFace identity embeddings.
125
 
@@ -257,7 +268,7 @@ _WEIGHT_URL = (
257
  )
258
 
259
 
260
- def _find_pretrained_weights() -> Optional[Path]:
261
  """Search known locations for pretrained IResNet-50 weights."""
262
  for p in _KNOWN_WEIGHT_PATHS:
263
  if p.exists() and p.suffix == ".pth":
@@ -269,6 +280,7 @@ def _try_download_weights(dest: Path) -> bool:
269
  """Attempt to download pretrained weights from the InsightFace release."""
270
  try:
271
  import urllib.request
 
272
  dest.parent.mkdir(parents=True, exist_ok=True)
273
  logger.info("Downloading ArcFace IResNet-50 weights from %s ...", _WEIGHT_URL)
274
  urllib.request.urlretrieve(_WEIGHT_URL, str(dest))
@@ -281,7 +293,7 @@ def _try_download_weights(dest: Path) -> bool:
281
 
282
  def load_pretrained_weights(
283
  model: ArcFaceBackbone,
284
- weights_path: Optional[str] = None,
285
  download: bool = True,
286
  ) -> bool:
287
  """Load pretrained InsightFace IResNet-50 weights into the model.
@@ -300,7 +312,7 @@ def load_pretrained_weights(
300
  ``True`` if weights were loaded successfully, ``False`` otherwise
301
  (model keeps random initialization).
302
  """
303
- path: Optional[Path] = None
304
 
305
  if weights_path is not None:
306
  path = Path(weights_path)
@@ -368,8 +380,7 @@ def load_pretrained_weights(
368
  return True
369
  except Exception as e:
370
  warnings.warn(
371
- f"Failed to load ArcFace weights from {path}: {e}. "
372
- "Using random initialization.",
373
  UserWarning,
374
  stacklevel=2,
375
  )
@@ -380,6 +391,7 @@ def load_pretrained_weights(
380
  # Differentiable face alignment
381
  # ---------------------------------------------------------------------------
382
 
 
383
  def align_face(
384
  images: torch.Tensor,
385
  size: int = 112,
@@ -402,7 +414,7 @@ def align_face(
402
  """
403
  B, C, H, W = images.shape
404
 
405
- if H == size and W == size:
406
  return images
407
 
408
  # Crop fraction: keep central 80% to remove background padding
@@ -414,13 +426,17 @@ def align_face(
414
  # grid_sample expects coordinates in [-1, 1] where -1 is top-left, +1 is bottom-right
415
  # Center crop: map [-1, 1] output range to [-crop_frac, +crop_frac] input range
416
  theta = torch.zeros(B, 2, 3, device=images.device, dtype=images.dtype)
417
- theta[:, 0, 0] = half_crop # x scale
418
- theta[:, 1, 1] = half_crop # y scale
419
  # translation stays 0 (centered)
420
 
421
  grid = F.affine_grid(theta, [B, C, size, size], align_corners=False)
422
  aligned = F.grid_sample(
423
- images, grid, mode="bilinear", padding_mode="border", align_corners=False,
 
 
 
 
424
  )
425
  return aligned
426
 
@@ -444,7 +460,10 @@ def align_face_no_crop(
444
  if images.shape[-2] == size and images.shape[-1] == size:
445
  return images
446
  return F.interpolate(
447
- images, size=(size, size), mode="bilinear", align_corners=False,
 
 
 
448
  )
449
 
450
 
@@ -452,6 +471,7 @@ def align_face_no_crop(
452
  # ArcFaceLoss: differentiable identity preservation loss
453
  # ---------------------------------------------------------------------------
454
 
 
455
  class ArcFaceLoss(nn.Module):
456
  """Differentiable identity loss using PyTorch-native ArcFace.
457
 
@@ -474,8 +494,8 @@ class ArcFaceLoss(nn.Module):
474
 
475
  def __init__(
476
  self,
477
- device: Optional[torch.device] = None,
478
- weights_path: Optional[str] = None,
479
  crop_face: bool = True,
480
  ):
481
  """
@@ -527,10 +547,7 @@ class ArcFaceLoss(nn.Module):
527
  Returns:
528
  (B, 3, 112, 112) in [-1, 1].
529
  """
530
- if self.crop_face:
531
- x = align_face(images, size=112)
532
- else:
533
- x = align_face_no_crop(images, size=112)
534
 
535
  # Normalize from [0, 1] to [-1, 1]
536
  x = x * 2.0 - 1.0
@@ -662,9 +679,10 @@ class ArcFaceLoss(nn.Module):
662
  # Convenience: create a pre-configured loss instance
663
  # ---------------------------------------------------------------------------
664
 
 
665
  def create_arcface_loss(
666
- device: Optional[torch.device] = None,
667
- weights_path: Optional[str] = None,
668
  ) -> ArcFaceLoss:
669
  """Factory function for creating an ArcFaceLoss with sensible defaults.
670
 
 
29
  import logging
30
  import warnings
31
  from pathlib import Path
 
32
 
33
  import torch
34
  import torch.nn as nn
 
41
  # Building blocks
42
  # ---------------------------------------------------------------------------
43
 
44
+
45
  class SEModule(nn.Module):
46
  """Squeeze-and-Excitation channel attention (Hu et al., 2018).
47
 
 
79
  inplanes: int,
80
  planes: int,
81
  stride: int = 1,
82
+ downsample: nn.Module | None = None,
83
  use_se: bool = True,
84
  ):
85
  super().__init__()
86
  self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-5)
87
  self.conv1 = nn.Conv2d(
88
+ inplanes,
89
+ planes,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1,
93
+ bias=False,
94
  )
95
  self.bn2 = nn.BatchNorm2d(planes, eps=1e-5)
96
  self.prelu = nn.PReLU(planes)
97
  self.conv2 = nn.Conv2d(
98
+ planes,
99
+ planes,
100
+ kernel_size=3,
101
+ stride=stride,
102
+ padding=1,
103
+ bias=False,
104
  )
105
  self.bn3 = nn.BatchNorm2d(planes, eps=1e-5)
106
 
 
130
  # Backbone
131
  # ---------------------------------------------------------------------------
132
 
133
+
134
  class ArcFaceBackbone(nn.Module):
135
  """IResNet-50 backbone for ArcFace identity embeddings.
136
 
 
268
  )
269
 
270
 
271
+ def _find_pretrained_weights() -> Path | None:
272
  """Search known locations for pretrained IResNet-50 weights."""
273
  for p in _KNOWN_WEIGHT_PATHS:
274
  if p.exists() and p.suffix == ".pth":
 
280
  """Attempt to download pretrained weights from the InsightFace release."""
281
  try:
282
  import urllib.request
283
+
284
  dest.parent.mkdir(parents=True, exist_ok=True)
285
  logger.info("Downloading ArcFace IResNet-50 weights from %s ...", _WEIGHT_URL)
286
  urllib.request.urlretrieve(_WEIGHT_URL, str(dest))
 
293
 
294
  def load_pretrained_weights(
295
  model: ArcFaceBackbone,
296
+ weights_path: str | None = None,
297
  download: bool = True,
298
  ) -> bool:
299
  """Load pretrained InsightFace IResNet-50 weights into the model.
 
312
  ``True`` if weights were loaded successfully, ``False`` otherwise
313
  (model keeps random initialization).
314
  """
315
+ path: Path | None = None
316
 
317
  if weights_path is not None:
318
  path = Path(weights_path)
 
380
  return True
381
  except Exception as e:
382
  warnings.warn(
383
+ f"Failed to load ArcFace weights from {path}: {e}. Using random initialization.",
 
384
  UserWarning,
385
  stacklevel=2,
386
  )
 
391
  # Differentiable face alignment
392
  # ---------------------------------------------------------------------------
393
 
394
+
395
  def align_face(
396
  images: torch.Tensor,
397
  size: int = 112,
 
414
  """
415
  B, C, H, W = images.shape
416
 
417
+ if size == H and size == W:
418
  return images
419
 
420
  # Crop fraction: keep central 80% to remove background padding
 
426
  # grid_sample expects coordinates in [-1, 1] where -1 is top-left, +1 is bottom-right
427
  # Center crop: map [-1, 1] output range to [-crop_frac, +crop_frac] input range
428
  theta = torch.zeros(B, 2, 3, device=images.device, dtype=images.dtype)
429
+ theta[:, 0, 0] = half_crop # x scale
430
+ theta[:, 1, 1] = half_crop # y scale
431
  # translation stays 0 (centered)
432
 
433
  grid = F.affine_grid(theta, [B, C, size, size], align_corners=False)
434
  aligned = F.grid_sample(
435
+ images,
436
+ grid,
437
+ mode="bilinear",
438
+ padding_mode="border",
439
+ align_corners=False,
440
  )
441
  return aligned
442
 
 
460
  if images.shape[-2] == size and images.shape[-1] == size:
461
  return images
462
  return F.interpolate(
463
+ images,
464
+ size=(size, size),
465
+ mode="bilinear",
466
+ align_corners=False,
467
  )
468
 
469
 
 
471
  # ArcFaceLoss: differentiable identity preservation loss
472
  # ---------------------------------------------------------------------------
473
 
474
+
475
  class ArcFaceLoss(nn.Module):
476
  """Differentiable identity loss using PyTorch-native ArcFace.
477
 
 
494
 
495
  def __init__(
496
  self,
497
+ device: torch.device | None = None,
498
+ weights_path: str | None = None,
499
  crop_face: bool = True,
500
  ):
501
  """
 
547
  Returns:
548
  (B, 3, 112, 112) in [-1, 1].
549
  """
550
+ x = align_face(images, size=112) if self.crop_face else align_face_no_crop(images, size=112)
 
 
 
551
 
552
  # Normalize from [0, 1] to [-1, 1]
553
  x = x * 2.0 - 1.0
 
679
  # Convenience: create a pre-configured loss instance
680
  # ---------------------------------------------------------------------------
681
 
682
+
683
  def create_arcface_loss(
684
+ device: torch.device | None = None,
685
+ weights_path: str | None = None,
686
  ) -> ArcFaceLoss:
687
  """Factory function for creating an ArcFaceLoss with sensible defaults.
688
 
landmarkdiff/audit.py CHANGED
@@ -116,12 +116,10 @@ class AuditReporter:
116
  if case.identity_sim > 0:
117
  by_proc[proc]["id_sims"].append(case.identity_sim)
118
 
119
- for proc, stats in by_proc.items():
120
  stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
121
  stats["mean_identity_sim"] = (
122
- sum(stats["id_sims"]) / len(stats["id_sims"])
123
- if stats["id_sims"]
124
- else 0.0
125
  )
126
  del stats["id_sims"]
127
 
@@ -137,12 +135,10 @@ class AuditReporter:
137
  if case.identity_sim > 0:
138
  by_fitz[ft]["id_sims"].append(case.identity_sim)
139
 
140
- for ft, stats in by_fitz.items():
141
  stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
142
  stats["mean_identity_sim"] = (
143
- sum(stats["id_sims"]) / len(stats["id_sims"])
144
- if stats["id_sims"]
145
- else 0.0
146
  )
147
  del stats["id_sims"]
148
 
@@ -268,7 +264,7 @@ class AuditReporter:
268
  f"<td>{c.procedure.title()}</td>"
269
  f"<td>{c.fitzpatrick_type}</td>"
270
  f"<td>{c.identity_sim:.4f}</td>"
271
- f'<td>{"WARN" if c.safety_passed else "FAIL"}</td>'
272
  f"<td>{issues}</td>"
273
  f"</tr>\n"
274
  )
 
116
  if case.identity_sim > 0:
117
  by_proc[proc]["id_sims"].append(case.identity_sim)
118
 
119
+ for _proc, stats in by_proc.items():
120
  stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
121
  stats["mean_identity_sim"] = (
122
+ sum(stats["id_sims"]) / len(stats["id_sims"]) if stats["id_sims"] else 0.0
 
 
123
  )
124
  del stats["id_sims"]
125
 
 
135
  if case.identity_sim > 0:
136
  by_fitz[ft]["id_sims"].append(case.identity_sim)
137
 
138
+ for _ft, stats in by_fitz.items():
139
  stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
140
  stats["mean_identity_sim"] = (
141
+ sum(stats["id_sims"]) / len(stats["id_sims"]) if stats["id_sims"] else 0.0
 
 
142
  )
143
  del stats["id_sims"]
144
 
 
264
  f"<td>{c.procedure.title()}</td>"
265
  f"<td>{c.fitzpatrick_type}</td>"
266
  f"<td>{c.identity_sim:.4f}</td>"
267
+ f"<td>{'WARN' if c.safety_passed else 'FAIL'}</td>"
268
  f"<td>{issues}</td>"
269
  f"</tr>\n"
270
  )
landmarkdiff/augmentation.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training data augmentation pipeline for LandmarkDiff.
2
+
3
+ Provides domain-specific augmentations that maintain landmark consistency:
4
+ - Geometric: flip, rotation, affine (landmarks co-transformed)
5
+ - Photometric: color jitter, brightness, contrast (applied to images only)
6
+ - Skin-tone augmentation: ITA-space perturbation for Fitzpatrick balance
7
+ - Conditioning augmentation: noise injection, dropout for robustness
8
+
9
+ All augmentations preserve the correspondence between:
10
+ input_image ↔ conditioning_image ↔ target_image ↔ mask
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from dataclasses import dataclass
16
+
17
+ import cv2
18
+ import numpy as np
19
+
20
+
21
+ @dataclass
22
+ class AugmentationConfig:
23
+ """Augmentation parameters."""
24
+
25
+ # Geometric
26
+ random_flip: bool = True
27
+ random_rotation_deg: float = 5.0
28
+ random_scale: tuple[float, float] = (0.95, 1.05)
29
+ random_translate: float = 0.02 # fraction of image size
30
+
31
+ # Photometric (images only, not conditioning)
32
+ brightness_range: tuple[float, float] = (0.9, 1.1)
33
+ contrast_range: tuple[float, float] = (0.9, 1.1)
34
+ saturation_range: tuple[float, float] = (0.9, 1.1)
35
+ hue_shift_range: float = 5.0 # degrees
36
+
37
+ # Conditioning augmentation
38
+ conditioning_dropout_prob: float = 0.1
39
+ conditioning_noise_std: float = 0.02
40
+
41
+ # Skin-tone augmentation
42
+ ita_perturbation_std: float = 3.0 # ITA angle noise
43
+
44
+ seed: int | None = None
45
+
46
+
47
+ def augment_training_sample(
48
+ input_image: np.ndarray,
49
+ target_image: np.ndarray,
50
+ conditioning: np.ndarray,
51
+ mask: np.ndarray,
52
+ landmarks_src: np.ndarray | None = None,
53
+ landmarks_dst: np.ndarray | None = None,
54
+ config: AugmentationConfig | None = None,
55
+ rng: np.random.Generator | None = None,
56
+ ) -> dict[str, np.ndarray]:
57
+ """Apply consistent augmentations to a training sample.
58
+
59
+ All spatial transforms are applied to images AND landmarks together
60
+ so correspondence is preserved.
61
+
62
+ Args:
63
+ input_image: (H, W, 3) original face image (uint8 BGR).
64
+ target_image: (H, W, 3) target face image (uint8 BGR).
65
+ conditioning: (H, W, 3) conditioning image (uint8).
66
+ mask: (H, W) or (H, W, 1) float32 mask.
67
+ landmarks_src: (N, 2) normalized [0,1] source landmark coords.
68
+ landmarks_dst: (N, 2) normalized [0,1] target landmark coords.
69
+ config: Augmentation parameters.
70
+ rng: Random generator for reproducibility.
71
+
72
+ Returns:
73
+ Dict with augmented versions of all inputs.
74
+ """
75
+ if config is None:
76
+ config = AugmentationConfig()
77
+ if rng is None:
78
+ rng = np.random.default_rng(config.seed)
79
+
80
+ h, w = input_image.shape[:2]
81
+ out_input = input_image.copy()
82
+ out_target = target_image.copy()
83
+ out_cond = conditioning.copy()
84
+ out_mask = mask.copy()
85
+ out_lm_src = landmarks_src.copy() if landmarks_src is not None else None
86
+ out_lm_dst = landmarks_dst.copy() if landmarks_dst is not None else None
87
+
88
+ # --- Geometric augmentations (applied to all) ---
89
+
90
+ # Random horizontal flip
91
+ if config.random_flip and rng.random() < 0.5:
92
+ out_input = np.ascontiguousarray(out_input[:, ::-1])
93
+ out_target = np.ascontiguousarray(out_target[:, ::-1])
94
+ out_cond = np.ascontiguousarray(out_cond[:, ::-1])
95
+ out_mask = np.ascontiguousarray(
96
+ out_mask[:, ::-1] if out_mask.ndim == 2 else out_mask[:, ::-1, :]
97
+ )
98
+ if out_lm_src is not None:
99
+ out_lm_src[:, 0] = 1.0 - out_lm_src[:, 0]
100
+ if out_lm_dst is not None:
101
+ out_lm_dst[:, 0] = 1.0 - out_lm_dst[:, 0]
102
+
103
+ # Random rotation + scale + translate
104
+ if config.random_rotation_deg > 0 or config.random_scale != (1.0, 1.0):
105
+ angle = rng.uniform(-config.random_rotation_deg, config.random_rotation_deg)
106
+ scale = rng.uniform(config.random_scale[0], config.random_scale[1])
107
+ tx = rng.uniform(-config.random_translate, config.random_translate) * w
108
+ ty = rng.uniform(-config.random_translate, config.random_translate) * h
109
+
110
+ center = (w / 2, h / 2)
111
+ M = cv2.getRotationMatrix2D(center, angle, scale)
112
+ M[0, 2] += tx
113
+ M[1, 2] += ty
114
+
115
+ out_input = cv2.warpAffine(out_input, M, (w, h), borderMode=cv2.BORDER_REFLECT_101)
116
+ out_target = cv2.warpAffine(out_target, M, (w, h), borderMode=cv2.BORDER_REFLECT_101)
117
+ out_cond = cv2.warpAffine(
118
+ out_cond, M, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=0
119
+ )
120
+ mask_2d = out_mask if out_mask.ndim == 2 else out_mask[:, :, 0]
121
+ mask_2d = cv2.warpAffine(mask_2d, M, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=0)
122
+ out_mask = mask_2d if out_mask.ndim == 2 else mask_2d[:, :, np.newaxis]
123
+
124
+ # Transform landmarks
125
+ if out_lm_src is not None:
126
+ out_lm_src = _transform_landmarks(out_lm_src, M, w, h)
127
+ if out_lm_dst is not None:
128
+ out_lm_dst = _transform_landmarks(out_lm_dst, M, w, h)
129
+
130
+ # --- Photometric augmentations (images only, not conditioning/mask) ---
131
+
132
+ # Brightness
133
+ b_factor = rng.uniform(config.brightness_range[0], config.brightness_range[1])
134
+ out_input = np.clip(out_input.astype(np.float32) * b_factor, 0, 255).astype(np.uint8)
135
+ out_target = np.clip(out_target.astype(np.float32) * b_factor, 0, 255).astype(np.uint8)
136
+
137
+ # Contrast
138
+ c_factor = rng.uniform(config.contrast_range[0], config.contrast_range[1])
139
+ mean_in = out_input.mean()
140
+ mean_tgt = out_target.mean()
141
+ out_input = np.clip(
142
+ (out_input.astype(np.float32) - mean_in) * c_factor + mean_in, 0, 255
143
+ ).astype(np.uint8)
144
+ out_target = np.clip(
145
+ (out_target.astype(np.float32) - mean_tgt) * c_factor + mean_tgt, 0, 255
146
+ ).astype(np.uint8)
147
+
148
+ # Saturation (in HSV space)
149
+ s_factor = rng.uniform(config.saturation_range[0], config.saturation_range[1])
150
+ if abs(s_factor - 1.0) > 1e-4:
151
+ out_input = _adjust_saturation(out_input, s_factor)
152
+ out_target = _adjust_saturation(out_target, s_factor)
153
+
154
+ # Hue shift
155
+ if config.hue_shift_range > 0:
156
+ hue_delta = rng.uniform(-config.hue_shift_range, config.hue_shift_range)
157
+ if abs(hue_delta) > 0.1:
158
+ out_input = _shift_hue(out_input, hue_delta)
159
+ out_target = _shift_hue(out_target, hue_delta)
160
+
161
+ # --- Conditioning augmentation ---
162
+
163
+ # Conditioning dropout (replace with zeros to learn unconditional)
164
+ if config.conditioning_dropout_prob > 0 and rng.random() < config.conditioning_dropout_prob:
165
+ out_cond = np.zeros_like(out_cond)
166
+
167
+ # Conditioning noise
168
+ if config.conditioning_noise_std > 0:
169
+ noise = rng.normal(0, config.conditioning_noise_std * 255, out_cond.shape)
170
+ out_cond = np.clip(out_cond.astype(np.float32) + noise, 0, 255).astype(np.uint8)
171
+
172
+ result = {
173
+ "input_image": out_input,
174
+ "target_image": out_target,
175
+ "conditioning": out_cond,
176
+ "mask": out_mask,
177
+ }
178
+ if out_lm_src is not None:
179
+ result["landmarks_src"] = out_lm_src
180
+ if out_lm_dst is not None:
181
+ result["landmarks_dst"] = out_lm_dst
182
+
183
+ return result
184
+
185
+
186
+ def _transform_landmarks(landmarks: np.ndarray, M: np.ndarray, w: int, h: int) -> np.ndarray:
187
+ """Transform normalized landmarks with an affine matrix."""
188
+ # Convert to pixel coords
189
+ px = landmarks.copy()
190
+ px[:, 0] *= w
191
+ px[:, 1] *= h
192
+
193
+ # Apply affine transform
194
+ ones = np.ones((px.shape[0], 1))
195
+ px_h = np.hstack([px, ones]) # (N, 3)
196
+ transformed = (M @ px_h.T).T # (N, 2)
197
+
198
+ # Back to normalized
199
+ transformed[:, 0] /= w
200
+ transformed[:, 1] /= h
201
+ return np.clip(transformed, 0.0, 1.0)
202
+
203
+
204
+ def _adjust_saturation(img: np.ndarray, factor: float) -> np.ndarray:
205
+ """Adjust saturation of a BGR image."""
206
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
207
+ hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255)
208
+ return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
209
+
210
+
211
+ def _shift_hue(img: np.ndarray, delta_deg: float) -> np.ndarray:
212
+ """Shift hue of a BGR image by delta degrees."""
213
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
214
+ # OpenCV hue range is [0, 180]
215
+ hsv[:, :, 0] = (hsv[:, :, 0] + delta_deg / 2) % 180
216
+ return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
217
+
218
+
219
+ def augment_skin_tone(
220
+ image: np.ndarray,
221
+ ita_delta: float = 0.0,
222
+ ) -> np.ndarray:
223
+ """Augment skin tone by shifting in L*a*b* space.
224
+
225
+ This helps balance Fitzpatrick representation in training by
226
+ simulating different skin tones from existing samples.
227
+
228
+ Args:
229
+ image: (H, W, 3) BGR uint8 image.
230
+ ita_delta: ITA angle shift (positive = lighter, negative = darker).
231
+
232
+ Returns:
233
+ Augmented image with shifted skin tone.
234
+ """
235
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
236
+
237
+ # Shift L channel (lightness) based on ITA delta
238
+ # ITA = arctan((L-50)/b), so shifting ITA shifts L
239
+ l_shift = ita_delta * 0.5 # approximate mapping
240
+ lab[:, :, 0] = np.clip(lab[:, :, 0] + l_shift, 0, 255)
241
+
242
+ # Slightly shift b channel too for more natural tone changes
243
+ b_shift = -ita_delta * 0.15
244
+ lab[:, :, 2] = np.clip(lab[:, :, 2] + b_shift, 0, 255)
245
+
246
+ return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
247
+
248
+
249
+ class FitzpatrickBalancer:
250
+ """Oversample underrepresented Fitzpatrick types during training.
251
+
252
+ Maintains per-type counts and generates sampling weights to ensure
253
+ equitable training across all skin types.
254
+ """
255
+
256
+ def __init__(self, target_distribution: dict[str, float] | None = None):
257
+ """Initialize balancer.
258
+
259
+ Args:
260
+ target_distribution: Target fraction per type. Defaults to uniform.
261
+ """
262
+ self.target = target_distribution or {
263
+ "I": 1 / 6,
264
+ "II": 1 / 6,
265
+ "III": 1 / 6,
266
+ "IV": 1 / 6,
267
+ "V": 1 / 6,
268
+ "VI": 1 / 6,
269
+ }
270
+ self._counts: dict[str, int] = {}
271
+
272
+ def register_sample(self, fitz_type: str) -> None:
273
+ """Register a sample's Fitzpatrick type."""
274
+ self._counts[fitz_type] = self._counts.get(fitz_type, 0) + 1
275
+
276
+ def get_sampling_weights(self, fitz_types: list[str]) -> np.ndarray:
277
+ """Compute sampling weights for a list of samples.
278
+
279
+ Returns weights inversely proportional to type frequency,
280
+ so underrepresented types get upsampled.
281
+ """
282
+ total = sum(self._counts.values()) or 1
283
+ weights = []
284
+ for ft in fitz_types:
285
+ count = self._counts.get(ft, 1)
286
+ freq = count / total
287
+ target_freq = self.target.get(ft, 1 / 6)
288
+ # Weight = target / actual (capped for stability)
289
+ w = min(target_freq / max(freq, 1e-6), 5.0)
290
+ weights.append(w)
291
+
292
+ w = np.array(weights, dtype=np.float64)
293
+ return w / w.sum() # normalize to probability distribution
landmarkdiff/benchmark.py CHANGED
@@ -63,17 +63,19 @@ class InferenceBenchmark:
63
  if throughput_fps == 0.0 and latency_ms > 0:
64
  throughput_fps = 1000.0 / latency_ms * batch_size
65
 
66
- self.results.append(BenchmarkResult(
67
- config_name=config_name,
68
- latency_ms=latency_ms,
69
- throughput_fps=throughput_fps,
70
- vram_gb=vram_gb,
71
- batch_size=batch_size,
72
- resolution=resolution,
73
- num_inference_steps=num_inference_steps,
74
- device=device,
75
- metadata=metadata,
76
- ))
 
 
77
 
78
  def mean_latency(self, config_name: str | None = None) -> float:
79
  """Mean latency in ms, optionally filtered by config."""
@@ -124,7 +126,10 @@ class InferenceBenchmark:
124
  if not configs:
125
  return "No benchmark results."
126
 
127
- header = f"{'Config':>20s} | {'Mean(ms)':>10s} | {'P99(ms)':>10s} | {'FPS':>8s} | {'VRAM(GB)':>8s} | {'N':>4s}"
 
 
 
128
  lines = [
129
  f"Inference Benchmark: {self.model_name}",
130
  header,
 
63
  if throughput_fps == 0.0 and latency_ms > 0:
64
  throughput_fps = 1000.0 / latency_ms * batch_size
65
 
66
+ self.results.append(
67
+ BenchmarkResult(
68
+ config_name=config_name,
69
+ latency_ms=latency_ms,
70
+ throughput_fps=throughput_fps,
71
+ vram_gb=vram_gb,
72
+ batch_size=batch_size,
73
+ resolution=resolution,
74
+ num_inference_steps=num_inference_steps,
75
+ device=device,
76
+ metadata=metadata,
77
+ )
78
+ )
79
 
80
  def mean_latency(self, config_name: str | None = None) -> float:
81
  """Mean latency in ms, optionally filtered by config."""
 
126
  if not configs:
127
  return "No benchmark results."
128
 
129
+ header = (
130
+ f"{'Config':>20s} | {'Mean(ms)':>10s} | {'P99(ms)':>10s}"
131
+ f" | {'FPS':>8s} | {'VRAM(GB)':>8s} | {'N':>4s}"
132
+ )
133
  lines = [
134
  f"Inference Benchmark: {self.model_name}",
135
  header,
landmarkdiff/checkpoint_manager.py CHANGED
@@ -166,9 +166,7 @@ class CheckpointManager:
166
  torch.save(state, ckpt_dir / "training_state.pt")
167
 
168
  # Compute checkpoint size
169
- size_mb = sum(
170
- f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()
171
- ) / (1024 * 1024)
172
 
173
  # Create metadata
174
  meta = CheckpointMetadata(
@@ -216,7 +214,7 @@ class CheckpointManager:
216
  entries.sort(key=lambda x: x[1], reverse=not self.lower_is_better)
217
 
218
  # Mark best
219
- best_names = {e[0] for e in entries[:self.keep_best]}
220
  for name, meta in self._index["checkpoints"].items():
221
  meta["is_best"] = name in best_names
222
 
@@ -245,11 +243,11 @@ class CheckpointManager:
245
  val = meta.get("metrics", {}).get(self.metric)
246
  if val is None:
247
  continue
248
- if best_val is None:
249
- best, best_val = name, val
250
- elif self.lower_is_better and val < best_val:
251
- best, best_val = name, val
252
- elif not self.lower_is_better and val > best_val:
253
  best, best_val = name, val
254
  return best
255
 
@@ -280,7 +278,7 @@ class CheckpointManager:
280
  keep = set()
281
 
282
  # Keep latest
283
- for name in all_names[-self.keep_latest:]:
284
  keep.add(name)
285
 
286
  # Keep best
@@ -323,10 +321,7 @@ class CheckpointManager:
323
 
324
  def total_size_mb(self) -> float:
325
  """Return total disk size of all tracked checkpoints."""
326
- return sum(
327
- meta.get("size_mb", 0.0)
328
- for meta in self._index["checkpoints"].values()
329
- )
330
 
331
  def summary(self) -> str:
332
  """Return a human-readable summary of checkpoint state."""
@@ -351,6 +346,7 @@ class CheckpointManager:
351
  # Helpers
352
  # ------------------------------------------------------------------
353
 
 
354
  def _get_state_dict(module: torch.nn.Module) -> dict:
355
  """Extract state dict, handling DDP wrapper."""
356
  if hasattr(module, "module"):
 
166
  torch.save(state, ckpt_dir / "training_state.pt")
167
 
168
  # Compute checkpoint size
169
+ size_mb = sum(f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()) / (1024 * 1024)
 
 
170
 
171
  # Create metadata
172
  meta = CheckpointMetadata(
 
214
  entries.sort(key=lambda x: x[1], reverse=not self.lower_is_better)
215
 
216
  # Mark best
217
+ best_names = {e[0] for e in entries[: self.keep_best]}
218
  for name, meta in self._index["checkpoints"].items():
219
  meta["is_best"] = name in best_names
220
 
 
243
  val = meta.get("metrics", {}).get(self.metric)
244
  if val is None:
245
  continue
246
+ if (
247
+ best_val is None
248
+ or (self.lower_is_better and val < best_val)
249
+ or (not self.lower_is_better and val > best_val)
250
+ ):
251
  best, best_val = name, val
252
  return best
253
 
 
278
  keep = set()
279
 
280
  # Keep latest
281
+ for name in all_names[-self.keep_latest :]:
282
  keep.add(name)
283
 
284
  # Keep best
 
321
 
322
  def total_size_mb(self) -> float:
323
  """Return total disk size of all tracked checkpoints."""
324
+ return sum(meta.get("size_mb", 0.0) for meta in self._index["checkpoints"].values())
 
 
 
325
 
326
  def summary(self) -> str:
327
  """Return a human-readable summary of checkpoint state."""
 
346
  # Helpers
347
  # ------------------------------------------------------------------
348
 
349
+
350
  def _get_state_dict(module: torch.nn.Module) -> dict:
351
  """Extract state dict, handling DDP wrapper."""
352
  if hasattr(module, "module"):
landmarkdiff/cli.py CHANGED
@@ -17,10 +17,10 @@ import sys
17
 
18
  def cmd_infer(args: argparse.Namespace) -> None:
19
  """Run single-image inference."""
20
- import cv2
21
- import numpy as np
22
  from pathlib import Path
23
 
 
 
24
  from landmarkdiff.inference import LandmarkDiffPipeline
25
 
26
  image = cv2.imread(args.image)
@@ -51,6 +51,7 @@ def cmd_infer(args: argparse.Namespace) -> None:
51
 
52
  if args.watermark:
53
  from landmarkdiff.safety import SafetyValidator
 
54
  validator = SafetyValidator()
55
  watermarked = validator.apply_watermark(result["output"])
56
  wm_path = out_path.with_stem(out_path.stem + "_watermarked")
@@ -87,9 +88,7 @@ def cmd_evaluate(args: argparse.Namespace) -> None:
87
  run_evaluation(
88
  test_dir=args.test_dir,
89
  output_dir=args.output,
90
- mode=args.mode,
91
  checkpoint=args.checkpoint,
92
- displacement_model=args.displacement_model,
93
  max_samples=args.max_samples,
94
  )
95
 
@@ -98,10 +97,7 @@ def cmd_config(args: argparse.Namespace) -> None:
98
  """Show or validate configuration."""
99
  from landmarkdiff.config import ExperimentConfig, load_config, validate_config
100
 
101
- if args.file:
102
- config = load_config(args.file)
103
- else:
104
- config = ExperimentConfig()
105
 
106
  if args.validate:
107
  warnings = validate_config(config)
@@ -112,14 +108,17 @@ def cmd_config(args: argparse.Namespace) -> None:
112
  else:
113
  print("Configuration valid (no warnings).")
114
  else:
115
- import yaml
116
  from dataclasses import asdict
 
 
 
117
  print(yaml.dump(asdict(config), default_flow_style=False, sort_keys=False))
118
 
119
 
120
  def cmd_validate(args: argparse.Namespace) -> None:
121
  """Run safety validation on an output image."""
122
  import cv2
 
123
  from landmarkdiff.safety import SafetyValidator
124
 
125
  input_img = cv2.imread(args.input)
@@ -148,6 +147,7 @@ def cmd_validate(args: argparse.Namespace) -> None:
148
  def cmd_version(args: argparse.Namespace) -> None:
149
  """Print version info."""
150
  from landmarkdiff import __version__
 
151
  print(f"LandmarkDiff v{__version__}")
152
 
153
 
@@ -162,8 +162,11 @@ def main(argv: list[str] | None = None) -> None:
162
  # --- infer ---
163
  p_infer = subparsers.add_parser("infer", help="Run single-image inference")
164
  p_infer.add_argument("image", help="Input face image path")
165
- p_infer.add_argument("--procedure", default="rhinoplasty",
166
- choices=["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic"])
 
 
 
167
  p_infer.add_argument("--intensity", type=float, default=65.0)
168
  p_infer.add_argument("--output", default="output.png")
169
  p_infer.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
@@ -180,8 +183,11 @@ def main(argv: list[str] | None = None) -> None:
180
  p_ensemble.add_argument("--intensity", type=float, default=65.0)
181
  p_ensemble.add_argument("--output", default="ensemble_output")
182
  p_ensemble.add_argument("--n-samples", type=int, default=5)
183
- p_ensemble.add_argument("--strategy", default="best_of_n",
184
- choices=["pixel_average", "weighted_average", "best_of_n", "median"])
 
 
 
185
  p_ensemble.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
186
  p_ensemble.add_argument("--checkpoint", default=None)
187
  p_ensemble.add_argument("--displacement-model", default=None)
 
17
 
18
  def cmd_infer(args: argparse.Namespace) -> None:
19
  """Run single-image inference."""
 
 
20
  from pathlib import Path
21
 
22
+ import cv2
23
+
24
  from landmarkdiff.inference import LandmarkDiffPipeline
25
 
26
  image = cv2.imread(args.image)
 
51
 
52
  if args.watermark:
53
  from landmarkdiff.safety import SafetyValidator
54
+
55
  validator = SafetyValidator()
56
  watermarked = validator.apply_watermark(result["output"])
57
  wm_path = out_path.with_stem(out_path.stem + "_watermarked")
 
88
  run_evaluation(
89
  test_dir=args.test_dir,
90
  output_dir=args.output,
 
91
  checkpoint=args.checkpoint,
 
92
  max_samples=args.max_samples,
93
  )
94
 
 
97
  """Show or validate configuration."""
98
  from landmarkdiff.config import ExperimentConfig, load_config, validate_config
99
 
100
+ config = load_config(args.file) if args.file else ExperimentConfig()
 
 
 
101
 
102
  if args.validate:
103
  warnings = validate_config(config)
 
108
  else:
109
  print("Configuration valid (no warnings).")
110
  else:
 
111
  from dataclasses import asdict
112
+
113
+ import yaml
114
+
115
  print(yaml.dump(asdict(config), default_flow_style=False, sort_keys=False))
116
 
117
 
118
  def cmd_validate(args: argparse.Namespace) -> None:
119
  """Run safety validation on an output image."""
120
  import cv2
121
+
122
  from landmarkdiff.safety import SafetyValidator
123
 
124
  input_img = cv2.imread(args.input)
 
147
  def cmd_version(args: argparse.Namespace) -> None:
148
  """Print version info."""
149
  from landmarkdiff import __version__
150
+
151
  print(f"LandmarkDiff v{__version__}")
152
 
153
 
 
162
  # --- infer ---
163
  p_infer = subparsers.add_parser("infer", help="Run single-image inference")
164
  p_infer.add_argument("image", help="Input face image path")
165
+ p_infer.add_argument(
166
+ "--procedure",
167
+ default="rhinoplasty",
168
+ choices=["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic"],
169
+ )
170
  p_infer.add_argument("--intensity", type=float, default=65.0)
171
  p_infer.add_argument("--output", default="output.png")
172
  p_infer.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
 
183
  p_ensemble.add_argument("--intensity", type=float, default=65.0)
184
  p_ensemble.add_argument("--output", default="ensemble_output")
185
  p_ensemble.add_argument("--n-samples", type=int, default=5)
186
+ p_ensemble.add_argument(
187
+ "--strategy",
188
+ default="best_of_n",
189
+ choices=["pixel_average", "weighted_average", "best_of_n", "median"],
190
+ )
191
  p_ensemble.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
192
  p_ensemble.add_argument("--checkpoint", default=None)
193
  p_ensemble.add_argument("--displacement-model", default=None)
landmarkdiff/clinical.py CHANGED
@@ -80,9 +80,9 @@ def detect_vitiligo_patches(
80
  # Also check for low saturation (a,b channels close to 128)
81
  a_channel = lab[:, :, 1]
82
  b_channel = lab[:, :, 2]
83
- low_sat = (
84
- (np.abs(a_channel - 128) < 15) & (np.abs(b_channel - 128) < 15)
85
- ).astype(np.uint8) * 255
86
 
87
  # Combined: bright AND low-saturation within face
88
  vitiligo_raw = cv2.bitwise_and(bright_mask, low_sat)
 
80
  # Also check for low saturation (a,b channels close to 128)
81
  a_channel = lab[:, :, 1]
82
  b_channel = lab[:, :, 2]
83
+ low_sat = ((np.abs(a_channel - 128) < 15) & (np.abs(b_channel - 128) < 15)).astype(
84
+ np.uint8
85
+ ) * 255
86
 
87
  # Combined: bright AND low-saturation within face
88
  vitiligo_raw = cv2.bitwise_and(bright_mask, low_sat)
landmarkdiff/conditioning.py CHANGED
@@ -18,17 +18,83 @@ from landmarkdiff.landmarks import FaceLandmarks
18
  # This is invariant to landmark displacement (unlike Delaunay).
19
 
20
  JAWLINE_CONTOUR = [
21
- 10, 338, 297, 332, 284, 251, 389, 356, 454, 323, 361, 288,
22
- 397, 365, 379, 378, 400, 377, 152, 148, 176, 149, 150, 136,
23
- 172, 58, 132, 93, 234, 127, 162, 21, 54, 103, 67, 109, 10,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ]
25
 
26
  LEFT_EYE_CONTOUR = [
27
- 33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246, 33,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  ]
29
 
30
  RIGHT_EYE_CONTOUR = [
31
- 362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386, 385, 384, 398, 362,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ]
33
 
34
  LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
@@ -39,13 +105,53 @@ NOSE_TIP = [94, 2, 326, 327, 294, 278, 279, 275, 274, 460, 456, 363, 370]
39
  NOSE_BOTTOM = [19, 1, 274, 275, 440, 344, 278, 294, 460, 305, 289, 392]
40
 
41
  OUTER_LIPS = [
42
- 61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291,
43
- 308, 324, 318, 402, 317, 14, 87, 178, 88, 95, 78, 61,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ]
45
 
46
  INNER_LIPS = [
47
- 78, 191, 80, 81, 82, 13, 312, 311, 310, 415, 308,
48
- 324, 318, 402, 317, 14, 87, 178, 88, 95, 78,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  ]
50
 
51
  ALL_CONTOURS = [
 
18
  # This is invariant to landmark displacement (unlike Delaunay).
19
 
20
  JAWLINE_CONTOUR = [
21
+ 10,
22
+ 338,
23
+ 297,
24
+ 332,
25
+ 284,
26
+ 251,
27
+ 389,
28
+ 356,
29
+ 454,
30
+ 323,
31
+ 361,
32
+ 288,
33
+ 397,
34
+ 365,
35
+ 379,
36
+ 378,
37
+ 400,
38
+ 377,
39
+ 152,
40
+ 148,
41
+ 176,
42
+ 149,
43
+ 150,
44
+ 136,
45
+ 172,
46
+ 58,
47
+ 132,
48
+ 93,
49
+ 234,
50
+ 127,
51
+ 162,
52
+ 21,
53
+ 54,
54
+ 103,
55
+ 67,
56
+ 109,
57
+ 10,
58
  ]
59
 
60
  LEFT_EYE_CONTOUR = [
61
+ 33,
62
+ 7,
63
+ 163,
64
+ 144,
65
+ 145,
66
+ 153,
67
+ 154,
68
+ 155,
69
+ 133,
70
+ 173,
71
+ 157,
72
+ 158,
73
+ 159,
74
+ 160,
75
+ 161,
76
+ 246,
77
+ 33,
78
  ]
79
 
80
  RIGHT_EYE_CONTOUR = [
81
+ 362,
82
+ 382,
83
+ 381,
84
+ 380,
85
+ 374,
86
+ 373,
87
+ 390,
88
+ 249,
89
+ 263,
90
+ 466,
91
+ 388,
92
+ 387,
93
+ 386,
94
+ 385,
95
+ 384,
96
+ 398,
97
+ 362,
98
  ]
99
 
100
  LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
 
105
  NOSE_BOTTOM = [19, 1, 274, 275, 440, 344, 278, 294, 460, 305, 289, 392]
106
 
107
  OUTER_LIPS = [
108
+ 61,
109
+ 146,
110
+ 91,
111
+ 181,
112
+ 84,
113
+ 17,
114
+ 314,
115
+ 405,
116
+ 321,
117
+ 375,
118
+ 291,
119
+ 308,
120
+ 324,
121
+ 318,
122
+ 402,
123
+ 317,
124
+ 14,
125
+ 87,
126
+ 178,
127
+ 88,
128
+ 95,
129
+ 78,
130
+ 61,
131
  ]
132
 
133
  INNER_LIPS = [
134
+ 78,
135
+ 191,
136
+ 80,
137
+ 81,
138
+ 82,
139
+ 13,
140
+ 312,
141
+ 311,
142
+ 310,
143
+ 415,
144
+ 308,
145
+ 324,
146
+ 318,
147
+ 402,
148
+ 317,
149
+ 14,
150
+ 87,
151
+ 178,
152
+ 88,
153
+ 95,
154
+ 78,
155
  ]
156
 
157
  ALL_CONTOURS = [
landmarkdiff/config.py CHANGED
@@ -18,7 +18,8 @@ Usage:
18
 
19
  from __future__ import annotations
20
 
21
- from dataclasses import dataclass, field, asdict
 
22
  from pathlib import Path
23
  from typing import Any
24
 
@@ -28,6 +29,7 @@ import yaml
28
  @dataclass
29
  class ModelConfig:
30
  """ControlNet and base model configuration."""
 
31
  base_model: str = "runwayml/stable-diffusion-v1-5"
32
  controlnet_conditioning_channels: int = 3
33
  controlnet_conditioning_scale: float = 1.0
@@ -39,6 +41,7 @@ class ModelConfig:
39
  @dataclass
40
  class TrainingConfig:
41
  """Training hyperparameters."""
 
42
  phase: str = "A" # "A" or "B"
43
  learning_rate: float = 1e-5
44
  batch_size: int = 4
@@ -77,6 +80,7 @@ class TrainingConfig:
77
  @dataclass
78
  class DataConfig:
79
  """Dataset configuration."""
 
80
  train_dir: str = "data/training"
81
  val_dir: str = "data/validation"
82
  test_dir: str = "data/test"
@@ -90,9 +94,14 @@ class DataConfig:
90
  color_jitter: float = 0.1
91
 
92
  # Procedure filtering
93
- procedures: list[str] = field(default_factory=lambda: [
94
- "rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic",
95
- ])
 
 
 
 
 
96
  intensity_range: tuple[float, float] = (30.0, 100.0)
97
 
98
  # Data-driven displacement
@@ -103,6 +112,7 @@ class DataConfig:
103
  @dataclass
104
  class InferenceConfig:
105
  """Inference / generation configuration."""
 
106
  num_inference_steps: int = 30
107
  guidance_scale: float = 7.5
108
  scheduler: str = "dpmsolver++" # "ddpm", "ddim", "dpmsolver++"
@@ -124,6 +134,7 @@ class InferenceConfig:
124
  @dataclass
125
  class EvaluationConfig:
126
  """Evaluation configuration."""
 
127
  compute_fid: bool = True
128
  compute_lpips: bool = True
129
  compute_nme: bool = True
@@ -137,6 +148,7 @@ class EvaluationConfig:
137
  @dataclass
138
  class WandbConfig:
139
  """Weights & Biases logging configuration."""
 
140
  enabled: bool = True
141
  project: str = "landmarkdiff"
142
  entity: str | None = None
@@ -147,8 +159,9 @@ class WandbConfig:
147
  @dataclass
148
  class SlurmConfig:
149
  """SLURM job submission parameters."""
 
150
  partition: str = "batch_gpu"
151
- account: str = "csb_gpu_acc"
152
  gpu_type: str = "nvidia_rtx_a6000"
153
  num_gpus: int = 1
154
  mem: str = "48G"
@@ -160,6 +173,7 @@ class SlurmConfig:
160
  @dataclass
161
  class SafetyConfig:
162
  """Clinical safety and responsible AI parameters."""
 
163
  identity_threshold: float = 0.6
164
  max_displacement_fraction: float = 0.05
165
  watermark_enabled: bool = True
@@ -173,6 +187,7 @@ class SafetyConfig:
173
  @dataclass
174
  class ExperimentConfig:
175
  """Top-level experiment configuration."""
 
176
  experiment_name: str = "default"
177
  description: str = ""
178
  version: str = "0.3.0"
@@ -227,9 +242,10 @@ class ExperimentConfig:
227
  return asdict(self)
228
 
229
 
230
- def _from_dict(cls, d: dict):
231
  """Create a dataclass from a dict, ignoring unknown keys."""
232
  import dataclasses
 
233
  field_map = {f.name: f for f in dataclasses.fields(cls)}
234
  filtered = {}
235
  for k, v in d.items():
@@ -243,7 +259,7 @@ def _from_dict(cls, d: dict):
243
  return cls(**filtered)
244
 
245
 
246
- def _convert_tuples(obj):
247
  """Recursively convert tuples to lists for YAML serialization."""
248
  if isinstance(obj, dict):
249
  return {k: _convert_tuples(v) for k, v in obj.items()}
@@ -266,10 +282,7 @@ def load_config(
266
  Returns:
267
  ExperimentConfig with overrides applied.
268
  """
269
- if config_path:
270
- config = ExperimentConfig.from_yaml(config_path)
271
- else:
272
- config = ExperimentConfig()
273
 
274
  if overrides:
275
  for key, value in overrides.items():
 
18
 
19
  from __future__ import annotations
20
 
21
+ import os
22
+ from dataclasses import asdict, dataclass, field
23
  from pathlib import Path
24
  from typing import Any
25
 
 
29
  @dataclass
30
  class ModelConfig:
31
  """ControlNet and base model configuration."""
32
+
33
  base_model: str = "runwayml/stable-diffusion-v1-5"
34
  controlnet_conditioning_channels: int = 3
35
  controlnet_conditioning_scale: float = 1.0
 
41
  @dataclass
42
  class TrainingConfig:
43
  """Training hyperparameters."""
44
+
45
  phase: str = "A" # "A" or "B"
46
  learning_rate: float = 1e-5
47
  batch_size: int = 4
 
80
  @dataclass
81
  class DataConfig:
82
  """Dataset configuration."""
83
+
84
  train_dir: str = "data/training"
85
  val_dir: str = "data/validation"
86
  test_dir: str = "data/test"
 
94
  color_jitter: float = 0.1
95
 
96
  # Procedure filtering
97
+ procedures: list[str] = field(
98
+ default_factory=lambda: [
99
+ "rhinoplasty",
100
+ "blepharoplasty",
101
+ "rhytidectomy",
102
+ "orthognathic",
103
+ ]
104
+ )
105
  intensity_range: tuple[float, float] = (30.0, 100.0)
106
 
107
  # Data-driven displacement
 
112
  @dataclass
113
  class InferenceConfig:
114
  """Inference / generation configuration."""
115
+
116
  num_inference_steps: int = 30
117
  guidance_scale: float = 7.5
118
  scheduler: str = "dpmsolver++" # "ddpm", "ddim", "dpmsolver++"
 
134
  @dataclass
135
  class EvaluationConfig:
136
  """Evaluation configuration."""
137
+
138
  compute_fid: bool = True
139
  compute_lpips: bool = True
140
  compute_nme: bool = True
 
148
  @dataclass
149
  class WandbConfig:
150
  """Weights & Biases logging configuration."""
151
+
152
  enabled: bool = True
153
  project: str = "landmarkdiff"
154
  entity: str | None = None
 
159
  @dataclass
160
  class SlurmConfig:
161
  """SLURM job submission parameters."""
162
+
163
  partition: str = "batch_gpu"
164
+ account: str = os.environ.get("SLURM_ACCOUNT", "default_gpu")
165
  gpu_type: str = "nvidia_rtx_a6000"
166
  num_gpus: int = 1
167
  mem: str = "48G"
 
173
  @dataclass
174
  class SafetyConfig:
175
  """Clinical safety and responsible AI parameters."""
176
+
177
  identity_threshold: float = 0.6
178
  max_displacement_fraction: float = 0.05
179
  watermark_enabled: bool = True
 
187
  @dataclass
188
  class ExperimentConfig:
189
  """Top-level experiment configuration."""
190
+
191
  experiment_name: str = "default"
192
  description: str = ""
193
  version: str = "0.3.0"
 
242
  return asdict(self)
243
 
244
 
245
+ def _from_dict(cls: type, d: dict) -> Any:
246
  """Create a dataclass from a dict, ignoring unknown keys."""
247
  import dataclasses
248
+
249
  field_map = {f.name: f for f in dataclasses.fields(cls)}
250
  filtered = {}
251
  for k, v in d.items():
 
259
  return cls(**filtered)
260
 
261
 
262
+ def _convert_tuples(obj: Any) -> Any:
263
  """Recursively convert tuples to lists for YAML serialization."""
264
  if isinstance(obj, dict):
265
  return {k: _convert_tuples(v) for k, v in obj.items()}
 
282
  Returns:
283
  ExperimentConfig with overrides applied.
284
  """
285
+ config = ExperimentConfig.from_yaml(config_path) if config_path else ExperimentConfig()
 
 
 
286
 
287
  if overrides:
288
  for key, value in overrides.items():
landmarkdiff/curriculum.py CHANGED
@@ -104,10 +104,10 @@ class ProcedureCurriculum:
104
 
105
  # Difficulty ranking (0=easiest, 1=hardest)
106
  DEFAULT_PROCEDURE_DIFFICULTY = {
107
- "blepharoplasty": 0.3, # small, localized changes
108
- "rhinoplasty": 0.5, # moderate, central face
109
- "rhytidectomy": 0.7, # large, affects face shape
110
- "orthognathic": 0.9, # largest deformations
111
  }
112
 
113
  def __init__(
@@ -137,10 +137,7 @@ class ProcedureCurriculum:
137
 
138
  def get_procedure_weights(self, step: int) -> dict[str, float]:
139
  """Get all procedure weights at the given step."""
140
- return {
141
- proc: self.get_weight(step, proc)
142
- for proc in self.proc_difficulty
143
- }
144
 
145
 
146
  def compute_sample_difficulty(
@@ -174,7 +171,7 @@ def compute_sample_difficulty(
174
  source_bonus = {
175
  "synthetic": 0.0,
176
  "synthetic_v3": 0.1, # realistic displacements slightly harder
177
- "real": 0.2, # real data hardest
178
  "augmented": 0.0,
179
  }
180
 
 
104
 
105
  # Difficulty ranking (0=easiest, 1=hardest)
106
  DEFAULT_PROCEDURE_DIFFICULTY = {
107
+ "blepharoplasty": 0.3, # small, localized changes
108
+ "rhinoplasty": 0.5, # moderate, central face
109
+ "rhytidectomy": 0.7, # large, affects face shape
110
+ "orthognathic": 0.9, # largest deformations
111
  }
112
 
113
  def __init__(
 
137
 
138
  def get_procedure_weights(self, step: int) -> dict[str, float]:
139
  """Get all procedure weights at the given step."""
140
+ return {proc: self.get_weight(step, proc) for proc in self.proc_difficulty}
 
 
 
141
 
142
 
143
  def compute_sample_difficulty(
 
171
  source_bonus = {
172
  "synthetic": 0.0,
173
  "synthetic_v3": 0.1, # realistic displacements slightly harder
174
+ "real": 0.2, # real data hardest
175
  "augmented": 0.0,
176
  }
177
 
landmarkdiff/data.py CHANGED
@@ -23,8 +23,8 @@ from __future__ import annotations
23
  import csv
24
  import json
25
  import logging
 
26
  from pathlib import Path
27
- from typing import Callable
28
 
29
  import cv2
30
  import numpy as np
@@ -38,6 +38,7 @@ logger = logging.getLogger(__name__)
38
  # Core dataset
39
  # ---------------------------------------------------------------------------
40
 
 
41
  class SurgicalPairDataset(Dataset):
42
  """Dataset for loading surgical before/after training pairs.
43
 
@@ -162,9 +163,7 @@ class SurgicalPairDataset(Dataset):
162
  img = cv2.imread(str(path))
163
  if img is None:
164
  logger.warning("Failed to load %s, using blank", path)
165
- return np.zeros(
166
- (self.resolution, self.resolution, 3), dtype=np.uint8
167
- )
168
  if img.shape[:2] != (self.resolution, self.resolution):
169
  img = cv2.resize(img, (self.resolution, self.resolution))
170
  return img
@@ -173,14 +172,10 @@ class SurgicalPairDataset(Dataset):
173
  """Load a mask as float32 [0,1], resized to resolution."""
174
  path = self.data_dir / filename
175
  if not path.exists():
176
- return np.ones(
177
- (self.resolution, self.resolution), dtype=np.float32
178
- )
179
  mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
180
  if mask is None:
181
- return np.ones(
182
- (self.resolution, self.resolution), dtype=np.float32
183
- )
184
  mask = cv2.resize(mask, (self.resolution, self.resolution))
185
  return mask.astype(np.float32) / 255.0
186
 
@@ -189,6 +184,7 @@ class SurgicalPairDataset(Dataset):
189
  # Evaluation dataset (input + ground truth)
190
  # ---------------------------------------------------------------------------
191
 
 
192
  class EvalPairDataset(Dataset):
193
  """Dataset for evaluation: loads input/target pairs with procedure labels.
194
 
@@ -235,9 +231,7 @@ class EvalPairDataset(Dataset):
235
  path = self.data_dir / filename
236
  img = cv2.imread(str(path))
237
  if img is None:
238
- return np.zeros(
239
- (self.resolution, self.resolution, 3), dtype=np.uint8
240
- )
241
  if img.shape[:2] != (self.resolution, self.resolution):
242
  img = cv2.resize(img, (self.resolution, self.resolution))
243
  return img
@@ -247,6 +241,7 @@ class EvalPairDataset(Dataset):
247
  # Conversion utilities
248
  # ---------------------------------------------------------------------------
249
 
 
250
  def bgr_to_tensor(bgr: np.ndarray) -> torch.Tensor:
251
  """Convert BGR uint8 image to RGB [0,1] tensor (C, H, W)."""
252
  rgb = bgr[:, :, ::-1].astype(np.float32) / 255.0
@@ -271,6 +266,7 @@ def mask_to_tensor(mask: np.ndarray) -> torch.Tensor:
271
  # Samplers
272
  # ---------------------------------------------------------------------------
273
 
 
274
  def create_procedure_sampler(
275
  dataset: SurgicalPairDataset,
276
  balance_procedures: bool = True,
@@ -309,6 +305,7 @@ def create_procedure_sampler(
309
  # DataLoader factory
310
  # ---------------------------------------------------------------------------
311
 
 
312
  def create_dataloader(
313
  dataset: Dataset,
314
  batch_size: int = 4,
@@ -353,6 +350,7 @@ def create_dataloader(
353
  # Multi-directory dataset
354
  # ---------------------------------------------------------------------------
355
 
 
356
  class CombinedDataset(Dataset):
357
  """Combine multiple SurgicalPairDatasets into one.
358
 
 
23
  import csv
24
  import json
25
  import logging
26
+ from collections.abc import Callable
27
  from pathlib import Path
 
28
 
29
  import cv2
30
  import numpy as np
 
38
  # Core dataset
39
  # ---------------------------------------------------------------------------
40
 
41
+
42
  class SurgicalPairDataset(Dataset):
43
  """Dataset for loading surgical before/after training pairs.
44
 
 
163
  img = cv2.imread(str(path))
164
  if img is None:
165
  logger.warning("Failed to load %s, using blank", path)
166
+ return np.zeros((self.resolution, self.resolution, 3), dtype=np.uint8)
 
 
167
  if img.shape[:2] != (self.resolution, self.resolution):
168
  img = cv2.resize(img, (self.resolution, self.resolution))
169
  return img
 
172
  """Load a mask as float32 [0,1], resized to resolution."""
173
  path = self.data_dir / filename
174
  if not path.exists():
175
+ return np.ones((self.resolution, self.resolution), dtype=np.float32)
 
 
176
  mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
177
  if mask is None:
178
+ return np.ones((self.resolution, self.resolution), dtype=np.float32)
 
 
179
  mask = cv2.resize(mask, (self.resolution, self.resolution))
180
  return mask.astype(np.float32) / 255.0
181
 
 
184
  # Evaluation dataset (input + ground truth)
185
  # ---------------------------------------------------------------------------
186
 
187
+
188
  class EvalPairDataset(Dataset):
189
  """Dataset for evaluation: loads input/target pairs with procedure labels.
190
 
 
231
  path = self.data_dir / filename
232
  img = cv2.imread(str(path))
233
  if img is None:
234
+ return np.zeros((self.resolution, self.resolution, 3), dtype=np.uint8)
 
 
235
  if img.shape[:2] != (self.resolution, self.resolution):
236
  img = cv2.resize(img, (self.resolution, self.resolution))
237
  return img
 
241
  # Conversion utilities
242
  # ---------------------------------------------------------------------------
243
 
244
+
245
  def bgr_to_tensor(bgr: np.ndarray) -> torch.Tensor:
246
  """Convert BGR uint8 image to RGB [0,1] tensor (C, H, W)."""
247
  rgb = bgr[:, :, ::-1].astype(np.float32) / 255.0
 
266
  # Samplers
267
  # ---------------------------------------------------------------------------
268
 
269
+
270
  def create_procedure_sampler(
271
  dataset: SurgicalPairDataset,
272
  balance_procedures: bool = True,
 
305
  # DataLoader factory
306
  # ---------------------------------------------------------------------------
307
 
308
+
309
  def create_dataloader(
310
  dataset: Dataset,
311
  batch_size: int = 4,
 
350
  # Multi-directory dataset
351
  # ---------------------------------------------------------------------------
352
 
353
+
354
  class CombinedDataset(Dataset):
355
  """Combine multiple SurgicalPairDatasets into one.
356
 
landmarkdiff/data_version.py CHANGED
@@ -22,7 +22,7 @@ import json
22
  from dataclasses import dataclass, field
23
  from datetime import datetime, timezone
24
  from pathlib import Path
25
- from typing import Any, Iterator
26
 
27
 
28
  @dataclass
@@ -68,9 +68,7 @@ class DataManifest:
68
  """
69
 
70
  version: str = "1.0"
71
- created_at: str = field(
72
- default_factory=lambda: datetime.now(timezone.utc).isoformat()
73
- )
74
  root_dir: str = ""
75
  files: list[FileEntry] = field(default_factory=list)
76
  metadata: dict[str, Any] = field(default_factory=dict)
@@ -237,8 +235,7 @@ class DataManifest:
237
  actual_size = fp.stat().st_size
238
  if actual_size != entry.size_bytes:
239
  issues.append(
240
- f"Size mismatch: {entry.path} "
241
- f"(expected {entry.size_bytes}, got {actual_size})"
242
  )
243
 
244
  # Check checksum
@@ -265,8 +262,7 @@ class DataManifest:
265
  added = sorted(other_paths - self_paths)
266
  removed = sorted(self_paths - other_paths)
267
  modified = sorted(
268
- p for p in self_paths & other_paths
269
- if self_files[p].checksum != other_files[p].checksum
270
  )
271
 
272
  return {"added": added, "removed": removed, "modified": modified}
@@ -292,6 +288,7 @@ def _get_hostname() -> str:
292
  """Get hostname safely."""
293
  try:
294
  import socket
 
295
  return socket.gethostname()
296
  except Exception:
297
  return "unknown"
 
22
  from dataclasses import dataclass, field
23
  from datetime import datetime, timezone
24
  from pathlib import Path
25
+ from typing import Any
26
 
27
 
28
  @dataclass
 
68
  """
69
 
70
  version: str = "1.0"
71
+ created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
 
 
72
  root_dir: str = ""
73
  files: list[FileEntry] = field(default_factory=list)
74
  metadata: dict[str, Any] = field(default_factory=dict)
 
235
  actual_size = fp.stat().st_size
236
  if actual_size != entry.size_bytes:
237
  issues.append(
238
+ f"Size mismatch: {entry.path} (expected {entry.size_bytes}, got {actual_size})"
 
239
  )
240
 
241
  # Check checksum
 
262
  added = sorted(other_paths - self_paths)
263
  removed = sorted(self_paths - other_paths)
264
  modified = sorted(
265
+ p for p in self_paths & other_paths if self_files[p].checksum != other_files[p].checksum
 
266
  )
267
 
268
  return {"added": added, "removed": removed, "modified": modified}
 
288
  """Get hostname safely."""
289
  try:
290
  import socket
291
+
292
  return socket.gethostname()
293
  except Exception:
294
  return "unknown"
landmarkdiff/displacement_model.py CHANGED
@@ -33,12 +33,11 @@ from __future__ import annotations
33
  import json
34
  import logging
35
  from pathlib import Path
36
- from typing import Optional, Union
37
 
38
  import cv2
39
  import numpy as np
40
 
41
- from landmarkdiff.landmarks import extract_landmarks, FaceLandmarks
42
  from landmarkdiff.manipulation import PROCEDURE_LANDMARKS
43
 
44
  logger = logging.getLogger(__name__)
@@ -54,6 +53,7 @@ PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
54
  # Helpers
55
  # ---------------------------------------------------------------------------
56
 
 
57
  def _normalized_coords_2d(face: FaceLandmarks) -> np.ndarray:
58
  """Extract (478, 2) normalized [0, 1] coordinates from a FaceLandmarks object.
59
 
@@ -84,9 +84,24 @@ def _compute_alignment_quality(
84
  # Stable landmarks: forehead, temple region, outer face oval
85
  # These should exhibit near-zero displacement after surgery.
86
  stable_indices = [
87
- 10, 109, 67, 103, 54, 21, 162, 127, # left forehead/temple
88
- 338, 297, 332, 284, 251, 389, 356, 454, # right forehead/temple
89
- 234, 93, # outer cheek anchors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  ]
91
  stable_indices = [i for i in stable_indices if i < NUM_LANDMARKS]
92
 
@@ -95,7 +110,7 @@ def _compute_alignment_quality(
95
 
96
  # RMS displacement on stable points
97
  diffs = after_stable - before_stable
98
- rms = np.sqrt(np.mean(np.sum(diffs ** 2, axis=1)))
99
 
100
  # Map RMS to quality: 0 displacement -> 1.0, rms >= 0.05 (5% of image) -> 0.0
101
  quality = float(np.clip(1.0 - rms / 0.05, 0.0, 1.0))
@@ -106,6 +121,7 @@ def _compute_alignment_quality(
106
  # Procedure classification
107
  # ---------------------------------------------------------------------------
108
 
 
109
  def classify_procedure(displacements: np.ndarray) -> str:
110
  """Classify which surgical procedure was performed from displacement vectors.
111
 
@@ -143,8 +159,7 @@ def classify_procedure(displacements: np.ndarray) -> str:
143
  # Threshold: mean displacement < 0.002 (~1 pixel at 512x512)
144
  if best_score < 0.002:
145
  logger.debug(
146
- "No significant displacement detected (best=%.5f). "
147
- "Classified as 'unknown'.",
148
  best_score,
149
  )
150
  return "unknown"
@@ -156,11 +171,12 @@ def classify_procedure(displacements: np.ndarray) -> str:
156
  # Single-pair extraction
157
  # ---------------------------------------------------------------------------
158
 
 
159
  def extract_displacements(
160
  before_img: np.ndarray,
161
  after_img: np.ndarray,
162
  min_detection_confidence: float = 0.5,
163
- ) -> Optional[dict]:
164
  """Extract landmark displacements from a before/after surgery image pair.
165
 
166
  Runs MediaPipe Face Mesh on both images, computes per-landmark
@@ -185,16 +201,12 @@ def extract_displacements(
185
  Returns ``None`` if face detection fails on either image.
186
  """
187
  # Extract landmarks from both images
188
- face_before = extract_landmarks(
189
- before_img, min_detection_confidence=min_detection_confidence
190
- )
191
  if face_before is None:
192
  logger.warning("Face detection failed on before image.")
193
  return None
194
 
195
- face_after = extract_landmarks(
196
- after_img, min_detection_confidence=min_detection_confidence
197
- )
198
  if face_after is None:
199
  logger.warning("Face detection failed on after image.")
200
  return None
@@ -227,8 +239,9 @@ def extract_displacements(
227
  # Batch extraction from directory
228
  # ---------------------------------------------------------------------------
229
 
 
230
  def extract_from_directory(
231
- pairs_dir: Union[str, Path],
232
  min_detection_confidence: float = 0.5,
233
  min_quality: float = 0.0,
234
  ) -> list[dict]:
@@ -338,6 +351,7 @@ def extract_from_directory(
338
  # Displacement model
339
  # ---------------------------------------------------------------------------
340
 
 
341
  class DisplacementModel:
342
  """Statistical model of per-procedure surgical displacements.
343
 
@@ -418,12 +432,12 @@ class DisplacementModel:
418
  n = stacked.shape[0]
419
 
420
  self.stats[proc] = {
421
- "mean": np.mean(stacked, axis=0), # (478, 2)
422
- "std": np.std(stacked, axis=0), # (478, 2)
423
- "min": np.min(stacked, axis=0), # (478, 2)
424
- "max": np.max(stacked, axis=0), # (478, 2)
425
- "median": np.median(stacked, axis=0), # (478, 2)
426
- "mean_magnitude": np.mean( # (478,)
427
  np.linalg.norm(stacked, axis=2), axis=0
428
  ),
429
  }
@@ -442,7 +456,7 @@ class DisplacementModel:
442
  procedure: str,
443
  intensity: float = 1.0,
444
  noise_scale: float = 0.0,
445
- rng: Optional[np.random.Generator] = None,
446
  ) -> np.ndarray:
447
  """Generate a displacement field for a given procedure and intensity.
448
 
@@ -470,10 +484,7 @@ class DisplacementModel:
470
 
471
  if procedure not in self.stats:
472
  available = ", ".join(self.procedures)
473
- raise KeyError(
474
- f"Procedure '{procedure}' not in model. "
475
- f"Available: {available}"
476
- )
477
 
478
  proc_stats = self.stats[procedure]
479
  field = proc_stats["mean"].copy() * intensity
@@ -489,7 +500,7 @@ class DisplacementModel:
489
 
490
  return field.astype(np.float32)
491
 
492
- def get_summary(self, procedure: Optional[str] = None) -> dict:
493
  """Get a human-readable summary of the model statistics.
494
 
495
  Args:
@@ -518,7 +529,7 @@ class DisplacementModel:
518
 
519
  return summary
520
 
521
- def save(self, path: Union[str, Path]) -> None:
522
  """Save the fitted model to disk as a ``.npz`` file.
523
 
524
  The file contains:
@@ -550,15 +561,13 @@ class DisplacementModel:
550
  "n_samples": self.n_samples,
551
  "num_landmarks": NUM_LANDMARKS,
552
  }
553
- arrays["__metadata__"] = np.frombuffer(
554
- json.dumps(metadata).encode("utf-8"), dtype=np.uint8
555
- )
556
 
557
  np.savez_compressed(str(path), **arrays)
558
  logger.info("Saved displacement model to %s", path)
559
 
560
  @classmethod
561
- def load(cls, path: Union[str, Path]) -> "DisplacementModel":
562
  """Load a fitted model from a ``.npz`` file.
563
 
564
  Supports two formats:
@@ -592,7 +601,7 @@ class DisplacementModel:
592
  model.stats[proc] = {}
593
  for key in data.files:
594
  if key.startswith(f"{proc}__"):
595
- stat_name = key[len(f"{proc}__"):]
596
  model.stats[proc][stat_name] = data[key]
597
 
598
  # Format 2: extract_displacements.py format with procedures array
@@ -625,10 +634,7 @@ class DisplacementModel:
625
  model.n_samples[proc] = 0
626
 
627
  else:
628
- raise ValueError(
629
- f"Unrecognized displacement model format. "
630
- f"Keys: {data.files[:10]}"
631
- )
632
 
633
  model._fitted = True
634
  logger.info(
@@ -644,6 +650,7 @@ class DisplacementModel:
644
  # Utilities
645
  # ---------------------------------------------------------------------------
646
 
 
647
  def _top_k_landmarks(
648
  magnitudes: np.ndarray,
649
  k: int = 10,
@@ -659,10 +666,7 @@ def _top_k_landmarks(
659
  descending by magnitude.
660
  """
661
  top_indices = np.argsort(magnitudes)[::-1][:k]
662
- return [
663
- {"index": int(idx), "magnitude": float(magnitudes[idx])}
664
- for idx in top_indices
665
- ]
666
 
667
 
668
  def visualize_displacements(
@@ -698,7 +702,7 @@ def visualize_displacements(
698
  dy = int(displacements[i, 1] * h * scale)
699
 
700
  # Only draw if displacement is above noise floor
701
- mag = np.sqrt(dx ** 2 + dy ** 2)
702
  if mag < 1.0:
703
  continue
704
 
 
33
  import json
34
  import logging
35
  from pathlib import Path
 
36
 
37
  import cv2
38
  import numpy as np
39
 
40
+ from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks
41
  from landmarkdiff.manipulation import PROCEDURE_LANDMARKS
42
 
43
  logger = logging.getLogger(__name__)
 
53
  # Helpers
54
  # ---------------------------------------------------------------------------
55
 
56
+
57
  def _normalized_coords_2d(face: FaceLandmarks) -> np.ndarray:
58
  """Extract (478, 2) normalized [0, 1] coordinates from a FaceLandmarks object.
59
 
 
84
  # Stable landmarks: forehead, temple region, outer face oval
85
  # These should exhibit near-zero displacement after surgery.
86
  stable_indices = [
87
+ 10,
88
+ 109,
89
+ 67,
90
+ 103,
91
+ 54,
92
+ 21,
93
+ 162,
94
+ 127, # left forehead/temple
95
+ 338,
96
+ 297,
97
+ 332,
98
+ 284,
99
+ 251,
100
+ 389,
101
+ 356,
102
+ 454, # right forehead/temple
103
+ 234,
104
+ 93, # outer cheek anchors
105
  ]
106
  stable_indices = [i for i in stable_indices if i < NUM_LANDMARKS]
107
 
 
110
 
111
  # RMS displacement on stable points
112
  diffs = after_stable - before_stable
113
+ rms = np.sqrt(np.mean(np.sum(diffs**2, axis=1)))
114
 
115
  # Map RMS to quality: 0 displacement -> 1.0, rms >= 0.05 (5% of image) -> 0.0
116
  quality = float(np.clip(1.0 - rms / 0.05, 0.0, 1.0))
 
121
  # Procedure classification
122
  # ---------------------------------------------------------------------------
123
 
124
+
125
  def classify_procedure(displacements: np.ndarray) -> str:
126
  """Classify which surgical procedure was performed from displacement vectors.
127
 
 
159
  # Threshold: mean displacement < 0.002 (~1 pixel at 512x512)
160
  if best_score < 0.002:
161
  logger.debug(
162
+ "No significant displacement detected (best=%.5f). Classified as 'unknown'.",
 
163
  best_score,
164
  )
165
  return "unknown"
 
171
  # Single-pair extraction
172
  # ---------------------------------------------------------------------------
173
 
174
+
175
  def extract_displacements(
176
  before_img: np.ndarray,
177
  after_img: np.ndarray,
178
  min_detection_confidence: float = 0.5,
179
+ ) -> dict | None:
180
  """Extract landmark displacements from a before/after surgery image pair.
181
 
182
  Runs MediaPipe Face Mesh on both images, computes per-landmark
 
201
  Returns ``None`` if face detection fails on either image.
202
  """
203
  # Extract landmarks from both images
204
+ face_before = extract_landmarks(before_img, min_detection_confidence=min_detection_confidence)
 
 
205
  if face_before is None:
206
  logger.warning("Face detection failed on before image.")
207
  return None
208
 
209
+ face_after = extract_landmarks(after_img, min_detection_confidence=min_detection_confidence)
 
 
210
  if face_after is None:
211
  logger.warning("Face detection failed on after image.")
212
  return None
 
239
  # Batch extraction from directory
240
  # ---------------------------------------------------------------------------
241
 
242
+
243
  def extract_from_directory(
244
+ pairs_dir: str | Path,
245
  min_detection_confidence: float = 0.5,
246
  min_quality: float = 0.0,
247
  ) -> list[dict]:
 
351
  # Displacement model
352
  # ---------------------------------------------------------------------------
353
 
354
+
355
  class DisplacementModel:
356
  """Statistical model of per-procedure surgical displacements.
357
 
 
432
  n = stacked.shape[0]
433
 
434
  self.stats[proc] = {
435
+ "mean": np.mean(stacked, axis=0), # (478, 2)
436
+ "std": np.std(stacked, axis=0), # (478, 2)
437
+ "min": np.min(stacked, axis=0), # (478, 2)
438
+ "max": np.max(stacked, axis=0), # (478, 2)
439
+ "median": np.median(stacked, axis=0), # (478, 2)
440
+ "mean_magnitude": np.mean( # (478,)
441
  np.linalg.norm(stacked, axis=2), axis=0
442
  ),
443
  }
 
456
  procedure: str,
457
  intensity: float = 1.0,
458
  noise_scale: float = 0.0,
459
+ rng: np.random.Generator | None = None,
460
  ) -> np.ndarray:
461
  """Generate a displacement field for a given procedure and intensity.
462
 
 
484
 
485
  if procedure not in self.stats:
486
  available = ", ".join(self.procedures)
487
+ raise KeyError(f"Procedure '{procedure}' not in model. Available: {available}")
 
 
 
488
 
489
  proc_stats = self.stats[procedure]
490
  field = proc_stats["mean"].copy() * intensity
 
500
 
501
  return field.astype(np.float32)
502
 
503
+ def get_summary(self, procedure: str | None = None) -> dict:
504
  """Get a human-readable summary of the model statistics.
505
 
506
  Args:
 
529
 
530
  return summary
531
 
532
+ def save(self, path: str | Path) -> None:
533
  """Save the fitted model to disk as a ``.npz`` file.
534
 
535
  The file contains:
 
561
  "n_samples": self.n_samples,
562
  "num_landmarks": NUM_LANDMARKS,
563
  }
564
+ arrays["__metadata__"] = np.frombuffer(json.dumps(metadata).encode("utf-8"), dtype=np.uint8)
 
 
565
 
566
  np.savez_compressed(str(path), **arrays)
567
  logger.info("Saved displacement model to %s", path)
568
 
569
  @classmethod
570
+ def load(cls, path: str | Path) -> DisplacementModel:
571
  """Load a fitted model from a ``.npz`` file.
572
 
573
  Supports two formats:
 
601
  model.stats[proc] = {}
602
  for key in data.files:
603
  if key.startswith(f"{proc}__"):
604
+ stat_name = key[len(f"{proc}__") :]
605
  model.stats[proc][stat_name] = data[key]
606
 
607
  # Format 2: extract_displacements.py format with procedures array
 
634
  model.n_samples[proc] = 0
635
 
636
  else:
637
+ raise ValueError(f"Unrecognized displacement model format. Keys: {data.files[:10]}")
 
 
 
638
 
639
  model._fitted = True
640
  logger.info(
 
650
  # Utilities
651
  # ---------------------------------------------------------------------------
652
 
653
+
654
  def _top_k_landmarks(
655
  magnitudes: np.ndarray,
656
  k: int = 10,
 
666
  descending by magnitude.
667
  """
668
  top_indices = np.argsort(magnitudes)[::-1][:k]
669
+ return [{"index": int(idx), "magnitude": float(magnitudes[idx])} for idx in top_indices]
 
 
 
670
 
671
 
672
  def visualize_displacements(
 
702
  dy = int(displacements[i, 1] * h * scale)
703
 
704
  # Only draw if displacement is above noise floor
705
+ mag = np.sqrt(dx**2 + dy**2)
706
  if mag < 1.0:
707
  continue
708
 
landmarkdiff/ensemble.py CHANGED
@@ -21,8 +21,6 @@ Usage:
21
 
22
  from __future__ import annotations
23
 
24
- from typing import Optional
25
-
26
  import cv2
27
  import numpy as np
28
 
@@ -93,7 +91,7 @@ class EnsembleInference:
93
  guidance_scale: float = 9.0,
94
  controlnet_conditioning_scale: float = 0.9,
95
  strength: float = 0.5,
96
- seed: Optional[int] = None,
97
  **kwargs,
98
  ) -> dict:
99
  """Generate ensemble output.
@@ -155,14 +153,16 @@ class EnsembleInference:
155
  # Copy metadata from best result
156
  best_idx = selected_idx if selected_idx >= 0 else 0
157
  ensemble_result = dict(results[best_idx])
158
- ensemble_result.update({
159
- "output": final,
160
- "outputs": outputs,
161
- "scores": scores,
162
- "selected_idx": selected_idx,
163
- "strategy": self.strategy,
164
- "n_samples": self.n_samples,
165
- })
 
 
166
 
167
  return ensemble_result
168
 
@@ -196,7 +196,7 @@ class EnsembleInference:
196
 
197
  # Weighted average
198
  result = np.zeros_like(outputs[0], dtype=np.float32)
199
- for output, weight in zip(outputs, weights):
200
  result += output.astype(np.float32) * weight
201
 
202
  return np.clip(result, 0, 255).astype(np.uint8), scores
@@ -269,8 +269,10 @@ def ensemble_inference(
269
  for i, output in enumerate(result["outputs"]):
270
  cv2.imwrite(str(out / f"sample_{i:02d}.png"), output)
271
  score = result["scores"][i]
272
- print(f" Sample {i}: score={score:.4f}"
273
- + (" <-- selected" if i == result.get("selected_idx") else ""))
 
 
274
 
275
  # Comparison grid
276
  panels = [image] + result["outputs"] + [result["output"]]
@@ -281,8 +283,10 @@ def ensemble_inference(
281
 
282
  print(f"\nEnsemble output saved: {out / 'ensemble_output.png'}")
283
  if result.get("selected_idx", -1) >= 0:
284
- print(f"Selected sample: {result['selected_idx']} "
285
- f"(score={result['scores'][result['selected_idx']]:.4f})")
 
 
286
 
287
 
288
  if __name__ == "__main__":
@@ -294,18 +298,26 @@ if __name__ == "__main__":
294
  parser.add_argument("--intensity", type=float, default=65.0)
295
  parser.add_argument("--output", default="ensemble_output")
296
  parser.add_argument("--n_samples", type=int, default=5)
297
- parser.add_argument("--strategy", default="best_of_n",
298
- choices=["pixel_average", "weighted_average", "best_of_n", "median"])
299
- parser.add_argument("--mode", default="tps",
300
- choices=["controlnet", "img2img", "tps"])
 
 
301
  parser.add_argument("--checkpoint", default=None)
302
  parser.add_argument("--displacement-model", default=None)
303
  parser.add_argument("--seed", type=int, default=42)
304
  args = parser.parse_args()
305
 
306
  ensemble_inference(
307
- args.image, args.procedure, args.intensity,
308
- args.output, args.n_samples, args.strategy,
309
- args.mode, args.checkpoint, args.displacement_model,
 
 
 
 
 
 
310
  args.seed,
311
  )
 
21
 
22
  from __future__ import annotations
23
 
 
 
24
  import cv2
25
  import numpy as np
26
 
 
91
  guidance_scale: float = 9.0,
92
  controlnet_conditioning_scale: float = 0.9,
93
  strength: float = 0.5,
94
+ seed: int | None = None,
95
  **kwargs,
96
  ) -> dict:
97
  """Generate ensemble output.
 
153
  # Copy metadata from best result
154
  best_idx = selected_idx if selected_idx >= 0 else 0
155
  ensemble_result = dict(results[best_idx])
156
+ ensemble_result.update(
157
+ {
158
+ "output": final,
159
+ "outputs": outputs,
160
+ "scores": scores,
161
+ "selected_idx": selected_idx,
162
+ "strategy": self.strategy,
163
+ "n_samples": self.n_samples,
164
+ }
165
+ )
166
 
167
  return ensemble_result
168
 
 
196
 
197
  # Weighted average
198
  result = np.zeros_like(outputs[0], dtype=np.float32)
199
+ for output, weight in zip(outputs, weights, strict=False):
200
  result += output.astype(np.float32) * weight
201
 
202
  return np.clip(result, 0, 255).astype(np.uint8), scores
 
269
  for i, output in enumerate(result["outputs"]):
270
  cv2.imwrite(str(out / f"sample_{i:02d}.png"), output)
271
  score = result["scores"][i]
272
+ print(
273
+ f" Sample {i}: score={score:.4f}"
274
+ + (" <-- selected" if i == result.get("selected_idx") else "")
275
+ )
276
 
277
  # Comparison grid
278
  panels = [image] + result["outputs"] + [result["output"]]
 
283
 
284
  print(f"\nEnsemble output saved: {out / 'ensemble_output.png'}")
285
  if result.get("selected_idx", -1) >= 0:
286
+ print(
287
+ f"Selected sample: {result['selected_idx']} "
288
+ f"(score={result['scores'][result['selected_idx']]:.4f})"
289
+ )
290
 
291
 
292
  if __name__ == "__main__":
 
298
  parser.add_argument("--intensity", type=float, default=65.0)
299
  parser.add_argument("--output", default="ensemble_output")
300
  parser.add_argument("--n_samples", type=int, default=5)
301
+ parser.add_argument(
302
+ "--strategy",
303
+ default="best_of_n",
304
+ choices=["pixel_average", "weighted_average", "best_of_n", "median"],
305
+ )
306
+ parser.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
307
  parser.add_argument("--checkpoint", default=None)
308
  parser.add_argument("--displacement-model", default=None)
309
  parser.add_argument("--seed", type=int, default=42)
310
  args = parser.parse_args()
311
 
312
  ensemble_inference(
313
+ args.image,
314
+ args.procedure,
315
+ args.intensity,
316
+ args.output,
317
+ args.n_samples,
318
+ args.strategy,
319
+ args.mode,
320
+ args.checkpoint,
321
+ args.displacement_model,
322
  args.seed,
323
  )
landmarkdiff/evaluation.py CHANGED
@@ -8,6 +8,7 @@ Secondary: SSIM (relaxed target >0.80).
8
  from __future__ import annotations
9
 
10
  from dataclasses import dataclass, field
 
11
 
12
  import numpy as np
13
 
@@ -23,7 +24,7 @@ class EvalMetrics:
23
 
24
  fid: float = 0.0
25
  lpips: float = 0.0
26
- nme: float = 0.0 # Normalized Mean landmark Error
27
  identity_sim: float = 0.0 # ArcFace cosine similarity
28
  ssim: float = 0.0
29
 
@@ -154,9 +155,7 @@ def compute_nme(
154
  Returns:
155
  NME value (lower is better).
156
  """
157
- iod = np.linalg.norm(
158
- target_landmarks[left_eye_idx] - target_landmarks[right_eye_idx]
159
- )
160
  if iod < 1.0:
161
  iod = 1.0
162
 
@@ -175,6 +174,7 @@ def compute_ssim(
175
  """
176
  try:
177
  from skimage.metrics import structural_similarity
 
178
  # Convert to grayscale if color, or compute per-channel
179
  if pred.ndim == 3 and pred.shape[2] == 3:
180
  return float(structural_similarity(pred, target, channel_axis=2, data_range=255))
@@ -194,10 +194,8 @@ def compute_ssim(
194
  C1 = (0.01 * 255) ** 2
195
  C2 = (0.03 * 255) ** 2
196
 
197
- ssim_val = (
198
- (2 * mu_p * mu_t + C1) * (2 * sigma_pt + C2)
199
- ) / (
200
- (mu_p ** 2 + mu_t ** 2 + C1) * (sigma_p ** 2 + sigma_t ** 2 + C2)
201
  )
202
  return float(ssim_val)
203
 
@@ -206,11 +204,12 @@ _LPIPS_FN = None
206
  _ARCFACE_APP = None
207
 
208
 
209
- def _get_lpips_fn():
210
  """Get or create singleton LPIPS model."""
211
  global _LPIPS_FN
212
  if _LPIPS_FN is None:
213
  import lpips
 
214
  _LPIPS_FN = lpips.LPIPS(net="alex", verbose=False)
215
  _LPIPS_FN.eval()
216
  return _LPIPS_FN
@@ -225,7 +224,7 @@ def compute_lpips(
225
  Returns LPIPS score (lower = more similar).
226
  """
227
  try:
228
- import lpips
229
  import torch
230
  except ImportError:
231
  return float("nan")
@@ -261,9 +260,10 @@ def compute_fid(
261
  except ImportError:
262
  raise ImportError(
263
  "torch-fidelity is required for FID. Install with: pip install torch-fidelity"
264
- )
265
 
266
  import torch
 
267
  metrics = calculate_metrics(
268
  input1=generated_dir,
269
  input2=real_dir,
@@ -285,6 +285,7 @@ def compute_identity_similarity(
285
  """
286
  try:
287
  from insightface.app import FaceAnalysis
 
288
  global _ARCFACE_APP
289
  if _ARCFACE_APP is None:
290
  _ARCFACE_APP = FaceAnalysis(
 
8
  from __future__ import annotations
9
 
10
  from dataclasses import dataclass, field
11
+ from typing import Any
12
 
13
  import numpy as np
14
 
 
24
 
25
  fid: float = 0.0
26
  lpips: float = 0.0
27
+ nme: float = 0.0 # Normalized Mean landmark Error
28
  identity_sim: float = 0.0 # ArcFace cosine similarity
29
  ssim: float = 0.0
30
 
 
155
  Returns:
156
  NME value (lower is better).
157
  """
158
+ iod = np.linalg.norm(target_landmarks[left_eye_idx] - target_landmarks[right_eye_idx])
 
 
159
  if iod < 1.0:
160
  iod = 1.0
161
 
 
174
  """
175
  try:
176
  from skimage.metrics import structural_similarity
177
+
178
  # Convert to grayscale if color, or compute per-channel
179
  if pred.ndim == 3 and pred.shape[2] == 3:
180
  return float(structural_similarity(pred, target, channel_axis=2, data_range=255))
 
194
  C1 = (0.01 * 255) ** 2
195
  C2 = (0.03 * 255) ** 2
196
 
197
+ ssim_val = ((2 * mu_p * mu_t + C1) * (2 * sigma_pt + C2)) / (
198
+ (mu_p**2 + mu_t**2 + C1) * (sigma_p**2 + sigma_t**2 + C2)
 
 
199
  )
200
  return float(ssim_val)
201
 
 
204
  _ARCFACE_APP = None
205
 
206
 
207
+ def _get_lpips_fn() -> Any:
208
  """Get or create singleton LPIPS model."""
209
  global _LPIPS_FN
210
  if _LPIPS_FN is None:
211
  import lpips
212
+
213
  _LPIPS_FN = lpips.LPIPS(net="alex", verbose=False)
214
  _LPIPS_FN.eval()
215
  return _LPIPS_FN
 
224
  Returns LPIPS score (lower = more similar).
225
  """
226
  try:
227
+ import lpips # noqa: F401
228
  import torch
229
  except ImportError:
230
  return float("nan")
 
260
  except ImportError:
261
  raise ImportError(
262
  "torch-fidelity is required for FID. Install with: pip install torch-fidelity"
263
+ ) from None
264
 
265
  import torch
266
+
267
  metrics = calculate_metrics(
268
  input1=generated_dir,
269
  input2=real_dir,
 
285
  """
286
  try:
287
  from insightface.app import FaceAnalysis
288
+
289
  global _ARCFACE_APP
290
  if _ARCFACE_APP is None:
291
  _ARCFACE_APP = FaceAnalysis(
landmarkdiff/experiment_tracker.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local experiment tracker for training reproducibility.
2
+
3
+ Tracks all training runs with their configs, metrics, and results.
4
+ Each experiment gets a unique ID and timestamp.
5
+
6
+ Usage::
7
+
8
+ tracker = ExperimentTracker("experiments/")
9
+
10
+ # Start a new experiment
11
+ exp_id = tracker.start(
12
+ name="phaseA_v2",
13
+ config={
14
+ "phase": "A", "lr": 1e-5, "batch": 4,
15
+ "steps": 100000, "data": "training_combined",
16
+ },
17
+ )
18
+
19
+ # Log metrics during training
20
+ tracker.log_metric(exp_id, step=1000, loss=0.045, ssim=0.82)
21
+
22
+ # Record final results
23
+ tracker.finish(exp_id, results={"fid": 42.3, "ssim": 0.87})
24
+
25
+ # List all experiments
26
+ tracker.list_experiments()
27
+
28
+ # Compare experiments
29
+ tracker.compare(["exp_001", "exp_002"])
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import json
35
+ import os
36
+ import socket
37
+ import time
38
+ from datetime import datetime
39
+ from pathlib import Path
40
+
41
+
42
+ class ExperimentTracker:
43
+ """Simple file-based experiment tracker."""
44
+
45
+ def __init__(self, experiments_dir: str = "experiments"):
46
+ self.dir = Path(experiments_dir)
47
+ self.dir.mkdir(parents=True, exist_ok=True)
48
+ self._index_path = self.dir / "index.json"
49
+ self._index = self._load_index()
50
+
51
+ def _load_index(self) -> dict:
52
+ if self._index_path.exists():
53
+ with open(self._index_path) as f:
54
+ return json.load(f)
55
+ return {"experiments": {}, "counter": 0}
56
+
57
+ def _save_index(self) -> None:
58
+ with open(self._index_path, "w") as f:
59
+ json.dump(self._index, f, indent=2)
60
+
61
+ def start(
62
+ self,
63
+ name: str,
64
+ config: dict,
65
+ tags: list[str] | None = None,
66
+ ) -> str:
67
+ """Start a new experiment. Returns experiment ID."""
68
+ self._index["counter"] += 1
69
+ exp_id = f"exp_{self._index['counter']:03d}"
70
+
71
+ exp = {
72
+ "id": exp_id,
73
+ "name": name,
74
+ "config": config,
75
+ "tags": tags or [],
76
+ "status": "running",
77
+ "started_at": datetime.now().isoformat(),
78
+ "finished_at": None,
79
+ "hostname": socket.gethostname(),
80
+ "slurm_job_id": os.environ.get("SLURM_JOB_ID"),
81
+ "gpu": os.environ.get("CUDA_VISIBLE_DEVICES"),
82
+ "results": {},
83
+ "metrics_file": f"{exp_id}_metrics.jsonl",
84
+ }
85
+
86
+ self._index["experiments"][exp_id] = exp
87
+ self._save_index()
88
+
89
+ # Create metrics log file
90
+ metrics_path = self.dir / str(exp["metrics_file"])
91
+ metrics_path.touch()
92
+
93
+ print(f"Experiment started: {exp_id} ({name})")
94
+ return exp_id
95
+
96
+ def log_metric(self, exp_id: str, step: int | None = None, **metrics) -> None:
97
+ """Log metrics for a training step."""
98
+ exp = self._index["experiments"].get(exp_id)
99
+ if not exp:
100
+ return
101
+
102
+ entry = {
103
+ "timestamp": time.time(),
104
+ "step": step,
105
+ **metrics,
106
+ }
107
+
108
+ metrics_path = self.dir / str(exp["metrics_file"])
109
+ with open(metrics_path, "a") as f:
110
+ f.write(json.dumps(entry) + "\n")
111
+
112
+ def finish(
113
+ self,
114
+ exp_id: str,
115
+ results: dict | None = None,
116
+ status: str = "completed",
117
+ ) -> None:
118
+ """Mark experiment as finished."""
119
+ exp = self._index["experiments"].get(exp_id)
120
+ if not exp:
121
+ return
122
+
123
+ exp["status"] = status
124
+ exp["finished_at"] = datetime.now().isoformat()
125
+ if results:
126
+ exp["results"] = results
127
+
128
+ self._save_index()
129
+ print(f"Experiment {exp_id} {status}")
130
+
131
+ def get_metrics(self, exp_id: str) -> list[dict]:
132
+ """Load all logged metrics for an experiment."""
133
+ exp = self._index["experiments"].get(exp_id)
134
+ if not exp:
135
+ return []
136
+
137
+ metrics_path = self.dir / str(exp["metrics_file"])
138
+ if not metrics_path.exists():
139
+ return []
140
+
141
+ entries = []
142
+ with open(metrics_path) as f:
143
+ for line in f:
144
+ line = line.strip()
145
+ if line:
146
+ entries.append(json.loads(line))
147
+ return entries
148
+
149
+ def list_experiments(self) -> list[dict]:
150
+ """List all experiments with summary info."""
151
+ experiments = []
152
+ for exp_id, exp in sorted(self._index["experiments"].items()):
153
+ summary = {
154
+ "id": exp_id,
155
+ "name": exp["name"],
156
+ "status": exp["status"],
157
+ "started": exp["started_at"][:19],
158
+ "tags": exp.get("tags", []),
159
+ }
160
+ if exp["results"]:
161
+ for key in ["fid", "ssim", "lpips", "nme"]:
162
+ if key in exp["results"]:
163
+ summary[key] = exp["results"][key]
164
+ experiments.append(summary)
165
+ return experiments
166
+
167
+ def compare(self, exp_ids: list[str]) -> dict:
168
+ """Compare multiple experiments by their results."""
169
+ comparison = {}
170
+ for exp_id in exp_ids:
171
+ exp = self._index["experiments"].get(exp_id)
172
+ if exp:
173
+ comparison[exp_id] = {
174
+ "name": exp["name"],
175
+ "config": exp["config"],
176
+ "results": exp["results"],
177
+ }
178
+ return comparison
179
+
180
+ def print_summary(self) -> None:
181
+ """Print a summary table of all experiments."""
182
+ experiments = self.list_experiments()
183
+ if not experiments:
184
+ print("No experiments found.")
185
+ return
186
+
187
+ # Header
188
+ print(f"{'ID':<10} {'Name':<20} {'Status':<12} {'FID':>6} {'SSIM':>6} {'LPIPS':>6}")
189
+ print("-" * 70)
190
+
191
+ for exp in experiments:
192
+ fid = f"{exp.get('fid', '')}" if "fid" in exp else "--"
193
+ ssim = f"{exp.get('ssim', ''):.4f}" if "ssim" in exp else "--"
194
+ lpips = f"{exp.get('lpips', ''):.4f}" if "lpips" in exp else "--"
195
+ print(
196
+ f"{exp['id']:<10} {exp['name']:<20}"
197
+ f" {exp['status']:<12} {fid:>6} {ssim:>6} {lpips:>6}"
198
+ )
199
+
200
+ def get_best(self, metric: str = "fid", lower_is_better: bool = True) -> str | None:
201
+ """Get the experiment ID with the best value for a given metric."""
202
+ best_id = None
203
+ best_val = float("inf") if lower_is_better else float("-inf")
204
+
205
+ for exp_id, exp in self._index["experiments"].items():
206
+ if exp["status"] != "completed":
207
+ continue
208
+ val = exp["results"].get(metric)
209
+ if val is None:
210
+ continue
211
+ if (lower_is_better and val < best_val) or (not lower_is_better and val > best_val):
212
+ best_val = val
213
+ best_id = exp_id
214
+
215
+ return best_id
landmarkdiff/face_verifier.py CHANGED
@@ -15,19 +15,18 @@ Designed for:
15
 
16
  from __future__ import annotations
17
 
18
- import os
19
  from dataclasses import dataclass, field
20
  from pathlib import Path
21
- from typing import Optional
22
 
23
  import cv2
24
  import numpy as np
25
 
26
-
27
  # ---------------------------------------------------------------------------
28
  # Data structures
29
  # ---------------------------------------------------------------------------
30
 
 
31
  @dataclass
32
  class DistortionReport:
33
  """Analysis of detected distortions in a face image."""
@@ -36,13 +35,13 @@ class DistortionReport:
36
  quality_score: float = 0.0
37
 
38
  # Individual distortion scores (0-1, higher = more distorted)
39
- blur_score: float = 0.0 # Laplacian variance-based
40
- noise_score: float = 0.0 # High-freq energy ratio
41
- compression_score: float = 0.0 # JPEG block artifact detection
42
- oversmooth_score: float = 0.0 # Beauty filter / airbrushed detection
43
- color_cast_score: float = 0.0 # Unnatural color shift
44
- geometric_distort: float = 0.0 # Face proportion anomalies
45
- lighting_score: float = 0.0 # Over/under exposure
46
 
47
  # Classification
48
  primary_distortion: str = "none"
@@ -74,14 +73,14 @@ class DistortionReport:
74
  class RestorationResult:
75
  """Result of neural face restoration pipeline."""
76
 
77
- restored: np.ndarray # Restored BGR image
78
- original: np.ndarray # Original BGR image
79
- distortion_report: DistortionReport # Pre-restoration analysis
80
- post_quality_score: float = 0.0 # Quality after restoration
81
- identity_similarity: float = 0.0 # ArcFace cosine sim (original vs restored)
82
- identity_preserved: bool = True # Whether identity check passed
83
  restoration_stages: list[str] = field(default_factory=list) # Which nets ran
84
- improvement: float = 0.0 # quality_after - quality_before
85
 
86
  def summary(self) -> str:
87
  lines = [
@@ -100,9 +99,9 @@ class BatchVerificationReport:
100
  """Summary of batch face verification/restoration."""
101
 
102
  total: int = 0
103
- passed: int = 0 # Good quality, no fix needed
104
- restored: int = 0 # Fixed and now usable
105
- rejected: int = 0 # Too distorted to salvage
106
  identity_failures: int = 0 # Restoration changed identity
107
  avg_quality_before: float = 0.0
108
  avg_quality_after: float = 0.0
@@ -123,7 +122,8 @@ class BatchVerificationReport:
123
  "Distortion Breakdown:",
124
  ]
125
  for dist_type, count in sorted(
126
- self.distortion_counts.items(), key=lambda x: -x[1],
 
127
  ):
128
  lines.append(f" {dist_type}: {count}")
129
  return "\n".join(lines)
@@ -133,6 +133,7 @@ class BatchVerificationReport:
133
  # Distortion Detection (classical + neural)
134
  # ---------------------------------------------------------------------------
135
 
 
136
  def detect_blur(image: np.ndarray) -> float:
137
  """Detect blur using Laplacian variance.
138
 
@@ -147,7 +148,7 @@ def detect_blur(image: np.ndarray) -> float:
147
  # Gradient magnitude (secondary)
148
  gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
149
  gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
150
- grad_mag = np.sqrt(gx ** 2 + gy ** 2).mean()
151
 
152
  # Normalize: typical sharp face has lap_var > 500, grad_mag > 30
153
  blur_lap = 1.0 - min(lap_var / 800.0, 1.0)
@@ -221,7 +222,7 @@ def detect_oversmoothing(image: np.ndarray) -> float:
221
  # Focus on face center region (avoid background)
222
  if h < 8 or w < 8:
223
  return 0.0 # Too small to analyze
224
- roi = gray[h // 4:3 * h // 4, w // 4:3 * w // 4]
225
 
226
  # Texture energy: variance of high-pass filtered image
227
  blurred = cv2.GaussianBlur(roi.astype(np.float64), (0, 0), 2.0)
@@ -254,7 +255,7 @@ def detect_color_cast(image: np.ndarray) -> float:
254
  h, w = image.shape[:2]
255
 
256
  # Sample face center region
257
- roi = lab[h // 4:3 * h // 4, w // 4:3 * w // 4]
258
 
259
  # A channel: green-red axis (neutral ~128)
260
  # B channel: blue-yellow axis (neutral ~128)
@@ -337,7 +338,7 @@ def detect_lighting_issues(image: np.ndarray) -> float:
337
 
338
  # Check for clipping
339
  overexposed = np.mean(l_channel > 245) * 5 # Fraction near white
340
- underexposed = np.mean(l_channel < 10) * 5 # Fraction near black
341
 
342
  # Check for bimodal distribution (harsh shadows)
343
  hist = cv2.calcHist([l_channel], [0], None, [256], [0, 256]).flatten()
@@ -429,7 +430,7 @@ def analyze_distortions(image: np.ndarray) -> DistortionReport:
429
  _FACE_QUALITY_NET = None
430
 
431
 
432
- def _get_face_quality_scorer():
433
  """Get or create singleton face quality assessment model.
434
 
435
  Uses FaceXLib's quality scorer or falls back to BRISQUE-style features.
@@ -440,6 +441,7 @@ def _get_face_quality_scorer():
440
 
441
  try:
442
  from facexlib.assessment import init_assessment_model
 
443
  _FACE_QUALITY_NET = init_assessment_model("hypernet")
444
  return _FACE_QUALITY_NET
445
  except Exception:
@@ -461,6 +463,7 @@ def neural_quality_score(image: np.ndarray) -> float:
461
  try:
462
  import torch
463
  from facexlib.utils import img2tensor
 
464
  img_t = img2tensor(image / 255.0, bgr2rgb=True, float32=True)
465
  img_t = img_t.unsqueeze(0)
466
  if torch.cuda.is_available():
@@ -481,6 +484,7 @@ def neural_quality_score(image: np.ndarray) -> float:
481
  # Neural Face Restoration (cascaded)
482
  # ---------------------------------------------------------------------------
483
 
 
484
  def restore_face(
485
  image: np.ndarray,
486
  distortion: DistortionReport | None = None,
@@ -563,6 +567,7 @@ def restore_face(
563
  post_blur = detect_blur(result)
564
  if post_blur > 0.3:
565
  from landmarkdiff.postprocess import frequency_aware_sharpen
 
566
  result = frequency_aware_sharpen(result, strength=0.3)
567
  stages.append("sharpen")
568
 
@@ -573,6 +578,7 @@ def _try_codeformer(image: np.ndarray, fidelity: float = 0.7) -> np.ndarray | No
573
  """Try CodeFormer restoration. Returns None if unavailable."""
574
  try:
575
  from landmarkdiff.postprocess import restore_face_codeformer
 
576
  restored = restore_face_codeformer(image, fidelity=fidelity)
577
  if restored is not image:
578
  return restored
@@ -585,6 +591,7 @@ def _try_gfpgan(image: np.ndarray) -> np.ndarray | None:
585
  """Try GFPGAN restoration. Returns None if unavailable."""
586
  try:
587
  from landmarkdiff.postprocess import restore_face_gfpgan
 
588
  restored = restore_face_gfpgan(image)
589
  if restored is not image:
590
  return restored
@@ -599,15 +606,19 @@ _FV_REALESRGAN = None
599
  def _try_realesrgan(image: np.ndarray) -> np.ndarray | None:
600
  """Try Real-ESRGAN 2x upscale + downsample. Returns None if unavailable."""
601
  try:
602
- from realesrgan import RealESRGANer
603
- from basicsr.archs.rrdbnet_arch import RRDBNet
604
  import torch
 
 
605
 
606
  global _FV_REALESRGAN
607
  if _FV_REALESRGAN is None:
608
  model = RRDBNet(
609
- num_in_ch=3, num_out_ch=3, num_feat=64,
610
- num_block=23, num_grow_ch=32, scale=4,
 
 
 
 
611
  )
612
  _FV_REALESRGAN = RealESRGANer(
613
  scale=4,
@@ -661,15 +672,15 @@ def _fix_lighting(image: np.ndarray) -> np.ndarray:
661
  _ARCFACE_APP = None
662
 
663
 
664
- def _get_arcface():
665
  """Get or create singleton ArcFace model."""
666
  global _ARCFACE_APP
667
  if _ARCFACE_APP is not None:
668
  return _ARCFACE_APP
669
 
670
  try:
671
- from insightface.app import FaceAnalysis
672
  import torch
 
673
 
674
  app = FaceAnalysis(
675
  name="buffalo_l",
@@ -717,9 +728,9 @@ def verify_identity(
717
  if emb_orig is None or emb_rest is None:
718
  return -1.0, True # Can't verify — assume OK
719
 
720
- sim = float(np.dot(emb_orig, emb_rest) / (
721
- np.linalg.norm(emb_orig) * np.linalg.norm(emb_rest) + 1e-8
722
- ))
723
  sim = float(np.clip(sim, -1, 1))
724
  return sim, sim >= threshold
725
 
@@ -728,6 +739,7 @@ def verify_identity(
728
  # Full Verification + Restoration Pipeline
729
  # ---------------------------------------------------------------------------
730
 
 
731
  def verify_and_restore(
732
  image: np.ndarray,
733
  quality_threshold: float = 60.0,
@@ -813,6 +825,7 @@ def verify_and_restore(
813
  # Batch Processing
814
  # ---------------------------------------------------------------------------
815
 
 
816
  def verify_batch(
817
  image_dir: str,
818
  output_dir: str | None = None,
@@ -858,10 +871,9 @@ def verify_batch(
858
  rejected_dir.mkdir(parents=True, exist_ok=True)
859
 
860
  # Find all images
861
- image_files = sorted([
862
- f for f in image_path.iterdir()
863
- if f.suffix.lower() in extensions and f.is_file()
864
- ])
865
 
866
  report = BatchVerificationReport(total=len(image_files))
867
  quality_before = []
 
15
 
16
  from __future__ import annotations
17
 
 
18
  from dataclasses import dataclass, field
19
  from pathlib import Path
20
+ from typing import Any
21
 
22
  import cv2
23
  import numpy as np
24
 
 
25
  # ---------------------------------------------------------------------------
26
  # Data structures
27
  # ---------------------------------------------------------------------------
28
 
29
+
30
  @dataclass
31
  class DistortionReport:
32
  """Analysis of detected distortions in a face image."""
 
35
  quality_score: float = 0.0
36
 
37
  # Individual distortion scores (0-1, higher = more distorted)
38
+ blur_score: float = 0.0 # Laplacian variance-based
39
+ noise_score: float = 0.0 # High-freq energy ratio
40
+ compression_score: float = 0.0 # JPEG block artifact detection
41
+ oversmooth_score: float = 0.0 # Beauty filter / airbrushed detection
42
+ color_cast_score: float = 0.0 # Unnatural color shift
43
+ geometric_distort: float = 0.0 # Face proportion anomalies
44
+ lighting_score: float = 0.0 # Over/under exposure
45
 
46
  # Classification
47
  primary_distortion: str = "none"
 
73
  class RestorationResult:
74
  """Result of neural face restoration pipeline."""
75
 
76
+ restored: np.ndarray # Restored BGR image
77
+ original: np.ndarray # Original BGR image
78
+ distortion_report: DistortionReport # Pre-restoration analysis
79
+ post_quality_score: float = 0.0 # Quality after restoration
80
+ identity_similarity: float = 0.0 # ArcFace cosine sim (original vs restored)
81
+ identity_preserved: bool = True # Whether identity check passed
82
  restoration_stages: list[str] = field(default_factory=list) # Which nets ran
83
+ improvement: float = 0.0 # quality_after - quality_before
84
 
85
  def summary(self) -> str:
86
  lines = [
 
99
  """Summary of batch face verification/restoration."""
100
 
101
  total: int = 0
102
+ passed: int = 0 # Good quality, no fix needed
103
+ restored: int = 0 # Fixed and now usable
104
+ rejected: int = 0 # Too distorted to salvage
105
  identity_failures: int = 0 # Restoration changed identity
106
  avg_quality_before: float = 0.0
107
  avg_quality_after: float = 0.0
 
122
  "Distortion Breakdown:",
123
  ]
124
  for dist_type, count in sorted(
125
+ self.distortion_counts.items(),
126
+ key=lambda x: -x[1],
127
  ):
128
  lines.append(f" {dist_type}: {count}")
129
  return "\n".join(lines)
 
133
  # Distortion Detection (classical + neural)
134
  # ---------------------------------------------------------------------------
135
 
136
+
137
  def detect_blur(image: np.ndarray) -> float:
138
  """Detect blur using Laplacian variance.
139
 
 
148
  # Gradient magnitude (secondary)
149
  gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
150
  gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
151
+ grad_mag = np.sqrt(gx**2 + gy**2).mean()
152
 
153
  # Normalize: typical sharp face has lap_var > 500, grad_mag > 30
154
  blur_lap = 1.0 - min(lap_var / 800.0, 1.0)
 
222
  # Focus on face center region (avoid background)
223
  if h < 8 or w < 8:
224
  return 0.0 # Too small to analyze
225
+ roi = gray[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
226
 
227
  # Texture energy: variance of high-pass filtered image
228
  blurred = cv2.GaussianBlur(roi.astype(np.float64), (0, 0), 2.0)
 
255
  h, w = image.shape[:2]
256
 
257
  # Sample face center region
258
+ roi = lab[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
259
 
260
  # A channel: green-red axis (neutral ~128)
261
  # B channel: blue-yellow axis (neutral ~128)
 
338
 
339
  # Check for clipping
340
  overexposed = np.mean(l_channel > 245) * 5 # Fraction near white
341
+ underexposed = np.mean(l_channel < 10) * 5 # Fraction near black
342
 
343
  # Check for bimodal distribution (harsh shadows)
344
  hist = cv2.calcHist([l_channel], [0], None, [256], [0, 256]).flatten()
 
430
  _FACE_QUALITY_NET = None
431
 
432
 
433
+ def _get_face_quality_scorer() -> Any:
434
  """Get or create singleton face quality assessment model.
435
 
436
  Uses FaceXLib's quality scorer or falls back to BRISQUE-style features.
 
441
 
442
  try:
443
  from facexlib.assessment import init_assessment_model
444
+
445
  _FACE_QUALITY_NET = init_assessment_model("hypernet")
446
  return _FACE_QUALITY_NET
447
  except Exception:
 
463
  try:
464
  import torch
465
  from facexlib.utils import img2tensor
466
+
467
  img_t = img2tensor(image / 255.0, bgr2rgb=True, float32=True)
468
  img_t = img_t.unsqueeze(0)
469
  if torch.cuda.is_available():
 
484
  # Neural Face Restoration (cascaded)
485
  # ---------------------------------------------------------------------------
486
 
487
+
488
  def restore_face(
489
  image: np.ndarray,
490
  distortion: DistortionReport | None = None,
 
567
  post_blur = detect_blur(result)
568
  if post_blur > 0.3:
569
  from landmarkdiff.postprocess import frequency_aware_sharpen
570
+
571
  result = frequency_aware_sharpen(result, strength=0.3)
572
  stages.append("sharpen")
573
 
 
578
  """Try CodeFormer restoration. Returns None if unavailable."""
579
  try:
580
  from landmarkdiff.postprocess import restore_face_codeformer
581
+
582
  restored = restore_face_codeformer(image, fidelity=fidelity)
583
  if restored is not image:
584
  return restored
 
591
  """Try GFPGAN restoration. Returns None if unavailable."""
592
  try:
593
  from landmarkdiff.postprocess import restore_face_gfpgan
594
+
595
  restored = restore_face_gfpgan(image)
596
  if restored is not image:
597
  return restored
 
606
  def _try_realesrgan(image: np.ndarray) -> np.ndarray | None:
607
  """Try Real-ESRGAN 2x upscale + downsample. Returns None if unavailable."""
608
  try:
 
 
609
  import torch
610
+ from basicsr.archs.rrdbnet_arch import RRDBNet
611
+ from realesrgan import RealESRGANer
612
 
613
  global _FV_REALESRGAN
614
  if _FV_REALESRGAN is None:
615
  model = RRDBNet(
616
+ num_in_ch=3,
617
+ num_out_ch=3,
618
+ num_feat=64,
619
+ num_block=23,
620
+ num_grow_ch=32,
621
+ scale=4,
622
  )
623
  _FV_REALESRGAN = RealESRGANer(
624
  scale=4,
 
672
  _ARCFACE_APP = None
673
 
674
 
675
+ def _get_arcface() -> Any:
676
  """Get or create singleton ArcFace model."""
677
  global _ARCFACE_APP
678
  if _ARCFACE_APP is not None:
679
  return _ARCFACE_APP
680
 
681
  try:
 
682
  import torch
683
+ from insightface.app import FaceAnalysis
684
 
685
  app = FaceAnalysis(
686
  name="buffalo_l",
 
728
  if emb_orig is None or emb_rest is None:
729
  return -1.0, True # Can't verify — assume OK
730
 
731
+ sim = float(
732
+ np.dot(emb_orig, emb_rest) / (np.linalg.norm(emb_orig) * np.linalg.norm(emb_rest) + 1e-8)
733
+ )
734
  sim = float(np.clip(sim, -1, 1))
735
  return sim, sim >= threshold
736
 
 
739
  # Full Verification + Restoration Pipeline
740
  # ---------------------------------------------------------------------------
741
 
742
+
743
  def verify_and_restore(
744
  image: np.ndarray,
745
  quality_threshold: float = 60.0,
 
825
  # Batch Processing
826
  # ---------------------------------------------------------------------------
827
 
828
+
829
  def verify_batch(
830
  image_dir: str,
831
  output_dir: str | None = None,
 
871
  rejected_dir.mkdir(parents=True, exist_ok=True)
872
 
873
  # Find all images
874
+ image_files = sorted(
875
+ [f for f in image_path.iterdir() if f.suffix.lower() in extensions and f.is_file()]
876
+ )
 
877
 
878
  report = BatchVerificationReport(total=len(image_files))
879
  quality_before = []
landmarkdiff/fid.py CHANGED
@@ -16,6 +16,7 @@ Usage:
16
  from __future__ import annotations
17
 
18
  from pathlib import Path
 
19
 
20
  import numpy as np
21
 
@@ -23,14 +24,15 @@ try:
23
  import torch
24
  import torch.nn as nn
25
  from torch.utils.data import DataLoader, Dataset
 
26
  HAS_TORCH = True
27
  except ImportError:
28
  HAS_TORCH = False
29
 
30
 
31
- def _load_inception_v3():
32
  """Load InceptionV3 with pool3 features (2048-dim)."""
33
- from torchvision.models import inception_v3, Inception_V3_Weights
34
 
35
  model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
36
  # We want features from the avg pool layer (2048-dim)
@@ -40,79 +42,81 @@ def _load_inception_v3():
40
  return model
41
 
42
 
43
- class ImageFolderDataset(Dataset):
44
- """Simple dataset that loads images from a directory."""
45
-
46
- def __init__(self, directory: str | Path, image_size: int = 299):
47
- self.directory = Path(directory)
48
- exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
49
- self.files = sorted(
50
- f for f in self.directory.iterdir()
51
- if f.suffix.lower() in exts and f.is_file()
52
- )
53
- self.image_size = image_size
54
-
55
- def __len__(self):
56
- return len(self.files)
57
-
58
- def __getitem__(self, idx):
59
- import cv2
60
- img = cv2.imread(str(self.files[idx]))
61
- if img is None:
62
- # Return zeros if image can't be loaded
63
- return torch.zeros(3, self.image_size, self.image_size)
64
- img = cv2.resize(img, (self.image_size, self.image_size))
65
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
66
- # Normalize to [0, 1] then ImageNet normalize
67
- t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
68
- t = _imagenet_normalize(t)
69
- return t
70
-
71
-
72
- class NumpyArrayDataset(Dataset):
73
- """Dataset wrapping a list of numpy arrays."""
74
-
75
- def __init__(self, images: list[np.ndarray], image_size: int = 299):
76
- self.images = images
77
- self.image_size = image_size
78
-
79
- def __len__(self):
80
- return len(self.images)
81
-
82
- def __getitem__(self, idx):
83
- import cv2
84
- img = self.images[idx]
85
- if img.shape[:2] != (self.image_size, self.image_size):
86
  img = cv2.resize(img, (self.image_size, self.image_size))
87
- if img.shape[2] == 3:
88
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
89
- t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
90
- t = _imagenet_normalize(t)
91
- return t
92
-
93
-
94
- def _imagenet_normalize(t: "torch.Tensor") -> "torch.Tensor":
95
- """Apply ImageNet normalization."""
96
- mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
97
- std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
98
- return (t - mean) / std
99
-
100
-
101
- @torch.no_grad()
102
- def _extract_features(
103
- model: nn.Module,
104
- dataloader: DataLoader,
105
- device: torch.device,
106
- ) -> np.ndarray:
107
- """Extract InceptionV3 pool3 features from a dataloader."""
108
- features = []
109
- for batch in dataloader:
110
- batch = batch.to(device)
111
- feat = model(batch)
112
- if isinstance(feat, tuple):
113
- feat = feat[0]
114
- features.append(feat.cpu().numpy())
115
- return np.concatenate(features, axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
@@ -123,8 +127,10 @@ def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
123
 
124
 
125
  def _calculate_fid(
126
- mu1: np.ndarray, sigma1: np.ndarray,
127
- mu2: np.ndarray, sigma2: np.ndarray,
 
 
128
  ) -> float:
129
  """Calculate FID given two sets of statistics.
130
 
@@ -177,10 +183,10 @@ def compute_fid_from_dirs(
177
  if len(real_ds) == 0 or len(gen_ds) == 0:
178
  raise ValueError("Need at least 1 image in each directory")
179
 
180
- real_loader = DataLoader(real_ds, batch_size=batch_size,
181
- num_workers=num_workers, pin_memory=True)
182
- gen_loader = DataLoader(gen_ds, batch_size=batch_size,
183
- num_workers=num_workers, pin_memory=True)
184
 
185
  real_features = _extract_features(model, real_loader, dev)
186
  gen_features = _extract_features(model, gen_loader, dev)
 
16
  from __future__ import annotations
17
 
18
  from pathlib import Path
19
+ from typing import Any
20
 
21
  import numpy as np
22
 
 
24
  import torch
25
  import torch.nn as nn
26
  from torch.utils.data import DataLoader, Dataset
27
+
28
  HAS_TORCH = True
29
  except ImportError:
30
  HAS_TORCH = False
31
 
32
 
33
+ def _load_inception_v3() -> Any:
34
  """Load InceptionV3 with pool3 features (2048-dim)."""
35
+ from torchvision.models import Inception_V3_Weights, inception_v3
36
 
37
  model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
38
  # We want features from the avg pool layer (2048-dim)
 
42
  return model
43
 
44
 
45
+ # Guard torch-dependent class and function definitions so the module
46
+ # can be imported safely when torch is not installed.
47
+ if HAS_TORCH:
48
+
49
+ class ImageFolderDataset(Dataset): # type: ignore[misc]
50
+ """Simple dataset that loads images from a directory."""
51
+
52
+ def __init__(self, directory: str | Path, image_size: int = 299):
53
+ self.directory = Path(directory)
54
+ exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
55
+ self.files = sorted(
56
+ f for f in self.directory.iterdir() if f.suffix.lower() in exts and f.is_file()
57
+ )
58
+ self.image_size = image_size
59
+
60
+ def __len__(self) -> int:
61
+ return len(self.files)
62
+
63
+ def __getitem__(self, idx: int) -> Any:
64
+ import cv2
65
+
66
+ img = cv2.imread(str(self.files[idx]))
67
+ if img is None:
68
+ # Return zeros if image can't be loaded
69
+ return torch.zeros(3, self.image_size, self.image_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  img = cv2.resize(img, (self.image_size, self.image_size))
 
71
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
72
+ # Normalize to [0, 1] then ImageNet normalize
73
+ t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
74
+ t = _imagenet_normalize(t)
75
+ return t
76
+
77
+ class NumpyArrayDataset(Dataset): # type: ignore[misc]
78
+ """Dataset wrapping a list of numpy arrays."""
79
+
80
+ def __init__(self, images: list[np.ndarray], image_size: int = 299):
81
+ self.images = images
82
+ self.image_size = image_size
83
+
84
+ def __len__(self) -> int:
85
+ return len(self.images)
86
+
87
+ def __getitem__(self, idx: int) -> Any:
88
+ import cv2
89
+
90
+ img = self.images[idx]
91
+ if img.shape[:2] != (self.image_size, self.image_size):
92
+ img = cv2.resize(img, (self.image_size, self.image_size))
93
+ if img.shape[2] == 3:
94
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
95
+ t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
96
+ t = _imagenet_normalize(t)
97
+ return t
98
+
99
+ def _imagenet_normalize(t: Any) -> Any:
100
+ """Apply ImageNet normalization."""
101
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
102
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
103
+ return (t - mean) / std
104
+
105
+ @torch.no_grad()
106
+ def _extract_features(
107
+ model: Any,
108
+ dataloader: Any,
109
+ device: Any,
110
+ ) -> np.ndarray:
111
+ """Extract InceptionV3 pool3 features from a dataloader."""
112
+ features = []
113
+ for batch in dataloader:
114
+ batch = batch.to(device)
115
+ feat = model(batch)
116
+ if isinstance(feat, tuple):
117
+ feat = feat[0]
118
+ features.append(feat.cpu().numpy())
119
+ return np.concatenate(features, axis=0)
120
 
121
 
122
  def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
 
127
 
128
 
129
  def _calculate_fid(
130
+ mu1: np.ndarray,
131
+ sigma1: np.ndarray,
132
+ mu2: np.ndarray,
133
+ sigma2: np.ndarray,
134
  ) -> float:
135
  """Calculate FID given two sets of statistics.
136
 
 
183
  if len(real_ds) == 0 or len(gen_ds) == 0:
184
  raise ValueError("Need at least 1 image in each directory")
185
 
186
+ real_loader = DataLoader(
187
+ real_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True
188
+ )
189
+ gen_loader = DataLoader(gen_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
190
 
191
  real_features = _extract_features(model, real_loader, dev)
192
  gen_features = _extract_features(model, gen_loader, dev)
landmarkdiff/hyperparam.py CHANGED
@@ -24,7 +24,7 @@ import json
24
  import math
25
  from dataclasses import dataclass, field
26
  from pathlib import Path
27
- from typing import Any, Iterator
28
 
29
 
30
  def _to_native(val: Any) -> Any:
@@ -99,27 +99,45 @@ class SearchSpace:
99
  self.params: dict[str, ParamSpec] = {}
100
 
101
  def add_float(
102
- self, name: str, low: float, high: float, log_scale: bool = False,
 
 
 
 
103
  ) -> SearchSpace:
104
  """Add a continuous float parameter."""
105
  self.params[name] = ParamSpec(
106
- name=name, param_type="float", low=low, high=high, log_scale=log_scale,
 
 
 
 
107
  )
108
  return self
109
 
110
  def add_int(
111
- self, name: str, low: int, high: int, step: int = 1,
 
 
 
 
112
  ) -> SearchSpace:
113
  """Add an integer parameter."""
114
  self.params[name] = ParamSpec(
115
- name=name, param_type="int", low=low, high=high, step=step,
 
 
 
 
116
  )
117
  return self
118
 
119
  def add_choice(self, name: str, choices: list[Any]) -> SearchSpace:
120
  """Add a categorical parameter."""
121
  self.params[name] = ParamSpec(
122
- name=name, param_type="choice", choices=choices,
 
 
123
  )
124
  return self
125
 
@@ -204,10 +222,7 @@ class HyperparamSearch:
204
  attempts = 0
205
  while len(trials) < n_trials and attempts < max_attempts:
206
  attempts += 1
207
- config = {
208
- name: spec.sample(rng)
209
- for name, spec in self.space.params.items()
210
- }
211
  trial = Trial(
212
  trial_id=f"trial_{len(trials):04d}",
213
  config=config,
@@ -223,14 +238,11 @@ class HyperparamSearch:
223
  import itertools
224
 
225
  param_names = list(self.space.params.keys())
226
- param_values = [
227
- self.space.params[name].grid_values(grid_points)
228
- for name in param_names
229
- ]
230
 
231
  trials = []
232
  for combo in itertools.product(*param_values):
233
- config = dict(zip(param_names, combo))
234
  trial = Trial(
235
  trial_id=f"trial_{len(trials):04d}",
236
  config=config,
@@ -240,7 +252,9 @@ class HyperparamSearch:
240
  return trials
241
 
242
  def record_result(
243
- self, trial_id: str, metrics: dict[str, float],
 
 
244
  ) -> None:
245
  """Record results for a trial."""
246
  for trial in self.trials:
@@ -251,7 +265,9 @@ class HyperparamSearch:
251
  raise KeyError(f"Trial {trial_id} not found")
252
 
253
  def best_trial(
254
- self, metric: str = "loss", lower_is_better: bool = True,
 
 
255
  ) -> Trial | None:
256
  """Get the best completed trial by a metric."""
257
  completed = [t for t in self.trials if t.status == "completed" and metric in t.result]
@@ -275,7 +291,8 @@ class HyperparamSearch:
275
  with open(cfg_path, "w") as f:
276
  yaml.safe_dump(
277
  {"trial_id": trial.trial_id, **native_config},
278
- f, default_flow_style=False,
 
279
  )
280
 
281
  # Save summary index
@@ -321,7 +338,7 @@ class HyperparamSearch:
321
  if isinstance(val, float):
322
  parts.append(f"{val:>12.6f}")
323
  else:
324
- parts.append(f"{str(val):>12s}")
325
  for m in metric_names:
326
  val = trial.result.get(m, float("nan"))
327
  parts.append(f"{val:>12.4f}")
 
24
  import math
25
  from dataclasses import dataclass, field
26
  from pathlib import Path
27
+ from typing import Any
28
 
29
 
30
  def _to_native(val: Any) -> Any:
 
99
  self.params: dict[str, ParamSpec] = {}
100
 
101
  def add_float(
102
+ self,
103
+ name: str,
104
+ low: float,
105
+ high: float,
106
+ log_scale: bool = False,
107
  ) -> SearchSpace:
108
  """Add a continuous float parameter."""
109
  self.params[name] = ParamSpec(
110
+ name=name,
111
+ param_type="float",
112
+ low=low,
113
+ high=high,
114
+ log_scale=log_scale,
115
  )
116
  return self
117
 
118
  def add_int(
119
+ self,
120
+ name: str,
121
+ low: int,
122
+ high: int,
123
+ step: int = 1,
124
  ) -> SearchSpace:
125
  """Add an integer parameter."""
126
  self.params[name] = ParamSpec(
127
+ name=name,
128
+ param_type="int",
129
+ low=low,
130
+ high=high,
131
+ step=step,
132
  )
133
  return self
134
 
135
  def add_choice(self, name: str, choices: list[Any]) -> SearchSpace:
136
  """Add a categorical parameter."""
137
  self.params[name] = ParamSpec(
138
+ name=name,
139
+ param_type="choice",
140
+ choices=choices,
141
  )
142
  return self
143
 
 
222
  attempts = 0
223
  while len(trials) < n_trials and attempts < max_attempts:
224
  attempts += 1
225
+ config = {name: spec.sample(rng) for name, spec in self.space.params.items()}
 
 
 
226
  trial = Trial(
227
  trial_id=f"trial_{len(trials):04d}",
228
  config=config,
 
238
  import itertools
239
 
240
  param_names = list(self.space.params.keys())
241
+ param_values = [self.space.params[name].grid_values(grid_points) for name in param_names]
 
 
 
242
 
243
  trials = []
244
  for combo in itertools.product(*param_values):
245
+ config = dict(zip(param_names, combo, strict=False))
246
  trial = Trial(
247
  trial_id=f"trial_{len(trials):04d}",
248
  config=config,
 
252
  return trials
253
 
254
  def record_result(
255
+ self,
256
+ trial_id: str,
257
+ metrics: dict[str, float],
258
  ) -> None:
259
  """Record results for a trial."""
260
  for trial in self.trials:
 
265
  raise KeyError(f"Trial {trial_id} not found")
266
 
267
  def best_trial(
268
+ self,
269
+ metric: str = "loss",
270
+ lower_is_better: bool = True,
271
  ) -> Trial | None:
272
  """Get the best completed trial by a metric."""
273
  completed = [t for t in self.trials if t.status == "completed" and metric in t.result]
 
291
  with open(cfg_path, "w") as f:
292
  yaml.safe_dump(
293
  {"trial_id": trial.trial_id, **native_config},
294
+ f,
295
+ default_flow_style=False,
296
  )
297
 
298
  # Save summary index
 
338
  if isinstance(val, float):
339
  parts.append(f"{val:>12.6f}")
340
  else:
341
+ parts.append(f"{val!s:>12s}")
342
  for m in metric_names:
343
  val = trial.result.get(m, float("nan"))
344
  parts.append(f"{val:>12.4f}")
landmarkdiff/inference.py CHANGED
@@ -13,7 +13,7 @@ from __future__ import annotations
13
 
14
  import sys
15
  from pathlib import Path
16
- from typing import Optional
17
 
18
  import cv2
19
  import numpy as np
@@ -21,11 +21,13 @@ import torch
21
  from PIL import Image
22
 
23
  from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks, render_landmark_image
24
- from landmarkdiff.conditioning import generate_conditioning
25
  from landmarkdiff.manipulation import apply_procedure_preset
26
  from landmarkdiff.masking import generate_surgical_mask, mask_to_3channel
27
  from landmarkdiff.synthetic.tps_warp import warp_image_tps
28
 
 
 
 
29
 
30
  def get_device() -> torch.device:
31
  if torch.backends.mps.is_available():
@@ -102,6 +104,7 @@ def mask_composite(
102
  if use_laplacian:
103
  try:
104
  from landmarkdiff.postprocess import laplacian_pyramid_blend
 
105
  return laplacian_pyramid_blend(corrected, original, mask_f)
106
  except Exception:
107
  pass
@@ -109,8 +112,7 @@ def mask_composite(
109
  # Fallback: simple alpha blend
110
  mask_3ch = mask_to_3channel(mask_f)
111
  result = (
112
- corrected.astype(np.float32) * mask_3ch
113
- + original.astype(np.float32) * (1.0 - mask_3ch)
114
  ).astype(np.uint8)
115
 
116
  return result
@@ -170,10 +172,10 @@ class LandmarkDiffPipeline:
170
  controlnet_id: str = "CrucibleAI/ControlNetMediaPipeFace",
171
  controlnet_checkpoint: str | None = None,
172
  base_model_id: str | None = None,
173
- device: Optional[torch.device] = None,
174
- dtype: Optional[torch.dtype] = None,
175
  ip_adapter_scale: float = 0.6,
176
- clinical_flags: Optional["ClinicalFlags"] = None,
177
  displacement_model_path: str | None = None,
178
  ):
179
  self.mode = mode
@@ -187,6 +189,7 @@ class LandmarkDiffPipeline:
187
  if displacement_model_path:
188
  try:
189
  from landmarkdiff.displacement_model import DisplacementModel
 
190
  self._displacement_model = DisplacementModel.load(displacement_model_path)
191
  print(f"Displacement model loaded: {self._displacement_model.procedures}")
192
  except Exception as e:
@@ -224,8 +227,8 @@ class LandmarkDiffPipeline:
224
  def _load_controlnet(self) -> None:
225
  from diffusers import (
226
  ControlNetModel,
227
- StableDiffusionControlNetPipeline,
228
  DPMSolverMultistepScheduler,
 
229
  )
230
 
231
  if self.controlnet_checkpoint:
@@ -236,12 +239,15 @@ class LandmarkDiffPipeline:
236
  ckpt_path = ckpt_path / "controlnet_ema"
237
  print(f"Loading fine-tuned ControlNet from {ckpt_path}...")
238
  controlnet = ControlNetModel.from_pretrained(
239
- str(ckpt_path), torch_dtype=self.dtype,
 
240
  )
241
  else:
242
  print(f"Loading ControlNet from {self.controlnet_id}...")
243
  controlnet = ControlNetModel.from_pretrained(
244
- self.controlnet_id, subfolder="diffusion_sd15", torch_dtype=self.dtype,
 
 
245
  )
246
  print(f"Loading base model from {self.base_model_id}...")
247
  self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
@@ -287,8 +293,8 @@ class LandmarkDiffPipeline:
287
 
288
  def _load_img2img(self) -> None:
289
  from diffusers import (
290
- StableDiffusionImg2ImgPipeline,
291
  DPMSolverMultistepScheduler,
 
292
  )
293
 
294
  print(f"Loading SD1.5 img2img from {self.base_model_id}...")
@@ -298,9 +304,7 @@ class LandmarkDiffPipeline:
298
  safety_checker=None,
299
  requires_safety_checker=False,
300
  )
301
- self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
302
- self._pipe.scheduler.config
303
- )
304
  self._apply_device_optimizations()
305
 
306
  def _apply_device_optimizations(self) -> None:
@@ -329,8 +333,8 @@ class LandmarkDiffPipeline:
329
  guidance_scale: float = 9.0,
330
  controlnet_conditioning_scale: float = 0.9,
331
  strength: float = 0.5,
332
- seed: Optional[int] = None,
333
- clinical_flags: Optional["ClinicalFlags"] = None,
334
  postprocess: bool = True,
335
  use_gfpgan: bool = False,
336
  ) -> dict:
@@ -351,12 +355,14 @@ class LandmarkDiffPipeline:
351
  manipulation_mode = "preset"
352
  if self._displacement_model and procedure in self._displacement_model.procedures:
353
  try:
354
- from landmarkdiff.displacement_model import DisplacementModel
355
  rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
356
  # Map UI intensity (0-100) to displacement model intensity (0-2)
357
  dm_intensity = intensity / 50.0 # 50 -> 1.0x mean displacement
358
  displacement = self._displacement_model.get_displacement_field(
359
- procedure, intensity=dm_intensity, noise_scale=0.3, rng=rng,
 
 
 
360
  )
361
  # Apply displacement to landmarks
362
  new_lm = face.landmarks.copy()
@@ -367,21 +373,34 @@ class LandmarkDiffPipeline:
367
  new_lm[:, 1] = np.clip(new_lm[:, 1], 0.01, 0.99)
368
  manipulated = FaceLandmarks(
369
  landmarks=new_lm,
370
- image_width=512, image_height=512,
 
371
  confidence=face.confidence,
372
  )
373
  manipulation_mode = "displacement_model"
374
  except Exception:
375
  manipulated = apply_procedure_preset(
376
- face, procedure, intensity, image_size=512, clinical_flags=flags,
 
 
 
 
377
  )
378
  else:
379
  manipulated = apply_procedure_preset(
380
- face, procedure, intensity, image_size=512, clinical_flags=flags,
 
 
 
 
381
  )
382
  landmark_img = render_landmark_image(manipulated, 512, 512)
383
  mask = generate_surgical_mask(
384
- face, procedure, 512, 512, clinical_flags=flags,
 
 
 
 
385
  )
386
 
387
  generator = None
@@ -398,14 +417,24 @@ class LandmarkDiffPipeline:
398
  elif self.mode in ("controlnet", "controlnet_ip"):
399
  ip_image = numpy_to_pil(image_512) if self._ip_adapter_loaded else None
400
  raw_output = self._generate_controlnet(
401
- image_512, landmark_img, prompt, num_inference_steps,
402
- guidance_scale, controlnet_conditioning_scale, generator,
 
 
 
 
 
403
  ip_adapter_image=ip_image,
404
  )
405
  else:
406
  raw_output = self._generate_img2img(
407
- tps_warped, mask, prompt, num_inference_steps,
408
- guidance_scale, strength, generator,
 
 
 
 
 
409
  )
410
 
411
  # Step 2: Post-processing for photorealism (neural + classical pipeline)
@@ -413,6 +442,7 @@ class LandmarkDiffPipeline:
413
  restore_used = "none"
414
  if postprocess and self.mode != "tps":
415
  from landmarkdiff.postprocess import full_postprocess
 
416
  pp_result = full_postprocess(
417
  generated=raw_output,
418
  original=image_512,
@@ -450,8 +480,13 @@ class LandmarkDiffPipeline:
450
  }
451
 
452
  def _generate_controlnet(
453
- self, image: np.ndarray, conditioning: np.ndarray,
454
- prompt: str, steps: int, cfg: float, cn_scale: float,
 
 
 
 
 
455
  generator: torch.Generator | None,
456
  ip_adapter_image: Image.Image | None = None,
457
  ) -> np.ndarray:
@@ -470,8 +505,13 @@ class LandmarkDiffPipeline:
470
  return pil_to_numpy(result.images[0])
471
 
472
  def _generate_img2img(
473
- self, image: np.ndarray, mask: np.ndarray,
474
- prompt: str, steps: int, cfg: float, strength: float,
 
 
 
 
 
475
  generator: torch.Generator | None,
476
  ) -> np.ndarray:
477
  result = self._pipe(
@@ -558,7 +598,8 @@ def run_inference(
558
  sys.exit(1)
559
 
560
  pipe = LandmarkDiffPipeline(
561
- mode=mode, ip_adapter_scale=ip_adapter_scale,
 
562
  controlnet_checkpoint=controlnet_checkpoint,
563
  displacement_model_path=displacement_model_path,
564
  )
@@ -594,18 +635,29 @@ if __name__ == "__main__":
594
  parser.add_argument("--output", default="scripts/inference_output")
595
  parser.add_argument("--seed", type=int, default=42)
596
  parser.add_argument(
597
- "--mode", default="img2img",
 
598
  choices=["img2img", "controlnet", "controlnet_ip", "tps"],
599
  )
600
  parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
601
- parser.add_argument("--checkpoint", default=None,
602
- help="Path to fine-tuned ControlNet checkpoint")
603
- parser.add_argument("--displacement-model", default=None,
604
- help="Path to displacement_model.npz for data-driven manipulation")
 
 
 
 
605
  args = parser.parse_args()
606
 
607
  run_inference(
608
- args.image, args.procedure, args.intensity, args.output,
609
- args.seed, args.mode, args.ip_adapter_scale, args.checkpoint,
 
 
 
 
 
 
610
  args.displacement_model,
611
  )
 
13
 
14
  import sys
15
  from pathlib import Path
16
+ from typing import TYPE_CHECKING
17
 
18
  import cv2
19
  import numpy as np
 
21
  from PIL import Image
22
 
23
  from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks, render_landmark_image
 
24
  from landmarkdiff.manipulation import apply_procedure_preset
25
  from landmarkdiff.masking import generate_surgical_mask, mask_to_3channel
26
  from landmarkdiff.synthetic.tps_warp import warp_image_tps
27
 
28
+ if TYPE_CHECKING:
29
+ from landmarkdiff.clinical import ClinicalFlags
30
+
31
 
32
  def get_device() -> torch.device:
33
  if torch.backends.mps.is_available():
 
104
  if use_laplacian:
105
  try:
106
  from landmarkdiff.postprocess import laplacian_pyramid_blend
107
+
108
  return laplacian_pyramid_blend(corrected, original, mask_f)
109
  except Exception:
110
  pass
 
112
  # Fallback: simple alpha blend
113
  mask_3ch = mask_to_3channel(mask_f)
114
  result = (
115
+ corrected.astype(np.float32) * mask_3ch + original.astype(np.float32) * (1.0 - mask_3ch)
 
116
  ).astype(np.uint8)
117
 
118
  return result
 
172
  controlnet_id: str = "CrucibleAI/ControlNetMediaPipeFace",
173
  controlnet_checkpoint: str | None = None,
174
  base_model_id: str | None = None,
175
+ device: torch.device | None = None,
176
+ dtype: torch.dtype | None = None,
177
  ip_adapter_scale: float = 0.6,
178
+ clinical_flags: ClinicalFlags | None = None,
179
  displacement_model_path: str | None = None,
180
  ):
181
  self.mode = mode
 
189
  if displacement_model_path:
190
  try:
191
  from landmarkdiff.displacement_model import DisplacementModel
192
+
193
  self._displacement_model = DisplacementModel.load(displacement_model_path)
194
  print(f"Displacement model loaded: {self._displacement_model.procedures}")
195
  except Exception as e:
 
227
  def _load_controlnet(self) -> None:
228
  from diffusers import (
229
  ControlNetModel,
 
230
  DPMSolverMultistepScheduler,
231
+ StableDiffusionControlNetPipeline,
232
  )
233
 
234
  if self.controlnet_checkpoint:
 
239
  ckpt_path = ckpt_path / "controlnet_ema"
240
  print(f"Loading fine-tuned ControlNet from {ckpt_path}...")
241
  controlnet = ControlNetModel.from_pretrained(
242
+ str(ckpt_path),
243
+ torch_dtype=self.dtype,
244
  )
245
  else:
246
  print(f"Loading ControlNet from {self.controlnet_id}...")
247
  controlnet = ControlNetModel.from_pretrained(
248
+ self.controlnet_id,
249
+ subfolder="diffusion_sd15",
250
+ torch_dtype=self.dtype,
251
  )
252
  print(f"Loading base model from {self.base_model_id}...")
253
  self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
 
293
 
294
  def _load_img2img(self) -> None:
295
  from diffusers import (
 
296
  DPMSolverMultistepScheduler,
297
+ StableDiffusionImg2ImgPipeline,
298
  )
299
 
300
  print(f"Loading SD1.5 img2img from {self.base_model_id}...")
 
304
  safety_checker=None,
305
  requires_safety_checker=False,
306
  )
307
+ self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(self._pipe.scheduler.config)
 
 
308
  self._apply_device_optimizations()
309
 
310
  def _apply_device_optimizations(self) -> None:
 
333
  guidance_scale: float = 9.0,
334
  controlnet_conditioning_scale: float = 0.9,
335
  strength: float = 0.5,
336
+ seed: int | None = None,
337
+ clinical_flags: ClinicalFlags | None = None,
338
  postprocess: bool = True,
339
  use_gfpgan: bool = False,
340
  ) -> dict:
 
355
  manipulation_mode = "preset"
356
  if self._displacement_model and procedure in self._displacement_model.procedures:
357
  try:
 
358
  rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
359
  # Map UI intensity (0-100) to displacement model intensity (0-2)
360
  dm_intensity = intensity / 50.0 # 50 -> 1.0x mean displacement
361
  displacement = self._displacement_model.get_displacement_field(
362
+ procedure,
363
+ intensity=dm_intensity,
364
+ noise_scale=0.3,
365
+ rng=rng,
366
  )
367
  # Apply displacement to landmarks
368
  new_lm = face.landmarks.copy()
 
373
  new_lm[:, 1] = np.clip(new_lm[:, 1], 0.01, 0.99)
374
  manipulated = FaceLandmarks(
375
  landmarks=new_lm,
376
+ image_width=512,
377
+ image_height=512,
378
  confidence=face.confidence,
379
  )
380
  manipulation_mode = "displacement_model"
381
  except Exception:
382
  manipulated = apply_procedure_preset(
383
+ face,
384
+ procedure,
385
+ intensity,
386
+ image_size=512,
387
+ clinical_flags=flags,
388
  )
389
  else:
390
  manipulated = apply_procedure_preset(
391
+ face,
392
+ procedure,
393
+ intensity,
394
+ image_size=512,
395
+ clinical_flags=flags,
396
  )
397
  landmark_img = render_landmark_image(manipulated, 512, 512)
398
  mask = generate_surgical_mask(
399
+ face,
400
+ procedure,
401
+ 512,
402
+ 512,
403
+ clinical_flags=flags,
404
  )
405
 
406
  generator = None
 
417
  elif self.mode in ("controlnet", "controlnet_ip"):
418
  ip_image = numpy_to_pil(image_512) if self._ip_adapter_loaded else None
419
  raw_output = self._generate_controlnet(
420
+ image_512,
421
+ landmark_img,
422
+ prompt,
423
+ num_inference_steps,
424
+ guidance_scale,
425
+ controlnet_conditioning_scale,
426
+ generator,
427
  ip_adapter_image=ip_image,
428
  )
429
  else:
430
  raw_output = self._generate_img2img(
431
+ tps_warped,
432
+ mask,
433
+ prompt,
434
+ num_inference_steps,
435
+ guidance_scale,
436
+ strength,
437
+ generator,
438
  )
439
 
440
  # Step 2: Post-processing for photorealism (neural + classical pipeline)
 
442
  restore_used = "none"
443
  if postprocess and self.mode != "tps":
444
  from landmarkdiff.postprocess import full_postprocess
445
+
446
  pp_result = full_postprocess(
447
  generated=raw_output,
448
  original=image_512,
 
480
  }
481
 
482
  def _generate_controlnet(
483
+ self,
484
+ image: np.ndarray,
485
+ conditioning: np.ndarray,
486
+ prompt: str,
487
+ steps: int,
488
+ cfg: float,
489
+ cn_scale: float,
490
  generator: torch.Generator | None,
491
  ip_adapter_image: Image.Image | None = None,
492
  ) -> np.ndarray:
 
505
  return pil_to_numpy(result.images[0])
506
 
507
  def _generate_img2img(
508
+ self,
509
+ image: np.ndarray,
510
+ mask: np.ndarray,
511
+ prompt: str,
512
+ steps: int,
513
+ cfg: float,
514
+ strength: float,
515
  generator: torch.Generator | None,
516
  ) -> np.ndarray:
517
  result = self._pipe(
 
598
  sys.exit(1)
599
 
600
  pipe = LandmarkDiffPipeline(
601
+ mode=mode,
602
+ ip_adapter_scale=ip_adapter_scale,
603
  controlnet_checkpoint=controlnet_checkpoint,
604
  displacement_model_path=displacement_model_path,
605
  )
 
635
  parser.add_argument("--output", default="scripts/inference_output")
636
  parser.add_argument("--seed", type=int, default=42)
637
  parser.add_argument(
638
+ "--mode",
639
+ default="img2img",
640
  choices=["img2img", "controlnet", "controlnet_ip", "tps"],
641
  )
642
  parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
643
+ parser.add_argument(
644
+ "--checkpoint", default=None, help="Path to fine-tuned ControlNet checkpoint"
645
+ )
646
+ parser.add_argument(
647
+ "--displacement-model",
648
+ default=None,
649
+ help="Path to displacement_model.npz for data-driven manipulation",
650
+ )
651
  args = parser.parse_args()
652
 
653
  run_inference(
654
+ args.image,
655
+ args.procedure,
656
+ args.intensity,
657
+ args.output,
658
+ args.seed,
659
+ args.mode,
660
+ args.ip_adapter_scale,
661
+ args.checkpoint,
662
  args.displacement_model,
663
  )
landmarkdiff/landmarks.py CHANGED
@@ -4,7 +4,6 @@ from __future__ import annotations
4
 
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
- from typing import Optional
8
 
9
  import cv2
10
  import mediapipe as mp
@@ -12,39 +11,145 @@ import numpy as np
12
 
13
  # Region color map for visualization (BGR)
14
  REGION_COLORS: dict[str, tuple[int, int, int]] = {
15
- "jawline": (255, 255, 255), # white
16
- "eyebrow_left": (0, 255, 0), # green
17
  "eyebrow_right": (0, 255, 0),
18
- "eye_left": (255, 255, 0), # cyan
19
  "eye_right": (255, 255, 0),
20
- "nose": (0, 255, 255), # yellow
21
- "lips": (0, 0, 255), # red
22
- "iris_left": (255, 0, 255), # magenta
23
  "iris_right": (255, 0, 255),
24
  }
25
 
26
  # MediaPipe landmark index groups by anatomical region
27
  LANDMARK_REGIONS: dict[str, list[int]] = {
28
  "jawline": [
29
- 10, 338, 297, 332, 284, 251, 389, 356, 454, 323, 361, 288,
30
- 397, 365, 379, 378, 400, 377, 152, 148, 176, 149, 150, 136,
31
- 172, 58, 132, 93, 234, 127, 162, 21, 54, 103, 67, 109,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ],
33
  "eye_left": [
34
- 33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ],
36
  "eye_right": [
37
- 362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386, 385, 384, 398,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ],
39
  "eyebrow_left": [70, 63, 105, 66, 107, 55, 65, 52, 53, 46],
40
  "eyebrow_right": [300, 293, 334, 296, 336, 285, 295, 282, 283, 276],
41
  "nose": [
42
- 1, 2, 4, 5, 6, 19, 94, 141, 168, 195, 197, 236, 240,
43
- 274, 275, 278, 279, 294, 326, 327, 360, 363, 370, 456, 460,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ],
45
  "lips": [
46
- 61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291,
47
- 308, 324, 318, 402, 317, 14, 87, 178, 88, 95, 78,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ],
49
  "iris_left": [468, 469, 470, 471, 472],
50
  "iris_right": [473, 474, 475, 476, 477],
@@ -78,7 +183,7 @@ def extract_landmarks(
78
  image: np.ndarray,
79
  min_detection_confidence: float = 0.5,
80
  min_tracking_confidence: float = 0.5,
81
- ) -> Optional[FaceLandmarks]:
82
  """Extract 478 facial landmarks from an image using MediaPipe Face Mesh.
83
 
84
  Args:
@@ -97,7 +202,9 @@ def extract_landmarks(
97
  landmarks, confidence = _extract_tasks_api(rgb, min_detection_confidence)
98
  except Exception:
99
  try:
100
- landmarks, confidence = _extract_solutions_api(rgb, min_detection_confidence, min_tracking_confidence)
 
 
101
  except Exception:
102
  return None
103
 
@@ -115,14 +222,14 @@ def extract_landmarks(
115
  def _extract_tasks_api(
116
  rgb: np.ndarray,
117
  min_confidence: float,
118
- ) -> tuple[Optional[np.ndarray], float]:
119
  """Extract landmarks using MediaPipe Tasks API (>= 0.10.20)."""
120
  FaceLandmarker = mp.tasks.vision.FaceLandmarker
121
  FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
122
  RunningMode = mp.tasks.vision.RunningMode
123
  BaseOptions = mp.tasks.BaseOptions
124
- import urllib.request
125
  import tempfile
 
126
 
127
  # Download model if not cached
128
  model_path = Path(tempfile.gettempdir()) / "face_landmarker_v2_with_blendshapes.task"
@@ -161,7 +268,7 @@ def _extract_solutions_api(
161
  rgb: np.ndarray,
162
  min_detection_confidence: float,
163
  min_tracking_confidence: float,
164
- ) -> tuple[Optional[np.ndarray], float]:
165
  """Extract landmarks using legacy MediaPipe Solutions API."""
166
  with mp.solutions.face_mesh.FaceMesh(
167
  static_image_mode=True,
@@ -224,8 +331,8 @@ def visualize_landmarks(
224
 
225
  def render_landmark_image(
226
  face: FaceLandmarks,
227
- width: Optional[int] = None,
228
- height: Optional[int] = None,
229
  radius: int = 2,
230
  ) -> np.ndarray:
231
  """Render MediaPipe face mesh tessellation on black canvas.
@@ -257,6 +364,7 @@ def render_landmark_image(
257
  # Draw tessellation mesh (what CrucibleAI ControlNet expects)
258
  try:
259
  from mediapipe.tasks.python.vision.face_landmarker import FaceLandmarksConnections
 
260
  tessellation = FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION
261
  contours = FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS
262
 
 
4
 
5
  from dataclasses import dataclass
6
  from pathlib import Path
 
7
 
8
  import cv2
9
  import mediapipe as mp
 
11
 
12
  # Region color map for visualization (BGR)
13
  REGION_COLORS: dict[str, tuple[int, int, int]] = {
14
+ "jawline": (255, 255, 255), # white
15
+ "eyebrow_left": (0, 255, 0), # green
16
  "eyebrow_right": (0, 255, 0),
17
+ "eye_left": (255, 255, 0), # cyan
18
  "eye_right": (255, 255, 0),
19
+ "nose": (0, 255, 255), # yellow
20
+ "lips": (0, 0, 255), # red
21
+ "iris_left": (255, 0, 255), # magenta
22
  "iris_right": (255, 0, 255),
23
  }
24
 
25
  # MediaPipe landmark index groups by anatomical region
26
  LANDMARK_REGIONS: dict[str, list[int]] = {
27
  "jawline": [
28
+ 10,
29
+ 338,
30
+ 297,
31
+ 332,
32
+ 284,
33
+ 251,
34
+ 389,
35
+ 356,
36
+ 454,
37
+ 323,
38
+ 361,
39
+ 288,
40
+ 397,
41
+ 365,
42
+ 379,
43
+ 378,
44
+ 400,
45
+ 377,
46
+ 152,
47
+ 148,
48
+ 176,
49
+ 149,
50
+ 150,
51
+ 136,
52
+ 172,
53
+ 58,
54
+ 132,
55
+ 93,
56
+ 234,
57
+ 127,
58
+ 162,
59
+ 21,
60
+ 54,
61
+ 103,
62
+ 67,
63
+ 109,
64
  ],
65
  "eye_left": [
66
+ 33,
67
+ 7,
68
+ 163,
69
+ 144,
70
+ 145,
71
+ 153,
72
+ 154,
73
+ 155,
74
+ 133,
75
+ 173,
76
+ 157,
77
+ 158,
78
+ 159,
79
+ 160,
80
+ 161,
81
+ 246,
82
  ],
83
  "eye_right": [
84
+ 362,
85
+ 382,
86
+ 381,
87
+ 380,
88
+ 374,
89
+ 373,
90
+ 390,
91
+ 249,
92
+ 263,
93
+ 466,
94
+ 388,
95
+ 387,
96
+ 386,
97
+ 385,
98
+ 384,
99
+ 398,
100
  ],
101
  "eyebrow_left": [70, 63, 105, 66, 107, 55, 65, 52, 53, 46],
102
  "eyebrow_right": [300, 293, 334, 296, 336, 285, 295, 282, 283, 276],
103
  "nose": [
104
+ 1,
105
+ 2,
106
+ 4,
107
+ 5,
108
+ 6,
109
+ 19,
110
+ 94,
111
+ 141,
112
+ 168,
113
+ 195,
114
+ 197,
115
+ 236,
116
+ 240,
117
+ 274,
118
+ 275,
119
+ 278,
120
+ 279,
121
+ 294,
122
+ 326,
123
+ 327,
124
+ 360,
125
+ 363,
126
+ 370,
127
+ 456,
128
+ 460,
129
  ],
130
  "lips": [
131
+ 61,
132
+ 146,
133
+ 91,
134
+ 181,
135
+ 84,
136
+ 17,
137
+ 314,
138
+ 405,
139
+ 321,
140
+ 375,
141
+ 291,
142
+ 308,
143
+ 324,
144
+ 318,
145
+ 402,
146
+ 317,
147
+ 14,
148
+ 87,
149
+ 178,
150
+ 88,
151
+ 95,
152
+ 78,
153
  ],
154
  "iris_left": [468, 469, 470, 471, 472],
155
  "iris_right": [473, 474, 475, 476, 477],
 
183
  image: np.ndarray,
184
  min_detection_confidence: float = 0.5,
185
  min_tracking_confidence: float = 0.5,
186
+ ) -> FaceLandmarks | None:
187
  """Extract 478 facial landmarks from an image using MediaPipe Face Mesh.
188
 
189
  Args:
 
202
  landmarks, confidence = _extract_tasks_api(rgb, min_detection_confidence)
203
  except Exception:
204
  try:
205
+ landmarks, confidence = _extract_solutions_api(
206
+ rgb, min_detection_confidence, min_tracking_confidence
207
+ )
208
  except Exception:
209
  return None
210
 
 
222
  def _extract_tasks_api(
223
  rgb: np.ndarray,
224
  min_confidence: float,
225
+ ) -> tuple[np.ndarray | None, float]:
226
  """Extract landmarks using MediaPipe Tasks API (>= 0.10.20)."""
227
  FaceLandmarker = mp.tasks.vision.FaceLandmarker
228
  FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
229
  RunningMode = mp.tasks.vision.RunningMode
230
  BaseOptions = mp.tasks.BaseOptions
 
231
  import tempfile
232
+ import urllib.request
233
 
234
  # Download model if not cached
235
  model_path = Path(tempfile.gettempdir()) / "face_landmarker_v2_with_blendshapes.task"
 
268
  rgb: np.ndarray,
269
  min_detection_confidence: float,
270
  min_tracking_confidence: float,
271
+ ) -> tuple[np.ndarray | None, float]:
272
  """Extract landmarks using legacy MediaPipe Solutions API."""
273
  with mp.solutions.face_mesh.FaceMesh(
274
  static_image_mode=True,
 
331
 
332
  def render_landmark_image(
333
  face: FaceLandmarks,
334
+ width: int | None = None,
335
+ height: int | None = None,
336
  radius: int = 2,
337
  ) -> np.ndarray:
338
  """Render MediaPipe face mesh tessellation on black canvas.
 
364
  # Draw tessellation mesh (what CrucibleAI ControlNet expects)
365
  try:
366
  from mediapipe.tasks.python.vision.face_landmarker import FaceLandmarksConnections
367
+
368
  tessellation = FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION
369
  contours = FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS
370
 
landmarkdiff/log.py CHANGED
@@ -46,10 +46,12 @@ def setup_logging(
46
 
47
  if not _CONFIGURED:
48
  handler = logging.StreamHandler(stream or sys.stderr)
49
- handler.setFormatter(logging.Formatter(
50
- fmt or LOG_FORMAT,
51
- datefmt=LOG_DATE_FORMAT,
52
- ))
 
 
53
  root_logger.addHandler(handler)
54
  # Prevent propagation to root logger to avoid duplicate messages
55
  root_logger.propagate = False
 
46
 
47
  if not _CONFIGURED:
48
  handler = logging.StreamHandler(stream or sys.stderr)
49
+ handler.setFormatter(
50
+ logging.Formatter(
51
+ fmt or LOG_FORMAT,
52
+ datefmt=LOG_DATE_FORMAT,
53
+ )
54
+ )
55
  root_logger.addHandler(handler)
56
  # Prevent propagation to root logger to avoid duplicate messages
57
  root_logger.propagate = False
landmarkdiff/losses.py CHANGED
@@ -1,6 +1,7 @@
1
  """4-term loss function module for ControlNet fine-tuning.
2
 
3
- L_total = L_diffusion + w_landmark * L_landmark + w_identity * L_identity + w_perceptual * L_perceptual
 
4
 
5
  Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
6
  rubbery TPS warps — it would penalize realism.
@@ -92,11 +93,16 @@ class IdentityLoss:
92
  return
93
  try:
94
  from insightface.app import FaceAnalysis
 
95
  self._app = FaceAnalysis(
96
  name="buffalo_l",
97
  providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
98
  )
99
- ctx_id = device.index if device.type == "cuda" and device.index is not None else (0 if device.type == "cuda" else -1)
 
 
 
 
100
  self._app.prepare(ctx_id=ctx_id, det_size=(320, 320))
101
  self._has_arcface = True
102
  except Exception:
@@ -114,6 +120,7 @@ class IdentityLoss:
114
  """
115
  if self._has_arcface:
116
  import numpy as np
 
117
  embeddings = []
118
  valid_mask = []
119
  for i in range(image_tensor.shape[0]):
@@ -152,7 +159,9 @@ class IdentityLoss:
152
 
153
  # Resize to 112x112 for ArcFace
154
  pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False)
155
- target_112 = F.interpolate(target_crop, size=(112, 112), mode="bilinear", align_corners=False)
 
 
156
 
157
  # Normalize to [-1, 1]
158
  pred_norm = pred_112 * 2 - 1
@@ -163,7 +172,7 @@ class IdentityLoss:
163
  target_emb, target_valid = self._extract_embedding(target_norm)
164
 
165
  # Only compute loss for samples where both faces were detected
166
- valid = [p and t for p, t in zip(pred_valid, target_valid)]
167
  if not any(valid):
168
  return torch.tensor(0.0, device=pred_image.device)
169
 
@@ -216,6 +225,7 @@ class PerceptualLoss:
216
  if self._lpips is None:
217
  try:
218
  import lpips
 
219
  self._lpips = lpips.LPIPS(net="alex").to(device)
220
  self._lpips.eval()
221
  for p in self._lpips.parameters():
@@ -225,9 +235,9 @@ class PerceptualLoss:
225
 
226
  def __call__(
227
  self,
228
- pred: torch.Tensor, # (B, 3, H, W) in [0, 1]
229
  target: torch.Tensor,
230
- mask: torch.Tensor, # (B, 1, H, W) surgical mask [0, 1]
231
  ) -> torch.Tensor:
232
  self._ensure_loaded(pred.device)
233
 
@@ -289,6 +299,7 @@ class CombinedLoss:
289
  # or ONNX-based fallback
290
  if use_differentiable_arcface:
291
  from landmarkdiff.arcface_torch import ArcFaceLoss
 
292
  self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
293
  else:
294
  self.identity_loss = IdentityLoss()
 
1
  """4-term loss function module for ControlNet fine-tuning.
2
 
3
+ L_total = L_diffusion + w_landmark * L_landmark
4
+ + w_identity * L_identity + w_perceptual * L_perceptual
5
 
6
  Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
7
  rubbery TPS warps — it would penalize realism.
 
93
  return
94
  try:
95
  from insightface.app import FaceAnalysis
96
+
97
  self._app = FaceAnalysis(
98
  name="buffalo_l",
99
  providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
100
  )
101
+ ctx_id = (
102
+ device.index
103
+ if device.type == "cuda" and device.index is not None
104
+ else (0 if device.type == "cuda" else -1)
105
+ )
106
  self._app.prepare(ctx_id=ctx_id, det_size=(320, 320))
107
  self._has_arcface = True
108
  except Exception:
 
120
  """
121
  if self._has_arcface:
122
  import numpy as np
123
+
124
  embeddings = []
125
  valid_mask = []
126
  for i in range(image_tensor.shape[0]):
 
159
 
160
  # Resize to 112x112 for ArcFace
161
  pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False)
162
+ target_112 = F.interpolate(
163
+ target_crop, size=(112, 112), mode="bilinear", align_corners=False
164
+ )
165
 
166
  # Normalize to [-1, 1]
167
  pred_norm = pred_112 * 2 - 1
 
172
  target_emb, target_valid = self._extract_embedding(target_norm)
173
 
174
  # Only compute loss for samples where both faces were detected
175
+ valid = [p and t for p, t in zip(pred_valid, target_valid, strict=False)]
176
  if not any(valid):
177
  return torch.tensor(0.0, device=pred_image.device)
178
 
 
225
  if self._lpips is None:
226
  try:
227
  import lpips
228
+
229
  self._lpips = lpips.LPIPS(net="alex").to(device)
230
  self._lpips.eval()
231
  for p in self._lpips.parameters():
 
235
 
236
  def __call__(
237
  self,
238
+ pred: torch.Tensor, # (B, 3, H, W) in [0, 1]
239
  target: torch.Tensor,
240
+ mask: torch.Tensor, # (B, 1, H, W) surgical mask [0, 1]
241
  ) -> torch.Tensor:
242
  self._ensure_loaded(pred.device)
243
 
 
299
  # or ONNX-based fallback
300
  if use_differentiable_arcface:
301
  from landmarkdiff.arcface_torch import ArcFaceLoss
302
+
303
  self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
304
  else:
305
  self.identity_loss = IdentityLoss()
landmarkdiff/manipulation.py CHANGED
@@ -7,11 +7,11 @@ mm inputs only in v3+ with FLAME calibrated metric space.
7
  from __future__ import annotations
8
 
9
  from dataclasses import dataclass
10
- from typing import Optional, TYPE_CHECKING
11
 
12
  import numpy as np
13
 
14
- from landmarkdiff.landmarks import FaceLandmarks, LANDMARK_REGIONS
15
 
16
  if TYPE_CHECKING:
17
  from landmarkdiff.clinical import ClinicalFlags
@@ -23,38 +23,184 @@ class DeformationHandle:
23
 
24
  landmark_index: int
25
  displacement: np.ndarray # (2,) or (3,) pixel displacement
26
- influence_radius: float # Gaussian RBF radius in pixels
27
 
28
 
29
  # Procedure-specific landmark indices from the technical specification
30
  PROCEDURE_LANDMARKS: dict[str, list[int]] = {
31
  "rhinoplasty": [
32
- 1, 2, 4, 5, 6, 19, 94, 141, 168, 195, 197, 236, 240,
33
- 274, 275, 278, 279, 294, 326, 327, 360, 363, 370, 456, 460,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ],
35
  "blepharoplasty": [
36
- 33, 7, 163, 144, 145, 153, 154, 155, 157, 158, 159, 160, 161, 246,
37
- 362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386,
38
- 385, 384, 398,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  ],
40
  "rhytidectomy": [
41
- 10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 150, 162, 172,
42
- 176, 187, 207, 213, 234, 284, 297, 323, 332, 338, 356, 361, 365,
43
- 379, 389, 397, 400, 427, 454,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ],
45
  "orthognathic": [
46
- 0, 17, 18, 36, 37, 39, 40, 57, 61, 78, 80, 81, 82, 84, 87, 88,
47
- 91, 95, 146, 167, 169, 170, 175, 181, 191, 200, 201, 202, 204,
48
- 208, 211, 212, 214, 269, 270, 291, 311, 312, 317, 321, 324, 325,
49
- 375, 396, 405, 407, 415,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  ],
51
  "brow_lift": [
52
- 70, 63, 105, 66, 107, # left brow
53
- 300, 293, 334, 296, 336, # right brow
54
- 9, 8, 10, 109, 67, 103, 338, 297, 332, # forehead/upper face
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  ],
56
  "mentoplasty": [
57
- 148, 149, 150, 152, 171, 175, 176, 377,
 
 
 
 
 
 
 
58
  ],
59
  }
60
  # Default influence radii per procedure (in pixels at 512x512)
@@ -78,7 +224,7 @@ def gaussian_rbf_deform(
78
  displacement = handle.displacement[:2]
79
 
80
  distances_sq = np.sum((landmarks[:, :2] - center) ** 2, axis=1)
81
- weights = np.exp(-distances_sq / (2.0 * handle.influence_radius ** 2))
82
 
83
  result[:, 0] += displacement[0] * weights
84
  result[:, 1] += displacement[1] * weights
@@ -94,8 +240,8 @@ def apply_procedure_preset(
94
  procedure: str,
95
  intensity: float = 50.0,
96
  image_size: int = 512,
97
- clinical_flags: Optional["ClinicalFlags"] = None,
98
- displacement_model_path: Optional[str] = None,
99
  noise_scale: float = 0.0,
100
  ) -> FaceLandmarks:
101
  """Apply a surgical procedure preset to landmarks.
@@ -123,7 +269,11 @@ def apply_procedure_preset(
123
  # Data-driven displacement mode
124
  if displacement_model_path is not None:
125
  return _apply_data_driven(
126
- face, procedure, scale, displacement_model_path, noise_scale,
 
 
 
 
127
  )
128
 
129
  indices = PROCEDURE_LANDMARKS[procedure]
@@ -140,6 +290,7 @@ def apply_procedure_preset(
140
  # Bell's palsy: remove handles on the affected (paralyzed) side
141
  if clinical_flags and clinical_flags.bells_palsy:
142
  from landmarkdiff.clinical import get_bells_palsy_side_indices
 
143
  affected = get_bells_palsy_side_indices(clinical_flags.bells_palsy_side)
144
  affected_indices = set()
145
  for region_indices in affected.values():
@@ -219,48 +370,58 @@ def _get_procedure_handles(
219
  left_alar = [240, 236, 141, 363, 370]
220
  for idx in left_alar:
221
  if idx in indices:
222
- handles.append(DeformationHandle(
223
- landmark_index=idx,
224
- displacement=np.array([2.5 * scale, 0.0]),
225
- influence_radius=radius * 0.6,
226
- ))
 
 
227
  # right nostril -> move LEFT (-X)
228
  right_alar = [460, 456, 274, 275, 278, 279]
229
  for idx in right_alar:
230
  if idx in indices:
231
- handles.append(DeformationHandle(
232
- landmark_index=idx,
233
- displacement=np.array([-2.5 * scale, 0.0]),
234
- influence_radius=radius * 0.6,
235
- ))
 
 
236
 
237
  # --- Tip refinement: subtle upward rotation + narrowing ---
238
  tip_indices = [1, 2, 94, 19]
239
  for idx in tip_indices:
240
  if idx in indices:
241
- handles.append(DeformationHandle(
242
- landmark_index=idx,
243
- displacement=np.array([0.0, -2.0 * scale]),
244
- influence_radius=radius * 0.5,
245
- ))
 
 
246
 
247
  # --- Dorsum narrowing: bilateral squeeze of nasal bridge ---
248
  dorsum_left = [195, 197, 236]
249
  for idx in dorsum_left:
250
  if idx in indices:
251
- handles.append(DeformationHandle(
252
- landmark_index=idx,
253
- displacement=np.array([1.5 * scale, 0.0]),
254
- influence_radius=radius * 0.5,
255
- ))
 
 
256
  dorsum_right = [326, 327, 456]
257
  for idx in dorsum_right:
258
  if idx in indices:
259
- handles.append(DeformationHandle(
260
- landmark_index=idx,
261
- displacement=np.array([-1.5 * scale, 0.0]),
262
- influence_radius=radius * 0.5,
263
- ))
 
 
264
 
265
  elif procedure == "blepharoplasty":
266
  # --- Upper lid elevation (primary effect) ---
@@ -268,31 +429,37 @@ def _get_procedure_handles(
268
  upper_lid_right = [386, 385, 384]
269
  for idx in upper_lid_left + upper_lid_right:
270
  if idx in indices:
271
- handles.append(DeformationHandle(
272
- landmark_index=idx,
273
- displacement=np.array([0.0, -2.0 * scale]),
274
- influence_radius=radius,
275
- ))
 
 
276
  # --- Medial/lateral lid corners: less displacement (tapered) ---
277
  corner_left = [158, 157, 133, 33]
278
  corner_right = [387, 388, 362, 263]
279
  for idx in corner_left + corner_right:
280
  if idx in indices:
281
- handles.append(DeformationHandle(
282
- landmark_index=idx,
283
- displacement=np.array([0.0, -0.8 * scale]),
284
- influence_radius=radius * 0.7,
285
- ))
 
 
286
  # --- Subtle lower lid tightening ---
287
  lower_lid_left = [145, 153, 154]
288
  lower_lid_right = [374, 380, 381]
289
  for idx in lower_lid_left + lower_lid_right:
290
  if idx in indices:
291
- handles.append(DeformationHandle(
292
- landmark_index=idx,
293
- displacement=np.array([0.0, 0.5 * scale]),
294
- influence_radius=radius * 0.5,
295
- ))
 
 
296
 
297
  elif procedure == "rhytidectomy":
298
  # Different displacement vectors by anatomical sub-region.
@@ -300,82 +467,100 @@ def _get_procedure_handles(
300
  jowl_left = [132, 136, 172, 58, 150, 176]
301
  for idx in jowl_left:
302
  if idx in indices:
303
- handles.append(DeformationHandle(
304
- landmark_index=idx,
305
- displacement=np.array([-2.5 * scale, -3.0 * scale]),
306
- influence_radius=radius,
307
- ))
 
 
308
  jowl_right = [361, 365, 397, 288, 379, 400]
309
  for idx in jowl_right:
310
  if idx in indices:
311
- handles.append(DeformationHandle(
312
- landmark_index=idx,
313
- displacement=np.array([2.5 * scale, -3.0 * scale]),
314
- influence_radius=radius,
315
- ))
 
 
316
  # Chin/submental: upward only (no lateral)
317
  chin = [152, 148, 377, 378]
318
  for idx in chin:
319
  if idx in indices:
320
- handles.append(DeformationHandle(
321
- landmark_index=idx,
322
- displacement=np.array([0.0, -2.0 * scale]),
323
- influence_radius=radius * 0.8,
324
- ))
 
 
325
  # Temple/upper face: very mild lift
326
  temple_left = [10, 21, 54, 67, 103, 109, 162, 127]
327
  temple_right = [284, 297, 332, 338, 323, 356, 389, 454]
328
  for idx in temple_left:
329
  if idx in indices:
330
- handles.append(DeformationHandle(
331
- landmark_index=idx,
332
- displacement=np.array([-0.5 * scale, -1.0 * scale]),
333
- influence_radius=radius * 0.6,
334
- ))
 
 
335
  for idx in temple_right:
336
  if idx in indices:
337
- handles.append(DeformationHandle(
338
- landmark_index=idx,
339
- displacement=np.array([0.5 * scale, -1.0 * scale]),
340
- influence_radius=radius * 0.6,
341
- ))
 
 
342
 
343
  elif procedure == "orthognathic":
344
  # --- Mandible repositioning: move jaw up and forward (visible as upward in 2D) ---
345
  lower_jaw = [17, 18, 200, 201, 202, 204, 208, 211, 212, 214]
346
  for idx in lower_jaw:
347
  if idx in indices:
348
- handles.append(DeformationHandle(
349
- landmark_index=idx,
350
- displacement=np.array([0.0, -3.0 * scale]),
351
- influence_radius=radius,
352
- ))
 
 
353
  # --- Chin projection: move chin point forward/upward ---
354
  chin_pts = [175, 170, 169, 167, 396]
355
  for idx in chin_pts:
356
  if idx in indices:
357
- handles.append(DeformationHandle(
358
- landmark_index=idx,
359
- displacement=np.array([0.0, -2.0 * scale]),
360
- influence_radius=radius * 0.7,
361
- ))
 
 
362
  # --- Lateral jaw: bilateral symmetric inward pull for narrowing ---
363
  jaw_left = [57, 61, 78, 91, 95, 146, 181]
364
  for idx in jaw_left:
365
  if idx in indices:
366
- handles.append(DeformationHandle(
367
- landmark_index=idx,
368
- displacement=np.array([1.5 * scale, -1.0 * scale]),
369
- influence_radius=radius * 0.8,
370
- ))
 
 
371
  jaw_right = [291, 311, 312, 321, 324, 325, 375, 405]
372
  for idx in jaw_right:
373
  if idx in indices:
374
- handles.append(DeformationHandle(
375
- landmark_index=idx,
376
- displacement=np.array([-1.5 * scale, -1.0 * scale]),
377
- influence_radius=radius * 0.8,
378
- ))
 
 
379
 
380
  elif procedure == "brow_lift":
381
  # --- Brow elevation ---
@@ -386,56 +571,68 @@ def _get_procedure_handles(
386
  left_weights = [0.7, 0.8, 0.9, 1.0, 1.1]
387
  for i, idx in enumerate(brow_left):
388
  if idx in indices:
389
- handles.append(DeformationHandle(
390
- landmark_index=idx,
391
- displacement=np.array([0.0, -4.0 * left_weights[i] * scale]),
392
- influence_radius=radius,
393
- ))
 
 
394
 
395
  right_weights = [0.7, 0.8, 0.9, 1.0, 1.1]
396
  for i, idx in enumerate(brow_right):
397
  if idx in indices:
398
- handles.append(DeformationHandle(
399
- landmark_index=idx,
400
- displacement=np.array([0.0, -4.0 * right_weights[i] * scale]),
401
- influence_radius=radius,
402
- ))
 
 
403
 
404
  # --- Forehead smoothing / subtle lift ---
405
  forehead = [9, 8, 10, 109, 67, 103, 338, 297, 332]
406
  for idx in forehead:
407
  if idx in indices:
408
- handles.append(DeformationHandle(
409
- landmark_index=idx,
410
- displacement=np.array([0.0, -1.5 * scale]),
411
- influence_radius=radius * 1.2,
412
- ))
 
 
413
  elif procedure == "mentoplasty":
414
  # --- Chin tip advancement: move chin forward (upward in 2D) ---
415
  chin_tip = [152, 175]
416
  for idx in chin_tip:
417
  if idx in indices:
418
- handles.append(DeformationHandle(
419
- landmark_index=idx,
420
- displacement=np.array([0.0, -4.0 * scale]),
421
- influence_radius=radius,
422
- ))
 
 
423
  # --- Lower chin contour: follow tip with softer displacement ---
424
  lower_contour = [148, 149, 150, 176, 377]
425
  for idx in lower_contour:
426
  if idx in indices:
427
- handles.append(DeformationHandle(
428
- landmark_index=idx,
429
- displacement=np.array([0.0, -2.5 * scale]),
430
- influence_radius=radius * 0.8,
431
- ))
 
 
432
  # --- Jaw angles: minimal upward pull for natural transition ---
433
  jaw_angles = [171, 396]
434
  for idx in jaw_angles:
435
  if idx in indices:
436
- handles.append(DeformationHandle(
437
- landmark_index=idx,
438
- displacement=np.array([0.0, -1.0 * scale]),
439
- influence_radius=radius * 0.6,
440
- ))
 
 
441
  return handles
 
7
  from __future__ import annotations
8
 
9
  from dataclasses import dataclass
10
+ from typing import TYPE_CHECKING
11
 
12
  import numpy as np
13
 
14
+ from landmarkdiff.landmarks import FaceLandmarks
15
 
16
  if TYPE_CHECKING:
17
  from landmarkdiff.clinical import ClinicalFlags
 
23
 
24
  landmark_index: int
25
  displacement: np.ndarray # (2,) or (3,) pixel displacement
26
+ influence_radius: float # Gaussian RBF radius in pixels
27
 
28
 
29
  # Procedure-specific landmark indices from the technical specification
30
  PROCEDURE_LANDMARKS: dict[str, list[int]] = {
31
  "rhinoplasty": [
32
+ 1,
33
+ 2,
34
+ 4,
35
+ 5,
36
+ 6,
37
+ 19,
38
+ 94,
39
+ 141,
40
+ 168,
41
+ 195,
42
+ 197,
43
+ 236,
44
+ 240,
45
+ 274,
46
+ 275,
47
+ 278,
48
+ 279,
49
+ 294,
50
+ 326,
51
+ 327,
52
+ 360,
53
+ 363,
54
+ 370,
55
+ 456,
56
+ 460,
57
  ],
58
  "blepharoplasty": [
59
+ 33,
60
+ 7,
61
+ 163,
62
+ 144,
63
+ 145,
64
+ 153,
65
+ 154,
66
+ 155,
67
+ 157,
68
+ 158,
69
+ 159,
70
+ 160,
71
+ 161,
72
+ 246,
73
+ 362,
74
+ 382,
75
+ 381,
76
+ 380,
77
+ 374,
78
+ 373,
79
+ 390,
80
+ 249,
81
+ 263,
82
+ 466,
83
+ 388,
84
+ 387,
85
+ 386,
86
+ 385,
87
+ 384,
88
+ 398,
89
  ],
90
  "rhytidectomy": [
91
+ 10,
92
+ 21,
93
+ 54,
94
+ 58,
95
+ 67,
96
+ 93,
97
+ 103,
98
+ 109,
99
+ 127,
100
+ 132,
101
+ 136,
102
+ 150,
103
+ 162,
104
+ 172,
105
+ 176,
106
+ 187,
107
+ 207,
108
+ 213,
109
+ 234,
110
+ 284,
111
+ 297,
112
+ 323,
113
+ 332,
114
+ 338,
115
+ 356,
116
+ 361,
117
+ 365,
118
+ 379,
119
+ 389,
120
+ 397,
121
+ 400,
122
+ 427,
123
+ 454,
124
  ],
125
  "orthognathic": [
126
+ 0,
127
+ 17,
128
+ 18,
129
+ 36,
130
+ 37,
131
+ 39,
132
+ 40,
133
+ 57,
134
+ 61,
135
+ 78,
136
+ 80,
137
+ 81,
138
+ 82,
139
+ 84,
140
+ 87,
141
+ 88,
142
+ 91,
143
+ 95,
144
+ 146,
145
+ 167,
146
+ 169,
147
+ 170,
148
+ 175,
149
+ 181,
150
+ 191,
151
+ 200,
152
+ 201,
153
+ 202,
154
+ 204,
155
+ 208,
156
+ 211,
157
+ 212,
158
+ 214,
159
+ 269,
160
+ 270,
161
+ 291,
162
+ 311,
163
+ 312,
164
+ 317,
165
+ 321,
166
+ 324,
167
+ 325,
168
+ 375,
169
+ 396,
170
+ 405,
171
+ 407,
172
+ 415,
173
  ],
174
  "brow_lift": [
175
+ 70,
176
+ 63,
177
+ 105,
178
+ 66,
179
+ 107, # left brow
180
+ 300,
181
+ 293,
182
+ 334,
183
+ 296,
184
+ 336, # right brow
185
+ 9,
186
+ 8,
187
+ 10,
188
+ 109,
189
+ 67,
190
+ 103,
191
+ 338,
192
+ 297,
193
+ 332, # forehead/upper face
194
  ],
195
  "mentoplasty": [
196
+ 148,
197
+ 149,
198
+ 150,
199
+ 152,
200
+ 171,
201
+ 175,
202
+ 176,
203
+ 377,
204
  ],
205
  }
206
  # Default influence radii per procedure (in pixels at 512x512)
 
224
  displacement = handle.displacement[:2]
225
 
226
  distances_sq = np.sum((landmarks[:, :2] - center) ** 2, axis=1)
227
+ weights = np.exp(-distances_sq / (2.0 * handle.influence_radius**2))
228
 
229
  result[:, 0] += displacement[0] * weights
230
  result[:, 1] += displacement[1] * weights
 
240
  procedure: str,
241
  intensity: float = 50.0,
242
  image_size: int = 512,
243
+ clinical_flags: ClinicalFlags | None = None,
244
+ displacement_model_path: str | None = None,
245
  noise_scale: float = 0.0,
246
  ) -> FaceLandmarks:
247
  """Apply a surgical procedure preset to landmarks.
 
269
  # Data-driven displacement mode
270
  if displacement_model_path is not None:
271
  return _apply_data_driven(
272
+ face,
273
+ procedure,
274
+ scale,
275
+ displacement_model_path,
276
+ noise_scale,
277
  )
278
 
279
  indices = PROCEDURE_LANDMARKS[procedure]
 
290
  # Bell's palsy: remove handles on the affected (paralyzed) side
291
  if clinical_flags and clinical_flags.bells_palsy:
292
  from landmarkdiff.clinical import get_bells_palsy_side_indices
293
+
294
  affected = get_bells_palsy_side_indices(clinical_flags.bells_palsy_side)
295
  affected_indices = set()
296
  for region_indices in affected.values():
 
370
  left_alar = [240, 236, 141, 363, 370]
371
  for idx in left_alar:
372
  if idx in indices:
373
+ handles.append(
374
+ DeformationHandle(
375
+ landmark_index=idx,
376
+ displacement=np.array([2.5 * scale, 0.0]),
377
+ influence_radius=radius * 0.6,
378
+ )
379
+ )
380
  # right nostril -> move LEFT (-X)
381
  right_alar = [460, 456, 274, 275, 278, 279]
382
  for idx in right_alar:
383
  if idx in indices:
384
+ handles.append(
385
+ DeformationHandle(
386
+ landmark_index=idx,
387
+ displacement=np.array([-2.5 * scale, 0.0]),
388
+ influence_radius=radius * 0.6,
389
+ )
390
+ )
391
 
392
  # --- Tip refinement: subtle upward rotation + narrowing ---
393
  tip_indices = [1, 2, 94, 19]
394
  for idx in tip_indices:
395
  if idx in indices:
396
+ handles.append(
397
+ DeformationHandle(
398
+ landmark_index=idx,
399
+ displacement=np.array([0.0, -2.0 * scale]),
400
+ influence_radius=radius * 0.5,
401
+ )
402
+ )
403
 
404
  # --- Dorsum narrowing: bilateral squeeze of nasal bridge ---
405
  dorsum_left = [195, 197, 236]
406
  for idx in dorsum_left:
407
  if idx in indices:
408
+ handles.append(
409
+ DeformationHandle(
410
+ landmark_index=idx,
411
+ displacement=np.array([1.5 * scale, 0.0]),
412
+ influence_radius=radius * 0.5,
413
+ )
414
+ )
415
  dorsum_right = [326, 327, 456]
416
  for idx in dorsum_right:
417
  if idx in indices:
418
+ handles.append(
419
+ DeformationHandle(
420
+ landmark_index=idx,
421
+ displacement=np.array([-1.5 * scale, 0.0]),
422
+ influence_radius=radius * 0.5,
423
+ )
424
+ )
425
 
426
  elif procedure == "blepharoplasty":
427
  # --- Upper lid elevation (primary effect) ---
 
429
  upper_lid_right = [386, 385, 384]
430
  for idx in upper_lid_left + upper_lid_right:
431
  if idx in indices:
432
+ handles.append(
433
+ DeformationHandle(
434
+ landmark_index=idx,
435
+ displacement=np.array([0.0, -2.0 * scale]),
436
+ influence_radius=radius,
437
+ )
438
+ )
439
  # --- Medial/lateral lid corners: less displacement (tapered) ---
440
  corner_left = [158, 157, 133, 33]
441
  corner_right = [387, 388, 362, 263]
442
  for idx in corner_left + corner_right:
443
  if idx in indices:
444
+ handles.append(
445
+ DeformationHandle(
446
+ landmark_index=idx,
447
+ displacement=np.array([0.0, -0.8 * scale]),
448
+ influence_radius=radius * 0.7,
449
+ )
450
+ )
451
  # --- Subtle lower lid tightening ---
452
  lower_lid_left = [145, 153, 154]
453
  lower_lid_right = [374, 380, 381]
454
  for idx in lower_lid_left + lower_lid_right:
455
  if idx in indices:
456
+ handles.append(
457
+ DeformationHandle(
458
+ landmark_index=idx,
459
+ displacement=np.array([0.0, 0.5 * scale]),
460
+ influence_radius=radius * 0.5,
461
+ )
462
+ )
463
 
464
  elif procedure == "rhytidectomy":
465
  # Different displacement vectors by anatomical sub-region.
 
467
  jowl_left = [132, 136, 172, 58, 150, 176]
468
  for idx in jowl_left:
469
  if idx in indices:
470
+ handles.append(
471
+ DeformationHandle(
472
+ landmark_index=idx,
473
+ displacement=np.array([-2.5 * scale, -3.0 * scale]),
474
+ influence_radius=radius,
475
+ )
476
+ )
477
  jowl_right = [361, 365, 397, 288, 379, 400]
478
  for idx in jowl_right:
479
  if idx in indices:
480
+ handles.append(
481
+ DeformationHandle(
482
+ landmark_index=idx,
483
+ displacement=np.array([2.5 * scale, -3.0 * scale]),
484
+ influence_radius=radius,
485
+ )
486
+ )
487
  # Chin/submental: upward only (no lateral)
488
  chin = [152, 148, 377, 378]
489
  for idx in chin:
490
  if idx in indices:
491
+ handles.append(
492
+ DeformationHandle(
493
+ landmark_index=idx,
494
+ displacement=np.array([0.0, -2.0 * scale]),
495
+ influence_radius=radius * 0.8,
496
+ )
497
+ )
498
  # Temple/upper face: very mild lift
499
  temple_left = [10, 21, 54, 67, 103, 109, 162, 127]
500
  temple_right = [284, 297, 332, 338, 323, 356, 389, 454]
501
  for idx in temple_left:
502
  if idx in indices:
503
+ handles.append(
504
+ DeformationHandle(
505
+ landmark_index=idx,
506
+ displacement=np.array([-0.5 * scale, -1.0 * scale]),
507
+ influence_radius=radius * 0.6,
508
+ )
509
+ )
510
  for idx in temple_right:
511
  if idx in indices:
512
+ handles.append(
513
+ DeformationHandle(
514
+ landmark_index=idx,
515
+ displacement=np.array([0.5 * scale, -1.0 * scale]),
516
+ influence_radius=radius * 0.6,
517
+ )
518
+ )
519
 
520
  elif procedure == "orthognathic":
521
  # --- Mandible repositioning: move jaw up and forward (visible as upward in 2D) ---
522
  lower_jaw = [17, 18, 200, 201, 202, 204, 208, 211, 212, 214]
523
  for idx in lower_jaw:
524
  if idx in indices:
525
+ handles.append(
526
+ DeformationHandle(
527
+ landmark_index=idx,
528
+ displacement=np.array([0.0, -3.0 * scale]),
529
+ influence_radius=radius,
530
+ )
531
+ )
532
  # --- Chin projection: move chin point forward/upward ---
533
  chin_pts = [175, 170, 169, 167, 396]
534
  for idx in chin_pts:
535
  if idx in indices:
536
+ handles.append(
537
+ DeformationHandle(
538
+ landmark_index=idx,
539
+ displacement=np.array([0.0, -2.0 * scale]),
540
+ influence_radius=radius * 0.7,
541
+ )
542
+ )
543
  # --- Lateral jaw: bilateral symmetric inward pull for narrowing ---
544
  jaw_left = [57, 61, 78, 91, 95, 146, 181]
545
  for idx in jaw_left:
546
  if idx in indices:
547
+ handles.append(
548
+ DeformationHandle(
549
+ landmark_index=idx,
550
+ displacement=np.array([1.5 * scale, -1.0 * scale]),
551
+ influence_radius=radius * 0.8,
552
+ )
553
+ )
554
  jaw_right = [291, 311, 312, 321, 324, 325, 375, 405]
555
  for idx in jaw_right:
556
  if idx in indices:
557
+ handles.append(
558
+ DeformationHandle(
559
+ landmark_index=idx,
560
+ displacement=np.array([-1.5 * scale, -1.0 * scale]),
561
+ influence_radius=radius * 0.8,
562
+ )
563
+ )
564
 
565
  elif procedure == "brow_lift":
566
  # --- Brow elevation ---
 
571
  left_weights = [0.7, 0.8, 0.9, 1.0, 1.1]
572
  for i, idx in enumerate(brow_left):
573
  if idx in indices:
574
+ handles.append(
575
+ DeformationHandle(
576
+ landmark_index=idx,
577
+ displacement=np.array([0.0, -4.0 * left_weights[i] * scale]),
578
+ influence_radius=radius,
579
+ )
580
+ )
581
 
582
  right_weights = [0.7, 0.8, 0.9, 1.0, 1.1]
583
  for i, idx in enumerate(brow_right):
584
  if idx in indices:
585
+ handles.append(
586
+ DeformationHandle(
587
+ landmark_index=idx,
588
+ displacement=np.array([0.0, -4.0 * right_weights[i] * scale]),
589
+ influence_radius=radius,
590
+ )
591
+ )
592
 
593
  # --- Forehead smoothing / subtle lift ---
594
  forehead = [9, 8, 10, 109, 67, 103, 338, 297, 332]
595
  for idx in forehead:
596
  if idx in indices:
597
+ handles.append(
598
+ DeformationHandle(
599
+ landmark_index=idx,
600
+ displacement=np.array([0.0, -1.5 * scale]),
601
+ influence_radius=radius * 1.2,
602
+ )
603
+ )
604
  elif procedure == "mentoplasty":
605
  # --- Chin tip advancement: move chin forward (upward in 2D) ---
606
  chin_tip = [152, 175]
607
  for idx in chin_tip:
608
  if idx in indices:
609
+ handles.append(
610
+ DeformationHandle(
611
+ landmark_index=idx,
612
+ displacement=np.array([0.0, -4.0 * scale]),
613
+ influence_radius=radius,
614
+ )
615
+ )
616
  # --- Lower chin contour: follow tip with softer displacement ---
617
  lower_contour = [148, 149, 150, 176, 377]
618
  for idx in lower_contour:
619
  if idx in indices:
620
+ handles.append(
621
+ DeformationHandle(
622
+ landmark_index=idx,
623
+ displacement=np.array([0.0, -2.5 * scale]),
624
+ influence_radius=radius * 0.8,
625
+ )
626
+ )
627
  # --- Jaw angles: minimal upward pull for natural transition ---
628
  jaw_angles = [171, 396]
629
  for idx in jaw_angles:
630
  if idx in indices:
631
+ handles.append(
632
+ DeformationHandle(
633
+ landmark_index=idx,
634
+ displacement=np.array([0.0, -1.0 * scale]),
635
+ influence_radius=radius * 0.6,
636
+ )
637
+ )
638
  return handles
landmarkdiff/metrics_agg.py CHANGED
@@ -41,8 +41,12 @@ class MetricsAggregator:
41
  """
42
 
43
  HIGHER_BETTER = {
44
- "ssim": True, "psnr": True, "identity_sim": True,
45
- "lpips": False, "fid": False, "nme": False,
 
 
 
 
46
  }
47
 
48
  def __init__(self) -> None:
@@ -57,13 +61,15 @@ class MetricsAggregator:
57
  **metadata: Any,
58
  ) -> None:
59
  """Add a single evaluation record."""
60
- self.records.append(MetricRecord(
61
- experiment=experiment,
62
- procedure=procedure,
63
- metrics=metrics,
64
- checkpoint_step=checkpoint_step,
65
- metadata=metadata,
66
- ))
 
 
67
 
68
  def add_batch(
69
  self,
@@ -76,7 +82,9 @@ class MetricsAggregator:
76
  """
77
  for rec in records:
78
  proc = rec.get("procedure", "all")
79
- metrics = {k: v for k, v in rec.items() if k != "procedure" and isinstance(v, (int, float))}
 
 
80
  self.add(experiment, proc, metrics)
81
 
82
  @property
@@ -211,10 +219,7 @@ class MetricsAggregator:
211
  val = self.mean(exp, metric, procedure)
212
  if math.isnan(val):
213
  continue
214
- if higher_better and val > best_val:
215
- best_val = val
216
- best_exp = exp
217
- elif not higher_better and val < best_val:
218
  best_val = val
219
  best_exp = exp
220
 
 
41
  """
42
 
43
  HIGHER_BETTER = {
44
+ "ssim": True,
45
+ "psnr": True,
46
+ "identity_sim": True,
47
+ "lpips": False,
48
+ "fid": False,
49
+ "nme": False,
50
  }
51
 
52
  def __init__(self) -> None:
 
61
  **metadata: Any,
62
  ) -> None:
63
  """Add a single evaluation record."""
64
+ self.records.append(
65
+ MetricRecord(
66
+ experiment=experiment,
67
+ procedure=procedure,
68
+ metrics=metrics,
69
+ checkpoint_step=checkpoint_step,
70
+ metadata=metadata,
71
+ )
72
+ )
73
 
74
  def add_batch(
75
  self,
 
82
  """
83
  for rec in records:
84
  proc = rec.get("procedure", "all")
85
+ metrics = {
86
+ k: v for k, v in rec.items() if k != "procedure" and isinstance(v, (int, float))
87
+ }
88
  self.add(experiment, proc, metrics)
89
 
90
  @property
 
219
  val = self.mean(exp, metric, procedure)
220
  if math.isnan(val):
221
  continue
222
+ if (higher_better and val > best_val) or (not higher_better and val < best_val):
 
 
 
223
  best_val = val
224
  best_exp = exp
225
 
landmarkdiff/metrics_viz.py CHANGED
@@ -24,7 +24,6 @@ Usage:
24
 
25
  from __future__ import annotations
26
 
27
- import json
28
  from pathlib import Path
29
  from typing import Any
30
 
@@ -79,25 +78,29 @@ class MetricsVisualizer:
79
  self.dpi = dpi
80
  self.style = style
81
 
82
- def _get_plt(self):
83
  """Import matplotlib with configuration."""
84
  import matplotlib
 
85
  matplotlib.use("Agg")
86
  import matplotlib.pyplot as plt
 
87
  try:
88
  plt.style.use(self.style)
89
  except OSError:
90
  plt.style.use("seaborn-v0_8")
91
  # Publication font sizes
92
- plt.rcParams.update({
93
- "font.size": 10,
94
- "axes.titlesize": 12,
95
- "axes.labelsize": 11,
96
- "xtick.labelsize": 9,
97
- "ytick.labelsize": 9,
98
- "legend.fontsize": 9,
99
- "figure.titlesize": 13,
100
- })
 
 
101
  return plt
102
 
103
  # ------------------------------------------------------------------
@@ -138,7 +141,7 @@ class MetricsVisualizer:
138
  if n_metrics == 1:
139
  axes = [axes]
140
 
141
- for ax, metric in zip(axes, metrics):
142
  values = [metrics_by_procedure[p].get(metric, 0) for p in procedures]
143
  colors = [self.COLORS.get(p, "#999999") for p in procedures]
144
 
@@ -146,16 +149,21 @@ class MetricsVisualizer:
146
  ax.set_xticks(range(n_procs))
147
  ax.set_xticklabels(
148
  [p[:5].title() for p in procedures],
149
- rotation=30, ha="right",
 
150
  )
151
  ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
152
  ax.set_title(self.METRIC_LABELS.get(metric, metric))
153
 
154
  # Add value labels on bars
155
- for bar, val in zip(bars, values):
156
  ax.text(
157
- bar.get_x() + bar.get_width() / 2, bar.get_height(),
158
- f"{val:.3f}", ha="center", va="bottom", fontsize=8,
 
 
 
 
159
  )
160
 
161
  fig.suptitle(title, fontweight="bold")
@@ -192,9 +200,8 @@ class MetricsVisualizer:
192
 
193
  if metrics is None:
194
  metrics = sorted(
195
- set.intersection(
196
- *(set(v.keys()) for v in experiments.values())
197
- ) & set(self.METRIC_LABELS.keys())
198
  )
199
 
200
  n_metrics = len(metrics)
@@ -258,9 +265,7 @@ class MetricsVisualizer:
258
  plt = self._get_plt()
259
 
260
  fitz_types = sorted(metrics_by_type.keys())
261
- procedures = sorted(
262
- set.union(*(set(v.keys()) for v in metrics_by_type.values()))
263
- )
264
 
265
  # Build matrix
266
  matrix = np.zeros((len(fitz_types), len(procedures)))
@@ -268,7 +273,9 @@ class MetricsVisualizer:
268
  for j, proc in enumerate(procedures):
269
  matrix[i, j] = metrics_by_type[ft].get(proc, 0)
270
 
271
- fig, ax = plt.subplots(figsize=(max(6, len(procedures) * 1.5), max(4, len(fitz_types) * 0.8)))
 
 
272
 
273
  cmap = "RdYlGn" if self.METRIC_HIGHER_BETTER.get(metric, True) else "RdYlGn_r"
274
  im = ax.imshow(matrix, cmap=cmap, aspect="auto")
@@ -282,9 +289,15 @@ class MetricsVisualizer:
282
  # Annotate cells
283
  for i in range(len(fitz_types)):
284
  for j in range(len(procedures)):
285
- ax.text(j, i, f"{matrix[i, j]:.3f}",
286
- ha="center", va="center", fontsize=9,
287
- color="white" if matrix[i, j] < np.median(matrix) else "black")
 
 
 
 
 
 
288
 
289
  fig.colorbar(im, ax=ax, label=self.METRIC_LABELS.get(metric, metric))
290
 
@@ -328,18 +341,21 @@ class MetricsVisualizer:
328
  fig, ax = plt.subplots(figsize=(max(6, len(groups) * 1.2), 5))
329
 
330
  bp = ax.boxplot(
331
- data, patch_artist=True, widths=0.6,
 
 
332
  medianprops={"color": "black", "linewidth": 1.5},
333
  )
334
 
335
  colors = [self.COLORS.get(g, "#4C72B0") for g in groups]
336
- for patch, color in zip(bp["boxes"], colors):
337
  patch.set_facecolor(color)
338
  patch.set_alpha(0.7)
339
 
340
  ax.set_xticklabels(
341
  [g.title() for g in groups],
342
- rotation=30, ha="right",
 
343
  )
344
  ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
345
 
@@ -348,9 +364,16 @@ class MetricsVisualizer:
348
  ax.set_title(title, fontweight="bold")
349
 
350
  # Add sample count annotations
351
- for i, (g, vals) in enumerate(zip(groups, data)):
352
- ax.text(i + 1, ax.get_ylim()[0], f"n={len(vals)}",
353
- ha="center", va="bottom", fontsize=8, color="gray")
 
 
 
 
 
 
 
354
 
355
  fig.tight_layout()
356
  out_path = self.output_dir / filename
@@ -430,10 +453,12 @@ class MetricsVisualizer:
430
  parts.append(val_str)
431
  lines.append(" & ".join(parts) + " \\\\")
432
 
433
- lines.extend([
434
- "\\bottomrule",
435
- "\\end{tabular}",
436
- "\\end{table}",
437
- ])
 
 
438
 
439
  return "\n".join(lines)
 
24
 
25
  from __future__ import annotations
26
 
 
27
  from pathlib import Path
28
  from typing import Any
29
 
 
78
  self.dpi = dpi
79
  self.style = style
80
 
81
+ def _get_plt(self) -> Any:
82
  """Import matplotlib with configuration."""
83
  import matplotlib
84
+
85
  matplotlib.use("Agg")
86
  import matplotlib.pyplot as plt
87
+
88
  try:
89
  plt.style.use(self.style)
90
  except OSError:
91
  plt.style.use("seaborn-v0_8")
92
  # Publication font sizes
93
+ plt.rcParams.update(
94
+ {
95
+ "font.size": 10,
96
+ "axes.titlesize": 12,
97
+ "axes.labelsize": 11,
98
+ "xtick.labelsize": 9,
99
+ "ytick.labelsize": 9,
100
+ "legend.fontsize": 9,
101
+ "figure.titlesize": 13,
102
+ }
103
+ )
104
  return plt
105
 
106
  # ------------------------------------------------------------------
 
141
  if n_metrics == 1:
142
  axes = [axes]
143
 
144
+ for ax, metric in zip(axes, metrics, strict=False):
145
  values = [metrics_by_procedure[p].get(metric, 0) for p in procedures]
146
  colors = [self.COLORS.get(p, "#999999") for p in procedures]
147
 
 
149
  ax.set_xticks(range(n_procs))
150
  ax.set_xticklabels(
151
  [p[:5].title() for p in procedures],
152
+ rotation=30,
153
+ ha="right",
154
  )
155
  ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
156
  ax.set_title(self.METRIC_LABELS.get(metric, metric))
157
 
158
  # Add value labels on bars
159
+ for bar, val in zip(bars, values, strict=False):
160
  ax.text(
161
+ bar.get_x() + bar.get_width() / 2,
162
+ bar.get_height(),
163
+ f"{val:.3f}",
164
+ ha="center",
165
+ va="bottom",
166
+ fontsize=8,
167
  )
168
 
169
  fig.suptitle(title, fontweight="bold")
 
200
 
201
  if metrics is None:
202
  metrics = sorted(
203
+ set.intersection(*(set(v.keys()) for v in experiments.values()))
204
+ & set(self.METRIC_LABELS.keys())
 
205
  )
206
 
207
  n_metrics = len(metrics)
 
265
  plt = self._get_plt()
266
 
267
  fitz_types = sorted(metrics_by_type.keys())
268
+ procedures = sorted(set.union(*(set(v.keys()) for v in metrics_by_type.values())))
 
 
269
 
270
  # Build matrix
271
  matrix = np.zeros((len(fitz_types), len(procedures)))
 
273
  for j, proc in enumerate(procedures):
274
  matrix[i, j] = metrics_by_type[ft].get(proc, 0)
275
 
276
+ fig, ax = plt.subplots(
277
+ figsize=(max(6, len(procedures) * 1.5), max(4, len(fitz_types) * 0.8))
278
+ )
279
 
280
  cmap = "RdYlGn" if self.METRIC_HIGHER_BETTER.get(metric, True) else "RdYlGn_r"
281
  im = ax.imshow(matrix, cmap=cmap, aspect="auto")
 
289
  # Annotate cells
290
  for i in range(len(fitz_types)):
291
  for j in range(len(procedures)):
292
+ ax.text(
293
+ j,
294
+ i,
295
+ f"{matrix[i, j]:.3f}",
296
+ ha="center",
297
+ va="center",
298
+ fontsize=9,
299
+ color="white" if matrix[i, j] < np.median(matrix) else "black",
300
+ )
301
 
302
  fig.colorbar(im, ax=ax, label=self.METRIC_LABELS.get(metric, metric))
303
 
 
341
  fig, ax = plt.subplots(figsize=(max(6, len(groups) * 1.2), 5))
342
 
343
  bp = ax.boxplot(
344
+ data,
345
+ patch_artist=True,
346
+ widths=0.6,
347
  medianprops={"color": "black", "linewidth": 1.5},
348
  )
349
 
350
  colors = [self.COLORS.get(g, "#4C72B0") for g in groups]
351
+ for patch, color in zip(bp["boxes"], colors, strict=False):
352
  patch.set_facecolor(color)
353
  patch.set_alpha(0.7)
354
 
355
  ax.set_xticklabels(
356
  [g.title() for g in groups],
357
+ rotation=30,
358
+ ha="right",
359
  )
360
  ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
361
 
 
364
  ax.set_title(title, fontweight="bold")
365
 
366
  # Add sample count annotations
367
+ for i, (_g, vals) in enumerate(zip(groups, data, strict=False)):
368
+ ax.text(
369
+ i + 1,
370
+ ax.get_ylim()[0],
371
+ f"n={len(vals)}",
372
+ ha="center",
373
+ va="bottom",
374
+ fontsize=8,
375
+ color="gray",
376
+ )
377
 
378
  fig.tight_layout()
379
  out_path = self.output_dir / filename
 
453
  parts.append(val_str)
454
  lines.append(" & ".join(parts) + " \\\\")
455
 
456
+ lines.extend(
457
+ [
458
+ "\\bottomrule",
459
+ "\\end{tabular}",
460
+ "\\end{table}",
461
+ ]
462
+ )
463
 
464
  return "\n".join(lines)
landmarkdiff/model_registry.py CHANGED
@@ -139,9 +139,7 @@ class ModelRegistry:
139
  step = int(parts[-1])
140
 
141
  # Compute size
142
- size_mb = sum(
143
- f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()
144
- ) / (1024 * 1024)
145
 
146
  return ModelEntry(
147
  name=ckpt_dir.name,
@@ -195,16 +193,15 @@ class ModelRegistry:
195
  Returns:
196
  Best ModelEntry, or None if no models have the metric.
197
  """
198
- candidates = [
199
- m for m in self._models.values()
200
- if metric in m.metrics
201
- ]
202
  if not candidates:
203
  return None
204
 
205
- return min(candidates, key=lambda m: m.metrics[metric]) \
206
- if lower_is_better else \
207
- max(candidates, key=lambda m: m.metrics[metric])
 
 
208
 
209
  def get_by_step(self, step: int) -> ModelEntry | None:
210
  """Get a model by its training step."""
@@ -266,9 +263,7 @@ class ModelRegistry:
266
  raise KeyError(f"Checkpoint '{name}' not found in registry")
267
 
268
  if use_ema and entry.has_ema:
269
- return ControlNetModel.from_pretrained(
270
- str(entry.path / "controlnet_ema")
271
- )
272
 
273
  # Fallback: load from training state
274
  state = self.load(name)
@@ -356,9 +351,7 @@ class ModelRegistry:
356
  for metric in sorted(all_metrics):
357
  values = [m.metrics[metric] for m in models if metric in m.metrics]
358
  if values:
359
- lines.append(
360
- f" {metric}: {min(values):.4f} — {max(values):.4f}"
361
- )
362
 
363
  return "\n".join(lines)
364
 
 
139
  step = int(parts[-1])
140
 
141
  # Compute size
142
+ size_mb = sum(f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()) / (1024 * 1024)
 
 
143
 
144
  return ModelEntry(
145
  name=ckpt_dir.name,
 
193
  Returns:
194
  Best ModelEntry, or None if no models have the metric.
195
  """
196
+ candidates = [m for m in self._models.values() if metric in m.metrics]
 
 
 
197
  if not candidates:
198
  return None
199
 
200
+ return (
201
+ min(candidates, key=lambda m: m.metrics[metric])
202
+ if lower_is_better
203
+ else max(candidates, key=lambda m: m.metrics[metric])
204
+ )
205
 
206
  def get_by_step(self, step: int) -> ModelEntry | None:
207
  """Get a model by its training step."""
 
263
  raise KeyError(f"Checkpoint '{name}' not found in registry")
264
 
265
  if use_ema and entry.has_ema:
266
+ return ControlNetModel.from_pretrained(str(entry.path / "controlnet_ema"))
 
 
267
 
268
  # Fallback: load from training state
269
  state = self.load(name)
 
351
  for metric in sorted(all_metrics):
352
  values = [m.metrics[metric] for m in models if metric in m.metrics]
353
  if values:
354
+ lines.append(f" {metric}: {min(values):.4f} — {max(values):.4f}")
 
 
355
 
356
  return "\n".join(lines)
357
 
landmarkdiff/postprocess.py CHANGED
@@ -54,13 +54,10 @@ def laplacian_pyramid_blend(
54
  mask_f = mask.astype(np.float32)
55
  if mask_f.max() > 1.0:
56
  mask_f = mask_f / 255.0
57
- if mask_f.ndim == 2:
58
- mask_3ch = np.stack([mask_f] * 3, axis=-1)
59
- else:
60
- mask_3ch = mask_f
61
 
62
  # Make dimensions divisible by 2^levels
63
- factor = 2 ** levels
64
  new_h = (h + factor - 1) // factor * factor
65
  new_w = (w + factor - 1) // factor * factor
66
 
@@ -232,24 +229,27 @@ def restore_face_codeformer(
232
  Restored BGR image, or original if CodeFormer unavailable.
233
  """
234
  try:
 
235
  from codeformer.basicsr.utils import img2tensor, tensor2img
236
- from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
237
  from codeformer.basicsr.utils.download_util import load_file_from_url
238
- import torch
239
  from torchvision.transforms.functional import normalize as tv_normalize
240
  except ImportError:
241
  return image
242
 
243
  try:
244
  global _CODEFORMER_MODEL, _CODEFORMER_HELPER
245
- from codeformer.inference_codeformer import set_realesrgan as _unused # noqa: F401
246
  from codeformer.basicsr.archs.codeformer_arch import CodeFormer as CodeFormerArch
 
247
 
248
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
249
 
250
  if _CODEFORMER_MODEL is None:
251
  model = CodeFormerArch(
252
- dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
 
 
 
253
  connect_list=["32", "64", "128", "256"],
254
  ).to(device)
255
 
@@ -316,16 +316,18 @@ def enhance_background_realesrgan(
316
  Enhanced BGR image at original resolution.
317
  """
318
  try:
319
- from realesrgan import RealESRGANer
320
- from basicsr.archs.rrdbnet_arch import RRDBNet
321
  import torch
 
 
322
  except ImportError:
323
  return image
324
 
325
  try:
326
  global _REALESRGAN_UPSAMPLER
327
  if _REALESRGAN_UPSAMPLER is None:
328
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
 
 
329
  _REALESRGAN_UPSAMPLER = RealESRGANer(
330
  scale=4,
331
  model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
@@ -345,15 +347,11 @@ def enhance_background_realesrgan(
345
  mask_f = mask.astype(np.float32)
346
  if mask_f.max() > 1.0:
347
  mask_f /= 255.0
348
- if mask_f.ndim == 2:
349
- mask_3ch = np.stack([mask_f] * 3, axis=-1)
350
- else:
351
- mask_3ch = mask_f
352
 
353
  # Keep face region from original, use enhanced for background
354
  result = (
355
- image.astype(np.float32) * mask_3ch
356
- + enhanced.astype(np.float32) * (1.0 - mask_3ch)
357
  ).astype(np.uint8)
358
  return result
359
  except Exception:
@@ -414,9 +412,10 @@ def verify_identity_arcface(
414
  orig_emb = orig_faces[0].embedding
415
  result_emb = result_faces[0].embedding
416
 
417
- sim = float(np.dot(orig_emb, result_emb) / (
418
- np.linalg.norm(orig_emb) * np.linalg.norm(result_emb) + 1e-8
419
- ))
 
420
  sim = float(np.clip(sim, 0, 1))
421
 
422
  passed = sim >= threshold
@@ -437,6 +436,7 @@ def verify_identity_arcface(
437
  def _has_cuda() -> bool:
438
  try:
439
  import torch
 
440
  return torch.cuda.is_available()
441
  except ImportError:
442
  return False
@@ -465,7 +465,7 @@ def histogram_match_skin(
465
  if not np.any(mask_bool):
466
  return source
467
 
468
- result = source.copy()
469
  src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
470
  ref_lab = cv2.cvtColor(reference, cv2.COLOR_BGR2LAB).astype(np.float32)
471
 
@@ -574,20 +574,18 @@ def full_postprocess(
574
  mask_f = mask.astype(np.float32)
575
  if mask_f.max() > 1.0:
576
  mask_f /= 255.0
577
- if mask_f.ndim == 2:
578
- mask_3ch = np.stack([mask_f] * 3, axis=-1)
579
- else:
580
- mask_3ch = mask_f
581
  composited = (
582
- result.astype(np.float32) * mask_3ch
583
- + original.astype(np.float32) * (1.0 - mask_3ch)
584
  ).astype(np.uint8)
585
 
586
  # Step 6: Neural identity verification
587
  identity_check = {"similarity": -1.0, "passed": True, "message": "skipped"}
588
  if verify_identity:
589
  identity_check = verify_identity_arcface(
590
- original, composited, threshold=identity_threshold,
 
 
591
  )
592
 
593
  return {
 
54
  mask_f = mask.astype(np.float32)
55
  if mask_f.max() > 1.0:
56
  mask_f = mask_f / 255.0
57
+ mask_3ch = np.stack([mask_f] * 3, axis=-1) if mask_f.ndim == 2 else mask_f
 
 
 
58
 
59
  # Make dimensions divisible by 2^levels
60
+ factor = 2**levels
61
  new_h = (h + factor - 1) // factor * factor
62
  new_w = (w + factor - 1) // factor * factor
63
 
 
229
  Restored BGR image, or original if CodeFormer unavailable.
230
  """
231
  try:
232
+ import torch
233
  from codeformer.basicsr.utils import img2tensor, tensor2img
 
234
  from codeformer.basicsr.utils.download_util import load_file_from_url
235
+ from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
236
  from torchvision.transforms.functional import normalize as tv_normalize
237
  except ImportError:
238
  return image
239
 
240
  try:
241
  global _CODEFORMER_MODEL, _CODEFORMER_HELPER
 
242
  from codeformer.basicsr.archs.codeformer_arch import CodeFormer as CodeFormerArch
243
+ from codeformer.inference_codeformer import set_realesrgan as _unused # noqa: F401
244
 
245
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
246
 
247
  if _CODEFORMER_MODEL is None:
248
  model = CodeFormerArch(
249
+ dim_embd=512,
250
+ codebook_size=1024,
251
+ n_head=8,
252
+ n_layers=9,
253
  connect_list=["32", "64", "128", "256"],
254
  ).to(device)
255
 
 
316
  Enhanced BGR image at original resolution.
317
  """
318
  try:
 
 
319
  import torch
320
+ from basicsr.archs.rrdbnet_arch import RRDBNet
321
+ from realesrgan import RealESRGANer
322
  except ImportError:
323
  return image
324
 
325
  try:
326
  global _REALESRGAN_UPSAMPLER
327
  if _REALESRGAN_UPSAMPLER is None:
328
+ model = RRDBNet(
329
+ num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
330
+ )
331
  _REALESRGAN_UPSAMPLER = RealESRGANer(
332
  scale=4,
333
  model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
 
347
  mask_f = mask.astype(np.float32)
348
  if mask_f.max() > 1.0:
349
  mask_f /= 255.0
350
+ mask_3ch = np.stack([mask_f] * 3, axis=-1) if mask_f.ndim == 2 else mask_f
 
 
 
351
 
352
  # Keep face region from original, use enhanced for background
353
  result = (
354
+ image.astype(np.float32) * mask_3ch + enhanced.astype(np.float32) * (1.0 - mask_3ch)
 
355
  ).astype(np.uint8)
356
  return result
357
  except Exception:
 
412
  orig_emb = orig_faces[0].embedding
413
  result_emb = result_faces[0].embedding
414
 
415
+ sim = float(
416
+ np.dot(orig_emb, result_emb)
417
+ / (np.linalg.norm(orig_emb) * np.linalg.norm(result_emb) + 1e-8)
418
+ )
419
  sim = float(np.clip(sim, 0, 1))
420
 
421
  passed = sim >= threshold
 
436
  def _has_cuda() -> bool:
437
  try:
438
  import torch
439
+
440
  return torch.cuda.is_available()
441
  except ImportError:
442
  return False
 
465
  if not np.any(mask_bool):
466
  return source
467
 
468
+ source.copy()
469
  src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
470
  ref_lab = cv2.cvtColor(reference, cv2.COLOR_BGR2LAB).astype(np.float32)
471
 
 
574
  mask_f = mask.astype(np.float32)
575
  if mask_f.max() > 1.0:
576
  mask_f /= 255.0
577
+ mask_3ch = np.stack([mask_f] * 3, axis=-1) if mask_f.ndim == 2 else mask_f
 
 
 
578
  composited = (
579
+ result.astype(np.float32) * mask_3ch + original.astype(np.float32) * (1.0 - mask_3ch)
 
580
  ).astype(np.uint8)
581
 
582
  # Step 6: Neural identity verification
583
  identity_check = {"similarity": -1.0, "passed": True, "message": "skipped"}
584
  if verify_identity:
585
  identity_check = verify_identity_arcface(
586
+ original,
587
+ composited,
588
+ threshold=identity_threshold,
589
  )
590
 
591
  return {
landmarkdiff/py.typed ADDED
File without changes
landmarkdiff/safety.py CHANGED
@@ -26,7 +26,6 @@ Usage:
26
  from __future__ import annotations
27
 
28
  from dataclasses import dataclass, field
29
- from typing import Optional
30
 
31
  import cv2
32
  import numpy as np
@@ -35,6 +34,7 @@ import numpy as np
35
  @dataclass
36
  class SafetyResult:
37
  """Result of safety validation checks."""
 
38
  passed: bool = True
39
  failures: list[str] = field(default_factory=list)
40
  warnings: list[str] = field(default_factory=list)
@@ -124,9 +124,7 @@ class SafetyValidator:
124
 
125
  return result
126
 
127
- def _check_face_confidence(
128
- self, result: SafetyResult, confidence: float
129
- ) -> None:
130
  """Check face detection confidence."""
131
  if confidence < self.min_face_confidence:
132
  result.add_failure(
@@ -147,14 +145,14 @@ class SafetyValidator:
147
  """Check identity preservation using ArcFace similarity."""
148
  try:
149
  from landmarkdiff.evaluation import compute_identity_similarity
 
150
  sim = compute_identity_similarity(output_image, input_image)
151
  result.details["identity_similarity"] = float(sim)
152
 
153
  if sim < self.identity_threshold:
154
  result.add_failure(
155
  "identity",
156
- f"Identity similarity {sim:.3f} below threshold "
157
- f"{self.identity_threshold}",
158
  )
159
  else:
160
  result.add_pass("identity")
@@ -257,9 +255,7 @@ class SafetyValidator:
257
  else:
258
  result.add_pass("procedure_region")
259
 
260
- def _check_output_quality(
261
- self, result: SafetyResult, output: np.ndarray
262
- ) -> None:
263
  """Check output image quality (not blank, not corrupted)."""
264
  if output is None or output.size == 0:
265
  result.add_failure("output_quality", "Output image is empty")
@@ -346,8 +342,7 @@ class SafetyValidator:
346
  cv2.addWeighted(overlay, opacity, result, 1 - opacity, 0, result)
347
 
348
  # White text
349
- cv2.putText(result, text, (x, y), font, font_scale,
350
- (255, 255, 255), thickness, cv2.LINE_AA)
351
 
352
  return result
353
 
@@ -371,7 +366,7 @@ class SafetyValidator:
371
  "procedure": procedure,
372
  "intensity": intensity,
373
  "disclaimer": "AI-generated surgical prediction for visualization only. "
374
- "Not a guarantee of surgical outcome.",
375
  }
376
 
377
  # Save as sidecar JSON (PNG doesn't have easy EXIF support)
 
26
  from __future__ import annotations
27
 
28
  from dataclasses import dataclass, field
 
29
 
30
  import cv2
31
  import numpy as np
 
34
  @dataclass
35
  class SafetyResult:
36
  """Result of safety validation checks."""
37
+
38
  passed: bool = True
39
  failures: list[str] = field(default_factory=list)
40
  warnings: list[str] = field(default_factory=list)
 
124
 
125
  return result
126
 
127
+ def _check_face_confidence(self, result: SafetyResult, confidence: float) -> None:
 
 
128
  """Check face detection confidence."""
129
  if confidence < self.min_face_confidence:
130
  result.add_failure(
 
145
  """Check identity preservation using ArcFace similarity."""
146
  try:
147
  from landmarkdiff.evaluation import compute_identity_similarity
148
+
149
  sim = compute_identity_similarity(output_image, input_image)
150
  result.details["identity_similarity"] = float(sim)
151
 
152
  if sim < self.identity_threshold:
153
  result.add_failure(
154
  "identity",
155
+ f"Identity similarity {sim:.3f} below threshold {self.identity_threshold}",
 
156
  )
157
  else:
158
  result.add_pass("identity")
 
255
  else:
256
  result.add_pass("procedure_region")
257
 
258
+ def _check_output_quality(self, result: SafetyResult, output: np.ndarray) -> None:
 
 
259
  """Check output image quality (not blank, not corrupted)."""
260
  if output is None or output.size == 0:
261
  result.add_failure("output_quality", "Output image is empty")
 
342
  cv2.addWeighted(overlay, opacity, result, 1 - opacity, 0, result)
343
 
344
  # White text
345
+ cv2.putText(result, text, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
 
346
 
347
  return result
348
 
 
366
  "procedure": procedure,
367
  "intensity": intensity,
368
  "disclaimer": "AI-generated surgical prediction for visualization only. "
369
+ "Not a guarantee of surgical outcome.",
370
  }
371
 
372
  # Save as sidecar JSON (PNG doesn't have easy EXIF support)
landmarkdiff/synthetic/__init__.py CHANGED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Synthetic data generation for ControlNet fine-tuning.
2
+
3
+ Modules:
4
+ - pair_generator: Generate training pairs from face images
5
+ - augmentation: Clinical degradation augmentations
6
+ - tps_warp: TPS warping with rigid region preservation
7
+ """
8
+
9
+ from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation
10
+ from landmarkdiff.synthetic.pair_generator import (
11
+ TrainingPair,
12
+ generate_pair,
13
+ generate_pairs_from_directory,
14
+ )
15
+ from landmarkdiff.synthetic.tps_warp import warp_image_tps
16
+
17
+ __all__ = [
18
+ "TrainingPair",
19
+ "apply_clinical_augmentation",
20
+ "generate_pair",
21
+ "generate_pairs_from_directory",
22
+ "warp_image_tps",
23
+ ]
landmarkdiff/synthetic/augmentation.py CHANGED
@@ -7,8 +7,8 @@ Applied from day 1 - domain gap prevention, not afterthought.
7
 
8
  from __future__ import annotations
9
 
 
10
  from dataclasses import dataclass
11
- from typing import Callable
12
 
13
  import cv2
14
  import numpy as np
@@ -35,7 +35,7 @@ def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.nda
35
  # Distance-based falloff
36
  y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
37
  dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
38
- max_dist = np.sqrt(w ** 2 + h ** 2)
39
  light_map = 1.0 - (dist / max_dist) * intensity
40
 
41
  light_map = np.clip(light_map, 0.3, 1.0)
@@ -132,7 +132,7 @@ def vignette(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
132
  y, x = np.mgrid[0:h, 0:w].astype(np.float32)
133
  cx, cy = w / 2, h / 2
134
  dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
135
- max_dist = np.sqrt(cx ** 2 + cy ** 2)
136
 
137
  mask = 1 - strength * (dist / max_dist) ** 2
138
  mask = np.clip(mask, 0.3, 1.0)
 
7
 
8
  from __future__ import annotations
9
 
10
+ from collections.abc import Callable
11
  from dataclasses import dataclass
 
12
 
13
  import cv2
14
  import numpy as np
 
35
  # Distance-based falloff
36
  y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
37
  dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
38
+ max_dist = np.sqrt(w**2 + h**2)
39
  light_map = 1.0 - (dist / max_dist) * intensity
40
 
41
  light_map = np.clip(light_map, 0.3, 1.0)
 
132
  y, x = np.mgrid[0:h, 0:w].astype(np.float32)
133
  cx, cy = w / 2, h / 2
134
  dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
135
+ max_dist = np.sqrt(cx**2 + cy**2)
136
 
137
  mask = 1 - strength * (dist / max_dist) ** 2
138
  mask = np.clip(mask, 0.3, 1.0)
landmarkdiff/synthetic/pair_generator.py CHANGED
@@ -6,33 +6,32 @@ Augmentations on INPUT only, never target.
6
 
7
  from __future__ import annotations
8
 
 
9
  from dataclasses import dataclass
10
  from pathlib import Path
11
- from typing import Iterator
12
 
13
  import cv2
14
  import numpy as np
15
 
16
- from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks, render_landmark_image
17
  from landmarkdiff.conditioning import generate_conditioning
 
18
  from landmarkdiff.manipulation import (
19
- PROCEDURE_LANDMARKS,
20
  apply_procedure_preset,
21
  )
22
  from landmarkdiff.masking import generate_surgical_mask
23
  from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation
24
- from landmarkdiff.synthetic.tps_warp import warp_image_tps, generate_random_warp
25
 
26
 
27
  @dataclass(frozen=True)
28
  class TrainingPair:
29
  """A single training sample for ControlNet fine-tuning."""
30
 
31
- input_image: np.ndarray # augmented input (512x512 BGR)
32
- target_image: np.ndarray # clean target (512x512 BGR) - TPS-warped original
33
- conditioning: np.ndarray # landmark rendering (512x512 BGR)
34
- canny: np.ndarray # canny edge map (512x512 grayscale)
35
- mask: np.ndarray # feathered surgical mask (512x512 float32)
36
  procedure: str
37
  intensity: float
38
 
@@ -104,10 +103,7 @@ def generate_pairs_from_directory(
104
  image_dir = Path(image_dir)
105
 
106
  extensions = {".jpg", ".jpeg", ".png", ".webp"}
107
- image_files = sorted(
108
- f for f in image_dir.iterdir()
109
- if f.suffix.lower() in extensions
110
- )
111
 
112
  if not image_files:
113
  raise FileNotFoundError(f"No images found in {image_dir}")
 
6
 
7
  from __future__ import annotations
8
 
9
+ from collections.abc import Iterator
10
  from dataclasses import dataclass
11
  from pathlib import Path
 
12
 
13
  import cv2
14
  import numpy as np
15
 
 
16
  from landmarkdiff.conditioning import generate_conditioning
17
+ from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
18
  from landmarkdiff.manipulation import (
 
19
  apply_procedure_preset,
20
  )
21
  from landmarkdiff.masking import generate_surgical_mask
22
  from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation
23
+ from landmarkdiff.synthetic.tps_warp import warp_image_tps
24
 
25
 
26
  @dataclass(frozen=True)
27
  class TrainingPair:
28
  """A single training sample for ControlNet fine-tuning."""
29
 
30
+ input_image: np.ndarray # augmented input (512x512 BGR)
31
+ target_image: np.ndarray # clean target (512x512 BGR) - TPS-warped original
32
+ conditioning: np.ndarray # landmark rendering (512x512 BGR)
33
+ canny: np.ndarray # canny edge map (512x512 grayscale)
34
+ mask: np.ndarray # feathered surgical mask (512x512 float32)
35
  procedure: str
36
  intensity: float
37
 
 
103
  image_dir = Path(image_dir)
104
 
105
  extensions = {".jpg", ".jpeg", ".png", ".webp"}
106
+ image_files = sorted(f for f in image_dir.iterdir() if f.suffix.lower() in extensions)
 
 
 
107
 
108
  if not image_files:
109
  raise FileNotFoundError(f"No images found in {image_dir}")
landmarkdiff/synthetic/tps_warp.py CHANGED
@@ -156,7 +156,7 @@ def _solve_tps_weights(
156
 
157
  # Build kernel matrix K (vectorized)
158
  diff = control_pts[:, np.newaxis, :] - control_pts[np.newaxis, :, :] # (n, n, 2)
159
- r_mat = np.sqrt((diff ** 2).sum(axis=2)) # (n, n)
160
  K = np.zeros((n, n))
161
  nz = r_mat > 0
162
  K[nz] = r_mat[nz] ** 2 * np.log(r_mat[nz])
@@ -205,7 +205,7 @@ def _evaluate_tps(
205
  # Compute all distances at once: (M, n)
206
  dx = batch[:, 0:1] - control_pts[:, 0] # (M, n) via broadcasting
207
  dy = batch[:, 1:2] - control_pts[:, 1] # (M, n)
208
- r = np.sqrt(dx ** 2 + dy ** 2)
209
 
210
  # TPS kernel: r^2 * log(r), with r=0 -> 0
211
  kernel = np.zeros_like(r)
@@ -230,9 +230,8 @@ def _compute_rigid_translation(
230
  inside = []
231
  for i, (x, y) in enumerate(src):
232
  ix, iy = int(x), int(y)
233
- if 0 <= ix < width and 0 <= iy < height:
234
- if mask[iy, ix] > 0:
235
- inside.append(i)
236
 
237
  if not inside:
238
  return np.array([0.0, 0.0])
 
156
 
157
  # Build kernel matrix K (vectorized)
158
  diff = control_pts[:, np.newaxis, :] - control_pts[np.newaxis, :, :] # (n, n, 2)
159
+ r_mat = np.sqrt((diff**2).sum(axis=2)) # (n, n)
160
  K = np.zeros((n, n))
161
  nz = r_mat > 0
162
  K[nz] = r_mat[nz] ** 2 * np.log(r_mat[nz])
 
205
  # Compute all distances at once: (M, n)
206
  dx = batch[:, 0:1] - control_pts[:, 0] # (M, n) via broadcasting
207
  dy = batch[:, 1:2] - control_pts[:, 1] # (M, n)
208
+ r = np.sqrt(dx**2 + dy**2)
209
 
210
  # TPS kernel: r^2 * log(r), with r=0 -> 0
211
  kernel = np.zeros_like(r)
 
230
  inside = []
231
  for i, (x, y) in enumerate(src):
232
  ix, iy = int(x), int(y)
233
+ if 0 <= ix < width and 0 <= iy < height and mask[iy, ix] > 0:
234
+ inside.append(i)
 
235
 
236
  if not inside:
237
  return np.array([0.0, 0.0])
landmarkdiff/validation.py CHANGED
@@ -14,13 +14,11 @@ import json
14
  import time
15
  from pathlib import Path
16
 
17
- import cv2
18
  import numpy as np
19
  import torch
20
- import torch.nn.functional as F
21
  from PIL import Image
22
 
23
- from landmarkdiff.evaluation import compute_ssim, compute_lpips, compute_nme
24
 
25
 
26
  class ValidationCallback:
@@ -116,13 +114,18 @@ class ValidationCallback:
116
 
117
  # ControlNet
118
  down_samples, mid_sample = controlnet(
119
- scaled, t, encoder_hidden_states=encoder_hidden_states,
120
- controlnet_cond=conditioning, return_dict=False,
 
 
 
121
  )
122
 
123
  # UNet with ControlNet residuals
124
  noise_pred = unet(
125
- scaled, t, encoder_hidden_states=encoder_hidden_states,
 
 
126
  down_block_additional_residuals=down_samples,
127
  mid_block_additional_residual=mid_sample,
128
  ).sample
@@ -136,7 +139,9 @@ class ValidationCallback:
136
  # Convert to numpy for metrics
137
  gen_np = (decoded[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
138
  tgt_np = (target[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
139
- cond_np = (conditioning[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
 
 
140
 
141
  # BGR for metrics (our metrics expect BGR)
142
  gen_bgr = gen_np[:, :, ::-1].copy()
@@ -177,7 +182,7 @@ class ValidationCallback:
177
  if generated_images:
178
  grid_rows = []
179
  for i in range(0, len(generated_images), 4):
180
- row_imgs = generated_images[i:i+4]
181
  while len(row_imgs) < 4:
182
  row_imgs.append(np.zeros_like(generated_images[0]))
183
  grid_rows.append(np.hstack(row_imgs))
@@ -202,6 +207,7 @@ class ValidationCallback:
202
 
203
  try:
204
  import matplotlib
 
205
  matplotlib.use("Agg")
206
  import matplotlib.pyplot as plt
207
  except ImportError:
 
14
  import time
15
  from pathlib import Path
16
 
 
17
  import numpy as np
18
  import torch
 
19
  from PIL import Image
20
 
21
+ from landmarkdiff.evaluation import compute_lpips, compute_ssim
22
 
23
 
24
  class ValidationCallback:
 
114
 
115
  # ControlNet
116
  down_samples, mid_sample = controlnet(
117
+ scaled,
118
+ t,
119
+ encoder_hidden_states=encoder_hidden_states,
120
+ controlnet_cond=conditioning,
121
+ return_dict=False,
122
  )
123
 
124
  # UNet with ControlNet residuals
125
  noise_pred = unet(
126
+ scaled,
127
+ t,
128
+ encoder_hidden_states=encoder_hidden_states,
129
  down_block_additional_residuals=down_samples,
130
  mid_block_additional_residual=mid_sample,
131
  ).sample
 
139
  # Convert to numpy for metrics
140
  gen_np = (decoded[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
141
  tgt_np = (target[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
142
+ cond_np = (conditioning[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(
143
+ np.uint8
144
+ )
145
 
146
  # BGR for metrics (our metrics expect BGR)
147
  gen_bgr = gen_np[:, :, ::-1].copy()
 
182
  if generated_images:
183
  grid_rows = []
184
  for i in range(0, len(generated_images), 4):
185
+ row_imgs = generated_images[i : i + 4]
186
  while len(row_imgs) < 4:
187
  row_imgs.append(np.zeros_like(generated_images[0]))
188
  grid_rows.append(np.hstack(row_imgs))
 
207
 
208
  try:
209
  import matplotlib
210
+
211
  matplotlib.use("Agg")
212
  import matplotlib.pyplot as plt
213
  except ImportError: