omar-ah's picture
Sequence training: pairs→K-frame clips, mLSTM memory carries across frames
be1f14e verified
"""
ViL Tracker: Full model combining backbone, FiLM modulation, and prediction heads.
Pipeline:
1. Template (128x128) + Search (256x256) → PatchEmbed → tokens
2. Concatenated tokens → ViL backbone (24 mLSTM blocks, bidirectional)
3. FiLM temporal modulation integrated BETWEEN backbone blocks
4. Search features → CenterHead → heatmap + size + offset
5. Optional: UncertaintyHead → log variance for adaptive weighting
"""
import torch
import torch.nn as nn
from .backbone import ViLBackbone
from .film_temporal import TemporalModulationManager
from .heads import CenterHead, UncertaintyHead, decode_predictions
def get_default_config() -> dict:
"""Default ViL-S tracker configuration meeting all constraints.
Constraints: ≤50M params, ≤30ms latency, ≤20 GFLOPs, ≤500MB
"""
return {
# Backbone
'dim': 384,
'depth': 24,
'patch_size': 16,
'proj_factor': 2.0,
'qkv_proj_blocksize': 4,
'num_heads': 4,
'conv_kernel': 4,
'mlp_ratio': 4.0,
'drop_path_rate': 0.05,
'tmoe_blocks': 2,
'num_experts': 4,
# FiLM temporal modulation
'film_interval': 6,
# Heads
'feat_size': 16,
# Inputs
'template_size': 128,
'search_size': 256,
# Uncertainty
'use_uncertainty': True,
}
class ViLTracker(nn.Module):
"""Complete ViL-based single object tracker.
Target specs (ViL-S):
- Parameters: ~35-40M (well under 50M limit)
- GFLOPs: ~15-18 (under 20 GFLOPs)
- Model size: ~140-160MB fp32, ~70-80MB fp16 (under 500MB)
- Latency: ~20-25ms on GPU (under 30ms)
"""
def __init__(self, config: dict = None):
super().__init__()
config = config or get_default_config()
self.config = config
dim = config['dim']
depth = config['depth']
# Backbone (now accepts temporal_mod_manager as forward arg)
self.backbone = ViLBackbone(
dim=dim,
depth=depth,
patch_size=config['patch_size'],
proj_factor=config['proj_factor'],
qkv_proj_blocksize=config['qkv_proj_blocksize'],
num_heads=config['num_heads'],
conv_kernel=config['conv_kernel'],
mlp_ratio=config['mlp_ratio'],
drop_path_rate=config['drop_path_rate'],
tmoe_blocks=config['tmoe_blocks'],
num_experts=config['num_experts'],
film_interval=config.get('film_interval', 6),
)
# FiLM temporal modulation (applied BETWEEN backbone blocks)
self.temporal_mod = TemporalModulationManager(
dim=dim,
num_blocks=depth,
modulation_interval=config['film_interval'],
)
# Prediction heads
self.center_head = CenterHead(dim=dim, feat_size=config['feat_size'])
if config.get('use_uncertainty', True):
self.uncertainty_head = UncertaintyHead(dim=dim, feat_size=config['feat_size'])
else:
self.uncertainty_head = None
def forward(
self,
template: torch.Tensor,
searches: torch.Tensor,
use_temporal: bool = False,
) -> dict:
"""
Process template + K search frames through the full tracker.
Args:
template: (B, 3, 128, 128) template image
searches: (B, K, 3, 256, 256) K consecutive search frames
OR (B, 3, 256, 256) single search frame (backward compat)
use_temporal: whether to apply FiLM temporal modulation
Returns:
dict with per-frame predictions:
heatmap: (B, K, 1, 16, 16) or (B, 1, 16, 16) if single
size: (B, K, 2, 16, 16) or (B, 2, 16, 16)
offset: (B, K, 2, 16, 16) or (B, 2, 16, 16)
boxes: (B, K, 4) or (B, 4)
scores: (B, K) or (B,)
template_feat: (B, 64, D)
search_feats: (B, K, 256, D) or (B, 256, D)
"""
single_frame = (searches.ndim == 4)
temporal_mgr = self.temporal_mod if use_temporal else None
template_feat, search_feats = self.backbone(template, searches, temporal_mod_manager=temporal_mgr)
# search_feats: (B, K, 256, D) for multi-frame, (B, 256, D) for single
if single_frame:
# Single frame path — same as before
preds = self.center_head(search_feats)
boxes, scores = decode_predictions(
preds['heatmap'], preds['size'], preds['offset'],
search_size=self.config['search_size'],
feat_size=self.config['feat_size'],
)
output = {
'heatmap': preds['heatmap'],
'size': preds['size'],
'offset': preds['offset'],
'boxes': boxes,
'scores': scores,
'template_feat': template_feat,
'search_feat': search_feats,
}
if self.uncertainty_head is not None:
output['log_variance'] = self.uncertainty_head(search_feats)
return output
# Multi-frame path: run head on each frame's search features
B, K = search_feats.shape[:2]
all_heatmaps, all_sizes, all_offsets = [], [], []
all_boxes, all_scores = [], []
all_log_var = []
for k in range(K):
s_feat_k = search_feats[:, k] # (B, 256, D)
preds_k = self.center_head(s_feat_k)
boxes_k, scores_k = decode_predictions(
preds_k['heatmap'], preds_k['size'], preds_k['offset'],
search_size=self.config['search_size'],
feat_size=self.config['feat_size'],
)
all_heatmaps.append(preds_k['heatmap'])
all_sizes.append(preds_k['size'])
all_offsets.append(preds_k['offset'])
all_boxes.append(boxes_k)
all_scores.append(scores_k)
if self.uncertainty_head is not None:
all_log_var.append(self.uncertainty_head(s_feat_k))
output = {
'heatmap': torch.stack(all_heatmaps, dim=1), # (B, K, 1, 16, 16)
'size': torch.stack(all_sizes, dim=1), # (B, K, 2, 16, 16)
'offset': torch.stack(all_offsets, dim=1), # (B, K, 2, 16, 16)
'boxes': torch.stack(all_boxes, dim=1), # (B, K, 4)
'scores': torch.stack(all_scores, dim=1), # (B, K)
'template_feat': template_feat, # (B, 64, D)
'search_feats': search_feats, # (B, K, 256, D)
}
if self.uncertainty_head is not None and all_log_var:
output['log_variance'] = torch.stack(all_log_var, dim=1) # (B, K, 1, 16, 16)
return output
def reset_temporal(self):
"""Reset temporal modulation state (for new tracking sequence)."""
self.temporal_mod.reset()
def freeze_backbone_shared_experts(self):
"""Freeze shared experts in TMoE blocks for Phase 2."""
self.backbone.freeze_shared_experts()
def build_tracker(config: dict = None) -> ViLTracker:
"""Build a ViL tracker with given or default config."""
return ViLTracker(config or get_default_config())