Upload test_models.py
Browse files- test_models.py +104 -0
test_models.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sanity check: verify teacher and student forward passes work correctly.
|
| 3 |
+
Run on CPU to test before GPU training.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
from models.teacher import PriviGazeTeacher
|
| 7 |
+
from models.student import PriviGazeStudent, count_parameters
|
| 8 |
+
from models.distillation_loss import PriviGazeDistillationLoss
|
| 9 |
+
from models.dataset import SyntheticGazeDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_teacher():
|
| 13 |
+
print("Testing Teacher...")
|
| 14 |
+
model = PriviGazeTeacher()
|
| 15 |
+
le = torch.randn(2, 3, 112, 112)
|
| 16 |
+
re = torch.randn(2, 3, 112, 112)
|
| 17 |
+
fb = torch.randn(2, 1, 224, 224)
|
| 18 |
+
pp, yp, pl, yl, feat = model(le, re, fb)
|
| 19 |
+
assert pp.shape == (2,), f"pitch shape {pp.shape} != (2,)"
|
| 20 |
+
assert yp.shape == (2,), f"yaw shape {yp.shape} != (2,)"
|
| 21 |
+
assert pl.shape == (2, 90), f"pitch_logits {pl.shape} != (2,90)"
|
| 22 |
+
assert yl.shape == (2, 90), f"yaw_logits {yl.shape} != (2,90)"
|
| 23 |
+
assert feat.shape == (2, 256), f"features {feat.shape} != (2,256)"
|
| 24 |
+
print(f" OK! Params: {count_parameters(model):,}")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_student():
|
| 28 |
+
print("Testing Student...")
|
| 29 |
+
model = PriviGazeStudent()
|
| 30 |
+
fg = torch.randn(2, 1, 224, 224)
|
| 31 |
+
pp, yp, feat = model(fg)
|
| 32 |
+
assert pp.shape == (2,), f"pitch shape {pp.shape}"
|
| 33 |
+
assert yp.shape == (2,), f"yaw shape {yp.shape}"
|
| 34 |
+
assert feat.shape == (2, 128), f"features {feat.shape}"
|
| 35 |
+
print(f" OK! Params: {count_parameters(model):,}")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_distillation_loss():
|
| 39 |
+
print("Testing Distillation Loss...")
|
| 40 |
+
loss_fn = PriviGazeDistillationLoss()
|
| 41 |
+
sp = torch.randn(4)
|
| 42 |
+
sy = torch.randn(4)
|
| 43 |
+
spl = torch.randn(4, 90)
|
| 44 |
+
syl = torch.randn(4, 90)
|
| 45 |
+
sf = torch.randn(4, 128)
|
| 46 |
+
tp = torch.randn(4)
|
| 47 |
+
ty = torch.randn(4)
|
| 48 |
+
tpl = torch.randn(4, 90)
|
| 49 |
+
tyl = torch.randn(4, 90)
|
| 50 |
+
tf = torch.randn(4, 256)
|
| 51 |
+
pt = torch.randn(4)
|
| 52 |
+
yt = torch.randn(4)
|
| 53 |
+
loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt)
|
| 54 |
+
assert loss.item() > 0, "loss should be positive"
|
| 55 |
+
assert 'loss_total' in ld
|
| 56 |
+
print(f" OK! Loss: {loss.item():.4f}")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def test_dataset():
|
| 60 |
+
print("Testing Dataset...")
|
| 61 |
+
ds = SyntheticGazeDataset(num_samples=10)
|
| 62 |
+
sample = ds[0]
|
| 63 |
+
assert sample['left_eye'].shape == (3, 112, 112)
|
| 64 |
+
assert sample['right_eye'].shape == (3, 112, 112)
|
| 65 |
+
assert sample['face_blurred_gray'].shape == (1, 224, 224)
|
| 66 |
+
assert sample['face_gray'].shape == (1, 224, 224)
|
| 67 |
+
assert -90 <= sample['pitch'].item() <= 90
|
| 68 |
+
assert -90 <= sample['yaw'].item() <= 90
|
| 69 |
+
print(f" OK! {len(ds)} samples")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_end_to_end():
|
| 73 |
+
print("Testing End-to-End (1 batch)...")
|
| 74 |
+
teacher = PriviGazeTeacher()
|
| 75 |
+
student = PriviGazeStudent()
|
| 76 |
+
loss_fn = PriviGazeDistillationLoss()
|
| 77 |
+
ds = SyntheticGazeDataset(num_samples=4)
|
| 78 |
+
batch = {k: torch.stack([ds[i][k] for i in range(4)]) for k in ds[0].keys()}
|
| 79 |
+
|
| 80 |
+
le = batch['left_eye']
|
| 81 |
+
re = batch['right_eye']
|
| 82 |
+
fb = batch['face_blurred_gray']
|
| 83 |
+
fg = batch['face_gray']
|
| 84 |
+
pt = batch['pitch']
|
| 85 |
+
yt = batch['yaw']
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
tp, ty, tpl, tyl, tf = teacher(le, re, fb)
|
| 89 |
+
sp, sy, sf = student(fg)
|
| 90 |
+
spl = student.pitch_head(sf)
|
| 91 |
+
syl = student.yaw_head(sf)
|
| 92 |
+
|
| 93 |
+
loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt)
|
| 94 |
+
loss.backward()
|
| 95 |
+
print(f" OK! Backward pass succeeded. Loss: {loss.item():.4f}")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
test_teacher()
|
| 100 |
+
test_student()
|
| 101 |
+
test_distillation_loss()
|
| 102 |
+
test_dataset()
|
| 103 |
+
test_end_to_end()
|
| 104 |
+
print("\nAll tests passed!")
|