synlayers / infer /infer.py
SynLayers's picture
Upload infer/infer.py with huggingface_hub
c933092 verified
import argparse
import json
import logging
import os
import re
import sys
from pathlib import Path
import numpy as np
import torch
from PIL import Image
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
from infer.common_infer import initialize_pipeline, quantize_box_16, scale_box_xyxy
from tools.tools import load_config, seed_everything
def load_real_metadata(jsonl_path: str):
"""Load real-test metadata from JSONL."""
items = []
with open(jsonl_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
items.append(json.loads(line))
return items
def extract_checkpoint_tag(path: str):
"""Extract a checkpoint tag like scaleup_1024_20k or original_1024_512seq."""
if not path:
return None
match = re.search(r"ckpt_prism_([^/]+)", path)
if match:
return match.group(1)
return None
def derive_run_name(config: dict) -> str:
"""Derive the result subfolder name from the active checkpoint setup."""
checkpoint_tags = {}
for key in ("lora_ckpt", "layer_ckpt", "adapter_lora_dir"):
tag = extract_checkpoint_tag(config.get(key, ""))
if tag:
checkpoint_tags[key] = tag
if checkpoint_tags:
unique_tags = sorted(set(checkpoint_tags.values()))
if len(unique_tags) != 1:
details = ", ".join(f"{key}={value}" for key, value in checkpoint_tags.items())
raise ValueError(
"Checkpoint paths are inconsistent. "
"Please switch lora_ckpt, layer_ckpt, and adapter_lora_dir together. "
f"Current tags: {details}"
)
inferred_tag = unique_tags[0]
else:
inferred_tag = "real_infer"
if config.get("run_name"):
return config["run_name"]
return inferred_tag
def build_run_save_dir(config: dict):
"""Build the final save directory as <save_dir>/<run_name>."""
save_root = config.get("save_dir", "./real_inference_output")
run_name = derive_run_name(config)
return os.path.join(save_root, run_name), run_name
def resolve_image_path(sample: dict, data_dir: str, image_dir: str = None) -> str:
"""Resolve the input image path, preferring local files_real_test images."""
sample_name = sample.get("sample_or_stem", "")
image_path = sample.get("image", "")
if image_dir is None and data_dir:
image_dir = os.path.join(data_dir, "layers_real_test_1024")
candidates = []
if image_dir:
if sample_name:
candidates.extend(
[
os.path.join(image_dir, f"{sample_name}.png"),
os.path.join(image_dir, f"{sample_name}.jpg"),
os.path.join(image_dir, f"{sample_name}.jpeg"),
]
)
if image_path:
candidates.append(os.path.join(image_dir, os.path.basename(image_path)))
if image_path:
candidates.append(image_path)
if data_dir and not os.path.isabs(image_path):
candidates.append(os.path.join(data_dir, image_path))
seen = set()
for candidate in candidates:
if not candidate or candidate in seen:
continue
seen.add(candidate)
if os.path.exists(candidate):
return candidate
raise FileNotFoundError(
f"Could not resolve image for sample '{sample_name}'. "
f"Tried local image_dir='{image_dir}' and json path '{image_path}'."
)
def quantize_box_16_safe(box: tuple, target_size: int) -> tuple:
"""Quantize a box to the 16-pixel grid and keep at least one latent cell."""
x0_q, y0_q, x1_q, y1_q = quantize_box_16(box, target_size)
if x1_q <= x0_q:
if x0_q + 16 <= target_size:
x1_q = x0_q + 16
else:
x0_q = max(0, target_size - 16)
x1_q = target_size
if y1_q <= y0_q:
if y0_q + 16 <= target_size:
y1_q = y0_q + 16
else:
y0_q = max(0, target_size - 16)
y1_q = target_size
return (x0_q, y0_q, x1_q, y1_q)
def get_real_boxes(sample: dict, source_size: int, target_size: int) -> list:
"""Scale and quantize real-test boxes from JSON metadata."""
boxes = []
for box in sample.get("bboxes", []):
if not isinstance(box, (list, tuple)) or len(box) != 4:
continue
scaled_box = scale_box_xyxy(box, source_size, target_size)
boxes.append(quantize_box_16_safe(scaled_box, target_size))
return boxes
def load_adapter_image(sample: dict, target_size: int, config: dict):
"""Load and resize the real-test image used as adapter input."""
image_path = resolve_image_path(
sample,
data_dir=config.get("data_dir", ""),
image_dir=config.get("image_dir"),
)
img = Image.open(image_path).convert("RGB")
if img.size != (target_size, target_size):
img = img.resize((target_size, target_size), Image.LANCZOS)
return img, image_path
def format_source_image_path(image_path: str, config: dict) -> str:
path = Path(image_path)
for key in ("image_dir", "data_dir"):
root = config.get(key)
if not root:
continue
try:
return path.relative_to(Path(root)).as_posix()
except ValueError:
continue
return path.name
@torch.no_grad()
def inference_real(config):
"""Main inference function for the real-test dataset."""
if config.get("seed") is not None:
seed_everything(config["seed"])
source_size = config.get("source_size", 1024)
target_size = config.get("target_size", 1024)
max_layer_num = config.get("max_layer_num", 52)
print(f"[INFO] Source size: {source_size}, Target size: {target_size}", flush=True)
save_dir, run_name = build_run_save_dir(config)
os.makedirs(save_dir, exist_ok=True)
os.makedirs(os.path.join(save_dir, "merged"), exist_ok=True)
os.makedirs(os.path.join(save_dir, "merged_rgba"), exist_ok=True)
print(f"[INFO] Run name: {run_name}", flush=True)
print(f"[INFO] Results will be saved to: {save_dir}", flush=True)
pipeline, transp_vae = initialize_pipeline(config)
test_jsonl = config.get("test_jsonl", "")
if not test_jsonl or not os.path.exists(test_jsonl):
raise ValueError(f"Test JSONL not found: {test_jsonl}")
all_samples = load_real_metadata(test_jsonl)
total_available = len(all_samples)
start_idx = config.get("start_idx", 1)
end_idx = config.get("end_idx", total_available)
max_samples = config.get("max_samples", None)
if max_samples and not config.get("end_idx"):
end_idx = min(start_idx + max_samples - 1, total_available)
start_idx = max(1, min(start_idx, total_available))
end_idx = max(start_idx, min(end_idx, total_available))
samples = all_samples[start_idx - 1 : end_idx]
print(f"[INFO] Total samples in dataset: {total_available}", flush=True)
print(
f"[INFO] Processing samples {start_idx} to {end_idx} ({len(samples)} samples)",
flush=True,
)
generator = torch.Generator(device=torch.device("cuda")).manual_seed(
config.get("seed", 42)
)
for local_idx, sample in enumerate(samples):
idx_zero_based = start_idx - 1 + local_idx
sample_name = sample.get("sample_or_stem", f"real_{idx_zero_based:06d}")
print(
f"Processing [{local_idx + 1}/{len(samples)}] idx={idx_zero_based} ({sample_name})...",
flush=True,
)
try:
layer_boxes = get_real_boxes(sample, source_size, target_size)
adapter_img, image_path = load_adapter_image(sample, target_size, config)
except Exception as e:
print(f" Error preparing sample: {e}", flush=True)
continue
whole_box = (0, 0, target_size, target_size)
bg_box = (0, 0, target_size, target_size)
all_boxes = [whole_box, bg_box] + layer_boxes
if len(all_boxes) > max_layer_num:
print(
f" Skipping sample because num_layers={len(all_boxes)} exceeds max_layer_num={max_layer_num}",
flush=True,
)
continue
caption = sample.get("whole_caption", "")
print(f" Size: {target_size}x{target_size}, Layers: {len(all_boxes)}", flush=True)
try:
x_hat, image, _ = pipeline(
prompt=caption,
adapter_image=adapter_img,
adapter_conditioning_scale=config.get("adapter_scale", 0.9),
validation_box=all_boxes,
generator=generator,
height=target_size,
width=target_size,
guidance_scale=config.get("cfg", 4.0),
num_layers=len(all_boxes),
sdxl_vae=transp_vae,
)
except Exception as e:
print(f" Error during inference: {e}", flush=True)
continue
x_hat = (x_hat + 1) / 2
x_hat = x_hat.squeeze(0).permute(1, 0, 2, 3).to(torch.float32)
case_dir = os.path.join(save_dir, sample_name)
os.makedirs(case_dir, exist_ok=True)
whole_image_layer = (
x_hat[0].permute(1, 2, 0).cpu().numpy() * 255
).astype(np.uint8)
Image.fromarray(whole_image_layer, "RGBA").save(
os.path.join(case_dir, "whole_image_rgba.png")
)
background_layer = (
x_hat[1].permute(1, 2, 0).cpu().numpy() * 255
).astype(np.uint8)
Image.fromarray(background_layer, "RGBA").save(
os.path.join(case_dir, "background_rgba.png")
)
adapter_img.save(os.path.join(case_dir, "origin.png"))
merged_image = image[1]
for layer_idx in range(2, x_hat.shape[0]):
rgba_layer = (
x_hat[layer_idx].permute(1, 2, 0).cpu().numpy() * 255
).astype(np.uint8)
rgba_image = Image.fromarray(rgba_layer, "RGBA")
rgba_image.save(os.path.join(case_dir, f"layer_{layer_idx - 2}_rgba.png"))
merged_image = Image.alpha_composite(merged_image.convert("RGBA"), rgba_image)
merged_image.convert("RGB").save(
os.path.join(save_dir, "merged", f"{sample_name}.png")
)
merged_image.convert("RGB").save(os.path.join(case_dir, "merged.png"))
merged_image.save(os.path.join(save_dir, "merged_rgba", f"{sample_name}.png"))
case_meta = {
"sample_idx_zero_based": idx_zero_based,
"sample_idx_one_based": idx_zero_based + 1,
"sample_name": sample_name,
"source_image_path": format_source_image_path(image_path, config),
"target_size": target_size,
"source_size": source_size,
"raw_num_layers": sample.get("num_layers"),
"num_layers": len(all_boxes),
"raw_boxes": sample.get("bboxes", []),
"boxes": all_boxes,
"caption": caption,
"run_name": run_name,
}
with open(os.path.join(case_dir, "inference_meta.json"), "w", encoding="utf-8") as f:
json.dump(case_meta, f, indent=2)
if idx_zero_based % 10 == 0:
torch.cuda.empty_cache()
print(f"[INFO] Inference complete. Results saved to {save_dir}", flush=True)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path",
"-c",
type=str,
required=True,
help="Path to the YAML configuration file.",
)
parser.add_argument(
"--start_idx",
type=int,
default=None,
help="1-based start index for the JSONL entries.",
)
parser.add_argument(
"--end_idx",
type=int,
default=None,
help="1-based end index for the JSONL entries (inclusive).",
)
args = parser.parse_args()
config = load_config(args.config_path)
if args.start_idx is not None:
config["start_idx"] = args.start_idx
if args.end_idx is not None:
config["end_idx"] = args.end_idx
inference_real(config)
if __name__ == "__main__":
main()