| """ |
| PriviGaze Distillation Loss - Privileged Knowledge Distillation for Gaze Estimation |
| |
| Components: |
| 1. Angular gaze loss (L1 on pitch/yaw in degrees) |
| 2. L2CS-Net style binned classification + regression loss |
| 3. Feature-level distillation (WCoRD-inspired contrastive + distribution matching) |
| 4. Logit-level distillation (KL on soft targets from teacher) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class L2CSLoss(nn.Module): |
| """L2CS-Net style combined classification + regression loss per angle. |
| |
| Loss = CrossEntropy(binned_logits, binned_target) + beta * MSE(continuous_pred, continuous_target) |
| """ |
| |
| def __init__(self, gaze_bins: int = 90, beta: float = 1.0): |
| super().__init__() |
| self.gaze_bins = gaze_bins |
| self.beta = beta |
| self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins)) |
| self.ce_loss = nn.CrossEntropyLoss() |
| |
| def _angle_to_bins(self, angles): |
| angles_clamped = angles.clamp(-90.0, 90.0) |
| bin_width = 180.0 / (self.gaze_bins - 1) |
| bins = ((angles_clamped + 90.0) / bin_width).long() |
| return bins.clamp(0, self.gaze_bins - 1) |
| |
| def forward(self, logits, continuous_pred, angle_target): |
| bin_targets = self._angle_to_bins(angle_target) |
| ce = self.ce_loss(logits, bin_targets) |
| mse = F.mse_loss(continuous_pred, angle_target) |
| return ce + self.beta * mse |
|
|
|
|
| class AngularLoss(nn.Module): |
| """Direct angular error loss in degrees (L1 on pitch and yaw).""" |
| |
| def __init__(self, reduction: str = 'mean'): |
| super().__init__() |
| self.reduction = reduction |
| |
| def forward(self, pitch_pred, yaw_pred, pitch_target, yaw_target): |
| pitch_loss = F.l1_loss(pitch_pred, pitch_target, reduction=self.reduction) |
| yaw_loss = F.l1_loss(yaw_pred, yaw_target, reduction=self.reduction) |
| return pitch_loss + yaw_loss |
|
|
|
|
| class ContrastiveDistillationLoss(nn.Module): |
| """WCoRD-inspired contrastive feature distillation. |
| |
| InfoNCE loss maximizing mutual information between teacher and student features. |
| Projects both to a shared space before computing similarity. |
| """ |
| |
| def __init__(self, teacher_dim: int = 256, student_dim: int = 128, proj_dim: int = 128, temperature: float = 0.1): |
| super().__init__() |
| self.teacher_proj = nn.Sequential( |
| nn.Linear(teacher_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim)) |
| self.student_proj = nn.Sequential( |
| nn.Linear(student_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim)) |
| self.temperature = temperature |
| |
| def forward(self, teacher_feat, student_feat): |
| t = F.normalize(self.teacher_proj(teacher_feat), dim=-1) |
| s = F.normalize(self.student_proj(student_feat), dim=-1) |
| logits = torch.matmul(t, s.T) / self.temperature |
| labels = torch.arange(logits.shape[0], device=logits.device) |
| loss_t2s = F.cross_entropy(logits, labels) |
| loss_s2t = F.cross_entropy(logits.T, labels) |
| return (loss_t2s + loss_s2t) / 2.0 |
|
|
|
|
| class DistributionMatchingLoss(nn.Module): |
| """MMD-based distribution matching between teacher and student features.""" |
| |
| def __init__(self, kernel: str = 'rbf'): |
| super().__init__() |
| self.kernel = kernel |
| |
| def _rbf_kernel(self, x, y, sigma=1.0): |
| xx = torch.matmul(x, x.T) |
| yy = torch.matmul(y, y.T) |
| xy = torch.matmul(x, y.T) |
| rx = xx.diag().unsqueeze(0) |
| ry = yy.diag().unsqueeze(0) |
| k_xx = torch.exp(-(rx + rx.T - 2*xx) / (2*sigma**2)) |
| k_yy = torch.exp(-(ry + ry.T - 2*yy) / (2*sigma**2)) |
| k_xy = torch.exp(-(rx + ry.T - 2*xy) / (2*sigma**2)) |
| return k_xx.mean() + k_yy.mean() - 2*k_xy.mean() |
| |
| def forward(self, teacher_feat, student_feat): |
| t = F.normalize(teacher_feat, dim=-1) |
| s = F.normalize(student_feat, dim=-1) |
| return self._rbf_kernel(t, s) |
|
|
|
|
| class LogitDistillationLoss(nn.Module): |
| """KL divergence distillation on soft gaze bin probabilities.""" |
| |
| def __init__(self, temperature: float = 3.0): |
| super().__init__() |
| self.temperature = temperature |
| |
| def forward(self, student_logits, teacher_logits): |
| student_soft = F.log_softmax(student_logits / self.temperature, dim=-1) |
| teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1) |
| return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.temperature**2) |
|
|
|
|
| class PriviGazeDistillationLoss(nn.Module): |
| """Complete privileged distillation loss for gaze estimation. |
| |
| L_total = L_task + 伪_angular路L_angular + 伪_contrastive路L_contrastive |
| + 伪_mmd路L_mmd + 伪_logit路L_logit |
| """ |
| |
| def __init__(self, gaze_bins=90, teacher_feature_dim=256, student_feature_dim=128, |
| alpha_angular=1.0, alpha_contrastive=0.5, alpha_mmd=0.1, alpha_logit=0.5): |
| super().__init__() |
| self.angular_loss = AngularLoss() |
| self.pitch_l2cs = L2CSLoss(gaze_bins) |
| self.yaw_l2cs = L2CSLoss(gaze_bins) |
| self.contrastive_loss = ContrastiveDistillationLoss(teacher_feature_dim, student_feature_dim) |
| self.mmd_loss = DistributionMatchingLoss() |
| self.logit_loss = LogitDistillationLoss() |
| self.alpha_angular = alpha_angular |
| self.alpha_contrastive = alpha_contrastive |
| self.alpha_mmd = alpha_mmd |
| self.alpha_logit = alpha_logit |
| |
| def forward(self, s_pitch, s_yaw, sp_logits, sy_logits, s_features, |
| t_pitch, t_yaw, tp_logits, ty_logits, t_features, |
| pitch_target, yaw_target): |
| task_pitch = self.pitch_l2cs(sp_logits, s_pitch, pitch_target) |
| task_yaw = self.yaw_l2cs(sy_logits, s_yaw, yaw_target) |
| loss_task = task_pitch + task_yaw |
| |
| loss_angular = self.alpha_angular * self.angular_loss(s_pitch, s_yaw, pitch_target, yaw_target) |
| loss_contrastive = self.alpha_contrastive * self.contrastive_loss(t_features.detach(), s_features) |
| loss_mmd = self.alpha_mmd * self.mmd_loss(t_features.detach(), s_features) |
| loss_logit = (self.alpha_logit * self.logit_loss(sp_logits, tp_logits.detach()) + |
| self.alpha_logit * self.logit_loss(sy_logits, ty_logits.detach())) |
| |
| total_loss = loss_task + loss_angular + loss_contrastive + loss_mmd + loss_logit |
| |
| loss_dict = { |
| 'loss_total': total_loss.item(), |
| 'loss_task': loss_task.item(), |
| 'loss_angular': loss_angular.item(), |
| 'loss_contrastive': loss_contrastive.item(), |
| 'loss_mmd': loss_mmd.item(), |
| 'loss_logit': loss_logit.item(), |
| } |
| return total_loss, loss_dict |