""" Sanity check: verify teacher and student forward passes work correctly. Run on CPU to test before GPU training. """ import torch from models.teacher import PriviGazeTeacher from models.student import PriviGazeStudent, count_parameters from models.distillation_loss import PriviGazeDistillationLoss from models.dataset import SyntheticGazeDataset def test_teacher(): print("Testing Teacher...") model = PriviGazeTeacher() le = torch.randn(2, 3, 112, 112) re = torch.randn(2, 3, 112, 112) fb = torch.randn(2, 1, 224, 224) pp, yp, pl, yl, feat = model(le, re, fb) assert pp.shape == (2,), f"pitch shape {pp.shape} != (2,)" assert yp.shape == (2,), f"yaw shape {yp.shape} != (2,)" assert pl.shape == (2, 90), f"pitch_logits {pl.shape} != (2,90)" assert yl.shape == (2, 90), f"yaw_logits {yl.shape} != (2,90)" assert feat.shape == (2, 256), f"features {feat.shape} != (2,256)" print(f" OK! Params: {count_parameters(model):,}") def test_student(): print("Testing Student...") model = PriviGazeStudent() fg = torch.randn(2, 1, 224, 224) pp, yp, feat = model(fg) assert pp.shape == (2,), f"pitch shape {pp.shape}" assert yp.shape == (2,), f"yaw shape {yp.shape}" assert feat.shape == (2, 128), f"features {feat.shape}" print(f" OK! Params: {count_parameters(model):,}") def test_distillation_loss(): print("Testing Distillation Loss...") loss_fn = PriviGazeDistillationLoss() sp = torch.randn(4) sy = torch.randn(4) spl = torch.randn(4, 90) syl = torch.randn(4, 90) sf = torch.randn(4, 128) tp = torch.randn(4) ty = torch.randn(4) tpl = torch.randn(4, 90) tyl = torch.randn(4, 90) tf = torch.randn(4, 256) pt = torch.randn(4) yt = torch.randn(4) loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt) assert loss.item() > 0, "loss should be positive" assert 'loss_total' in ld print(f" OK! Loss: {loss.item():.4f}") def test_dataset(): print("Testing Dataset...") ds = SyntheticGazeDataset(num_samples=10) sample = ds[0] assert sample['left_eye'].shape == (3, 112, 112) assert sample['right_eye'].shape == (3, 112, 112) assert sample['face_blurred_gray'].shape == (1, 224, 224) assert sample['face_gray'].shape == (1, 224, 224) assert -90 <= sample['pitch'].item() <= 90 assert -90 <= sample['yaw'].item() <= 90 print(f" OK! {len(ds)} samples") def test_end_to_end(): print("Testing End-to-End (1 batch)...") teacher = PriviGazeTeacher() student = PriviGazeStudent() loss_fn = PriviGazeDistillationLoss() ds = SyntheticGazeDataset(num_samples=4) batch = {k: torch.stack([ds[i][k] for i in range(4)]) for k in ds[0].keys()} le = batch['left_eye'] re = batch['right_eye'] fb = batch['face_blurred_gray'] fg = batch['face_gray'] pt = batch['pitch'] yt = batch['yaw'] with torch.no_grad(): tp, ty, tpl, tyl, tf = teacher(le, re, fb) sp, sy, sf = student(fg) spl = student.pitch_head(sf) syl = student.yaw_head(sf) loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt) loss.backward() print(f" OK! Backward pass succeeded. Loss: {loss.item():.4f}") if __name__ == "__main__": test_teacher() test_student() test_distillation_loss() test_dataset() test_end_to_end() print("\nAll tests passed!")