Fix test_all.py: audit corrections
Browse files- test_all.py +226 -35
test_all.py
CHANGED
|
@@ -1,20 +1,23 @@
|
|
| 1 |
"""
|
| 2 |
Comprehensive test suite for ViL Tracker.
|
| 3 |
|
| 4 |
-
|
| 5 |
1. mLSTM Cell (LinearHeadwiseExpand correctness + param count)
|
| 6 |
-
2. mLSTM Block (full block
|
| 7 |
3. TMoE MLP
|
| 8 |
4. Backbone (standard, small depth)
|
| 9 |
-
5. Backbone (with TMoE, medium depth)
|
| 10 |
6. Prediction Heads
|
| 11 |
7. FiLM Temporal Modulation
|
| 12 |
8. Full Tracker (small depth for speed)
|
| 13 |
-
9. Loss Functions
|
| 14 |
-
10. Kalman Filter
|
| 15 |
11. Dataset (synthetic)
|
| 16 |
-
12. Training Step (mini forward + backward)
|
| 17 |
13. Model Summary (FULL depth=24, constraint check)
|
|
|
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
import sys
|
|
@@ -94,6 +97,9 @@ def test_mlstm_block():
|
|
| 94 |
params = count_params(block)
|
| 95 |
print(f" mLSTMBlock params: {params:,} ({params/1e6:.3f}M)")
|
| 96 |
|
|
|
|
|
|
|
|
|
|
| 97 |
x = torch.randn(2, 20, 384)
|
| 98 |
y = block(x)
|
| 99 |
assert y.shape == (2, 20, 384), f"Block output shape: {y.shape}"
|
|
@@ -102,7 +108,7 @@ def test_mlstm_block():
|
|
| 102 |
diff = (y - x).abs().mean().item()
|
| 103 |
print(f" Residual diff from input: {diff:.4f}")
|
| 104 |
|
| 105 |
-
test("mLSTM Block", test_mlstm_block)
|
| 106 |
|
| 107 |
|
| 108 |
# ============================================================
|
|
@@ -149,23 +155,35 @@ test("Backbone (standard, depth=4)", test_backbone_small)
|
|
| 149 |
|
| 150 |
|
| 151 |
# ============================================================
|
| 152 |
-
# Test 5: Backbone
|
| 153 |
# ============================================================
|
| 154 |
-
def
|
| 155 |
from vil_tracker.models.backbone import ViLBackbone
|
|
|
|
| 156 |
|
| 157 |
-
backbone = ViLBackbone(dim=384, depth=6, patch_size=16, tmoe_blocks=2,
|
|
|
|
| 158 |
params = count_params(backbone)
|
| 159 |
print(f" Backbone (depth=6, TMoE=2) params: {params:,} ({params/1e6:.3f}M)")
|
| 160 |
|
|
|
|
|
|
|
|
|
|
| 161 |
template = torch.randn(1, 3, 128, 128)
|
| 162 |
search = torch.randn(1, 3, 256, 256)
|
| 163 |
|
| 164 |
-
|
|
|
|
| 165 |
assert t_feat.shape == (1, 64, 384), f"Template feat shape: {t_feat.shape}"
|
| 166 |
assert s_feat.shape == (1, 256, 384), f"Search feat shape: {s_feat.shape}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
-
test("Backbone (
|
| 169 |
|
| 170 |
|
| 171 |
# ============================================================
|
|
@@ -234,8 +252,12 @@ def test_film():
|
|
| 234 |
# Update context and try again
|
| 235 |
manager.update_temporal_context(x)
|
| 236 |
y = manager.modulate(x, block_idx=5) # block 5 → (5+1)%6==0, should modulate
|
| 237 |
-
# With temporal context, output should differ
|
| 238 |
assert y.shape == (2, 20, 384)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
test("FiLM Temporal Modulation", test_film)
|
| 241 |
|
|
@@ -258,27 +280,38 @@ def test_full_tracker_small():
|
|
| 258 |
template = torch.randn(2, 3, 128, 128)
|
| 259 |
search = torch.randn(2, 3, 256, 256)
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
assert output['heatmap'].shape == (2, 1, 16, 16)
|
| 264 |
-
assert output['size'].shape == (2, 2, 16, 16)
|
| 265 |
assert output['boxes'].shape == (2, 4)
|
| 266 |
assert output['scores'].shape == (2,)
|
| 267 |
assert 'log_variance' in output
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
print(f" Predicted boxes: {output['boxes'][0].tolist()}")
|
| 270 |
print(f" Scores: {output['scores'].tolist()}")
|
| 271 |
|
| 272 |
-
test("Full Tracker (depth=4)", test_full_tracker_small)
|
| 273 |
|
| 274 |
|
| 275 |
# ============================================================
|
| 276 |
-
# Test 9: Loss Functions
|
| 277 |
# ============================================================
|
| 278 |
def test_losses():
|
| 279 |
from vil_tracker.training.losses import (
|
| 280 |
FocalLoss, GIoULoss, UncertaintyNLLLoss,
|
| 281 |
-
MemoryContrastiveLoss,
|
|
|
|
| 282 |
)
|
| 283 |
|
| 284 |
B = 4
|
|
@@ -300,13 +333,36 @@ def test_losses():
|
|
| 300 |
print(f" GIoU loss: {gl.item():.4f}")
|
| 301 |
assert 0 <= gl.item() <= 2, f"GIoU loss out of range: {gl.item()}"
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
# Contrastive loss
|
| 304 |
contrastive = MemoryContrastiveLoss()
|
| 305 |
feat_a = torch.randn(B, 384)
|
| 306 |
-
feat_b = feat_a + torch.randn(B, 384) * 0.1
|
| 307 |
cl = contrastive(feat_a, feat_b)
|
| 308 |
print(f" Contrastive loss: {cl.item():.4f}")
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
# Combined loss
|
| 311 |
combined = CombinedTrackingLoss()
|
| 312 |
pred = {
|
|
@@ -319,7 +375,7 @@ def test_losses():
|
|
| 319 |
print(f" Combined loss: {loss_dict['total'].item():.4f}")
|
| 320 |
assert loss_dict['total'].item() > 0
|
| 321 |
|
| 322 |
-
test("Loss Functions", test_losses)
|
| 323 |
|
| 324 |
|
| 325 |
# ============================================================
|
|
@@ -336,12 +392,12 @@ def test_kalman():
|
|
| 336 |
kf.initialize(init_box)
|
| 337 |
assert kf.initialized
|
| 338 |
|
| 339 |
-
# Predict + update cycle
|
| 340 |
for i in range(10):
|
| 341 |
pred = kf.predict()
|
| 342 |
assert len(pred) == 4, f"Prediction length: {len(pred)}"
|
| 343 |
|
| 344 |
-
# Simulate noisy measurement
|
| 345 |
noise = np.random.randn(4) * 2
|
| 346 |
meas = init_box + np.array([i * 2, i * 1, 0, 0]) + noise
|
| 347 |
kf.update(meas, uncertainty=1.0)
|
|
@@ -349,17 +405,23 @@ def test_kalman():
|
|
| 349 |
state = kf.get_state()
|
| 350 |
print(f" Final state: cx={state[0]:.1f}, cy={state[1]:.1f}, w={state[2]:.1f}, h={state[3]:.1f}")
|
| 351 |
assert state[2] > 0 and state[3] > 0, "Width/height should be positive"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
-
test("Kalman Filter", test_kalman)
|
| 354 |
|
| 355 |
|
| 356 |
# ============================================================
|
| 357 |
# Test 11: Dataset (synthetic)
|
| 358 |
# ============================================================
|
| 359 |
def test_dataset():
|
| 360 |
-
from vil_tracker.data.dataset import TrackingDataset
|
| 361 |
|
| 362 |
-
ds =
|
| 363 |
assert len(ds) == 100
|
| 364 |
|
| 365 |
sample = ds[0]
|
|
@@ -376,16 +438,22 @@ def test_dataset():
|
|
| 376 |
hard_sample = ds[42]
|
| 377 |
print(f" Easy center: {easy_sample['boxes'][:2].tolist()}")
|
| 378 |
print(f" Hard center: {hard_sample['boxes'][:2].tolist()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
-
test("Dataset (synthetic)", test_dataset)
|
| 381 |
|
| 382 |
|
| 383 |
# ============================================================
|
| 384 |
-
# Test 12: Training Step (
|
| 385 |
# ============================================================
|
| 386 |
def test_training_step():
|
| 387 |
from vil_tracker.models.tracker import ViLTracker, get_default_config
|
| 388 |
-
from vil_tracker.training.losses import CombinedTrackingLoss
|
| 389 |
from vil_tracker.models.heads import generate_heatmap
|
| 390 |
|
| 391 |
config = get_default_config()
|
|
@@ -396,6 +464,7 @@ def test_training_step():
|
|
| 396 |
model = ViLTracker(config)
|
| 397 |
model.train()
|
| 398 |
loss_fn = CombinedTrackingLoss()
|
|
|
|
| 399 |
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
| 400 |
|
| 401 |
B = 2
|
|
@@ -408,27 +477,33 @@ def test_training_step():
|
|
| 408 |
gt_size = torch.tensor([[0.2, 0.3], [0.15, 0.25]])
|
| 409 |
gt_boxes = torch.tensor([[128.0, 128.0, 51.2, 76.8], [100.0, 150.0, 38.4, 64.0]])
|
| 410 |
|
| 411 |
-
# Forward
|
| 412 |
-
pred = model(template, search)
|
| 413 |
loss_dict = loss_fn(pred, gt_heatmap, gt_size, gt_boxes)
|
| 414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
# Backward
|
| 416 |
-
|
| 417 |
|
| 418 |
# Check gradients exist
|
| 419 |
has_grads = sum(1 for p in model.parameters() if p.grad is not None)
|
| 420 |
total_params_count = sum(1 for p in model.parameters())
|
| 421 |
-
print(f"
|
| 422 |
print(f" Params with gradients: {has_grads}/{total_params_count}")
|
| 423 |
|
| 424 |
# Optimizer step
|
| 425 |
optimizer.step()
|
| 426 |
optimizer.zero_grad()
|
| 427 |
|
| 428 |
-
assert
|
| 429 |
assert has_grads > 0
|
| 430 |
|
| 431 |
-
test("Training Step (
|
| 432 |
|
| 433 |
|
| 434 |
# ============================================================
|
|
@@ -455,6 +530,122 @@ def test_model_summary():
|
|
| 455 |
test("Model Summary (full depth=24)", test_model_summary)
|
| 456 |
|
| 457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
# ============================================================
|
| 459 |
# Summary
|
| 460 |
# ============================================================
|
|
|
|
| 1 |
"""
|
| 2 |
Comprehensive test suite for ViL Tracker.
|
| 3 |
|
| 4 |
+
16 tests covering all components:
|
| 5 |
1. mLSTM Cell (LinearHeadwiseExpand correctness + param count)
|
| 6 |
+
2. mLSTM Block (full block without MLP)
|
| 7 |
3. TMoE MLP
|
| 8 |
4. Backbone (standard, small depth)
|
| 9 |
+
5. Backbone (with TMoE + integrated FiLM, medium depth)
|
| 10 |
6. Prediction Heads
|
| 11 |
7. FiLM Temporal Modulation
|
| 12 |
8. Full Tracker (small depth for speed)
|
| 13 |
+
9. Loss Functions (all 6)
|
| 14 |
+
10. Kalman Filter (8-state, adaptive)
|
| 15 |
11. Dataset (synthetic)
|
| 16 |
+
12. Training Step (mini forward + backward with temporal)
|
| 17 |
13. Model Summary (FULL depth=24, constraint check)
|
| 18 |
+
14. Online Tracker (full inference pipeline)
|
| 19 |
+
15. Augmentation pipeline
|
| 20 |
+
16. ACL curriculum integration
|
| 21 |
"""
|
| 22 |
|
| 23 |
import sys
|
|
|
|
| 97 |
params = count_params(block)
|
| 98 |
print(f" mLSTMBlock params: {params:,} ({params/1e6:.3f}M)")
|
| 99 |
|
| 100 |
+
# No separate MLP — should be ~920K, same as cell + LayerNorm
|
| 101 |
+
assert params < 1_050_000, f"Block has {params:,} params (should be <1.05M without MLP)"
|
| 102 |
+
|
| 103 |
x = torch.randn(2, 20, 384)
|
| 104 |
y = block(x)
|
| 105 |
assert y.shape == (2, 20, 384), f"Block output shape: {y.shape}"
|
|
|
|
| 108 |
diff = (y - x).abs().mean().item()
|
| 109 |
print(f" Residual diff from input: {diff:.4f}")
|
| 110 |
|
| 111 |
+
test("mLSTM Block (no separate MLP)", test_mlstm_block)
|
| 112 |
|
| 113 |
|
| 114 |
# ============================================================
|
|
|
|
| 155 |
|
| 156 |
|
| 157 |
# ============================================================
|
| 158 |
+
# Test 5: Backbone with TMoE + integrated FiLM
|
| 159 |
# ============================================================
|
| 160 |
+
def test_backbone_tmoe_film():
|
| 161 |
from vil_tracker.models.backbone import ViLBackbone
|
| 162 |
+
from vil_tracker.models.film_temporal import TemporalModulationManager
|
| 163 |
|
| 164 |
+
backbone = ViLBackbone(dim=384, depth=6, patch_size=16, tmoe_blocks=2,
|
| 165 |
+
num_experts=4, film_interval=3)
|
| 166 |
params = count_params(backbone)
|
| 167 |
print(f" Backbone (depth=6, TMoE=2) params: {params:,} ({params/1e6:.3f}M)")
|
| 168 |
|
| 169 |
+
# Create temporal modulation manager
|
| 170 |
+
temporal_mod = TemporalModulationManager(dim=384, num_blocks=6, modulation_interval=3)
|
| 171 |
+
|
| 172 |
template = torch.randn(1, 3, 128, 128)
|
| 173 |
search = torch.randn(1, 3, 256, 256)
|
| 174 |
|
| 175 |
+
# First pass: no temporal context yet
|
| 176 |
+
t_feat, s_feat = backbone(template, search, temporal_mod_manager=temporal_mod)
|
| 177 |
assert t_feat.shape == (1, 64, 384), f"Template feat shape: {t_feat.shape}"
|
| 178 |
assert s_feat.shape == (1, 256, 384), f"Search feat shape: {s_feat.shape}"
|
| 179 |
+
|
| 180 |
+
# Second pass: temporal context should be active now
|
| 181 |
+
t_feat2, s_feat2 = backbone(template, search, temporal_mod_manager=temporal_mod)
|
| 182 |
+
# Output should differ when temporal modulation is active
|
| 183 |
+
assert t_feat2.shape == (1, 64, 384)
|
| 184 |
+
print(f" FiLM modulation active: features differ = {not torch.allclose(t_feat, t_feat2, atol=1e-5)}")
|
| 185 |
|
| 186 |
+
test("Backbone (TMoE + integrated FiLM)", test_backbone_tmoe_film)
|
| 187 |
|
| 188 |
|
| 189 |
# ============================================================
|
|
|
|
| 252 |
# Update context and try again
|
| 253 |
manager.update_temporal_context(x)
|
| 254 |
y = manager.modulate(x, block_idx=5) # block 5 → (5+1)%6==0, should modulate
|
|
|
|
| 255 |
assert y.shape == (2, 20, 384)
|
| 256 |
+
|
| 257 |
+
# Test reset
|
| 258 |
+
manager.reset()
|
| 259 |
+
y = manager.modulate(x, block_idx=5)
|
| 260 |
+
assert torch.allclose(y, x), "After reset, should return unchanged"
|
| 261 |
|
| 262 |
test("FiLM Temporal Modulation", test_film)
|
| 263 |
|
|
|
|
| 280 |
template = torch.randn(2, 3, 128, 128)
|
| 281 |
search = torch.randn(2, 3, 256, 256)
|
| 282 |
|
| 283 |
+
# Test without temporal
|
| 284 |
+
output = tracker(template, search, use_temporal=False)
|
| 285 |
assert output['heatmap'].shape == (2, 1, 16, 16)
|
|
|
|
| 286 |
assert output['boxes'].shape == (2, 4)
|
| 287 |
assert output['scores'].shape == (2,)
|
| 288 |
assert 'log_variance' in output
|
| 289 |
|
| 290 |
+
# Test with temporal (first frame: no context)
|
| 291 |
+
output_t1 = tracker(template, search, use_temporal=True)
|
| 292 |
+
assert output_t1['boxes'].shape == (2, 4)
|
| 293 |
+
|
| 294 |
+
# Second frame: temporal context available
|
| 295 |
+
output_t2 = tracker(template, search, use_temporal=True)
|
| 296 |
+
assert output_t2['boxes'].shape == (2, 4)
|
| 297 |
+
|
| 298 |
+
# Reset temporal
|
| 299 |
+
tracker.reset_temporal()
|
| 300 |
+
|
| 301 |
print(f" Predicted boxes: {output['boxes'][0].tolist()}")
|
| 302 |
print(f" Scores: {output['scores'].tolist()}")
|
| 303 |
|
| 304 |
+
test("Full Tracker (depth=4, with temporal)", test_full_tracker_small)
|
| 305 |
|
| 306 |
|
| 307 |
# ============================================================
|
| 308 |
+
# Test 9: Loss Functions (all 6)
|
| 309 |
# ============================================================
|
| 310 |
def test_losses():
|
| 311 |
from vil_tracker.training.losses import (
|
| 312 |
FocalLoss, GIoULoss, UncertaintyNLLLoss,
|
| 313 |
+
MemoryContrastiveLoss, AFKDDistillationLoss,
|
| 314 |
+
ADWLoss, CombinedTrackingLoss,
|
| 315 |
)
|
| 316 |
|
| 317 |
B = 4
|
|
|
|
| 333 |
print(f" GIoU loss: {gl.item():.4f}")
|
| 334 |
assert 0 <= gl.item() <= 2, f"GIoU loss out of range: {gl.item()}"
|
| 335 |
|
| 336 |
+
# Uncertainty NLL loss
|
| 337 |
+
unc = UncertaintyNLLLoss()
|
| 338 |
+
pred_v = torch.randn(B, 4)
|
| 339 |
+
target_v = torch.randn(B, 4)
|
| 340 |
+
log_var = torch.zeros(B, 4) # unit variance
|
| 341 |
+
ul = unc(pred_v, target_v, log_var)
|
| 342 |
+
print(f" Uncertainty NLL loss: {ul.item():.4f}")
|
| 343 |
+
assert ul.item() > 0
|
| 344 |
+
|
| 345 |
# Contrastive loss
|
| 346 |
contrastive = MemoryContrastiveLoss()
|
| 347 |
feat_a = torch.randn(B, 384)
|
| 348 |
+
feat_b = feat_a + torch.randn(B, 384) * 0.1
|
| 349 |
cl = contrastive(feat_a, feat_b)
|
| 350 |
print(f" Contrastive loss: {cl.item():.4f}")
|
| 351 |
|
| 352 |
+
# AFKD distillation loss
|
| 353 |
+
afkd = AFKDDistillationLoss(student_dim=384, teacher_dim=768)
|
| 354 |
+
student_feat = torch.randn(B, 256, 384)
|
| 355 |
+
teacher_feat = torch.randn(B, 256, 768)
|
| 356 |
+
dl = afkd(student_feat, teacher_feat)
|
| 357 |
+
print(f" AFKD distillation loss: {dl.item():.4f}")
|
| 358 |
+
assert dl.item() > 0
|
| 359 |
+
|
| 360 |
+
# ADW loss
|
| 361 |
+
adw = ADWLoss(num_tasks=3)
|
| 362 |
+
losses = [torch.tensor(1.0), torch.tensor(0.5), torch.tensor(2.0)]
|
| 363 |
+
al = adw(losses)
|
| 364 |
+
print(f" ADW loss: {al.item():.4f}")
|
| 365 |
+
|
| 366 |
# Combined loss
|
| 367 |
combined = CombinedTrackingLoss()
|
| 368 |
pred = {
|
|
|
|
| 375 |
print(f" Combined loss: {loss_dict['total'].item():.4f}")
|
| 376 |
assert loss_dict['total'].item() > 0
|
| 377 |
|
| 378 |
+
test("Loss Functions (all 6)", test_losses)
|
| 379 |
|
| 380 |
|
| 381 |
# ============================================================
|
|
|
|
| 392 |
kf.initialize(init_box)
|
| 393 |
assert kf.initialized
|
| 394 |
|
| 395 |
+
# Predict + update cycle with moving target
|
| 396 |
for i in range(10):
|
| 397 |
pred = kf.predict()
|
| 398 |
assert len(pred) == 4, f"Prediction length: {len(pred)}"
|
| 399 |
|
| 400 |
+
# Simulate noisy measurement of linearly moving target
|
| 401 |
noise = np.random.randn(4) * 2
|
| 402 |
meas = init_box + np.array([i * 2, i * 1, 0, 0]) + noise
|
| 403 |
kf.update(meas, uncertainty=1.0)
|
|
|
|
| 405 |
state = kf.get_state()
|
| 406 |
print(f" Final state: cx={state[0]:.1f}, cy={state[1]:.1f}, w={state[2]:.1f}, h={state[3]:.1f}")
|
| 407 |
assert state[2] > 0 and state[3] > 0, "Width/height should be positive"
|
| 408 |
+
|
| 409 |
+
# Test outlier rejection (chi-squared gating)
|
| 410 |
+
kf.update(np.array([500.0, 500.0, 50.0, 50.0]), uncertainty=1.0) # Far outlier
|
| 411 |
+
state_after = kf.get_state()
|
| 412 |
+
# State should NOT have jumped to 500,500
|
| 413 |
+
assert state_after[0] < 200, f"Outlier should be rejected, cx={state_after[0]}"
|
| 414 |
|
| 415 |
+
test("Kalman Filter (8-state, adaptive)", test_kalman)
|
| 416 |
|
| 417 |
|
| 418 |
# ============================================================
|
| 419 |
# Test 11: Dataset (synthetic)
|
| 420 |
# ============================================================
|
| 421 |
def test_dataset():
|
| 422 |
+
from vil_tracker.data.dataset import SyntheticTrackingDataset, TrackingDataset
|
| 423 |
|
| 424 |
+
ds = SyntheticTrackingDataset(length=100)
|
| 425 |
assert len(ds) == 100
|
| 426 |
|
| 427 |
sample = ds[0]
|
|
|
|
| 438 |
hard_sample = ds[42]
|
| 439 |
print(f" Easy center: {easy_sample['boxes'][:2].tolist()}")
|
| 440 |
print(f" Hard center: {hard_sample['boxes'][:2].tolist()}")
|
| 441 |
+
|
| 442 |
+
# Test backward-compatible alias
|
| 443 |
+
ds2 = TrackingDataset(synthetic=True, synthetic_length=50)
|
| 444 |
+
assert len(ds2) == 50
|
| 445 |
+
sample2 = ds2[0]
|
| 446 |
+
assert sample2['template'].shape == (3, 128, 128)
|
| 447 |
|
| 448 |
+
test("Dataset (synthetic + backward compat)", test_dataset)
|
| 449 |
|
| 450 |
|
| 451 |
# ============================================================
|
| 452 |
+
# Test 12: Training Step (with temporal modulation)
|
| 453 |
# ============================================================
|
| 454 |
def test_training_step():
|
| 455 |
from vil_tracker.models.tracker import ViLTracker, get_default_config
|
| 456 |
+
from vil_tracker.training.losses import CombinedTrackingLoss, MemoryContrastiveLoss
|
| 457 |
from vil_tracker.models.heads import generate_heatmap
|
| 458 |
|
| 459 |
config = get_default_config()
|
|
|
|
| 464 |
model = ViLTracker(config)
|
| 465 |
model.train()
|
| 466 |
loss_fn = CombinedTrackingLoss()
|
| 467 |
+
contrastive_loss = MemoryContrastiveLoss()
|
| 468 |
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
| 469 |
|
| 470 |
B = 2
|
|
|
|
| 477 |
gt_size = torch.tensor([[0.2, 0.3], [0.15, 0.25]])
|
| 478 |
gt_boxes = torch.tensor([[128.0, 128.0, 51.2, 76.8], [100.0, 150.0, 38.4, 64.0]])
|
| 479 |
|
| 480 |
+
# Forward WITH temporal modulation
|
| 481 |
+
pred = model(template, search, use_temporal=True)
|
| 482 |
loss_dict = loss_fn(pred, gt_heatmap, gt_size, gt_boxes)
|
| 483 |
|
| 484 |
+
# Add contrastive loss
|
| 485 |
+
t_pooled = pred['template_feat'].mean(dim=1)
|
| 486 |
+
s_pooled = pred['search_feat'].mean(dim=1)
|
| 487 |
+
c_loss = contrastive_loss(t_pooled, s_pooled)
|
| 488 |
+
total_loss = loss_dict['total'] + 0.1 * c_loss
|
| 489 |
+
|
| 490 |
# Backward
|
| 491 |
+
total_loss.backward()
|
| 492 |
|
| 493 |
# Check gradients exist
|
| 494 |
has_grads = sum(1 for p in model.parameters() if p.grad is not None)
|
| 495 |
total_params_count = sum(1 for p in model.parameters())
|
| 496 |
+
print(f" Total loss: {total_loss.item():.4f} (tracking={loss_dict['total'].item():.4f}, contr={c_loss.item():.4f})")
|
| 497 |
print(f" Params with gradients: {has_grads}/{total_params_count}")
|
| 498 |
|
| 499 |
# Optimizer step
|
| 500 |
optimizer.step()
|
| 501 |
optimizer.zero_grad()
|
| 502 |
|
| 503 |
+
assert total_loss.item() > 0
|
| 504 |
assert has_grads > 0
|
| 505 |
|
| 506 |
+
test("Training Step (with temporal + contrastive)", test_training_step)
|
| 507 |
|
| 508 |
|
| 509 |
# ============================================================
|
|
|
|
| 530 |
test("Model Summary (full depth=24)", test_model_summary)
|
| 531 |
|
| 532 |
|
| 533 |
+
# ============================================================
|
| 534 |
+
# Test 14: Online Tracker (inference pipeline)
|
| 535 |
+
# ============================================================
|
| 536 |
+
def test_online_tracker():
|
| 537 |
+
from vil_tracker.models.tracker import ViLTracker, get_default_config
|
| 538 |
+
from vil_tracker.inference.online_tracker import OnlineTracker
|
| 539 |
+
|
| 540 |
+
config = get_default_config()
|
| 541 |
+
config['depth'] = 2
|
| 542 |
+
config['tmoe_blocks'] = 0
|
| 543 |
+
config['film_interval'] = 2
|
| 544 |
+
|
| 545 |
+
model = ViLTracker(config)
|
| 546 |
+
model.eval()
|
| 547 |
+
|
| 548 |
+
tracker = OnlineTracker(model, device='cpu', template_size=128, search_size=256)
|
| 549 |
+
|
| 550 |
+
# Simulate a sequence: 480x640 frames with a moving rectangle
|
| 551 |
+
H, W = 480, 640
|
| 552 |
+
init_bbox = [200, 200, 60, 80] # [x, y, w, h]
|
| 553 |
+
|
| 554 |
+
# First frame
|
| 555 |
+
frame0 = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
|
| 556 |
+
# Draw target
|
| 557 |
+
x, y, w, h = init_bbox
|
| 558 |
+
frame0[y:y+h, x:x+w] = [255, 0, 0] # Red rectangle
|
| 559 |
+
|
| 560 |
+
tracker.initialize(frame0, init_bbox)
|
| 561 |
+
|
| 562 |
+
# Track for 5 frames
|
| 563 |
+
for i in range(1, 6):
|
| 564 |
+
frame = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
|
| 565 |
+
# Move target
|
| 566 |
+
nx = x + i * 5
|
| 567 |
+
ny = y + i * 3
|
| 568 |
+
frame[ny:ny+h, nx:nx+w] = [255, 0, 0]
|
| 569 |
+
|
| 570 |
+
bbox = tracker.track(frame)
|
| 571 |
+
assert len(bbox) == 4, f"Bbox should have 4 elements, got {len(bbox)}"
|
| 572 |
+
assert all(isinstance(v, (int, float, np.floating)) for v in bbox), f"Bbox values: {bbox}"
|
| 573 |
+
print(f" Frame {i}: predicted [{bbox[0]:.1f}, {bbox[1]:.1f}, {bbox[2]:.1f}, {bbox[3]:.1f}]")
|
| 574 |
+
|
| 575 |
+
print(f" Online tracker completed 5-frame sequence")
|
| 576 |
+
|
| 577 |
+
test("Online Tracker (inference pipeline)", test_online_tracker)
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
# ============================================================
|
| 581 |
+
# Test 15: Augmentation pipeline
|
| 582 |
+
# ============================================================
|
| 583 |
+
def test_augmentation():
|
| 584 |
+
from vil_tracker.data.dataset import TrackingAugmentation
|
| 585 |
+
|
| 586 |
+
aug = TrackingAugmentation(
|
| 587 |
+
brightness=0.2,
|
| 588 |
+
contrast=0.2,
|
| 589 |
+
horizontal_flip_prob=1.0, # Force flip to test bbox update
|
| 590 |
+
grayscale_prob=0.0,
|
| 591 |
+
blur_prob=0.0,
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
template = torch.rand(3, 128, 128)
|
| 595 |
+
search = torch.rand(3, 256, 256)
|
| 596 |
+
bbox = torch.tensor([128.0, 128.0, 50.0, 50.0]) # [cx, cy, w, h]
|
| 597 |
+
|
| 598 |
+
t_aug, s_aug, b_aug = aug(template, search, bbox)
|
| 599 |
+
|
| 600 |
+
assert t_aug.shape == (3, 128, 128), f"Aug template shape: {t_aug.shape}"
|
| 601 |
+
assert s_aug.shape == (3, 256, 256), f"Aug search shape: {s_aug.shape}"
|
| 602 |
+
assert b_aug.shape == (4,), f"Aug bbox shape: {b_aug.shape}"
|
| 603 |
+
|
| 604 |
+
# With flip_prob=1.0, cx should be flipped: new_cx = W - old_cx = 256 - 128 = 128
|
| 605 |
+
print(f" Original bbox: {bbox.tolist()}")
|
| 606 |
+
print(f" Augmented bbox: {b_aug.tolist()}")
|
| 607 |
+
assert abs(b_aug[0].item() - (256 - 128)) < 1.0, f"Flipped cx should be ~128, got {b_aug[0]}"
|
| 608 |
+
|
| 609 |
+
test("Augmentation pipeline", test_augmentation)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
# ============================================================
|
| 613 |
+
# Test 16: ACL curriculum integration
|
| 614 |
+
# ============================================================
|
| 615 |
+
def test_acl_curriculum():
|
| 616 |
+
from vil_tracker.data.dataset import SyntheticTrackingDataset
|
| 617 |
+
|
| 618 |
+
ds = SyntheticTrackingDataset(length=100, acl_difficulty=0.0)
|
| 619 |
+
|
| 620 |
+
# Easy: targets near center
|
| 621 |
+
easy_offsets = []
|
| 622 |
+
for i in range(20):
|
| 623 |
+
sample = ds[i]
|
| 624 |
+
cx, cy = sample['boxes'][:2].tolist()
|
| 625 |
+
offset = ((cx - 128) ** 2 + (cy - 128) ** 2) ** 0.5
|
| 626 |
+
easy_offsets.append(offset)
|
| 627 |
+
|
| 628 |
+
ds.set_acl_difficulty(1.0)
|
| 629 |
+
|
| 630 |
+
hard_offsets = []
|
| 631 |
+
for i in range(20):
|
| 632 |
+
sample = ds[i]
|
| 633 |
+
cx, cy = sample['boxes'][:2].tolist()
|
| 634 |
+
offset = ((cx - 128) ** 2 + (cy - 128) ** 2) ** 0.5
|
| 635 |
+
hard_offsets.append(offset)
|
| 636 |
+
|
| 637 |
+
avg_easy = np.mean(easy_offsets)
|
| 638 |
+
avg_hard = np.mean(hard_offsets)
|
| 639 |
+
|
| 640 |
+
print(f" Avg offset (easy, d=0.0): {avg_easy:.1f} px")
|
| 641 |
+
print(f" Avg offset (hard, d=1.0): {avg_hard:.1f} px")
|
| 642 |
+
# Hard samples should have larger offsets from center on average
|
| 643 |
+
# (this is stochastic, so we allow some tolerance)
|
| 644 |
+
print(f" Hard > Easy: {avg_hard > avg_easy * 0.5}")
|
| 645 |
+
|
| 646 |
+
test("ACL curriculum integration", test_acl_curriculum)
|
| 647 |
+
|
| 648 |
+
|
| 649 |
# ============================================================
|
| 650 |
# Summary
|
| 651 |
# ============================================================
|