SynLayers commited on
Commit
5837a1e
·
verified ·
1 Parent(s): 02f6539

Upload demo/real_world_pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo/real_world_pipeline.py +469 -0
demo/real_world_pipeline.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import gc
5
+ import json
6
+ import os
7
+ import re
8
+ import sys
9
+ import time
10
+ import zipfile
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import torch
15
+ from PIL import Image, ImageOps
16
+
17
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
18
+ if str(PROJECT_ROOT) not in sys.path:
19
+ sys.path.insert(0, str(PROJECT_ROOT))
20
+
21
+ from demo.infer.run_caption_bbox_infer import ( # noqa: E402
22
+ CAPTION_BBOX_PROMPT_TOP_LEFT,
23
+ DEFAULT_BBOX_MODEL,
24
+ draw_boxes,
25
+ infer_caption_bbox,
26
+ )
27
+ from demo.hf_repo_assets import build_repo_asset_overrides, get_stage2_model_repo_id # noqa: E402
28
+ from demo.infer.vlm_bbox_inference import get_model_and_processor # noqa: E402
29
+ from infer.common_infer import initialize_pipeline # noqa: E402
30
+ from infer.infer import build_run_save_dir, get_real_boxes, load_adapter_image # noqa: E402
31
+ from tools.tools import load_config, seed_everything # noqa: E402
32
+
33
+
34
+ DEFAULT_REAL_CONFIG_PATH = PROJECT_ROOT / "infer" / "infer.yaml"
35
+ DEFAULT_WORK_DIR = PROJECT_ROOT / "demo" / "outputs" / "real_world_demo"
36
+ DEFAULT_RUN_NAME = "step_120000"
37
+ DEFAULT_TARGET_SIZE = 1024
38
+
39
+ _BBOX_CACHE: dict[str, object] = {"model_path": None, "model": None, "processor": None}
40
+ _REAL_CACHE: dict[str, object] = {"key": None, "pipeline": None, "transp_vae": None}
41
+
42
+
43
+ def slugify(text: str) -> str:
44
+ value = re.sub(r"[^A-Za-z0-9._-]+", "_", text).strip("._-")
45
+ return value or "sample"
46
+
47
+
48
+ def resolve_existing_path(*candidates) -> str | None:
49
+ for candidate in candidates:
50
+ if not candidate:
51
+ continue
52
+ path = Path(candidate)
53
+ if path.exists():
54
+ return str(path)
55
+ return None
56
+
57
+
58
+ DEFAULT_DECOMP_CKPT_ROOT = Path(
59
+ resolve_existing_path(
60
+ os.environ.get("SYNLAYERS_DECOMP_CKPT_ROOT"),
61
+ PROJECT_ROOT / "SynLayers_ckpt" / "step_120000",
62
+ "/project/llmsvgen/share/data/kmw_layered_checkpoint/SynLayers_ckpt/step_120000",
63
+ )
64
+ or "/project/llmsvgen/share/data/kmw_layered_checkpoint/SynLayers_ckpt/step_120000"
65
+ )
66
+
67
+
68
+ def prepare_input_image(input_path: str | Path, output_path: Path, size: int) -> Path:
69
+ image = Image.open(input_path).convert("RGB")
70
+
71
+ if image.size != (size, size):
72
+ resized = ImageOps.contain(image, (size, size), Image.LANCZOS)
73
+ canvas = Image.new("RGB", (size, size), (255, 255, 255))
74
+ offset = ((size - resized.width) // 2, (size - resized.height) // 2)
75
+ canvas.paste(resized, offset)
76
+ image = canvas
77
+
78
+ output_path.parent.mkdir(parents=True, exist_ok=True)
79
+ image.save(output_path)
80
+ return output_path
81
+
82
+
83
+ def load_bbox_bundle(model_path: str):
84
+ cached_model_path = _BBOX_CACHE["model_path"]
85
+ if cached_model_path == model_path and _BBOX_CACHE["model"] is not None:
86
+ return _BBOX_CACHE["model"], _BBOX_CACHE["processor"]
87
+
88
+ model, processor = get_model_and_processor(model_path)
89
+ _BBOX_CACHE.update(
90
+ {
91
+ "model_path": model_path,
92
+ "model": model,
93
+ "processor": processor,
94
+ }
95
+ )
96
+ return model, processor
97
+
98
+
99
+ def release_bbox_bundle():
100
+ model = _BBOX_CACHE.get("model")
101
+ processor = _BBOX_CACHE.get("processor")
102
+ if model is not None:
103
+ del model
104
+ if processor is not None:
105
+ del processor
106
+ _BBOX_CACHE.update({"model_path": None, "model": None, "processor": None})
107
+ gc.collect()
108
+ if torch.cuda.is_available():
109
+ torch.cuda.empty_cache()
110
+
111
+
112
+ def load_real_bundle(config: dict):
113
+ key = (
114
+ config.get("pretrained_model_name_or_path"),
115
+ config.get("pretrained_adapter_path"),
116
+ config.get("transp_vae_path"),
117
+ config.get("pretrained_lora_dir"),
118
+ config.get("artplus_lora_dir"),
119
+ config.get("lora_ckpt"),
120
+ config.get("layer_ckpt"),
121
+ config.get("adapter_lora_dir"),
122
+ config.get("max_layer_num"),
123
+ )
124
+
125
+ if _REAL_CACHE["key"] == key and _REAL_CACHE["pipeline"] is not None:
126
+ return _REAL_CACHE["pipeline"], _REAL_CACHE["transp_vae"]
127
+
128
+ if _REAL_CACHE["pipeline"] is not None:
129
+ del _REAL_CACHE["pipeline"]
130
+ del _REAL_CACHE["transp_vae"]
131
+ _REAL_CACHE["pipeline"] = None
132
+ _REAL_CACHE["transp_vae"] = None
133
+ if torch.cuda.is_available():
134
+ torch.cuda.empty_cache()
135
+
136
+ pipeline, transp_vae = initialize_pipeline(config)
137
+ _REAL_CACHE.update({"key": key, "pipeline": pipeline, "transp_vae": transp_vae})
138
+ return pipeline, transp_vae
139
+
140
+
141
+ def build_runtime_config(
142
+ *,
143
+ config_path: str | Path,
144
+ image_dir: Path,
145
+ bbox_jsonl: Path,
146
+ results_root: Path,
147
+ run_name: str,
148
+ seed: int | None = None,
149
+ ) -> dict:
150
+ config = load_config(str(config_path))
151
+ repo_overrides = build_repo_asset_overrides(get_stage2_model_repo_id())
152
+ decomp_ckpt_root = Path(
153
+ os.environ.get("SYNLAYERS_DECOMP_CKPT_ROOT")
154
+ or repo_overrides.get("decomp_ckpt_root")
155
+ or DEFAULT_DECOMP_CKPT_ROOT
156
+ )
157
+ config["data_dir"] = str(image_dir.parent)
158
+ config["image_dir"] = str(image_dir)
159
+ config["test_jsonl"] = str(bbox_jsonl)
160
+ config["save_dir"] = str(results_root)
161
+ config["run_name"] = run_name
162
+ config["lora_ckpt"] = str(decomp_ckpt_root / "transformer")
163
+ config["layer_ckpt"] = str(decomp_ckpt_root)
164
+ config["adapter_lora_dir"] = str(decomp_ckpt_root / "adapter")
165
+
166
+ env_overrides = {
167
+ "pretrained_model_name_or_path": (
168
+ repo_overrides.get("pretrained_model_name_or_path")
169
+ or resolve_existing_path(PROJECT_ROOT / "SynLayers_checkpoints" / "FLUX.1-dev")
170
+ ),
171
+ "pretrained_adapter_path": (
172
+ os.environ.get("SYNLAYERS_ADAPTER_MODEL")
173
+ or repo_overrides.get("pretrained_adapter_path")
174
+ or resolve_existing_path(
175
+ PROJECT_ROOT / "SynLayers_checkpoints" / "FLUX.1-dev-Controlnet-Inpainting-Alpha"
176
+ )
177
+ ),
178
+ "transp_vae_path": (
179
+ os.environ.get("SYNLAYERS_TRANSP_VAE")
180
+ or repo_overrides.get("transp_vae_path")
181
+ or resolve_existing_path(PROJECT_ROOT / "ckpt" / "trans_vae" / "0008000.pt")
182
+ ),
183
+ "pretrained_lora_dir": (
184
+ os.environ.get("SYNLAYERS_PRETRAINED_LORA")
185
+ or repo_overrides.get("pretrained_lora_dir")
186
+ or resolve_existing_path(PROJECT_ROOT / "ckpt" / "pre_trained_LoRA")
187
+ ),
188
+ "artplus_lora_dir": (
189
+ os.environ.get("SYNLAYERS_ARTPLUS_LORA")
190
+ or repo_overrides.get("artplus_lora_dir")
191
+ or resolve_existing_path(PROJECT_ROOT / "ckpt" / "prism_ft_LoRA")
192
+ ),
193
+ }
194
+ for key, value in env_overrides.items():
195
+ if value:
196
+ config[key] = value
197
+
198
+ if seed is not None:
199
+ config["seed"] = seed
200
+
201
+ return config
202
+
203
+
204
+ def write_bbox_jsonl(record: dict, output_path: Path) -> Path:
205
+ output_path.parent.mkdir(parents=True, exist_ok=True)
206
+ with output_path.open("w", encoding="utf-8") as handle:
207
+ handle.write(json.dumps(record, ensure_ascii=False) + "\n")
208
+ return output_path
209
+
210
+
211
+ def format_source_image_path(image_path: str, image_dir: Path) -> str:
212
+ path = Path(image_path)
213
+ try:
214
+ return path.relative_to(image_dir).as_posix()
215
+ except ValueError:
216
+ return path.name
217
+
218
+
219
+ def save_real_case(
220
+ *,
221
+ sample: dict,
222
+ config: dict,
223
+ pipeline,
224
+ transp_vae,
225
+ ) -> dict:
226
+ if config.get("seed") is not None:
227
+ seed_everything(config["seed"])
228
+
229
+ source_size = config.get("source_size", DEFAULT_TARGET_SIZE)
230
+ target_size = config.get("target_size", DEFAULT_TARGET_SIZE)
231
+ max_layer_num = config.get("max_layer_num", 52)
232
+ sample_name = sample["sample_or_stem"]
233
+
234
+ layer_boxes = get_real_boxes(sample, source_size, target_size)
235
+ adapter_img, resolved_image_path = load_adapter_image(sample, target_size, config)
236
+
237
+ whole_box = (0, 0, target_size, target_size)
238
+ bg_box = (0, 0, target_size, target_size)
239
+ all_boxes = [whole_box, bg_box] + layer_boxes
240
+ if len(all_boxes) > max_layer_num:
241
+ raise ValueError(
242
+ f"num_layers={len(all_boxes)} exceeds max_layer_num={max_layer_num} for {sample_name}"
243
+ )
244
+
245
+ generator = torch.Generator(device=torch.device("cuda")).manual_seed(config.get("seed", 42))
246
+ caption = sample.get("whole_caption", "")
247
+
248
+ x_hat, image, _ = pipeline(
249
+ prompt=caption,
250
+ adapter_image=adapter_img,
251
+ adapter_conditioning_scale=config.get("adapter_scale", 0.9),
252
+ validation_box=all_boxes,
253
+ generator=generator,
254
+ height=target_size,
255
+ width=target_size,
256
+ guidance_scale=config.get("cfg", 4.0),
257
+ num_layers=len(all_boxes),
258
+ sdxl_vae=transp_vae,
259
+ )
260
+
261
+ x_hat = (x_hat + 1) / 2
262
+ x_hat = x_hat.squeeze(0).permute(1, 0, 2, 3).to(torch.float32)
263
+
264
+ save_dir, resolved_run_name = build_run_save_dir(config)
265
+ save_dir_path = Path(save_dir)
266
+ case_dir = save_dir_path / sample_name
267
+ merged_dir = save_dir_path / "merged"
268
+ merged_rgba_dir = save_dir_path / "merged_rgba"
269
+ case_dir.mkdir(parents=True, exist_ok=True)
270
+ merged_dir.mkdir(parents=True, exist_ok=True)
271
+ merged_rgba_dir.mkdir(parents=True, exist_ok=True)
272
+
273
+ whole_rgba_path = case_dir / "whole_image_rgba.png"
274
+ background_rgba_path = case_dir / "background_rgba.png"
275
+ origin_path = case_dir / "origin.png"
276
+ merged_case_path = case_dir / "merged.png"
277
+ merged_global_path = merged_dir / f"{sample_name}.png"
278
+ merged_rgba_path = merged_rgba_dir / f"{sample_name}.png"
279
+
280
+ whole_image_layer = (x_hat[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
281
+ Image.fromarray(whole_image_layer, "RGBA").save(whole_rgba_path)
282
+
283
+ background_layer = (x_hat[1].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
284
+ Image.fromarray(background_layer, "RGBA").save(background_rgba_path)
285
+
286
+ adapter_img.save(origin_path)
287
+
288
+ merged_image = image[1]
289
+ layer_paths: list[str] = []
290
+ for layer_idx in range(2, x_hat.shape[0]):
291
+ rgba_layer = (x_hat[layer_idx].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
292
+ rgba_image = Image.fromarray(rgba_layer, "RGBA")
293
+ layer_path = case_dir / f"layer_{layer_idx - 2}_rgba.png"
294
+ rgba_image.save(layer_path)
295
+ layer_paths.append(str(layer_path))
296
+ merged_image = Image.alpha_composite(merged_image.convert("RGBA"), rgba_image)
297
+
298
+ merged_image.convert("RGB").save(merged_global_path)
299
+ merged_image.convert("RGB").save(merged_case_path)
300
+ merged_image.save(merged_rgba_path)
301
+
302
+ case_meta = {
303
+ "sample_name": sample_name,
304
+ "source_image_path": format_source_image_path(
305
+ resolved_image_path,
306
+ Path(config["image_dir"]),
307
+ ),
308
+ "target_size": target_size,
309
+ "source_size": source_size,
310
+ "raw_num_layers": sample.get("num_layers"),
311
+ "num_layers": len(all_boxes),
312
+ "raw_boxes": sample.get("bboxes", []),
313
+ "boxes": all_boxes,
314
+ "caption": caption,
315
+ "run_name": resolved_run_name,
316
+ }
317
+ meta_path = case_dir / "inference_meta.json"
318
+ with meta_path.open("w", encoding="utf-8") as handle:
319
+ json.dump(case_meta, handle, indent=2)
320
+
321
+ return {
322
+ "run_name": resolved_run_name,
323
+ "save_dir": str(save_dir_path),
324
+ "case_dir": str(case_dir),
325
+ "merged_image": str(merged_case_path),
326
+ "merged_global_image": str(merged_global_path),
327
+ "merged_rgba_image": str(merged_rgba_path),
328
+ "whole_image_rgba": str(whole_rgba_path),
329
+ "background_rgba": str(background_rgba_path),
330
+ "origin_image": str(origin_path),
331
+ "layer_images": layer_paths,
332
+ "metadata_path": str(meta_path),
333
+ "metadata": case_meta,
334
+ }
335
+
336
+
337
+ def create_archive(run_dir: Path) -> Path:
338
+ archive_path = run_dir / "synlayers_result_bundle.zip"
339
+ with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
340
+ for path in run_dir.rglob("*"):
341
+ if path == archive_path or path.is_dir():
342
+ continue
343
+ zf.write(path, arcname=path.relative_to(run_dir))
344
+ return archive_path
345
+
346
+
347
+ def run_real_world_pipeline(
348
+ image_path: str | Path,
349
+ *,
350
+ sample_name: str | None = None,
351
+ work_dir: str | Path | None = None,
352
+ bbox_model: str | None = None,
353
+ config_path: str | Path | None = None,
354
+ max_new_tokens: int = 1024,
355
+ seed: int | None = None,
356
+ run_name: str = DEFAULT_RUN_NAME,
357
+ ) -> dict:
358
+ if not torch.cuda.is_available():
359
+ raise RuntimeError(
360
+ "CUDA GPU is required for the unified SynLayers real-world pipeline. "
361
+ "On Hugging Face Spaces, assign GPU hardware such as A100 and rebuild the Space."
362
+ )
363
+
364
+ image_path = Path(image_path)
365
+ if not image_path.exists():
366
+ raise FileNotFoundError(f"Input image not found: {image_path}")
367
+
368
+ bbox_model = (
369
+ bbox_model
370
+ or os.environ.get("SYNLAYERS_BBOX_MODEL")
371
+ or os.environ.get("SYNLAYERS_BBOX_MODEL_REPO")
372
+ or DEFAULT_BBOX_MODEL
373
+ )
374
+ config_path = Path(config_path or os.environ.get("SYNLAYERS_REAL_CONFIG", str(DEFAULT_REAL_CONFIG_PATH)))
375
+ work_dir = Path(work_dir or os.environ.get("SYNLAYERS_DEMO_WORK_DIR", str(DEFAULT_WORK_DIR)))
376
+
377
+ normalized_sample_name = slugify(sample_name or image_path.stem)
378
+ timestamp = f"{time.strftime('%Y%m%d_%H%M%S')}_{int((time.time() % 1) * 1000):03d}"
379
+ run_dir = work_dir / f"{timestamp}_{normalized_sample_name}"
380
+ image_dir = run_dir / "layers_real_test_1024"
381
+ prepared_image_path = prepare_input_image(
382
+ image_path,
383
+ image_dir / f"{normalized_sample_name}.png",
384
+ DEFAULT_TARGET_SIZE,
385
+ )
386
+
387
+ bbox_model_bundle, bbox_processor = load_bbox_bundle(bbox_model)
388
+ whole_caption, bboxes = infer_caption_bbox(
389
+ prepared_image_path,
390
+ bbox_model_bundle,
391
+ bbox_processor,
392
+ prompt=CAPTION_BBOX_PROMPT_TOP_LEFT,
393
+ max_new_tokens=max_new_tokens,
394
+ )
395
+
396
+ record = {
397
+ "sample_or_stem": normalized_sample_name,
398
+ "image": prepared_image_path.name,
399
+ "whole_caption": whole_caption,
400
+ "bboxes": bboxes,
401
+ "num_layers": len(bboxes),
402
+ "coord": "top_left",
403
+ }
404
+
405
+ bbox_jsonl = write_bbox_jsonl(record, run_dir / "caption_bbox_infer.jsonl")
406
+ bbox_vis_path = run_dir / "bbox_vis" / f"{normalized_sample_name}_vis.png"
407
+ draw_boxes(prepared_image_path, bboxes, bbox_vis_path)
408
+ release_bbox_bundle()
409
+
410
+ config = build_runtime_config(
411
+ config_path=config_path,
412
+ image_dir=image_dir,
413
+ bbox_jsonl=bbox_jsonl,
414
+ results_root=run_dir / "results",
415
+ run_name=run_name,
416
+ seed=seed,
417
+ )
418
+ pipeline, transp_vae = load_real_bundle(config)
419
+ decomposition_result = save_real_case(
420
+ sample=record,
421
+ config=config,
422
+ pipeline=pipeline,
423
+ transp_vae=transp_vae,
424
+ )
425
+
426
+ archive_path = create_archive(run_dir)
427
+ decomposition_result.update(
428
+ {
429
+ "input_image": str(prepared_image_path),
430
+ "bbox_visualization": str(bbox_vis_path),
431
+ "bbox_jsonl": str(bbox_jsonl),
432
+ "bbox_record": record,
433
+ "archive_path": str(archive_path),
434
+ "config_path": str(config_path),
435
+ "bbox_model": bbox_model,
436
+ }
437
+ )
438
+ return decomposition_result
439
+
440
+
441
+ def main():
442
+ parser = argparse.ArgumentParser(
443
+ description="Run the unified real-world SynLayers pipeline on one image."
444
+ )
445
+ parser.add_argument("--image", type=str, required=True, help="Input image path")
446
+ parser.add_argument("--sample-name", type=str, default=None)
447
+ parser.add_argument("--work-dir", type=str, default=str(DEFAULT_WORK_DIR))
448
+ parser.add_argument("--bbox-model", type=str, default=DEFAULT_BBOX_MODEL)
449
+ parser.add_argument("--config", type=str, default=str(DEFAULT_REAL_CONFIG_PATH))
450
+ parser.add_argument("--max-new-tokens", type=int, default=1024)
451
+ parser.add_argument("--seed", type=int, default=None)
452
+ parser.add_argument("--run-name", type=str, default=DEFAULT_RUN_NAME)
453
+ args = parser.parse_args()
454
+
455
+ result = run_real_world_pipeline(
456
+ args.image,
457
+ sample_name=args.sample_name,
458
+ work_dir=args.work_dir,
459
+ bbox_model=args.bbox_model,
460
+ config_path=args.config,
461
+ max_new_tokens=args.max_new_tokens,
462
+ seed=args.seed,
463
+ run_name=args.run_name,
464
+ )
465
+ print(json.dumps(result, indent=2, ensure_ascii=False))
466
+
467
+
468
+ if __name__ == "__main__":
469
+ main()