Upload vil_tracker/training/losses.py with huggingface_hub
Browse files- vil_tracker/training/losses.py +290 -0
vil_tracker/training/losses.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Loss functions for ViL Tracker training.
|
| 3 |
+
|
| 4 |
+
Includes:
|
| 5 |
+
- FocalLoss: for center heatmap prediction (handles class imbalance)
|
| 6 |
+
- GIoULoss: for bounding box regression
|
| 7 |
+
- UncertaintyNLLLoss: uncertainty-aware NLL loss
|
| 8 |
+
- MemoryContrastiveLoss: contrastive loss for mLSTM memory states
|
| 9 |
+
- AFKDDistillationLoss: attention-free knowledge distillation
|
| 10 |
+
- ADWLoss: adaptive dynamic weighting for multi-task loss
|
| 11 |
+
- CombinedTrackingLoss: combines all losses with learned weighting
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FocalLoss(nn.Module):
|
| 20 |
+
"""Focal loss for heatmap prediction (CornerNet-style).
|
| 21 |
+
|
| 22 |
+
Handles extreme foreground/background imbalance in center heatmaps
|
| 23 |
+
where only ~1/256 positions are positive.
|
| 24 |
+
"""
|
| 25 |
+
def __init__(self, alpha: float = 2.0, beta: float = 4.0):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.alpha = alpha
|
| 28 |
+
self.beta = beta
|
| 29 |
+
|
| 30 |
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
pred: (B, 1, H, W) predicted heatmap (logits)
|
| 34 |
+
target: (B, 1, H, W) ground truth Gaussian heatmap
|
| 35 |
+
"""
|
| 36 |
+
pred_sig = torch.sigmoid(pred)
|
| 37 |
+
pred_sig = pred_sig.clamp(1e-6, 1 - 1e-6)
|
| 38 |
+
|
| 39 |
+
pos_mask = target.eq(1).float()
|
| 40 |
+
neg_mask = target.lt(1).float()
|
| 41 |
+
|
| 42 |
+
# Positive loss
|
| 43 |
+
pos_loss = -((1 - pred_sig) ** self.alpha) * torch.log(pred_sig) * pos_mask
|
| 44 |
+
|
| 45 |
+
# Negative loss (weighted by distance from GT peak)
|
| 46 |
+
neg_weight = (1 - target) ** self.beta
|
| 47 |
+
neg_loss = -(pred_sig ** self.alpha) * torch.log(1 - pred_sig) * neg_weight * neg_mask
|
| 48 |
+
|
| 49 |
+
num_pos = pos_mask.sum().clamp(min=1)
|
| 50 |
+
loss = (pos_loss.sum() + neg_loss.sum()) / num_pos
|
| 51 |
+
return loss
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class GIoULoss(nn.Module):
|
| 55 |
+
"""Generalized IoU loss for bounding box regression.
|
| 56 |
+
|
| 57 |
+
Better gradient signal than L1 for box prediction, especially
|
| 58 |
+
for non-overlapping boxes.
|
| 59 |
+
"""
|
| 60 |
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Args:
|
| 63 |
+
pred: (B, 4) predicted [cx, cy, w, h]
|
| 64 |
+
target: (B, 4) ground truth [cx, cy, w, h]
|
| 65 |
+
"""
|
| 66 |
+
# Convert to [x1, y1, x2, y2]
|
| 67 |
+
pred_x1 = pred[:, 0] - pred[:, 2] / 2
|
| 68 |
+
pred_y1 = pred[:, 1] - pred[:, 3] / 2
|
| 69 |
+
pred_x2 = pred[:, 0] + pred[:, 2] / 2
|
| 70 |
+
pred_y2 = pred[:, 1] + pred[:, 3] / 2
|
| 71 |
+
|
| 72 |
+
gt_x1 = target[:, 0] - target[:, 2] / 2
|
| 73 |
+
gt_y1 = target[:, 1] - target[:, 3] / 2
|
| 74 |
+
gt_x2 = target[:, 0] + target[:, 2] / 2
|
| 75 |
+
gt_y2 = target[:, 1] + target[:, 3] / 2
|
| 76 |
+
|
| 77 |
+
# Intersection
|
| 78 |
+
inter_x1 = torch.max(pred_x1, gt_x1)
|
| 79 |
+
inter_y1 = torch.max(pred_y1, gt_y1)
|
| 80 |
+
inter_x2 = torch.min(pred_x2, gt_x2)
|
| 81 |
+
inter_y2 = torch.min(pred_y2, gt_y2)
|
| 82 |
+
inter_area = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
|
| 83 |
+
|
| 84 |
+
# Union
|
| 85 |
+
pred_area = (pred_x2 - pred_x1).clamp(min=0) * (pred_y2 - pred_y1).clamp(min=0)
|
| 86 |
+
gt_area = (gt_x2 - gt_x1).clamp(min=0) * (gt_y2 - gt_y1).clamp(min=0)
|
| 87 |
+
union_area = pred_area + gt_area - inter_area
|
| 88 |
+
|
| 89 |
+
iou = inter_area / union_area.clamp(min=1e-6)
|
| 90 |
+
|
| 91 |
+
# Enclosing box
|
| 92 |
+
enc_x1 = torch.min(pred_x1, gt_x1)
|
| 93 |
+
enc_y1 = torch.min(pred_y1, gt_y1)
|
| 94 |
+
enc_x2 = torch.max(pred_x2, gt_x2)
|
| 95 |
+
enc_y2 = torch.max(pred_y2, gt_y2)
|
| 96 |
+
enc_area = (enc_x2 - enc_x1).clamp(min=0) * (enc_y2 - enc_y1).clamp(min=0)
|
| 97 |
+
|
| 98 |
+
giou = iou - (enc_area - union_area) / enc_area.clamp(min=1e-6)
|
| 99 |
+
return (1 - giou).mean()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class UncertaintyNLLLoss(nn.Module):
|
| 103 |
+
"""Uncertainty-aware negative log-likelihood loss.
|
| 104 |
+
|
| 105 |
+
Weighs the regression loss by predicted uncertainty:
|
| 106 |
+
L = 0.5 * exp(-s) * |pred - target|^2 + 0.5 * s
|
| 107 |
+
where s = log(variance).
|
| 108 |
+
"""
|
| 109 |
+
def forward(self, pred: torch.Tensor, target: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
"""
|
| 111 |
+
Args:
|
| 112 |
+
pred: (B, ...) predictions
|
| 113 |
+
target: (B, ...) targets
|
| 114 |
+
log_var: (B, ...) predicted log variance
|
| 115 |
+
"""
|
| 116 |
+
precision = torch.exp(-log_var)
|
| 117 |
+
sq_error = (pred - target) ** 2
|
| 118 |
+
loss = 0.5 * (precision * sq_error + log_var)
|
| 119 |
+
return loss.mean()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class MemoryContrastiveLoss(nn.Module):
|
| 123 |
+
"""Contrastive loss for mLSTM memory states.
|
| 124 |
+
|
| 125 |
+
Encourages similar memory states for the same target across frames
|
| 126 |
+
and dissimilar states for different targets.
|
| 127 |
+
"""
|
| 128 |
+
def __init__(self, temperature: float = 0.1):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.temperature = temperature
|
| 131 |
+
|
| 132 |
+
def forward(self, feat_a: torch.Tensor, feat_b: torch.Tensor) -> torch.Tensor:
|
| 133 |
+
"""
|
| 134 |
+
Args:
|
| 135 |
+
feat_a: (B, D) features from frame A
|
| 136 |
+
feat_b: (B, D) features from frame B (same target)
|
| 137 |
+
"""
|
| 138 |
+
# L2 normalize
|
| 139 |
+
feat_a = F.normalize(feat_a, dim=-1)
|
| 140 |
+
feat_b = F.normalize(feat_b, dim=-1)
|
| 141 |
+
|
| 142 |
+
B = feat_a.shape[0]
|
| 143 |
+
|
| 144 |
+
# Similarity matrix
|
| 145 |
+
sim = torch.mm(feat_a, feat_b.t()) / self.temperature # (B, B)
|
| 146 |
+
|
| 147 |
+
# Positive pairs along diagonal
|
| 148 |
+
labels = torch.arange(B, device=feat_a.device)
|
| 149 |
+
loss = F.cross_entropy(sim, labels)
|
| 150 |
+
return loss
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class AFKDDistillationLoss(nn.Module):
|
| 154 |
+
"""Attention-Free Knowledge Distillation loss.
|
| 155 |
+
|
| 156 |
+
For distilling from MCITrack-B256 teacher to ViL-S student.
|
| 157 |
+
Uses feature matching + response-based distillation.
|
| 158 |
+
"""
|
| 159 |
+
def __init__(self, student_dim: int = 384, teacher_dim: int = 768, temperature: float = 4.0):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.temperature = temperature
|
| 162 |
+
# Projector to match dimensions
|
| 163 |
+
self.projector = nn.Sequential(
|
| 164 |
+
nn.Linear(student_dim, teacher_dim),
|
| 165 |
+
nn.GELU(),
|
| 166 |
+
nn.Linear(teacher_dim, teacher_dim),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def forward(
|
| 170 |
+
self,
|
| 171 |
+
student_feat: torch.Tensor,
|
| 172 |
+
teacher_feat: torch.Tensor,
|
| 173 |
+
student_logits: torch.Tensor = None,
|
| 174 |
+
teacher_logits: torch.Tensor = None,
|
| 175 |
+
) -> torch.Tensor:
|
| 176 |
+
"""
|
| 177 |
+
Args:
|
| 178 |
+
student_feat: (B, S, D_s) student features
|
| 179 |
+
teacher_feat: (B, S, D_t) teacher features
|
| 180 |
+
student_logits: optional (B, ...) student predictions
|
| 181 |
+
teacher_logits: optional (B, ...) teacher predictions
|
| 182 |
+
"""
|
| 183 |
+
# Feature distillation
|
| 184 |
+
student_proj = self.projector(student_feat)
|
| 185 |
+
feat_loss = F.mse_loss(student_proj, teacher_feat.detach())
|
| 186 |
+
|
| 187 |
+
# Response distillation (if logits provided)
|
| 188 |
+
if student_logits is not None and teacher_logits is not None:
|
| 189 |
+
T = self.temperature
|
| 190 |
+
s_soft = F.log_softmax(student_logits.view(student_logits.shape[0], -1) / T, dim=-1)
|
| 191 |
+
t_soft = F.softmax(teacher_logits.view(teacher_logits.shape[0], -1) / T, dim=-1)
|
| 192 |
+
resp_loss = F.kl_div(s_soft, t_soft.detach(), reduction='batchmean') * (T ** 2)
|
| 193 |
+
return feat_loss + resp_loss
|
| 194 |
+
|
| 195 |
+
return feat_loss
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class ADWLoss(nn.Module):
|
| 199 |
+
"""Adaptive Dynamic Weighting for multi-task loss.
|
| 200 |
+
|
| 201 |
+
Learns task weights based on loss magnitudes using homoscedastic uncertainty.
|
| 202 |
+
w_k = 1/(2*sigma_k^2), regularizer = log(sigma_k)
|
| 203 |
+
"""
|
| 204 |
+
def __init__(self, num_tasks: int = 4):
|
| 205 |
+
super().__init__()
|
| 206 |
+
# Log variance parameters (initialized to 0 = equal weighting)
|
| 207 |
+
self.log_vars = nn.Parameter(torch.zeros(num_tasks))
|
| 208 |
+
|
| 209 |
+
def forward(self, losses: list) -> torch.Tensor:
|
| 210 |
+
"""
|
| 211 |
+
Args:
|
| 212 |
+
losses: list of scalar loss tensors (one per task)
|
| 213 |
+
Returns:
|
| 214 |
+
weighted sum of losses
|
| 215 |
+
"""
|
| 216 |
+
total = 0
|
| 217 |
+
for i, loss in enumerate(losses):
|
| 218 |
+
precision = torch.exp(-self.log_vars[i])
|
| 219 |
+
total = total + precision * loss + self.log_vars[i]
|
| 220 |
+
return total
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class CombinedTrackingLoss(nn.Module):
|
| 224 |
+
"""Combined loss for tracker training.
|
| 225 |
+
|
| 226 |
+
Combines:
|
| 227 |
+
- Focal loss on center heatmap
|
| 228 |
+
- GIoU loss on predicted boxes
|
| 229 |
+
- L1 loss on size regression
|
| 230 |
+
- Optional: uncertainty NLL, contrastive, distillation
|
| 231 |
+
"""
|
| 232 |
+
def __init__(self, use_uncertainty: bool = True, use_adw: bool = True):
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.focal = FocalLoss()
|
| 235 |
+
self.giou = GIoULoss()
|
| 236 |
+
self.l1 = nn.L1Loss()
|
| 237 |
+
self.use_uncertainty = use_uncertainty
|
| 238 |
+
|
| 239 |
+
if use_uncertainty:
|
| 240 |
+
self.uncertainty_loss = UncertaintyNLLLoss()
|
| 241 |
+
|
| 242 |
+
num_tasks = 4 if use_uncertainty else 3
|
| 243 |
+
self.adw = ADWLoss(num_tasks=num_tasks) if use_adw else None
|
| 244 |
+
|
| 245 |
+
def forward(
|
| 246 |
+
self,
|
| 247 |
+
pred: dict,
|
| 248 |
+
gt_heatmap: torch.Tensor,
|
| 249 |
+
gt_size: torch.Tensor,
|
| 250 |
+
gt_boxes: torch.Tensor,
|
| 251 |
+
) -> dict:
|
| 252 |
+
"""
|
| 253 |
+
Args:
|
| 254 |
+
pred: model output dict with 'heatmap', 'size', 'boxes', optionally 'log_variance'
|
| 255 |
+
gt_heatmap: (B, 1, H, W) ground truth heatmap
|
| 256 |
+
gt_size: (B, 2) ground truth normalized size [w, h]
|
| 257 |
+
gt_boxes: (B, 4) ground truth boxes [cx, cy, w, h] in pixels
|
| 258 |
+
"""
|
| 259 |
+
# Heatmap loss
|
| 260 |
+
heatmap_loss = self.focal(pred['heatmap'], gt_heatmap)
|
| 261 |
+
|
| 262 |
+
# Size loss (at peak location)
|
| 263 |
+
B = gt_size.shape[0]
|
| 264 |
+
pred_size = pred['size'].view(B, 2, -1).mean(dim=-1) # average pool
|
| 265 |
+
size_loss = self.l1(pred_size, gt_size)
|
| 266 |
+
|
| 267 |
+
# GIoU box loss
|
| 268 |
+
giou_loss = self.giou(pred['boxes'], gt_boxes)
|
| 269 |
+
|
| 270 |
+
losses = [heatmap_loss, size_loss, giou_loss]
|
| 271 |
+
|
| 272 |
+
# Uncertainty loss
|
| 273 |
+
if self.use_uncertainty and 'log_variance' in pred:
|
| 274 |
+
log_var = pred['log_variance'].mean(dim=[1, 2, 3]) # (B,)
|
| 275 |
+
unc_loss = (0.5 * torch.exp(-log_var) * giou_loss + 0.5 * log_var).mean()
|
| 276 |
+
losses.append(unc_loss)
|
| 277 |
+
|
| 278 |
+
# Combine with ADW or simple sum
|
| 279 |
+
if self.adw is not None:
|
| 280 |
+
total_loss = self.adw(losses)
|
| 281 |
+
else:
|
| 282 |
+
weights = [1.0, 1.0, 2.0, 0.5] if len(losses) == 4 else [1.0, 1.0, 2.0]
|
| 283 |
+
total_loss = sum(w * l for w, l in zip(weights, losses))
|
| 284 |
+
|
| 285 |
+
return {
|
| 286 |
+
'total': total_loss,
|
| 287 |
+
'heatmap': heatmap_loss.detach(),
|
| 288 |
+
'size': size_loss.detach(),
|
| 289 |
+
'giou': giou_loss.detach(),
|
| 290 |
+
}
|