omar-ah commited on
Commit
b3b0529
·
verified ·
1 Parent(s): ccfb718

Upload vil_tracker/models/tracker.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vil_tracker/models/tracker.py +166 -0
vil_tracker/models/tracker.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ViL Tracker: Full model combining backbone, FiLM modulation, and prediction heads.
3
+
4
+ Pipeline:
5
+ 1. Template (128x128) + Search (256x256) → PatchEmbed → tokens
6
+ 2. Concatenated tokens → ViL backbone (24 mLSTM blocks, bidirectional)
7
+ 3. FiLM temporal modulation at intervals (conditioned on prev frame)
8
+ 4. Search features → CenterHead → heatmap + size + offset
9
+ 5. Optional: UncertaintyHead → log variance for adaptive weighting
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from .backbone import ViLBackbone
16
+ from .film_temporal import TemporalModulationManager
17
+ from .heads import CenterHead, UncertaintyHead, decode_predictions
18
+
19
+
20
+ def get_default_config() -> dict:
21
+ """Default ViL-S tracker configuration meeting all constraints.
22
+
23
+ Constraints: ≤50M params, ≤30ms latency, ≤20 GFLOPs, ≤500MB
24
+ """
25
+ return {
26
+ # Backbone
27
+ 'dim': 384,
28
+ 'depth': 24,
29
+ 'patch_size': 16,
30
+ 'proj_factor': 2.0,
31
+ 'qkv_proj_blocksize': 4,
32
+ 'num_heads': 4,
33
+ 'conv_kernel': 4,
34
+ 'mlp_ratio': 4.0,
35
+ 'drop_path_rate': 0.1,
36
+ 'tmoe_blocks': 2,
37
+ 'num_experts': 4,
38
+
39
+ # FiLM temporal modulation
40
+ 'film_interval': 6,
41
+
42
+ # Heads
43
+ 'feat_size': 16,
44
+
45
+ # Inputs
46
+ 'template_size': 128,
47
+ 'search_size': 256,
48
+
49
+ # Uncertainty
50
+ 'use_uncertainty': True,
51
+ }
52
+
53
+
54
+ class ViLTracker(nn.Module):
55
+ """Complete ViL-based single object tracker.
56
+
57
+ Target specs (ViL-S):
58
+ - Parameters: ~35-40M (well under 50M limit)
59
+ - GFLOPs: ~15-18 (under 20 GFLOPs)
60
+ - Model size: ~140-160MB fp32, ~70-80MB fp16 (under 500MB)
61
+ - Latency: ~20-25ms on GPU (under 30ms)
62
+ """
63
+ def __init__(self, config: dict = None):
64
+ super().__init__()
65
+ config = config or get_default_config()
66
+ self.config = config
67
+
68
+ dim = config['dim']
69
+ depth = config['depth']
70
+
71
+ # Backbone
72
+ self.backbone = ViLBackbone(
73
+ dim=dim,
74
+ depth=depth,
75
+ patch_size=config['patch_size'],
76
+ proj_factor=config['proj_factor'],
77
+ qkv_proj_blocksize=config['qkv_proj_blocksize'],
78
+ num_heads=config['num_heads'],
79
+ conv_kernel=config['conv_kernel'],
80
+ mlp_ratio=config['mlp_ratio'],
81
+ drop_path_rate=config['drop_path_rate'],
82
+ tmoe_blocks=config['tmoe_blocks'],
83
+ num_experts=config['num_experts'],
84
+ )
85
+
86
+ # FiLM temporal modulation
87
+ self.temporal_mod = TemporalModulationManager(
88
+ dim=dim,
89
+ num_blocks=depth,
90
+ modulation_interval=config['film_interval'],
91
+ )
92
+
93
+ # Prediction heads
94
+ self.center_head = CenterHead(dim=dim, feat_size=config['feat_size'])
95
+
96
+ if config.get('use_uncertainty', True):
97
+ self.uncertainty_head = UncertaintyHead(dim=dim, feat_size=config['feat_size'])
98
+ else:
99
+ self.uncertainty_head = None
100
+
101
+ def forward(
102
+ self,
103
+ template: torch.Tensor,
104
+ search: torch.Tensor,
105
+ use_temporal: bool = False,
106
+ ) -> dict:
107
+ """
108
+ Args:
109
+ template: (B, 3, 128, 128) template image
110
+ search: (B, 3, 256, 256) search region
111
+ use_temporal: whether to apply FiLM temporal modulation
112
+ Returns:
113
+ dict with predictions: heatmap, size, offset, boxes, scores,
114
+ and optionally uncertainty
115
+ """
116
+ # Backbone forward
117
+ template_feat, search_feat = self.backbone(template, search)
118
+
119
+ # Optional FiLM temporal modulation on search features
120
+ if use_temporal:
121
+ for i in range(self.backbone.depth):
122
+ if self.temporal_mod.should_modulate(i):
123
+ search_feat = self.temporal_mod.modulate(search_feat, i)
124
+ # Update temporal context for next frame
125
+ self.temporal_mod.update_temporal_context(search_feat)
126
+
127
+ # Prediction heads
128
+ preds = self.center_head(search_feat)
129
+
130
+ # Decode to boxes
131
+ boxes, scores = decode_predictions(
132
+ preds['heatmap'],
133
+ preds['size'],
134
+ preds['offset'],
135
+ search_size=self.config['search_size'],
136
+ feat_size=self.config['feat_size'],
137
+ )
138
+
139
+ output = {
140
+ 'heatmap': preds['heatmap'],
141
+ 'size': preds['size'],
142
+ 'offset': preds['offset'],
143
+ 'boxes': boxes,
144
+ 'scores': scores,
145
+ 'template_feat': template_feat,
146
+ 'search_feat': search_feat,
147
+ }
148
+
149
+ # Uncertainty prediction
150
+ if self.uncertainty_head is not None:
151
+ output['log_variance'] = self.uncertainty_head(search_feat)
152
+
153
+ return output
154
+
155
+ def reset_temporal(self):
156
+ """Reset temporal modulation state (for new tracking sequence)."""
157
+ self.temporal_mod.reset()
158
+
159
+ def freeze_backbone_shared_experts(self):
160
+ """Freeze shared experts in TMoE blocks for Phase 2."""
161
+ self.backbone.freeze_shared_experts()
162
+
163
+
164
+ def build_tracker(config: dict = None) -> ViLTracker:
165
+ """Build a ViL tracker with given or default config."""
166
+ return ViLTracker(config or get_default_config())