cledouxluma commited on
Commit
8299fbe
·
verified ·
1 Parent(s): f7417f1

Upload deploy/export_onnx.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. deploy/export_onnx.py +144 -0
deploy/export_onnx.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ONNX Export for SCRFD models.
3
+
4
+ Supports:
5
+ - Static and dynamic input shapes
6
+ - INT8/FP16 TensorRT optimization (post-export)
7
+ - Validation after export
8
+ """
9
+
10
+ import os
11
+ import numpy as np
12
+ from typing import Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+
18
+ class SCRFDExportWrapper(nn.Module):
19
+ """Wrap SCRFD for clean ONNX export (flatten outputs)."""
20
+
21
+ def __init__(self, model):
22
+ super().__init__()
23
+ self.backbone = model.backbone
24
+ self.neck = model.neck
25
+ self.head = model.head
26
+ self.strides = model.strides
27
+
28
+ def forward(self, x: torch.Tensor):
29
+ features = self.backbone(x)
30
+ features = self.neck(features)
31
+ head_out = self.head(features)
32
+
33
+ # Flatten all outputs for ONNX
34
+ cls_scores = []
35
+ bbox_preds = []
36
+
37
+ for i in range(len(self.strides)):
38
+ B, _, H, W = head_out['cls_scores'][i].shape
39
+ cls = head_out['cls_scores'][i].permute(0, 2, 3, 1).reshape(B, -1, 1)
40
+ reg = head_out['bbox_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 4)
41
+ cls_scores.append(cls)
42
+ bbox_preds.append(reg)
43
+
44
+ all_cls = torch.cat(cls_scores, dim=1).sigmoid()
45
+ all_reg = torch.cat(bbox_preds, dim=1)
46
+
47
+ return all_cls, all_reg
48
+
49
+
50
+ def export_to_onnx(
51
+ model: nn.Module,
52
+ output_path: str,
53
+ input_size: int = 640,
54
+ dynamic_batch: bool = False,
55
+ opset_version: int = 12,
56
+ simplify: bool = True,
57
+ verify: bool = True,
58
+ ) -> str:
59
+ """
60
+ Export SCRFD model to ONNX format.
61
+
62
+ Args:
63
+ model: Trained SCRFD model
64
+ output_path: Output .onnx file path
65
+ input_size: Model input resolution
66
+ dynamic_batch: Enable dynamic batch size
67
+ opset_version: ONNX opset version
68
+ simplify: Run onnx-simplifier after export
69
+ verify: Run verification after export
70
+
71
+ Returns:
72
+ Path to exported ONNX model
73
+ """
74
+ model.eval()
75
+ wrapper = SCRFDExportWrapper(model).cpu()
76
+
77
+ dummy_input = torch.randn(1, 3, input_size, input_size)
78
+
79
+ # Dynamic axes
80
+ dynamic_axes = None
81
+ if dynamic_batch:
82
+ dynamic_axes = {
83
+ 'input': {0: 'batch_size'},
84
+ 'scores': {0: 'batch_size'},
85
+ 'boxes': {0: 'batch_size'},
86
+ }
87
+
88
+ os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
89
+
90
+ print(f"Exporting ONNX to {output_path}...")
91
+ torch.onnx.export(
92
+ wrapper,
93
+ dummy_input,
94
+ output_path,
95
+ input_names=['input'],
96
+ output_names=['scores', 'boxes'],
97
+ dynamic_axes=dynamic_axes,
98
+ opset_version=opset_version,
99
+ do_constant_folding=True,
100
+ )
101
+ print(f" Export complete: {output_path}")
102
+
103
+ # Simplify
104
+ if simplify:
105
+ try:
106
+ import onnxsim
107
+ import onnx
108
+ model_onnx = onnx.load(output_path)
109
+ model_onnx, check = onnxsim.simplify(model_onnx)
110
+ if check:
111
+ onnx.save(model_onnx, output_path)
112
+ print(" Simplified ONNX model")
113
+ else:
114
+ print(" Warning: ONNX simplification check failed")
115
+ except ImportError:
116
+ print(" Skipping simplification (install onnxsim: pip install onnxsim)")
117
+
118
+ # Verify
119
+ if verify:
120
+ try:
121
+ import onnxruntime as ort
122
+ session = ort.InferenceSession(output_path)
123
+ ort_inputs = {session.get_inputs()[0].name: dummy_input.numpy()}
124
+ ort_outputs = session.run(None, ort_inputs)
125
+
126
+ # Compare with PyTorch output
127
+ with torch.no_grad():
128
+ pt_outputs = wrapper(dummy_input)
129
+
130
+ for i, (pt_out, ort_out) in enumerate(zip(pt_outputs, ort_outputs)):
131
+ diff = np.abs(pt_out.numpy() - ort_out).max()
132
+ print(f" Output {i} max diff: {diff:.6f}")
133
+ if diff > 0.01:
134
+ print(f" WARNING: Large difference in output {i}!")
135
+
136
+ print(" Verification passed ✓")
137
+ except ImportError:
138
+ print(" Skipping verification (install onnxruntime)")
139
+
140
+ # File size
141
+ size_mb = os.path.getsize(output_path) / 1e6
142
+ print(f" Model size: {size_mb:.1f} MB")
143
+
144
+ return output_path