Fix vil_tracker/models/tracker.py: audit corrections
Browse files
vil_tracker/models/tracker.py
CHANGED
|
@@ -4,7 +4,7 @@ ViL Tracker: Full model combining backbone, FiLM modulation, and prediction head
|
|
| 4 |
Pipeline:
|
| 5 |
1. Template (128x128) + Search (256x256) β PatchEmbed β tokens
|
| 6 |
2. Concatenated tokens β ViL backbone (24 mLSTM blocks, bidirectional)
|
| 7 |
-
3. FiLM temporal modulation
|
| 8 |
4. Search features β CenterHead β heatmap + size + offset
|
| 9 |
5. Optional: UncertaintyHead β log variance for adaptive weighting
|
| 10 |
"""
|
|
@@ -68,7 +68,7 @@ class ViLTracker(nn.Module):
|
|
| 68 |
dim = config['dim']
|
| 69 |
depth = config['depth']
|
| 70 |
|
| 71 |
-
# Backbone
|
| 72 |
self.backbone = ViLBackbone(
|
| 73 |
dim=dim,
|
| 74 |
depth=depth,
|
|
@@ -81,9 +81,10 @@ class ViLTracker(nn.Module):
|
|
| 81 |
drop_path_rate=config['drop_path_rate'],
|
| 82 |
tmoe_blocks=config['tmoe_blocks'],
|
| 83 |
num_experts=config['num_experts'],
|
|
|
|
| 84 |
)
|
| 85 |
|
| 86 |
-
# FiLM temporal modulation
|
| 87 |
self.temporal_mod = TemporalModulationManager(
|
| 88 |
dim=dim,
|
| 89 |
num_blocks=depth,
|
|
@@ -113,16 +114,9 @@ class ViLTracker(nn.Module):
|
|
| 113 |
dict with predictions: heatmap, size, offset, boxes, scores,
|
| 114 |
and optionally uncertainty
|
| 115 |
"""
|
| 116 |
-
# Backbone forward
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# Optional FiLM temporal modulation on search features
|
| 120 |
-
if use_temporal:
|
| 121 |
-
for i in range(self.backbone.depth):
|
| 122 |
-
if self.temporal_mod.should_modulate(i):
|
| 123 |
-
search_feat = self.temporal_mod.modulate(search_feat, i)
|
| 124 |
-
# Update temporal context for next frame
|
| 125 |
-
self.temporal_mod.update_temporal_context(search_feat)
|
| 126 |
|
| 127 |
# Prediction heads
|
| 128 |
preds = self.center_head(search_feat)
|
|
|
|
| 4 |
Pipeline:
|
| 5 |
1. Template (128x128) + Search (256x256) β PatchEmbed β tokens
|
| 6 |
2. Concatenated tokens β ViL backbone (24 mLSTM blocks, bidirectional)
|
| 7 |
+
3. FiLM temporal modulation integrated BETWEEN backbone blocks
|
| 8 |
4. Search features β CenterHead β heatmap + size + offset
|
| 9 |
5. Optional: UncertaintyHead β log variance for adaptive weighting
|
| 10 |
"""
|
|
|
|
| 68 |
dim = config['dim']
|
| 69 |
depth = config['depth']
|
| 70 |
|
| 71 |
+
# Backbone (now accepts temporal_mod_manager as forward arg)
|
| 72 |
self.backbone = ViLBackbone(
|
| 73 |
dim=dim,
|
| 74 |
depth=depth,
|
|
|
|
| 81 |
drop_path_rate=config['drop_path_rate'],
|
| 82 |
tmoe_blocks=config['tmoe_blocks'],
|
| 83 |
num_experts=config['num_experts'],
|
| 84 |
+
film_interval=config.get('film_interval', 6),
|
| 85 |
)
|
| 86 |
|
| 87 |
+
# FiLM temporal modulation (applied BETWEEN backbone blocks)
|
| 88 |
self.temporal_mod = TemporalModulationManager(
|
| 89 |
dim=dim,
|
| 90 |
num_blocks=depth,
|
|
|
|
| 114 |
dict with predictions: heatmap, size, offset, boxes, scores,
|
| 115 |
and optionally uncertainty
|
| 116 |
"""
|
| 117 |
+
# Backbone forward with optional integrated FiLM modulation
|
| 118 |
+
temporal_mgr = self.temporal_mod if use_temporal else None
|
| 119 |
+
template_feat, search_feat = self.backbone(template, search, temporal_mod_manager=temporal_mgr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
# Prediction heads
|
| 122 |
preds = self.center_head(search_feat)
|