""" PriviGaze Student Model - Ultra-Compact Gaze Estimation CNN Design Philosophy: - ~50-100K parameters total (target: fit on microcontrollers with TinyML) - Inception blocks with factorized convolutions (exploits eye biology) - Only takes light-corrected grayscale face as input (no eye crops) - Designed for on-device inference with <1ms latency Architecture inspired by: - "One Eye is All You Need" (Athavale et al., 2022) - Inception for gaze - DFT Gaze from GazeGen (281K params, distillation from 10x larger teacher) - Eye biology: horizontal/vertical edge detectors mimic ocular muscle structure Key design choices for disability support: - Grayscale only: robust to varied lighting/occlusion - Large receptive field: handles droopy eyes, head roll - Factorized convolutions: 1x3 + 3x1 instead of 3x3 for efficiency """ import torch import torch.nn as nn import torch.nn.functional as F def conv_bn_relu(in_ch, out_ch, kernel_size, stride=1, padding=0, groups=1): """Standard Conv-BN-ReLU block.""" return nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, groups=groups, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def depthwise_separable_conv(in_ch, out_ch, kernel_size, stride=1, padding=0): """Depthwise separable convolution for extreme parameter efficiency.""" return nn.Sequential( # Depthwise nn.Conv2d(in_ch, in_ch, kernel_size, stride, padding, groups=in_ch, bias=False), nn.BatchNorm2d(in_ch), nn.ReLU(inplace=True), # Pointwise nn.Conv2d(in_ch, out_ch, 1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) class FactorizedConvBlock(nn.Module): """Factorized 3x3 into 1x3 + 3x1 for efficiency. This mimics the ocular muscle structure: horizontal rectus (1x3) and vertical rectus (3x1) for detecting gaze direction. """ def __init__(self, in_ch, out_ch, stride=1): super().__init__() mid_ch = out_ch // 2 self.horizontal = nn.Sequential( nn.Conv2d(in_ch, mid_ch, (1, 3), stride, (0, 1), bias=False), nn.BatchNorm2d(mid_ch), nn.ReLU(inplace=True), ) self.vertical = nn.Sequential( nn.Conv2d(in_ch, out_ch - mid_ch, (3, 1), stride, (1, 0), bias=False), nn.BatchNorm2d(out_ch - mid_ch), nn.ReLU(inplace=True), ) def forward(self, x): h = self.horizontal(x) v = self.vertical(x) return torch.cat([h, v], dim=1) class InceptionBlock(nn.Module): """Lightweight inception block with factorized convolutions. Branches: 1. 1x1 conv (pointwise - captures color/illumination) 2. Factorized 1x3 + 3x1 (edge detection in H/V directions) 3. MaxPool + 1x1 (spatial context) 4. 2x 3x3 (standard feature extraction) All branches use depthwise separable convolutions for efficiency. """ def __init__(self, in_ch, out_ch): super().__init__() # Each branch outputs out_ch // 4 branch_ch = max(out_ch // 4, 8) # Branch 1: 1x1 pointwise self.branch1 = nn.Sequential( nn.Conv2d(in_ch, branch_ch, 1, bias=False), nn.BatchNorm2d(branch_ch), nn.ReLU(inplace=True), ) # Branch 2: Factorized 1x3 + 3x1 self.branch2 = FactorizedConvBlock(in_ch, branch_ch) # Branch 3: MaxPool + 1x1 self.branch3 = nn.Sequential( nn.MaxPool2d(3, stride=1, padding=1), nn.Conv2d(in_ch, branch_ch, 1, bias=False), nn.BatchNorm2d(branch_ch), nn.ReLU(inplace=True), ) # Branch 4: Two stacked 3x3 depthwise separable self.branch4 = nn.Sequential( depthwise_separable_conv(in_ch, branch_ch, 3, padding=1), depthwise_separable_conv(branch_ch, branch_ch, 3, padding=1), ) # Final fusion total_ch = branch_ch * 4 self.fusion = nn.Sequential( nn.Conv2d(total_ch, out_ch, 1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): b1 = self.branch1(x) b2 = self.branch2(x) b3 = self.branch3(x) b4 = self.branch4(x) out = torch.cat([b1, b2, b3, b4], dim=1) return self.fusion(out) class LightCorrection(nn.Module): """Learnable light correction for grayscale face input. Applies per-channel (single channel for grayscale) affine transform to normalize lighting variations. This is critical for disability support where users may be in varied lighting conditions. """ def __init__(self): super().__init__() # Learnable gamma correction parameter self.gamma = nn.Parameter(torch.ones(1)) self.alpha = nn.Parameter(torch.ones(1)) self.beta = nn.Parameter(torch.zeros(1)) def forward(self, x): # x: [B, 1, H, W] # Apply gamma correction x = torch.pow(x.clamp(min=1e-6), self.gamma) # Apply affine transform x = self.alpha * x + self.beta return x class PriviGazeStudent(nn.Module): """Ultra-compact student model for on-device gaze estimation. Input: - face_gray: [B, 1, 224, 224] Light-corrected grayscale face Output: - pitch_pred: [B] gaze pitch in degrees - yaw_pred: [B] gaze yaw in degrees - features: [B, 128] feature representation (for distillation matching) Target: ~80K parameters, <1ms inference on mobile """ def __init__( self, input_channels: int = 1, # Grayscale feature_dim: int = 128, gaze_bins: int = 90, ): super().__init__() # Light correction self.light_correction = LightCorrection() # Stem: initial feature extraction self.stem = nn.Sequential( nn.Conv2d(input_channels, 32, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), # 32 x 112 x 112 depthwise_separable_conv(32, 32, 3, stride=2, padding=1), # 32 x 56 x 56 ) # Stage 1: Inception blocks at high resolution self.stage1 = nn.Sequential( InceptionBlock(32, 64), # 64 x 56 x 56 depthwise_separable_conv(64, 64, 3, stride=2, padding=1), # 64 x 28 x 28 ) # Stage 2: Deeper features self.stage2 = nn.Sequential( InceptionBlock(64, 96), # 96 x 28 x 28 depthwise_separable_conv(96, 96, 3, stride=2, padding=1), # 96 x 14 x 14 ) # Stage 3: Abstract features self.stage3 = nn.Sequential( InceptionBlock(96, 128), # 128 x 14 x 14 depthwise_separable_conv(128, 128, 3, stride=2, padding=1), # 128 x 7 x 7 ) # Stage 4: Global context self.stage4 = nn.Sequential( InceptionBlock(128, 160), # 160 x 7 x 7 nn.AdaptiveAvgPool2d(1), # 160 x 1 x 1 ) # Feature projection self.feature_projection = nn.Sequential( nn.Flatten(), nn.Linear(160, feature_dim), nn.GELU(), nn.LayerNorm(feature_dim), ) # Gaze regression heads (L2CS-Net style: per-angle binned regression) self.pitch_head = nn.Sequential( nn.Linear(feature_dim, feature_dim // 2), nn.GELU(), nn.Linear(feature_dim // 2, gaze_bins), ) self.yaw_head = nn.Sequential( nn.Linear(feature_dim, feature_dim // 2), nn.GELU(), nn.Linear(feature_dim // 2, gaze_bins), ) # Bin centers for expectation-based regression self.register_buffer( 'bin_centers', torch.linspace(-90.0, 90.0, gaze_bins) ) self.feature_dim = feature_dim self.gaze_bins = gaze_bins def forward(self, face_gray): """ Args: face_gray: [B, 1, 224, 224] grayscale face image Returns: pitch_pred: [B] yaw_pred: [B] features: [B, feature_dim] """ # Light correction x = self.light_correction(face_gray) # Stem x = self.stem(x) # Stages x = self.stage1(x) x = self.stage2(x) x = self.stage3(x) x = self.stage4(x) # Feature projection features = self.feature_projection(x) # Gaze prediction pitch_logits = self.pitch_head(features) yaw_logits = self.yaw_head(features) pitch_probs = F.softmax(pitch_logits, dim=-1) yaw_probs = F.softmax(yaw_logits, dim=-1) pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1) yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1) return pitch_pred, yaw_pred, features def get_penultimate_features(self, face_gray): """Return features before regression heads for distillation.""" _, _, features = self.forward(face_gray) return features def count_parameters(model): """Count trainable parameters.""" return sum(p.numel() for p in model.parameters() if p.requires_grad) if __name__ == "__main__": model = PriviGazeStudent() params = count_parameters(model) print(f"Student model parameters: {params:,}") # Test forward pass dummy = torch.randn(1, 1, 224, 224) pitch, yaw, features = model(dummy) print(f"Pitch: {pitch.shape}, Yaw: {yaw.shape}, Features: {features.shape}")