""" ONNX Export for SCRFD models. Supports: - Static and dynamic input shapes - INT8/FP16 TensorRT optimization (post-export) - Validation after export """ import os import numpy as np from typing import Optional, Tuple import torch import torch.nn as nn class SCRFDExportWrapper(nn.Module): """Wrap SCRFD for clean ONNX export (flatten outputs).""" def __init__(self, model): super().__init__() self.backbone = model.backbone self.neck = model.neck self.head = model.head self.strides = model.strides def forward(self, x: torch.Tensor): features = self.backbone(x) features = self.neck(features) head_out = self.head(features) # Flatten all outputs for ONNX cls_scores = [] bbox_preds = [] for i in range(len(self.strides)): B, _, H, W = head_out['cls_scores'][i].shape cls = head_out['cls_scores'][i].permute(0, 2, 3, 1).reshape(B, -1, 1) reg = head_out['bbox_preds'][i].permute(0, 2, 3, 1).reshape(B, -1, 4) cls_scores.append(cls) bbox_preds.append(reg) all_cls = torch.cat(cls_scores, dim=1).sigmoid() all_reg = torch.cat(bbox_preds, dim=1) return all_cls, all_reg def export_to_onnx( model: nn.Module, output_path: str, input_size: int = 640, dynamic_batch: bool = False, opset_version: int = 12, simplify: bool = True, verify: bool = True, ) -> str: """ Export SCRFD model to ONNX format. Args: model: Trained SCRFD model output_path: Output .onnx file path input_size: Model input resolution dynamic_batch: Enable dynamic batch size opset_version: ONNX opset version simplify: Run onnx-simplifier after export verify: Run verification after export Returns: Path to exported ONNX model """ model.eval() wrapper = SCRFDExportWrapper(model).cpu() dummy_input = torch.randn(1, 3, input_size, input_size) # Dynamic axes dynamic_axes = None if dynamic_batch: dynamic_axes = { 'input': {0: 'batch_size'}, 'scores': {0: 'batch_size'}, 'boxes': {0: 'batch_size'}, } os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) print(f"Exporting ONNX to {output_path}...") torch.onnx.export( wrapper, dummy_input, output_path, input_names=['input'], output_names=['scores', 'boxes'], dynamic_axes=dynamic_axes, opset_version=opset_version, do_constant_folding=True, ) print(f" Export complete: {output_path}") # Simplify if simplify: try: import onnxsim import onnx model_onnx = onnx.load(output_path) model_onnx, check = onnxsim.simplify(model_onnx) if check: onnx.save(model_onnx, output_path) print(" Simplified ONNX model") else: print(" Warning: ONNX simplification check failed") except ImportError: print(" Skipping simplification (install onnxsim: pip install onnxsim)") # Verify if verify: try: import onnxruntime as ort session = ort.InferenceSession(output_path) ort_inputs = {session.get_inputs()[0].name: dummy_input.numpy()} ort_outputs = session.run(None, ort_inputs) # Compare with PyTorch output with torch.no_grad(): pt_outputs = wrapper(dummy_input) for i, (pt_out, ort_out) in enumerate(zip(pt_outputs, ort_outputs)): diff = np.abs(pt_out.numpy() - ort_out).max() print(f" Output {i} max diff: {diff:.6f}") if diff > 0.01: print(f" WARNING: Large difference in output {i}!") print(" Verification passed ✓") except ImportError: print(" Skipping verification (install onnxruntime)") # File size size_mb = os.path.getsize(output_path) / 1e6 print(f" Model size: {size_mb:.1f} MB") return output_path