BcantCode commited on
Commit
ee7a26f
·
verified ·
1 Parent(s): 01809ab

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +198 -554
train.py CHANGED
@@ -1,621 +1,265 @@
1
  """
2
  PriviGaze Training Script - Privileged Distillation for Gaze Estimation
3
-
4
- Two-phase training:
5
- 1. Teacher pre-training: Train teacher on privileged data (RGB eyes + blurred face)
6
- 2. Student distillation: Train student with privileged distillation loss
7
-
8
- This script implements Phase 2 (distillation). Phase 1 (teacher pre-training)
9
- should be run first to produce a strong teacher model.
10
-
11
- Usage:
12
- python train.py --mode distill --teacher-path ./teacher_best.pt --epochs 100
13
  """
14
-
15
- import os
16
- import sys
17
- import argparse
18
- import time
19
  from pathlib import Path
20
  from collections import defaultdict
21
-
22
  import torch
23
- import torch.nn as nn
24
  from torch.optim import AdamW
25
- from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
26
  import numpy as np
27
 
28
- # Add parent directory to path
29
  sys.path.insert(0, str(Path(__file__).parent))
30
-
31
  from models.teacher import PriviGazeTeacher
32
  from models.student import PriviGazeStudent, count_parameters
33
- from models.distillation_loss import PriviGazeDistillationLoss
34
- from models.dataset import create_dataloaders, SyntheticGazeDataset
35
 
36
- # Trackio for experiment monitoring
37
  try:
38
- import trackio
39
- HAS_TRACKIO = True
40
  except ImportError:
41
- HAS_TRACKIO = False
42
- print("Warning: trackio not installed. Logging to stdout only.")
43
 
44
 
45
  class DistillationTrainer:
46
- """Trains student model via privileged distillation from teacher."""
47
-
48
- def __init__(
49
- self,
50
- teacher: PriviGazeTeacher,
51
- student: PriviGazeStudent,
52
- distillation_loss: PriviGazeDistillationLoss,
53
- train_loader,
54
- val_loader,
55
- device: torch.device,
56
- lr: float = 1e-4,
57
- weight_decay: float = 1e-4,
58
- epochs: int = 100,
59
- teacher_frozen: bool = True,
60
- trackio_project: str = "privi-gaze",
61
- trackio_run_name: str = "distill",
62
- ):
63
  self.teacher = teacher.to(device)
64
  self.student = student.to(device)
65
- self.distillation_loss = distillation_loss.to(device)
66
  self.train_loader = train_loader
67
  self.val_loader = val_loader
68
  self.device = device
69
  self.epochs = epochs
70
- self.trackio_project = trackio_project
71
- self.trackio_run_name = trackio_run_name
72
-
73
- if teacher_frozen:
74
- for param in self.teacher.parameters():
75
- param.requires_grad = False
76
- self.teacher.eval()
77
-
78
- # Optimizer: only student parameters
79
- self.optimizer = AdamW(
80
- self.student.parameters(),
81
- lr=lr,
82
- weight_decay=weight_decay,
83
- )
84
-
85
- # Scheduler
86
- self.scheduler = CosineAnnealingLR(
87
- self.optimizer,
88
- T_max=epochs,
89
- eta_min=lr * 0.01,
90
- )
91
-
92
- # Track best model
93
- self.best_val_loss = float('inf')
94
  self.best_epoch = 0
95
-
96
- # Metrics tracking
97
- self.metrics_history = defaultdict(list)
98
-
99
- # Initialize trackio
100
  if HAS_TRACKIO:
101
- trackio.init(
102
- project=trackio_project,
103
- run_name=trackio_run_name,
104
- config={
105
- 'student_params': count_parameters(self.student),
106
- 'teacher_params': count_parameters(self.teacher),
107
- 'lr': lr,
108
- 'weight_decay': weight_decay,
109
- 'epochs': epochs,
110
- 'batch_size': train_loader.batch_size,
111
- }
112
- )
113
-
114
- def train_epoch(self, epoch: int) -> dict:
115
- """Train for one epoch."""
116
  self.student.train()
117
- epoch_losses = defaultdict(float)
118
- num_batches = 0
119
-
120
- for batch_idx, batch in enumerate(self.train_loader):
121
- # Move to device
122
- left_eye = batch['left_eye'].to(self.device)
123
- right_eye = batch['right_eye'].to(self.device)
124
- face_blurred = batch['face_blurred_gray'].to(self.device)
125
- face_gray = batch['face_gray'].to(self.device)
126
- pitch_target = batch['pitch'].to(self.device)
127
- yaw_target = batch['yaw'].to(self.device)
128
-
129
- # Teacher forward (no grad)
130
  with torch.no_grad():
131
- t_pitch, t_yaw, t_features = self.teacher(
132
- left_eye, right_eye, face_blurred
133
- )
134
- # Get teacher logits by running forward through heads
135
- # (We need these for logit distillation)
136
- # We extract them from the teacher's internal state
137
- t_pitch_logits = self.teacher.pitch_head(t_features)
138
- t_yaw_logits = self.teacher.yaw_head(t_features)
139
-
140
- # Student forward
141
- s_pitch, s_yaw, s_features = self.student(face_gray)
142
- s_pitch_logits = self.student.pitch_head(s_features)
143
- s_yaw_logits = self.student.yaw_head(s_features)
144
-
145
- # Compute distillation loss
146
- loss, loss_dict = self.distillation_loss(
147
- s_pitch, s_yaw,
148
- s_pitch_logits, s_yaw_logits,
149
- s_features,
150
- t_pitch, t_yaw,
151
- t_pitch_logits, t_yaw_logits,
152
- t_features,
153
- pitch_target, yaw_target,
154
- )
155
-
156
- # Backward
157
- self.optimizer.zero_grad()
158
  loss.backward()
159
-
160
- # Gradient clipping
161
- torch.nn.utils.clip_grad_norm_(self.student.parameters(), max_norm=1.0)
162
-
163
- self.optimizer.step()
164
-
165
- # Accumulate losses
166
- for k, v in loss_dict.items():
167
- epoch_losses[k] += v
168
- num_batches += 1
169
-
170
- # Log every 100 batches
171
- if batch_idx % 100 == 0:
172
- self._log_step(epoch, batch_idx, loss_dict)
173
-
174
- # Average losses
175
- for k in epoch_losses:
176
- epoch_losses[k] /= num_batches
177
-
178
- return dict(epoch_losses)
179
-
180
  @torch.no_grad()
181
- def validate(self, epoch: int) -> dict:
182
- """Validate the student model."""
183
  self.student.eval()
184
  self.teacher.eval()
185
-
186
- val_losses = defaultdict(float)
187
- angular_errors = []
188
- pitch_errors = []
189
- yaw_errors = []
190
- num_batches = 0
191
-
192
  for batch in self.val_loader:
193
- left_eye = batch['left_eye'].to(self.device)
194
- right_eye = batch['right_eye'].to(self.device)
195
- face_blurred = batch['face_blurred_gray'].to(self.device)
196
- face_gray = batch['face_gray'].to(self.device)
197
- pitch_target = batch['pitch'].to(self.device)
198
- yaw_target = batch['yaw'].to(self.device)
199
-
200
- # Teacher forward
201
- t_pitch, t_yaw, t_features = self.teacher(
202
- left_eye, right_eye, face_blurred
203
- )
204
- t_pitch_logits = self.teacher.pitch_head(t_features)
205
- t_yaw_logits = self.teacher.yaw_head(t_features)
206
-
207
- # Student forward
208
- s_pitch, s_yaw, s_features = self.student(face_gray)
209
- s_pitch_logits = self.student.pitch_head(s_features)
210
- s_yaw_logits = self.student.yaw_head(s_features)
211
-
212
- # Compute loss
213
- loss, loss_dict = self.distillation_loss(
214
- s_pitch, s_yaw,
215
- s_pitch_logits, s_yaw_logits,
216
- s_features,
217
- t_pitch, t_yaw,
218
- t_pitch_logits, t_yaw_logits,
219
- t_features,
220
- pitch_target, yaw_target,
221
- )
222
-
223
- for k, v in loss_dict.items():
224
- val_losses[k] += v
225
- num_batches += 1
226
-
227
- # Compute angular error
228
- angular_err = torch.sqrt(
229
- (s_pitch - pitch_target) ** 2 + (s_yaw - yaw_target) ** 2
230
- )
231
- angular_errors.extend(angular_err.cpu().tolist())
232
- pitch_errors.extend((s_pitch - pitch_target).abs().cpu().tolist())
233
- yaw_errors.extend((s_yaw - yaw_target).abs().cpu().tolist())
234
-
235
- for k in val_losses:
236
- val_losses[k] /= num_batches
237
-
238
- val_losses['angular_error_mean'] = np.mean(angular_errors)
239
- val_losses['angular_error_std'] = np.std(angular_errors)
240
- val_losses['pitch_error_mean'] = np.mean(pitch_errors)
241
- val_losses['yaw_error_mean'] = np.mean(yaw_errors)
242
-
243
- return dict(val_losses)
244
-
245
- def _log_step(self, epoch, batch_idx, loss_dict):
246
- """Log training step metrics."""
247
- msg = f"Epoch {epoch} | Batch {batch_idx} | "
248
- msg += " | ".join(f"{k}={v:.4f}" for k, v in loss_dict.items())
249
- print(msg)
250
-
251
- if HAS_TRACKIO:
252
- for k, v in loss_dict.items():
253
- trackio.log({f"train/{k}": v})
254
-
255
- def _log_epoch(self, epoch, train_losses, val_losses):
256
- """Log epoch metrics."""
257
- print(f"\n{'='*60}")
258
- print(f"Epoch {epoch} Summary:")
259
- print(f" Train: ", " | ".join(f"{k}={v:.4f}" for k, v in train_losses.items()))
260
- print(f" Val: ", " | ".join(f"{k}={v:.4f}" for k, v in val_losses.items()))
261
- print(f"{'='*60}\n")
262
-
263
- if HAS_TRACKIO:
264
- for k, v in train_losses.items():
265
- trackio.log({f"epoch/train_{k}": v}, step=epoch)
266
- for k, v in val_losses.items():
267
- trackio.log({f"epoch/val_{k}": v}, step=epoch)
268
-
269
- # Alert on overfitting
270
- if epoch > 10 and val_losses.get('loss_total', 0) > self.best_val_loss * 1.3:
271
- trackio.alert(
272
- "Possible Overfitting",
273
- f"Val loss {val_losses['loss_total']:.4f} >> best {self.best_val_loss:.4f} at epoch {epoch}",
274
- level="WARN",
275
- )
276
-
277
- def train(self, save_dir: str = "./checkpoints"):
278
- """Full training loop."""
279
  os.makedirs(save_dir, exist_ok=True)
280
-
281
- print(f"Starting distillation training for {self.epochs} epochs")
282
- print(f"Student parameters: {count_parameters(self.student):,}")
283
- print(f"Device: {self.device}")
284
-
285
- start_time = time.time()
286
-
287
  for epoch in range(self.epochs):
288
- epoch_start = time.time()
289
-
290
- # Train
291
- train_losses = self.train_epoch(epoch)
292
-
293
- # Validate
294
- val_losses = self.validate(epoch)
295
-
296
- # Step scheduler
297
- self.scheduler.step()
298
- current_lr = self.optimizer.param_groups[0]['lr']
299
-
300
- # Log
301
- self._log_epoch(epoch, train_losses, val_losses)
302
-
303
- # Track metrics
304
- for k, v in train_losses.items():
305
- self.metrics_history[f'train_{k}'].append(v)
306
- for k, v in val_losses.items():
307
- self.metrics_history[f'val_{k}'].append(v)
308
-
309
- # Save best model
310
- val_total = val_losses.get('loss_total', val_losses.get('angular_error_mean', float('inf')))
311
- if val_total < self.best_val_loss:
312
- self.best_val_loss = val_total
313
  self.best_epoch = epoch
314
-
315
- torch.save({
316
- 'epoch': epoch,
317
- 'student_state_dict': self.student.state_dict(),
318
- 'optimizer_state_dict': self.optimizer.state_dict(),
319
- 'best_val_loss': self.best_val_loss,
320
- 'metrics_history': dict(self.metrics_history),
321
- }, os.path.join(save_dir, 'student_best.pt'))
322
-
323
- if HAS_TRACKIO:
324
- trackio.alert(
325
- "New Best Model",
326
- f"Val loss: {val_total:.4f} at epoch {epoch} (angular: {val_losses.get('angular_error_mean', 0):.2f}°)",
327
- level="INFO",
328
- )
329
-
330
- # Save checkpoint every 10 epochs
331
  if epoch % 10 == 0:
332
- torch.save({
333
- 'epoch': epoch,
334
- 'student_state_dict': self.student.state_dict(),
335
- 'optimizer_state_dict': self.optimizer.state_dict(),
336
- }, os.path.join(save_dir, f'student_epoch_{epoch}.pt'))
337
-
338
- epoch_time = time.time() - epoch_start
339
- print(f"Epoch {epoch} took {epoch_time:.1f}s, LR: {current_lr:.2e}")
340
-
341
- total_time = time.time() - start_time
342
- print(f"\nTraining complete! Total time: {total_time/3600:.1f}h")
343
- print(f"Best validation loss: {self.best_val_loss:.4f} at epoch {self.best_epoch}")
344
-
345
- if HAS_TRACKIO:
346
- trackio.alert(
347
- "Training Complete",
348
- f"Best val loss: {self.best_val_loss:.4f} at epoch {self.best_epoch}. "
349
- f"Student params: {count_parameters(self.student):,}",
350
- level="INFO",
351
- )
352
-
353
- return self.best_val_loss
354
 
355
 
356
- def pretrain_teacher(
357
- teacher: PriviGazeTeacher,
358
- train_loader,
359
- val_loader,
360
- device: torch.device,
361
- lr: float = 1e-4,
362
- epochs: int = 50,
363
- save_dir: str = "./checkpoints",
364
- ) -> str:
365
- """Pre-train the teacher model on privileged data."""
366
- from models.distillation_loss import L2CSLoss, AngularLoss
367
-
368
  teacher = teacher.to(device)
369
- teacher.train()
370
-
371
- optimizer = AdamW(teacher.parameters(), lr=lr, weight_decay=1e-4)
372
- scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=lr * 0.01)
373
-
374
- pitch_loss_fn = L2CSLoss(gaze_bins=90)
375
- yaw_loss_fn = L2CSLoss(gaze_bins=90)
376
- angular_loss_fn = AngularLoss()
377
-
378
- best_val_loss = float('inf')
379
  os.makedirs(save_dir, exist_ok=True)
380
-
381
  for epoch in range(epochs):
382
- # Training
383
  teacher.train()
384
- train_loss_total = 0.0
385
  for batch in train_loader:
386
- left_eye = batch['left_eye'].to(device)
387
- right_eye = batch['right_eye'].to(device)
388
- face_blurred = batch['face_blurred_gray'].to(device)
389
- pitch_target = batch['pitch'].to(device)
390
- yaw_target = batch['yaw'].to(device)
391
-
392
- pitch_pred, yaw_pred, features = teacher(left_eye, right_eye, face_blurred)
393
- pitch_logits = teacher.pitch_head(features)
394
- yaw_logits = teacher.yaw_head(features)
395
-
396
- loss = (pitch_loss_fn(pitch_logits, pitch_pred, pitch_target) +
397
- yaw_loss_fn(yaw_logits, yaw_pred, yaw_target) +
398
- angular_loss_fn(pitch_pred, yaw_pred, pitch_target, yaw_target))
399
-
400
- optimizer.zero_grad()
401
  loss.backward()
402
- torch.nn.utils.clip_grad_norm_(teacher.parameters(), max_norm=1.0)
403
- optimizer.step()
404
-
405
- train_loss_total += loss.item()
406
-
407
- train_loss_total /= len(train_loader)
408
-
409
- # Validation
410
  teacher.eval()
411
- val_loss_total = 0.0
412
- val_angular = 0.0
413
  with torch.no_grad():
414
  for batch in val_loader:
415
- left_eye = batch['left_eye'].to(device)
416
- right_eye = batch['right_eye'].to(device)
417
- face_blurred = batch['face_blurred_gray'].to(device)
418
- pitch_target = batch['pitch'].to(device)
419
- yaw_target = batch['yaw'].to(device)
420
-
421
- pitch_pred, yaw_pred, features = teacher(left_eye, right_eye, face_blurred)
422
- pitch_logits = teacher.pitch_head(features)
423
- yaw_logits = teacher.yaw_head(features)
424
-
425
- loss = (pitch_loss_fn(pitch_logits, pitch_pred, pitch_target) +
426
- yaw_loss_fn(yaw_logits, yaw_pred, yaw_target))
427
- val_loss_total += loss.item()
428
-
429
- angular_err = torch.sqrt((pitch_pred - pitch_target)**2 + (yaw_pred - yaw_target)**2)
430
- val_angular += angular_err.mean().item()
431
-
432
- val_loss_total /= len(val_loader)
433
- val_angular /= len(val_loader)
434
-
435
- scheduler.step()
436
-
437
- print(f"Teacher Epoch {epoch}: train_loss={train_loss_total:.4f}, "
438
- f"val_loss={val_loss_total:.4f}, val_angular={val_angular:.2f}°")
439
-
440
- if val_loss_total < best_val_loss:
441
- best_val_loss = val_loss_total
442
  torch.save(teacher.state_dict(), os.path.join(save_dir, 'teacher_best.pt'))
443
-
444
  return os.path.join(save_dir, 'teacher_best.pt')
445
 
446
 
447
  def main():
448
- parser = argparse.ArgumentParser(description="PriviGaze Distillation Training")
449
- parser.add_argument('--mode', type=str, default='distill',
450
- choices=['pretrain_teacher', 'distill', 'both'],
451
- help='Training mode')
452
- parser.add_argument('--teacher-path', type=str, default=None,
453
- help='Path to pre-trained teacher checkpoint')
454
- parser.add_argument('--batch-size', type=int, default=32,
455
- help='Batch size')
456
- parser.add_argument('--epochs', type=int, default=100,
457
- help='Number of distillation epochs')
458
- parser.add_argument('--teacher-epochs', type=int, default=50,
459
- help='Number of teacher pre-training epochs')
460
- parser.add_argument('--lr', type=float, default=1e-4,
461
- help='Learning rate')
462
- parser.add_argument('--weight-decay', type=float, default=1e-4,
463
- help='Weight decay')
464
- parser.add_argument('--num-train', type=int, default=40000,
465
- help='Number of synthetic training samples')
466
- parser.add_argument('--num-val', type=int, default=5000,
467
- help='Number of synthetic val samples')
468
- parser.add_argument('--save-dir', type=str, default='./checkpoints',
469
- help='Directory to save checkpoints')
470
- parser.add_argument('--device', type=str, default='cuda',
471
- help='Device to train on')
472
- parser.add_argument('--trackio-project', type=str, default='privi-gaze',
473
- help='Trackio project name')
474
- parser.add_argument('--trackio-run', type=str, default='distill-run',
475
- help='Trackio run name')
476
- parser.add_argument('--push-to-hub', action='store_true',
477
- help='Push trained model to HF Hub')
478
- parser.add_argument('--hub-model-id', type=str, default=None,
479
- help='HF Hub model ID for pushing')
480
- parser.add_argument('--alpha-contrastive', type=float, default=0.5,
481
- help='Weight for contrastive distillation loss')
482
- parser.add_argument('--alpha-mmd', type=float, default=0.1,
483
- help='Weight for MMD distribution matching loss')
484
- parser.add_argument('--alpha-logit', type=float, default=0.5,
485
- help='Weight for logit distillation loss')
486
-
487
- args = parser.parse_args()
488
-
489
- # Device setup
490
  device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
491
- print(f"Using device: {device}")
492
-
493
- # Create dataloaders
494
  train_loader, val_loader, test_loader = create_dataloaders(
495
- num_train=args.num_train,
496
- num_val=args.num_val,
497
- batch_size=args.batch_size,
498
- )
499
-
500
- # Initialize models
501
- teacher = PriviGazeTeacher(
502
- eye_backbone="facebook/convnextv2-atto-1k-224",
503
- face_backbone="facebook/convnextv2-nano-22k-384",
504
- )
505
-
506
  student = PriviGazeStudent()
507
-
508
- print(f"Teacher parameters: {count_parameters(teacher):,}")
509
- print(f"Student parameters: {count_parameters(student):,}")
510
-
511
- # Pre-train teacher if needed
512
  if args.mode in ['pretrain_teacher', 'both']:
513
- print("\n=== Phase 1: Pre-training Teacher ===")
514
- teacher_path = pretrain_teacher(
515
- teacher, train_loader, val_loader, device,
516
- lr=args.lr, epochs=args.teacher_epochs,
517
- save_dir=args.save_dir,
518
- )
519
- print(f"Teacher saved to: {teacher_path}")
520
- args.teacher_path = teacher_path
521
-
522
- # Load teacher checkpoint
523
  if args.teacher_path:
524
- print(f"\nLoading teacher from: {args.teacher_path}")
525
  teacher.load_state_dict(torch.load(args.teacher_path, map_location=device))
526
-
527
- # Distill
528
  if args.mode in ['distill', 'both']:
529
- print("\n=== Phase 2: Privileged Distillation ===")
530
-
531
- # Create distillation loss
532
- dist_loss = PriviGazeDistillationLoss(
533
- gaze_bins=90,
534
- teacher_feature_dim=256,
535
- student_feature_dim=128,
536
- alpha_contrastive=args.alpha_contrastive,
537
- alpha_mmd=args.alpha_mmd,
538
- alpha_logit=args.alpha_logit,
539
- )
540
-
541
- # Create trainer
542
- trainer = DistillationTrainer(
543
- teacher=teacher,
544
- student=student,
545
- distillation_loss=dist_loss,
546
- train_loader=train_loader,
547
- val_loader=val_loader,
548
- device=device,
549
- lr=args.lr,
550
- weight_decay=args.weight_decay,
551
- epochs=args.epochs,
552
- trackio_project=args.trackio_project,
553
- trackio_run_name=args.trackio_run,
554
- )
555
-
556
- # Train
557
- best_loss = trainer.train(save_dir=args.save_dir)
558
-
559
- # Test evaluation
560
- print("\n=== Final Test Evaluation ===")
561
- student.eval()
562
- student.to(device)
563
-
564
- test_angular_errors = []
565
  with torch.no_grad():
566
  for batch in test_loader:
567
- face_gray = batch['face_gray'].to(device)
568
- pitch_target = batch['pitch'].to(device)
569
- yaw_target = batch['yaw'].to(device)
570
-
571
- pitch_pred, yaw_pred, _ = student(face_gray)
572
-
573
- angular_err = torch.sqrt(
574
- (pitch_pred - pitch_target) ** 2 + (yaw_pred - yaw_target) ** 2
575
- )
576
- test_angular_errors.extend(angular_err.cpu().tolist())
577
-
578
- mean_error = np.mean(test_angular_errors)
579
- std_error = np.std(test_angular_errors)
580
- print(f"Test Angular Error: {mean_error:.2f}° ± {std_error:.2f}°")
581
-
582
- if HAS_TRACKIO:
583
- trackio.log({
584
- 'test/angular_error_mean': mean_error,
585
- 'test/angular_error_std': std_error,
586
- })
587
- trackio.alert(
588
- "Test Results",
589
- f"Angular error: {mean_error:.2f}° ± {std_error:.2f}°. "
590
- f"Student params: {count_parameters(student):,}",
591
- level="INFO",
592
- )
593
-
594
- # Push to hub
595
  if args.push_to_hub and args.hub_model_id:
596
  from huggingface_hub import HfApi
597
- api = HfApi()
598
-
599
- # Save final model
600
- model_path = os.path.join(args.save_dir, 'student_final.pt')
601
- torch.save({
602
- 'student_state_dict': student.state_dict(),
603
- 'config': {
604
- 'params': count_parameters(student),
605
- 'test_angular_error': mean_error,
606
- }
607
- }, model_path)
608
-
609
- # Upload
610
- api.upload_file(
611
- path_or_fileobj=model_path,
612
- path_in_repo="student_model.pt",
613
- repo_id=args.hub_model_id,
614
- )
615
- print(f"Model pushed to: https://huggingface.co/{args.hub_model_id}")
616
-
617
- return best_loss if args.mode in ['distill', 'both'] else None
618
-
619
 
620
  if __name__ == "__main__":
621
  main()
 
1
  """
2
  PriviGaze Training Script - Privileged Distillation for Gaze Estimation
 
 
 
 
 
 
 
 
 
 
3
  """
4
+ import os, sys, argparse, time
 
 
 
 
5
  from pathlib import Path
6
  from collections import defaultdict
 
7
  import torch
 
8
  from torch.optim import AdamW
9
+ from torch.optim.lr_scheduler import CosineAnnealingLR
10
  import numpy as np
11
 
 
12
  sys.path.insert(0, str(Path(__file__).parent))
 
13
  from models.teacher import PriviGazeTeacher
14
  from models.student import PriviGazeStudent, count_parameters
15
+ from models.distillation_loss import PriviGazeDistillationLoss, L2CSLoss, AngularLoss
16
+ from models.dataset import create_dataloaders
17
 
 
18
  try:
19
+ import trackio; HAS_TRACKIO = True
 
20
  except ImportError:
21
+ HAS_TRACKIO = False; print("Warning: trackio not installed.")
 
22
 
23
 
24
  class DistillationTrainer:
25
+ def __init__(self, teacher, student, dist_loss, train_loader, val_loader,
26
+ device, lr=1e-4, wd=1e-4, epochs=100, tproj="privi-gaze", trun="distill"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  self.teacher = teacher.to(device)
28
  self.student = student.to(device)
29
+ self.dist_loss = dist_loss.to(device)
30
  self.train_loader = train_loader
31
  self.val_loader = val_loader
32
  self.device = device
33
  self.epochs = epochs
34
+ for p in self.teacher.parameters(): p.requires_grad = False
35
+ self.teacher.eval()
36
+ self.opt = AdamW(self.student.parameters(), lr=lr, weight_decay=wd)
37
+ self.sched = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=lr*0.01)
38
+ self.best_val = float('inf')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  self.best_epoch = 0
40
+ self.metrics = defaultdict(list)
 
 
 
 
41
  if HAS_TRACKIO:
42
+ trackio.init(project=tproj, run_name=trun,
43
+ config={'student_params': count_parameters(student),
44
+ 'teacher_params': count_parameters(teacher), 'lr': lr, 'epochs': epochs})
45
+
46
+ def train_epoch(self, epoch):
 
 
 
 
 
 
 
 
 
 
47
  self.student.train()
48
+ losses = defaultdict(float)
49
+ n = 0
50
+ for bi, batch in enumerate(self.train_loader):
51
+ le = batch['left_eye'].to(self.device)
52
+ re = batch['right_eye'].to(self.device)
53
+ fb = batch['face_blurred_gray'].to(self.device)
54
+ fg = batch['face_gray'].to(self.device)
55
+ pt = batch['pitch'].to(self.device)
56
+ yt = batch['yaw'].to(self.device)
 
 
 
 
57
  with torch.no_grad():
58
+ tp, ty, tplog, tylog, tf = self.teacher(le, re, fb)
59
+ sp, sy, sf = self.student(fg)
60
+ splog = self.student.pitch_head(sf)
61
+ sylog = self.student.yaw_head(sf)
62
+ loss, ld = self.dist_loss(sp, sy, splog, sylog, sf,
63
+ tp, ty, tplog, tylog, tf, pt, yt)
64
+ self.opt.zero_grad()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  loss.backward()
66
+ torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.0)
67
+ self.opt.step()
68
+ for k, v in ld.items(): losses[k] += v
69
+ n += 1
70
+ if bi % 100 == 0:
71
+ print(f"Epoch {epoch} | Batch {bi} | " + " | ".join(f"{k}={v:.4f}" for k, v in ld.items()))
72
+ if HAS_TRACKIO:
73
+ for k2, v2 in ld.items(): trackio.log({f"train/{k2}": v2})
74
+ return {k: v/n for k, v in losses.items()}
75
+
 
 
 
 
 
 
 
 
 
 
 
76
  @torch.no_grad()
77
+ def validate(self, epoch):
 
78
  self.student.eval()
79
  self.teacher.eval()
80
+ losses = defaultdict(float)
81
+ ae, pe, ye = [], [], []
82
+ n = 0
 
 
 
 
83
  for batch in self.val_loader:
84
+ le = batch['left_eye'].to(self.device)
85
+ re = batch['right_eye'].to(self.device)
86
+ fb = batch['face_blurred_gray'].to(self.device)
87
+ fg = batch['face_gray'].to(self.device)
88
+ pt = batch['pitch'].to(self.device)
89
+ yt = batch['yaw'].to(self.device)
90
+ tp, ty, tplog, tylog, tf = self.teacher(le, re, fb)
91
+ sp, sy, sf = self.student(fg)
92
+ splog = self.student.pitch_head(sf)
93
+ sylog = self.student.yaw_head(sf)
94
+ loss, ld = self.dist_loss(sp, sy, splog, sylog, sf,
95
+ tp, ty, tplog, tylog, tf, pt, yt)
96
+ for k, v in ld.items(): losses[k] += v
97
+ n += 1
98
+ aerr = torch.sqrt((sp-pt)**2 + (sy-yt)**2)
99
+ ae.extend(aerr.cpu().tolist())
100
+ pe.extend((sp-pt).abs().cpu().tolist())
101
+ ye.extend((sy-yt).abs().cpu().tolist())
102
+ for k in losses: losses[k] /= n
103
+ losses['angular_mean'] = np.mean(ae)
104
+ losses['angular_std'] = np.std(ae)
105
+ losses['pitch_mean'] = np.mean(pe)
106
+ losses['yaw_mean'] = np.mean(ye)
107
+ return dict(losses)
108
+
109
+ def train(self, save_dir="./checkpoints"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  os.makedirs(save_dir, exist_ok=True)
111
+ print(f"Distillation: {self.epochs} epochs | Student: {count_parameters(self.student):,} params")
112
+ t0 = time.time()
 
 
 
 
 
113
  for epoch in range(self.epochs):
114
+ te = time.time()
115
+ tl = self.train_epoch(epoch)
116
+ vl = self.validate(epoch)
117
+ self.sched.step()
118
+ lr = self.opt.param_groups[0]['lr']
119
+ print(f"\n{'='*60}")
120
+ print(f"Epoch {epoch}: train={tl.get('loss_total',0):.4f} val={vl.get('loss_total',0):.4f} angular={vl.get('angular_mean',0):.2f}deg")
121
+ print(f"{'='*60}\n")
122
+ for k, v in tl.items(): self.metrics[f'train_{k}'].append(v)
123
+ for k, v in vl.items(): self.metrics[f'val_{k}'].append(v)
124
+ vt = vl.get('loss_total', vl.get('angular_mean', float('inf')))
125
+ if vt < self.best_val:
126
+ self.best_val = vt
 
 
 
 
 
 
 
 
 
 
 
 
127
  self.best_epoch = epoch
128
+ torch.save({'epoch': epoch, 'student_state_dict': self.student.state_dict(),
129
+ 'opt_state_dict': self.opt.state_dict(), 'best_val': self.best_val,
130
+ 'metrics': dict(self.metrics)}, os.path.join(save_dir, 'student_best.pt'))
131
+ if HAS_TRACKIO: trackio.alert("New Best", f"Val {vt:.4f} @ epoch {epoch}", level="INFO")
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if epoch % 10 == 0:
133
+ torch.save({'epoch': epoch, 'student_state_dict': self.student.state_dict(),
134
+ 'opt_state_dict': self.opt.state_dict()},
135
+ os.path.join(save_dir, f'student_epoch_{epoch}.pt'))
136
+ print(f"Epoch {epoch} took {time.time()-te:.1f}s, LR={lr:.2e}")
137
+ print(f"\nDone! Best val: {self.best_val:.4f} @ epoch {self.best_epoch}")
138
+ return self.best_val
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
 
141
+ def pretrain_teacher(teacher, train_loader, val_loader, device, lr=1e-4, epochs=50, save_dir="./checkpoints"):
 
 
 
 
 
 
 
 
 
 
 
142
  teacher = teacher.to(device)
143
+ opt = AdamW(teacher.parameters(), lr=lr, weight_decay=1e-4)
144
+ sched = CosineAnnealingLR(opt, T_max=epochs, eta_min=lr*0.01)
145
+ ploss = L2CSLoss(gaze_bins=90)
146
+ yloss = L2CSLoss(gaze_bins=90)
147
+ aloss = AngularLoss()
148
+ best = float('inf')
 
 
 
 
149
  os.makedirs(save_dir, exist_ok=True)
 
150
  for epoch in range(epochs):
 
151
  teacher.train()
152
+ tloss = 0.0
153
  for batch in train_loader:
154
+ le = batch['left_eye'].to(device)
155
+ re = batch['right_eye'].to(device)
156
+ fb = batch['face_blurred_gray'].to(device)
157
+ pt = batch['pitch'].to(device)
158
+ yt = batch['yaw'].to(device)
159
+ pp, yp, pl, yl, _ = teacher(le, re, fb)
160
+ loss = ploss(pl, pp, pt) + yloss(yl, yp, yt) + aloss(pp, yp, pt, yt)
161
+ opt.zero_grad()
 
 
 
 
 
 
 
162
  loss.backward()
163
+ torch.nn.utils.clip_grad_norm_(teacher.parameters(), 1.0)
164
+ opt.step()
165
+ tloss += loss.item()
166
+ tloss /= len(train_loader)
 
 
 
 
167
  teacher.eval()
168
+ vloss = 0.0
169
+ va = 0.0
170
  with torch.no_grad():
171
  for batch in val_loader:
172
+ le = batch['left_eye'].to(device)
173
+ re = batch['right_eye'].to(device)
174
+ fb = batch['face_blurred_gray'].to(device)
175
+ pt = batch['pitch'].to(device)
176
+ yt = batch['yaw'].to(device)
177
+ pp, yp, pl, yl, _ = teacher(le, re, fb)
178
+ vloss += (ploss(pl, pp, pt) + yloss(yl, yp, yt)).item()
179
+ va += torch.sqrt((pp-pt)**2 + (yp-yt)**2).mean().item()
180
+ vloss /= len(val_loader)
181
+ va /= len(val_loader)
182
+ sched.step()
183
+ print(f"Teacher Epoch {epoch}: train={tloss:.4f} val={vloss:.4f} angular={va:.2f}deg")
184
+ if vloss < best:
185
+ best = vloss
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  torch.save(teacher.state_dict(), os.path.join(save_dir, 'teacher_best.pt'))
 
187
  return os.path.join(save_dir, 'teacher_best.pt')
188
 
189
 
190
  def main():
191
+ p = argparse.ArgumentParser(description="PriviGaze Training")
192
+ p.add_argument('--mode', type=str, default='distill', choices=['pretrain_teacher','distill','both'])
193
+ p.add_argument('--teacher-path', type=str, default=None)
194
+ p.add_argument('--batch-size', type=int, default=32)
195
+ p.add_argument('--epochs', type=int, default=100)
196
+ p.add_argument('--teacher-epochs', type=int, default=50)
197
+ p.add_argument('--lr', type=float, default=1e-4)
198
+ p.add_argument('--weight-decay', type=float, default=1e-4)
199
+ p.add_argument('--num-train', type=int, default=40000)
200
+ p.add_argument('--num-val', type=int, default=5000)
201
+ p.add_argument('--save-dir', type=str, default='./checkpoints')
202
+ p.add_argument('--device', type=str, default='cuda')
203
+ p.add_argument('--trackio-project', type=str, default='privi-gaze')
204
+ p.add_argument('--trackio-run', type=str, default='distill-run')
205
+ p.add_argument('--push-to-hub', action='store_true')
206
+ p.add_argument('--hub-model-id', type=str, default=None)
207
+ p.add_argument('--alpha-contrastive', type=float, default=0.5)
208
+ p.add_argument('--alpha-mmd', type=float, default=0.1)
209
+ p.add_argument('--alpha-logit', type=float, default=0.5)
210
+ args = p.parse_args()
211
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
213
+ print(f"Device: {device}")
 
 
214
  train_loader, val_loader, test_loader = create_dataloaders(
215
+ num_train=args.num_train, num_val=args.num_val, batch_size=args.batch_size)
216
+
217
+ teacher = PriviGazeTeacher()
 
 
 
 
 
 
 
 
218
  student = PriviGazeStudent()
219
+ print(f"Teacher: {count_parameters(teacher):,} params")
220
+ print(f"Student: {count_parameters(student):,} params")
221
+
 
 
222
  if args.mode in ['pretrain_teacher', 'both']:
223
+ print("\n=== Phase 1: Teacher Pre-training ===")
224
+ tp = pretrain_teacher(teacher, train_loader, val_loader, device,
225
+ lr=args.lr, epochs=args.teacher_epochs, save_dir=args.save_dir)
226
+ args.teacher_path = tp
227
+
 
 
 
 
 
228
  if args.teacher_path:
229
+ print(f"\nLoading teacher: {args.teacher_path}")
230
  teacher.load_state_dict(torch.load(args.teacher_path, map_location=device))
231
+
 
232
  if args.mode in ['distill', 'both']:
233
+ print("\n=== Phase 2: Distillation ===")
234
+ dloss = PriviGazeDistillationLoss(
235
+ gaze_bins=90, teacher_feature_dim=256, student_feature_dim=128,
236
+ alpha_contrastive=args.alpha_contrastive, alpha_mmd=args.alpha_mmd,
237
+ alpha_logit=args.alpha_logit)
238
+ trainer = DistillationTrainer(teacher, student, dloss, train_loader, val_loader,
239
+ device, lr=args.lr, wd=args.weight_decay, epochs=args.epochs,
240
+ tproj=args.trackio_project, trun=args.trackio_run)
241
+ trainer.train(save_dir=args.save_dir)
242
+
243
+ print("\n=== Test ===")
244
+ student.eval().to(device)
245
+ terr = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  with torch.no_grad():
247
  for batch in test_loader:
248
+ fg = batch['face_gray'].to(device)
249
+ pt = batch['pitch'].to(device)
250
+ yt = batch['yaw'].to(device)
251
+ sp, sy, _ = student(fg)
252
+ terr.extend(torch.sqrt((sp-pt)**2 + (sy-yt)**2).cpu().tolist())
253
+ me = np.mean(terr); se = np.std(terr)
254
+ print(f"Test Angular Error: {me:.2f}deg +- {se:.2f}deg")
255
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  if args.push_to_hub and args.hub_model_id:
257
  from huggingface_hub import HfApi
258
+ mp = os.path.join(args.save_dir, 'student_final.pt')
259
+ torch.save({'student_state_dict': student.state_dict(),
260
+ 'config': {'params': count_parameters(student), 'test_err': me}}, mp)
261
+ HfApi().upload_file(path_or_fileobj=mp, path_in_repo="student_model.pt", repo_id=args.hub_model_id)
262
+ print(f"Pushed to: https://huggingface.co/{args.hub_model_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  if __name__ == "__main__":
265
  main()