dreamlessx commited on
Commit
9f3acc0
·
verified ·
1 Parent(s): fedb187

Clean up demo: real faces, simplified UI, remove bloat

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/demo_face_1.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/demo_face_2.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/demo_face_3.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,14 +1,10 @@
1
- """LandmarkDiff Hugging Face Spaces Demo - TPS-only (CPU)."""
2
 
3
  from __future__ import annotations
4
 
5
- import json
6
  import logging
7
- import os
8
- import threading
9
  import time
10
  import traceback
11
- from datetime import datetime, timezone
12
  from pathlib import Path
13
 
14
  import cv2
@@ -23,137 +19,21 @@ from landmarkdiff.masking import generate_surgical_mask
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
26
- VERSION = "v0.2.2"
27
-
28
  GITHUB_URL = "https://github.com/dreamlessx/LandmarkDiff-public"
29
- DOCS_URL = f"{GITHUB_URL}/tree/main/docs"
30
- WIKI_URL = f"{GITHUB_URL}/wiki"
31
- DISCUSSIONS_URL = f"{GITHUB_URL}/discussions"
32
-
33
- PROCEDURE_DESCRIPTIONS = {
34
- "rhinoplasty": "Nose reshaping -- adjusts nasal bridge, tip projection, and alar width",
35
- "blepharoplasty": "Eyelid surgery -- modifies upper/lower lid position and canthal tilt",
36
- "rhytidectomy": "Facelift -- tightens midface and jawline contours",
37
- "orthognathic": "Jaw surgery -- repositions maxilla and mandible for skeletal alignment",
38
- "brow_lift": "Brow lift -- elevates brow position and reduces forehead ptosis",
39
- "mentoplasty": "Chin surgery -- adjusts chin projection and vertical height",
40
- }
41
 
42
- # -- Detailed procedure info shown when user selects a procedure --
43
- PROCEDURE_DETAILS = {
44
- "rhinoplasty": (
45
- "**Rhinoplasty** (nose reshaping)\n\n"
46
- "Modifies the nasal bridge height, tip projection, tip rotation, and alar (nostril) "
47
- "width. The landmark displacement targets the nose dorsum, tip, columella, and alar "
48
- "base regions. At low intensity (10-30%) the effect is subtle refinement; at high "
49
- "intensity (70-100%) the reshaping is more dramatic.\n\n"
50
- "Affected landmarks: nasal bridge, tip, alar base, columella"
51
- ),
52
- "blepharoplasty": (
53
- "**Blepharoplasty** (eyelid surgery)\n\n"
54
- "Adjusts upper and lower eyelid position and canthal tilt. Targets the periorbital "
55
- "region including upper lid crease, lower lid margin, and lateral/medial canthi. "
56
- "Simulates both upper blepharoplasty (lid ptosis correction) and lower blepharoplasty "
57
- "(under-eye bag removal).\n\n"
58
- "Affected landmarks: upper/lower eyelid margins, canthi, periorbital region"
59
- ),
60
- "rhytidectomy": (
61
- "**Rhytidectomy** (facelift)\n\n"
62
- "Tightens the midface and jawline by displacing landmarks along vectors that simulate "
63
- "SMAS lift and skin redraping. Affects the cheek, jowl, and submental regions. The "
64
- "effect tightens nasolabial folds and redefines the jawline contour.\n\n"
65
- "Affected landmarks: cheek, jowl, jawline, submental region"
66
- ),
67
- "orthognathic": (
68
- "**Orthognathic surgery** (jaw repositioning)\n\n"
69
- "Simulates maxillary and mandibular osteotomy outcomes by repositioning the skeletal "
70
- "framework. Affects jaw position, chin projection, and overall facial proportion. "
71
- "Used for correcting class II/III malocclusion and facial asymmetry.\n\n"
72
- "Affected landmarks: maxilla, mandible, chin, lower face contour"
73
- ),
74
- "brow_lift": (
75
- "**Brow lift** (forehead rejuvenation)\n\n"
76
- "Elevates brow position and reduces forehead ptosis. Targets the eyebrow arch, "
77
- "lateral brow tail, and glabellar region. Simulates both endoscopic and coronal "
78
- "brow lift approaches. Higher intensities produce more visible brow elevation.\n\n"
79
- "Affected landmarks: brow arch, lateral brow, glabella, upper forehead"
80
- ),
81
- "mentoplasty": (
82
- "**Mentoplasty** (chin surgery)\n\n"
83
- "Adjusts chin projection (anteroposterior position) and vertical height. Simulates "
84
- "both augmentation (advancement) and reduction genioplasty. Affects the pogonion, "
85
- "menton, and lower border of the mandible.\n\n"
86
- "Affected landmarks: chin point, lower mandibular border, mentolabial fold"
87
- ),
88
  }
89
 
90
 
91
- # ---------------------------------------------------------------------------
92
- # Usage analytics -- simple thread-safe counter persisted to disk
93
- # ---------------------------------------------------------------------------
94
- class UsageTracker:
95
- """Track demo usage counts to a JSON file (thread-safe)."""
96
-
97
- def __init__(self, path: str = "usage_stats.json"):
98
- self._path = Path(path)
99
- self._lock = threading.Lock()
100
- self._stats: dict = self._load()
101
-
102
- def _load(self) -> dict:
103
- if self._path.exists():
104
- try:
105
- return json.loads(self._path.read_text())
106
- except (json.JSONDecodeError, OSError):
107
- pass
108
- return {
109
- "total_runs": 0,
110
- "procedures": {},
111
- "tabs": {},
112
- "first_run": None,
113
- "last_run": None,
114
- }
115
-
116
- def _save(self) -> None:
117
- try:
118
- self._path.write_text(json.dumps(self._stats, indent=2))
119
- except OSError:
120
- logger.warning("Could not persist usage stats")
121
-
122
- def record(self, tab: str, procedure: str | None = None) -> None:
123
- with self._lock:
124
- now = datetime.now(timezone.utc).isoformat()
125
- self._stats["total_runs"] = self._stats.get("total_runs", 0) + 1
126
- if self._stats.get("first_run") is None:
127
- self._stats["first_run"] = now
128
- self._stats["last_run"] = now
129
-
130
- tabs = self._stats.setdefault("tabs", {})
131
- tabs[tab] = tabs.get(tab, 0) + 1
132
-
133
- if procedure:
134
- procs = self._stats.setdefault("procedures", {})
135
- procs[procedure] = procs.get(procedure, 0) + 1
136
-
137
- self._save()
138
-
139
- @property
140
- def total_runs(self) -> int:
141
- return self._stats.get("total_runs", 0)
142
-
143
- @property
144
- def summary(self) -> str:
145
- total = self._stats.get("total_runs", 0)
146
- top_proc = ""
147
- procs = self._stats.get("procedures", {})
148
- if procs:
149
- top = max(procs, key=procs.get)
150
- top_proc = f" | Most popular: {top.replace('_', ' ').title()}"
151
- return f"Total runs: {total}{top_proc}"
152
-
153
-
154
- tracker = UsageTracker()
155
-
156
-
157
  def warp_image_tps(image, src_pts, dst_pts):
158
  """Thin-plate spline warp (CPU only)."""
159
  from landmarkdiff.synthetic.tps_warp import warp_image_tps as _warp
@@ -161,18 +41,12 @@ def warp_image_tps(image, src_pts, dst_pts):
161
  return _warp(image, src_pts, dst_pts)
162
 
163
 
164
- def mask_composite(warped, original, mask):
165
- """Alpha blend warped into original using mask."""
166
- mask_3 = np.stack([mask] * 3, axis=-1) if mask.ndim == 2 else mask
167
- return (warped * mask_3 + original * (1.0 - mask_3)).astype(np.uint8)
168
-
169
-
170
  def resize_preserve_aspect(image, size=512):
171
- """Resize image to size x size, padding to preserve aspect ratio."""
172
  h, w = image.shape[:2]
173
  scale = size / max(h, w)
174
  new_w, new_h = int(w * scale), int(h * scale)
175
- resized = cv2.resize(image, (new_w, new_h))
176
  canvas = np.zeros((size, size, 3), dtype=np.uint8)
177
  y_off = (size - new_h) // 2
178
  x_off = (size - new_w) // 2
@@ -180,26 +54,21 @@ def resize_preserve_aspect(image, size=512):
180
  return canvas
181
 
182
 
183
- PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
184
-
185
-
186
- def _error_result(msg):
187
- """Return a 5-tuple of blanks + error message for the UI."""
188
- blank = np.zeros((512, 512, 3), dtype=np.uint8)
189
- return blank, blank, blank, blank, msg
190
 
191
 
192
- def _get_procedure_description(procedure: str) -> str:
193
- """Return the detailed Markdown description for a procedure."""
194
- return PROCEDURE_DETAILS.get(procedure, "Select a procedure to see details.")
195
 
196
 
197
  def process_image(image_rgb, procedure, intensity):
198
- """Process a single image through the TPS pipeline."""
199
- tracker.record("single", procedure)
200
-
201
  if image_rgb is None:
202
- return _error_result("Upload a face photo to begin.")
 
203
 
204
  t0 = time.monotonic()
205
 
@@ -209,31 +78,29 @@ def process_image(image_rgb, procedure, intensity):
209
  image_rgb_512 = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
210
  except Exception as exc:
211
  logger.error("Image conversion failed: %s", exc)
212
- return _error_result(f"Image conversion failed: {exc}")
 
213
 
214
  try:
215
  face = extract_landmarks(image_bgr)
216
  except Exception as exc:
217
  logger.error("Landmark extraction failed: %s\n%s", exc, traceback.format_exc())
218
- return _error_result(f"Landmark extraction error: {exc}")
 
219
 
220
  if face is None:
221
  return (
222
- image_rgb_512,
223
- image_rgb_512,
224
- image_rgb_512,
225
- image_rgb_512,
226
- "No face detected. Try a clearer photo with good lighting.",
227
  )
228
 
229
  try:
230
  manipulated = apply_procedure_preset(face, procedure, float(intensity), image_size=512)
231
-
232
  wireframe = render_wireframe(manipulated, width=512, height=512)
233
  wireframe_rgb = cv2.cvtColor(wireframe, cv2.COLOR_GRAY2RGB)
234
 
235
  mask = generate_surgical_mask(face, procedure, 512, 512)
236
- mask_vis = (mask * 255).astype(np.uint8)
237
 
238
  warped = warp_image_tps(image_bgr, face.pixel_coords, manipulated.pixel_coords)
239
  composited = mask_composite(warped, image_bgr, mask)
@@ -242,38 +109,32 @@ def process_image(image_rgb, procedure, intensity):
242
  displacement = np.mean(
243
  np.linalg.norm(manipulated.pixel_coords - face.pixel_coords, axis=1)
244
  )
245
-
246
  elapsed = time.monotonic() - t0
247
 
248
  info = (
249
- f"Procedure: {procedure}\n"
250
  f"Intensity: {intensity:.0f}%\n"
251
  f"Landmarks: {len(face.landmarks)}\n"
252
  f"Avg displacement: {displacement:.1f} px\n"
253
  f"Confidence: {face.confidence:.2f}\n"
254
- f"Processing time: {elapsed:.2f}s\n"
255
- f"Mode: TPS (CPU)"
256
  )
257
- # Return original as 4th output instead of stretched side-by-side
258
  return wireframe_rgb, mask_vis, composited_rgb, image_rgb_512, info
259
 
260
  except Exception as exc:
261
  logger.error("Processing failed: %s\n%s", exc, traceback.format_exc())
262
- return _error_result(f"Processing error: {exc}")
 
263
 
264
 
265
  def compare_procedures(image_rgb, intensity):
266
- """Compare all procedures at the same intensity."""
267
- tracker.record("compare")
268
-
269
  if image_rgb is None:
270
- blank = np.zeros((512, 512, 3), dtype=np.uint8)
271
- return [blank] * len(PROCEDURES)
272
 
273
  try:
274
  image_bgr = cv2.cvtColor(np.asarray(image_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
275
  image_bgr = resize_preserve_aspect(image_bgr, 512)
276
-
277
  face = extract_landmarks(image_bgr)
278
  if face is None:
279
  rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
@@ -286,32 +147,26 @@ def compare_procedures(image_rgb, intensity):
286
  warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
287
  comp = mask_composite(warped, image_bgr, mask)
288
  results.append(cv2.cvtColor(comp, cv2.COLOR_BGR2RGB))
289
-
290
  return results
291
  except Exception as exc:
292
- logger.error("Compare procedures failed: %s\n%s", exc, traceback.format_exc())
293
- blank = np.zeros((512, 512, 3), dtype=np.uint8)
294
- return [blank] * len(PROCEDURES)
295
 
296
 
297
  def intensity_sweep(image_rgb, procedure):
298
- """Generate intensity sweep from 0 to 100."""
299
- tracker.record("sweep", procedure)
300
-
301
  if image_rgb is None:
302
  return []
303
 
304
  try:
305
  image_bgr = cv2.cvtColor(np.asarray(image_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
306
  image_bgr = resize_preserve_aspect(image_bgr, 512)
307
-
308
  face = extract_landmarks(image_bgr)
309
  if face is None:
310
  return []
311
 
312
- steps = [0, 20, 40, 60, 80, 100]
313
  results = []
314
- for val in steps:
315
  if val == 0:
316
  results.append((cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB), "0%"))
317
  continue
@@ -320,161 +175,49 @@ def intensity_sweep(image_rgb, procedure):
320
  warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
321
  comp = mask_composite(warped, image_bgr, mask)
322
  results.append((cv2.cvtColor(comp, cv2.COLOR_BGR2RGB), f"{val}%"))
323
-
324
  return results
325
  except Exception as exc:
326
- logger.error("Intensity sweep failed: %s\n%s", exc, traceback.format_exc())
327
  return []
328
 
329
 
330
- # -- Example images --
331
- EXAMPLE_DIR = Path(__file__).parent / "examples"
332
- EXAMPLE_IMAGES = sorted(EXAMPLE_DIR.glob("*.png")) if EXAMPLE_DIR.exists() else []
333
 
334
- # -- Build the procedure table for the description --
335
- _proc_rows = "\n".join(
336
- f"| **{name.replace('_', ' ').title()}** | {desc} |"
337
- for name, desc in PROCEDURE_DESCRIPTIONS.items()
338
  )
339
 
340
- HEADER_MD = f"""
341
- # LandmarkDiff
342
-
343
- **Anatomically-conditioned facial surgery outcome prediction from standard clinical photography**
344
-
345
- Upload a face photo, select a procedure, and adjust intensity to see a predicted
346
- surgical outcome in real time.
347
- This demo runs TPS (thin-plate spline) warping on CPU. The full package also supports
348
- GPU-accelerated ControlNet and img2img inference modes.
349
-
350
- ---
351
-
352
- ### Supported Procedures
353
-
354
- | Procedure | Description |
355
- |-----------|-------------|
356
- {_proc_rows}
357
-
358
- ---
359
-
360
- ### How It Works
361
-
362
- 1. **Landmark detection** -- MediaPipe extracts a 478-point facial mesh from the input photo.
363
- 2. **Anatomical displacement** -- Procedure-specific presets shift landmark subsets by calibrated
364
- vectors (intensity 0-100 controls magnitude).
365
- 3. **TPS deformation** -- A thin-plate spline maps source landmarks to displaced targets, warping
366
- the image smoothly while preserving non-surgical regions.
367
- 4. **Masked compositing** -- A procedure-aware mask blends the warped region back into the
368
- original, keeping hair, background, and uninvolved anatomy intact.
369
-
370
- In GPU modes the deformed wireframe is passed to a ControlNet-conditioned Stable Diffusion
371
- pipeline for photorealistic rendering, followed by CodeFormer + Real-ESRGAN post-processing.
372
-
373
- ---
374
-
375
- [GitHub]({GITHUB_URL}) | \
376
- [Documentation]({DOCS_URL}) | \
377
- [Wiki]({WIKI_URL}) | \
378
- [Discussions]({DISCUSSIONS_URL})
379
- """
380
-
381
- FOOTER_MD = f"""
382
- ---
383
- <div style="text-align:center; color:#888; font-size:0.85em; padding: 12px 0;">
384
- <p>
385
- <strong>LandmarkDiff</strong> {VERSION} &middot;
386
- TPS warping on CPU &middot;
387
- MediaPipe 478-point mesh &middot;
388
- 6 surgical procedures
389
- </p>
390
- <p>
391
- <a href="{GITHUB_URL}">GitHub</a> &middot;
392
- <a href="{DOCS_URL}">Docs</a> &middot;
393
- <a href="{WIKI_URL}">Wiki</a> &middot;
394
- <a href="{DISCUSSIONS_URL}">Discussions</a> &middot;
395
- MIT License
396
- </p>
397
- <p style="font-size:0.75em; color:#aaa;">
398
- Built with Gradio &middot;
399
- Powered by MediaPipe + OpenCV &middot;
400
- <a href="{GITHUB_URL}/blob/main/CITATION.cff">Cite this work</a>
401
- </p>
402
- </div>
403
- """
404
-
405
-
406
  with gr.Blocks(
407
- title="LandmarkDiff - Surgical Outcome Prediction",
408
  theme=gr.themes.Soft(),
409
- css="""
410
- .status-processing {
411
- background: linear-gradient(90deg, #e3f2fd 0%, #bbdefb 50%, #e3f2fd 100%);
412
- background-size: 200% 100%;
413
- animation: shimmer 2s infinite;
414
- padding: 8px 16px;
415
- border-radius: 6px;
416
- text-align: center;
417
- font-weight: 500;
418
- }
419
- @keyframes shimmer {
420
- 0% { background-position: -200% 0; }
421
- 100% { background-position: 200% 0; }
422
- }
423
- .status-ready {
424
- background: #e8f5e9;
425
- padding: 8px 16px;
426
- border-radius: 6px;
427
- text-align: center;
428
- color: #2e7d32;
429
- font-weight: 500;
430
- }
431
- .status-error {
432
- background: #ffebee;
433
- padding: 8px 16px;
434
- border-radius: 6px;
435
- text-align: center;
436
- color: #c62828;
437
- font-weight: 500;
438
- }
439
- .proc-detail-box {
440
- background: #f5f5f5;
441
- border-left: 3px solid #1976d2;
442
- padding: 12px 16px;
443
- border-radius: 4px;
444
- margin-top: 8px;
445
- }
446
- """,
447
  ) as demo:
448
- gr.Markdown(HEADER_MD)
449
-
450
- # -- Single Procedure tab --
 
 
 
 
 
 
 
 
451
  with gr.Tab("Single Procedure"):
452
  with gr.Row():
453
  with gr.Column(scale=1):
454
- input_image = gr.Image(label="Upload Face Photo", type="numpy", height=350)
455
  procedure = gr.Radio(
456
- choices=PROCEDURES,
457
- value="rhinoplasty",
458
- label="Surgical Procedure",
459
- )
460
- proc_detail = gr.Markdown(
461
- value=_get_procedure_description("rhinoplasty"),
462
- elem_classes=["proc-detail-box"],
463
  )
464
  intensity = gr.Slider(
465
- minimum=0,
466
- maximum=100,
467
- value=50,
468
- step=1,
469
- label="Intensity (%)",
470
- info="0 = no change, 100 = maximum effect",
471
- )
472
- run_btn = gr.Button("Generate Preview", variant="primary", size="lg")
473
- status_box = gr.HTML(
474
- value='<div class="status-ready">Ready -- upload a photo or click an example below</div>',
475
- label="Status",
476
  )
477
- info_box = gr.Textbox(label="Info", lines=7, interactive=False)
 
478
 
479
  with gr.Column(scale=2):
480
  with gr.Row():
@@ -482,157 +225,74 @@ with gr.Blocks(
482
  out_mask = gr.Image(label="Surgical Mask", height=256)
483
  with gr.Row():
484
  out_result = gr.Image(label="Predicted Result", height=256)
485
- out_sidebyside = gr.Image(label="Original", height=256)
486
 
487
- # -- Example images --
488
  if EXAMPLE_IMAGES:
489
- gr.Markdown("### Try an Example")
490
  gr.Examples(
491
  examples=[[str(p)] for p in EXAMPLE_IMAGES],
492
  inputs=[input_image],
493
- label="Click an example face to load it (these are synthetic sketches "
494
- "-- for best results, upload a real photo)",
495
  )
496
 
497
- # -- Procedure description update --
498
- procedure.change(
499
- fn=_get_procedure_description,
500
- inputs=[procedure],
501
- outputs=[proc_detail],
502
- )
503
-
504
- # -- Processing with status indicator --
505
- def _process_with_status(image_rgb, proc, intens):
506
- results = process_image(image_rgb, proc, intens)
507
- # Last element is the info/error text
508
- info_text = results[-1]
509
- if "error" in info_text.lower() or "No face" in info_text:
510
- status_html = f'<div class="status-error">{info_text.split(chr(10))[0]}</div>'
511
- else:
512
- status_html = '<div class="status-ready">Done -- result ready</div>'
513
- return results + (status_html,)
514
-
515
- all_outputs = [out_wireframe, out_mask, out_result, out_sidebyside, info_box, status_box]
516
-
517
- run_btn.click(
518
- fn=lambda: '<div class="status-processing">Processing... extracting landmarks and warping</div>',
519
- inputs=None,
520
- outputs=[status_box],
521
- ).then(
522
- fn=_process_with_status,
523
- inputs=[input_image, procedure, intensity],
524
- outputs=all_outputs,
525
- )
526
-
527
- # Auto-trigger on input change (image upload, procedure change, intensity change)
528
  for trigger in [input_image, procedure, intensity]:
529
- trigger.change(
530
- fn=lambda: '<div class="status-processing">Processing...</div>',
531
- inputs=None,
532
- outputs=[status_box],
533
- ).then(
534
- fn=_process_with_status,
535
- inputs=[input_image, procedure, intensity],
536
- outputs=all_outputs,
537
- )
538
 
539
- # -- Compare Procedures tab --
540
- with gr.Tab("Compare Procedures"):
541
- gr.Markdown("Compare all six procedures side by side at the same intensity.")
542
  with gr.Row():
543
  with gr.Column(scale=1):
544
- cmp_image = gr.Image(label="Upload Face Photo", type="numpy", height=300)
545
  cmp_intensity = gr.Slider(0, 100, 50, step=1, label="Intensity (%)")
546
- cmp_btn = gr.Button("Compare All", variant="primary", size="lg")
547
- cmp_status = gr.HTML(
548
- value='<div class="status-ready">Ready</div>',
549
- )
550
  with gr.Column(scale=2):
551
  cmp_outputs = []
552
- rows_needed = (len(PROCEDURES) + 2) // 3
553
- for row_idx in range(rows_needed):
554
  with gr.Row():
555
  for col_idx in range(3):
556
- proc_idx = row_idx * 3 + col_idx
557
- if proc_idx < len(PROCEDURES):
558
  cmp_outputs.append(
559
  gr.Image(
560
- label=PROCEDURES[proc_idx].replace("_", " ").title(),
561
  height=200,
562
  )
563
  )
564
 
565
- # Example images for Compare tab
566
  if EXAMPLE_IMAGES:
567
  gr.Examples(
568
  examples=[[str(p)] for p in EXAMPLE_IMAGES],
569
- inputs=[cmp_image],
570
- label="Example faces",
571
  )
572
 
573
- def _compare_with_status(img, intens):
574
- results = compare_procedures(img, intens)
575
- return results + ['<div class="status-ready">Done -- 6 procedures compared</div>']
576
-
577
- cmp_btn.click(
578
- fn=lambda: '<div class="status-processing">Processing 6 procedures...</div>',
579
- inputs=None,
580
- outputs=[cmp_status],
581
- ).then(
582
- fn=_compare_with_status,
583
- inputs=[cmp_image, cmp_intensity],
584
- outputs=cmp_outputs + [cmp_status],
585
- )
586
 
587
- # -- Intensity Sweep tab --
588
  with gr.Tab("Intensity Sweep"):
589
- gr.Markdown(
590
- "See how a procedure looks across intensity levels (0% through 100% in 20% steps)."
591
- )
592
  with gr.Row():
593
  with gr.Column(scale=1):
594
- sweep_image = gr.Image(label="Upload Face Photo", type="numpy", height=300)
595
- sweep_procedure = gr.Radio(
596
- choices=PROCEDURES,
597
- value="rhinoplasty",
598
- label="Procedure",
599
- )
600
- sweep_btn = gr.Button("Generate Sweep", variant="primary", size="lg")
601
- sweep_status = gr.HTML(
602
- value='<div class="status-ready">Ready</div>',
603
- )
604
  with gr.Column(scale=2):
605
- sweep_gallery = gr.Gallery(
606
- label="Intensity Sweep (0% - 100%)", columns=3, height=400
607
- )
608
 
609
- # Example images for Sweep tab
610
  if EXAMPLE_IMAGES:
611
  gr.Examples(
612
  examples=[[str(p)] for p in EXAMPLE_IMAGES],
613
- inputs=[sweep_image],
614
- label="Example faces",
615
  )
616
 
617
- def _sweep_with_status(img, proc):
618
- results = intensity_sweep(img, proc)
619
- if results:
620
- status = '<div class="status-ready">Done -- 6 intensity levels generated</div>'
621
- else:
622
- status = '<div class="status-error">No face detected or processing failed</div>'
623
- return results, status
624
-
625
- sweep_btn.click(
626
- fn=lambda: '<div class="status-processing">Generating 6 intensity levels...</div>',
627
- inputs=None,
628
- outputs=[sweep_status],
629
- ).then(
630
- fn=_sweep_with_status,
631
- inputs=[sweep_image, sweep_procedure],
632
- outputs=[sweep_gallery, sweep_status],
633
- )
634
 
635
- gr.Markdown(FOOTER_MD)
 
 
 
 
636
 
637
  if __name__ == "__main__":
638
  demo.launch(show_error=True)
 
1
+ """LandmarkDiff -- Facial surgery outcome prediction demo (TPS on CPU)."""
2
 
3
  from __future__ import annotations
4
 
 
5
  import logging
 
 
6
  import time
7
  import traceback
 
8
  from pathlib import Path
9
 
10
  import cv2
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
 
 
22
  GITHUB_URL = "https://github.com/dreamlessx/LandmarkDiff-public"
23
+ PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
24
+ EXAMPLE_DIR = Path(__file__).parent / "examples"
25
+ EXAMPLE_IMAGES = sorted(EXAMPLE_DIR.glob("*.png")) if EXAMPLE_DIR.exists() else []
 
 
 
 
 
 
 
 
 
26
 
27
+ PROCEDURE_INFO = {
28
+ "rhinoplasty": "Nose reshaping (bridge, tip, alar width)",
29
+ "blepharoplasty": "Eyelid surgery (lid position, canthal tilt)",
30
+ "rhytidectomy": "Facelift (midface, jawline tightening)",
31
+ "orthognathic": "Jaw surgery (maxilla/mandible repositioning)",
32
+ "brow_lift": "Brow elevation, forehead ptosis reduction",
33
+ "mentoplasty": "Chin surgery (projection, vertical height)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  }
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def warp_image_tps(image, src_pts, dst_pts):
38
  """Thin-plate spline warp (CPU only)."""
39
  from landmarkdiff.synthetic.tps_warp import warp_image_tps as _warp
 
41
  return _warp(image, src_pts, dst_pts)
42
 
43
 
 
 
 
 
 
 
44
  def resize_preserve_aspect(image, size=512):
45
+ """Resize to square canvas, padding to preserve aspect ratio."""
46
  h, w = image.shape[:2]
47
  scale = size / max(h, w)
48
  new_w, new_h = int(w * scale), int(h * scale)
49
+ resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
50
  canvas = np.zeros((size, size, 3), dtype=np.uint8)
51
  y_off = (size - new_h) // 2
52
  x_off = (size - new_w) // 2
 
54
  return canvas
55
 
56
 
57
+ def mask_composite(warped, original, mask):
58
+ """Alpha-blend warped region into original using mask."""
59
+ mask_3 = np.stack([mask] * 3, axis=-1) if mask.ndim == 2 else mask
60
+ return (warped * mask_3 + original * (1.0 - mask_3)).astype(np.uint8)
 
 
 
61
 
62
 
63
+ def _blank():
64
+ return np.zeros((512, 512, 3), dtype=np.uint8)
 
65
 
66
 
67
  def process_image(image_rgb, procedure, intensity):
68
+ """Run the TPS pipeline on a single image."""
 
 
69
  if image_rgb is None:
70
+ b = _blank()
71
+ return b, b, b, b, "Upload a face photo to begin."
72
 
73
  t0 = time.monotonic()
74
 
 
78
  image_rgb_512 = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
79
  except Exception as exc:
80
  logger.error("Image conversion failed: %s", exc)
81
+ b = _blank()
82
+ return b, b, b, b, f"Image conversion failed: {exc}"
83
 
84
  try:
85
  face = extract_landmarks(image_bgr)
86
  except Exception as exc:
87
  logger.error("Landmark extraction failed: %s\n%s", exc, traceback.format_exc())
88
+ b = _blank()
89
+ return b, b, b, b, f"Landmark extraction error: {exc}"
90
 
91
  if face is None:
92
  return (
93
+ image_rgb_512, image_rgb_512, image_rgb_512, image_rgb_512,
94
+ "No face detected. Try a clearer, well-lit frontal photo.",
 
 
 
95
  )
96
 
97
  try:
98
  manipulated = apply_procedure_preset(face, procedure, float(intensity), image_size=512)
 
99
  wireframe = render_wireframe(manipulated, width=512, height=512)
100
  wireframe_rgb = cv2.cvtColor(wireframe, cv2.COLOR_GRAY2RGB)
101
 
102
  mask = generate_surgical_mask(face, procedure, 512, 512)
103
+ mask_vis = cv2.cvtColor((mask * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB)
104
 
105
  warped = warp_image_tps(image_bgr, face.pixel_coords, manipulated.pixel_coords)
106
  composited = mask_composite(warped, image_bgr, mask)
 
109
  displacement = np.mean(
110
  np.linalg.norm(manipulated.pixel_coords - face.pixel_coords, axis=1)
111
  )
 
112
  elapsed = time.monotonic() - t0
113
 
114
  info = (
115
+ f"Procedure: {procedure.replace('_', ' ').title()}\n"
116
  f"Intensity: {intensity:.0f}%\n"
117
  f"Landmarks: {len(face.landmarks)}\n"
118
  f"Avg displacement: {displacement:.1f} px\n"
119
  f"Confidence: {face.confidence:.2f}\n"
120
+ f"Time: {elapsed:.2f}s | Mode: TPS (CPU)"
 
121
  )
 
122
  return wireframe_rgb, mask_vis, composited_rgb, image_rgb_512, info
123
 
124
  except Exception as exc:
125
  logger.error("Processing failed: %s\n%s", exc, traceback.format_exc())
126
+ b = _blank()
127
+ return b, b, b, b, f"Processing error: {exc}"
128
 
129
 
130
  def compare_procedures(image_rgb, intensity):
131
+ """Compare all six procedures at the same intensity."""
 
 
132
  if image_rgb is None:
133
+ return [_blank()] * len(PROCEDURES)
 
134
 
135
  try:
136
  image_bgr = cv2.cvtColor(np.asarray(image_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
137
  image_bgr = resize_preserve_aspect(image_bgr, 512)
 
138
  face = extract_landmarks(image_bgr)
139
  if face is None:
140
  rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
 
147
  warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
148
  comp = mask_composite(warped, image_bgr, mask)
149
  results.append(cv2.cvtColor(comp, cv2.COLOR_BGR2RGB))
 
150
  return results
151
  except Exception as exc:
152
+ logger.error("Compare failed: %s\n%s", exc, traceback.format_exc())
153
+ return [_blank()] * len(PROCEDURES)
 
154
 
155
 
156
  def intensity_sweep(image_rgb, procedure):
157
+ """Generate results at 0%, 20%, 40%, 60%, 80%, 100% intensity."""
 
 
158
  if image_rgb is None:
159
  return []
160
 
161
  try:
162
  image_bgr = cv2.cvtColor(np.asarray(image_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
163
  image_bgr = resize_preserve_aspect(image_bgr, 512)
 
164
  face = extract_landmarks(image_bgr)
165
  if face is None:
166
  return []
167
 
 
168
  results = []
169
+ for val in [0, 20, 40, 60, 80, 100]:
170
  if val == 0:
171
  results.append((cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB), "0%"))
172
  continue
 
175
  warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
176
  comp = mask_composite(warped, image_bgr, mask)
177
  results.append((cv2.cvtColor(comp, cv2.COLOR_BGR2RGB), f"{val}%"))
 
178
  return results
179
  except Exception as exc:
180
+ logger.error("Sweep failed: %s\n%s", exc, traceback.format_exc())
181
  return []
182
 
183
 
184
+ # ---------------------------------------------------------------------------
185
+ # Build the Gradio UI
186
+ # ---------------------------------------------------------------------------
187
 
188
+ _proc_table = "\n".join(
189
+ f"| {name.replace('_', ' ').title()} | {desc} |"
190
+ for name, desc in PROCEDURE_INFO.items()
 
191
  )
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  with gr.Blocks(
194
+ title="LandmarkDiff",
195
  theme=gr.themes.Soft(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  ) as demo:
197
+ gr.Markdown(
198
+ f"# LandmarkDiff\n\n"
199
+ f"Facial surgery outcome prediction from clinical photography. "
200
+ f"Upload a face photo, pick a procedure, adjust intensity.\n\n"
201
+ f"| Procedure | Effect |\n|---|---|\n{_proc_table}\n\n"
202
+ f"[GitHub]({GITHUB_URL}) | "
203
+ f"[Docs]({GITHUB_URL}/tree/main/docs) | "
204
+ f"[Wiki]({GITHUB_URL}/wiki)"
205
+ )
206
+
207
+ # -- Tab 1: Single Procedure --
208
  with gr.Tab("Single Procedure"):
209
  with gr.Row():
210
  with gr.Column(scale=1):
211
+ input_image = gr.Image(label="Face Photo", type="numpy", height=350)
212
  procedure = gr.Radio(
213
+ choices=PROCEDURES, value="rhinoplasty", label="Procedure",
 
 
 
 
 
 
214
  )
215
  intensity = gr.Slider(
216
+ 0, 100, 50, step=1, label="Intensity (%)",
217
+ info="0 = no change, 100 = maximum",
 
 
 
 
 
 
 
 
 
218
  )
219
+ run_btn = gr.Button("Generate", variant="primary", size="lg")
220
+ info_box = gr.Textbox(label="Info", lines=6, interactive=False)
221
 
222
  with gr.Column(scale=2):
223
  with gr.Row():
 
225
  out_mask = gr.Image(label="Surgical Mask", height=256)
226
  with gr.Row():
227
  out_result = gr.Image(label="Predicted Result", height=256)
228
+ out_original = gr.Image(label="Original", height=256)
229
 
 
230
  if EXAMPLE_IMAGES:
 
231
  gr.Examples(
232
  examples=[[str(p)] for p in EXAMPLE_IMAGES],
233
  inputs=[input_image],
234
+ label="Example faces (click to load)",
 
235
  )
236
 
237
+ outputs = [out_wireframe, out_mask, out_result, out_original, info_box]
238
+ run_btn.click(fn=process_image, inputs=[input_image, procedure, intensity], outputs=outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  for trigger in [input_image, procedure, intensity]:
240
+ trigger.change(fn=process_image, inputs=[input_image, procedure, intensity], outputs=outputs)
 
 
 
 
 
 
 
 
241
 
242
+ # -- Tab 2: Compare Procedures --
243
+ with gr.Tab("Compare All"):
244
+ gr.Markdown("All six procedures at the same intensity, side by side.")
245
  with gr.Row():
246
  with gr.Column(scale=1):
247
+ cmp_image = gr.Image(label="Face Photo", type="numpy", height=300)
248
  cmp_intensity = gr.Slider(0, 100, 50, step=1, label="Intensity (%)")
249
+ cmp_btn = gr.Button("Compare", variant="primary", size="lg")
 
 
 
250
  with gr.Column(scale=2):
251
  cmp_outputs = []
252
+ for row_idx in range(2):
 
253
  with gr.Row():
254
  for col_idx in range(3):
255
+ idx = row_idx * 3 + col_idx
256
+ if idx < len(PROCEDURES):
257
  cmp_outputs.append(
258
  gr.Image(
259
+ label=PROCEDURES[idx].replace("_", " ").title(),
260
  height=200,
261
  )
262
  )
263
 
 
264
  if EXAMPLE_IMAGES:
265
  gr.Examples(
266
  examples=[[str(p)] for p in EXAMPLE_IMAGES],
267
+ inputs=[cmp_image], label="Examples",
 
268
  )
269
 
270
+ cmp_btn.click(fn=compare_procedures, inputs=[cmp_image, cmp_intensity], outputs=cmp_outputs)
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ # -- Tab 3: Intensity Sweep --
273
  with gr.Tab("Intensity Sweep"):
274
+ gr.Markdown("See a procedure at 0% through 100% in six steps.")
 
 
275
  with gr.Row():
276
  with gr.Column(scale=1):
277
+ sweep_image = gr.Image(label="Face Photo", type="numpy", height=300)
278
+ sweep_proc = gr.Radio(choices=PROCEDURES, value="rhinoplasty", label="Procedure")
279
+ sweep_btn = gr.Button("Sweep", variant="primary", size="lg")
 
 
 
 
 
 
 
280
  with gr.Column(scale=2):
281
+ sweep_gallery = gr.Gallery(label="0% to 100%", columns=3, height=400)
 
 
282
 
 
283
  if EXAMPLE_IMAGES:
284
  gr.Examples(
285
  examples=[[str(p)] for p in EXAMPLE_IMAGES],
286
+ inputs=[sweep_image], label="Examples",
 
287
  )
288
 
289
+ sweep_btn.click(fn=intensity_sweep, inputs=[sweep_image, sweep_proc], outputs=[sweep_gallery])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
+ gr.Markdown(
292
+ f"<div style='text-align:center;color:#999;font-size:0.8em;padding:8px'>"
293
+ f"LandmarkDiff v0.2.2 | TPS on CPU | MediaPipe 478-point mesh | "
294
+ f"<a href='{GITHUB_URL}'>GitHub</a> | MIT License</div>"
295
+ )
296
 
297
  if __name__ == "__main__":
298
  demo.launch(show_error=True)
examples/demo_face_1.png ADDED

Git LFS Details

  • SHA256: ed160a54ab0d022bfc75547e1be82a3ac677c1bd85cd537281bd88c937eee998
  • Pointer size: 131 Bytes
  • Size of remote file: 440 kB
examples/demo_face_2.png ADDED

Git LFS Details

  • SHA256: 0f54ca90ad94c4a1552f4342dbcec58a46512df10d3f74a20e425bd6e9fdefcb
  • Pointer size: 131 Bytes
  • Size of remote file: 449 kB
examples/demo_face_3.png ADDED

Git LFS Details

  • SHA256: 5a3bc6867c9626ed0521569b125ddc090623afdc64772dd86968eff2cbc821a9
  • Pointer size: 131 Bytes
  • Size of remote file: 449 kB