omar-ah's picture
Upload vil_tracker/training/losses.py with huggingface_hub
01f95f3 verified
"""
Loss functions for ViL Tracker training.
Includes:
- FocalLoss: for center heatmap prediction (handles class imbalance)
- GIoULoss: for bounding box regression
- UncertaintyNLLLoss: uncertainty-aware NLL loss
- MemoryContrastiveLoss: contrastive loss for mLSTM memory states
- AFKDDistillationLoss: attention-free knowledge distillation
- ADWLoss: adaptive dynamic weighting for multi-task loss
- CombinedTrackingLoss: combines all losses with learned weighting
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
"""Focal loss for heatmap prediction (CornerNet-style).
Handles extreme foreground/background imbalance in center heatmaps
where only ~1/256 positions are positive.
"""
def __init__(self, alpha: float = 2.0, beta: float = 4.0):
super().__init__()
self.alpha = alpha
self.beta = beta
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
pred: (B, 1, H, W) predicted heatmap (logits)
target: (B, 1, H, W) ground truth Gaussian heatmap
"""
pred_sig = torch.sigmoid(pred)
pred_sig = pred_sig.clamp(1e-6, 1 - 1e-6)
pos_mask = target.eq(1).float()
neg_mask = target.lt(1).float()
# Positive loss
pos_loss = -((1 - pred_sig) ** self.alpha) * torch.log(pred_sig) * pos_mask
# Negative loss (weighted by distance from GT peak)
neg_weight = (1 - target) ** self.beta
neg_loss = -(pred_sig ** self.alpha) * torch.log(1 - pred_sig) * neg_weight * neg_mask
num_pos = pos_mask.sum().clamp(min=1)
loss = (pos_loss.sum() + neg_loss.sum()) / num_pos
return loss
class GIoULoss(nn.Module):
"""Generalized IoU loss for bounding box regression.
Better gradient signal than L1 for box prediction, especially
for non-overlapping boxes.
"""
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
pred: (B, 4) predicted [cx, cy, w, h]
target: (B, 4) ground truth [cx, cy, w, h]
"""
# Convert to [x1, y1, x2, y2]
pred_x1 = pred[:, 0] - pred[:, 2] / 2
pred_y1 = pred[:, 1] - pred[:, 3] / 2
pred_x2 = pred[:, 0] + pred[:, 2] / 2
pred_y2 = pred[:, 1] + pred[:, 3] / 2
gt_x1 = target[:, 0] - target[:, 2] / 2
gt_y1 = target[:, 1] - target[:, 3] / 2
gt_x2 = target[:, 0] + target[:, 2] / 2
gt_y2 = target[:, 1] + target[:, 3] / 2
# Intersection
inter_x1 = torch.max(pred_x1, gt_x1)
inter_y1 = torch.max(pred_y1, gt_y1)
inter_x2 = torch.min(pred_x2, gt_x2)
inter_y2 = torch.min(pred_y2, gt_y2)
inter_area = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
# Union
pred_area = (pred_x2 - pred_x1).clamp(min=0) * (pred_y2 - pred_y1).clamp(min=0)
gt_area = (gt_x2 - gt_x1).clamp(min=0) * (gt_y2 - gt_y1).clamp(min=0)
union_area = pred_area + gt_area - inter_area
iou = inter_area / union_area.clamp(min=1e-6)
# Enclosing box
enc_x1 = torch.min(pred_x1, gt_x1)
enc_y1 = torch.min(pred_y1, gt_y1)
enc_x2 = torch.max(pred_x2, gt_x2)
enc_y2 = torch.max(pred_y2, gt_y2)
enc_area = (enc_x2 - enc_x1).clamp(min=0) * (enc_y2 - enc_y1).clamp(min=0)
giou = iou - (enc_area - union_area) / enc_area.clamp(min=1e-6)
return (1 - giou).mean()
class UncertaintyNLLLoss(nn.Module):
"""Uncertainty-aware negative log-likelihood loss.
Weighs the regression loss by predicted uncertainty:
L = 0.5 * exp(-s) * |pred - target|^2 + 0.5 * s
where s = log(variance).
"""
def forward(self, pred: torch.Tensor, target: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
"""
Args:
pred: (B, ...) predictions
target: (B, ...) targets
log_var: (B, ...) predicted log variance
"""
precision = torch.exp(-log_var)
sq_error = (pred - target) ** 2
loss = 0.5 * (precision * sq_error + log_var)
return loss.mean()
class MemoryContrastiveLoss(nn.Module):
"""Contrastive loss for mLSTM memory states.
Encourages similar memory states for the same target across frames
and dissimilar states for different targets.
"""
def __init__(self, temperature: float = 0.1):
super().__init__()
self.temperature = temperature
def forward(self, feat_a: torch.Tensor, feat_b: torch.Tensor) -> torch.Tensor:
"""
Args:
feat_a: (B, D) features from frame A
feat_b: (B, D) features from frame B (same target)
"""
# L2 normalize
feat_a = F.normalize(feat_a, dim=-1)
feat_b = F.normalize(feat_b, dim=-1)
B = feat_a.shape[0]
# Similarity matrix
sim = torch.mm(feat_a, feat_b.t()) / self.temperature # (B, B)
# Positive pairs along diagonal
labels = torch.arange(B, device=feat_a.device)
loss = F.cross_entropy(sim, labels)
return loss
class AFKDDistillationLoss(nn.Module):
"""Attention-Free Knowledge Distillation loss.
For distilling from MCITrack-B256 teacher to ViL-S student.
Uses feature matching + response-based distillation.
"""
def __init__(self, student_dim: int = 384, teacher_dim: int = 768, temperature: float = 4.0):
super().__init__()
self.temperature = temperature
# Projector to match dimensions
self.projector = nn.Sequential(
nn.Linear(student_dim, teacher_dim),
nn.GELU(),
nn.Linear(teacher_dim, teacher_dim),
)
def forward(
self,
student_feat: torch.Tensor,
teacher_feat: torch.Tensor,
student_logits: torch.Tensor = None,
teacher_logits: torch.Tensor = None,
) -> torch.Tensor:
"""
Args:
student_feat: (B, S, D_s) student features
teacher_feat: (B, S, D_t) teacher features
student_logits: optional (B, ...) student predictions
teacher_logits: optional (B, ...) teacher predictions
"""
# Feature distillation
student_proj = self.projector(student_feat)
feat_loss = F.mse_loss(student_proj, teacher_feat.detach())
# Response distillation (if logits provided)
if student_logits is not None and teacher_logits is not None:
T = self.temperature
s_soft = F.log_softmax(student_logits.view(student_logits.shape[0], -1) / T, dim=-1)
t_soft = F.softmax(teacher_logits.view(teacher_logits.shape[0], -1) / T, dim=-1)
resp_loss = F.kl_div(s_soft, t_soft.detach(), reduction='batchmean') * (T ** 2)
return feat_loss + resp_loss
return feat_loss
class ADWLoss(nn.Module):
"""Adaptive Dynamic Weighting for multi-task loss.
Learns task weights based on loss magnitudes using homoscedastic uncertainty.
w_k = 1/(2*sigma_k^2), regularizer = log(sigma_k)
"""
def __init__(self, num_tasks: int = 4):
super().__init__()
# Log variance parameters (initialized to 0 = equal weighting)
self.log_vars = nn.Parameter(torch.zeros(num_tasks))
def forward(self, losses: list) -> torch.Tensor:
"""
Args:
losses: list of scalar loss tensors (one per task)
Returns:
weighted sum of losses
"""
total = 0
for i, loss in enumerate(losses):
precision = torch.exp(-self.log_vars[i])
total = total + precision * loss + self.log_vars[i]
return total
class CombinedTrackingLoss(nn.Module):
"""Combined loss for tracker training.
Combines:
- Focal loss on center heatmap
- GIoU loss on predicted boxes
- L1 loss on size regression
- Optional: uncertainty NLL, contrastive, distillation
"""
def __init__(self, use_uncertainty: bool = True, use_adw: bool = True):
super().__init__()
self.focal = FocalLoss()
self.giou = GIoULoss()
self.l1 = nn.L1Loss()
self.use_uncertainty = use_uncertainty
if use_uncertainty:
self.uncertainty_loss = UncertaintyNLLLoss()
num_tasks = 4 if use_uncertainty else 3
self.adw = ADWLoss(num_tasks=num_tasks) if use_adw else None
def forward(
self,
pred: dict,
gt_heatmap: torch.Tensor,
gt_size: torch.Tensor,
gt_boxes: torch.Tensor,
) -> dict:
"""
Args:
pred: model output dict with 'heatmap', 'size', 'boxes', optionally 'log_variance'
gt_heatmap: (B, 1, H, W) ground truth heatmap
gt_size: (B, 2) ground truth normalized size [w, h]
gt_boxes: (B, 4) ground truth boxes [cx, cy, w, h] in pixels
"""
# Heatmap loss
heatmap_loss = self.focal(pred['heatmap'], gt_heatmap)
# Size loss (at peak location)
B = gt_size.shape[0]
pred_size = pred['size'].view(B, 2, -1).mean(dim=-1) # average pool
size_loss = self.l1(pred_size, gt_size)
# GIoU box loss
giou_loss = self.giou(pred['boxes'], gt_boxes)
losses = [heatmap_loss, size_loss, giou_loss]
# Uncertainty loss
if self.use_uncertainty and 'log_variance' in pred:
log_var = pred['log_variance'].mean(dim=[1, 2, 3]) # (B,)
unc_loss = (0.5 * torch.exp(-log_var) * giou_loss + 0.5 * log_var).mean()
losses.append(unc_loss)
# Combine with ADW or simple sum
if self.adw is not None:
total_loss = self.adw(losses)
else:
weights = [1.0, 1.0, 2.0, 0.5] if len(losses) == 4 else [1.0, 1.0, 2.0]
total_loss = sum(w * l for w, l in zip(weights, losses))
return {
'total': total_loss,
'heatmap': heatmap_loss.detach(),
'size': size_loss.detach(),
'giou': giou_loss.detach(),
}