from __future__ import annotations import argparse import gc import json import os import re import sys import time import zipfile from pathlib import Path import numpy as np import torch from PIL import Image, ImageOps PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from demo.infer.run_caption_bbox_infer import ( # noqa: E402 CAPTION_BBOX_PROMPT_TOP_LEFT, DEFAULT_BBOX_MODEL, draw_boxes, infer_caption_bbox, ) from demo.hf_repo_assets import build_repo_asset_overrides, get_stage2_model_repo_id # noqa: E402 from demo.infer.vlm_bbox_inference import get_model_and_processor # noqa: E402 from infer.common_infer import initialize_pipeline # noqa: E402 from infer.infer import build_run_save_dir, get_real_boxes, load_adapter_image # noqa: E402 from tools.tools import load_config, seed_everything # noqa: E402 DEFAULT_REAL_CONFIG_PATH = PROJECT_ROOT / "infer" / "infer.yaml" DEFAULT_WORK_DIR = PROJECT_ROOT / "demo" / "outputs" / "real_world_demo" DEFAULT_RUN_NAME = "step_120000" DEFAULT_TARGET_SIZE = 1024 DEFAULT_STAGE2_MODEL_REPO_ID = "SynLayers/synlayers" _BBOX_CACHE: dict[str, object] = {"model_path": None, "model": None, "processor": None} _REAL_CACHE: dict[str, object] = {"key": None, "pipeline": None, "transp_vae": None} RELEASE_BBOX_AFTER_CAPTION = os.environ.get("SYNLAYERS_RELEASE_BBOX_AFTER_CAPTION", "0") == "1" def slugify(text: str) -> str: value = re.sub(r"[^A-Za-z0-9._-]+", "_", text).strip("._-") return value or "sample" def resolve_existing_path(*candidates) -> str | None: for candidate in candidates: if not candidate: continue path = Path(candidate) if path.exists(): return str(path) return None DEFAULT_DECOMP_CKPT_ROOT = Path( resolve_existing_path( os.environ.get("SYNLAYERS_DECOMP_CKPT_ROOT"), PROJECT_ROOT / "SynLayers_ckpt" / "step_120000", ) or PROJECT_ROOT / "SynLayers_ckpt" / "step_120000" ) def prepare_input_image(input_path: str | Path, output_path: Path, size: int) -> Path: image = Image.open(input_path).convert("RGB") if image.size != (size, size): resized = ImageOps.contain(image, (size, size), Image.LANCZOS) canvas = Image.new("RGB", (size, size), (255, 255, 255)) offset = ((size - resized.width) // 2, (size - resized.height) // 2) canvas.paste(resized, offset) image = canvas output_path.parent.mkdir(parents=True, exist_ok=True) image.save(output_path) return output_path def load_bbox_bundle(model_path: str): cached_model_path = _BBOX_CACHE["model_path"] if cached_model_path == model_path and _BBOX_CACHE["model"] is not None: return _BBOX_CACHE["model"], _BBOX_CACHE["processor"] model, processor = get_model_and_processor(model_path) _BBOX_CACHE.update( { "model_path": model_path, "model": model, "processor": processor, } ) return model, processor def release_bbox_bundle(): model = _BBOX_CACHE.get("model") processor = _BBOX_CACHE.get("processor") if model is not None: del model if processor is not None: del processor _BBOX_CACHE.update({"model_path": None, "model": None, "processor": None}) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def load_real_bundle(config: dict): key = ( config.get("pretrained_model_name_or_path"), config.get("pretrained_adapter_path"), config.get("transp_vae_path"), config.get("pretrained_lora_dir"), config.get("artplus_lora_dir"), config.get("lora_ckpt"), config.get("layer_ckpt"), config.get("adapter_lora_dir"), config.get("max_layer_num"), ) if _REAL_CACHE["key"] == key and _REAL_CACHE["pipeline"] is not None: return _REAL_CACHE["pipeline"], _REAL_CACHE["transp_vae"] if _REAL_CACHE["pipeline"] is not None: del _REAL_CACHE["pipeline"] del _REAL_CACHE["transp_vae"] _REAL_CACHE["pipeline"] = None _REAL_CACHE["transp_vae"] = None if torch.cuda.is_available(): torch.cuda.empty_cache() pipeline, transp_vae = initialize_pipeline(config) _REAL_CACHE.update({"key": key, "pipeline": pipeline, "transp_vae": transp_vae}) return pipeline, transp_vae def build_runtime_config( *, config_path: str | Path, image_dir: Path, bbox_jsonl: Path, results_root: Path, run_name: str, seed: int | None = None, ) -> dict: config = load_config(str(config_path)) stage2_model_repo = get_stage2_model_repo_id() repo_overrides = build_repo_asset_overrides(stage2_model_repo) decomp_ckpt_root = Path( os.environ.get("SYNLAYERS_DECOMP_CKPT_ROOT") or repo_overrides.get("decomp_ckpt_root") or DEFAULT_DECOMP_CKPT_ROOT ) config["data_dir"] = str(image_dir.parent) config["image_dir"] = str(image_dir) config["test_jsonl"] = str(bbox_jsonl) config["save_dir"] = str(results_root) config["run_name"] = run_name config["lora_ckpt"] = str(decomp_ckpt_root / "transformer") config["layer_ckpt"] = str(decomp_ckpt_root) config["adapter_lora_dir"] = str(decomp_ckpt_root / "adapter") env_overrides = { "pretrained_model_name_or_path": ( repo_overrides.get("pretrained_model_name_or_path") or resolve_existing_path(PROJECT_ROOT / "SynLayers_checkpoints" / "FLUX.1-dev") ), "pretrained_adapter_path": ( os.environ.get("SYNLAYERS_ADAPTER_MODEL") or repo_overrides.get("pretrained_adapter_path") or resolve_existing_path( PROJECT_ROOT / "SynLayers_checkpoints" / "FLUX.1-dev-Controlnet-Inpainting-Alpha" ) ), "transp_vae_path": ( os.environ.get("SYNLAYERS_TRANSP_VAE") or repo_overrides.get("transp_vae_path") or resolve_existing_path(PROJECT_ROOT / "ckpt" / "trans_vae" / "0008000.pt") ), "pretrained_lora_dir": ( os.environ.get("SYNLAYERS_PRETRAINED_LORA") or repo_overrides.get("pretrained_lora_dir") or resolve_existing_path(PROJECT_ROOT / "ckpt" / "pre_trained_LoRA") ), "artplus_lora_dir": ( os.environ.get("SYNLAYERS_ARTPLUS_LORA") or repo_overrides.get("artplus_lora_dir") or resolve_existing_path(PROJECT_ROOT / "ckpt" / "prism_ft_LoRA") ), } for key, value in env_overrides.items(): if value: config[key] = value if seed is not None: config["seed"] = seed return config def write_bbox_jsonl(record: dict, output_path: Path) -> Path: output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w", encoding="utf-8") as handle: handle.write(json.dumps(record, ensure_ascii=False) + "\n") return output_path def format_source_image_path(image_path: str, image_dir: Path) -> str: path = Path(image_path) try: return path.relative_to(image_dir).as_posix() except ValueError: return path.name def save_real_case( *, sample: dict, config: dict, pipeline, transp_vae, ) -> dict: if config.get("seed") is not None: seed_everything(config["seed"]) source_size = config.get("source_size", DEFAULT_TARGET_SIZE) target_size = config.get("target_size", DEFAULT_TARGET_SIZE) max_layer_num = config.get("max_layer_num", 52) sample_name = sample["sample_or_stem"] layer_boxes = get_real_boxes(sample, source_size, target_size) adapter_img, resolved_image_path = load_adapter_image(sample, target_size, config) whole_box = (0, 0, target_size, target_size) bg_box = (0, 0, target_size, target_size) all_boxes = [whole_box, bg_box] + layer_boxes if len(all_boxes) > max_layer_num: raise ValueError( f"num_layers={len(all_boxes)} exceeds max_layer_num={max_layer_num} for {sample_name}" ) generator = torch.Generator(device=torch.device("cuda")).manual_seed(config.get("seed", 42)) caption = sample.get("whole_caption", "") x_hat, image, _ = pipeline( prompt=caption, adapter_image=adapter_img, adapter_conditioning_scale=config.get("adapter_scale", 0.9), validation_box=all_boxes, generator=generator, height=target_size, width=target_size, guidance_scale=config.get("cfg", 4.0), num_layers=len(all_boxes), sdxl_vae=transp_vae, ) x_hat = (x_hat + 1) / 2 x_hat = x_hat.squeeze(0).permute(1, 0, 2, 3).to(torch.float32) save_dir, resolved_run_name = build_run_save_dir(config) save_dir_path = Path(save_dir) case_dir = save_dir_path / sample_name merged_dir = save_dir_path / "merged" merged_rgba_dir = save_dir_path / "merged_rgba" case_dir.mkdir(parents=True, exist_ok=True) merged_dir.mkdir(parents=True, exist_ok=True) merged_rgba_dir.mkdir(parents=True, exist_ok=True) whole_rgba_path = case_dir / "whole_image_rgba.png" background_rgba_path = case_dir / "background_rgba.png" origin_path = case_dir / "origin.png" merged_case_path = case_dir / "merged.png" merged_global_path = merged_dir / f"{sample_name}.png" merged_rgba_path = merged_rgba_dir / f"{sample_name}.png" whole_image_layer = (x_hat[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) Image.fromarray(whole_image_layer, "RGBA").save(whole_rgba_path) background_layer = (x_hat[1].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) Image.fromarray(background_layer, "RGBA").save(background_rgba_path) adapter_img.save(origin_path) merged_image = image[1] layer_paths: list[str] = [] for layer_idx in range(2, x_hat.shape[0]): rgba_layer = (x_hat[layer_idx].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) rgba_image = Image.fromarray(rgba_layer, "RGBA") layer_path = case_dir / f"layer_{layer_idx - 2}_rgba.png" rgba_image.save(layer_path) layer_paths.append(str(layer_path)) merged_image = Image.alpha_composite(merged_image.convert("RGBA"), rgba_image) merged_image.convert("RGB").save(merged_global_path) merged_image.convert("RGB").save(merged_case_path) merged_image.save(merged_rgba_path) case_meta = { "sample_name": sample_name, "source_image_path": format_source_image_path( resolved_image_path, Path(config["image_dir"]), ), "target_size": target_size, "source_size": source_size, "raw_num_layers": sample.get("num_layers"), "num_layers": len(all_boxes), "raw_boxes": sample.get("bboxes", []), "boxes": all_boxes, "caption": caption, "run_name": resolved_run_name, } meta_path = case_dir / "inference_meta.json" with meta_path.open("w", encoding="utf-8") as handle: json.dump(case_meta, handle, indent=2) return { "run_name": resolved_run_name, "save_dir": str(save_dir_path), "case_dir": str(case_dir), "merged_image": str(merged_case_path), "merged_global_image": str(merged_global_path), "merged_rgba_image": str(merged_rgba_path), "whole_image_rgba": str(whole_rgba_path), "background_rgba": str(background_rgba_path), "origin_image": str(origin_path), "layer_images": layer_paths, "metadata_path": str(meta_path), "metadata": case_meta, } def create_archive(run_dir: Path) -> Path: archive_path = run_dir / "synlayers_result_bundle.zip" with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: for path in run_dir.rglob("*"): if path == archive_path or path.is_dir(): continue zf.write(path, arcname=path.relative_to(run_dir)) return archive_path def run_real_world_pipeline( image_path: str | Path, *, sample_name: str | None = None, work_dir: str | Path | None = None, bbox_model: str | None = None, config_path: str | Path | None = None, max_new_tokens: int = 1024, seed: int | None = None, run_name: str = DEFAULT_RUN_NAME, ) -> dict: if not torch.cuda.is_available(): raise RuntimeError( "CUDA GPU is required for the unified SynLayers real-world pipeline. " "On Hugging Face Spaces, assign GPU hardware such as A100 and rebuild the Space." ) image_path = Path(image_path) if not image_path.exists(): raise FileNotFoundError(f"Input image not found: {image_path}") bbox_model = ( bbox_model or os.environ.get("SYNLAYERS_BBOX_MODEL") or os.environ.get("SYNLAYERS_BBOX_MODEL_REPO") or DEFAULT_BBOX_MODEL ) config_path = Path(config_path or os.environ.get("SYNLAYERS_REAL_CONFIG", str(DEFAULT_REAL_CONFIG_PATH))) work_dir = Path(work_dir or os.environ.get("SYNLAYERS_DEMO_WORK_DIR", str(DEFAULT_WORK_DIR))) normalized_sample_name = slugify(sample_name or image_path.stem) timestamp = f"{time.strftime('%Y%m%d_%H%M%S')}_{int((time.time() % 1) * 1000):03d}" run_dir = work_dir / f"{timestamp}_{normalized_sample_name}" image_dir = run_dir / "layers_real_test_1024" prepared_image_path = prepare_input_image( image_path, image_dir / f"{normalized_sample_name}.png", DEFAULT_TARGET_SIZE, ) bbox_model_bundle, bbox_processor = load_bbox_bundle(bbox_model) whole_caption, bboxes = infer_caption_bbox( prepared_image_path, bbox_model_bundle, bbox_processor, prompt=CAPTION_BBOX_PROMPT_TOP_LEFT, max_new_tokens=max_new_tokens, ) record = { "sample_or_stem": normalized_sample_name, "image": prepared_image_path.name, "whole_caption": whole_caption, "bboxes": bboxes, "num_layers": len(bboxes), "coord": "top_left", } bbox_jsonl = write_bbox_jsonl(record, run_dir / "caption_bbox_infer.jsonl") bbox_vis_path = run_dir / "bbox_vis" / f"{normalized_sample_name}_vis.png" draw_boxes(prepared_image_path, bboxes, bbox_vis_path) if RELEASE_BBOX_AFTER_CAPTION: release_bbox_bundle() config = build_runtime_config( config_path=config_path, image_dir=image_dir, bbox_jsonl=bbox_jsonl, results_root=run_dir / "results", run_name=run_name, seed=seed, ) pipeline, transp_vae = load_real_bundle(config) decomposition_result = save_real_case( sample=record, config=config, pipeline=pipeline, transp_vae=transp_vae, ) archive_path = create_archive(run_dir) decomposition_result.update( { "input_image": str(prepared_image_path), "bbox_visualization": str(bbox_vis_path), "bbox_jsonl": str(bbox_jsonl), "bbox_record": record, "archive_path": str(archive_path), "config_path": str(config_path), "bbox_model": bbox_model, } ) return decomposition_result def main(): parser = argparse.ArgumentParser( description="Run the unified real-world SynLayers pipeline on one image." ) parser.add_argument("--image", type=str, required=True, help="Input image path") parser.add_argument("--sample-name", type=str, default=None) parser.add_argument("--work-dir", type=str, default=str(DEFAULT_WORK_DIR)) parser.add_argument("--bbox-model", type=str, default=DEFAULT_BBOX_MODEL) parser.add_argument("--config", type=str, default=str(DEFAULT_REAL_CONFIG_PATH)) parser.add_argument("--max-new-tokens", type=int, default=1024) parser.add_argument("--seed", type=int, default=None) parser.add_argument("--run-name", type=str, default=DEFAULT_RUN_NAME) args = parser.parse_args() result = run_real_world_pipeline( args.image, sample_name=args.sample_name, work_dir=args.work_dir, bbox_model=args.bbox_model, config_path=args.config, max_new_tokens=args.max_new_tokens, seed=args.seed, run_name=args.run_name, ) print(json.dumps(result, indent=2, ensure_ascii=False)) if __name__ == "__main__": main()