BcantCode commited on
Commit
94cb2c0
·
verified ·
1 Parent(s): 0607636

Upload models/distillation_loss.py

Browse files
Files changed (1) hide show
  1. models/distillation_loss.py +304 -0
models/distillation_loss.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PriviGaze Distillation Loss - Privileged Knowledge Distillation for Gaze Estimation
3
+
4
+ Key components:
5
+ 1. Angular gaze loss (L1 on pitch/yaw in degrees)
6
+ 2. L2CS-Net style binned classification + regression loss
7
+ 3. Feature-level distillation (WCoRD-inspired contrastive + distribution matching)
8
+ 4. Logit-level distillation (KL on soft targets from teacher)
9
+
10
+ The teacher has access to privileged information (RGB eye crops, high-res face)
11
+ that the student does NOT have at inference time.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class L2CSLoss(nn.Module):
20
+ """L2CS-Net style combined classification + regression loss per angle.
21
+
22
+ From "L2CS-Net: Fine-Grained Gaze Estimation in Unconstrained Environments"
23
+ (Abdelrahman et al., 2022)
24
+
25
+ Loss = CrossEntropy(binned_logits, binned_target) + beta * MSE(continuous_pred, continuous_target)
26
+ """
27
+
28
+ def __init__(self, gaze_bins: int = 90, beta: float = 1.0):
29
+ super().__init__()
30
+ self.gaze_bins = gaze_bins
31
+ self.beta = beta
32
+ self.register_buffer(
33
+ 'bin_centers',
34
+ torch.linspace(-90.0, 90.0, gaze_bins)
35
+ )
36
+ self.ce_loss = nn.CrossEntropyLoss()
37
+
38
+ def _angle_to_bins(self, angles: torch.Tensor) -> torch.Tensor:
39
+ """Convert continuous angle to bin index."""
40
+ angles_clamped = angles.clamp(-90.0, 90.0)
41
+ bin_width = 180.0 / (self.gaze_bins - 1)
42
+ bins = ((angles_clamped + 90.0) / bin_width).long()
43
+ return bins.clamp(0, self.gaze_bins - 1)
44
+
45
+ def forward(self, logits, continuous_pred, angle_target):
46
+ """
47
+ Args:
48
+ logits: [B, gaze_bins] - classification logits
49
+ continuous_pred: [B] - continuous angle prediction
50
+ angle_target: [B] - ground truth angle in degrees
51
+
52
+ Returns:
53
+ loss: scalar
54
+ """
55
+ bin_targets = self._angle_to_bins(angle_target)
56
+ ce = self.ce_loss(logits, bin_targets)
57
+ mse = F.mse_loss(continuous_pred, angle_target)
58
+ return ce + self.beta * mse
59
+
60
+
61
+ class AngularLoss(nn.Module):
62
+ """Direct angular error loss in degrees.
63
+
64
+ Computes L1 loss on pitch and yaw predictions.
65
+ This is the standard metric for gaze estimation.
66
+ """
67
+
68
+ def __init__(self, reduction: str = 'mean'):
69
+ super().__init__()
70
+ self.reduction = reduction
71
+
72
+ def forward(self, pitch_pred, yaw_pred, pitch_target, yaw_target):
73
+ """
74
+ Args:
75
+ pitch_pred: [B]
76
+ yaw_pred: [B]
77
+ pitch_target: [B]
78
+ yaw_target: [B]
79
+
80
+ Returns:
81
+ loss: scalar (mean angular error in degrees)
82
+ """
83
+ pitch_loss = F.l1_loss(pitch_pred, pitch_target, reduction=self.reduction)
84
+ yaw_loss = F.l1_loss(yaw_pred, yaw_target, reduction=self.reduction)
85
+ return pitch_loss + yaw_loss
86
+
87
+
88
+ class ContrastiveDistillationLoss(nn.Module):
89
+ """WCoRD-inspired contrastive feature distillation.
90
+
91
+ Maximizes mutual information between teacher and student feature
92
+ representations using InfoNCE contrastive loss.
93
+
94
+ From "Wasserstein Contrastive Representation Distillation" (Chen et al., 2020)
95
+ """
96
+
97
+ def __init__(self, feature_dim: int = 256, proj_dim: int = 128, temperature: float = 0.1):
98
+ super().__init__()
99
+ # Project both teacher and student features to shared space
100
+ self.teacher_proj = nn.Sequential(
101
+ nn.Linear(feature_dim, proj_dim),
102
+ nn.GELU(),
103
+ nn.Linear(proj_dim, proj_dim),
104
+ )
105
+
106
+ self.student_proj = nn.Sequential(
107
+ nn.Linear(128, proj_dim), # student has smaller feature dim
108
+ nn.GELU(),
109
+ nn.Linear(proj_dim, proj_dim),
110
+ )
111
+
112
+ self.temperature = temperature
113
+
114
+ def forward(self, teacher_feat: torch.Tensor, student_feat: torch.Tensor) -> torch.Tensor:
115
+ """
116
+ Args:
117
+ teacher_feat: [B, feature_dim] - teacher's penultimate features
118
+ student_feat: [B, 128] - student's penultimate features
119
+
120
+ Returns:
121
+ contrastive_loss: scalar
122
+ """
123
+ # Project to shared space
124
+ t = F.normalize(self.teacher_proj(teacher_feat), dim=-1) # [B, proj_dim]
125
+ s = F.normalize(self.student_proj(student_feat), dim=-1) # [B, proj_dim]
126
+
127
+ # Compute similarity matrix
128
+ # Positive pairs: (t_i, s_i) for all i
129
+ # Negative pairs: (t_i, s_j) for i != j
130
+ logits = torch.matmul(t, s.T) / self.temperature # [B, B]
131
+
132
+ # InfoNCE loss: each teacher feature should match its corresponding student
133
+ labels = torch.arange(logits.shape[0], device=logits.device)
134
+
135
+ # Symmetric loss: teacher -> student and student -> teacher
136
+ loss_t2s = F.cross_entropy(logits, labels)
137
+ loss_s2t = F.cross_entropy(logits.T, labels)
138
+
139
+ return (loss_t2s + loss_s2t) / 2.0
140
+
141
+
142
+ class DistributionMatchingLoss(nn.Module):
143
+ """Distribution matching loss for feature-level knowledge transfer.
144
+
145
+ Uses Maximum Mean Discrepancy (MMD) to match feature distributions
146
+ between teacher and student. This is a simpler alternative to
147
+ Wasserstein/Sinkhorn while still effective.
148
+ """
149
+
150
+ def __init__(self, kernel: str = 'rbf'):
151
+ super().__init__()
152
+ self.kernel = kernel
153
+
154
+ def _rbf_kernel(self, x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> torch.Tensor:
155
+ """RBF kernel between two sets of features."""
156
+ xx = torch.matmul(x, x.T)
157
+ yy = torch.matmul(y, y.T)
158
+ xy = torch.matmul(x, y.T)
159
+
160
+ rx = xx.diag().unsqueeze(0)
161
+ ry = yy.diag().unsqueeze(0)
162
+
163
+ k_xx = torch.exp(- (rx + rx.T - 2 * xx) / (2 * sigma ** 2))
164
+ k_yy = torch.exp(- (ry + ry.T - 2 * yy) / (2 * sigma ** 2))
165
+ k_xy = torch.exp(- (rx + ry.T - 2 * xy) / (2 * sigma ** 2))
166
+
167
+ return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
168
+
169
+ def forward(self, teacher_feat: torch.Tensor, student_feat: torch.Tensor) -> torch.Tensor:
170
+ """Compute MMD between teacher and student feature distributions."""
171
+ t = F.normalize(teacher_feat, dim=-1)
172
+ s = F.normalize(student_feat, dim=-1)
173
+ return self._rbf_kernel(t, s)
174
+
175
+
176
+ class LogitDistillationLoss(nn.Module):
177
+ """KL divergence distillation on output gaze predictions.
178
+
179
+ Standard knowledge distillation: student learns to mimic teacher's
180
+ soft probability distribution over gaze bins.
181
+ """
182
+
183
+ def __init__(self, temperature: float = 3.0):
184
+ super().__init__()
185
+ self.temperature = temperature
186
+
187
+ def forward(self, student_logits, teacher_logits):
188
+ """
189
+ Args:
190
+ student_logits: [B, gaze_bins]
191
+ teacher_logits: [B, gaze_bins] (detached)
192
+
193
+ Returns:
194
+ kl_loss: scalar
195
+ """
196
+ student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
197
+ teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
198
+ return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.temperature ** 2)
199
+
200
+
201
+ class PriviGazeDistillationLoss(nn.Module):
202
+ """Complete privileged distillation loss for gaze estimation.
203
+
204
+ Total loss = alpha_task * L_task
205
+ + alpha_angular * L_angular
206
+ + alpha_contrastive * L_contrastive
207
+ + alpha_mmd * L_mmd
208
+ + alpha_logit * L_logit
209
+
210
+ Task losses: L2CS-Net binned regression on student predictions
211
+ Angular losses: Direct L1 on pitch/yaw
212
+ Contrastive: Feature-level mutual information maximization
213
+ MMD: Distribution matching
214
+ Logit: Soft target distillation
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ gaze_bins: int = 90,
220
+ teacher_feature_dim: int = 256,
221
+ student_feature_dim: int = 128,
222
+ alpha_angular: float = 1.0,
223
+ alpha_contrastive: float = 0.5,
224
+ alpha_mmd: float = 0.1,
225
+ alpha_logit: float = 0.5,
226
+ ):
227
+ super().__init__()
228
+
229
+ self.angular_loss = AngularLoss()
230
+ self.pitch_l2cs = L2CSLoss(gaze_bins)
231
+ self.yaw_l2cs = L2CSLoss(gaze_bins)
232
+ self.contrastive_loss = ContrastiveDistillationLoss(
233
+ teacher_feature_dim, student_feature_dim
234
+ )
235
+ self.mmd_loss = DistributionMatchingLoss()
236
+ self.logit_loss = LogitDistillationLoss()
237
+
238
+ self.alpha_angular = alpha_angular
239
+ self.alpha_contrastive = alpha_contrastive
240
+ self.alpha_mmd = alpha_mmd
241
+ self.alpha_logit = alpha_logit
242
+
243
+ def forward(
244
+ self,
245
+ student_pitch: torch.Tensor,
246
+ student_yaw: torch.Tensor,
247
+ student_pitch_logits: torch.Tensor,
248
+ student_yaw_logits: torch.Tensor,
249
+ student_features: torch.Tensor,
250
+ teacher_pitch: torch.Tensor,
251
+ teacher_yaw: torch.Tensor,
252
+ teacher_pitch_logits: torch.Tensor,
253
+ teacher_yaw_logits: torch.Tensor,
254
+ teacher_features: torch.Tensor,
255
+ pitch_target: torch.Tensor,
256
+ yaw_target: torch.Tensor,
257
+ ):
258
+ """
259
+ Returns:
260
+ total_loss: scalar
261
+ loss_dict: dict of individual losses for logging
262
+ """
263
+ # 1. Task losses (student predictions vs ground truth)
264
+ task_pitch = self.pitch_l2cs(student_pitch_logits, student_pitch, pitch_target)
265
+ task_yaw = self.yaw_l2cs(student_yaw_logits, student_yaw, yaw_target)
266
+ loss_task = task_pitch + task_yaw
267
+
268
+ # 2. Angular loss (direct L1 in degrees)
269
+ loss_angular = self.alpha_angular * self.angular_loss(
270
+ student_pitch, student_yaw, pitch_target, yaw_target
271
+ )
272
+
273
+ # 3. Contrastive feature distillation
274
+ loss_contrastive = self.alpha_contrastive * self.contrastive_loss(
275
+ teacher_features.detach(), student_features
276
+ )
277
+
278
+ # 4. Distribution matching (MMD)
279
+ loss_mmd = self.alpha_mmd * self.mmd_loss(
280
+ teacher_features.detach(), student_features
281
+ )
282
+
283
+ # 5. Logit distillation (teacher soft targets)
284
+ loss_logit_pitch = self.alpha_logit * self.logit_loss(
285
+ student_pitch_logits, teacher_pitch_logits.detach()
286
+ )
287
+ loss_logit_yaw = self.alpha_logit * self.logit_loss(
288
+ student_yaw_logits, teacher_yaw_logits.detach()
289
+ )
290
+ loss_logit = loss_logit_pitch + loss_logit_yaw
291
+
292
+ # Total
293
+ total_loss = loss_task + loss_angular + loss_contrastive + loss_mmd + loss_logit
294
+
295
+ loss_dict = {
296
+ 'loss_total': total_loss.item(),
297
+ 'loss_task': loss_task.item(),
298
+ 'loss_angular': loss_angular.item(),
299
+ 'loss_contrastive': loss_contrastive.item(),
300
+ 'loss_mmd': loss_mmd.item(),
301
+ 'loss_logit': loss_logit.item(),
302
+ }
303
+
304
+ return total_loss, loss_dict