File size: 5,454 Bytes
e97351b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01809ab
e97351b
 
 
 
 
 
 
 
01809ab
e97351b
 
 
 
 
 
 
 
 
 
01809ab
e97351b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01809ab
 
e97351b
 
01809ab
e97351b
 
 
 
 
 
 
 
01809ab
e97351b
 
01809ab
 
 
 
e97351b
 
 
 
 
 
 
 
01809ab
e97351b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01809ab
e97351b
 
 
 
 
 
01809ab
e97351b
 
01809ab
e97351b
 
 
 
 
 
 
 
 
01809ab
 
e97351b
 
01809ab
e97351b
01809ab
 
e97351b
01809ab
 
 
e97351b
01809ab
 
e97351b
 
 
 
01809ab
 
e97351b
01809ab
e97351b
 
01809ab
e97351b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
PriviGaze Teacher Model - Siamese Multi-Input Gaze Estimation Network

Architecture:
- Takes 3 inputs: left eye RGB, right eye RGB, blurred grayscale face
- Uses ConvNeXtV2-Atto as shared backbone for eye streams
- Uses ConvNeXtV2-Nano for face stream
- Fuses multi-modal features via cross-attention
- Outputs: pitch and yaw gaze angles (degrees)

This teacher has access to privileged information (RGB eye crops, high-res face)
that the student does NOT have at inference time.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ConvNextV2Model


class ConvNextV2FeatureExtractor(nn.Module):
    """Wrapper around ConvNeXtV2 for feature extraction (no classification head)."""
    
    def __init__(self, model_name: str, output_dim: int = 256):
        super().__init__()
        self.backbone = ConvNextV2Model.from_pretrained(model_name)
        self.backbone.gradient_checkpointing_enable()
        
        hidden_size = self.backbone.config.hidden_sizes[-1]
        self.projection = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, output_dim),
            nn.GELU(),
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        outputs = self.backbone(x)
        pooled = outputs.pooler_output
        return self.projection(pooled)


class CrossAttentionFusion(nn.Module):
    """Cross-attention fusion module for multi-modal features."""
    
    def __init__(self, dim: int = 256, num_heads: int = 4):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )
    
    def forward(self, face_feat: torch.Tensor, eye_feats: torch.Tensor) -> torch.Tensor:
        face_seq = face_feat.unsqueeze(1)
        attn_out, _ = self.cross_attn(face_seq, eye_feats, eye_feats)
        out = self.norm1(face_seq + attn_out)
        out = self.norm2(out + self.ffn(out))
        return out.squeeze(1)


class PriviGazeTeacher(nn.Module):
    """Siamese teacher model with privileged multi-modal inputs.
    
    Inputs:
        - left_eye: [B, 3, 112, 112] RGB left eye crop
        - right_eye: [B, 3, 112, 112] RGB right eye crop  
        - face_blurred_gray: [B, 1, 224, 224] Blurred grayscale face
    
    Outputs:
        - pitch_pred: [B] gaze pitch angle in degrees
        - yaw_pred: [B] gaze yaw angle in degrees
        - pitch_logits: [B, gaze_bins] for logit distillation
        - yaw_logits: [B, gaze_bins] for logit distillation
        - features: [B, 256] fused feature representation for distillation
    """
    
    def __init__(
        self,
        eye_backbone: str = "facebook/convnextv2-atto-1k-224",
        face_backbone: str = "facebook/convnextv2-nano-22k-384",
        feature_dim: int = 256,
        gaze_bins: int = 90,
    ):
        super().__init__()
        
        self.eye_extractor = ConvNextV2FeatureExtractor(eye_backbone, feature_dim)
        self.face_extractor = ConvNextV2FeatureExtractor(face_backbone, feature_dim)
        
        self.eye_fusion = nn.Sequential(
            nn.Linear(feature_dim * 2, feature_dim),
            nn.GELU(),
            nn.LayerNorm(feature_dim),
        )
        
        self.cross_fusion = CrossAttentionFusion(feature_dim, num_heads=4)
        
        self.pitch_head = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(feature_dim // 2, gaze_bins),
        )
        
        self.yaw_head = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(feature_dim // 2, gaze_bins),
        )
        
        self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins))
        self.feature_dim = feature_dim
        self.gaze_bins = gaze_bins
    
    def _adapt_face_input(self, x: torch.Tensor) -> torch.Tensor:
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        return x
    
    def forward(self, left_eye, right_eye, face_blurred_gray):
        left_feat = self.eye_extractor(left_eye)
        right_feat = self.eye_extractor(right_eye)
        
        face_input = self._adapt_face_input(face_blurred_gray)
        face_feat = self.face_extractor(face_input)
        
        eye_combined = torch.cat([left_feat, right_feat], dim=-1)
        eye_fused = self.eye_fusion(eye_combined)
        
        eye_stacked = torch.stack([left_feat, right_feat], dim=1)
        fused = self.cross_fusion(face_feat, eye_stacked)
        fused = fused + eye_fused
        
        pitch_logits = self.pitch_head(fused)
        yaw_logits = self.yaw_head(fused)
        
        pitch_probs = F.softmax(pitch_logits, dim=-1)
        yaw_probs = F.softmax(yaw_logits, dim=-1)
        
        pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1)
        yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1)
        
        return pitch_pred, yaw_pred, pitch_logits, yaw_logits, fused
    
    def get_penultimate_features(self, left_eye, right_eye, face_blurred_gray):
        _, _, _, _, fused = self.forward(left_eye, right_eye, face_blurred_gray)
        return fused