dawn-yolo-wbf-ensemble / prepare_data.py
AmeenAktharT's picture
Upload standalone data preparation script
99f86f1 verified
"""
DAWN Dataset Preparation Script
Downloads from HuggingFace Hub, converts to YOLO format, applies augmentation
for minority classes, and creates train/val/test splits.
"""
import os
import json
import random
import shutil
import numpy as np
from pathlib import Path
from datasets import load_dataset
from PIL import Image, ImageOps
# ─── Configuration ───────────────────────────────────────────────────
DATASET_ROOT = "/app/dawn_dataset"
SEED = 42
TRAIN_RATIO = 0.60
VAL_RATIO = 0.20
TEST_RATIO = 0.20
# Class mapping matching user's specification
CLASS_NAMES = ['Bicycle', 'Bus', 'Car', 'Motorcycle', 'Pedestrian', 'Truck']
# Map from dataset class_name to our index
CLASS_MAP = {
'Bicycle': 0,
'Bus': 1,
'Car': 2,
'Motorcycle': 3,
'Pedestrian': 4, 'Person': 4, 'Cyclist': 4,
'Truck': 5,
}
random.seed(SEED)
np.random.seed(SEED)
def setup_dirs():
"""Create YOLO directory structure."""
for split in ['train', 'val', 'test']:
os.makedirs(f"{DATASET_ROOT}/images/{split}", exist_ok=True)
os.makedirs(f"{DATASET_ROOT}/labels/{split}", exist_ok=True)
def convert_to_yolo(objects, img_w, img_h):
"""Convert absolute bbox annotations to YOLO normalized format."""
labels = []
for obj in objects:
cls_name = obj['class_name']
if cls_name not in CLASS_MAP:
print(f" WARNING: Unknown class '{cls_name}', skipping")
continue
cls_id = CLASS_MAP[cls_name]
x_min = obj['x_min']
y_min = obj['y_min']
w = obj['width']
h = obj['height']
# Convert to YOLO format: cx, cy, w, h (normalized)
cx = (x_min + w / 2) / img_w
cy = (y_min + h / 2) / img_h
nw = w / img_w
nh = h / img_h
# Clip to [0, 1]
cx = max(0, min(1, cx))
cy = max(0, min(1, cy))
nw = max(0, min(1, nw))
nh = max(0, min(1, nh))
if nw > 0.001 and nh > 0.001: # skip degenerate boxes
labels.append(f"{cls_id} {cx:.6f} {cy:.6f} {nw:.6f} {nh:.6f}")
return labels
def save_image_and_label(image, labels, img_name, split):
"""Save image and YOLO label file."""
img_path = f"{DATASET_ROOT}/images/{split}/{img_name}.jpg"
lbl_path = f"{DATASET_ROOT}/labels/{split}/{img_name}.txt"
if isinstance(image, Image.Image):
image.save(img_path, quality=95)
else:
image.save(img_path)
with open(lbl_path, 'w') as f:
f.write('\n'.join(labels))
def augment_mirror(image, labels_raw, img_w, img_h):
"""Horizontal flip augmentation with bbox adjustment."""
flipped = ImageOps.mirror(image)
new_labels = []
for lbl in labels_raw:
parts = lbl.split()
cls_id = parts[0]
cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
new_cx = 1.0 - cx
new_labels.append(f"{cls_id} {new_cx:.6f} {cy:.6f} {w:.6f} {h:.6f}")
return flipped, new_labels
def augment_rotate(image, labels_raw, img_w, img_h, angle=90):
"""Rotation augmentation (90, 180, 270 degrees) with bbox adjustment."""
if angle == 90:
rotated = image.transpose(Image.ROTATE_90)
new_labels = []
for lbl in labels_raw:
parts = lbl.split()
cls_id = parts[0]
cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
new_cx, new_cy = cy, 1.0 - cx
new_w, new_h = h, w
new_labels.append(f"{cls_id} {new_cx:.6f} {new_cy:.6f} {new_w:.6f} {new_h:.6f}")
elif angle == 180:
rotated = image.transpose(Image.ROTATE_180)
new_labels = []
for lbl in labels_raw:
parts = lbl.split()
cls_id = parts[0]
cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
new_cx, new_cy = 1.0 - cx, 1.0 - cy
new_labels.append(f"{cls_id} {new_cx:.6f} {new_cy:.6f} {w:.6f} {h:.6f}")
elif angle == 270:
rotated = image.transpose(Image.ROTATE_270)
new_labels = []
for lbl in labels_raw:
parts = lbl.split()
cls_id = parts[0]
cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
new_cx, new_cy = 1.0 - cy, cx
new_w, new_h = h, w
new_labels.append(f"{cls_id} {new_cx:.6f} {new_cy:.6f} {new_w:.6f} {new_h:.6f}")
else:
return image, labels_raw
return rotated, new_labels
def main():
print("=" * 60)
print("DAWN Dataset Preparation Pipeline")
print("=" * 60)
setup_dirs()
# Load dataset from HF Hub
print("\n[1/5] Loading DAWN dataset from HuggingFace Hub...")
ds = load_dataset("Maxim37/dawn-dataset")
print(f" Train split: {len(ds['train'])} images")
print(f" Val split: {len(ds['val'])} images")
# Combine all data for re-splitting
all_samples = []
for split_name in ['train', 'val']:
for idx, sample in enumerate(ds[split_name]):
all_samples.append(sample)
print(f" Total images: {len(all_samples)}")
# ─── Phase 1: Convert all images to YOLO format ──────────────────
print("\n[2/5] Converting annotations to YOLO format...")
converted = []
class_counts = {name: 0 for name in CLASS_NAMES}
for i, sample in enumerate(all_samples):
img = sample['image']
if not isinstance(img, Image.Image):
continue
img_w = sample['width']
img_h = sample['height']
image_id = sample['image_id']
objects = sample['objects']
labels = convert_to_yolo(objects, img_w, img_h)
if len(labels) == 0:
continue
# Count classes in this image
img_classes = set()
for lbl in labels:
cls_id = int(lbl.split()[0])
class_counts[CLASS_NAMES[cls_id]] += 1
img_classes.add(cls_id)
converted.append({
'image': img,
'labels': labels,
'image_id': image_id,
'img_classes': img_classes,
'img_w': img_w,
'img_h': img_h,
})
if (i + 1) % 100 == 0:
print(f" Processed {i + 1}/{len(all_samples)} images...")
print(f" Successfully converted: {len(converted)} images")
print(f"\n Class distribution (before augmentation):")
for name, count in class_counts.items():
print(f" {name}: {count} instances")
# ─── Phase 2: Identify minority classes & augment ─────────────────
print("\n[3/5] Augmenting minority classes...")
total_instances = sum(class_counts.values())
mean_count = total_instances / len(CLASS_NAMES)
# Classes below mean are minority
minority_classes = set()
for name, count in class_counts.items():
if count < mean_count * 0.5: # Less than 50% of mean
minority_classes.add(CLASS_NAMES.index(name))
print(f" Minority class: {name} ({count} instances)")
# Augment images containing minority classes
augmented_samples = []
for sample in converted:
has_minority = bool(sample['img_classes'] & minority_classes)
if has_minority:
img = sample['image']
labels = sample['labels']
img_w = sample['img_w']
img_h = sample['img_h']
base_id = sample['image_id']
# Mirror augmentation
mir_img, mir_labels = augment_mirror(img, labels, img_w, img_h)
augmented_samples.append({
'image': mir_img,
'labels': mir_labels,
'image_id': f"{base_id}_mirror",
'img_classes': sample['img_classes'],
})
# Rotation augmentations (90Β° and 180Β°)
for angle in [90, 180]:
rot_img, rot_labels = augment_rotate(img, labels, img_w, img_h, angle)
augmented_samples.append({
'image': rot_img,
'labels': rot_labels,
'image_id': f"{base_id}_rot{angle}",
'img_classes': sample['img_classes'],
})
all_data = converted + augmented_samples
print(f" Original images: {len(converted)}")
print(f" Augmented images: {len(augmented_samples)}")
print(f" Total images: {len(all_data)}")
# ─── Phase 3: Split into train/val/test ───────────────────────────
print("\n[4/5] Splitting into train/val/test (60/20/20)...")
random.shuffle(all_data)
n = len(all_data)
n_train = int(n * TRAIN_RATIO)
n_val = int(n * VAL_RATIO)
splits = {
'train': all_data[:n_train],
'val': all_data[n_train:n_train + n_val],
'test': all_data[n_train + n_val:],
}
for split_name, split_data in splits.items():
print(f" {split_name}: {len(split_data)} images")
# ─── Phase 4: Save everything ─────────────────────────────────────
print("\n[5/5] Saving images and labels...")
split_class_counts = {s: {n: 0 for n in CLASS_NAMES} for s in ['train', 'val', 'test']}
for split_name, split_data in splits.items():
for i, sample in enumerate(split_data):
img_name = f"{split_name}_{i:05d}"
save_image_and_label(sample['image'], sample['labels'], img_name, split_name)
for lbl in sample['labels']:
cls_id = int(lbl.split()[0])
split_class_counts[split_name][CLASS_NAMES[cls_id]] += 1
if (i + 1) % 200 == 0:
print(f" [{split_name}] Saved {i + 1}/{len(split_data)}")
# Print final statistics
print("\n" + "=" * 60)
print("FINAL DATASET STATISTICS")
print("=" * 60)
for split_name in ['train', 'val', 'test']:
print(f"\n {split_name.upper()}:")
for cls_name, count in split_class_counts[split_name].items():
print(f" {cls_name}: {count} instances")
# ─── Create dataset YAML ─────────────────────────────────────────
yaml_content = f"""# DAWN Dataset - Vehicle Detection in Adverse Weather
path: {DATASET_ROOT}
train: images/train
val: images/val
test: images/test
nc: {len(CLASS_NAMES)}
names: {CLASS_NAMES}
"""
yaml_path = f"{DATASET_ROOT}/dataset.yaml"
with open(yaml_path, 'w') as f:
f.write(yaml_content)
print(f"\n Dataset YAML saved to: {yaml_path}")
# Save metadata
metadata = {
'total_images': len(all_data),
'original_images': len(converted),
'augmented_images': len(augmented_samples),
'splits': {s: len(d) for s, d in splits.items()},
'class_names': CLASS_NAMES,
'class_counts': {s: split_class_counts[s] for s in ['train', 'val', 'test']},
}
with open(f"{DATASET_ROOT}/metadata.json", 'w') as f:
json.dump(metadata, f, indent=2)
print("\nβœ… Dataset preparation complete!")
print(f" Root: {DATASET_ROOT}")
return metadata
if __name__ == "__main__":
main()