| """ |
| Loss functions for ViL Tracker training. |
| |
| Includes: |
| - FocalLoss: for center heatmap prediction (handles class imbalance) |
| - GIoULoss: for bounding box regression |
| - UncertaintyNLLLoss: uncertainty-aware NLL loss |
| - MemoryContrastiveLoss: contrastive loss for mLSTM memory states |
| - AFKDDistillationLoss: attention-free knowledge distillation |
| - ADWLoss: adaptive dynamic weighting for multi-task loss |
| - CombinedTrackingLoss: combines all losses with learned weighting |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class FocalLoss(nn.Module): |
| """Focal loss for heatmap prediction (CornerNet-style). |
| |
| Handles extreme foreground/background imbalance in center heatmaps |
| where only ~1/256 positions are positive. |
| """ |
| def __init__(self, alpha: float = 2.0, beta: float = 4.0): |
| super().__init__() |
| self.alpha = alpha |
| self.beta = beta |
| |
| def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| pred: (B, 1, H, W) predicted heatmap (logits) |
| target: (B, 1, H, W) ground truth Gaussian heatmap |
| """ |
| pred_sig = torch.sigmoid(pred) |
| pred_sig = pred_sig.clamp(1e-6, 1 - 1e-6) |
| |
| pos_mask = target.eq(1).float() |
| neg_mask = target.lt(1).float() |
| |
| |
| pos_loss = -((1 - pred_sig) ** self.alpha) * torch.log(pred_sig) * pos_mask |
| |
| |
| neg_weight = (1 - target) ** self.beta |
| neg_loss = -(pred_sig ** self.alpha) * torch.log(1 - pred_sig) * neg_weight * neg_mask |
| |
| num_pos = pos_mask.sum().clamp(min=1) |
| loss = (pos_loss.sum() + neg_loss.sum()) / num_pos |
| return loss |
|
|
|
|
| class GIoULoss(nn.Module): |
| """Generalized IoU loss for bounding box regression. |
| |
| Better gradient signal than L1 for box prediction, especially |
| for non-overlapping boxes. |
| """ |
| def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| pred: (B, 4) predicted [cx, cy, w, h] |
| target: (B, 4) ground truth [cx, cy, w, h] |
| """ |
| |
| pred_x1 = pred[:, 0] - pred[:, 2] / 2 |
| pred_y1 = pred[:, 1] - pred[:, 3] / 2 |
| pred_x2 = pred[:, 0] + pred[:, 2] / 2 |
| pred_y2 = pred[:, 1] + pred[:, 3] / 2 |
| |
| gt_x1 = target[:, 0] - target[:, 2] / 2 |
| gt_y1 = target[:, 1] - target[:, 3] / 2 |
| gt_x2 = target[:, 0] + target[:, 2] / 2 |
| gt_y2 = target[:, 1] + target[:, 3] / 2 |
| |
| |
| inter_x1 = torch.max(pred_x1, gt_x1) |
| inter_y1 = torch.max(pred_y1, gt_y1) |
| inter_x2 = torch.min(pred_x2, gt_x2) |
| inter_y2 = torch.min(pred_y2, gt_y2) |
| inter_area = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0) |
| |
| |
| pred_area = (pred_x2 - pred_x1).clamp(min=0) * (pred_y2 - pred_y1).clamp(min=0) |
| gt_area = (gt_x2 - gt_x1).clamp(min=0) * (gt_y2 - gt_y1).clamp(min=0) |
| union_area = pred_area + gt_area - inter_area |
| |
| iou = inter_area / union_area.clamp(min=1e-6) |
| |
| |
| enc_x1 = torch.min(pred_x1, gt_x1) |
| enc_y1 = torch.min(pred_y1, gt_y1) |
| enc_x2 = torch.max(pred_x2, gt_x2) |
| enc_y2 = torch.max(pred_y2, gt_y2) |
| enc_area = (enc_x2 - enc_x1).clamp(min=0) * (enc_y2 - enc_y1).clamp(min=0) |
| |
| giou = iou - (enc_area - union_area) / enc_area.clamp(min=1e-6) |
| return (1 - giou).mean() |
|
|
|
|
| class UncertaintyNLLLoss(nn.Module): |
| """Uncertainty-aware negative log-likelihood loss. |
| |
| Weighs the regression loss by predicted uncertainty: |
| L = 0.5 * exp(-s) * |pred - target|^2 + 0.5 * s |
| where s = log(variance). |
| """ |
| def forward(self, pred: torch.Tensor, target: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| pred: (B, ...) predictions |
| target: (B, ...) targets |
| log_var: (B, ...) predicted log variance |
| """ |
| precision = torch.exp(-log_var) |
| sq_error = (pred - target) ** 2 |
| loss = 0.5 * (precision * sq_error + log_var) |
| return loss.mean() |
|
|
|
|
| class MemoryContrastiveLoss(nn.Module): |
| """Contrastive loss for mLSTM memory states. |
| |
| Encourages similar memory states for the same target across frames |
| and dissimilar states for different targets. |
| """ |
| def __init__(self, temperature: float = 0.1): |
| super().__init__() |
| self.temperature = temperature |
| |
| def forward(self, feat_a: torch.Tensor, feat_b: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| feat_a: (B, D) features from frame A |
| feat_b: (B, D) features from frame B (same target) |
| """ |
| |
| feat_a = F.normalize(feat_a, dim=-1) |
| feat_b = F.normalize(feat_b, dim=-1) |
| |
| B = feat_a.shape[0] |
| |
| |
| sim = torch.mm(feat_a, feat_b.t()) / self.temperature |
| |
| |
| labels = torch.arange(B, device=feat_a.device) |
| loss = F.cross_entropy(sim, labels) |
| return loss |
|
|
|
|
| class AFKDDistillationLoss(nn.Module): |
| """Attention-Free Knowledge Distillation loss. |
| |
| For distilling from MCITrack-B256 teacher to ViL-S student. |
| Uses feature matching + response-based distillation. |
| """ |
| def __init__(self, student_dim: int = 384, teacher_dim: int = 768, temperature: float = 4.0): |
| super().__init__() |
| self.temperature = temperature |
| |
| self.projector = nn.Sequential( |
| nn.Linear(student_dim, teacher_dim), |
| nn.GELU(), |
| nn.Linear(teacher_dim, teacher_dim), |
| ) |
| |
| def forward( |
| self, |
| student_feat: torch.Tensor, |
| teacher_feat: torch.Tensor, |
| student_logits: torch.Tensor = None, |
| teacher_logits: torch.Tensor = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| student_feat: (B, S, D_s) student features |
| teacher_feat: (B, S, D_t) teacher features |
| student_logits: optional (B, ...) student predictions |
| teacher_logits: optional (B, ...) teacher predictions |
| """ |
| |
| student_proj = self.projector(student_feat) |
| feat_loss = F.mse_loss(student_proj, teacher_feat.detach()) |
| |
| |
| if student_logits is not None and teacher_logits is not None: |
| T = self.temperature |
| s_soft = F.log_softmax(student_logits.view(student_logits.shape[0], -1) / T, dim=-1) |
| t_soft = F.softmax(teacher_logits.view(teacher_logits.shape[0], -1) / T, dim=-1) |
| resp_loss = F.kl_div(s_soft, t_soft.detach(), reduction='batchmean') * (T ** 2) |
| return feat_loss + resp_loss |
| |
| return feat_loss |
|
|
|
|
| class ADWLoss(nn.Module): |
| """Adaptive Dynamic Weighting for multi-task loss. |
| |
| Learns task weights based on loss magnitudes using homoscedastic uncertainty. |
| w_k = 1/(2*sigma_k^2), regularizer = log(sigma_k) |
| """ |
| def __init__(self, num_tasks: int = 4): |
| super().__init__() |
| |
| self.log_vars = nn.Parameter(torch.zeros(num_tasks)) |
| |
| def forward(self, losses: list) -> torch.Tensor: |
| """ |
| Args: |
| losses: list of scalar loss tensors (one per task) |
| Returns: |
| weighted sum of losses |
| """ |
| total = 0 |
| for i, loss in enumerate(losses): |
| precision = torch.exp(-self.log_vars[i]) |
| total = total + precision * loss + self.log_vars[i] |
| return total |
|
|
|
|
| class CombinedTrackingLoss(nn.Module): |
| """Combined loss for tracker training. |
| |
| Combines: |
| - Focal loss on center heatmap |
| - GIoU loss on predicted boxes |
| - L1 loss on size regression |
| - Optional: uncertainty NLL, contrastive, distillation |
| """ |
| def __init__(self, use_uncertainty: bool = True, use_adw: bool = True): |
| super().__init__() |
| self.focal = FocalLoss() |
| self.giou = GIoULoss() |
| self.l1 = nn.L1Loss() |
| self.use_uncertainty = use_uncertainty |
| |
| if use_uncertainty: |
| self.uncertainty_loss = UncertaintyNLLLoss() |
| |
| num_tasks = 4 if use_uncertainty else 3 |
| self.adw = ADWLoss(num_tasks=num_tasks) if use_adw else None |
| |
| def forward( |
| self, |
| pred: dict, |
| gt_heatmap: torch.Tensor, |
| gt_size: torch.Tensor, |
| gt_boxes: torch.Tensor, |
| ) -> dict: |
| """ |
| Args: |
| pred: model output dict with 'heatmap', 'size', 'boxes', optionally 'log_variance' |
| gt_heatmap: (B, 1, H, W) ground truth heatmap |
| gt_size: (B, 2) ground truth normalized size [w, h] |
| gt_boxes: (B, 4) ground truth boxes [cx, cy, w, h] in pixels |
| """ |
| |
| heatmap_loss = self.focal(pred['heatmap'], gt_heatmap) |
| |
| |
| B = gt_size.shape[0] |
| pred_size = pred['size'].view(B, 2, -1).mean(dim=-1) |
| size_loss = self.l1(pred_size, gt_size) |
| |
| |
| giou_loss = self.giou(pred['boxes'], gt_boxes) |
| |
| losses = [heatmap_loss, size_loss, giou_loss] |
| |
| |
| if self.use_uncertainty and 'log_variance' in pred: |
| log_var = pred['log_variance'].mean(dim=[1, 2, 3]) |
| unc_loss = (0.5 * torch.exp(-log_var) * giou_loss + 0.5 * log_var).mean() |
| losses.append(unc_loss) |
| |
| |
| if self.adw is not None: |
| total_loss = self.adw(losses) |
| else: |
| weights = [1.0, 1.0, 2.0, 0.5] if len(losses) == 4 else [1.0, 1.0, 2.0] |
| total_loss = sum(w * l for w, l in zip(weights, losses)) |
| |
| return { |
| 'total': total_loss, |
| 'heatmap': heatmap_loss.detach(), |
| 'size': size_loss.detach(), |
| 'giou': giou_loss.detach(), |
| } |
|
|