""" Training script for ViL Tracker. Two-phase training: Phase 1: Standard supervised training on GOT-10k + LaSOT + TrackingNet - Full model training with focal + GIoU + size losses - ACL curriculum (progressive difficulty ramp-up on dataset AND loss weighting) - FiLM temporal modulation trained with temporal pairs - 300 epochs, lr=1e-4 with cosine decay, warmup=5 epochs Phase 2: Fine-tuning with TMoE and distillation - Freeze shared experts in TMoE blocks - Add contrastive loss on temporal features - Optional AFKD distillation from MCITrack-B256 teacher - FiLM temporal modulation active for all samples - 100 epochs, lr=1e-5 Hardware: Designed for A10G (24GB) or A100 (80GB) """ import os import json import math import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.cuda.amp import autocast, GradScaler def build_optimizer(model, lr=1e-4, weight_decay=0.05, backbone_lr_scale=0.1): """Build AdamW optimizer with component-wise learning rate scaling. Groups: - backbone: lr * backbone_lr_scale (pretrained or dominant, train slower) - heads: full lr (task-specific, need fast adaptation) - temporal_mod: lr * 0.5 (FiLM modulation, moderate learning) - loss params (ADW): lr * 0.1 (loss weighting, very slow adaptation) """ backbone_params = [] head_params = [] temporal_params = [] loss_params = [] other_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if 'backbone' in name: backbone_params.append(param) elif 'center_head' in name or 'uncertainty_head' in name: head_params.append(param) elif 'temporal_mod' in name: temporal_params.append(param) else: other_params.append(param) param_groups = [ {'params': backbone_params, 'lr': lr * backbone_lr_scale, 'name': 'backbone'}, {'params': head_params, 'lr': lr, 'name': 'heads'}, {'params': temporal_params, 'lr': lr * 0.5, 'name': 'temporal'}, {'params': other_params, 'lr': lr * 0.5, 'name': 'other'}, ] # Filter empty groups param_groups = [g for g in param_groups if len(g['params']) > 0] return optim.AdamW(param_groups, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999)) def build_loss_optimizer(loss_fn, lr=1e-3): """Separate optimizer for ADW loss weights (if trainable).""" loss_params = [p for p in loss_fn.parameters() if p.requires_grad] if loss_params: return optim.Adam(loss_params, lr=lr) return None def build_scheduler(optimizer, total_epochs, warmup_epochs=5): """Cosine annealing with linear warmup.""" def lr_lambda(epoch): if epoch < warmup_epochs: return max(0.01, epoch / warmup_epochs) progress = (epoch - warmup_epochs) / max(1, total_epochs - warmup_epochs) return 0.5 * (1 + math.cos(math.pi * progress)) return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) def train_one_epoch( model, dataloader, optimizer, loss_optimizer, scaler, loss_fn, device, epoch, total_epochs, acl_lambda=None, grad_clip=1.0, use_temporal=False, contrastive_loss=None, contrastive_weight=0.1, ): """Train for one epoch with AMP, gradient clipping, and optional temporal training. Args: model: ViLTracker instance dataloader: training data loader optimizer: model optimizer loss_optimizer: separate optimizer for ADW loss weights (can be None) scaler: GradScaler for AMP (None if cpu) loss_fn: CombinedTrackingLoss instance device: 'cuda' or 'cpu' epoch: current epoch number total_epochs: total number of epochs acl_lambda: ACL difficulty weight for loss scaling grad_clip: max gradient norm use_temporal: whether to use FiLM temporal modulation contrastive_loss: optional MemoryContrastiveLoss for Phase 2 contrastive_weight: weight for contrastive loss """ model.train() total_loss = 0 total_heatmap_loss = 0 total_giou_loss = 0 total_size_loss = 0 total_contrastive_loss = 0 num_batches = 0 for batch_idx, batch in enumerate(dataloader): template = batch['template'].to(device) searches = batch['searches'].to(device) # (B, K, 3, 256, 256) gt_heatmaps = batch['heatmaps'].to(device) # (B, K, 1, 16, 16) gt_sizes = batch['sizes'].to(device) # (B, K, 2) gt_boxes = batch['boxes'].to(device) # (B, K, 4) B, K = searches.shape[:2] optimizer.zero_grad() if loss_optimizer is not None: loss_optimizer.zero_grad() with autocast(enabled=scaler is not None): # Forward: template + K search frames as one sequence pred = model(template, searches, use_temporal=use_temporal) # Accumulate loss over K frames loss = torch.tensor(0.0, device=device) frame_heatmap = 0.0 frame_giou = 0.0 frame_size = 0.0 for k in range(K): pred_k = { 'heatmap': pred['heatmap'][:, k], # (B, 1, 16, 16) 'size': pred['size'][:, k], # (B, 2, 16, 16) 'boxes': pred['boxes'][:, k], # (B, 4) } if 'log_variance' in pred: pred_k['log_variance'] = pred['log_variance'][:, k] loss_dict_k = loss_fn(pred_k, gt_heatmaps[:, k], gt_sizes[:, k], gt_boxes[:, k]) loss = loss + loss_dict_k['total'] frame_heatmap += loss_dict_k['heatmap'].item() frame_giou += loss_dict_k['giou'].item() frame_size += loss_dict_k['size'].item() loss = loss / K # Average over frames # Contrastive loss on template/search features if contrastive_loss is not None and 'search_feats' in pred: t_pooled = pred['template_feat'].mean(dim=1) # (B, D) s_pooled = pred['search_feats'][:, -1].mean(dim=1) # (B, D) last frame c_loss = contrastive_loss(t_pooled, s_pooled) loss = loss + contrastive_weight * c_loss total_contrastive_loss += c_loss.item() # ACL difficulty weighting if acl_lambda is not None: loss = loss * acl_lambda if scaler is not None: scaler.scale(loss).backward() scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), grad_clip) scaler.step(optimizer) if loss_optimizer is not None: scaler.unscale_(loss_optimizer) scaler.step(loss_optimizer) scaler.update() else: loss.backward() nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() if loss_optimizer is not None: loss_optimizer.step() total_loss += loss.item() total_heatmap_loss += frame_heatmap / K total_giou_loss += frame_giou / K total_size_loss += frame_size / K num_batches += 1 if batch_idx % 100 == 0: msg = (f" Epoch {epoch}/{total_epochs} | Batch {batch_idx} | " f"Loss: {loss.item():.4f} | " f"Heatmap: {frame_heatmap/K:.4f} | " f"GIoU: {frame_giou/K:.4f} | " f"Size: {frame_size/K:.4f}") if contrastive_loss is not None and total_contrastive_loss > 0: msg += f" | Contr: {total_contrastive_loss / max(1, num_batches):.4f}" print(msg) n = max(num_batches, 1) return { 'total': total_loss / n, 'heatmap': total_heatmap_loss / n, 'giou': total_giou_loss / n, 'size': total_size_loss / n, 'contrastive': total_contrastive_loss / n if total_contrastive_loss > 0 else 0, } def train_phase1( model, train_dataset, config, device='cuda', num_epochs=300, lr=1e-4, batch_size=32, num_workers=4, save_dir='./checkpoints', push_to_hub=False, hub_model_id=None, ): """Phase 1: Standard supervised training with ACL curriculum. ACL Curriculum: - Epoch 0-50: difficulty ramps from 0→1 (easy to hard samples) - Loss weighting: acl_lambda ramps from 0.5→1.0 - Dataset augmentation intensity increases with difficulty FiLM temporal modulation: - Starts training after epoch 30 (model needs basic features first) - Activated for 50% of batches initially, 100% after epoch 100 """ print(f"=== Phase 1 Training: {num_epochs} epochs ===") os.makedirs(save_dir, exist_ok=True) from .losses import CombinedTrackingLoss loss_fn = CombinedTrackingLoss(use_uncertainty=True, use_adw=True).to(device) model = model.to(device) optimizer = build_optimizer(model, lr=lr) loss_optimizer = build_loss_optimizer(loss_fn) scheduler = build_scheduler(optimizer, num_epochs) scaler = GradScaler() if device == 'cuda' else None dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True, ) best_loss = float('inf') for epoch in range(num_epochs): # ACL curriculum: progressive difficulty ramp-up acl_progress = min(1.0, (epoch + 1) / 50) # Linear ramp over 50 epochs acl_lambda = 0.5 + 0.5 * acl_progress # Loss weight: 0.5 → 1.0 # Update dataset difficulty (if supported) if hasattr(train_dataset, 'set_acl_difficulty'): train_dataset.set_acl_difficulty(acl_progress) elif hasattr(train_dataset, 'datasets'): # ConcatDataset: update all sub-datasets for ds in train_dataset.datasets: if hasattr(ds, 'set_acl_difficulty'): ds.set_acl_difficulty(acl_progress) # FiLM temporal modulation schedule use_temporal = epoch >= 30 # Start FiLM after 30 epochs loss_metrics = train_one_epoch( model, dataloader, optimizer, loss_optimizer, scaler, loss_fn, device, epoch, num_epochs, acl_lambda=acl_lambda, use_temporal=use_temporal, ) scheduler.step() # Reset temporal state between epochs (each epoch starts fresh sequences) model.reset_temporal() print(f"Epoch {epoch}/{num_epochs} | " f"Loss: {loss_metrics['total']:.4f} | " f"Heatmap: {loss_metrics['heatmap']:.4f} | " f"GIoU: {loss_metrics['giou']:.4f} | " f"Size: {loss_metrics['size']:.4f} | " f"LR: {scheduler.get_last_lr()[0]:.6f} | " f"ACL: {acl_progress:.2f} | " f"Temporal: {'ON' if use_temporal else 'OFF'}") # Save best if loss_metrics['total'] < best_loss: best_loss = loss_metrics['total'] torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': best_loss, 'config': config, }, os.path.join(save_dir, 'best_phase1.pth')) # Save periodic if (epoch + 1) % 50 == 0: torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss_metrics['total'], 'config': config, }, os.path.join(save_dir, f'phase1_epoch{epoch+1}.pth')) if push_to_hub and hub_model_id: _push_checkpoint_to_hub(model, save_dir, hub_model_id, 'phase1') return model def train_phase2( model, train_dataset, config, device='cuda', num_epochs=100, lr=1e-5, batch_size=32, num_workers=4, save_dir='./checkpoints', push_to_hub=False, hub_model_id=None, teacher_model=None, ): """Phase 2: Fine-tuning with frozen shared experts, contrastive loss, and distillation. Changes from Phase 1: 1. Shared experts in TMoE blocks are frozen 2. Contrastive loss on template/search features (temporal consistency) 3. FiLM temporal modulation always active 4. Optional AFKD knowledge distillation from teacher model 5. Lower learning rate, especially for backbone """ print(f"=== Phase 2 Training: {num_epochs} epochs ===") # Freeze shared experts in TMoE blocks model.freeze_backbone_shared_experts() frozen_count = sum(1 for p in model.parameters() if not p.requires_grad) total_count = sum(1 for p in model.parameters()) print(f" Frozen parameters: {frozen_count}/{total_count}") from .losses import CombinedTrackingLoss, MemoryContrastiveLoss, AFKDDistillationLoss loss_fn = CombinedTrackingLoss(use_uncertainty=True, use_adw=True).to(device) contrastive_loss = MemoryContrastiveLoss(temperature=0.1).to(device) # Optional distillation loss distill_loss = None if teacher_model is not None: teacher_model = teacher_model.to(device) teacher_model.eval() for p in teacher_model.parameters(): p.requires_grad = False distill_loss = AFKDDistillationLoss( student_dim=config['dim'], teacher_dim=768, temperature=4.0 ).to(device) print(" AFKD distillation enabled (teacher → student)") model = model.to(device) optimizer = build_optimizer(model, lr=lr, backbone_lr_scale=0.01) loss_optimizer = build_loss_optimizer(loss_fn) scheduler = build_scheduler(optimizer, num_epochs, warmup_epochs=2) scaler = GradScaler() if device == 'cuda' else None dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True, ) best_loss = float('inf') for epoch in range(num_epochs): model.train() total_loss = 0 num_batches = 0 for batch_idx, batch in enumerate(dataloader): template = batch['template'].to(device) searches = batch['searches'].to(device) gt_heatmaps = batch['heatmaps'].to(device) gt_sizes = batch['sizes'].to(device) gt_boxes = batch['boxes'].to(device) B, K = searches.shape[:2] optimizer.zero_grad() if loss_optimizer is not None: loss_optimizer.zero_grad() with autocast(enabled=scaler is not None): pred = model(template, searches, use_temporal=True) # Accumulate loss over K frames loss = torch.tensor(0.0, device=device) for k in range(K): pred_k = { 'heatmap': pred['heatmap'][:, k], 'size': pred['size'][:, k], 'boxes': pred['boxes'][:, k], } if 'log_variance' in pred: pred_k['log_variance'] = pred['log_variance'][:, k] loss_dict_k = loss_fn(pred_k, gt_heatmaps[:, k], gt_sizes[:, k], gt_boxes[:, k]) loss = loss + loss_dict_k['total'] loss = loss / K # Contrastive loss t_pooled = pred['template_feat'].mean(dim=1) s_pooled = pred['search_feats'][:, -1].mean(dim=1) c_loss = contrastive_loss(t_pooled, s_pooled) loss = loss + 0.1 * c_loss # AFKD distillation (if teacher available) if distill_loss is not None and teacher_model is not None: with torch.no_grad(): teacher_pred = teacher_model(template, searches) # Distill on last frame features d_loss = distill_loss( student_feat=pred['search_feats'][:, -1], teacher_feat=teacher_pred['search_feats'][:, -1] if teacher_pred['search_feats'].ndim == 4 else teacher_pred['search_feat'], student_logits=pred['heatmap'][:, -1], teacher_logits=teacher_pred['heatmap'][:, -1] if teacher_pred['heatmap'].ndim == 5 else teacher_pred['heatmap'], ) loss = loss + 0.5 * d_loss if scaler is not None: scaler.scale(loss).backward() scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), grad_clip=1.0) scaler.step(optimizer) if loss_optimizer is not None: scaler.unscale_(loss_optimizer) scaler.step(loss_optimizer) scaler.update() else: loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() if loss_optimizer is not None: loss_optimizer.step() total_loss += loss.item() num_batches += 1 if batch_idx % 100 == 0: msg = (f" Phase2 Epoch {epoch}/{num_epochs} | Batch {batch_idx} | " f"Loss: {loss.item():.4f} | " f"Contr: {c_loss.item():.4f}") if distill_loss is not None: msg += f" | Distill: {d_loss.item():.4f}" print(msg) scheduler.step() model.reset_temporal() # Reset between epochs avg_loss = total_loss / max(num_batches, 1) print(f"Phase2 Epoch {epoch}/{num_epochs} | Avg Loss: {avg_loss:.4f} | " f"LR: {scheduler.get_last_lr()[0]:.6f}") if avg_loss < best_loss: best_loss = avg_loss torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'loss': best_loss, 'config': config, }, os.path.join(save_dir, 'best_phase2.pth')) if (epoch + 1) % 25 == 0: torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'loss': avg_loss, 'config': config, }, os.path.join(save_dir, f'phase2_epoch{epoch+1}.pth')) if push_to_hub and hub_model_id: _push_checkpoint_to_hub(model, save_dir, hub_model_id, 'phase2') return model def _push_checkpoint_to_hub(model, save_dir, hub_model_id, phase): """Push checkpoint to HuggingFace Hub.""" try: from huggingface_hub import HfApi api = HfApi() api.upload_folder( folder_path=save_dir, repo_id=hub_model_id, path_in_repo=f'checkpoints/{phase}', ) print(f"Pushed {phase} checkpoint to {hub_model_id}") except Exception as e: print(f"Warning: Could not push to hub: {e}")