omar-ah commited on
Commit
4ba026e
·
verified ·
1 Parent(s): 1bf192e

Sequence training: pairs→K-frame clips, mLSTM memory carries across frames

Browse files
Files changed (1) hide show
  1. test_all.py +84 -68
test_all.py CHANGED
@@ -294,31 +294,30 @@ def test_full_tracker_small():
294
  params = count_params(tracker)
295
  print(f" Tracker (depth=4) params: {params:,} ({params/1e6:.3f}M)")
296
 
297
- template = torch.randn(2, 3, 128, 128)
298
- search = torch.randn(2, 3, 256, 256)
299
-
300
- # Test without temporal
301
- output = tracker(template, search, use_temporal=False)
302
- assert output['heatmap'].shape == (2, 1, 16, 16)
303
- assert output['boxes'].shape == (2, 4)
304
- assert output['scores'].shape == (2,)
305
- assert 'log_variance' in output
306
-
307
- # Test with temporal (first frame: no context)
308
- output_t1 = tracker(template, search, use_temporal=True)
309
- assert output_t1['boxes'].shape == (2, 4)
310
 
311
- # Second frame: temporal context available
312
- output_t2 = tracker(template, search, use_temporal=True)
313
- assert output_t2['boxes'].shape == (2, 4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
- # Reset temporal
316
  tracker.reset_temporal()
317
-
318
- print(f" Predicted boxes: {output['boxes'][0].tolist()}")
319
- print(f" Scores: {output['scores'].tolist()}")
320
 
321
- test("Full Tracker (depth=4, with temporal)", test_full_tracker_small)
322
 
323
 
324
  # ============================================================
@@ -438,29 +437,34 @@ test("Kalman Filter (8-state, adaptive)", test_kalman)
438
  def test_dataset():
439
  from vil_tracker.data.dataset import SyntheticTrackingDataset, TrackingDataset
440
 
441
- ds = SyntheticTrackingDataset(length=100)
442
  assert len(ds) == 100
443
 
444
  sample = ds[0]
445
  assert sample['template'].shape == (3, 128, 128), f"Template shape: {sample['template'].shape}"
446
- assert sample['search'].shape == (3, 256, 256), f"Search shape: {sample['search'].shape}"
447
- assert sample['heatmap'].shape == (1, 16, 16), f"Heatmap shape: {sample['heatmap'].shape}"
448
- assert sample['size'].shape == (2,), f"Size shape: {sample['size'].shape}"
449
- assert sample['boxes'].shape == (4,), f"Boxes shape: {sample['boxes'].shape}"
450
 
451
- # Check ACL difficulty changes output
 
 
 
 
 
452
  ds.set_acl_difficulty(0.0)
453
  easy_sample = ds[42]
454
  ds.set_acl_difficulty(1.0)
455
  hard_sample = ds[42]
456
- print(f" Easy center: {easy_sample['boxes'][:2].tolist()}")
457
- print(f" Hard center: {hard_sample['boxes'][:2].tolist()}")
458
 
459
  # Test backward-compatible alias
460
- ds2 = TrackingDataset(synthetic=True, synthetic_length=50)
461
  assert len(ds2) == 50
462
  sample2 = ds2[0]
463
- assert sample2['template'].shape == (3, 128, 128)
464
 
465
  test("Dataset (synthetic + backward compat)", test_dataset)
466
 
@@ -484,43 +488,59 @@ def test_training_step():
484
  contrastive_loss = MemoryContrastiveLoss()
485
  optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
486
 
487
- B = 2
488
  template = torch.randn(B, 3, 128, 128)
489
- search = torch.randn(B, 3, 256, 256)
490
-
491
- # GT targets
492
- gt_center = torch.tensor([[128.0, 128.0], [100.0, 150.0]])
493
- gt_heatmap = generate_heatmap(gt_center, feat_size=16, search_size=256)
494
- gt_size = torch.tensor([[0.2, 0.3], [0.15, 0.25]])
495
- gt_boxes = torch.tensor([[128.0, 128.0, 51.2, 76.8], [100.0, 150.0, 38.4, 64.0]])
496
-
497
- # Forward WITH temporal modulation
498
- pred = model(template, search, use_temporal=True)
499
- loss_dict = loss_fn(pred, gt_heatmap, gt_size, gt_boxes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
  # Add contrastive loss
502
  t_pooled = pred['template_feat'].mean(dim=1)
503
- s_pooled = pred['search_feat'].mean(dim=1)
504
  c_loss = contrastive_loss(t_pooled, s_pooled)
505
- total_loss = loss_dict['total'] + 0.1 * c_loss
506
 
507
  # Backward
508
  total_loss.backward()
509
 
510
- # Check gradients exist
511
  has_grads = sum(1 for p in model.parameters() if p.grad is not None)
512
  total_params_count = sum(1 for p in model.parameters())
513
- print(f" Total loss: {total_loss.item():.4f} (tracking={loss_dict['total'].item():.4f}, contr={c_loss.item():.4f})")
514
  print(f" Params with gradients: {has_grads}/{total_params_count}")
515
 
516
- # Optimizer step
517
  optimizer.step()
518
  optimizer.zero_grad()
519
 
520
  assert total_loss.item() > 0
521
  assert has_grads > 0
522
 
523
- test("Training Step (with temporal + contrastive)", test_training_step)
524
 
525
 
526
  # ============================================================
@@ -632,33 +652,29 @@ test("Augmentation pipeline", test_augmentation)
632
  def test_acl_curriculum():
633
  from vil_tracker.data.dataset import SyntheticTrackingDataset
634
 
635
- ds = SyntheticTrackingDataset(length=100, acl_difficulty=0.0)
636
 
637
- # Easy: targets near center
638
- easy_offsets = []
639
  for i in range(20):
640
  sample = ds[i]
641
- cx, cy = sample['boxes'][:2].tolist()
642
- offset = ((cx - 128) ** 2 + (cy - 128) ** 2) ** 0.5
643
- easy_offsets.append(offset)
644
 
645
  ds.set_acl_difficulty(1.0)
646
 
647
- hard_offsets = []
648
  for i in range(20):
649
  sample = ds[i]
650
- cx, cy = sample['boxes'][:2].tolist()
651
- offset = ((cx - 128) ** 2 + (cy - 128) ** 2) ** 0.5
652
- hard_offsets.append(offset)
653
-
654
- avg_easy = np.mean(easy_offsets)
655
- avg_hard = np.mean(hard_offsets)
656
-
657
- print(f" Avg offset (easy, d=0.0): {avg_easy:.1f} px")
658
- print(f" Avg offset (hard, d=1.0): {avg_hard:.1f} px")
659
- # Hard samples should have larger offsets from center on average
660
- # (this is stochastic, so we allow some tolerance)
661
- print(f" Hard > Easy: {avg_hard > avg_easy * 0.5}")
662
 
663
  test("ACL curriculum integration", test_acl_curriculum)
664
 
 
294
  params = count_params(tracker)
295
  print(f" Tracker (depth=4) params: {params:,} ({params/1e6:.3f}M)")
296
 
297
+ B, K = 2, 3
298
+ template = torch.randn(B, 3, 128, 128)
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ # Test single-frame (backward compat)
301
+ search_single = torch.randn(B, 3, 256, 256)
302
+ output_s = tracker(template, search_single, use_temporal=False)
303
+ assert output_s['heatmap'].shape == (B, 1, 16, 16), f"Single heatmap: {output_s['heatmap'].shape}"
304
+ assert output_s['boxes'].shape == (B, 4), f"Single boxes: {output_s['boxes'].shape}"
305
+ assert output_s['scores'].shape == (B,), f"Single scores: {output_s['scores'].shape}"
306
+ print(f" Single-frame: boxes={output_s['boxes'][0].tolist()}")
307
+
308
+ # Test multi-frame sequence
309
+ searches = torch.randn(B, K, 3, 256, 256)
310
+ output_m = tracker(template, searches, use_temporal=True)
311
+ assert output_m['heatmap'].shape == (B, K, 1, 16, 16), f"Multi heatmap: {output_m['heatmap'].shape}"
312
+ assert output_m['boxes'].shape == (B, K, 4), f"Multi boxes: {output_m['boxes'].shape}"
313
+ assert output_m['scores'].shape == (B, K), f"Multi scores: {output_m['scores'].shape}"
314
+ assert output_m['search_feats'].shape == (B, K, 256, 384), f"Multi feats: {output_m['search_feats'].shape}"
315
+ print(f" Multi-frame (K={K}): frame 0 box={output_m['boxes'][0,0].tolist()}")
316
+ print(f" frame 2 box={output_m['boxes'][0,2].tolist()}")
317
 
 
318
  tracker.reset_temporal()
 
 
 
319
 
320
+ test("Full Tracker (single + multi-frame)", test_full_tracker_small)
321
 
322
 
323
  # ============================================================
 
437
  def test_dataset():
438
  from vil_tracker.data.dataset import SyntheticTrackingDataset, TrackingDataset
439
 
440
+ ds = SyntheticTrackingDataset(length=100, clip_length=3)
441
  assert len(ds) == 100
442
 
443
  sample = ds[0]
444
  assert sample['template'].shape == (3, 128, 128), f"Template shape: {sample['template'].shape}"
445
+ assert sample['searches'].shape == (3, 3, 256, 256), f"Searches shape: {sample['searches'].shape}"
446
+ assert sample['heatmaps'].shape == (3, 1, 16, 16), f"Heatmaps shape: {sample['heatmaps'].shape}"
447
+ assert sample['sizes'].shape == (3, 2), f"Sizes shape: {sample['sizes'].shape}"
448
+ assert sample['boxes'].shape == (3, 4), f"Boxes shape: {sample['boxes'].shape}"
449
 
450
+ # Verify target moves across frames (not static)
451
+ cx_f0 = sample['boxes'][0, 0].item()
452
+ cx_f2 = sample['boxes'][2, 0].item()
453
+ print(f" Frame 0 cx: {cx_f0:.1f}, Frame 2 cx: {cx_f2:.1f} (moving target)")
454
+
455
+ # Check ACL difficulty changes motion magnitude
456
  ds.set_acl_difficulty(0.0)
457
  easy_sample = ds[42]
458
  ds.set_acl_difficulty(1.0)
459
  hard_sample = ds[42]
460
+ print(f" Easy frame spread: {(easy_sample['boxes'][:, 0].max() - easy_sample['boxes'][:, 0].min()).item():.1f} px")
461
+ print(f" Hard frame spread: {(hard_sample['boxes'][:, 0].max() - hard_sample['boxes'][:, 0].min()).item():.1f} px")
462
 
463
  # Test backward-compatible alias
464
+ ds2 = TrackingDataset(synthetic=True, synthetic_length=50, clip_length=3)
465
  assert len(ds2) == 50
466
  sample2 = ds2[0]
467
+ assert sample2['searches'].shape[0] == 3, "Clip length should be 3"
468
 
469
  test("Dataset (synthetic + backward compat)", test_dataset)
470
 
 
488
  contrastive_loss = MemoryContrastiveLoss()
489
  optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
490
 
491
+ B, K = 2, 3
492
  template = torch.randn(B, 3, 128, 128)
493
+ searches = torch.randn(B, K, 3, 256, 256)
494
+
495
+ # GT targets for K frames
496
+ gt_heatmaps = torch.zeros(B, K, 1, 16, 16)
497
+ gt_heatmaps[:, :, :, 8, 8] = 1.0 # center
498
+ gt_sizes = torch.tensor([[[0.2, 0.3]] * K] * B)
499
+ gt_boxes = torch.tensor([[[128.0, 128.0, 51.2, 76.8]] * K] * B)
500
+
501
+ # Forward WITH temporal modulation, multi-frame
502
+ pred = model(template, searches, use_temporal=True)
503
+
504
+ assert pred['heatmap'].shape == (B, K, 1, 16, 16), f"Heatmap shape: {pred['heatmap'].shape}"
505
+ assert pred['boxes'].shape == (B, K, 4), f"Boxes shape: {pred['boxes'].shape}"
506
+ assert pred['scores'].shape == (B, K), f"Scores shape: {pred['scores'].shape}"
507
+ assert pred['search_feats'].shape == (B, K, 256, 384), f"Search feats: {pred['search_feats'].shape}"
508
+
509
+ # Accumulate loss over K frames
510
+ total_loss = torch.tensor(0.0)
511
+ for k in range(K):
512
+ pred_k = {
513
+ 'heatmap': pred['heatmap'][:, k],
514
+ 'size': pred['size'][:, k],
515
+ 'boxes': pred['boxes'][:, k],
516
+ }
517
+ if 'log_variance' in pred:
518
+ pred_k['log_variance'] = pred['log_variance'][:, k]
519
+ loss_dict = loss_fn(pred_k, gt_heatmaps[:, k], gt_sizes[:, k], gt_boxes[:, k])
520
+ total_loss = total_loss + loss_dict['total']
521
+ total_loss = total_loss / K
522
 
523
  # Add contrastive loss
524
  t_pooled = pred['template_feat'].mean(dim=1)
525
+ s_pooled = pred['search_feats'][:, -1].mean(dim=1)
526
  c_loss = contrastive_loss(t_pooled, s_pooled)
527
+ total_loss = total_loss + 0.1 * c_loss
528
 
529
  # Backward
530
  total_loss.backward()
531
 
 
532
  has_grads = sum(1 for p in model.parameters() if p.grad is not None)
533
  total_params_count = sum(1 for p in model.parameters())
534
+ print(f" Total loss: {total_loss.item():.4f} (K={K} frames, contr={c_loss.item():.4f})")
535
  print(f" Params with gradients: {has_grads}/{total_params_count}")
536
 
 
537
  optimizer.step()
538
  optimizer.zero_grad()
539
 
540
  assert total_loss.item() > 0
541
  assert has_grads > 0
542
 
543
+ test("Training Step (K=3 sequence + contrastive)", test_training_step)
544
 
545
 
546
  # ============================================================
 
652
  def test_acl_curriculum():
653
  from vil_tracker.data.dataset import SyntheticTrackingDataset
654
 
655
+ ds = SyntheticTrackingDataset(length=100, acl_difficulty=0.0, clip_length=3)
656
 
657
+ # Easy: targets barely move
658
+ easy_spreads = []
659
  for i in range(20):
660
  sample = ds[i]
661
+ spread = (sample['boxes'][:, 0].max() - sample['boxes'][:, 0].min()).item()
662
+ easy_spreads.append(spread)
 
663
 
664
  ds.set_acl_difficulty(1.0)
665
 
666
+ hard_spreads = []
667
  for i in range(20):
668
  sample = ds[i]
669
+ spread = (sample['boxes'][:, 0].max() - sample['boxes'][:, 0].min()).item()
670
+ hard_spreads.append(spread)
671
+
672
+ avg_easy = np.mean(easy_spreads)
673
+ avg_hard = np.mean(hard_spreads)
674
+
675
+ print(f" Avg cx spread (easy, d=0.0): {avg_easy:.1f} px")
676
+ print(f" Avg cx spread (hard, d=1.0): {avg_hard:.1f} px")
677
+ print(f" Hard > Easy: {avg_hard > avg_easy}")
 
 
 
678
 
679
  test("ACL curriculum integration", test_acl_curriculum)
680