Spaces:
Running
Running
Update landmarkdiff/ensemble.py to v0.3.2
Browse files- landmarkdiff/ensemble.py +24 -37
landmarkdiff/ensemble.py
CHANGED
|
@@ -153,16 +153,14 @@ class EnsembleInference:
|
|
| 153 |
# Copy metadata from best result
|
| 154 |
best_idx = selected_idx if selected_idx >= 0 else 0
|
| 155 |
ensemble_result = dict(results[best_idx])
|
| 156 |
-
ensemble_result.update(
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
}
|
| 165 |
-
)
|
| 166 |
|
| 167 |
return ensemble_result
|
| 168 |
|
|
@@ -190,13 +188,14 @@ class EnsembleInference:
|
|
| 190 |
ssim = compute_ssim(output, reference)
|
| 191 |
scores.append(float(ssim))
|
| 192 |
|
| 193 |
-
# Normalize to weights (higher SSIM = higher weight)
|
| 194 |
-
|
| 195 |
-
|
|
|
|
| 196 |
|
| 197 |
# Weighted average
|
| 198 |
result = np.zeros_like(outputs[0], dtype=np.float32)
|
| 199 |
-
for output, weight in zip(outputs, weights
|
| 200 |
result += output.astype(np.float32) * weight
|
| 201 |
|
| 202 |
return np.clip(result, 0, 255).astype(np.uint8), scores
|
|
@@ -269,10 +268,8 @@ def ensemble_inference(
|
|
| 269 |
for i, output in enumerate(result["outputs"]):
|
| 270 |
cv2.imwrite(str(out / f"sample_{i:02d}.png"), output)
|
| 271 |
score = result["scores"][i]
|
| 272 |
-
print(
|
| 273 |
-
|
| 274 |
-
+ (" <-- selected" if i == result.get("selected_idx") else "")
|
| 275 |
-
)
|
| 276 |
|
| 277 |
# Comparison grid
|
| 278 |
panels = [image] + result["outputs"] + [result["output"]]
|
|
@@ -283,10 +280,8 @@ def ensemble_inference(
|
|
| 283 |
|
| 284 |
print(f"\nEnsemble output saved: {out / 'ensemble_output.png'}")
|
| 285 |
if result.get("selected_idx", -1) >= 0:
|
| 286 |
-
print(
|
| 287 |
-
|
| 288 |
-
f"(score={result['scores'][result['selected_idx']]:.4f})"
|
| 289 |
-
)
|
| 290 |
|
| 291 |
|
| 292 |
if __name__ == "__main__":
|
|
@@ -298,26 +293,18 @@ if __name__ == "__main__":
|
|
| 298 |
parser.add_argument("--intensity", type=float, default=65.0)
|
| 299 |
parser.add_argument("--output", default="ensemble_output")
|
| 300 |
parser.add_argument("--n_samples", type=int, default=5)
|
| 301 |
-
parser.add_argument(
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
)
|
| 306 |
-
parser.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
|
| 307 |
parser.add_argument("--checkpoint", default=None)
|
| 308 |
parser.add_argument("--displacement-model", default=None)
|
| 309 |
parser.add_argument("--seed", type=int, default=42)
|
| 310 |
args = parser.parse_args()
|
| 311 |
|
| 312 |
ensemble_inference(
|
| 313 |
-
args.image,
|
| 314 |
-
args.
|
| 315 |
-
args.
|
| 316 |
-
args.output,
|
| 317 |
-
args.n_samples,
|
| 318 |
-
args.strategy,
|
| 319 |
-
args.mode,
|
| 320 |
-
args.checkpoint,
|
| 321 |
-
args.displacement_model,
|
| 322 |
args.seed,
|
| 323 |
)
|
|
|
|
| 153 |
# Copy metadata from best result
|
| 154 |
best_idx = selected_idx if selected_idx >= 0 else 0
|
| 155 |
ensemble_result = dict(results[best_idx])
|
| 156 |
+
ensemble_result.update({
|
| 157 |
+
"output": final,
|
| 158 |
+
"outputs": outputs,
|
| 159 |
+
"scores": scores,
|
| 160 |
+
"selected_idx": selected_idx,
|
| 161 |
+
"strategy": self.strategy,
|
| 162 |
+
"n_samples": self.n_samples,
|
| 163 |
+
})
|
|
|
|
|
|
|
| 164 |
|
| 165 |
return ensemble_result
|
| 166 |
|
|
|
|
| 188 |
ssim = compute_ssim(output, reference)
|
| 189 |
scores.append(float(ssim))
|
| 190 |
|
| 191 |
+
# Normalize to weights (higher SSIM = higher weight, clamp negatives)
|
| 192 |
+
clamped = [max(0.0, s) for s in scores]
|
| 193 |
+
total = sum(clamped) or 1.0
|
| 194 |
+
weights = [s / total for s in clamped]
|
| 195 |
|
| 196 |
# Weighted average
|
| 197 |
result = np.zeros_like(outputs[0], dtype=np.float32)
|
| 198 |
+
for output, weight in zip(outputs, weights):
|
| 199 |
result += output.astype(np.float32) * weight
|
| 200 |
|
| 201 |
return np.clip(result, 0, 255).astype(np.uint8), scores
|
|
|
|
| 268 |
for i, output in enumerate(result["outputs"]):
|
| 269 |
cv2.imwrite(str(out / f"sample_{i:02d}.png"), output)
|
| 270 |
score = result["scores"][i]
|
| 271 |
+
print(f" Sample {i}: score={score:.4f}"
|
| 272 |
+
+ (" <-- selected" if i == result.get("selected_idx") else ""))
|
|
|
|
|
|
|
| 273 |
|
| 274 |
# Comparison grid
|
| 275 |
panels = [image] + result["outputs"] + [result["output"]]
|
|
|
|
| 280 |
|
| 281 |
print(f"\nEnsemble output saved: {out / 'ensemble_output.png'}")
|
| 282 |
if result.get("selected_idx", -1) >= 0:
|
| 283 |
+
print(f"Selected sample: {result['selected_idx']} "
|
| 284 |
+
f"(score={result['scores'][result['selected_idx']]:.4f})")
|
|
|
|
|
|
|
| 285 |
|
| 286 |
|
| 287 |
if __name__ == "__main__":
|
|
|
|
| 293 |
parser.add_argument("--intensity", type=float, default=65.0)
|
| 294 |
parser.add_argument("--output", default="ensemble_output")
|
| 295 |
parser.add_argument("--n_samples", type=int, default=5)
|
| 296 |
+
parser.add_argument("--strategy", default="best_of_n",
|
| 297 |
+
choices=["pixel_average", "weighted_average", "best_of_n", "median"])
|
| 298 |
+
parser.add_argument("--mode", default="tps",
|
| 299 |
+
choices=["controlnet", "img2img", "tps"])
|
|
|
|
|
|
|
| 300 |
parser.add_argument("--checkpoint", default=None)
|
| 301 |
parser.add_argument("--displacement-model", default=None)
|
| 302 |
parser.add_argument("--seed", type=int, default=42)
|
| 303 |
args = parser.parse_args()
|
| 304 |
|
| 305 |
ensemble_inference(
|
| 306 |
+
args.image, args.procedure, args.intensity,
|
| 307 |
+
args.output, args.n_samples, args.strategy,
|
| 308 |
+
args.mode, args.checkpoint, args.displacement_model,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
args.seed,
|
| 310 |
)
|