cledouxluma commited on
Commit
b734cb2
·
verified ·
1 Parent(s): 307c3fe

Upload scripts/export.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/export.py +67 -0
scripts/export.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export SCRFD model to ONNX for deployment.
4
+
5
+ Usage:
6
+ python scripts/export.py \\
7
+ --model scrfd_34g \\
8
+ --checkpoint checkpoints/scrfd_34g_best.pth \\
9
+ --output deploy/scrfd_34g.onnx \\
10
+ --input-size 640
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import argparse
16
+ from pathlib import Path
17
+
18
+ import torch
19
+
20
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
21
+
22
+ from models.detector import build_detector
23
+ from deploy.export_onnx import export_to_onnx
24
+ from deploy.optimize import benchmark_deployment
25
+
26
+
27
+ def parse_args():
28
+ parser = argparse.ArgumentParser(description='Export SCRFD to ONNX')
29
+ parser.add_argument('--model', type=str, default='scrfd_34g')
30
+ parser.add_argument('--checkpoint', type=str, required=True)
31
+ parser.add_argument('--output', type=str, default='deploy/scrfd_34g.onnx')
32
+ parser.add_argument('--input-size', type=int, default=640)
33
+ parser.add_argument('--dynamic-batch', action='store_true')
34
+ parser.add_argument('--simplify', action='store_true', default=True)
35
+ parser.add_argument('--benchmark', action='store_true', default=True)
36
+ return parser.parse_args()
37
+
38
+
39
+ def main():
40
+ args = parse_args()
41
+
42
+ # Load model
43
+ model = build_detector(args.model)
44
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
45
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
46
+ model.load_state_dict(state_dict, strict=False)
47
+ model.eval()
48
+
49
+ # Export
50
+ export_to_onnx(
51
+ model=model,
52
+ output_path=args.output,
53
+ input_size=args.input_size,
54
+ dynamic_batch=args.dynamic_batch,
55
+ simplify=args.simplify,
56
+ )
57
+
58
+ # Benchmark
59
+ if args.benchmark:
60
+ print("\nBenchmarking ONNX model...")
61
+ results = benchmark_deployment(args.output, input_size=args.input_size)
62
+ for k, v in results.items():
63
+ print(f" {k}: {v}")
64
+
65
+
66
+ if __name__ == '__main__':
67
+ main()