dreamlessx commited on
Commit
afc1ddc
·
verified ·
1 Parent(s): f790aa8

Update landmarkdiff/ensemble.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/ensemble.py +24 -37
landmarkdiff/ensemble.py CHANGED
@@ -153,16 +153,14 @@ class EnsembleInference:
153
  # Copy metadata from best result
154
  best_idx = selected_idx if selected_idx >= 0 else 0
155
  ensemble_result = dict(results[best_idx])
156
- ensemble_result.update(
157
- {
158
- "output": final,
159
- "outputs": outputs,
160
- "scores": scores,
161
- "selected_idx": selected_idx,
162
- "strategy": self.strategy,
163
- "n_samples": self.n_samples,
164
- }
165
- )
166
 
167
  return ensemble_result
168
 
@@ -190,13 +188,14 @@ class EnsembleInference:
190
  ssim = compute_ssim(output, reference)
191
  scores.append(float(ssim))
192
 
193
- # Normalize to weights (higher SSIM = higher weight)
194
- total = sum(scores) or 1.0
195
- weights = [s / total for s in scores]
 
196
 
197
  # Weighted average
198
  result = np.zeros_like(outputs[0], dtype=np.float32)
199
- for output, weight in zip(outputs, weights, strict=False):
200
  result += output.astype(np.float32) * weight
201
 
202
  return np.clip(result, 0, 255).astype(np.uint8), scores
@@ -269,10 +268,8 @@ def ensemble_inference(
269
  for i, output in enumerate(result["outputs"]):
270
  cv2.imwrite(str(out / f"sample_{i:02d}.png"), output)
271
  score = result["scores"][i]
272
- print(
273
- f" Sample {i}: score={score:.4f}"
274
- + (" <-- selected" if i == result.get("selected_idx") else "")
275
- )
276
 
277
  # Comparison grid
278
  panels = [image] + result["outputs"] + [result["output"]]
@@ -283,10 +280,8 @@ def ensemble_inference(
283
 
284
  print(f"\nEnsemble output saved: {out / 'ensemble_output.png'}")
285
  if result.get("selected_idx", -1) >= 0:
286
- print(
287
- f"Selected sample: {result['selected_idx']} "
288
- f"(score={result['scores'][result['selected_idx']]:.4f})"
289
- )
290
 
291
 
292
  if __name__ == "__main__":
@@ -298,26 +293,18 @@ if __name__ == "__main__":
298
  parser.add_argument("--intensity", type=float, default=65.0)
299
  parser.add_argument("--output", default="ensemble_output")
300
  parser.add_argument("--n_samples", type=int, default=5)
301
- parser.add_argument(
302
- "--strategy",
303
- default="best_of_n",
304
- choices=["pixel_average", "weighted_average", "best_of_n", "median"],
305
- )
306
- parser.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
307
  parser.add_argument("--checkpoint", default=None)
308
  parser.add_argument("--displacement-model", default=None)
309
  parser.add_argument("--seed", type=int, default=42)
310
  args = parser.parse_args()
311
 
312
  ensemble_inference(
313
- args.image,
314
- args.procedure,
315
- args.intensity,
316
- args.output,
317
- args.n_samples,
318
- args.strategy,
319
- args.mode,
320
- args.checkpoint,
321
- args.displacement_model,
322
  args.seed,
323
  )
 
153
  # Copy metadata from best result
154
  best_idx = selected_idx if selected_idx >= 0 else 0
155
  ensemble_result = dict(results[best_idx])
156
+ ensemble_result.update({
157
+ "output": final,
158
+ "outputs": outputs,
159
+ "scores": scores,
160
+ "selected_idx": selected_idx,
161
+ "strategy": self.strategy,
162
+ "n_samples": self.n_samples,
163
+ })
 
 
164
 
165
  return ensemble_result
166
 
 
188
  ssim = compute_ssim(output, reference)
189
  scores.append(float(ssim))
190
 
191
+ # Normalize to weights (higher SSIM = higher weight, clamp negatives)
192
+ clamped = [max(0.0, s) for s in scores]
193
+ total = sum(clamped) or 1.0
194
+ weights = [s / total for s in clamped]
195
 
196
  # Weighted average
197
  result = np.zeros_like(outputs[0], dtype=np.float32)
198
+ for output, weight in zip(outputs, weights):
199
  result += output.astype(np.float32) * weight
200
 
201
  return np.clip(result, 0, 255).astype(np.uint8), scores
 
268
  for i, output in enumerate(result["outputs"]):
269
  cv2.imwrite(str(out / f"sample_{i:02d}.png"), output)
270
  score = result["scores"][i]
271
+ print(f" Sample {i}: score={score:.4f}"
272
+ + (" <-- selected" if i == result.get("selected_idx") else ""))
 
 
273
 
274
  # Comparison grid
275
  panels = [image] + result["outputs"] + [result["output"]]
 
280
 
281
  print(f"\nEnsemble output saved: {out / 'ensemble_output.png'}")
282
  if result.get("selected_idx", -1) >= 0:
283
+ print(f"Selected sample: {result['selected_idx']} "
284
+ f"(score={result['scores'][result['selected_idx']]:.4f})")
 
 
285
 
286
 
287
  if __name__ == "__main__":
 
293
  parser.add_argument("--intensity", type=float, default=65.0)
294
  parser.add_argument("--output", default="ensemble_output")
295
  parser.add_argument("--n_samples", type=int, default=5)
296
+ parser.add_argument("--strategy", default="best_of_n",
297
+ choices=["pixel_average", "weighted_average", "best_of_n", "median"])
298
+ parser.add_argument("--mode", default="tps",
299
+ choices=["controlnet", "img2img", "tps"])
 
 
300
  parser.add_argument("--checkpoint", default=None)
301
  parser.add_argument("--displacement-model", default=None)
302
  parser.add_argument("--seed", type=int, default=42)
303
  args = parser.parse_args()
304
 
305
  ensemble_inference(
306
+ args.image, args.procedure, args.intensity,
307
+ args.output, args.n_samples, args.strategy,
308
+ args.mode, args.checkpoint, args.displacement_model,
 
 
 
 
 
 
309
  args.seed,
310
  )