""" Loss functions for ViL Tracker training. Includes: - FocalLoss: for center heatmap prediction (handles class imbalance) - GIoULoss: for bounding box regression - UncertaintyNLLLoss: uncertainty-aware NLL loss - MemoryContrastiveLoss: contrastive loss for mLSTM memory states - AFKDDistillationLoss: attention-free knowledge distillation - ADWLoss: adaptive dynamic weighting for multi-task loss - CombinedTrackingLoss: combines all losses with learned weighting """ import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): """Focal loss for heatmap prediction (CornerNet-style). Handles extreme foreground/background imbalance in center heatmaps where only ~1/256 positions are positive. """ def __init__(self, alpha: float = 2.0, beta: float = 4.0): super().__init__() self.alpha = alpha self.beta = beta def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: pred: (B, 1, H, W) predicted heatmap (logits) target: (B, 1, H, W) ground truth Gaussian heatmap """ pred_sig = torch.sigmoid(pred) pred_sig = pred_sig.clamp(1e-6, 1 - 1e-6) pos_mask = target.eq(1).float() neg_mask = target.lt(1).float() # Positive loss pos_loss = -((1 - pred_sig) ** self.alpha) * torch.log(pred_sig) * pos_mask # Negative loss (weighted by distance from GT peak) neg_weight = (1 - target) ** self.beta neg_loss = -(pred_sig ** self.alpha) * torch.log(1 - pred_sig) * neg_weight * neg_mask num_pos = pos_mask.sum().clamp(min=1) loss = (pos_loss.sum() + neg_loss.sum()) / num_pos return loss class GIoULoss(nn.Module): """Generalized IoU loss for bounding box regression. Better gradient signal than L1 for box prediction, especially for non-overlapping boxes. """ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: pred: (B, 4) predicted [cx, cy, w, h] target: (B, 4) ground truth [cx, cy, w, h] """ # Convert to [x1, y1, x2, y2] pred_x1 = pred[:, 0] - pred[:, 2] / 2 pred_y1 = pred[:, 1] - pred[:, 3] / 2 pred_x2 = pred[:, 0] + pred[:, 2] / 2 pred_y2 = pred[:, 1] + pred[:, 3] / 2 gt_x1 = target[:, 0] - target[:, 2] / 2 gt_y1 = target[:, 1] - target[:, 3] / 2 gt_x2 = target[:, 0] + target[:, 2] / 2 gt_y2 = target[:, 1] + target[:, 3] / 2 # Intersection inter_x1 = torch.max(pred_x1, gt_x1) inter_y1 = torch.max(pred_y1, gt_y1) inter_x2 = torch.min(pred_x2, gt_x2) inter_y2 = torch.min(pred_y2, gt_y2) inter_area = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0) # Union pred_area = (pred_x2 - pred_x1).clamp(min=0) * (pred_y2 - pred_y1).clamp(min=0) gt_area = (gt_x2 - gt_x1).clamp(min=0) * (gt_y2 - gt_y1).clamp(min=0) union_area = pred_area + gt_area - inter_area iou = inter_area / union_area.clamp(min=1e-6) # Enclosing box enc_x1 = torch.min(pred_x1, gt_x1) enc_y1 = torch.min(pred_y1, gt_y1) enc_x2 = torch.max(pred_x2, gt_x2) enc_y2 = torch.max(pred_y2, gt_y2) enc_area = (enc_x2 - enc_x1).clamp(min=0) * (enc_y2 - enc_y1).clamp(min=0) giou = iou - (enc_area - union_area) / enc_area.clamp(min=1e-6) return (1 - giou).mean() class UncertaintyNLLLoss(nn.Module): """Uncertainty-aware negative log-likelihood loss. Weighs the regression loss by predicted uncertainty: L = 0.5 * exp(-s) * |pred - target|^2 + 0.5 * s where s = log(variance). """ def forward(self, pred: torch.Tensor, target: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: """ Args: pred: (B, ...) predictions target: (B, ...) targets log_var: (B, ...) predicted log variance """ precision = torch.exp(-log_var) sq_error = (pred - target) ** 2 loss = 0.5 * (precision * sq_error + log_var) return loss.mean() class MemoryContrastiveLoss(nn.Module): """Contrastive loss for mLSTM memory states. Encourages similar memory states for the same target across frames and dissimilar states for different targets. """ def __init__(self, temperature: float = 0.1): super().__init__() self.temperature = temperature def forward(self, feat_a: torch.Tensor, feat_b: torch.Tensor) -> torch.Tensor: """ Args: feat_a: (B, D) features from frame A feat_b: (B, D) features from frame B (same target) """ # L2 normalize feat_a = F.normalize(feat_a, dim=-1) feat_b = F.normalize(feat_b, dim=-1) B = feat_a.shape[0] # Similarity matrix sim = torch.mm(feat_a, feat_b.t()) / self.temperature # (B, B) # Positive pairs along diagonal labels = torch.arange(B, device=feat_a.device) loss = F.cross_entropy(sim, labels) return loss class AFKDDistillationLoss(nn.Module): """Attention-Free Knowledge Distillation loss. For distilling from MCITrack-B256 teacher to ViL-S student. Uses feature matching + response-based distillation. """ def __init__(self, student_dim: int = 384, teacher_dim: int = 768, temperature: float = 4.0): super().__init__() self.temperature = temperature # Projector to match dimensions self.projector = nn.Sequential( nn.Linear(student_dim, teacher_dim), nn.GELU(), nn.Linear(teacher_dim, teacher_dim), ) def forward( self, student_feat: torch.Tensor, teacher_feat: torch.Tensor, student_logits: torch.Tensor = None, teacher_logits: torch.Tensor = None, ) -> torch.Tensor: """ Args: student_feat: (B, S, D_s) student features teacher_feat: (B, S, D_t) teacher features student_logits: optional (B, ...) student predictions teacher_logits: optional (B, ...) teacher predictions """ # Feature distillation student_proj = self.projector(student_feat) feat_loss = F.mse_loss(student_proj, teacher_feat.detach()) # Response distillation (if logits provided) if student_logits is not None and teacher_logits is not None: T = self.temperature s_soft = F.log_softmax(student_logits.view(student_logits.shape[0], -1) / T, dim=-1) t_soft = F.softmax(teacher_logits.view(teacher_logits.shape[0], -1) / T, dim=-1) resp_loss = F.kl_div(s_soft, t_soft.detach(), reduction='batchmean') * (T ** 2) return feat_loss + resp_loss return feat_loss class ADWLoss(nn.Module): """Adaptive Dynamic Weighting for multi-task loss. Learns task weights based on loss magnitudes using homoscedastic uncertainty. w_k = 1/(2*sigma_k^2), regularizer = log(sigma_k) """ def __init__(self, num_tasks: int = 4): super().__init__() # Log variance parameters (initialized to 0 = equal weighting) self.log_vars = nn.Parameter(torch.zeros(num_tasks)) def forward(self, losses: list) -> torch.Tensor: """ Args: losses: list of scalar loss tensors (one per task) Returns: weighted sum of losses """ total = 0 for i, loss in enumerate(losses): precision = torch.exp(-self.log_vars[i]) total = total + precision * loss + self.log_vars[i] return total class CombinedTrackingLoss(nn.Module): """Combined loss for tracker training. Combines: - Focal loss on center heatmap - GIoU loss on predicted boxes - L1 loss on size regression - Optional: uncertainty NLL, contrastive, distillation """ def __init__(self, use_uncertainty: bool = True, use_adw: bool = True): super().__init__() self.focal = FocalLoss() self.giou = GIoULoss() self.l1 = nn.L1Loss() self.use_uncertainty = use_uncertainty if use_uncertainty: self.uncertainty_loss = UncertaintyNLLLoss() num_tasks = 4 if use_uncertainty else 3 self.adw = ADWLoss(num_tasks=num_tasks) if use_adw else None def forward( self, pred: dict, gt_heatmap: torch.Tensor, gt_size: torch.Tensor, gt_boxes: torch.Tensor, ) -> dict: """ Args: pred: model output dict with 'heatmap', 'size', 'boxes', optionally 'log_variance' gt_heatmap: (B, 1, H, W) ground truth heatmap gt_size: (B, 2) ground truth normalized size [w, h] gt_boxes: (B, 4) ground truth boxes [cx, cy, w, h] in pixels """ # Heatmap loss heatmap_loss = self.focal(pred['heatmap'], gt_heatmap) # Size loss (at peak location) B = gt_size.shape[0] pred_size = pred['size'].view(B, 2, -1).mean(dim=-1) # average pool size_loss = self.l1(pred_size, gt_size) # GIoU box loss giou_loss = self.giou(pred['boxes'], gt_boxes) losses = [heatmap_loss, size_loss, giou_loss] # Uncertainty loss if self.use_uncertainty and 'log_variance' in pred: log_var = pred['log_variance'].mean(dim=[1, 2, 3]) # (B,) unc_loss = (0.5 * torch.exp(-log_var) * giou_loss + 0.5 * log_var).mean() losses.append(unc_loss) # Combine with ADW or simple sum if self.adw is not None: total_loss = self.adw(losses) else: weights = [1.0, 1.0, 2.0, 0.5] if len(losses) == 4 else [1.0, 1.0, 2.0] total_loss = sum(w * l for w, l in zip(weights, losses)) return { 'total': total_loss, 'heatmap': heatmap_loss.detach(), 'size': size_loss.detach(), 'giou': giou_loss.detach(), }