omar-ah commited on
Commit
1bf192e
·
verified ·
1 Parent(s): 9bef6c8

Sequence training: pairs→K-frame clips, mLSTM memory carries across frames

Browse files
Files changed (1) hide show
  1. vil_tracker/training/train.py +72 -36
vil_tracker/training/train.py CHANGED
@@ -120,26 +120,49 @@ def train_one_epoch(
120
 
121
  for batch_idx, batch in enumerate(dataloader):
122
  template = batch['template'].to(device)
123
- search = batch['search'].to(device)
124
- gt_heatmap = batch['heatmap'].to(device)
125
- gt_size = batch['size'].to(device)
126
- gt_boxes = batch['boxes'].to(device)
 
 
127
 
128
  optimizer.zero_grad()
129
  if loss_optimizer is not None:
130
  loss_optimizer.zero_grad()
131
 
132
  with autocast(enabled=scaler is not None):
133
- # Forward pass with optional temporal modulation
134
- pred = model(template, search, use_temporal=use_temporal)
135
- loss_dict = loss_fn(pred, gt_heatmap, gt_size, gt_boxes)
136
- loss = loss_dict['total']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- # Contrastive loss on template/search features (Phase 2)
139
- if contrastive_loss is not None and 'template_feat' in pred and 'search_feat' in pred:
140
- # Pool features to get sequence-level representations
141
- t_pooled = pred['template_feat'].mean(dim=1) # (B, D)
142
- s_pooled = pred['search_feat'].mean(dim=1) # (B, D)
 
143
  c_loss = contrastive_loss(t_pooled, s_pooled)
144
  loss = loss + contrastive_weight * c_loss
145
  total_contrastive_loss += c_loss.item()
@@ -165,17 +188,17 @@ def train_one_epoch(
165
  loss_optimizer.step()
166
 
167
  total_loss += loss.item()
168
- total_heatmap_loss += loss_dict['heatmap'].item()
169
- total_giou_loss += loss_dict['giou'].item()
170
- total_size_loss += loss_dict['size'].item()
171
  num_batches += 1
172
 
173
  if batch_idx % 100 == 0:
174
  msg = (f" Epoch {epoch}/{total_epochs} | Batch {batch_idx} | "
175
  f"Loss: {loss.item():.4f} | "
176
- f"Heatmap: {loss_dict['heatmap']:.4f} | "
177
- f"GIoU: {loss_dict['giou']:.4f} | "
178
- f"Size: {loss_dict['size']:.4f}")
179
  if contrastive_loss is not None and total_contrastive_loss > 0:
180
  msg += f" | Contr: {total_contrastive_loss / max(1, num_batches):.4f}"
181
  print(msg)
@@ -349,36 +372,51 @@ def train_phase2(
349
 
350
  for batch_idx, batch in enumerate(dataloader):
351
  template = batch['template'].to(device)
352
- search = batch['search'].to(device)
353
- gt_heatmap = batch['heatmap'].to(device)
354
- gt_size = batch['size'].to(device)
355
  gt_boxes = batch['boxes'].to(device)
356
 
 
 
357
  optimizer.zero_grad()
358
  if loss_optimizer is not None:
359
  loss_optimizer.zero_grad()
360
 
361
  with autocast(enabled=scaler is not None):
362
- # Always use temporal modulation in Phase 2
363
- pred = model(template, search, use_temporal=True)
364
- loss_dict = loss_fn(pred, gt_heatmap, gt_size, gt_boxes)
365
- loss = loss_dict['total']
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
- # Contrastive loss on temporal features
368
  t_pooled = pred['template_feat'].mean(dim=1)
369
- s_pooled = pred['search_feat'].mean(dim=1)
370
  c_loss = contrastive_loss(t_pooled, s_pooled)
371
  loss = loss + 0.1 * c_loss
372
 
373
- # AFKD distillation loss (if teacher available)
374
  if distill_loss is not None and teacher_model is not None:
375
  with torch.no_grad():
376
- teacher_pred = teacher_model(template, search)
 
377
  d_loss = distill_loss(
378
- student_feat=pred['search_feat'],
379
- teacher_feat=teacher_pred['search_feat'],
380
- student_logits=pred['heatmap'],
381
- teacher_logits=teacher_pred['heatmap'],
382
  )
383
  loss = loss + 0.5 * d_loss
384
 
@@ -404,8 +442,6 @@ def train_phase2(
404
  if batch_idx % 100 == 0:
405
  msg = (f" Phase2 Epoch {epoch}/{num_epochs} | Batch {batch_idx} | "
406
  f"Loss: {loss.item():.4f} | "
407
- f"Heatmap: {loss_dict['heatmap']:.4f} | "
408
- f"GIoU: {loss_dict['giou']:.4f} | "
409
  f"Contr: {c_loss.item():.4f}")
410
  if distill_loss is not None:
411
  msg += f" | Distill: {d_loss.item():.4f}"
 
120
 
121
  for batch_idx, batch in enumerate(dataloader):
122
  template = batch['template'].to(device)
123
+ searches = batch['searches'].to(device) # (B, K, 3, 256, 256)
124
+ gt_heatmaps = batch['heatmaps'].to(device) # (B, K, 1, 16, 16)
125
+ gt_sizes = batch['sizes'].to(device) # (B, K, 2)
126
+ gt_boxes = batch['boxes'].to(device) # (B, K, 4)
127
+
128
+ B, K = searches.shape[:2]
129
 
130
  optimizer.zero_grad()
131
  if loss_optimizer is not None:
132
  loss_optimizer.zero_grad()
133
 
134
  with autocast(enabled=scaler is not None):
135
+ # Forward: template + K search frames as one sequence
136
+ pred = model(template, searches, use_temporal=use_temporal)
137
+
138
+ # Accumulate loss over K frames
139
+ loss = torch.tensor(0.0, device=device)
140
+ frame_heatmap = 0.0
141
+ frame_giou = 0.0
142
+ frame_size = 0.0
143
+
144
+ for k in range(K):
145
+ pred_k = {
146
+ 'heatmap': pred['heatmap'][:, k], # (B, 1, 16, 16)
147
+ 'size': pred['size'][:, k], # (B, 2, 16, 16)
148
+ 'boxes': pred['boxes'][:, k], # (B, 4)
149
+ }
150
+ if 'log_variance' in pred:
151
+ pred_k['log_variance'] = pred['log_variance'][:, k]
152
+
153
+ loss_dict_k = loss_fn(pred_k, gt_heatmaps[:, k],
154
+ gt_sizes[:, k], gt_boxes[:, k])
155
+ loss = loss + loss_dict_k['total']
156
+ frame_heatmap += loss_dict_k['heatmap'].item()
157
+ frame_giou += loss_dict_k['giou'].item()
158
+ frame_size += loss_dict_k['size'].item()
159
 
160
+ loss = loss / K # Average over frames
161
+
162
+ # Contrastive loss on template/search features
163
+ if contrastive_loss is not None and 'search_feats' in pred:
164
+ t_pooled = pred['template_feat'].mean(dim=1) # (B, D)
165
+ s_pooled = pred['search_feats'][:, -1].mean(dim=1) # (B, D) last frame
166
  c_loss = contrastive_loss(t_pooled, s_pooled)
167
  loss = loss + contrastive_weight * c_loss
168
  total_contrastive_loss += c_loss.item()
 
188
  loss_optimizer.step()
189
 
190
  total_loss += loss.item()
191
+ total_heatmap_loss += frame_heatmap / K
192
+ total_giou_loss += frame_giou / K
193
+ total_size_loss += frame_size / K
194
  num_batches += 1
195
 
196
  if batch_idx % 100 == 0:
197
  msg = (f" Epoch {epoch}/{total_epochs} | Batch {batch_idx} | "
198
  f"Loss: {loss.item():.4f} | "
199
+ f"Heatmap: {frame_heatmap/K:.4f} | "
200
+ f"GIoU: {frame_giou/K:.4f} | "
201
+ f"Size: {frame_size/K:.4f}")
202
  if contrastive_loss is not None and total_contrastive_loss > 0:
203
  msg += f" | Contr: {total_contrastive_loss / max(1, num_batches):.4f}"
204
  print(msg)
 
372
 
373
  for batch_idx, batch in enumerate(dataloader):
374
  template = batch['template'].to(device)
375
+ searches = batch['searches'].to(device)
376
+ gt_heatmaps = batch['heatmaps'].to(device)
377
+ gt_sizes = batch['sizes'].to(device)
378
  gt_boxes = batch['boxes'].to(device)
379
 
380
+ B, K = searches.shape[:2]
381
+
382
  optimizer.zero_grad()
383
  if loss_optimizer is not None:
384
  loss_optimizer.zero_grad()
385
 
386
  with autocast(enabled=scaler is not None):
387
+ pred = model(template, searches, use_temporal=True)
388
+
389
+ # Accumulate loss over K frames
390
+ loss = torch.tensor(0.0, device=device)
391
+ for k in range(K):
392
+ pred_k = {
393
+ 'heatmap': pred['heatmap'][:, k],
394
+ 'size': pred['size'][:, k],
395
+ 'boxes': pred['boxes'][:, k],
396
+ }
397
+ if 'log_variance' in pred:
398
+ pred_k['log_variance'] = pred['log_variance'][:, k]
399
+ loss_dict_k = loss_fn(pred_k, gt_heatmaps[:, k],
400
+ gt_sizes[:, k], gt_boxes[:, k])
401
+ loss = loss + loss_dict_k['total']
402
+ loss = loss / K
403
 
404
+ # Contrastive loss
405
  t_pooled = pred['template_feat'].mean(dim=1)
406
+ s_pooled = pred['search_feats'][:, -1].mean(dim=1)
407
  c_loss = contrastive_loss(t_pooled, s_pooled)
408
  loss = loss + 0.1 * c_loss
409
 
410
+ # AFKD distillation (if teacher available)
411
  if distill_loss is not None and teacher_model is not None:
412
  with torch.no_grad():
413
+ teacher_pred = teacher_model(template, searches)
414
+ # Distill on last frame features
415
  d_loss = distill_loss(
416
+ student_feat=pred['search_feats'][:, -1],
417
+ teacher_feat=teacher_pred['search_feats'][:, -1] if teacher_pred['search_feats'].ndim == 4 else teacher_pred['search_feat'],
418
+ student_logits=pred['heatmap'][:, -1],
419
+ teacher_logits=teacher_pred['heatmap'][:, -1] if teacher_pred['heatmap'].ndim == 5 else teacher_pred['heatmap'],
420
  )
421
  loss = loss + 0.5 * d_loss
422
 
 
442
  if batch_idx % 100 == 0:
443
  msg = (f" Phase2 Epoch {epoch}/{num_epochs} | Batch {batch_idx} | "
444
  f"Loss: {loss.item():.4f} | "
 
 
445
  f"Contr: {c_loss.item():.4f}")
446
  if distill_loss is not None:
447
  msg += f" | Distill: {d_loss.item():.4f}"