File size: 1,682 Bytes
e92f7bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | """
PriviGaze Inference Script - Run the student model on-device
Usage:
python inference.py --model ./checkpoints/student_best.pt --image face.jpg
"""
import argparse
import torch
import numpy as np
from PIL import Image, ImageOps
from models.student import PriviGazeStudent
def load_model(checkpoint_path, device='cpu'):
model = PriviGazeStudent()
ckpt = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(ckpt.get('student_state_dict', ckpt))
model.eval().to(device)
return model
def preprocess(image_path, size=224):
img = Image.open(image_path).convert('L')
img = img.resize((size, size))
arr = np.array(img).astype(np.float32) / 255.0
tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
return tensor * 2 - 1
def predict(model, tensor, device='cpu'):
tensor = tensor.to(device)
with torch.no_grad():
pitch, yaw, _ = model(tensor)
return pitch.item(), yaw.item()
def main():
p = argparse.ArgumentParser()
p.add_argument('--model', type=str, required=True, help='Path to student checkpoint')
p.add_argument('--image', type=str, required=True, help='Path to grayscale face image')
p.add_argument('--device', type=str, default='cpu')
args = p.parse_args()
device = torch.device(args.device)
model = load_model(args.model, device)
tensor = preprocess(args.image)
pitch, yaw = predict(model, tensor, device)
print(f"Gaze: pitch={pitch:+.2f}deg, yaw={yaw:+.2f}deg")
print(f"Direction: {'up' if pitch>0 else 'down'} {abs(pitch):.1f}deg, "
f"{'right' if yaw>0 else 'left'} {abs(yaw):.1f}deg")
if __name__ == "__main__":
main()
|