BcantCode commited on
Commit
b39769f
·
verified ·
1 Parent(s): 94cb2c0

Upload models/dataset.py

Browse files
Files changed (1) hide show
  1. models/dataset.py +374 -0
models/dataset.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PriviGaze Dataset - Synthetic Gaze Dataset Generator and MPIIGaze Loader
3
+
4
+ Since gaze datasets are not readily available on HF Hub, this module provides:
5
+ 1. A synthetic gaze dataset generator using UnityEyes-style rendering
6
+ 2. MPIIGaze dataset loader (if dataset is available locally)
7
+
8
+ The synthetic generator creates realistic face/eye crops with known gaze vectors,
9
+ enabling the teacher-student distillation pipeline to be tested end-to-end.
10
+ """
11
+
12
+ import os
13
+ import numpy as np
14
+ import torch
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from PIL import Image, ImageFilter, ImageOps, ImageEnhance
17
+ import json
18
+ from pathlib import Path
19
+ from typing import Optional, Tuple, Dict, List
20
+
21
+
22
+ class SyntheticGazeDataset(Dataset):
23
+ """Generates synthetic eye/face crops with known gaze vectors.
24
+
25
+ Creates simple but realistic eye and face patterns where the gaze direction
26
+ is encoded in the relative positions of pupil and iris within the eye crop.
27
+
28
+ This allows end-to-end testing and training of the gaze estimation pipeline
29
+ when real gaze datasets are not available.
30
+
31
+ Each sample includes:
32
+ - left_eye_rgb: [3, 112, 112] simulated eye with pupil position encoding gaze
33
+ - right_eye_rgb: [3, 112, 112]
34
+ - face_blurred_gray: [1, 224, 224] blurred grayscale face
35
+ - face_gray: [1, 224, 224] light-corrected grayscale face (for student)
36
+ - pitch: float (degrees, -90 to +90)
37
+ - yaw: float (degrees, -90 to +90)
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ num_samples: int = 50000,
43
+ img_size_eye: int = 112,
44
+ img_size_face: int = 224,
45
+ seed: int = 42,
46
+ noise_level: float = 0.1,
47
+ ):
48
+ self.num_samples = num_samples
49
+ self.img_size_eye = img_size_eye
50
+ self.img_size_face = img_size_face
51
+ self.noise_level = noise_level
52
+
53
+ # Generate all gaze angles upfront
54
+ rng = np.random.RandomState(seed)
55
+ self.pitch_angles = rng.uniform(-60, 60, num_samples).astype(np.float32)
56
+ self.yaw_angles = rng.uniform(-60, 60, num_samples).astype(np.float32)
57
+
58
+ # Generate random iris colors
59
+ self.iris_colors = rng.uniform(0.3, 0.9, (num_samples, 3)).astype(np.float32)
60
+ self.skin_colors = rng.uniform(0.4, 0.9, (num_samples, 3)).astype(np.float32)
61
+
62
+ def __len__(self):
63
+ return self.num_samples
64
+
65
+ def _generate_eye(self, pitch: float, yaw: float, iris_color: np.ndarray,
66
+ eye_idx: int = 0) -> Image.Image:
67
+ """Generate a synthetic eye image with pupil position encoding gaze.
68
+
69
+ Args:
70
+ pitch: gaze pitch angle in degrees
71
+ yaw: gaze yaw angle in degrees
72
+ iris_color: [3] RGB iris color
73
+ eye_idx: 0 for left eye, 1 for right eye
74
+
75
+ Returns:
76
+ PIL Image of size (img_size_eye, img_size_eye)
77
+ """
78
+ size = self.img_size_eye
79
+ img = np.ones((size, size, 3), dtype=np.float32) * 0.95 # White background (sclera)
80
+
81
+ # Eye oval (sclera boundary)
82
+ center_y, center_x = size // 2, size // 2
83
+ y_grid, x_grid = np.ogrid[:size, :size]
84
+
85
+ # Eye shape: oval
86
+ eye_mask = ((x_grid - center_x) ** 2 / (size * 0.35) ** 2 +
87
+ (y_grid - center_y) ** 2 / (size * 0.25) ** 2) <= 1.0
88
+
89
+ # Add slight skin around eye
90
+ skin_mask = ~eye_mask
91
+ skin_color = np.array([0.85, 0.7, 0.6]) # Default skin tone
92
+ img[skin_mask] = skin_color * 0.9 + np.random.randn(size, size)[..., None][skin_mask] * 0.02
93
+
94
+ # Iris circle
95
+ iris_radius = size * 0.18
96
+
97
+ # Pupil position: yaw moves left/right, pitch moves up/down
98
+ # Scale: max displacement = iris can move within eye oval
99
+ max_displacement = size * 0.12
100
+ pupil_dx = yaw / 90.0 * max_displacement # Positive yaw = looking right = pupil right
101
+ pupil_dy = -pitch / 90.0 * max_displacement # Positive pitch = looking up = pupil up
102
+
103
+ iris_cy = center_y + int(pupil_dy)
104
+ iris_cx = center_x + int(pupil_dx)
105
+
106
+ # Create iris mask
107
+ iris_mask = (x_grid - iris_cx) ** 2 + (y_grid - iris_cy) ** 2 <= iris_radius ** 2
108
+ iris_mask = iris_mask & eye_mask # Clip to eye boundary
109
+
110
+ # Fill iris with color
111
+ img[iris_mask] = iris_color
112
+
113
+ # Pupil (black circle in center of iris)
114
+ pupil_radius = iris_radius * 0.4
115
+ pupil_mask = (x_grid - iris_cx) ** 2 + (y_grid - iris_cy) ** 2 <= pupil_radius ** 2
116
+ img[pupil_mask] = np.array([0.05, 0.05, 0.05])
117
+
118
+ # Specular highlight (reflection)
119
+ highlight_radius = iris_radius * 0.15
120
+ highlight_cy = iris_cy - int(iris_radius * 0.3)
121
+ highlight_cx = iris_cx - int(iris_radius * 0.2)
122
+ highlight_mask = (x_grid - highlight_cx) ** 2 + (y_grid - highlight_cy) ** 2 <= highlight_radius ** 2
123
+ img[highlight_mask] = np.clip(img[highlight_mask] + 0.3, 0, 1.0)
124
+
125
+ # Eyelids (top and bottom)
126
+ eyelid_thickness = 0.15
127
+ top_lid_mask = (y_grid - center_y) / (size * 0.25) < -0.7 + eyelid_thickness
128
+ bottom_lid_mask = (y_grid - center_y) / (size * 0.25) > 0.7 - eyelid_thickness
129
+ eyelid_color = skin_color * 0.85
130
+ img[top_lid_mask & eye_mask] = eyelid_color
131
+ img[bottom_lid_mask & eye_mask] = eyelid_color
132
+
133
+ # Add noise
134
+ noise = np.random.randn(size, size, 3) * self.noise_level
135
+ img = np.clip(img + noise, 0, 1.0)
136
+
137
+ # Convert to PIL
138
+ img_uint8 = (img * 255).astype(np.uint8)
139
+ return Image.fromarray(img_uint8)
140
+
141
+ def _generate_face(self, pitch: float, yaw: float, skin_color: np.ndarray) -> Image.Image:
142
+ """Generate a simple face-like pattern.
143
+
144
+ The face contains both eyes positioned according to gaze direction,
145
+ providing the geometric information that the teacher model uses
146
+ (via blurred version) and the student must learn from directly.
147
+ """
148
+ size = self.img_size_face
149
+ img = np.ones((size, size, 3), dtype=np.float32) * skin_color
150
+
151
+ center_y, center_x = size // 2, size // 2
152
+
153
+ # Simple oval face shape
154
+ y_grid, x_grid = np.ogrid[:size, :size]
155
+ face_mask = ((x_grid - center_x) ** 2 / (size * 0.38) ** 2 +
156
+ (y_grid - center_y) ** 2 / (size * 0.45) ** 2) <= 1.0
157
+
158
+ # Background
159
+ img[~face_mask] = np.array([0.3, 0.3, 0.35])
160
+
161
+ # Eye positions on face (further apart, higher up)
162
+ left_eye_cx = center_x - int(size * 0.12)
163
+ right_eye_cx = center_x + int(size * 0.12)
164
+ eye_cy = center_y - int(size * 0.08)
165
+
166
+ # Gaze-displaced pupil positions on each eye
167
+ displacement = size * 0.02
168
+ pupil_dx = yaw / 90.0 * displacement
169
+ pupil_dy = -pitch / 90.0 * displacement
170
+
171
+ # Draw eyes on face
172
+ eye_size = size * 0.06
173
+ for eye_cx in [left_eye_cx, right_eye_cx]:
174
+ # Eye white
175
+ eye_white = (x_grid - eye_cx) ** 2 + (y_grid - eye_cy) ** 2 <= eye_size ** 2
176
+ img[eye_white] = np.array([0.95, 0.95, 0.95])
177
+
178
+ # Iris
179
+ iris_radius = eye_size * 0.5
180
+ iris_cy = eye_cy + int(pupil_dy)
181
+ iris_cx = eye_cx + int(pupil_dx)
182
+ iris = (x_grid - iris_cx) ** 2 + (y_grid - iris_cy) ** 2 <= iris_radius ** 2
183
+ img[iris] = np.array([0.3, 0.5, 0.7])
184
+
185
+ # Pupil
186
+ pupil_r = iris_radius * 0.4
187
+ pupil = (x_grid - iris_cx) ** 2 + (y_grid - iris_cy) ** 2 <= pupil_r ** 2
188
+ img[pupil] = np.array([0.05, 0.05, 0.05])
189
+
190
+ # Nose hint
191
+ nose_cx, nose_cy = center_x, center_y + int(size * 0.1)
192
+ nose = (x_grid - nose_cx) ** 2 + (y_grid - nose_cy) ** 2 <= (size * 0.03) ** 2
193
+ img[nose] = skin_color * 0.85
194
+
195
+ # Add noise
196
+ noise = np.random.randn(size, size, 3) * self.noise_level
197
+ img = np.clip(img + noise, 0, 1.0)
198
+
199
+ img_uint8 = (img * 255).astype(np.uint8)
200
+ return Image.fromarray(img_uint8)
201
+
202
+ def __getitem__(self, idx):
203
+ pitch = float(self.pitch_angles[idx])
204
+ yaw = float(self.yaw_angles[idx])
205
+ iris_color = self.iris_colors[idx]
206
+ skin_color = self.skin_colors[idx]
207
+
208
+ # Generate left and right eyes
209
+ # Left eye: slightly different iris color for realism
210
+ left_eye = self._generate_eye(pitch, yaw, iris_color, eye_idx=0)
211
+ right_eye = self._generate_eye(pitch, yaw, iris_color * 0.95, eye_idx=1)
212
+
213
+ # Generate face
214
+ face_rgb = self._generate_face(pitch, yaw, skin_color)
215
+
216
+ # Create blurred grayscale face (teacher input - only geometric info)
217
+ face_gray = ImageOps.grayscale(face_rgb)
218
+ face_blurred = face_gray.filter(ImageFilter.GaussianBlur(radius=8.0))
219
+
220
+ # Create light-corrected grayscale face (student input)
221
+ # Simulate varied lighting by adjusting brightness/contrast
222
+ enhancer = ImageEnhance.Brightness(face_gray)
223
+ face_light_corrected = enhancer.enhance(0.8 + 0.4 * np.random.random())
224
+ enhancer = ImageEnhance.Contrast(face_light_corrected)
225
+ face_light_corrected = enhancer.enhance(0.9 + 0.2 * np.random.random())
226
+
227
+ # Convert to tensors
228
+ left_eye_tensor = torch.from_numpy(np.array(left_eye)).permute(2, 0, 1).float() / 255.0
229
+ right_eye_tensor = torch.from_numpy(np.array(right_eye)).permute(2, 0, 1).float() / 255.0
230
+ face_blurred_tensor = torch.from_numpy(np.array(face_blurred)).unsqueeze(0).float() / 255.0
231
+ face_light_tensor = torch.from_numpy(np.array(face_light_corrected)).unsqueeze(0).float() / 255.0
232
+
233
+ # Normalize to [-1, 1]
234
+ left_eye_tensor = left_eye_tensor * 2 - 1
235
+ right_eye_tensor = right_eye_tensor * 2 - 1
236
+ face_blurred_tensor = face_blurred_tensor * 2 - 1
237
+ face_light_tensor = face_light_tensor * 2 - 1
238
+
239
+ return {
240
+ 'left_eye': left_eye_tensor, # [3, 112, 112]
241
+ 'right_eye': right_eye_tensor, # [3, 112, 112]
242
+ 'face_blurred_gray': face_blurred_tensor, # [1, 224, 224]
243
+ 'face_gray': face_light_tensor, # [1, 224, 224]
244
+ 'pitch': torch.tensor(pitch),
245
+ 'yaw': torch.tensor(yaw),
246
+ }
247
+
248
+
249
+ class MPIIGazeDataset(Dataset):
250
+ """Loader for MPIIGaze/MPIIFaceGaze dataset.
251
+
252
+ MPIIFaceGaze contains:
253
+ - Face images normalized to 224x224
254
+ - Left and right eye patches extracted from face images
255
+ - 3D gaze direction vectors
256
+
257
+ Dataset format: HDF5 files with keys:
258
+ - 'image': face image [224, 224, 3]
259
+ - 'left_eye': left eye patch [varies, varies, 3]
260
+ - 'right_eye': right eye patch [varies, varies, 3]
261
+ - 'gaze': gaze vector [3] (unit vector in camera coordinate system)
262
+ - 'head_pose': head rotation vector [3]
263
+ """
264
+
265
+ def __init__(
266
+ self,
267
+ data_dir: str,
268
+ split: str = 'train',
269
+ img_size_eye: int = 112,
270
+ img_size_face: int = 224,
271
+ transform=None,
272
+ ):
273
+ self.data_dir = Path(data_dir)
274
+ self.split = split
275
+ self.img_size_eye = img_size_eye
276
+ self.img_size_face = img_size_face
277
+ self.transform = transform
278
+
279
+ # Load data indices
280
+ self.samples = self._load_samples()
281
+
282
+ def _load_samples(self) -> List[Dict]:
283
+ """Load sample metadata from the dataset."""
284
+ samples = []
285
+ # Implementation depends on actual dataset format
286
+ # For MPIIGaze: scans .mat or .h5 files
287
+ # This is a placeholder - fill in based on actual data
288
+ data_path = self.data_dir / self.split
289
+ if not data_path.exists():
290
+ raise FileNotFoundError(f"Data directory not found: {data_path}")
291
+
292
+ # TODO: Implement actual MPIIGaze loading
293
+ # See: https://github.com/hysts/pytorch_mpiigaze for reference
294
+ return samples
295
+
296
+ def _gaze_to_angles(self, gaze_vector: np.ndarray) -> Tuple[float, float]:
297
+ """Convert 3D gaze direction vector to pitch/yaw angles."""
298
+ # Gaze vector is [x, y, z] in camera coordinates
299
+ # Z points forward, X right, Y down
300
+ x, y, z = gaze_vector
301
+
302
+ # Yaw: rotation around Y axis (left-right)
303
+ yaw = np.arctan2(x, z) * 180.0 / np.pi
304
+
305
+ # Pitch: rotation around X axis (up-down)
306
+ pitch = np.arctan2(-y, np.sqrt(x**2 + z**2)) * 180.0 / np.pi
307
+
308
+ return float(pitch), float(yaw)
309
+
310
+ def __len__(self):
311
+ return len(self.samples)
312
+
313
+ def __getitem__(self, idx):
314
+ # Placeholder - implement based on actual data format
315
+ raise NotImplementedError(
316
+ "MPIIGaze dataset loader requires the actual dataset files. "
317
+ "Use SyntheticGazeDataset for development and testing."
318
+ )
319
+
320
+
321
+ def create_dataloaders(
322
+ num_train: int = 40000,
323
+ num_val: int = 5000,
324
+ num_test: int = 5000,
325
+ batch_size: int = 64,
326
+ num_workers: int = 4,
327
+ seed: int = 42,
328
+ ):
329
+ """Create train/val/test dataloaders with synthetic data."""
330
+
331
+ train_dataset = SyntheticGazeDataset(
332
+ num_samples=num_train,
333
+ seed=seed,
334
+ noise_level=0.08,
335
+ )
336
+
337
+ val_dataset = SyntheticGazeDataset(
338
+ num_samples=num_val,
339
+ seed=seed + 1,
340
+ noise_level=0.05,
341
+ )
342
+
343
+ test_dataset = SyntheticGazeDataset(
344
+ num_samples=num_test,
345
+ seed=seed + 2,
346
+ noise_level=0.05,
347
+ )
348
+
349
+ train_loader = DataLoader(
350
+ train_dataset,
351
+ batch_size=batch_size,
352
+ shuffle=True,
353
+ num_workers=num_workers,
354
+ pin_memory=True,
355
+ drop_last=True,
356
+ )
357
+
358
+ val_loader = DataLoader(
359
+ val_dataset,
360
+ batch_size=batch_size,
361
+ shuffle=False,
362
+ num_workers=num_workers,
363
+ pin_memory=True,
364
+ )
365
+
366
+ test_loader = DataLoader(
367
+ test_dataset,
368
+ batch_size=batch_size,
369
+ shuffle=False,
370
+ num_workers=num_workers,
371
+ pin_memory=True,
372
+ )
373
+
374
+ return train_loader, val_loader, test_loader