omar-ah commited on
Commit
64c9f85
·
verified ·
1 Parent(s): 69641c6

Fix vil_tracker/training/train.py: audit corrections

Browse files
Files changed (1) hide show
  1. vil_tracker/training/train.py +247 -35
vil_tracker/training/train.py CHANGED
@@ -4,13 +4,15 @@ Training script for ViL Tracker.
4
  Two-phase training:
5
  Phase 1: Standard supervised training on GOT-10k + LaSOT + TrackingNet
6
  - Full model training with focal + GIoU + size losses
7
- - ACL curriculum (progressive difficulty ramp-up)
 
8
  - 300 epochs, lr=1e-4 with cosine decay, warmup=5 epochs
9
 
10
  Phase 2: Fine-tuning with TMoE and distillation
11
  - Freeze shared experts in TMoE blocks
12
  - Add contrastive loss on temporal features
13
- - Optional AFKD distillation from MCITrack teacher
 
14
  - 100 epochs, lr=1e-5
15
 
16
  Hardware: Designed for A10G (24GB) or A100 (80GB)
@@ -27,9 +29,18 @@ from torch.cuda.amp import autocast, GradScaler
27
 
28
 
29
  def build_optimizer(model, lr=1e-4, weight_decay=0.05, backbone_lr_scale=0.1):
30
- """Build AdamW optimizer with layer-wise learning rate decay."""
 
 
 
 
 
 
 
31
  backbone_params = []
32
  head_params = []
 
 
33
  other_params = []
34
 
35
  for name, param in model.named_parameters():
@@ -39,36 +50,72 @@ def build_optimizer(model, lr=1e-4, weight_decay=0.05, backbone_lr_scale=0.1):
39
  backbone_params.append(param)
40
  elif 'center_head' in name or 'uncertainty_head' in name:
41
  head_params.append(param)
 
 
42
  else:
43
  other_params.append(param)
44
 
45
  param_groups = [
46
- {'params': backbone_params, 'lr': lr * backbone_lr_scale},
47
- {'params': head_params, 'lr': lr},
48
- {'params': other_params, 'lr': lr * 0.5},
 
49
  ]
50
 
 
 
 
51
  return optim.AdamW(param_groups, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
52
 
53
 
 
 
 
 
 
 
 
 
54
  def build_scheduler(optimizer, total_epochs, warmup_epochs=5):
55
  """Cosine annealing with linear warmup."""
56
  def lr_lambda(epoch):
57
  if epoch < warmup_epochs:
58
- return epoch / warmup_epochs
59
- progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
60
  return 0.5 * (1 + math.cos(math.pi * progress))
61
 
62
  return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
63
 
64
 
65
  def train_one_epoch(
66
- model, dataloader, optimizer, scheduler, scaler, loss_fn, device,
67
  epoch, total_epochs, acl_lambda=None, grad_clip=1.0,
 
68
  ):
69
- """Train for one epoch with AMP and gradient clipping."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  model.train()
71
  total_loss = 0
 
 
 
 
72
  num_batches = 0
73
 
74
  for batch_idx, batch in enumerate(dataloader):
@@ -79,12 +126,24 @@ def train_one_epoch(
79
  gt_boxes = batch['boxes'].to(device)
80
 
81
  optimizer.zero_grad()
 
 
82
 
83
  with autocast(enabled=scaler is not None):
84
- pred = model(template, search, use_temporal=False)
 
85
  loss_dict = loss_fn(pred, gt_heatmap, gt_size, gt_boxes)
86
  loss = loss_dict['total']
87
 
 
 
 
 
 
 
 
 
 
88
  # ACL difficulty weighting
89
  if acl_lambda is not None:
90
  loss = loss * acl_lambda
@@ -94,24 +153,41 @@ def train_one_epoch(
94
  scaler.unscale_(optimizer)
95
  nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
96
  scaler.step(optimizer)
 
 
 
97
  scaler.update()
98
  else:
99
  loss.backward()
100
  nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
101
  optimizer.step()
 
 
102
 
103
  total_loss += loss.item()
 
 
 
104
  num_batches += 1
105
 
106
  if batch_idx % 100 == 0:
107
- print(f" Epoch {epoch}/{total_epochs} | Batch {batch_idx} | "
108
- f"Loss: {loss.item():.4f} | "
109
- f"Heatmap: {loss_dict['heatmap']:.4f} | "
110
- f"GIoU: {loss_dict['giou']:.4f} | "
111
- f"Size: {loss_dict['size']:.4f}")
 
 
 
112
 
113
- avg_loss = total_loss / max(num_batches, 1)
114
- return avg_loss
 
 
 
 
 
 
115
 
116
 
117
  def train_phase1(
@@ -119,7 +195,17 @@ def train_phase1(
119
  num_epochs=300, lr=1e-4, batch_size=32, num_workers=4,
120
  save_dir='./checkpoints', push_to_hub=False, hub_model_id=None,
121
  ):
122
- """Phase 1: Standard supervised training."""
 
 
 
 
 
 
 
 
 
 
123
  print(f"=== Phase 1 Training: {num_epochs} epochs ===")
124
 
125
  os.makedirs(save_dir, exist_ok=True)
@@ -129,6 +215,7 @@ def train_phase1(
129
 
130
  model = model.to(device)
131
  optimizer = build_optimizer(model, lr=lr)
 
132
  scheduler = build_scheduler(optimizer, num_epochs)
133
  scaler = GradScaler() if device == 'cuda' else None
134
 
@@ -140,27 +227,51 @@ def train_phase1(
140
  best_loss = float('inf')
141
 
142
  for epoch in range(num_epochs):
143
- # ACL curriculum: linear ramp-up of difficulty
144
- acl_lambda = min(1.0, (epoch + 1) / 50) # Ramp up over 50 epochs
 
145
 
146
- avg_loss = train_one_epoch(
147
- model, dataloader, optimizer, scheduler, scaler, loss_fn,
 
 
 
 
 
 
 
 
 
 
 
 
148
  device, epoch, num_epochs, acl_lambda=acl_lambda,
 
149
  )
150
 
151
  scheduler.step()
152
 
153
- print(f"Epoch {epoch}/{num_epochs} | Avg Loss: {avg_loss:.4f} | "
154
- f"LR: {scheduler.get_last_lr()[0]:.6f} | ACL λ: {acl_lambda:.2f}")
 
 
 
 
 
 
 
 
 
155
 
156
  # Save best
157
- if avg_loss < best_loss:
158
- best_loss = avg_loss
159
  torch.save({
160
  'epoch': epoch,
161
  'model_state_dict': model.state_dict(),
162
  'optimizer_state_dict': optimizer.state_dict(),
163
  'loss': best_loss,
 
164
  }, os.path.join(save_dir, 'best_phase1.pth'))
165
 
166
  # Save periodic
@@ -169,7 +280,8 @@ def train_phase1(
169
  'epoch': epoch,
170
  'model_state_dict': model.state_dict(),
171
  'optimizer_state_dict': optimizer.state_dict(),
172
- 'loss': avg_loss,
 
173
  }, os.path.join(save_dir, f'phase1_epoch{epoch+1}.pth'))
174
 
175
  if push_to_hub and hub_model_id:
@@ -182,18 +294,44 @@ def train_phase2(
182
  model, train_dataset, config, device='cuda',
183
  num_epochs=100, lr=1e-5, batch_size=32, num_workers=4,
184
  save_dir='./checkpoints', push_to_hub=False, hub_model_id=None,
 
185
  ):
186
- """Phase 2: Fine-tuning with frozen shared experts."""
 
 
 
 
 
 
 
 
187
  print(f"=== Phase 2 Training: {num_epochs} epochs ===")
188
 
189
- # Freeze shared experts
190
  model.freeze_backbone_shared_experts()
 
 
 
191
 
192
- from .losses import CombinedTrackingLoss
193
  loss_fn = CombinedTrackingLoss(use_uncertainty=True, use_adw=True).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  model = model.to(device)
196
  optimizer = build_optimizer(model, lr=lr, backbone_lr_scale=0.01)
 
197
  scheduler = build_scheduler(optimizer, num_epochs, warmup_epochs=2)
198
  scaler = GradScaler() if device == 'cuda' else None
199
 
@@ -205,13 +343,78 @@ def train_phase2(
205
  best_loss = float('inf')
206
 
207
  for epoch in range(num_epochs):
208
- avg_loss = train_one_epoch(
209
- model, dataloader, optimizer, scheduler, scaler, loss_fn,
210
- device, epoch, num_epochs,
211
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  scheduler.step()
 
214
 
 
215
  print(f"Phase2 Epoch {epoch}/{num_epochs} | Avg Loss: {avg_loss:.4f} | "
216
  f"LR: {scheduler.get_last_lr()[0]:.6f}")
217
 
@@ -221,7 +424,16 @@ def train_phase2(
221
  'epoch': epoch,
222
  'model_state_dict': model.state_dict(),
223
  'loss': best_loss,
 
224
  }, os.path.join(save_dir, 'best_phase2.pth'))
 
 
 
 
 
 
 
 
225
 
226
  if push_to_hub and hub_model_id:
227
  _push_checkpoint_to_hub(model, save_dir, hub_model_id, 'phase2')
 
4
  Two-phase training:
5
  Phase 1: Standard supervised training on GOT-10k + LaSOT + TrackingNet
6
  - Full model training with focal + GIoU + size losses
7
+ - ACL curriculum (progressive difficulty ramp-up on dataset AND loss weighting)
8
+ - FiLM temporal modulation trained with temporal pairs
9
  - 300 epochs, lr=1e-4 with cosine decay, warmup=5 epochs
10
 
11
  Phase 2: Fine-tuning with TMoE and distillation
12
  - Freeze shared experts in TMoE blocks
13
  - Add contrastive loss on temporal features
14
+ - Optional AFKD distillation from MCITrack-B256 teacher
15
+ - FiLM temporal modulation active for all samples
16
  - 100 epochs, lr=1e-5
17
 
18
  Hardware: Designed for A10G (24GB) or A100 (80GB)
 
29
 
30
 
31
  def build_optimizer(model, lr=1e-4, weight_decay=0.05, backbone_lr_scale=0.1):
32
+ """Build AdamW optimizer with component-wise learning rate scaling.
33
+
34
+ Groups:
35
+ - backbone: lr * backbone_lr_scale (pretrained or dominant, train slower)
36
+ - heads: full lr (task-specific, need fast adaptation)
37
+ - temporal_mod: lr * 0.5 (FiLM modulation, moderate learning)
38
+ - loss params (ADW): lr * 0.1 (loss weighting, very slow adaptation)
39
+ """
40
  backbone_params = []
41
  head_params = []
42
+ temporal_params = []
43
+ loss_params = []
44
  other_params = []
45
 
46
  for name, param in model.named_parameters():
 
50
  backbone_params.append(param)
51
  elif 'center_head' in name or 'uncertainty_head' in name:
52
  head_params.append(param)
53
+ elif 'temporal_mod' in name:
54
+ temporal_params.append(param)
55
  else:
56
  other_params.append(param)
57
 
58
  param_groups = [
59
+ {'params': backbone_params, 'lr': lr * backbone_lr_scale, 'name': 'backbone'},
60
+ {'params': head_params, 'lr': lr, 'name': 'heads'},
61
+ {'params': temporal_params, 'lr': lr * 0.5, 'name': 'temporal'},
62
+ {'params': other_params, 'lr': lr * 0.5, 'name': 'other'},
63
  ]
64
 
65
+ # Filter empty groups
66
+ param_groups = [g for g in param_groups if len(g['params']) > 0]
67
+
68
  return optim.AdamW(param_groups, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
69
 
70
 
71
+ def build_loss_optimizer(loss_fn, lr=1e-3):
72
+ """Separate optimizer for ADW loss weights (if trainable)."""
73
+ loss_params = [p for p in loss_fn.parameters() if p.requires_grad]
74
+ if loss_params:
75
+ return optim.Adam(loss_params, lr=lr)
76
+ return None
77
+
78
+
79
  def build_scheduler(optimizer, total_epochs, warmup_epochs=5):
80
  """Cosine annealing with linear warmup."""
81
  def lr_lambda(epoch):
82
  if epoch < warmup_epochs:
83
+ return max(0.01, epoch / warmup_epochs)
84
+ progress = (epoch - warmup_epochs) / max(1, total_epochs - warmup_epochs)
85
  return 0.5 * (1 + math.cos(math.pi * progress))
86
 
87
  return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
88
 
89
 
90
  def train_one_epoch(
91
+ model, dataloader, optimizer, loss_optimizer, scaler, loss_fn, device,
92
  epoch, total_epochs, acl_lambda=None, grad_clip=1.0,
93
+ use_temporal=False, contrastive_loss=None, contrastive_weight=0.1,
94
  ):
95
+ """Train for one epoch with AMP, gradient clipping, and optional temporal training.
96
+
97
+ Args:
98
+ model: ViLTracker instance
99
+ dataloader: training data loader
100
+ optimizer: model optimizer
101
+ loss_optimizer: separate optimizer for ADW loss weights (can be None)
102
+ scaler: GradScaler for AMP (None if cpu)
103
+ loss_fn: CombinedTrackingLoss instance
104
+ device: 'cuda' or 'cpu'
105
+ epoch: current epoch number
106
+ total_epochs: total number of epochs
107
+ acl_lambda: ACL difficulty weight for loss scaling
108
+ grad_clip: max gradient norm
109
+ use_temporal: whether to use FiLM temporal modulation
110
+ contrastive_loss: optional MemoryContrastiveLoss for Phase 2
111
+ contrastive_weight: weight for contrastive loss
112
+ """
113
  model.train()
114
  total_loss = 0
115
+ total_heatmap_loss = 0
116
+ total_giou_loss = 0
117
+ total_size_loss = 0
118
+ total_contrastive_loss = 0
119
  num_batches = 0
120
 
121
  for batch_idx, batch in enumerate(dataloader):
 
126
  gt_boxes = batch['boxes'].to(device)
127
 
128
  optimizer.zero_grad()
129
+ if loss_optimizer is not None:
130
+ loss_optimizer.zero_grad()
131
 
132
  with autocast(enabled=scaler is not None):
133
+ # Forward pass with optional temporal modulation
134
+ pred = model(template, search, use_temporal=use_temporal)
135
  loss_dict = loss_fn(pred, gt_heatmap, gt_size, gt_boxes)
136
  loss = loss_dict['total']
137
 
138
+ # Contrastive loss on template/search features (Phase 2)
139
+ if contrastive_loss is not None and 'template_feat' in pred and 'search_feat' in pred:
140
+ # Pool features to get sequence-level representations
141
+ t_pooled = pred['template_feat'].mean(dim=1) # (B, D)
142
+ s_pooled = pred['search_feat'].mean(dim=1) # (B, D)
143
+ c_loss = contrastive_loss(t_pooled, s_pooled)
144
+ loss = loss + contrastive_weight * c_loss
145
+ total_contrastive_loss += c_loss.item()
146
+
147
  # ACL difficulty weighting
148
  if acl_lambda is not None:
149
  loss = loss * acl_lambda
 
153
  scaler.unscale_(optimizer)
154
  nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
155
  scaler.step(optimizer)
156
+ if loss_optimizer is not None:
157
+ scaler.unscale_(loss_optimizer)
158
+ scaler.step(loss_optimizer)
159
  scaler.update()
160
  else:
161
  loss.backward()
162
  nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
163
  optimizer.step()
164
+ if loss_optimizer is not None:
165
+ loss_optimizer.step()
166
 
167
  total_loss += loss.item()
168
+ total_heatmap_loss += loss_dict['heatmap'].item()
169
+ total_giou_loss += loss_dict['giou'].item()
170
+ total_size_loss += loss_dict['size'].item()
171
  num_batches += 1
172
 
173
  if batch_idx % 100 == 0:
174
+ msg = (f" Epoch {epoch}/{total_epochs} | Batch {batch_idx} | "
175
+ f"Loss: {loss.item():.4f} | "
176
+ f"Heatmap: {loss_dict['heatmap']:.4f} | "
177
+ f"GIoU: {loss_dict['giou']:.4f} | "
178
+ f"Size: {loss_dict['size']:.4f}")
179
+ if contrastive_loss is not None and total_contrastive_loss > 0:
180
+ msg += f" | Contr: {total_contrastive_loss / max(1, num_batches):.4f}"
181
+ print(msg)
182
 
183
+ n = max(num_batches, 1)
184
+ return {
185
+ 'total': total_loss / n,
186
+ 'heatmap': total_heatmap_loss / n,
187
+ 'giou': total_giou_loss / n,
188
+ 'size': total_size_loss / n,
189
+ 'contrastive': total_contrastive_loss / n if total_contrastive_loss > 0 else 0,
190
+ }
191
 
192
 
193
  def train_phase1(
 
195
  num_epochs=300, lr=1e-4, batch_size=32, num_workers=4,
196
  save_dir='./checkpoints', push_to_hub=False, hub_model_id=None,
197
  ):
198
+ """Phase 1: Standard supervised training with ACL curriculum.
199
+
200
+ ACL Curriculum:
201
+ - Epoch 0-50: difficulty ramps from 0→1 (easy to hard samples)
202
+ - Loss weighting: acl_lambda ramps from 0.5→1.0
203
+ - Dataset augmentation intensity increases with difficulty
204
+
205
+ FiLM temporal modulation:
206
+ - Starts training after epoch 30 (model needs basic features first)
207
+ - Activated for 50% of batches initially, 100% after epoch 100
208
+ """
209
  print(f"=== Phase 1 Training: {num_epochs} epochs ===")
210
 
211
  os.makedirs(save_dir, exist_ok=True)
 
215
 
216
  model = model.to(device)
217
  optimizer = build_optimizer(model, lr=lr)
218
+ loss_optimizer = build_loss_optimizer(loss_fn)
219
  scheduler = build_scheduler(optimizer, num_epochs)
220
  scaler = GradScaler() if device == 'cuda' else None
221
 
 
227
  best_loss = float('inf')
228
 
229
  for epoch in range(num_epochs):
230
+ # ACL curriculum: progressive difficulty ramp-up
231
+ acl_progress = min(1.0, (epoch + 1) / 50) # Linear ramp over 50 epochs
232
+ acl_lambda = 0.5 + 0.5 * acl_progress # Loss weight: 0.5 → 1.0
233
 
234
+ # Update dataset difficulty (if supported)
235
+ if hasattr(train_dataset, 'set_acl_difficulty'):
236
+ train_dataset.set_acl_difficulty(acl_progress)
237
+ elif hasattr(train_dataset, 'datasets'):
238
+ # ConcatDataset: update all sub-datasets
239
+ for ds in train_dataset.datasets:
240
+ if hasattr(ds, 'set_acl_difficulty'):
241
+ ds.set_acl_difficulty(acl_progress)
242
+
243
+ # FiLM temporal modulation schedule
244
+ use_temporal = epoch >= 30 # Start FiLM after 30 epochs
245
+
246
+ loss_metrics = train_one_epoch(
247
+ model, dataloader, optimizer, loss_optimizer, scaler, loss_fn,
248
  device, epoch, num_epochs, acl_lambda=acl_lambda,
249
+ use_temporal=use_temporal,
250
  )
251
 
252
  scheduler.step()
253
 
254
+ # Reset temporal state between epochs (each epoch starts fresh sequences)
255
+ model.reset_temporal()
256
+
257
+ print(f"Epoch {epoch}/{num_epochs} | "
258
+ f"Loss: {loss_metrics['total']:.4f} | "
259
+ f"Heatmap: {loss_metrics['heatmap']:.4f} | "
260
+ f"GIoU: {loss_metrics['giou']:.4f} | "
261
+ f"Size: {loss_metrics['size']:.4f} | "
262
+ f"LR: {scheduler.get_last_lr()[0]:.6f} | "
263
+ f"ACL: {acl_progress:.2f} | "
264
+ f"Temporal: {'ON' if use_temporal else 'OFF'}")
265
 
266
  # Save best
267
+ if loss_metrics['total'] < best_loss:
268
+ best_loss = loss_metrics['total']
269
  torch.save({
270
  'epoch': epoch,
271
  'model_state_dict': model.state_dict(),
272
  'optimizer_state_dict': optimizer.state_dict(),
273
  'loss': best_loss,
274
+ 'config': config,
275
  }, os.path.join(save_dir, 'best_phase1.pth'))
276
 
277
  # Save periodic
 
280
  'epoch': epoch,
281
  'model_state_dict': model.state_dict(),
282
  'optimizer_state_dict': optimizer.state_dict(),
283
+ 'loss': loss_metrics['total'],
284
+ 'config': config,
285
  }, os.path.join(save_dir, f'phase1_epoch{epoch+1}.pth'))
286
 
287
  if push_to_hub and hub_model_id:
 
294
  model, train_dataset, config, device='cuda',
295
  num_epochs=100, lr=1e-5, batch_size=32, num_workers=4,
296
  save_dir='./checkpoints', push_to_hub=False, hub_model_id=None,
297
+ teacher_model=None,
298
  ):
299
+ """Phase 2: Fine-tuning with frozen shared experts, contrastive loss, and distillation.
300
+
301
+ Changes from Phase 1:
302
+ 1. Shared experts in TMoE blocks are frozen
303
+ 2. Contrastive loss on template/search features (temporal consistency)
304
+ 3. FiLM temporal modulation always active
305
+ 4. Optional AFKD knowledge distillation from teacher model
306
+ 5. Lower learning rate, especially for backbone
307
+ """
308
  print(f"=== Phase 2 Training: {num_epochs} epochs ===")
309
 
310
+ # Freeze shared experts in TMoE blocks
311
  model.freeze_backbone_shared_experts()
312
+ frozen_count = sum(1 for p in model.parameters() if not p.requires_grad)
313
+ total_count = sum(1 for p in model.parameters())
314
+ print(f" Frozen parameters: {frozen_count}/{total_count}")
315
 
316
+ from .losses import CombinedTrackingLoss, MemoryContrastiveLoss, AFKDDistillationLoss
317
  loss_fn = CombinedTrackingLoss(use_uncertainty=True, use_adw=True).to(device)
318
+ contrastive_loss = MemoryContrastiveLoss(temperature=0.1).to(device)
319
+
320
+ # Optional distillation loss
321
+ distill_loss = None
322
+ if teacher_model is not None:
323
+ teacher_model = teacher_model.to(device)
324
+ teacher_model.eval()
325
+ for p in teacher_model.parameters():
326
+ p.requires_grad = False
327
+ distill_loss = AFKDDistillationLoss(
328
+ student_dim=config['dim'], teacher_dim=768, temperature=4.0
329
+ ).to(device)
330
+ print(" AFKD distillation enabled (teacher → student)")
331
 
332
  model = model.to(device)
333
  optimizer = build_optimizer(model, lr=lr, backbone_lr_scale=0.01)
334
+ loss_optimizer = build_loss_optimizer(loss_fn)
335
  scheduler = build_scheduler(optimizer, num_epochs, warmup_epochs=2)
336
  scaler = GradScaler() if device == 'cuda' else None
337
 
 
343
  best_loss = float('inf')
344
 
345
  for epoch in range(num_epochs):
346
+ model.train()
347
+ total_loss = 0
348
+ num_batches = 0
349
+
350
+ for batch_idx, batch in enumerate(dataloader):
351
+ template = batch['template'].to(device)
352
+ search = batch['search'].to(device)
353
+ gt_heatmap = batch['heatmap'].to(device)
354
+ gt_size = batch['size'].to(device)
355
+ gt_boxes = batch['boxes'].to(device)
356
+
357
+ optimizer.zero_grad()
358
+ if loss_optimizer is not None:
359
+ loss_optimizer.zero_grad()
360
+
361
+ with autocast(enabled=scaler is not None):
362
+ # Always use temporal modulation in Phase 2
363
+ pred = model(template, search, use_temporal=True)
364
+ loss_dict = loss_fn(pred, gt_heatmap, gt_size, gt_boxes)
365
+ loss = loss_dict['total']
366
+
367
+ # Contrastive loss on temporal features
368
+ t_pooled = pred['template_feat'].mean(dim=1)
369
+ s_pooled = pred['search_feat'].mean(dim=1)
370
+ c_loss = contrastive_loss(t_pooled, s_pooled)
371
+ loss = loss + 0.1 * c_loss
372
+
373
+ # AFKD distillation loss (if teacher available)
374
+ if distill_loss is not None and teacher_model is not None:
375
+ with torch.no_grad():
376
+ teacher_pred = teacher_model(template, search)
377
+ d_loss = distill_loss(
378
+ student_feat=pred['search_feat'],
379
+ teacher_feat=teacher_pred['search_feat'],
380
+ student_logits=pred['heatmap'],
381
+ teacher_logits=teacher_pred['heatmap'],
382
+ )
383
+ loss = loss + 0.5 * d_loss
384
+
385
+ if scaler is not None:
386
+ scaler.scale(loss).backward()
387
+ scaler.unscale_(optimizer)
388
+ nn.utils.clip_grad_norm_(model.parameters(), grad_clip=1.0)
389
+ scaler.step(optimizer)
390
+ if loss_optimizer is not None:
391
+ scaler.unscale_(loss_optimizer)
392
+ scaler.step(loss_optimizer)
393
+ scaler.update()
394
+ else:
395
+ loss.backward()
396
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
397
+ optimizer.step()
398
+ if loss_optimizer is not None:
399
+ loss_optimizer.step()
400
+
401
+ total_loss += loss.item()
402
+ num_batches += 1
403
+
404
+ if batch_idx % 100 == 0:
405
+ msg = (f" Phase2 Epoch {epoch}/{num_epochs} | Batch {batch_idx} | "
406
+ f"Loss: {loss.item():.4f} | "
407
+ f"Heatmap: {loss_dict['heatmap']:.4f} | "
408
+ f"GIoU: {loss_dict['giou']:.4f} | "
409
+ f"Contr: {c_loss.item():.4f}")
410
+ if distill_loss is not None:
411
+ msg += f" | Distill: {d_loss.item():.4f}"
412
+ print(msg)
413
 
414
  scheduler.step()
415
+ model.reset_temporal() # Reset between epochs
416
 
417
+ avg_loss = total_loss / max(num_batches, 1)
418
  print(f"Phase2 Epoch {epoch}/{num_epochs} | Avg Loss: {avg_loss:.4f} | "
419
  f"LR: {scheduler.get_last_lr()[0]:.6f}")
420
 
 
424
  'epoch': epoch,
425
  'model_state_dict': model.state_dict(),
426
  'loss': best_loss,
427
+ 'config': config,
428
  }, os.path.join(save_dir, 'best_phase2.pth'))
429
+
430
+ if (epoch + 1) % 25 == 0:
431
+ torch.save({
432
+ 'epoch': epoch,
433
+ 'model_state_dict': model.state_dict(),
434
+ 'loss': avg_loss,
435
+ 'config': config,
436
+ }, os.path.join(save_dir, f'phase2_epoch{epoch+1}.pth'))
437
 
438
  if push_to_hub and hub_model_id:
439
  _push_checkpoint_to_hub(model, save_dir, hub_model_id, 'phase2')