| """ |
| Comprehensive test suite for ViL Tracker. |
| |
| 16 tests covering all components: |
| 1. mLSTM Cell (LinearHeadwiseExpand correctness + param count) |
| 2. mLSTM Block (full block without MLP) |
| 3. TMoE MLP |
| 4. Backbone (standard, small depth) |
| 5. Backbone (with TMoE + integrated FiLM, medium depth) |
| 6. Prediction Heads |
| 7. FiLM Temporal Modulation |
| 8. Full Tracker (small depth for speed) |
| 9. Loss Functions (all 6) |
| 10. Kalman Filter (8-state, adaptive) |
| 11. Dataset (synthetic) |
| 12. Training Step (mini forward + backward with temporal) |
| 13. Model Summary (FULL depth=24, constraint check) |
| 14. Online Tracker (full inference pipeline) |
| 15. Augmentation pipeline |
| 16. ACL curriculum integration |
| """ |
|
|
| import sys |
| import time |
| import torch |
| import numpy as np |
|
|
| torch.manual_seed(42) |
| np.random.seed(42) |
|
|
| PASS = 0 |
| FAIL = 0 |
|
|
| def test(name, fn): |
| global PASS, FAIL |
| print(f"\nTest {PASS + FAIL + 1}: {name}...", flush=True) |
| try: |
| fn() |
| PASS += 1 |
| print(f" ✅ PASSED") |
| except Exception as e: |
| FAIL += 1 |
| print(f" ❌ FAILED: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
|
|
| def count_params(model): |
| return sum(p.numel() for p in model.parameters()) |
|
|
|
|
| |
| |
| |
| def test_mlstm_cell(): |
| from vil_tracker.models.mlstm import mLSTMCell, LinearHeadwiseExpand |
| |
| |
| lhe = LinearHeadwiseExpand(768, num_heads=192, bias=False) |
| lhe_params = count_params(lhe) |
| assert lhe_params == 192 * 4 * 4, f"LHE params: {lhe_params} != {192*4*4}" |
| |
| x = torch.randn(2, 10, 768) |
| y = lhe(x) |
| assert y.shape == (2, 10, 768), f"LHE output shape: {y.shape}" |
| |
| |
| cell = mLSTMCell(dim=384, proj_factor=2.0, qkv_proj_blocksize=4, num_heads=4) |
| cell_params = count_params(cell) |
| print(f" mLSTMCell params: {cell_params:,} ({cell_params/1e6:.3f}M)") |
| |
| |
| assert cell_params < 1_000_000, f"Cell has {cell_params:,} params (should be <1M)" |
| assert cell_params > 800_000, f"Cell has {cell_params:,} params (should be >800K)" |
| |
| |
| assert cell.outnorm.num_groups == 192, f"GroupNorm should have 192 groups, got {cell.outnorm.num_groups}" |
| print(f" GroupNorm groups: {cell.outnorm.num_groups} (correct: per-projection-head)") |
| |
| x = torch.randn(2, 20, 384) |
| y = cell(x) |
| assert y.shape == (2, 20, 384), f"Cell output shape: {y.shape}" |
| |
| |
| y_rev = cell(x, reverse=True) |
| assert y_rev.shape == (2, 20, 384), f"Reverse output shape: {y_rev.shape}" |
| |
| assert not torch.allclose(y, y_rev, atol=1e-3), "Forward and reverse should differ" |
|
|
| test("mLSTM Cell (LinearHeadwiseExpand)", test_mlstm_cell) |
|
|
|
|
| |
| |
| |
| def test_mlstm_block(): |
| from vil_tracker.models.mlstm import mLSTMBlock |
| |
| block = mLSTMBlock(dim=384, proj_factor=2.0, qkv_proj_blocksize=4, |
| num_heads=4, mlp_ratio=4.0) |
| params = count_params(block) |
| print(f" mLSTMBlock params: {params:,} ({params/1e6:.3f}M)") |
| |
| |
| assert params < 1_050_000, f"Block has {params:,} params (should be <1.05M without MLP)" |
| |
| x = torch.randn(2, 20, 384) |
| y = block(x) |
| assert y.shape == (2, 20, 384), f"Block output shape: {y.shape}" |
| |
| |
| diff = (y - x).abs().mean().item() |
| print(f" Residual diff from input: {diff:.4f}") |
|
|
| test("mLSTM Block (no separate MLP)", test_mlstm_block) |
|
|
|
|
| |
| |
| |
| def test_tmoe(): |
| from vil_tracker.models.backbone import TMoEMLP |
| |
| tmoe = TMoEMLP(dim=384, mlp_ratio=4.0, num_experts=4) |
| params = count_params(tmoe) |
| print(f" TMoEMLP params: {params:,} ({params/1e6:.3f}M)") |
| |
| x = torch.randn(2, 20, 384) |
| y = tmoe(x) |
| assert y.shape == (2, 20, 384), f"TMoE output shape: {y.shape}" |
| |
| |
| tmoe.freeze_shared_expert() |
| frozen = sum(1 for p in tmoe.shared_expert.parameters() if not p.requires_grad) |
| total_shared = sum(1 for p in tmoe.shared_expert.parameters()) |
| assert frozen == total_shared, "Shared expert should be fully frozen" |
|
|
| test("TMoE MLP", test_tmoe) |
|
|
|
|
| |
| |
| |
| def test_backbone_small(): |
| from vil_tracker.models.backbone import ViLBackbone |
| |
| backbone = ViLBackbone(dim=384, depth=4, patch_size=16, tmoe_blocks=0) |
| params = count_params(backbone) |
| print(f" Backbone (depth=4, no TMoE) params: {params:,} ({params/1e6:.3f}M)") |
| |
| template = torch.randn(2, 3, 128, 128) |
| search = torch.randn(2, 3, 256, 256) |
| |
| t_feat, s_feat = backbone(template, search) |
| assert t_feat.shape == (2, 64, 384), f"Template feat shape: {t_feat.shape}" |
| assert s_feat.shape == (2, 256, 384), f"Search feat shape: {s_feat.shape}" |
|
|
| test("Backbone (standard, depth=4)", test_backbone_small) |
|
|
|
|
| |
| |
| |
| def test_backbone_tmoe_film(): |
| from vil_tracker.models.backbone import ViLBackbone |
| from vil_tracker.models.film_temporal import TemporalModulationManager |
| |
| backbone = ViLBackbone(dim=384, depth=6, patch_size=16, tmoe_blocks=2, |
| num_experts=4, film_interval=3) |
| params = count_params(backbone) |
| print(f" Backbone (depth=6, TMoE=2) params: {params:,} ({params/1e6:.3f}M)") |
| |
| |
| temporal_mod = TemporalModulationManager(dim=384, num_blocks=6, modulation_interval=3) |
| |
| template = torch.randn(1, 3, 128, 128) |
| search = torch.randn(1, 3, 256, 256) |
| |
| |
| t_feat, s_feat = backbone(template, search, temporal_mod_manager=temporal_mod) |
| assert t_feat.shape == (1, 64, 384), f"Template feat shape: {t_feat.shape}" |
| assert s_feat.shape == (1, 256, 384), f"Search feat shape: {s_feat.shape}" |
| |
| |
| t_feat2, s_feat2 = backbone(template, search, temporal_mod_manager=temporal_mod) |
| |
| assert t_feat2.shape == (1, 64, 384) |
| print(f" FiLM modulation active: features differ = {not torch.allclose(t_feat, t_feat2, atol=1e-5)}") |
|
|
| test("Backbone (TMoE + integrated FiLM)", test_backbone_tmoe_film) |
|
|
|
|
| |
| |
| |
| def test_heads(): |
| from vil_tracker.models.heads import CenterHead, UncertaintyHead, decode_predictions, create_hanning_window |
| |
| center_head = CenterHead(dim=384, feat_size=16) |
| unc_head = UncertaintyHead(dim=384, feat_size=16) |
| |
| print(f" CenterHead params: {count_params(center_head):,}") |
| print(f" UncertaintyHead params: {count_params(unc_head):,}") |
| |
| search_feat = torch.randn(2, 256, 384) |
| preds = center_head(search_feat) |
| |
| assert preds['heatmap'].shape == (2, 1, 16, 16), f"Heatmap shape: {preds['heatmap'].shape}" |
| assert preds['size'].shape == (2, 2, 16, 16), f"Size shape: {preds['size'].shape}" |
| assert preds['offset'].shape == (2, 2, 16, 16), f"Offset shape: {preds['offset'].shape}" |
| |
| |
| boxes, scores = decode_predictions(preds['heatmap'], preds['size'], preds['offset']) |
| assert boxes.shape == (2, 4), f"Boxes shape: {boxes.shape}" |
| assert scores.shape == (2,), f"Scores shape: {scores.shape}" |
| |
| |
| hann = create_hanning_window(16) |
| assert hann.shape == (16, 16), f"Hanning shape: {hann.shape}" |
| assert abs(hann[8, 8].item() - 1.0) < 0.05, f"Hanning center should be ~1.0, got {hann[8, 8]}" |
| assert hann[0, 0].item() < 0.01, f"Hanning corner should be ~0, got {hann[0, 0]}" |
| |
| boxes_h, scores_h = decode_predictions(preds['heatmap'], preds['size'], preds['offset'], |
| hanning_window=hann) |
| assert boxes_h.shape == (2, 4), f"Hanning boxes shape: {boxes_h.shape}" |
| print(f" Hanning window: center={hann[8,8]:.3f}, corner={hann[0,0]:.6f}") |
| print(f" Without Hanning: box={boxes[0].tolist()}, score={scores[0].item():.4f}") |
| print(f" With Hanning: box={boxes_h[0].tolist()}, score={scores_h[0].item():.4f}") |
| |
| |
| log_var = unc_head(search_feat) |
| assert log_var.shape == (2, 1, 16, 16), f"Log variance shape: {log_var.shape}" |
|
|
| test("Prediction Heads", test_heads) |
|
|
|
|
| |
| |
| |
| def test_film(): |
| from vil_tracker.models.film_temporal import ( |
| TemporalReliabilityCalibrator, |
| FiLMTemporalModulation, |
| TemporalModulationManager, |
| ) |
| |
| |
| calib = TemporalReliabilityCalibrator(384) |
| film = FiLMTemporalModulation(384) |
| |
| x = torch.randn(2, 20, 384) |
| tc = torch.randn(2, 20, 384) |
| |
| rel = calib(tc) |
| assert rel.shape == (2, 20, 1), f"Reliability shape: {rel.shape}" |
| assert (rel >= 0).all() and (rel <= 1).all(), "Reliability not in [0,1]" |
| |
| modulated = film(x, tc, rel) |
| assert modulated.shape == (2, 20, 384), f"Modulated shape: {modulated.shape}" |
| |
| |
| manager = TemporalModulationManager(dim=384, num_blocks=24, modulation_interval=6) |
| print(f" TemporalModulationManager params: {count_params(manager):,}") |
| |
| |
| y = manager.modulate(x, block_idx=5) |
| assert torch.allclose(y, x), "Should return unchanged without temporal context" |
| |
| |
| manager.update_temporal_context(x) |
| y = manager.modulate(x, block_idx=5) |
| assert y.shape == (2, 20, 384) |
| |
| |
| manager.reset() |
| y = manager.modulate(x, block_idx=5) |
| assert torch.allclose(y, x), "After reset, should return unchanged" |
|
|
| test("FiLM Temporal Modulation", test_film) |
|
|
|
|
| |
| |
| |
| def test_full_tracker_small(): |
| from vil_tracker.models.tracker import ViLTracker, get_default_config |
| |
| config = get_default_config() |
| config['depth'] = 4 |
| config['tmoe_blocks'] = 1 |
| config['film_interval'] = 2 |
| |
| tracker = ViLTracker(config) |
| params = count_params(tracker) |
| print(f" Tracker (depth=4) params: {params:,} ({params/1e6:.3f}M)") |
| |
| B, K = 2, 3 |
| template = torch.randn(B, 3, 128, 128) |
| |
| |
| search_single = torch.randn(B, 3, 256, 256) |
| output_s = tracker(template, search_single, use_temporal=False) |
| assert output_s['heatmap'].shape == (B, 1, 16, 16), f"Single heatmap: {output_s['heatmap'].shape}" |
| assert output_s['boxes'].shape == (B, 4), f"Single boxes: {output_s['boxes'].shape}" |
| assert output_s['scores'].shape == (B,), f"Single scores: {output_s['scores'].shape}" |
| print(f" Single-frame: boxes={output_s['boxes'][0].tolist()}") |
| |
| |
| searches = torch.randn(B, K, 3, 256, 256) |
| output_m = tracker(template, searches, use_temporal=True) |
| assert output_m['heatmap'].shape == (B, K, 1, 16, 16), f"Multi heatmap: {output_m['heatmap'].shape}" |
| assert output_m['boxes'].shape == (B, K, 4), f"Multi boxes: {output_m['boxes'].shape}" |
| assert output_m['scores'].shape == (B, K), f"Multi scores: {output_m['scores'].shape}" |
| assert output_m['search_feats'].shape == (B, K, 256, 384), f"Multi feats: {output_m['search_feats'].shape}" |
| print(f" Multi-frame (K={K}): frame 0 box={output_m['boxes'][0,0].tolist()}") |
| print(f" frame 2 box={output_m['boxes'][0,2].tolist()}") |
| |
| tracker.reset_temporal() |
|
|
| test("Full Tracker (single + multi-frame)", test_full_tracker_small) |
|
|
|
|
| |
| |
| |
| def test_losses(): |
| from vil_tracker.training.losses import ( |
| FocalLoss, GIoULoss, UncertaintyNLLLoss, |
| MemoryContrastiveLoss, AFKDDistillationLoss, |
| ADWLoss, CombinedTrackingLoss, |
| ) |
| |
| B = 4 |
| |
| |
| focal = FocalLoss() |
| pred_hm = torch.randn(B, 1, 16, 16) |
| gt_hm = torch.zeros(B, 1, 16, 16) |
| gt_hm[:, :, 8, 8] = 1.0 |
| fl = focal(pred_hm, gt_hm) |
| print(f" Focal loss: {fl.item():.4f}") |
| assert fl.item() > 0, "Focal loss should be positive" |
| |
| |
| giou = GIoULoss() |
| pred_box = torch.tensor([[128.0, 128.0, 50.0, 50.0]] * B) |
| gt_box = torch.tensor([[130.0, 130.0, 48.0, 48.0]] * B) |
| gl = giou(pred_box, gt_box) |
| print(f" GIoU loss: {gl.item():.4f}") |
| assert 0 <= gl.item() <= 2, f"GIoU loss out of range: {gl.item()}" |
| |
| |
| unc = UncertaintyNLLLoss() |
| pred_v = torch.randn(B, 4) |
| target_v = torch.randn(B, 4) |
| log_var = torch.zeros(B, 4) |
| ul = unc(pred_v, target_v, log_var) |
| print(f" Uncertainty NLL loss: {ul.item():.4f}") |
| assert ul.item() > 0 |
| |
| |
| contrastive = MemoryContrastiveLoss() |
| feat_a = torch.randn(B, 384) |
| feat_b = feat_a + torch.randn(B, 384) * 0.1 |
| cl = contrastive(feat_a, feat_b) |
| print(f" Contrastive loss: {cl.item():.4f}") |
| |
| |
| afkd = AFKDDistillationLoss(student_dim=384, teacher_dim=768) |
| student_feat = torch.randn(B, 256, 384) |
| teacher_feat = torch.randn(B, 256, 768) |
| dl = afkd(student_feat, teacher_feat) |
| print(f" AFKD distillation loss: {dl.item():.4f}") |
| assert dl.item() > 0 |
| |
| |
| adw = ADWLoss(num_tasks=3) |
| losses = [torch.tensor(1.0), torch.tensor(0.5), torch.tensor(2.0)] |
| al = adw(losses) |
| print(f" ADW loss: {al.item():.4f}") |
| |
| |
| combined = CombinedTrackingLoss() |
| pred = { |
| 'heatmap': pred_hm, |
| 'size': torch.rand(B, 2, 16, 16), |
| 'boxes': pred_box, |
| 'log_variance': torch.randn(B, 1, 16, 16), |
| } |
| loss_dict = combined(pred, gt_hm, torch.tensor([[0.2, 0.2]] * B), gt_box) |
| print(f" Combined loss: {loss_dict['total'].item():.4f}") |
| assert loss_dict['total'].item() > 0 |
|
|
| test("Loss Functions (all 6)", test_losses) |
|
|
|
|
| |
| |
| |
| def test_kalman(): |
| from vil_tracker.inference.kalman import KalmanFilter |
| |
| kf = KalmanFilter() |
| assert not kf.initialized |
| |
| |
| init_box = np.array([100.0, 100.0, 50.0, 50.0]) |
| kf.initialize(init_box) |
| assert kf.initialized |
| |
| |
| for i in range(10): |
| pred = kf.predict() |
| assert len(pred) == 4, f"Prediction length: {len(pred)}" |
| |
| |
| noise = np.random.randn(4) * 2 |
| meas = init_box + np.array([i * 2, i * 1, 0, 0]) + noise |
| kf.update(meas, uncertainty=1.0) |
| |
| state = kf.get_state() |
| print(f" Final state: cx={state[0]:.1f}, cy={state[1]:.1f}, w={state[2]:.1f}, h={state[3]:.1f}") |
| assert state[2] > 0 and state[3] > 0, "Width/height should be positive" |
| |
| |
| kf.update(np.array([500.0, 500.0, 50.0, 50.0]), uncertainty=1.0) |
| state_after = kf.get_state() |
| |
| assert state_after[0] < 200, f"Outlier should be rejected, cx={state_after[0]}" |
|
|
| test("Kalman Filter (8-state, adaptive)", test_kalman) |
|
|
|
|
| |
| |
| |
| def test_dataset(): |
| from vil_tracker.data.dataset import SyntheticTrackingDataset, TrackingDataset |
| |
| ds = SyntheticTrackingDataset(length=100, clip_length=3) |
| assert len(ds) == 100 |
| |
| sample = ds[0] |
| assert sample['template'].shape == (3, 128, 128), f"Template shape: {sample['template'].shape}" |
| assert sample['searches'].shape == (3, 3, 256, 256), f"Searches shape: {sample['searches'].shape}" |
| assert sample['heatmaps'].shape == (3, 1, 16, 16), f"Heatmaps shape: {sample['heatmaps'].shape}" |
| assert sample['sizes'].shape == (3, 2), f"Sizes shape: {sample['sizes'].shape}" |
| assert sample['boxes'].shape == (3, 4), f"Boxes shape: {sample['boxes'].shape}" |
| |
| |
| cx_f0 = sample['boxes'][0, 0].item() |
| cx_f2 = sample['boxes'][2, 0].item() |
| print(f" Frame 0 cx: {cx_f0:.1f}, Frame 2 cx: {cx_f2:.1f} (moving target)") |
| |
| |
| ds.set_acl_difficulty(0.0) |
| easy_sample = ds[42] |
| ds.set_acl_difficulty(1.0) |
| hard_sample = ds[42] |
| print(f" Easy frame spread: {(easy_sample['boxes'][:, 0].max() - easy_sample['boxes'][:, 0].min()).item():.1f} px") |
| print(f" Hard frame spread: {(hard_sample['boxes'][:, 0].max() - hard_sample['boxes'][:, 0].min()).item():.1f} px") |
| |
| |
| ds2 = TrackingDataset(synthetic=True, synthetic_length=50, clip_length=3) |
| assert len(ds2) == 50 |
| sample2 = ds2[0] |
| assert sample2['searches'].shape[0] == 3, "Clip length should be 3" |
|
|
| test("Dataset (synthetic + backward compat)", test_dataset) |
|
|
|
|
| |
| |
| |
| def test_training_step(): |
| from vil_tracker.models.tracker import ViLTracker, get_default_config |
| from vil_tracker.training.losses import CombinedTrackingLoss, MemoryContrastiveLoss |
| from vil_tracker.models.heads import generate_heatmap |
| |
| config = get_default_config() |
| config['depth'] = 2 |
| config['tmoe_blocks'] = 0 |
| config['film_interval'] = 2 |
| |
| model = ViLTracker(config) |
| model.train() |
| loss_fn = CombinedTrackingLoss() |
| contrastive_loss = MemoryContrastiveLoss() |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) |
| |
| B, K = 2, 3 |
| template = torch.randn(B, 3, 128, 128) |
| searches = torch.randn(B, K, 3, 256, 256) |
| |
| |
| gt_heatmaps = torch.zeros(B, K, 1, 16, 16) |
| gt_heatmaps[:, :, :, 8, 8] = 1.0 |
| gt_sizes = torch.tensor([[[0.2, 0.3]] * K] * B) |
| gt_boxes = torch.tensor([[[128.0, 128.0, 51.2, 76.8]] * K] * B) |
| |
| |
| pred = model(template, searches, use_temporal=True) |
| |
| assert pred['heatmap'].shape == (B, K, 1, 16, 16), f"Heatmap shape: {pred['heatmap'].shape}" |
| assert pred['boxes'].shape == (B, K, 4), f"Boxes shape: {pred['boxes'].shape}" |
| assert pred['scores'].shape == (B, K), f"Scores shape: {pred['scores'].shape}" |
| assert pred['search_feats'].shape == (B, K, 256, 384), f"Search feats: {pred['search_feats'].shape}" |
| |
| |
| total_loss = torch.tensor(0.0) |
| for k in range(K): |
| pred_k = { |
| 'heatmap': pred['heatmap'][:, k], |
| 'size': pred['size'][:, k], |
| 'boxes': pred['boxes'][:, k], |
| } |
| if 'log_variance' in pred: |
| pred_k['log_variance'] = pred['log_variance'][:, k] |
| loss_dict = loss_fn(pred_k, gt_heatmaps[:, k], gt_sizes[:, k], gt_boxes[:, k]) |
| total_loss = total_loss + loss_dict['total'] |
| total_loss = total_loss / K |
| |
| |
| t_pooled = pred['template_feat'].mean(dim=1) |
| s_pooled = pred['search_feats'][:, -1].mean(dim=1) |
| c_loss = contrastive_loss(t_pooled, s_pooled) |
| total_loss = total_loss + 0.1 * c_loss |
| |
| |
| total_loss.backward() |
| |
| has_grads = sum(1 for p in model.parameters() if p.grad is not None) |
| total_params_count = sum(1 for p in model.parameters()) |
| print(f" Total loss: {total_loss.item():.4f} (K={K} frames, contr={c_loss.item():.4f})") |
| print(f" Params with gradients: {has_grads}/{total_params_count}") |
| |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| assert total_loss.item() > 0 |
| assert has_grads > 0 |
|
|
| test("Training Step (K=3 sequence + contrastive)", test_training_step) |
|
|
|
|
| |
| |
| |
| def test_model_summary(): |
| from vil_tracker.models.tracker import ViLTracker, get_default_config |
| from vil_tracker.utils.helpers import print_model_summary |
| |
| config = get_default_config() |
| model = ViLTracker(config) |
| |
| summary = print_model_summary(model, config) |
| |
| total_m = summary['total_params'] / 1e6 |
| |
| |
| assert summary['param_ok'], f"FAIL: {total_m:.2f}M params exceeds 50M limit" |
| assert summary['size_ok'], f"FAIL: {summary['size_fp16_mb']:.1f}MB exceeds 500MB limit" |
| |
| if not summary['flop_ok']: |
| print(f" ⚠️ GFLOPs estimate ({summary['gflops']:.2f}) exceeds 20, but this is approximate") |
|
|
| test("Model Summary (full depth=24)", test_model_summary) |
|
|
|
|
| |
| |
| |
| def test_online_tracker(): |
| from vil_tracker.models.tracker import ViLTracker, get_default_config |
| from vil_tracker.inference.online_tracker import OnlineTracker |
| |
| config = get_default_config() |
| config['depth'] = 2 |
| config['tmoe_blocks'] = 0 |
| config['film_interval'] = 2 |
| |
| model = ViLTracker(config) |
| model.eval() |
| |
| tracker = OnlineTracker(model, device='cpu', template_size=128, search_size=256) |
| |
| |
| H, W = 480, 640 |
| init_bbox = [200, 200, 60, 80] |
| |
| |
| frame0 = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8) |
| |
| x, y, w, h = init_bbox |
| frame0[y:y+h, x:x+w] = [255, 0, 0] |
| |
| tracker.initialize(frame0, init_bbox) |
| |
| |
| for i in range(1, 6): |
| frame = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8) |
| |
| nx = x + i * 5 |
| ny = y + i * 3 |
| frame[ny:ny+h, nx:nx+w] = [255, 0, 0] |
| |
| bbox = tracker.track(frame) |
| assert len(bbox) == 4, f"Bbox should have 4 elements, got {len(bbox)}" |
| assert all(isinstance(v, (int, float, np.floating)) for v in bbox), f"Bbox values: {bbox}" |
| print(f" Frame {i}: predicted [{bbox[0]:.1f}, {bbox[1]:.1f}, {bbox[2]:.1f}, {bbox[3]:.1f}]") |
| |
| print(f" Online tracker completed 5-frame sequence") |
|
|
| test("Online Tracker (inference pipeline)", test_online_tracker) |
|
|
|
|
| |
| |
| |
| def test_augmentation(): |
| from vil_tracker.data.dataset import TrackingAugmentation |
| |
| aug = TrackingAugmentation( |
| brightness=0.2, |
| contrast=0.2, |
| horizontal_flip_prob=1.0, |
| grayscale_prob=0.0, |
| blur_prob=0.0, |
| ) |
| |
| template = torch.rand(3, 128, 128) |
| search = torch.rand(3, 256, 256) |
| bbox = torch.tensor([128.0, 128.0, 50.0, 50.0]) |
| |
| t_aug, s_aug, b_aug = aug(template, search, bbox) |
| |
| assert t_aug.shape == (3, 128, 128), f"Aug template shape: {t_aug.shape}" |
| assert s_aug.shape == (3, 256, 256), f"Aug search shape: {s_aug.shape}" |
| assert b_aug.shape == (4,), f"Aug bbox shape: {b_aug.shape}" |
| |
| |
| print(f" Original bbox: {bbox.tolist()}") |
| print(f" Augmented bbox: {b_aug.tolist()}") |
| assert abs(b_aug[0].item() - (256 - 128)) < 1.0, f"Flipped cx should be ~128, got {b_aug[0]}" |
|
|
| test("Augmentation pipeline", test_augmentation) |
|
|
|
|
| |
| |
| |
| def test_acl_curriculum(): |
| from vil_tracker.data.dataset import SyntheticTrackingDataset |
| |
| ds = SyntheticTrackingDataset(length=100, acl_difficulty=0.0, clip_length=3) |
| |
| |
| easy_spreads = [] |
| for i in range(20): |
| sample = ds[i] |
| spread = (sample['boxes'][:, 0].max() - sample['boxes'][:, 0].min()).item() |
| easy_spreads.append(spread) |
| |
| ds.set_acl_difficulty(1.0) |
| |
| hard_spreads = [] |
| for i in range(20): |
| sample = ds[i] |
| spread = (sample['boxes'][:, 0].max() - sample['boxes'][:, 0].min()).item() |
| hard_spreads.append(spread) |
| |
| avg_easy = np.mean(easy_spreads) |
| avg_hard = np.mean(hard_spreads) |
| |
| print(f" Avg cx spread (easy, d=0.0): {avg_easy:.1f} px") |
| print(f" Avg cx spread (hard, d=1.0): {avg_hard:.1f} px") |
| print(f" Hard > Easy: {avg_hard > avg_easy}") |
|
|
| test("ACL curriculum integration", test_acl_curriculum) |
|
|
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print(f"Results: {PASS}/{PASS + FAIL} tests passed") |
| if FAIL > 0: |
| print(f" ❌ {FAIL} test(s) FAILED") |
| sys.exit(1) |
| else: |
| print(" ✅ All tests passed!") |
| sys.exit(0) |
|
|