omar-ah's picture
Sequence training: pairs→K-frame clips, mLSTM memory carries across frames
1bf192e verified
"""
Training script for ViL Tracker.
Two-phase training:
Phase 1: Standard supervised training on GOT-10k + LaSOT + TrackingNet
- Full model training with focal + GIoU + size losses
- ACL curriculum (progressive difficulty ramp-up on dataset AND loss weighting)
- FiLM temporal modulation trained with temporal pairs
- 300 epochs, lr=1e-4 with cosine decay, warmup=5 epochs
Phase 2: Fine-tuning with TMoE and distillation
- Freeze shared experts in TMoE blocks
- Add contrastive loss on temporal features
- Optional AFKD distillation from MCITrack-B256 teacher
- FiLM temporal modulation active for all samples
- 100 epochs, lr=1e-5
Hardware: Designed for A10G (24GB) or A100 (80GB)
"""
import os
import json
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
def build_optimizer(model, lr=1e-4, weight_decay=0.05, backbone_lr_scale=0.1):
"""Build AdamW optimizer with component-wise learning rate scaling.
Groups:
- backbone: lr * backbone_lr_scale (pretrained or dominant, train slower)
- heads: full lr (task-specific, need fast adaptation)
- temporal_mod: lr * 0.5 (FiLM modulation, moderate learning)
- loss params (ADW): lr * 0.1 (loss weighting, very slow adaptation)
"""
backbone_params = []
head_params = []
temporal_params = []
loss_params = []
other_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if 'backbone' in name:
backbone_params.append(param)
elif 'center_head' in name or 'uncertainty_head' in name:
head_params.append(param)
elif 'temporal_mod' in name:
temporal_params.append(param)
else:
other_params.append(param)
param_groups = [
{'params': backbone_params, 'lr': lr * backbone_lr_scale, 'name': 'backbone'},
{'params': head_params, 'lr': lr, 'name': 'heads'},
{'params': temporal_params, 'lr': lr * 0.5, 'name': 'temporal'},
{'params': other_params, 'lr': lr * 0.5, 'name': 'other'},
]
# Filter empty groups
param_groups = [g for g in param_groups if len(g['params']) > 0]
return optim.AdamW(param_groups, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
def build_loss_optimizer(loss_fn, lr=1e-3):
"""Separate optimizer for ADW loss weights (if trainable)."""
loss_params = [p for p in loss_fn.parameters() if p.requires_grad]
if loss_params:
return optim.Adam(loss_params, lr=lr)
return None
def build_scheduler(optimizer, total_epochs, warmup_epochs=5):
"""Cosine annealing with linear warmup."""
def lr_lambda(epoch):
if epoch < warmup_epochs:
return max(0.01, epoch / warmup_epochs)
progress = (epoch - warmup_epochs) / max(1, total_epochs - warmup_epochs)
return 0.5 * (1 + math.cos(math.pi * progress))
return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def train_one_epoch(
model, dataloader, optimizer, loss_optimizer, scaler, loss_fn, device,
epoch, total_epochs, acl_lambda=None, grad_clip=1.0,
use_temporal=False, contrastive_loss=None, contrastive_weight=0.1,
):
"""Train for one epoch with AMP, gradient clipping, and optional temporal training.
Args:
model: ViLTracker instance
dataloader: training data loader
optimizer: model optimizer
loss_optimizer: separate optimizer for ADW loss weights (can be None)
scaler: GradScaler for AMP (None if cpu)
loss_fn: CombinedTrackingLoss instance
device: 'cuda' or 'cpu'
epoch: current epoch number
total_epochs: total number of epochs
acl_lambda: ACL difficulty weight for loss scaling
grad_clip: max gradient norm
use_temporal: whether to use FiLM temporal modulation
contrastive_loss: optional MemoryContrastiveLoss for Phase 2
contrastive_weight: weight for contrastive loss
"""
model.train()
total_loss = 0
total_heatmap_loss = 0
total_giou_loss = 0
total_size_loss = 0
total_contrastive_loss = 0
num_batches = 0
for batch_idx, batch in enumerate(dataloader):
template = batch['template'].to(device)
searches = batch['searches'].to(device) # (B, K, 3, 256, 256)
gt_heatmaps = batch['heatmaps'].to(device) # (B, K, 1, 16, 16)
gt_sizes = batch['sizes'].to(device) # (B, K, 2)
gt_boxes = batch['boxes'].to(device) # (B, K, 4)
B, K = searches.shape[:2]
optimizer.zero_grad()
if loss_optimizer is not None:
loss_optimizer.zero_grad()
with autocast(enabled=scaler is not None):
# Forward: template + K search frames as one sequence
pred = model(template, searches, use_temporal=use_temporal)
# Accumulate loss over K frames
loss = torch.tensor(0.0, device=device)
frame_heatmap = 0.0
frame_giou = 0.0
frame_size = 0.0
for k in range(K):
pred_k = {
'heatmap': pred['heatmap'][:, k], # (B, 1, 16, 16)
'size': pred['size'][:, k], # (B, 2, 16, 16)
'boxes': pred['boxes'][:, k], # (B, 4)
}
if 'log_variance' in pred:
pred_k['log_variance'] = pred['log_variance'][:, k]
loss_dict_k = loss_fn(pred_k, gt_heatmaps[:, k],
gt_sizes[:, k], gt_boxes[:, k])
loss = loss + loss_dict_k['total']
frame_heatmap += loss_dict_k['heatmap'].item()
frame_giou += loss_dict_k['giou'].item()
frame_size += loss_dict_k['size'].item()
loss = loss / K # Average over frames
# Contrastive loss on template/search features
if contrastive_loss is not None and 'search_feats' in pred:
t_pooled = pred['template_feat'].mean(dim=1) # (B, D)
s_pooled = pred['search_feats'][:, -1].mean(dim=1) # (B, D) last frame
c_loss = contrastive_loss(t_pooled, s_pooled)
loss = loss + contrastive_weight * c_loss
total_contrastive_loss += c_loss.item()
# ACL difficulty weighting
if acl_lambda is not None:
loss = loss * acl_lambda
if scaler is not None:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.step(optimizer)
if loss_optimizer is not None:
scaler.unscale_(loss_optimizer)
scaler.step(loss_optimizer)
scaler.update()
else:
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
if loss_optimizer is not None:
loss_optimizer.step()
total_loss += loss.item()
total_heatmap_loss += frame_heatmap / K
total_giou_loss += frame_giou / K
total_size_loss += frame_size / K
num_batches += 1
if batch_idx % 100 == 0:
msg = (f" Epoch {epoch}/{total_epochs} | Batch {batch_idx} | "
f"Loss: {loss.item():.4f} | "
f"Heatmap: {frame_heatmap/K:.4f} | "
f"GIoU: {frame_giou/K:.4f} | "
f"Size: {frame_size/K:.4f}")
if contrastive_loss is not None and total_contrastive_loss > 0:
msg += f" | Contr: {total_contrastive_loss / max(1, num_batches):.4f}"
print(msg)
n = max(num_batches, 1)
return {
'total': total_loss / n,
'heatmap': total_heatmap_loss / n,
'giou': total_giou_loss / n,
'size': total_size_loss / n,
'contrastive': total_contrastive_loss / n if total_contrastive_loss > 0 else 0,
}
def train_phase1(
model, train_dataset, config, device='cuda',
num_epochs=300, lr=1e-4, batch_size=32, num_workers=4,
save_dir='./checkpoints', push_to_hub=False, hub_model_id=None,
):
"""Phase 1: Standard supervised training with ACL curriculum.
ACL Curriculum:
- Epoch 0-50: difficulty ramps from 0→1 (easy to hard samples)
- Loss weighting: acl_lambda ramps from 0.5→1.0
- Dataset augmentation intensity increases with difficulty
FiLM temporal modulation:
- Starts training after epoch 30 (model needs basic features first)
- Activated for 50% of batches initially, 100% after epoch 100
"""
print(f"=== Phase 1 Training: {num_epochs} epochs ===")
os.makedirs(save_dir, exist_ok=True)
from .losses import CombinedTrackingLoss
loss_fn = CombinedTrackingLoss(use_uncertainty=True, use_adw=True).to(device)
model = model.to(device)
optimizer = build_optimizer(model, lr=lr)
loss_optimizer = build_loss_optimizer(loss_fn)
scheduler = build_scheduler(optimizer, num_epochs)
scaler = GradScaler() if device == 'cuda' else None
dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True, drop_last=True,
)
best_loss = float('inf')
for epoch in range(num_epochs):
# ACL curriculum: progressive difficulty ramp-up
acl_progress = min(1.0, (epoch + 1) / 50) # Linear ramp over 50 epochs
acl_lambda = 0.5 + 0.5 * acl_progress # Loss weight: 0.5 → 1.0
# Update dataset difficulty (if supported)
if hasattr(train_dataset, 'set_acl_difficulty'):
train_dataset.set_acl_difficulty(acl_progress)
elif hasattr(train_dataset, 'datasets'):
# ConcatDataset: update all sub-datasets
for ds in train_dataset.datasets:
if hasattr(ds, 'set_acl_difficulty'):
ds.set_acl_difficulty(acl_progress)
# FiLM temporal modulation schedule
use_temporal = epoch >= 30 # Start FiLM after 30 epochs
loss_metrics = train_one_epoch(
model, dataloader, optimizer, loss_optimizer, scaler, loss_fn,
device, epoch, num_epochs, acl_lambda=acl_lambda,
use_temporal=use_temporal,
)
scheduler.step()
# Reset temporal state between epochs (each epoch starts fresh sequences)
model.reset_temporal()
print(f"Epoch {epoch}/{num_epochs} | "
f"Loss: {loss_metrics['total']:.4f} | "
f"Heatmap: {loss_metrics['heatmap']:.4f} | "
f"GIoU: {loss_metrics['giou']:.4f} | "
f"Size: {loss_metrics['size']:.4f} | "
f"LR: {scheduler.get_last_lr()[0]:.6f} | "
f"ACL: {acl_progress:.2f} | "
f"Temporal: {'ON' if use_temporal else 'OFF'}")
# Save best
if loss_metrics['total'] < best_loss:
best_loss = loss_metrics['total']
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': best_loss,
'config': config,
}, os.path.join(save_dir, 'best_phase1.pth'))
# Save periodic
if (epoch + 1) % 50 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_metrics['total'],
'config': config,
}, os.path.join(save_dir, f'phase1_epoch{epoch+1}.pth'))
if push_to_hub and hub_model_id:
_push_checkpoint_to_hub(model, save_dir, hub_model_id, 'phase1')
return model
def train_phase2(
model, train_dataset, config, device='cuda',
num_epochs=100, lr=1e-5, batch_size=32, num_workers=4,
save_dir='./checkpoints', push_to_hub=False, hub_model_id=None,
teacher_model=None,
):
"""Phase 2: Fine-tuning with frozen shared experts, contrastive loss, and distillation.
Changes from Phase 1:
1. Shared experts in TMoE blocks are frozen
2. Contrastive loss on template/search features (temporal consistency)
3. FiLM temporal modulation always active
4. Optional AFKD knowledge distillation from teacher model
5. Lower learning rate, especially for backbone
"""
print(f"=== Phase 2 Training: {num_epochs} epochs ===")
# Freeze shared experts in TMoE blocks
model.freeze_backbone_shared_experts()
frozen_count = sum(1 for p in model.parameters() if not p.requires_grad)
total_count = sum(1 for p in model.parameters())
print(f" Frozen parameters: {frozen_count}/{total_count}")
from .losses import CombinedTrackingLoss, MemoryContrastiveLoss, AFKDDistillationLoss
loss_fn = CombinedTrackingLoss(use_uncertainty=True, use_adw=True).to(device)
contrastive_loss = MemoryContrastiveLoss(temperature=0.1).to(device)
# Optional distillation loss
distill_loss = None
if teacher_model is not None:
teacher_model = teacher_model.to(device)
teacher_model.eval()
for p in teacher_model.parameters():
p.requires_grad = False
distill_loss = AFKDDistillationLoss(
student_dim=config['dim'], teacher_dim=768, temperature=4.0
).to(device)
print(" AFKD distillation enabled (teacher → student)")
model = model.to(device)
optimizer = build_optimizer(model, lr=lr, backbone_lr_scale=0.01)
loss_optimizer = build_loss_optimizer(loss_fn)
scheduler = build_scheduler(optimizer, num_epochs, warmup_epochs=2)
scaler = GradScaler() if device == 'cuda' else None
dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True, drop_last=True,
)
best_loss = float('inf')
for epoch in range(num_epochs):
model.train()
total_loss = 0
num_batches = 0
for batch_idx, batch in enumerate(dataloader):
template = batch['template'].to(device)
searches = batch['searches'].to(device)
gt_heatmaps = batch['heatmaps'].to(device)
gt_sizes = batch['sizes'].to(device)
gt_boxes = batch['boxes'].to(device)
B, K = searches.shape[:2]
optimizer.zero_grad()
if loss_optimizer is not None:
loss_optimizer.zero_grad()
with autocast(enabled=scaler is not None):
pred = model(template, searches, use_temporal=True)
# Accumulate loss over K frames
loss = torch.tensor(0.0, device=device)
for k in range(K):
pred_k = {
'heatmap': pred['heatmap'][:, k],
'size': pred['size'][:, k],
'boxes': pred['boxes'][:, k],
}
if 'log_variance' in pred:
pred_k['log_variance'] = pred['log_variance'][:, k]
loss_dict_k = loss_fn(pred_k, gt_heatmaps[:, k],
gt_sizes[:, k], gt_boxes[:, k])
loss = loss + loss_dict_k['total']
loss = loss / K
# Contrastive loss
t_pooled = pred['template_feat'].mean(dim=1)
s_pooled = pred['search_feats'][:, -1].mean(dim=1)
c_loss = contrastive_loss(t_pooled, s_pooled)
loss = loss + 0.1 * c_loss
# AFKD distillation (if teacher available)
if distill_loss is not None and teacher_model is not None:
with torch.no_grad():
teacher_pred = teacher_model(template, searches)
# Distill on last frame features
d_loss = distill_loss(
student_feat=pred['search_feats'][:, -1],
teacher_feat=teacher_pred['search_feats'][:, -1] if teacher_pred['search_feats'].ndim == 4 else teacher_pred['search_feat'],
student_logits=pred['heatmap'][:, -1],
teacher_logits=teacher_pred['heatmap'][:, -1] if teacher_pred['heatmap'].ndim == 5 else teacher_pred['heatmap'],
)
loss = loss + 0.5 * d_loss
if scaler is not None:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), grad_clip=1.0)
scaler.step(optimizer)
if loss_optimizer is not None:
scaler.unscale_(loss_optimizer)
scaler.step(loss_optimizer)
scaler.update()
else:
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if loss_optimizer is not None:
loss_optimizer.step()
total_loss += loss.item()
num_batches += 1
if batch_idx % 100 == 0:
msg = (f" Phase2 Epoch {epoch}/{num_epochs} | Batch {batch_idx} | "
f"Loss: {loss.item():.4f} | "
f"Contr: {c_loss.item():.4f}")
if distill_loss is not None:
msg += f" | Distill: {d_loss.item():.4f}"
print(msg)
scheduler.step()
model.reset_temporal() # Reset between epochs
avg_loss = total_loss / max(num_batches, 1)
print(f"Phase2 Epoch {epoch}/{num_epochs} | Avg Loss: {avg_loss:.4f} | "
f"LR: {scheduler.get_last_lr()[0]:.6f}")
if avg_loss < best_loss:
best_loss = avg_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'loss': best_loss,
'config': config,
}, os.path.join(save_dir, 'best_phase2.pth'))
if (epoch + 1) % 25 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'loss': avg_loss,
'config': config,
}, os.path.join(save_dir, f'phase2_epoch{epoch+1}.pth'))
if push_to_hub and hub_model_id:
_push_checkpoint_to_hub(model, save_dir, hub_model_id, 'phase2')
return model
def _push_checkpoint_to_hub(model, save_dir, hub_model_id, phase):
"""Push checkpoint to HuggingFace Hub."""
try:
from huggingface_hub import HfApi
api = HfApi()
api.upload_folder(
folder_path=save_dir,
repo_id=hub_model_id,
path_in_repo=f'checkpoints/{phase}',
)
print(f"Pushed {phase} checkpoint to {hub_model_id}")
except Exception as e:
print(f"Warning: Could not push to hub: {e}")