facedet / scripts /export.py
cledouxluma's picture
Upload scripts/export.py with huggingface_hub
b734cb2 verified
raw
history blame
1.93 kB
#!/usr/bin/env python3
"""
Export SCRFD model to ONNX for deployment.
Usage:
python scripts/export.py \\
--model scrfd_34g \\
--checkpoint checkpoints/scrfd_34g_best.pth \\
--output deploy/scrfd_34g.onnx \\
--input-size 640
"""
import os
import sys
import argparse
from pathlib import Path
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from models.detector import build_detector
from deploy.export_onnx import export_to_onnx
from deploy.optimize import benchmark_deployment
def parse_args():
parser = argparse.ArgumentParser(description='Export SCRFD to ONNX')
parser.add_argument('--model', type=str, default='scrfd_34g')
parser.add_argument('--checkpoint', type=str, required=True)
parser.add_argument('--output', type=str, default='deploy/scrfd_34g.onnx')
parser.add_argument('--input-size', type=int, default=640)
parser.add_argument('--dynamic-batch', action='store_true')
parser.add_argument('--simplify', action='store_true', default=True)
parser.add_argument('--benchmark', action='store_true', default=True)
return parser.parse_args()
def main():
args = parse_args()
# Load model
model = build_detector(args.model)
checkpoint = torch.load(args.checkpoint, map_location='cpu')
state_dict = checkpoint.get('model_state_dict', checkpoint)
model.load_state_dict(state_dict, strict=False)
model.eval()
# Export
export_to_onnx(
model=model,
output_path=args.output,
input_size=args.input_size,
dynamic_batch=args.dynamic_batch,
simplify=args.simplify,
)
# Benchmark
if args.benchmark:
print("\nBenchmarking ONNX model...")
results = benchmark_deployment(args.output, input_size=args.input_size)
for k, v in results.items():
print(f" {k}: {v}")
if __name__ == '__main__':
main()