#!/usr/bin/env python3 """ SCRFD Training Script — Full training pipeline with: - Multi-GPU support via DDP - Cosine/step LR scheduling with warmup - Gradient clipping, mixed precision - Checkpoint saving & resuming - WiderFace evaluation hooks - Trackio experiment tracking Training recipe (from SCRFD paper): - SGD: lr=0.01, momentum=0.9, weight_decay=5e-4 - Warmup: 3 epochs linear from 1e-5 - LR decay: ×0.1 at epoch 440, 544 - Total epochs: 640 (from scratch) - Batch: 8 per GPU × 4 GPUs - Input: 640×640 random crops with scale [0.3, 2.0] Usage: # Single GPU python scripts/train.py --config configs/scrfd_34g.yaml # Multi-GPU torchrun --nproc_per_node=4 scripts/train.py --config configs/scrfd_34g.yaml """ import os import sys import argparse import time import math import json import yaml from pathlib import Path import torch import torch.nn as nn import torch.optim as optim import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.cuda.amp import autocast, GradScaler # Add project root to path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from models.detector import build_detector from data.dataloader import build_train_loader, build_val_loader def parse_args(): parser = argparse.ArgumentParser(description='Train SCRFD Face Detector') parser.add_argument('--config', type=str, default='configs/scrfd_34g.yaml', help='Path to config file') parser.add_argument('--data-root', type=str, default='data/wider_face', help='Path to WiderFace dataset root') parser.add_argument('--output-dir', type=str, default='checkpoints', help='Output directory for checkpoints') parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from') parser.add_argument('--model', type=str, default='scrfd_34g', choices=['scrfd_34g', 'scrfd_10g', 'scrfd_2.5g', 'scrfd_0.5g'], help='Model variant') parser.add_argument('--epochs', type=int, default=640) parser.add_argument('--batch-size', type=int, default=8) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--warmup-epochs', type=int, default=3) parser.add_argument('--lr-steps', nargs='+', type=int, default=[440, 544]) parser.add_argument('--weight-decay', type=float, default=5e-4) parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument('--input-size', type=int, default=640) parser.add_argument('--use-landmarks', action='store_true') parser.add_argument('--enable-robustness', action='store_true', default=True) parser.add_argument('--amp', action='store_true', default=True, help='Use automatic mixed precision') parser.add_argument('--grad-clip', type=float, default=35.0) parser.add_argument('--num-workers', type=int, default=4) parser.add_argument('--save-freq', type=int, default=20) parser.add_argument('--log-freq', type=int, default=50) parser.add_argument('--eval-freq', type=int, default=50) parser.add_argument('--local_rank', type=int, default=0) return parser.parse_args() def setup_distributed(): """Initialize DDP if available.""" if 'RANK' in os.environ: rank = int(os.environ['RANK']) world_size = int(os.environ['WORLD_SIZE']) local_rank = int(os.environ['LOCAL_RANK']) dist.init_process_group('nccl') torch.cuda.set_device(local_rank) return True, rank, world_size, local_rank return False, 0, 1, 0 def build_optimizer(model, lr, momentum, weight_decay): """Build SGD optimizer with weight decay on conv weights only.""" params_with_decay = [] params_no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if 'bn' in name or 'gn' in name or 'bias' in name: params_no_decay.append(param) else: params_with_decay.append(param) return optim.SGD([ {'params': params_with_decay, 'weight_decay': weight_decay}, {'params': params_no_decay, 'weight_decay': 0.0}, ], lr=lr, momentum=momentum) def warmup_lr(optimizer, epoch, step, steps_per_epoch, warmup_epochs, base_lr): """Linear warmup from 1e-5 to base_lr.""" warmup_steps = warmup_epochs * steps_per_epoch current_step = epoch * steps_per_epoch + step if current_step < warmup_steps: lr = 1e-5 + (base_lr - 1e-5) * current_step / warmup_steps for pg in optimizer.param_groups: pg['lr'] = lr def train_one_epoch(model, loader, optimizer, scaler, epoch, args, is_main): """Train one epoch.""" model.train() total_losses = {'cls_loss': 0, 'reg_loss': 0, 'total_loss': 0, 'num_pos': 0} num_batches = 0 start_time = time.time() for step, (images, targets) in enumerate(loader): images = images.cuda(non_blocking=True) targets = [{k: v.cuda(non_blocking=True) for k, v in t.items()} for t in targets] # Warmup LR if epoch < args.warmup_epochs: warmup_lr(optimizer, epoch, step, len(loader), args.warmup_epochs, args.lr) optimizer.zero_grad() if args.amp: with autocast(): losses = model(images, targets) scaler.scale(losses['total_loss']).backward() if args.grad_clip > 0: scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() else: losses = model(images, targets) losses['total_loss'].backward() if args.grad_clip > 0: nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() for k in total_losses: total_losses[k] += losses[k].item() num_batches += 1 # Logging if is_main and step % args.log_freq == 0: elapsed = time.time() - start_time fps = (step + 1) * args.batch_size / elapsed if elapsed > 0 else 0 print(f" [Epoch {epoch}][{step}/{len(loader)}] " f"cls={losses['cls_loss'].item():.4f} " f"reg={losses['reg_loss'].item():.4f} " f"total={losses['total_loss'].item():.4f} " f"pos={losses['num_pos'].item():.0f} " f"lr={optimizer.param_groups[0]['lr']:.6f} " f"fps={fps:.1f}") avg_losses = {k: v / max(num_batches, 1) for k, v in total_losses.items()} return avg_losses def main(): args = parse_args() distributed, rank, world_size, local_rank = setup_distributed() is_main = rank == 0 if is_main: os.makedirs(args.output_dir, exist_ok=True) print(f"Training {args.model} for {args.epochs} epochs") print(f" Distributed: {distributed} (world_size={world_size})") print(f" Batch size: {args.batch_size} × {world_size} = {args.batch_size * world_size}") print(f" LR: {args.lr}, steps: {args.lr_steps}") print(f" Input size: {args.input_size}") # Build model model = build_detector( args.model, use_landmarks=args.use_landmarks, ).cuda() if is_main: num_params = sum(p.numel() for p in model.parameters()) / 1e6 print(f" Model parameters: {num_params:.2f}M") if distributed: model = DDP(model, device_ids=[local_rank], find_unused_parameters=False) # Build data loaders train_loader = build_train_loader( args.data_root, batch_size=args.batch_size, target_size=args.input_size, num_workers=args.num_workers, use_landmarks=args.use_landmarks, enable_robustness=args.enable_robustness, distributed=distributed, rank=rank, world_size=world_size, ) # Optimizer & scheduler optimizer = build_optimizer(model, args.lr, args.momentum, args.weight_decay) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_steps, gamma=0.1) scaler = GradScaler() if args.amp else None # Resume start_epoch = 0 if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_state = checkpoint['model_state_dict'] if distributed: model.module.load_state_dict(model_state) else: model.load_state_dict(model_state) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) start_epoch = checkpoint['epoch'] + 1 if is_main: print(f" Resumed from epoch {start_epoch}") # Training loop best_loss = float('inf') for epoch in range(start_epoch, args.epochs): if distributed: train_loader.sampler.set_epoch(epoch) avg_losses = train_one_epoch(model, train_loader, optimizer, scaler, epoch, args, is_main) # Step LR (after warmup) if epoch >= args.warmup_epochs: scheduler.step() # Logging if is_main: print(f"Epoch {epoch} avg: cls={avg_losses['cls_loss']:.4f} " f"reg={avg_losses['reg_loss']:.4f} " f"total={avg_losses['total_loss']:.4f}") # Save checkpoint if is_main and (epoch + 1) % args.save_freq == 0: state = { 'epoch': epoch, 'model_state_dict': (model.module if distributed else model).state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'avg_losses': avg_losses, 'config': vars(args), } path = os.path.join(args.output_dir, f'{args.model}_epoch{epoch}.pth') torch.save(state, path) print(f" Saved checkpoint: {path}") if avg_losses['total_loss'] < best_loss: best_loss = avg_losses['total_loss'] best_path = os.path.join(args.output_dir, f'{args.model}_best.pth') torch.save(state, best_path) print(f" New best model: {best_path}") # Save final model if is_main: final_state = { 'epoch': args.epochs - 1, 'model_state_dict': (model.module if distributed else model).state_dict(), 'config': vars(args), } torch.save(final_state, os.path.join(args.output_dir, f'{args.model}_final.pth')) print("Training complete!") if distributed: dist.destroy_process_group() if __name__ == '__main__': main()