File size: 1,682 Bytes
e92f7bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
PriviGaze Inference Script - Run the student model on-device

Usage:
    python inference.py --model ./checkpoints/student_best.pt --image face.jpg
"""
import argparse
import torch
import numpy as np
from PIL import Image, ImageOps
from models.student import PriviGazeStudent


def load_model(checkpoint_path, device='cpu'):
    model = PriviGazeStudent()
    ckpt = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(ckpt.get('student_state_dict', ckpt))
    model.eval().to(device)
    return model


def preprocess(image_path, size=224):
    img = Image.open(image_path).convert('L')
    img = img.resize((size, size))
    arr = np.array(img).astype(np.float32) / 255.0
    tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
    return tensor * 2 - 1


def predict(model, tensor, device='cpu'):
    tensor = tensor.to(device)
    with torch.no_grad():
        pitch, yaw, _ = model(tensor)
    return pitch.item(), yaw.item()


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--model', type=str, required=True, help='Path to student checkpoint')
    p.add_argument('--image', type=str, required=True, help='Path to grayscale face image')
    p.add_argument('--device', type=str, default='cpu')
    args = p.parse_args()

    device = torch.device(args.device)
    model = load_model(args.model, device)
    tensor = preprocess(args.image)
    pitch, yaw = predict(model, tensor, device)
    print(f"Gaze: pitch={pitch:+.2f}deg, yaw={yaw:+.2f}deg")
    print(f"Direction: {'up' if pitch>0 else 'down'} {abs(pitch):.1f}deg, "
          f"{'right' if yaw>0 else 'left'} {abs(yaw):.1f}deg")


if __name__ == "__main__":
    main()