File size: 5,419 Bytes
327b23d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """BBB Permeability Map / Score from MRI input.
Two modes:
- `heuristic_proxy` (default, demo-ready): reuses the 2D resnet18 4-class
Alzheimer's classifier output. Score = 1 - P(NonDemented). Anchored in the
published correlation between disease severity and BBB breakdown.
- `dce_onnx` (real-DCE artifact, future): loads an ONNX model trained on
4D DCE-MRI inputs that emit per-voxel Ktrans maps. Stub for now; works
when the artifact drops in.
Researcher-persona module. Does NOT feed into the fusion engine — fusion's
'BBB is NOT a modality' rule is preserved. The only legitimate place where
BBB and MRI couple in this platform.
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Literal
import numpy as np
from src.core.logger import get_logger
logger = get_logger(__name__)
PermeabilityMode = Literal["heuristic_proxy", "dce_onnx"]
DEFAULT_MODE: PermeabilityMode = "heuristic_proxy"
def _interpret(score: float) -> str:
if score < 0.2:
return "BBB intact"
if score < 0.4:
return "mild leakage"
if score < 0.7:
return "moderate leakage"
return "severe leakage"
def compute_from_classifier_probs(probabilities: list[dict[str, Any]]) -> float:
"""Heuristic: 1 - P(NonDemented) from a 4-class probability list.
Accepts the standard prediction-dict shape used across this repo
(`probabilities=[{"label_text", "probability"}, ...]`).
"""
p_nondemented = 0.0
for entry in probabilities:
if str(entry.get("label_text", "")).lower() == "nondemented":
p_nondemented = float(entry.get("probability", 0.0))
break
score = max(0.0, min(1.0, 1.0 - p_nondemented))
return float(score)
def heuristic_proxy_score(input_path: Path, checkpoint_path: Path) -> dict[str, Any]:
"""Run the 2D resnet18 classifier and derive the permeability score from
its NonDemented probability.
"""
from src.models import mri_dl_2d
model = mri_dl_2d.load(checkpoint_path)
pred = mri_dl_2d.predict_image(model, input_path)
score = compute_from_classifier_probs(pred["probabilities"])
return {
"permeability_score": score,
"interpretation": _interpret(score),
"method": "heuristic_proxy",
"voxel_map_available": False,
"source_class_probabilities": pred["probabilities"],
}
def dce_onnx_score(input_path: Path, checkpoint_path: Path) -> dict[str, Any]:
"""Load a DCE-MRI ONNX model and compute per-voxel Ktrans → scalar score.
Contract for the future artifact:
- Input: 4D NIfTI `(X, Y, Z, T)` with at least one timepoint.
- Output: 3D Ktrans map (mL/min/100g). We normalise to `[0, 1]` by
dividing by a clinically-conservative cap (`_DCE_KTRANS_CAP_MAX`)
and clipping. Mean over the brain mask becomes the scalar score.
No real artifact ships with the repo — this raises a clear error
explaining the contract until one lands at `data/processed/bbb_permeability_dce.onnx`
(override via `BBB_PERMEABILITY_DCE_PATH`).
"""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(
f"DCE-MRI BBB permeability artifact not found at {checkpoint_path}. "
"Train and export an ONNX model that consumes a 4D DCE volume "
"(X, Y, Z, T) and emits a 3D Ktrans map; drop it at this path or "
"set BBB_PERMEABILITY_DCE_PATH."
)
import nibabel as nib
import onnxruntime as ort
img = nib.load(str(input_path))
arr = np.asarray(img.get_fdata(dtype=np.float32), dtype=np.float32)
if arr.ndim != 4:
raise ValueError(
f"DCE-MRI mode expects a 4D NIfTI (X,Y,Z,T); got shape {arr.shape}."
)
session = ort.InferenceSession(str(checkpoint_path), providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name
ktrans = session.run(None, {input_name: arr[np.newaxis, ...]})[0]
# Normalise to [0, 1]; clinically-conservative cap of 0.5 mL/min/100g.
_DCE_KTRANS_CAP = 0.5
normalised = np.clip(ktrans / _DCE_KTRANS_CAP, 0.0, 1.0).astype(np.float32)
score = float(np.mean(normalised))
return {
"permeability_score": score,
"interpretation": _interpret(score),
"method": "dce_onnx",
"voxel_map_available": True,
"voxel_map_shape": list(normalised.shape),
}
def compute_permeability(
input_path: Path,
mode: PermeabilityMode = DEFAULT_MODE,
checkpoint_path: Path | None = None,
) -> dict[str, Any]:
"""Dispatch to the requested mode and return a unified payload."""
input_path = Path(input_path)
if not input_path.exists():
raise FileNotFoundError(f"MRI input not found: {input_path}")
if mode == "heuristic_proxy":
ckpt = checkpoint_path or Path(os.environ.get(
"MRI_MODEL_PATH_2D", "data/processed/mri_dl_2d/best_model.pt",
))
return heuristic_proxy_score(input_path, ckpt)
if mode == "dce_onnx":
ckpt = checkpoint_path or Path(os.environ.get(
"BBB_PERMEABILITY_DCE_PATH",
"data/processed/bbb_permeability_dce.onnx",
))
return dce_onnx_score(input_path, ckpt)
raise ValueError(
f"unknown BBB permeability mode={mode!r}; expected one of "
"('heuristic_proxy', 'dce_onnx')"
)
|