| """ |
| PriviGaze Student Model - Ultra-Compact Gaze Estimation CNN |
| |
| Design Philosophy: |
| - ~50-100K parameters total (target: fit on microcontrollers with TinyML) |
| - Inception blocks with factorized convolutions (exploits eye biology) |
| - Only takes light-corrected grayscale face as input (no eye crops) |
| - Designed for on-device inference with <1ms latency |
| |
| Architecture inspired by: |
| - "One Eye is All You Need" (Athavale et al., 2022) - Inception for gaze |
| - DFT Gaze from GazeGen (281K params, distillation from 10x larger teacher) |
| - Eye biology: horizontal/vertical edge detectors mimic ocular muscle structure |
| |
| Key design choices for disability support: |
| - Grayscale only: robust to varied lighting/occlusion |
| - Large receptive field: handles droopy eyes, head roll |
| - Factorized convolutions: 1x3 + 3x1 instead of 3x3 for efficiency |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def conv_bn_relu(in_ch, out_ch, kernel_size, stride=1, padding=0, groups=1): |
| """Standard Conv-BN-ReLU block.""" |
| return nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, groups=groups, bias=False), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(inplace=True), |
| ) |
|
|
|
|
| def depthwise_separable_conv(in_ch, out_ch, kernel_size, stride=1, padding=0): |
| """Depthwise separable convolution for extreme parameter efficiency.""" |
| return nn.Sequential( |
| |
| nn.Conv2d(in_ch, in_ch, kernel_size, stride, padding, groups=in_ch, bias=False), |
| nn.BatchNorm2d(in_ch), |
| nn.ReLU(inplace=True), |
| |
| nn.Conv2d(in_ch, out_ch, 1, bias=False), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(inplace=True), |
| ) |
|
|
|
|
| class FactorizedConvBlock(nn.Module): |
| """Factorized 3x3 into 1x3 + 3x1 for efficiency. |
| |
| This mimics the ocular muscle structure: horizontal rectus (1x3) |
| and vertical rectus (3x1) for detecting gaze direction. |
| """ |
| |
| def __init__(self, in_ch, out_ch, stride=1): |
| super().__init__() |
| mid_ch = out_ch // 2 |
| |
| self.horizontal = nn.Sequential( |
| nn.Conv2d(in_ch, mid_ch, (1, 3), stride, (0, 1), bias=False), |
| nn.BatchNorm2d(mid_ch), |
| nn.ReLU(inplace=True), |
| ) |
| |
| self.vertical = nn.Sequential( |
| nn.Conv2d(in_ch, out_ch - mid_ch, (3, 1), stride, (1, 0), bias=False), |
| nn.BatchNorm2d(out_ch - mid_ch), |
| nn.ReLU(inplace=True), |
| ) |
| |
| def forward(self, x): |
| h = self.horizontal(x) |
| v = self.vertical(x) |
| return torch.cat([h, v], dim=1) |
|
|
|
|
| class InceptionBlock(nn.Module): |
| """Lightweight inception block with factorized convolutions. |
| |
| Branches: |
| 1. 1x1 conv (pointwise - captures color/illumination) |
| 2. Factorized 1x3 + 3x1 (edge detection in H/V directions) |
| 3. MaxPool + 1x1 (spatial context) |
| 4. 2x 3x3 (standard feature extraction) |
| |
| All branches use depthwise separable convolutions for efficiency. |
| """ |
| |
| def __init__(self, in_ch, out_ch): |
| super().__init__() |
| |
| branch_ch = max(out_ch // 4, 8) |
| |
| |
| self.branch1 = nn.Sequential( |
| nn.Conv2d(in_ch, branch_ch, 1, bias=False), |
| nn.BatchNorm2d(branch_ch), |
| nn.ReLU(inplace=True), |
| ) |
| |
| |
| self.branch2 = FactorizedConvBlock(in_ch, branch_ch) |
| |
| |
| self.branch3 = nn.Sequential( |
| nn.MaxPool2d(3, stride=1, padding=1), |
| nn.Conv2d(in_ch, branch_ch, 1, bias=False), |
| nn.BatchNorm2d(branch_ch), |
| nn.ReLU(inplace=True), |
| ) |
| |
| |
| self.branch4 = nn.Sequential( |
| depthwise_separable_conv(in_ch, branch_ch, 3, padding=1), |
| depthwise_separable_conv(branch_ch, branch_ch, 3, padding=1), |
| ) |
| |
| |
| total_ch = branch_ch * 4 |
| self.fusion = nn.Sequential( |
| nn.Conv2d(total_ch, out_ch, 1, bias=False), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(inplace=True), |
| ) |
| |
| def forward(self, x): |
| b1 = self.branch1(x) |
| b2 = self.branch2(x) |
| b3 = self.branch3(x) |
| b4 = self.branch4(x) |
| out = torch.cat([b1, b2, b3, b4], dim=1) |
| return self.fusion(out) |
|
|
|
|
| class LightCorrection(nn.Module): |
| """Learnable light correction for grayscale face input. |
| |
| Applies per-channel (single channel for grayscale) affine transform |
| to normalize lighting variations. This is critical for disability support |
| where users may be in varied lighting conditions. |
| """ |
| |
| def __init__(self): |
| super().__init__() |
| |
| self.gamma = nn.Parameter(torch.ones(1)) |
| self.alpha = nn.Parameter(torch.ones(1)) |
| self.beta = nn.Parameter(torch.zeros(1)) |
| |
| def forward(self, x): |
| |
| |
| x = torch.pow(x.clamp(min=1e-6), self.gamma) |
| |
| x = self.alpha * x + self.beta |
| return x |
|
|
|
|
| class PriviGazeStudent(nn.Module): |
| """Ultra-compact student model for on-device gaze estimation. |
| |
| Input: |
| - face_gray: [B, 1, 224, 224] Light-corrected grayscale face |
| |
| Output: |
| - pitch_pred: [B] gaze pitch in degrees |
| - yaw_pred: [B] gaze yaw in degrees |
| - features: [B, 128] feature representation (for distillation matching) |
| |
| Target: ~80K parameters, <1ms inference on mobile |
| """ |
| |
| def __init__( |
| self, |
| input_channels: int = 1, |
| feature_dim: int = 128, |
| gaze_bins: int = 90, |
| ): |
| super().__init__() |
| |
| |
| self.light_correction = LightCorrection() |
| |
| |
| self.stem = nn.Sequential( |
| nn.Conv2d(input_channels, 32, 3, stride=2, padding=1, bias=False), |
| nn.BatchNorm2d(32), |
| nn.ReLU(inplace=True), |
| |
| depthwise_separable_conv(32, 32, 3, stride=2, padding=1), |
| |
| ) |
| |
| |
| self.stage1 = nn.Sequential( |
| InceptionBlock(32, 64), |
| |
| depthwise_separable_conv(64, 64, 3, stride=2, padding=1), |
| |
| ) |
| |
| |
| self.stage2 = nn.Sequential( |
| InceptionBlock(64, 96), |
| |
| depthwise_separable_conv(96, 96, 3, stride=2, padding=1), |
| |
| ) |
| |
| |
| self.stage3 = nn.Sequential( |
| InceptionBlock(96, 128), |
| |
| depthwise_separable_conv(128, 128, 3, stride=2, padding=1), |
| |
| ) |
| |
| |
| self.stage4 = nn.Sequential( |
| InceptionBlock(128, 160), |
| |
| nn.AdaptiveAvgPool2d(1), |
| |
| ) |
| |
| |
| self.feature_projection = nn.Sequential( |
| nn.Flatten(), |
| nn.Linear(160, feature_dim), |
| nn.GELU(), |
| nn.LayerNorm(feature_dim), |
| ) |
| |
| |
| self.pitch_head = nn.Sequential( |
| nn.Linear(feature_dim, feature_dim // 2), |
| nn.GELU(), |
| nn.Linear(feature_dim // 2, gaze_bins), |
| ) |
| |
| self.yaw_head = nn.Sequential( |
| nn.Linear(feature_dim, feature_dim // 2), |
| nn.GELU(), |
| nn.Linear(feature_dim // 2, gaze_bins), |
| ) |
| |
| |
| self.register_buffer( |
| 'bin_centers', |
| torch.linspace(-90.0, 90.0, gaze_bins) |
| ) |
| |
| self.feature_dim = feature_dim |
| self.gaze_bins = gaze_bins |
| |
| def forward(self, face_gray): |
| """ |
| Args: |
| face_gray: [B, 1, 224, 224] grayscale face image |
| |
| Returns: |
| pitch_pred: [B] |
| yaw_pred: [B] |
| features: [B, feature_dim] |
| """ |
| |
| x = self.light_correction(face_gray) |
| |
| |
| x = self.stem(x) |
| |
| |
| x = self.stage1(x) |
| x = self.stage2(x) |
| x = self.stage3(x) |
| x = self.stage4(x) |
| |
| |
| features = self.feature_projection(x) |
| |
| |
| pitch_logits = self.pitch_head(features) |
| yaw_logits = self.yaw_head(features) |
| |
| pitch_probs = F.softmax(pitch_logits, dim=-1) |
| yaw_probs = F.softmax(yaw_logits, dim=-1) |
| |
| pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1) |
| yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1) |
| |
| return pitch_pred, yaw_pred, features |
| |
| def get_penultimate_features(self, face_gray): |
| """Return features before regression heads for distillation.""" |
| _, _, features = self.forward(face_gray) |
| return features |
|
|
|
|
| def count_parameters(model): |
| """Count trainable parameters.""" |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| if __name__ == "__main__": |
| model = PriviGazeStudent() |
| params = count_parameters(model) |
| print(f"Student model parameters: {params:,}") |
| |
| |
| dummy = torch.randn(1, 1, 224, 224) |
| pitch, yaw, features = model(dummy) |
| print(f"Pitch: {pitch.shape}, Yaw: {yaw.shape}, Features: {features.shape}") |
|
|