File size: 1,458 Bytes
bd3b08e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | """
Export PriviGaze student to ONNX for on-device deployment.
Usage:
python export_onnx.py --checkpoint ./checkpoints/student_best.pt --output privigaze.onnx
"""
import argparse
import torch
from models.student import PriviGazeStudent
def export(checkpoint_path, output_path, opset=11):
model = PriviGazeStudent()
ckpt = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(ckpt.get('student_state_dict', ckpt))
model.eval()
dummy = torch.randn(1, 1, 224, 224)
torch.onnx.export(
model,
dummy,
output_path,
input_names=['face_gray'],
output_names=['pitch', 'yaw', 'features'],
dynamic_axes={'face_gray': {0: 'batch_size'},
'pitch': {0: 'batch_size'},
'yaw': {0: 'batch_size'},
'features': {0: 'batch_size'}},
opset_version=opset,
do_constant_folding=True,
)
print(f"Exported to {output_path}")
# Verify
import onnx
m = onnx.load(output_path)
onnx.checker.check_model(m)
print("ONNX model validated OK")
def main():
p = argparse.ArgumentParser()
p.add_argument('--checkpoint', type=str, required=True)
p.add_argument('--output', type=str, default='privigaze.onnx')
p.add_argument('--opset', type=int, default=11)
args = p.parse_args()
export(args.checkpoint, args.output, args.opset)
if __name__ == "__main__":
main()
|