""" PriviGaze Training Script - Privileged Distillation for Gaze Estimation """ import os, sys, argparse, time from pathlib import Path from collections import defaultdict import torch from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR import numpy as np sys.path.insert(0, str(Path(__file__).parent)) from models.teacher import PriviGazeTeacher from models.student import PriviGazeStudent, count_parameters from models.distillation_loss import PriviGazeDistillationLoss, L2CSLoss, AngularLoss from models.dataset import create_dataloaders try: import trackio; HAS_TRACKIO = True except ImportError: HAS_TRACKIO = False; print("Warning: trackio not installed.") class DistillationTrainer: def __init__(self, teacher, student, dist_loss, train_loader, val_loader, device, lr=1e-4, wd=1e-4, epochs=100, tproj="privi-gaze", trun="distill"): self.teacher = teacher.to(device) self.student = student.to(device) self.dist_loss = dist_loss.to(device) self.train_loader = train_loader self.val_loader = val_loader self.device = device self.epochs = epochs for p in self.teacher.parameters(): p.requires_grad = False self.teacher.eval() self.opt = AdamW(self.student.parameters(), lr=lr, weight_decay=wd) self.sched = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=lr*0.01) self.best_val = float('inf') self.best_epoch = 0 self.metrics = defaultdict(list) if HAS_TRACKIO: trackio.init(project=tproj, run_name=trun, config={'student_params': count_parameters(student), 'teacher_params': count_parameters(teacher), 'lr': lr, 'epochs': epochs}) def train_epoch(self, epoch): self.student.train() losses = defaultdict(float) n = 0 for bi, batch in enumerate(self.train_loader): le = batch['left_eye'].to(self.device) re = batch['right_eye'].to(self.device) fb = batch['face_blurred_gray'].to(self.device) fg = batch['face_gray'].to(self.device) pt = batch['pitch'].to(self.device) yt = batch['yaw'].to(self.device) with torch.no_grad(): tp, ty, tplog, tylog, tf = self.teacher(le, re, fb) sp, sy, sf = self.student(fg) splog = self.student.pitch_head(sf) sylog = self.student.yaw_head(sf) loss, ld = self.dist_loss(sp, sy, splog, sylog, sf, tp, ty, tplog, tylog, tf, pt, yt) self.opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.0) self.opt.step() for k, v in ld.items(): losses[k] += v n += 1 if bi % 100 == 0: print(f"Epoch {epoch} | Batch {bi} | " + " | ".join(f"{k}={v:.4f}" for k, v in ld.items())) if HAS_TRACKIO: for k2, v2 in ld.items(): trackio.log({f"train/{k2}": v2}) return {k: v/n for k, v in losses.items()} @torch.no_grad() def validate(self, epoch): self.student.eval() self.teacher.eval() losses = defaultdict(float) ae, pe, ye = [], [], [] n = 0 for batch in self.val_loader: le = batch['left_eye'].to(self.device) re = batch['right_eye'].to(self.device) fb = batch['face_blurred_gray'].to(self.device) fg = batch['face_gray'].to(self.device) pt = batch['pitch'].to(self.device) yt = batch['yaw'].to(self.device) tp, ty, tplog, tylog, tf = self.teacher(le, re, fb) sp, sy, sf = self.student(fg) splog = self.student.pitch_head(sf) sylog = self.student.yaw_head(sf) loss, ld = self.dist_loss(sp, sy, splog, sylog, sf, tp, ty, tplog, tylog, tf, pt, yt) for k, v in ld.items(): losses[k] += v n += 1 aerr = torch.sqrt((sp-pt)**2 + (sy-yt)**2) ae.extend(aerr.cpu().tolist()) pe.extend((sp-pt).abs().cpu().tolist()) ye.extend((sy-yt).abs().cpu().tolist()) for k in losses: losses[k] /= n losses['angular_mean'] = np.mean(ae) losses['angular_std'] = np.std(ae) losses['pitch_mean'] = np.mean(pe) losses['yaw_mean'] = np.mean(ye) return dict(losses) def train(self, save_dir="./checkpoints"): os.makedirs(save_dir, exist_ok=True) print(f"Distillation: {self.epochs} epochs | Student: {count_parameters(self.student):,} params") t0 = time.time() for epoch in range(self.epochs): te = time.time() tl = self.train_epoch(epoch) vl = self.validate(epoch) self.sched.step() lr = self.opt.param_groups[0]['lr'] print(f"\n{'='*60}") print(f"Epoch {epoch}: train={tl.get('loss_total',0):.4f} val={vl.get('loss_total',0):.4f} angular={vl.get('angular_mean',0):.2f}deg") print(f"{'='*60}\n") for k, v in tl.items(): self.metrics[f'train_{k}'].append(v) for k, v in vl.items(): self.metrics[f'val_{k}'].append(v) vt = vl.get('loss_total', vl.get('angular_mean', float('inf'))) if vt < self.best_val: self.best_val = vt self.best_epoch = epoch torch.save({'epoch': epoch, 'student_state_dict': self.student.state_dict(), 'opt_state_dict': self.opt.state_dict(), 'best_val': self.best_val, 'metrics': dict(self.metrics)}, os.path.join(save_dir, 'student_best.pt')) if HAS_TRACKIO: trackio.alert("New Best", f"Val {vt:.4f} @ epoch {epoch}", level="INFO") if epoch % 10 == 0: torch.save({'epoch': epoch, 'student_state_dict': self.student.state_dict(), 'opt_state_dict': self.opt.state_dict()}, os.path.join(save_dir, f'student_epoch_{epoch}.pt')) print(f"Epoch {epoch} took {time.time()-te:.1f}s, LR={lr:.2e}") print(f"\nDone! Best val: {self.best_val:.4f} @ epoch {self.best_epoch}") return self.best_val def pretrain_teacher(teacher, train_loader, val_loader, device, lr=1e-4, epochs=50, save_dir="./checkpoints"): teacher = teacher.to(device) opt = AdamW(teacher.parameters(), lr=lr, weight_decay=1e-4) sched = CosineAnnealingLR(opt, T_max=epochs, eta_min=lr*0.01) ploss = L2CSLoss(gaze_bins=90) yloss = L2CSLoss(gaze_bins=90) aloss = AngularLoss() best = float('inf') os.makedirs(save_dir, exist_ok=True) for epoch in range(epochs): teacher.train() tloss = 0.0 for batch in train_loader: le = batch['left_eye'].to(device) re = batch['right_eye'].to(device) fb = batch['face_blurred_gray'].to(device) pt = batch['pitch'].to(device) yt = batch['yaw'].to(device) pp, yp, pl, yl, _ = teacher(le, re, fb) loss = ploss(pl, pp, pt) + yloss(yl, yp, yt) + aloss(pp, yp, pt, yt) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(teacher.parameters(), 1.0) opt.step() tloss += loss.item() tloss /= len(train_loader) teacher.eval() vloss = 0.0 va = 0.0 with torch.no_grad(): for batch in val_loader: le = batch['left_eye'].to(device) re = batch['right_eye'].to(device) fb = batch['face_blurred_gray'].to(device) pt = batch['pitch'].to(device) yt = batch['yaw'].to(device) pp, yp, pl, yl, _ = teacher(le, re, fb) vloss += (ploss(pl, pp, pt) + yloss(yl, yp, yt)).item() va += torch.sqrt((pp-pt)**2 + (yp-yt)**2).mean().item() vloss /= len(val_loader) va /= len(val_loader) sched.step() print(f"Teacher Epoch {epoch}: train={tloss:.4f} val={vloss:.4f} angular={va:.2f}deg") if vloss < best: best = vloss torch.save(teacher.state_dict(), os.path.join(save_dir, 'teacher_best.pt')) return os.path.join(save_dir, 'teacher_best.pt') def main(): p = argparse.ArgumentParser(description="PriviGaze Training") p.add_argument('--mode', type=str, default='distill', choices=['pretrain_teacher','distill','both']) p.add_argument('--teacher-path', type=str, default=None) p.add_argument('--batch-size', type=int, default=32) p.add_argument('--epochs', type=int, default=100) p.add_argument('--teacher-epochs', type=int, default=50) p.add_argument('--lr', type=float, default=1e-4) p.add_argument('--weight-decay', type=float, default=1e-4) p.add_argument('--num-train', type=int, default=40000) p.add_argument('--num-val', type=int, default=5000) p.add_argument('--save-dir', type=str, default='./checkpoints') p.add_argument('--device', type=str, default='cuda') p.add_argument('--trackio-project', type=str, default='privi-gaze') p.add_argument('--trackio-run', type=str, default='distill-run') p.add_argument('--push-to-hub', action='store_true') p.add_argument('--hub-model-id', type=str, default=None) p.add_argument('--alpha-contrastive', type=float, default=0.5) p.add_argument('--alpha-mmd', type=float, default=0.1) p.add_argument('--alpha-logit', type=float, default=0.5) args = p.parse_args() device = torch.device(args.device if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") train_loader, val_loader, test_loader = create_dataloaders( num_train=args.num_train, num_val=args.num_val, batch_size=args.batch_size) teacher = PriviGazeTeacher() student = PriviGazeStudent() print(f"Teacher: {count_parameters(teacher):,} params") print(f"Student: {count_parameters(student):,} params") if args.mode in ['pretrain_teacher', 'both']: print("\n=== Phase 1: Teacher Pre-training ===") tp = pretrain_teacher(teacher, train_loader, val_loader, device, lr=args.lr, epochs=args.teacher_epochs, save_dir=args.save_dir) args.teacher_path = tp if args.teacher_path: print(f"\nLoading teacher: {args.teacher_path}") teacher.load_state_dict(torch.load(args.teacher_path, map_location=device)) if args.mode in ['distill', 'both']: print("\n=== Phase 2: Distillation ===") dloss = PriviGazeDistillationLoss( gaze_bins=90, teacher_feature_dim=256, student_feature_dim=128, alpha_contrastive=args.alpha_contrastive, alpha_mmd=args.alpha_mmd, alpha_logit=args.alpha_logit) trainer = DistillationTrainer(teacher, student, dloss, train_loader, val_loader, device, lr=args.lr, wd=args.weight_decay, epochs=args.epochs, tproj=args.trackio_project, trun=args.trackio_run) trainer.train(save_dir=args.save_dir) print("\n=== Test ===") student.eval().to(device) terr = [] with torch.no_grad(): for batch in test_loader: fg = batch['face_gray'].to(device) pt = batch['pitch'].to(device) yt = batch['yaw'].to(device) sp, sy, _ = student(fg) terr.extend(torch.sqrt((sp-pt)**2 + (sy-yt)**2).cpu().tolist()) me = np.mean(terr); se = np.std(terr) print(f"Test Angular Error: {me:.2f}deg +- {se:.2f}deg") if args.push_to_hub and args.hub_model_id: from huggingface_hub import HfApi mp = os.path.join(args.save_dir, 'student_final.pt') torch.save({'student_state_dict': student.state_dict(), 'config': {'params': count_parameters(student), 'test_err': me}}, mp) HfApi().upload_file(path_or_fileobj=mp, path_in_repo="student_model.pt", repo_id=args.hub_model_id) print(f"Pushed to: https://huggingface.co/{args.hub_model_id}") if __name__ == "__main__": main()