Fix: mLSTM SiLU gate+activation, GroupNorm 192, stochastic depth 0.05, Hanning window
Browse files- vil_tracker/models/heads.py +31 -2
vil_tracker/models/heads.py
CHANGED
|
@@ -120,6 +120,7 @@ def decode_predictions(
|
|
| 120 |
offset: torch.Tensor,
|
| 121 |
search_size: int = 256,
|
| 122 |
feat_size: int = 16,
|
|
|
|
| 123 |
) -> tuple:
|
| 124 |
"""Decode head outputs to bounding boxes.
|
| 125 |
|
|
@@ -129,6 +130,7 @@ def decode_predictions(
|
|
| 129 |
offset: (B, 2, H, W) sub-pixel offset
|
| 130 |
search_size: pixel size of search region
|
| 131 |
feat_size: spatial size of feature map
|
|
|
|
| 132 |
|
| 133 |
Returns:
|
| 134 |
boxes: (B, 4) predicted boxes in [cx, cy, w, h] format, in pixels
|
|
@@ -137,8 +139,17 @@ def decode_predictions(
|
|
| 137 |
B = heatmap.shape[0]
|
| 138 |
stride = search_size / feat_size # 256/16 = 16
|
| 139 |
|
| 140 |
-
#
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
scores, indices = heatmap_flat.max(dim=-1) # (B,)
|
| 143 |
scores = scores.sigmoid()
|
| 144 |
|
|
@@ -213,3 +224,21 @@ def generate_size_target(
|
|
| 213 |
size_norm: (B, 2) normalized to [0, 1] relative to search region
|
| 214 |
"""
|
| 215 |
return size.clamp(min=1) / search_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
offset: torch.Tensor,
|
| 121 |
search_size: int = 256,
|
| 122 |
feat_size: int = 16,
|
| 123 |
+
hanning_window: torch.Tensor = None,
|
| 124 |
) -> tuple:
|
| 125 |
"""Decode head outputs to bounding boxes.
|
| 126 |
|
|
|
|
| 130 |
offset: (B, 2, H, W) sub-pixel offset
|
| 131 |
search_size: pixel size of search region
|
| 132 |
feat_size: spatial size of feature map
|
| 133 |
+
hanning_window: optional (H, W) Hanning window for positional prior penalty
|
| 134 |
|
| 135 |
Returns:
|
| 136 |
boxes: (B, 4) predicted boxes in [cx, cy, w, h] format, in pixels
|
|
|
|
| 139 |
B = heatmap.shape[0]
|
| 140 |
stride = search_size / feat_size # 256/16 = 16
|
| 141 |
|
| 142 |
+
# Apply Hanning window penalty to suppress false positives at search edges
|
| 143 |
+
heatmap_penalized = heatmap
|
| 144 |
+
if hanning_window is not None:
|
| 145 |
+
# hanning_window: (H, W) → broadcast to (1, 1, H, W)
|
| 146 |
+
hw = hanning_window.to(heatmap.device)
|
| 147 |
+
if hw.ndim == 2:
|
| 148 |
+
hw = hw.unsqueeze(0).unsqueeze(0)
|
| 149 |
+
heatmap_penalized = heatmap * hw
|
| 150 |
+
|
| 151 |
+
# Find peak in (penalized) heatmap
|
| 152 |
+
heatmap_flat = heatmap_penalized.view(B, -1) # (B, H*W)
|
| 153 |
scores, indices = heatmap_flat.max(dim=-1) # (B,)
|
| 154 |
scores = scores.sigmoid()
|
| 155 |
|
|
|
|
| 224 |
size_norm: (B, 2) normalized to [0, 1] relative to search region
|
| 225 |
"""
|
| 226 |
return size.clamp(min=1) / search_size
|
| 227 |
+
|
| 228 |
+
def create_hanning_window(feat_size: int = 16) -> torch.Tensor:
|
| 229 |
+
"""Create a 2D Hanning window for positional prior penalty.
|
| 230 |
+
|
| 231 |
+
Applied to the classification/heatmap score map before peak detection
|
| 232 |
+
during inference. Suppresses false positives near the edges of the
|
| 233 |
+
search region, where the target is unlikely to be (it should be near center).
|
| 234 |
+
|
| 235 |
+
Used by every SOTA tracker (OSTrack, SUTrack, SGLATrack, UETrack, DTPTrack).
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
feat_size: spatial size of feature map (16 for 256/16 stride)
|
| 239 |
+
Returns:
|
| 240 |
+
(feat_size, feat_size) Hanning window in [0, 1], peak=1 at center
|
| 241 |
+
"""
|
| 242 |
+
hann_1d = torch.hann_window(feat_size, periodic=False)
|
| 243 |
+
hann_2d = hann_1d.unsqueeze(1) * hann_1d.unsqueeze(0) # outer product
|
| 244 |
+
return hann_2d
|