facedet / deploy /export_onnx.py
cledouxluma's picture
Upload deploy/export_onnx.py with huggingface_hub
8299fbe verified
"""
ONNX Export for SCRFD models.
Supports:
- Static and dynamic input shapes
- INT8/FP16 TensorRT optimization (post-export)
- Validation after export
"""
import os
import numpy as np
from typing import Optional, Tuple
import torch
import torch.nn as nn
class SCRFDExportWrapper(nn.Module):
"""Wrap SCRFD for clean ONNX export (flatten outputs)."""
def __init__(self, model):
super().__init__()
self.backbone = model.backbone
self.neck = model.neck
self.head = model.head
self.strides = model.strides
def forward(self, x: torch.Tensor):
features = self.backbone(x)
features = self.neck(features)
head_out = self.head(features)
# Flatten all outputs for ONNX
cls_scores = []
bbox_preds = []
for i in range(len(self.strides)):
B, _, H, W = head_out['cls_scores'][i].shape
cls = head_out['cls_scores'][i].permute(0, 2, 3, 1).reshape(B, -1, 1)
reg = head_out['bbox_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 4)
cls_scores.append(cls)
bbox_preds.append(reg)
all_cls = torch.cat(cls_scores, dim=1).sigmoid()
all_reg = torch.cat(bbox_preds, dim=1)
return all_cls, all_reg
def export_to_onnx(
model: nn.Module,
output_path: str,
input_size: int = 640,
dynamic_batch: bool = False,
opset_version: int = 12,
simplify: bool = True,
verify: bool = True,
) -> str:
"""
Export SCRFD model to ONNX format.
Args:
model: Trained SCRFD model
output_path: Output .onnx file path
input_size: Model input resolution
dynamic_batch: Enable dynamic batch size
opset_version: ONNX opset version
simplify: Run onnx-simplifier after export
verify: Run verification after export
Returns:
Path to exported ONNX model
"""
model.eval()
wrapper = SCRFDExportWrapper(model).cpu()
dummy_input = torch.randn(1, 3, input_size, input_size)
# Dynamic axes
dynamic_axes = None
if dynamic_batch:
dynamic_axes = {
'input': {0: 'batch_size'},
'scores': {0: 'batch_size'},
'boxes': {0: 'batch_size'},
}
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
print(f"Exporting ONNX to {output_path}...")
torch.onnx.export(
wrapper,
dummy_input,
output_path,
input_names=['input'],
output_names=['scores', 'boxes'],
dynamic_axes=dynamic_axes,
opset_version=opset_version,
do_constant_folding=True,
)
print(f" Export complete: {output_path}")
# Simplify
if simplify:
try:
import onnxsim
import onnx
model_onnx = onnx.load(output_path)
model_onnx, check = onnxsim.simplify(model_onnx)
if check:
onnx.save(model_onnx, output_path)
print(" Simplified ONNX model")
else:
print(" Warning: ONNX simplification check failed")
except ImportError:
print(" Skipping simplification (install onnxsim: pip install onnxsim)")
# Verify
if verify:
try:
import onnxruntime as ort
session = ort.InferenceSession(output_path)
ort_inputs = {session.get_inputs()[0].name: dummy_input.numpy()}
ort_outputs = session.run(None, ort_inputs)
# Compare with PyTorch output
with torch.no_grad():
pt_outputs = wrapper(dummy_input)
for i, (pt_out, ort_out) in enumerate(zip(pt_outputs, ort_outputs)):
diff = np.abs(pt_out.numpy() - ort_out).max()
print(f" Output {i} max diff: {diff:.6f}")
if diff > 0.01:
print(f" WARNING: Large difference in output {i}!")
print(" Verification passed ✓")
except ImportError:
print(" Skipping verification (install onnxruntime)")
# File size
size_mb = os.path.getsize(output_path) / 1e6
print(f" Model size: {size_mb:.1f} MB")
return output_path