File size: 5,454 Bytes
e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b 01809ab e97351b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
PriviGaze Teacher Model - Siamese Multi-Input Gaze Estimation Network
Architecture:
- Takes 3 inputs: left eye RGB, right eye RGB, blurred grayscale face
- Uses ConvNeXtV2-Atto as shared backbone for eye streams
- Uses ConvNeXtV2-Nano for face stream
- Fuses multi-modal features via cross-attention
- Outputs: pitch and yaw gaze angles (degrees)
This teacher has access to privileged information (RGB eye crops, high-res face)
that the student does NOT have at inference time.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ConvNextV2Model
class ConvNextV2FeatureExtractor(nn.Module):
"""Wrapper around ConvNeXtV2 for feature extraction (no classification head)."""
def __init__(self, model_name: str, output_dim: int = 256):
super().__init__()
self.backbone = ConvNextV2Model.from_pretrained(model_name)
self.backbone.gradient_checkpointing_enable()
hidden_size = self.backbone.config.hidden_sizes[-1]
self.projection = nn.Sequential(
nn.LayerNorm(hidden_size),
nn.Linear(hidden_size, output_dim),
nn.GELU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
outputs = self.backbone(x)
pooled = outputs.pooler_output
return self.projection(pooled)
class CrossAttentionFusion(nn.Module):
"""Cross-attention fusion module for multi-modal features."""
def __init__(self, dim: int = 256, num_heads: int = 4):
super().__init__()
self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
)
def forward(self, face_feat: torch.Tensor, eye_feats: torch.Tensor) -> torch.Tensor:
face_seq = face_feat.unsqueeze(1)
attn_out, _ = self.cross_attn(face_seq, eye_feats, eye_feats)
out = self.norm1(face_seq + attn_out)
out = self.norm2(out + self.ffn(out))
return out.squeeze(1)
class PriviGazeTeacher(nn.Module):
"""Siamese teacher model with privileged multi-modal inputs.
Inputs:
- left_eye: [B, 3, 112, 112] RGB left eye crop
- right_eye: [B, 3, 112, 112] RGB right eye crop
- face_blurred_gray: [B, 1, 224, 224] Blurred grayscale face
Outputs:
- pitch_pred: [B] gaze pitch angle in degrees
- yaw_pred: [B] gaze yaw angle in degrees
- pitch_logits: [B, gaze_bins] for logit distillation
- yaw_logits: [B, gaze_bins] for logit distillation
- features: [B, 256] fused feature representation for distillation
"""
def __init__(
self,
eye_backbone: str = "facebook/convnextv2-atto-1k-224",
face_backbone: str = "facebook/convnextv2-nano-22k-384",
feature_dim: int = 256,
gaze_bins: int = 90,
):
super().__init__()
self.eye_extractor = ConvNextV2FeatureExtractor(eye_backbone, feature_dim)
self.face_extractor = ConvNextV2FeatureExtractor(face_backbone, feature_dim)
self.eye_fusion = nn.Sequential(
nn.Linear(feature_dim * 2, feature_dim),
nn.GELU(),
nn.LayerNorm(feature_dim),
)
self.cross_fusion = CrossAttentionFusion(feature_dim, num_heads=4)
self.pitch_head = nn.Sequential(
nn.Linear(feature_dim, feature_dim // 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(feature_dim // 2, gaze_bins),
)
self.yaw_head = nn.Sequential(
nn.Linear(feature_dim, feature_dim // 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(feature_dim // 2, gaze_bins),
)
self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins))
self.feature_dim = feature_dim
self.gaze_bins = gaze_bins
def _adapt_face_input(self, x: torch.Tensor) -> torch.Tensor:
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1)
return x
def forward(self, left_eye, right_eye, face_blurred_gray):
left_feat = self.eye_extractor(left_eye)
right_feat = self.eye_extractor(right_eye)
face_input = self._adapt_face_input(face_blurred_gray)
face_feat = self.face_extractor(face_input)
eye_combined = torch.cat([left_feat, right_feat], dim=-1)
eye_fused = self.eye_fusion(eye_combined)
eye_stacked = torch.stack([left_feat, right_feat], dim=1)
fused = self.cross_fusion(face_feat, eye_stacked)
fused = fused + eye_fused
pitch_logits = self.pitch_head(fused)
yaw_logits = self.yaw_head(fused)
pitch_probs = F.softmax(pitch_logits, dim=-1)
yaw_probs = F.softmax(yaw_logits, dim=-1)
pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1)
yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1)
return pitch_pred, yaw_pred, pitch_logits, yaw_logits, fused
def get_penultimate_features(self, left_eye, right_eye, face_blurred_gray):
_, _, _, _, fused = self.forward(left_eye, right_eye, face_blurred_gray)
return fused
|