omar-ah's picture
Fix: mLSTM SiLU gate+activation, GroupNorm 192, stochastic depth 0.05, Hanning window
bb51611 verified
"""
Prediction Heads for ViL Tracker.
CenterHead: Predicts center heatmap + bounding box size from search features
UncertaintyHead: Predicts aleatoric uncertainty for each prediction
decode_predictions: Converts heatmaps + sizes to bounding boxes
Architecture follows SUTrack/OSTrack corner-free head design:
- Search features (B, 256, D) → reshape to (B, D, 16, 16)
- Conv layers predict heatmap (B, 1, 16, 16) and size (B, 2, 16, 16)
- Peak detection gives center, size gives w/h relative to search region
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class CenterHead(nn.Module):
"""Center-based prediction head.
Produces:
- Center heatmap: (B, 1, H, W) probability of target center at each location
- Size map: (B, 2, H, W) predicted width/height at each location
- Offset map: (B, 2, H, W) sub-pixel offset refinement
"""
def __init__(self, dim: int = 384, feat_size: int = 16):
super().__init__()
self.feat_size = feat_size
# Shared stem
self.stem = nn.Sequential(
nn.Conv2d(dim, 256, 3, padding=1),
nn.GroupNorm(32, 256),
nn.GELU(),
nn.Conv2d(256, 256, 3, padding=1),
nn.GroupNorm(32, 256),
nn.GELU(),
)
# Center heatmap head
self.heatmap = nn.Sequential(
nn.Conv2d(256, 64, 3, padding=1),
nn.GELU(),
nn.Conv2d(64, 1, 1),
)
# Size head (w, h)
self.size = nn.Sequential(
nn.Conv2d(256, 64, 3, padding=1),
nn.GELU(),
nn.Conv2d(64, 2, 1),
nn.Sigmoid(), # size in [0, 1] relative to search region
)
# Sub-pixel offset head
self.offset = nn.Sequential(
nn.Conv2d(256, 64, 3, padding=1),
nn.GELU(),
nn.Conv2d(64, 2, 1),
nn.Tanh(), # offset in [-1, 1] (sub-pixel correction)
)
def forward(self, search_feat: torch.Tensor) -> dict:
"""
Args:
search_feat: (B, N, D) search region features, N=16*16=256
Returns:
dict with 'heatmap', 'size', 'offset' tensors
"""
B = search_feat.shape[0]
# Reshape to spatial grid
x = rearrange(search_feat, 'b (h w) d -> b d h w', h=self.feat_size, w=self.feat_size)
feat = self.stem(x)
return {
'heatmap': self.heatmap(feat), # (B, 1, 16, 16)
'size': self.size(feat), # (B, 2, 16, 16)
'offset': self.offset(feat) * 0.5, # (B, 2, 16, 16) scaled to [-0.5, 0.5]
}
class UncertaintyHead(nn.Module):
"""Predicts aleatoric uncertainty (log variance) for predictions.
Used for:
1. Weighting loss contributions (uncertain predictions get lower weight)
2. Online tracking confidence (skip update when uncertain)
3. Kalman filter measurement noise adaptation
"""
def __init__(self, dim: int = 384, feat_size: int = 16):
super().__init__()
self.feat_size = feat_size
self.net = nn.Sequential(
nn.Conv2d(dim, 128, 3, padding=1),
nn.GroupNorm(16, 128),
nn.GELU(),
nn.Conv2d(128, 64, 3, padding=1),
nn.GELU(),
nn.Conv2d(64, 1, 1),
)
def forward(self, search_feat: torch.Tensor) -> torch.Tensor:
"""
Args:
search_feat: (B, N, D) search features
Returns:
log_variance: (B, 1, H, W) predicted log variance
"""
B = search_feat.shape[0]
x = rearrange(search_feat, 'b (h w) d -> b d h w', h=self.feat_size, w=self.feat_size)
return self.net(x)
def decode_predictions(
heatmap: torch.Tensor,
size: torch.Tensor,
offset: torch.Tensor,
search_size: int = 256,
feat_size: int = 16,
hanning_window: torch.Tensor = None,
) -> tuple:
"""Decode head outputs to bounding boxes.
Args:
heatmap: (B, 1, H, W) center heatmap
size: (B, 2, H, W) predicted w/h relative to search region
offset: (B, 2, H, W) sub-pixel offset
search_size: pixel size of search region
feat_size: spatial size of feature map
hanning_window: optional (H, W) Hanning window for positional prior penalty
Returns:
boxes: (B, 4) predicted boxes in [cx, cy, w, h] format, in pixels
scores: (B,) confidence scores
"""
B = heatmap.shape[0]
stride = search_size / feat_size # 256/16 = 16
# Apply Hanning window penalty to suppress false positives at search edges
heatmap_penalized = heatmap
if hanning_window is not None:
# hanning_window: (H, W) → broadcast to (1, 1, H, W)
hw = hanning_window.to(heatmap.device)
if hw.ndim == 2:
hw = hw.unsqueeze(0).unsqueeze(0)
heatmap_penalized = heatmap * hw
# Find peak in (penalized) heatmap
heatmap_flat = heatmap_penalized.view(B, -1) # (B, H*W)
scores, indices = heatmap_flat.max(dim=-1) # (B,)
scores = scores.sigmoid()
# Convert flat index to 2D coordinates
cy_idx = indices // feat_size # row
cx_idx = indices % feat_size # col
# Get size and offset at peak location
pred_w = size[:, 0].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1) # (B,)
pred_h = size[:, 1].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1)
off_x = offset[:, 0].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1)
off_y = offset[:, 1].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1)
# Convert to pixel coordinates
cx = (cx_idx.float() + 0.5 + off_x) * stride
cy = (cy_idx.float() + 0.5 + off_y) * stride
w = pred_w * search_size
h = pred_h * search_size
boxes = torch.stack([cx, cy, w, h], dim=-1) # (B, 4)
return boxes, scores
def generate_heatmap(
center: torch.Tensor,
feat_size: int = 16,
search_size: int = 256,
sigma: float = 2.0,
) -> torch.Tensor:
"""Generate ground truth Gaussian heatmap for center supervision.
Args:
center: (B, 2) target center in pixel coords (cx, cy) in search region
feat_size: spatial size of feature map
search_size: pixel size of search region
sigma: Gaussian standard deviation in feature map units
Returns:
heatmap: (B, 1, feat_size, feat_size) ground truth heatmap
"""
B = center.shape[0]
stride = search_size / feat_size
# Convert pixel center to feature map coordinates
center_feat = center / stride # (B, 2) in feature map coords
# Create coordinate grid
y = torch.arange(feat_size, device=center.device, dtype=center.dtype)
x = torch.arange(feat_size, device=center.device, dtype=center.dtype)
yy, xx = torch.meshgrid(y, x, indexing='ij')
grid = torch.stack([xx, yy], dim=-1) # (H, W, 2)
# Gaussian around center
center_feat = center_feat.view(B, 1, 1, 2)
grid = grid.unsqueeze(0) # (1, H, W, 2)
dist_sq = ((grid - center_feat) ** 2).sum(dim=-1) # (B, H, W)
heatmap = torch.exp(-dist_sq / (2 * sigma ** 2))
return heatmap.unsqueeze(1) # (B, 1, H, W)
def generate_size_target(
size: torch.Tensor,
search_size: int = 256,
) -> torch.Tensor:
"""Generate ground truth size target.
Args:
size: (B, 2) target [width, height] in pixels
search_size: pixel size of search region
Returns:
size_norm: (B, 2) normalized to [0, 1] relative to search region
"""
return size.clamp(min=1) / search_size
def create_hanning_window(feat_size: int = 16) -> torch.Tensor:
"""Create a 2D Hanning window for positional prior penalty.
Applied to the classification/heatmap score map before peak detection
during inference. Suppresses false positives near the edges of the
search region, where the target is unlikely to be (it should be near center).
Used by every SOTA tracker (OSTrack, SUTrack, SGLATrack, UETrack, DTPTrack).
Args:
feat_size: spatial size of feature map (16 for 256/16 stride)
Returns:
(feat_size, feat_size) Hanning window in [0, 1], peak=1 at center
"""
hann_1d = torch.hann_window(feat_size, periodic=False)
hann_2d = hann_1d.unsqueeze(1) * hann_1d.unsqueeze(0) # outer product
return hann_2d