File size: 10,008 Bytes
0607636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
"""
PriviGaze Student Model - Ultra-Compact Gaze Estimation CNN

Design Philosophy:
- ~50-100K parameters total (target: fit on microcontrollers with TinyML)
- Inception blocks with factorized convolutions (exploits eye biology)
- Only takes light-corrected grayscale face as input (no eye crops)
- Designed for on-device inference with <1ms latency

Architecture inspired by:
- "One Eye is All You Need" (Athavale et al., 2022) - Inception for gaze
- DFT Gaze from GazeGen (281K params, distillation from 10x larger teacher)
- Eye biology: horizontal/vertical edge detectors mimic ocular muscle structure

Key design choices for disability support:
- Grayscale only: robust to varied lighting/occlusion
- Large receptive field: handles droopy eyes, head roll
- Factorized convolutions: 1x3 + 3x1 instead of 3x3 for efficiency
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


def conv_bn_relu(in_ch, out_ch, kernel_size, stride=1, padding=0, groups=1):
    """Standard Conv-BN-ReLU block."""
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, groups=groups, bias=False),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
    )


def depthwise_separable_conv(in_ch, out_ch, kernel_size, stride=1, padding=0):
    """Depthwise separable convolution for extreme parameter efficiency."""
    return nn.Sequential(
        # Depthwise
        nn.Conv2d(in_ch, in_ch, kernel_size, stride, padding, groups=in_ch, bias=False),
        nn.BatchNorm2d(in_ch),
        nn.ReLU(inplace=True),
        # Pointwise
        nn.Conv2d(in_ch, out_ch, 1, bias=False),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
    )


class FactorizedConvBlock(nn.Module):
    """Factorized 3x3 into 1x3 + 3x1 for efficiency.
    
    This mimics the ocular muscle structure: horizontal rectus (1x3) 
    and vertical rectus (3x1) for detecting gaze direction.
    """
    
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        mid_ch = out_ch // 2
        
        self.horizontal = nn.Sequential(
            nn.Conv2d(in_ch, mid_ch, (1, 3), stride, (0, 1), bias=False),
            nn.BatchNorm2d(mid_ch),
            nn.ReLU(inplace=True),
        )
        
        self.vertical = nn.Sequential(
            nn.Conv2d(in_ch, out_ch - mid_ch, (3, 1), stride, (1, 0), bias=False),
            nn.BatchNorm2d(out_ch - mid_ch),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        h = self.horizontal(x)
        v = self.vertical(x)
        return torch.cat([h, v], dim=1)


class InceptionBlock(nn.Module):
    """Lightweight inception block with factorized convolutions.
    
    Branches:
    1. 1x1 conv (pointwise - captures color/illumination)
    2. Factorized 1x3 + 3x1 (edge detection in H/V directions)
    3. MaxPool + 1x1 (spatial context)
    4. 2x 3x3 (standard feature extraction)
    
    All branches use depthwise separable convolutions for efficiency.
    """
    
    def __init__(self, in_ch, out_ch):
        super().__init__()
        # Each branch outputs out_ch // 4
        branch_ch = max(out_ch // 4, 8)
        
        # Branch 1: 1x1 pointwise
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_ch, branch_ch, 1, bias=False),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )
        
        # Branch 2: Factorized 1x3 + 3x1
        self.branch2 = FactorizedConvBlock(in_ch, branch_ch)
        
        # Branch 3: MaxPool + 1x1
        self.branch3 = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_ch, branch_ch, 1, bias=False),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )
        
        # Branch 4: Two stacked 3x3 depthwise separable
        self.branch4 = nn.Sequential(
            depthwise_separable_conv(in_ch, branch_ch, 3, padding=1),
            depthwise_separable_conv(branch_ch, branch_ch, 3, padding=1),
        )
        
        # Final fusion
        total_ch = branch_ch * 4
        self.fusion = nn.Sequential(
            nn.Conv2d(total_ch, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        b4 = self.branch4(x)
        out = torch.cat([b1, b2, b3, b4], dim=1)
        return self.fusion(out)


class LightCorrection(nn.Module):
    """Learnable light correction for grayscale face input.
    
    Applies per-channel (single channel for grayscale) affine transform
    to normalize lighting variations. This is critical for disability support
    where users may be in varied lighting conditions.
    """
    
    def __init__(self):
        super().__init__()
        # Learnable gamma correction parameter
        self.gamma = nn.Parameter(torch.ones(1))
        self.alpha = nn.Parameter(torch.ones(1))
        self.beta = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        # x: [B, 1, H, W]
        # Apply gamma correction
        x = torch.pow(x.clamp(min=1e-6), self.gamma)
        # Apply affine transform
        x = self.alpha * x + self.beta
        return x


class PriviGazeStudent(nn.Module):
    """Ultra-compact student model for on-device gaze estimation.
    
    Input: 
        - face_gray: [B, 1, 224, 224] Light-corrected grayscale face
    
    Output:
        - pitch_pred: [B] gaze pitch in degrees
        - yaw_pred: [B] gaze yaw in degrees
        - features: [B, 128] feature representation (for distillation matching)
    
    Target: ~80K parameters, <1ms inference on mobile
    """
    
    def __init__(
        self,
        input_channels: int = 1,  # Grayscale
        feature_dim: int = 128,
        gaze_bins: int = 90,
    ):
        super().__init__()
        
        # Light correction
        self.light_correction = LightCorrection()
        
        # Stem: initial feature extraction
        self.stem = nn.Sequential(
            nn.Conv2d(input_channels, 32, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            # 32 x 112 x 112
            depthwise_separable_conv(32, 32, 3, stride=2, padding=1),
            # 32 x 56 x 56
        )
        
        # Stage 1: Inception blocks at high resolution
        self.stage1 = nn.Sequential(
            InceptionBlock(32, 64),
            # 64 x 56 x 56
            depthwise_separable_conv(64, 64, 3, stride=2, padding=1),
            # 64 x 28 x 28
        )
        
        # Stage 2: Deeper features
        self.stage2 = nn.Sequential(
            InceptionBlock(64, 96),
            # 96 x 28 x 28
            depthwise_separable_conv(96, 96, 3, stride=2, padding=1),
            # 96 x 14 x 14
        )
        
        # Stage 3: Abstract features
        self.stage3 = nn.Sequential(
            InceptionBlock(96, 128),
            # 128 x 14 x 14
            depthwise_separable_conv(128, 128, 3, stride=2, padding=1),
            # 128 x 7 x 7
        )
        
        # Stage 4: Global context
        self.stage4 = nn.Sequential(
            InceptionBlock(128, 160),
            # 160 x 7 x 7
            nn.AdaptiveAvgPool2d(1),
            # 160 x 1 x 1
        )
        
        # Feature projection
        self.feature_projection = nn.Sequential(
            nn.Flatten(),
            nn.Linear(160, feature_dim),
            nn.GELU(),
            nn.LayerNorm(feature_dim),
        )
        
        # Gaze regression heads (L2CS-Net style: per-angle binned regression)
        self.pitch_head = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.GELU(),
            nn.Linear(feature_dim // 2, gaze_bins),
        )
        
        self.yaw_head = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.GELU(),
            nn.Linear(feature_dim // 2, gaze_bins),
        )
        
        # Bin centers for expectation-based regression
        self.register_buffer(
            'bin_centers',
            torch.linspace(-90.0, 90.0, gaze_bins)
        )
        
        self.feature_dim = feature_dim
        self.gaze_bins = gaze_bins
    
    def forward(self, face_gray):
        """
        Args:
            face_gray: [B, 1, 224, 224] grayscale face image
        
        Returns:
            pitch_pred: [B]
            yaw_pred: [B]
            features: [B, feature_dim]
        """
        # Light correction
        x = self.light_correction(face_gray)
        
        # Stem
        x = self.stem(x)
        
        # Stages
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        
        # Feature projection
        features = self.feature_projection(x)
        
        # Gaze prediction
        pitch_logits = self.pitch_head(features)
        yaw_logits = self.yaw_head(features)
        
        pitch_probs = F.softmax(pitch_logits, dim=-1)
        yaw_probs = F.softmax(yaw_logits, dim=-1)
        
        pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1)
        yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1)
        
        return pitch_pred, yaw_pred, features
    
    def get_penultimate_features(self, face_gray):
        """Return features before regression heads for distillation."""
        _, _, features = self.forward(face_gray)
        return features


def count_parameters(model):
    """Count trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == "__main__":
    model = PriviGazeStudent()
    params = count_parameters(model)
    print(f"Student model parameters: {params:,}")
    
    # Test forward pass
    dummy = torch.randn(1, 1, 224, 224)
    pitch, yaw, features = model(dummy)
    print(f"Pitch: {pitch.shape}, Yaw: {yaw.shape}, Features: {features.shape}")