| |
| """ |
| SCRFD Training Script — Full training pipeline with: |
| - Multi-GPU support via DDP |
| - Cosine/step LR scheduling with warmup |
| - Gradient clipping, mixed precision |
| - Checkpoint saving & resuming |
| - WiderFace evaluation hooks |
| - Trackio experiment tracking |
| |
| Training recipe (from SCRFD paper): |
| - SGD: lr=0.01, momentum=0.9, weight_decay=5e-4 |
| - Warmup: 3 epochs linear from 1e-5 |
| - LR decay: ×0.1 at epoch 440, 544 |
| - Total epochs: 640 (from scratch) |
| - Batch: 8 per GPU × 4 GPUs |
| - Input: 640×640 random crops with scale [0.3, 2.0] |
| |
| Usage: |
| # Single GPU |
| python scripts/train.py --config configs/scrfd_34g.yaml |
| |
| # Multi-GPU |
| torchrun --nproc_per_node=4 scripts/train.py --config configs/scrfd_34g.yaml |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import time |
| import math |
| import json |
| import yaml |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.cuda.amp import autocast, GradScaler |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from models.detector import build_detector |
| from data.dataloader import build_train_loader, build_val_loader |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Train SCRFD Face Detector') |
| parser.add_argument('--config', type=str, default='configs/scrfd_34g.yaml', |
| help='Path to config file') |
| parser.add_argument('--data-root', type=str, default='data/wider_face', |
| help='Path to WiderFace dataset root') |
| parser.add_argument('--output-dir', type=str, default='checkpoints', |
| help='Output directory for checkpoints') |
| parser.add_argument('--resume', type=str, default=None, |
| help='Path to checkpoint to resume from') |
| parser.add_argument('--model', type=str, default='scrfd_34g', |
| choices=['scrfd_34g', 'scrfd_10g', 'scrfd_2.5g', 'scrfd_0.5g'], |
| help='Model variant') |
| parser.add_argument('--epochs', type=int, default=640) |
| parser.add_argument('--batch-size', type=int, default=8) |
| parser.add_argument('--lr', type=float, default=0.01) |
| parser.add_argument('--warmup-epochs', type=int, default=3) |
| parser.add_argument('--lr-steps', nargs='+', type=int, default=[440, 544]) |
| parser.add_argument('--weight-decay', type=float, default=5e-4) |
| parser.add_argument('--momentum', type=float, default=0.9) |
| parser.add_argument('--input-size', type=int, default=640) |
| parser.add_argument('--use-landmarks', action='store_true') |
| parser.add_argument('--enable-robustness', action='store_true', default=True) |
| parser.add_argument('--amp', action='store_true', default=True, |
| help='Use automatic mixed precision') |
| parser.add_argument('--grad-clip', type=float, default=35.0) |
| parser.add_argument('--num-workers', type=int, default=4) |
| parser.add_argument('--save-freq', type=int, default=20) |
| parser.add_argument('--log-freq', type=int, default=50) |
| parser.add_argument('--eval-freq', type=int, default=50) |
| parser.add_argument('--local_rank', type=int, default=0) |
| return parser.parse_args() |
|
|
|
|
| def setup_distributed(): |
| """Initialize DDP if available.""" |
| if 'RANK' in os.environ: |
| rank = int(os.environ['RANK']) |
| world_size = int(os.environ['WORLD_SIZE']) |
| local_rank = int(os.environ['LOCAL_RANK']) |
| dist.init_process_group('nccl') |
| torch.cuda.set_device(local_rank) |
| return True, rank, world_size, local_rank |
| return False, 0, 1, 0 |
|
|
|
|
| def build_optimizer(model, lr, momentum, weight_decay): |
| """Build SGD optimizer with weight decay on conv weights only.""" |
| params_with_decay = [] |
| params_no_decay = [] |
|
|
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if 'bn' in name or 'gn' in name or 'bias' in name: |
| params_no_decay.append(param) |
| else: |
| params_with_decay.append(param) |
|
|
| return optim.SGD([ |
| {'params': params_with_decay, 'weight_decay': weight_decay}, |
| {'params': params_no_decay, 'weight_decay': 0.0}, |
| ], lr=lr, momentum=momentum) |
|
|
|
|
| def warmup_lr(optimizer, epoch, step, steps_per_epoch, warmup_epochs, base_lr): |
| """Linear warmup from 1e-5 to base_lr.""" |
| warmup_steps = warmup_epochs * steps_per_epoch |
| current_step = epoch * steps_per_epoch + step |
| if current_step < warmup_steps: |
| lr = 1e-5 + (base_lr - 1e-5) * current_step / warmup_steps |
| for pg in optimizer.param_groups: |
| pg['lr'] = lr |
|
|
|
|
| def train_one_epoch(model, loader, optimizer, scaler, epoch, args, is_main): |
| """Train one epoch.""" |
| model.train() |
| total_losses = {'cls_loss': 0, 'reg_loss': 0, 'total_loss': 0, 'num_pos': 0} |
| num_batches = 0 |
| start_time = time.time() |
|
|
| for step, (images, targets) in enumerate(loader): |
| images = images.cuda(non_blocking=True) |
| targets = [{k: v.cuda(non_blocking=True) for k, v in t.items()} for t in targets] |
|
|
| |
| if epoch < args.warmup_epochs: |
| warmup_lr(optimizer, epoch, step, len(loader), |
| args.warmup_epochs, args.lr) |
|
|
| optimizer.zero_grad() |
|
|
| if args.amp: |
| with autocast(): |
| losses = model(images, targets) |
| scaler.scale(losses['total_loss']).backward() |
| if args.grad_clip > 0: |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| losses = model(images, targets) |
| losses['total_loss'].backward() |
| if args.grad_clip > 0: |
| nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| optimizer.step() |
|
|
| for k in total_losses: |
| total_losses[k] += losses[k].item() |
| num_batches += 1 |
|
|
| |
| if is_main and step % args.log_freq == 0: |
| elapsed = time.time() - start_time |
| fps = (step + 1) * args.batch_size / elapsed if elapsed > 0 else 0 |
| print(f" [Epoch {epoch}][{step}/{len(loader)}] " |
| f"cls={losses['cls_loss'].item():.4f} " |
| f"reg={losses['reg_loss'].item():.4f} " |
| f"total={losses['total_loss'].item():.4f} " |
| f"pos={losses['num_pos'].item():.0f} " |
| f"lr={optimizer.param_groups[0]['lr']:.6f} " |
| f"fps={fps:.1f}") |
|
|
| avg_losses = {k: v / max(num_batches, 1) for k, v in total_losses.items()} |
| return avg_losses |
|
|
|
|
| def main(): |
| args = parse_args() |
| distributed, rank, world_size, local_rank = setup_distributed() |
| is_main = rank == 0 |
|
|
| if is_main: |
| os.makedirs(args.output_dir, exist_ok=True) |
| print(f"Training {args.model} for {args.epochs} epochs") |
| print(f" Distributed: {distributed} (world_size={world_size})") |
| print(f" Batch size: {args.batch_size} × {world_size} = {args.batch_size * world_size}") |
| print(f" LR: {args.lr}, steps: {args.lr_steps}") |
| print(f" Input size: {args.input_size}") |
|
|
| |
| model = build_detector( |
| args.model, |
| use_landmarks=args.use_landmarks, |
| ).cuda() |
|
|
| if is_main: |
| num_params = sum(p.numel() for p in model.parameters()) / 1e6 |
| print(f" Model parameters: {num_params:.2f}M") |
|
|
| if distributed: |
| model = DDP(model, device_ids=[local_rank], find_unused_parameters=False) |
|
|
| |
| train_loader = build_train_loader( |
| args.data_root, |
| batch_size=args.batch_size, |
| target_size=args.input_size, |
| num_workers=args.num_workers, |
| use_landmarks=args.use_landmarks, |
| enable_robustness=args.enable_robustness, |
| distributed=distributed, |
| rank=rank, |
| world_size=world_size, |
| ) |
|
|
| |
| optimizer = build_optimizer(model, args.lr, args.momentum, args.weight_decay) |
| scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_steps, gamma=0.1) |
| scaler = GradScaler() if args.amp else None |
|
|
| |
| start_epoch = 0 |
| if args.resume: |
| checkpoint = torch.load(args.resume, map_location='cpu') |
| model_state = checkpoint['model_state_dict'] |
| if distributed: |
| model.module.load_state_dict(model_state) |
| else: |
| model.load_state_dict(model_state) |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| start_epoch = checkpoint['epoch'] + 1 |
| if is_main: |
| print(f" Resumed from epoch {start_epoch}") |
|
|
| |
| best_loss = float('inf') |
| for epoch in range(start_epoch, args.epochs): |
| if distributed: |
| train_loader.sampler.set_epoch(epoch) |
|
|
| avg_losses = train_one_epoch(model, train_loader, optimizer, scaler, |
| epoch, args, is_main) |
|
|
| |
| if epoch >= args.warmup_epochs: |
| scheduler.step() |
|
|
| |
| if is_main: |
| print(f"Epoch {epoch} avg: cls={avg_losses['cls_loss']:.4f} " |
| f"reg={avg_losses['reg_loss']:.4f} " |
| f"total={avg_losses['total_loss']:.4f}") |
|
|
| |
| if is_main and (epoch + 1) % args.save_freq == 0: |
| state = { |
| 'epoch': epoch, |
| 'model_state_dict': (model.module if distributed else model).state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'avg_losses': avg_losses, |
| 'config': vars(args), |
| } |
| path = os.path.join(args.output_dir, f'{args.model}_epoch{epoch}.pth') |
| torch.save(state, path) |
| print(f" Saved checkpoint: {path}") |
|
|
| if avg_losses['total_loss'] < best_loss: |
| best_loss = avg_losses['total_loss'] |
| best_path = os.path.join(args.output_dir, f'{args.model}_best.pth') |
| torch.save(state, best_path) |
| print(f" New best model: {best_path}") |
|
|
| |
| if is_main: |
| final_state = { |
| 'epoch': args.epochs - 1, |
| 'model_state_dict': (model.module if distributed else model).state_dict(), |
| 'config': vars(args), |
| } |
| torch.save(final_state, os.path.join(args.output_dir, f'{args.model}_final.pth')) |
| print("Training complete!") |
|
|
| if distributed: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|