BcantCode commited on
Commit
d945fba
·
verified ·
1 Parent(s): b39769f

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +621 -0
train.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PriviGaze Training Script - Privileged Distillation for Gaze Estimation
3
+
4
+ Two-phase training:
5
+ 1. Teacher pre-training: Train teacher on privileged data (RGB eyes + blurred face)
6
+ 2. Student distillation: Train student with privileged distillation loss
7
+
8
+ This script implements Phase 2 (distillation). Phase 1 (teacher pre-training)
9
+ should be run first to produce a strong teacher model.
10
+
11
+ Usage:
12
+ python train.py --mode distill --teacher-path ./teacher_best.pt --epochs 100
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ import argparse
18
+ import time
19
+ from pathlib import Path
20
+ from collections import defaultdict
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from torch.optim import AdamW
25
+ from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
26
+ import numpy as np
27
+
28
+ # Add parent directory to path
29
+ sys.path.insert(0, str(Path(__file__).parent))
30
+
31
+ from models.teacher import PriviGazeTeacher
32
+ from models.student import PriviGazeStudent, count_parameters
33
+ from models.distillation_loss import PriviGazeDistillationLoss
34
+ from models.dataset import create_dataloaders, SyntheticGazeDataset
35
+
36
+ # Trackio for experiment monitoring
37
+ try:
38
+ import trackio
39
+ HAS_TRACKIO = True
40
+ except ImportError:
41
+ HAS_TRACKIO = False
42
+ print("Warning: trackio not installed. Logging to stdout only.")
43
+
44
+
45
+ class DistillationTrainer:
46
+ """Trains student model via privileged distillation from teacher."""
47
+
48
+ def __init__(
49
+ self,
50
+ teacher: PriviGazeTeacher,
51
+ student: PriviGazeStudent,
52
+ distillation_loss: PriviGazeDistillationLoss,
53
+ train_loader,
54
+ val_loader,
55
+ device: torch.device,
56
+ lr: float = 1e-4,
57
+ weight_decay: float = 1e-4,
58
+ epochs: int = 100,
59
+ teacher_frozen: bool = True,
60
+ trackio_project: str = "privi-gaze",
61
+ trackio_run_name: str = "distill",
62
+ ):
63
+ self.teacher = teacher.to(device)
64
+ self.student = student.to(device)
65
+ self.distillation_loss = distillation_loss.to(device)
66
+ self.train_loader = train_loader
67
+ self.val_loader = val_loader
68
+ self.device = device
69
+ self.epochs = epochs
70
+ self.trackio_project = trackio_project
71
+ self.trackio_run_name = trackio_run_name
72
+
73
+ if teacher_frozen:
74
+ for param in self.teacher.parameters():
75
+ param.requires_grad = False
76
+ self.teacher.eval()
77
+
78
+ # Optimizer: only student parameters
79
+ self.optimizer = AdamW(
80
+ self.student.parameters(),
81
+ lr=lr,
82
+ weight_decay=weight_decay,
83
+ )
84
+
85
+ # Scheduler
86
+ self.scheduler = CosineAnnealingLR(
87
+ self.optimizer,
88
+ T_max=epochs,
89
+ eta_min=lr * 0.01,
90
+ )
91
+
92
+ # Track best model
93
+ self.best_val_loss = float('inf')
94
+ self.best_epoch = 0
95
+
96
+ # Metrics tracking
97
+ self.metrics_history = defaultdict(list)
98
+
99
+ # Initialize trackio
100
+ if HAS_TRACKIO:
101
+ trackio.init(
102
+ project=trackio_project,
103
+ run_name=trackio_run_name,
104
+ config={
105
+ 'student_params': count_parameters(self.student),
106
+ 'teacher_params': count_parameters(self.teacher),
107
+ 'lr': lr,
108
+ 'weight_decay': weight_decay,
109
+ 'epochs': epochs,
110
+ 'batch_size': train_loader.batch_size,
111
+ }
112
+ )
113
+
114
+ def train_epoch(self, epoch: int) -> dict:
115
+ """Train for one epoch."""
116
+ self.student.train()
117
+ epoch_losses = defaultdict(float)
118
+ num_batches = 0
119
+
120
+ for batch_idx, batch in enumerate(self.train_loader):
121
+ # Move to device
122
+ left_eye = batch['left_eye'].to(self.device)
123
+ right_eye = batch['right_eye'].to(self.device)
124
+ face_blurred = batch['face_blurred_gray'].to(self.device)
125
+ face_gray = batch['face_gray'].to(self.device)
126
+ pitch_target = batch['pitch'].to(self.device)
127
+ yaw_target = batch['yaw'].to(self.device)
128
+
129
+ # Teacher forward (no grad)
130
+ with torch.no_grad():
131
+ t_pitch, t_yaw, t_features = self.teacher(
132
+ left_eye, right_eye, face_blurred
133
+ )
134
+ # Get teacher logits by running forward through heads
135
+ # (We need these for logit distillation)
136
+ # We extract them from the teacher's internal state
137
+ t_pitch_logits = self.teacher.pitch_head(t_features)
138
+ t_yaw_logits = self.teacher.yaw_head(t_features)
139
+
140
+ # Student forward
141
+ s_pitch, s_yaw, s_features = self.student(face_gray)
142
+ s_pitch_logits = self.student.pitch_head(s_features)
143
+ s_yaw_logits = self.student.yaw_head(s_features)
144
+
145
+ # Compute distillation loss
146
+ loss, loss_dict = self.distillation_loss(
147
+ s_pitch, s_yaw,
148
+ s_pitch_logits, s_yaw_logits,
149
+ s_features,
150
+ t_pitch, t_yaw,
151
+ t_pitch_logits, t_yaw_logits,
152
+ t_features,
153
+ pitch_target, yaw_target,
154
+ )
155
+
156
+ # Backward
157
+ self.optimizer.zero_grad()
158
+ loss.backward()
159
+
160
+ # Gradient clipping
161
+ torch.nn.utils.clip_grad_norm_(self.student.parameters(), max_norm=1.0)
162
+
163
+ self.optimizer.step()
164
+
165
+ # Accumulate losses
166
+ for k, v in loss_dict.items():
167
+ epoch_losses[k] += v
168
+ num_batches += 1
169
+
170
+ # Log every 100 batches
171
+ if batch_idx % 100 == 0:
172
+ self._log_step(epoch, batch_idx, loss_dict)
173
+
174
+ # Average losses
175
+ for k in epoch_losses:
176
+ epoch_losses[k] /= num_batches
177
+
178
+ return dict(epoch_losses)
179
+
180
+ @torch.no_grad()
181
+ def validate(self, epoch: int) -> dict:
182
+ """Validate the student model."""
183
+ self.student.eval()
184
+ self.teacher.eval()
185
+
186
+ val_losses = defaultdict(float)
187
+ angular_errors = []
188
+ pitch_errors = []
189
+ yaw_errors = []
190
+ num_batches = 0
191
+
192
+ for batch in self.val_loader:
193
+ left_eye = batch['left_eye'].to(self.device)
194
+ right_eye = batch['right_eye'].to(self.device)
195
+ face_blurred = batch['face_blurred_gray'].to(self.device)
196
+ face_gray = batch['face_gray'].to(self.device)
197
+ pitch_target = batch['pitch'].to(self.device)
198
+ yaw_target = batch['yaw'].to(self.device)
199
+
200
+ # Teacher forward
201
+ t_pitch, t_yaw, t_features = self.teacher(
202
+ left_eye, right_eye, face_blurred
203
+ )
204
+ t_pitch_logits = self.teacher.pitch_head(t_features)
205
+ t_yaw_logits = self.teacher.yaw_head(t_features)
206
+
207
+ # Student forward
208
+ s_pitch, s_yaw, s_features = self.student(face_gray)
209
+ s_pitch_logits = self.student.pitch_head(s_features)
210
+ s_yaw_logits = self.student.yaw_head(s_features)
211
+
212
+ # Compute loss
213
+ loss, loss_dict = self.distillation_loss(
214
+ s_pitch, s_yaw,
215
+ s_pitch_logits, s_yaw_logits,
216
+ s_features,
217
+ t_pitch, t_yaw,
218
+ t_pitch_logits, t_yaw_logits,
219
+ t_features,
220
+ pitch_target, yaw_target,
221
+ )
222
+
223
+ for k, v in loss_dict.items():
224
+ val_losses[k] += v
225
+ num_batches += 1
226
+
227
+ # Compute angular error
228
+ angular_err = torch.sqrt(
229
+ (s_pitch - pitch_target) ** 2 + (s_yaw - yaw_target) ** 2
230
+ )
231
+ angular_errors.extend(angular_err.cpu().tolist())
232
+ pitch_errors.extend((s_pitch - pitch_target).abs().cpu().tolist())
233
+ yaw_errors.extend((s_yaw - yaw_target).abs().cpu().tolist())
234
+
235
+ for k in val_losses:
236
+ val_losses[k] /= num_batches
237
+
238
+ val_losses['angular_error_mean'] = np.mean(angular_errors)
239
+ val_losses['angular_error_std'] = np.std(angular_errors)
240
+ val_losses['pitch_error_mean'] = np.mean(pitch_errors)
241
+ val_losses['yaw_error_mean'] = np.mean(yaw_errors)
242
+
243
+ return dict(val_losses)
244
+
245
+ def _log_step(self, epoch, batch_idx, loss_dict):
246
+ """Log training step metrics."""
247
+ msg = f"Epoch {epoch} | Batch {batch_idx} | "
248
+ msg += " | ".join(f"{k}={v:.4f}" for k, v in loss_dict.items())
249
+ print(msg)
250
+
251
+ if HAS_TRACKIO:
252
+ for k, v in loss_dict.items():
253
+ trackio.log({f"train/{k}": v})
254
+
255
+ def _log_epoch(self, epoch, train_losses, val_losses):
256
+ """Log epoch metrics."""
257
+ print(f"\n{'='*60}")
258
+ print(f"Epoch {epoch} Summary:")
259
+ print(f" Train: ", " | ".join(f"{k}={v:.4f}" for k, v in train_losses.items()))
260
+ print(f" Val: ", " | ".join(f"{k}={v:.4f}" for k, v in val_losses.items()))
261
+ print(f"{'='*60}\n")
262
+
263
+ if HAS_TRACKIO:
264
+ for k, v in train_losses.items():
265
+ trackio.log({f"epoch/train_{k}": v}, step=epoch)
266
+ for k, v in val_losses.items():
267
+ trackio.log({f"epoch/val_{k}": v}, step=epoch)
268
+
269
+ # Alert on overfitting
270
+ if epoch > 10 and val_losses.get('loss_total', 0) > self.best_val_loss * 1.3:
271
+ trackio.alert(
272
+ "Possible Overfitting",
273
+ f"Val loss {val_losses['loss_total']:.4f} >> best {self.best_val_loss:.4f} at epoch {epoch}",
274
+ level="WARN",
275
+ )
276
+
277
+ def train(self, save_dir: str = "./checkpoints"):
278
+ """Full training loop."""
279
+ os.makedirs(save_dir, exist_ok=True)
280
+
281
+ print(f"Starting distillation training for {self.epochs} epochs")
282
+ print(f"Student parameters: {count_parameters(self.student):,}")
283
+ print(f"Device: {self.device}")
284
+
285
+ start_time = time.time()
286
+
287
+ for epoch in range(self.epochs):
288
+ epoch_start = time.time()
289
+
290
+ # Train
291
+ train_losses = self.train_epoch(epoch)
292
+
293
+ # Validate
294
+ val_losses = self.validate(epoch)
295
+
296
+ # Step scheduler
297
+ self.scheduler.step()
298
+ current_lr = self.optimizer.param_groups[0]['lr']
299
+
300
+ # Log
301
+ self._log_epoch(epoch, train_losses, val_losses)
302
+
303
+ # Track metrics
304
+ for k, v in train_losses.items():
305
+ self.metrics_history[f'train_{k}'].append(v)
306
+ for k, v in val_losses.items():
307
+ self.metrics_history[f'val_{k}'].append(v)
308
+
309
+ # Save best model
310
+ val_total = val_losses.get('loss_total', val_losses.get('angular_error_mean', float('inf')))
311
+ if val_total < self.best_val_loss:
312
+ self.best_val_loss = val_total
313
+ self.best_epoch = epoch
314
+
315
+ torch.save({
316
+ 'epoch': epoch,
317
+ 'student_state_dict': self.student.state_dict(),
318
+ 'optimizer_state_dict': self.optimizer.state_dict(),
319
+ 'best_val_loss': self.best_val_loss,
320
+ 'metrics_history': dict(self.metrics_history),
321
+ }, os.path.join(save_dir, 'student_best.pt'))
322
+
323
+ if HAS_TRACKIO:
324
+ trackio.alert(
325
+ "New Best Model",
326
+ f"Val loss: {val_total:.4f} at epoch {epoch} (angular: {val_losses.get('angular_error_mean', 0):.2f}°)",
327
+ level="INFO",
328
+ )
329
+
330
+ # Save checkpoint every 10 epochs
331
+ if epoch % 10 == 0:
332
+ torch.save({
333
+ 'epoch': epoch,
334
+ 'student_state_dict': self.student.state_dict(),
335
+ 'optimizer_state_dict': self.optimizer.state_dict(),
336
+ }, os.path.join(save_dir, f'student_epoch_{epoch}.pt'))
337
+
338
+ epoch_time = time.time() - epoch_start
339
+ print(f"Epoch {epoch} took {epoch_time:.1f}s, LR: {current_lr:.2e}")
340
+
341
+ total_time = time.time() - start_time
342
+ print(f"\nTraining complete! Total time: {total_time/3600:.1f}h")
343
+ print(f"Best validation loss: {self.best_val_loss:.4f} at epoch {self.best_epoch}")
344
+
345
+ if HAS_TRACKIO:
346
+ trackio.alert(
347
+ "Training Complete",
348
+ f"Best val loss: {self.best_val_loss:.4f} at epoch {self.best_epoch}. "
349
+ f"Student params: {count_parameters(self.student):,}",
350
+ level="INFO",
351
+ )
352
+
353
+ return self.best_val_loss
354
+
355
+
356
+ def pretrain_teacher(
357
+ teacher: PriviGazeTeacher,
358
+ train_loader,
359
+ val_loader,
360
+ device: torch.device,
361
+ lr: float = 1e-4,
362
+ epochs: int = 50,
363
+ save_dir: str = "./checkpoints",
364
+ ) -> str:
365
+ """Pre-train the teacher model on privileged data."""
366
+ from models.distillation_loss import L2CSLoss, AngularLoss
367
+
368
+ teacher = teacher.to(device)
369
+ teacher.train()
370
+
371
+ optimizer = AdamW(teacher.parameters(), lr=lr, weight_decay=1e-4)
372
+ scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=lr * 0.01)
373
+
374
+ pitch_loss_fn = L2CSLoss(gaze_bins=90)
375
+ yaw_loss_fn = L2CSLoss(gaze_bins=90)
376
+ angular_loss_fn = AngularLoss()
377
+
378
+ best_val_loss = float('inf')
379
+ os.makedirs(save_dir, exist_ok=True)
380
+
381
+ for epoch in range(epochs):
382
+ # Training
383
+ teacher.train()
384
+ train_loss_total = 0.0
385
+ for batch in train_loader:
386
+ left_eye = batch['left_eye'].to(device)
387
+ right_eye = batch['right_eye'].to(device)
388
+ face_blurred = batch['face_blurred_gray'].to(device)
389
+ pitch_target = batch['pitch'].to(device)
390
+ yaw_target = batch['yaw'].to(device)
391
+
392
+ pitch_pred, yaw_pred, features = teacher(left_eye, right_eye, face_blurred)
393
+ pitch_logits = teacher.pitch_head(features)
394
+ yaw_logits = teacher.yaw_head(features)
395
+
396
+ loss = (pitch_loss_fn(pitch_logits, pitch_pred, pitch_target) +
397
+ yaw_loss_fn(yaw_logits, yaw_pred, yaw_target) +
398
+ angular_loss_fn(pitch_pred, yaw_pred, pitch_target, yaw_target))
399
+
400
+ optimizer.zero_grad()
401
+ loss.backward()
402
+ torch.nn.utils.clip_grad_norm_(teacher.parameters(), max_norm=1.0)
403
+ optimizer.step()
404
+
405
+ train_loss_total += loss.item()
406
+
407
+ train_loss_total /= len(train_loader)
408
+
409
+ # Validation
410
+ teacher.eval()
411
+ val_loss_total = 0.0
412
+ val_angular = 0.0
413
+ with torch.no_grad():
414
+ for batch in val_loader:
415
+ left_eye = batch['left_eye'].to(device)
416
+ right_eye = batch['right_eye'].to(device)
417
+ face_blurred = batch['face_blurred_gray'].to(device)
418
+ pitch_target = batch['pitch'].to(device)
419
+ yaw_target = batch['yaw'].to(device)
420
+
421
+ pitch_pred, yaw_pred, features = teacher(left_eye, right_eye, face_blurred)
422
+ pitch_logits = teacher.pitch_head(features)
423
+ yaw_logits = teacher.yaw_head(features)
424
+
425
+ loss = (pitch_loss_fn(pitch_logits, pitch_pred, pitch_target) +
426
+ yaw_loss_fn(yaw_logits, yaw_pred, yaw_target))
427
+ val_loss_total += loss.item()
428
+
429
+ angular_err = torch.sqrt((pitch_pred - pitch_target)**2 + (yaw_pred - yaw_target)**2)
430
+ val_angular += angular_err.mean().item()
431
+
432
+ val_loss_total /= len(val_loader)
433
+ val_angular /= len(val_loader)
434
+
435
+ scheduler.step()
436
+
437
+ print(f"Teacher Epoch {epoch}: train_loss={train_loss_total:.4f}, "
438
+ f"val_loss={val_loss_total:.4f}, val_angular={val_angular:.2f}°")
439
+
440
+ if val_loss_total < best_val_loss:
441
+ best_val_loss = val_loss_total
442
+ torch.save(teacher.state_dict(), os.path.join(save_dir, 'teacher_best.pt'))
443
+
444
+ return os.path.join(save_dir, 'teacher_best.pt')
445
+
446
+
447
+ def main():
448
+ parser = argparse.ArgumentParser(description="PriviGaze Distillation Training")
449
+ parser.add_argument('--mode', type=str, default='distill',
450
+ choices=['pretrain_teacher', 'distill', 'both'],
451
+ help='Training mode')
452
+ parser.add_argument('--teacher-path', type=str, default=None,
453
+ help='Path to pre-trained teacher checkpoint')
454
+ parser.add_argument('--batch-size', type=int, default=32,
455
+ help='Batch size')
456
+ parser.add_argument('--epochs', type=int, default=100,
457
+ help='Number of distillation epochs')
458
+ parser.add_argument('--teacher-epochs', type=int, default=50,
459
+ help='Number of teacher pre-training epochs')
460
+ parser.add_argument('--lr', type=float, default=1e-4,
461
+ help='Learning rate')
462
+ parser.add_argument('--weight-decay', type=float, default=1e-4,
463
+ help='Weight decay')
464
+ parser.add_argument('--num-train', type=int, default=40000,
465
+ help='Number of synthetic training samples')
466
+ parser.add_argument('--num-val', type=int, default=5000,
467
+ help='Number of synthetic val samples')
468
+ parser.add_argument('--save-dir', type=str, default='./checkpoints',
469
+ help='Directory to save checkpoints')
470
+ parser.add_argument('--device', type=str, default='cuda',
471
+ help='Device to train on')
472
+ parser.add_argument('--trackio-project', type=str, default='privi-gaze',
473
+ help='Trackio project name')
474
+ parser.add_argument('--trackio-run', type=str, default='distill-run',
475
+ help='Trackio run name')
476
+ parser.add_argument('--push-to-hub', action='store_true',
477
+ help='Push trained model to HF Hub')
478
+ parser.add_argument('--hub-model-id', type=str, default=None,
479
+ help='HF Hub model ID for pushing')
480
+ parser.add_argument('--alpha-contrastive', type=float, default=0.5,
481
+ help='Weight for contrastive distillation loss')
482
+ parser.add_argument('--alpha-mmd', type=float, default=0.1,
483
+ help='Weight for MMD distribution matching loss')
484
+ parser.add_argument('--alpha-logit', type=float, default=0.5,
485
+ help='Weight for logit distillation loss')
486
+
487
+ args = parser.parse_args()
488
+
489
+ # Device setup
490
+ device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
491
+ print(f"Using device: {device}")
492
+
493
+ # Create dataloaders
494
+ train_loader, val_loader, test_loader = create_dataloaders(
495
+ num_train=args.num_train,
496
+ num_val=args.num_val,
497
+ batch_size=args.batch_size,
498
+ )
499
+
500
+ # Initialize models
501
+ teacher = PriviGazeTeacher(
502
+ eye_backbone="facebook/convnextv2-atto-1k-224",
503
+ face_backbone="facebook/convnextv2-nano-22k-384",
504
+ )
505
+
506
+ student = PriviGazeStudent()
507
+
508
+ print(f"Teacher parameters: {count_parameters(teacher):,}")
509
+ print(f"Student parameters: {count_parameters(student):,}")
510
+
511
+ # Pre-train teacher if needed
512
+ if args.mode in ['pretrain_teacher', 'both']:
513
+ print("\n=== Phase 1: Pre-training Teacher ===")
514
+ teacher_path = pretrain_teacher(
515
+ teacher, train_loader, val_loader, device,
516
+ lr=args.lr, epochs=args.teacher_epochs,
517
+ save_dir=args.save_dir,
518
+ )
519
+ print(f"Teacher saved to: {teacher_path}")
520
+ args.teacher_path = teacher_path
521
+
522
+ # Load teacher checkpoint
523
+ if args.teacher_path:
524
+ print(f"\nLoading teacher from: {args.teacher_path}")
525
+ teacher.load_state_dict(torch.load(args.teacher_path, map_location=device))
526
+
527
+ # Distill
528
+ if args.mode in ['distill', 'both']:
529
+ print("\n=== Phase 2: Privileged Distillation ===")
530
+
531
+ # Create distillation loss
532
+ dist_loss = PriviGazeDistillationLoss(
533
+ gaze_bins=90,
534
+ teacher_feature_dim=256,
535
+ student_feature_dim=128,
536
+ alpha_contrastive=args.alpha_contrastive,
537
+ alpha_mmd=args.alpha_mmd,
538
+ alpha_logit=args.alpha_logit,
539
+ )
540
+
541
+ # Create trainer
542
+ trainer = DistillationTrainer(
543
+ teacher=teacher,
544
+ student=student,
545
+ distillation_loss=dist_loss,
546
+ train_loader=train_loader,
547
+ val_loader=val_loader,
548
+ device=device,
549
+ lr=args.lr,
550
+ weight_decay=args.weight_decay,
551
+ epochs=args.epochs,
552
+ trackio_project=args.trackio_project,
553
+ trackio_run_name=args.trackio_run,
554
+ )
555
+
556
+ # Train
557
+ best_loss = trainer.train(save_dir=args.save_dir)
558
+
559
+ # Test evaluation
560
+ print("\n=== Final Test Evaluation ===")
561
+ student.eval()
562
+ student.to(device)
563
+
564
+ test_angular_errors = []
565
+ with torch.no_grad():
566
+ for batch in test_loader:
567
+ face_gray = batch['face_gray'].to(device)
568
+ pitch_target = batch['pitch'].to(device)
569
+ yaw_target = batch['yaw'].to(device)
570
+
571
+ pitch_pred, yaw_pred, _ = student(face_gray)
572
+
573
+ angular_err = torch.sqrt(
574
+ (pitch_pred - pitch_target) ** 2 + (yaw_pred - yaw_target) ** 2
575
+ )
576
+ test_angular_errors.extend(angular_err.cpu().tolist())
577
+
578
+ mean_error = np.mean(test_angular_errors)
579
+ std_error = np.std(test_angular_errors)
580
+ print(f"Test Angular Error: {mean_error:.2f}° ± {std_error:.2f}°")
581
+
582
+ if HAS_TRACKIO:
583
+ trackio.log({
584
+ 'test/angular_error_mean': mean_error,
585
+ 'test/angular_error_std': std_error,
586
+ })
587
+ trackio.alert(
588
+ "Test Results",
589
+ f"Angular error: {mean_error:.2f}° ± {std_error:.2f}°. "
590
+ f"Student params: {count_parameters(student):,}",
591
+ level="INFO",
592
+ )
593
+
594
+ # Push to hub
595
+ if args.push_to_hub and args.hub_model_id:
596
+ from huggingface_hub import HfApi
597
+ api = HfApi()
598
+
599
+ # Save final model
600
+ model_path = os.path.join(args.save_dir, 'student_final.pt')
601
+ torch.save({
602
+ 'student_state_dict': student.state_dict(),
603
+ 'config': {
604
+ 'params': count_parameters(student),
605
+ 'test_angular_error': mean_error,
606
+ }
607
+ }, model_path)
608
+
609
+ # Upload
610
+ api.upload_file(
611
+ path_or_fileobj=model_path,
612
+ path_in_repo="student_model.pt",
613
+ repo_id=args.hub_model_id,
614
+ )
615
+ print(f"Model pushed to: https://huggingface.co/{args.hub_model_id}")
616
+
617
+ return best_loss if args.mode in ['distill', 'both'] else None
618
+
619
+
620
+ if __name__ == "__main__":
621
+ main()