BcantCode commited on
Commit
0607636
·
verified ·
1 Parent(s): e97351b

Upload models/student.py

Browse files
Files changed (1) hide show
  1. models/student.py +313 -0
models/student.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PriviGaze Student Model - Ultra-Compact Gaze Estimation CNN
3
+
4
+ Design Philosophy:
5
+ - ~50-100K parameters total (target: fit on microcontrollers with TinyML)
6
+ - Inception blocks with factorized convolutions (exploits eye biology)
7
+ - Only takes light-corrected grayscale face as input (no eye crops)
8
+ - Designed for on-device inference with <1ms latency
9
+
10
+ Architecture inspired by:
11
+ - "One Eye is All You Need" (Athavale et al., 2022) - Inception for gaze
12
+ - DFT Gaze from GazeGen (281K params, distillation from 10x larger teacher)
13
+ - Eye biology: horizontal/vertical edge detectors mimic ocular muscle structure
14
+
15
+ Key design choices for disability support:
16
+ - Grayscale only: robust to varied lighting/occlusion
17
+ - Large receptive field: handles droopy eyes, head roll
18
+ - Factorized convolutions: 1x3 + 3x1 instead of 3x3 for efficiency
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+
26
+ def conv_bn_relu(in_ch, out_ch, kernel_size, stride=1, padding=0, groups=1):
27
+ """Standard Conv-BN-ReLU block."""
28
+ return nn.Sequential(
29
+ nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, groups=groups, bias=False),
30
+ nn.BatchNorm2d(out_ch),
31
+ nn.ReLU(inplace=True),
32
+ )
33
+
34
+
35
+ def depthwise_separable_conv(in_ch, out_ch, kernel_size, stride=1, padding=0):
36
+ """Depthwise separable convolution for extreme parameter efficiency."""
37
+ return nn.Sequential(
38
+ # Depthwise
39
+ nn.Conv2d(in_ch, in_ch, kernel_size, stride, padding, groups=in_ch, bias=False),
40
+ nn.BatchNorm2d(in_ch),
41
+ nn.ReLU(inplace=True),
42
+ # Pointwise
43
+ nn.Conv2d(in_ch, out_ch, 1, bias=False),
44
+ nn.BatchNorm2d(out_ch),
45
+ nn.ReLU(inplace=True),
46
+ )
47
+
48
+
49
+ class FactorizedConvBlock(nn.Module):
50
+ """Factorized 3x3 into 1x3 + 3x1 for efficiency.
51
+
52
+ This mimics the ocular muscle structure: horizontal rectus (1x3)
53
+ and vertical rectus (3x1) for detecting gaze direction.
54
+ """
55
+
56
+ def __init__(self, in_ch, out_ch, stride=1):
57
+ super().__init__()
58
+ mid_ch = out_ch // 2
59
+
60
+ self.horizontal = nn.Sequential(
61
+ nn.Conv2d(in_ch, mid_ch, (1, 3), stride, (0, 1), bias=False),
62
+ nn.BatchNorm2d(mid_ch),
63
+ nn.ReLU(inplace=True),
64
+ )
65
+
66
+ self.vertical = nn.Sequential(
67
+ nn.Conv2d(in_ch, out_ch - mid_ch, (3, 1), stride, (1, 0), bias=False),
68
+ nn.BatchNorm2d(out_ch - mid_ch),
69
+ nn.ReLU(inplace=True),
70
+ )
71
+
72
+ def forward(self, x):
73
+ h = self.horizontal(x)
74
+ v = self.vertical(x)
75
+ return torch.cat([h, v], dim=1)
76
+
77
+
78
+ class InceptionBlock(nn.Module):
79
+ """Lightweight inception block with factorized convolutions.
80
+
81
+ Branches:
82
+ 1. 1x1 conv (pointwise - captures color/illumination)
83
+ 2. Factorized 1x3 + 3x1 (edge detection in H/V directions)
84
+ 3. MaxPool + 1x1 (spatial context)
85
+ 4. 2x 3x3 (standard feature extraction)
86
+
87
+ All branches use depthwise separable convolutions for efficiency.
88
+ """
89
+
90
+ def __init__(self, in_ch, out_ch):
91
+ super().__init__()
92
+ # Each branch outputs out_ch // 4
93
+ branch_ch = max(out_ch // 4, 8)
94
+
95
+ # Branch 1: 1x1 pointwise
96
+ self.branch1 = nn.Sequential(
97
+ nn.Conv2d(in_ch, branch_ch, 1, bias=False),
98
+ nn.BatchNorm2d(branch_ch),
99
+ nn.ReLU(inplace=True),
100
+ )
101
+
102
+ # Branch 2: Factorized 1x3 + 3x1
103
+ self.branch2 = FactorizedConvBlock(in_ch, branch_ch)
104
+
105
+ # Branch 3: MaxPool + 1x1
106
+ self.branch3 = nn.Sequential(
107
+ nn.MaxPool2d(3, stride=1, padding=1),
108
+ nn.Conv2d(in_ch, branch_ch, 1, bias=False),
109
+ nn.BatchNorm2d(branch_ch),
110
+ nn.ReLU(inplace=True),
111
+ )
112
+
113
+ # Branch 4: Two stacked 3x3 depthwise separable
114
+ self.branch4 = nn.Sequential(
115
+ depthwise_separable_conv(in_ch, branch_ch, 3, padding=1),
116
+ depthwise_separable_conv(branch_ch, branch_ch, 3, padding=1),
117
+ )
118
+
119
+ # Final fusion
120
+ total_ch = branch_ch * 4
121
+ self.fusion = nn.Sequential(
122
+ nn.Conv2d(total_ch, out_ch, 1, bias=False),
123
+ nn.BatchNorm2d(out_ch),
124
+ nn.ReLU(inplace=True),
125
+ )
126
+
127
+ def forward(self, x):
128
+ b1 = self.branch1(x)
129
+ b2 = self.branch2(x)
130
+ b3 = self.branch3(x)
131
+ b4 = self.branch4(x)
132
+ out = torch.cat([b1, b2, b3, b4], dim=1)
133
+ return self.fusion(out)
134
+
135
+
136
+ class LightCorrection(nn.Module):
137
+ """Learnable light correction for grayscale face input.
138
+
139
+ Applies per-channel (single channel for grayscale) affine transform
140
+ to normalize lighting variations. This is critical for disability support
141
+ where users may be in varied lighting conditions.
142
+ """
143
+
144
+ def __init__(self):
145
+ super().__init__()
146
+ # Learnable gamma correction parameter
147
+ self.gamma = nn.Parameter(torch.ones(1))
148
+ self.alpha = nn.Parameter(torch.ones(1))
149
+ self.beta = nn.Parameter(torch.zeros(1))
150
+
151
+ def forward(self, x):
152
+ # x: [B, 1, H, W]
153
+ # Apply gamma correction
154
+ x = torch.pow(x.clamp(min=1e-6), self.gamma)
155
+ # Apply affine transform
156
+ x = self.alpha * x + self.beta
157
+ return x
158
+
159
+
160
+ class PriviGazeStudent(nn.Module):
161
+ """Ultra-compact student model for on-device gaze estimation.
162
+
163
+ Input:
164
+ - face_gray: [B, 1, 224, 224] Light-corrected grayscale face
165
+
166
+ Output:
167
+ - pitch_pred: [B] gaze pitch in degrees
168
+ - yaw_pred: [B] gaze yaw in degrees
169
+ - features: [B, 128] feature representation (for distillation matching)
170
+
171
+ Target: ~80K parameters, <1ms inference on mobile
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ input_channels: int = 1, # Grayscale
177
+ feature_dim: int = 128,
178
+ gaze_bins: int = 90,
179
+ ):
180
+ super().__init__()
181
+
182
+ # Light correction
183
+ self.light_correction = LightCorrection()
184
+
185
+ # Stem: initial feature extraction
186
+ self.stem = nn.Sequential(
187
+ nn.Conv2d(input_channels, 32, 3, stride=2, padding=1, bias=False),
188
+ nn.BatchNorm2d(32),
189
+ nn.ReLU(inplace=True),
190
+ # 32 x 112 x 112
191
+ depthwise_separable_conv(32, 32, 3, stride=2, padding=1),
192
+ # 32 x 56 x 56
193
+ )
194
+
195
+ # Stage 1: Inception blocks at high resolution
196
+ self.stage1 = nn.Sequential(
197
+ InceptionBlock(32, 64),
198
+ # 64 x 56 x 56
199
+ depthwise_separable_conv(64, 64, 3, stride=2, padding=1),
200
+ # 64 x 28 x 28
201
+ )
202
+
203
+ # Stage 2: Deeper features
204
+ self.stage2 = nn.Sequential(
205
+ InceptionBlock(64, 96),
206
+ # 96 x 28 x 28
207
+ depthwise_separable_conv(96, 96, 3, stride=2, padding=1),
208
+ # 96 x 14 x 14
209
+ )
210
+
211
+ # Stage 3: Abstract features
212
+ self.stage3 = nn.Sequential(
213
+ InceptionBlock(96, 128),
214
+ # 128 x 14 x 14
215
+ depthwise_separable_conv(128, 128, 3, stride=2, padding=1),
216
+ # 128 x 7 x 7
217
+ )
218
+
219
+ # Stage 4: Global context
220
+ self.stage4 = nn.Sequential(
221
+ InceptionBlock(128, 160),
222
+ # 160 x 7 x 7
223
+ nn.AdaptiveAvgPool2d(1),
224
+ # 160 x 1 x 1
225
+ )
226
+
227
+ # Feature projection
228
+ self.feature_projection = nn.Sequential(
229
+ nn.Flatten(),
230
+ nn.Linear(160, feature_dim),
231
+ nn.GELU(),
232
+ nn.LayerNorm(feature_dim),
233
+ )
234
+
235
+ # Gaze regression heads (L2CS-Net style: per-angle binned regression)
236
+ self.pitch_head = nn.Sequential(
237
+ nn.Linear(feature_dim, feature_dim // 2),
238
+ nn.GELU(),
239
+ nn.Linear(feature_dim // 2, gaze_bins),
240
+ )
241
+
242
+ self.yaw_head = nn.Sequential(
243
+ nn.Linear(feature_dim, feature_dim // 2),
244
+ nn.GELU(),
245
+ nn.Linear(feature_dim // 2, gaze_bins),
246
+ )
247
+
248
+ # Bin centers for expectation-based regression
249
+ self.register_buffer(
250
+ 'bin_centers',
251
+ torch.linspace(-90.0, 90.0, gaze_bins)
252
+ )
253
+
254
+ self.feature_dim = feature_dim
255
+ self.gaze_bins = gaze_bins
256
+
257
+ def forward(self, face_gray):
258
+ """
259
+ Args:
260
+ face_gray: [B, 1, 224, 224] grayscale face image
261
+
262
+ Returns:
263
+ pitch_pred: [B]
264
+ yaw_pred: [B]
265
+ features: [B, feature_dim]
266
+ """
267
+ # Light correction
268
+ x = self.light_correction(face_gray)
269
+
270
+ # Stem
271
+ x = self.stem(x)
272
+
273
+ # Stages
274
+ x = self.stage1(x)
275
+ x = self.stage2(x)
276
+ x = self.stage3(x)
277
+ x = self.stage4(x)
278
+
279
+ # Feature projection
280
+ features = self.feature_projection(x)
281
+
282
+ # Gaze prediction
283
+ pitch_logits = self.pitch_head(features)
284
+ yaw_logits = self.yaw_head(features)
285
+
286
+ pitch_probs = F.softmax(pitch_logits, dim=-1)
287
+ yaw_probs = F.softmax(yaw_logits, dim=-1)
288
+
289
+ pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1)
290
+ yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1)
291
+
292
+ return pitch_pred, yaw_pred, features
293
+
294
+ def get_penultimate_features(self, face_gray):
295
+ """Return features before regression heads for distillation."""
296
+ _, _, features = self.forward(face_gray)
297
+ return features
298
+
299
+
300
+ def count_parameters(model):
301
+ """Count trainable parameters."""
302
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
303
+
304
+
305
+ if __name__ == "__main__":
306
+ model = PriviGazeStudent()
307
+ params = count_parameters(model)
308
+ print(f"Student model parameters: {params:,}")
309
+
310
+ # Test forward pass
311
+ dummy = torch.randn(1, 1, 224, 224)
312
+ pitch, yaw, features = model(dummy)
313
+ print(f"Pitch: {pitch.shape}, Yaw: {yaw.shape}, Features: {features.shape}")