""" PriviGaze Teacher Model - Siamese Multi-Input Gaze Estimation Network Architecture: - Takes 3 inputs: left eye RGB, right eye RGB, blurred grayscale face - Uses ConvNeXtV2-Atto as shared backbone for eye streams - Uses ConvNeXtV2-Nano for face stream - Fuses multi-modal features via cross-attention - Outputs: pitch and yaw gaze angles (degrees) This teacher has access to privileged information (RGB eye crops, high-res face) that the student does NOT have at inference time. """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import ConvNextV2Model class ConvNextV2FeatureExtractor(nn.Module): """Wrapper around ConvNeXtV2 for feature extraction (no classification head).""" def __init__(self, model_name: str, output_dim: int = 256): super().__init__() self.backbone = ConvNextV2Model.from_pretrained(model_name) self.backbone.gradient_checkpointing_enable() hidden_size = self.backbone.config.hidden_sizes[-1] self.projection = nn.Sequential( nn.LayerNorm(hidden_size), nn.Linear(hidden_size, output_dim), nn.GELU(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: outputs = self.backbone(x) pooled = outputs.pooler_output return self.projection(pooled) class CrossAttentionFusion(nn.Module): """Cross-attention fusion module for multi-modal features.""" def __init__(self, dim: int = 256, num_heads: int = 4): super().__init__() self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.ffn = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), ) def forward(self, face_feat: torch.Tensor, eye_feats: torch.Tensor) -> torch.Tensor: face_seq = face_feat.unsqueeze(1) attn_out, _ = self.cross_attn(face_seq, eye_feats, eye_feats) out = self.norm1(face_seq + attn_out) out = self.norm2(out + self.ffn(out)) return out.squeeze(1) class PriviGazeTeacher(nn.Module): """Siamese teacher model with privileged multi-modal inputs. Inputs: - left_eye: [B, 3, 112, 112] RGB left eye crop - right_eye: [B, 3, 112, 112] RGB right eye crop - face_blurred_gray: [B, 1, 224, 224] Blurred grayscale face Outputs: - pitch_pred: [B] gaze pitch angle in degrees - yaw_pred: [B] gaze yaw angle in degrees - pitch_logits: [B, gaze_bins] for logit distillation - yaw_logits: [B, gaze_bins] for logit distillation - features: [B, 256] fused feature representation for distillation """ def __init__( self, eye_backbone: str = "facebook/convnextv2-atto-1k-224", face_backbone: str = "facebook/convnextv2-nano-22k-384", feature_dim: int = 256, gaze_bins: int = 90, ): super().__init__() self.eye_extractor = ConvNextV2FeatureExtractor(eye_backbone, feature_dim) self.face_extractor = ConvNextV2FeatureExtractor(face_backbone, feature_dim) self.eye_fusion = nn.Sequential( nn.Linear(feature_dim * 2, feature_dim), nn.GELU(), nn.LayerNorm(feature_dim), ) self.cross_fusion = CrossAttentionFusion(feature_dim, num_heads=4) self.pitch_head = nn.Sequential( nn.Linear(feature_dim, feature_dim // 2), nn.GELU(), nn.Dropout(0.1), nn.Linear(feature_dim // 2, gaze_bins), ) self.yaw_head = nn.Sequential( nn.Linear(feature_dim, feature_dim // 2), nn.GELU(), nn.Dropout(0.1), nn.Linear(feature_dim // 2, gaze_bins), ) self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins)) self.feature_dim = feature_dim self.gaze_bins = gaze_bins def _adapt_face_input(self, x: torch.Tensor) -> torch.Tensor: if x.shape[1] == 1: x = x.repeat(1, 3, 1, 1) return x def forward(self, left_eye, right_eye, face_blurred_gray): left_feat = self.eye_extractor(left_eye) right_feat = self.eye_extractor(right_eye) face_input = self._adapt_face_input(face_blurred_gray) face_feat = self.face_extractor(face_input) eye_combined = torch.cat([left_feat, right_feat], dim=-1) eye_fused = self.eye_fusion(eye_combined) eye_stacked = torch.stack([left_feat, right_feat], dim=1) fused = self.cross_fusion(face_feat, eye_stacked) fused = fused + eye_fused pitch_logits = self.pitch_head(fused) yaw_logits = self.yaw_head(fused) pitch_probs = F.softmax(pitch_logits, dim=-1) yaw_probs = F.softmax(yaw_logits, dim=-1) pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1) yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1) return pitch_pred, yaw_pred, pitch_logits, yaw_logits, fused def get_penultimate_features(self, left_eye, right_eye, face_blurred_gray): _, _, _, _, fused = self.forward(left_eye, right_eye, face_blurred_gray) return fused