BcantCode commited on
Commit
fdc4b3d
verified
1 Parent(s): 327e860

Upload models/distillation_loss.py

Browse files
Files changed (1) hide show
  1. models/distillation_loss.py +38 -181
models/distillation_loss.py CHANGED
@@ -1,14 +1,11 @@
1
  """
2
  PriviGaze Distillation Loss - Privileged Knowledge Distillation for Gaze Estimation
3
 
4
- Key components:
5
  1. Angular gaze loss (L1 on pitch/yaw in degrees)
6
  2. L2CS-Net style binned classification + regression loss
7
  3. Feature-level distillation (WCoRD-inspired contrastive + distribution matching)
8
  4. Logit-level distillation (KL on soft targets from teacher)
9
-
10
- The teacher has access to privileged information (RGB eye crops, high-res face)
11
- that the student does NOT have at inference time.
12
  """
13
 
14
  import torch
@@ -19,9 +16,6 @@ import torch.nn.functional as F
19
  class L2CSLoss(nn.Module):
20
  """L2CS-Net style combined classification + regression loss per angle.
21
 
22
- From "L2CS-Net: Fine-Grained Gaze Estimation in Unconstrained Environments"
23
- (Abdelrahman et al., 2022)
24
-
25
  Loss = CrossEntropy(binned_logits, binned_target) + beta * MSE(continuous_pred, continuous_target)
26
  """
27
 
@@ -29,29 +23,16 @@ class L2CSLoss(nn.Module):
29
  super().__init__()
30
  self.gaze_bins = gaze_bins
31
  self.beta = beta
32
- self.register_buffer(
33
- 'bin_centers',
34
- torch.linspace(-90.0, 90.0, gaze_bins)
35
- )
36
  self.ce_loss = nn.CrossEntropyLoss()
37
 
38
- def _angle_to_bins(self, angles: torch.Tensor) -> torch.Tensor:
39
- """Convert continuous angle to bin index."""
40
  angles_clamped = angles.clamp(-90.0, 90.0)
41
  bin_width = 180.0 / (self.gaze_bins - 1)
42
  bins = ((angles_clamped + 90.0) / bin_width).long()
43
  return bins.clamp(0, self.gaze_bins - 1)
44
 
45
  def forward(self, logits, continuous_pred, angle_target):
46
- """
47
- Args:
48
- logits: [B, gaze_bins] - classification logits
49
- continuous_pred: [B] - continuous angle prediction
50
- angle_target: [B] - ground truth angle in degrees
51
-
52
- Returns:
53
- loss: scalar
54
- """
55
  bin_targets = self._angle_to_bins(angle_target)
56
  ce = self.ce_loss(logits, bin_targets)
57
  mse = F.mse_loss(continuous_pred, angle_target)
@@ -59,27 +40,13 @@ class L2CSLoss(nn.Module):
59
 
60
 
61
  class AngularLoss(nn.Module):
62
- """Direct angular error loss in degrees.
63
-
64
- Computes L1 loss on pitch and yaw predictions.
65
- This is the standard metric for gaze estimation.
66
- """
67
 
68
  def __init__(self, reduction: str = 'mean'):
69
  super().__init__()
70
  self.reduction = reduction
71
 
72
  def forward(self, pitch_pred, yaw_pred, pitch_target, yaw_target):
73
- """
74
- Args:
75
- pitch_pred: [B]
76
- yaw_pred: [B]
77
- pitch_target: [B]
78
- yaw_target: [B]
79
-
80
- Returns:
81
- loss: scalar (mean angular error in degrees)
82
- """
83
  pitch_loss = F.l1_loss(pitch_pred, pitch_target, reduction=self.reduction)
84
  yaw_loss = F.l1_loss(yaw_pred, yaw_target, reduction=self.reduction)
85
  return pitch_loss + yaw_loss
@@ -88,208 +55,99 @@ class AngularLoss(nn.Module):
88
  class ContrastiveDistillationLoss(nn.Module):
89
  """WCoRD-inspired contrastive feature distillation.
90
 
91
- Maximizes mutual information between teacher and student feature
92
- representations using InfoNCE contrastive loss.
93
-
94
- From "Wasserstein Contrastive Representation Distillation" (Chen et al., 2020)
95
  """
96
 
97
- def __init__(self, feature_dim: int = 256, proj_dim: int = 128, temperature: float = 0.1):
98
  super().__init__()
99
- # Project both teacher and student features to shared space
100
  self.teacher_proj = nn.Sequential(
101
- nn.Linear(feature_dim, proj_dim),
102
- nn.GELU(),
103
- nn.Linear(proj_dim, proj_dim),
104
- )
105
-
106
  self.student_proj = nn.Sequential(
107
- nn.Linear(128, proj_dim), # student has smaller feature dim
108
- nn.GELU(),
109
- nn.Linear(proj_dim, proj_dim),
110
- )
111
-
112
  self.temperature = temperature
113
 
114
- def forward(self, teacher_feat: torch.Tensor, student_feat: torch.Tensor) -> torch.Tensor:
115
- """
116
- Args:
117
- teacher_feat: [B, feature_dim] - teacher's penultimate features
118
- student_feat: [B, 128] - student's penultimate features
119
-
120
- Returns:
121
- contrastive_loss: scalar
122
- """
123
- # Project to shared space
124
- t = F.normalize(self.teacher_proj(teacher_feat), dim=-1) # [B, proj_dim]
125
- s = F.normalize(self.student_proj(student_feat), dim=-1) # [B, proj_dim]
126
-
127
- # Compute similarity matrix
128
- # Positive pairs: (t_i, s_i) for all i
129
- # Negative pairs: (t_i, s_j) for i != j
130
- logits = torch.matmul(t, s.T) / self.temperature # [B, B]
131
-
132
- # InfoNCE loss: each teacher feature should match its corresponding student
133
  labels = torch.arange(logits.shape[0], device=logits.device)
134
-
135
- # Symmetric loss: teacher -> student and student -> teacher
136
  loss_t2s = F.cross_entropy(logits, labels)
137
  loss_s2t = F.cross_entropy(logits.T, labels)
138
-
139
  return (loss_t2s + loss_s2t) / 2.0
140
 
141
 
142
  class DistributionMatchingLoss(nn.Module):
143
- """Distribution matching loss for feature-level knowledge transfer.
144
-
145
- Uses Maximum Mean Discrepancy (MMD) to match feature distributions
146
- between teacher and student. This is a simpler alternative to
147
- Wasserstein/Sinkhorn while still effective.
148
- """
149
 
150
  def __init__(self, kernel: str = 'rbf'):
151
  super().__init__()
152
  self.kernel = kernel
153
 
154
- def _rbf_kernel(self, x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> torch.Tensor:
155
- """RBF kernel between two sets of features."""
156
  xx = torch.matmul(x, x.T)
157
  yy = torch.matmul(y, y.T)
158
  xy = torch.matmul(x, y.T)
159
-
160
  rx = xx.diag().unsqueeze(0)
161
  ry = yy.diag().unsqueeze(0)
162
-
163
- k_xx = torch.exp(- (rx + rx.T - 2 * xx) / (2 * sigma ** 2))
164
- k_yy = torch.exp(- (ry + ry.T - 2 * yy) / (2 * sigma ** 2))
165
- k_xy = torch.exp(- (rx + ry.T - 2 * xy) / (2 * sigma ** 2))
166
-
167
- return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
168
 
169
- def forward(self, teacher_feat: torch.Tensor, student_feat: torch.Tensor) -> torch.Tensor:
170
- """Compute MMD between teacher and student feature distributions."""
171
  t = F.normalize(teacher_feat, dim=-1)
172
  s = F.normalize(student_feat, dim=-1)
173
  return self._rbf_kernel(t, s)
174
 
175
 
176
  class LogitDistillationLoss(nn.Module):
177
- """KL divergence distillation on output gaze predictions.
178
-
179
- Standard knowledge distillation: student learns to mimic teacher's
180
- soft probability distribution over gaze bins.
181
- """
182
 
183
  def __init__(self, temperature: float = 3.0):
184
  super().__init__()
185
  self.temperature = temperature
186
 
187
  def forward(self, student_logits, teacher_logits):
188
- """
189
- Args:
190
- student_logits: [B, gaze_bins]
191
- teacher_logits: [B, gaze_bins] (detached)
192
-
193
- Returns:
194
- kl_loss: scalar
195
- """
196
  student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
197
  teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
198
- return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.temperature ** 2)
199
 
200
 
201
  class PriviGazeDistillationLoss(nn.Module):
202
  """Complete privileged distillation loss for gaze estimation.
203
 
204
- Total loss = alpha_task * L_task
205
- + alpha_angular * L_angular
206
- + alpha_contrastive * L_contrastive
207
- + alpha_mmd * L_mmd
208
- + alpha_logit * L_logit
209
-
210
- Task losses: L2CS-Net binned regression on student predictions
211
- Angular losses: Direct L1 on pitch/yaw
212
- Contrastive: Feature-level mutual information maximization
213
- MMD: Distribution matching
214
- Logit: Soft target distillation
215
  """
216
 
217
- def __init__(
218
- self,
219
- gaze_bins: int = 90,
220
- teacher_feature_dim: int = 256,
221
- student_feature_dim: int = 128,
222
- alpha_angular: float = 1.0,
223
- alpha_contrastive: float = 0.5,
224
- alpha_mmd: float = 0.1,
225
- alpha_logit: float = 0.5,
226
- ):
227
  super().__init__()
228
-
229
  self.angular_loss = AngularLoss()
230
  self.pitch_l2cs = L2CSLoss(gaze_bins)
231
  self.yaw_l2cs = L2CSLoss(gaze_bins)
232
- self.contrastive_loss = ContrastiveDistillationLoss(
233
- teacher_feature_dim, student_feature_dim
234
- )
235
  self.mmd_loss = DistributionMatchingLoss()
236
  self.logit_loss = LogitDistillationLoss()
237
-
238
  self.alpha_angular = alpha_angular
239
  self.alpha_contrastive = alpha_contrastive
240
  self.alpha_mmd = alpha_mmd
241
  self.alpha_logit = alpha_logit
242
 
243
- def forward(
244
- self,
245
- student_pitch: torch.Tensor,
246
- student_yaw: torch.Tensor,
247
- student_pitch_logits: torch.Tensor,
248
- student_yaw_logits: torch.Tensor,
249
- student_features: torch.Tensor,
250
- teacher_pitch: torch.Tensor,
251
- teacher_yaw: torch.Tensor,
252
- teacher_pitch_logits: torch.Tensor,
253
- teacher_yaw_logits: torch.Tensor,
254
- teacher_features: torch.Tensor,
255
- pitch_target: torch.Tensor,
256
- yaw_target: torch.Tensor,
257
- ):
258
- """
259
- Returns:
260
- total_loss: scalar
261
- loss_dict: dict of individual losses for logging
262
- """
263
- # 1. Task losses (student predictions vs ground truth)
264
- task_pitch = self.pitch_l2cs(student_pitch_logits, student_pitch, pitch_target)
265
- task_yaw = self.yaw_l2cs(student_yaw_logits, student_yaw, yaw_target)
266
  loss_task = task_pitch + task_yaw
267
 
268
- # 2. Angular loss (direct L1 in degrees)
269
- loss_angular = self.alpha_angular * self.angular_loss(
270
- student_pitch, student_yaw, pitch_target, yaw_target
271
- )
272
-
273
- # 3. Contrastive feature distillation
274
- loss_contrastive = self.alpha_contrastive * self.contrastive_loss(
275
- teacher_features.detach(), student_features
276
- )
277
 
278
- # 4. Distribution matching (MMD)
279
- loss_mmd = self.alpha_mmd * self.mmd_loss(
280
- teacher_features.detach(), student_features
281
- )
282
-
283
- # 5. Logit distillation (teacher soft targets)
284
- loss_logit_pitch = self.alpha_logit * self.logit_loss(
285
- student_pitch_logits, teacher_pitch_logits.detach()
286
- )
287
- loss_logit_yaw = self.alpha_logit * self.logit_loss(
288
- student_yaw_logits, teacher_yaw_logits.detach()
289
- )
290
- loss_logit = loss_logit_pitch + loss_logit_yaw
291
-
292
- # Total
293
  total_loss = loss_task + loss_angular + loss_contrastive + loss_mmd + loss_logit
294
 
295
  loss_dict = {
@@ -300,5 +158,4 @@ class PriviGazeDistillationLoss(nn.Module):
300
  'loss_mmd': loss_mmd.item(),
301
  'loss_logit': loss_logit.item(),
302
  }
303
-
304
- return total_loss, loss_dict
 
1
  """
2
  PriviGaze Distillation Loss - Privileged Knowledge Distillation for Gaze Estimation
3
 
4
+ Components:
5
  1. Angular gaze loss (L1 on pitch/yaw in degrees)
6
  2. L2CS-Net style binned classification + regression loss
7
  3. Feature-level distillation (WCoRD-inspired contrastive + distribution matching)
8
  4. Logit-level distillation (KL on soft targets from teacher)
 
 
 
9
  """
10
 
11
  import torch
 
16
  class L2CSLoss(nn.Module):
17
  """L2CS-Net style combined classification + regression loss per angle.
18
 
 
 
 
19
  Loss = CrossEntropy(binned_logits, binned_target) + beta * MSE(continuous_pred, continuous_target)
20
  """
21
 
 
23
  super().__init__()
24
  self.gaze_bins = gaze_bins
25
  self.beta = beta
26
+ self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins))
 
 
 
27
  self.ce_loss = nn.CrossEntropyLoss()
28
 
29
+ def _angle_to_bins(self, angles):
 
30
  angles_clamped = angles.clamp(-90.0, 90.0)
31
  bin_width = 180.0 / (self.gaze_bins - 1)
32
  bins = ((angles_clamped + 90.0) / bin_width).long()
33
  return bins.clamp(0, self.gaze_bins - 1)
34
 
35
  def forward(self, logits, continuous_pred, angle_target):
 
 
 
 
 
 
 
 
 
36
  bin_targets = self._angle_to_bins(angle_target)
37
  ce = self.ce_loss(logits, bin_targets)
38
  mse = F.mse_loss(continuous_pred, angle_target)
 
40
 
41
 
42
  class AngularLoss(nn.Module):
43
+ """Direct angular error loss in degrees (L1 on pitch and yaw)."""
 
 
 
 
44
 
45
  def __init__(self, reduction: str = 'mean'):
46
  super().__init__()
47
  self.reduction = reduction
48
 
49
  def forward(self, pitch_pred, yaw_pred, pitch_target, yaw_target):
 
 
 
 
 
 
 
 
 
 
50
  pitch_loss = F.l1_loss(pitch_pred, pitch_target, reduction=self.reduction)
51
  yaw_loss = F.l1_loss(yaw_pred, yaw_target, reduction=self.reduction)
52
  return pitch_loss + yaw_loss
 
55
  class ContrastiveDistillationLoss(nn.Module):
56
  """WCoRD-inspired contrastive feature distillation.
57
 
58
+ InfoNCE loss maximizing mutual information between teacher and student features.
59
+ Projects both to a shared space before computing similarity.
 
 
60
  """
61
 
62
+ def __init__(self, teacher_dim: int = 256, student_dim: int = 128, proj_dim: int = 128, temperature: float = 0.1):
63
  super().__init__()
 
64
  self.teacher_proj = nn.Sequential(
65
+ nn.Linear(teacher_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim))
 
 
 
 
66
  self.student_proj = nn.Sequential(
67
+ nn.Linear(student_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim))
 
 
 
 
68
  self.temperature = temperature
69
 
70
+ def forward(self, teacher_feat, student_feat):
71
+ t = F.normalize(self.teacher_proj(teacher_feat), dim=-1)
72
+ s = F.normalize(self.student_proj(student_feat), dim=-1)
73
+ logits = torch.matmul(t, s.T) / self.temperature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  labels = torch.arange(logits.shape[0], device=logits.device)
 
 
75
  loss_t2s = F.cross_entropy(logits, labels)
76
  loss_s2t = F.cross_entropy(logits.T, labels)
 
77
  return (loss_t2s + loss_s2t) / 2.0
78
 
79
 
80
  class DistributionMatchingLoss(nn.Module):
81
+ """MMD-based distribution matching between teacher and student features."""
 
 
 
 
 
82
 
83
  def __init__(self, kernel: str = 'rbf'):
84
  super().__init__()
85
  self.kernel = kernel
86
 
87
+ def _rbf_kernel(self, x, y, sigma=1.0):
 
88
  xx = torch.matmul(x, x.T)
89
  yy = torch.matmul(y, y.T)
90
  xy = torch.matmul(x, y.T)
 
91
  rx = xx.diag().unsqueeze(0)
92
  ry = yy.diag().unsqueeze(0)
93
+ k_xx = torch.exp(-(rx + rx.T - 2*xx) / (2*sigma**2))
94
+ k_yy = torch.exp(-(ry + ry.T - 2*yy) / (2*sigma**2))
95
+ k_xy = torch.exp(-(rx + ry.T - 2*xy) / (2*sigma**2))
96
+ return k_xx.mean() + k_yy.mean() - 2*k_xy.mean()
 
 
97
 
98
+ def forward(self, teacher_feat, student_feat):
 
99
  t = F.normalize(teacher_feat, dim=-1)
100
  s = F.normalize(student_feat, dim=-1)
101
  return self._rbf_kernel(t, s)
102
 
103
 
104
  class LogitDistillationLoss(nn.Module):
105
+ """KL divergence distillation on soft gaze bin probabilities."""
 
 
 
 
106
 
107
  def __init__(self, temperature: float = 3.0):
108
  super().__init__()
109
  self.temperature = temperature
110
 
111
  def forward(self, student_logits, teacher_logits):
 
 
 
 
 
 
 
 
112
  student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
113
  teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
114
+ return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.temperature**2)
115
 
116
 
117
  class PriviGazeDistillationLoss(nn.Module):
118
  """Complete privileged distillation loss for gaze estimation.
119
 
120
+ L_total = L_task + 伪_angular路L_angular + 伪_contrastive路L_contrastive
121
+ + 伪_mmd路L_mmd + 伪_logit路L_logit
 
 
 
 
 
 
 
 
 
122
  """
123
 
124
+ def __init__(self, gaze_bins=90, teacher_feature_dim=256, student_feature_dim=128,
125
+ alpha_angular=1.0, alpha_contrastive=0.5, alpha_mmd=0.1, alpha_logit=0.5):
 
 
 
 
 
 
 
 
126
  super().__init__()
 
127
  self.angular_loss = AngularLoss()
128
  self.pitch_l2cs = L2CSLoss(gaze_bins)
129
  self.yaw_l2cs = L2CSLoss(gaze_bins)
130
+ self.contrastive_loss = ContrastiveDistillationLoss(teacher_feature_dim, student_feature_dim)
 
 
131
  self.mmd_loss = DistributionMatchingLoss()
132
  self.logit_loss = LogitDistillationLoss()
 
133
  self.alpha_angular = alpha_angular
134
  self.alpha_contrastive = alpha_contrastive
135
  self.alpha_mmd = alpha_mmd
136
  self.alpha_logit = alpha_logit
137
 
138
+ def forward(self, s_pitch, s_yaw, sp_logits, sy_logits, s_features,
139
+ t_pitch, t_yaw, tp_logits, ty_logits, t_features,
140
+ pitch_target, yaw_target):
141
+ task_pitch = self.pitch_l2cs(sp_logits, s_pitch, pitch_target)
142
+ task_yaw = self.yaw_l2cs(sy_logits, s_yaw, yaw_target)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  loss_task = task_pitch + task_yaw
144
 
145
+ loss_angular = self.alpha_angular * self.angular_loss(s_pitch, s_yaw, pitch_target, yaw_target)
146
+ loss_contrastive = self.alpha_contrastive * self.contrastive_loss(t_features.detach(), s_features)
147
+ loss_mmd = self.alpha_mmd * self.mmd_loss(t_features.detach(), s_features)
148
+ loss_logit = (self.alpha_logit * self.logit_loss(sp_logits, tp_logits.detach()) +
149
+ self.alpha_logit * self.logit_loss(sy_logits, ty_logits.detach()))
 
 
 
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  total_loss = loss_task + loss_angular + loss_contrastive + loss_mmd + loss_logit
152
 
153
  loss_dict = {
 
158
  'loss_mmd': loss_mmd.item(),
159
  'loss_logit': loss_logit.item(),
160
  }
161
+ return total_loss, loss_dict