# app/models/onnx_utils.py # ONNX export and inference utilities from pathlib import Path import numpy as np from app.observability.logging import get_logger logger = get_logger(__name__) def export_to_onnx( model, sample_input: dict, output_path: Path, input_names: list[str] | None = None, output_names: list[str] | None = None, dynamic_axes: dict | None = None, opset_version: int = 14, ) -> Path: """ Export a PyTorch model to ONNX format. Args: model: PyTorch model (eval mode). sample_input: Dict of tensor inputs for tracing. output_path: Where to save the .onnx file. input_names: Names for input tensors. output_names: Names for output tensors. dynamic_axes: Dynamic axes specification. opset_version: ONNX opset version. Returns: Path to the exported ONNX model. """ import torch output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) if input_names is None: input_names = list(sample_input.keys()) if output_names is None: output_names = ["logits"] if dynamic_axes is None: dynamic_axes = {name: {0: "batch_size"} for name in input_names + output_names} # Prepare ordered tuple of inputs input_tuple = tuple(sample_input[name] for name in input_names) model.eval() with torch.no_grad(): torch.onnx.export( model, input_tuple, str(output_path), input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=opset_version, do_constant_folding=True, ) logger.info("onnx_export_complete", path=str(output_path), size_mb=round(output_path.stat().st_size / 1e6, 1)) return output_path def load_onnx_session(model_path: Path, providers: list[str] | None = None): """ Load an ONNX model as an InferenceSession. Args: model_path: Path to .onnx file. providers: ONNX Runtime execution providers (defaults to CPU). Returns: ort.InferenceSession instance. """ import onnxruntime as ort if providers is None: available = ort.get_available_providers() # Prefer CUDA if available, else CPU if "CUDAExecutionProvider" in available: providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] session = ort.InferenceSession(str(model_path), providers=providers) logger.info( "onnx_session_loaded", path=str(model_path), providers=providers, ) return session def onnx_inference(session, inputs: dict[str, np.ndarray]) -> list[np.ndarray]: """ Run inference on an ONNX session. Args: session: ONNX InferenceSession. inputs: Dict mapping input names to numpy arrays. Returns: List of output numpy arrays. """ # Ensure proper dtypes feed = {} for inp in session.get_inputs(): if inp.name in inputs: arr = inputs[inp.name] # Match expected dtype if "int" in inp.type: arr = arr.astype(np.int64) else: arr = arr.astype(np.float32) feed[inp.name] = arr return session.run(None, feed)