File size: 7,539 Bytes
b3b0529 8237685 b3b0529 3e094bb b3b0529 8237685 b3b0529 8237685 b3b0529 8237685 b3b0529 be1f14e b3b0529 be1f14e b3b0529 be1f14e b3b0529 be1f14e b3b0529 be1f14e 8237685 be1f14e b3b0529 be1f14e b3b0529 be1f14e b3b0529 be1f14e b3b0529 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | """
ViL Tracker: Full model combining backbone, FiLM modulation, and prediction heads.
Pipeline:
1. Template (128x128) + Search (256x256) β PatchEmbed β tokens
2. Concatenated tokens β ViL backbone (24 mLSTM blocks, bidirectional)
3. FiLM temporal modulation integrated BETWEEN backbone blocks
4. Search features β CenterHead β heatmap + size + offset
5. Optional: UncertaintyHead β log variance for adaptive weighting
"""
import torch
import torch.nn as nn
from .backbone import ViLBackbone
from .film_temporal import TemporalModulationManager
from .heads import CenterHead, UncertaintyHead, decode_predictions
def get_default_config() -> dict:
"""Default ViL-S tracker configuration meeting all constraints.
Constraints: β€50M params, β€30ms latency, β€20 GFLOPs, β€500MB
"""
return {
# Backbone
'dim': 384,
'depth': 24,
'patch_size': 16,
'proj_factor': 2.0,
'qkv_proj_blocksize': 4,
'num_heads': 4,
'conv_kernel': 4,
'mlp_ratio': 4.0,
'drop_path_rate': 0.05,
'tmoe_blocks': 2,
'num_experts': 4,
# FiLM temporal modulation
'film_interval': 6,
# Heads
'feat_size': 16,
# Inputs
'template_size': 128,
'search_size': 256,
# Uncertainty
'use_uncertainty': True,
}
class ViLTracker(nn.Module):
"""Complete ViL-based single object tracker.
Target specs (ViL-S):
- Parameters: ~35-40M (well under 50M limit)
- GFLOPs: ~15-18 (under 20 GFLOPs)
- Model size: ~140-160MB fp32, ~70-80MB fp16 (under 500MB)
- Latency: ~20-25ms on GPU (under 30ms)
"""
def __init__(self, config: dict = None):
super().__init__()
config = config or get_default_config()
self.config = config
dim = config['dim']
depth = config['depth']
# Backbone (now accepts temporal_mod_manager as forward arg)
self.backbone = ViLBackbone(
dim=dim,
depth=depth,
patch_size=config['patch_size'],
proj_factor=config['proj_factor'],
qkv_proj_blocksize=config['qkv_proj_blocksize'],
num_heads=config['num_heads'],
conv_kernel=config['conv_kernel'],
mlp_ratio=config['mlp_ratio'],
drop_path_rate=config['drop_path_rate'],
tmoe_blocks=config['tmoe_blocks'],
num_experts=config['num_experts'],
film_interval=config.get('film_interval', 6),
)
# FiLM temporal modulation (applied BETWEEN backbone blocks)
self.temporal_mod = TemporalModulationManager(
dim=dim,
num_blocks=depth,
modulation_interval=config['film_interval'],
)
# Prediction heads
self.center_head = CenterHead(dim=dim, feat_size=config['feat_size'])
if config.get('use_uncertainty', True):
self.uncertainty_head = UncertaintyHead(dim=dim, feat_size=config['feat_size'])
else:
self.uncertainty_head = None
def forward(
self,
template: torch.Tensor,
searches: torch.Tensor,
use_temporal: bool = False,
) -> dict:
"""
Process template + K search frames through the full tracker.
Args:
template: (B, 3, 128, 128) template image
searches: (B, K, 3, 256, 256) K consecutive search frames
OR (B, 3, 256, 256) single search frame (backward compat)
use_temporal: whether to apply FiLM temporal modulation
Returns:
dict with per-frame predictions:
heatmap: (B, K, 1, 16, 16) or (B, 1, 16, 16) if single
size: (B, K, 2, 16, 16) or (B, 2, 16, 16)
offset: (B, K, 2, 16, 16) or (B, 2, 16, 16)
boxes: (B, K, 4) or (B, 4)
scores: (B, K) or (B,)
template_feat: (B, 64, D)
search_feats: (B, K, 256, D) or (B, 256, D)
"""
single_frame = (searches.ndim == 4)
temporal_mgr = self.temporal_mod if use_temporal else None
template_feat, search_feats = self.backbone(template, searches, temporal_mod_manager=temporal_mgr)
# search_feats: (B, K, 256, D) for multi-frame, (B, 256, D) for single
if single_frame:
# Single frame path β same as before
preds = self.center_head(search_feats)
boxes, scores = decode_predictions(
preds['heatmap'], preds['size'], preds['offset'],
search_size=self.config['search_size'],
feat_size=self.config['feat_size'],
)
output = {
'heatmap': preds['heatmap'],
'size': preds['size'],
'offset': preds['offset'],
'boxes': boxes,
'scores': scores,
'template_feat': template_feat,
'search_feat': search_feats,
}
if self.uncertainty_head is not None:
output['log_variance'] = self.uncertainty_head(search_feats)
return output
# Multi-frame path: run head on each frame's search features
B, K = search_feats.shape[:2]
all_heatmaps, all_sizes, all_offsets = [], [], []
all_boxes, all_scores = [], []
all_log_var = []
for k in range(K):
s_feat_k = search_feats[:, k] # (B, 256, D)
preds_k = self.center_head(s_feat_k)
boxes_k, scores_k = decode_predictions(
preds_k['heatmap'], preds_k['size'], preds_k['offset'],
search_size=self.config['search_size'],
feat_size=self.config['feat_size'],
)
all_heatmaps.append(preds_k['heatmap'])
all_sizes.append(preds_k['size'])
all_offsets.append(preds_k['offset'])
all_boxes.append(boxes_k)
all_scores.append(scores_k)
if self.uncertainty_head is not None:
all_log_var.append(self.uncertainty_head(s_feat_k))
output = {
'heatmap': torch.stack(all_heatmaps, dim=1), # (B, K, 1, 16, 16)
'size': torch.stack(all_sizes, dim=1), # (B, K, 2, 16, 16)
'offset': torch.stack(all_offsets, dim=1), # (B, K, 2, 16, 16)
'boxes': torch.stack(all_boxes, dim=1), # (B, K, 4)
'scores': torch.stack(all_scores, dim=1), # (B, K)
'template_feat': template_feat, # (B, 64, D)
'search_feats': search_feats, # (B, K, 256, D)
}
if self.uncertainty_head is not None and all_log_var:
output['log_variance'] = torch.stack(all_log_var, dim=1) # (B, K, 1, 16, 16)
return output
def reset_temporal(self):
"""Reset temporal modulation state (for new tracking sequence)."""
self.temporal_mod.reset()
def freeze_backbone_shared_experts(self):
"""Freeze shared experts in TMoE blocks for Phase 2."""
self.backbone.freeze_shared_experts()
def build_tracker(config: dict = None) -> ViLTracker:
"""Build a ViL tracker with given or default config."""
return ViLTracker(config or get_default_config())
|