omar-ah commited on
Commit
01f95f3
·
verified ·
1 Parent(s): b3b0529

Upload vil_tracker/training/losses.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vil_tracker/training/losses.py +290 -0
vil_tracker/training/losses.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss functions for ViL Tracker training.
3
+
4
+ Includes:
5
+ - FocalLoss: for center heatmap prediction (handles class imbalance)
6
+ - GIoULoss: for bounding box regression
7
+ - UncertaintyNLLLoss: uncertainty-aware NLL loss
8
+ - MemoryContrastiveLoss: contrastive loss for mLSTM memory states
9
+ - AFKDDistillationLoss: attention-free knowledge distillation
10
+ - ADWLoss: adaptive dynamic weighting for multi-task loss
11
+ - CombinedTrackingLoss: combines all losses with learned weighting
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class FocalLoss(nn.Module):
20
+ """Focal loss for heatmap prediction (CornerNet-style).
21
+
22
+ Handles extreme foreground/background imbalance in center heatmaps
23
+ where only ~1/256 positions are positive.
24
+ """
25
+ def __init__(self, alpha: float = 2.0, beta: float = 4.0):
26
+ super().__init__()
27
+ self.alpha = alpha
28
+ self.beta = beta
29
+
30
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
31
+ """
32
+ Args:
33
+ pred: (B, 1, H, W) predicted heatmap (logits)
34
+ target: (B, 1, H, W) ground truth Gaussian heatmap
35
+ """
36
+ pred_sig = torch.sigmoid(pred)
37
+ pred_sig = pred_sig.clamp(1e-6, 1 - 1e-6)
38
+
39
+ pos_mask = target.eq(1).float()
40
+ neg_mask = target.lt(1).float()
41
+
42
+ # Positive loss
43
+ pos_loss = -((1 - pred_sig) ** self.alpha) * torch.log(pred_sig) * pos_mask
44
+
45
+ # Negative loss (weighted by distance from GT peak)
46
+ neg_weight = (1 - target) ** self.beta
47
+ neg_loss = -(pred_sig ** self.alpha) * torch.log(1 - pred_sig) * neg_weight * neg_mask
48
+
49
+ num_pos = pos_mask.sum().clamp(min=1)
50
+ loss = (pos_loss.sum() + neg_loss.sum()) / num_pos
51
+ return loss
52
+
53
+
54
+ class GIoULoss(nn.Module):
55
+ """Generalized IoU loss for bounding box regression.
56
+
57
+ Better gradient signal than L1 for box prediction, especially
58
+ for non-overlapping boxes.
59
+ """
60
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Args:
63
+ pred: (B, 4) predicted [cx, cy, w, h]
64
+ target: (B, 4) ground truth [cx, cy, w, h]
65
+ """
66
+ # Convert to [x1, y1, x2, y2]
67
+ pred_x1 = pred[:, 0] - pred[:, 2] / 2
68
+ pred_y1 = pred[:, 1] - pred[:, 3] / 2
69
+ pred_x2 = pred[:, 0] + pred[:, 2] / 2
70
+ pred_y2 = pred[:, 1] + pred[:, 3] / 2
71
+
72
+ gt_x1 = target[:, 0] - target[:, 2] / 2
73
+ gt_y1 = target[:, 1] - target[:, 3] / 2
74
+ gt_x2 = target[:, 0] + target[:, 2] / 2
75
+ gt_y2 = target[:, 1] + target[:, 3] / 2
76
+
77
+ # Intersection
78
+ inter_x1 = torch.max(pred_x1, gt_x1)
79
+ inter_y1 = torch.max(pred_y1, gt_y1)
80
+ inter_x2 = torch.min(pred_x2, gt_x2)
81
+ inter_y2 = torch.min(pred_y2, gt_y2)
82
+ inter_area = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
83
+
84
+ # Union
85
+ pred_area = (pred_x2 - pred_x1).clamp(min=0) * (pred_y2 - pred_y1).clamp(min=0)
86
+ gt_area = (gt_x2 - gt_x1).clamp(min=0) * (gt_y2 - gt_y1).clamp(min=0)
87
+ union_area = pred_area + gt_area - inter_area
88
+
89
+ iou = inter_area / union_area.clamp(min=1e-6)
90
+
91
+ # Enclosing box
92
+ enc_x1 = torch.min(pred_x1, gt_x1)
93
+ enc_y1 = torch.min(pred_y1, gt_y1)
94
+ enc_x2 = torch.max(pred_x2, gt_x2)
95
+ enc_y2 = torch.max(pred_y2, gt_y2)
96
+ enc_area = (enc_x2 - enc_x1).clamp(min=0) * (enc_y2 - enc_y1).clamp(min=0)
97
+
98
+ giou = iou - (enc_area - union_area) / enc_area.clamp(min=1e-6)
99
+ return (1 - giou).mean()
100
+
101
+
102
+ class UncertaintyNLLLoss(nn.Module):
103
+ """Uncertainty-aware negative log-likelihood loss.
104
+
105
+ Weighs the regression loss by predicted uncertainty:
106
+ L = 0.5 * exp(-s) * |pred - target|^2 + 0.5 * s
107
+ where s = log(variance).
108
+ """
109
+ def forward(self, pred: torch.Tensor, target: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
110
+ """
111
+ Args:
112
+ pred: (B, ...) predictions
113
+ target: (B, ...) targets
114
+ log_var: (B, ...) predicted log variance
115
+ """
116
+ precision = torch.exp(-log_var)
117
+ sq_error = (pred - target) ** 2
118
+ loss = 0.5 * (precision * sq_error + log_var)
119
+ return loss.mean()
120
+
121
+
122
+ class MemoryContrastiveLoss(nn.Module):
123
+ """Contrastive loss for mLSTM memory states.
124
+
125
+ Encourages similar memory states for the same target across frames
126
+ and dissimilar states for different targets.
127
+ """
128
+ def __init__(self, temperature: float = 0.1):
129
+ super().__init__()
130
+ self.temperature = temperature
131
+
132
+ def forward(self, feat_a: torch.Tensor, feat_b: torch.Tensor) -> torch.Tensor:
133
+ """
134
+ Args:
135
+ feat_a: (B, D) features from frame A
136
+ feat_b: (B, D) features from frame B (same target)
137
+ """
138
+ # L2 normalize
139
+ feat_a = F.normalize(feat_a, dim=-1)
140
+ feat_b = F.normalize(feat_b, dim=-1)
141
+
142
+ B = feat_a.shape[0]
143
+
144
+ # Similarity matrix
145
+ sim = torch.mm(feat_a, feat_b.t()) / self.temperature # (B, B)
146
+
147
+ # Positive pairs along diagonal
148
+ labels = torch.arange(B, device=feat_a.device)
149
+ loss = F.cross_entropy(sim, labels)
150
+ return loss
151
+
152
+
153
+ class AFKDDistillationLoss(nn.Module):
154
+ """Attention-Free Knowledge Distillation loss.
155
+
156
+ For distilling from MCITrack-B256 teacher to ViL-S student.
157
+ Uses feature matching + response-based distillation.
158
+ """
159
+ def __init__(self, student_dim: int = 384, teacher_dim: int = 768, temperature: float = 4.0):
160
+ super().__init__()
161
+ self.temperature = temperature
162
+ # Projector to match dimensions
163
+ self.projector = nn.Sequential(
164
+ nn.Linear(student_dim, teacher_dim),
165
+ nn.GELU(),
166
+ nn.Linear(teacher_dim, teacher_dim),
167
+ )
168
+
169
+ def forward(
170
+ self,
171
+ student_feat: torch.Tensor,
172
+ teacher_feat: torch.Tensor,
173
+ student_logits: torch.Tensor = None,
174
+ teacher_logits: torch.Tensor = None,
175
+ ) -> torch.Tensor:
176
+ """
177
+ Args:
178
+ student_feat: (B, S, D_s) student features
179
+ teacher_feat: (B, S, D_t) teacher features
180
+ student_logits: optional (B, ...) student predictions
181
+ teacher_logits: optional (B, ...) teacher predictions
182
+ """
183
+ # Feature distillation
184
+ student_proj = self.projector(student_feat)
185
+ feat_loss = F.mse_loss(student_proj, teacher_feat.detach())
186
+
187
+ # Response distillation (if logits provided)
188
+ if student_logits is not None and teacher_logits is not None:
189
+ T = self.temperature
190
+ s_soft = F.log_softmax(student_logits.view(student_logits.shape[0], -1) / T, dim=-1)
191
+ t_soft = F.softmax(teacher_logits.view(teacher_logits.shape[0], -1) / T, dim=-1)
192
+ resp_loss = F.kl_div(s_soft, t_soft.detach(), reduction='batchmean') * (T ** 2)
193
+ return feat_loss + resp_loss
194
+
195
+ return feat_loss
196
+
197
+
198
+ class ADWLoss(nn.Module):
199
+ """Adaptive Dynamic Weighting for multi-task loss.
200
+
201
+ Learns task weights based on loss magnitudes using homoscedastic uncertainty.
202
+ w_k = 1/(2*sigma_k^2), regularizer = log(sigma_k)
203
+ """
204
+ def __init__(self, num_tasks: int = 4):
205
+ super().__init__()
206
+ # Log variance parameters (initialized to 0 = equal weighting)
207
+ self.log_vars = nn.Parameter(torch.zeros(num_tasks))
208
+
209
+ def forward(self, losses: list) -> torch.Tensor:
210
+ """
211
+ Args:
212
+ losses: list of scalar loss tensors (one per task)
213
+ Returns:
214
+ weighted sum of losses
215
+ """
216
+ total = 0
217
+ for i, loss in enumerate(losses):
218
+ precision = torch.exp(-self.log_vars[i])
219
+ total = total + precision * loss + self.log_vars[i]
220
+ return total
221
+
222
+
223
+ class CombinedTrackingLoss(nn.Module):
224
+ """Combined loss for tracker training.
225
+
226
+ Combines:
227
+ - Focal loss on center heatmap
228
+ - GIoU loss on predicted boxes
229
+ - L1 loss on size regression
230
+ - Optional: uncertainty NLL, contrastive, distillation
231
+ """
232
+ def __init__(self, use_uncertainty: bool = True, use_adw: bool = True):
233
+ super().__init__()
234
+ self.focal = FocalLoss()
235
+ self.giou = GIoULoss()
236
+ self.l1 = nn.L1Loss()
237
+ self.use_uncertainty = use_uncertainty
238
+
239
+ if use_uncertainty:
240
+ self.uncertainty_loss = UncertaintyNLLLoss()
241
+
242
+ num_tasks = 4 if use_uncertainty else 3
243
+ self.adw = ADWLoss(num_tasks=num_tasks) if use_adw else None
244
+
245
+ def forward(
246
+ self,
247
+ pred: dict,
248
+ gt_heatmap: torch.Tensor,
249
+ gt_size: torch.Tensor,
250
+ gt_boxes: torch.Tensor,
251
+ ) -> dict:
252
+ """
253
+ Args:
254
+ pred: model output dict with 'heatmap', 'size', 'boxes', optionally 'log_variance'
255
+ gt_heatmap: (B, 1, H, W) ground truth heatmap
256
+ gt_size: (B, 2) ground truth normalized size [w, h]
257
+ gt_boxes: (B, 4) ground truth boxes [cx, cy, w, h] in pixels
258
+ """
259
+ # Heatmap loss
260
+ heatmap_loss = self.focal(pred['heatmap'], gt_heatmap)
261
+
262
+ # Size loss (at peak location)
263
+ B = gt_size.shape[0]
264
+ pred_size = pred['size'].view(B, 2, -1).mean(dim=-1) # average pool
265
+ size_loss = self.l1(pred_size, gt_size)
266
+
267
+ # GIoU box loss
268
+ giou_loss = self.giou(pred['boxes'], gt_boxes)
269
+
270
+ losses = [heatmap_loss, size_loss, giou_loss]
271
+
272
+ # Uncertainty loss
273
+ if self.use_uncertainty and 'log_variance' in pred:
274
+ log_var = pred['log_variance'].mean(dim=[1, 2, 3]) # (B,)
275
+ unc_loss = (0.5 * torch.exp(-log_var) * giou_loss + 0.5 * log_var).mean()
276
+ losses.append(unc_loss)
277
+
278
+ # Combine with ADW or simple sum
279
+ if self.adw is not None:
280
+ total_loss = self.adw(losses)
281
+ else:
282
+ weights = [1.0, 1.0, 2.0, 0.5] if len(losses) == 4 else [1.0, 1.0, 2.0]
283
+ total_loss = sum(w * l for w, l in zip(weights, losses))
284
+
285
+ return {
286
+ 'total': total_loss,
287
+ 'heatmap': heatmap_loss.detach(),
288
+ 'size': size_loss.detach(),
289
+ 'giou': giou_loss.detach(),
290
+ }