| """ |
| Export PriviGaze student to ONNX for on-device deployment. |
| |
| Usage: |
| python export_onnx.py --checkpoint ./checkpoints/student_best.pt --output privigaze.onnx |
| """ |
| import argparse |
| import torch |
| from models.student import PriviGazeStudent |
|
|
|
|
| def export(checkpoint_path, output_path, opset=11): |
| model = PriviGazeStudent() |
| ckpt = torch.load(checkpoint_path, map_location='cpu') |
| model.load_state_dict(ckpt.get('student_state_dict', ckpt)) |
| model.eval() |
|
|
| dummy = torch.randn(1, 1, 224, 224) |
| torch.onnx.export( |
| model, |
| dummy, |
| output_path, |
| input_names=['face_gray'], |
| output_names=['pitch', 'yaw', 'features'], |
| dynamic_axes={'face_gray': {0: 'batch_size'}, |
| 'pitch': {0: 'batch_size'}, |
| 'yaw': {0: 'batch_size'}, |
| 'features': {0: 'batch_size'}}, |
| opset_version=opset, |
| do_constant_folding=True, |
| ) |
| print(f"Exported to {output_path}") |
|
|
| |
| import onnx |
| m = onnx.load(output_path) |
| onnx.checker.check_model(m) |
| print("ONNX model validated OK") |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument('--checkpoint', type=str, required=True) |
| p.add_argument('--output', type=str, default='privigaze.onnx') |
| p.add_argument('--opset', type=int, default=11) |
| args = p.parse_args() |
| export(args.checkpoint, args.output, args.opset) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|