omar-ah commited on
Commit
be1f14e
·
verified ·
1 Parent(s): a4d3af5

Sequence training: pairs→K-frame clips, mLSTM memory carries across frames

Browse files
Files changed (1) hide show
  1. vil_tracker/models/tracker.py +70 -27
vil_tracker/models/tracker.py CHANGED
@@ -102,47 +102,90 @@ class ViLTracker(nn.Module):
102
  def forward(
103
  self,
104
  template: torch.Tensor,
105
- search: torch.Tensor,
106
  use_temporal: bool = False,
107
  ) -> dict:
108
  """
 
 
109
  Args:
110
  template: (B, 3, 128, 128) template image
111
- search: (B, 3, 256, 256) search region
 
112
  use_temporal: whether to apply FiLM temporal modulation
113
  Returns:
114
- dict with predictions: heatmap, size, offset, boxes, scores,
115
- and optionally uncertainty
 
 
 
 
 
 
116
  """
117
- # Backbone forward with optional integrated FiLM modulation
 
118
  temporal_mgr = self.temporal_mod if use_temporal else None
119
- template_feat, search_feat = self.backbone(template, search, temporal_mod_manager=temporal_mgr)
120
 
121
- # Prediction heads
122
- preds = self.center_head(search_feat)
123
-
124
- # Decode to boxes
125
- boxes, scores = decode_predictions(
126
- preds['heatmap'],
127
- preds['size'],
128
- preds['offset'],
129
- search_size=self.config['search_size'],
130
- feat_size=self.config['feat_size'],
131
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  output = {
134
- 'heatmap': preds['heatmap'],
135
- 'size': preds['size'],
136
- 'offset': preds['offset'],
137
- 'boxes': boxes,
138
- 'scores': scores,
139
- 'template_feat': template_feat,
140
- 'search_feat': search_feat,
141
  }
142
 
143
- # Uncertainty prediction
144
- if self.uncertainty_head is not None:
145
- output['log_variance'] = self.uncertainty_head(search_feat)
146
 
147
  return output
148
 
 
102
  def forward(
103
  self,
104
  template: torch.Tensor,
105
+ searches: torch.Tensor,
106
  use_temporal: bool = False,
107
  ) -> dict:
108
  """
109
+ Process template + K search frames through the full tracker.
110
+
111
  Args:
112
  template: (B, 3, 128, 128) template image
113
+ searches: (B, K, 3, 256, 256) K consecutive search frames
114
+ OR (B, 3, 256, 256) single search frame (backward compat)
115
  use_temporal: whether to apply FiLM temporal modulation
116
  Returns:
117
+ dict with per-frame predictions:
118
+ heatmap: (B, K, 1, 16, 16) or (B, 1, 16, 16) if single
119
+ size: (B, K, 2, 16, 16) or (B, 2, 16, 16)
120
+ offset: (B, K, 2, 16, 16) or (B, 2, 16, 16)
121
+ boxes: (B, K, 4) or (B, 4)
122
+ scores: (B, K) or (B,)
123
+ template_feat: (B, 64, D)
124
+ search_feats: (B, K, 256, D) or (B, 256, D)
125
  """
126
+ single_frame = (searches.ndim == 4)
127
+
128
  temporal_mgr = self.temporal_mod if use_temporal else None
129
+ template_feat, search_feats = self.backbone(template, searches, temporal_mod_manager=temporal_mgr)
130
 
131
+ # search_feats: (B, K, 256, D) for multi-frame, (B, 256, D) for single
132
+ if single_frame:
133
+ # Single frame path — same as before
134
+ preds = self.center_head(search_feats)
135
+ boxes, scores = decode_predictions(
136
+ preds['heatmap'], preds['size'], preds['offset'],
137
+ search_size=self.config['search_size'],
138
+ feat_size=self.config['feat_size'],
139
+ )
140
+ output = {
141
+ 'heatmap': preds['heatmap'],
142
+ 'size': preds['size'],
143
+ 'offset': preds['offset'],
144
+ 'boxes': boxes,
145
+ 'scores': scores,
146
+ 'template_feat': template_feat,
147
+ 'search_feat': search_feats,
148
+ }
149
+ if self.uncertainty_head is not None:
150
+ output['log_variance'] = self.uncertainty_head(search_feats)
151
+ return output
152
+
153
+ # Multi-frame path: run head on each frame's search features
154
+ B, K = search_feats.shape[:2]
155
+
156
+ all_heatmaps, all_sizes, all_offsets = [], [], []
157
+ all_boxes, all_scores = [], []
158
+ all_log_var = []
159
+
160
+ for k in range(K):
161
+ s_feat_k = search_feats[:, k] # (B, 256, D)
162
+ preds_k = self.center_head(s_feat_k)
163
+ boxes_k, scores_k = decode_predictions(
164
+ preds_k['heatmap'], preds_k['size'], preds_k['offset'],
165
+ search_size=self.config['search_size'],
166
+ feat_size=self.config['feat_size'],
167
+ )
168
+ all_heatmaps.append(preds_k['heatmap'])
169
+ all_sizes.append(preds_k['size'])
170
+ all_offsets.append(preds_k['offset'])
171
+ all_boxes.append(boxes_k)
172
+ all_scores.append(scores_k)
173
+
174
+ if self.uncertainty_head is not None:
175
+ all_log_var.append(self.uncertainty_head(s_feat_k))
176
 
177
  output = {
178
+ 'heatmap': torch.stack(all_heatmaps, dim=1), # (B, K, 1, 16, 16)
179
+ 'size': torch.stack(all_sizes, dim=1), # (B, K, 2, 16, 16)
180
+ 'offset': torch.stack(all_offsets, dim=1), # (B, K, 2, 16, 16)
181
+ 'boxes': torch.stack(all_boxes, dim=1), # (B, K, 4)
182
+ 'scores': torch.stack(all_scores, dim=1), # (B, K)
183
+ 'template_feat': template_feat, # (B, 64, D)
184
+ 'search_feats': search_feats, # (B, K, 256, D)
185
  }
186
 
187
+ if self.uncertainty_head is not None and all_log_var:
188
+ output['log_variance'] = torch.stack(all_log_var, dim=1) # (B, K, 1, 16, 16)
 
189
 
190
  return output
191