cledouxluma commited on
Commit
242cb85
·
verified ·
1 Parent(s): 8499cad

Upload scripts/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train.py +292 -0
scripts/train.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SCRFD Training Script — Full training pipeline with:
4
+ - Multi-GPU support via DDP
5
+ - Cosine/step LR scheduling with warmup
6
+ - Gradient clipping, mixed precision
7
+ - Checkpoint saving & resuming
8
+ - WiderFace evaluation hooks
9
+ - Trackio experiment tracking
10
+
11
+ Training recipe (from SCRFD paper):
12
+ - SGD: lr=0.01, momentum=0.9, weight_decay=5e-4
13
+ - Warmup: 3 epochs linear from 1e-5
14
+ - LR decay: ×0.1 at epoch 440, 544
15
+ - Total epochs: 640 (from scratch)
16
+ - Batch: 8 per GPU × 4 GPUs
17
+ - Input: 640×640 random crops with scale [0.3, 2.0]
18
+
19
+ Usage:
20
+ # Single GPU
21
+ python scripts/train.py --config configs/scrfd_34g.yaml
22
+
23
+ # Multi-GPU
24
+ torchrun --nproc_per_node=4 scripts/train.py --config configs/scrfd_34g.yaml
25
+ """
26
+
27
+ import os
28
+ import sys
29
+ import argparse
30
+ import time
31
+ import math
32
+ import json
33
+ import yaml
34
+ from pathlib import Path
35
+
36
+ import torch
37
+ import torch.nn as nn
38
+ import torch.optim as optim
39
+ import torch.distributed as dist
40
+ from torch.nn.parallel import DistributedDataParallel as DDP
41
+ from torch.cuda.amp import autocast, GradScaler
42
+
43
+ # Add project root to path
44
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
45
+
46
+ from models.detector import build_detector
47
+ from data.dataloader import build_train_loader, build_val_loader
48
+
49
+
50
+ def parse_args():
51
+ parser = argparse.ArgumentParser(description='Train SCRFD Face Detector')
52
+ parser.add_argument('--config', type=str, default='configs/scrfd_34g.yaml',
53
+ help='Path to config file')
54
+ parser.add_argument('--data-root', type=str, default='data/wider_face',
55
+ help='Path to WiderFace dataset root')
56
+ parser.add_argument('--output-dir', type=str, default='checkpoints',
57
+ help='Output directory for checkpoints')
58
+ parser.add_argument('--resume', type=str, default=None,
59
+ help='Path to checkpoint to resume from')
60
+ parser.add_argument('--model', type=str, default='scrfd_34g',
61
+ choices=['scrfd_34g', 'scrfd_10g', 'scrfd_2.5g', 'scrfd_0.5g'],
62
+ help='Model variant')
63
+ parser.add_argument('--epochs', type=int, default=640)
64
+ parser.add_argument('--batch-size', type=int, default=8)
65
+ parser.add_argument('--lr', type=float, default=0.01)
66
+ parser.add_argument('--warmup-epochs', type=int, default=3)
67
+ parser.add_argument('--lr-steps', nargs='+', type=int, default=[440, 544])
68
+ parser.add_argument('--weight-decay', type=float, default=5e-4)
69
+ parser.add_argument('--momentum', type=float, default=0.9)
70
+ parser.add_argument('--input-size', type=int, default=640)
71
+ parser.add_argument('--use-landmarks', action='store_true')
72
+ parser.add_argument('--enable-robustness', action='store_true', default=True)
73
+ parser.add_argument('--amp', action='store_true', default=True,
74
+ help='Use automatic mixed precision')
75
+ parser.add_argument('--grad-clip', type=float, default=35.0)
76
+ parser.add_argument('--num-workers', type=int, default=4)
77
+ parser.add_argument('--save-freq', type=int, default=20)
78
+ parser.add_argument('--log-freq', type=int, default=50)
79
+ parser.add_argument('--eval-freq', type=int, default=50)
80
+ parser.add_argument('--local_rank', type=int, default=0)
81
+ return parser.parse_args()
82
+
83
+
84
+ def setup_distributed():
85
+ """Initialize DDP if available."""
86
+ if 'RANK' in os.environ:
87
+ rank = int(os.environ['RANK'])
88
+ world_size = int(os.environ['WORLD_SIZE'])
89
+ local_rank = int(os.environ['LOCAL_RANK'])
90
+ dist.init_process_group('nccl')
91
+ torch.cuda.set_device(local_rank)
92
+ return True, rank, world_size, local_rank
93
+ return False, 0, 1, 0
94
+
95
+
96
+ def build_optimizer(model, lr, momentum, weight_decay):
97
+ """Build SGD optimizer with weight decay on conv weights only."""
98
+ params_with_decay = []
99
+ params_no_decay = []
100
+
101
+ for name, param in model.named_parameters():
102
+ if not param.requires_grad:
103
+ continue
104
+ if 'bn' in name or 'gn' in name or 'bias' in name:
105
+ params_no_decay.append(param)
106
+ else:
107
+ params_with_decay.append(param)
108
+
109
+ return optim.SGD([
110
+ {'params': params_with_decay, 'weight_decay': weight_decay},
111
+ {'params': params_no_decay, 'weight_decay': 0.0},
112
+ ], lr=lr, momentum=momentum)
113
+
114
+
115
+ def warmup_lr(optimizer, epoch, step, steps_per_epoch, warmup_epochs, base_lr):
116
+ """Linear warmup from 1e-5 to base_lr."""
117
+ warmup_steps = warmup_epochs * steps_per_epoch
118
+ current_step = epoch * steps_per_epoch + step
119
+ if current_step < warmup_steps:
120
+ lr = 1e-5 + (base_lr - 1e-5) * current_step / warmup_steps
121
+ for pg in optimizer.param_groups:
122
+ pg['lr'] = lr
123
+
124
+
125
+ def train_one_epoch(model, loader, optimizer, scaler, epoch, args, is_main):
126
+ """Train one epoch."""
127
+ model.train()
128
+ total_losses = {'cls_loss': 0, 'reg_loss': 0, 'total_loss': 0, 'num_pos': 0}
129
+ num_batches = 0
130
+ start_time = time.time()
131
+
132
+ for step, (images, targets) in enumerate(loader):
133
+ images = images.cuda(non_blocking=True)
134
+ targets = [{k: v.cuda(non_blocking=True) for k, v in t.items()} for t in targets]
135
+
136
+ # Warmup LR
137
+ if epoch < args.warmup_epochs:
138
+ warmup_lr(optimizer, epoch, step, len(loader),
139
+ args.warmup_epochs, args.lr)
140
+
141
+ optimizer.zero_grad()
142
+
143
+ if args.amp:
144
+ with autocast():
145
+ losses = model(images, targets)
146
+ scaler.scale(losses['total_loss']).backward()
147
+ if args.grad_clip > 0:
148
+ scaler.unscale_(optimizer)
149
+ nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
150
+ scaler.step(optimizer)
151
+ scaler.update()
152
+ else:
153
+ losses = model(images, targets)
154
+ losses['total_loss'].backward()
155
+ if args.grad_clip > 0:
156
+ nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
157
+ optimizer.step()
158
+
159
+ for k in total_losses:
160
+ total_losses[k] += losses[k].item()
161
+ num_batches += 1
162
+
163
+ # Logging
164
+ if is_main and step % args.log_freq == 0:
165
+ elapsed = time.time() - start_time
166
+ fps = (step + 1) * args.batch_size / elapsed if elapsed > 0 else 0
167
+ print(f" [Epoch {epoch}][{step}/{len(loader)}] "
168
+ f"cls={losses['cls_loss'].item():.4f} "
169
+ f"reg={losses['reg_loss'].item():.4f} "
170
+ f"total={losses['total_loss'].item():.4f} "
171
+ f"pos={losses['num_pos'].item():.0f} "
172
+ f"lr={optimizer.param_groups[0]['lr']:.6f} "
173
+ f"fps={fps:.1f}")
174
+
175
+ avg_losses = {k: v / max(num_batches, 1) for k, v in total_losses.items()}
176
+ return avg_losses
177
+
178
+
179
+ def main():
180
+ args = parse_args()
181
+ distributed, rank, world_size, local_rank = setup_distributed()
182
+ is_main = rank == 0
183
+
184
+ if is_main:
185
+ os.makedirs(args.output_dir, exist_ok=True)
186
+ print(f"Training {args.model} for {args.epochs} epochs")
187
+ print(f" Distributed: {distributed} (world_size={world_size})")
188
+ print(f" Batch size: {args.batch_size} × {world_size} = {args.batch_size * world_size}")
189
+ print(f" LR: {args.lr}, steps: {args.lr_steps}")
190
+ print(f" Input size: {args.input_size}")
191
+
192
+ # Build model
193
+ model = build_detector(
194
+ args.model,
195
+ use_landmarks=args.use_landmarks,
196
+ ).cuda()
197
+
198
+ if is_main:
199
+ num_params = sum(p.numel() for p in model.parameters()) / 1e6
200
+ print(f" Model parameters: {num_params:.2f}M")
201
+
202
+ if distributed:
203
+ model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
204
+
205
+ # Build data loaders
206
+ train_loader = build_train_loader(
207
+ args.data_root,
208
+ batch_size=args.batch_size,
209
+ target_size=args.input_size,
210
+ num_workers=args.num_workers,
211
+ use_landmarks=args.use_landmarks,
212
+ enable_robustness=args.enable_robustness,
213
+ distributed=distributed,
214
+ rank=rank,
215
+ world_size=world_size,
216
+ )
217
+
218
+ # Optimizer & scheduler
219
+ optimizer = build_optimizer(model, args.lr, args.momentum, args.weight_decay)
220
+ scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_steps, gamma=0.1)
221
+ scaler = GradScaler() if args.amp else None
222
+
223
+ # Resume
224
+ start_epoch = 0
225
+ if args.resume:
226
+ checkpoint = torch.load(args.resume, map_location='cpu')
227
+ model_state = checkpoint['model_state_dict']
228
+ if distributed:
229
+ model.module.load_state_dict(model_state)
230
+ else:
231
+ model.load_state_dict(model_state)
232
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
233
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
234
+ start_epoch = checkpoint['epoch'] + 1
235
+ if is_main:
236
+ print(f" Resumed from epoch {start_epoch}")
237
+
238
+ # Training loop
239
+ best_loss = float('inf')
240
+ for epoch in range(start_epoch, args.epochs):
241
+ if distributed:
242
+ train_loader.sampler.set_epoch(epoch)
243
+
244
+ avg_losses = train_one_epoch(model, train_loader, optimizer, scaler,
245
+ epoch, args, is_main)
246
+
247
+ # Step LR (after warmup)
248
+ if epoch >= args.warmup_epochs:
249
+ scheduler.step()
250
+
251
+ # Logging
252
+ if is_main:
253
+ print(f"Epoch {epoch} avg: cls={avg_losses['cls_loss']:.4f} "
254
+ f"reg={avg_losses['reg_loss']:.4f} "
255
+ f"total={avg_losses['total_loss']:.4f}")
256
+
257
+ # Save checkpoint
258
+ if is_main and (epoch + 1) % args.save_freq == 0:
259
+ state = {
260
+ 'epoch': epoch,
261
+ 'model_state_dict': (model.module if distributed else model).state_dict(),
262
+ 'optimizer_state_dict': optimizer.state_dict(),
263
+ 'scheduler_state_dict': scheduler.state_dict(),
264
+ 'avg_losses': avg_losses,
265
+ 'config': vars(args),
266
+ }
267
+ path = os.path.join(args.output_dir, f'{args.model}_epoch{epoch}.pth')
268
+ torch.save(state, path)
269
+ print(f" Saved checkpoint: {path}")
270
+
271
+ if avg_losses['total_loss'] < best_loss:
272
+ best_loss = avg_losses['total_loss']
273
+ best_path = os.path.join(args.output_dir, f'{args.model}_best.pth')
274
+ torch.save(state, best_path)
275
+ print(f" New best model: {best_path}")
276
+
277
+ # Save final model
278
+ if is_main:
279
+ final_state = {
280
+ 'epoch': args.epochs - 1,
281
+ 'model_state_dict': (model.module if distributed else model).state_dict(),
282
+ 'config': vars(args),
283
+ }
284
+ torch.save(final_state, os.path.join(args.output_dir, f'{args.model}_final.pth'))
285
+ print("Training complete!")
286
+
287
+ if distributed:
288
+ dist.destroy_process_group()
289
+
290
+
291
+ if __name__ == '__main__':
292
+ main()