| """ |
| FiLM (Feature-wise Linear Modulation) Temporal Module. |
| |
| Replaces DTPTrack's temporal prompt tokens (which are broken for bidirectional mLSTM |
| scanning) with channel-wise affine modulation conditioned on temporal context. |
| |
| Architecture: |
| 1. TemporalReliabilityCalibrator: learns reliability weights for temporal features |
| 2. FiLMTemporalModulation: γ(t)·x + β(t) modulation per block |
| 3. TemporalModulationManager: manages FiLM layers across all blocks |
| |
| Reference: Perez et al., "FiLM: Visual Reasoning with a General Conditioning Layer" |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class TemporalReliabilityCalibrator(nn.Module): |
| """Learns a reliability score for temporal context. |
| |
| Takes temporal features (e.g., from previous frame's mLSTM states) |
| and produces a scalar reliability weight in [0, 1] for each token. |
| """ |
| def __init__(self, dim: int = 384): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(dim, dim // 4), |
| nn.GELU(), |
| nn.Linear(dim // 4, 1), |
| nn.Sigmoid(), |
| ) |
| |
| def forward(self, temporal_feat: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| temporal_feat: (B, S, D) temporal context features |
| Returns: |
| reliability: (B, S, 1) reliability weights in [0, 1] |
| """ |
| return self.net(temporal_feat) |
|
|
|
|
| class FiLMTemporalModulation(nn.Module): |
| """Feature-wise Linear Modulation conditioned on temporal context. |
| |
| Computes: output = γ(temporal) · x + β(temporal) |
| where γ, β are learned from temporal features via small networks. |
| """ |
| def __init__(self, dim: int = 384): |
| super().__init__() |
| |
| self.gamma_net = nn.Sequential( |
| nn.Linear(dim, dim // 4), |
| nn.GELU(), |
| nn.Linear(dim // 4, dim), |
| ) |
| self.beta_net = nn.Sequential( |
| nn.Linear(dim, dim // 4), |
| nn.GELU(), |
| nn.Linear(dim // 4, dim), |
| ) |
| |
| |
| nn.init.zeros_(self.gamma_net[-1].weight) |
| nn.init.ones_(self.gamma_net[-1].bias) |
| nn.init.zeros_(self.beta_net[-1].weight) |
| nn.init.zeros_(self.beta_net[-1].bias) |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| temporal_context: torch.Tensor, |
| reliability: torch.Tensor = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, S, D) input features from current frame |
| temporal_context: (B, S, D) temporal features (prev frame, pooled states, etc.) |
| reliability: (B, S, 1) optional reliability weights |
| Returns: |
| (B, S, D) modulated features |
| """ |
| gamma = self.gamma_net(temporal_context) |
| beta = self.beta_net(temporal_context) |
| |
| if reliability is not None: |
| |
| gamma = reliability * gamma + (1 - reliability) * torch.ones_like(gamma) |
| beta = reliability * beta |
| |
| return gamma * x + beta |
|
|
|
|
| class TemporalModulationManager(nn.Module): |
| """Manages FiLM modulation across multiple backbone blocks. |
| |
| Applies FiLM modulation after every N-th block, using temporal context |
| from the previous frame's features (or running average). |
| """ |
| def __init__( |
| self, |
| dim: int = 384, |
| num_blocks: int = 24, |
| modulation_interval: int = 6, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.num_blocks = num_blocks |
| self.modulation_interval = modulation_interval |
| |
| |
| num_film = num_blocks // modulation_interval |
| self.film_layers = nn.ModuleList([ |
| FiLMTemporalModulation(dim=dim) |
| for _ in range(num_film) |
| ]) |
| |
| |
| self.reliability = TemporalReliabilityCalibrator(dim=dim) |
| |
| |
| self.context_proj = nn.Linear(dim, dim) |
| |
| |
| self.register_buffer('_temporal_context', None) |
| |
| def should_modulate(self, block_idx: int) -> bool: |
| """Check if this block index should apply FiLM modulation.""" |
| return (block_idx + 1) % self.modulation_interval == 0 |
| |
| def get_film_layer(self, block_idx: int) -> FiLMTemporalModulation: |
| """Get the FiLM layer for a given block index.""" |
| film_idx = (block_idx + 1) // self.modulation_interval - 1 |
| return self.film_layers[film_idx] |
| |
| def update_temporal_context(self, features: torch.Tensor): |
| """Update temporal context from current frame features. |
| |
| Args: |
| features: (B, S, D) features from current frame processing |
| """ |
| context = self.context_proj(features.detach()) |
| if self._temporal_context is None: |
| self._temporal_context = context.detach() |
| else: |
| |
| self._temporal_context = (0.7 * self._temporal_context + 0.3 * context).detach() |
| |
| def modulate( |
| self, |
| x: torch.Tensor, |
| block_idx: int, |
| ) -> torch.Tensor: |
| """Apply FiLM modulation at the appropriate block. |
| |
| Args: |
| x: (B, S, D) features at block_idx |
| block_idx: current block index |
| Returns: |
| (B, S, D) modulated features (or unchanged if not a modulation block) |
| """ |
| if not self.should_modulate(block_idx): |
| return x |
| |
| if self._temporal_context is None: |
| return x |
| |
| film = self.get_film_layer(block_idx) |
| |
| |
| tc = self._temporal_context |
| if tc.shape[1] != x.shape[1]: |
| |
| tc = F.interpolate( |
| tc.transpose(1, 2), |
| size=x.shape[1], |
| mode='linear', |
| align_corners=False, |
| ).transpose(1, 2) |
| |
| reliability = self.reliability(tc) |
| return film(x, tc, reliability) |
| |
| def reset(self): |
| """Reset temporal context (e.g., for new tracking sequence).""" |
| self._temporal_context = None |
|
|