omar-ah's picture
Upload vil_tracker/models/tracker.py with huggingface_hub
b3b0529 verified
raw
history blame
5.37 kB
"""
ViL Tracker: Full model combining backbone, FiLM modulation, and prediction heads.
Pipeline:
1. Template (128x128) + Search (256x256) → PatchEmbed → tokens
2. Concatenated tokens → ViL backbone (24 mLSTM blocks, bidirectional)
3. FiLM temporal modulation at intervals (conditioned on prev frame)
4. Search features → CenterHead → heatmap + size + offset
5. Optional: UncertaintyHead → log variance for adaptive weighting
"""
import torch
import torch.nn as nn
from .backbone import ViLBackbone
from .film_temporal import TemporalModulationManager
from .heads import CenterHead, UncertaintyHead, decode_predictions
def get_default_config() -> dict:
"""Default ViL-S tracker configuration meeting all constraints.
Constraints: ≤50M params, ≤30ms latency, ≤20 GFLOPs, ≤500MB
"""
return {
# Backbone
'dim': 384,
'depth': 24,
'patch_size': 16,
'proj_factor': 2.0,
'qkv_proj_blocksize': 4,
'num_heads': 4,
'conv_kernel': 4,
'mlp_ratio': 4.0,
'drop_path_rate': 0.1,
'tmoe_blocks': 2,
'num_experts': 4,
# FiLM temporal modulation
'film_interval': 6,
# Heads
'feat_size': 16,
# Inputs
'template_size': 128,
'search_size': 256,
# Uncertainty
'use_uncertainty': True,
}
class ViLTracker(nn.Module):
"""Complete ViL-based single object tracker.
Target specs (ViL-S):
- Parameters: ~35-40M (well under 50M limit)
- GFLOPs: ~15-18 (under 20 GFLOPs)
- Model size: ~140-160MB fp32, ~70-80MB fp16 (under 500MB)
- Latency: ~20-25ms on GPU (under 30ms)
"""
def __init__(self, config: dict = None):
super().__init__()
config = config or get_default_config()
self.config = config
dim = config['dim']
depth = config['depth']
# Backbone
self.backbone = ViLBackbone(
dim=dim,
depth=depth,
patch_size=config['patch_size'],
proj_factor=config['proj_factor'],
qkv_proj_blocksize=config['qkv_proj_blocksize'],
num_heads=config['num_heads'],
conv_kernel=config['conv_kernel'],
mlp_ratio=config['mlp_ratio'],
drop_path_rate=config['drop_path_rate'],
tmoe_blocks=config['tmoe_blocks'],
num_experts=config['num_experts'],
)
# FiLM temporal modulation
self.temporal_mod = TemporalModulationManager(
dim=dim,
num_blocks=depth,
modulation_interval=config['film_interval'],
)
# Prediction heads
self.center_head = CenterHead(dim=dim, feat_size=config['feat_size'])
if config.get('use_uncertainty', True):
self.uncertainty_head = UncertaintyHead(dim=dim, feat_size=config['feat_size'])
else:
self.uncertainty_head = None
def forward(
self,
template: torch.Tensor,
search: torch.Tensor,
use_temporal: bool = False,
) -> dict:
"""
Args:
template: (B, 3, 128, 128) template image
search: (B, 3, 256, 256) search region
use_temporal: whether to apply FiLM temporal modulation
Returns:
dict with predictions: heatmap, size, offset, boxes, scores,
and optionally uncertainty
"""
# Backbone forward
template_feat, search_feat = self.backbone(template, search)
# Optional FiLM temporal modulation on search features
if use_temporal:
for i in range(self.backbone.depth):
if self.temporal_mod.should_modulate(i):
search_feat = self.temporal_mod.modulate(search_feat, i)
# Update temporal context for next frame
self.temporal_mod.update_temporal_context(search_feat)
# Prediction heads
preds = self.center_head(search_feat)
# Decode to boxes
boxes, scores = decode_predictions(
preds['heatmap'],
preds['size'],
preds['offset'],
search_size=self.config['search_size'],
feat_size=self.config['feat_size'],
)
output = {
'heatmap': preds['heatmap'],
'size': preds['size'],
'offset': preds['offset'],
'boxes': boxes,
'scores': scores,
'template_feat': template_feat,
'search_feat': search_feat,
}
# Uncertainty prediction
if self.uncertainty_head is not None:
output['log_variance'] = self.uncertainty_head(search_feat)
return output
def reset_temporal(self):
"""Reset temporal modulation state (for new tracking sequence)."""
self.temporal_mod.reset()
def freeze_backbone_shared_experts(self):
"""Freeze shared experts in TMoE blocks for Phase 2."""
self.backbone.freeze_shared_experts()
def build_tracker(config: dict = None) -> ViLTracker:
"""Build a ViL tracker with given or default config."""
return ViLTracker(config or get_default_config())