| """ |
| Prediction Heads for ViL Tracker. |
| |
| CenterHead: Predicts center heatmap + bounding box size from search features |
| UncertaintyHead: Predicts aleatoric uncertainty for each prediction |
| decode_predictions: Converts heatmaps + sizes to bounding boxes |
| |
| Architecture follows SUTrack/OSTrack corner-free head design: |
| - Search features (B, 256, D) → reshape to (B, D, 16, 16) |
| - Conv layers predict heatmap (B, 1, 16, 16) and size (B, 2, 16, 16) |
| - Peak detection gives center, size gives w/h relative to search region |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
|
|
| class CenterHead(nn.Module): |
| """Center-based prediction head. |
| |
| Produces: |
| - Center heatmap: (B, 1, H, W) probability of target center at each location |
| - Size map: (B, 2, H, W) predicted width/height at each location |
| - Offset map: (B, 2, H, W) sub-pixel offset refinement |
| """ |
| def __init__(self, dim: int = 384, feat_size: int = 16): |
| super().__init__() |
| self.feat_size = feat_size |
| |
| |
| self.stem = nn.Sequential( |
| nn.Conv2d(dim, 256, 3, padding=1), |
| nn.GroupNorm(32, 256), |
| nn.GELU(), |
| nn.Conv2d(256, 256, 3, padding=1), |
| nn.GroupNorm(32, 256), |
| nn.GELU(), |
| ) |
| |
| |
| self.heatmap = nn.Sequential( |
| nn.Conv2d(256, 64, 3, padding=1), |
| nn.GELU(), |
| nn.Conv2d(64, 1, 1), |
| ) |
| |
| |
| self.size = nn.Sequential( |
| nn.Conv2d(256, 64, 3, padding=1), |
| nn.GELU(), |
| nn.Conv2d(64, 2, 1), |
| nn.Sigmoid(), |
| ) |
| |
| |
| self.offset = nn.Sequential( |
| nn.Conv2d(256, 64, 3, padding=1), |
| nn.GELU(), |
| nn.Conv2d(64, 2, 1), |
| nn.Tanh(), |
| ) |
| |
| def forward(self, search_feat: torch.Tensor) -> dict: |
| """ |
| Args: |
| search_feat: (B, N, D) search region features, N=16*16=256 |
| Returns: |
| dict with 'heatmap', 'size', 'offset' tensors |
| """ |
| B = search_feat.shape[0] |
| |
| x = rearrange(search_feat, 'b (h w) d -> b d h w', h=self.feat_size, w=self.feat_size) |
| |
| feat = self.stem(x) |
| |
| return { |
| 'heatmap': self.heatmap(feat), |
| 'size': self.size(feat), |
| 'offset': self.offset(feat) * 0.5, |
| } |
|
|
|
|
| class UncertaintyHead(nn.Module): |
| """Predicts aleatoric uncertainty (log variance) for predictions. |
| |
| Used for: |
| 1. Weighting loss contributions (uncertain predictions get lower weight) |
| 2. Online tracking confidence (skip update when uncertain) |
| 3. Kalman filter measurement noise adaptation |
| """ |
| def __init__(self, dim: int = 384, feat_size: int = 16): |
| super().__init__() |
| self.feat_size = feat_size |
| self.net = nn.Sequential( |
| nn.Conv2d(dim, 128, 3, padding=1), |
| nn.GroupNorm(16, 128), |
| nn.GELU(), |
| nn.Conv2d(128, 64, 3, padding=1), |
| nn.GELU(), |
| nn.Conv2d(64, 1, 1), |
| ) |
| |
| def forward(self, search_feat: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| search_feat: (B, N, D) search features |
| Returns: |
| log_variance: (B, 1, H, W) predicted log variance |
| """ |
| B = search_feat.shape[0] |
| x = rearrange(search_feat, 'b (h w) d -> b d h w', h=self.feat_size, w=self.feat_size) |
| return self.net(x) |
|
|
|
|
| def decode_predictions( |
| heatmap: torch.Tensor, |
| size: torch.Tensor, |
| offset: torch.Tensor, |
| search_size: int = 256, |
| feat_size: int = 16, |
| hanning_window: torch.Tensor = None, |
| ) -> tuple: |
| """Decode head outputs to bounding boxes. |
| |
| Args: |
| heatmap: (B, 1, H, W) center heatmap |
| size: (B, 2, H, W) predicted w/h relative to search region |
| offset: (B, 2, H, W) sub-pixel offset |
| search_size: pixel size of search region |
| feat_size: spatial size of feature map |
| hanning_window: optional (H, W) Hanning window for positional prior penalty |
| |
| Returns: |
| boxes: (B, 4) predicted boxes in [cx, cy, w, h] format, in pixels |
| scores: (B,) confidence scores |
| """ |
| B = heatmap.shape[0] |
| stride = search_size / feat_size |
| |
| |
| heatmap_penalized = heatmap |
| if hanning_window is not None: |
| |
| hw = hanning_window.to(heatmap.device) |
| if hw.ndim == 2: |
| hw = hw.unsqueeze(0).unsqueeze(0) |
| heatmap_penalized = heatmap * hw |
| |
| |
| heatmap_flat = heatmap_penalized.view(B, -1) |
| scores, indices = heatmap_flat.max(dim=-1) |
| scores = scores.sigmoid() |
| |
| |
| cy_idx = indices // feat_size |
| cx_idx = indices % feat_size |
| |
| |
| pred_w = size[:, 0].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1) |
| pred_h = size[:, 1].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1) |
| off_x = offset[:, 0].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1) |
| off_y = offset[:, 1].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1) |
| |
| |
| cx = (cx_idx.float() + 0.5 + off_x) * stride |
| cy = (cy_idx.float() + 0.5 + off_y) * stride |
| w = pred_w * search_size |
| h = pred_h * search_size |
| |
| boxes = torch.stack([cx, cy, w, h], dim=-1) |
| return boxes, scores |
|
|
|
|
| def generate_heatmap( |
| center: torch.Tensor, |
| feat_size: int = 16, |
| search_size: int = 256, |
| sigma: float = 2.0, |
| ) -> torch.Tensor: |
| """Generate ground truth Gaussian heatmap for center supervision. |
| |
| Args: |
| center: (B, 2) target center in pixel coords (cx, cy) in search region |
| feat_size: spatial size of feature map |
| search_size: pixel size of search region |
| sigma: Gaussian standard deviation in feature map units |
| Returns: |
| heatmap: (B, 1, feat_size, feat_size) ground truth heatmap |
| """ |
| B = center.shape[0] |
| stride = search_size / feat_size |
| |
| |
| center_feat = center / stride |
| |
| |
| y = torch.arange(feat_size, device=center.device, dtype=center.dtype) |
| x = torch.arange(feat_size, device=center.device, dtype=center.dtype) |
| yy, xx = torch.meshgrid(y, x, indexing='ij') |
| grid = torch.stack([xx, yy], dim=-1) |
| |
| |
| center_feat = center_feat.view(B, 1, 1, 2) |
| grid = grid.unsqueeze(0) |
| |
| dist_sq = ((grid - center_feat) ** 2).sum(dim=-1) |
| heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)) |
| |
| return heatmap.unsqueeze(1) |
|
|
|
|
| def generate_size_target( |
| size: torch.Tensor, |
| search_size: int = 256, |
| ) -> torch.Tensor: |
| """Generate ground truth size target. |
| |
| Args: |
| size: (B, 2) target [width, height] in pixels |
| search_size: pixel size of search region |
| Returns: |
| size_norm: (B, 2) normalized to [0, 1] relative to search region |
| """ |
| return size.clamp(min=1) / search_size |
|
|
| def create_hanning_window(feat_size: int = 16) -> torch.Tensor: |
| """Create a 2D Hanning window for positional prior penalty. |
| |
| Applied to the classification/heatmap score map before peak detection |
| during inference. Suppresses false positives near the edges of the |
| search region, where the target is unlikely to be (it should be near center). |
| |
| Used by every SOTA tracker (OSTrack, SUTrack, SGLATrack, UETrack, DTPTrack). |
| |
| Args: |
| feat_size: spatial size of feature map (16 for 256/16 stride) |
| Returns: |
| (feat_size, feat_size) Hanning window in [0, 1], peak=1 at center |
| """ |
| hann_1d = torch.hann_window(feat_size, periodic=False) |
| hann_2d = hann_1d.unsqueeze(1) * hann_1d.unsqueeze(0) |
| return hann_2d |
|
|