BcantCode commited on
Commit
327e860
·
verified ·
1 Parent(s): bd3b08e

Upload test_models.py

Browse files
Files changed (1) hide show
  1. test_models.py +104 -0
test_models.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sanity check: verify teacher and student forward passes work correctly.
3
+ Run on CPU to test before GPU training.
4
+ """
5
+ import torch
6
+ from models.teacher import PriviGazeTeacher
7
+ from models.student import PriviGazeStudent, count_parameters
8
+ from models.distillation_loss import PriviGazeDistillationLoss
9
+ from models.dataset import SyntheticGazeDataset
10
+
11
+
12
+ def test_teacher():
13
+ print("Testing Teacher...")
14
+ model = PriviGazeTeacher()
15
+ le = torch.randn(2, 3, 112, 112)
16
+ re = torch.randn(2, 3, 112, 112)
17
+ fb = torch.randn(2, 1, 224, 224)
18
+ pp, yp, pl, yl, feat = model(le, re, fb)
19
+ assert pp.shape == (2,), f"pitch shape {pp.shape} != (2,)"
20
+ assert yp.shape == (2,), f"yaw shape {yp.shape} != (2,)"
21
+ assert pl.shape == (2, 90), f"pitch_logits {pl.shape} != (2,90)"
22
+ assert yl.shape == (2, 90), f"yaw_logits {yl.shape} != (2,90)"
23
+ assert feat.shape == (2, 256), f"features {feat.shape} != (2,256)"
24
+ print(f" OK! Params: {count_parameters(model):,}")
25
+
26
+
27
+ def test_student():
28
+ print("Testing Student...")
29
+ model = PriviGazeStudent()
30
+ fg = torch.randn(2, 1, 224, 224)
31
+ pp, yp, feat = model(fg)
32
+ assert pp.shape == (2,), f"pitch shape {pp.shape}"
33
+ assert yp.shape == (2,), f"yaw shape {yp.shape}"
34
+ assert feat.shape == (2, 128), f"features {feat.shape}"
35
+ print(f" OK! Params: {count_parameters(model):,}")
36
+
37
+
38
+ def test_distillation_loss():
39
+ print("Testing Distillation Loss...")
40
+ loss_fn = PriviGazeDistillationLoss()
41
+ sp = torch.randn(4)
42
+ sy = torch.randn(4)
43
+ spl = torch.randn(4, 90)
44
+ syl = torch.randn(4, 90)
45
+ sf = torch.randn(4, 128)
46
+ tp = torch.randn(4)
47
+ ty = torch.randn(4)
48
+ tpl = torch.randn(4, 90)
49
+ tyl = torch.randn(4, 90)
50
+ tf = torch.randn(4, 256)
51
+ pt = torch.randn(4)
52
+ yt = torch.randn(4)
53
+ loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt)
54
+ assert loss.item() > 0, "loss should be positive"
55
+ assert 'loss_total' in ld
56
+ print(f" OK! Loss: {loss.item():.4f}")
57
+
58
+
59
+ def test_dataset():
60
+ print("Testing Dataset...")
61
+ ds = SyntheticGazeDataset(num_samples=10)
62
+ sample = ds[0]
63
+ assert sample['left_eye'].shape == (3, 112, 112)
64
+ assert sample['right_eye'].shape == (3, 112, 112)
65
+ assert sample['face_blurred_gray'].shape == (1, 224, 224)
66
+ assert sample['face_gray'].shape == (1, 224, 224)
67
+ assert -90 <= sample['pitch'].item() <= 90
68
+ assert -90 <= sample['yaw'].item() <= 90
69
+ print(f" OK! {len(ds)} samples")
70
+
71
+
72
+ def test_end_to_end():
73
+ print("Testing End-to-End (1 batch)...")
74
+ teacher = PriviGazeTeacher()
75
+ student = PriviGazeStudent()
76
+ loss_fn = PriviGazeDistillationLoss()
77
+ ds = SyntheticGazeDataset(num_samples=4)
78
+ batch = {k: torch.stack([ds[i][k] for i in range(4)]) for k in ds[0].keys()}
79
+
80
+ le = batch['left_eye']
81
+ re = batch['right_eye']
82
+ fb = batch['face_blurred_gray']
83
+ fg = batch['face_gray']
84
+ pt = batch['pitch']
85
+ yt = batch['yaw']
86
+
87
+ with torch.no_grad():
88
+ tp, ty, tpl, tyl, tf = teacher(le, re, fb)
89
+ sp, sy, sf = student(fg)
90
+ spl = student.pitch_head(sf)
91
+ syl = student.yaw_head(sf)
92
+
93
+ loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt)
94
+ loss.backward()
95
+ print(f" OK! Backward pass succeeded. Loss: {loss.item():.4f}")
96
+
97
+
98
+ if __name__ == "__main__":
99
+ test_teacher()
100
+ test_student()
101
+ test_distillation_loss()
102
+ test_dataset()
103
+ test_end_to_end()
104
+ print("\nAll tests passed!")