Spaces:
Running on Zero
Running on Zero
| import json | |
| import os | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from datasets import load_dataset, concatenate_datasets | |
| import torchvision.transforms as T | |
| from collections import defaultdict | |
| def collate_fn(batch): | |
| pixels_RGBA = [torch.stack(item["pixel_RGBA"]) for item in batch] # [L, C, H, W] | |
| pixels_RGB = [torch.stack(item["pixel_RGB"]) for item in batch] # [L, C, H, W] | |
| pixels_RGBA = torch.stack(pixels_RGBA) # [B, L, C, H, W] | |
| pixels_RGB = torch.stack(pixels_RGB) # [B, L, C, H, W] | |
| return { | |
| "pixel_RGBA": pixels_RGBA, | |
| "pixel_RGB": pixels_RGB, | |
| "whole_img": [item["whole_img"] for item in batch], | |
| "caption": [item["caption"] for item in batch], | |
| "height": [item["height"] for item in batch], | |
| "width": [item["width"] for item in batch], | |
| "layout": [item["layout"] for item in batch], | |
| } | |
| class LayoutTrainDataset(Dataset): | |
| def __init__(self, data_dir, split="train"): | |
| full_dataset = load_dataset( | |
| "artplus/PrismLayersPro", | |
| cache_dir=data_dir, | |
| ) | |
| full_dataset = concatenate_datasets(list(full_dataset.values())) | |
| if "style_category" not in full_dataset.column_names: | |
| raise ValueError("Dataset must contain a 'style_category' field to split by class.") | |
| categories = np.array(full_dataset["style_category"]) | |
| category_to_indices = defaultdict(list) | |
| for i, cat in enumerate(categories): | |
| category_to_indices[cat].append(i) | |
| subsets = [] | |
| for cat, indices in category_to_indices.items(): | |
| total_len = len(indices) | |
| idx_90 = int(total_len * 0.9) | |
| idx_95 = int(total_len * 0.95) | |
| if split == "train": | |
| selected_idx = indices[:idx_90] | |
| elif split == "test": | |
| selected_idx = indices[idx_90:idx_95] | |
| elif split == "val": | |
| selected_idx = indices[idx_95:] | |
| else: | |
| raise ValueError("split must be 'train', 'val', or 'test'") | |
| subsets.append(full_dataset.select(selected_idx)) | |
| self.dataset = concatenate_datasets(subsets) | |
| self.to_tensor = T.ToTensor() | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = self.dataset[idx] | |
| def rgba2rgb(img_RGBA): | |
| img_RGB = Image.new("RGB", img_RGBA.size, (128, 128, 128)) | |
| img_RGB.paste(img_RGBA, mask=img_RGBA.split()[3]) | |
| return img_RGB | |
| def get_img(x): | |
| if isinstance(x, str): | |
| img_RGBA = Image.open(x).convert("RGBA") | |
| img_RGB = rgba2rgb(img_RGBA) | |
| else: | |
| img_RGBA = x.convert("RGBA") | |
| img_RGB = rgba2rgb(img_RGBA) | |
| return img_RGBA, img_RGB | |
| whole_img_RGBA, whole_img_RGB = get_img(item["whole_image"]) | |
| whole_cap = item["whole_caption"] | |
| W, H = whole_img_RGBA.size | |
| base_layout = [0, 0, W, H] # xyxy with exclusive end coordinates | |
| layer_image_RGBA = [self.to_tensor(whole_img_RGBA)] | |
| layer_image_RGB = [self.to_tensor(whole_img_RGB)] | |
| layout = [base_layout] | |
| base_img_RGBA, base_img_RGB = get_img(item["base_image"]) | |
| layer_image_RGBA.append(self.to_tensor(base_img_RGBA)) | |
| layer_image_RGB.append(self.to_tensor(base_img_RGB)) | |
| layout.append(base_layout) | |
| layer_count = item["layer_count"] | |
| for i in range(layer_count): | |
| key = f"layer_{i:02d}" | |
| img_RGBA, img_RGB = get_img(item[key]) | |
| w0, h0, w1, h1 = item[f"{key}_box"] | |
| canvas_RGBA = Image.new("RGBA", (W, H), (0, 0, 0, 0)) | |
| canvas_RGB = Image.new("RGB", (W, H), (128, 128, 128)) | |
| W_img, H_img = w1 - w0, h1 - h0 | |
| if img_RGBA.size != (W_img, H_img): | |
| img_RGBA = img_RGBA.resize((W_img, H_img), Image.BILINEAR) | |
| img_RGB = img_RGB.resize((W_img, H_img), Image.BILINEAR) | |
| canvas_RGBA.paste(img_RGBA, (w0, h0), img_RGBA) | |
| canvas_RGB.paste(img_RGB, (w0, h0)) | |
| layer_image_RGBA.append(self.to_tensor(canvas_RGBA)) | |
| layer_image_RGB.append(self.to_tensor(canvas_RGB)) | |
| layout.append([w0, h0, w1, h1]) | |
| return { | |
| "pixel_RGBA": layer_image_RGBA, | |
| "pixel_RGB": layer_image_RGB, | |
| "whole_img": whole_img_RGB, | |
| "caption": whole_cap, | |
| "height": H, | |
| "width": W, | |
| "layout": layout, | |
| } | |
| class LayoutDatasetFixedSplit(Dataset): | |
| """ | |
| HuggingFace PrismLayersPro with a fixed index-based split. | |
| Total 20,000 samples: train = [0, 19500), test = [19500, 20000). | |
| For test split, use start_index and max_samples to select a sub-range: | |
| start_index=200, max_samples=100 -> samples 019700-019799 | |
| start_index=0, max_samples=100 -> samples 019500-019599 | |
| """ | |
| TRAIN_END = 19500 | |
| TOTAL = 20000 | |
| def __init__(self, data_dir, split="train", start_index=0, max_samples=None): | |
| full_dataset = load_dataset( | |
| "artplus/PrismLayersPro", | |
| cache_dir=data_dir, | |
| ) | |
| full_dataset = concatenate_datasets(list(full_dataset.values())) | |
| if split == "train": | |
| self.dataset = full_dataset.select(range(self.TRAIN_END)) | |
| self.global_offset = 0 | |
| elif split == "test": | |
| self.dataset = full_dataset.select(range(self.TRAIN_END, self.TOTAL)) | |
| self.global_offset = self.TRAIN_END | |
| else: | |
| raise ValueError("split must be 'train' or 'test'") | |
| end_index = len(self.dataset) | |
| if max_samples is not None: | |
| end_index = min(start_index + max_samples, len(self.dataset)) | |
| self.dataset = self.dataset.select(range(start_index, end_index)) | |
| self.global_offset += start_index | |
| self.to_tensor = T.ToTensor() | |
| print(f"[INFO] LayoutDatasetFixedSplit: split={split}, " | |
| f"global range=[{self.global_offset}, {self.global_offset + len(self.dataset)}), " | |
| f"samples={len(self.dataset)}") | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = self.dataset[idx] | |
| def rgba2rgb(img_RGBA): | |
| img_RGB = Image.new("RGB", img_RGBA.size, (128, 128, 128)) | |
| img_RGB.paste(img_RGBA, mask=img_RGBA.split()[3]) | |
| return img_RGB | |
| def get_img(x): | |
| if isinstance(x, str): | |
| img_RGBA = Image.open(x).convert("RGBA") | |
| else: | |
| img_RGBA = x.convert("RGBA") | |
| return img_RGBA, rgba2rgb(img_RGBA) | |
| whole_img_RGBA, whole_img_RGB = get_img(item["whole_image"]) | |
| whole_cap = item["whole_caption"] | |
| W, H = whole_img_RGBA.size | |
| base_layout = [0, 0, W, H] | |
| layer_image_RGBA = [self.to_tensor(whole_img_RGBA)] | |
| layer_image_RGB = [self.to_tensor(whole_img_RGB)] | |
| layout = [base_layout] | |
| base_img_RGBA, base_img_RGB = get_img(item["base_image"]) | |
| layer_image_RGBA.append(self.to_tensor(base_img_RGBA)) | |
| layer_image_RGB.append(self.to_tensor(base_img_RGB)) | |
| layout.append(base_layout) | |
| layer_count = item["layer_count"] | |
| for i in range(layer_count): | |
| key = f"layer_{i:02d}" | |
| img_RGBA, img_RGB = get_img(item[key]) | |
| w0, h0, w1, h1 = item[f"{key}_box"] | |
| canvas_RGBA = Image.new("RGBA", (W, H), (0, 0, 0, 0)) | |
| canvas_RGB = Image.new("RGB", (W, H), (128, 128, 128)) | |
| W_img, H_img = w1 - w0, h1 - h0 | |
| if img_RGBA.size != (W_img, H_img): | |
| img_RGBA = img_RGBA.resize((W_img, H_img), Image.BILINEAR) | |
| img_RGB = img_RGB.resize((W_img, H_img), Image.BILINEAR) | |
| canvas_RGBA.paste(img_RGBA, (w0, h0), img_RGBA) | |
| canvas_RGB.paste(img_RGB, (w0, h0)) | |
| layer_image_RGBA.append(self.to_tensor(canvas_RGBA)) | |
| layer_image_RGB.append(self.to_tensor(canvas_RGB)) | |
| layout.append([w0, h0, w1, h1]) | |
| return { | |
| "pixel_RGBA": layer_image_RGBA, | |
| "pixel_RGB": layer_image_RGB, | |
| "whole_img": whole_img_RGB, | |
| "caption": whole_cap, | |
| "height": H, | |
| "width": W, | |
| "layout": layout, | |
| } | |
| def prism_collate_fn(batch): | |
| """Collate function for PrismBlendDataset.""" | |
| pixels_RGBA = [torch.stack(item["pixel_RGBA"]) for item in batch] | |
| pixels_RGB = [torch.stack(item["pixel_RGB"]) for item in batch] | |
| pixels_RGBA = torch.stack(pixels_RGBA) | |
| pixels_RGB = torch.stack(pixels_RGB) | |
| return { | |
| "pixel_RGBA": pixels_RGBA, | |
| "pixel_RGB": pixels_RGB, | |
| "whole_img": [item["whole_img"] for item in batch], | |
| "caption": [item["caption"] for item in batch], | |
| "height": [item["height"] for item in batch], | |
| "width": [item["width"] for item in batch], | |
| "layout": [item["layout"] for item in batch], | |
| } | |
| class PrismBlendDataset(Dataset): | |
| """ | |
| Dataset for PrismLayersPro blended data. | |
| Loads from local directory structure (following PrismLayersPro convention): | |
| - data_dir/sample_XXXXXX/metadata.json | |
| - data_dir/sample_XXXXXX/whole_image.png | |
| - data_dir/sample_XXXXXX/base_image.png | |
| - data_dir/sample_XXXXXX/layer_00.png, layer_01.png, ... | |
| Boxes are in xyxy format: [x0, y0, x1, y1] | |
| All layer images have transparent backgrounds. | |
| """ | |
| def __init__(self, data_dir: str, jsonl_path: str = None, target_size: int = 512, split: str = "all", max_layer_num: int = None): | |
| self.data_dir = data_dir | |
| self.target_size = target_size | |
| self.max_layer_num = max_layer_num | |
| self.to_tensor = T.ToTensor() | |
| # Load samples | |
| if jsonl_path and os.path.exists(jsonl_path): | |
| self.samples = self._load_from_jsonl(jsonl_path) | |
| else: | |
| self.samples = self._load_from_directory(data_dir) | |
| # Filter samples exceeding max_layer_num (if specified) | |
| # Total layers = 2 (whole_image + base_image) + layer_count | |
| if max_layer_num is not None: | |
| original_count = len(self.samples) | |
| self.samples = [ | |
| s for s in self.samples | |
| if (2 + s.get('layer_count', 0)) <= max_layer_num | |
| ] | |
| filtered_count = original_count - len(self.samples) | |
| if filtered_count > 0: | |
| print(f"[INFO] Filtered {filtered_count} samples exceeding max_layer_num={max_layer_num}") | |
| # Split dataset (only if explicitly requested, default is "all" = use all samples) | |
| # Usually you have separate train/test datasets, so no splitting needed | |
| if split == "train_split": | |
| self.samples = self.samples[:int(len(self.samples) * 0.9)] | |
| elif split == "test_split": | |
| self.samples = self.samples[int(len(self.samples) * 0.9):int(len(self.samples) * 0.95)] | |
| elif split == "val_split": | |
| self.samples = self.samples[int(len(self.samples) * 0.95):] | |
| # "all", "train", "test" -> use all samples from the provided jsonl/directory | |
| def _load_from_jsonl(self, jsonl_path: str): | |
| """Load samples from JSONL file.""" | |
| samples = [] | |
| with open(jsonl_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| samples.append(json.loads(line)) | |
| return samples | |
| def _load_from_directory(self, data_dir: str): | |
| """Load samples from directory structure.""" | |
| samples = [] | |
| for name in sorted(os.listdir(data_dir)): | |
| sample_dir = os.path.join(data_dir, name) | |
| if os.path.isdir(sample_dir) and name.startswith('sample_'): | |
| metadata_path = os.path.join(sample_dir, 'metadata.json') | |
| #metadata_path = os.path.join(sample_dir, 'metadata_old.json') # old for original_1024. | |
| if os.path.exists(metadata_path): | |
| with open(metadata_path, 'r', encoding='utf-8') as f: | |
| samples.append(json.load(f)) | |
| return samples | |
| def __len__(self): | |
| return len(self.samples) | |
| def _rgba2rgb(self, img_RGBA): | |
| """Convert RGBA to RGB with gray background.""" | |
| img_RGB = Image.new("RGB", img_RGBA.size, (128, 128, 128)) | |
| img_RGB.paste(img_RGBA, mask=img_RGBA.split()[3]) | |
| return img_RGB | |
| def _get_sample_dir(self, sample): | |
| """Get the directory for a sample.""" | |
| # Try sample_dir first | |
| sample_dir = sample.get('sample_dir', '') | |
| if sample_dir: | |
| full_path = os.path.join(self.data_dir, sample_dir) | |
| if os.path.exists(full_path): | |
| return full_path | |
| return None | |
| def __getitem__(self, idx): | |
| sample = self.samples[idx] | |
| sample_dir = self._get_sample_dir(sample) | |
| if not sample_dir: | |
| raise ValueError(f"Could not find sample directory for index {idx}") | |
| source_size = sample.get('width', self.target_size) | |
| caption = sample.get('whole_caption', '') | |
| # Scale factor (source -> target) | |
| scale = self.target_size / source_size | |
| # Load whole_image (composite) | |
| whole_img_path = os.path.join(sample_dir, 'whole_image.png') | |
| if os.path.exists(whole_img_path): | |
| whole_img = Image.open(whole_img_path).convert('RGBA') | |
| else: | |
| whole_img = Image.new('RGBA', (source_size, source_size), (128, 128, 128, 255)) | |
| # Resize if needed | |
| if whole_img.size != (self.target_size, self.target_size): | |
| whole_img = whole_img.resize((self.target_size, self.target_size), Image.LANCZOS) | |
| whole_img_RGB = self._rgba2rgb(whole_img) | |
| # Initialize layer lists with whole_image first | |
| layer_image_RGBA = [self.to_tensor(whole_img)] | |
| layer_image_RGB = [self.to_tensor(whole_img_RGB)] | |
| # Base layout (whole image) in xyxy format [x0, y0, x1, y1] | |
| W, H = self.target_size, self.target_size | |
| base_layout = [0, 0, W, H] # xyxy with exclusive end coordinates | |
| layout = [base_layout] | |
| # Load base_image (background) as second layer | |
| base_img_path = os.path.join(sample_dir, 'base_image.png') | |
| if os.path.exists(base_img_path): | |
| base_img = Image.open(base_img_path).convert('RGBA') | |
| if base_img.size != (self.target_size, self.target_size): | |
| base_img = base_img.resize((self.target_size, self.target_size), Image.LANCZOS) | |
| else: | |
| base_img = Image.new('RGBA', (self.target_size, self.target_size), (0, 0, 0, 0)) | |
| base_img_RGB = self._rgba2rgb(base_img) | |
| layer_image_RGBA.append(self.to_tensor(base_img)) | |
| layer_image_RGB.append(self.to_tensor(base_img_RGB)) | |
| layout.append(base_layout) # background covers whole image | |
| # Load layers from metadata | |
| layers = sample.get('layers', []) | |
| for layer_info in layers: | |
| image_path = layer_info.get('image_path', '') | |
| box = layer_info.get('box', [0, 0, source_size, source_size]) | |
| # Scale box (xyxy format) | |
| x0, y0, x1, y1 = box | |
| scaled_box = [ | |
| int(x0 * scale), | |
| int(y0 * scale), | |
| int(x1 * scale), | |
| int(y1 * scale) | |
| ] | |
| # Load layer image | |
| # Handles two formats: | |
| # 1. Full-canvas (target_size x target_size) — use as-is | |
| # 2. Cropped (smaller than canvas) — place at bbox position on transparent canvas | |
| layer_path = os.path.join(sample_dir, image_path) | |
| if os.path.exists(layer_path): | |
| layer_img = Image.open(layer_path).convert('RGBA') | |
| if layer_img.size == (self.target_size, self.target_size): | |
| # Already full-canvas, use directly | |
| pass | |
| elif layer_img.size == (source_size, source_size) and source_size != self.target_size: | |
| # Full-canvas at source resolution, just resize | |
| layer_img = layer_img.resize((self.target_size, self.target_size), Image.LANCZOS) | |
| else: | |
| # Cropped layer — resize to fit the scaled bbox and place on canvas | |
| bw = max(1, scaled_box[2] - scaled_box[0]) | |
| bh = max(1, scaled_box[3] - scaled_box[1]) | |
| layer_resized = layer_img.resize((bw, bh), Image.LANCZOS) | |
| layer_img = Image.new('RGBA', (self.target_size, self.target_size), (0, 0, 0, 0)) | |
| layer_img.paste(layer_resized, (scaled_box[0], scaled_box[1]), layer_resized) | |
| else: | |
| layer_img = Image.new('RGBA', (self.target_size, self.target_size), (0, 0, 0, 0)) | |
| layer_img_RGB = self._rgba2rgb(layer_img) | |
| layer_image_RGBA.append(self.to_tensor(layer_img)) | |
| layer_image_RGB.append(self.to_tensor(layer_img_RGB)) | |
| layout.append(scaled_box) | |
| return { | |
| "pixel_RGBA": layer_image_RGBA, | |
| "pixel_RGB": layer_image_RGB, | |
| "whole_img": whole_img_RGB, | |
| "caption": caption, | |
| "height": H, | |
| "width": W, | |
| "layout": layout, | |
| } |