File size: 11,410 Bytes
99f86f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
"""
DAWN Dataset Preparation Script
Downloads from HuggingFace Hub, converts to YOLO format, applies augmentation
for minority classes, and creates train/val/test splits.
"""
import os
import json
import random
import shutil
import numpy as np
from pathlib import Path
from datasets import load_dataset
from PIL import Image, ImageOps

# ─── Configuration ───────────────────────────────────────────────────
DATASET_ROOT = "/app/dawn_dataset"
SEED = 42
TRAIN_RATIO = 0.60
VAL_RATIO = 0.20
TEST_RATIO = 0.20

# Class mapping matching user's specification
CLASS_NAMES = ['Bicycle', 'Bus', 'Car', 'Motorcycle', 'Pedestrian', 'Truck']
# Map from dataset class_name to our index
CLASS_MAP = {
    'Bicycle': 0,
    'Bus': 1,
    'Car': 2,
    'Motorcycle': 3,
    'Pedestrian': 4, 'Person': 4, 'Cyclist': 4,
    'Truck': 5,
}

random.seed(SEED)
np.random.seed(SEED)


def setup_dirs():
    """Create YOLO directory structure."""
    for split in ['train', 'val', 'test']:
        os.makedirs(f"{DATASET_ROOT}/images/{split}", exist_ok=True)
        os.makedirs(f"{DATASET_ROOT}/labels/{split}", exist_ok=True)


def convert_to_yolo(objects, img_w, img_h):
    """Convert absolute bbox annotations to YOLO normalized format."""
    labels = []
    for obj in objects:
        cls_name = obj['class_name']
        if cls_name not in CLASS_MAP:
            print(f"  WARNING: Unknown class '{cls_name}', skipping")
            continue
        cls_id = CLASS_MAP[cls_name]

        x_min = obj['x_min']
        y_min = obj['y_min']
        w = obj['width']
        h = obj['height']

        # Convert to YOLO format: cx, cy, w, h (normalized)
        cx = (x_min + w / 2) / img_w
        cy = (y_min + h / 2) / img_h
        nw = w / img_w
        nh = h / img_h

        # Clip to [0, 1]
        cx = max(0, min(1, cx))
        cy = max(0, min(1, cy))
        nw = max(0, min(1, nw))
        nh = max(0, min(1, nh))

        if nw > 0.001 and nh > 0.001:  # skip degenerate boxes
            labels.append(f"{cls_id} {cx:.6f} {cy:.6f} {nw:.6f} {nh:.6f}")

    return labels


def save_image_and_label(image, labels, img_name, split):
    """Save image and YOLO label file."""
    img_path = f"{DATASET_ROOT}/images/{split}/{img_name}.jpg"
    lbl_path = f"{DATASET_ROOT}/labels/{split}/{img_name}.txt"

    if isinstance(image, Image.Image):
        image.save(img_path, quality=95)
    else:
        image.save(img_path)

    with open(lbl_path, 'w') as f:
        f.write('\n'.join(labels))


def augment_mirror(image, labels_raw, img_w, img_h):
    """Horizontal flip augmentation with bbox adjustment."""
    flipped = ImageOps.mirror(image)
    new_labels = []
    for lbl in labels_raw:
        parts = lbl.split()
        cls_id = parts[0]
        cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
        new_cx = 1.0 - cx
        new_labels.append(f"{cls_id} {new_cx:.6f} {cy:.6f} {w:.6f} {h:.6f}")
    return flipped, new_labels


def augment_rotate(image, labels_raw, img_w, img_h, angle=90):
    """Rotation augmentation (90, 180, 270 degrees) with bbox adjustment."""
    if angle == 90:
        rotated = image.transpose(Image.ROTATE_90)
        new_labels = []
        for lbl in labels_raw:
            parts = lbl.split()
            cls_id = parts[0]
            cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
            new_cx, new_cy = cy, 1.0 - cx
            new_w, new_h = h, w
            new_labels.append(f"{cls_id} {new_cx:.6f} {new_cy:.6f} {new_w:.6f} {new_h:.6f}")
    elif angle == 180:
        rotated = image.transpose(Image.ROTATE_180)
        new_labels = []
        for lbl in labels_raw:
            parts = lbl.split()
            cls_id = parts[0]
            cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
            new_cx, new_cy = 1.0 - cx, 1.0 - cy
            new_labels.append(f"{cls_id} {new_cx:.6f} {new_cy:.6f} {w:.6f} {h:.6f}")
    elif angle == 270:
        rotated = image.transpose(Image.ROTATE_270)
        new_labels = []
        for lbl in labels_raw:
            parts = lbl.split()
            cls_id = parts[0]
            cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
            new_cx, new_cy = 1.0 - cy, cx
            new_w, new_h = h, w
            new_labels.append(f"{cls_id} {new_cx:.6f} {new_cy:.6f} {new_w:.6f} {new_h:.6f}")
    else:
        return image, labels_raw

    return rotated, new_labels


def main():
    print("=" * 60)
    print("DAWN Dataset Preparation Pipeline")
    print("=" * 60)

    setup_dirs()

    # Load dataset from HF Hub
    print("\n[1/5] Loading DAWN dataset from HuggingFace Hub...")
    ds = load_dataset("Maxim37/dawn-dataset")
    print(f"  Train split: {len(ds['train'])} images")
    print(f"  Val split: {len(ds['val'])} images")

    # Combine all data for re-splitting
    all_samples = []
    for split_name in ['train', 'val']:
        for idx, sample in enumerate(ds[split_name]):
            all_samples.append(sample)

    print(f"  Total images: {len(all_samples)}")

    # ─── Phase 1: Convert all images to YOLO format ──────────────────
    print("\n[2/5] Converting annotations to YOLO format...")
    converted = []
    class_counts = {name: 0 for name in CLASS_NAMES}

    for i, sample in enumerate(all_samples):
        img = sample['image']
        if not isinstance(img, Image.Image):
            continue

        img_w = sample['width']
        img_h = sample['height']
        image_id = sample['image_id']
        objects = sample['objects']

        labels = convert_to_yolo(objects, img_w, img_h)

        if len(labels) == 0:
            continue

        # Count classes in this image
        img_classes = set()
        for lbl in labels:
            cls_id = int(lbl.split()[0])
            class_counts[CLASS_NAMES[cls_id]] += 1
            img_classes.add(cls_id)

        converted.append({
            'image': img,
            'labels': labels,
            'image_id': image_id,
            'img_classes': img_classes,
            'img_w': img_w,
            'img_h': img_h,
        })

        if (i + 1) % 100 == 0:
            print(f"  Processed {i + 1}/{len(all_samples)} images...")

    print(f"  Successfully converted: {len(converted)} images")
    print(f"\n  Class distribution (before augmentation):")
    for name, count in class_counts.items():
        print(f"    {name}: {count} instances")

    # ─── Phase 2: Identify minority classes & augment ─────────────────
    print("\n[3/5] Augmenting minority classes...")
    total_instances = sum(class_counts.values())
    mean_count = total_instances / len(CLASS_NAMES)

    # Classes below mean are minority
    minority_classes = set()
    for name, count in class_counts.items():
        if count < mean_count * 0.5:  # Less than 50% of mean
            minority_classes.add(CLASS_NAMES.index(name))
            print(f"  Minority class: {name} ({count} instances)")

    # Augment images containing minority classes
    augmented_samples = []
    for sample in converted:
        has_minority = bool(sample['img_classes'] & minority_classes)
        if has_minority:
            img = sample['image']
            labels = sample['labels']
            img_w = sample['img_w']
            img_h = sample['img_h']
            base_id = sample['image_id']

            # Mirror augmentation
            mir_img, mir_labels = augment_mirror(img, labels, img_w, img_h)
            augmented_samples.append({
                'image': mir_img,
                'labels': mir_labels,
                'image_id': f"{base_id}_mirror",
                'img_classes': sample['img_classes'],
            })

            # Rotation augmentations (90Β° and 180Β°)
            for angle in [90, 180]:
                rot_img, rot_labels = augment_rotate(img, labels, img_w, img_h, angle)
                augmented_samples.append({
                    'image': rot_img,
                    'labels': rot_labels,
                    'image_id': f"{base_id}_rot{angle}",
                    'img_classes': sample['img_classes'],
                })

    all_data = converted + augmented_samples
    print(f"  Original images: {len(converted)}")
    print(f"  Augmented images: {len(augmented_samples)}")
    print(f"  Total images: {len(all_data)}")

    # ─── Phase 3: Split into train/val/test ───────────────────────────
    print("\n[4/5] Splitting into train/val/test (60/20/20)...")
    random.shuffle(all_data)

    n = len(all_data)
    n_train = int(n * TRAIN_RATIO)
    n_val = int(n * VAL_RATIO)

    splits = {
        'train': all_data[:n_train],
        'val': all_data[n_train:n_train + n_val],
        'test': all_data[n_train + n_val:],
    }

    for split_name, split_data in splits.items():
        print(f"  {split_name}: {len(split_data)} images")

    # ─── Phase 4: Save everything ─────────────────────────────────────
    print("\n[5/5] Saving images and labels...")
    split_class_counts = {s: {n: 0 for n in CLASS_NAMES} for s in ['train', 'val', 'test']}

    for split_name, split_data in splits.items():
        for i, sample in enumerate(split_data):
            img_name = f"{split_name}_{i:05d}"
            save_image_and_label(sample['image'], sample['labels'], img_name, split_name)

            for lbl in sample['labels']:
                cls_id = int(lbl.split()[0])
                split_class_counts[split_name][CLASS_NAMES[cls_id]] += 1

            if (i + 1) % 200 == 0:
                print(f"  [{split_name}] Saved {i + 1}/{len(split_data)}")

    # Print final statistics
    print("\n" + "=" * 60)
    print("FINAL DATASET STATISTICS")
    print("=" * 60)
    for split_name in ['train', 'val', 'test']:
        print(f"\n  {split_name.upper()}:")
        for cls_name, count in split_class_counts[split_name].items():
            print(f"    {cls_name}: {count} instances")

    # ─── Create dataset YAML ─────────────────────────────────────────
    yaml_content = f"""# DAWN Dataset - Vehicle Detection in Adverse Weather
path: {DATASET_ROOT}
train: images/train
val: images/val
test: images/test

nc: {len(CLASS_NAMES)}
names: {CLASS_NAMES}
"""
    yaml_path = f"{DATASET_ROOT}/dataset.yaml"
    with open(yaml_path, 'w') as f:
        f.write(yaml_content)

    print(f"\n  Dataset YAML saved to: {yaml_path}")

    # Save metadata
    metadata = {
        'total_images': len(all_data),
        'original_images': len(converted),
        'augmented_images': len(augmented_samples),
        'splits': {s: len(d) for s, d in splits.items()},
        'class_names': CLASS_NAMES,
        'class_counts': {s: split_class_counts[s] for s in ['train', 'val', 'test']},
    }
    with open(f"{DATASET_ROOT}/metadata.json", 'w') as f:
        json.dump(metadata, f, indent=2)

    print("\nβœ… Dataset preparation complete!")
    print(f"  Root: {DATASET_ROOT}")
    return metadata


if __name__ == "__main__":
    main()