privi-gaze-distill / models /student.py
BcantCode's picture
Upload models/student.py
0607636 verified
"""
PriviGaze Student Model - Ultra-Compact Gaze Estimation CNN
Design Philosophy:
- ~50-100K parameters total (target: fit on microcontrollers with TinyML)
- Inception blocks with factorized convolutions (exploits eye biology)
- Only takes light-corrected grayscale face as input (no eye crops)
- Designed for on-device inference with <1ms latency
Architecture inspired by:
- "One Eye is All You Need" (Athavale et al., 2022) - Inception for gaze
- DFT Gaze from GazeGen (281K params, distillation from 10x larger teacher)
- Eye biology: horizontal/vertical edge detectors mimic ocular muscle structure
Key design choices for disability support:
- Grayscale only: robust to varied lighting/occlusion
- Large receptive field: handles droopy eyes, head roll
- Factorized convolutions: 1x3 + 3x1 instead of 3x3 for efficiency
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv_bn_relu(in_ch, out_ch, kernel_size, stride=1, padding=0, groups=1):
"""Standard Conv-BN-ReLU block."""
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def depthwise_separable_conv(in_ch, out_ch, kernel_size, stride=1, padding=0):
"""Depthwise separable convolution for extreme parameter efficiency."""
return nn.Sequential(
# Depthwise
nn.Conv2d(in_ch, in_ch, kernel_size, stride, padding, groups=in_ch, bias=False),
nn.BatchNorm2d(in_ch),
nn.ReLU(inplace=True),
# Pointwise
nn.Conv2d(in_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
class FactorizedConvBlock(nn.Module):
"""Factorized 3x3 into 1x3 + 3x1 for efficiency.
This mimics the ocular muscle structure: horizontal rectus (1x3)
and vertical rectus (3x1) for detecting gaze direction.
"""
def __init__(self, in_ch, out_ch, stride=1):
super().__init__()
mid_ch = out_ch // 2
self.horizontal = nn.Sequential(
nn.Conv2d(in_ch, mid_ch, (1, 3), stride, (0, 1), bias=False),
nn.BatchNorm2d(mid_ch),
nn.ReLU(inplace=True),
)
self.vertical = nn.Sequential(
nn.Conv2d(in_ch, out_ch - mid_ch, (3, 1), stride, (1, 0), bias=False),
nn.BatchNorm2d(out_ch - mid_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
h = self.horizontal(x)
v = self.vertical(x)
return torch.cat([h, v], dim=1)
class InceptionBlock(nn.Module):
"""Lightweight inception block with factorized convolutions.
Branches:
1. 1x1 conv (pointwise - captures color/illumination)
2. Factorized 1x3 + 3x1 (edge detection in H/V directions)
3. MaxPool + 1x1 (spatial context)
4. 2x 3x3 (standard feature extraction)
All branches use depthwise separable convolutions for efficiency.
"""
def __init__(self, in_ch, out_ch):
super().__init__()
# Each branch outputs out_ch // 4
branch_ch = max(out_ch // 4, 8)
# Branch 1: 1x1 pointwise
self.branch1 = nn.Sequential(
nn.Conv2d(in_ch, branch_ch, 1, bias=False),
nn.BatchNorm2d(branch_ch),
nn.ReLU(inplace=True),
)
# Branch 2: Factorized 1x3 + 3x1
self.branch2 = FactorizedConvBlock(in_ch, branch_ch)
# Branch 3: MaxPool + 1x1
self.branch3 = nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.Conv2d(in_ch, branch_ch, 1, bias=False),
nn.BatchNorm2d(branch_ch),
nn.ReLU(inplace=True),
)
# Branch 4: Two stacked 3x3 depthwise separable
self.branch4 = nn.Sequential(
depthwise_separable_conv(in_ch, branch_ch, 3, padding=1),
depthwise_separable_conv(branch_ch, branch_ch, 3, padding=1),
)
# Final fusion
total_ch = branch_ch * 4
self.fusion = nn.Sequential(
nn.Conv2d(total_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
b1 = self.branch1(x)
b2 = self.branch2(x)
b3 = self.branch3(x)
b4 = self.branch4(x)
out = torch.cat([b1, b2, b3, b4], dim=1)
return self.fusion(out)
class LightCorrection(nn.Module):
"""Learnable light correction for grayscale face input.
Applies per-channel (single channel for grayscale) affine transform
to normalize lighting variations. This is critical for disability support
where users may be in varied lighting conditions.
"""
def __init__(self):
super().__init__()
# Learnable gamma correction parameter
self.gamma = nn.Parameter(torch.ones(1))
self.alpha = nn.Parameter(torch.ones(1))
self.beta = nn.Parameter(torch.zeros(1))
def forward(self, x):
# x: [B, 1, H, W]
# Apply gamma correction
x = torch.pow(x.clamp(min=1e-6), self.gamma)
# Apply affine transform
x = self.alpha * x + self.beta
return x
class PriviGazeStudent(nn.Module):
"""Ultra-compact student model for on-device gaze estimation.
Input:
- face_gray: [B, 1, 224, 224] Light-corrected grayscale face
Output:
- pitch_pred: [B] gaze pitch in degrees
- yaw_pred: [B] gaze yaw in degrees
- features: [B, 128] feature representation (for distillation matching)
Target: ~80K parameters, <1ms inference on mobile
"""
def __init__(
self,
input_channels: int = 1, # Grayscale
feature_dim: int = 128,
gaze_bins: int = 90,
):
super().__init__()
# Light correction
self.light_correction = LightCorrection()
# Stem: initial feature extraction
self.stem = nn.Sequential(
nn.Conv2d(input_channels, 32, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
# 32 x 112 x 112
depthwise_separable_conv(32, 32, 3, stride=2, padding=1),
# 32 x 56 x 56
)
# Stage 1: Inception blocks at high resolution
self.stage1 = nn.Sequential(
InceptionBlock(32, 64),
# 64 x 56 x 56
depthwise_separable_conv(64, 64, 3, stride=2, padding=1),
# 64 x 28 x 28
)
# Stage 2: Deeper features
self.stage2 = nn.Sequential(
InceptionBlock(64, 96),
# 96 x 28 x 28
depthwise_separable_conv(96, 96, 3, stride=2, padding=1),
# 96 x 14 x 14
)
# Stage 3: Abstract features
self.stage3 = nn.Sequential(
InceptionBlock(96, 128),
# 128 x 14 x 14
depthwise_separable_conv(128, 128, 3, stride=2, padding=1),
# 128 x 7 x 7
)
# Stage 4: Global context
self.stage4 = nn.Sequential(
InceptionBlock(128, 160),
# 160 x 7 x 7
nn.AdaptiveAvgPool2d(1),
# 160 x 1 x 1
)
# Feature projection
self.feature_projection = nn.Sequential(
nn.Flatten(),
nn.Linear(160, feature_dim),
nn.GELU(),
nn.LayerNorm(feature_dim),
)
# Gaze regression heads (L2CS-Net style: per-angle binned regression)
self.pitch_head = nn.Sequential(
nn.Linear(feature_dim, feature_dim // 2),
nn.GELU(),
nn.Linear(feature_dim // 2, gaze_bins),
)
self.yaw_head = nn.Sequential(
nn.Linear(feature_dim, feature_dim // 2),
nn.GELU(),
nn.Linear(feature_dim // 2, gaze_bins),
)
# Bin centers for expectation-based regression
self.register_buffer(
'bin_centers',
torch.linspace(-90.0, 90.0, gaze_bins)
)
self.feature_dim = feature_dim
self.gaze_bins = gaze_bins
def forward(self, face_gray):
"""
Args:
face_gray: [B, 1, 224, 224] grayscale face image
Returns:
pitch_pred: [B]
yaw_pred: [B]
features: [B, feature_dim]
"""
# Light correction
x = self.light_correction(face_gray)
# Stem
x = self.stem(x)
# Stages
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
# Feature projection
features = self.feature_projection(x)
# Gaze prediction
pitch_logits = self.pitch_head(features)
yaw_logits = self.yaw_head(features)
pitch_probs = F.softmax(pitch_logits, dim=-1)
yaw_probs = F.softmax(yaw_logits, dim=-1)
pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1)
yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1)
return pitch_pred, yaw_pred, features
def get_penultimate_features(self, face_gray):
"""Return features before regression heads for distillation."""
_, _, features = self.forward(face_gray)
return features
def count_parameters(model):
"""Count trainable parameters."""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == "__main__":
model = PriviGazeStudent()
params = count_parameters(model)
print(f"Student model parameters: {params:,}")
# Test forward pass
dummy = torch.randn(1, 1, 224, 224)
pitch, yaw, features = model(dummy)
print(f"Pitch: {pitch.shape}, Yaw: {yaw.shape}, Features: {features.shape}")