| |
| |
| |
|
|
| import io |
| import os |
|
|
| import onnx |
| import torch |
| import torch.nn as nn |
| from onnxsim import simplify as simplify_func |
|
|
| __all__ = ["export_onnx"] |
|
|
|
|
| def export_onnx( |
| model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11 |
| ) -> None: |
| """Export a model to a platform-specific onnx format. |
| |
| Args: |
| model: a torch.nn.Module object. |
| export_path: export location. |
| sample_inputs: Any. |
| simplify: a flag to turn on onnx-simplifier |
| opset: int |
| """ |
| model.eval() |
|
|
| buffer = io.BytesIO() |
| with torch.no_grad(): |
| torch.onnx.export(model, sample_inputs, buffer, opset_version=opset) |
| buffer.seek(0, 0) |
| if simplify: |
| onnx_model = onnx.load_model(buffer) |
| onnx_model, success = simplify_func(onnx_model) |
| assert success |
| new_buffer = io.BytesIO() |
| onnx.save(onnx_model, new_buffer) |
| buffer = new_buffer |
| buffer.seek(0, 0) |
|
|
| if buffer.getbuffer().nbytes > 0: |
| save_dir = os.path.dirname(export_path) |
| os.makedirs(save_dir, exist_ok=True) |
| with open(export_path, "wb") as f: |
| f.write(buffer.read()) |
|
|