synlayers / tools /dataset.py
SynLayers's picture
Upload tools/dataset.py with huggingface_hub
6cbd779 verified
raw
history blame
17.7 kB
import json
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from datasets import load_dataset, concatenate_datasets
import torchvision.transforms as T
from collections import defaultdict
def collate_fn(batch):
pixels_RGBA = [torch.stack(item["pixel_RGBA"]) for item in batch] # [L, C, H, W]
pixels_RGB = [torch.stack(item["pixel_RGB"]) for item in batch] # [L, C, H, W]
pixels_RGBA = torch.stack(pixels_RGBA) # [B, L, C, H, W]
pixels_RGB = torch.stack(pixels_RGB) # [B, L, C, H, W]
return {
"pixel_RGBA": pixels_RGBA,
"pixel_RGB": pixels_RGB,
"whole_img": [item["whole_img"] for item in batch],
"caption": [item["caption"] for item in batch],
"height": [item["height"] for item in batch],
"width": [item["width"] for item in batch],
"layout": [item["layout"] for item in batch],
}
class LayoutTrainDataset(Dataset):
def __init__(self, data_dir, split="train"):
full_dataset = load_dataset(
"artplus/PrismLayersPro",
cache_dir=data_dir,
)
full_dataset = concatenate_datasets(list(full_dataset.values()))
if "style_category" not in full_dataset.column_names:
raise ValueError("Dataset must contain a 'style_category' field to split by class.")
categories = np.array(full_dataset["style_category"])
category_to_indices = defaultdict(list)
for i, cat in enumerate(categories):
category_to_indices[cat].append(i)
subsets = []
for cat, indices in category_to_indices.items():
total_len = len(indices)
idx_90 = int(total_len * 0.9)
idx_95 = int(total_len * 0.95)
if split == "train":
selected_idx = indices[:idx_90]
elif split == "test":
selected_idx = indices[idx_90:idx_95]
elif split == "val":
selected_idx = indices[idx_95:]
else:
raise ValueError("split must be 'train', 'val', or 'test'")
subsets.append(full_dataset.select(selected_idx))
self.dataset = concatenate_datasets(subsets)
self.to_tensor = T.ToTensor()
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
def rgba2rgb(img_RGBA):
img_RGB = Image.new("RGB", img_RGBA.size, (128, 128, 128))
img_RGB.paste(img_RGBA, mask=img_RGBA.split()[3])
return img_RGB
def get_img(x):
if isinstance(x, str):
img_RGBA = Image.open(x).convert("RGBA")
img_RGB = rgba2rgb(img_RGBA)
else:
img_RGBA = x.convert("RGBA")
img_RGB = rgba2rgb(img_RGBA)
return img_RGBA, img_RGB
whole_img_RGBA, whole_img_RGB = get_img(item["whole_image"])
whole_cap = item["whole_caption"]
W, H = whole_img_RGBA.size
base_layout = [0, 0, W, H] # xyxy with exclusive end coordinates
layer_image_RGBA = [self.to_tensor(whole_img_RGBA)]
layer_image_RGB = [self.to_tensor(whole_img_RGB)]
layout = [base_layout]
base_img_RGBA, base_img_RGB = get_img(item["base_image"])
layer_image_RGBA.append(self.to_tensor(base_img_RGBA))
layer_image_RGB.append(self.to_tensor(base_img_RGB))
layout.append(base_layout)
layer_count = item["layer_count"]
for i in range(layer_count):
key = f"layer_{i:02d}"
img_RGBA, img_RGB = get_img(item[key])
w0, h0, w1, h1 = item[f"{key}_box"]
canvas_RGBA = Image.new("RGBA", (W, H), (0, 0, 0, 0))
canvas_RGB = Image.new("RGB", (W, H), (128, 128, 128))
W_img, H_img = w1 - w0, h1 - h0
if img_RGBA.size != (W_img, H_img):
img_RGBA = img_RGBA.resize((W_img, H_img), Image.BILINEAR)
img_RGB = img_RGB.resize((W_img, H_img), Image.BILINEAR)
canvas_RGBA.paste(img_RGBA, (w0, h0), img_RGBA)
canvas_RGB.paste(img_RGB, (w0, h0))
layer_image_RGBA.append(self.to_tensor(canvas_RGBA))
layer_image_RGB.append(self.to_tensor(canvas_RGB))
layout.append([w0, h0, w1, h1])
return {
"pixel_RGBA": layer_image_RGBA,
"pixel_RGB": layer_image_RGB,
"whole_img": whole_img_RGB,
"caption": whole_cap,
"height": H,
"width": W,
"layout": layout,
}
class LayoutDatasetFixedSplit(Dataset):
"""
HuggingFace PrismLayersPro with a fixed index-based split.
Total 20,000 samples: train = [0, 19500), test = [19500, 20000).
For test split, use start_index and max_samples to select a sub-range:
start_index=200, max_samples=100 -> samples 019700-019799
start_index=0, max_samples=100 -> samples 019500-019599
"""
TRAIN_END = 19500
TOTAL = 20000
def __init__(self, data_dir, split="train", start_index=0, max_samples=None):
full_dataset = load_dataset(
"artplus/PrismLayersPro",
cache_dir=data_dir,
)
full_dataset = concatenate_datasets(list(full_dataset.values()))
if split == "train":
self.dataset = full_dataset.select(range(self.TRAIN_END))
self.global_offset = 0
elif split == "test":
self.dataset = full_dataset.select(range(self.TRAIN_END, self.TOTAL))
self.global_offset = self.TRAIN_END
else:
raise ValueError("split must be 'train' or 'test'")
end_index = len(self.dataset)
if max_samples is not None:
end_index = min(start_index + max_samples, len(self.dataset))
self.dataset = self.dataset.select(range(start_index, end_index))
self.global_offset += start_index
self.to_tensor = T.ToTensor()
print(f"[INFO] LayoutDatasetFixedSplit: split={split}, "
f"global range=[{self.global_offset}, {self.global_offset + len(self.dataset)}), "
f"samples={len(self.dataset)}")
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
def rgba2rgb(img_RGBA):
img_RGB = Image.new("RGB", img_RGBA.size, (128, 128, 128))
img_RGB.paste(img_RGBA, mask=img_RGBA.split()[3])
return img_RGB
def get_img(x):
if isinstance(x, str):
img_RGBA = Image.open(x).convert("RGBA")
else:
img_RGBA = x.convert("RGBA")
return img_RGBA, rgba2rgb(img_RGBA)
whole_img_RGBA, whole_img_RGB = get_img(item["whole_image"])
whole_cap = item["whole_caption"]
W, H = whole_img_RGBA.size
base_layout = [0, 0, W, H]
layer_image_RGBA = [self.to_tensor(whole_img_RGBA)]
layer_image_RGB = [self.to_tensor(whole_img_RGB)]
layout = [base_layout]
base_img_RGBA, base_img_RGB = get_img(item["base_image"])
layer_image_RGBA.append(self.to_tensor(base_img_RGBA))
layer_image_RGB.append(self.to_tensor(base_img_RGB))
layout.append(base_layout)
layer_count = item["layer_count"]
for i in range(layer_count):
key = f"layer_{i:02d}"
img_RGBA, img_RGB = get_img(item[key])
w0, h0, w1, h1 = item[f"{key}_box"]
canvas_RGBA = Image.new("RGBA", (W, H), (0, 0, 0, 0))
canvas_RGB = Image.new("RGB", (W, H), (128, 128, 128))
W_img, H_img = w1 - w0, h1 - h0
if img_RGBA.size != (W_img, H_img):
img_RGBA = img_RGBA.resize((W_img, H_img), Image.BILINEAR)
img_RGB = img_RGB.resize((W_img, H_img), Image.BILINEAR)
canvas_RGBA.paste(img_RGBA, (w0, h0), img_RGBA)
canvas_RGB.paste(img_RGB, (w0, h0))
layer_image_RGBA.append(self.to_tensor(canvas_RGBA))
layer_image_RGB.append(self.to_tensor(canvas_RGB))
layout.append([w0, h0, w1, h1])
return {
"pixel_RGBA": layer_image_RGBA,
"pixel_RGB": layer_image_RGB,
"whole_img": whole_img_RGB,
"caption": whole_cap,
"height": H,
"width": W,
"layout": layout,
}
def prism_collate_fn(batch):
"""Collate function for PrismBlendDataset."""
pixels_RGBA = [torch.stack(item["pixel_RGBA"]) for item in batch]
pixels_RGB = [torch.stack(item["pixel_RGB"]) for item in batch]
pixels_RGBA = torch.stack(pixels_RGBA)
pixels_RGB = torch.stack(pixels_RGB)
return {
"pixel_RGBA": pixels_RGBA,
"pixel_RGB": pixels_RGB,
"whole_img": [item["whole_img"] for item in batch],
"caption": [item["caption"] for item in batch],
"height": [item["height"] for item in batch],
"width": [item["width"] for item in batch],
"layout": [item["layout"] for item in batch],
}
class PrismBlendDataset(Dataset):
"""
Dataset for PrismLayersPro blended data.
Loads from local directory structure (following PrismLayersPro convention):
- data_dir/sample_XXXXXX/metadata.json
- data_dir/sample_XXXXXX/whole_image.png
- data_dir/sample_XXXXXX/base_image.png
- data_dir/sample_XXXXXX/layer_00.png, layer_01.png, ...
Boxes are in xyxy format: [x0, y0, x1, y1]
All layer images have transparent backgrounds.
"""
def __init__(self, data_dir: str, jsonl_path: str = None, target_size: int = 512, split: str = "all", max_layer_num: int = None):
self.data_dir = data_dir
self.target_size = target_size
self.max_layer_num = max_layer_num
self.to_tensor = T.ToTensor()
# Load samples
if jsonl_path and os.path.exists(jsonl_path):
self.samples = self._load_from_jsonl(jsonl_path)
else:
self.samples = self._load_from_directory(data_dir)
# Filter samples exceeding max_layer_num (if specified)
# Total layers = 2 (whole_image + base_image) + layer_count
if max_layer_num is not None:
original_count = len(self.samples)
self.samples = [
s for s in self.samples
if (2 + s.get('layer_count', 0)) <= max_layer_num
]
filtered_count = original_count - len(self.samples)
if filtered_count > 0:
print(f"[INFO] Filtered {filtered_count} samples exceeding max_layer_num={max_layer_num}")
# Split dataset (only if explicitly requested, default is "all" = use all samples)
# Usually you have separate train/test datasets, so no splitting needed
if split == "train_split":
self.samples = self.samples[:int(len(self.samples) * 0.9)]
elif split == "test_split":
self.samples = self.samples[int(len(self.samples) * 0.9):int(len(self.samples) * 0.95)]
elif split == "val_split":
self.samples = self.samples[int(len(self.samples) * 0.95):]
# "all", "train", "test" -> use all samples from the provided jsonl/directory
def _load_from_jsonl(self, jsonl_path: str):
"""Load samples from JSONL file."""
samples = []
with open(jsonl_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
samples.append(json.loads(line))
return samples
def _load_from_directory(self, data_dir: str):
"""Load samples from directory structure."""
samples = []
for name in sorted(os.listdir(data_dir)):
sample_dir = os.path.join(data_dir, name)
if os.path.isdir(sample_dir) and name.startswith('sample_'):
metadata_path = os.path.join(sample_dir, 'metadata.json')
#metadata_path = os.path.join(sample_dir, 'metadata_old.json') # old for original_1024.
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
samples.append(json.load(f))
return samples
def __len__(self):
return len(self.samples)
def _rgba2rgb(self, img_RGBA):
"""Convert RGBA to RGB with gray background."""
img_RGB = Image.new("RGB", img_RGBA.size, (128, 128, 128))
img_RGB.paste(img_RGBA, mask=img_RGBA.split()[3])
return img_RGB
def _get_sample_dir(self, sample):
"""Get the directory for a sample."""
# Try sample_dir first
sample_dir = sample.get('sample_dir', '')
if sample_dir:
full_path = os.path.join(self.data_dir, sample_dir)
if os.path.exists(full_path):
return full_path
return None
def __getitem__(self, idx):
sample = self.samples[idx]
sample_dir = self._get_sample_dir(sample)
if not sample_dir:
raise ValueError(f"Could not find sample directory for index {idx}")
source_size = sample.get('width', self.target_size)
caption = sample.get('whole_caption', '')
# Scale factor (source -> target)
scale = self.target_size / source_size
# Load whole_image (composite)
whole_img_path = os.path.join(sample_dir, 'whole_image.png')
if os.path.exists(whole_img_path):
whole_img = Image.open(whole_img_path).convert('RGBA')
else:
whole_img = Image.new('RGBA', (source_size, source_size), (128, 128, 128, 255))
# Resize if needed
if whole_img.size != (self.target_size, self.target_size):
whole_img = whole_img.resize((self.target_size, self.target_size), Image.LANCZOS)
whole_img_RGB = self._rgba2rgb(whole_img)
# Initialize layer lists with whole_image first
layer_image_RGBA = [self.to_tensor(whole_img)]
layer_image_RGB = [self.to_tensor(whole_img_RGB)]
# Base layout (whole image) in xyxy format [x0, y0, x1, y1]
W, H = self.target_size, self.target_size
base_layout = [0, 0, W, H] # xyxy with exclusive end coordinates
layout = [base_layout]
# Load base_image (background) as second layer
base_img_path = os.path.join(sample_dir, 'base_image.png')
if os.path.exists(base_img_path):
base_img = Image.open(base_img_path).convert('RGBA')
if base_img.size != (self.target_size, self.target_size):
base_img = base_img.resize((self.target_size, self.target_size), Image.LANCZOS)
else:
base_img = Image.new('RGBA', (self.target_size, self.target_size), (0, 0, 0, 0))
base_img_RGB = self._rgba2rgb(base_img)
layer_image_RGBA.append(self.to_tensor(base_img))
layer_image_RGB.append(self.to_tensor(base_img_RGB))
layout.append(base_layout) # background covers whole image
# Load layers from metadata
layers = sample.get('layers', [])
for layer_info in layers:
image_path = layer_info.get('image_path', '')
box = layer_info.get('box', [0, 0, source_size, source_size])
# Scale box (xyxy format)
x0, y0, x1, y1 = box
scaled_box = [
int(x0 * scale),
int(y0 * scale),
int(x1 * scale),
int(y1 * scale)
]
# Load layer image
# Handles two formats:
# 1. Full-canvas (target_size x target_size) — use as-is
# 2. Cropped (smaller than canvas) — place at bbox position on transparent canvas
layer_path = os.path.join(sample_dir, image_path)
if os.path.exists(layer_path):
layer_img = Image.open(layer_path).convert('RGBA')
if layer_img.size == (self.target_size, self.target_size):
# Already full-canvas, use directly
pass
elif layer_img.size == (source_size, source_size) and source_size != self.target_size:
# Full-canvas at source resolution, just resize
layer_img = layer_img.resize((self.target_size, self.target_size), Image.LANCZOS)
else:
# Cropped layer — resize to fit the scaled bbox and place on canvas
bw = max(1, scaled_box[2] - scaled_box[0])
bh = max(1, scaled_box[3] - scaled_box[1])
layer_resized = layer_img.resize((bw, bh), Image.LANCZOS)
layer_img = Image.new('RGBA', (self.target_size, self.target_size), (0, 0, 0, 0))
layer_img.paste(layer_resized, (scaled_box[0], scaled_box[1]), layer_resized)
else:
layer_img = Image.new('RGBA', (self.target_size, self.target_size), (0, 0, 0, 0))
layer_img_RGB = self._rgba2rgb(layer_img)
layer_image_RGBA.append(self.to_tensor(layer_img))
layer_image_RGB.append(self.to_tensor(layer_img_RGB))
layout.append(scaled_box)
return {
"pixel_RGBA": layer_image_RGBA,
"pixel_RGB": layer_image_RGB,
"whole_img": whole_img_RGB,
"caption": caption,
"height": H,
"width": W,
"layout": layout,
}