""" ViL Tracker: Full model combining backbone, FiLM modulation, and prediction heads. Pipeline: 1. Template (128x128) + Search (256x256) → PatchEmbed → tokens 2. Concatenated tokens → ViL backbone (24 mLSTM blocks, bidirectional) 3. FiLM temporal modulation at intervals (conditioned on prev frame) 4. Search features → CenterHead → heatmap + size + offset 5. Optional: UncertaintyHead → log variance for adaptive weighting """ import torch import torch.nn as nn from .backbone import ViLBackbone from .film_temporal import TemporalModulationManager from .heads import CenterHead, UncertaintyHead, decode_predictions def get_default_config() -> dict: """Default ViL-S tracker configuration meeting all constraints. Constraints: ≤50M params, ≤30ms latency, ≤20 GFLOPs, ≤500MB """ return { # Backbone 'dim': 384, 'depth': 24, 'patch_size': 16, 'proj_factor': 2.0, 'qkv_proj_blocksize': 4, 'num_heads': 4, 'conv_kernel': 4, 'mlp_ratio': 4.0, 'drop_path_rate': 0.1, 'tmoe_blocks': 2, 'num_experts': 4, # FiLM temporal modulation 'film_interval': 6, # Heads 'feat_size': 16, # Inputs 'template_size': 128, 'search_size': 256, # Uncertainty 'use_uncertainty': True, } class ViLTracker(nn.Module): """Complete ViL-based single object tracker. Target specs (ViL-S): - Parameters: ~35-40M (well under 50M limit) - GFLOPs: ~15-18 (under 20 GFLOPs) - Model size: ~140-160MB fp32, ~70-80MB fp16 (under 500MB) - Latency: ~20-25ms on GPU (under 30ms) """ def __init__(self, config: dict = None): super().__init__() config = config or get_default_config() self.config = config dim = config['dim'] depth = config['depth'] # Backbone self.backbone = ViLBackbone( dim=dim, depth=depth, patch_size=config['patch_size'], proj_factor=config['proj_factor'], qkv_proj_blocksize=config['qkv_proj_blocksize'], num_heads=config['num_heads'], conv_kernel=config['conv_kernel'], mlp_ratio=config['mlp_ratio'], drop_path_rate=config['drop_path_rate'], tmoe_blocks=config['tmoe_blocks'], num_experts=config['num_experts'], ) # FiLM temporal modulation self.temporal_mod = TemporalModulationManager( dim=dim, num_blocks=depth, modulation_interval=config['film_interval'], ) # Prediction heads self.center_head = CenterHead(dim=dim, feat_size=config['feat_size']) if config.get('use_uncertainty', True): self.uncertainty_head = UncertaintyHead(dim=dim, feat_size=config['feat_size']) else: self.uncertainty_head = None def forward( self, template: torch.Tensor, search: torch.Tensor, use_temporal: bool = False, ) -> dict: """ Args: template: (B, 3, 128, 128) template image search: (B, 3, 256, 256) search region use_temporal: whether to apply FiLM temporal modulation Returns: dict with predictions: heatmap, size, offset, boxes, scores, and optionally uncertainty """ # Backbone forward template_feat, search_feat = self.backbone(template, search) # Optional FiLM temporal modulation on search features if use_temporal: for i in range(self.backbone.depth): if self.temporal_mod.should_modulate(i): search_feat = self.temporal_mod.modulate(search_feat, i) # Update temporal context for next frame self.temporal_mod.update_temporal_context(search_feat) # Prediction heads preds = self.center_head(search_feat) # Decode to boxes boxes, scores = decode_predictions( preds['heatmap'], preds['size'], preds['offset'], search_size=self.config['search_size'], feat_size=self.config['feat_size'], ) output = { 'heatmap': preds['heatmap'], 'size': preds['size'], 'offset': preds['offset'], 'boxes': boxes, 'scores': scores, 'template_feat': template_feat, 'search_feat': search_feat, } # Uncertainty prediction if self.uncertainty_head is not None: output['log_variance'] = self.uncertainty_head(search_feat) return output def reset_temporal(self): """Reset temporal modulation state (for new tracking sequence).""" self.temporal_mod.reset() def freeze_backbone_shared_experts(self): """Freeze shared experts in TMoE blocks for Phase 2.""" self.backbone.freeze_shared_experts() def build_tracker(config: dict = None) -> ViLTracker: """Build a ViL tracker with given or default config.""" return ViLTracker(config or get_default_config())