BcantCode commited on
Commit
e92f7bc
·
verified ·
1 Parent(s): ee7a26f

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +54 -0
inference.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PriviGaze Inference Script - Run the student model on-device
3
+
4
+ Usage:
5
+ python inference.py --model ./checkpoints/student_best.pt --image face.jpg
6
+ """
7
+ import argparse
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image, ImageOps
11
+ from models.student import PriviGazeStudent
12
+
13
+
14
+ def load_model(checkpoint_path, device='cpu'):
15
+ model = PriviGazeStudent()
16
+ ckpt = torch.load(checkpoint_path, map_location=device)
17
+ model.load_state_dict(ckpt.get('student_state_dict', ckpt))
18
+ model.eval().to(device)
19
+ return model
20
+
21
+
22
+ def preprocess(image_path, size=224):
23
+ img = Image.open(image_path).convert('L')
24
+ img = img.resize((size, size))
25
+ arr = np.array(img).astype(np.float32) / 255.0
26
+ tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
27
+ return tensor * 2 - 1
28
+
29
+
30
+ def predict(model, tensor, device='cpu'):
31
+ tensor = tensor.to(device)
32
+ with torch.no_grad():
33
+ pitch, yaw, _ = model(tensor)
34
+ return pitch.item(), yaw.item()
35
+
36
+
37
+ def main():
38
+ p = argparse.ArgumentParser()
39
+ p.add_argument('--model', type=str, required=True, help='Path to student checkpoint')
40
+ p.add_argument('--image', type=str, required=True, help='Path to grayscale face image')
41
+ p.add_argument('--device', type=str, default='cpu')
42
+ args = p.parse_args()
43
+
44
+ device = torch.device(args.device)
45
+ model = load_model(args.model, device)
46
+ tensor = preprocess(args.image)
47
+ pitch, yaw = predict(model, tensor, device)
48
+ print(f"Gaze: pitch={pitch:+.2f}deg, yaw={yaw:+.2f}deg")
49
+ print(f"Direction: {'up' if pitch>0 else 'down'} {abs(pitch):.1f}deg, "
50
+ f"{'right' if yaw>0 else 'left'} {abs(yaw):.1f}deg")
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()