dreamlessx commited on
Commit
16a86a4
·
verified ·
1 Parent(s): 0c568c7

Update landmarkdiff/inference.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/inference.py +162 -71
landmarkdiff/inference.py CHANGED
@@ -4,13 +4,15 @@ Four modes:
4
  1. ControlNet: CrucibleAI/ControlNetMediaPipeFace + SD1.5 (requires HF auth + GPU)
5
  2. ControlNet + IP-Adapter: ControlNet with identity preservation via face embeddings
6
  3. Img2Img: SD1.5 img2img with mask compositing (runs on MPS, no auth needed)
7
- 4. TPS-only: Pure geometric warp no diffusion model, instant results
8
 
9
  Supports MPS (Apple Silicon), CUDA, and CPU backends.
10
  """
11
 
12
  from __future__ import annotations
13
 
 
 
14
  import sys
15
  from pathlib import Path
16
  from typing import TYPE_CHECKING
@@ -28,6 +30,8 @@ from landmarkdiff.synthetic.tps_warp import warp_image_tps
28
  if TYPE_CHECKING:
29
  from landmarkdiff.clinical import ClinicalFlags
30
 
 
 
31
 
32
  def get_device() -> torch.device:
33
  if torch.backends.mps.is_available():
@@ -71,6 +75,16 @@ PROCEDURE_PROMPTS: dict[str, str] = {
71
  "realistic skin pores and texture, sharp focus, studio lighting, "
72
  "DSLR quality, natural skin color"
73
  ),
 
 
 
 
 
 
 
 
 
 
74
  }
75
 
76
  NEGATIVE_PROMPT = (
@@ -81,6 +95,21 @@ NEGATIVE_PROMPT = (
81
  "plastic skin, waxy, smooth skin, airbrushed, oversaturated"
82
  )
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def mask_composite(
86
  warped: np.ndarray,
@@ -107,7 +136,7 @@ def mask_composite(
107
 
108
  return laplacian_pyramid_blend(corrected, original, mask_f)
109
  except Exception:
110
- pass
111
 
112
  # Fallback: simple alpha blend
113
  mask_3ch = mask_to_3channel(mask_f)
@@ -124,7 +153,7 @@ def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -
124
  Works in LAB space: transfers L (luminance) and AB (color) statistics
125
  from the original to the warped image so skin tone is preserved exactly.
126
  """
127
- mask_bool = mask > 0.3
128
  if not np.any(mask_bool):
129
  return source
130
 
@@ -136,8 +165,8 @@ def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -
136
  src_vals = src_lab[:, :, ch][mask_bool]
137
  tgt_vals = tgt_lab[:, :, ch][mask_bool]
138
 
139
- src_mean, src_std = np.mean(src_vals), np.std(src_vals) + 1e-6
140
- tgt_mean, tgt_std = np.mean(tgt_vals), np.std(tgt_vals) + 1e-6
141
 
142
  # Normalize source to match target's distribution
143
  src_lab[:, :, ch] = np.where(
@@ -154,7 +183,8 @@ class LandmarkDiffPipeline:
154
  """End-to-end pipeline: image -> landmarks -> manipulate -> generate.
155
 
156
  Modes:
157
- - 'controlnet': CrucibleAI/ControlNetMediaPipeFace + SD1.5
 
158
  - 'controlnet_ip': ControlNet + IP-Adapter for identity preservation
159
  - 'img2img': SD1.5 img2img with mask compositing
160
  - 'tps': Pure geometric TPS warp (no diffusion, instant)
@@ -166,6 +196,9 @@ class LandmarkDiffPipeline:
166
  IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus-face_sd15.bin"
167
  IP_ADAPTER_SCALE_DEFAULT = 0.6
168
 
 
 
 
169
  def __init__(
170
  self,
171
  mode: str = "img2img",
@@ -191,9 +224,9 @@ class LandmarkDiffPipeline:
191
  from landmarkdiff.displacement_model import DisplacementModel
192
 
193
  self._displacement_model = DisplacementModel.load(displacement_model_path)
194
- print(f"Displacement model loaded: {self._displacement_model.procedures}")
195
  except Exception as e:
196
- print(f"WARNING: Failed to load displacement model: {e}")
197
 
198
  if self.device.type == "mps":
199
  self.dtype = torch.float32
@@ -204,22 +237,23 @@ class LandmarkDiffPipeline:
204
 
205
  if base_model_id:
206
  self.base_model_id = base_model_id
207
- elif mode in ("controlnet", "controlnet_ip"):
208
- self.base_model_id = "runwayml/stable-diffusion-v1-5"
209
  else:
210
  self.base_model_id = "runwayml/stable-diffusion-v1-5"
211
 
212
  self.controlnet_id = controlnet_id
213
  self._pipe = None
214
  self._ip_adapter_loaded = False
 
215
 
216
  def load(self) -> None:
217
  if self.mode == "tps":
218
- print("TPS mode no model to load")
219
  return
220
- if self.mode in ("controlnet", "controlnet_ip"):
221
  self._load_controlnet()
222
- if self.mode == "controlnet_ip":
 
 
223
  self._load_ip_adapter()
224
  else:
225
  self._load_img2img()
@@ -231,43 +265,72 @@ class LandmarkDiffPipeline:
231
  StableDiffusionControlNetPipeline,
232
  )
233
 
 
 
 
234
  if self.controlnet_checkpoint:
235
  # Load fine-tuned ControlNet from local checkpoint
236
  ckpt_path = Path(self.controlnet_checkpoint)
237
  # Support both direct path and training checkpoint structure
238
  if (ckpt_path / "controlnet_ema").exists():
239
  ckpt_path = ckpt_path / "controlnet_ema"
240
- print(f"Loading fine-tuned ControlNet from {ckpt_path}...")
241
  controlnet = ControlNetModel.from_pretrained(
242
  str(ckpt_path),
243
  torch_dtype=self.dtype,
244
  )
245
  else:
246
- print(f"Loading ControlNet from {self.controlnet_id}...")
247
  controlnet = ControlNetModel.from_pretrained(
248
  self.controlnet_id,
249
  subfolder="diffusion_sd15",
250
  torch_dtype=self.dtype,
 
251
  )
252
- print(f"Loading base model from {self.base_model_id}...")
253
  self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
254
  self.base_model_id,
255
  controlnet=controlnet,
256
  torch_dtype=self.dtype,
257
  safety_checker=None,
258
  requires_safety_checker=False,
 
259
  )
260
- # DPM++ 2M Karras produces more photorealistic output than UniPC
261
  self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
262
  self._pipe.scheduler.config,
263
  algorithm_type="dpmsolver++",
264
  use_karras_sigmas=True,
265
  )
266
- # FP32 VAE decode prevents color banding artifacts on skin tones
267
  if hasattr(self._pipe, "vae") and self._pipe.vae is not None:
268
  self._pipe.vae.config.force_upcast = True
269
  self._apply_device_optimizations()
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  def _load_ip_adapter(self) -> None:
272
  """Load IP-Adapter for identity-preserving generation.
273
 
@@ -277,7 +340,7 @@ class LandmarkDiffPipeline:
277
  if self._pipe is None:
278
  raise RuntimeError("Base pipeline must be loaded before IP-Adapter")
279
  try:
280
- print(f"Loading IP-Adapter ({self.IP_ADAPTER_WEIGHT_NAME})...")
281
  self._pipe.load_ip_adapter(
282
  self.IP_ADAPTER_REPO,
283
  subfolder=self.IP_ADAPTER_SUBFOLDER,
@@ -285,10 +348,10 @@ class LandmarkDiffPipeline:
285
  )
286
  self._pipe.set_ip_adapter_scale(self.ip_adapter_scale)
287
  self._ip_adapter_loaded = True
288
- print(f"IP-Adapter loaded (scale={self.ip_adapter_scale})")
289
  except Exception as e:
290
- print(f"WARNING: IP-Adapter load failed: {e}")
291
- print("Falling back to ControlNet-only mode")
292
  self._ip_adapter_loaded = False
293
 
294
  def _load_img2img(self) -> None:
@@ -297,12 +360,16 @@ class LandmarkDiffPipeline:
297
  StableDiffusionImg2ImgPipeline,
298
  )
299
 
300
- print(f"Loading SD1.5 img2img from {self.base_model_id}...")
 
 
 
301
  self._pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
302
  self.base_model_id,
303
  torch_dtype=self.dtype,
304
  safety_checker=None,
305
  requires_safety_checker=False,
 
306
  )
307
  self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(self._pipe.scheduler.config)
308
  self._apply_device_optimizations()
@@ -318,7 +385,7 @@ class LandmarkDiffPipeline:
318
  self._pipe = self._pipe.to(self.device)
319
  else:
320
  self._pipe.enable_sequential_cpu_offload()
321
- print(f"Pipeline loaded on {self.device} ({self.dtype})")
322
 
323
  @property
324
  def is_loaded(self) -> bool:
@@ -342,7 +409,8 @@ class LandmarkDiffPipeline:
342
  raise RuntimeError("Pipeline not loaded. Call .load() first.")
343
 
344
  flags = clinical_flags or self.clinical_flags
345
- image_512 = cv2.resize(image, (512, 512))
 
346
 
347
  face = extract_landmarks(image_512)
348
  if face is None:
@@ -357,7 +425,7 @@ class LandmarkDiffPipeline:
357
  try:
358
  rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
359
  # Map UI intensity (0-100) to displacement model intensity (0-2)
360
- dm_intensity = intensity / 50.0 # 50 -> 1.0x mean displacement
361
  displacement = self._displacement_model.get_displacement_field(
362
  procedure,
363
  intensity=dm_intensity,
@@ -373,17 +441,18 @@ class LandmarkDiffPipeline:
373
  new_lm[:, 1] = np.clip(new_lm[:, 1], 0.01, 0.99)
374
  manipulated = FaceLandmarks(
375
  landmarks=new_lm,
376
- image_width=512,
377
- image_height=512,
378
  confidence=face.confidence,
379
  )
380
  manipulation_mode = "displacement_model"
381
- except Exception:
 
382
  manipulated = apply_procedure_preset(
383
  face,
384
  procedure,
385
  intensity,
386
- image_size=512,
387
  clinical_flags=flags,
388
  )
389
  else:
@@ -391,15 +460,15 @@ class LandmarkDiffPipeline:
391
  face,
392
  procedure,
393
  intensity,
394
- image_size=512,
395
  clinical_flags=flags,
396
  )
397
- landmark_img = render_landmark_image(manipulated, 512, 512)
398
  mask = generate_surgical_mask(
399
  face,
400
  procedure,
401
- 512,
402
- 512,
403
  clinical_flags=flags,
404
  )
405
 
@@ -409,33 +478,51 @@ class LandmarkDiffPipeline:
409
 
410
  prompt = PROCEDURE_PROMPTS.get(procedure, "a photo of a person's face")
411
 
412
- # Step 1: TPS geometric warp (always computed the geometric baseline)
413
  tps_warped = warp_image_tps(image_512, face.pixel_coords, manipulated.pixel_coords)
414
 
415
  if self.mode == "tps":
416
  raw_output = tps_warped
417
- elif self.mode in ("controlnet", "controlnet_ip"):
 
 
 
 
418
  ip_image = numpy_to_pil(image_512) if self._ip_adapter_loaded else None
419
- raw_output = self._generate_controlnet(
420
- image_512,
421
- landmark_img,
422
- prompt,
423
- num_inference_steps,
424
- guidance_scale,
425
- controlnet_conditioning_scale,
426
- generator,
427
- ip_adapter_image=ip_image,
428
- )
 
 
 
 
 
 
 
429
  else:
430
- raw_output = self._generate_img2img(
431
- tps_warped,
432
- mask,
433
- prompt,
434
- num_inference_steps,
435
- guidance_scale,
436
- strength,
437
- generator,
438
- )
 
 
 
 
 
 
 
439
 
440
  # Step 2: Post-processing for photorealism (neural + classical pipeline)
441
  identity_check = None
@@ -474,6 +561,7 @@ class LandmarkDiffPipeline:
474
  "mode": self.mode,
475
  "view_info": view_info,
476
  "ip_adapter_active": self._ip_adapter_loaded,
 
477
  "identity_check": identity_check,
478
  "restore_used": restore_used,
479
  "manipulation_mode": manipulation_mode,
@@ -535,11 +623,12 @@ def estimate_face_view(face: FaceLandmarks) -> dict:
535
  Returns dict with yaw, pitch (degrees), and view classification.
536
  """
537
  coords = face.pixel_coords
538
- nose_tip = coords[1]
539
- left_ear = coords[234]
540
- right_ear = coords[454]
541
- forehead = coords[10]
542
- chin = coords[152]
 
543
 
544
  # Yaw: ratio of nose-to-ear distances (symmetric = 0 degrees)
545
  left_dist = np.linalg.norm(nose_tip - left_ear)
@@ -558,13 +647,13 @@ def estimate_face_view(face: FaceLandmarks) -> dict:
558
  pitch = 0.0
559
  else:
560
  pitch_ratio = (lower - upper) / (upper + lower)
561
- pitch = float(pitch_ratio * 45)
562
 
563
  # Classify view
564
  abs_yaw = abs(yaw)
565
- if abs_yaw < 15:
566
  view = "frontal"
567
- elif abs_yaw < 45:
568
  view = "three_quarter"
569
  else:
570
  view = "profile"
@@ -573,8 +662,10 @@ def estimate_face_view(face: FaceLandmarks) -> dict:
573
  "yaw": round(yaw, 1),
574
  "pitch": round(pitch, 1),
575
  "view": view,
576
- "is_frontal": abs_yaw < 15,
577
- "warning": "Side-view detected: results may be less accurate" if abs_yaw > 30 else None,
 
 
578
  }
579
 
580
 
@@ -594,7 +685,7 @@ def run_inference(
594
 
595
  image = cv2.imread(image_path)
596
  if image is None:
597
- print(f"ERROR: Could not load {image_path}")
598
  sys.exit(1)
599
 
600
  pipe = LandmarkDiffPipeline(
@@ -605,7 +696,7 @@ def run_inference(
605
  )
606
  pipe.load()
607
 
608
- print(f"\nGenerating {procedure} prediction (intensity={intensity}, mode={mode})...")
609
  result = pipe.generate(image, procedure=procedure, intensity=intensity, seed=seed)
610
 
611
  cv2.imwrite(str(out / "input.png"), result["input"])
@@ -620,9 +711,9 @@ def run_inference(
620
 
621
  view = result.get("view_info", {})
622
  if view.get("warning"):
623
- print(f"WARNING: {view['warning']}")
624
- print(f"Face view: {view.get('view', 'unknown')} (yaw={view.get('yaw', 0)})")
625
- print(f"Results saved to {out}/")
626
 
627
 
628
  if __name__ == "__main__":
@@ -637,7 +728,7 @@ if __name__ == "__main__":
637
  parser.add_argument(
638
  "--mode",
639
  default="img2img",
640
- choices=["img2img", "controlnet", "controlnet_ip", "tps"],
641
  )
642
  parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
643
  parser.add_argument(
 
4
  1. ControlNet: CrucibleAI/ControlNetMediaPipeFace + SD1.5 (requires HF auth + GPU)
5
  2. ControlNet + IP-Adapter: ControlNet with identity preservation via face embeddings
6
  3. Img2Img: SD1.5 img2img with mask compositing (runs on MPS, no auth needed)
7
+ 4. TPS-only: Pure geometric warp -- no diffusion model, instant results
8
 
9
  Supports MPS (Apple Silicon), CUDA, and CPU backends.
10
  """
11
 
12
  from __future__ import annotations
13
 
14
+ import logging
15
+ import os
16
  import sys
17
  from pathlib import Path
18
  from typing import TYPE_CHECKING
 
30
  if TYPE_CHECKING:
31
  from landmarkdiff.clinical import ClinicalFlags
32
 
33
+ logger = logging.getLogger(__name__)
34
+
35
 
36
  def get_device() -> torch.device:
37
  if torch.backends.mps.is_available():
 
75
  "realistic skin pores and texture, sharp focus, studio lighting, "
76
  "DSLR quality, natural skin color"
77
  ),
78
+ "brow_lift": (
79
+ "clinical photograph, patient face, elevated brow position, smooth forehead, "
80
+ "realistic skin pores and texture, sharp focus, studio lighting, "
81
+ "DSLR quality, natural skin color"
82
+ ),
83
+ "mentoplasty": (
84
+ "clinical photograph, patient face, refined chin contour, balanced lower face, "
85
+ "realistic skin pores and texture, sharp focus, studio lighting, "
86
+ "DSLR quality, natural skin color"
87
+ ),
88
  }
89
 
90
  NEGATIVE_PROMPT = (
 
95
  "plastic skin, waxy, smooth skin, airbrushed, oversaturated"
96
  )
97
 
98
+ # Skin tone matching: minimum mask alpha to include in LAB stats transfer
99
+ _SKIN_TONE_MASK_THRESHOLD = 0.3
100
+ # Epsilon to avoid division by zero in std normalization
101
+ _STD_EPSILON = 1e-6
102
+ # Default SD1.5 resolution (all pipelines resize to this)
103
+ _SD15_RESOLUTION = 512
104
+ # Intensity mapping: UI scale (0-100) to displacement model scale (0-2)
105
+ _INTENSITY_UI_TO_MODEL = 50.0
106
+ # Face view classification thresholds (degrees)
107
+ _YAW_FRONTAL_MAX = 15
108
+ _YAW_THREE_QUARTER_MAX = 45
109
+ _YAW_WARNING_THRESHOLD = 30
110
+ # Max pitch scale factor (maps pitch ratio to degrees)
111
+ _PITCH_SCALE = 45
112
+
113
 
114
  def mask_composite(
115
  warped: np.ndarray,
 
136
 
137
  return laplacian_pyramid_blend(corrected, original, mask_f)
138
  except Exception:
139
+ logger.debug("Laplacian blend failed, using alpha blend", exc_info=True)
140
 
141
  # Fallback: simple alpha blend
142
  mask_3ch = mask_to_3channel(mask_f)
 
153
  Works in LAB space: transfers L (luminance) and AB (color) statistics
154
  from the original to the warped image so skin tone is preserved exactly.
155
  """
156
+ mask_bool = mask > _SKIN_TONE_MASK_THRESHOLD
157
  if not np.any(mask_bool):
158
  return source
159
 
 
165
  src_vals = src_lab[:, :, ch][mask_bool]
166
  tgt_vals = tgt_lab[:, :, ch][mask_bool]
167
 
168
+ src_mean, src_std = np.mean(src_vals), np.std(src_vals) + _STD_EPSILON
169
+ tgt_mean, tgt_std = np.mean(tgt_vals), np.std(tgt_vals) + _STD_EPSILON
170
 
171
  # Normalize source to match target's distribution
172
  src_lab[:, :, ch] = np.where(
 
183
  """End-to-end pipeline: image -> landmarks -> manipulate -> generate.
184
 
185
  Modes:
186
+ - 'controlnet': CrucibleAI/ControlNetMediaPipeFace + SD1.5 (30 steps)
187
+ - 'controlnet_fast': ControlNet + LCM-LoRA (4 steps, CPU-viable)
188
  - 'controlnet_ip': ControlNet + IP-Adapter for identity preservation
189
  - 'img2img': SD1.5 img2img with mask compositing
190
  - 'tps': Pure geometric TPS warp (no diffusion, instant)
 
196
  IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus-face_sd15.bin"
197
  IP_ADAPTER_SCALE_DEFAULT = 0.6
198
 
199
+ # LCM-LoRA for fast inference (2-4 steps instead of 30)
200
+ LCM_LORA_REPO = "latent-consistency/lcm-lora-sdv1-5"
201
+
202
  def __init__(
203
  self,
204
  mode: str = "img2img",
 
224
  from landmarkdiff.displacement_model import DisplacementModel
225
 
226
  self._displacement_model = DisplacementModel.load(displacement_model_path)
227
+ logger.info("Displacement model loaded: %s", self._displacement_model.procedures)
228
  except Exception as e:
229
+ logger.warning("Failed to load displacement model: %s", e)
230
 
231
  if self.device.type == "mps":
232
  self.dtype = torch.float32
 
237
 
238
  if base_model_id:
239
  self.base_model_id = base_model_id
 
 
240
  else:
241
  self.base_model_id = "runwayml/stable-diffusion-v1-5"
242
 
243
  self.controlnet_id = controlnet_id
244
  self._pipe = None
245
  self._ip_adapter_loaded = False
246
+ self._lcm_loaded = False
247
 
248
  def load(self) -> None:
249
  if self.mode == "tps":
250
+ logger.info("TPS mode -- no model to load")
251
  return
252
+ if self.mode in ("controlnet", "controlnet_ip", "controlnet_fast"):
253
  self._load_controlnet()
254
+ if self.mode == "controlnet_fast":
255
+ self._load_lcm_lora()
256
+ elif self.mode == "controlnet_ip":
257
  self._load_ip_adapter()
258
  else:
259
  self._load_img2img()
 
265
  StableDiffusionControlNetPipeline,
266
  )
267
 
268
+ _local_only = os.environ.get("HF_HUB_OFFLINE", "0") == "1"
269
+ _kw: dict = {"local_files_only": True} if _local_only else {}
270
+
271
  if self.controlnet_checkpoint:
272
  # Load fine-tuned ControlNet from local checkpoint
273
  ckpt_path = Path(self.controlnet_checkpoint)
274
  # Support both direct path and training checkpoint structure
275
  if (ckpt_path / "controlnet_ema").exists():
276
  ckpt_path = ckpt_path / "controlnet_ema"
277
+ logger.info("Loading fine-tuned ControlNet from %s", ckpt_path)
278
  controlnet = ControlNetModel.from_pretrained(
279
  str(ckpt_path),
280
  torch_dtype=self.dtype,
281
  )
282
  else:
283
+ logger.info("Loading ControlNet from %s", self.controlnet_id)
284
  controlnet = ControlNetModel.from_pretrained(
285
  self.controlnet_id,
286
  subfolder="diffusion_sd15",
287
  torch_dtype=self.dtype,
288
+ **_kw,
289
  )
290
+ logger.info("Loading base model from %s", self.base_model_id)
291
  self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
292
  self.base_model_id,
293
  controlnet=controlnet,
294
  torch_dtype=self.dtype,
295
  safety_checker=None,
296
  requires_safety_checker=False,
297
+ **_kw,
298
  )
299
+ # DPM++ 2M Karras -- produces more photorealistic output than UniPC
300
  self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
301
  self._pipe.scheduler.config,
302
  algorithm_type="dpmsolver++",
303
  use_karras_sigmas=True,
304
  )
305
+ # FP32 VAE decode -- prevents color banding artifacts on skin tones
306
  if hasattr(self._pipe, "vae") and self._pipe.vae is not None:
307
  self._pipe.vae.config.force_upcast = True
308
  self._apply_device_optimizations()
309
 
310
+ def _load_lcm_lora(self) -> None:
311
+ """Load LCM-LoRA for fast 4-step inference.
312
+
313
+ LCM-LoRA (Latent Consistency Model) distills the denoising process
314
+ into 2-4 steps, making CPU inference viable (~3-8s vs ~60s+).
315
+ Replaces the scheduler with LCMScheduler for consistency sampling.
316
+ """
317
+ if self._pipe is None:
318
+ raise RuntimeError("Base pipeline must be loaded before LCM-LoRA")
319
+ try:
320
+ from diffusers import LCMScheduler
321
+
322
+ logger.info("Loading LCM-LoRA from %s", self.LCM_LORA_REPO)
323
+ _local_only = os.environ.get("HF_HUB_OFFLINE", "0") == "1"
324
+ _kw: dict = {"local_files_only": True} if _local_only else {}
325
+ self._pipe.load_lora_weights(self.LCM_LORA_REPO, **_kw)
326
+ self._pipe.scheduler = LCMScheduler.from_config(self._pipe.scheduler.config)
327
+ self._lcm_loaded = True
328
+ logger.info("LCM-LoRA loaded -- 4-step inference enabled")
329
+ except Exception as e:
330
+ logger.warning("LCM-LoRA load failed: %s", e)
331
+ logger.warning("Falling back to standard scheduler (30 steps)")
332
+ self._lcm_loaded = False
333
+
334
  def _load_ip_adapter(self) -> None:
335
  """Load IP-Adapter for identity-preserving generation.
336
 
 
340
  if self._pipe is None:
341
  raise RuntimeError("Base pipeline must be loaded before IP-Adapter")
342
  try:
343
+ logger.info("Loading IP-Adapter (%s)", self.IP_ADAPTER_WEIGHT_NAME)
344
  self._pipe.load_ip_adapter(
345
  self.IP_ADAPTER_REPO,
346
  subfolder=self.IP_ADAPTER_SUBFOLDER,
 
348
  )
349
  self._pipe.set_ip_adapter_scale(self.ip_adapter_scale)
350
  self._ip_adapter_loaded = True
351
+ logger.info("IP-Adapter loaded (scale=%s)", self.ip_adapter_scale)
352
  except Exception as e:
353
+ logger.warning("IP-Adapter load failed: %s", e)
354
+ logger.warning("Falling back to ControlNet-only mode")
355
  self._ip_adapter_loaded = False
356
 
357
  def _load_img2img(self) -> None:
 
360
  StableDiffusionImg2ImgPipeline,
361
  )
362
 
363
+ _local_only = os.environ.get("HF_HUB_OFFLINE", "0") == "1"
364
+ _kw: dict = {"local_files_only": True} if _local_only else {}
365
+
366
+ logger.info("Loading SD1.5 img2img from %s", self.base_model_id)
367
  self._pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
368
  self.base_model_id,
369
  torch_dtype=self.dtype,
370
  safety_checker=None,
371
  requires_safety_checker=False,
372
+ **_kw,
373
  )
374
  self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(self._pipe.scheduler.config)
375
  self._apply_device_optimizations()
 
385
  self._pipe = self._pipe.to(self.device)
386
  else:
387
  self._pipe.enable_sequential_cpu_offload()
388
+ logger.info("Pipeline loaded on %s (%s)", self.device, self.dtype)
389
 
390
  @property
391
  def is_loaded(self) -> bool:
 
409
  raise RuntimeError("Pipeline not loaded. Call .load() first.")
410
 
411
  flags = clinical_flags or self.clinical_flags
412
+ res = _SD15_RESOLUTION
413
+ image_512 = cv2.resize(image, (res, res))
414
 
415
  face = extract_landmarks(image_512)
416
  if face is None:
 
425
  try:
426
  rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
427
  # Map UI intensity (0-100) to displacement model intensity (0-2)
428
+ dm_intensity = intensity / _INTENSITY_UI_TO_MODEL # 50 -> 1.0x mean displacement
429
  displacement = self._displacement_model.get_displacement_field(
430
  procedure,
431
  intensity=dm_intensity,
 
441
  new_lm[:, 1] = np.clip(new_lm[:, 1], 0.01, 0.99)
442
  manipulated = FaceLandmarks(
443
  landmarks=new_lm,
444
+ image_width=res,
445
+ image_height=res,
446
  confidence=face.confidence,
447
  )
448
  manipulation_mode = "displacement_model"
449
+ except Exception as exc:
450
+ logger.warning("Displacement model failed, falling back to preset: %s", exc)
451
  manipulated = apply_procedure_preset(
452
  face,
453
  procedure,
454
  intensity,
455
+ image_size=res,
456
  clinical_flags=flags,
457
  )
458
  else:
 
460
  face,
461
  procedure,
462
  intensity,
463
+ image_size=res,
464
  clinical_flags=flags,
465
  )
466
+ landmark_img = render_landmark_image(manipulated, res, res)
467
  mask = generate_surgical_mask(
468
  face,
469
  procedure,
470
+ res,
471
+ res,
472
  clinical_flags=flags,
473
  )
474
 
 
478
 
479
  prompt = PROCEDURE_PROMPTS.get(procedure, "a photo of a person's face")
480
 
481
+ # Step 1: TPS geometric warp (always computed -- the geometric baseline)
482
  tps_warped = warp_image_tps(image_512, face.pixel_coords, manipulated.pixel_coords)
483
 
484
  if self.mode == "tps":
485
  raw_output = tps_warped
486
+ elif self.mode in ("controlnet", "controlnet_ip", "controlnet_fast"):
487
+ # LCM mode: override to 4 steps, low guidance (LCM works best with cfg=1-2)
488
+ if self._lcm_loaded:
489
+ num_inference_steps = min(num_inference_steps, 4)
490
+ guidance_scale = min(guidance_scale, 1.5)
491
  ip_image = numpy_to_pil(image_512) if self._ip_adapter_loaded else None
492
+ try:
493
+ raw_output = self._generate_controlnet(
494
+ image_512,
495
+ landmark_img,
496
+ prompt,
497
+ num_inference_steps,
498
+ guidance_scale,
499
+ controlnet_conditioning_scale,
500
+ generator,
501
+ ip_adapter_image=ip_image,
502
+ )
503
+ except torch.cuda.OutOfMemoryError as exc:
504
+ torch.cuda.empty_cache()
505
+ raise RuntimeError(
506
+ "GPU out of memory during inference. Try reducing "
507
+ "num_inference_steps or switching to mode='tps' for CPU-only."
508
+ ) from exc
509
  else:
510
+ try:
511
+ raw_output = self._generate_img2img(
512
+ tps_warped,
513
+ mask,
514
+ prompt,
515
+ num_inference_steps,
516
+ guidance_scale,
517
+ strength,
518
+ generator,
519
+ )
520
+ except torch.cuda.OutOfMemoryError as exc:
521
+ torch.cuda.empty_cache()
522
+ raise RuntimeError(
523
+ "GPU out of memory during inference. Try reducing "
524
+ "num_inference_steps or switching to mode='tps' for CPU-only."
525
+ ) from exc
526
 
527
  # Step 2: Post-processing for photorealism (neural + classical pipeline)
528
  identity_check = None
 
561
  "mode": self.mode,
562
  "view_info": view_info,
563
  "ip_adapter_active": self._ip_adapter_loaded,
564
+ "lcm_active": self._lcm_loaded,
565
  "identity_check": identity_check,
566
  "restore_used": restore_used,
567
  "manipulation_mode": manipulation_mode,
 
623
  Returns dict with yaw, pitch (degrees), and view classification.
624
  """
625
  coords = face.pixel_coords
626
+ # MediaPipe landmark indices for key anatomical points
627
+ nose_tip = coords[1] # nose tip
628
+ left_ear = coords[234] # left tragion (ear)
629
+ right_ear = coords[454] # right tragion (ear)
630
+ forehead = coords[10] # forehead center
631
+ chin = coords[152] # chin center
632
 
633
  # Yaw: ratio of nose-to-ear distances (symmetric = 0 degrees)
634
  left_dist = np.linalg.norm(nose_tip - left_ear)
 
647
  pitch = 0.0
648
  else:
649
  pitch_ratio = (lower - upper) / (upper + lower)
650
+ pitch = float(pitch_ratio * _PITCH_SCALE)
651
 
652
  # Classify view
653
  abs_yaw = abs(yaw)
654
+ if abs_yaw < _YAW_FRONTAL_MAX:
655
  view = "frontal"
656
+ elif abs_yaw < _YAW_THREE_QUARTER_MAX:
657
  view = "three_quarter"
658
  else:
659
  view = "profile"
 
662
  "yaw": round(yaw, 1),
663
  "pitch": round(pitch, 1),
664
  "view": view,
665
+ "is_frontal": abs_yaw < _YAW_FRONTAL_MAX,
666
+ "warning": "Side-view detected: results may be less accurate"
667
+ if abs_yaw > _YAW_WARNING_THRESHOLD
668
+ else None,
669
  }
670
 
671
 
 
685
 
686
  image = cv2.imread(image_path)
687
  if image is None:
688
+ logger.error("Could not load %s", image_path)
689
  sys.exit(1)
690
 
691
  pipe = LandmarkDiffPipeline(
 
696
  )
697
  pipe.load()
698
 
699
+ logger.info("Generating %s prediction (intensity=%s, mode=%s)", procedure, intensity, mode)
700
  result = pipe.generate(image, procedure=procedure, intensity=intensity, seed=seed)
701
 
702
  cv2.imwrite(str(out / "input.png"), result["input"])
 
711
 
712
  view = result.get("view_info", {})
713
  if view.get("warning"):
714
+ logger.warning("%s", view["warning"])
715
+ logger.info("Face view: %s (yaw=%s)", view.get("view", "unknown"), view.get("yaw", 0))
716
+ logger.info("Results saved to %s/", out)
717
 
718
 
719
  if __name__ == "__main__":
 
728
  parser.add_argument(
729
  "--mode",
730
  default="img2img",
731
+ choices=["img2img", "controlnet", "controlnet_ip", "controlnet_fast", "tps"],
732
  )
733
  parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
734
  parser.add_argument(