BcantCode commited on
Commit
bd3b08e
·
verified ·
1 Parent(s): e92f7bc

Upload export_onnx.py

Browse files
Files changed (1) hide show
  1. export_onnx.py +51 -0
export_onnx.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Export PriviGaze student to ONNX for on-device deployment.
3
+
4
+ Usage:
5
+ python export_onnx.py --checkpoint ./checkpoints/student_best.pt --output privigaze.onnx
6
+ """
7
+ import argparse
8
+ import torch
9
+ from models.student import PriviGazeStudent
10
+
11
+
12
+ def export(checkpoint_path, output_path, opset=11):
13
+ model = PriviGazeStudent()
14
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
15
+ model.load_state_dict(ckpt.get('student_state_dict', ckpt))
16
+ model.eval()
17
+
18
+ dummy = torch.randn(1, 1, 224, 224)
19
+ torch.onnx.export(
20
+ model,
21
+ dummy,
22
+ output_path,
23
+ input_names=['face_gray'],
24
+ output_names=['pitch', 'yaw', 'features'],
25
+ dynamic_axes={'face_gray': {0: 'batch_size'},
26
+ 'pitch': {0: 'batch_size'},
27
+ 'yaw': {0: 'batch_size'},
28
+ 'features': {0: 'batch_size'}},
29
+ opset_version=opset,
30
+ do_constant_folding=True,
31
+ )
32
+ print(f"Exported to {output_path}")
33
+
34
+ # Verify
35
+ import onnx
36
+ m = onnx.load(output_path)
37
+ onnx.checker.check_model(m)
38
+ print("ONNX model validated OK")
39
+
40
+
41
+ def main():
42
+ p = argparse.ArgumentParser()
43
+ p.add_argument('--checkpoint', type=str, required=True)
44
+ p.add_argument('--output', type=str, default='privigaze.onnx')
45
+ p.add_argument('--opset', type=int, default=11)
46
+ args = p.parse_args()
47
+ export(args.checkpoint, args.output, args.opset)
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()