File size: 5,677 Bytes
c0a7163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ae5b40
 
 
 
 
 
 
c0a7163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""MRI image deep-learning inference utilities.

This module is the decision-layer bridge for an externally-trained volumetric
MRI model. The training code can live outside this repo; production only needs
an ONNX artifact plus the preprocessing contract below.
"""
from __future__ import annotations

from pathlib import Path
from typing import Any, Sequence

import nibabel as nib
import numpy as np
from scipy import ndimage as scipy_ndimage

from src.core.logger import get_logger
from src.pipelines.mri_pipeline import is_valid_volume

logger = get_logger(__name__)

DEFAULT_MODEL_PATH = Path("data/processed/mri_model.onnx")
DEFAULT_TARGET_SHAPE: tuple[int, int, int] = (64, 64, 64)
DEFAULT_LABEL_NAMES: tuple[str, ...] = ("class_0", "class_1")
_MIN_STD = 1e-6


def load(path: Path) -> Any:
    """Load an ONNX MRI model artifact.

    Args:
        path: Path to an externally-trained `.onnx` artifact.

    Returns:
        An `onnxruntime.InferenceSession`.

    Raises:
        FileNotFoundError: if the artifact does not exist.
    """
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"MRI model artifact not found: {path}")
    import onnxruntime as ort

    return ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])


def load_nifti_volume(path: Path) -> np.ndarray:
    """Read a NIfTI volume from disk as float32."""
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"MRI input not found: {path}")
    img = nib.load(str(path))
    return np.asarray(img.get_fdata(dtype=np.float32), dtype=np.float32)


def preprocess_volume(
    volume: np.ndarray,
    target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE,
) -> np.ndarray:
    """Convert a 3-D MRI volume into model input `[1, 1, D, H, W]`.

    The external trainer must use the same contract: trilinear resize to
    `target_shape`, z-score over non-zero voxels when present, then add batch
    and channel dimensions.
    """
    if not is_valid_volume(volume):
        raise ValueError("MRI volume must be a finite numeric 3-D array")
    if len(target_shape) != 3 or any(int(x) <= 0 for x in target_shape):
        raise ValueError(f"target_shape must contain three positive integers: {target_shape}")

    resized = _resize_volume(np.asarray(volume, dtype=np.float32), target_shape)
    normalized = _zscore_volume(resized)
    return normalized[np.newaxis, np.newaxis, :, :, :].astype(np.float32, copy=False)


def preprocess_nifti(
    input_path: Path,
    target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE,
) -> np.ndarray:
    """Read and preprocess one NIfTI file for ONNX inference."""
    return preprocess_volume(load_nifti_volume(input_path), target_shape=target_shape)


def predict_with_proba(
    model: Any,
    model_input: np.ndarray,
    label_names: Sequence[str] | None = None,
) -> dict[str, object]:
    """Run an ONNX model and return label, confidence, and per-class probabilities."""
    labels = tuple(label_names or DEFAULT_LABEL_NAMES)
    if model_input.ndim != 5:
        raise ValueError(f"model_input must have shape [1, 1, D, H, W], got {model_input.shape}")

    input_name = model.get_inputs()[0].name
    output = model.run(None, {input_name: model_input.astype(np.float32, copy=False)})[0]
    proba = _as_probabilities(np.asarray(output, dtype=np.float32))
    if len(labels) != proba.shape[0]:
        logger.warning(
            "label_names length (%d) does not match model output dim (%d); "
            "overriding with class_0..class_N. Provided labels: %r",
            len(labels),
            proba.shape[0],
            list(labels),
        )
        labels = tuple(f"class_{i}" for i in range(proba.shape[0]))

    label_idx = int(np.argmax(proba))
    return {
        "label": label_idx,
        "label_text": labels[label_idx],
        "confidence": float(proba[label_idx]),
        "probabilities": [
            {"label": i, "label_text": labels[i], "probability": float(p)}
            for i, p in enumerate(proba)
        ],
    }


def predict_nifti(
    model: Any,
    input_path: Path,
    target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE,
    label_names: Sequence[str] | None = None,
) -> dict[str, object]:
    """Preprocess one NIfTI image and run MRI model inference."""
    model_input = preprocess_nifti(input_path, target_shape=target_shape)
    return predict_with_proba(model, model_input, label_names=label_names)


def _resize_volume(volume: np.ndarray, target_shape: tuple[int, int, int]) -> np.ndarray:
    zoom = tuple(t / s for t, s in zip(target_shape, volume.shape, strict=True))
    return scipy_ndimage.zoom(volume, zoom=zoom, order=1).astype(np.float32, copy=False)


def _zscore_volume(volume: np.ndarray) -> np.ndarray:
    mask = volume != 0
    ref = volume[mask] if np.any(mask) else volume.reshape(-1)
    mean = float(ref.mean())
    std = float(ref.std())
    if std < _MIN_STD:
        return np.zeros_like(volume, dtype=np.float32)
    return ((volume - mean) / std).astype(np.float32, copy=False)


def _as_probabilities(raw_output: np.ndarray) -> np.ndarray:
    logits = np.squeeze(raw_output)
    if logits.ndim != 1:
        raise ValueError(f"MRI model output must be one class vector, got shape {raw_output.shape}")
    if logits.size < 2:
        raise ValueError("MRI model output must contain at least two class scores")

    if np.all(logits >= 0.0) and np.all(logits <= 1.0) and np.isclose(logits.sum(), 1.0, atol=1e-4):
        return logits.astype(np.float32, copy=False)
    shifted = logits - np.max(logits)
    exp = np.exp(shifted)
    return (exp / exp.sum()).astype(np.float32, copy=False)