File size: 6,805 Bytes
94cb2c0
 
 
fdc4b3d
94cb2c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc4b3d
94cb2c0
 
fdc4b3d
94cb2c0
 
 
 
 
 
 
 
 
 
 
 
 
fdc4b3d
94cb2c0
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc4b3d
 
94cb2c0
 
fdc4b3d
94cb2c0
 
fdc4b3d
94cb2c0
fdc4b3d
94cb2c0
 
fdc4b3d
 
 
 
94cb2c0
 
 
 
 
 
 
fdc4b3d
94cb2c0
 
 
 
 
fdc4b3d
94cb2c0
 
 
 
 
fdc4b3d
 
 
 
94cb2c0
fdc4b3d
94cb2c0
 
 
 
 
 
fdc4b3d
94cb2c0
 
 
 
 
 
 
 
fdc4b3d
94cb2c0
 
 
 
 
fdc4b3d
 
94cb2c0
 
fdc4b3d
 
94cb2c0
 
 
 
fdc4b3d
94cb2c0
 
 
 
 
 
 
fdc4b3d
 
 
 
 
94cb2c0
 
fdc4b3d
 
 
 
 
94cb2c0
 
 
 
 
 
 
 
 
 
 
fdc4b3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
PriviGaze Distillation Loss - Privileged Knowledge Distillation for Gaze Estimation

Components:
1. Angular gaze loss (L1 on pitch/yaw in degrees)
2. L2CS-Net style binned classification + regression loss
3. Feature-level distillation (WCoRD-inspired contrastive + distribution matching)
4. Logit-level distillation (KL on soft targets from teacher)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class L2CSLoss(nn.Module):
    """L2CS-Net style combined classification + regression loss per angle.
    
    Loss = CrossEntropy(binned_logits, binned_target) + beta * MSE(continuous_pred, continuous_target)
    """
    
    def __init__(self, gaze_bins: int = 90, beta: float = 1.0):
        super().__init__()
        self.gaze_bins = gaze_bins
        self.beta = beta
        self.register_buffer('bin_centers', torch.linspace(-90.0, 90.0, gaze_bins))
        self.ce_loss = nn.CrossEntropyLoss()
    
    def _angle_to_bins(self, angles):
        angles_clamped = angles.clamp(-90.0, 90.0)
        bin_width = 180.0 / (self.gaze_bins - 1)
        bins = ((angles_clamped + 90.0) / bin_width).long()
        return bins.clamp(0, self.gaze_bins - 1)
    
    def forward(self, logits, continuous_pred, angle_target):
        bin_targets = self._angle_to_bins(angle_target)
        ce = self.ce_loss(logits, bin_targets)
        mse = F.mse_loss(continuous_pred, angle_target)
        return ce + self.beta * mse


class AngularLoss(nn.Module):
    """Direct angular error loss in degrees (L1 on pitch and yaw)."""
    
    def __init__(self, reduction: str = 'mean'):
        super().__init__()
        self.reduction = reduction
    
    def forward(self, pitch_pred, yaw_pred, pitch_target, yaw_target):
        pitch_loss = F.l1_loss(pitch_pred, pitch_target, reduction=self.reduction)
        yaw_loss = F.l1_loss(yaw_pred, yaw_target, reduction=self.reduction)
        return pitch_loss + yaw_loss


class ContrastiveDistillationLoss(nn.Module):
    """WCoRD-inspired contrastive feature distillation.
    
    InfoNCE loss maximizing mutual information between teacher and student features.
    Projects both to a shared space before computing similarity.
    """
    
    def __init__(self, teacher_dim: int = 256, student_dim: int = 128, proj_dim: int = 128, temperature: float = 0.1):
        super().__init__()
        self.teacher_proj = nn.Sequential(
            nn.Linear(teacher_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim))
        self.student_proj = nn.Sequential(
            nn.Linear(student_dim, proj_dim), nn.GELU(), nn.Linear(proj_dim, proj_dim))
        self.temperature = temperature
    
    def forward(self, teacher_feat, student_feat):
        t = F.normalize(self.teacher_proj(teacher_feat), dim=-1)
        s = F.normalize(self.student_proj(student_feat), dim=-1)
        logits = torch.matmul(t, s.T) / self.temperature
        labels = torch.arange(logits.shape[0], device=logits.device)
        loss_t2s = F.cross_entropy(logits, labels)
        loss_s2t = F.cross_entropy(logits.T, labels)
        return (loss_t2s + loss_s2t) / 2.0


class DistributionMatchingLoss(nn.Module):
    """MMD-based distribution matching between teacher and student features."""
    
    def __init__(self, kernel: str = 'rbf'):
        super().__init__()
        self.kernel = kernel
    
    def _rbf_kernel(self, x, y, sigma=1.0):
        xx = torch.matmul(x, x.T)
        yy = torch.matmul(y, y.T)
        xy = torch.matmul(x, y.T)
        rx = xx.diag().unsqueeze(0)
        ry = yy.diag().unsqueeze(0)
        k_xx = torch.exp(-(rx + rx.T - 2*xx) / (2*sigma**2))
        k_yy = torch.exp(-(ry + ry.T - 2*yy) / (2*sigma**2))
        k_xy = torch.exp(-(rx + ry.T - 2*xy) / (2*sigma**2))
        return k_xx.mean() + k_yy.mean() - 2*k_xy.mean()
    
    def forward(self, teacher_feat, student_feat):
        t = F.normalize(teacher_feat, dim=-1)
        s = F.normalize(student_feat, dim=-1)
        return self._rbf_kernel(t, s)


class LogitDistillationLoss(nn.Module):
    """KL divergence distillation on soft gaze bin probabilities."""
    
    def __init__(self, temperature: float = 3.0):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, student_logits, teacher_logits):
        student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
        return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.temperature**2)


class PriviGazeDistillationLoss(nn.Module):
    """Complete privileged distillation loss for gaze estimation.
    
    L_total = L_task + 伪_angular路L_angular + 伪_contrastive路L_contrastive 
            + 伪_mmd路L_mmd + 伪_logit路L_logit
    """
    
    def __init__(self, gaze_bins=90, teacher_feature_dim=256, student_feature_dim=128,
                 alpha_angular=1.0, alpha_contrastive=0.5, alpha_mmd=0.1, alpha_logit=0.5):
        super().__init__()
        self.angular_loss = AngularLoss()
        self.pitch_l2cs = L2CSLoss(gaze_bins)
        self.yaw_l2cs = L2CSLoss(gaze_bins)
        self.contrastive_loss = ContrastiveDistillationLoss(teacher_feature_dim, student_feature_dim)
        self.mmd_loss = DistributionMatchingLoss()
        self.logit_loss = LogitDistillationLoss()
        self.alpha_angular = alpha_angular
        self.alpha_contrastive = alpha_contrastive
        self.alpha_mmd = alpha_mmd
        self.alpha_logit = alpha_logit
    
    def forward(self, s_pitch, s_yaw, sp_logits, sy_logits, s_features,
                t_pitch, t_yaw, tp_logits, ty_logits, t_features,
                pitch_target, yaw_target):
        task_pitch = self.pitch_l2cs(sp_logits, s_pitch, pitch_target)
        task_yaw = self.yaw_l2cs(sy_logits, s_yaw, yaw_target)
        loss_task = task_pitch + task_yaw
        
        loss_angular = self.alpha_angular * self.angular_loss(s_pitch, s_yaw, pitch_target, yaw_target)
        loss_contrastive = self.alpha_contrastive * self.contrastive_loss(t_features.detach(), s_features)
        loss_mmd = self.alpha_mmd * self.mmd_loss(t_features.detach(), s_features)
        loss_logit = (self.alpha_logit * self.logit_loss(sp_logits, tp_logits.detach()) +
                      self.alpha_logit * self.logit_loss(sy_logits, ty_logits.detach()))
        
        total_loss = loss_task + loss_angular + loss_contrastive + loss_mmd + loss_logit
        
        loss_dict = {
            'loss_total': total_loss.item(),
            'loss_task': loss_task.item(),
            'loss_angular': loss_angular.item(),
            'loss_contrastive': loss_contrastive.item(),
            'loss_mmd': loss_mmd.item(),
            'loss_logit': loss_logit.item(),
        }
        return total_loss, loss_dict