File size: 1,927 Bytes
b734cb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python3
"""
Export SCRFD model to ONNX for deployment.

Usage:
    python scripts/export.py \\
        --model scrfd_34g \\
        --checkpoint checkpoints/scrfd_34g_best.pth \\
        --output deploy/scrfd_34g.onnx \\
        --input-size 640
"""

import os
import sys
import argparse
from pathlib import Path

import torch

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from models.detector import build_detector
from deploy.export_onnx import export_to_onnx
from deploy.optimize import benchmark_deployment


def parse_args():
    parser = argparse.ArgumentParser(description='Export SCRFD to ONNX')
    parser.add_argument('--model', type=str, default='scrfd_34g')
    parser.add_argument('--checkpoint', type=str, required=True)
    parser.add_argument('--output', type=str, default='deploy/scrfd_34g.onnx')
    parser.add_argument('--input-size', type=int, default=640)
    parser.add_argument('--dynamic-batch', action='store_true')
    parser.add_argument('--simplify', action='store_true', default=True)
    parser.add_argument('--benchmark', action='store_true', default=True)
    return parser.parse_args()


def main():
    args = parse_args()

    # Load model
    model = build_detector(args.model)
    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    state_dict = checkpoint.get('model_state_dict', checkpoint)
    model.load_state_dict(state_dict, strict=False)
    model.eval()

    # Export
    export_to_onnx(
        model=model,
        output_path=args.output,
        input_size=args.input_size,
        dynamic_batch=args.dynamic_batch,
        simplify=args.simplify,
    )

    # Benchmark
    if args.benchmark:
        print("\nBenchmarking ONNX model...")
        results = benchmark_deployment(args.output, input_size=args.input_size)
        for k, v in results.items():
            print(f"  {k}: {v}")


if __name__ == '__main__':
    main()