dreamlessx commited on
Commit
83b71db
·
verified ·
1 Parent(s): 57d9540

Upload landmarkdiff/postprocess.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/postprocess.py +452 -0
landmarkdiff/postprocess.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Post-processing: CodeFormer/GFPGAN face restore, Real-ESRGAN bg,
2
+ Laplacian blend, sharpening, histogram matching, ArcFace identity gate.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import cv2
8
+ import numpy as np
9
+
10
+
11
+ def laplacian_pyramid_blend(
12
+ source: np.ndarray,
13
+ target: np.ndarray,
14
+ mask: np.ndarray,
15
+ levels: int = 6,
16
+ ) -> np.ndarray:
17
+ """Laplacian pyramid blend - kills the 'pasted on' look from alpha blending."""
18
+ # Ensure same size
19
+ h, w = target.shape[:2]
20
+ source = cv2.resize(source, (w, h)) if source.shape[:2] != (h, w) else source
21
+
22
+ # Normalize mask
23
+ mask_f = mask.astype(np.float32)
24
+ if mask_f.max() > 1.0:
25
+ mask_f = mask_f / 255.0
26
+ if mask_f.ndim == 2:
27
+ mask_3ch = np.stack([mask_f] * 3, axis=-1)
28
+ else:
29
+ mask_3ch = mask_f
30
+
31
+ # Make dimensions divisible by 2^levels
32
+ factor = 2 ** levels
33
+ new_h = (h + factor - 1) // factor * factor
34
+ new_w = (w + factor - 1) // factor * factor
35
+
36
+ if new_h != h or new_w != w:
37
+ source = cv2.resize(source, (new_w, new_h))
38
+ target = cv2.resize(target, (new_w, new_h))
39
+ mask_3ch = cv2.resize(mask_3ch, (new_w, new_h))
40
+
41
+ src_f = source.astype(np.float32)
42
+ tgt_f = target.astype(np.float32)
43
+
44
+ # Build Gaussian pyramids for the mask
45
+ mask_pyr = [mask_3ch]
46
+ for _ in range(levels):
47
+ mask_pyr.append(cv2.pyrDown(mask_pyr[-1]))
48
+
49
+ # Build Laplacian pyramids for source and target
50
+ src_lap = _build_laplacian_pyramid(src_f, levels)
51
+ tgt_lap = _build_laplacian_pyramid(tgt_f, levels)
52
+
53
+ # Blend each level using the mask at that resolution
54
+ blended_lap = []
55
+ for i in range(levels + 1):
56
+ sl = src_lap[i]
57
+ tl = tgt_lap[i]
58
+ ml = mask_pyr[i]
59
+ # Resize mask to match level shape if needed
60
+ if ml.shape[:2] != sl.shape[:2]:
61
+ ml = cv2.resize(ml, (sl.shape[1], sl.shape[0]))
62
+ blended = sl * ml + tl * (1.0 - ml)
63
+ blended_lap.append(blended)
64
+
65
+ # Reconstruct from blended Laplacian
66
+ result = _reconstruct_from_laplacian(blended_lap)
67
+
68
+ # Crop back to original size
69
+ result = result[:h, :w]
70
+ return np.clip(result, 0, 255).astype(np.uint8)
71
+
72
+
73
+ def _build_laplacian_pyramid(
74
+ image: np.ndarray,
75
+ levels: int,
76
+ ) -> list[np.ndarray]:
77
+ """Build Laplacian pyramid from an image."""
78
+ gaussian = [image.copy()]
79
+ for _ in range(levels):
80
+ gaussian.append(cv2.pyrDown(gaussian[-1]))
81
+
82
+ laplacian = []
83
+ for i in range(levels):
84
+ upsampled = cv2.pyrUp(gaussian[i + 1])
85
+ # Match sizes (pyrUp can add a pixel)
86
+ gh, gw = gaussian[i].shape[:2]
87
+ upsampled = upsampled[:gh, :gw]
88
+ laplacian.append(gaussian[i] - upsampled)
89
+
90
+ laplacian.append(gaussian[-1]) # coarsest level
91
+ return laplacian
92
+
93
+
94
+ def _reconstruct_from_laplacian(pyramid: list[np.ndarray]) -> np.ndarray:
95
+ """Reconstruct image from Laplacian pyramid."""
96
+ image = pyramid[-1].copy()
97
+ for i in range(len(pyramid) - 2, -1, -1):
98
+ image = cv2.pyrUp(image)
99
+ lh, lw = pyramid[i].shape[:2]
100
+ image = image[:lh, :lw]
101
+ image = image + pyramid[i]
102
+ return image
103
+
104
+
105
+ def frequency_aware_sharpen(
106
+ image: np.ndarray,
107
+ strength: float = 0.3,
108
+ radius: int = 3,
109
+ ) -> np.ndarray:
110
+ """Unsharp mask on LAB luminance only - sharpens skin texture without color fringe."""
111
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
112
+ l_channel = lab[:, :, 0]
113
+
114
+ # Unsharp mask on luminance only
115
+ ksize = radius * 2 + 1
116
+ blurred = cv2.GaussianBlur(l_channel, (ksize, ksize), 0)
117
+ sharpened = l_channel + strength * (l_channel - blurred)
118
+
119
+ lab[:, :, 0] = np.clip(sharpened, 0, 255)
120
+ return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
121
+
122
+
123
+ def restore_face_gfpgan(
124
+ image: np.ndarray,
125
+ upscale: int = 1,
126
+ ) -> np.ndarray:
127
+ """GFPGAN face restore. Returns original if not installed."""
128
+ try:
129
+ from gfpgan import GFPGANer
130
+ except ImportError:
131
+ return image
132
+
133
+ try:
134
+ restorer = GFPGANer(
135
+ model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
136
+ upscale=upscale,
137
+ arch="clean",
138
+ channel_multiplier=2,
139
+ bg_upsampler=None,
140
+ )
141
+ _, _, restored = restorer.enhance(
142
+ image,
143
+ has_aligned=False,
144
+ only_center_face=True,
145
+ paste_back=True,
146
+ )
147
+ if restored is not None:
148
+ return restored
149
+ except Exception:
150
+ pass
151
+
152
+ return image
153
+
154
+
155
+ def restore_face_codeformer(
156
+ image: np.ndarray,
157
+ fidelity: float = 0.7,
158
+ upscale: int = 1,
159
+ ) -> np.ndarray:
160
+ """CodeFormer face restore. fidelity: 0=quality, 1=identity. Returns original if not installed."""
161
+ try:
162
+ from codeformer.basicsr.utils import img2tensor, tensor2img
163
+ from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
164
+ from codeformer.basicsr.utils.download_util import load_file_from_url
165
+ import torch
166
+ from torchvision.transforms.functional import normalize as tv_normalize
167
+ except ImportError:
168
+ return image
169
+
170
+ try:
171
+ from codeformer.inference_codeformer import set_realesrgan as _unused # noqa: F401
172
+ from codeformer.basicsr.archs.codeformer_arch import CodeFormer as CodeFormerArch
173
+
174
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
175
+
176
+ model = CodeFormerArch(
177
+ dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
178
+ connect_list=["32", "64", "128", "256"],
179
+ ).to(device)
180
+
181
+ ckpt_path = load_file_from_url(
182
+ url="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
183
+ model_dir="weights/CodeFormer",
184
+ progress=True,
185
+ )
186
+ checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
187
+ model.load_state_dict(checkpoint["params_ema"])
188
+ model.eval()
189
+
190
+ face_helper = FaceRestoreHelper(
191
+ upscale,
192
+ face_size=512,
193
+ crop_ratio=(1, 1),
194
+ det_model="retinaface_resnet50",
195
+ save_ext="png",
196
+ device=device,
197
+ )
198
+ face_helper.read_image(image)
199
+ face_helper.get_face_landmarks_5(only_center_face=True)
200
+ face_helper.align_warp_face()
201
+
202
+ for cropped_face in face_helper.cropped_faces:
203
+ face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
204
+ tv_normalize(face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
205
+ face_t = face_t.unsqueeze(0).to(device)
206
+
207
+ with torch.no_grad():
208
+ output = model(face_t, w=fidelity, adain=True)[0]
209
+ restored = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
210
+ restored = restored.astype(np.uint8)
211
+ face_helper.add_restored_face(restored)
212
+
213
+ face_helper.get_inverse_affine(None)
214
+ restored_img = face_helper.paste_faces_to_image()
215
+ if restored_img is not None:
216
+ return restored_img
217
+ except Exception:
218
+ pass
219
+
220
+ return image
221
+
222
+
223
+ def enhance_background_realesrgan(
224
+ image: np.ndarray,
225
+ mask: np.ndarray,
226
+ outscale: int = 2,
227
+ ) -> np.ndarray:
228
+ """Real-ESRGAN on background only (outside mask). Returns original if not installed."""
229
+ try:
230
+ from realesrgan import RealESRGANer
231
+ from basicsr.archs.rrdbnet_arch import RRDBNet
232
+ import torch
233
+ except ImportError:
234
+ return image
235
+
236
+ try:
237
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
238
+ upsampler = RealESRGANer(
239
+ scale=4,
240
+ model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
241
+ model=model,
242
+ tile=400,
243
+ tile_pad=10,
244
+ pre_pad=0,
245
+ half=torch.cuda.is_available(),
246
+ )
247
+ enhanced, _ = upsampler.enhance(image, outscale=outscale)
248
+
249
+ # Downscale back to original size
250
+ h, w = image.shape[:2]
251
+ enhanced = cv2.resize(enhanced, (w, h), interpolation=cv2.INTER_LANCZOS4)
252
+
253
+ # Only apply enhancement to background (outside mask)
254
+ mask_f = mask.astype(np.float32)
255
+ if mask_f.max() > 1.0:
256
+ mask_f /= 255.0
257
+ if mask_f.ndim == 2:
258
+ mask_3ch = np.stack([mask_f] * 3, axis=-1)
259
+ else:
260
+ mask_3ch = mask_f
261
+
262
+ # Keep face region from original, use enhanced for background
263
+ result = (
264
+ image.astype(np.float32) * mask_3ch
265
+ + enhanced.astype(np.float32) * (1.0 - mask_3ch)
266
+ ).astype(np.uint8)
267
+ return result
268
+ except Exception:
269
+ pass
270
+
271
+ return image
272
+
273
+
274
+ def verify_identity_arcface(
275
+ original: np.ndarray,
276
+ result: np.ndarray,
277
+ threshold: float = 0.6,
278
+ ) -> dict:
279
+ """ArcFace cosine similarity check. Flags if output drifted from input identity."""
280
+ try:
281
+ from insightface.app import FaceAnalysis
282
+ except ImportError:
283
+ return {
284
+ "similarity": -1.0,
285
+ "passed": True,
286
+ "message": "InsightFace not installed - identity check skipped",
287
+ }
288
+
289
+ try:
290
+ app = FaceAnalysis(
291
+ name="buffalo_l",
292
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
293
+ )
294
+ app.prepare(ctx_id=0 if _has_cuda() else -1, det_size=(320, 320))
295
+
296
+ orig_faces = app.get(original)
297
+ result_faces = app.get(result)
298
+
299
+ if not orig_faces or not result_faces:
300
+ return {
301
+ "similarity": -1.0,
302
+ "passed": True,
303
+ "message": "Could not detect face in one/both images - check skipped",
304
+ }
305
+
306
+ orig_emb = orig_faces[0].embedding
307
+ result_emb = result_faces[0].embedding
308
+
309
+ sim = float(np.dot(orig_emb, result_emb) / (
310
+ np.linalg.norm(orig_emb) * np.linalg.norm(result_emb) + 1e-8
311
+ ))
312
+ sim = float(np.clip(sim, 0, 1))
313
+
314
+ passed = sim >= threshold
315
+ if passed:
316
+ msg = f"Identity preserved (similarity={sim:.3f})"
317
+ else:
318
+ msg = f"WARNING: Identity drift detected (similarity={sim:.3f} < {threshold})"
319
+
320
+ return {"similarity": sim, "passed": passed, "message": msg}
321
+ except Exception as e:
322
+ return {
323
+ "similarity": -1.0,
324
+ "passed": True,
325
+ "message": f"Identity check failed: {e}",
326
+ }
327
+
328
+
329
+ def _has_cuda() -> bool:
330
+ try:
331
+ import torch
332
+ return torch.cuda.is_available()
333
+ except ImportError:
334
+ return False
335
+
336
+
337
+ def histogram_match_skin(
338
+ source: np.ndarray,
339
+ reference: np.ndarray,
340
+ mask: np.ndarray,
341
+ ) -> np.ndarray:
342
+ """CDF-based histogram matching in LAB space. Better than mean/std for skin."""
343
+ mask_bool = mask > 0.3 if mask.dtype == np.float32 else mask > 76
344
+
345
+ if not np.any(mask_bool):
346
+ return source
347
+
348
+ result = source.copy()
349
+ src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
350
+ ref_lab = cv2.cvtColor(reference, cv2.COLOR_BGR2LAB).astype(np.float32)
351
+
352
+ for ch in range(3):
353
+ src_vals = src_lab[:, :, ch][mask_bool]
354
+ ref_vals = ref_lab[:, :, ch][mask_bool]
355
+
356
+ if len(src_vals) == 0 or len(ref_vals) == 0:
357
+ continue
358
+
359
+ # CDF matching
360
+ src_sorted = np.sort(src_vals)
361
+ ref_sorted = np.sort(ref_vals)
362
+
363
+ # Interpolate reference CDF to match source length
364
+ src_cdf = np.linspace(0, 1, len(src_sorted))
365
+ ref_cdf = np.linspace(0, 1, len(ref_sorted))
366
+
367
+ # Map source values through reference distribution
368
+ mapping = np.interp(src_cdf, ref_cdf, ref_sorted)
369
+
370
+ # Create lookup from source intensity to matched intensity
371
+ src_flat = src_lab[:, :, ch].ravel()
372
+ matched = np.interp(src_flat, src_sorted, mapping)
373
+ matched_2d = matched.reshape(src_lab.shape[:2])
374
+
375
+ # Apply only in mask region
376
+ src_lab[:, :, ch] = np.where(mask_bool, matched_2d, src_lab[:, :, ch])
377
+
378
+ result_lab = np.clip(src_lab, 0, 255).astype(np.uint8)
379
+ return cv2.cvtColor(result_lab, cv2.COLOR_LAB2BGR)
380
+
381
+
382
+ def full_postprocess(
383
+ generated: np.ndarray,
384
+ original: np.ndarray,
385
+ mask: np.ndarray,
386
+ restore_mode: str = "codeformer",
387
+ codeformer_fidelity: float = 0.7,
388
+ use_realesrgan: bool = True,
389
+ use_laplacian_blend: bool = True,
390
+ sharpen_strength: float = 0.25,
391
+ verify_identity: bool = True,
392
+ identity_threshold: float = 0.6,
393
+ ) -> dict:
394
+ """Full pipeline: restore -> bg enhance -> histogram match -> sharpen -> blend -> identity check."""
395
+ result = generated.copy()
396
+ restore_used = "none"
397
+
398
+ # Step 1: Neural face restoration (CodeFormer > GFPGAN > skip)
399
+ if restore_mode == "codeformer":
400
+ restored = restore_face_codeformer(result, fidelity=codeformer_fidelity)
401
+ if restored is not result:
402
+ result = restored
403
+ restore_used = "codeformer"
404
+ else:
405
+ # CodeFormer unavailable, fall back to GFPGAN
406
+ result = restore_face_gfpgan(result)
407
+ restore_used = "gfpgan" if result is not generated else "none"
408
+ elif restore_mode == "gfpgan":
409
+ restored = restore_face_gfpgan(result)
410
+ if restored is not result:
411
+ result = restored
412
+ restore_used = "gfpgan"
413
+
414
+ # Step 2: Neural background enhancement
415
+ if use_realesrgan:
416
+ result = enhance_background_realesrgan(result, mask)
417
+
418
+ # Step 3: Skin tone histogram matching (classical)
419
+ result = histogram_match_skin(result, original, mask)
420
+
421
+ # Step 4: Sharpen texture (classical)
422
+ if sharpen_strength > 0:
423
+ result = frequency_aware_sharpen(result, strength=sharpen_strength)
424
+
425
+ # Step 5: Blend into original (classical)
426
+ if use_laplacian_blend:
427
+ composited = laplacian_pyramid_blend(result, original, mask)
428
+ else:
429
+ mask_f = mask.astype(np.float32)
430
+ if mask_f.max() > 1.0:
431
+ mask_f /= 255.0
432
+ if mask_f.ndim == 2:
433
+ mask_3ch = np.stack([mask_f] * 3, axis=-1)
434
+ else:
435
+ mask_3ch = mask_f
436
+ composited = (
437
+ result.astype(np.float32) * mask_3ch
438
+ + original.astype(np.float32) * (1.0 - mask_3ch)
439
+ ).astype(np.uint8)
440
+
441
+ # Step 6: Neural identity verification
442
+ identity_check = {"similarity": -1.0, "passed": True, "message": "skipped"}
443
+ if verify_identity:
444
+ identity_check = verify_identity_arcface(
445
+ original, composited, threshold=identity_threshold,
446
+ )
447
+
448
+ return {
449
+ "image": composited,
450
+ "identity_check": identity_check,
451
+ "restore_used": restore_used,
452
+ }