hackathon / src /models /mri_model.py
mekosotto's picture
fix(mri/model): warn when label_names length != model output dim (was silent override)
9ae5b40
"""MRI image deep-learning inference utilities.
This module is the decision-layer bridge for an externally-trained volumetric
MRI model. The training code can live outside this repo; production only needs
an ONNX artifact plus the preprocessing contract below.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any, Sequence
import nibabel as nib
import numpy as np
from scipy import ndimage as scipy_ndimage
from src.core.logger import get_logger
from src.pipelines.mri_pipeline import is_valid_volume
logger = get_logger(__name__)
DEFAULT_MODEL_PATH = Path("data/processed/mri_model.onnx")
DEFAULT_TARGET_SHAPE: tuple[int, int, int] = (64, 64, 64)
DEFAULT_LABEL_NAMES: tuple[str, ...] = ("class_0", "class_1")
_MIN_STD = 1e-6
def load(path: Path) -> Any:
"""Load an ONNX MRI model artifact.
Args:
path: Path to an externally-trained `.onnx` artifact.
Returns:
An `onnxruntime.InferenceSession`.
Raises:
FileNotFoundError: if the artifact does not exist.
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"MRI model artifact not found: {path}")
import onnxruntime as ort
return ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
def load_nifti_volume(path: Path) -> np.ndarray:
"""Read a NIfTI volume from disk as float32."""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"MRI input not found: {path}")
img = nib.load(str(path))
return np.asarray(img.get_fdata(dtype=np.float32), dtype=np.float32)
def preprocess_volume(
volume: np.ndarray,
target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE,
) -> np.ndarray:
"""Convert a 3-D MRI volume into model input `[1, 1, D, H, W]`.
The external trainer must use the same contract: trilinear resize to
`target_shape`, z-score over non-zero voxels when present, then add batch
and channel dimensions.
"""
if not is_valid_volume(volume):
raise ValueError("MRI volume must be a finite numeric 3-D array")
if len(target_shape) != 3 or any(int(x) <= 0 for x in target_shape):
raise ValueError(f"target_shape must contain three positive integers: {target_shape}")
resized = _resize_volume(np.asarray(volume, dtype=np.float32), target_shape)
normalized = _zscore_volume(resized)
return normalized[np.newaxis, np.newaxis, :, :, :].astype(np.float32, copy=False)
def preprocess_nifti(
input_path: Path,
target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE,
) -> np.ndarray:
"""Read and preprocess one NIfTI file for ONNX inference."""
return preprocess_volume(load_nifti_volume(input_path), target_shape=target_shape)
def predict_with_proba(
model: Any,
model_input: np.ndarray,
label_names: Sequence[str] | None = None,
) -> dict[str, object]:
"""Run an ONNX model and return label, confidence, and per-class probabilities."""
labels = tuple(label_names or DEFAULT_LABEL_NAMES)
if model_input.ndim != 5:
raise ValueError(f"model_input must have shape [1, 1, D, H, W], got {model_input.shape}")
input_name = model.get_inputs()[0].name
output = model.run(None, {input_name: model_input.astype(np.float32, copy=False)})[0]
proba = _as_probabilities(np.asarray(output, dtype=np.float32))
if len(labels) != proba.shape[0]:
logger.warning(
"label_names length (%d) does not match model output dim (%d); "
"overriding with class_0..class_N. Provided labels: %r",
len(labels),
proba.shape[0],
list(labels),
)
labels = tuple(f"class_{i}" for i in range(proba.shape[0]))
label_idx = int(np.argmax(proba))
return {
"label": label_idx,
"label_text": labels[label_idx],
"confidence": float(proba[label_idx]),
"probabilities": [
{"label": i, "label_text": labels[i], "probability": float(p)}
for i, p in enumerate(proba)
],
}
def predict_nifti(
model: Any,
input_path: Path,
target_shape: tuple[int, int, int] = DEFAULT_TARGET_SHAPE,
label_names: Sequence[str] | None = None,
) -> dict[str, object]:
"""Preprocess one NIfTI image and run MRI model inference."""
model_input = preprocess_nifti(input_path, target_shape=target_shape)
return predict_with_proba(model, model_input, label_names=label_names)
def _resize_volume(volume: np.ndarray, target_shape: tuple[int, int, int]) -> np.ndarray:
zoom = tuple(t / s for t, s in zip(target_shape, volume.shape, strict=True))
return scipy_ndimage.zoom(volume, zoom=zoom, order=1).astype(np.float32, copy=False)
def _zscore_volume(volume: np.ndarray) -> np.ndarray:
mask = volume != 0
ref = volume[mask] if np.any(mask) else volume.reshape(-1)
mean = float(ref.mean())
std = float(ref.std())
if std < _MIN_STD:
return np.zeros_like(volume, dtype=np.float32)
return ((volume - mean) / std).astype(np.float32, copy=False)
def _as_probabilities(raw_output: np.ndarray) -> np.ndarray:
logits = np.squeeze(raw_output)
if logits.ndim != 1:
raise ValueError(f"MRI model output must be one class vector, got shape {raw_output.shape}")
if logits.size < 2:
raise ValueError("MRI model output must contain at least two class scores")
if np.all(logits >= 0.0) and np.all(logits <= 1.0) and np.isclose(logits.sum(), 1.0, atol=1e-4):
return logits.astype(np.float32, copy=False)
shifted = logits - np.max(logits)
exp = np.exp(shifted)
return (exp / exp.sum()).astype(np.float32, copy=False)