omar-ah commited on
Commit
7e7f067
·
verified ·
1 Parent(s): 0bd347a

Fix test_all.py: audit corrections

Browse files
Files changed (1) hide show
  1. test_all.py +226 -35
test_all.py CHANGED
@@ -1,20 +1,23 @@
1
  """
2
  Comprehensive test suite for ViL Tracker.
3
 
4
- 13 tests covering all components:
5
  1. mLSTM Cell (LinearHeadwiseExpand correctness + param count)
6
- 2. mLSTM Block (full block with MLP)
7
  3. TMoE MLP
8
  4. Backbone (standard, small depth)
9
- 5. Backbone (with TMoE, medium depth)
10
  6. Prediction Heads
11
  7. FiLM Temporal Modulation
12
  8. Full Tracker (small depth for speed)
13
- 9. Loss Functions
14
- 10. Kalman Filter
15
  11. Dataset (synthetic)
16
- 12. Training Step (mini forward + backward)
17
  13. Model Summary (FULL depth=24, constraint check)
 
 
 
18
  """
19
 
20
  import sys
@@ -94,6 +97,9 @@ def test_mlstm_block():
94
  params = count_params(block)
95
  print(f" mLSTMBlock params: {params:,} ({params/1e6:.3f}M)")
96
 
 
 
 
97
  x = torch.randn(2, 20, 384)
98
  y = block(x)
99
  assert y.shape == (2, 20, 384), f"Block output shape: {y.shape}"
@@ -102,7 +108,7 @@ def test_mlstm_block():
102
  diff = (y - x).abs().mean().item()
103
  print(f" Residual diff from input: {diff:.4f}")
104
 
105
- test("mLSTM Block", test_mlstm_block)
106
 
107
 
108
  # ============================================================
@@ -149,23 +155,35 @@ test("Backbone (standard, depth=4)", test_backbone_small)
149
 
150
 
151
  # ============================================================
152
- # Test 5: Backbone (with TMoE, depth=6)
153
  # ============================================================
154
- def test_backbone_tmoe():
155
  from vil_tracker.models.backbone import ViLBackbone
 
156
 
157
- backbone = ViLBackbone(dim=384, depth=6, patch_size=16, tmoe_blocks=2, num_experts=4)
 
158
  params = count_params(backbone)
159
  print(f" Backbone (depth=6, TMoE=2) params: {params:,} ({params/1e6:.3f}M)")
160
 
 
 
 
161
  template = torch.randn(1, 3, 128, 128)
162
  search = torch.randn(1, 3, 256, 256)
163
 
164
- t_feat, s_feat = backbone(template, search)
 
165
  assert t_feat.shape == (1, 64, 384), f"Template feat shape: {t_feat.shape}"
166
  assert s_feat.shape == (1, 256, 384), f"Search feat shape: {s_feat.shape}"
 
 
 
 
 
 
167
 
168
- test("Backbone (with TMoE, depth=6)", test_backbone_tmoe)
169
 
170
 
171
  # ============================================================
@@ -234,8 +252,12 @@ def test_film():
234
  # Update context and try again
235
  manager.update_temporal_context(x)
236
  y = manager.modulate(x, block_idx=5) # block 5 → (5+1)%6==0, should modulate
237
- # With temporal context, output should differ
238
  assert y.shape == (2, 20, 384)
 
 
 
 
 
239
 
240
  test("FiLM Temporal Modulation", test_film)
241
 
@@ -258,27 +280,38 @@ def test_full_tracker_small():
258
  template = torch.randn(2, 3, 128, 128)
259
  search = torch.randn(2, 3, 256, 256)
260
 
261
- output = tracker(template, search)
262
-
263
  assert output['heatmap'].shape == (2, 1, 16, 16)
264
- assert output['size'].shape == (2, 2, 16, 16)
265
  assert output['boxes'].shape == (2, 4)
266
  assert output['scores'].shape == (2,)
267
  assert 'log_variance' in output
268
 
 
 
 
 
 
 
 
 
 
 
 
269
  print(f" Predicted boxes: {output['boxes'][0].tolist()}")
270
  print(f" Scores: {output['scores'].tolist()}")
271
 
272
- test("Full Tracker (depth=4)", test_full_tracker_small)
273
 
274
 
275
  # ============================================================
276
- # Test 9: Loss Functions
277
  # ============================================================
278
  def test_losses():
279
  from vil_tracker.training.losses import (
280
  FocalLoss, GIoULoss, UncertaintyNLLLoss,
281
- MemoryContrastiveLoss, CombinedTrackingLoss,
 
282
  )
283
 
284
  B = 4
@@ -300,13 +333,36 @@ def test_losses():
300
  print(f" GIoU loss: {gl.item():.4f}")
301
  assert 0 <= gl.item() <= 2, f"GIoU loss out of range: {gl.item()}"
302
 
 
 
 
 
 
 
 
 
 
303
  # Contrastive loss
304
  contrastive = MemoryContrastiveLoss()
305
  feat_a = torch.randn(B, 384)
306
- feat_b = feat_a + torch.randn(B, 384) * 0.1 # slightly perturbed
307
  cl = contrastive(feat_a, feat_b)
308
  print(f" Contrastive loss: {cl.item():.4f}")
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  # Combined loss
311
  combined = CombinedTrackingLoss()
312
  pred = {
@@ -319,7 +375,7 @@ def test_losses():
319
  print(f" Combined loss: {loss_dict['total'].item():.4f}")
320
  assert loss_dict['total'].item() > 0
321
 
322
- test("Loss Functions", test_losses)
323
 
324
 
325
  # ============================================================
@@ -336,12 +392,12 @@ def test_kalman():
336
  kf.initialize(init_box)
337
  assert kf.initialized
338
 
339
- # Predict + update cycle
340
  for i in range(10):
341
  pred = kf.predict()
342
  assert len(pred) == 4, f"Prediction length: {len(pred)}"
343
 
344
- # Simulate noisy measurement
345
  noise = np.random.randn(4) * 2
346
  meas = init_box + np.array([i * 2, i * 1, 0, 0]) + noise
347
  kf.update(meas, uncertainty=1.0)
@@ -349,17 +405,23 @@ def test_kalman():
349
  state = kf.get_state()
350
  print(f" Final state: cx={state[0]:.1f}, cy={state[1]:.1f}, w={state[2]:.1f}, h={state[3]:.1f}")
351
  assert state[2] > 0 and state[3] > 0, "Width/height should be positive"
 
 
 
 
 
 
352
 
353
- test("Kalman Filter", test_kalman)
354
 
355
 
356
  # ============================================================
357
  # Test 11: Dataset (synthetic)
358
  # ============================================================
359
  def test_dataset():
360
- from vil_tracker.data.dataset import TrackingDataset
361
 
362
- ds = TrackingDataset(synthetic=True, synthetic_length=100)
363
  assert len(ds) == 100
364
 
365
  sample = ds[0]
@@ -376,16 +438,22 @@ def test_dataset():
376
  hard_sample = ds[42]
377
  print(f" Easy center: {easy_sample['boxes'][:2].tolist()}")
378
  print(f" Hard center: {hard_sample['boxes'][:2].tolist()}")
 
 
 
 
 
 
379
 
380
- test("Dataset (synthetic)", test_dataset)
381
 
382
 
383
  # ============================================================
384
- # Test 12: Training Step (mini forward + backward)
385
  # ============================================================
386
  def test_training_step():
387
  from vil_tracker.models.tracker import ViLTracker, get_default_config
388
- from vil_tracker.training.losses import CombinedTrackingLoss
389
  from vil_tracker.models.heads import generate_heatmap
390
 
391
  config = get_default_config()
@@ -396,6 +464,7 @@ def test_training_step():
396
  model = ViLTracker(config)
397
  model.train()
398
  loss_fn = CombinedTrackingLoss()
 
399
  optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
400
 
401
  B = 2
@@ -408,27 +477,33 @@ def test_training_step():
408
  gt_size = torch.tensor([[0.2, 0.3], [0.15, 0.25]])
409
  gt_boxes = torch.tensor([[128.0, 128.0, 51.2, 76.8], [100.0, 150.0, 38.4, 64.0]])
410
 
411
- # Forward
412
- pred = model(template, search)
413
  loss_dict = loss_fn(pred, gt_heatmap, gt_size, gt_boxes)
414
 
 
 
 
 
 
 
415
  # Backward
416
- loss_dict['total'].backward()
417
 
418
  # Check gradients exist
419
  has_grads = sum(1 for p in model.parameters() if p.grad is not None)
420
  total_params_count = sum(1 for p in model.parameters())
421
- print(f" Loss: {loss_dict['total'].item():.4f}")
422
  print(f" Params with gradients: {has_grads}/{total_params_count}")
423
 
424
  # Optimizer step
425
  optimizer.step()
426
  optimizer.zero_grad()
427
 
428
- assert loss_dict['total'].item() > 0
429
  assert has_grads > 0
430
 
431
- test("Training Step (depth=2)", test_training_step)
432
 
433
 
434
  # ============================================================
@@ -455,6 +530,122 @@ def test_model_summary():
455
  test("Model Summary (full depth=24)", test_model_summary)
456
 
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  # ============================================================
459
  # Summary
460
  # ============================================================
 
1
  """
2
  Comprehensive test suite for ViL Tracker.
3
 
4
+ 16 tests covering all components:
5
  1. mLSTM Cell (LinearHeadwiseExpand correctness + param count)
6
+ 2. mLSTM Block (full block without MLP)
7
  3. TMoE MLP
8
  4. Backbone (standard, small depth)
9
+ 5. Backbone (with TMoE + integrated FiLM, medium depth)
10
  6. Prediction Heads
11
  7. FiLM Temporal Modulation
12
  8. Full Tracker (small depth for speed)
13
+ 9. Loss Functions (all 6)
14
+ 10. Kalman Filter (8-state, adaptive)
15
  11. Dataset (synthetic)
16
+ 12. Training Step (mini forward + backward with temporal)
17
  13. Model Summary (FULL depth=24, constraint check)
18
+ 14. Online Tracker (full inference pipeline)
19
+ 15. Augmentation pipeline
20
+ 16. ACL curriculum integration
21
  """
22
 
23
  import sys
 
97
  params = count_params(block)
98
  print(f" mLSTMBlock params: {params:,} ({params/1e6:.3f}M)")
99
 
100
+ # No separate MLP — should be ~920K, same as cell + LayerNorm
101
+ assert params < 1_050_000, f"Block has {params:,} params (should be <1.05M without MLP)"
102
+
103
  x = torch.randn(2, 20, 384)
104
  y = block(x)
105
  assert y.shape == (2, 20, 384), f"Block output shape: {y.shape}"
 
108
  diff = (y - x).abs().mean().item()
109
  print(f" Residual diff from input: {diff:.4f}")
110
 
111
+ test("mLSTM Block (no separate MLP)", test_mlstm_block)
112
 
113
 
114
  # ============================================================
 
155
 
156
 
157
  # ============================================================
158
+ # Test 5: Backbone with TMoE + integrated FiLM
159
  # ============================================================
160
+ def test_backbone_tmoe_film():
161
  from vil_tracker.models.backbone import ViLBackbone
162
+ from vil_tracker.models.film_temporal import TemporalModulationManager
163
 
164
+ backbone = ViLBackbone(dim=384, depth=6, patch_size=16, tmoe_blocks=2,
165
+ num_experts=4, film_interval=3)
166
  params = count_params(backbone)
167
  print(f" Backbone (depth=6, TMoE=2) params: {params:,} ({params/1e6:.3f}M)")
168
 
169
+ # Create temporal modulation manager
170
+ temporal_mod = TemporalModulationManager(dim=384, num_blocks=6, modulation_interval=3)
171
+
172
  template = torch.randn(1, 3, 128, 128)
173
  search = torch.randn(1, 3, 256, 256)
174
 
175
+ # First pass: no temporal context yet
176
+ t_feat, s_feat = backbone(template, search, temporal_mod_manager=temporal_mod)
177
  assert t_feat.shape == (1, 64, 384), f"Template feat shape: {t_feat.shape}"
178
  assert s_feat.shape == (1, 256, 384), f"Search feat shape: {s_feat.shape}"
179
+
180
+ # Second pass: temporal context should be active now
181
+ t_feat2, s_feat2 = backbone(template, search, temporal_mod_manager=temporal_mod)
182
+ # Output should differ when temporal modulation is active
183
+ assert t_feat2.shape == (1, 64, 384)
184
+ print(f" FiLM modulation active: features differ = {not torch.allclose(t_feat, t_feat2, atol=1e-5)}")
185
 
186
+ test("Backbone (TMoE + integrated FiLM)", test_backbone_tmoe_film)
187
 
188
 
189
  # ============================================================
 
252
  # Update context and try again
253
  manager.update_temporal_context(x)
254
  y = manager.modulate(x, block_idx=5) # block 5 → (5+1)%6==0, should modulate
 
255
  assert y.shape == (2, 20, 384)
256
+
257
+ # Test reset
258
+ manager.reset()
259
+ y = manager.modulate(x, block_idx=5)
260
+ assert torch.allclose(y, x), "After reset, should return unchanged"
261
 
262
  test("FiLM Temporal Modulation", test_film)
263
 
 
280
  template = torch.randn(2, 3, 128, 128)
281
  search = torch.randn(2, 3, 256, 256)
282
 
283
+ # Test without temporal
284
+ output = tracker(template, search, use_temporal=False)
285
  assert output['heatmap'].shape == (2, 1, 16, 16)
 
286
  assert output['boxes'].shape == (2, 4)
287
  assert output['scores'].shape == (2,)
288
  assert 'log_variance' in output
289
 
290
+ # Test with temporal (first frame: no context)
291
+ output_t1 = tracker(template, search, use_temporal=True)
292
+ assert output_t1['boxes'].shape == (2, 4)
293
+
294
+ # Second frame: temporal context available
295
+ output_t2 = tracker(template, search, use_temporal=True)
296
+ assert output_t2['boxes'].shape == (2, 4)
297
+
298
+ # Reset temporal
299
+ tracker.reset_temporal()
300
+
301
  print(f" Predicted boxes: {output['boxes'][0].tolist()}")
302
  print(f" Scores: {output['scores'].tolist()}")
303
 
304
+ test("Full Tracker (depth=4, with temporal)", test_full_tracker_small)
305
 
306
 
307
  # ============================================================
308
+ # Test 9: Loss Functions (all 6)
309
  # ============================================================
310
  def test_losses():
311
  from vil_tracker.training.losses import (
312
  FocalLoss, GIoULoss, UncertaintyNLLLoss,
313
+ MemoryContrastiveLoss, AFKDDistillationLoss,
314
+ ADWLoss, CombinedTrackingLoss,
315
  )
316
 
317
  B = 4
 
333
  print(f" GIoU loss: {gl.item():.4f}")
334
  assert 0 <= gl.item() <= 2, f"GIoU loss out of range: {gl.item()}"
335
 
336
+ # Uncertainty NLL loss
337
+ unc = UncertaintyNLLLoss()
338
+ pred_v = torch.randn(B, 4)
339
+ target_v = torch.randn(B, 4)
340
+ log_var = torch.zeros(B, 4) # unit variance
341
+ ul = unc(pred_v, target_v, log_var)
342
+ print(f" Uncertainty NLL loss: {ul.item():.4f}")
343
+ assert ul.item() > 0
344
+
345
  # Contrastive loss
346
  contrastive = MemoryContrastiveLoss()
347
  feat_a = torch.randn(B, 384)
348
+ feat_b = feat_a + torch.randn(B, 384) * 0.1
349
  cl = contrastive(feat_a, feat_b)
350
  print(f" Contrastive loss: {cl.item():.4f}")
351
 
352
+ # AFKD distillation loss
353
+ afkd = AFKDDistillationLoss(student_dim=384, teacher_dim=768)
354
+ student_feat = torch.randn(B, 256, 384)
355
+ teacher_feat = torch.randn(B, 256, 768)
356
+ dl = afkd(student_feat, teacher_feat)
357
+ print(f" AFKD distillation loss: {dl.item():.4f}")
358
+ assert dl.item() > 0
359
+
360
+ # ADW loss
361
+ adw = ADWLoss(num_tasks=3)
362
+ losses = [torch.tensor(1.0), torch.tensor(0.5), torch.tensor(2.0)]
363
+ al = adw(losses)
364
+ print(f" ADW loss: {al.item():.4f}")
365
+
366
  # Combined loss
367
  combined = CombinedTrackingLoss()
368
  pred = {
 
375
  print(f" Combined loss: {loss_dict['total'].item():.4f}")
376
  assert loss_dict['total'].item() > 0
377
 
378
+ test("Loss Functions (all 6)", test_losses)
379
 
380
 
381
  # ============================================================
 
392
  kf.initialize(init_box)
393
  assert kf.initialized
394
 
395
+ # Predict + update cycle with moving target
396
  for i in range(10):
397
  pred = kf.predict()
398
  assert len(pred) == 4, f"Prediction length: {len(pred)}"
399
 
400
+ # Simulate noisy measurement of linearly moving target
401
  noise = np.random.randn(4) * 2
402
  meas = init_box + np.array([i * 2, i * 1, 0, 0]) + noise
403
  kf.update(meas, uncertainty=1.0)
 
405
  state = kf.get_state()
406
  print(f" Final state: cx={state[0]:.1f}, cy={state[1]:.1f}, w={state[2]:.1f}, h={state[3]:.1f}")
407
  assert state[2] > 0 and state[3] > 0, "Width/height should be positive"
408
+
409
+ # Test outlier rejection (chi-squared gating)
410
+ kf.update(np.array([500.0, 500.0, 50.0, 50.0]), uncertainty=1.0) # Far outlier
411
+ state_after = kf.get_state()
412
+ # State should NOT have jumped to 500,500
413
+ assert state_after[0] < 200, f"Outlier should be rejected, cx={state_after[0]}"
414
 
415
+ test("Kalman Filter (8-state, adaptive)", test_kalman)
416
 
417
 
418
  # ============================================================
419
  # Test 11: Dataset (synthetic)
420
  # ============================================================
421
  def test_dataset():
422
+ from vil_tracker.data.dataset import SyntheticTrackingDataset, TrackingDataset
423
 
424
+ ds = SyntheticTrackingDataset(length=100)
425
  assert len(ds) == 100
426
 
427
  sample = ds[0]
 
438
  hard_sample = ds[42]
439
  print(f" Easy center: {easy_sample['boxes'][:2].tolist()}")
440
  print(f" Hard center: {hard_sample['boxes'][:2].tolist()}")
441
+
442
+ # Test backward-compatible alias
443
+ ds2 = TrackingDataset(synthetic=True, synthetic_length=50)
444
+ assert len(ds2) == 50
445
+ sample2 = ds2[0]
446
+ assert sample2['template'].shape == (3, 128, 128)
447
 
448
+ test("Dataset (synthetic + backward compat)", test_dataset)
449
 
450
 
451
  # ============================================================
452
+ # Test 12: Training Step (with temporal modulation)
453
  # ============================================================
454
  def test_training_step():
455
  from vil_tracker.models.tracker import ViLTracker, get_default_config
456
+ from vil_tracker.training.losses import CombinedTrackingLoss, MemoryContrastiveLoss
457
  from vil_tracker.models.heads import generate_heatmap
458
 
459
  config = get_default_config()
 
464
  model = ViLTracker(config)
465
  model.train()
466
  loss_fn = CombinedTrackingLoss()
467
+ contrastive_loss = MemoryContrastiveLoss()
468
  optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
469
 
470
  B = 2
 
477
  gt_size = torch.tensor([[0.2, 0.3], [0.15, 0.25]])
478
  gt_boxes = torch.tensor([[128.0, 128.0, 51.2, 76.8], [100.0, 150.0, 38.4, 64.0]])
479
 
480
+ # Forward WITH temporal modulation
481
+ pred = model(template, search, use_temporal=True)
482
  loss_dict = loss_fn(pred, gt_heatmap, gt_size, gt_boxes)
483
 
484
+ # Add contrastive loss
485
+ t_pooled = pred['template_feat'].mean(dim=1)
486
+ s_pooled = pred['search_feat'].mean(dim=1)
487
+ c_loss = contrastive_loss(t_pooled, s_pooled)
488
+ total_loss = loss_dict['total'] + 0.1 * c_loss
489
+
490
  # Backward
491
+ total_loss.backward()
492
 
493
  # Check gradients exist
494
  has_grads = sum(1 for p in model.parameters() if p.grad is not None)
495
  total_params_count = sum(1 for p in model.parameters())
496
+ print(f" Total loss: {total_loss.item():.4f} (tracking={loss_dict['total'].item():.4f}, contr={c_loss.item():.4f})")
497
  print(f" Params with gradients: {has_grads}/{total_params_count}")
498
 
499
  # Optimizer step
500
  optimizer.step()
501
  optimizer.zero_grad()
502
 
503
+ assert total_loss.item() > 0
504
  assert has_grads > 0
505
 
506
+ test("Training Step (with temporal + contrastive)", test_training_step)
507
 
508
 
509
  # ============================================================
 
530
  test("Model Summary (full depth=24)", test_model_summary)
531
 
532
 
533
+ # ============================================================
534
+ # Test 14: Online Tracker (inference pipeline)
535
+ # ============================================================
536
+ def test_online_tracker():
537
+ from vil_tracker.models.tracker import ViLTracker, get_default_config
538
+ from vil_tracker.inference.online_tracker import OnlineTracker
539
+
540
+ config = get_default_config()
541
+ config['depth'] = 2
542
+ config['tmoe_blocks'] = 0
543
+ config['film_interval'] = 2
544
+
545
+ model = ViLTracker(config)
546
+ model.eval()
547
+
548
+ tracker = OnlineTracker(model, device='cpu', template_size=128, search_size=256)
549
+
550
+ # Simulate a sequence: 480x640 frames with a moving rectangle
551
+ H, W = 480, 640
552
+ init_bbox = [200, 200, 60, 80] # [x, y, w, h]
553
+
554
+ # First frame
555
+ frame0 = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
556
+ # Draw target
557
+ x, y, w, h = init_bbox
558
+ frame0[y:y+h, x:x+w] = [255, 0, 0] # Red rectangle
559
+
560
+ tracker.initialize(frame0, init_bbox)
561
+
562
+ # Track for 5 frames
563
+ for i in range(1, 6):
564
+ frame = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
565
+ # Move target
566
+ nx = x + i * 5
567
+ ny = y + i * 3
568
+ frame[ny:ny+h, nx:nx+w] = [255, 0, 0]
569
+
570
+ bbox = tracker.track(frame)
571
+ assert len(bbox) == 4, f"Bbox should have 4 elements, got {len(bbox)}"
572
+ assert all(isinstance(v, (int, float, np.floating)) for v in bbox), f"Bbox values: {bbox}"
573
+ print(f" Frame {i}: predicted [{bbox[0]:.1f}, {bbox[1]:.1f}, {bbox[2]:.1f}, {bbox[3]:.1f}]")
574
+
575
+ print(f" Online tracker completed 5-frame sequence")
576
+
577
+ test("Online Tracker (inference pipeline)", test_online_tracker)
578
+
579
+
580
+ # ============================================================
581
+ # Test 15: Augmentation pipeline
582
+ # ============================================================
583
+ def test_augmentation():
584
+ from vil_tracker.data.dataset import TrackingAugmentation
585
+
586
+ aug = TrackingAugmentation(
587
+ brightness=0.2,
588
+ contrast=0.2,
589
+ horizontal_flip_prob=1.0, # Force flip to test bbox update
590
+ grayscale_prob=0.0,
591
+ blur_prob=0.0,
592
+ )
593
+
594
+ template = torch.rand(3, 128, 128)
595
+ search = torch.rand(3, 256, 256)
596
+ bbox = torch.tensor([128.0, 128.0, 50.0, 50.0]) # [cx, cy, w, h]
597
+
598
+ t_aug, s_aug, b_aug = aug(template, search, bbox)
599
+
600
+ assert t_aug.shape == (3, 128, 128), f"Aug template shape: {t_aug.shape}"
601
+ assert s_aug.shape == (3, 256, 256), f"Aug search shape: {s_aug.shape}"
602
+ assert b_aug.shape == (4,), f"Aug bbox shape: {b_aug.shape}"
603
+
604
+ # With flip_prob=1.0, cx should be flipped: new_cx = W - old_cx = 256 - 128 = 128
605
+ print(f" Original bbox: {bbox.tolist()}")
606
+ print(f" Augmented bbox: {b_aug.tolist()}")
607
+ assert abs(b_aug[0].item() - (256 - 128)) < 1.0, f"Flipped cx should be ~128, got {b_aug[0]}"
608
+
609
+ test("Augmentation pipeline", test_augmentation)
610
+
611
+
612
+ # ============================================================
613
+ # Test 16: ACL curriculum integration
614
+ # ============================================================
615
+ def test_acl_curriculum():
616
+ from vil_tracker.data.dataset import SyntheticTrackingDataset
617
+
618
+ ds = SyntheticTrackingDataset(length=100, acl_difficulty=0.0)
619
+
620
+ # Easy: targets near center
621
+ easy_offsets = []
622
+ for i in range(20):
623
+ sample = ds[i]
624
+ cx, cy = sample['boxes'][:2].tolist()
625
+ offset = ((cx - 128) ** 2 + (cy - 128) ** 2) ** 0.5
626
+ easy_offsets.append(offset)
627
+
628
+ ds.set_acl_difficulty(1.0)
629
+
630
+ hard_offsets = []
631
+ for i in range(20):
632
+ sample = ds[i]
633
+ cx, cy = sample['boxes'][:2].tolist()
634
+ offset = ((cx - 128) ** 2 + (cy - 128) ** 2) ** 0.5
635
+ hard_offsets.append(offset)
636
+
637
+ avg_easy = np.mean(easy_offsets)
638
+ avg_hard = np.mean(hard_offsets)
639
+
640
+ print(f" Avg offset (easy, d=0.0): {avg_easy:.1f} px")
641
+ print(f" Avg offset (hard, d=1.0): {avg_hard:.1f} px")
642
+ # Hard samples should have larger offsets from center on average
643
+ # (this is stochastic, so we allow some tolerance)
644
+ print(f" Hard > Easy: {avg_hard > avg_easy * 0.5}")
645
+
646
+ test("ACL curriculum integration", test_acl_curriculum)
647
+
648
+
649
  # ============================================================
650
  # Summary
651
  # ============================================================