File size: 3,438 Bytes
327e860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Sanity check: verify teacher and student forward passes work correctly.
Run on CPU to test before GPU training.
"""
import torch
from models.teacher import PriviGazeTeacher
from models.student import PriviGazeStudent, count_parameters
from models.distillation_loss import PriviGazeDistillationLoss
from models.dataset import SyntheticGazeDataset


def test_teacher():
    print("Testing Teacher...")
    model = PriviGazeTeacher()
    le = torch.randn(2, 3, 112, 112)
    re = torch.randn(2, 3, 112, 112)
    fb = torch.randn(2, 1, 224, 224)
    pp, yp, pl, yl, feat = model(le, re, fb)
    assert pp.shape == (2,), f"pitch shape {pp.shape} != (2,)"
    assert yp.shape == (2,), f"yaw shape {yp.shape} != (2,)"
    assert pl.shape == (2, 90), f"pitch_logits {pl.shape} != (2,90)"
    assert yl.shape == (2, 90), f"yaw_logits {yl.shape} != (2,90)"
    assert feat.shape == (2, 256), f"features {feat.shape} != (2,256)"
    print(f"  OK! Params: {count_parameters(model):,}")


def test_student():
    print("Testing Student...")
    model = PriviGazeStudent()
    fg = torch.randn(2, 1, 224, 224)
    pp, yp, feat = model(fg)
    assert pp.shape == (2,), f"pitch shape {pp.shape}"
    assert yp.shape == (2,), f"yaw shape {yp.shape}"
    assert feat.shape == (2, 128), f"features {feat.shape}"
    print(f"  OK! Params: {count_parameters(model):,}")


def test_distillation_loss():
    print("Testing Distillation Loss...")
    loss_fn = PriviGazeDistillationLoss()
    sp = torch.randn(4)
    sy = torch.randn(4)
    spl = torch.randn(4, 90)
    syl = torch.randn(4, 90)
    sf = torch.randn(4, 128)
    tp = torch.randn(4)
    ty = torch.randn(4)
    tpl = torch.randn(4, 90)
    tyl = torch.randn(4, 90)
    tf = torch.randn(4, 256)
    pt = torch.randn(4)
    yt = torch.randn(4)
    loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt)
    assert loss.item() > 0, "loss should be positive"
    assert 'loss_total' in ld
    print(f"  OK! Loss: {loss.item():.4f}")


def test_dataset():
    print("Testing Dataset...")
    ds = SyntheticGazeDataset(num_samples=10)
    sample = ds[0]
    assert sample['left_eye'].shape == (3, 112, 112)
    assert sample['right_eye'].shape == (3, 112, 112)
    assert sample['face_blurred_gray'].shape == (1, 224, 224)
    assert sample['face_gray'].shape == (1, 224, 224)
    assert -90 <= sample['pitch'].item() <= 90
    assert -90 <= sample['yaw'].item() <= 90
    print(f"  OK! {len(ds)} samples")


def test_end_to_end():
    print("Testing End-to-End (1 batch)...")
    teacher = PriviGazeTeacher()
    student = PriviGazeStudent()
    loss_fn = PriviGazeDistillationLoss()
    ds = SyntheticGazeDataset(num_samples=4)
    batch = {k: torch.stack([ds[i][k] for i in range(4)]) for k in ds[0].keys()}

    le = batch['left_eye']
    re = batch['right_eye']
    fb = batch['face_blurred_gray']
    fg = batch['face_gray']
    pt = batch['pitch']
    yt = batch['yaw']

    with torch.no_grad():
        tp, ty, tpl, tyl, tf = teacher(le, re, fb)
    sp, sy, sf = student(fg)
    spl = student.pitch_head(sf)
    syl = student.yaw_head(sf)

    loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt)
    loss.backward()
    print(f"  OK! Backward pass succeeded. Loss: {loss.item():.4f}")


if __name__ == "__main__":
    test_teacher()
    test_student()
    test_distillation_loss()
    test_dataset()
    test_end_to_end()
    print("\nAll tests passed!")