SynLayers commited on
Commit
1e1d6df
·
verified ·
1 Parent(s): 0203586

Upload demo/real_world_pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo/real_world_pipeline.py +436 -0
demo/real_world_pipeline.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import re
7
+ import sys
8
+ import time
9
+ import zipfile
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import torch
14
+ from PIL import Image, ImageOps
15
+
16
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
17
+ if str(PROJECT_ROOT) not in sys.path:
18
+ sys.path.insert(0, str(PROJECT_ROOT))
19
+
20
+ from demo.infer.run_caption_bbox_infer import ( # noqa: E402
21
+ CAPTION_BBOX_PROMPT_TOP_LEFT,
22
+ DEFAULT_BBOX_MODEL,
23
+ draw_boxes,
24
+ infer_caption_bbox,
25
+ )
26
+ from demo.infer.vlm_bbox_inference import get_model_and_processor # noqa: E402
27
+ from infer.common_infer import initialize_pipeline # noqa: E402
28
+ from infer.infer import build_run_save_dir, get_real_boxes, load_adapter_image # noqa: E402
29
+ from tools.tools import load_config, seed_everything # noqa: E402
30
+
31
+
32
+ DEFAULT_REAL_CONFIG_PATH = PROJECT_ROOT / "infer" / "infer.yaml"
33
+ DEFAULT_WORK_DIR = PROJECT_ROOT / "demo" / "outputs" / "real_world_demo"
34
+ DEFAULT_RUN_NAME = "step_120000"
35
+ DEFAULT_TARGET_SIZE = 1024
36
+
37
+ _BBOX_CACHE: dict[str, object] = {"model_path": None, "model": None, "processor": None}
38
+ _REAL_CACHE: dict[str, object] = {"key": None, "pipeline": None, "transp_vae": None}
39
+
40
+
41
+ def slugify(text: str) -> str:
42
+ value = re.sub(r"[^A-Za-z0-9._-]+", "_", text).strip("._-")
43
+ return value or "sample"
44
+
45
+
46
+ def resolve_existing_path(*candidates) -> str | None:
47
+ for candidate in candidates:
48
+ if not candidate:
49
+ continue
50
+ path = Path(candidate)
51
+ if path.exists():
52
+ return str(path)
53
+ return None
54
+
55
+
56
+ DEFAULT_DECOMP_CKPT_ROOT = Path(
57
+ resolve_existing_path(
58
+ os.environ.get("SYNLAYERS_DECOMP_CKPT_ROOT"),
59
+ PROJECT_ROOT / "SynLayers_ckpt" / "step_120000",
60
+ "/project/llmsvgen/share/data/kmw_layered_checkpoint/SynLayers_ckpt/step_120000",
61
+ )
62
+ or "/project/llmsvgen/share/data/kmw_layered_checkpoint/SynLayers_ckpt/step_120000"
63
+ )
64
+
65
+
66
+ def prepare_input_image(input_path: str | Path, output_path: Path, size: int) -> Path:
67
+ image = Image.open(input_path).convert("RGB")
68
+
69
+ if image.size != (size, size):
70
+ resized = ImageOps.contain(image, (size, size), Image.LANCZOS)
71
+ canvas = Image.new("RGB", (size, size), (255, 255, 255))
72
+ offset = ((size - resized.width) // 2, (size - resized.height) // 2)
73
+ canvas.paste(resized, offset)
74
+ image = canvas
75
+
76
+ output_path.parent.mkdir(parents=True, exist_ok=True)
77
+ image.save(output_path)
78
+ return output_path
79
+
80
+
81
+ def load_bbox_bundle(model_path: str):
82
+ cached_model_path = _BBOX_CACHE["model_path"]
83
+ if cached_model_path == model_path and _BBOX_CACHE["model"] is not None:
84
+ return _BBOX_CACHE["model"], _BBOX_CACHE["processor"]
85
+
86
+ model, processor = get_model_and_processor(model_path)
87
+ _BBOX_CACHE.update(
88
+ {
89
+ "model_path": model_path,
90
+ "model": model,
91
+ "processor": processor,
92
+ }
93
+ )
94
+ return model, processor
95
+
96
+
97
+ def load_real_bundle(config: dict):
98
+ key = (
99
+ config.get("pretrained_model_name_or_path"),
100
+ config.get("pretrained_adapter_path"),
101
+ config.get("transp_vae_path"),
102
+ config.get("pretrained_lora_dir"),
103
+ config.get("artplus_lora_dir"),
104
+ config.get("lora_ckpt"),
105
+ config.get("layer_ckpt"),
106
+ config.get("adapter_lora_dir"),
107
+ config.get("max_layer_num"),
108
+ )
109
+
110
+ if _REAL_CACHE["key"] == key and _REAL_CACHE["pipeline"] is not None:
111
+ return _REAL_CACHE["pipeline"], _REAL_CACHE["transp_vae"]
112
+
113
+ if _REAL_CACHE["pipeline"] is not None:
114
+ del _REAL_CACHE["pipeline"]
115
+ del _REAL_CACHE["transp_vae"]
116
+ _REAL_CACHE["pipeline"] = None
117
+ _REAL_CACHE["transp_vae"] = None
118
+ if torch.cuda.is_available():
119
+ torch.cuda.empty_cache()
120
+
121
+ pipeline, transp_vae = initialize_pipeline(config)
122
+ _REAL_CACHE.update({"key": key, "pipeline": pipeline, "transp_vae": transp_vae})
123
+ return pipeline, transp_vae
124
+
125
+
126
+ def build_runtime_config(
127
+ *,
128
+ config_path: str | Path,
129
+ image_dir: Path,
130
+ bbox_jsonl: Path,
131
+ results_root: Path,
132
+ run_name: str,
133
+ seed: int | None = None,
134
+ ) -> dict:
135
+ config = load_config(str(config_path))
136
+ config["data_dir"] = str(image_dir.parent)
137
+ config["image_dir"] = str(image_dir)
138
+ config["test_jsonl"] = str(bbox_jsonl)
139
+ config["save_dir"] = str(results_root)
140
+ config["run_name"] = run_name
141
+ config["lora_ckpt"] = str(DEFAULT_DECOMP_CKPT_ROOT / "transformer")
142
+ config["layer_ckpt"] = str(DEFAULT_DECOMP_CKPT_ROOT)
143
+ config["adapter_lora_dir"] = str(DEFAULT_DECOMP_CKPT_ROOT / "adapter")
144
+
145
+ env_overrides = {
146
+ "pretrained_model_name_or_path": (
147
+ os.environ.get("SYNLAYERS_BASE_MODEL")
148
+ or resolve_existing_path(PROJECT_ROOT / "SynLayers_checkpoints" / "FLUX.1-dev")
149
+ or "black-forest-labs/FLUX.1-dev"
150
+ ),
151
+ "pretrained_adapter_path": (
152
+ os.environ.get("SYNLAYERS_ADAPTER_MODEL")
153
+ or resolve_existing_path(
154
+ PROJECT_ROOT / "SynLayers_checkpoints" / "FLUX.1-dev-Controlnet-Inpainting-Alpha"
155
+ )
156
+ ),
157
+ "transp_vae_path": (
158
+ os.environ.get("SYNLAYERS_TRANSP_VAE")
159
+ or resolve_existing_path(PROJECT_ROOT / "ckpt" / "trans_vae" / "0008000.pt")
160
+ ),
161
+ "pretrained_lora_dir": (
162
+ os.environ.get("SYNLAYERS_PRETRAINED_LORA")
163
+ or resolve_existing_path(PROJECT_ROOT / "ckpt" / "pre_trained_LoRA")
164
+ ),
165
+ "artplus_lora_dir": (
166
+ os.environ.get("SYNLAYERS_ARTPLUS_LORA")
167
+ or resolve_existing_path(PROJECT_ROOT / "ckpt" / "prism_ft_LoRA")
168
+ ),
169
+ }
170
+ for key, value in env_overrides.items():
171
+ if value:
172
+ config[key] = value
173
+
174
+ if seed is not None:
175
+ config["seed"] = seed
176
+
177
+ return config
178
+
179
+
180
+ def write_bbox_jsonl(record: dict, output_path: Path) -> Path:
181
+ output_path.parent.mkdir(parents=True, exist_ok=True)
182
+ with output_path.open("w", encoding="utf-8") as handle:
183
+ handle.write(json.dumps(record, ensure_ascii=False) + "\n")
184
+ return output_path
185
+
186
+
187
+ def format_source_image_path(image_path: str, image_dir: Path) -> str:
188
+ path = Path(image_path)
189
+ try:
190
+ return path.relative_to(image_dir).as_posix()
191
+ except ValueError:
192
+ return path.name
193
+
194
+
195
+ def save_real_case(
196
+ *,
197
+ sample: dict,
198
+ config: dict,
199
+ pipeline,
200
+ transp_vae,
201
+ ) -> dict:
202
+ if config.get("seed") is not None:
203
+ seed_everything(config["seed"])
204
+
205
+ source_size = config.get("source_size", DEFAULT_TARGET_SIZE)
206
+ target_size = config.get("target_size", DEFAULT_TARGET_SIZE)
207
+ max_layer_num = config.get("max_layer_num", 52)
208
+ sample_name = sample["sample_or_stem"]
209
+
210
+ layer_boxes = get_real_boxes(sample, source_size, target_size)
211
+ adapter_img, resolved_image_path = load_adapter_image(sample, target_size, config)
212
+
213
+ whole_box = (0, 0, target_size, target_size)
214
+ bg_box = (0, 0, target_size, target_size)
215
+ all_boxes = [whole_box, bg_box] + layer_boxes
216
+ if len(all_boxes) > max_layer_num:
217
+ raise ValueError(
218
+ f"num_layers={len(all_boxes)} exceeds max_layer_num={max_layer_num} for {sample_name}"
219
+ )
220
+
221
+ generator = torch.Generator(device=torch.device("cuda")).manual_seed(config.get("seed", 42))
222
+ caption = sample.get("whole_caption", "")
223
+
224
+ x_hat, image, _ = pipeline(
225
+ prompt=caption,
226
+ adapter_image=adapter_img,
227
+ adapter_conditioning_scale=config.get("adapter_scale", 0.9),
228
+ validation_box=all_boxes,
229
+ generator=generator,
230
+ height=target_size,
231
+ width=target_size,
232
+ guidance_scale=config.get("cfg", 4.0),
233
+ num_layers=len(all_boxes),
234
+ sdxl_vae=transp_vae,
235
+ )
236
+
237
+ x_hat = (x_hat + 1) / 2
238
+ x_hat = x_hat.squeeze(0).permute(1, 0, 2, 3).to(torch.float32)
239
+
240
+ save_dir, resolved_run_name = build_run_save_dir(config)
241
+ save_dir_path = Path(save_dir)
242
+ case_dir = save_dir_path / sample_name
243
+ merged_dir = save_dir_path / "merged"
244
+ merged_rgba_dir = save_dir_path / "merged_rgba"
245
+ case_dir.mkdir(parents=True, exist_ok=True)
246
+ merged_dir.mkdir(parents=True, exist_ok=True)
247
+ merged_rgba_dir.mkdir(parents=True, exist_ok=True)
248
+
249
+ whole_rgba_path = case_dir / "whole_image_rgba.png"
250
+ background_rgba_path = case_dir / "background_rgba.png"
251
+ origin_path = case_dir / "origin.png"
252
+ merged_case_path = case_dir / "merged.png"
253
+ merged_global_path = merged_dir / f"{sample_name}.png"
254
+ merged_rgba_path = merged_rgba_dir / f"{sample_name}.png"
255
+
256
+ whole_image_layer = (x_hat[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
257
+ Image.fromarray(whole_image_layer, "RGBA").save(whole_rgba_path)
258
+
259
+ background_layer = (x_hat[1].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
260
+ Image.fromarray(background_layer, "RGBA").save(background_rgba_path)
261
+
262
+ adapter_img.save(origin_path)
263
+
264
+ merged_image = image[1]
265
+ layer_paths: list[str] = []
266
+ for layer_idx in range(2, x_hat.shape[0]):
267
+ rgba_layer = (x_hat[layer_idx].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
268
+ rgba_image = Image.fromarray(rgba_layer, "RGBA")
269
+ layer_path = case_dir / f"layer_{layer_idx - 2}_rgba.png"
270
+ rgba_image.save(layer_path)
271
+ layer_paths.append(str(layer_path))
272
+ merged_image = Image.alpha_composite(merged_image.convert("RGBA"), rgba_image)
273
+
274
+ merged_image.convert("RGB").save(merged_global_path)
275
+ merged_image.convert("RGB").save(merged_case_path)
276
+ merged_image.save(merged_rgba_path)
277
+
278
+ case_meta = {
279
+ "sample_name": sample_name,
280
+ "source_image_path": format_source_image_path(
281
+ resolved_image_path,
282
+ Path(config["image_dir"]),
283
+ ),
284
+ "target_size": target_size,
285
+ "source_size": source_size,
286
+ "raw_num_layers": sample.get("num_layers"),
287
+ "num_layers": len(all_boxes),
288
+ "raw_boxes": sample.get("bboxes", []),
289
+ "boxes": all_boxes,
290
+ "caption": caption,
291
+ "run_name": resolved_run_name,
292
+ }
293
+ meta_path = case_dir / "inference_meta.json"
294
+ with meta_path.open("w", encoding="utf-8") as handle:
295
+ json.dump(case_meta, handle, indent=2)
296
+
297
+ return {
298
+ "run_name": resolved_run_name,
299
+ "save_dir": str(save_dir_path),
300
+ "case_dir": str(case_dir),
301
+ "merged_image": str(merged_case_path),
302
+ "merged_global_image": str(merged_global_path),
303
+ "merged_rgba_image": str(merged_rgba_path),
304
+ "whole_image_rgba": str(whole_rgba_path),
305
+ "background_rgba": str(background_rgba_path),
306
+ "origin_image": str(origin_path),
307
+ "layer_images": layer_paths,
308
+ "metadata_path": str(meta_path),
309
+ "metadata": case_meta,
310
+ }
311
+
312
+
313
+ def create_archive(run_dir: Path) -> Path:
314
+ archive_path = run_dir / "synlayers_result_bundle.zip"
315
+ with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
316
+ for path in run_dir.rglob("*"):
317
+ if path == archive_path or path.is_dir():
318
+ continue
319
+ zf.write(path, arcname=path.relative_to(run_dir))
320
+ return archive_path
321
+
322
+
323
+ def run_real_world_pipeline(
324
+ image_path: str | Path,
325
+ *,
326
+ sample_name: str | None = None,
327
+ work_dir: str | Path | None = None,
328
+ bbox_model: str | None = None,
329
+ config_path: str | Path | None = None,
330
+ max_new_tokens: int = 1024,
331
+ seed: int | None = None,
332
+ run_name: str = DEFAULT_RUN_NAME,
333
+ ) -> dict:
334
+ if not torch.cuda.is_available():
335
+ raise RuntimeError("CUDA GPU is required for the unified SynLayers real-world pipeline.")
336
+
337
+ image_path = Path(image_path)
338
+ if not image_path.exists():
339
+ raise FileNotFoundError(f"Input image not found: {image_path}")
340
+
341
+ bbox_model = bbox_model or os.environ.get("SYNLAYERS_BBOX_MODEL", DEFAULT_BBOX_MODEL)
342
+ config_path = Path(config_path or os.environ.get("SYNLAYERS_REAL_CONFIG", str(DEFAULT_REAL_CONFIG_PATH)))
343
+ work_dir = Path(work_dir or os.environ.get("SYNLAYERS_DEMO_WORK_DIR", str(DEFAULT_WORK_DIR)))
344
+
345
+ normalized_sample_name = slugify(sample_name or image_path.stem)
346
+ timestamp = f"{time.strftime('%Y%m%d_%H%M%S')}_{int((time.time() % 1) * 1000):03d}"
347
+ run_dir = work_dir / f"{timestamp}_{normalized_sample_name}"
348
+ image_dir = run_dir / "layers_real_test_1024"
349
+ prepared_image_path = prepare_input_image(
350
+ image_path,
351
+ image_dir / f"{normalized_sample_name}.png",
352
+ DEFAULT_TARGET_SIZE,
353
+ )
354
+
355
+ bbox_model_bundle, bbox_processor = load_bbox_bundle(bbox_model)
356
+ whole_caption, bboxes = infer_caption_bbox(
357
+ prepared_image_path,
358
+ bbox_model_bundle,
359
+ bbox_processor,
360
+ prompt=CAPTION_BBOX_PROMPT_TOP_LEFT,
361
+ max_new_tokens=max_new_tokens,
362
+ )
363
+
364
+ record = {
365
+ "sample_or_stem": normalized_sample_name,
366
+ "image": prepared_image_path.name,
367
+ "whole_caption": whole_caption,
368
+ "bboxes": bboxes,
369
+ "num_layers": len(bboxes),
370
+ "coord": "top_left",
371
+ }
372
+
373
+ bbox_jsonl = write_bbox_jsonl(record, run_dir / "caption_bbox_infer.jsonl")
374
+ bbox_vis_path = run_dir / "bbox_vis" / f"{normalized_sample_name}_vis.png"
375
+ draw_boxes(prepared_image_path, bboxes, bbox_vis_path)
376
+
377
+ config = build_runtime_config(
378
+ config_path=config_path,
379
+ image_dir=image_dir,
380
+ bbox_jsonl=bbox_jsonl,
381
+ results_root=run_dir / "results",
382
+ run_name=run_name,
383
+ seed=seed,
384
+ )
385
+ pipeline, transp_vae = load_real_bundle(config)
386
+ decomposition_result = save_real_case(
387
+ sample=record,
388
+ config=config,
389
+ pipeline=pipeline,
390
+ transp_vae=transp_vae,
391
+ )
392
+
393
+ archive_path = create_archive(run_dir)
394
+ decomposition_result.update(
395
+ {
396
+ "input_image": str(prepared_image_path),
397
+ "bbox_visualization": str(bbox_vis_path),
398
+ "bbox_jsonl": str(bbox_jsonl),
399
+ "bbox_record": record,
400
+ "archive_path": str(archive_path),
401
+ "config_path": str(config_path),
402
+ "bbox_model": bbox_model,
403
+ }
404
+ )
405
+ return decomposition_result
406
+
407
+
408
+ def main():
409
+ parser = argparse.ArgumentParser(
410
+ description="Run the unified real-world SynLayers pipeline on one image."
411
+ )
412
+ parser.add_argument("--image", type=str, required=True, help="Input image path")
413
+ parser.add_argument("--sample-name", type=str, default=None)
414
+ parser.add_argument("--work-dir", type=str, default=str(DEFAULT_WORK_DIR))
415
+ parser.add_argument("--bbox-model", type=str, default=DEFAULT_BBOX_MODEL)
416
+ parser.add_argument("--config", type=str, default=str(DEFAULT_REAL_CONFIG_PATH))
417
+ parser.add_argument("--max-new-tokens", type=int, default=1024)
418
+ parser.add_argument("--seed", type=int, default=None)
419
+ parser.add_argument("--run-name", type=str, default=DEFAULT_RUN_NAME)
420
+ args = parser.parse_args()
421
+
422
+ result = run_real_world_pipeline(
423
+ args.image,
424
+ sample_name=args.sample_name,
425
+ work_dir=args.work_dir,
426
+ bbox_model=args.bbox_model,
427
+ config_path=args.config,
428
+ max_new_tokens=args.max_new_tokens,
429
+ seed=args.seed,
430
+ run_name=args.run_name,
431
+ )
432
+ print(json.dumps(result, indent=2, ensure_ascii=False))
433
+
434
+
435
+ if __name__ == "__main__":
436
+ main()