BcantCode commited on
Commit
01809ab
·
verified ·
1 Parent(s): 10d67d3

Upload models/teacher.py

Browse files
Files changed (1) hide show
  1. models/teacher.py +29 -93
models/teacher.py CHANGED
@@ -15,7 +15,7 @@ that the student does NOT have at inference time.
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
18
- from transformers import ConvNextV2Model, ConvNextV2Config
19
 
20
 
21
  class ConvNextV2FeatureExtractor(nn.Module):
@@ -24,10 +24,8 @@ class ConvNextV2FeatureExtractor(nn.Module):
24
  def __init__(self, model_name: str, output_dim: int = 256):
25
  super().__init__()
26
  self.backbone = ConvNextV2Model.from_pretrained(model_name)
27
- # Freeze early layers, fine-tune later stages
28
- self._setup_gradient_checkpointing()
29
 
30
- # Determine feature dimension from backbone
31
  hidden_size = self.backbone.config.hidden_sizes[-1]
32
  self.projection = nn.Sequential(
33
  nn.LayerNorm(hidden_size),
@@ -35,21 +33,9 @@ class ConvNextV2FeatureExtractor(nn.Module):
35
  nn.GELU(),
36
  )
37
 
38
- def _setup_gradient_checkpointing(self):
39
- self.backbone.gradient_checkpointing_enable()
40
-
41
  def forward(self, x: torch.Tensor) -> torch.Tensor:
42
- """Extract features from input image.
43
-
44
- Args:
45
- x: [B, 3, H, W] RGB image tensor
46
-
47
- Returns:
48
- features: [B, output_dim]
49
- """
50
  outputs = self.backbone(x)
51
- # Use pooled output (avg pool over spatial dims)
52
- pooled = outputs.pooler_output # [B, hidden_size]
53
  return self.projection(pooled)
54
 
55
 
@@ -68,25 +54,11 @@ class CrossAttentionFusion(nn.Module):
68
  )
69
 
70
  def forward(self, face_feat: torch.Tensor, eye_feats: torch.Tensor) -> torch.Tensor:
71
- """Fuse face and eye features via cross-attention.
72
-
73
- Args:
74
- face_feat: [B, dim]
75
- eye_feats: [B, 2, dim] - left and right eye features concatenated
76
-
77
- Returns:
78
- fused: [B, dim]
79
- """
80
- # Reshape for attention: [B, 1, dim] for face, [B, 2, dim] for eyes
81
- face_seq = face_feat.unsqueeze(1) # [B, 1, dim]
82
- eye_seq = eye_feats # [B, 2, dim]
83
-
84
- # Cross-attention: face attends to eye features
85
- attn_out, _ = self.cross_attn(face_seq, eye_seq, eye_seq)
86
  out = self.norm1(face_seq + attn_out)
87
  out = self.norm2(out + self.ffn(out))
88
-
89
- return out.squeeze(1) # [B, dim]
90
 
91
 
92
  class PriviGazeTeacher(nn.Module):
@@ -95,11 +67,13 @@ class PriviGazeTeacher(nn.Module):
95
  Inputs:
96
  - left_eye: [B, 3, 112, 112] RGB left eye crop
97
  - right_eye: [B, 3, 112, 112] RGB right eye crop
98
- - face_blurred_gray: [B, 1, 224, 224] Blurred grayscale face (only geometric info)
99
 
100
  Outputs:
101
- - pitch: [B] gaze pitch angle in degrees
102
- - yaw: [B] gaze yaw angle in degrees
 
 
103
  - features: [B, 256] fused feature representation for distillation
104
  """
105
 
@@ -108,107 +82,69 @@ class PriviGazeTeacher(nn.Module):
108
  eye_backbone: str = "facebook/convnextv2-atto-1k-224",
109
  face_backbone: str = "facebook/convnextv2-nano-22k-384",
110
  feature_dim: int = 256,
111
- gaze_bins: int = 90, # -90 to +90 degrees, binned
112
  ):
113
  super().__init__()
114
 
115
- # Eye feature extractors (shared weights for left and right)
116
  self.eye_extractor = ConvNextV2FeatureExtractor(eye_backbone, feature_dim)
117
-
118
- # Face feature extractor (takes 1-channel input, adapt first conv)
119
  self.face_extractor = ConvNextV2FeatureExtractor(face_backbone, feature_dim)
120
 
121
- # Eye fusion via self-attention
122
  self.eye_fusion = nn.Sequential(
123
  nn.Linear(feature_dim * 2, feature_dim),
124
  nn.GELU(),
125
  nn.LayerNorm(feature_dim),
126
  )
127
 
128
- # Cross-modal fusion
129
  self.cross_fusion = CrossAttentionFusion(feature_dim, num_heads=4)
130
 
131
- # Gaze regression heads (one per angle - L2CS-Net style)
132
  self.pitch_head = nn.Sequential(
133
  nn.Linear(feature_dim, feature_dim // 2),
134
  nn.GELU(),
135
  nn.Dropout(0.1),
136
- nn.Linear(feature_dim // 2, gaze_bins), # Binned classification
137
  )
138
 
139
  self.yaw_head = nn.Sequential(
140
  nn.Linear(feature_dim, feature_dim // 2),
141
  nn.GELU(),
142
  nn.Dropout(0.1),
143
- nn.Linear(feature_dim // 2, gaze_bins), # Binned classification
144
- )
145
-
146
- # Bin centers for expectation-based regression
147
- self.register_buffer(
148
- 'bin_centers',
149
- torch.linspace(-90.0, 90.0, gaze_bins)
150
  )
151
 
 
152
  self.feature_dim = feature_dim
153
  self.gaze_bins = gaze_bins
154
 
155
  def _adapt_face_input(self, x: torch.Tensor) -> torch.Tensor:
156
- """Adapt 1-channel grayscale face input to 3-channel for ConvNeXtV2.
157
-
158
- The first conv layer expects 3 channels. We replicate the grayscale
159
- channel 3 times. The model learns to treat this as a geometric-only signal
160
- since high-frequency texture/color is removed by blurring.
161
- """
162
  if x.shape[1] == 1:
163
  x = x.repeat(1, 3, 1, 1)
164
  return x
165
 
166
  def forward(self, left_eye, right_eye, face_blurred_gray):
167
- """
168
- Args:
169
- left_eye: [B, 3, 112, 112]
170
- right_eye: [B, 3, 112, 112]
171
- face_blurred_gray: [B, 1, 224, 224]
172
-
173
- Returns:
174
- pitch_pred: [B] gaze pitch in degrees
175
- yaw_pred: [B] gaze yaw in degrees
176
- fused_features: [B, feature_dim]
177
- """
178
- # Extract features from each modality
179
- left_feat = self.eye_extractor(left_eye) # [B, dim]
180
- right_feat = self.eye_extractor(right_eye) # [B, dim]
181
 
182
  face_input = self._adapt_face_input(face_blurred_gray)
183
- face_feat = self.face_extractor(face_input) # [B, dim]
184
-
185
- # Fuse eye features
186
- eye_combined = torch.cat([left_feat, right_feat], dim=-1) # [B, 2*dim]
187
- eye_fused = self.eye_fusion(eye_combined) # [B, dim]
188
-
189
- # Stack eye features for cross-attention
190
- eye_stacked = torch.stack([left_feat, right_feat], dim=1) # [B, 2, dim]
191
 
192
- # Cross-modal fusion: face attends to eye features
193
- fused = self.cross_fusion(face_feat, eye_stacked) # [B, dim]
194
 
195
- # Add eye fused features
196
- fused = fused + eye_fused # residual connection
 
197
 
198
- # Predict gaze angles using L2CS-Net style binned regression
199
- pitch_logits = self.pitch_head(fused) # [B, gaze_bins]
200
- yaw_logits = self.yaw_head(fused) # [B, gaze_bins]
201
 
202
- # Softmax + expectation for fine-grained regression
203
  pitch_probs = F.softmax(pitch_logits, dim=-1)
204
  yaw_probs = F.softmax(yaw_logits, dim=-1)
205
 
206
- pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1) # [B]
207
- yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1) # [B]
208
 
209
- return pitch_pred, yaw_pred, fused
210
 
211
  def get_penultimate_features(self, left_eye, right_eye, face_blurred_gray):
212
- """Return features before the regression heads for distillation."""
213
- _, _, fused = self.forward(left_eye, right_eye, face_blurred_gray)
214
  return fused
 
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
18
+ from transformers import ConvNextV2Model
19
 
20
 
21
  class ConvNextV2FeatureExtractor(nn.Module):
 
24
  def __init__(self, model_name: str, output_dim: int = 256):
25
  super().__init__()
26
  self.backbone = ConvNextV2Model.from_pretrained(model_name)
27
+ self.backbone.gradient_checkpointing_enable()
 
28
 
 
29
  hidden_size = self.backbone.config.hidden_sizes[-1]
30
  self.projection = nn.Sequential(
31
  nn.LayerNorm(hidden_size),
 
33
  nn.GELU(),
34
  )
35
 
 
 
 
36
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
37
  outputs = self.backbone(x)
38
+ pooled = outputs.pooler_output
 
39
  return self.projection(pooled)
40
 
41
 
 
54
  )
55
 
56
  def forward(self, face_feat: torch.Tensor, eye_feats: torch.Tensor) -> torch.Tensor:
57
+ face_seq = face_feat.unsqueeze(1)
58
+ attn_out, _ = self.cross_attn(face_seq, eye_feats, eye_feats)
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  out = self.norm1(face_seq + attn_out)
60
  out = self.norm2(out + self.ffn(out))
61
+ return out.squeeze(1)
 
62
 
63
 
64
  class PriviGazeTeacher(nn.Module):
 
67
  Inputs:
68
  - left_eye: [B, 3, 112, 112] RGB left eye crop
69
  - right_eye: [B, 3, 112, 112] RGB right eye crop
70
+ - face_blurred_gray: [B, 1, 224, 224] Blurred grayscale face
71
 
72
  Outputs:
73
+ - pitch_pred: [B] gaze pitch angle in degrees
74
+ - yaw_pred: [B] gaze yaw angle in degrees
75
+ - pitch_logits: [B, gaze_bins] for logit distillation
76
+ - yaw_logits: [B, gaze_bins] for logit distillation
77
  - features: [B, 256] fused feature representation for distillation
78
  """
79
 
 
82
  eye_backbone: str = "facebook/convnextv2-atto-1k-224",
83
  face_backbone: str = "facebook/convnextv2-nano-22k-384",
84
  feature_dim: int = 256,
85
+ gaze_bins: int = 90,
86
  ):
87
  super().__init__()
88
 
 
89
  self.eye_extractor = ConvNextV2FeatureExtractor(eye_backbone, feature_dim)
 
 
90
  self.face_extractor = ConvNextV2FeatureExtractor(face_backbone, feature_dim)
91
 
 
92
  self.eye_fusion = nn.Sequential(
93
  nn.Linear(feature_dim * 2, feature_dim),
94
  nn.GELU(),
95
  nn.LayerNorm(feature_dim),
96
  )
97
 
 
98
  self.cross_fusion = CrossAttentionFusion(feature_dim, num_heads=4)
99
 
 
100
  self.pitch_head = nn.Sequential(
101
  nn.Linear(feature_dim, feature_dim // 2),
102
  nn.GELU(),
103
  nn.Dropout(0.1),
104
+ nn.Linear(feature_dim // 2, gaze_bins),
105
  )
106
 
107
  self.yaw_head = nn.Sequential(
108
  nn.Linear(feature_dim, feature_dim // 2),
109
  nn.GELU(),
110
  nn.Dropout(0.1),
111
+ nn.Linear(feature_dim // 2, gaze_bins),
 
 
 
 
 
 
112
  )
113
 
114
+ self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins))
115
  self.feature_dim = feature_dim
116
  self.gaze_bins = gaze_bins
117
 
118
  def _adapt_face_input(self, x: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
119
  if x.shape[1] == 1:
120
  x = x.repeat(1, 3, 1, 1)
121
  return x
122
 
123
  def forward(self, left_eye, right_eye, face_blurred_gray):
124
+ left_feat = self.eye_extractor(left_eye)
125
+ right_feat = self.eye_extractor(right_eye)
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  face_input = self._adapt_face_input(face_blurred_gray)
128
+ face_feat = self.face_extractor(face_input)
 
 
 
 
 
 
 
129
 
130
+ eye_combined = torch.cat([left_feat, right_feat], dim=-1)
131
+ eye_fused = self.eye_fusion(eye_combined)
132
 
133
+ eye_stacked = torch.stack([left_feat, right_feat], dim=1)
134
+ fused = self.cross_fusion(face_feat, eye_stacked)
135
+ fused = fused + eye_fused
136
 
137
+ pitch_logits = self.pitch_head(fused)
138
+ yaw_logits = self.yaw_head(fused)
 
139
 
 
140
  pitch_probs = F.softmax(pitch_logits, dim=-1)
141
  yaw_probs = F.softmax(yaw_logits, dim=-1)
142
 
143
+ pitch_pred = (pitch_probs * self.bin_centers).sum(dim=-1)
144
+ yaw_pred = (yaw_probs * self.bin_centers).sum(dim=-1)
145
 
146
+ return pitch_pred, yaw_pred, pitch_logits, yaw_logits, fused
147
 
148
  def get_penultimate_features(self, left_eye, right_eye, face_blurred_gray):
149
+ _, _, _, _, fused = self.forward(left_eye, right_eye, face_blurred_gray)
 
150
  return fused