privi-gaze-distill / test_models.py
BcantCode's picture
Upload test_models.py
327e860 verified
"""
Sanity check: verify teacher and student forward passes work correctly.
Run on CPU to test before GPU training.
"""
import torch
from models.teacher import PriviGazeTeacher
from models.student import PriviGazeStudent, count_parameters
from models.distillation_loss import PriviGazeDistillationLoss
from models.dataset import SyntheticGazeDataset
def test_teacher():
print("Testing Teacher...")
model = PriviGazeTeacher()
le = torch.randn(2, 3, 112, 112)
re = torch.randn(2, 3, 112, 112)
fb = torch.randn(2, 1, 224, 224)
pp, yp, pl, yl, feat = model(le, re, fb)
assert pp.shape == (2,), f"pitch shape {pp.shape} != (2,)"
assert yp.shape == (2,), f"yaw shape {yp.shape} != (2,)"
assert pl.shape == (2, 90), f"pitch_logits {pl.shape} != (2,90)"
assert yl.shape == (2, 90), f"yaw_logits {yl.shape} != (2,90)"
assert feat.shape == (2, 256), f"features {feat.shape} != (2,256)"
print(f" OK! Params: {count_parameters(model):,}")
def test_student():
print("Testing Student...")
model = PriviGazeStudent()
fg = torch.randn(2, 1, 224, 224)
pp, yp, feat = model(fg)
assert pp.shape == (2,), f"pitch shape {pp.shape}"
assert yp.shape == (2,), f"yaw shape {yp.shape}"
assert feat.shape == (2, 128), f"features {feat.shape}"
print(f" OK! Params: {count_parameters(model):,}")
def test_distillation_loss():
print("Testing Distillation Loss...")
loss_fn = PriviGazeDistillationLoss()
sp = torch.randn(4)
sy = torch.randn(4)
spl = torch.randn(4, 90)
syl = torch.randn(4, 90)
sf = torch.randn(4, 128)
tp = torch.randn(4)
ty = torch.randn(4)
tpl = torch.randn(4, 90)
tyl = torch.randn(4, 90)
tf = torch.randn(4, 256)
pt = torch.randn(4)
yt = torch.randn(4)
loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt)
assert loss.item() > 0, "loss should be positive"
assert 'loss_total' in ld
print(f" OK! Loss: {loss.item():.4f}")
def test_dataset():
print("Testing Dataset...")
ds = SyntheticGazeDataset(num_samples=10)
sample = ds[0]
assert sample['left_eye'].shape == (3, 112, 112)
assert sample['right_eye'].shape == (3, 112, 112)
assert sample['face_blurred_gray'].shape == (1, 224, 224)
assert sample['face_gray'].shape == (1, 224, 224)
assert -90 <= sample['pitch'].item() <= 90
assert -90 <= sample['yaw'].item() <= 90
print(f" OK! {len(ds)} samples")
def test_end_to_end():
print("Testing End-to-End (1 batch)...")
teacher = PriviGazeTeacher()
student = PriviGazeStudent()
loss_fn = PriviGazeDistillationLoss()
ds = SyntheticGazeDataset(num_samples=4)
batch = {k: torch.stack([ds[i][k] for i in range(4)]) for k in ds[0].keys()}
le = batch['left_eye']
re = batch['right_eye']
fb = batch['face_blurred_gray']
fg = batch['face_gray']
pt = batch['pitch']
yt = batch['yaw']
with torch.no_grad():
tp, ty, tpl, tyl, tf = teacher(le, re, fb)
sp, sy, sf = student(fg)
spl = student.pitch_head(sf)
syl = student.yaw_head(sf)
loss, ld = loss_fn(sp, sy, spl, syl, sf, tp, ty, tpl, tyl, tf, pt, yt)
loss.backward()
print(f" OK! Backward pass succeeded. Loss: {loss.item():.4f}")
if __name__ == "__main__":
test_teacher()
test_student()
test_distillation_loss()
test_dataset()
test_end_to_end()
print("\nAll tests passed!")