privi-gaze-distill / models /distillation_loss.py
BcantCode's picture
Upload models/distillation_loss.py
fdc4b3d verified
"""
PriviGaze Distillation Loss - Privileged Knowledge Distillation for Gaze Estimation
Components:
1. Angular gaze loss (L1 on pitch/yaw in degrees)
2. L2CS-Net style binned classification + regression loss
3. Feature-level distillation (WCoRD-inspired contrastive + distribution matching)
4. Logit-level distillation (KL on soft targets from teacher)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class L2CSLoss(nn.Module):
"""L2CS-Net style combined classification + regression loss per angle.
Loss = CrossEntropy(binned_logits, binned_target) + beta * MSE(continuous_pred, continuous_target)
"""
def __init__(self, gaze_bins: int = 90, beta: float = 1.0):
super().__init__()
self.gaze_bins = gaze_bins
self.beta = beta
self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins))
self.ce_loss = nn.CrossEntropyLoss()
def _angle_to_bins(self, angles):
angles_clamped = angles.clamp(-90.0, 90.0)
bin_width = 180.0 / (self.gaze_bins - 1)
bins = ((angles_clamped + 90.0) / bin_width).long()
return bins.clamp(0, self.gaze_bins - 1)
def forward(self, logits, continuous_pred, angle_target):
bin_targets = self._angle_to_bins(angle_target)
ce = self.ce_loss(logits, bin_targets)
mse = F.mse_loss(continuous_pred, angle_target)
return ce + self.beta * mse
class AngularLoss(nn.Module):
"""Direct angular error loss in degrees (L1 on pitch and yaw)."""
def __init__(self, reduction: str = 'mean'):
super().__init__()
self.reduction = reduction
def forward(self, pitch_pred, yaw_pred, pitch_target, yaw_target):
pitch_loss = F.l1_loss(pitch_pred, pitch_target, reduction=self.reduction)
yaw_loss = F.l1_loss(yaw_pred, yaw_target, reduction=self.reduction)
return pitch_loss + yaw_loss
class ContrastiveDistillationLoss(nn.Module):
"""WCoRD-inspired contrastive feature distillation.
InfoNCE loss maximizing mutual information between teacher and student features.
Projects both to a shared space before computing similarity.
"""
def __init__(self, teacher_dim: int = 256, student_dim: int = 128, proj_dim: int = 128, temperature: float = 0.1):
super().__init__()
self.teacher_proj = nn.Sequential(
nn.Linear(teacher_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim))
self.student_proj = nn.Sequential(
nn.Linear(student_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim))
self.temperature = temperature
def forward(self, teacher_feat, student_feat):
t = F.normalize(self.teacher_proj(teacher_feat), dim=-1)
s = F.normalize(self.student_proj(student_feat), dim=-1)
logits = torch.matmul(t, s.T) / self.temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss_t2s = F.cross_entropy(logits, labels)
loss_s2t = F.cross_entropy(logits.T, labels)
return (loss_t2s + loss_s2t) / 2.0
class DistributionMatchingLoss(nn.Module):
"""MMD-based distribution matching between teacher and student features."""
def __init__(self, kernel: str = 'rbf'):
super().__init__()
self.kernel = kernel
def _rbf_kernel(self, x, y, sigma=1.0):
xx = torch.matmul(x, x.T)
yy = torch.matmul(y, y.T)
xy = torch.matmul(x, y.T)
rx = xx.diag().unsqueeze(0)
ry = yy.diag().unsqueeze(0)
k_xx = torch.exp(-(rx + rx.T - 2*xx) / (2*sigma**2))
k_yy = torch.exp(-(ry + ry.T - 2*yy) / (2*sigma**2))
k_xy = torch.exp(-(rx + ry.T - 2*xy) / (2*sigma**2))
return k_xx.mean() + k_yy.mean() - 2*k_xy.mean()
def forward(self, teacher_feat, student_feat):
t = F.normalize(teacher_feat, dim=-1)
s = F.normalize(student_feat, dim=-1)
return self._rbf_kernel(t, s)
class LogitDistillationLoss(nn.Module):
"""KL divergence distillation on soft gaze bin probabilities."""
def __init__(self, temperature: float = 3.0):
super().__init__()
self.temperature = temperature
def forward(self, student_logits, teacher_logits):
student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.temperature**2)
class PriviGazeDistillationLoss(nn.Module):
"""Complete privileged distillation loss for gaze estimation.
L_total = L_task + 伪_angular路L_angular + 伪_contrastive路L_contrastive
+ 伪_mmd路L_mmd + 伪_logit路L_logit
"""
def __init__(self, gaze_bins=90, teacher_feature_dim=256, student_feature_dim=128,
alpha_angular=1.0, alpha_contrastive=0.5, alpha_mmd=0.1, alpha_logit=0.5):
super().__init__()
self.angular_loss = AngularLoss()
self.pitch_l2cs = L2CSLoss(gaze_bins)
self.yaw_l2cs = L2CSLoss(gaze_bins)
self.contrastive_loss = ContrastiveDistillationLoss(teacher_feature_dim, student_feature_dim)
self.mmd_loss = DistributionMatchingLoss()
self.logit_loss = LogitDistillationLoss()
self.alpha_angular = alpha_angular
self.alpha_contrastive = alpha_contrastive
self.alpha_mmd = alpha_mmd
self.alpha_logit = alpha_logit
def forward(self, s_pitch, s_yaw, sp_logits, sy_logits, s_features,
t_pitch, t_yaw, tp_logits, ty_logits, t_features,
pitch_target, yaw_target):
task_pitch = self.pitch_l2cs(sp_logits, s_pitch, pitch_target)
task_yaw = self.yaw_l2cs(sy_logits, s_yaw, yaw_target)
loss_task = task_pitch + task_yaw
loss_angular = self.alpha_angular * self.angular_loss(s_pitch, s_yaw, pitch_target, yaw_target)
loss_contrastive = self.alpha_contrastive * self.contrastive_loss(t_features.detach(), s_features)
loss_mmd = self.alpha_mmd * self.mmd_loss(t_features.detach(), s_features)
loss_logit = (self.alpha_logit * self.logit_loss(sp_logits, tp_logits.detach()) +
self.alpha_logit * self.logit_loss(sy_logits, ty_logits.detach()))
total_loss = loss_task + loss_angular + loss_contrastive + loss_mmd + loss_logit
loss_dict = {
'loss_total': total_loss.item(),
'loss_task': loss_task.item(),
'loss_angular': loss_angular.item(),
'loss_contrastive': loss_contrastive.item(),
'loss_mmd': loss_mmd.item(),
'loss_logit': loss_logit.item(),
}
return total_loss, loss_dict