Bbox-caption-8b / dataset /scaleup_utils.py
SynLayers's picture
Upload dataset/scaleup_utils.py with huggingface_hub
12e6363 verified
"""
Utility functions for scaling up the PrismLayersPro-blended dataset.
This module provides utilities for:
- Loading existing blended samples
- Computing non-overlapping bounding boxes
- Generating spatial-aware captions with position words
- Layer combination and compositing
"""
import os
import json
import random
from typing import Dict, List, Tuple, Optional
from PIL import Image
import numpy as np
def load_jsonl(path: str) -> List[Dict]:
"""Load JSONL file and return list of dictionaries."""
items = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
items.append(json.loads(line))
return items
def save_jsonl(items: List[Dict], path: str):
"""Save list of dictionaries to JSONL file."""
with open(path, 'w', encoding='utf-8') as f:
for item in items:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
def load_blended_sample(sample_dir: str) -> Optional[Dict]:
"""
Load a blended sample from its directory.
Returns metadata dict with loaded layer images.
"""
metadata_path = os.path.join(sample_dir, 'metadata.json')
if not os.path.exists(metadata_path):
return None
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Load base_image (background)
base_path = os.path.join(sample_dir, 'base_image.png')
if os.path.exists(base_path):
metadata['base_image'] = Image.open(base_path).convert('RGBA')
else:
metadata['base_image'] = None
# Load layer images
metadata['layer_images'] = {}
for layer in metadata.get('layers', []):
img_path = os.path.join(sample_dir, layer['image_path'])
if os.path.exists(img_path):
metadata['layer_images'][layer['layer_idx']] = Image.open(img_path).convert('RGBA')
# Store sample directory path
metadata['sample_path'] = sample_dir
return metadata
def get_blended_sample_dirs(blended_dir: str, max_samples: Optional[int] = None) -> List[str]:
"""
Get list of sample directories in the blended directory.
"""
sample_dirs = []
for name in sorted(os.listdir(blended_dir)):
if name.startswith('sample_') and os.path.isdir(os.path.join(blended_dir, name)):
sample_dirs.append(os.path.join(blended_dir, name))
if max_samples and len(sample_dirs) >= max_samples:
break
return sample_dirs
def compute_overlap_area(box1: List[int], box2: List[int]) -> int:
"""
Calculate the overlap area between two boxes (xyxy format).
Returns 0 if no overlap.
"""
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
# Calculate intersection
x0_i = max(x0_1, x0_2)
y0_i = max(y0_1, y0_2)
x1_i = min(x1_1, x1_2)
y1_i = min(y1_1, y1_2)
# Check if there's an intersection
if x0_i >= x1_i or y0_i >= y1_i:
return 0
return (x1_i - x0_i) * (y1_i - y0_i)
def compute_total_overlap(box: List[int], existing_boxes: List[List[int]]) -> int:
"""
Calculate total overlap area between a box and all existing boxes.
"""
total = 0
for eb in existing_boxes:
total += compute_overlap_area(box, eb)
return total
def get_position_description(box: List[int], canvas_size: int) -> str:
"""
Get position description for a bounding box.
Based on the box center point position, returns one of:
- "On the top-left"
- "On the top-right"
- "On the bottom-left"
- "On the bottom-right"
- "In the center"
- "At the top"
- "At the bottom"
- "On the left"
- "On the right"
"""
x0, y0, x1, y1 = box
center_x = (x0 + x1) / 2
center_y = (y0 + y1) / 2
# Normalize to 0-1 range
norm_x = center_x / canvas_size
norm_y = center_y / canvas_size
# Define regions (3x3 grid)
# Left: 0-0.33, Center: 0.33-0.67, Right: 0.67-1.0
# Top: 0-0.33, Middle: 0.33-0.67, Bottom: 0.67-1.0
if norm_y < 0.33:
if norm_x < 0.33:
return "On the top-left"
elif norm_x > 0.67:
return "On the top-right"
else:
return "At the top"
elif norm_y > 0.67:
if norm_x < 0.33:
return "On the bottom-left"
elif norm_x > 0.67:
return "On the bottom-right"
else:
return "At the bottom"
else:
if norm_x < 0.33:
return "On the left"
elif norm_x > 0.67:
return "On the right"
else:
return "In the center"
def build_spatial_aware_caption(layers: List[Dict], canvas_size: int, base_caption: str = "") -> str:
"""
Build a spatial-aware whole caption by adding position descriptions to each layer.
Example output:
"On the top-left, a red balloon. In the center, a clown character. At the bottom, Text: hello world."
This structured format with spatial information helps diffusion models (especially Flux with T5)
better understand the position-layer correspondence.
"""
parts = []
# Add base caption if provided (shortened version)
if base_caption:
# Take only the first sentence of base caption to keep it concise
first_sentence = base_caption.split('.')[0].strip()
if first_sentence:
parts.append(first_sentence + ".")
# Add layer descriptions with position
for layer in layers:
caption = layer.get('caption', '').strip()
if not caption:
continue
box = layer.get('box', [0, 0, canvas_size, canvas_size])
position = get_position_description(box, canvas_size)
# Clean up caption - remove leading "The picture/image features" etc.
caption_clean = caption
prefixes_to_remove = [
"The picture features ",
"The image features ",
"Text ",
]
for prefix in prefixes_to_remove:
if caption_clean.startswith(prefix):
caption_clean = caption_clean[len(prefix):]
break
# Capitalize first letter
if caption_clean:
caption_clean = caption_clean[0].upper() + caption_clean[1:] if len(caption_clean) > 1 else caption_clean.upper()
# Remove trailing period if present
caption_clean = caption_clean.rstrip('.')
parts.append(f"{position}, {caption_clean}.")
return " ".join(parts)
def compute_random_box_xyxy(
canvas_size: int,
min_size_ratio: float = 0.10,
max_size_ratio: float = 0.25,
aspect_ratio_range: Tuple[float, float] = (0.5, 2.0),
center_margin: int = 16
) -> List[int]:
"""
Compute a random bounding box in xyxy format [x0, y0, x1, y1].
Args:
canvas_size: Size of the canvas (e.g., 512)
min_size_ratio: Minimum size as ratio of canvas
max_size_ratio: Maximum size as ratio of canvas
aspect_ratio_range: Range of aspect ratios (width/height)
center_margin: Margin from edges for box center (e.g., 16 means center
must be within [16, canvas_size-16] range, i.e., 480x480 area for 512 canvas)
"""
min_size = int(canvas_size * min_size_ratio)
max_size = int(canvas_size * max_size_ratio)
# Random aspect ratio
aspect_ratio = random.uniform(*aspect_ratio_range)
if aspect_ratio >= 1.0:
w = random.randint(min_size, max_size)
h = int(w / aspect_ratio)
else:
h = random.randint(min_size, max_size)
w = int(h * aspect_ratio)
# Clamp to valid range
w = max(min_size, min(w, max_size))
h = max(min_size, min(h, max_size))
# Random center position within the allowed region (canvas_size - 2*margin)
# For 512 canvas with margin=16, center can be in [16, 496]
min_center = center_margin
max_center = canvas_size - center_margin
# Ensure we have valid range
if max_center <= min_center:
max_center = canvas_size - 1
min_center = 0
center_x = random.randint(min_center, max_center)
center_y = random.randint(min_center, max_center)
# Convert center to box coordinates
x0 = center_x - w // 2
y0 = center_y - h // 2
x1 = x0 + w
y1 = y0 + h
# Clamp to canvas bounds (box can extend to edges, just center is constrained)
x0 = max(0, x0)
y0 = max(0, y0)
x1 = min(canvas_size, x1)
y1 = min(canvas_size, y1)
return [x0, y0, x1, y1]
def compute_non_overlapping_box_xyxy(
canvas_size: int,
existing_boxes: List[List[int]],
min_size_ratio: float = 0.10,
max_size_ratio: float = 0.25,
max_attempts: int = 300,
max_overlap_ratio: float = 0.20,
center_margin: int = 16
) -> List[int]:
"""
Compute a box (xyxy) that minimizes overlap with existing boxes.
Args:
canvas_size: Size of the canvas (e.g., 512)
existing_boxes: List of existing boxes to avoid overlapping with
min_size_ratio: Minimum size as ratio of canvas
max_size_ratio: Maximum size as ratio of canvas
max_attempts: Maximum attempts to find a good position
max_overlap_ratio: Maximum acceptable overlap ratio (default 20%)
center_margin: Margin from edges for box center (default 16px, so center
is within 480x480 area for 512 canvas)
Strategy:
1. Try to find a position with no overlap
2. If not possible, accept positions with < max_overlap_ratio overlap
3. Return the position with minimum overlap
"""
best_box = None
best_overlap_ratio = float('inf')
for _ in range(max_attempts):
box = compute_random_box_xyxy(
canvas_size, min_size_ratio, max_size_ratio,
center_margin=center_margin
)
box_area = (box[2] - box[0]) * (box[3] - box[1])
if box_area <= 0:
continue
overlap = compute_total_overlap(box, existing_boxes)
overlap_ratio = overlap / box_area
# If no overlap, return immediately
if overlap == 0:
return box
# Track best box
if overlap_ratio < best_overlap_ratio:
best_overlap_ratio = overlap_ratio
best_box = box
# Accept if overlap is small enough
if overlap_ratio < max_overlap_ratio:
return box
# Return the best box found
if best_box is not None:
return best_box
# Fallback
return compute_random_box_xyxy(
canvas_size, min_size_ratio, max_size_ratio,
center_margin=center_margin
)
def create_layer_on_canvas(
layer_img: Image.Image,
box: List[int],
canvas_size: int
) -> Image.Image:
"""
Create a full-canvas RGBA image with the layer placed at box position.
Box is in xyxy format: [x0, y0, x1, y1].
Layer will have transparent background.
"""
x0, y0, x1, y1 = box
w = x1 - x0
h = y1 - y0
# Create transparent canvas
canvas = Image.new('RGBA', (canvas_size, canvas_size), (0, 0, 0, 0))
# Ensure positive dimensions
if w <= 0 or h <= 0:
return canvas
# Resize layer to fit box
layer_resized = layer_img.resize((w, h), Image.LANCZOS)
# Paste with alpha (preserving transparency)
if layer_resized.mode == 'RGBA':
canvas.paste(layer_resized, (x0, y0), layer_resized)
else:
layer_resized = layer_resized.convert('RGBA')
canvas.paste(layer_resized, (x0, y0), layer_resized)
return canvas
def get_content_bbox(img: Image.Image) -> Optional[List[int]]:
"""
Get the tight bounding box of non-transparent content in an RGBA image.
Returns [x0, y0, x1, y1] or None if the image is fully transparent.
"""
arr = np.array(img.convert('RGBA'))
alpha = arr[:, :, 3]
rows = np.any(alpha > 0, axis=1)
cols = np.any(alpha > 0, axis=0)
if not rows.any() or not cols.any():
return None
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
return [int(cmin), int(rmin), int(cmax + 1), int(rmax + 1)]
def get_box_size(box: List[int]) -> Tuple[int, int]:
"""Get width and height from xyxy box."""
x0, y0, x1, y1 = box
return (x1 - x0, y1 - y0)
def load_caption_list(caption_jsonl: str) -> List[Dict]:
"""
Load captions.jsonl as a list (ordered by line number).
"""
return load_jsonl(caption_jsonl)
def get_laion_caption_from_json(image_path: str) -> str:
"""
Get LAION image caption from its corresponding .json file.
"""
json_path = image_path.rsplit('.', 1)[0] + '.json'
if os.path.exists(json_path):
try:
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data.get('caption', '')
except Exception:
pass
return os.path.basename(image_path).rsplit('.', 1)[0]
def get_laion_images_with_captions(laion_dir: str, laion_jsonl: Optional[str] = None) -> List[Tuple[str, str]]:
"""
Get all LAION images with their captions.
"""
images = []
for subdir in sorted(os.listdir(laion_dir)):
subdir_path = os.path.join(laion_dir, subdir)
if os.path.isdir(subdir_path):
for fname in sorted(os.listdir(subdir_path)):
if fname.endswith(('.jpg', '.jpeg', '.png')):
img_path = os.path.join(subdir_path, fname)
caption = get_laion_caption_from_json(img_path)
images.append((img_path, caption))
return images
def get_caption_images_with_text(caption_dir: str, caption_list: List[Dict]) -> List[Tuple[str, str]]:
"""
Get caption images with their text content.
"""
images = []
for fname in sorted(os.listdir(caption_dir)):
if fname.endswith('.png'):
img_path = os.path.join(caption_dir, fname)
idx_str = fname.split('.')[0]
try:
idx = int(idx_str)
except ValueError:
idx = -1
caption_text = ""
if 0 <= idx < len(caption_list):
caption_text = caption_list[idx].get('caption', '')
images.append((img_path, caption_text))
return images
def extract_layer_from_sample(
sample_metadata: Dict,
layer_idx: int
) -> Optional[Tuple[Image.Image, Dict]]:
"""
Extract a specific layer from a sample.
Returns (layer_image, layer_info) or None if not found.
"""
layer_images = sample_metadata.get('layer_images', {})
if layer_idx not in layer_images:
return None
# Find layer info
for layer in sample_metadata.get('layers', []):
if layer['layer_idx'] == layer_idx:
return (layer_images[layer_idx], layer.copy())
return None
def select_random_layers_from_samples(
sample_dirs: List[str],
exclude_sample: str,
num_samples_to_pick: int = 2,
num_layers_per_sample: Tuple[int, int] = (1, 2)
) -> List[Tuple[Image.Image, Dict, str]]:
"""
Select random layers from random samples.
Args:
sample_dirs: List of all sample directories
exclude_sample: Sample directory to exclude (the base sample)
num_samples_to_pick: Number of different samples to pick from (2-3)
num_layers_per_sample: Range of layers to pick from each sample (min, max)
Returns:
List of (layer_image, layer_info, source_sample) tuples
"""
# Filter out the base sample
available_samples = [s for s in sample_dirs if s != exclude_sample]
if len(available_samples) < num_samples_to_pick:
num_samples_to_pick = len(available_samples)
# Randomly select samples
selected_samples = random.sample(available_samples, num_samples_to_pick)
collected_layers = []
for sample_dir in selected_samples:
# Load sample
sample_meta = load_blended_sample(sample_dir)
if sample_meta is None:
continue
# Get available layers (excluding laion_foreground and caption types to avoid duplicates)
layers = sample_meta.get('layers', [])
prism_layers = [l for l in layers if l.get('type') is None] # Original prism layers only
if not prism_layers:
continue
# Randomly select how many layers to pick
min_layers, max_layers = num_layers_per_sample
num_to_pick = random.randint(min_layers, min(max_layers, len(prism_layers)))
# Select random layers
selected_layers = random.sample(prism_layers, num_to_pick)
for layer_info in selected_layers:
layer_idx = layer_info['layer_idx']
layer_img = sample_meta.get('layer_images', {}).get(layer_idx)
if layer_img is not None:
collected_layers.append((layer_img, layer_info.copy(), sample_dir))
return collected_layers