File size: 7,618 Bytes
550b1d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""
WiderFace Dataset Loader.

WIDER FACE (Yang et al., 2016):
- 32,203 images, 393,703 annotated face bounding boxes
- Split: 40% train (12,880), 10% val (3,226), 50% test (labels not public)
- 3 difficulty levels: Easy, Medium, Hard
- Annotations include: bbox, blur, expression, illumination, occlusion, pose, invalid

Directory structure expected:
    wider_face/
    β”œβ”€β”€ WIDER_train/
    β”‚   └── images/
    β”‚       β”œβ”€β”€ 0--Parade/
    β”‚       β”œβ”€β”€ 1--Handshaking/
    β”‚       └── ...
    β”œβ”€β”€ WIDER_val/
    β”‚   └── images/
    β”‚       └── ...
    β”œβ”€β”€ wider_face_split/
    β”‚   β”œβ”€β”€ wider_face_train_bbx_gt.txt
    β”‚   β”œβ”€β”€ wider_face_val_bbx_gt.txt
    β”‚   └── ...
    └── retinaface_gt/  (optional, for landmarks)
        β”œβ”€β”€ train/
        β”‚   └── label.txt
        └── val/
            └── label.txt
"""

import os
import numpy as np
import cv2
from typing import List, Dict, Optional, Tuple, Callable
import torch
from torch.utils.data import Dataset


class WiderFaceDataset(Dataset):
    """
    WIDER FACE dataset with support for:
    - Standard WiderFace bbox annotations
    - RetinaFace-format 5-point landmark annotations
    - Filtering invalid/tiny faces
    - On-the-fly augmentation
    """

    def __init__(self,
                 root_dir: str,
                 split: str = 'train',
                 transform: Optional[Callable] = None,
                 min_face_size: int = 2,
                 use_landmarks: bool = False,
                 annotation_format: str = 'widerface'):
        """
        Args:
            root_dir: Path to wider_face/ directory
            split: 'train' or 'val'
            transform: Augmentation callable
            min_face_size: Minimum face size to keep (pixels)
            use_landmarks: Load 5-point landmarks (requires retinaface_gt/)
            annotation_format: 'widerface' (standard) or 'retinaface' (with landmarks)
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.min_face_size = min_face_size
        self.use_landmarks = use_landmarks

        if annotation_format == 'retinaface' and use_landmarks:
            self.samples = self._load_retinaface_annotations()
        else:
            self.samples = self._load_widerface_annotations()

        print(f"[WiderFace {split}] Loaded {len(self.samples)} images")

    def _load_widerface_annotations(self) -> List[Dict]:
        """Load standard WiderFace bbox annotations."""
        ann_file = os.path.join(
            self.root_dir, 'wider_face_split',
            f'wider_face_{self.split}_bbx_gt.txt'
        )
        img_dir = os.path.join(self.root_dir, f'WIDER_{self.split}', 'images')

        samples = []
        with open(ann_file, 'r') as f:
            while True:
                filename = f.readline().strip()
                if not filename:
                    break

                num_faces = int(f.readline().strip())
                boxes = []
                for _ in range(max(num_faces, 1)):
                    line = f.readline().strip()
                    parts = list(map(float, line.split()))
                    if num_faces == 0:
                        continue  # Skip placeholder line for 0-face images
                    x, y, w, h = parts[0], parts[1], parts[2], parts[3]
                    # Filter tiny/invalid faces
                    if w < self.min_face_size or h < self.min_face_size:
                        continue
                    # Convert to x1, y1, x2, y2
                    boxes.append([x, y, x + w, y + h])

                if boxes:
                    samples.append({
                        'image_path': os.path.join(img_dir, filename),
                        'boxes': np.array(boxes, dtype=np.float32),
                        'filename': filename,
                    })

        return samples

    def _load_retinaface_annotations(self) -> List[Dict]:
        """Load RetinaFace-format annotations with 5-point landmarks."""
        ann_file = os.path.join(
            self.root_dir, 'retinaface_gt', self.split, 'label.txt'
        )
        img_dir = os.path.join(self.root_dir, f'WIDER_{self.split}', 'images')

        samples = []
        current_file = None
        current_boxes = []
        current_lmks = []

        with open(ann_file, 'r') as f:
            for line in f:
                line = line.strip()
                if line.startswith('#'):
                    # Save previous image
                    if current_file and current_boxes:
                        samples.append({
                            'image_path': os.path.join(img_dir, current_file),
                            'boxes': np.array(current_boxes, dtype=np.float32),
                            'landmarks': np.array(current_lmks, dtype=np.float32),
                            'filename': current_file,
                        })
                    current_file = line[2:].strip()
                    current_boxes = []
                    current_lmks = []
                else:
                    parts = list(map(float, line.split()))
                    if len(parts) >= 4:
                        x, y, w, h = parts[0], parts[1], parts[2], parts[3]
                        if w < self.min_face_size or h < self.min_face_size:
                            continue
                        current_boxes.append([x, y, x + w, y + h])
                        if len(parts) >= 14:
                            # 5 landmarks: (x1,y1, x2,y2, x3,y3, x4,y4, x5,y5)
                            lmk = parts[4:14]
                            current_lmks.append(lmk)
                        else:
                            current_lmks.append([-1]*10)  # Invalid landmarks

        # Save last image
        if current_file and current_boxes:
            samples.append({
                'image_path': os.path.join(img_dir, current_file),
                'boxes': np.array(current_boxes, dtype=np.float32),
                'landmarks': np.array(current_lmks, dtype=np.float32),
                'filename': current_file,
            })

        return samples

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict:
        sample = self.samples[idx]

        # Load image
        img = cv2.imread(sample['image_path'])
        if img is None:
            raise IOError(f"Failed to load image: {sample['image_path']}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        boxes = sample['boxes'].copy()
        landmarks = sample.get('landmarks', np.zeros((boxes.shape[0], 10), dtype=np.float32)).copy()

        # Apply augmentation
        if self.transform:
            result = self.transform(img, boxes, landmarks)
            img, boxes, landmarks = result['image'], result['boxes'], result['landmarks']

        # Convert to tensors
        img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float()
        boxes_tensor = torch.from_numpy(boxes).float()

        target = {
            'boxes': boxes_tensor,
            'labels': torch.ones(boxes_tensor.shape[0], dtype=torch.long),
        }
        if self.use_landmarks:
            target['landmarks'] = torch.from_numpy(landmarks).float()

        return img_tensor, target

    @staticmethod
    def collate_fn(batch):
        """Custom collate for variable-length targets."""
        images = torch.stack([item[0] for item in batch])
        targets = [item[1] for item in batch]
        return images, targets