| """MRI image deep-learning inference utilities. |
| |
| This module is the decision-layer bridge for an externally-trained volumetric |
| MRI model. The training code can live outside this repo; production only needs |
| an ONNX artifact plus the preprocessing contract below. |
| """ |
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Any, Sequence |
|
|
| import nibabel as nib |
| import numpy as np |
| from scipy import ndimage as scipy_ndimage |
|
|
| from src.core.logger import get_logger |
| from src.pipelines.mri_pipeline import is_valid_volume |
|
|
| logger = get_logger(__name__) |
|
|
| DEFAULT_MODEL_PATH = Path("data/processed/mri_model.onnx") |
| DEFAULT_TARGET_SHAPE: tuple[int, int, int] = (64, 64, 64) |
| DEFAULT_LABEL_NAMES: tuple[str, ...] = ("class_0", "class_1") |
| _MIN_STD = 1e-6 |
|
|
|
|
| def load(path: Path) -> Any: |
| """Load an ONNX MRI model artifact. |
| |
| Args: |
| path: Path to an externally-trained `.onnx` artifact. |
| |
| Returns: |
| An `onnxruntime.InferenceSession`. |
| |
| Raises: |
| FileNotFoundError: if the artifact does not exist. |
| """ |
| path = Path(path) |
| if not path.exists(): |
| raise FileNotFoundError(f"MRI model artifact not found: {path}") |
| import onnxruntime as ort |
|
|
| return ort.InferenceSession(str(path), providers=["CPUExecutionProvider"]) |
|
|
|
|
| def load_nifti_volume(path: Path) -> np.ndarray: |
| """Read a NIfTI volume from disk as float32.""" |
| path = Path(path) |
| if not path.exists(): |
| raise FileNotFoundError(f"MRI input not found: {path}") |
| img = nib.load(str(path)) |
| return np.asarray(img.get_fdata(dtype=np.float32), dtype=np.float32) |
|
|
|
|
| def preprocess_volume( |
| volume: np.ndarray, |
| target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE, |
| ) -> np.ndarray: |
| """Convert a 3-D MRI volume into model input `[1, 1, D, H, W]`. |
| |
| The external trainer must use the same contract: trilinear resize to |
| `target_shape`, z-score over non-zero voxels when present, then add batch |
| and channel dimensions. |
| """ |
| if not is_valid_volume(volume): |
| raise ValueError("MRI volume must be a finite numeric 3-D array") |
| if len(target_shape) != 3 or any(int(x) <= 0 for x in target_shape): |
| raise ValueError(f"target_shape must contain three positive integers: {target_shape}") |
|
|
| resized = _resize_volume(np.asarray(volume, dtype=np.float32), target_shape) |
| normalized = _zscore_volume(resized) |
| return normalized[np.newaxis, np.newaxis, :, :, :].astype(np.float32, copy=False) |
|
|
|
|
| def preprocess_nifti( |
| input_path: Path, |
| target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE, |
| ) -> np.ndarray: |
| """Read and preprocess one NIfTI file for ONNX inference.""" |
| return preprocess_volume(load_nifti_volume(input_path), target_shape=target_shape) |
|
|
|
|
| def predict_with_proba( |
| model: Any, |
| model_input: np.ndarray, |
| label_names: Sequence[str] | None = None, |
| ) -> dict[str, object]: |
| """Run an ONNX model and return label, confidence, and per-class probabilities.""" |
| labels = tuple(label_names or DEFAULT_LABEL_NAMES) |
| if model_input.ndim != 5: |
| raise ValueError(f"model_input must have shape [1, 1, D, H, W], got {model_input.shape}") |
|
|
| input_name = model.get_inputs()[0].name |
| output = model.run(None, {input_name: model_input.astype(np.float32, copy=False)})[0] |
| proba = _as_probabilities(np.asarray(output, dtype=np.float32)) |
| if len(labels) != proba.shape[0]: |
| logger.warning( |
| "label_names length (%d) does not match model output dim (%d); " |
| "overriding with class_0..class_N. Provided labels: %r", |
| len(labels), |
| proba.shape[0], |
| list(labels), |
| ) |
| labels = tuple(f"class_{i}" for i in range(proba.shape[0])) |
|
|
| label_idx = int(np.argmax(proba)) |
| return { |
| "label": label_idx, |
| "label_text": labels[label_idx], |
| "confidence": float(proba[label_idx]), |
| "probabilities": [ |
| {"label": i, "label_text": labels[i], "probability": float(p)} |
| for i, p in enumerate(proba) |
| ], |
| } |
|
|
|
|
| def predict_nifti( |
| model: Any, |
| input_path: Path, |
| target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE, |
| label_names: Sequence[str] | None = None, |
| ) -> dict[str, object]: |
| """Preprocess one NIfTI image and run MRI model inference.""" |
| model_input = preprocess_nifti(input_path, target_shape=target_shape) |
| return predict_with_proba(model, model_input, label_names=label_names) |
|
|
|
|
| def _resize_volume(volume: np.ndarray, target_shape: tuple[int, int, int]) -> np.ndarray: |
| zoom = tuple(t / s for t, s in zip(target_shape, volume.shape, strict=True)) |
| return scipy_ndimage.zoom(volume, zoom=zoom, order=1).astype(np.float32, copy=False) |
|
|
|
|
| def _zscore_volume(volume: np.ndarray) -> np.ndarray: |
| mask = volume != 0 |
| ref = volume[mask] if np.any(mask) else volume.reshape(-1) |
| mean = float(ref.mean()) |
| std = float(ref.std()) |
| if std < _MIN_STD: |
| return np.zeros_like(volume, dtype=np.float32) |
| return ((volume - mean) / std).astype(np.float32, copy=False) |
|
|
|
|
| def _as_probabilities(raw_output: np.ndarray) -> np.ndarray: |
| logits = np.squeeze(raw_output) |
| if logits.ndim != 1: |
| raise ValueError(f"MRI model output must be one class vector, got shape {raw_output.shape}") |
| if logits.size < 2: |
| raise ValueError("MRI model output must contain at least two class scores") |
|
|
| if np.all(logits >= 0.0) and np.all(logits <= 1.0) and np.isclose(logits.sum(), 1.0, atol=1e-4): |
| return logits.astype(np.float32, copy=False) |
| shifted = logits - np.max(logits) |
| exp = np.exp(shifted) |
| return (exp / exp.sum()).astype(np.float32, copy=False) |
|
|