omar-ah commited on
Commit
c08a3b0
·
verified ·
1 Parent(s): 3547636

Fix vil_tracker/models/backbone.py: audit corrections

Browse files
Files changed (1) hide show
  1. vil_tracker/models/backbone.py +19 -4
vil_tracker/models/backbone.py CHANGED
@@ -4,6 +4,7 @@ ViL (Vision-LSTM) Backbone for single object tracking.
4
  Architecture:
5
  - Patch embedding (Conv2d) for template + search region
6
  - Stack of mLSTM blocks with bidirectional scanning (even=L→R, odd=R→L)
 
7
  - Optional TMoE-MLP in last N blocks (dense routing, frozen shared expert)
8
  - Outputs concatenated template+search features for head processing
9
 
@@ -133,10 +134,11 @@ class mLSTMBlockWithTMoE(nn.Module):
133
 
134
 
135
  class ViLBackbone(nn.Module):
136
- """Vision-LSTM backbone for tracking.
137
 
138
  Concatenates template + search patches into a single sequence,
139
- processes through bidirectional mLSTM blocks, then separates outputs.
 
140
 
141
  Template: 128x128 → 8x8 = 64 tokens
142
  Search: 256x256 → 16x16 = 256 tokens
@@ -144,6 +146,7 @@ class ViLBackbone(nn.Module):
144
 
145
  Bidirectional scanning: even blocks L→R, odd blocks R→L.
146
  Last `tmoe_blocks` blocks use TMoE MLP for temporal specialization.
 
147
  """
148
  def __init__(
149
  self,
@@ -160,11 +163,13 @@ class ViLBackbone(nn.Module):
160
  tmoe_blocks: int = 2,
161
  num_experts: int = 4,
162
  bias: bool = False,
 
163
  ):
164
  super().__init__()
165
  self.dim = dim
166
  self.depth = depth
167
  self.patch_size = patch_size
 
168
 
169
  # Patch embedding
170
  self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, dim=dim)
@@ -209,11 +214,13 @@ class ViLBackbone(nn.Module):
209
  self,
210
  template: torch.Tensor,
211
  search: torch.Tensor,
 
212
  ) -> tuple:
213
  """
214
  Args:
215
  template: (B, 3, 128, 128) template image
216
  search: (B, 3, 256, 256) search region image
 
217
  Returns:
218
  template_feat: (B, 64, D) template features
219
  search_feat: (B, 256, D) search features
@@ -230,16 +237,24 @@ class ViLBackbone(nn.Module):
230
 
231
  # Concatenate: [template | search]
232
  tokens = torch.cat([t_tokens, s_tokens], dim=1) # (B, 320, D)
 
233
 
234
- # Process through bidirectional mLSTM blocks
235
  for i, block in enumerate(self.blocks):
236
  reverse = (i % 2 == 1) # odd blocks: R→L
237
  tokens = block(tokens, reverse=reverse)
 
 
 
 
238
 
239
  tokens = self.norm(tokens)
240
 
 
 
 
 
241
  # Split back
242
- n_template = t_tokens.shape[1]
243
  template_feat = tokens[:, :n_template]
244
  search_feat = tokens[:, n_template:]
245
 
 
4
  Architecture:
5
  - Patch embedding (Conv2d) for template + search region
6
  - Stack of mLSTM blocks with bidirectional scanning (even=L→R, odd=R→L)
7
+ - FiLM temporal modulation integrated BETWEEN blocks (at interval=6)
8
  - Optional TMoE-MLP in last N blocks (dense routing, frozen shared expert)
9
  - Outputs concatenated template+search features for head processing
10
 
 
134
 
135
 
136
  class ViLBackbone(nn.Module):
137
+ """Vision-LSTM backbone for tracking with integrated FiLM temporal modulation.
138
 
139
  Concatenates template + search patches into a single sequence,
140
+ processes through bidirectional mLSTM blocks with FiLM modulation
141
+ injected between blocks at regular intervals, then separates outputs.
142
 
143
  Template: 128x128 → 8x8 = 64 tokens
144
  Search: 256x256 → 16x16 = 256 tokens
 
146
 
147
  Bidirectional scanning: even blocks L→R, odd blocks R→L.
148
  Last `tmoe_blocks` blocks use TMoE MLP for temporal specialization.
149
+ FiLM modulation: applied after every `film_interval`-th block.
150
  """
151
  def __init__(
152
  self,
 
163
  tmoe_blocks: int = 2,
164
  num_experts: int = 4,
165
  bias: bool = False,
166
+ film_interval: int = 6,
167
  ):
168
  super().__init__()
169
  self.dim = dim
170
  self.depth = depth
171
  self.patch_size = patch_size
172
+ self.film_interval = film_interval
173
 
174
  # Patch embedding
175
  self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, dim=dim)
 
214
  self,
215
  template: torch.Tensor,
216
  search: torch.Tensor,
217
+ temporal_mod_manager=None,
218
  ) -> tuple:
219
  """
220
  Args:
221
  template: (B, 3, 128, 128) template image
222
  search: (B, 3, 256, 256) search region image
223
+ temporal_mod_manager: optional TemporalModulationManager for FiLM
224
  Returns:
225
  template_feat: (B, 64, D) template features
226
  search_feat: (B, 256, D) search features
 
237
 
238
  # Concatenate: [template | search]
239
  tokens = torch.cat([t_tokens, s_tokens], dim=1) # (B, 320, D)
240
+ n_template = t_tokens.shape[1]
241
 
242
+ # Process through bidirectional mLSTM blocks with optional FiLM
243
  for i, block in enumerate(self.blocks):
244
  reverse = (i % 2 == 1) # odd blocks: R→L
245
  tokens = block(tokens, reverse=reverse)
246
+
247
+ # Apply FiLM temporal modulation between blocks
248
+ if temporal_mod_manager is not None:
249
+ tokens = temporal_mod_manager.modulate(tokens, i)
250
 
251
  tokens = self.norm(tokens)
252
 
253
+ # Update temporal context after full forward pass
254
+ if temporal_mod_manager is not None:
255
+ temporal_mod_manager.update_temporal_context(tokens)
256
+
257
  # Split back
 
258
  template_feat = tokens[:, :n_template]
259
  search_feat = tokens[:, n_template:]
260