dreamlessx commited on
Commit
e316420
·
verified ·
1 Parent(s): e013072

Upload landmarkdiff/cli.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/cli.py +228 -0
landmarkdiff/cli.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unified CLI for LandmarkDiff.
2
+
3
+ Usage:
4
+ landmarkdiff infer IMAGE --procedure rhinoplasty --intensity 65
5
+ landmarkdiff evaluate --test-dir data/test --checkpoint checkpoints/latest
6
+ landmarkdiff train --config configs/phaseA.yaml
7
+ landmarkdiff demo IMAGE --output demo_report.png
8
+ landmarkdiff config --show
9
+ landmarkdiff validate IMAGE --output validated.png
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import sys
16
+
17
+
18
+ def cmd_infer(args: argparse.Namespace) -> None:
19
+ """Run single-image inference."""
20
+ import cv2
21
+ import numpy as np
22
+ from pathlib import Path
23
+
24
+ from landmarkdiff.inference import LandmarkDiffPipeline
25
+
26
+ image = cv2.imread(args.image)
27
+ if image is None:
28
+ print(f"ERROR: Cannot read image: {args.image}")
29
+ sys.exit(1)
30
+
31
+ image = cv2.resize(image, (512, 512))
32
+
33
+ pipeline = LandmarkDiffPipeline(
34
+ mode=args.mode,
35
+ controlnet_checkpoint=args.checkpoint,
36
+ displacement_model_path=args.displacement_model,
37
+ )
38
+ pipeline.load()
39
+
40
+ result = pipeline.generate(
41
+ image,
42
+ procedure=args.procedure,
43
+ intensity=args.intensity,
44
+ seed=args.seed,
45
+ )
46
+
47
+ out_path = Path(args.output)
48
+ out_path.parent.mkdir(parents=True, exist_ok=True)
49
+ cv2.imwrite(str(out_path), result["output"])
50
+ print(f"Output saved: {out_path}")
51
+
52
+ if args.watermark:
53
+ from landmarkdiff.safety import SafetyValidator
54
+ validator = SafetyValidator()
55
+ watermarked = validator.apply_watermark(result["output"])
56
+ wm_path = out_path.with_stem(out_path.stem + "_watermarked")
57
+ cv2.imwrite(str(wm_path), watermarked)
58
+ print(f"Watermarked: {wm_path}")
59
+
60
+
61
+ def cmd_ensemble(args: argparse.Namespace) -> None:
62
+ """Run ensemble inference."""
63
+ from landmarkdiff.ensemble import ensemble_inference
64
+
65
+ ensemble_inference(
66
+ image_path=args.image,
67
+ procedure=args.procedure,
68
+ intensity=args.intensity,
69
+ output_dir=args.output,
70
+ n_samples=args.n_samples,
71
+ strategy=args.strategy,
72
+ mode=args.mode,
73
+ controlnet_checkpoint=args.checkpoint,
74
+ displacement_model_path=args.displacement_model,
75
+ seed=args.seed,
76
+ )
77
+
78
+
79
+ def cmd_evaluate(args: argparse.Namespace) -> None:
80
+ """Run evaluation on test set."""
81
+ from pathlib import Path
82
+
83
+ # Import evaluation functions
84
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
85
+ from scripts.run_evaluation import run_evaluation
86
+
87
+ run_evaluation(
88
+ test_dir=args.test_dir,
89
+ output_dir=args.output,
90
+ mode=args.mode,
91
+ checkpoint=args.checkpoint,
92
+ displacement_model=args.displacement_model,
93
+ max_samples=args.max_samples,
94
+ )
95
+
96
+
97
+ def cmd_config(args: argparse.Namespace) -> None:
98
+ """Show or validate configuration."""
99
+ from landmarkdiff.config import ExperimentConfig, load_config, validate_config
100
+
101
+ if args.file:
102
+ config = load_config(args.file)
103
+ else:
104
+ config = ExperimentConfig()
105
+
106
+ if args.validate:
107
+ warnings = validate_config(config)
108
+ if warnings:
109
+ print("Validation warnings:")
110
+ for w in warnings:
111
+ print(f" - {w}")
112
+ else:
113
+ print("Configuration valid (no warnings).")
114
+ else:
115
+ import yaml
116
+ from dataclasses import asdict
117
+ print(yaml.dump(asdict(config), default_flow_style=False, sort_keys=False))
118
+
119
+
120
+ def cmd_validate(args: argparse.Namespace) -> None:
121
+ """Run safety validation on an output image."""
122
+ import cv2
123
+ from landmarkdiff.safety import SafetyValidator
124
+
125
+ input_img = cv2.imread(args.input)
126
+ output_img = cv2.imread(args.output_image)
127
+
128
+ if input_img is None or output_img is None:
129
+ print("ERROR: Cannot read input or output image.")
130
+ sys.exit(1)
131
+
132
+ validator = SafetyValidator(
133
+ watermark_enabled=args.watermark,
134
+ )
135
+
136
+ result = validator.validate(
137
+ input_image=input_img,
138
+ output_image=output_img,
139
+ face_confidence=args.face_confidence,
140
+ )
141
+
142
+ print(result.summary())
143
+
144
+ if not result.passed:
145
+ sys.exit(1)
146
+
147
+
148
+ def cmd_version(args: argparse.Namespace) -> None:
149
+ """Print version info."""
150
+ from landmarkdiff import __version__
151
+ print(f"LandmarkDiff v{__version__}")
152
+
153
+
154
+ def main(argv: list[str] | None = None) -> None:
155
+ """Main CLI entry point."""
156
+ parser = argparse.ArgumentParser(
157
+ prog="landmarkdiff",
158
+ description="LandmarkDiff: Facial surgery outcome prediction via latent diffusion",
159
+ )
160
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
161
+
162
+ # --- infer ---
163
+ p_infer = subparsers.add_parser("infer", help="Run single-image inference")
164
+ p_infer.add_argument("image", help="Input face image path")
165
+ p_infer.add_argument("--procedure", default="rhinoplasty",
166
+ choices=["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic"])
167
+ p_infer.add_argument("--intensity", type=float, default=65.0)
168
+ p_infer.add_argument("--output", default="output.png")
169
+ p_infer.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
170
+ p_infer.add_argument("--checkpoint", default=None)
171
+ p_infer.add_argument("--displacement-model", default=None)
172
+ p_infer.add_argument("--seed", type=int, default=42)
173
+ p_infer.add_argument("--watermark", action="store_true")
174
+ p_infer.set_defaults(func=cmd_infer)
175
+
176
+ # --- ensemble ---
177
+ p_ensemble = subparsers.add_parser("ensemble", help="Run ensemble inference")
178
+ p_ensemble.add_argument("image", help="Input face image path")
179
+ p_ensemble.add_argument("--procedure", default="rhinoplasty")
180
+ p_ensemble.add_argument("--intensity", type=float, default=65.0)
181
+ p_ensemble.add_argument("--output", default="ensemble_output")
182
+ p_ensemble.add_argument("--n-samples", type=int, default=5)
183
+ p_ensemble.add_argument("--strategy", default="best_of_n",
184
+ choices=["pixel_average", "weighted_average", "best_of_n", "median"])
185
+ p_ensemble.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
186
+ p_ensemble.add_argument("--checkpoint", default=None)
187
+ p_ensemble.add_argument("--displacement-model", default=None)
188
+ p_ensemble.add_argument("--seed", type=int, default=42)
189
+ p_ensemble.set_defaults(func=cmd_ensemble)
190
+
191
+ # --- evaluate ---
192
+ p_eval = subparsers.add_parser("evaluate", help="Evaluate on test set")
193
+ p_eval.add_argument("--test-dir", required=True)
194
+ p_eval.add_argument("--output", default="eval_results")
195
+ p_eval.add_argument("--mode", default="tps")
196
+ p_eval.add_argument("--checkpoint", default=None)
197
+ p_eval.add_argument("--displacement-model", default=None)
198
+ p_eval.add_argument("--max-samples", type=int, default=0)
199
+ p_eval.set_defaults(func=cmd_evaluate)
200
+
201
+ # --- config ---
202
+ p_config = subparsers.add_parser("config", help="Show or validate configuration")
203
+ p_config.add_argument("--file", default=None, help="YAML config file")
204
+ p_config.add_argument("--validate", action="store_true")
205
+ p_config.set_defaults(func=cmd_config)
206
+
207
+ # --- validate ---
208
+ p_validate = subparsers.add_parser("validate", help="Run safety validation")
209
+ p_validate.add_argument("input", help="Original input image")
210
+ p_validate.add_argument("output_image", help="Generated output image")
211
+ p_validate.add_argument("--watermark", action="store_true")
212
+ p_validate.add_argument("--face-confidence", type=float, default=1.0)
213
+ p_validate.set_defaults(func=cmd_validate)
214
+
215
+ # --- version ---
216
+ p_version = subparsers.add_parser("version", help="Print version")
217
+ p_version.set_defaults(func=cmd_version)
218
+
219
+ args = parser.parse_args(argv)
220
+ if not hasattr(args, "func"):
221
+ parser.print_help()
222
+ sys.exit(1)
223
+
224
+ args.func(args)
225
+
226
+
227
+ if __name__ == "__main__":
228
+ main()