Fix vil_tracker/models/backbone.py: audit corrections
Browse files
vil_tracker/models/backbone.py
CHANGED
|
@@ -4,6 +4,7 @@ ViL (Vision-LSTM) Backbone for single object tracking.
|
|
| 4 |
Architecture:
|
| 5 |
- Patch embedding (Conv2d) for template + search region
|
| 6 |
- Stack of mLSTM blocks with bidirectional scanning (even=L→R, odd=R→L)
|
|
|
|
| 7 |
- Optional TMoE-MLP in last N blocks (dense routing, frozen shared expert)
|
| 8 |
- Outputs concatenated template+search features for head processing
|
| 9 |
|
|
@@ -133,10 +134,11 @@ class mLSTMBlockWithTMoE(nn.Module):
|
|
| 133 |
|
| 134 |
|
| 135 |
class ViLBackbone(nn.Module):
|
| 136 |
-
"""Vision-LSTM backbone for tracking.
|
| 137 |
|
| 138 |
Concatenates template + search patches into a single sequence,
|
| 139 |
-
processes through bidirectional mLSTM blocks
|
|
|
|
| 140 |
|
| 141 |
Template: 128x128 → 8x8 = 64 tokens
|
| 142 |
Search: 256x256 → 16x16 = 256 tokens
|
|
@@ -144,6 +146,7 @@ class ViLBackbone(nn.Module):
|
|
| 144 |
|
| 145 |
Bidirectional scanning: even blocks L→R, odd blocks R→L.
|
| 146 |
Last `tmoe_blocks` blocks use TMoE MLP for temporal specialization.
|
|
|
|
| 147 |
"""
|
| 148 |
def __init__(
|
| 149 |
self,
|
|
@@ -160,11 +163,13 @@ class ViLBackbone(nn.Module):
|
|
| 160 |
tmoe_blocks: int = 2,
|
| 161 |
num_experts: int = 4,
|
| 162 |
bias: bool = False,
|
|
|
|
| 163 |
):
|
| 164 |
super().__init__()
|
| 165 |
self.dim = dim
|
| 166 |
self.depth = depth
|
| 167 |
self.patch_size = patch_size
|
|
|
|
| 168 |
|
| 169 |
# Patch embedding
|
| 170 |
self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, dim=dim)
|
|
@@ -209,11 +214,13 @@ class ViLBackbone(nn.Module):
|
|
| 209 |
self,
|
| 210 |
template: torch.Tensor,
|
| 211 |
search: torch.Tensor,
|
|
|
|
| 212 |
) -> tuple:
|
| 213 |
"""
|
| 214 |
Args:
|
| 215 |
template: (B, 3, 128, 128) template image
|
| 216 |
search: (B, 3, 256, 256) search region image
|
|
|
|
| 217 |
Returns:
|
| 218 |
template_feat: (B, 64, D) template features
|
| 219 |
search_feat: (B, 256, D) search features
|
|
@@ -230,16 +237,24 @@ class ViLBackbone(nn.Module):
|
|
| 230 |
|
| 231 |
# Concatenate: [template | search]
|
| 232 |
tokens = torch.cat([t_tokens, s_tokens], dim=1) # (B, 320, D)
|
|
|
|
| 233 |
|
| 234 |
-
# Process through bidirectional mLSTM blocks
|
| 235 |
for i, block in enumerate(self.blocks):
|
| 236 |
reverse = (i % 2 == 1) # odd blocks: R→L
|
| 237 |
tokens = block(tokens, reverse=reverse)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
tokens = self.norm(tokens)
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
# Split back
|
| 242 |
-
n_template = t_tokens.shape[1]
|
| 243 |
template_feat = tokens[:, :n_template]
|
| 244 |
search_feat = tokens[:, n_template:]
|
| 245 |
|
|
|
|
| 4 |
Architecture:
|
| 5 |
- Patch embedding (Conv2d) for template + search region
|
| 6 |
- Stack of mLSTM blocks with bidirectional scanning (even=L→R, odd=R→L)
|
| 7 |
+
- FiLM temporal modulation integrated BETWEEN blocks (at interval=6)
|
| 8 |
- Optional TMoE-MLP in last N blocks (dense routing, frozen shared expert)
|
| 9 |
- Outputs concatenated template+search features for head processing
|
| 10 |
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
class ViLBackbone(nn.Module):
|
| 137 |
+
"""Vision-LSTM backbone for tracking with integrated FiLM temporal modulation.
|
| 138 |
|
| 139 |
Concatenates template + search patches into a single sequence,
|
| 140 |
+
processes through bidirectional mLSTM blocks with FiLM modulation
|
| 141 |
+
injected between blocks at regular intervals, then separates outputs.
|
| 142 |
|
| 143 |
Template: 128x128 → 8x8 = 64 tokens
|
| 144 |
Search: 256x256 → 16x16 = 256 tokens
|
|
|
|
| 146 |
|
| 147 |
Bidirectional scanning: even blocks L→R, odd blocks R→L.
|
| 148 |
Last `tmoe_blocks` blocks use TMoE MLP for temporal specialization.
|
| 149 |
+
FiLM modulation: applied after every `film_interval`-th block.
|
| 150 |
"""
|
| 151 |
def __init__(
|
| 152 |
self,
|
|
|
|
| 163 |
tmoe_blocks: int = 2,
|
| 164 |
num_experts: int = 4,
|
| 165 |
bias: bool = False,
|
| 166 |
+
film_interval: int = 6,
|
| 167 |
):
|
| 168 |
super().__init__()
|
| 169 |
self.dim = dim
|
| 170 |
self.depth = depth
|
| 171 |
self.patch_size = patch_size
|
| 172 |
+
self.film_interval = film_interval
|
| 173 |
|
| 174 |
# Patch embedding
|
| 175 |
self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, dim=dim)
|
|
|
|
| 214 |
self,
|
| 215 |
template: torch.Tensor,
|
| 216 |
search: torch.Tensor,
|
| 217 |
+
temporal_mod_manager=None,
|
| 218 |
) -> tuple:
|
| 219 |
"""
|
| 220 |
Args:
|
| 221 |
template: (B, 3, 128, 128) template image
|
| 222 |
search: (B, 3, 256, 256) search region image
|
| 223 |
+
temporal_mod_manager: optional TemporalModulationManager for FiLM
|
| 224 |
Returns:
|
| 225 |
template_feat: (B, 64, D) template features
|
| 226 |
search_feat: (B, 256, D) search features
|
|
|
|
| 237 |
|
| 238 |
# Concatenate: [template | search]
|
| 239 |
tokens = torch.cat([t_tokens, s_tokens], dim=1) # (B, 320, D)
|
| 240 |
+
n_template = t_tokens.shape[1]
|
| 241 |
|
| 242 |
+
# Process through bidirectional mLSTM blocks with optional FiLM
|
| 243 |
for i, block in enumerate(self.blocks):
|
| 244 |
reverse = (i % 2 == 1) # odd blocks: R→L
|
| 245 |
tokens = block(tokens, reverse=reverse)
|
| 246 |
+
|
| 247 |
+
# Apply FiLM temporal modulation between blocks
|
| 248 |
+
if temporal_mod_manager is not None:
|
| 249 |
+
tokens = temporal_mod_manager.modulate(tokens, i)
|
| 250 |
|
| 251 |
tokens = self.norm(tokens)
|
| 252 |
|
| 253 |
+
# Update temporal context after full forward pass
|
| 254 |
+
if temporal_mod_manager is not None:
|
| 255 |
+
temporal_mod_manager.update_temporal_context(tokens)
|
| 256 |
+
|
| 257 |
# Split back
|
|
|
|
| 258 |
template_feat = tokens[:, :n_template]
|
| 259 |
search_feat = tokens[:, n_template:]
|
| 260 |
|