synlayers / demo /real_world_pipeline.py
SynLayers's picture
Upload demo/real_world_pipeline.py with huggingface_hub
b1fab61 verified
from __future__ import annotations
import argparse
import gc
import json
import os
import re
import sys
import time
import zipfile
from pathlib import Path
import numpy as np
import torch
from PIL import Image, ImageOps
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from demo.infer.run_caption_bbox_infer import ( # noqa: E402
CAPTION_BBOX_PROMPT_TOP_LEFT,
DEFAULT_BBOX_MODEL,
draw_boxes,
infer_caption_bbox,
)
from demo.hf_repo_assets import build_repo_asset_overrides, get_stage2_model_repo_id # noqa: E402
from demo.infer.vlm_bbox_inference import get_model_and_processor # noqa: E402
from infer.common_infer import initialize_pipeline # noqa: E402
from infer.infer import build_run_save_dir, get_real_boxes, load_adapter_image # noqa: E402
from tools.tools import load_config, seed_everything # noqa: E402
DEFAULT_REAL_CONFIG_PATH = PROJECT_ROOT / "infer" / "infer.yaml"
DEFAULT_WORK_DIR = PROJECT_ROOT / "demo" / "outputs" / "real_world_demo"
DEFAULT_RUN_NAME = "step_120000"
DEFAULT_TARGET_SIZE = 1024
DEFAULT_STAGE2_MODEL_REPO_ID = "SynLayers/synlayers"
_BBOX_CACHE: dict[str, object] = {"model_path": None, "model": None, "processor": None}
_REAL_CACHE: dict[str, object] = {"key": None, "pipeline": None, "transp_vae": None}
RELEASE_BBOX_AFTER_CAPTION = os.environ.get("SYNLAYERS_RELEASE_BBOX_AFTER_CAPTION", "0") == "1"
def slugify(text: str) -> str:
value = re.sub(r"[^A-Za-z0-9._-]+", "_", text).strip("._-")
return value or "sample"
def resolve_existing_path(*candidates) -> str | None:
for candidate in candidates:
if not candidate:
continue
path = Path(candidate)
if path.exists():
return str(path)
return None
DEFAULT_DECOMP_CKPT_ROOT = Path(
resolve_existing_path(
os.environ.get("SYNLAYERS_DECOMP_CKPT_ROOT"),
PROJECT_ROOT / "SynLayers_ckpt" / "step_120000",
)
or PROJECT_ROOT / "SynLayers_ckpt" / "step_120000"
)
def prepare_input_image(input_path: str | Path, output_path: Path, size: int) -> Path:
image = Image.open(input_path).convert("RGB")
if image.size != (size, size):
resized = ImageOps.contain(image, (size, size), Image.LANCZOS)
canvas = Image.new("RGB", (size, size), (255, 255, 255))
offset = ((size - resized.width) // 2, (size - resized.height) // 2)
canvas.paste(resized, offset)
image = canvas
output_path.parent.mkdir(parents=True, exist_ok=True)
image.save(output_path)
return output_path
def load_bbox_bundle(model_path: str):
cached_model_path = _BBOX_CACHE["model_path"]
if cached_model_path == model_path and _BBOX_CACHE["model"] is not None:
return _BBOX_CACHE["model"], _BBOX_CACHE["processor"]
model, processor = get_model_and_processor(model_path)
_BBOX_CACHE.update(
{
"model_path": model_path,
"model": model,
"processor": processor,
}
)
return model, processor
def release_bbox_bundle():
model = _BBOX_CACHE.get("model")
processor = _BBOX_CACHE.get("processor")
if model is not None:
del model
if processor is not None:
del processor
_BBOX_CACHE.update({"model_path": None, "model": None, "processor": None})
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load_real_bundle(config: dict):
key = (
config.get("pretrained_model_name_or_path"),
config.get("pretrained_adapter_path"),
config.get("transp_vae_path"),
config.get("pretrained_lora_dir"),
config.get("artplus_lora_dir"),
config.get("lora_ckpt"),
config.get("layer_ckpt"),
config.get("adapter_lora_dir"),
config.get("max_layer_num"),
)
if _REAL_CACHE["key"] == key and _REAL_CACHE["pipeline"] is not None:
return _REAL_CACHE["pipeline"], _REAL_CACHE["transp_vae"]
if _REAL_CACHE["pipeline"] is not None:
del _REAL_CACHE["pipeline"]
del _REAL_CACHE["transp_vae"]
_REAL_CACHE["pipeline"] = None
_REAL_CACHE["transp_vae"] = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
pipeline, transp_vae = initialize_pipeline(config)
_REAL_CACHE.update({"key": key, "pipeline": pipeline, "transp_vae": transp_vae})
return pipeline, transp_vae
def build_runtime_config(
*,
config_path: str | Path,
image_dir: Path,
bbox_jsonl: Path,
results_root: Path,
run_name: str,
seed: int | None = None,
) -> dict:
config = load_config(str(config_path))
stage2_model_repo = get_stage2_model_repo_id()
repo_overrides = build_repo_asset_overrides(stage2_model_repo)
decomp_ckpt_root = Path(
os.environ.get("SYNLAYERS_DECOMP_CKPT_ROOT")
or repo_overrides.get("decomp_ckpt_root")
or DEFAULT_DECOMP_CKPT_ROOT
)
config["data_dir"] = str(image_dir.parent)
config["image_dir"] = str(image_dir)
config["test_jsonl"] = str(bbox_jsonl)
config["save_dir"] = str(results_root)
config["run_name"] = run_name
config["lora_ckpt"] = str(decomp_ckpt_root / "transformer")
config["layer_ckpt"] = str(decomp_ckpt_root)
config["adapter_lora_dir"] = str(decomp_ckpt_root / "adapter")
env_overrides = {
"pretrained_model_name_or_path": (
repo_overrides.get("pretrained_model_name_or_path")
or resolve_existing_path(PROJECT_ROOT / "SynLayers_checkpoints" / "FLUX.1-dev")
),
"pretrained_adapter_path": (
os.environ.get("SYNLAYERS_ADAPTER_MODEL")
or repo_overrides.get("pretrained_adapter_path")
or resolve_existing_path(
PROJECT_ROOT / "SynLayers_checkpoints" / "FLUX.1-dev-Controlnet-Inpainting-Alpha"
)
),
"transp_vae_path": (
os.environ.get("SYNLAYERS_TRANSP_VAE")
or repo_overrides.get("transp_vae_path")
or resolve_existing_path(PROJECT_ROOT / "ckpt" / "trans_vae" / "0008000.pt")
),
"pretrained_lora_dir": (
os.environ.get("SYNLAYERS_PRETRAINED_LORA")
or repo_overrides.get("pretrained_lora_dir")
or resolve_existing_path(PROJECT_ROOT / "ckpt" / "pre_trained_LoRA")
),
"artplus_lora_dir": (
os.environ.get("SYNLAYERS_ARTPLUS_LORA")
or repo_overrides.get("artplus_lora_dir")
or resolve_existing_path(PROJECT_ROOT / "ckpt" / "prism_ft_LoRA")
),
}
for key, value in env_overrides.items():
if value:
config[key] = value
if seed is not None:
config["seed"] = seed
return config
def write_bbox_jsonl(record: dict, output_path: Path) -> Path:
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as handle:
handle.write(json.dumps(record, ensure_ascii=False) + "\n")
return output_path
def format_source_image_path(image_path: str, image_dir: Path) -> str:
path = Path(image_path)
try:
return path.relative_to(image_dir).as_posix()
except ValueError:
return path.name
def save_real_case(
*,
sample: dict,
config: dict,
pipeline,
transp_vae,
) -> dict:
if config.get("seed") is not None:
seed_everything(config["seed"])
source_size = config.get("source_size", DEFAULT_TARGET_SIZE)
target_size = config.get("target_size", DEFAULT_TARGET_SIZE)
max_layer_num = config.get("max_layer_num", 52)
sample_name = sample["sample_or_stem"]
layer_boxes = get_real_boxes(sample, source_size, target_size)
adapter_img, resolved_image_path = load_adapter_image(sample, target_size, config)
whole_box = (0, 0, target_size, target_size)
bg_box = (0, 0, target_size, target_size)
all_boxes = [whole_box, bg_box] + layer_boxes
if len(all_boxes) > max_layer_num:
raise ValueError(
f"num_layers={len(all_boxes)} exceeds max_layer_num={max_layer_num} for {sample_name}"
)
generator = torch.Generator(device=torch.device("cuda")).manual_seed(config.get("seed", 42))
caption = sample.get("whole_caption", "")
x_hat, image, _ = pipeline(
prompt=caption,
adapter_image=adapter_img,
adapter_conditioning_scale=config.get("adapter_scale", 0.9),
validation_box=all_boxes,
generator=generator,
height=target_size,
width=target_size,
guidance_scale=config.get("cfg", 4.0),
num_layers=len(all_boxes),
sdxl_vae=transp_vae,
)
x_hat = (x_hat + 1) / 2
x_hat = x_hat.squeeze(0).permute(1, 0, 2, 3).to(torch.float32)
save_dir, resolved_run_name = build_run_save_dir(config)
save_dir_path = Path(save_dir)
case_dir = save_dir_path / sample_name
merged_dir = save_dir_path / "merged"
merged_rgba_dir = save_dir_path / "merged_rgba"
case_dir.mkdir(parents=True, exist_ok=True)
merged_dir.mkdir(parents=True, exist_ok=True)
merged_rgba_dir.mkdir(parents=True, exist_ok=True)
whole_rgba_path = case_dir / "whole_image_rgba.png"
background_rgba_path = case_dir / "background_rgba.png"
origin_path = case_dir / "origin.png"
merged_case_path = case_dir / "merged.png"
merged_global_path = merged_dir / f"{sample_name}.png"
merged_rgba_path = merged_rgba_dir / f"{sample_name}.png"
whole_image_layer = (x_hat[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
Image.fromarray(whole_image_layer, "RGBA").save(whole_rgba_path)
background_layer = (x_hat[1].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
Image.fromarray(background_layer, "RGBA").save(background_rgba_path)
adapter_img.save(origin_path)
merged_image = image[1]
layer_paths: list[str] = []
for layer_idx in range(2, x_hat.shape[0]):
rgba_layer = (x_hat[layer_idx].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
rgba_image = Image.fromarray(rgba_layer, "RGBA")
layer_path = case_dir / f"layer_{layer_idx - 2}_rgba.png"
rgba_image.save(layer_path)
layer_paths.append(str(layer_path))
merged_image = Image.alpha_composite(merged_image.convert("RGBA"), rgba_image)
merged_image.convert("RGB").save(merged_global_path)
merged_image.convert("RGB").save(merged_case_path)
merged_image.save(merged_rgba_path)
case_meta = {
"sample_name": sample_name,
"source_image_path": format_source_image_path(
resolved_image_path,
Path(config["image_dir"]),
),
"target_size": target_size,
"source_size": source_size,
"raw_num_layers": sample.get("num_layers"),
"num_layers": len(all_boxes),
"raw_boxes": sample.get("bboxes", []),
"boxes": all_boxes,
"caption": caption,
"run_name": resolved_run_name,
}
meta_path = case_dir / "inference_meta.json"
with meta_path.open("w", encoding="utf-8") as handle:
json.dump(case_meta, handle, indent=2)
return {
"run_name": resolved_run_name,
"save_dir": str(save_dir_path),
"case_dir": str(case_dir),
"merged_image": str(merged_case_path),
"merged_global_image": str(merged_global_path),
"merged_rgba_image": str(merged_rgba_path),
"whole_image_rgba": str(whole_rgba_path),
"background_rgba": str(background_rgba_path),
"origin_image": str(origin_path),
"layer_images": layer_paths,
"metadata_path": str(meta_path),
"metadata": case_meta,
}
def create_archive(run_dir: Path) -> Path:
archive_path = run_dir / "synlayers_result_bundle.zip"
with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for path in run_dir.rglob("*"):
if path == archive_path or path.is_dir():
continue
zf.write(path, arcname=path.relative_to(run_dir))
return archive_path
def run_real_world_pipeline(
image_path: str | Path,
*,
sample_name: str | None = None,
work_dir: str | Path | None = None,
bbox_model: str | None = None,
config_path: str | Path | None = None,
max_new_tokens: int = 1024,
seed: int | None = None,
run_name: str = DEFAULT_RUN_NAME,
) -> dict:
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA GPU is required for the unified SynLayers real-world pipeline. "
"On Hugging Face Spaces, assign GPU hardware such as A100 and rebuild the Space."
)
image_path = Path(image_path)
if not image_path.exists():
raise FileNotFoundError(f"Input image not found: {image_path}")
bbox_model = (
bbox_model
or os.environ.get("SYNLAYERS_BBOX_MODEL")
or os.environ.get("SYNLAYERS_BBOX_MODEL_REPO")
or DEFAULT_BBOX_MODEL
)
config_path = Path(config_path or os.environ.get("SYNLAYERS_REAL_CONFIG", str(DEFAULT_REAL_CONFIG_PATH)))
work_dir = Path(work_dir or os.environ.get("SYNLAYERS_DEMO_WORK_DIR", str(DEFAULT_WORK_DIR)))
normalized_sample_name = slugify(sample_name or image_path.stem)
timestamp = f"{time.strftime('%Y%m%d_%H%M%S')}_{int((time.time() % 1) * 1000):03d}"
run_dir = work_dir / f"{timestamp}_{normalized_sample_name}"
image_dir = run_dir / "layers_real_test_1024"
prepared_image_path = prepare_input_image(
image_path,
image_dir / f"{normalized_sample_name}.png",
DEFAULT_TARGET_SIZE,
)
bbox_model_bundle, bbox_processor = load_bbox_bundle(bbox_model)
whole_caption, bboxes = infer_caption_bbox(
prepared_image_path,
bbox_model_bundle,
bbox_processor,
prompt=CAPTION_BBOX_PROMPT_TOP_LEFT,
max_new_tokens=max_new_tokens,
)
record = {
"sample_or_stem": normalized_sample_name,
"image": prepared_image_path.name,
"whole_caption": whole_caption,
"bboxes": bboxes,
"num_layers": len(bboxes),
"coord": "top_left",
}
bbox_jsonl = write_bbox_jsonl(record, run_dir / "caption_bbox_infer.jsonl")
bbox_vis_path = run_dir / "bbox_vis" / f"{normalized_sample_name}_vis.png"
draw_boxes(prepared_image_path, bboxes, bbox_vis_path)
if RELEASE_BBOX_AFTER_CAPTION:
release_bbox_bundle()
config = build_runtime_config(
config_path=config_path,
image_dir=image_dir,
bbox_jsonl=bbox_jsonl,
results_root=run_dir / "results",
run_name=run_name,
seed=seed,
)
pipeline, transp_vae = load_real_bundle(config)
decomposition_result = save_real_case(
sample=record,
config=config,
pipeline=pipeline,
transp_vae=transp_vae,
)
archive_path = create_archive(run_dir)
decomposition_result.update(
{
"input_image": str(prepared_image_path),
"bbox_visualization": str(bbox_vis_path),
"bbox_jsonl": str(bbox_jsonl),
"bbox_record": record,
"archive_path": str(archive_path),
"config_path": str(config_path),
"bbox_model": bbox_model,
}
)
return decomposition_result
def main():
parser = argparse.ArgumentParser(
description="Run the unified real-world SynLayers pipeline on one image."
)
parser.add_argument("--image", type=str, required=True, help="Input image path")
parser.add_argument("--sample-name", type=str, default=None)
parser.add_argument("--work-dir", type=str, default=str(DEFAULT_WORK_DIR))
parser.add_argument("--bbox-model", type=str, default=DEFAULT_BBOX_MODEL)
parser.add_argument("--config", type=str, default=str(DEFAULT_REAL_CONFIG_PATH))
parser.add_argument("--max-new-tokens", type=int, default=1024)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--run-name", type=str, default=DEFAULT_RUN_NAME)
args = parser.parse_args()
result = run_real_world_pipeline(
args.image,
sample_name=args.sample_name,
work_dir=args.work_dir,
bbox_model=args.bbox_model,
config_path=args.config,
max_new_tokens=args.max_new_tokens,
seed=args.seed,
run_name=args.run_name,
)
print(json.dumps(result, indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()