| """ |
| ViL Tracker: Full model combining backbone, FiLM modulation, and prediction heads. |
| |
| Pipeline: |
| 1. Template (128x128) + Search (256x256) → PatchEmbed → tokens |
| 2. Concatenated tokens → ViL backbone (24 mLSTM blocks, bidirectional) |
| 3. FiLM temporal modulation integrated BETWEEN backbone blocks |
| 4. Search features → CenterHead → heatmap + size + offset |
| 5. Optional: UncertaintyHead → log variance for adaptive weighting |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .backbone import ViLBackbone |
| from .film_temporal import TemporalModulationManager |
| from .heads import CenterHead, UncertaintyHead, decode_predictions |
|
|
|
|
| def get_default_config() -> dict: |
| """Default ViL-S tracker configuration meeting all constraints. |
| |
| Constraints: ≤50M params, ≤30ms latency, ≤20 GFLOPs, ≤500MB |
| """ |
| return { |
| |
| 'dim': 384, |
| 'depth': 24, |
| 'patch_size': 16, |
| 'proj_factor': 2.0, |
| 'qkv_proj_blocksize': 4, |
| 'num_heads': 4, |
| 'conv_kernel': 4, |
| 'mlp_ratio': 4.0, |
| 'drop_path_rate': 0.05, |
| 'tmoe_blocks': 2, |
| 'num_experts': 4, |
| |
| |
| 'film_interval': 6, |
| |
| |
| 'feat_size': 16, |
| |
| |
| 'template_size': 128, |
| 'search_size': 256, |
| |
| |
| 'use_uncertainty': True, |
| } |
|
|
|
|
| class ViLTracker(nn.Module): |
| """Complete ViL-based single object tracker. |
| |
| Target specs (ViL-S): |
| - Parameters: ~35-40M (well under 50M limit) |
| - GFLOPs: ~15-18 (under 20 GFLOPs) |
| - Model size: ~140-160MB fp32, ~70-80MB fp16 (under 500MB) |
| - Latency: ~20-25ms on GPU (under 30ms) |
| """ |
| def __init__(self, config: dict = None): |
| super().__init__() |
| config = config or get_default_config() |
| self.config = config |
| |
| dim = config['dim'] |
| depth = config['depth'] |
| |
| |
| self.backbone = ViLBackbone( |
| dim=dim, |
| depth=depth, |
| patch_size=config['patch_size'], |
| proj_factor=config['proj_factor'], |
| qkv_proj_blocksize=config['qkv_proj_blocksize'], |
| num_heads=config['num_heads'], |
| conv_kernel=config['conv_kernel'], |
| mlp_ratio=config['mlp_ratio'], |
| drop_path_rate=config['drop_path_rate'], |
| tmoe_blocks=config['tmoe_blocks'], |
| num_experts=config['num_experts'], |
| film_interval=config.get('film_interval', 6), |
| ) |
| |
| |
| self.temporal_mod = TemporalModulationManager( |
| dim=dim, |
| num_blocks=depth, |
| modulation_interval=config['film_interval'], |
| ) |
| |
| |
| self.center_head = CenterHead(dim=dim, feat_size=config['feat_size']) |
| |
| if config.get('use_uncertainty', True): |
| self.uncertainty_head = UncertaintyHead(dim=dim, feat_size=config['feat_size']) |
| else: |
| self.uncertainty_head = None |
| |
| def forward( |
| self, |
| template: torch.Tensor, |
| searches: torch.Tensor, |
| use_temporal: bool = False, |
| ) -> dict: |
| """ |
| Process template + K search frames through the full tracker. |
| |
| Args: |
| template: (B, 3, 128, 128) template image |
| searches: (B, K, 3, 256, 256) K consecutive search frames |
| OR (B, 3, 256, 256) single search frame (backward compat) |
| use_temporal: whether to apply FiLM temporal modulation |
| Returns: |
| dict with per-frame predictions: |
| heatmap: (B, K, 1, 16, 16) or (B, 1, 16, 16) if single |
| size: (B, K, 2, 16, 16) or (B, 2, 16, 16) |
| offset: (B, K, 2, 16, 16) or (B, 2, 16, 16) |
| boxes: (B, K, 4) or (B, 4) |
| scores: (B, K) or (B,) |
| template_feat: (B, 64, D) |
| search_feats: (B, K, 256, D) or (B, 256, D) |
| """ |
| single_frame = (searches.ndim == 4) |
| |
| temporal_mgr = self.temporal_mod if use_temporal else None |
| template_feat, search_feats = self.backbone(template, searches, temporal_mod_manager=temporal_mgr) |
| |
| |
| if single_frame: |
| |
| preds = self.center_head(search_feats) |
| boxes, scores = decode_predictions( |
| preds['heatmap'], preds['size'], preds['offset'], |
| search_size=self.config['search_size'], |
| feat_size=self.config['feat_size'], |
| ) |
| output = { |
| 'heatmap': preds['heatmap'], |
| 'size': preds['size'], |
| 'offset': preds['offset'], |
| 'boxes': boxes, |
| 'scores': scores, |
| 'template_feat': template_feat, |
| 'search_feat': search_feats, |
| } |
| if self.uncertainty_head is not None: |
| output['log_variance'] = self.uncertainty_head(search_feats) |
| return output |
| |
| |
| B, K = search_feats.shape[:2] |
| |
| all_heatmaps, all_sizes, all_offsets = [], [], [] |
| all_boxes, all_scores = [], [] |
| all_log_var = [] |
| |
| for k in range(K): |
| s_feat_k = search_feats[:, k] |
| preds_k = self.center_head(s_feat_k) |
| boxes_k, scores_k = decode_predictions( |
| preds_k['heatmap'], preds_k['size'], preds_k['offset'], |
| search_size=self.config['search_size'], |
| feat_size=self.config['feat_size'], |
| ) |
| all_heatmaps.append(preds_k['heatmap']) |
| all_sizes.append(preds_k['size']) |
| all_offsets.append(preds_k['offset']) |
| all_boxes.append(boxes_k) |
| all_scores.append(scores_k) |
| |
| if self.uncertainty_head is not None: |
| all_log_var.append(self.uncertainty_head(s_feat_k)) |
| |
| output = { |
| 'heatmap': torch.stack(all_heatmaps, dim=1), |
| 'size': torch.stack(all_sizes, dim=1), |
| 'offset': torch.stack(all_offsets, dim=1), |
| 'boxes': torch.stack(all_boxes, dim=1), |
| 'scores': torch.stack(all_scores, dim=1), |
| 'template_feat': template_feat, |
| 'search_feats': search_feats, |
| } |
| |
| if self.uncertainty_head is not None and all_log_var: |
| output['log_variance'] = torch.stack(all_log_var, dim=1) |
| |
| return output |
| |
| def reset_temporal(self): |
| """Reset temporal modulation state (for new tracking sequence).""" |
| self.temporal_mod.reset() |
| |
| def freeze_backbone_shared_experts(self): |
| """Freeze shared experts in TMoE blocks for Phase 2.""" |
| self.backbone.freeze_shared_experts() |
|
|
|
|
| def build_tracker(config: dict = None) -> ViLTracker: |
| """Build a ViL tracker with given or default config.""" |
| return ViLTracker(config or get_default_config()) |
|
|