omar-ah commited on
Commit
bcd8770
·
verified ·
1 Parent(s): 9556ef9

Upload vil_tracker/models/film_temporal.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vil_tracker/models/film_temporal.py +188 -0
vil_tracker/models/film_temporal.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FiLM (Feature-wise Linear Modulation) Temporal Module.
3
+
4
+ Replaces DTPTrack's temporal prompt tokens (which are broken for bidirectional mLSTM
5
+ scanning) with channel-wise affine modulation conditioned on temporal context.
6
+
7
+ Architecture:
8
+ 1. TemporalReliabilityCalibrator: learns reliability weights for temporal features
9
+ 2. FiLMTemporalModulation: γ(t)·x + β(t) modulation per block
10
+ 3. TemporalModulationManager: manages FiLM layers across all blocks
11
+
12
+ Reference: Perez et al., "FiLM: Visual Reasoning with a General Conditioning Layer"
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ class TemporalReliabilityCalibrator(nn.Module):
21
+ """Learns a reliability score for temporal context.
22
+
23
+ Takes temporal features (e.g., from previous frame's mLSTM states)
24
+ and produces a scalar reliability weight in [0, 1] for each token.
25
+ """
26
+ def __init__(self, dim: int = 384):
27
+ super().__init__()
28
+ self.net = nn.Sequential(
29
+ nn.Linear(dim, dim // 4),
30
+ nn.GELU(),
31
+ nn.Linear(dim // 4, 1),
32
+ nn.Sigmoid(),
33
+ )
34
+
35
+ def forward(self, temporal_feat: torch.Tensor) -> torch.Tensor:
36
+ """
37
+ Args:
38
+ temporal_feat: (B, S, D) temporal context features
39
+ Returns:
40
+ reliability: (B, S, 1) reliability weights in [0, 1]
41
+ """
42
+ return self.net(temporal_feat)
43
+
44
+
45
+ class FiLMTemporalModulation(nn.Module):
46
+ """Feature-wise Linear Modulation conditioned on temporal context.
47
+
48
+ Computes: output = γ(temporal) · x + β(temporal)
49
+ where γ, β are learned from temporal features via small networks.
50
+ """
51
+ def __init__(self, dim: int = 384):
52
+ super().__init__()
53
+ # Generate scale (γ) and shift (β) from temporal context
54
+ self.gamma_net = nn.Sequential(
55
+ nn.Linear(dim, dim // 4),
56
+ nn.GELU(),
57
+ nn.Linear(dim // 4, dim),
58
+ )
59
+ self.beta_net = nn.Sequential(
60
+ nn.Linear(dim, dim // 4),
61
+ nn.GELU(),
62
+ nn.Linear(dim // 4, dim),
63
+ )
64
+
65
+ # Initialize γ near 1 and β near 0 (identity modulation at init)
66
+ nn.init.zeros_(self.gamma_net[-1].weight)
67
+ nn.init.ones_(self.gamma_net[-1].bias)
68
+ nn.init.zeros_(self.beta_net[-1].weight)
69
+ nn.init.zeros_(self.beta_net[-1].bias)
70
+
71
+ def forward(
72
+ self,
73
+ x: torch.Tensor,
74
+ temporal_context: torch.Tensor,
75
+ reliability: torch.Tensor = None,
76
+ ) -> torch.Tensor:
77
+ """
78
+ Args:
79
+ x: (B, S, D) input features from current frame
80
+ temporal_context: (B, S, D) temporal features (prev frame, pooled states, etc.)
81
+ reliability: (B, S, 1) optional reliability weights
82
+ Returns:
83
+ (B, S, D) modulated features
84
+ """
85
+ gamma = self.gamma_net(temporal_context) # (B, S, D)
86
+ beta = self.beta_net(temporal_context) # (B, S, D)
87
+
88
+ if reliability is not None:
89
+ # Blend between identity (no modulation) and full modulation based on reliability
90
+ gamma = reliability * gamma + (1 - reliability) * torch.ones_like(gamma)
91
+ beta = reliability * beta
92
+
93
+ return gamma * x + beta
94
+
95
+
96
+ class TemporalModulationManager(nn.Module):
97
+ """Manages FiLM modulation across multiple backbone blocks.
98
+
99
+ Applies FiLM modulation after every N-th block, using temporal context
100
+ from the previous frame's features (or running average).
101
+ """
102
+ def __init__(
103
+ self,
104
+ dim: int = 384,
105
+ num_blocks: int = 24,
106
+ modulation_interval: int = 6,
107
+ ):
108
+ super().__init__()
109
+ self.dim = dim
110
+ self.num_blocks = num_blocks
111
+ self.modulation_interval = modulation_interval
112
+
113
+ # FiLM layers at intervals
114
+ num_film = num_blocks // modulation_interval
115
+ self.film_layers = nn.ModuleList([
116
+ FiLMTemporalModulation(dim=dim)
117
+ for _ in range(num_film)
118
+ ])
119
+
120
+ # Reliability calibrator
121
+ self.reliability = TemporalReliabilityCalibrator(dim=dim)
122
+
123
+ # Temporal context projection (map prev features to context)
124
+ self.context_proj = nn.Linear(dim, dim)
125
+
126
+ # Running temporal context (registered as buffer, not parameter)
127
+ self.register_buffer('_temporal_context', None)
128
+
129
+ def should_modulate(self, block_idx: int) -> bool:
130
+ """Check if this block index should apply FiLM modulation."""
131
+ return (block_idx + 1) % self.modulation_interval == 0
132
+
133
+ def get_film_layer(self, block_idx: int) -> FiLMTemporalModulation:
134
+ """Get the FiLM layer for a given block index."""
135
+ film_idx = (block_idx + 1) // self.modulation_interval - 1
136
+ return self.film_layers[film_idx]
137
+
138
+ def update_temporal_context(self, features: torch.Tensor):
139
+ """Update temporal context from current frame features.
140
+
141
+ Args:
142
+ features: (B, S, D) features from current frame processing
143
+ """
144
+ context = self.context_proj(features.detach())
145
+ if self._temporal_context is None:
146
+ self._temporal_context = context
147
+ else:
148
+ # EMA update
149
+ self._temporal_context = 0.7 * self._temporal_context + 0.3 * context
150
+
151
+ def modulate(
152
+ self,
153
+ x: torch.Tensor,
154
+ block_idx: int,
155
+ ) -> torch.Tensor:
156
+ """Apply FiLM modulation at the appropriate block.
157
+
158
+ Args:
159
+ x: (B, S, D) features at block_idx
160
+ block_idx: current block index
161
+ Returns:
162
+ (B, S, D) modulated features (or unchanged if not a modulation block)
163
+ """
164
+ if not self.should_modulate(block_idx):
165
+ return x
166
+
167
+ if self._temporal_context is None:
168
+ return x # No temporal context yet (first frame)
169
+
170
+ film = self.get_film_layer(block_idx)
171
+
172
+ # Ensure temporal context matches spatial dimension
173
+ tc = self._temporal_context
174
+ if tc.shape[1] != x.shape[1]:
175
+ # Interpolate or pad temporal context
176
+ tc = F.interpolate(
177
+ tc.transpose(1, 2),
178
+ size=x.shape[1],
179
+ mode='linear',
180
+ align_corners=False,
181
+ ).transpose(1, 2)
182
+
183
+ reliability = self.reliability(tc)
184
+ return film(x, tc, reliability)
185
+
186
+ def reset(self):
187
+ """Reset temporal context (e.g., for new tracking sequence)."""
188
+ self._temporal_context = None