dreamlessx commited on
Commit
871693c
·
verified ·
1 Parent(s): cfc00ce

Update landmarkdiff/losses.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/losses.py +24 -37
landmarkdiff/losses.py CHANGED
@@ -1,7 +1,6 @@
1
  """4-term loss function module for ControlNet fine-tuning.
2
 
3
- L_total = L_diffusion + w_landmark * L_landmark
4
- + w_identity * L_identity + w_perceptual * L_perceptual
5
 
6
  Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
7
  rubbery TPS warps — it would penalize realism.
@@ -23,8 +22,8 @@ class LossWeights:
23
 
24
  diffusion: float = 1.0
25
  landmark: float = 0.1
26
- identity: float = 0.05
27
- perceptual: float = 0.1
28
 
29
 
30
  class DiffusionLoss:
@@ -93,16 +92,11 @@ class IdentityLoss:
93
  return
94
  try:
95
  from insightface.app import FaceAnalysis
96
-
97
  self._app = FaceAnalysis(
98
  name="buffalo_l",
99
  providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
100
  )
101
- ctx_id = (
102
- device.index
103
- if device.type == "cuda" and device.index is not None
104
- else (0 if device.type == "cuda" else -1)
105
- )
106
  self._app.prepare(ctx_id=ctx_id, det_size=(320, 320))
107
  self._has_arcface = True
108
  except Exception:
@@ -120,7 +114,6 @@ class IdentityLoss:
120
  """
121
  if self._has_arcface:
122
  import numpy as np
123
-
124
  embeddings = []
125
  valid_mask = []
126
  for i in range(image_tensor.shape[0]):
@@ -159,9 +152,7 @@ class IdentityLoss:
159
 
160
  # Resize to 112x112 for ArcFace
161
  pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False)
162
- target_112 = F.interpolate(
163
- target_crop, size=(112, 112), mode="bilinear", align_corners=False
164
- )
165
 
166
  # Normalize to [-1, 1]
167
  pred_norm = pred_112 * 2 - 1
@@ -172,7 +163,7 @@ class IdentityLoss:
172
  target_emb, target_valid = self._extract_embedding(target_norm)
173
 
174
  # Only compute loss for samples where both faces were detected
175
- valid = [p and t for p, t in zip(pred_valid, target_valid, strict=False)]
176
  if not any(valid):
177
  return torch.tensor(0.0, device=pred_image.device)
178
 
@@ -225,7 +216,6 @@ class PerceptualLoss:
225
  if self._lpips is None:
226
  try:
227
  import lpips
228
-
229
  self._lpips = lpips.LPIPS(net="alex").to(device)
230
  self._lpips.eval()
231
  for p in self._lpips.parameters():
@@ -235,33 +225,31 @@ class PerceptualLoss:
235
 
236
  def __call__(
237
  self,
238
- pred: torch.Tensor, # (B, 3, H, W) in [0, 1]
239
  target: torch.Tensor,
240
- mask: torch.Tensor, # (B, 1, H, W) surgical mask [0, 1]
241
  ) -> torch.Tensor:
242
  self._ensure_loaded(pred.device)
243
 
244
- # Invert mask: we want loss OUTSIDE surgical region
245
- outside_mask = 1 - mask
246
-
247
- # Erode outside_mask to exclude boundary pixels — avoids artificial
248
- # edge features where masked (0) meets unmasked (non-zero) values
249
- erode_kernel = 5
250
- if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel:
251
- outside_mask = -F.max_pool2d(
252
- -outside_mask,
253
- kernel_size=erode_kernel,
254
- stride=1,
255
- padding=erode_kernel // 2,
256
- )
257
-
258
- # Normalize to [-1, 1] for LPIPS FIRST, then mask
259
  pred_norm = pred * 2 - 1
260
  target_norm = target * 2 - 1
261
 
262
- # Apply mask after normalization (masked regions become 0, not -1)
263
- pred_norm = pred_norm * outside_mask
264
- target_norm = target_norm * outside_mask
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  if self._lpips == "unavailable":
267
  # Fallback: simple L1 loss
@@ -299,7 +287,6 @@ class CombinedLoss:
299
  # or ONNX-based fallback
300
  if use_differentiable_arcface:
301
  from landmarkdiff.arcface_torch import ArcFaceLoss
302
-
303
  self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
304
  else:
305
  self.identity_loss = IdentityLoss()
 
1
  """4-term loss function module for ControlNet fine-tuning.
2
 
3
+ L_total = L_diffusion + w_landmark * L_landmark + w_identity * L_identity + w_perceptual * L_perceptual
 
4
 
5
  Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
6
  rubbery TPS warps — it would penalize realism.
 
22
 
23
  diffusion: float = 1.0
24
  landmark: float = 0.1
25
+ identity: float = 0.1
26
+ perceptual: float = 0.05
27
 
28
 
29
  class DiffusionLoss:
 
92
  return
93
  try:
94
  from insightface.app import FaceAnalysis
 
95
  self._app = FaceAnalysis(
96
  name="buffalo_l",
97
  providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
98
  )
99
+ ctx_id = device.index if device.type == "cuda" and device.index is not None else (0 if device.type == "cuda" else -1)
 
 
 
 
100
  self._app.prepare(ctx_id=ctx_id, det_size=(320, 320))
101
  self._has_arcface = True
102
  except Exception:
 
114
  """
115
  if self._has_arcface:
116
  import numpy as np
 
117
  embeddings = []
118
  valid_mask = []
119
  for i in range(image_tensor.shape[0]):
 
152
 
153
  # Resize to 112x112 for ArcFace
154
  pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False)
155
+ target_112 = F.interpolate(target_crop, size=(112, 112), mode="bilinear", align_corners=False)
 
 
156
 
157
  # Normalize to [-1, 1]
158
  pred_norm = pred_112 * 2 - 1
 
163
  target_emb, target_valid = self._extract_embedding(target_norm)
164
 
165
  # Only compute loss for samples where both faces were detected
166
+ valid = [p and t for p, t in zip(pred_valid, target_valid)]
167
  if not any(valid):
168
  return torch.tensor(0.0, device=pred_image.device)
169
 
 
216
  if self._lpips is None:
217
  try:
218
  import lpips
 
219
  self._lpips = lpips.LPIPS(net="alex").to(device)
220
  self._lpips.eval()
221
  for p in self._lpips.parameters():
 
225
 
226
  def __call__(
227
  self,
228
+ pred: torch.Tensor, # (B, 3, H, W) in [0, 1]
229
  target: torch.Tensor,
230
+ mask: torch.Tensor, # (B, 1, H, W) surgical mask [0, 1]
231
  ) -> torch.Tensor:
232
  self._ensure_loaded(pred.device)
233
 
234
+ # Normalize to [-1, 1] for LPIPS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  pred_norm = pred * 2 - 1
236
  target_norm = target * 2 - 1
237
 
238
+ # When mask is all-ones (no mask file available), compute on full image.
239
+ # Otherwise invert mask to get loss OUTSIDE the surgical region only.
240
+ has_mask = mask.sum() < mask.numel() * 0.99
241
+ if has_mask:
242
+ outside_mask = 1 - mask
243
+ erode_kernel = 5
244
+ if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel:
245
+ outside_mask = -F.max_pool2d(
246
+ -outside_mask,
247
+ kernel_size=erode_kernel,
248
+ stride=1,
249
+ padding=erode_kernel // 2,
250
+ )
251
+ pred_norm = pred_norm * outside_mask
252
+ target_norm = target_norm * outside_mask
253
 
254
  if self._lpips == "unavailable":
255
  # Fallback: simple L1 loss
 
287
  # or ONNX-based fallback
288
  if use_differentiable_arcface:
289
  from landmarkdiff.arcface_torch import ArcFaceLoss
 
290
  self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
291
  else:
292
  self.identity_loss = IdentityLoss()