File size: 8,515 Bytes
ccfb718 bb51611 ccfb718 bb51611 ccfb718 bb51611 ccfb718 bb51611 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 | """
Prediction Heads for ViL Tracker.
CenterHead: Predicts center heatmap + bounding box size from search features
UncertaintyHead: Predicts aleatoric uncertainty for each prediction
decode_predictions: Converts heatmaps + sizes to bounding boxes
Architecture follows SUTrack/OSTrack corner-free head design:
- Search features (B, 256, D) → reshape to (B, D, 16, 16)
- Conv layers predict heatmap (B, 1, 16, 16) and size (B, 2, 16, 16)
- Peak detection gives center, size gives w/h relative to search region
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class CenterHead(nn.Module):
"""Center-based prediction head.
Produces:
- Center heatmap: (B, 1, H, W) probability of target center at each location
- Size map: (B, 2, H, W) predicted width/height at each location
- Offset map: (B, 2, H, W) sub-pixel offset refinement
"""
def __init__(self, dim: int = 384, feat_size: int = 16):
super().__init__()
self.feat_size = feat_size
# Shared stem
self.stem = nn.Sequential(
nn.Conv2d(dim, 256, 3, padding=1),
nn.GroupNorm(32, 256),
nn.GELU(),
nn.Conv2d(256, 256, 3, padding=1),
nn.GroupNorm(32, 256),
nn.GELU(),
)
# Center heatmap head
self.heatmap = nn.Sequential(
nn.Conv2d(256, 64, 3, padding=1),
nn.GELU(),
nn.Conv2d(64, 1, 1),
)
# Size head (w, h)
self.size = nn.Sequential(
nn.Conv2d(256, 64, 3, padding=1),
nn.GELU(),
nn.Conv2d(64, 2, 1),
nn.Sigmoid(), # size in [0, 1] relative to search region
)
# Sub-pixel offset head
self.offset = nn.Sequential(
nn.Conv2d(256, 64, 3, padding=1),
nn.GELU(),
nn.Conv2d(64, 2, 1),
nn.Tanh(), # offset in [-1, 1] (sub-pixel correction)
)
def forward(self, search_feat: torch.Tensor) -> dict:
"""
Args:
search_feat: (B, N, D) search region features, N=16*16=256
Returns:
dict with 'heatmap', 'size', 'offset' tensors
"""
B = search_feat.shape[0]
# Reshape to spatial grid
x = rearrange(search_feat, 'b (h w) d -> b d h w', h=self.feat_size, w=self.feat_size)
feat = self.stem(x)
return {
'heatmap': self.heatmap(feat), # (B, 1, 16, 16)
'size': self.size(feat), # (B, 2, 16, 16)
'offset': self.offset(feat) * 0.5, # (B, 2, 16, 16) scaled to [-0.5, 0.5]
}
class UncertaintyHead(nn.Module):
"""Predicts aleatoric uncertainty (log variance) for predictions.
Used for:
1. Weighting loss contributions (uncertain predictions get lower weight)
2. Online tracking confidence (skip update when uncertain)
3. Kalman filter measurement noise adaptation
"""
def __init__(self, dim: int = 384, feat_size: int = 16):
super().__init__()
self.feat_size = feat_size
self.net = nn.Sequential(
nn.Conv2d(dim, 128, 3, padding=1),
nn.GroupNorm(16, 128),
nn.GELU(),
nn.Conv2d(128, 64, 3, padding=1),
nn.GELU(),
nn.Conv2d(64, 1, 1),
)
def forward(self, search_feat: torch.Tensor) -> torch.Tensor:
"""
Args:
search_feat: (B, N, D) search features
Returns:
log_variance: (B, 1, H, W) predicted log variance
"""
B = search_feat.shape[0]
x = rearrange(search_feat, 'b (h w) d -> b d h w', h=self.feat_size, w=self.feat_size)
return self.net(x)
def decode_predictions(
heatmap: torch.Tensor,
size: torch.Tensor,
offset: torch.Tensor,
search_size: int = 256,
feat_size: int = 16,
hanning_window: torch.Tensor = None,
) -> tuple:
"""Decode head outputs to bounding boxes.
Args:
heatmap: (B, 1, H, W) center heatmap
size: (B, 2, H, W) predicted w/h relative to search region
offset: (B, 2, H, W) sub-pixel offset
search_size: pixel size of search region
feat_size: spatial size of feature map
hanning_window: optional (H, W) Hanning window for positional prior penalty
Returns:
boxes: (B, 4) predicted boxes in [cx, cy, w, h] format, in pixels
scores: (B,) confidence scores
"""
B = heatmap.shape[0]
stride = search_size / feat_size # 256/16 = 16
# Apply Hanning window penalty to suppress false positives at search edges
heatmap_penalized = heatmap
if hanning_window is not None:
# hanning_window: (H, W) → broadcast to (1, 1, H, W)
hw = hanning_window.to(heatmap.device)
if hw.ndim == 2:
hw = hw.unsqueeze(0).unsqueeze(0)
heatmap_penalized = heatmap * hw
# Find peak in (penalized) heatmap
heatmap_flat = heatmap_penalized.view(B, -1) # (B, H*W)
scores, indices = heatmap_flat.max(dim=-1) # (B,)
scores = scores.sigmoid()
# Convert flat index to 2D coordinates
cy_idx = indices // feat_size # row
cx_idx = indices % feat_size # col
# Get size and offset at peak location
pred_w = size[:, 0].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1) # (B,)
pred_h = size[:, 1].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1)
off_x = offset[:, 0].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1)
off_y = offset[:, 1].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1)
# Convert to pixel coordinates
cx = (cx_idx.float() + 0.5 + off_x) * stride
cy = (cy_idx.float() + 0.5 + off_y) * stride
w = pred_w * search_size
h = pred_h * search_size
boxes = torch.stack([cx, cy, w, h], dim=-1) # (B, 4)
return boxes, scores
def generate_heatmap(
center: torch.Tensor,
feat_size: int = 16,
search_size: int = 256,
sigma: float = 2.0,
) -> torch.Tensor:
"""Generate ground truth Gaussian heatmap for center supervision.
Args:
center: (B, 2) target center in pixel coords (cx, cy) in search region
feat_size: spatial size of feature map
search_size: pixel size of search region
sigma: Gaussian standard deviation in feature map units
Returns:
heatmap: (B, 1, feat_size, feat_size) ground truth heatmap
"""
B = center.shape[0]
stride = search_size / feat_size
# Convert pixel center to feature map coordinates
center_feat = center / stride # (B, 2) in feature map coords
# Create coordinate grid
y = torch.arange(feat_size, device=center.device, dtype=center.dtype)
x = torch.arange(feat_size, device=center.device, dtype=center.dtype)
yy, xx = torch.meshgrid(y, x, indexing='ij')
grid = torch.stack([xx, yy], dim=-1) # (H, W, 2)
# Gaussian around center
center_feat = center_feat.view(B, 1, 1, 2)
grid = grid.unsqueeze(0) # (1, H, W, 2)
dist_sq = ((grid - center_feat) ** 2).sum(dim=-1) # (B, H, W)
heatmap = torch.exp(-dist_sq / (2 * sigma ** 2))
return heatmap.unsqueeze(1) # (B, 1, H, W)
def generate_size_target(
size: torch.Tensor,
search_size: int = 256,
) -> torch.Tensor:
"""Generate ground truth size target.
Args:
size: (B, 2) target [width, height] in pixels
search_size: pixel size of search region
Returns:
size_norm: (B, 2) normalized to [0, 1] relative to search region
"""
return size.clamp(min=1) / search_size
def create_hanning_window(feat_size: int = 16) -> torch.Tensor:
"""Create a 2D Hanning window for positional prior penalty.
Applied to the classification/heatmap score map before peak detection
during inference. Suppresses false positives near the edges of the
search region, where the target is unlikely to be (it should be near center).
Used by every SOTA tracker (OSTrack, SUTrack, SGLATrack, UETrack, DTPTrack).
Args:
feat_size: spatial size of feature map (16 for 256/16 stride)
Returns:
(feat_size, feat_size) Hanning window in [0, 1], peak=1 at center
"""
hann_1d = torch.hann_window(feat_size, periodic=False)
hann_2d = hann_1d.unsqueeze(1) * hann_1d.unsqueeze(0) # outer product
return hann_2d
|