vil-tracker / vil_tracker /models /film_temporal.py
omar-ah's picture
Fix FiLM temporal context gradient leak: detach EMA to prevent cross-step backprop
505973d verified
"""
FiLM (Feature-wise Linear Modulation) Temporal Module.
Replaces DTPTrack's temporal prompt tokens (which are broken for bidirectional mLSTM
scanning) with channel-wise affine modulation conditioned on temporal context.
Architecture:
1. TemporalReliabilityCalibrator: learns reliability weights for temporal features
2. FiLMTemporalModulation: γ(t)·x + β(t) modulation per block
3. TemporalModulationManager: manages FiLM layers across all blocks
Reference: Perez et al., "FiLM: Visual Reasoning with a General Conditioning Layer"
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class TemporalReliabilityCalibrator(nn.Module):
"""Learns a reliability score for temporal context.
Takes temporal features (e.g., from previous frame's mLSTM states)
and produces a scalar reliability weight in [0, 1] for each token.
"""
def __init__(self, dim: int = 384):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim // 4),
nn.GELU(),
nn.Linear(dim // 4, 1),
nn.Sigmoid(),
)
def forward(self, temporal_feat: torch.Tensor) -> torch.Tensor:
"""
Args:
temporal_feat: (B, S, D) temporal context features
Returns:
reliability: (B, S, 1) reliability weights in [0, 1]
"""
return self.net(temporal_feat)
class FiLMTemporalModulation(nn.Module):
"""Feature-wise Linear Modulation conditioned on temporal context.
Computes: output = γ(temporal) · x + β(temporal)
where γ, β are learned from temporal features via small networks.
"""
def __init__(self, dim: int = 384):
super().__init__()
# Generate scale (γ) and shift (β) from temporal context
self.gamma_net = nn.Sequential(
nn.Linear(dim, dim // 4),
nn.GELU(),
nn.Linear(dim // 4, dim),
)
self.beta_net = nn.Sequential(
nn.Linear(dim, dim // 4),
nn.GELU(),
nn.Linear(dim // 4, dim),
)
# Initialize γ near 1 and β near 0 (identity modulation at init)
nn.init.zeros_(self.gamma_net[-1].weight)
nn.init.ones_(self.gamma_net[-1].bias)
nn.init.zeros_(self.beta_net[-1].weight)
nn.init.zeros_(self.beta_net[-1].bias)
def forward(
self,
x: torch.Tensor,
temporal_context: torch.Tensor,
reliability: torch.Tensor = None,
) -> torch.Tensor:
"""
Args:
x: (B, S, D) input features from current frame
temporal_context: (B, S, D) temporal features (prev frame, pooled states, etc.)
reliability: (B, S, 1) optional reliability weights
Returns:
(B, S, D) modulated features
"""
gamma = self.gamma_net(temporal_context) # (B, S, D)
beta = self.beta_net(temporal_context) # (B, S, D)
if reliability is not None:
# Blend between identity (no modulation) and full modulation based on reliability
gamma = reliability * gamma + (1 - reliability) * torch.ones_like(gamma)
beta = reliability * beta
return gamma * x + beta
class TemporalModulationManager(nn.Module):
"""Manages FiLM modulation across multiple backbone blocks.
Applies FiLM modulation after every N-th block, using temporal context
from the previous frame's features (or running average).
"""
def __init__(
self,
dim: int = 384,
num_blocks: int = 24,
modulation_interval: int = 6,
):
super().__init__()
self.dim = dim
self.num_blocks = num_blocks
self.modulation_interval = modulation_interval
# FiLM layers at intervals
num_film = num_blocks // modulation_interval
self.film_layers = nn.ModuleList([
FiLMTemporalModulation(dim=dim)
for _ in range(num_film)
])
# Reliability calibrator
self.reliability = TemporalReliabilityCalibrator(dim=dim)
# Temporal context projection (map prev features to context)
self.context_proj = nn.Linear(dim, dim)
# Running temporal context (registered as buffer, not parameter)
self.register_buffer('_temporal_context', None)
def should_modulate(self, block_idx: int) -> bool:
"""Check if this block index should apply FiLM modulation."""
return (block_idx + 1) % self.modulation_interval == 0
def get_film_layer(self, block_idx: int) -> FiLMTemporalModulation:
"""Get the FiLM layer for a given block index."""
film_idx = (block_idx + 1) // self.modulation_interval - 1
return self.film_layers[film_idx]
def update_temporal_context(self, features: torch.Tensor):
"""Update temporal context from current frame features.
Args:
features: (B, S, D) features from current frame processing
"""
context = self.context_proj(features.detach())
if self._temporal_context is None:
self._temporal_context = context.detach()
else:
# EMA update — detach to prevent cross-step gradient leakage
self._temporal_context = (0.7 * self._temporal_context + 0.3 * context).detach()
def modulate(
self,
x: torch.Tensor,
block_idx: int,
) -> torch.Tensor:
"""Apply FiLM modulation at the appropriate block.
Args:
x: (B, S, D) features at block_idx
block_idx: current block index
Returns:
(B, S, D) modulated features (or unchanged if not a modulation block)
"""
if not self.should_modulate(block_idx):
return x
if self._temporal_context is None:
return x # No temporal context yet (first frame)
film = self.get_film_layer(block_idx)
# Ensure temporal context matches spatial dimension
tc = self._temporal_context
if tc.shape[1] != x.shape[1]:
# Interpolate or pad temporal context
tc = F.interpolate(
tc.transpose(1, 2),
size=x.shape[1],
mode='linear',
align_corners=False,
).transpose(1, 2)
reliability = self.reliability(tc)
return film(x, tc, reliability)
def reset(self):
"""Reset temporal context (e.g., for new tracking sequence)."""
self._temporal_context = None