File size: 5,677 Bytes
c0a7163 9ae5b40 c0a7163 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | """MRI image deep-learning inference utilities.
This module is the decision-layer bridge for an externally-trained volumetric
MRI model. The training code can live outside this repo; production only needs
an ONNX artifact plus the preprocessing contract below.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any, Sequence
import nibabel as nib
import numpy as np
from scipy import ndimage as scipy_ndimage
from src.core.logger import get_logger
from src.pipelines.mri_pipeline import is_valid_volume
logger = get_logger(__name__)
DEFAULT_MODEL_PATH = Path("data/processed/mri_model.onnx")
DEFAULT_TARGET_SHAPE: tuple[int, int, int] = (64, 64, 64)
DEFAULT_LABEL_NAMES: tuple[str, ...] = ("class_0", "class_1")
_MIN_STD = 1e-6
def load(path: Path) -> Any:
"""Load an ONNX MRI model artifact.
Args:
path: Path to an externally-trained `.onnx` artifact.
Returns:
An `onnxruntime.InferenceSession`.
Raises:
FileNotFoundError: if the artifact does not exist.
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"MRI model artifact not found: {path}")
import onnxruntime as ort
return ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
def load_nifti_volume(path: Path) -> np.ndarray:
"""Read a NIfTI volume from disk as float32."""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"MRI input not found: {path}")
img = nib.load(str(path))
return np.asarray(img.get_fdata(dtype=np.float32), dtype=np.float32)
def preprocess_volume(
volume: np.ndarray,
target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE,
) -> np.ndarray:
"""Convert a 3-D MRI volume into model input `[1, 1, D, H, W]`.
The external trainer must use the same contract: trilinear resize to
`target_shape`, z-score over non-zero voxels when present, then add batch
and channel dimensions.
"""
if not is_valid_volume(volume):
raise ValueError("MRI volume must be a finite numeric 3-D array")
if len(target_shape) != 3 or any(int(x) <= 0 for x in target_shape):
raise ValueError(f"target_shape must contain three positive integers: {target_shape}")
resized = _resize_volume(np.asarray(volume, dtype=np.float32), target_shape)
normalized = _zscore_volume(resized)
return normalized[np.newaxis, np.newaxis, :, :, :].astype(np.float32, copy=False)
def preprocess_nifti(
input_path: Path,
target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE,
) -> np.ndarray:
"""Read and preprocess one NIfTI file for ONNX inference."""
return preprocess_volume(load_nifti_volume(input_path), target_shape=target_shape)
def predict_with_proba(
model: Any,
model_input: np.ndarray,
label_names: Sequence[str] | None = None,
) -> dict[str, object]:
"""Run an ONNX model and return label, confidence, and per-class probabilities."""
labels = tuple(label_names or DEFAULT_LABEL_NAMES)
if model_input.ndim != 5:
raise ValueError(f"model_input must have shape [1, 1, D, H, W], got {model_input.shape}")
input_name = model.get_inputs()[0].name
output = model.run(None, {input_name: model_input.astype(np.float32, copy=False)})[0]
proba = _as_probabilities(np.asarray(output, dtype=np.float32))
if len(labels) != proba.shape[0]:
logger.warning(
"label_names length (%d) does not match model output dim (%d); "
"overriding with class_0..class_N. Provided labels: %r",
len(labels),
proba.shape[0],
list(labels),
)
labels = tuple(f"class_{i}" for i in range(proba.shape[0]))
label_idx = int(np.argmax(proba))
return {
"label": label_idx,
"label_text": labels[label_idx],
"confidence": float(proba[label_idx]),
"probabilities": [
{"label": i, "label_text": labels[i], "probability": float(p)}
for i, p in enumerate(proba)
],
}
def predict_nifti(
model: Any,
input_path: Path,
target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE,
label_names: Sequence[str] | None = None,
) -> dict[str, object]:
"""Preprocess one NIfTI image and run MRI model inference."""
model_input = preprocess_nifti(input_path, target_shape=target_shape)
return predict_with_proba(model, model_input, label_names=label_names)
def _resize_volume(volume: np.ndarray, target_shape: tuple[int, int, int]) -> np.ndarray:
zoom = tuple(t / s for t, s in zip(target_shape, volume.shape, strict=True))
return scipy_ndimage.zoom(volume, zoom=zoom, order=1).astype(np.float32, copy=False)
def _zscore_volume(volume: np.ndarray) -> np.ndarray:
mask = volume != 0
ref = volume[mask] if np.any(mask) else volume.reshape(-1)
mean = float(ref.mean())
std = float(ref.std())
if std < _MIN_STD:
return np.zeros_like(volume, dtype=np.float32)
return ((volume - mean) / std).astype(np.float32, copy=False)
def _as_probabilities(raw_output: np.ndarray) -> np.ndarray:
logits = np.squeeze(raw_output)
if logits.ndim != 1:
raise ValueError(f"MRI model output must be one class vector, got shape {raw_output.shape}")
if logits.size < 2:
raise ValueError("MRI model output must contain at least two class scores")
if np.all(logits >= 0.0) and np.all(logits <= 1.0) and np.isclose(logits.sum(), 1.0, atol=1e-4):
return logits.astype(np.float32, copy=False)
shifted = logits - np.max(logits)
exp = np.exp(shifted)
return (exp / exp.sum()).astype(np.float32, copy=False)
|