SynLayers commited on
Commit
6cbd779
·
verified ·
1 Parent(s): e1b8cea

Upload tools/dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tools/dataset.py +447 -0
tools/dataset.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+ from datasets import load_dataset, concatenate_datasets
8
+ import torchvision.transforms as T
9
+ from collections import defaultdict
10
+
11
+
12
+ def collate_fn(batch):
13
+ pixels_RGBA = [torch.stack(item["pixel_RGBA"]) for item in batch] # [L, C, H, W]
14
+ pixels_RGB = [torch.stack(item["pixel_RGB"]) for item in batch] # [L, C, H, W]
15
+ pixels_RGBA = torch.stack(pixels_RGBA) # [B, L, C, H, W]
16
+ pixels_RGB = torch.stack(pixels_RGB) # [B, L, C, H, W]
17
+
18
+ return {
19
+ "pixel_RGBA": pixels_RGBA,
20
+ "pixel_RGB": pixels_RGB,
21
+ "whole_img": [item["whole_img"] for item in batch],
22
+ "caption": [item["caption"] for item in batch],
23
+ "height": [item["height"] for item in batch],
24
+ "width": [item["width"] for item in batch],
25
+ "layout": [item["layout"] for item in batch],
26
+ }
27
+
28
+ class LayoutTrainDataset(Dataset):
29
+ def __init__(self, data_dir, split="train"):
30
+ full_dataset = load_dataset(
31
+ "artplus/PrismLayersPro",
32
+ cache_dir=data_dir,
33
+ )
34
+ full_dataset = concatenate_datasets(list(full_dataset.values()))
35
+
36
+ if "style_category" not in full_dataset.column_names:
37
+ raise ValueError("Dataset must contain a 'style_category' field to split by class.")
38
+
39
+ categories = np.array(full_dataset["style_category"])
40
+ category_to_indices = defaultdict(list)
41
+ for i, cat in enumerate(categories):
42
+ category_to_indices[cat].append(i)
43
+
44
+ subsets = []
45
+ for cat, indices in category_to_indices.items():
46
+ total_len = len(indices)
47
+ idx_90 = int(total_len * 0.9)
48
+ idx_95 = int(total_len * 0.95)
49
+
50
+ if split == "train":
51
+ selected_idx = indices[:idx_90]
52
+ elif split == "test":
53
+ selected_idx = indices[idx_90:idx_95]
54
+ elif split == "val":
55
+ selected_idx = indices[idx_95:]
56
+ else:
57
+ raise ValueError("split must be 'train', 'val', or 'test'")
58
+
59
+ subsets.append(full_dataset.select(selected_idx))
60
+
61
+ self.dataset = concatenate_datasets(subsets)
62
+ self.to_tensor = T.ToTensor()
63
+
64
+ def __len__(self):
65
+ return len(self.dataset)
66
+
67
+ def __getitem__(self, idx):
68
+ item = self.dataset[idx]
69
+
70
+ def rgba2rgb(img_RGBA):
71
+ img_RGB = Image.new("RGB", img_RGBA.size, (128, 128, 128))
72
+ img_RGB.paste(img_RGBA, mask=img_RGBA.split()[3])
73
+ return img_RGB
74
+
75
+ def get_img(x):
76
+ if isinstance(x, str):
77
+ img_RGBA = Image.open(x).convert("RGBA")
78
+ img_RGB = rgba2rgb(img_RGBA)
79
+ else:
80
+ img_RGBA = x.convert("RGBA")
81
+ img_RGB = rgba2rgb(img_RGBA)
82
+ return img_RGBA, img_RGB
83
+
84
+ whole_img_RGBA, whole_img_RGB = get_img(item["whole_image"])
85
+ whole_cap = item["whole_caption"]
86
+ W, H = whole_img_RGBA.size
87
+ base_layout = [0, 0, W, H] # xyxy with exclusive end coordinates
88
+
89
+ layer_image_RGBA = [self.to_tensor(whole_img_RGBA)]
90
+ layer_image_RGB = [self.to_tensor(whole_img_RGB)]
91
+ layout = [base_layout]
92
+
93
+ base_img_RGBA, base_img_RGB = get_img(item["base_image"])
94
+ layer_image_RGBA.append(self.to_tensor(base_img_RGBA))
95
+ layer_image_RGB.append(self.to_tensor(base_img_RGB))
96
+ layout.append(base_layout)
97
+
98
+ layer_count = item["layer_count"]
99
+ for i in range(layer_count):
100
+ key = f"layer_{i:02d}"
101
+ img_RGBA, img_RGB = get_img(item[key])
102
+
103
+ w0, h0, w1, h1 = item[f"{key}_box"]
104
+
105
+ canvas_RGBA = Image.new("RGBA", (W, H), (0, 0, 0, 0))
106
+ canvas_RGB = Image.new("RGB", (W, H), (128, 128, 128))
107
+
108
+ W_img, H_img = w1 - w0, h1 - h0
109
+ if img_RGBA.size != (W_img, H_img):
110
+ img_RGBA = img_RGBA.resize((W_img, H_img), Image.BILINEAR)
111
+ img_RGB = img_RGB.resize((W_img, H_img), Image.BILINEAR)
112
+
113
+ canvas_RGBA.paste(img_RGBA, (w0, h0), img_RGBA)
114
+ canvas_RGB.paste(img_RGB, (w0, h0))
115
+
116
+ layer_image_RGBA.append(self.to_tensor(canvas_RGBA))
117
+ layer_image_RGB.append(self.to_tensor(canvas_RGB))
118
+ layout.append([w0, h0, w1, h1])
119
+
120
+ return {
121
+ "pixel_RGBA": layer_image_RGBA,
122
+ "pixel_RGB": layer_image_RGB,
123
+ "whole_img": whole_img_RGB,
124
+ "caption": whole_cap,
125
+ "height": H,
126
+ "width": W,
127
+ "layout": layout,
128
+ }
129
+
130
+
131
+ class LayoutDatasetFixedSplit(Dataset):
132
+ """
133
+ HuggingFace PrismLayersPro with a fixed index-based split.
134
+ Total 20,000 samples: train = [0, 19500), test = [19500, 20000).
135
+
136
+ For test split, use start_index and max_samples to select a sub-range:
137
+ start_index=200, max_samples=100 -> samples 019700-019799
138
+ start_index=0, max_samples=100 -> samples 019500-019599
139
+ """
140
+
141
+ TRAIN_END = 19500
142
+ TOTAL = 20000
143
+
144
+ def __init__(self, data_dir, split="train", start_index=0, max_samples=None):
145
+ full_dataset = load_dataset(
146
+ "artplus/PrismLayersPro",
147
+ cache_dir=data_dir,
148
+ )
149
+ full_dataset = concatenate_datasets(list(full_dataset.values()))
150
+
151
+ if split == "train":
152
+ self.dataset = full_dataset.select(range(self.TRAIN_END))
153
+ self.global_offset = 0
154
+ elif split == "test":
155
+ self.dataset = full_dataset.select(range(self.TRAIN_END, self.TOTAL))
156
+ self.global_offset = self.TRAIN_END
157
+ else:
158
+ raise ValueError("split must be 'train' or 'test'")
159
+
160
+ end_index = len(self.dataset)
161
+ if max_samples is not None:
162
+ end_index = min(start_index + max_samples, len(self.dataset))
163
+ self.dataset = self.dataset.select(range(start_index, end_index))
164
+ self.global_offset += start_index
165
+
166
+ self.to_tensor = T.ToTensor()
167
+ print(f"[INFO] LayoutDatasetFixedSplit: split={split}, "
168
+ f"global range=[{self.global_offset}, {self.global_offset + len(self.dataset)}), "
169
+ f"samples={len(self.dataset)}")
170
+
171
+ def __len__(self):
172
+ return len(self.dataset)
173
+
174
+ def __getitem__(self, idx):
175
+ item = self.dataset[idx]
176
+
177
+ def rgba2rgb(img_RGBA):
178
+ img_RGB = Image.new("RGB", img_RGBA.size, (128, 128, 128))
179
+ img_RGB.paste(img_RGBA, mask=img_RGBA.split()[3])
180
+ return img_RGB
181
+
182
+ def get_img(x):
183
+ if isinstance(x, str):
184
+ img_RGBA = Image.open(x).convert("RGBA")
185
+ else:
186
+ img_RGBA = x.convert("RGBA")
187
+ return img_RGBA, rgba2rgb(img_RGBA)
188
+
189
+ whole_img_RGBA, whole_img_RGB = get_img(item["whole_image"])
190
+ whole_cap = item["whole_caption"]
191
+ W, H = whole_img_RGBA.size
192
+ base_layout = [0, 0, W, H]
193
+
194
+ layer_image_RGBA = [self.to_tensor(whole_img_RGBA)]
195
+ layer_image_RGB = [self.to_tensor(whole_img_RGB)]
196
+ layout = [base_layout]
197
+
198
+ base_img_RGBA, base_img_RGB = get_img(item["base_image"])
199
+ layer_image_RGBA.append(self.to_tensor(base_img_RGBA))
200
+ layer_image_RGB.append(self.to_tensor(base_img_RGB))
201
+ layout.append(base_layout)
202
+
203
+ layer_count = item["layer_count"]
204
+ for i in range(layer_count):
205
+ key = f"layer_{i:02d}"
206
+ img_RGBA, img_RGB = get_img(item[key])
207
+
208
+ w0, h0, w1, h1 = item[f"{key}_box"]
209
+
210
+ canvas_RGBA = Image.new("RGBA", (W, H), (0, 0, 0, 0))
211
+ canvas_RGB = Image.new("RGB", (W, H), (128, 128, 128))
212
+
213
+ W_img, H_img = w1 - w0, h1 - h0
214
+ if img_RGBA.size != (W_img, H_img):
215
+ img_RGBA = img_RGBA.resize((W_img, H_img), Image.BILINEAR)
216
+ img_RGB = img_RGB.resize((W_img, H_img), Image.BILINEAR)
217
+
218
+ canvas_RGBA.paste(img_RGBA, (w0, h0), img_RGBA)
219
+ canvas_RGB.paste(img_RGB, (w0, h0))
220
+
221
+ layer_image_RGBA.append(self.to_tensor(canvas_RGBA))
222
+ layer_image_RGB.append(self.to_tensor(canvas_RGB))
223
+ layout.append([w0, h0, w1, h1])
224
+
225
+ return {
226
+ "pixel_RGBA": layer_image_RGBA,
227
+ "pixel_RGB": layer_image_RGB,
228
+ "whole_img": whole_img_RGB,
229
+ "caption": whole_cap,
230
+ "height": H,
231
+ "width": W,
232
+ "layout": layout,
233
+ }
234
+
235
+
236
+ def prism_collate_fn(batch):
237
+ """Collate function for PrismBlendDataset."""
238
+ pixels_RGBA = [torch.stack(item["pixel_RGBA"]) for item in batch]
239
+ pixels_RGB = [torch.stack(item["pixel_RGB"]) for item in batch]
240
+ pixels_RGBA = torch.stack(pixels_RGBA)
241
+ pixels_RGB = torch.stack(pixels_RGB)
242
+
243
+ return {
244
+ "pixel_RGBA": pixels_RGBA,
245
+ "pixel_RGB": pixels_RGB,
246
+ "whole_img": [item["whole_img"] for item in batch],
247
+ "caption": [item["caption"] for item in batch],
248
+ "height": [item["height"] for item in batch],
249
+ "width": [item["width"] for item in batch],
250
+ "layout": [item["layout"] for item in batch],
251
+ }
252
+
253
+
254
+ class PrismBlendDataset(Dataset):
255
+ """
256
+ Dataset for PrismLayersPro blended data.
257
+
258
+ Loads from local directory structure (following PrismLayersPro convention):
259
+ - data_dir/sample_XXXXXX/metadata.json
260
+ - data_dir/sample_XXXXXX/whole_image.png
261
+ - data_dir/sample_XXXXXX/base_image.png
262
+ - data_dir/sample_XXXXXX/layer_00.png, layer_01.png, ...
263
+
264
+ Boxes are in xyxy format: [x0, y0, x1, y1]
265
+ All layer images have transparent backgrounds.
266
+ """
267
+
268
+ def __init__(self, data_dir: str, jsonl_path: str = None, target_size: int = 512, split: str = "all", max_layer_num: int = None):
269
+ self.data_dir = data_dir
270
+ self.target_size = target_size
271
+ self.max_layer_num = max_layer_num
272
+ self.to_tensor = T.ToTensor()
273
+
274
+ # Load samples
275
+ if jsonl_path and os.path.exists(jsonl_path):
276
+ self.samples = self._load_from_jsonl(jsonl_path)
277
+ else:
278
+ self.samples = self._load_from_directory(data_dir)
279
+
280
+ # Filter samples exceeding max_layer_num (if specified)
281
+ # Total layers = 2 (whole_image + base_image) + layer_count
282
+ if max_layer_num is not None:
283
+ original_count = len(self.samples)
284
+ self.samples = [
285
+ s for s in self.samples
286
+ if (2 + s.get('layer_count', 0)) <= max_layer_num
287
+ ]
288
+ filtered_count = original_count - len(self.samples)
289
+ if filtered_count > 0:
290
+ print(f"[INFO] Filtered {filtered_count} samples exceeding max_layer_num={max_layer_num}")
291
+
292
+ # Split dataset (only if explicitly requested, default is "all" = use all samples)
293
+ # Usually you have separate train/test datasets, so no splitting needed
294
+ if split == "train_split":
295
+ self.samples = self.samples[:int(len(self.samples) * 0.9)]
296
+ elif split == "test_split":
297
+ self.samples = self.samples[int(len(self.samples) * 0.9):int(len(self.samples) * 0.95)]
298
+ elif split == "val_split":
299
+ self.samples = self.samples[int(len(self.samples) * 0.95):]
300
+ # "all", "train", "test" -> use all samples from the provided jsonl/directory
301
+
302
+ def _load_from_jsonl(self, jsonl_path: str):
303
+ """Load samples from JSONL file."""
304
+ samples = []
305
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
306
+ for line in f:
307
+ line = line.strip()
308
+ if line:
309
+ samples.append(json.loads(line))
310
+ return samples
311
+
312
+ def _load_from_directory(self, data_dir: str):
313
+ """Load samples from directory structure."""
314
+ samples = []
315
+ for name in sorted(os.listdir(data_dir)):
316
+ sample_dir = os.path.join(data_dir, name)
317
+ if os.path.isdir(sample_dir) and name.startswith('sample_'):
318
+ metadata_path = os.path.join(sample_dir, 'metadata.json')
319
+ #metadata_path = os.path.join(sample_dir, 'metadata_old.json') # old for original_1024.
320
+ if os.path.exists(metadata_path):
321
+ with open(metadata_path, 'r', encoding='utf-8') as f:
322
+ samples.append(json.load(f))
323
+ return samples
324
+
325
+ def __len__(self):
326
+ return len(self.samples)
327
+
328
+ def _rgba2rgb(self, img_RGBA):
329
+ """Convert RGBA to RGB with gray background."""
330
+ img_RGB = Image.new("RGB", img_RGBA.size, (128, 128, 128))
331
+ img_RGB.paste(img_RGBA, mask=img_RGBA.split()[3])
332
+ return img_RGB
333
+
334
+ def _get_sample_dir(self, sample):
335
+ """Get the directory for a sample."""
336
+ # Try sample_dir first
337
+ sample_dir = sample.get('sample_dir', '')
338
+ if sample_dir:
339
+ full_path = os.path.join(self.data_dir, sample_dir)
340
+ if os.path.exists(full_path):
341
+ return full_path
342
+
343
+ return None
344
+
345
+ def __getitem__(self, idx):
346
+ sample = self.samples[idx]
347
+ sample_dir = self._get_sample_dir(sample)
348
+
349
+ if not sample_dir:
350
+ raise ValueError(f"Could not find sample directory for index {idx}")
351
+
352
+ source_size = sample.get('width', self.target_size)
353
+ caption = sample.get('whole_caption', '')
354
+
355
+ # Scale factor (source -> target)
356
+ scale = self.target_size / source_size
357
+
358
+ # Load whole_image (composite)
359
+ whole_img_path = os.path.join(sample_dir, 'whole_image.png')
360
+ if os.path.exists(whole_img_path):
361
+ whole_img = Image.open(whole_img_path).convert('RGBA')
362
+ else:
363
+ whole_img = Image.new('RGBA', (source_size, source_size), (128, 128, 128, 255))
364
+
365
+ # Resize if needed
366
+ if whole_img.size != (self.target_size, self.target_size):
367
+ whole_img = whole_img.resize((self.target_size, self.target_size), Image.LANCZOS)
368
+
369
+ whole_img_RGB = self._rgba2rgb(whole_img)
370
+
371
+ # Initialize layer lists with whole_image first
372
+ layer_image_RGBA = [self.to_tensor(whole_img)]
373
+ layer_image_RGB = [self.to_tensor(whole_img_RGB)]
374
+
375
+ # Base layout (whole image) in xyxy format [x0, y0, x1, y1]
376
+ W, H = self.target_size, self.target_size
377
+ base_layout = [0, 0, W, H] # xyxy with exclusive end coordinates
378
+ layout = [base_layout]
379
+
380
+ # Load base_image (background) as second layer
381
+ base_img_path = os.path.join(sample_dir, 'base_image.png')
382
+ if os.path.exists(base_img_path):
383
+ base_img = Image.open(base_img_path).convert('RGBA')
384
+ if base_img.size != (self.target_size, self.target_size):
385
+ base_img = base_img.resize((self.target_size, self.target_size), Image.LANCZOS)
386
+ else:
387
+ base_img = Image.new('RGBA', (self.target_size, self.target_size), (0, 0, 0, 0))
388
+
389
+ base_img_RGB = self._rgba2rgb(base_img)
390
+ layer_image_RGBA.append(self.to_tensor(base_img))
391
+ layer_image_RGB.append(self.to_tensor(base_img_RGB))
392
+ layout.append(base_layout) # background covers whole image
393
+
394
+ # Load layers from metadata
395
+ layers = sample.get('layers', [])
396
+
397
+ for layer_info in layers:
398
+ image_path = layer_info.get('image_path', '')
399
+ box = layer_info.get('box', [0, 0, source_size, source_size])
400
+
401
+ # Scale box (xyxy format)
402
+ x0, y0, x1, y1 = box
403
+ scaled_box = [
404
+ int(x0 * scale),
405
+ int(y0 * scale),
406
+ int(x1 * scale),
407
+ int(y1 * scale)
408
+ ]
409
+
410
+ # Load layer image
411
+ # Handles two formats:
412
+ # 1. Full-canvas (target_size x target_size) — use as-is
413
+ # 2. Cropped (smaller than canvas) — place at bbox position on transparent canvas
414
+ layer_path = os.path.join(sample_dir, image_path)
415
+ if os.path.exists(layer_path):
416
+ layer_img = Image.open(layer_path).convert('RGBA')
417
+ if layer_img.size == (self.target_size, self.target_size):
418
+ # Already full-canvas, use directly
419
+ pass
420
+ elif layer_img.size == (source_size, source_size) and source_size != self.target_size:
421
+ # Full-canvas at source resolution, just resize
422
+ layer_img = layer_img.resize((self.target_size, self.target_size), Image.LANCZOS)
423
+ else:
424
+ # Cropped layer — resize to fit the scaled bbox and place on canvas
425
+ bw = max(1, scaled_box[2] - scaled_box[0])
426
+ bh = max(1, scaled_box[3] - scaled_box[1])
427
+ layer_resized = layer_img.resize((bw, bh), Image.LANCZOS)
428
+ layer_img = Image.new('RGBA', (self.target_size, self.target_size), (0, 0, 0, 0))
429
+ layer_img.paste(layer_resized, (scaled_box[0], scaled_box[1]), layer_resized)
430
+ else:
431
+ layer_img = Image.new('RGBA', (self.target_size, self.target_size), (0, 0, 0, 0))
432
+
433
+ layer_img_RGB = self._rgba2rgb(layer_img)
434
+
435
+ layer_image_RGBA.append(self.to_tensor(layer_img))
436
+ layer_image_RGB.append(self.to_tensor(layer_img_RGB))
437
+ layout.append(scaled_box)
438
+
439
+ return {
440
+ "pixel_RGBA": layer_image_RGBA,
441
+ "pixel_RGB": layer_image_RGB,
442
+ "whole_img": whole_img_RGB,
443
+ "caption": caption,
444
+ "height": H,
445
+ "width": W,
446
+ "layout": layout,
447
+ }