Fix: mLSTM SiLU gate+activation, GroupNorm 192, stochastic depth 0.05, Hanning window
Browse files- test_all.py +19 -2
test_all.py
CHANGED
|
@@ -73,6 +73,10 @@ def test_mlstm_cell():
|
|
| 73 |
assert cell_params < 1_000_000, f"Cell has {cell_params:,} params (should be <1M)"
|
| 74 |
assert cell_params > 800_000, f"Cell has {cell_params:,} params (should be >800K)"
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
x = torch.randn(2, 20, 384)
|
| 77 |
y = cell(x)
|
| 78 |
assert y.shape == (2, 20, 384), f"Cell output shape: {y.shape}"
|
|
@@ -190,7 +194,7 @@ test("Backbone (TMoE + integrated FiLM)", test_backbone_tmoe_film)
|
|
| 190 |
# Test 6: Prediction Heads
|
| 191 |
# ============================================================
|
| 192 |
def test_heads():
|
| 193 |
-
from vil_tracker.models.heads import CenterHead, UncertaintyHead, decode_predictions
|
| 194 |
|
| 195 |
center_head = CenterHead(dim=384, feat_size=16)
|
| 196 |
unc_head = UncertaintyHead(dim=384, feat_size=16)
|
|
@@ -205,11 +209,24 @@ def test_heads():
|
|
| 205 |
assert preds['size'].shape == (2, 2, 16, 16), f"Size shape: {preds['size'].shape}"
|
| 206 |
assert preds['offset'].shape == (2, 2, 16, 16), f"Offset shape: {preds['offset'].shape}"
|
| 207 |
|
| 208 |
-
# Decode
|
| 209 |
boxes, scores = decode_predictions(preds['heatmap'], preds['size'], preds['offset'])
|
| 210 |
assert boxes.shape == (2, 4), f"Boxes shape: {boxes.shape}"
|
| 211 |
assert scores.shape == (2,), f"Scores shape: {scores.shape}"
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
# Uncertainty
|
| 214 |
log_var = unc_head(search_feat)
|
| 215 |
assert log_var.shape == (2, 1, 16, 16), f"Log variance shape: {log_var.shape}"
|
|
|
|
| 73 |
assert cell_params < 1_000_000, f"Cell has {cell_params:,} params (should be <1M)"
|
| 74 |
assert cell_params > 800_000, f"Cell has {cell_params:,} params (should be >800K)"
|
| 75 |
|
| 76 |
+
# Verify GroupNorm uses 192 groups (num_proj_heads), not 4 (num_heads)
|
| 77 |
+
assert cell.outnorm.num_groups == 192, f"GroupNorm should have 192 groups, got {cell.outnorm.num_groups}"
|
| 78 |
+
print(f" GroupNorm groups: {cell.outnorm.num_groups} (correct: per-projection-head)")
|
| 79 |
+
|
| 80 |
x = torch.randn(2, 20, 384)
|
| 81 |
y = cell(x)
|
| 82 |
assert y.shape == (2, 20, 384), f"Cell output shape: {y.shape}"
|
|
|
|
| 194 |
# Test 6: Prediction Heads
|
| 195 |
# ============================================================
|
| 196 |
def test_heads():
|
| 197 |
+
from vil_tracker.models.heads import CenterHead, UncertaintyHead, decode_predictions, create_hanning_window
|
| 198 |
|
| 199 |
center_head = CenterHead(dim=384, feat_size=16)
|
| 200 |
unc_head = UncertaintyHead(dim=384, feat_size=16)
|
|
|
|
| 209 |
assert preds['size'].shape == (2, 2, 16, 16), f"Size shape: {preds['size'].shape}"
|
| 210 |
assert preds['offset'].shape == (2, 2, 16, 16), f"Offset shape: {preds['offset'].shape}"
|
| 211 |
|
| 212 |
+
# Decode without Hanning
|
| 213 |
boxes, scores = decode_predictions(preds['heatmap'], preds['size'], preds['offset'])
|
| 214 |
assert boxes.shape == (2, 4), f"Boxes shape: {boxes.shape}"
|
| 215 |
assert scores.shape == (2,), f"Scores shape: {scores.shape}"
|
| 216 |
|
| 217 |
+
# Decode WITH Hanning window
|
| 218 |
+
hann = create_hanning_window(16)
|
| 219 |
+
assert hann.shape == (16, 16), f"Hanning shape: {hann.shape}"
|
| 220 |
+
assert abs(hann[8, 8].item() - 1.0) < 0.05, f"Hanning center should be ~1.0, got {hann[8, 8]}"
|
| 221 |
+
assert hann[0, 0].item() < 0.01, f"Hanning corner should be ~0, got {hann[0, 0]}"
|
| 222 |
+
|
| 223 |
+
boxes_h, scores_h = decode_predictions(preds['heatmap'], preds['size'], preds['offset'],
|
| 224 |
+
hanning_window=hann)
|
| 225 |
+
assert boxes_h.shape == (2, 4), f"Hanning boxes shape: {boxes_h.shape}"
|
| 226 |
+
print(f" Hanning window: center={hann[8,8]:.3f}, corner={hann[0,0]:.6f}")
|
| 227 |
+
print(f" Without Hanning: box={boxes[0].tolist()}, score={scores[0].item():.4f}")
|
| 228 |
+
print(f" With Hanning: box={boxes_h[0].tolist()}, score={scores_h[0].item():.4f}")
|
| 229 |
+
|
| 230 |
# Uncertainty
|
| 231 |
log_var = unc_head(search_feat)
|
| 232 |
assert log_var.shape == (2, 1, 16, 16), f"Log variance shape: {log_var.shape}"
|