File size: 5,365 Bytes
b3b0529 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | """
ViL Tracker: Full model combining backbone, FiLM modulation, and prediction heads.
Pipeline:
1. Template (128x128) + Search (256x256) → PatchEmbed → tokens
2. Concatenated tokens → ViL backbone (24 mLSTM blocks, bidirectional)
3. FiLM temporal modulation at intervals (conditioned on prev frame)
4. Search features → CenterHead → heatmap + size + offset
5. Optional: UncertaintyHead → log variance for adaptive weighting
"""
import torch
import torch.nn as nn
from .backbone import ViLBackbone
from .film_temporal import TemporalModulationManager
from .heads import CenterHead, UncertaintyHead, decode_predictions
def get_default_config() -> dict:
"""Default ViL-S tracker configuration meeting all constraints.
Constraints: ≤50M params, ≤30ms latency, ≤20 GFLOPs, ≤500MB
"""
return {
# Backbone
'dim': 384,
'depth': 24,
'patch_size': 16,
'proj_factor': 2.0,
'qkv_proj_blocksize': 4,
'num_heads': 4,
'conv_kernel': 4,
'mlp_ratio': 4.0,
'drop_path_rate': 0.1,
'tmoe_blocks': 2,
'num_experts': 4,
# FiLM temporal modulation
'film_interval': 6,
# Heads
'feat_size': 16,
# Inputs
'template_size': 128,
'search_size': 256,
# Uncertainty
'use_uncertainty': True,
}
class ViLTracker(nn.Module):
"""Complete ViL-based single object tracker.
Target specs (ViL-S):
- Parameters: ~35-40M (well under 50M limit)
- GFLOPs: ~15-18 (under 20 GFLOPs)
- Model size: ~140-160MB fp32, ~70-80MB fp16 (under 500MB)
- Latency: ~20-25ms on GPU (under 30ms)
"""
def __init__(self, config: dict = None):
super().__init__()
config = config or get_default_config()
self.config = config
dim = config['dim']
depth = config['depth']
# Backbone
self.backbone = ViLBackbone(
dim=dim,
depth=depth,
patch_size=config['patch_size'],
proj_factor=config['proj_factor'],
qkv_proj_blocksize=config['qkv_proj_blocksize'],
num_heads=config['num_heads'],
conv_kernel=config['conv_kernel'],
mlp_ratio=config['mlp_ratio'],
drop_path_rate=config['drop_path_rate'],
tmoe_blocks=config['tmoe_blocks'],
num_experts=config['num_experts'],
)
# FiLM temporal modulation
self.temporal_mod = TemporalModulationManager(
dim=dim,
num_blocks=depth,
modulation_interval=config['film_interval'],
)
# Prediction heads
self.center_head = CenterHead(dim=dim, feat_size=config['feat_size'])
if config.get('use_uncertainty', True):
self.uncertainty_head = UncertaintyHead(dim=dim, feat_size=config['feat_size'])
else:
self.uncertainty_head = None
def forward(
self,
template: torch.Tensor,
search: torch.Tensor,
use_temporal: bool = False,
) -> dict:
"""
Args:
template: (B, 3, 128, 128) template image
search: (B, 3, 256, 256) search region
use_temporal: whether to apply FiLM temporal modulation
Returns:
dict with predictions: heatmap, size, offset, boxes, scores,
and optionally uncertainty
"""
# Backbone forward
template_feat, search_feat = self.backbone(template, search)
# Optional FiLM temporal modulation on search features
if use_temporal:
for i in range(self.backbone.depth):
if self.temporal_mod.should_modulate(i):
search_feat = self.temporal_mod.modulate(search_feat, i)
# Update temporal context for next frame
self.temporal_mod.update_temporal_context(search_feat)
# Prediction heads
preds = self.center_head(search_feat)
# Decode to boxes
boxes, scores = decode_predictions(
preds['heatmap'],
preds['size'],
preds['offset'],
search_size=self.config['search_size'],
feat_size=self.config['feat_size'],
)
output = {
'heatmap': preds['heatmap'],
'size': preds['size'],
'offset': preds['offset'],
'boxes': boxes,
'scores': scores,
'template_feat': template_feat,
'search_feat': search_feat,
}
# Uncertainty prediction
if self.uncertainty_head is not None:
output['log_variance'] = self.uncertainty_head(search_feat)
return output
def reset_temporal(self):
"""Reset temporal modulation state (for new tracking sequence)."""
self.temporal_mod.reset()
def freeze_backbone_shared_experts(self):
"""Freeze shared experts in TMoE blocks for Phase 2."""
self.backbone.freeze_shared_experts()
def build_tracker(config: dict = None) -> ViLTracker:
"""Build a ViL tracker with given or default config."""
return ViLTracker(config or get_default_config())
|