privi-gaze-distill / train.py
BcantCode's picture
Upload train.py
ee7a26f verified
"""
PriviGaze Training Script - Privileged Distillation for Gaze Estimation
"""
import os, sys, argparse, time
from pathlib import Path
from collections import defaultdict
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
sys.path.insert(0, str(Path(__file__).parent))
from models.teacher import PriviGazeTeacher
from models.student import PriviGazeStudent, count_parameters
from models.distillation_loss import PriviGazeDistillationLoss, L2CSLoss, AngularLoss
from models.dataset import create_dataloaders
try:
import trackio; HAS_TRACKIO = True
except ImportError:
HAS_TRACKIO = False; print("Warning: trackio not installed.")
class DistillationTrainer:
def __init__(self, teacher, student, dist_loss, train_loader, val_loader,
device, lr=1e-4, wd=1e-4, epochs=100, tproj="privi-gaze", trun="distill"):
self.teacher = teacher.to(device)
self.student = student.to(device)
self.dist_loss = dist_loss.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.epochs = epochs
for p in self.teacher.parameters(): p.requires_grad = False
self.teacher.eval()
self.opt = AdamW(self.student.parameters(), lr=lr, weight_decay=wd)
self.sched = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=lr*0.01)
self.best_val = float('inf')
self.best_epoch = 0
self.metrics = defaultdict(list)
if HAS_TRACKIO:
trackio.init(project=tproj, run_name=trun,
config={'student_params': count_parameters(student),
'teacher_params': count_parameters(teacher), 'lr': lr, 'epochs': epochs})
def train_epoch(self, epoch):
self.student.train()
losses = defaultdict(float)
n = 0
for bi, batch in enumerate(self.train_loader):
le = batch['left_eye'].to(self.device)
re = batch['right_eye'].to(self.device)
fb = batch['face_blurred_gray'].to(self.device)
fg = batch['face_gray'].to(self.device)
pt = batch['pitch'].to(self.device)
yt = batch['yaw'].to(self.device)
with torch.no_grad():
tp, ty, tplog, tylog, tf = self.teacher(le, re, fb)
sp, sy, sf = self.student(fg)
splog = self.student.pitch_head(sf)
sylog = self.student.yaw_head(sf)
loss, ld = self.dist_loss(sp, sy, splog, sylog, sf,
tp, ty, tplog, tylog, tf, pt, yt)
self.opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.0)
self.opt.step()
for k, v in ld.items(): losses[k] += v
n += 1
if bi % 100 == 0:
print(f"Epoch {epoch} | Batch {bi} | " + " | ".join(f"{k}={v:.4f}" for k, v in ld.items()))
if HAS_TRACKIO:
for k2, v2 in ld.items(): trackio.log({f"train/{k2}": v2})
return {k: v/n for k, v in losses.items()}
@torch.no_grad()
def validate(self, epoch):
self.student.eval()
self.teacher.eval()
losses = defaultdict(float)
ae, pe, ye = [], [], []
n = 0
for batch in self.val_loader:
le = batch['left_eye'].to(self.device)
re = batch['right_eye'].to(self.device)
fb = batch['face_blurred_gray'].to(self.device)
fg = batch['face_gray'].to(self.device)
pt = batch['pitch'].to(self.device)
yt = batch['yaw'].to(self.device)
tp, ty, tplog, tylog, tf = self.teacher(le, re, fb)
sp, sy, sf = self.student(fg)
splog = self.student.pitch_head(sf)
sylog = self.student.yaw_head(sf)
loss, ld = self.dist_loss(sp, sy, splog, sylog, sf,
tp, ty, tplog, tylog, tf, pt, yt)
for k, v in ld.items(): losses[k] += v
n += 1
aerr = torch.sqrt((sp-pt)**2 + (sy-yt)**2)
ae.extend(aerr.cpu().tolist())
pe.extend((sp-pt).abs().cpu().tolist())
ye.extend((sy-yt).abs().cpu().tolist())
for k in losses: losses[k] /= n
losses['angular_mean'] = np.mean(ae)
losses['angular_std'] = np.std(ae)
losses['pitch_mean'] = np.mean(pe)
losses['yaw_mean'] = np.mean(ye)
return dict(losses)
def train(self, save_dir="./checkpoints"):
os.makedirs(save_dir, exist_ok=True)
print(f"Distillation: {self.epochs} epochs | Student: {count_parameters(self.student):,} params")
t0 = time.time()
for epoch in range(self.epochs):
te = time.time()
tl = self.train_epoch(epoch)
vl = self.validate(epoch)
self.sched.step()
lr = self.opt.param_groups[0]['lr']
print(f"\n{'='*60}")
print(f"Epoch {epoch}: train={tl.get('loss_total',0):.4f} val={vl.get('loss_total',0):.4f} angular={vl.get('angular_mean',0):.2f}deg")
print(f"{'='*60}\n")
for k, v in tl.items(): self.metrics[f'train_{k}'].append(v)
for k, v in vl.items(): self.metrics[f'val_{k}'].append(v)
vt = vl.get('loss_total', vl.get('angular_mean', float('inf')))
if vt < self.best_val:
self.best_val = vt
self.best_epoch = epoch
torch.save({'epoch': epoch, 'student_state_dict': self.student.state_dict(),
'opt_state_dict': self.opt.state_dict(), 'best_val': self.best_val,
'metrics': dict(self.metrics)}, os.path.join(save_dir, 'student_best.pt'))
if HAS_TRACKIO: trackio.alert("New Best", f"Val {vt:.4f} @ epoch {epoch}", level="INFO")
if epoch % 10 == 0:
torch.save({'epoch': epoch, 'student_state_dict': self.student.state_dict(),
'opt_state_dict': self.opt.state_dict()},
os.path.join(save_dir, f'student_epoch_{epoch}.pt'))
print(f"Epoch {epoch} took {time.time()-te:.1f}s, LR={lr:.2e}")
print(f"\nDone! Best val: {self.best_val:.4f} @ epoch {self.best_epoch}")
return self.best_val
def pretrain_teacher(teacher, train_loader, val_loader, device, lr=1e-4, epochs=50, save_dir="./checkpoints"):
teacher = teacher.to(device)
opt = AdamW(teacher.parameters(), lr=lr, weight_decay=1e-4)
sched = CosineAnnealingLR(opt, T_max=epochs, eta_min=lr*0.01)
ploss = L2CSLoss(gaze_bins=90)
yloss = L2CSLoss(gaze_bins=90)
aloss = AngularLoss()
best = float('inf')
os.makedirs(save_dir, exist_ok=True)
for epoch in range(epochs):
teacher.train()
tloss = 0.0
for batch in train_loader:
le = batch['left_eye'].to(device)
re = batch['right_eye'].to(device)
fb = batch['face_blurred_gray'].to(device)
pt = batch['pitch'].to(device)
yt = batch['yaw'].to(device)
pp, yp, pl, yl, _ = teacher(le, re, fb)
loss = ploss(pl, pp, pt) + yloss(yl, yp, yt) + aloss(pp, yp, pt, yt)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(teacher.parameters(), 1.0)
opt.step()
tloss += loss.item()
tloss /= len(train_loader)
teacher.eval()
vloss = 0.0
va = 0.0
with torch.no_grad():
for batch in val_loader:
le = batch['left_eye'].to(device)
re = batch['right_eye'].to(device)
fb = batch['face_blurred_gray'].to(device)
pt = batch['pitch'].to(device)
yt = batch['yaw'].to(device)
pp, yp, pl, yl, _ = teacher(le, re, fb)
vloss += (ploss(pl, pp, pt) + yloss(yl, yp, yt)).item()
va += torch.sqrt((pp-pt)**2 + (yp-yt)**2).mean().item()
vloss /= len(val_loader)
va /= len(val_loader)
sched.step()
print(f"Teacher Epoch {epoch}: train={tloss:.4f} val={vloss:.4f} angular={va:.2f}deg")
if vloss < best:
best = vloss
torch.save(teacher.state_dict(), os.path.join(save_dir, 'teacher_best.pt'))
return os.path.join(save_dir, 'teacher_best.pt')
def main():
p = argparse.ArgumentParser(description="PriviGaze Training")
p.add_argument('--mode', type=str, default='distill', choices=['pretrain_teacher','distill','both'])
p.add_argument('--teacher-path', type=str, default=None)
p.add_argument('--batch-size', type=int, default=32)
p.add_argument('--epochs', type=int, default=100)
p.add_argument('--teacher-epochs', type=int, default=50)
p.add_argument('--lr', type=float, default=1e-4)
p.add_argument('--weight-decay', type=float, default=1e-4)
p.add_argument('--num-train', type=int, default=40000)
p.add_argument('--num-val', type=int, default=5000)
p.add_argument('--save-dir', type=str, default='./checkpoints')
p.add_argument('--device', type=str, default='cuda')
p.add_argument('--trackio-project', type=str, default='privi-gaze')
p.add_argument('--trackio-run', type=str, default='distill-run')
p.add_argument('--push-to-hub', action='store_true')
p.add_argument('--hub-model-id', type=str, default=None)
p.add_argument('--alpha-contrastive', type=float, default=0.5)
p.add_argument('--alpha-mmd', type=float, default=0.1)
p.add_argument('--alpha-logit', type=float, default=0.5)
args = p.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
train_loader, val_loader, test_loader = create_dataloaders(
num_train=args.num_train, num_val=args.num_val, batch_size=args.batch_size)
teacher = PriviGazeTeacher()
student = PriviGazeStudent()
print(f"Teacher: {count_parameters(teacher):,} params")
print(f"Student: {count_parameters(student):,} params")
if args.mode in ['pretrain_teacher', 'both']:
print("\n=== Phase 1: Teacher Pre-training ===")
tp = pretrain_teacher(teacher, train_loader, val_loader, device,
lr=args.lr, epochs=args.teacher_epochs, save_dir=args.save_dir)
args.teacher_path = tp
if args.teacher_path:
print(f"\nLoading teacher: {args.teacher_path}")
teacher.load_state_dict(torch.load(args.teacher_path, map_location=device))
if args.mode in ['distill', 'both']:
print("\n=== Phase 2: Distillation ===")
dloss = PriviGazeDistillationLoss(
gaze_bins=90, teacher_feature_dim=256, student_feature_dim=128,
alpha_contrastive=args.alpha_contrastive, alpha_mmd=args.alpha_mmd,
alpha_logit=args.alpha_logit)
trainer = DistillationTrainer(teacher, student, dloss, train_loader, val_loader,
device, lr=args.lr, wd=args.weight_decay, epochs=args.epochs,
tproj=args.trackio_project, trun=args.trackio_run)
trainer.train(save_dir=args.save_dir)
print("\n=== Test ===")
student.eval().to(device)
terr = []
with torch.no_grad():
for batch in test_loader:
fg = batch['face_gray'].to(device)
pt = batch['pitch'].to(device)
yt = batch['yaw'].to(device)
sp, sy, _ = student(fg)
terr.extend(torch.sqrt((sp-pt)**2 + (sy-yt)**2).cpu().tolist())
me = np.mean(terr); se = np.std(terr)
print(f"Test Angular Error: {me:.2f}deg +- {se:.2f}deg")
if args.push_to_hub and args.hub_model_id:
from huggingface_hub import HfApi
mp = os.path.join(args.save_dir, 'student_final.pt')
torch.save({'student_state_dict': student.state_dict(),
'config': {'params': count_parameters(student), 'test_err': me}}, mp)
HfApi().upload_file(path_or_fileobj=mp, path_in_repo="student_model.pt", repo_id=args.hub_model_id)
print(f"Pushed to: https://huggingface.co/{args.hub_model_id}")
if __name__ == "__main__":
main()