privi-gaze-distill / models /teacher.py
BcantCode's picture
Upload models/teacher.py
01809ab verified
"""
PriviGaze Teacher Model - Siamese Multi-Input Gaze Estimation Network
Architecture:
- Takes 3 inputs: left eye RGB, right eye RGB, blurred grayscale face
- Uses ConvNeXtV2-Atto as shared backbone for eye streams
- Uses ConvNeXtV2-Nano for face stream
- Fuses multi-modal features via cross-attention
- Outputs: pitch and yaw gaze angles (degrees)
This teacher has access to privileged information (RGB eye crops, high-res face)
that the student does NOT have at inference time.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ConvNextV2Model
class ConvNextV2FeatureExtractor(nn.Module):
"""Wrapper around ConvNeXtV2 for feature extraction (no classification head)."""
def __init__(self, model_name: str, output_dim: int = 256):
super().__init__()
self.backbone = ConvNextV2Model.from_pretrained(model_name)
self.backbone.gradient_checkpointing_enable()
hidden_size = self.backbone.config.hidden_sizes[-1]
self.projection = nn.Sequential(
nn.LayerNorm(hidden_size),
nn.Linear(hidden_size, output_dim),
nn.GELU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
outputs = self.backbone(x)
pooled = outputs.pooler_output
return self.projection(pooled)
class CrossAttentionFusion(nn.Module):
"""Cross-attention fusion module for multi-modal features."""
def __init__(self, dim: int = 256, num_heads: int = 4):
super().__init__()
self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
)
def forward(self, face_feat: torch.Tensor, eye_feats: torch.Tensor) -> torch.Tensor:
face_seq = face_feat.unsqueeze(1)
attn_out, _ = self.cross_attn(face_seq, eye_feats, eye_feats)
out = self.norm1(face_seq + attn_out)
out = self.norm2(out + self.ffn(out))
return out.squeeze(1)
class PriviGazeTeacher(nn.Module):
"""Siamese teacher model with privileged multi-modal inputs.
Inputs:
- left_eye: [B, 3, 112, 112] RGB left eye crop
- right_eye: [B, 3, 112, 112] RGB right eye crop
- face_blurred_gray: [B, 1, 224, 224] Blurred grayscale face
Outputs:
- pitch_pred: [B] gaze pitch angle in degrees
- yaw_pred: [B] gaze yaw angle in degrees
- pitch_logits: [B, gaze_bins] for logit distillation
- yaw_logits: [B, gaze_bins] for logit distillation
- features: [B, 256] fused feature representation for distillation
"""
def __init__(
self,
eye_backbone: str = "facebook/convnextv2-atto-1k-224",
face_backbone: str = "facebook/convnextv2-nano-22k-384",
feature_dim: int = 256,
gaze_bins: int = 90,
):
super().__init__()
self.eye_extractor = ConvNextV2FeatureExtractor(eye_backbone, feature_dim)
self.face_extractor = ConvNextV2FeatureExtractor(face_backbone, feature_dim)
self.eye_fusion = nn.Sequential(
nn.Linear(feature_dim * 2, feature_dim),
nn.GELU(),
nn.LayerNorm(feature_dim),
)
self.cross_fusion = CrossAttentionFusion(feature_dim, num_heads=4)
self.pitch_head = nn.Sequential(
nn.Linear(feature_dim, feature_dim // 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(feature_dim // 2, gaze_bins),
)
self.yaw_head = nn.Sequential(
nn.Linear(feature_dim, feature_dim // 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(feature_dim // 2, gaze_bins),
)
self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins))
self.feature_dim = feature_dim
self.gaze_bins = gaze_bins
def _adapt_face_input(self, x: torch.Tensor) -> torch.Tensor:
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1)
return x
def forward(self, left_eye, right_eye, face_blurred_gray):
left_feat = self.eye_extractor(left_eye)
right_feat = self.eye_extractor(right_eye)
face_input = self._adapt_face_input(face_blurred_gray)
face_feat = self.face_extractor(face_input)
eye_combined = torch.cat([left_feat, right_feat], dim=-1)
eye_fused = self.eye_fusion(eye_combined)
eye_stacked = torch.stack([left_feat, right_feat], dim=1)
fused = self.cross_fusion(face_feat, eye_stacked)
fused = fused + eye_fused
pitch_logits = self.pitch_head(fused)
yaw_logits = self.yaw_head(fused)
pitch_probs = F.softmax(pitch_logits, dim=-1)
yaw_probs = F.softmax(yaw_logits, dim=-1)
pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1)
yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1)
return pitch_pred, yaw_pred, pitch_logits, yaw_logits, fused
def get_penultimate_features(self, left_eye, right_eye, face_blurred_gray):
_, _, _, _, fused = self.forward(left_eye, right_eye, face_blurred_gray)
return fused