omar-ah commited on
Commit
51d2470
·
verified ·
1 Parent(s): cde1dbf

Fix: mLSTM SiLU gate+activation, GroupNorm 192, stochastic depth 0.05, Hanning window

Browse files
Files changed (1) hide show
  1. test_all.py +19 -2
test_all.py CHANGED
@@ -73,6 +73,10 @@ def test_mlstm_cell():
73
  assert cell_params < 1_000_000, f"Cell has {cell_params:,} params (should be <1M)"
74
  assert cell_params > 800_000, f"Cell has {cell_params:,} params (should be >800K)"
75
 
 
 
 
 
76
  x = torch.randn(2, 20, 384)
77
  y = cell(x)
78
  assert y.shape == (2, 20, 384), f"Cell output shape: {y.shape}"
@@ -190,7 +194,7 @@ test("Backbone (TMoE + integrated FiLM)", test_backbone_tmoe_film)
190
  # Test 6: Prediction Heads
191
  # ============================================================
192
  def test_heads():
193
- from vil_tracker.models.heads import CenterHead, UncertaintyHead, decode_predictions
194
 
195
  center_head = CenterHead(dim=384, feat_size=16)
196
  unc_head = UncertaintyHead(dim=384, feat_size=16)
@@ -205,11 +209,24 @@ def test_heads():
205
  assert preds['size'].shape == (2, 2, 16, 16), f"Size shape: {preds['size'].shape}"
206
  assert preds['offset'].shape == (2, 2, 16, 16), f"Offset shape: {preds['offset'].shape}"
207
 
208
- # Decode
209
  boxes, scores = decode_predictions(preds['heatmap'], preds['size'], preds['offset'])
210
  assert boxes.shape == (2, 4), f"Boxes shape: {boxes.shape}"
211
  assert scores.shape == (2,), f"Scores shape: {scores.shape}"
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  # Uncertainty
214
  log_var = unc_head(search_feat)
215
  assert log_var.shape == (2, 1, 16, 16), f"Log variance shape: {log_var.shape}"
 
73
  assert cell_params < 1_000_000, f"Cell has {cell_params:,} params (should be <1M)"
74
  assert cell_params > 800_000, f"Cell has {cell_params:,} params (should be >800K)"
75
 
76
+ # Verify GroupNorm uses 192 groups (num_proj_heads), not 4 (num_heads)
77
+ assert cell.outnorm.num_groups == 192, f"GroupNorm should have 192 groups, got {cell.outnorm.num_groups}"
78
+ print(f" GroupNorm groups: {cell.outnorm.num_groups} (correct: per-projection-head)")
79
+
80
  x = torch.randn(2, 20, 384)
81
  y = cell(x)
82
  assert y.shape == (2, 20, 384), f"Cell output shape: {y.shape}"
 
194
  # Test 6: Prediction Heads
195
  # ============================================================
196
  def test_heads():
197
+ from vil_tracker.models.heads import CenterHead, UncertaintyHead, decode_predictions, create_hanning_window
198
 
199
  center_head = CenterHead(dim=384, feat_size=16)
200
  unc_head = UncertaintyHead(dim=384, feat_size=16)
 
209
  assert preds['size'].shape == (2, 2, 16, 16), f"Size shape: {preds['size'].shape}"
210
  assert preds['offset'].shape == (2, 2, 16, 16), f"Offset shape: {preds['offset'].shape}"
211
 
212
+ # Decode without Hanning
213
  boxes, scores = decode_predictions(preds['heatmap'], preds['size'], preds['offset'])
214
  assert boxes.shape == (2, 4), f"Boxes shape: {boxes.shape}"
215
  assert scores.shape == (2,), f"Scores shape: {scores.shape}"
216
 
217
+ # Decode WITH Hanning window
218
+ hann = create_hanning_window(16)
219
+ assert hann.shape == (16, 16), f"Hanning shape: {hann.shape}"
220
+ assert abs(hann[8, 8].item() - 1.0) < 0.05, f"Hanning center should be ~1.0, got {hann[8, 8]}"
221
+ assert hann[0, 0].item() < 0.01, f"Hanning corner should be ~0, got {hann[0, 0]}"
222
+
223
+ boxes_h, scores_h = decode_predictions(preds['heatmap'], preds['size'], preds['offset'],
224
+ hanning_window=hann)
225
+ assert boxes_h.shape == (2, 4), f"Hanning boxes shape: {boxes_h.shape}"
226
+ print(f" Hanning window: center={hann[8,8]:.3f}, corner={hann[0,0]:.6f}")
227
+ print(f" Without Hanning: box={boxes[0].tolist()}, score={scores[0].item():.4f}")
228
+ print(f" With Hanning: box={boxes_h[0].tolist()}, score={scores_h[0].item():.4f}")
229
+
230
  # Uncertainty
231
  log_var = unc_head(search_feat)
232
  assert log_var.shape == (2, 1, 16, 16), f"Log variance shape: {log_var.shape}"