| """ |
| PriviGaze Training Script - Privileged Distillation for Gaze Estimation |
| """ |
| import os, sys, argparse, time |
| from pathlib import Path |
| from collections import defaultdict |
| import torch |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| import numpy as np |
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
| from models.teacher import PriviGazeTeacher |
| from models.student import PriviGazeStudent, count_parameters |
| from models.distillation_loss import PriviGazeDistillationLoss, L2CSLoss, AngularLoss |
| from models.dataset import create_dataloaders |
|
|
| try: |
| import trackio; HAS_TRACKIO = True |
| except ImportError: |
| HAS_TRACKIO = False; print("Warning: trackio not installed.") |
|
|
|
|
| class DistillationTrainer: |
| def __init__(self, teacher, student, dist_loss, train_loader, val_loader, |
| device, lr=1e-4, wd=1e-4, epochs=100, tproj="privi-gaze", trun="distill"): |
| self.teacher = teacher.to(device) |
| self.student = student.to(device) |
| self.dist_loss = dist_loss.to(device) |
| self.train_loader = train_loader |
| self.val_loader = val_loader |
| self.device = device |
| self.epochs = epochs |
| for p in self.teacher.parameters(): p.requires_grad = False |
| self.teacher.eval() |
| self.opt = AdamW(self.student.parameters(), lr=lr, weight_decay=wd) |
| self.sched = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=lr*0.01) |
| self.best_val = float('inf') |
| self.best_epoch = 0 |
| self.metrics = defaultdict(list) |
| if HAS_TRACKIO: |
| trackio.init(project=tproj, run_name=trun, |
| config={'student_params': count_parameters(student), |
| 'teacher_params': count_parameters(teacher), 'lr': lr, 'epochs': epochs}) |
|
|
| def train_epoch(self, epoch): |
| self.student.train() |
| losses = defaultdict(float) |
| n = 0 |
| for bi, batch in enumerate(self.train_loader): |
| le = batch['left_eye'].to(self.device) |
| re = batch['right_eye'].to(self.device) |
| fb = batch['face_blurred_gray'].to(self.device) |
| fg = batch['face_gray'].to(self.device) |
| pt = batch['pitch'].to(self.device) |
| yt = batch['yaw'].to(self.device) |
| with torch.no_grad(): |
| tp, ty, tplog, tylog, tf = self.teacher(le, re, fb) |
| sp, sy, sf = self.student(fg) |
| splog = self.student.pitch_head(sf) |
| sylog = self.student.yaw_head(sf) |
| loss, ld = self.dist_loss(sp, sy, splog, sylog, sf, |
| tp, ty, tplog, tylog, tf, pt, yt) |
| self.opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.0) |
| self.opt.step() |
| for k, v in ld.items(): losses[k] += v |
| n += 1 |
| if bi % 100 == 0: |
| print(f"Epoch {epoch} | Batch {bi} | " + " | ".join(f"{k}={v:.4f}" for k, v in ld.items())) |
| if HAS_TRACKIO: |
| for k2, v2 in ld.items(): trackio.log({f"train/{k2}": v2}) |
| return {k: v/n for k, v in losses.items()} |
|
|
| @torch.no_grad() |
| def validate(self, epoch): |
| self.student.eval() |
| self.teacher.eval() |
| losses = defaultdict(float) |
| ae, pe, ye = [], [], [] |
| n = 0 |
| for batch in self.val_loader: |
| le = batch['left_eye'].to(self.device) |
| re = batch['right_eye'].to(self.device) |
| fb = batch['face_blurred_gray'].to(self.device) |
| fg = batch['face_gray'].to(self.device) |
| pt = batch['pitch'].to(self.device) |
| yt = batch['yaw'].to(self.device) |
| tp, ty, tplog, tylog, tf = self.teacher(le, re, fb) |
| sp, sy, sf = self.student(fg) |
| splog = self.student.pitch_head(sf) |
| sylog = self.student.yaw_head(sf) |
| loss, ld = self.dist_loss(sp, sy, splog, sylog, sf, |
| tp, ty, tplog, tylog, tf, pt, yt) |
| for k, v in ld.items(): losses[k] += v |
| n += 1 |
| aerr = torch.sqrt((sp-pt)**2 + (sy-yt)**2) |
| ae.extend(aerr.cpu().tolist()) |
| pe.extend((sp-pt).abs().cpu().tolist()) |
| ye.extend((sy-yt).abs().cpu().tolist()) |
| for k in losses: losses[k] /= n |
| losses['angular_mean'] = np.mean(ae) |
| losses['angular_std'] = np.std(ae) |
| losses['pitch_mean'] = np.mean(pe) |
| losses['yaw_mean'] = np.mean(ye) |
| return dict(losses) |
|
|
| def train(self, save_dir="./checkpoints"): |
| os.makedirs(save_dir, exist_ok=True) |
| print(f"Distillation: {self.epochs} epochs | Student: {count_parameters(self.student):,} params") |
| t0 = time.time() |
| for epoch in range(self.epochs): |
| te = time.time() |
| tl = self.train_epoch(epoch) |
| vl = self.validate(epoch) |
| self.sched.step() |
| lr = self.opt.param_groups[0]['lr'] |
| print(f"\n{'='*60}") |
| print(f"Epoch {epoch}: train={tl.get('loss_total',0):.4f} val={vl.get('loss_total',0):.4f} angular={vl.get('angular_mean',0):.2f}deg") |
| print(f"{'='*60}\n") |
| for k, v in tl.items(): self.metrics[f'train_{k}'].append(v) |
| for k, v in vl.items(): self.metrics[f'val_{k}'].append(v) |
| vt = vl.get('loss_total', vl.get('angular_mean', float('inf'))) |
| if vt < self.best_val: |
| self.best_val = vt |
| self.best_epoch = epoch |
| torch.save({'epoch': epoch, 'student_state_dict': self.student.state_dict(), |
| 'opt_state_dict': self.opt.state_dict(), 'best_val': self.best_val, |
| 'metrics': dict(self.metrics)}, os.path.join(save_dir, 'student_best.pt')) |
| if HAS_TRACKIO: trackio.alert("New Best", f"Val {vt:.4f} @ epoch {epoch}", level="INFO") |
| if epoch % 10 == 0: |
| torch.save({'epoch': epoch, 'student_state_dict': self.student.state_dict(), |
| 'opt_state_dict': self.opt.state_dict()}, |
| os.path.join(save_dir, f'student_epoch_{epoch}.pt')) |
| print(f"Epoch {epoch} took {time.time()-te:.1f}s, LR={lr:.2e}") |
| print(f"\nDone! Best val: {self.best_val:.4f} @ epoch {self.best_epoch}") |
| return self.best_val |
|
|
|
|
| def pretrain_teacher(teacher, train_loader, val_loader, device, lr=1e-4, epochs=50, save_dir="./checkpoints"): |
| teacher = teacher.to(device) |
| opt = AdamW(teacher.parameters(), lr=lr, weight_decay=1e-4) |
| sched = CosineAnnealingLR(opt, T_max=epochs, eta_min=lr*0.01) |
| ploss = L2CSLoss(gaze_bins=90) |
| yloss = L2CSLoss(gaze_bins=90) |
| aloss = AngularLoss() |
| best = float('inf') |
| os.makedirs(save_dir, exist_ok=True) |
| for epoch in range(epochs): |
| teacher.train() |
| tloss = 0.0 |
| for batch in train_loader: |
| le = batch['left_eye'].to(device) |
| re = batch['right_eye'].to(device) |
| fb = batch['face_blurred_gray'].to(device) |
| pt = batch['pitch'].to(device) |
| yt = batch['yaw'].to(device) |
| pp, yp, pl, yl, _ = teacher(le, re, fb) |
| loss = ploss(pl, pp, pt) + yloss(yl, yp, yt) + aloss(pp, yp, pt, yt) |
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(teacher.parameters(), 1.0) |
| opt.step() |
| tloss += loss.item() |
| tloss /= len(train_loader) |
| teacher.eval() |
| vloss = 0.0 |
| va = 0.0 |
| with torch.no_grad(): |
| for batch in val_loader: |
| le = batch['left_eye'].to(device) |
| re = batch['right_eye'].to(device) |
| fb = batch['face_blurred_gray'].to(device) |
| pt = batch['pitch'].to(device) |
| yt = batch['yaw'].to(device) |
| pp, yp, pl, yl, _ = teacher(le, re, fb) |
| vloss += (ploss(pl, pp, pt) + yloss(yl, yp, yt)).item() |
| va += torch.sqrt((pp-pt)**2 + (yp-yt)**2).mean().item() |
| vloss /= len(val_loader) |
| va /= len(val_loader) |
| sched.step() |
| print(f"Teacher Epoch {epoch}: train={tloss:.4f} val={vloss:.4f} angular={va:.2f}deg") |
| if vloss < best: |
| best = vloss |
| torch.save(teacher.state_dict(), os.path.join(save_dir, 'teacher_best.pt')) |
| return os.path.join(save_dir, 'teacher_best.pt') |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser(description="PriviGaze Training") |
| p.add_argument('--mode', type=str, default='distill', choices=['pretrain_teacher','distill','both']) |
| p.add_argument('--teacher-path', type=str, default=None) |
| p.add_argument('--batch-size', type=int, default=32) |
| p.add_argument('--epochs', type=int, default=100) |
| p.add_argument('--teacher-epochs', type=int, default=50) |
| p.add_argument('--lr', type=float, default=1e-4) |
| p.add_argument('--weight-decay', type=float, default=1e-4) |
| p.add_argument('--num-train', type=int, default=40000) |
| p.add_argument('--num-val', type=int, default=5000) |
| p.add_argument('--save-dir', type=str, default='./checkpoints') |
| p.add_argument('--device', type=str, default='cuda') |
| p.add_argument('--trackio-project', type=str, default='privi-gaze') |
| p.add_argument('--trackio-run', type=str, default='distill-run') |
| p.add_argument('--push-to-hub', action='store_true') |
| p.add_argument('--hub-model-id', type=str, default=None) |
| p.add_argument('--alpha-contrastive', type=float, default=0.5) |
| p.add_argument('--alpha-mmd', type=float, default=0.1) |
| p.add_argument('--alpha-logit', type=float, default=0.5) |
| args = p.parse_args() |
|
|
| device = torch.device(args.device if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| train_loader, val_loader, test_loader = create_dataloaders( |
| num_train=args.num_train, num_val=args.num_val, batch_size=args.batch_size) |
|
|
| teacher = PriviGazeTeacher() |
| student = PriviGazeStudent() |
| print(f"Teacher: {count_parameters(teacher):,} params") |
| print(f"Student: {count_parameters(student):,} params") |
|
|
| if args.mode in ['pretrain_teacher', 'both']: |
| print("\n=== Phase 1: Teacher Pre-training ===") |
| tp = pretrain_teacher(teacher, train_loader, val_loader, device, |
| lr=args.lr, epochs=args.teacher_epochs, save_dir=args.save_dir) |
| args.teacher_path = tp |
|
|
| if args.teacher_path: |
| print(f"\nLoading teacher: {args.teacher_path}") |
| teacher.load_state_dict(torch.load(args.teacher_path, map_location=device)) |
|
|
| if args.mode in ['distill', 'both']: |
| print("\n=== Phase 2: Distillation ===") |
| dloss = PriviGazeDistillationLoss( |
| gaze_bins=90, teacher_feature_dim=256, student_feature_dim=128, |
| alpha_contrastive=args.alpha_contrastive, alpha_mmd=args.alpha_mmd, |
| alpha_logit=args.alpha_logit) |
| trainer = DistillationTrainer(teacher, student, dloss, train_loader, val_loader, |
| device, lr=args.lr, wd=args.weight_decay, epochs=args.epochs, |
| tproj=args.trackio_project, trun=args.trackio_run) |
| trainer.train(save_dir=args.save_dir) |
|
|
| print("\n=== Test ===") |
| student.eval().to(device) |
| terr = [] |
| with torch.no_grad(): |
| for batch in test_loader: |
| fg = batch['face_gray'].to(device) |
| pt = batch['pitch'].to(device) |
| yt = batch['yaw'].to(device) |
| sp, sy, _ = student(fg) |
| terr.extend(torch.sqrt((sp-pt)**2 + (sy-yt)**2).cpu().tolist()) |
| me = np.mean(terr); se = np.std(terr) |
| print(f"Test Angular Error: {me:.2f}deg +- {se:.2f}deg") |
|
|
| if args.push_to_hub and args.hub_model_id: |
| from huggingface_hub import HfApi |
| mp = os.path.join(args.save_dir, 'student_final.pt') |
| torch.save({'student_state_dict': student.state_dict(), |
| 'config': {'params': count_parameters(student), 'test_err': me}}, mp) |
| HfApi().upload_file(path_or_fileobj=mp, path_in_repo="student_model.pt", repo_id=args.hub_model_id) |
| print(f"Pushed to: https://huggingface.co/{args.hub_model_id}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|