File size: 6,805 Bytes
94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d 94cb2c0 fdc4b3d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """
PriviGaze Distillation Loss - Privileged Knowledge Distillation for Gaze Estimation
Components:
1. Angular gaze loss (L1 on pitch/yaw in degrees)
2. L2CS-Net style binned classification + regression loss
3. Feature-level distillation (WCoRD-inspired contrastive + distribution matching)
4. Logit-level distillation (KL on soft targets from teacher)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class L2CSLoss(nn.Module):
"""L2CS-Net style combined classification + regression loss per angle.
Loss = CrossEntropy(binned_logits, binned_target) + beta * MSE(continuous_pred, continuous_target)
"""
def __init__(self, gaze_bins: int = 90, beta: float = 1.0):
super().__init__()
self.gaze_bins = gaze_bins
self.beta = beta
self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins))
self.ce_loss = nn.CrossEntropyLoss()
def _angle_to_bins(self, angles):
angles_clamped = angles.clamp(-90.0, 90.0)
bin_width = 180.0 / (self.gaze_bins - 1)
bins = ((angles_clamped + 90.0) / bin_width).long()
return bins.clamp(0, self.gaze_bins - 1)
def forward(self, logits, continuous_pred, angle_target):
bin_targets = self._angle_to_bins(angle_target)
ce = self.ce_loss(logits, bin_targets)
mse = F.mse_loss(continuous_pred, angle_target)
return ce + self.beta * mse
class AngularLoss(nn.Module):
"""Direct angular error loss in degrees (L1 on pitch and yaw)."""
def __init__(self, reduction: str = 'mean'):
super().__init__()
self.reduction = reduction
def forward(self, pitch_pred, yaw_pred, pitch_target, yaw_target):
pitch_loss = F.l1_loss(pitch_pred, pitch_target, reduction=self.reduction)
yaw_loss = F.l1_loss(yaw_pred, yaw_target, reduction=self.reduction)
return pitch_loss + yaw_loss
class ContrastiveDistillationLoss(nn.Module):
"""WCoRD-inspired contrastive feature distillation.
InfoNCE loss maximizing mutual information between teacher and student features.
Projects both to a shared space before computing similarity.
"""
def __init__(self, teacher_dim: int = 256, student_dim: int = 128, proj_dim: int = 128, temperature: float = 0.1):
super().__init__()
self.teacher_proj = nn.Sequential(
nn.Linear(teacher_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim))
self.student_proj = nn.Sequential(
nn.Linear(student_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim))
self.temperature = temperature
def forward(self, teacher_feat, student_feat):
t = F.normalize(self.teacher_proj(teacher_feat), dim=-1)
s = F.normalize(self.student_proj(student_feat), dim=-1)
logits = torch.matmul(t, s.T) / self.temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss_t2s = F.cross_entropy(logits, labels)
loss_s2t = F.cross_entropy(logits.T, labels)
return (loss_t2s + loss_s2t) / 2.0
class DistributionMatchingLoss(nn.Module):
"""MMD-based distribution matching between teacher and student features."""
def __init__(self, kernel: str = 'rbf'):
super().__init__()
self.kernel = kernel
def _rbf_kernel(self, x, y, sigma=1.0):
xx = torch.matmul(x, x.T)
yy = torch.matmul(y, y.T)
xy = torch.matmul(x, y.T)
rx = xx.diag().unsqueeze(0)
ry = yy.diag().unsqueeze(0)
k_xx = torch.exp(-(rx + rx.T - 2*xx) / (2*sigma**2))
k_yy = torch.exp(-(ry + ry.T - 2*yy) / (2*sigma**2))
k_xy = torch.exp(-(rx + ry.T - 2*xy) / (2*sigma**2))
return k_xx.mean() + k_yy.mean() - 2*k_xy.mean()
def forward(self, teacher_feat, student_feat):
t = F.normalize(teacher_feat, dim=-1)
s = F.normalize(student_feat, dim=-1)
return self._rbf_kernel(t, s)
class LogitDistillationLoss(nn.Module):
"""KL divergence distillation on soft gaze bin probabilities."""
def __init__(self, temperature: float = 3.0):
super().__init__()
self.temperature = temperature
def forward(self, student_logits, teacher_logits):
student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.temperature**2)
class PriviGazeDistillationLoss(nn.Module):
"""Complete privileged distillation loss for gaze estimation.
L_total = L_task + 伪_angular路L_angular + 伪_contrastive路L_contrastive
+ 伪_mmd路L_mmd + 伪_logit路L_logit
"""
def __init__(self, gaze_bins=90, teacher_feature_dim=256, student_feature_dim=128,
alpha_angular=1.0, alpha_contrastive=0.5, alpha_mmd=0.1, alpha_logit=0.5):
super().__init__()
self.angular_loss = AngularLoss()
self.pitch_l2cs = L2CSLoss(gaze_bins)
self.yaw_l2cs = L2CSLoss(gaze_bins)
self.contrastive_loss = ContrastiveDistillationLoss(teacher_feature_dim, student_feature_dim)
self.mmd_loss = DistributionMatchingLoss()
self.logit_loss = LogitDistillationLoss()
self.alpha_angular = alpha_angular
self.alpha_contrastive = alpha_contrastive
self.alpha_mmd = alpha_mmd
self.alpha_logit = alpha_logit
def forward(self, s_pitch, s_yaw, sp_logits, sy_logits, s_features,
t_pitch, t_yaw, tp_logits, ty_logits, t_features,
pitch_target, yaw_target):
task_pitch = self.pitch_l2cs(sp_logits, s_pitch, pitch_target)
task_yaw = self.yaw_l2cs(sy_logits, s_yaw, yaw_target)
loss_task = task_pitch + task_yaw
loss_angular = self.alpha_angular * self.angular_loss(s_pitch, s_yaw, pitch_target, yaw_target)
loss_contrastive = self.alpha_contrastive * self.contrastive_loss(t_features.detach(), s_features)
loss_mmd = self.alpha_mmd * self.mmd_loss(t_features.detach(), s_features)
loss_logit = (self.alpha_logit * self.logit_loss(sp_logits, tp_logits.detach()) +
self.alpha_logit * self.logit_loss(sy_logits, ty_logits.detach()))
total_loss = loss_task + loss_angular + loss_contrastive + loss_mmd + loss_logit
loss_dict = {
'loss_total': total_loss.item(),
'loss_task': loss_task.item(),
'loss_angular': loss_angular.item(),
'loss_contrastive': loss_contrastive.item(),
'loss_mmd': loss_mmd.item(),
'loss_logit': loss_logit.item(),
}
return total_loss, loss_dict |