SynLayers commited on
Commit
26369cf
·
verified ·
1 Parent(s): 17b46bf

Upload dataset/scaleup_dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataset/scaleup_dataset.py +637 -0
dataset/scaleup_dataset.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+ import logging
6
+ import shutil
7
+ from typing import Dict, List, Optional, Tuple
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ from multiprocessing import Pool, cpu_count
11
+ from functools import partial
12
+
13
+ from scaleup_utils import (
14
+ load_jsonl,
15
+ save_jsonl,
16
+ load_blended_sample,
17
+ get_blended_sample_dirs,
18
+ compute_non_overlapping_box_xyxy,
19
+ compute_total_overlap,
20
+ create_layer_on_canvas,
21
+ build_spatial_aware_caption,
22
+ get_position_description,
23
+ get_box_size,
24
+ get_content_bbox,
25
+ load_caption_list,
26
+ get_laion_images_with_captions,
27
+ get_caption_images_with_text,
28
+ select_random_layers_from_samples,
29
+ )
30
+
31
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Default canvas size
35
+ CANVAS_SIZE = 1024
36
+
37
+
38
+ def parse_args():
39
+ parser = argparse.ArgumentParser(description='Scale up PrismLayersPro-blended dataset')
40
+ parser.add_argument('--blended_dir', type=str, required=True,
41
+ help='Path to PrismLayersPro-blended directory')
42
+ parser.add_argument('--laion_dir', type=str, required=True,
43
+ help='Path to LAION aesthetic images directory')
44
+ parser.add_argument('--caption_dir', type=str, required=True,
45
+ help='Path to caption images directory')
46
+ parser.add_argument('--caption_meta', type=str, required=True,
47
+ help='Path to captions.jsonl with caption text')
48
+ parser.add_argument('--output_dir', type=str, required=True,
49
+ help='Output directory for scaled-up dataset')
50
+ parser.add_argument('--num_samples', type=int, default=100000,
51
+ help='Number of new samples to generate')
52
+ parser.add_argument('--start_index', type=int, default=0,
53
+ help='Starting sample index for output naming')
54
+ parser.add_argument('--seed', type=int, default=42,
55
+ help='Random seed')
56
+ # Layer selection parameters
57
+ parser.add_argument('--min_donor_samples', type=int, default=2,
58
+ help='Minimum number of samples to pick layers from')
59
+ parser.add_argument('--max_donor_samples', type=int, default=3,
60
+ help='Maximum number of samples to pick layers from')
61
+ parser.add_argument('--min_layers_per_donor', type=int, default=1,
62
+ help='Minimum layers to pick from each donor sample')
63
+ parser.add_argument('--max_layers_per_donor', type=int, default=2,
64
+ help='Maximum layers to pick from each donor sample')
65
+ parser.add_argument('--added_layer_min_size', type=float, default=0.8,
66
+ help='Minimum size ratio for added layers')
67
+ parser.add_argument('--added_layer_max_size', type=float, default=1.2,
68
+ help='Maximum size ratio for added layers')
69
+ parser.add_argument('--laion_prob', type=float, default=0.1,
70
+ help='Probability of including LAION image layer')
71
+ parser.add_argument('--caption_prob', type=float, default=0.2,
72
+ help='Probability of including caption text image layer')
73
+ parser.add_argument('--laion_min_size', type=float, default=0.2,
74
+ help='Minimum size ratio for LAION layer')
75
+ parser.add_argument('--laion_max_size', type=float, default=0.4,
76
+ help='Maximum size ratio for LAION layer')
77
+ parser.add_argument('--caption_min_size', type=float, default=1.0,
78
+ help='Minimum size ratio for caption layer')
79
+ parser.add_argument('--caption_max_size', type=float, default=1.2,
80
+ help='Maximum size ratio for caption layer')
81
+ # Base layer removal parameters
82
+ parser.add_argument('--min_layers_to_remove', type=int, default=2,
83
+ help='Minimum number of layers to remove from base sample')
84
+ parser.add_argument('--max_layers_to_remove', type=int, default=4,
85
+ help='Maximum number of layers to remove from base sample')
86
+ # AlphaVAE layer parameters
87
+ parser.add_argument('--alphavae_dir', type=str, default=None,
88
+ help='Path to AlphaVAE_frontview images directory')
89
+ parser.add_argument('--alphavae_prompts', type=str, default=None,
90
+ help='Path to prompts.txt for AlphaVAE captions')
91
+ parser.add_argument('--alphavae_min_layers', type=int, default=0,
92
+ help='Minimum number of AlphaVAE layers to add per sample')
93
+ parser.add_argument('--alphavae_max_layers', type=int, default=0,
94
+ help='Maximum number of AlphaVAE layers to add per sample')
95
+ parser.add_argument('--alphavae_min_size', type=float, default=0.15,
96
+ help='Minimum size ratio for AlphaVAE layers')
97
+ parser.add_argument('--alphavae_max_size', type=float, default=0.35,
98
+ help='Maximum size ratio for AlphaVAE layers')
99
+ # Multiprocessing
100
+ parser.add_argument('--max_base_samples', type=int, default=None,
101
+ help='Max number of base samples to use (sorted by name). '
102
+ 'E.g. 18000 to use sample_000000..sample_017999')
103
+ parser.add_argument('--skip_existing', action='store_true',
104
+ help='Skip samples whose output directory already exists (for resuming)')
105
+ parser.add_argument('--num_workers', type=int, default=64,
106
+ help='Number of parallel workers (0 = single-process)')
107
+ return parser.parse_args()
108
+
109
+
110
+ def create_scaled_up_sample(
111
+ base_sample_dir: str,
112
+ all_sample_dirs: List[str],
113
+ laion_images: List[Tuple[str, str]],
114
+ caption_images: List[Tuple[str, str]],
115
+ alphavae_images: List[Tuple[str, str]],
116
+ output_dir: str,
117
+ sample_idx: int,
118
+ args: argparse.Namespace,
119
+ ) -> Optional[Dict]:
120
+ """
121
+ Create a new scaled-up sample by combining a base sample with layers from other samples.
122
+
123
+ Returns metadata dict or None if failed.
124
+ """
125
+ # Load base sample
126
+ base_meta = load_blended_sample(base_sample_dir)
127
+ if base_meta is None:
128
+ logger.warning(f"Failed to load base sample: {base_sample_dir}")
129
+ return None
130
+
131
+ canvas_size = base_meta.get('width', CANVAS_SIZE)
132
+
133
+ # Create output sample directory
134
+ sample_name = f"sample_{sample_idx:06d}"
135
+ sample_output_dir = os.path.join(output_dir, sample_name)
136
+ os.makedirs(sample_output_dir, exist_ok=True)
137
+
138
+ # Copy base_image
139
+ base_image = base_meta.get('base_image')
140
+ if base_image:
141
+ base_image.save(os.path.join(sample_output_dir, 'base_image.png'))
142
+ else:
143
+ base_image = Image.new('RGBA', (canvas_size, canvas_size), (0, 0, 0, 0))
144
+ base_image.save(os.path.join(sample_output_dir, 'base_image.png'))
145
+
146
+ # Start with base composite
147
+ composite = base_image.copy()
148
+
149
+ # Collect occupied boxes
150
+ occupied_boxes = []
151
+
152
+ # New layers list
153
+ new_layers = []
154
+ current_layer_idx = 0
155
+
156
+ # === Step 1: Copy layers from base sample (excluding laion_foreground and caption types) ===
157
+ # Also randomly remove 2-3 layers to keep total layer count reasonable
158
+ base_layers = base_meta.get('layers', [])
159
+ base_prism_layers = [l for l in base_layers if l.get('type') is None]
160
+
161
+ # Randomly remove some layers from base
162
+ num_to_remove = random.randint(args.min_layers_to_remove, args.max_layers_to_remove)
163
+ num_to_remove = min(num_to_remove, max(0, len(base_prism_layers) - 1)) # Keep at least 1 layer
164
+
165
+ removed_layer_indices = set()
166
+ if num_to_remove > 0 and len(base_prism_layers) > 1:
167
+ layers_to_remove = random.sample(base_prism_layers, num_to_remove)
168
+ removed_layer_indices = {l['layer_idx'] for l in layers_to_remove}
169
+
170
+ # Filter out removed layers
171
+ base_prism_layers_filtered = [l for l in base_prism_layers if l['layer_idx'] not in removed_layer_indices]
172
+
173
+ for layer in base_prism_layers_filtered:
174
+ orig_layer_idx = layer['layer_idx']
175
+ layer_img = base_meta.get('layer_images', {}).get(orig_layer_idx)
176
+
177
+ if layer_img is None:
178
+ continue
179
+
180
+ orig_box = layer.get('box', [0, 0, canvas_size, canvas_size])
181
+ caption = layer.get('caption', '')
182
+ orig_w = orig_box[2] - orig_box[0]
183
+ orig_h = orig_box[3] - orig_box[1]
184
+
185
+ # Randomly reposition (layout-agnostic): find a non-overlapping spot
186
+ best_box = None
187
+ best_overlap_ratio = float('inf')
188
+ new_box = None
189
+ for _ in range(300):
190
+ x0 = random.randint(0, max(0, canvas_size - orig_w))
191
+ y0 = random.randint(0, max(0, canvas_size - orig_h))
192
+ candidate = [x0, y0, x0 + orig_w, y0 + orig_h]
193
+ box_area = orig_w * orig_h
194
+ if box_area <= 0:
195
+ continue
196
+ overlap = compute_total_overlap(candidate, occupied_boxes)
197
+ overlap_ratio = overlap / box_area
198
+ if overlap == 0:
199
+ new_box = candidate
200
+ break
201
+ if overlap_ratio < best_overlap_ratio:
202
+ best_overlap_ratio = overlap_ratio
203
+ best_box = candidate
204
+ if new_box is None:
205
+ new_box = best_box if best_box else [0, 0, orig_w, orig_h]
206
+
207
+ # Place cropped layer onto full canvas at the new random position
208
+ layer_canvas = create_layer_on_canvas(layer_img, new_box, canvas_size)
209
+
210
+ # Save layer with new index
211
+ layer_filename = f'layer_{current_layer_idx:02d}.png'
212
+ layer_canvas.save(os.path.join(sample_output_dir, layer_filename))
213
+
214
+ # Composite
215
+ composite = Image.alpha_composite(composite, layer_canvas)
216
+
217
+ # Record
218
+ w, h = get_box_size(new_box)
219
+ new_layers.append({
220
+ 'layer_idx': current_layer_idx,
221
+ 'caption': caption,
222
+ 'box': new_box,
223
+ 'width_dst': w,
224
+ 'height_dst': h,
225
+ 'image_path': layer_filename,
226
+ 'source': 'base',
227
+ 'source_sample': base_sample_dir,
228
+ })
229
+
230
+ occupied_boxes.append(new_box)
231
+ current_layer_idx += 1
232
+
233
+ # === Step 2: Add layers from other samples ===
234
+ num_donors = random.randint(args.min_donor_samples, args.max_donor_samples)
235
+ donor_layers = select_random_layers_from_samples(
236
+ all_sample_dirs,
237
+ exclude_sample=base_sample_dir,
238
+ num_samples_to_pick=num_donors,
239
+ num_layers_per_sample=(args.min_layers_per_donor, args.max_layers_per_donor)
240
+ )
241
+
242
+ for layer_img, layer_info, source_sample in donor_layers:
243
+ caption = layer_info.get('caption', '')
244
+
245
+ # Use the donor layer's original bounding box dimensions
246
+ orig_box = layer_info.get('box', [0, 0, canvas_size, canvas_size])
247
+ orig_w = orig_box[2] - orig_box[0]
248
+ orig_h = orig_box[3] - orig_box[1]
249
+
250
+ # Find a non-overlapping position for a box of this exact size
251
+ best_box = None
252
+ best_overlap_ratio = float('inf')
253
+ new_box = None
254
+ for _ in range(300):
255
+ x0 = random.randint(0, max(0, canvas_size - orig_w))
256
+ y0 = random.randint(0, max(0, canvas_size - orig_h))
257
+ candidate = [x0, y0, x0 + orig_w, y0 + orig_h]
258
+ box_area = orig_w * orig_h
259
+ if box_area <= 0:
260
+ continue
261
+ overlap = compute_total_overlap(candidate, occupied_boxes)
262
+ overlap_ratio = overlap / box_area
263
+ if overlap == 0:
264
+ new_box = candidate
265
+ break
266
+ if overlap_ratio < best_overlap_ratio:
267
+ best_overlap_ratio = overlap_ratio
268
+ best_box = candidate
269
+ if new_box is None:
270
+ new_box = best_box if best_box else [0, 0, orig_w, orig_h]
271
+
272
+ # Create layer on canvas at new position
273
+ layer_canvas = create_layer_on_canvas(layer_img, new_box, canvas_size)
274
+
275
+ # Save layer
276
+ layer_filename = f'layer_{current_layer_idx:02d}.png'
277
+ layer_canvas.save(os.path.join(sample_output_dir, layer_filename))
278
+
279
+ # Composite
280
+ composite = Image.alpha_composite(composite, layer_canvas)
281
+
282
+ # Record
283
+ w, h = get_box_size(new_box)
284
+ new_layers.append({
285
+ 'layer_idx': current_layer_idx,
286
+ 'caption': caption,
287
+ 'box': new_box,
288
+ 'width_dst': w,
289
+ 'height_dst': h,
290
+ 'image_path': layer_filename,
291
+ 'source': 'donor',
292
+ 'source_sample': source_sample,
293
+ 'original_layer_idx': layer_info.get('layer_idx'),
294
+ })
295
+
296
+ occupied_boxes.append(new_box)
297
+ current_layer_idx += 1
298
+
299
+ # === Step 2.5: Optionally add AlphaVAE layers (0 to max) ===
300
+ if alphavae_images and args.alphavae_max_layers > 0:
301
+ num_alpha = random.randint(args.alphavae_min_layers, args.alphavae_max_layers)
302
+ if num_alpha > 0:
303
+ selected_alpha = random.sample(alphavae_images, min(num_alpha, len(alphavae_images)))
304
+ for alpha_path, alpha_caption in selected_alpha:
305
+ try:
306
+ alpha_img = Image.open(alpha_path).convert('RGBA')
307
+ except Exception as e:
308
+ logger.warning(f"Failed to load AlphaVAE image: {alpha_path}, {e}")
309
+ continue
310
+
311
+ alpha_box = compute_non_overlapping_box_xyxy(
312
+ canvas_size, occupied_boxes,
313
+ min_size_ratio=args.alphavae_min_size,
314
+ max_size_ratio=args.alphavae_max_size,
315
+ max_attempts=300,
316
+ max_overlap_ratio=0.10,
317
+ center_margin=32
318
+ )
319
+
320
+ alpha_layer = create_layer_on_canvas(alpha_img, alpha_box, canvas_size)
321
+
322
+ layer_filename = f'layer_{current_layer_idx:02d}.png'
323
+ alpha_layer.save(os.path.join(sample_output_dir, layer_filename))
324
+
325
+ composite = Image.alpha_composite(composite, alpha_layer)
326
+
327
+ w, h = get_box_size(alpha_box)
328
+ new_layers.append({
329
+ 'layer_idx': current_layer_idx,
330
+ 'caption': alpha_caption,
331
+ 'box': alpha_box,
332
+ 'width_dst': w,
333
+ 'height_dst': h,
334
+ 'image_path': layer_filename,
335
+ 'type': 'alphavae',
336
+ 'source_path': alpha_path,
337
+ })
338
+
339
+ occupied_boxes.append(alpha_box)
340
+ current_layer_idx += 1
341
+
342
+ # === Step 3: Optionally add LAION image ===
343
+ laion_caption = None
344
+ laion_path = None
345
+ if random.random() < args.laion_prob and laion_images:
346
+ laion_path, laion_caption = random.choice(laion_images)
347
+ try:
348
+ laion_img = Image.open(laion_path).convert('RGBA')
349
+ laion_orig_size = laion_img.size
350
+ except Exception as e:
351
+ logger.warning(f"Failed to load LAION image: {laion_path}, {e}")
352
+ laion_img = None
353
+
354
+ if laion_img is not None:
355
+ laion_box = compute_non_overlapping_box_xyxy(
356
+ canvas_size, occupied_boxes,
357
+ min_size_ratio=args.laion_min_size,
358
+ max_size_ratio=args.laion_max_size,
359
+ max_attempts=300,
360
+ max_overlap_ratio=0.10,
361
+ center_margin=32
362
+ )
363
+
364
+ # Create layer
365
+ laion_layer = create_layer_on_canvas(laion_img, laion_box, canvas_size)
366
+
367
+ # Save
368
+ layer_filename = f'layer_{current_layer_idx:02d}.png'
369
+ laion_layer.save(os.path.join(sample_output_dir, layer_filename))
370
+
371
+ # Composite
372
+ composite = Image.alpha_composite(composite, laion_layer)
373
+
374
+ # Record
375
+ w, h = get_box_size(laion_box)
376
+ new_layers.append({
377
+ 'layer_idx': current_layer_idx,
378
+ 'caption': laion_caption,
379
+ 'box': laion_box,
380
+ 'width_dst': w,
381
+ 'height_dst': h,
382
+ 'image_path': layer_filename,
383
+ 'type': 'laion_foreground',
384
+ 'original_size': list(laion_orig_size),
385
+ })
386
+
387
+ occupied_boxes.append(laion_box)
388
+ current_layer_idx += 1
389
+
390
+ # === Step 4: Optionally add caption image ===
391
+ caption_text = None
392
+ caption_path = None
393
+ if random.random() < args.caption_prob and caption_images:
394
+ caption_path, caption_text = random.choice(caption_images)
395
+ try:
396
+ caption_img = Image.open(caption_path).convert('RGBA')
397
+ caption_orig_size = caption_img.size
398
+ except Exception as e:
399
+ logger.warning(f"Failed to load caption image: {caption_path}, {e}")
400
+ caption_img = None
401
+
402
+ if caption_img is not None:
403
+ caption_box = compute_non_overlapping_box_xyxy(
404
+ canvas_size, occupied_boxes,
405
+ min_size_ratio=args.caption_min_size,
406
+ max_size_ratio=args.caption_max_size,
407
+ max_attempts=300,
408
+ max_overlap_ratio=0.10,
409
+ center_margin=32
410
+ )
411
+
412
+ # Create layer
413
+ caption_layer = create_layer_on_canvas(caption_img, caption_box, canvas_size)
414
+
415
+ # Compute tight bbox from actual non-transparent content
416
+ # Caption images have rectangular colored areas surrounded by transparency,
417
+ # so we use the actual content bounds instead of the placement box.
418
+ tight_box = get_content_bbox(caption_layer)
419
+ if tight_box is not None:
420
+ caption_box = tight_box
421
+
422
+ # Save
423
+ layer_filename = f'layer_{current_layer_idx:02d}.png'
424
+ caption_layer.save(os.path.join(sample_output_dir, layer_filename))
425
+
426
+ # Composite
427
+ composite = Image.alpha_composite(composite, caption_layer)
428
+
429
+ # Record
430
+ w, h = get_box_size(caption_box)
431
+ new_layers.append({
432
+ 'layer_idx': current_layer_idx,
433
+ 'caption': f"Text: {caption_text}" if caption_text else "Text",
434
+ 'box': caption_box,
435
+ 'width_dst': w,
436
+ 'height_dst': h,
437
+ 'image_path': layer_filename,
438
+ 'type': 'caption',
439
+ 'original_size': list(caption_orig_size),
440
+ })
441
+
442
+ occupied_boxes.append(caption_box)
443
+ current_layer_idx += 1
444
+
445
+ # === Step 5: Save whole_image ===
446
+ composite.save(os.path.join(sample_output_dir, 'whole_image.png'))
447
+
448
+ # === Step 6: Build spatial-aware whole caption ===
449
+ base_caption = base_meta.get('base_caption', '')
450
+ whole_caption = build_spatial_aware_caption(new_layers, canvas_size, base_caption)
451
+
452
+ # === Step 7: Create metadata ===
453
+ metadata = {
454
+ 'id': f'{sample_idx:09d}',
455
+ 'style_category': base_meta.get('style_category', ''),
456
+ 'whole_caption': whole_caption,
457
+ 'base_caption': base_caption,
458
+ 'layer_count': len(new_layers),
459
+ 'layers': new_layers,
460
+ # Extra fields
461
+ 'sample_dir': sample_name,
462
+ 'width': canvas_size,
463
+ 'height': canvas_size,
464
+ 'base_sample': base_sample_dir,
465
+ 'num_base_layers_removed': len(removed_layer_indices),
466
+ 'num_donor_samples': num_donors,
467
+ 'num_donated_layers': len(donor_layers),
468
+ }
469
+
470
+ if laion_path:
471
+ metadata['laion_path'] = laion_path
472
+ metadata['laion_caption'] = laion_caption
473
+ if caption_path:
474
+ metadata['caption_path'] = caption_path
475
+ metadata['caption_text'] = caption_text
476
+
477
+ # Save metadata
478
+ with open(os.path.join(sample_output_dir, 'metadata.json'), 'w', encoding='utf-8') as f:
479
+ json.dump(metadata, f, indent=2, ensure_ascii=False)
480
+
481
+ return metadata
482
+
483
+
484
+ def _worker_fn(task):
485
+ """Worker function for multiprocessing. Each task is a dict with all needed info."""
486
+ sample_idx = task['sample_idx']
487
+ base_sample_dir = task['base_sample_dir']
488
+ all_sample_dirs = task['all_sample_dirs']
489
+ laion_images = task['laion_images']
490
+ caption_images = task['caption_images']
491
+ alphavae_images = task['alphavae_images']
492
+ output_dir = task['output_dir']
493
+ args = task['args']
494
+ seed = task['seed']
495
+
496
+ if args.skip_existing:
497
+ meta_path = os.path.join(output_dir, f"sample_{sample_idx:06d}", 'metadata.json')
498
+ if os.path.exists(meta_path):
499
+ try:
500
+ with open(meta_path, 'r') as f:
501
+ return json.load(f)
502
+ except Exception:
503
+ pass
504
+
505
+ random.seed(seed)
506
+
507
+ try:
508
+ metadata = create_scaled_up_sample(
509
+ base_sample_dir=base_sample_dir,
510
+ all_sample_dirs=all_sample_dirs,
511
+ laion_images=laion_images,
512
+ caption_images=caption_images,
513
+ alphavae_images=alphavae_images,
514
+ output_dir=output_dir,
515
+ sample_idx=sample_idx,
516
+ args=args,
517
+ )
518
+ return metadata
519
+ except Exception as e:
520
+ logger.error(f"Error processing sample {sample_idx}: {e}")
521
+ return None
522
+
523
+
524
+ def main():
525
+ args = parse_args()
526
+ random.seed(args.seed)
527
+
528
+ os.makedirs(args.output_dir, exist_ok=True)
529
+
530
+ # Load existing blended samples
531
+ logger.info("Loading existing blended samples...")
532
+ all_sample_dirs = get_blended_sample_dirs(args.blended_dir, max_samples=args.max_base_samples)
533
+ logger.info(f"Found {len(all_sample_dirs)} existing samples"
534
+ + (f" (limited to first {args.max_base_samples})" if args.max_base_samples else ""))
535
+
536
+ if len(all_sample_dirs) < 10:
537
+ logger.error("Not enough existing samples to create scaled-up dataset!")
538
+ return
539
+
540
+ # Load caption list
541
+ logger.info("Loading caption list from captions.jsonl...")
542
+ caption_list = load_caption_list(args.caption_meta)
543
+ logger.info(f"Loaded {len(caption_list)} caption entries")
544
+
545
+ # Load LAION images (cap at 20000 for balanced diversity)
546
+ logger.info("Loading LAION images with captions...")
547
+ laion_images = get_laion_images_with_captions(args.laion_dir)
548
+ if len(laion_images) > 20000:
549
+ random.shuffle(laion_images)
550
+ laion_images = laion_images[:20000]
551
+ logger.info(f"Using {len(laion_images)} LAION images")
552
+
553
+ # Load caption images
554
+ logger.info("Loading caption images...")
555
+ caption_images = get_caption_images_with_text(args.caption_dir, caption_list)
556
+ logger.info(f"Found {len(caption_images)} caption images")
557
+
558
+ # Load AlphaVAE images (optional)
559
+ alphavae_images = [] # list of (image_path, caption)
560
+ if args.alphavae_dir and args.alphavae_prompts:
561
+ logger.info("Loading AlphaVAE images with prompts...")
562
+ with open(args.alphavae_prompts, 'r') as f:
563
+ alphavae_prompts = [l.strip() for l in f.readlines() if l.strip()]
564
+ alpha_files = sorted([
565
+ f for f in os.listdir(args.alphavae_dir)
566
+ if f.endswith('.png')
567
+ ])
568
+ for fname in alpha_files:
569
+ idx = int(fname.replace('.png', ''))
570
+ prompt_idx = idx // 5
571
+ if prompt_idx < len(alphavae_prompts):
572
+ caption = alphavae_prompts[prompt_idx]
573
+ else:
574
+ caption = ""
575
+ alphavae_images.append((os.path.join(args.alphavae_dir, fname), caption))
576
+ logger.info(f"Found {len(alphavae_images)} AlphaVAE images")
577
+
578
+ # Pre-generate tasks with deterministic per-sample seeds
579
+ rng = random.Random(args.seed)
580
+ tasks = []
581
+ for i in range(args.num_samples):
582
+ sample_idx = args.start_index + i
583
+ tasks.append({
584
+ 'sample_idx': sample_idx,
585
+ 'base_sample_dir': rng.choice(all_sample_dirs),
586
+ 'all_sample_dirs': all_sample_dirs,
587
+ 'laion_images': laion_images,
588
+ 'caption_images': caption_images,
589
+ 'alphavae_images': alphavae_images,
590
+ 'output_dir': args.output_dir,
591
+ 'args': args,
592
+ 'seed': rng.randint(0, 2**31),
593
+ })
594
+
595
+ # Generate samples
596
+ all_metadata = []
597
+ failed_count = 0
598
+ num_workers = args.num_workers
599
+
600
+ if num_workers > 0:
601
+ logger.info(f"Using multiprocessing with {num_workers} workers")
602
+ with Pool(processes=num_workers) as pool:
603
+ for metadata in tqdm(
604
+ pool.imap_unordered(_worker_fn, tasks),
605
+ total=len(tasks),
606
+ desc="Generating samples"
607
+ ):
608
+ if metadata:
609
+ all_metadata.append(metadata)
610
+ else:
611
+ failed_count += 1
612
+ else:
613
+ logger.info("Using single-process mode")
614
+ for task in tqdm(tasks, desc="Generating samples"):
615
+ metadata = _worker_fn(task)
616
+ if metadata:
617
+ all_metadata.append(metadata)
618
+ else:
619
+ failed_count += 1
620
+
621
+ # Sort by sample index for deterministic output order
622
+ all_metadata.sort(key=lambda m: int(m['id']))
623
+
624
+ # Save index (scaleup_meta.jsonl)
625
+ index_path = os.path.join(args.output_dir, 'scaleup_meta.jsonl')
626
+ with open(index_path, 'w', encoding='utf-8') as f:
627
+ for meta in all_metadata:
628
+ f.write(json.dumps(meta, ensure_ascii=False) + '\n')
629
+
630
+ logger.info(f"Generated {len(all_metadata)} samples ({failed_count} failed)")
631
+ logger.info(f"Output saved to {args.output_dir}")
632
+ logger.info(f"Index saved to {index_path}")
633
+
634
+
635
+ if __name__ == '__main__':
636
+ main()
637
+