privi-gaze-distill / export_onnx.py
BcantCode's picture
Upload export_onnx.py
bd3b08e verified
"""
Export PriviGaze student to ONNX for on-device deployment.
Usage:
python export_onnx.py --checkpoint ./checkpoints/student_best.pt --output privigaze.onnx
"""
import argparse
import torch
from models.student import PriviGazeStudent
def export(checkpoint_path, output_path, opset=11):
model = PriviGazeStudent()
ckpt = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(ckpt.get('student_state_dict', ckpt))
model.eval()
dummy = torch.randn(1, 1, 224, 224)
torch.onnx.export(
model,
dummy,
output_path,
input_names=['face_gray'],
output_names=['pitch', 'yaw', 'features'],
dynamic_axes={'face_gray': {0: 'batch_size'},
'pitch': {0: 'batch_size'},
'yaw': {0: 'batch_size'},
'features': {0: 'batch_size'}},
opset_version=opset,
do_constant_folding=True,
)
print(f"Exported to {output_path}")
# Verify
import onnx
m = onnx.load(output_path)
onnx.checker.check_model(m)
print("ONNX model validated OK")
def main():
p = argparse.ArgumentParser()
p.add_argument('--checkpoint', type=str, required=True)
p.add_argument('--output', type=str, default='privigaze.onnx')
p.add_argument('--opset', type=int, default=11)
args = p.parse_args()
export(args.checkpoint, args.output, args.opset)
if __name__ == "__main__":
main()