BcantCode commited on
Commit
e97351b
·
verified ·
1 Parent(s): c07e4db

Upload models/teacher.py

Browse files
Files changed (1) hide show
  1. models/teacher.py +214 -0
models/teacher.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PriviGaze Teacher Model - Siamese Multi-Input Gaze Estimation Network
3
+
4
+ Architecture:
5
+ - Takes 3 inputs: left eye RGB, right eye RGB, blurred grayscale face
6
+ - Uses ConvNeXtV2-Atto as shared backbone for eye streams
7
+ - Uses ConvNeXtV2-Nano for face stream
8
+ - Fuses multi-modal features via cross-attention
9
+ - Outputs: pitch and yaw gaze angles (degrees)
10
+
11
+ This teacher has access to privileged information (RGB eye crops, high-res face)
12
+ that the student does NOT have at inference time.
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from transformers import ConvNextV2Model, ConvNextV2Config
19
+
20
+
21
+ class ConvNextV2FeatureExtractor(nn.Module):
22
+ """Wrapper around ConvNeXtV2 for feature extraction (no classification head)."""
23
+
24
+ def __init__(self, model_name: str, output_dim: int = 256):
25
+ super().__init__()
26
+ self.backbone = ConvNextV2Model.from_pretrained(model_name)
27
+ # Freeze early layers, fine-tune later stages
28
+ self._setup_gradient_checkpointing()
29
+
30
+ # Determine feature dimension from backbone
31
+ hidden_size = self.backbone.config.hidden_sizes[-1]
32
+ self.projection = nn.Sequential(
33
+ nn.LayerNorm(hidden_size),
34
+ nn.Linear(hidden_size, output_dim),
35
+ nn.GELU(),
36
+ )
37
+
38
+ def _setup_gradient_checkpointing(self):
39
+ self.backbone.gradient_checkpointing_enable()
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ """Extract features from input image.
43
+
44
+ Args:
45
+ x: [B, 3, H, W] RGB image tensor
46
+
47
+ Returns:
48
+ features: [B, output_dim]
49
+ """
50
+ outputs = self.backbone(x)
51
+ # Use pooled output (avg pool over spatial dims)
52
+ pooled = outputs.pooler_output # [B, hidden_size]
53
+ return self.projection(pooled)
54
+
55
+
56
+ class CrossAttentionFusion(nn.Module):
57
+ """Cross-attention fusion module for multi-modal features."""
58
+
59
+ def __init__(self, dim: int = 256, num_heads: int = 4):
60
+ super().__init__()
61
+ self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
62
+ self.norm1 = nn.LayerNorm(dim)
63
+ self.norm2 = nn.LayerNorm(dim)
64
+ self.ffn = nn.Sequential(
65
+ nn.Linear(dim, dim * 4),
66
+ nn.GELU(),
67
+ nn.Linear(dim * 4, dim),
68
+ )
69
+
70
+ def forward(self, face_feat: torch.Tensor, eye_feats: torch.Tensor) -> torch.Tensor:
71
+ """Fuse face and eye features via cross-attention.
72
+
73
+ Args:
74
+ face_feat: [B, dim]
75
+ eye_feats: [B, 2, dim] - left and right eye features concatenated
76
+
77
+ Returns:
78
+ fused: [B, dim]
79
+ """
80
+ # Reshape for attention: [B, 1, dim] for face, [B, 2, dim] for eyes
81
+ face_seq = face_feat.unsqueeze(1) # [B, 1, dim]
82
+ eye_seq = eye_feats # [B, 2, dim]
83
+
84
+ # Cross-attention: face attends to eye features
85
+ attn_out, _ = self.cross_attn(face_seq, eye_seq, eye_seq)
86
+ out = self.norm1(face_seq + attn_out)
87
+ out = self.norm2(out + self.ffn(out))
88
+
89
+ return out.squeeze(1) # [B, dim]
90
+
91
+
92
+ class PriviGazeTeacher(nn.Module):
93
+ """Siamese teacher model with privileged multi-modal inputs.
94
+
95
+ Inputs:
96
+ - left_eye: [B, 3, 112, 112] RGB left eye crop
97
+ - right_eye: [B, 3, 112, 112] RGB right eye crop
98
+ - face_blurred_gray: [B, 1, 224, 224] Blurred grayscale face (only geometric info)
99
+
100
+ Outputs:
101
+ - pitch: [B] gaze pitch angle in degrees
102
+ - yaw: [B] gaze yaw angle in degrees
103
+ - features: [B, 256] fused feature representation for distillation
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ eye_backbone: str = "facebook/convnextv2-atto-1k-224",
109
+ face_backbone: str = "facebook/convnextv2-nano-22k-384",
110
+ feature_dim: int = 256,
111
+ gaze_bins: int = 90, # -90 to +90 degrees, binned
112
+ ):
113
+ super().__init__()
114
+
115
+ # Eye feature extractors (shared weights for left and right)
116
+ self.eye_extractor = ConvNextV2FeatureExtractor(eye_backbone, feature_dim)
117
+
118
+ # Face feature extractor (takes 1-channel input, adapt first conv)
119
+ self.face_extractor = ConvNextV2FeatureExtractor(face_backbone, feature_dim)
120
+
121
+ # Eye fusion via self-attention
122
+ self.eye_fusion = nn.Sequential(
123
+ nn.Linear(feature_dim * 2, feature_dim),
124
+ nn.GELU(),
125
+ nn.LayerNorm(feature_dim),
126
+ )
127
+
128
+ # Cross-modal fusion
129
+ self.cross_fusion = CrossAttentionFusion(feature_dim, num_heads=4)
130
+
131
+ # Gaze regression heads (one per angle - L2CS-Net style)
132
+ self.pitch_head = nn.Sequential(
133
+ nn.Linear(feature_dim, feature_dim // 2),
134
+ nn.GELU(),
135
+ nn.Dropout(0.1),
136
+ nn.Linear(feature_dim // 2, gaze_bins), # Binned classification
137
+ )
138
+
139
+ self.yaw_head = nn.Sequential(
140
+ nn.Linear(feature_dim, feature_dim // 2),
141
+ nn.GELU(),
142
+ nn.Dropout(0.1),
143
+ nn.Linear(feature_dim // 2, gaze_bins), # Binned classification
144
+ )
145
+
146
+ # Bin centers for expectation-based regression
147
+ self.register_buffer(
148
+ 'bin_centers',
149
+ torch.linspace(-90.0, 90.0, gaze_bins)
150
+ )
151
+
152
+ self.feature_dim = feature_dim
153
+ self.gaze_bins = gaze_bins
154
+
155
+ def _adapt_face_input(self, x: torch.Tensor) -> torch.Tensor:
156
+ """Adapt 1-channel grayscale face input to 3-channel for ConvNeXtV2.
157
+
158
+ The first conv layer expects 3 channels. We replicate the grayscale
159
+ channel 3 times. The model learns to treat this as a geometric-only signal
160
+ since high-frequency texture/color is removed by blurring.
161
+ """
162
+ if x.shape[1] == 1:
163
+ x = x.repeat(1, 3, 1, 1)
164
+ return x
165
+
166
+ def forward(self, left_eye, right_eye, face_blurred_gray):
167
+ """
168
+ Args:
169
+ left_eye: [B, 3, 112, 112]
170
+ right_eye: [B, 3, 112, 112]
171
+ face_blurred_gray: [B, 1, 224, 224]
172
+
173
+ Returns:
174
+ pitch_pred: [B] gaze pitch in degrees
175
+ yaw_pred: [B] gaze yaw in degrees
176
+ fused_features: [B, feature_dim]
177
+ """
178
+ # Extract features from each modality
179
+ left_feat = self.eye_extractor(left_eye) # [B, dim]
180
+ right_feat = self.eye_extractor(right_eye) # [B, dim]
181
+
182
+ face_input = self._adapt_face_input(face_blurred_gray)
183
+ face_feat = self.face_extractor(face_input) # [B, dim]
184
+
185
+ # Fuse eye features
186
+ eye_combined = torch.cat([left_feat, right_feat], dim=-1) # [B, 2*dim]
187
+ eye_fused = self.eye_fusion(eye_combined) # [B, dim]
188
+
189
+ # Stack eye features for cross-attention
190
+ eye_stacked = torch.stack([left_feat, right_feat], dim=1) # [B, 2, dim]
191
+
192
+ # Cross-modal fusion: face attends to eye features
193
+ fused = self.cross_fusion(face_feat, eye_stacked) # [B, dim]
194
+
195
+ # Add eye fused features
196
+ fused = fused + eye_fused # residual connection
197
+
198
+ # Predict gaze angles using L2CS-Net style binned regression
199
+ pitch_logits = self.pitch_head(fused) # [B, gaze_bins]
200
+ yaw_logits = self.yaw_head(fused) # [B, gaze_bins]
201
+
202
+ # Softmax + expectation for fine-grained regression
203
+ pitch_probs = F.softmax(pitch_logits, dim=-1)
204
+ yaw_probs = F.softmax(yaw_logits, dim=-1)
205
+
206
+ pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1) # [B]
207
+ yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1) # [B]
208
+
209
+ return pitch_pred, yaw_pred, fused
210
+
211
+ def get_penultimate_features(self, left_eye, right_eye, face_blurred_gray):
212
+ """Return features before the regression heads for distillation."""
213
+ _, _, fused = self.forward(left_eye, right_eye, face_blurred_gray)
214
+ return fused