""" Prediction Heads for ViL Tracker. CenterHead: Predicts center heatmap + bounding box size from search features UncertaintyHead: Predicts aleatoric uncertainty for each prediction decode_predictions: Converts heatmaps + sizes to bounding boxes Architecture follows SUTrack/OSTrack corner-free head design: - Search features (B, 256, D) → reshape to (B, D, 16, 16) - Conv layers predict heatmap (B, 1, 16, 16) and size (B, 2, 16, 16) - Peak detection gives center, size gives w/h relative to search region """ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class CenterHead(nn.Module): """Center-based prediction head. Produces: - Center heatmap: (B, 1, H, W) probability of target center at each location - Size map: (B, 2, H, W) predicted width/height at each location - Offset map: (B, 2, H, W) sub-pixel offset refinement """ def __init__(self, dim: int = 384, feat_size: int = 16): super().__init__() self.feat_size = feat_size # Shared stem self.stem = nn.Sequential( nn.Conv2d(dim, 256, 3, padding=1), nn.GroupNorm(32, 256), nn.GELU(), nn.Conv2d(256, 256, 3, padding=1), nn.GroupNorm(32, 256), nn.GELU(), ) # Center heatmap head self.heatmap = nn.Sequential( nn.Conv2d(256, 64, 3, padding=1), nn.GELU(), nn.Conv2d(64, 1, 1), ) # Size head (w, h) self.size = nn.Sequential( nn.Conv2d(256, 64, 3, padding=1), nn.GELU(), nn.Conv2d(64, 2, 1), nn.Sigmoid(), # size in [0, 1] relative to search region ) # Sub-pixel offset head self.offset = nn.Sequential( nn.Conv2d(256, 64, 3, padding=1), nn.GELU(), nn.Conv2d(64, 2, 1), nn.Tanh(), # offset in [-1, 1] (sub-pixel correction) ) def forward(self, search_feat: torch.Tensor) -> dict: """ Args: search_feat: (B, N, D) search region features, N=16*16=256 Returns: dict with 'heatmap', 'size', 'offset' tensors """ B = search_feat.shape[0] # Reshape to spatial grid x = rearrange(search_feat, 'b (h w) d -> b d h w', h=self.feat_size, w=self.feat_size) feat = self.stem(x) return { 'heatmap': self.heatmap(feat), # (B, 1, 16, 16) 'size': self.size(feat), # (B, 2, 16, 16) 'offset': self.offset(feat) * 0.5, # (B, 2, 16, 16) scaled to [-0.5, 0.5] } class UncertaintyHead(nn.Module): """Predicts aleatoric uncertainty (log variance) for predictions. Used for: 1. Weighting loss contributions (uncertain predictions get lower weight) 2. Online tracking confidence (skip update when uncertain) 3. Kalman filter measurement noise adaptation """ def __init__(self, dim: int = 384, feat_size: int = 16): super().__init__() self.feat_size = feat_size self.net = nn.Sequential( nn.Conv2d(dim, 128, 3, padding=1), nn.GroupNorm(16, 128), nn.GELU(), nn.Conv2d(128, 64, 3, padding=1), nn.GELU(), nn.Conv2d(64, 1, 1), ) def forward(self, search_feat: torch.Tensor) -> torch.Tensor: """ Args: search_feat: (B, N, D) search features Returns: log_variance: (B, 1, H, W) predicted log variance """ B = search_feat.shape[0] x = rearrange(search_feat, 'b (h w) d -> b d h w', h=self.feat_size, w=self.feat_size) return self.net(x) def decode_predictions( heatmap: torch.Tensor, size: torch.Tensor, offset: torch.Tensor, search_size: int = 256, feat_size: int = 16, hanning_window: torch.Tensor = None, ) -> tuple: """Decode head outputs to bounding boxes. Args: heatmap: (B, 1, H, W) center heatmap size: (B, 2, H, W) predicted w/h relative to search region offset: (B, 2, H, W) sub-pixel offset search_size: pixel size of search region feat_size: spatial size of feature map hanning_window: optional (H, W) Hanning window for positional prior penalty Returns: boxes: (B, 4) predicted boxes in [cx, cy, w, h] format, in pixels scores: (B,) confidence scores """ B = heatmap.shape[0] stride = search_size / feat_size # 256/16 = 16 # Apply Hanning window penalty to suppress false positives at search edges heatmap_penalized = heatmap if hanning_window is not None: # hanning_window: (H, W) → broadcast to (1, 1, H, W) hw = hanning_window.to(heatmap.device) if hw.ndim == 2: hw = hw.unsqueeze(0).unsqueeze(0) heatmap_penalized = heatmap * hw # Find peak in (penalized) heatmap heatmap_flat = heatmap_penalized.view(B, -1) # (B, H*W) scores, indices = heatmap_flat.max(dim=-1) # (B,) scores = scores.sigmoid() # Convert flat index to 2D coordinates cy_idx = indices // feat_size # row cx_idx = indices % feat_size # col # Get size and offset at peak location pred_w = size[:, 0].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1) # (B,) pred_h = size[:, 1].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1) off_x = offset[:, 0].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1) off_y = offset[:, 1].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1) # Convert to pixel coordinates cx = (cx_idx.float() + 0.5 + off_x) * stride cy = (cy_idx.float() + 0.5 + off_y) * stride w = pred_w * search_size h = pred_h * search_size boxes = torch.stack([cx, cy, w, h], dim=-1) # (B, 4) return boxes, scores def generate_heatmap( center: torch.Tensor, feat_size: int = 16, search_size: int = 256, sigma: float = 2.0, ) -> torch.Tensor: """Generate ground truth Gaussian heatmap for center supervision. Args: center: (B, 2) target center in pixel coords (cx, cy) in search region feat_size: spatial size of feature map search_size: pixel size of search region sigma: Gaussian standard deviation in feature map units Returns: heatmap: (B, 1, feat_size, feat_size) ground truth heatmap """ B = center.shape[0] stride = search_size / feat_size # Convert pixel center to feature map coordinates center_feat = center / stride # (B, 2) in feature map coords # Create coordinate grid y = torch.arange(feat_size, device=center.device, dtype=center.dtype) x = torch.arange(feat_size, device=center.device, dtype=center.dtype) yy, xx = torch.meshgrid(y, x, indexing='ij') grid = torch.stack([xx, yy], dim=-1) # (H, W, 2) # Gaussian around center center_feat = center_feat.view(B, 1, 1, 2) grid = grid.unsqueeze(0) # (1, H, W, 2) dist_sq = ((grid - center_feat) ** 2).sum(dim=-1) # (B, H, W) heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)) return heatmap.unsqueeze(1) # (B, 1, H, W) def generate_size_target( size: torch.Tensor, search_size: int = 256, ) -> torch.Tensor: """Generate ground truth size target. Args: size: (B, 2) target [width, height] in pixels search_size: pixel size of search region Returns: size_norm: (B, 2) normalized to [0, 1] relative to search region """ return size.clamp(min=1) / search_size def create_hanning_window(feat_size: int = 16) -> torch.Tensor: """Create a 2D Hanning window for positional prior penalty. Applied to the classification/heatmap score map before peak detection during inference. Suppresses false positives near the edges of the search region, where the target is unlikely to be (it should be near center). Used by every SOTA tracker (OSTrack, SUTrack, SGLATrack, UETrack, DTPTrack). Args: feat_size: spatial size of feature map (16 for 256/16 stride) Returns: (feat_size, feat_size) Hanning window in [0, 1], peak=1 at center """ hann_1d = torch.hann_window(feat_size, periodic=False) hann_2d = hann_1d.unsqueeze(1) * hann_1d.unsqueeze(0) # outer product return hann_2d