| """ |
| Sanity check: verify teacher and student forward passes work correctly. |
| Run on CPU to test before GPU training. |
| """ |
| import torch |
| from models.teacher import PriviGazeTeacher |
| from models.student import PriviGazeStudent, count_parameters |
| from models.distillation_loss import PriviGazeDistillationLoss |
| from models.dataset import SyntheticGazeDataset |
|
|
|
|
| def test_teacher(): |
| print("Testing Teacher...") |
| model = PriviGazeTeacher() |
| le = torch.randn(2, 3, 112, 112) |
| re = torch.randn(2, 3, 112, 112) |
| fb = torch.randn(2, 1, 224, 224) |
| pp, yp, pl, yl, feat = model(le, re, fb) |
| assert pp.shape == (2,), f"pitch shape {pp.shape} != (2,)" |
| assert yp.shape == (2,), f"yaw shape {yp.shape} != (2,)" |
| assert pl.shape == (2, 90), f"pitch_logits {pl.shape} != (2,90)" |
| assert yl.shape == (2, 90), f"yaw_logits {yl.shape} != (2,90)" |
| assert feat.shape == (2, 256), f"features {feat.shape} != (2,256)" |
| print(f" OK! Params: {count_parameters(model):,}") |
|
|
|
|
| def test_student(): |
| print("Testing Student...") |
| model = PriviGazeStudent() |
| fg = torch.randn(2, 1, 224, 224) |
| pp, yp, feat = model(fg) |
| assert pp.shape == (2,), f"pitch shape {pp.shape}" |
| assert yp.shape == (2,), f"yaw shape {yp.shape}" |
| assert feat.shape == (2, 128), f"features {feat.shape}" |
| print(f" OK! Params: {count_parameters(model):,}") |
|
|
|
|
| def test_distillation_loss(): |
| print("Testing Distillation Loss...") |
| loss_fn = PriviGazeDistillationLoss() |
| sp = torch.randn(4) |
| sy = torch.randn(4) |
| spl = torch.randn(4, 90) |
| syl = torch.randn(4, 90) |
| sf = torch.randn(4, 128) |
| tp = torch.randn(4) |
| ty = torch.randn(4) |
| tpl = torch.randn(4, 90) |
| tyl = torch.randn(4, 90) |
| tf = torch.randn(4, 256) |
| pt = torch.randn(4) |
| yt = torch.randn(4) |
| loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt) |
| assert loss.item() > 0, "loss should be positive" |
| assert 'loss_total' in ld |
| print(f" OK! Loss: {loss.item():.4f}") |
|
|
|
|
| def test_dataset(): |
| print("Testing Dataset...") |
| ds = SyntheticGazeDataset(num_samples=10) |
| sample = ds[0] |
| assert sample['left_eye'].shape == (3, 112, 112) |
| assert sample['right_eye'].shape == (3, 112, 112) |
| assert sample['face_blurred_gray'].shape == (1, 224, 224) |
| assert sample['face_gray'].shape == (1, 224, 224) |
| assert -90 <= sample['pitch'].item() <= 90 |
| assert -90 <= sample['yaw'].item() <= 90 |
| print(f" OK! {len(ds)} samples") |
|
|
|
|
| def test_end_to_end(): |
| print("Testing End-to-End (1 batch)...") |
| teacher = PriviGazeTeacher() |
| student = PriviGazeStudent() |
| loss_fn = PriviGazeDistillationLoss() |
| ds = SyntheticGazeDataset(num_samples=4) |
| batch = {k: torch.stack([ds[i][k] for i in range(4)]) for k in ds[0].keys()} |
|
|
| le = batch['left_eye'] |
| re = batch['right_eye'] |
| fb = batch['face_blurred_gray'] |
| fg = batch['face_gray'] |
| pt = batch['pitch'] |
| yt = batch['yaw'] |
|
|
| with torch.no_grad(): |
| tp, ty, tpl, tyl, tf = teacher(le, re, fb) |
| sp, sy, sf = student(fg) |
| spl = student.pitch_head(sf) |
| syl = student.yaw_head(sf) |
|
|
| loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt) |
| loss.backward() |
| print(f" OK! Backward pass succeeded. Loss: {loss.item():.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| test_teacher() |
| test_student() |
| test_distillation_loss() |
| test_dataset() |
| test_end_to_end() |
| print("\nAll tests passed!") |
|
|