| """ |
| ONNX Export for SCRFD models. |
| |
| Supports: |
| - Static and dynamic input shapes |
| - INT8/FP16 TensorRT optimization (post-export) |
| - Validation after export |
| """ |
|
|
| import os |
| import numpy as np |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class SCRFDExportWrapper(nn.Module): |
| """Wrap SCRFD for clean ONNX export (flatten outputs).""" |
|
|
| def __init__(self, model): |
| super().__init__() |
| self.backbone = model.backbone |
| self.neck = model.neck |
| self.head = model.head |
| self.strides = model.strides |
|
|
| def forward(self, x: torch.Tensor): |
| features = self.backbone(x) |
| features = self.neck(features) |
| head_out = self.head(features) |
|
|
| |
| cls_scores = [] |
| bbox_preds = [] |
|
|
| for i in range(len(self.strides)): |
| B, _, H, W = head_out['cls_scores'][i].shape |
| cls = head_out['cls_scores'][i].permute(0, 2, 3, 1).reshape(B, -1, 1) |
| reg = head_out['bbox_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 4) |
| cls_scores.append(cls) |
| bbox_preds.append(reg) |
|
|
| all_cls = torch.cat(cls_scores, dim=1).sigmoid() |
| all_reg = torch.cat(bbox_preds, dim=1) |
|
|
| return all_cls, all_reg |
|
|
|
|
| def export_to_onnx( |
| model: nn.Module, |
| output_path: str, |
| input_size: int = 640, |
| dynamic_batch: bool = False, |
| opset_version: int = 12, |
| simplify: bool = True, |
| verify: bool = True, |
| ) -> str: |
| """ |
| Export SCRFD model to ONNX format. |
| |
| Args: |
| model: Trained SCRFD model |
| output_path: Output .onnx file path |
| input_size: Model input resolution |
| dynamic_batch: Enable dynamic batch size |
| opset_version: ONNX opset version |
| simplify: Run onnx-simplifier after export |
| verify: Run verification after export |
| |
| Returns: |
| Path to exported ONNX model |
| """ |
| model.eval() |
| wrapper = SCRFDExportWrapper(model).cpu() |
|
|
| dummy_input = torch.randn(1, 3, input_size, input_size) |
|
|
| |
| dynamic_axes = None |
| if dynamic_batch: |
| dynamic_axes = { |
| 'input': {0: 'batch_size'}, |
| 'scores': {0: 'batch_size'}, |
| 'boxes': {0: 'batch_size'}, |
| } |
|
|
| os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) |
|
|
| print(f"Exporting ONNX to {output_path}...") |
| torch.onnx.export( |
| wrapper, |
| dummy_input, |
| output_path, |
| input_names=['input'], |
| output_names=['scores', 'boxes'], |
| dynamic_axes=dynamic_axes, |
| opset_version=opset_version, |
| do_constant_folding=True, |
| ) |
| print(f" Export complete: {output_path}") |
|
|
| |
| if simplify: |
| try: |
| import onnxsim |
| import onnx |
| model_onnx = onnx.load(output_path) |
| model_onnx, check = onnxsim.simplify(model_onnx) |
| if check: |
| onnx.save(model_onnx, output_path) |
| print(" Simplified ONNX model") |
| else: |
| print(" Warning: ONNX simplification check failed") |
| except ImportError: |
| print(" Skipping simplification (install onnxsim: pip install onnxsim)") |
|
|
| |
| if verify: |
| try: |
| import onnxruntime as ort |
| session = ort.InferenceSession(output_path) |
| ort_inputs = {session.get_inputs()[0].name: dummy_input.numpy()} |
| ort_outputs = session.run(None, ort_inputs) |
|
|
| |
| with torch.no_grad(): |
| pt_outputs = wrapper(dummy_input) |
|
|
| for i, (pt_out, ort_out) in enumerate(zip(pt_outputs, ort_outputs)): |
| diff = np.abs(pt_out.numpy() - ort_out).max() |
| print(f" Output {i} max diff: {diff:.6f}") |
| if diff > 0.01: |
| print(f" WARNING: Large difference in output {i}!") |
|
|
| print(" Verification passed ✓") |
| except ImportError: |
| print(" Skipping verification (install onnxruntime)") |
|
|
| |
| size_mb = os.path.getsize(output_path) / 1e6 |
| print(f" Model size: {size_mb:.1f} MB") |
|
|
| return output_path |
|
|