Sequence training: pairs→K-frame clips, mLSTM memory carries across frames
Browse files- test_all.py +84 -68
test_all.py
CHANGED
|
@@ -294,31 +294,30 @@ def test_full_tracker_small():
|
|
| 294 |
params = count_params(tracker)
|
| 295 |
print(f" Tracker (depth=4) params: {params:,} ({params/1e6:.3f}M)")
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
# Test without temporal
|
| 301 |
-
output = tracker(template, search, use_temporal=False)
|
| 302 |
-
assert output['heatmap'].shape == (2, 1, 16, 16)
|
| 303 |
-
assert output['boxes'].shape == (2, 4)
|
| 304 |
-
assert output['scores'].shape == (2,)
|
| 305 |
-
assert 'log_variance' in output
|
| 306 |
-
|
| 307 |
-
# Test with temporal (first frame: no context)
|
| 308 |
-
output_t1 = tracker(template, search, use_temporal=True)
|
| 309 |
-
assert output_t1['boxes'].shape == (2, 4)
|
| 310 |
|
| 311 |
-
#
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
-
# Reset temporal
|
| 316 |
tracker.reset_temporal()
|
| 317 |
-
|
| 318 |
-
print(f" Predicted boxes: {output['boxes'][0].tolist()}")
|
| 319 |
-
print(f" Scores: {output['scores'].tolist()}")
|
| 320 |
|
| 321 |
-
test("Full Tracker (
|
| 322 |
|
| 323 |
|
| 324 |
# ============================================================
|
|
@@ -438,29 +437,34 @@ test("Kalman Filter (8-state, adaptive)", test_kalman)
|
|
| 438 |
def test_dataset():
|
| 439 |
from vil_tracker.data.dataset import SyntheticTrackingDataset, TrackingDataset
|
| 440 |
|
| 441 |
-
ds = SyntheticTrackingDataset(length=100)
|
| 442 |
assert len(ds) == 100
|
| 443 |
|
| 444 |
sample = ds[0]
|
| 445 |
assert sample['template'].shape == (3, 128, 128), f"Template shape: {sample['template'].shape}"
|
| 446 |
-
assert sample['
|
| 447 |
-
assert sample['
|
| 448 |
-
assert sample['
|
| 449 |
-
assert sample['boxes'].shape == (
|
| 450 |
|
| 451 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
ds.set_acl_difficulty(0.0)
|
| 453 |
easy_sample = ds[42]
|
| 454 |
ds.set_acl_difficulty(1.0)
|
| 455 |
hard_sample = ds[42]
|
| 456 |
-
print(f" Easy
|
| 457 |
-
print(f" Hard
|
| 458 |
|
| 459 |
# Test backward-compatible alias
|
| 460 |
-
ds2 = TrackingDataset(synthetic=True, synthetic_length=50)
|
| 461 |
assert len(ds2) == 50
|
| 462 |
sample2 = ds2[0]
|
| 463 |
-
assert sample2['
|
| 464 |
|
| 465 |
test("Dataset (synthetic + backward compat)", test_dataset)
|
| 466 |
|
|
@@ -484,43 +488,59 @@ def test_training_step():
|
|
| 484 |
contrastive_loss = MemoryContrastiveLoss()
|
| 485 |
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
| 486 |
|
| 487 |
-
B = 2
|
| 488 |
template = torch.randn(B, 3, 128, 128)
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
# GT targets
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
gt_boxes = torch.tensor([[128.0, 128.0, 51.2, 76.8]
|
| 496 |
-
|
| 497 |
-
# Forward WITH temporal modulation
|
| 498 |
-
pred = model(template,
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
|
| 501 |
# Add contrastive loss
|
| 502 |
t_pooled = pred['template_feat'].mean(dim=1)
|
| 503 |
-
s_pooled = pred['
|
| 504 |
c_loss = contrastive_loss(t_pooled, s_pooled)
|
| 505 |
-
total_loss =
|
| 506 |
|
| 507 |
# Backward
|
| 508 |
total_loss.backward()
|
| 509 |
|
| 510 |
-
# Check gradients exist
|
| 511 |
has_grads = sum(1 for p in model.parameters() if p.grad is not None)
|
| 512 |
total_params_count = sum(1 for p in model.parameters())
|
| 513 |
-
print(f" Total loss: {total_loss.item():.4f} (
|
| 514 |
print(f" Params with gradients: {has_grads}/{total_params_count}")
|
| 515 |
|
| 516 |
-
# Optimizer step
|
| 517 |
optimizer.step()
|
| 518 |
optimizer.zero_grad()
|
| 519 |
|
| 520 |
assert total_loss.item() > 0
|
| 521 |
assert has_grads > 0
|
| 522 |
|
| 523 |
-
test("Training Step (
|
| 524 |
|
| 525 |
|
| 526 |
# ============================================================
|
|
@@ -632,33 +652,29 @@ test("Augmentation pipeline", test_augmentation)
|
|
| 632 |
def test_acl_curriculum():
|
| 633 |
from vil_tracker.data.dataset import SyntheticTrackingDataset
|
| 634 |
|
| 635 |
-
ds = SyntheticTrackingDataset(length=100, acl_difficulty=0.0)
|
| 636 |
|
| 637 |
-
# Easy: targets
|
| 638 |
-
|
| 639 |
for i in range(20):
|
| 640 |
sample = ds[i]
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
easy_offsets.append(offset)
|
| 644 |
|
| 645 |
ds.set_acl_difficulty(1.0)
|
| 646 |
|
| 647 |
-
|
| 648 |
for i in range(20):
|
| 649 |
sample = ds[i]
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
print(f" Avg
|
| 658 |
-
print(f"
|
| 659 |
-
# Hard samples should have larger offsets from center on average
|
| 660 |
-
# (this is stochastic, so we allow some tolerance)
|
| 661 |
-
print(f" Hard > Easy: {avg_hard > avg_easy * 0.5}")
|
| 662 |
|
| 663 |
test("ACL curriculum integration", test_acl_curriculum)
|
| 664 |
|
|
|
|
| 294 |
params = count_params(tracker)
|
| 295 |
print(f" Tracker (depth=4) params: {params:,} ({params/1e6:.3f}M)")
|
| 296 |
|
| 297 |
+
B, K = 2, 3
|
| 298 |
+
template = torch.randn(B, 3, 128, 128)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
+
# Test single-frame (backward compat)
|
| 301 |
+
search_single = torch.randn(B, 3, 256, 256)
|
| 302 |
+
output_s = tracker(template, search_single, use_temporal=False)
|
| 303 |
+
assert output_s['heatmap'].shape == (B, 1, 16, 16), f"Single heatmap: {output_s['heatmap'].shape}"
|
| 304 |
+
assert output_s['boxes'].shape == (B, 4), f"Single boxes: {output_s['boxes'].shape}"
|
| 305 |
+
assert output_s['scores'].shape == (B,), f"Single scores: {output_s['scores'].shape}"
|
| 306 |
+
print(f" Single-frame: boxes={output_s['boxes'][0].tolist()}")
|
| 307 |
+
|
| 308 |
+
# Test multi-frame sequence
|
| 309 |
+
searches = torch.randn(B, K, 3, 256, 256)
|
| 310 |
+
output_m = tracker(template, searches, use_temporal=True)
|
| 311 |
+
assert output_m['heatmap'].shape == (B, K, 1, 16, 16), f"Multi heatmap: {output_m['heatmap'].shape}"
|
| 312 |
+
assert output_m['boxes'].shape == (B, K, 4), f"Multi boxes: {output_m['boxes'].shape}"
|
| 313 |
+
assert output_m['scores'].shape == (B, K), f"Multi scores: {output_m['scores'].shape}"
|
| 314 |
+
assert output_m['search_feats'].shape == (B, K, 256, 384), f"Multi feats: {output_m['search_feats'].shape}"
|
| 315 |
+
print(f" Multi-frame (K={K}): frame 0 box={output_m['boxes'][0,0].tolist()}")
|
| 316 |
+
print(f" frame 2 box={output_m['boxes'][0,2].tolist()}")
|
| 317 |
|
|
|
|
| 318 |
tracker.reset_temporal()
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
+
test("Full Tracker (single + multi-frame)", test_full_tracker_small)
|
| 321 |
|
| 322 |
|
| 323 |
# ============================================================
|
|
|
|
| 437 |
def test_dataset():
|
| 438 |
from vil_tracker.data.dataset import SyntheticTrackingDataset, TrackingDataset
|
| 439 |
|
| 440 |
+
ds = SyntheticTrackingDataset(length=100, clip_length=3)
|
| 441 |
assert len(ds) == 100
|
| 442 |
|
| 443 |
sample = ds[0]
|
| 444 |
assert sample['template'].shape == (3, 128, 128), f"Template shape: {sample['template'].shape}"
|
| 445 |
+
assert sample['searches'].shape == (3, 3, 256, 256), f"Searches shape: {sample['searches'].shape}"
|
| 446 |
+
assert sample['heatmaps'].shape == (3, 1, 16, 16), f"Heatmaps shape: {sample['heatmaps'].shape}"
|
| 447 |
+
assert sample['sizes'].shape == (3, 2), f"Sizes shape: {sample['sizes'].shape}"
|
| 448 |
+
assert sample['boxes'].shape == (3, 4), f"Boxes shape: {sample['boxes'].shape}"
|
| 449 |
|
| 450 |
+
# Verify target moves across frames (not static)
|
| 451 |
+
cx_f0 = sample['boxes'][0, 0].item()
|
| 452 |
+
cx_f2 = sample['boxes'][2, 0].item()
|
| 453 |
+
print(f" Frame 0 cx: {cx_f0:.1f}, Frame 2 cx: {cx_f2:.1f} (moving target)")
|
| 454 |
+
|
| 455 |
+
# Check ACL difficulty changes motion magnitude
|
| 456 |
ds.set_acl_difficulty(0.0)
|
| 457 |
easy_sample = ds[42]
|
| 458 |
ds.set_acl_difficulty(1.0)
|
| 459 |
hard_sample = ds[42]
|
| 460 |
+
print(f" Easy frame spread: {(easy_sample['boxes'][:, 0].max() - easy_sample['boxes'][:, 0].min()).item():.1f} px")
|
| 461 |
+
print(f" Hard frame spread: {(hard_sample['boxes'][:, 0].max() - hard_sample['boxes'][:, 0].min()).item():.1f} px")
|
| 462 |
|
| 463 |
# Test backward-compatible alias
|
| 464 |
+
ds2 = TrackingDataset(synthetic=True, synthetic_length=50, clip_length=3)
|
| 465 |
assert len(ds2) == 50
|
| 466 |
sample2 = ds2[0]
|
| 467 |
+
assert sample2['searches'].shape[0] == 3, "Clip length should be 3"
|
| 468 |
|
| 469 |
test("Dataset (synthetic + backward compat)", test_dataset)
|
| 470 |
|
|
|
|
| 488 |
contrastive_loss = MemoryContrastiveLoss()
|
| 489 |
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
| 490 |
|
| 491 |
+
B, K = 2, 3
|
| 492 |
template = torch.randn(B, 3, 128, 128)
|
| 493 |
+
searches = torch.randn(B, K, 3, 256, 256)
|
| 494 |
+
|
| 495 |
+
# GT targets for K frames
|
| 496 |
+
gt_heatmaps = torch.zeros(B, K, 1, 16, 16)
|
| 497 |
+
gt_heatmaps[:, :, :, 8, 8] = 1.0 # center
|
| 498 |
+
gt_sizes = torch.tensor([[[0.2, 0.3]] * K] * B)
|
| 499 |
+
gt_boxes = torch.tensor([[[128.0, 128.0, 51.2, 76.8]] * K] * B)
|
| 500 |
+
|
| 501 |
+
# Forward WITH temporal modulation, multi-frame
|
| 502 |
+
pred = model(template, searches, use_temporal=True)
|
| 503 |
+
|
| 504 |
+
assert pred['heatmap'].shape == (B, K, 1, 16, 16), f"Heatmap shape: {pred['heatmap'].shape}"
|
| 505 |
+
assert pred['boxes'].shape == (B, K, 4), f"Boxes shape: {pred['boxes'].shape}"
|
| 506 |
+
assert pred['scores'].shape == (B, K), f"Scores shape: {pred['scores'].shape}"
|
| 507 |
+
assert pred['search_feats'].shape == (B, K, 256, 384), f"Search feats: {pred['search_feats'].shape}"
|
| 508 |
+
|
| 509 |
+
# Accumulate loss over K frames
|
| 510 |
+
total_loss = torch.tensor(0.0)
|
| 511 |
+
for k in range(K):
|
| 512 |
+
pred_k = {
|
| 513 |
+
'heatmap': pred['heatmap'][:, k],
|
| 514 |
+
'size': pred['size'][:, k],
|
| 515 |
+
'boxes': pred['boxes'][:, k],
|
| 516 |
+
}
|
| 517 |
+
if 'log_variance' in pred:
|
| 518 |
+
pred_k['log_variance'] = pred['log_variance'][:, k]
|
| 519 |
+
loss_dict = loss_fn(pred_k, gt_heatmaps[:, k], gt_sizes[:, k], gt_boxes[:, k])
|
| 520 |
+
total_loss = total_loss + loss_dict['total']
|
| 521 |
+
total_loss = total_loss / K
|
| 522 |
|
| 523 |
# Add contrastive loss
|
| 524 |
t_pooled = pred['template_feat'].mean(dim=1)
|
| 525 |
+
s_pooled = pred['search_feats'][:, -1].mean(dim=1)
|
| 526 |
c_loss = contrastive_loss(t_pooled, s_pooled)
|
| 527 |
+
total_loss = total_loss + 0.1 * c_loss
|
| 528 |
|
| 529 |
# Backward
|
| 530 |
total_loss.backward()
|
| 531 |
|
|
|
|
| 532 |
has_grads = sum(1 for p in model.parameters() if p.grad is not None)
|
| 533 |
total_params_count = sum(1 for p in model.parameters())
|
| 534 |
+
print(f" Total loss: {total_loss.item():.4f} (K={K} frames, contr={c_loss.item():.4f})")
|
| 535 |
print(f" Params with gradients: {has_grads}/{total_params_count}")
|
| 536 |
|
|
|
|
| 537 |
optimizer.step()
|
| 538 |
optimizer.zero_grad()
|
| 539 |
|
| 540 |
assert total_loss.item() > 0
|
| 541 |
assert has_grads > 0
|
| 542 |
|
| 543 |
+
test("Training Step (K=3 sequence + contrastive)", test_training_step)
|
| 544 |
|
| 545 |
|
| 546 |
# ============================================================
|
|
|
|
| 652 |
def test_acl_curriculum():
|
| 653 |
from vil_tracker.data.dataset import SyntheticTrackingDataset
|
| 654 |
|
| 655 |
+
ds = SyntheticTrackingDataset(length=100, acl_difficulty=0.0, clip_length=3)
|
| 656 |
|
| 657 |
+
# Easy: targets barely move
|
| 658 |
+
easy_spreads = []
|
| 659 |
for i in range(20):
|
| 660 |
sample = ds[i]
|
| 661 |
+
spread = (sample['boxes'][:, 0].max() - sample['boxes'][:, 0].min()).item()
|
| 662 |
+
easy_spreads.append(spread)
|
|
|
|
| 663 |
|
| 664 |
ds.set_acl_difficulty(1.0)
|
| 665 |
|
| 666 |
+
hard_spreads = []
|
| 667 |
for i in range(20):
|
| 668 |
sample = ds[i]
|
| 669 |
+
spread = (sample['boxes'][:, 0].max() - sample['boxes'][:, 0].min()).item()
|
| 670 |
+
hard_spreads.append(spread)
|
| 671 |
+
|
| 672 |
+
avg_easy = np.mean(easy_spreads)
|
| 673 |
+
avg_hard = np.mean(hard_spreads)
|
| 674 |
+
|
| 675 |
+
print(f" Avg cx spread (easy, d=0.0): {avg_easy:.1f} px")
|
| 676 |
+
print(f" Avg cx spread (hard, d=1.0): {avg_hard:.1f} px")
|
| 677 |
+
print(f" Hard > Easy: {avg_hard > avg_easy}")
|
|
|
|
|
|
|
|
|
|
| 678 |
|
| 679 |
test("ACL curriculum integration", test_acl_curriculum)
|
| 680 |
|