cledouxluma commited on
Commit
550b1d5
Β·
verified Β·
1 Parent(s): 6953619

Upload data/widerface.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. data/widerface.py +205 -0
data/widerface.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WiderFace Dataset Loader.
3
+
4
+ WIDER FACE (Yang et al., 2016):
5
+ - 32,203 images, 393,703 annotated face bounding boxes
6
+ - Split: 40% train (12,880), 10% val (3,226), 50% test (labels not public)
7
+ - 3 difficulty levels: Easy, Medium, Hard
8
+ - Annotations include: bbox, blur, expression, illumination, occlusion, pose, invalid
9
+
10
+ Directory structure expected:
11
+ wider_face/
12
+ β”œβ”€β”€ WIDER_train/
13
+ β”‚ └── images/
14
+ β”‚ β”œβ”€β”€ 0--Parade/
15
+ β”‚ β”œβ”€β”€ 1--Handshaking/
16
+ β”‚ └── ...
17
+ β”œβ”€β”€ WIDER_val/
18
+ β”‚ └── images/
19
+ β”‚ └── ...
20
+ β”œβ”€β”€ wider_face_split/
21
+ β”‚ β”œβ”€β”€ wider_face_train_bbx_gt.txt
22
+ β”‚ β”œβ”€β”€ wider_face_val_bbx_gt.txt
23
+ β”‚ └── ...
24
+ └── retinaface_gt/ (optional, for landmarks)
25
+ β”œβ”€β”€ train/
26
+ β”‚ └── label.txt
27
+ └── val/
28
+ └── label.txt
29
+ """
30
+
31
+ import os
32
+ import numpy as np
33
+ import cv2
34
+ from typing import List, Dict, Optional, Tuple, Callable
35
+ import torch
36
+ from torch.utils.data import Dataset
37
+
38
+
39
+ class WiderFaceDataset(Dataset):
40
+ """
41
+ WIDER FACE dataset with support for:
42
+ - Standard WiderFace bbox annotations
43
+ - RetinaFace-format 5-point landmark annotations
44
+ - Filtering invalid/tiny faces
45
+ - On-the-fly augmentation
46
+ """
47
+
48
+ def __init__(self,
49
+ root_dir: str,
50
+ split: str = 'train',
51
+ transform: Optional[Callable] = None,
52
+ min_face_size: int = 2,
53
+ use_landmarks: bool = False,
54
+ annotation_format: str = 'widerface'):
55
+ """
56
+ Args:
57
+ root_dir: Path to wider_face/ directory
58
+ split: 'train' or 'val'
59
+ transform: Augmentation callable
60
+ min_face_size: Minimum face size to keep (pixels)
61
+ use_landmarks: Load 5-point landmarks (requires retinaface_gt/)
62
+ annotation_format: 'widerface' (standard) or 'retinaface' (with landmarks)
63
+ """
64
+ self.root_dir = root_dir
65
+ self.split = split
66
+ self.transform = transform
67
+ self.min_face_size = min_face_size
68
+ self.use_landmarks = use_landmarks
69
+
70
+ if annotation_format == 'retinaface' and use_landmarks:
71
+ self.samples = self._load_retinaface_annotations()
72
+ else:
73
+ self.samples = self._load_widerface_annotations()
74
+
75
+ print(f"[WiderFace {split}] Loaded {len(self.samples)} images")
76
+
77
+ def _load_widerface_annotations(self) -> List[Dict]:
78
+ """Load standard WiderFace bbox annotations."""
79
+ ann_file = os.path.join(
80
+ self.root_dir, 'wider_face_split',
81
+ f'wider_face_{self.split}_bbx_gt.txt'
82
+ )
83
+ img_dir = os.path.join(self.root_dir, f'WIDER_{self.split}', 'images')
84
+
85
+ samples = []
86
+ with open(ann_file, 'r') as f:
87
+ while True:
88
+ filename = f.readline().strip()
89
+ if not filename:
90
+ break
91
+
92
+ num_faces = int(f.readline().strip())
93
+ boxes = []
94
+ for _ in range(max(num_faces, 1)):
95
+ line = f.readline().strip()
96
+ parts = list(map(float, line.split()))
97
+ if num_faces == 0:
98
+ continue # Skip placeholder line for 0-face images
99
+ x, y, w, h = parts[0], parts[1], parts[2], parts[3]
100
+ # Filter tiny/invalid faces
101
+ if w < self.min_face_size or h < self.min_face_size:
102
+ continue
103
+ # Convert to x1, y1, x2, y2
104
+ boxes.append([x, y, x + w, y + h])
105
+
106
+ if boxes:
107
+ samples.append({
108
+ 'image_path': os.path.join(img_dir, filename),
109
+ 'boxes': np.array(boxes, dtype=np.float32),
110
+ 'filename': filename,
111
+ })
112
+
113
+ return samples
114
+
115
+ def _load_retinaface_annotations(self) -> List[Dict]:
116
+ """Load RetinaFace-format annotations with 5-point landmarks."""
117
+ ann_file = os.path.join(
118
+ self.root_dir, 'retinaface_gt', self.split, 'label.txt'
119
+ )
120
+ img_dir = os.path.join(self.root_dir, f'WIDER_{self.split}', 'images')
121
+
122
+ samples = []
123
+ current_file = None
124
+ current_boxes = []
125
+ current_lmks = []
126
+
127
+ with open(ann_file, 'r') as f:
128
+ for line in f:
129
+ line = line.strip()
130
+ if line.startswith('#'):
131
+ # Save previous image
132
+ if current_file and current_boxes:
133
+ samples.append({
134
+ 'image_path': os.path.join(img_dir, current_file),
135
+ 'boxes': np.array(current_boxes, dtype=np.float32),
136
+ 'landmarks': np.array(current_lmks, dtype=np.float32),
137
+ 'filename': current_file,
138
+ })
139
+ current_file = line[2:].strip()
140
+ current_boxes = []
141
+ current_lmks = []
142
+ else:
143
+ parts = list(map(float, line.split()))
144
+ if len(parts) >= 4:
145
+ x, y, w, h = parts[0], parts[1], parts[2], parts[3]
146
+ if w < self.min_face_size or h < self.min_face_size:
147
+ continue
148
+ current_boxes.append([x, y, x + w, y + h])
149
+ if len(parts) >= 14:
150
+ # 5 landmarks: (x1,y1, x2,y2, x3,y3, x4,y4, x5,y5)
151
+ lmk = parts[4:14]
152
+ current_lmks.append(lmk)
153
+ else:
154
+ current_lmks.append([-1]*10) # Invalid landmarks
155
+
156
+ # Save last image
157
+ if current_file and current_boxes:
158
+ samples.append({
159
+ 'image_path': os.path.join(img_dir, current_file),
160
+ 'boxes': np.array(current_boxes, dtype=np.float32),
161
+ 'landmarks': np.array(current_lmks, dtype=np.float32),
162
+ 'filename': current_file,
163
+ })
164
+
165
+ return samples
166
+
167
+ def __len__(self) -> int:
168
+ return len(self.samples)
169
+
170
+ def __getitem__(self, idx: int) -> Dict:
171
+ sample = self.samples[idx]
172
+
173
+ # Load image
174
+ img = cv2.imread(sample['image_path'])
175
+ if img is None:
176
+ raise IOError(f"Failed to load image: {sample['image_path']}")
177
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
178
+
179
+ boxes = sample['boxes'].copy()
180
+ landmarks = sample.get('landmarks', np.zeros((boxes.shape[0], 10), dtype=np.float32)).copy()
181
+
182
+ # Apply augmentation
183
+ if self.transform:
184
+ result = self.transform(img, boxes, landmarks)
185
+ img, boxes, landmarks = result['image'], result['boxes'], result['landmarks']
186
+
187
+ # Convert to tensors
188
+ img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float()
189
+ boxes_tensor = torch.from_numpy(boxes).float()
190
+
191
+ target = {
192
+ 'boxes': boxes_tensor,
193
+ 'labels': torch.ones(boxes_tensor.shape[0], dtype=torch.long),
194
+ }
195
+ if self.use_landmarks:
196
+ target['landmarks'] = torch.from_numpy(landmarks).float()
197
+
198
+ return img_tensor, target
199
+
200
+ @staticmethod
201
+ def collate_fn(batch):
202
+ """Custom collate for variable-length targets."""
203
+ images = torch.stack([item[0] for item in batch])
204
+ targets = [item[1] for item in batch]
205
+ return images, targets