omar-ah commited on
Commit
bb51611
·
verified ·
1 Parent(s): 3e094bb

Fix: mLSTM SiLU gate+activation, GroupNorm 192, stochastic depth 0.05, Hanning window

Browse files
Files changed (1) hide show
  1. vil_tracker/models/heads.py +31 -2
vil_tracker/models/heads.py CHANGED
@@ -120,6 +120,7 @@ def decode_predictions(
120
  offset: torch.Tensor,
121
  search_size: int = 256,
122
  feat_size: int = 16,
 
123
  ) -> tuple:
124
  """Decode head outputs to bounding boxes.
125
 
@@ -129,6 +130,7 @@ def decode_predictions(
129
  offset: (B, 2, H, W) sub-pixel offset
130
  search_size: pixel size of search region
131
  feat_size: spatial size of feature map
 
132
 
133
  Returns:
134
  boxes: (B, 4) predicted boxes in [cx, cy, w, h] format, in pixels
@@ -137,8 +139,17 @@ def decode_predictions(
137
  B = heatmap.shape[0]
138
  stride = search_size / feat_size # 256/16 = 16
139
 
140
- # Find peak in heatmap
141
- heatmap_flat = heatmap.view(B, -1) # (B, H*W)
 
 
 
 
 
 
 
 
 
142
  scores, indices = heatmap_flat.max(dim=-1) # (B,)
143
  scores = scores.sigmoid()
144
 
@@ -213,3 +224,21 @@ def generate_size_target(
213
  size_norm: (B, 2) normalized to [0, 1] relative to search region
214
  """
215
  return size.clamp(min=1) / search_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  offset: torch.Tensor,
121
  search_size: int = 256,
122
  feat_size: int = 16,
123
+ hanning_window: torch.Tensor = None,
124
  ) -> tuple:
125
  """Decode head outputs to bounding boxes.
126
 
 
130
  offset: (B, 2, H, W) sub-pixel offset
131
  search_size: pixel size of search region
132
  feat_size: spatial size of feature map
133
+ hanning_window: optional (H, W) Hanning window for positional prior penalty
134
 
135
  Returns:
136
  boxes: (B, 4) predicted boxes in [cx, cy, w, h] format, in pixels
 
139
  B = heatmap.shape[0]
140
  stride = search_size / feat_size # 256/16 = 16
141
 
142
+ # Apply Hanning window penalty to suppress false positives at search edges
143
+ heatmap_penalized = heatmap
144
+ if hanning_window is not None:
145
+ # hanning_window: (H, W) → broadcast to (1, 1, H, W)
146
+ hw = hanning_window.to(heatmap.device)
147
+ if hw.ndim == 2:
148
+ hw = hw.unsqueeze(0).unsqueeze(0)
149
+ heatmap_penalized = heatmap * hw
150
+
151
+ # Find peak in (penalized) heatmap
152
+ heatmap_flat = heatmap_penalized.view(B, -1) # (B, H*W)
153
  scores, indices = heatmap_flat.max(dim=-1) # (B,)
154
  scores = scores.sigmoid()
155
 
 
224
  size_norm: (B, 2) normalized to [0, 1] relative to search region
225
  """
226
  return size.clamp(min=1) / search_size
227
+
228
+ def create_hanning_window(feat_size: int = 16) -> torch.Tensor:
229
+ """Create a 2D Hanning window for positional prior penalty.
230
+
231
+ Applied to the classification/heatmap score map before peak detection
232
+ during inference. Suppresses false positives near the edges of the
233
+ search region, where the target is unlikely to be (it should be near center).
234
+
235
+ Used by every SOTA tracker (OSTrack, SUTrack, SGLATrack, UETrack, DTPTrack).
236
+
237
+ Args:
238
+ feat_size: spatial size of feature map (16 for 256/16 stride)
239
+ Returns:
240
+ (feat_size, feat_size) Hanning window in [0, 1], peak=1 at center
241
+ """
242
+ hann_1d = torch.hann_window(feat_size, periodic=False)
243
+ hann_2d = hann_1d.unsqueeze(1) * hann_1d.unsqueeze(0) # outer product
244
+ return hann_2d