""" ViL Tracker: Full model combining backbone, FiLM modulation, and prediction heads. Pipeline: 1. Template (128x128) + Search (256x256) → PatchEmbed → tokens 2. Concatenated tokens → ViL backbone (24 mLSTM blocks, bidirectional) 3. FiLM temporal modulation integrated BETWEEN backbone blocks 4. Search features → CenterHead → heatmap + size + offset 5. Optional: UncertaintyHead → log variance for adaptive weighting """ import torch import torch.nn as nn from .backbone import ViLBackbone from .film_temporal import TemporalModulationManager from .heads import CenterHead, UncertaintyHead, decode_predictions def get_default_config() -> dict: """Default ViL-S tracker configuration meeting all constraints. Constraints: ≤50M params, ≤30ms latency, ≤20 GFLOPs, ≤500MB """ return { # Backbone 'dim': 384, 'depth': 24, 'patch_size': 16, 'proj_factor': 2.0, 'qkv_proj_blocksize': 4, 'num_heads': 4, 'conv_kernel': 4, 'mlp_ratio': 4.0, 'drop_path_rate': 0.05, 'tmoe_blocks': 2, 'num_experts': 4, # FiLM temporal modulation 'film_interval': 6, # Heads 'feat_size': 16, # Inputs 'template_size': 128, 'search_size': 256, # Uncertainty 'use_uncertainty': True, } class ViLTracker(nn.Module): """Complete ViL-based single object tracker. Target specs (ViL-S): - Parameters: ~35-40M (well under 50M limit) - GFLOPs: ~15-18 (under 20 GFLOPs) - Model size: ~140-160MB fp32, ~70-80MB fp16 (under 500MB) - Latency: ~20-25ms on GPU (under 30ms) """ def __init__(self, config: dict = None): super().__init__() config = config or get_default_config() self.config = config dim = config['dim'] depth = config['depth'] # Backbone (now accepts temporal_mod_manager as forward arg) self.backbone = ViLBackbone( dim=dim, depth=depth, patch_size=config['patch_size'], proj_factor=config['proj_factor'], qkv_proj_blocksize=config['qkv_proj_blocksize'], num_heads=config['num_heads'], conv_kernel=config['conv_kernel'], mlp_ratio=config['mlp_ratio'], drop_path_rate=config['drop_path_rate'], tmoe_blocks=config['tmoe_blocks'], num_experts=config['num_experts'], film_interval=config.get('film_interval', 6), ) # FiLM temporal modulation (applied BETWEEN backbone blocks) self.temporal_mod = TemporalModulationManager( dim=dim, num_blocks=depth, modulation_interval=config['film_interval'], ) # Prediction heads self.center_head = CenterHead(dim=dim, feat_size=config['feat_size']) if config.get('use_uncertainty', True): self.uncertainty_head = UncertaintyHead(dim=dim, feat_size=config['feat_size']) else: self.uncertainty_head = None def forward( self, template: torch.Tensor, searches: torch.Tensor, use_temporal: bool = False, ) -> dict: """ Process template + K search frames through the full tracker. Args: template: (B, 3, 128, 128) template image searches: (B, K, 3, 256, 256) K consecutive search frames OR (B, 3, 256, 256) single search frame (backward compat) use_temporal: whether to apply FiLM temporal modulation Returns: dict with per-frame predictions: heatmap: (B, K, 1, 16, 16) or (B, 1, 16, 16) if single size: (B, K, 2, 16, 16) or (B, 2, 16, 16) offset: (B, K, 2, 16, 16) or (B, 2, 16, 16) boxes: (B, K, 4) or (B, 4) scores: (B, K) or (B,) template_feat: (B, 64, D) search_feats: (B, K, 256, D) or (B, 256, D) """ single_frame = (searches.ndim == 4) temporal_mgr = self.temporal_mod if use_temporal else None template_feat, search_feats = self.backbone(template, searches, temporal_mod_manager=temporal_mgr) # search_feats: (B, K, 256, D) for multi-frame, (B, 256, D) for single if single_frame: # Single frame path — same as before preds = self.center_head(search_feats) boxes, scores = decode_predictions( preds['heatmap'], preds['size'], preds['offset'], search_size=self.config['search_size'], feat_size=self.config['feat_size'], ) output = { 'heatmap': preds['heatmap'], 'size': preds['size'], 'offset': preds['offset'], 'boxes': boxes, 'scores': scores, 'template_feat': template_feat, 'search_feat': search_feats, } if self.uncertainty_head is not None: output['log_variance'] = self.uncertainty_head(search_feats) return output # Multi-frame path: run head on each frame's search features B, K = search_feats.shape[:2] all_heatmaps, all_sizes, all_offsets = [], [], [] all_boxes, all_scores = [], [] all_log_var = [] for k in range(K): s_feat_k = search_feats[:, k] # (B, 256, D) preds_k = self.center_head(s_feat_k) boxes_k, scores_k = decode_predictions( preds_k['heatmap'], preds_k['size'], preds_k['offset'], search_size=self.config['search_size'], feat_size=self.config['feat_size'], ) all_heatmaps.append(preds_k['heatmap']) all_sizes.append(preds_k['size']) all_offsets.append(preds_k['offset']) all_boxes.append(boxes_k) all_scores.append(scores_k) if self.uncertainty_head is not None: all_log_var.append(self.uncertainty_head(s_feat_k)) output = { 'heatmap': torch.stack(all_heatmaps, dim=1), # (B, K, 1, 16, 16) 'size': torch.stack(all_sizes, dim=1), # (B, K, 2, 16, 16) 'offset': torch.stack(all_offsets, dim=1), # (B, K, 2, 16, 16) 'boxes': torch.stack(all_boxes, dim=1), # (B, K, 4) 'scores': torch.stack(all_scores, dim=1), # (B, K) 'template_feat': template_feat, # (B, 64, D) 'search_feats': search_feats, # (B, K, 256, D) } if self.uncertainty_head is not None and all_log_var: output['log_variance'] = torch.stack(all_log_var, dim=1) # (B, K, 1, 16, 16) return output def reset_temporal(self): """Reset temporal modulation state (for new tracking sequence).""" self.temporal_mod.reset() def freeze_backbone_shared_experts(self): """Freeze shared experts in TMoE blocks for Phase 2.""" self.backbone.freeze_shared_experts() def build_tracker(config: dict = None) -> ViLTracker: """Build a ViL tracker with given or default config.""" return ViLTracker(config or get_default_config())