File size: 3,404 Bytes
71c1ad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# app/models/onnx_utils.py
# ONNX export and inference utilities

from pathlib import Path
import numpy as np
from app.observability.logging import get_logger

logger = get_logger(__name__)


def export_to_onnx(
    model,
    sample_input: dict,
    output_path: Path,
    input_names: list[str] | None = None,
    output_names: list[str] | None = None,
    dynamic_axes: dict | None = None,
    opset_version: int = 14,
) -> Path:
    """
    Export a PyTorch model to ONNX format.

    Args:
        model: PyTorch model (eval mode).
        sample_input: Dict of tensor inputs for tracing.
        output_path: Where to save the .onnx file.
        input_names: Names for input tensors.
        output_names: Names for output tensors.
        dynamic_axes: Dynamic axes specification.
        opset_version: ONNX opset version.

    Returns:
        Path to the exported ONNX model.
    """
    import torch

    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    if input_names is None:
        input_names = list(sample_input.keys())
    if output_names is None:
        output_names = ["logits"]
    if dynamic_axes is None:
        dynamic_axes = {name: {0: "batch_size"} for name in input_names + output_names}

    # Prepare ordered tuple of inputs
    input_tuple = tuple(sample_input[name] for name in input_names)

    model.eval()
    with torch.no_grad():
        torch.onnx.export(
            model,
            input_tuple,
            str(output_path),
            input_names=input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            opset_version=opset_version,
            do_constant_folding=True,
        )

    logger.info("onnx_export_complete", path=str(output_path), size_mb=round(output_path.stat().st_size / 1e6, 1))
    return output_path


def load_onnx_session(model_path: Path, providers: list[str] | None = None):
    """
    Load an ONNX model as an InferenceSession.

    Args:
        model_path: Path to .onnx file.
        providers: ONNX Runtime execution providers (defaults to CPU).

    Returns:
        ort.InferenceSession instance.
    """
    import onnxruntime as ort

    if providers is None:
        available = ort.get_available_providers()
        # Prefer CUDA if available, else CPU
        if "CUDAExecutionProvider" in available:
            providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
        else:
            providers = ["CPUExecutionProvider"]

    session = ort.InferenceSession(str(model_path), providers=providers)
    logger.info(
        "onnx_session_loaded",
        path=str(model_path),
        providers=providers,
    )
    return session


def onnx_inference(session, inputs: dict[str, np.ndarray]) -> list[np.ndarray]:
    """
    Run inference on an ONNX session.

    Args:
        session: ONNX InferenceSession.
        inputs: Dict mapping input names to numpy arrays.

    Returns:
        List of output numpy arrays.
    """
    # Ensure proper dtypes
    feed = {}
    for inp in session.get_inputs():
        if inp.name in inputs:
            arr = inputs[inp.name]
            # Match expected dtype
            if "int" in inp.type:
                arr = arr.astype(np.int64)
            else:
                arr = arr.astype(np.float32)
            feed[inp.name] = arr

    return session.run(None, feed)