""" PriviGaze Distillation Loss - Privileged Knowledge Distillation for Gaze Estimation Components: 1. Angular gaze loss (L1 on pitch/yaw in degrees) 2. L2CS-Net style binned classification + regression loss 3. Feature-level distillation (WCoRD-inspired contrastive + distribution matching) 4. Logit-level distillation (KL on soft targets from teacher) """ import torch import torch.nn as nn import torch.nn.functional as F class L2CSLoss(nn.Module): """L2CS-Net style combined classification + regression loss per angle. Loss = CrossEntropy(binned_logits, binned_target) + beta * MSE(continuous_pred, continuous_target) """ def __init__(self, gaze_bins: int = 90, beta: float = 1.0): super().__init__() self.gaze_bins = gaze_bins self.beta = beta self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins)) self.ce_loss = nn.CrossEntropyLoss() def _angle_to_bins(self, angles): angles_clamped = angles.clamp(-90.0, 90.0) bin_width = 180.0 / (self.gaze_bins - 1) bins = ((angles_clamped + 90.0) / bin_width).long() return bins.clamp(0, self.gaze_bins - 1) def forward(self, logits, continuous_pred, angle_target): bin_targets = self._angle_to_bins(angle_target) ce = self.ce_loss(logits, bin_targets) mse = F.mse_loss(continuous_pred, angle_target) return ce + self.beta * mse class AngularLoss(nn.Module): """Direct angular error loss in degrees (L1 on pitch and yaw).""" def __init__(self, reduction: str = 'mean'): super().__init__() self.reduction = reduction def forward(self, pitch_pred, yaw_pred, pitch_target, yaw_target): pitch_loss = F.l1_loss(pitch_pred, pitch_target, reduction=self.reduction) yaw_loss = F.l1_loss(yaw_pred, yaw_target, reduction=self.reduction) return pitch_loss + yaw_loss class ContrastiveDistillationLoss(nn.Module): """WCoRD-inspired contrastive feature distillation. InfoNCE loss maximizing mutual information between teacher and student features. Projects both to a shared space before computing similarity. """ def __init__(self, teacher_dim: int = 256, student_dim: int = 128, proj_dim: int = 128, temperature: float = 0.1): super().__init__() self.teacher_proj = nn.Sequential( nn.Linear(teacher_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim)) self.student_proj = nn.Sequential( nn.Linear(student_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim)) self.temperature = temperature def forward(self, teacher_feat, student_feat): t = F.normalize(self.teacher_proj(teacher_feat), dim=-1) s = F.normalize(self.student_proj(student_feat), dim=-1) logits = torch.matmul(t, s.T) / self.temperature labels = torch.arange(logits.shape[0], device=logits.device) loss_t2s = F.cross_entropy(logits, labels) loss_s2t = F.cross_entropy(logits.T, labels) return (loss_t2s + loss_s2t) / 2.0 class DistributionMatchingLoss(nn.Module): """MMD-based distribution matching between teacher and student features.""" def __init__(self, kernel: str = 'rbf'): super().__init__() self.kernel = kernel def _rbf_kernel(self, x, y, sigma=1.0): xx = torch.matmul(x, x.T) yy = torch.matmul(y, y.T) xy = torch.matmul(x, y.T) rx = xx.diag().unsqueeze(0) ry = yy.diag().unsqueeze(0) k_xx = torch.exp(-(rx + rx.T - 2*xx) / (2*sigma**2)) k_yy = torch.exp(-(ry + ry.T - 2*yy) / (2*sigma**2)) k_xy = torch.exp(-(rx + ry.T - 2*xy) / (2*sigma**2)) return k_xx.mean() + k_yy.mean() - 2*k_xy.mean() def forward(self, teacher_feat, student_feat): t = F.normalize(teacher_feat, dim=-1) s = F.normalize(student_feat, dim=-1) return self._rbf_kernel(t, s) class LogitDistillationLoss(nn.Module): """KL divergence distillation on soft gaze bin probabilities.""" def __init__(self, temperature: float = 3.0): super().__init__() self.temperature = temperature def forward(self, student_logits, teacher_logits): student_soft = F.log_softmax(student_logits / self.temperature, dim=-1) teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1) return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.temperature**2) class PriviGazeDistillationLoss(nn.Module): """Complete privileged distillation loss for gaze estimation. L_total = L_task + α_angular·L_angular + α_contrastive·L_contrastive + α_mmd·L_mmd + α_logit·L_logit """ def __init__(self, gaze_bins=90, teacher_feature_dim=256, student_feature_dim=128, alpha_angular=1.0, alpha_contrastive=0.5, alpha_mmd=0.1, alpha_logit=0.5): super().__init__() self.angular_loss = AngularLoss() self.pitch_l2cs = L2CSLoss(gaze_bins) self.yaw_l2cs = L2CSLoss(gaze_bins) self.contrastive_loss = ContrastiveDistillationLoss(teacher_feature_dim, student_feature_dim) self.mmd_loss = DistributionMatchingLoss() self.logit_loss = LogitDistillationLoss() self.alpha_angular = alpha_angular self.alpha_contrastive = alpha_contrastive self.alpha_mmd = alpha_mmd self.alpha_logit = alpha_logit def forward(self, s_pitch, s_yaw, sp_logits, sy_logits, s_features, t_pitch, t_yaw, tp_logits, ty_logits, t_features, pitch_target, yaw_target): task_pitch = self.pitch_l2cs(sp_logits, s_pitch, pitch_target) task_yaw = self.yaw_l2cs(sy_logits, s_yaw, yaw_target) loss_task = task_pitch + task_yaw loss_angular = self.alpha_angular * self.angular_loss(s_pitch, s_yaw, pitch_target, yaw_target) loss_contrastive = self.alpha_contrastive * self.contrastive_loss(t_features.detach(), s_features) loss_mmd = self.alpha_mmd * self.mmd_loss(t_features.detach(), s_features) loss_logit = (self.alpha_logit * self.logit_loss(sp_logits, tp_logits.detach()) + self.alpha_logit * self.logit_loss(sy_logits, ty_logits.detach())) total_loss = loss_task + loss_angular + loss_contrastive + loss_mmd + loss_logit loss_dict = { 'loss_total': total_loss.item(), 'loss_task': loss_task.item(), 'loss_angular': loss_angular.item(), 'loss_contrastive': loss_contrastive.item(), 'loss_mmd': loss_mmd.item(), 'loss_logit': loss_logit.item(), } return total_loss, loss_dict