SynLayers commited on
Commit
33eede6
·
verified ·
1 Parent(s): 81edc25

Upload demo/real_world_pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo/real_world_pipeline.py +453 -0
demo/real_world_pipeline.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import re
7
+ import sys
8
+ import time
9
+ import zipfile
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import torch
14
+ from PIL import Image, ImageOps
15
+
16
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
17
+ if str(PROJECT_ROOT) not in sys.path:
18
+ sys.path.insert(0, str(PROJECT_ROOT))
19
+
20
+ from demo.infer.run_caption_bbox_infer import ( # noqa: E402
21
+ CAPTION_BBOX_PROMPT_TOP_LEFT,
22
+ DEFAULT_BBOX_MODEL,
23
+ draw_boxes,
24
+ infer_caption_bbox,
25
+ )
26
+ from demo.hf_repo_assets import build_repo_asset_overrides # noqa: E402
27
+ from demo.infer.vlm_bbox_inference import get_model_and_processor # noqa: E402
28
+ from infer.common_infer import initialize_pipeline # noqa: E402
29
+ from infer.infer import build_run_save_dir, get_real_boxes, load_adapter_image # noqa: E402
30
+ from tools.tools import load_config, seed_everything # noqa: E402
31
+
32
+
33
+ DEFAULT_REAL_CONFIG_PATH = PROJECT_ROOT / "infer" / "infer.yaml"
34
+ DEFAULT_WORK_DIR = PROJECT_ROOT / "demo" / "outputs" / "real_world_demo"
35
+ DEFAULT_RUN_NAME = "step_120000"
36
+ DEFAULT_TARGET_SIZE = 1024
37
+
38
+ _BBOX_CACHE: dict[str, object] = {"model_path": None, "model": None, "processor": None}
39
+ _REAL_CACHE: dict[str, object] = {"key": None, "pipeline": None, "transp_vae": None}
40
+
41
+
42
+ def slugify(text: str) -> str:
43
+ value = re.sub(r"[^A-Za-z0-9._-]+", "_", text).strip("._-")
44
+ return value or "sample"
45
+
46
+
47
+ def resolve_existing_path(*candidates) -> str | None:
48
+ for candidate in candidates:
49
+ if not candidate:
50
+ continue
51
+ path = Path(candidate)
52
+ if path.exists():
53
+ return str(path)
54
+ return None
55
+
56
+
57
+ DEFAULT_DECOMP_CKPT_ROOT = Path(
58
+ resolve_existing_path(
59
+ os.environ.get("SYNLAYERS_DECOMP_CKPT_ROOT"),
60
+ PROJECT_ROOT / "SynLayers_ckpt" / "step_120000",
61
+ "/project/llmsvgen/share/data/kmw_layered_checkpoint/SynLayers_ckpt/step_120000",
62
+ )
63
+ or "/project/llmsvgen/share/data/kmw_layered_checkpoint/SynLayers_ckpt/step_120000"
64
+ )
65
+
66
+
67
+ def prepare_input_image(input_path: str | Path, output_path: Path, size: int) -> Path:
68
+ image = Image.open(input_path).convert("RGB")
69
+
70
+ if image.size != (size, size):
71
+ resized = ImageOps.contain(image, (size, size), Image.LANCZOS)
72
+ canvas = Image.new("RGB", (size, size), (255, 255, 255))
73
+ offset = ((size - resized.width) // 2, (size - resized.height) // 2)
74
+ canvas.paste(resized, offset)
75
+ image = canvas
76
+
77
+ output_path.parent.mkdir(parents=True, exist_ok=True)
78
+ image.save(output_path)
79
+ return output_path
80
+
81
+
82
+ def load_bbox_bundle(model_path: str):
83
+ cached_model_path = _BBOX_CACHE["model_path"]
84
+ if cached_model_path == model_path and _BBOX_CACHE["model"] is not None:
85
+ return _BBOX_CACHE["model"], _BBOX_CACHE["processor"]
86
+
87
+ model, processor = get_model_and_processor(model_path)
88
+ _BBOX_CACHE.update(
89
+ {
90
+ "model_path": model_path,
91
+ "model": model,
92
+ "processor": processor,
93
+ }
94
+ )
95
+ return model, processor
96
+
97
+
98
+ def load_real_bundle(config: dict):
99
+ key = (
100
+ config.get("pretrained_model_name_or_path"),
101
+ config.get("pretrained_adapter_path"),
102
+ config.get("transp_vae_path"),
103
+ config.get("pretrained_lora_dir"),
104
+ config.get("artplus_lora_dir"),
105
+ config.get("lora_ckpt"),
106
+ config.get("layer_ckpt"),
107
+ config.get("adapter_lora_dir"),
108
+ config.get("max_layer_num"),
109
+ )
110
+
111
+ if _REAL_CACHE["key"] == key and _REAL_CACHE["pipeline"] is not None:
112
+ return _REAL_CACHE["pipeline"], _REAL_CACHE["transp_vae"]
113
+
114
+ if _REAL_CACHE["pipeline"] is not None:
115
+ del _REAL_CACHE["pipeline"]
116
+ del _REAL_CACHE["transp_vae"]
117
+ _REAL_CACHE["pipeline"] = None
118
+ _REAL_CACHE["transp_vae"] = None
119
+ if torch.cuda.is_available():
120
+ torch.cuda.empty_cache()
121
+
122
+ pipeline, transp_vae = initialize_pipeline(config)
123
+ _REAL_CACHE.update({"key": key, "pipeline": pipeline, "transp_vae": transp_vae})
124
+ return pipeline, transp_vae
125
+
126
+
127
+ def build_runtime_config(
128
+ *,
129
+ config_path: str | Path,
130
+ image_dir: Path,
131
+ bbox_jsonl: Path,
132
+ results_root: Path,
133
+ run_name: str,
134
+ seed: int | None = None,
135
+ ) -> dict:
136
+ config = load_config(str(config_path))
137
+ repo_overrides = build_repo_asset_overrides(os.environ.get("SYNLAYERS_MODEL_REPO"))
138
+ decomp_ckpt_root = Path(
139
+ os.environ.get("SYNLAYERS_DECOMP_CKPT_ROOT")
140
+ or repo_overrides.get("decomp_ckpt_root")
141
+ or DEFAULT_DECOMP_CKPT_ROOT
142
+ )
143
+ config["data_dir"] = str(image_dir.parent)
144
+ config["image_dir"] = str(image_dir)
145
+ config["test_jsonl"] = str(bbox_jsonl)
146
+ config["save_dir"] = str(results_root)
147
+ config["run_name"] = run_name
148
+ config["lora_ckpt"] = str(decomp_ckpt_root / "transformer")
149
+ config["layer_ckpt"] = str(decomp_ckpt_root)
150
+ config["adapter_lora_dir"] = str(decomp_ckpt_root / "adapter")
151
+
152
+ env_overrides = {
153
+ "pretrained_model_name_or_path": (
154
+ os.environ.get("SYNLAYERS_BASE_MODEL")
155
+ or repo_overrides.get("pretrained_model_name_or_path")
156
+ or resolve_existing_path(PROJECT_ROOT / "SynLayers_checkpoints" / "FLUX.1-dev")
157
+ or "black-forest-labs/FLUX.1-dev"
158
+ ),
159
+ "pretrained_adapter_path": (
160
+ os.environ.get("SYNLAYERS_ADAPTER_MODEL")
161
+ or repo_overrides.get("pretrained_adapter_path")
162
+ or resolve_existing_path(
163
+ PROJECT_ROOT / "SynLayers_checkpoints" / "FLUX.1-dev-Controlnet-Inpainting-Alpha"
164
+ )
165
+ ),
166
+ "transp_vae_path": (
167
+ os.environ.get("SYNLAYERS_TRANSP_VAE")
168
+ or repo_overrides.get("transp_vae_path")
169
+ or resolve_existing_path(PROJECT_ROOT / "ckpt" / "trans_vae" / "0008000.pt")
170
+ ),
171
+ "pretrained_lora_dir": (
172
+ os.environ.get("SYNLAYERS_PRETRAINED_LORA")
173
+ or repo_overrides.get("pretrained_lora_dir")
174
+ or resolve_existing_path(PROJECT_ROOT / "ckpt" / "pre_trained_LoRA")
175
+ ),
176
+ "artplus_lora_dir": (
177
+ os.environ.get("SYNLAYERS_ARTPLUS_LORA")
178
+ or repo_overrides.get("artplus_lora_dir")
179
+ or resolve_existing_path(PROJECT_ROOT / "ckpt" / "prism_ft_LoRA")
180
+ ),
181
+ }
182
+ for key, value in env_overrides.items():
183
+ if value:
184
+ config[key] = value
185
+
186
+ if seed is not None:
187
+ config["seed"] = seed
188
+
189
+ return config
190
+
191
+
192
+ def write_bbox_jsonl(record: dict, output_path: Path) -> Path:
193
+ output_path.parent.mkdir(parents=True, exist_ok=True)
194
+ with output_path.open("w", encoding="utf-8") as handle:
195
+ handle.write(json.dumps(record, ensure_ascii=False) + "\n")
196
+ return output_path
197
+
198
+
199
+ def format_source_image_path(image_path: str, image_dir: Path) -> str:
200
+ path = Path(image_path)
201
+ try:
202
+ return path.relative_to(image_dir).as_posix()
203
+ except ValueError:
204
+ return path.name
205
+
206
+
207
+ def save_real_case(
208
+ *,
209
+ sample: dict,
210
+ config: dict,
211
+ pipeline,
212
+ transp_vae,
213
+ ) -> dict:
214
+ if config.get("seed") is not None:
215
+ seed_everything(config["seed"])
216
+
217
+ source_size = config.get("source_size", DEFAULT_TARGET_SIZE)
218
+ target_size = config.get("target_size", DEFAULT_TARGET_SIZE)
219
+ max_layer_num = config.get("max_layer_num", 52)
220
+ sample_name = sample["sample_or_stem"]
221
+
222
+ layer_boxes = get_real_boxes(sample, source_size, target_size)
223
+ adapter_img, resolved_image_path = load_adapter_image(sample, target_size, config)
224
+
225
+ whole_box = (0, 0, target_size, target_size)
226
+ bg_box = (0, 0, target_size, target_size)
227
+ all_boxes = [whole_box, bg_box] + layer_boxes
228
+ if len(all_boxes) > max_layer_num:
229
+ raise ValueError(
230
+ f"num_layers={len(all_boxes)} exceeds max_layer_num={max_layer_num} for {sample_name}"
231
+ )
232
+
233
+ generator = torch.Generator(device=torch.device("cuda")).manual_seed(config.get("seed", 42))
234
+ caption = sample.get("whole_caption", "")
235
+
236
+ x_hat, image, _ = pipeline(
237
+ prompt=caption,
238
+ adapter_image=adapter_img,
239
+ adapter_conditioning_scale=config.get("adapter_scale", 0.9),
240
+ validation_box=all_boxes,
241
+ generator=generator,
242
+ height=target_size,
243
+ width=target_size,
244
+ guidance_scale=config.get("cfg", 4.0),
245
+ num_layers=len(all_boxes),
246
+ sdxl_vae=transp_vae,
247
+ )
248
+
249
+ x_hat = (x_hat + 1) / 2
250
+ x_hat = x_hat.squeeze(0).permute(1, 0, 2, 3).to(torch.float32)
251
+
252
+ save_dir, resolved_run_name = build_run_save_dir(config)
253
+ save_dir_path = Path(save_dir)
254
+ case_dir = save_dir_path / sample_name
255
+ merged_dir = save_dir_path / "merged"
256
+ merged_rgba_dir = save_dir_path / "merged_rgba"
257
+ case_dir.mkdir(parents=True, exist_ok=True)
258
+ merged_dir.mkdir(parents=True, exist_ok=True)
259
+ merged_rgba_dir.mkdir(parents=True, exist_ok=True)
260
+
261
+ whole_rgba_path = case_dir / "whole_image_rgba.png"
262
+ background_rgba_path = case_dir / "background_rgba.png"
263
+ origin_path = case_dir / "origin.png"
264
+ merged_case_path = case_dir / "merged.png"
265
+ merged_global_path = merged_dir / f"{sample_name}.png"
266
+ merged_rgba_path = merged_rgba_dir / f"{sample_name}.png"
267
+
268
+ whole_image_layer = (x_hat[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
269
+ Image.fromarray(whole_image_layer, "RGBA").save(whole_rgba_path)
270
+
271
+ background_layer = (x_hat[1].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
272
+ Image.fromarray(background_layer, "RGBA").save(background_rgba_path)
273
+
274
+ adapter_img.save(origin_path)
275
+
276
+ merged_image = image[1]
277
+ layer_paths: list[str] = []
278
+ for layer_idx in range(2, x_hat.shape[0]):
279
+ rgba_layer = (x_hat[layer_idx].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
280
+ rgba_image = Image.fromarray(rgba_layer, "RGBA")
281
+ layer_path = case_dir / f"layer_{layer_idx - 2}_rgba.png"
282
+ rgba_image.save(layer_path)
283
+ layer_paths.append(str(layer_path))
284
+ merged_image = Image.alpha_composite(merged_image.convert("RGBA"), rgba_image)
285
+
286
+ merged_image.convert("RGB").save(merged_global_path)
287
+ merged_image.convert("RGB").save(merged_case_path)
288
+ merged_image.save(merged_rgba_path)
289
+
290
+ case_meta = {
291
+ "sample_name": sample_name,
292
+ "source_image_path": format_source_image_path(
293
+ resolved_image_path,
294
+ Path(config["image_dir"]),
295
+ ),
296
+ "target_size": target_size,
297
+ "source_size": source_size,
298
+ "raw_num_layers": sample.get("num_layers"),
299
+ "num_layers": len(all_boxes),
300
+ "raw_boxes": sample.get("bboxes", []),
301
+ "boxes": all_boxes,
302
+ "caption": caption,
303
+ "run_name": resolved_run_name,
304
+ }
305
+ meta_path = case_dir / "inference_meta.json"
306
+ with meta_path.open("w", encoding="utf-8") as handle:
307
+ json.dump(case_meta, handle, indent=2)
308
+
309
+ return {
310
+ "run_name": resolved_run_name,
311
+ "save_dir": str(save_dir_path),
312
+ "case_dir": str(case_dir),
313
+ "merged_image": str(merged_case_path),
314
+ "merged_global_image": str(merged_global_path),
315
+ "merged_rgba_image": str(merged_rgba_path),
316
+ "whole_image_rgba": str(whole_rgba_path),
317
+ "background_rgba": str(background_rgba_path),
318
+ "origin_image": str(origin_path),
319
+ "layer_images": layer_paths,
320
+ "metadata_path": str(meta_path),
321
+ "metadata": case_meta,
322
+ }
323
+
324
+
325
+ def create_archive(run_dir: Path) -> Path:
326
+ archive_path = run_dir / "synlayers_result_bundle.zip"
327
+ with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
328
+ for path in run_dir.rglob("*"):
329
+ if path == archive_path or path.is_dir():
330
+ continue
331
+ zf.write(path, arcname=path.relative_to(run_dir))
332
+ return archive_path
333
+
334
+
335
+ def run_real_world_pipeline(
336
+ image_path: str | Path,
337
+ *,
338
+ sample_name: str | None = None,
339
+ work_dir: str | Path | None = None,
340
+ bbox_model: str | None = None,
341
+ config_path: str | Path | None = None,
342
+ max_new_tokens: int = 1024,
343
+ seed: int | None = None,
344
+ run_name: str = DEFAULT_RUN_NAME,
345
+ ) -> dict:
346
+ if not torch.cuda.is_available():
347
+ raise RuntimeError("CUDA GPU is required for the unified SynLayers real-world pipeline.")
348
+
349
+ image_path = Path(image_path)
350
+ if not image_path.exists():
351
+ raise FileNotFoundError(f"Input image not found: {image_path}")
352
+
353
+ bbox_model = (
354
+ bbox_model
355
+ or os.environ.get("SYNLAYERS_BBOX_MODEL")
356
+ or os.environ.get("SYNLAYERS_MODEL_REPO")
357
+ or DEFAULT_BBOX_MODEL
358
+ )
359
+ config_path = Path(config_path or os.environ.get("SYNLAYERS_REAL_CONFIG", str(DEFAULT_REAL_CONFIG_PATH)))
360
+ work_dir = Path(work_dir or os.environ.get("SYNLAYERS_DEMO_WORK_DIR", str(DEFAULT_WORK_DIR)))
361
+
362
+ normalized_sample_name = slugify(sample_name or image_path.stem)
363
+ timestamp = f"{time.strftime('%Y%m%d_%H%M%S')}_{int((time.time() % 1) * 1000):03d}"
364
+ run_dir = work_dir / f"{timestamp}_{normalized_sample_name}"
365
+ image_dir = run_dir / "layers_real_test_1024"
366
+ prepared_image_path = prepare_input_image(
367
+ image_path,
368
+ image_dir / f"{normalized_sample_name}.png",
369
+ DEFAULT_TARGET_SIZE,
370
+ )
371
+
372
+ bbox_model_bundle, bbox_processor = load_bbox_bundle(bbox_model)
373
+ whole_caption, bboxes = infer_caption_bbox(
374
+ prepared_image_path,
375
+ bbox_model_bundle,
376
+ bbox_processor,
377
+ prompt=CAPTION_BBOX_PROMPT_TOP_LEFT,
378
+ max_new_tokens=max_new_tokens,
379
+ )
380
+
381
+ record = {
382
+ "sample_or_stem": normalized_sample_name,
383
+ "image": prepared_image_path.name,
384
+ "whole_caption": whole_caption,
385
+ "bboxes": bboxes,
386
+ "num_layers": len(bboxes),
387
+ "coord": "top_left",
388
+ }
389
+
390
+ bbox_jsonl = write_bbox_jsonl(record, run_dir / "caption_bbox_infer.jsonl")
391
+ bbox_vis_path = run_dir / "bbox_vis" / f"{normalized_sample_name}_vis.png"
392
+ draw_boxes(prepared_image_path, bboxes, bbox_vis_path)
393
+
394
+ config = build_runtime_config(
395
+ config_path=config_path,
396
+ image_dir=image_dir,
397
+ bbox_jsonl=bbox_jsonl,
398
+ results_root=run_dir / "results",
399
+ run_name=run_name,
400
+ seed=seed,
401
+ )
402
+ pipeline, transp_vae = load_real_bundle(config)
403
+ decomposition_result = save_real_case(
404
+ sample=record,
405
+ config=config,
406
+ pipeline=pipeline,
407
+ transp_vae=transp_vae,
408
+ )
409
+
410
+ archive_path = create_archive(run_dir)
411
+ decomposition_result.update(
412
+ {
413
+ "input_image": str(prepared_image_path),
414
+ "bbox_visualization": str(bbox_vis_path),
415
+ "bbox_jsonl": str(bbox_jsonl),
416
+ "bbox_record": record,
417
+ "archive_path": str(archive_path),
418
+ "config_path": str(config_path),
419
+ "bbox_model": bbox_model,
420
+ }
421
+ )
422
+ return decomposition_result
423
+
424
+
425
+ def main():
426
+ parser = argparse.ArgumentParser(
427
+ description="Run the unified real-world SynLayers pipeline on one image."
428
+ )
429
+ parser.add_argument("--image", type=str, required=True, help="Input image path")
430
+ parser.add_argument("--sample-name", type=str, default=None)
431
+ parser.add_argument("--work-dir", type=str, default=str(DEFAULT_WORK_DIR))
432
+ parser.add_argument("--bbox-model", type=str, default=DEFAULT_BBOX_MODEL)
433
+ parser.add_argument("--config", type=str, default=str(DEFAULT_REAL_CONFIG_PATH))
434
+ parser.add_argument("--max-new-tokens", type=int, default=1024)
435
+ parser.add_argument("--seed", type=int, default=None)
436
+ parser.add_argument("--run-name", type=str, default=DEFAULT_RUN_NAME)
437
+ args = parser.parse_args()
438
+
439
+ result = run_real_world_pipeline(
440
+ args.image,
441
+ sample_name=args.sample_name,
442
+ work_dir=args.work_dir,
443
+ bbox_model=args.bbox_model,
444
+ config_path=args.config,
445
+ max_new_tokens=args.max_new_tokens,
446
+ seed=args.seed,
447
+ run_name=args.run_name,
448
+ )
449
+ print(json.dumps(result, indent=2, ensure_ascii=False))
450
+
451
+
452
+ if __name__ == "__main__":
453
+ main()