| """ |
| PriviGaze Teacher Model - Siamese Multi-Input Gaze Estimation Network |
| |
| Architecture: |
| - Takes 3 inputs: left eye RGB, right eye RGB, blurred grayscale face |
| - Uses ConvNeXtV2-Atto as shared backbone for eye streams |
| - Uses ConvNeXtV2-Nano for face stream |
| - Fuses multi-modal features via cross-attention |
| - Outputs: pitch and yaw gaze angles (degrees) |
| |
| This teacher has access to privileged information (RGB eye crops, high-res face) |
| that the student does NOT have at inference time. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import ConvNextV2Model |
|
|
|
|
| class ConvNextV2FeatureExtractor(nn.Module): |
| """Wrapper around ConvNeXtV2 for feature extraction (no classification head).""" |
| |
| def __init__(self, model_name: str, output_dim: int = 256): |
| super().__init__() |
| self.backbone = ConvNextV2Model.from_pretrained(model_name) |
| self.backbone.gradient_checkpointing_enable() |
| |
| hidden_size = self.backbone.config.hidden_sizes[-1] |
| self.projection = nn.Sequential( |
| nn.LayerNorm(hidden_size), |
| nn.Linear(hidden_size, output_dim), |
| nn.GELU(), |
| ) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| outputs = self.backbone(x) |
| pooled = outputs.pooler_output |
| return self.projection(pooled) |
|
|
|
|
| class CrossAttentionFusion(nn.Module): |
| """Cross-attention fusion module for multi-modal features.""" |
| |
| def __init__(self, dim: int = 256, num_heads: int = 4): |
| super().__init__() |
| self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True) |
| self.norm1 = nn.LayerNorm(dim) |
| self.norm2 = nn.LayerNorm(dim) |
| self.ffn = nn.Sequential( |
| nn.Linear(dim, dim * 4), |
| nn.GELU(), |
| nn.Linear(dim * 4, dim), |
| ) |
| |
| def forward(self, face_feat: torch.Tensor, eye_feats: torch.Tensor) -> torch.Tensor: |
| face_seq = face_feat.unsqueeze(1) |
| attn_out, _ = self.cross_attn(face_seq, eye_feats, eye_feats) |
| out = self.norm1(face_seq + attn_out) |
| out = self.norm2(out + self.ffn(out)) |
| return out.squeeze(1) |
|
|
|
|
| class PriviGazeTeacher(nn.Module): |
| """Siamese teacher model with privileged multi-modal inputs. |
| |
| Inputs: |
| - left_eye: [B, 3, 112, 112] RGB left eye crop |
| - right_eye: [B, 3, 112, 112] RGB right eye crop |
| - face_blurred_gray: [B, 1, 224, 224] Blurred grayscale face |
| |
| Outputs: |
| - pitch_pred: [B] gaze pitch angle in degrees |
| - yaw_pred: [B] gaze yaw angle in degrees |
| - pitch_logits: [B, gaze_bins] for logit distillation |
| - yaw_logits: [B, gaze_bins] for logit distillation |
| - features: [B, 256] fused feature representation for distillation |
| """ |
| |
| def __init__( |
| self, |
| eye_backbone: str = "facebook/convnextv2-atto-1k-224", |
| face_backbone: str = "facebook/convnextv2-nano-22k-384", |
| feature_dim: int = 256, |
| gaze_bins: int = 90, |
| ): |
| super().__init__() |
| |
| self.eye_extractor = ConvNextV2FeatureExtractor(eye_backbone, feature_dim) |
| self.face_extractor = ConvNextV2FeatureExtractor(face_backbone, feature_dim) |
| |
| self.eye_fusion = nn.Sequential( |
| nn.Linear(feature_dim * 2, feature_dim), |
| nn.GELU(), |
| nn.LayerNorm(feature_dim), |
| ) |
| |
| self.cross_fusion = CrossAttentionFusion(feature_dim, num_heads=4) |
| |
| self.pitch_head = nn.Sequential( |
| nn.Linear(feature_dim, feature_dim // 2), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(feature_dim // 2, gaze_bins), |
| ) |
| |
| self.yaw_head = nn.Sequential( |
| nn.Linear(feature_dim, feature_dim // 2), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(feature_dim // 2, gaze_bins), |
| ) |
| |
| self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins)) |
| self.feature_dim = feature_dim |
| self.gaze_bins = gaze_bins |
| |
| def _adapt_face_input(self, x: torch.Tensor) -> torch.Tensor: |
| if x.shape[1] == 1: |
| x = x.repeat(1, 3, 1, 1) |
| return x |
| |
| def forward(self, left_eye, right_eye, face_blurred_gray): |
| left_feat = self.eye_extractor(left_eye) |
| right_feat = self.eye_extractor(right_eye) |
| |
| face_input = self._adapt_face_input(face_blurred_gray) |
| face_feat = self.face_extractor(face_input) |
| |
| eye_combined = torch.cat([left_feat, right_feat], dim=-1) |
| eye_fused = self.eye_fusion(eye_combined) |
| |
| eye_stacked = torch.stack([left_feat, right_feat], dim=1) |
| fused = self.cross_fusion(face_feat, eye_stacked) |
| fused = fused + eye_fused |
| |
| pitch_logits = self.pitch_head(fused) |
| yaw_logits = self.yaw_head(fused) |
| |
| pitch_probs = F.softmax(pitch_logits, dim=-1) |
| yaw_probs = F.softmax(yaw_logits, dim=-1) |
| |
| pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1) |
| yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1) |
| |
| return pitch_pred, yaw_pred, pitch_logits, yaw_logits, fused |
| |
| def get_penultimate_features(self, left_eye, right_eye, face_blurred_gray): |
| _, _, _, _, fused = self.forward(left_eye, right_eye, face_blurred_gray) |
| return fused |
|
|