| """ |
| Training script for ViL Tracker. |
| |
| Two-phase training: |
| Phase 1: Standard supervised training on GOT-10k + LaSOT + TrackingNet |
| - Full model training with focal + GIoU + size losses |
| - ACL curriculum (progressive difficulty ramp-up on dataset AND loss weighting) |
| - FiLM temporal modulation trained with temporal pairs |
| - 300 epochs, lr=1e-4 with cosine decay, warmup=5 epochs |
| |
| Phase 2: Fine-tuning with TMoE and distillation |
| - Freeze shared experts in TMoE blocks |
| - Add contrastive loss on temporal features |
| - Optional AFKD distillation from MCITrack-B256 teacher |
| - FiLM temporal modulation active for all samples |
| - 100 epochs, lr=1e-5 |
| |
| Hardware: Designed for A10G (24GB) or A100 (80GB) |
| """ |
|
|
| import os |
| import json |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
| from torch.cuda.amp import autocast, GradScaler |
|
|
|
|
| def build_optimizer(model, lr=1e-4, weight_decay=0.05, backbone_lr_scale=0.1): |
| """Build AdamW optimizer with component-wise learning rate scaling. |
| |
| Groups: |
| - backbone: lr * backbone_lr_scale (pretrained or dominant, train slower) |
| - heads: full lr (task-specific, need fast adaptation) |
| - temporal_mod: lr * 0.5 (FiLM modulation, moderate learning) |
| - loss params (ADW): lr * 0.1 (loss weighting, very slow adaptation) |
| """ |
| backbone_params = [] |
| head_params = [] |
| temporal_params = [] |
| loss_params = [] |
| other_params = [] |
| |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if 'backbone' in name: |
| backbone_params.append(param) |
| elif 'center_head' in name or 'uncertainty_head' in name: |
| head_params.append(param) |
| elif 'temporal_mod' in name: |
| temporal_params.append(param) |
| else: |
| other_params.append(param) |
| |
| param_groups = [ |
| {'params': backbone_params, 'lr': lr * backbone_lr_scale, 'name': 'backbone'}, |
| {'params': head_params, 'lr': lr, 'name': 'heads'}, |
| {'params': temporal_params, 'lr': lr * 0.5, 'name': 'temporal'}, |
| {'params': other_params, 'lr': lr * 0.5, 'name': 'other'}, |
| ] |
| |
| |
| param_groups = [g for g in param_groups if len(g['params']) > 0] |
| |
| return optim.AdamW(param_groups, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999)) |
|
|
|
|
| def build_loss_optimizer(loss_fn, lr=1e-3): |
| """Separate optimizer for ADW loss weights (if trainable).""" |
| loss_params = [p for p in loss_fn.parameters() if p.requires_grad] |
| if loss_params: |
| return optim.Adam(loss_params, lr=lr) |
| return None |
|
|
|
|
| def build_scheduler(optimizer, total_epochs, warmup_epochs=5): |
| """Cosine annealing with linear warmup.""" |
| def lr_lambda(epoch): |
| if epoch < warmup_epochs: |
| return max(0.01, epoch / warmup_epochs) |
| progress = (epoch - warmup_epochs) / max(1, total_epochs - warmup_epochs) |
| return 0.5 * (1 + math.cos(math.pi * progress)) |
| |
| return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| def train_one_epoch( |
| model, dataloader, optimizer, loss_optimizer, scaler, loss_fn, device, |
| epoch, total_epochs, acl_lambda=None, grad_clip=1.0, |
| use_temporal=False, contrastive_loss=None, contrastive_weight=0.1, |
| ): |
| """Train for one epoch with AMP, gradient clipping, and optional temporal training. |
| |
| Args: |
| model: ViLTracker instance |
| dataloader: training data loader |
| optimizer: model optimizer |
| loss_optimizer: separate optimizer for ADW loss weights (can be None) |
| scaler: GradScaler for AMP (None if cpu) |
| loss_fn: CombinedTrackingLoss instance |
| device: 'cuda' or 'cpu' |
| epoch: current epoch number |
| total_epochs: total number of epochs |
| acl_lambda: ACL difficulty weight for loss scaling |
| grad_clip: max gradient norm |
| use_temporal: whether to use FiLM temporal modulation |
| contrastive_loss: optional MemoryContrastiveLoss for Phase 2 |
| contrastive_weight: weight for contrastive loss |
| """ |
| model.train() |
| total_loss = 0 |
| total_heatmap_loss = 0 |
| total_giou_loss = 0 |
| total_size_loss = 0 |
| total_contrastive_loss = 0 |
| num_batches = 0 |
| |
| for batch_idx, batch in enumerate(dataloader): |
| template = batch['template'].to(device) |
| searches = batch['searches'].to(device) |
| gt_heatmaps = batch['heatmaps'].to(device) |
| gt_sizes = batch['sizes'].to(device) |
| gt_boxes = batch['boxes'].to(device) |
| |
| B, K = searches.shape[:2] |
| |
| optimizer.zero_grad() |
| if loss_optimizer is not None: |
| loss_optimizer.zero_grad() |
| |
| with autocast(enabled=scaler is not None): |
| |
| pred = model(template, searches, use_temporal=use_temporal) |
| |
| |
| loss = torch.tensor(0.0, device=device) |
| frame_heatmap = 0.0 |
| frame_giou = 0.0 |
| frame_size = 0.0 |
| |
| for k in range(K): |
| pred_k = { |
| 'heatmap': pred['heatmap'][:, k], |
| 'size': pred['size'][:, k], |
| 'boxes': pred['boxes'][:, k], |
| } |
| if 'log_variance' in pred: |
| pred_k['log_variance'] = pred['log_variance'][:, k] |
| |
| loss_dict_k = loss_fn(pred_k, gt_heatmaps[:, k], |
| gt_sizes[:, k], gt_boxes[:, k]) |
| loss = loss + loss_dict_k['total'] |
| frame_heatmap += loss_dict_k['heatmap'].item() |
| frame_giou += loss_dict_k['giou'].item() |
| frame_size += loss_dict_k['size'].item() |
| |
| loss = loss / K |
| |
| |
| if contrastive_loss is not None and 'search_feats' in pred: |
| t_pooled = pred['template_feat'].mean(dim=1) |
| s_pooled = pred['search_feats'][:, -1].mean(dim=1) |
| c_loss = contrastive_loss(t_pooled, s_pooled) |
| loss = loss + contrastive_weight * c_loss |
| total_contrastive_loss += c_loss.item() |
| |
| |
| if acl_lambda is not None: |
| loss = loss * acl_lambda |
| |
| if scaler is not None: |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
| scaler.step(optimizer) |
| if loss_optimizer is not None: |
| scaler.unscale_(loss_optimizer) |
| scaler.step(loss_optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
| optimizer.step() |
| if loss_optimizer is not None: |
| loss_optimizer.step() |
| |
| total_loss += loss.item() |
| total_heatmap_loss += frame_heatmap / K |
| total_giou_loss += frame_giou / K |
| total_size_loss += frame_size / K |
| num_batches += 1 |
| |
| if batch_idx % 100 == 0: |
| msg = (f" Epoch {epoch}/{total_epochs} | Batch {batch_idx} | " |
| f"Loss: {loss.item():.4f} | " |
| f"Heatmap: {frame_heatmap/K:.4f} | " |
| f"GIoU: {frame_giou/K:.4f} | " |
| f"Size: {frame_size/K:.4f}") |
| if contrastive_loss is not None and total_contrastive_loss > 0: |
| msg += f" | Contr: {total_contrastive_loss / max(1, num_batches):.4f}" |
| print(msg) |
| |
| n = max(num_batches, 1) |
| return { |
| 'total': total_loss / n, |
| 'heatmap': total_heatmap_loss / n, |
| 'giou': total_giou_loss / n, |
| 'size': total_size_loss / n, |
| 'contrastive': total_contrastive_loss / n if total_contrastive_loss > 0 else 0, |
| } |
|
|
|
|
| def train_phase1( |
| model, train_dataset, config, device='cuda', |
| num_epochs=300, lr=1e-4, batch_size=32, num_workers=4, |
| save_dir='./checkpoints', push_to_hub=False, hub_model_id=None, |
| ): |
| """Phase 1: Standard supervised training with ACL curriculum. |
| |
| ACL Curriculum: |
| - Epoch 0-50: difficulty ramps from 0→1 (easy to hard samples) |
| - Loss weighting: acl_lambda ramps from 0.5→1.0 |
| - Dataset augmentation intensity increases with difficulty |
| |
| FiLM temporal modulation: |
| - Starts training after epoch 30 (model needs basic features first) |
| - Activated for 50% of batches initially, 100% after epoch 100 |
| """ |
| print(f"=== Phase 1 Training: {num_epochs} epochs ===") |
| |
| os.makedirs(save_dir, exist_ok=True) |
| |
| from .losses import CombinedTrackingLoss |
| loss_fn = CombinedTrackingLoss(use_uncertainty=True, use_adw=True).to(device) |
| |
| model = model.to(device) |
| optimizer = build_optimizer(model, lr=lr) |
| loss_optimizer = build_loss_optimizer(loss_fn) |
| scheduler = build_scheduler(optimizer, num_epochs) |
| scaler = GradScaler() if device == 'cuda' else None |
| |
| dataloader = DataLoader( |
| train_dataset, batch_size=batch_size, shuffle=True, |
| num_workers=num_workers, pin_memory=True, drop_last=True, |
| ) |
| |
| best_loss = float('inf') |
| |
| for epoch in range(num_epochs): |
| |
| acl_progress = min(1.0, (epoch + 1) / 50) |
| acl_lambda = 0.5 + 0.5 * acl_progress |
| |
| |
| if hasattr(train_dataset, 'set_acl_difficulty'): |
| train_dataset.set_acl_difficulty(acl_progress) |
| elif hasattr(train_dataset, 'datasets'): |
| |
| for ds in train_dataset.datasets: |
| if hasattr(ds, 'set_acl_difficulty'): |
| ds.set_acl_difficulty(acl_progress) |
| |
| |
| use_temporal = epoch >= 30 |
| |
| loss_metrics = train_one_epoch( |
| model, dataloader, optimizer, loss_optimizer, scaler, loss_fn, |
| device, epoch, num_epochs, acl_lambda=acl_lambda, |
| use_temporal=use_temporal, |
| ) |
| |
| scheduler.step() |
| |
| |
| model.reset_temporal() |
| |
| print(f"Epoch {epoch}/{num_epochs} | " |
| f"Loss: {loss_metrics['total']:.4f} | " |
| f"Heatmap: {loss_metrics['heatmap']:.4f} | " |
| f"GIoU: {loss_metrics['giou']:.4f} | " |
| f"Size: {loss_metrics['size']:.4f} | " |
| f"LR: {scheduler.get_last_lr()[0]:.6f} | " |
| f"ACL: {acl_progress:.2f} | " |
| f"Temporal: {'ON' if use_temporal else 'OFF'}") |
| |
| |
| if loss_metrics['total'] < best_loss: |
| best_loss = loss_metrics['total'] |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'loss': best_loss, |
| 'config': config, |
| }, os.path.join(save_dir, 'best_phase1.pth')) |
| |
| |
| if (epoch + 1) % 50 == 0: |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'loss': loss_metrics['total'], |
| 'config': config, |
| }, os.path.join(save_dir, f'phase1_epoch{epoch+1}.pth')) |
| |
| if push_to_hub and hub_model_id: |
| _push_checkpoint_to_hub(model, save_dir, hub_model_id, 'phase1') |
| |
| return model |
|
|
|
|
| def train_phase2( |
| model, train_dataset, config, device='cuda', |
| num_epochs=100, lr=1e-5, batch_size=32, num_workers=4, |
| save_dir='./checkpoints', push_to_hub=False, hub_model_id=None, |
| teacher_model=None, |
| ): |
| """Phase 2: Fine-tuning with frozen shared experts, contrastive loss, and distillation. |
| |
| Changes from Phase 1: |
| 1. Shared experts in TMoE blocks are frozen |
| 2. Contrastive loss on template/search features (temporal consistency) |
| 3. FiLM temporal modulation always active |
| 4. Optional AFKD knowledge distillation from teacher model |
| 5. Lower learning rate, especially for backbone |
| """ |
| print(f"=== Phase 2 Training: {num_epochs} epochs ===") |
| |
| |
| model.freeze_backbone_shared_experts() |
| frozen_count = sum(1 for p in model.parameters() if not p.requires_grad) |
| total_count = sum(1 for p in model.parameters()) |
| print(f" Frozen parameters: {frozen_count}/{total_count}") |
| |
| from .losses import CombinedTrackingLoss, MemoryContrastiveLoss, AFKDDistillationLoss |
| loss_fn = CombinedTrackingLoss(use_uncertainty=True, use_adw=True).to(device) |
| contrastive_loss = MemoryContrastiveLoss(temperature=0.1).to(device) |
| |
| |
| distill_loss = None |
| if teacher_model is not None: |
| teacher_model = teacher_model.to(device) |
| teacher_model.eval() |
| for p in teacher_model.parameters(): |
| p.requires_grad = False |
| distill_loss = AFKDDistillationLoss( |
| student_dim=config['dim'], teacher_dim=768, temperature=4.0 |
| ).to(device) |
| print(" AFKD distillation enabled (teacher → student)") |
| |
| model = model.to(device) |
| optimizer = build_optimizer(model, lr=lr, backbone_lr_scale=0.01) |
| loss_optimizer = build_loss_optimizer(loss_fn) |
| scheduler = build_scheduler(optimizer, num_epochs, warmup_epochs=2) |
| scaler = GradScaler() if device == 'cuda' else None |
| |
| dataloader = DataLoader( |
| train_dataset, batch_size=batch_size, shuffle=True, |
| num_workers=num_workers, pin_memory=True, drop_last=True, |
| ) |
| |
| best_loss = float('inf') |
| |
| for epoch in range(num_epochs): |
| model.train() |
| total_loss = 0 |
| num_batches = 0 |
| |
| for batch_idx, batch in enumerate(dataloader): |
| template = batch['template'].to(device) |
| searches = batch['searches'].to(device) |
| gt_heatmaps = batch['heatmaps'].to(device) |
| gt_sizes = batch['sizes'].to(device) |
| gt_boxes = batch['boxes'].to(device) |
| |
| B, K = searches.shape[:2] |
| |
| optimizer.zero_grad() |
| if loss_optimizer is not None: |
| loss_optimizer.zero_grad() |
| |
| with autocast(enabled=scaler is not None): |
| pred = model(template, searches, use_temporal=True) |
| |
| |
| loss = torch.tensor(0.0, device=device) |
| for k in range(K): |
| pred_k = { |
| 'heatmap': pred['heatmap'][:, k], |
| 'size': pred['size'][:, k], |
| 'boxes': pred['boxes'][:, k], |
| } |
| if 'log_variance' in pred: |
| pred_k['log_variance'] = pred['log_variance'][:, k] |
| loss_dict_k = loss_fn(pred_k, gt_heatmaps[:, k], |
| gt_sizes[:, k], gt_boxes[:, k]) |
| loss = loss + loss_dict_k['total'] |
| loss = loss / K |
| |
| |
| t_pooled = pred['template_feat'].mean(dim=1) |
| s_pooled = pred['search_feats'][:, -1].mean(dim=1) |
| c_loss = contrastive_loss(t_pooled, s_pooled) |
| loss = loss + 0.1 * c_loss |
| |
| |
| if distill_loss is not None and teacher_model is not None: |
| with torch.no_grad(): |
| teacher_pred = teacher_model(template, searches) |
| |
| d_loss = distill_loss( |
| student_feat=pred['search_feats'][:, -1], |
| teacher_feat=teacher_pred['search_feats'][:, -1] if teacher_pred['search_feats'].ndim == 4 else teacher_pred['search_feat'], |
| student_logits=pred['heatmap'][:, -1], |
| teacher_logits=teacher_pred['heatmap'][:, -1] if teacher_pred['heatmap'].ndim == 5 else teacher_pred['heatmap'], |
| ) |
| loss = loss + 0.5 * d_loss |
| |
| if scaler is not None: |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), grad_clip=1.0) |
| scaler.step(optimizer) |
| if loss_optimizer is not None: |
| scaler.unscale_(loss_optimizer) |
| scaler.step(loss_optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| if loss_optimizer is not None: |
| loss_optimizer.step() |
| |
| total_loss += loss.item() |
| num_batches += 1 |
| |
| if batch_idx % 100 == 0: |
| msg = (f" Phase2 Epoch {epoch}/{num_epochs} | Batch {batch_idx} | " |
| f"Loss: {loss.item():.4f} | " |
| f"Contr: {c_loss.item():.4f}") |
| if distill_loss is not None: |
| msg += f" | Distill: {d_loss.item():.4f}" |
| print(msg) |
| |
| scheduler.step() |
| model.reset_temporal() |
| |
| avg_loss = total_loss / max(num_batches, 1) |
| print(f"Phase2 Epoch {epoch}/{num_epochs} | Avg Loss: {avg_loss:.4f} | " |
| f"LR: {scheduler.get_last_lr()[0]:.6f}") |
| |
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'loss': best_loss, |
| 'config': config, |
| }, os.path.join(save_dir, 'best_phase2.pth')) |
| |
| if (epoch + 1) % 25 == 0: |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'loss': avg_loss, |
| 'config': config, |
| }, os.path.join(save_dir, f'phase2_epoch{epoch+1}.pth')) |
| |
| if push_to_hub and hub_model_id: |
| _push_checkpoint_to_hub(model, save_dir, hub_model_id, 'phase2') |
| |
| return model |
|
|
|
|
| def _push_checkpoint_to_hub(model, save_dir, hub_model_id, phase): |
| """Push checkpoint to HuggingFace Hub.""" |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_folder( |
| folder_path=save_dir, |
| repo_id=hub_model_id, |
| path_in_repo=f'checkpoints/{phase}', |
| ) |
| print(f"Pushed {phase} checkpoint to {hub_model_id}") |
| except Exception as e: |
| print(f"Warning: Could not push to hub: {e}") |
|
|