""" FiLM (Feature-wise Linear Modulation) Temporal Module. Replaces DTPTrack's temporal prompt tokens (which are broken for bidirectional mLSTM scanning) with channel-wise affine modulation conditioned on temporal context. Architecture: 1. TemporalReliabilityCalibrator: learns reliability weights for temporal features 2. FiLMTemporalModulation: γ(t)·x + β(t) modulation per block 3. TemporalModulationManager: manages FiLM layers across all blocks Reference: Perez et al., "FiLM: Visual Reasoning with a General Conditioning Layer" """ import torch import torch.nn as nn import torch.nn.functional as F class TemporalReliabilityCalibrator(nn.Module): """Learns a reliability score for temporal context. Takes temporal features (e.g., from previous frame's mLSTM states) and produces a scalar reliability weight in [0, 1] for each token. """ def __init__(self, dim: int = 384): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim // 4), nn.GELU(), nn.Linear(dim // 4, 1), nn.Sigmoid(), ) def forward(self, temporal_feat: torch.Tensor) -> torch.Tensor: """ Args: temporal_feat: (B, S, D) temporal context features Returns: reliability: (B, S, 1) reliability weights in [0, 1] """ return self.net(temporal_feat) class FiLMTemporalModulation(nn.Module): """Feature-wise Linear Modulation conditioned on temporal context. Computes: output = γ(temporal) · x + β(temporal) where γ, β are learned from temporal features via small networks. """ def __init__(self, dim: int = 384): super().__init__() # Generate scale (γ) and shift (β) from temporal context self.gamma_net = nn.Sequential( nn.Linear(dim, dim // 4), nn.GELU(), nn.Linear(dim // 4, dim), ) self.beta_net = nn.Sequential( nn.Linear(dim, dim // 4), nn.GELU(), nn.Linear(dim // 4, dim), ) # Initialize γ near 1 and β near 0 (identity modulation at init) nn.init.zeros_(self.gamma_net[-1].weight) nn.init.ones_(self.gamma_net[-1].bias) nn.init.zeros_(self.beta_net[-1].weight) nn.init.zeros_(self.beta_net[-1].bias) def forward( self, x: torch.Tensor, temporal_context: torch.Tensor, reliability: torch.Tensor = None, ) -> torch.Tensor: """ Args: x: (B, S, D) input features from current frame temporal_context: (B, S, D) temporal features (prev frame, pooled states, etc.) reliability: (B, S, 1) optional reliability weights Returns: (B, S, D) modulated features """ gamma = self.gamma_net(temporal_context) # (B, S, D) beta = self.beta_net(temporal_context) # (B, S, D) if reliability is not None: # Blend between identity (no modulation) and full modulation based on reliability gamma = reliability * gamma + (1 - reliability) * torch.ones_like(gamma) beta = reliability * beta return gamma * x + beta class TemporalModulationManager(nn.Module): """Manages FiLM modulation across multiple backbone blocks. Applies FiLM modulation after every N-th block, using temporal context from the previous frame's features (or running average). """ def __init__( self, dim: int = 384, num_blocks: int = 24, modulation_interval: int = 6, ): super().__init__() self.dim = dim self.num_blocks = num_blocks self.modulation_interval = modulation_interval # FiLM layers at intervals num_film = num_blocks // modulation_interval self.film_layers = nn.ModuleList([ FiLMTemporalModulation(dim=dim) for _ in range(num_film) ]) # Reliability calibrator self.reliability = TemporalReliabilityCalibrator(dim=dim) # Temporal context projection (map prev features to context) self.context_proj = nn.Linear(dim, dim) # Running temporal context (registered as buffer, not parameter) self.register_buffer('_temporal_context', None) def should_modulate(self, block_idx: int) -> bool: """Check if this block index should apply FiLM modulation.""" return (block_idx + 1) % self.modulation_interval == 0 def get_film_layer(self, block_idx: int) -> FiLMTemporalModulation: """Get the FiLM layer for a given block index.""" film_idx = (block_idx + 1) // self.modulation_interval - 1 return self.film_layers[film_idx] def update_temporal_context(self, features: torch.Tensor): """Update temporal context from current frame features. Args: features: (B, S, D) features from current frame processing """ context = self.context_proj(features.detach()) if self._temporal_context is None: self._temporal_context = context.detach() else: # EMA update — detach to prevent cross-step gradient leakage self._temporal_context = (0.7 * self._temporal_context + 0.3 * context).detach() def modulate( self, x: torch.Tensor, block_idx: int, ) -> torch.Tensor: """Apply FiLM modulation at the appropriate block. Args: x: (B, S, D) features at block_idx block_idx: current block index Returns: (B, S, D) modulated features (or unchanged if not a modulation block) """ if not self.should_modulate(block_idx): return x if self._temporal_context is None: return x # No temporal context yet (first frame) film = self.get_film_layer(block_idx) # Ensure temporal context matches spatial dimension tc = self._temporal_context if tc.shape[1] != x.shape[1]: # Interpolate or pad temporal context tc = F.interpolate( tc.transpose(1, 2), size=x.shape[1], mode='linear', align_corners=False, ).transpose(1, 2) reliability = self.reliability(tc) return film(x, tc, reliability) def reset(self): """Reset temporal context (e.g., for new tracking sequence).""" self._temporal_context = None