| """BBB Permeability Map / Score from MRI input. |
| |
| Two modes: |
| - `heuristic_proxy` (default, demo-ready): reuses the 2D resnet18 4-class |
| Alzheimer's classifier output. Score = 1 - P(NonDemented). Anchored in the |
| published correlation between disease severity and BBB breakdown. |
| - `dce_onnx` (real-DCE artifact, future): loads an ONNX model trained on |
| 4D DCE-MRI inputs that emit per-voxel Ktrans maps. Stub for now; works |
| when the artifact drops in. |
| |
| Researcher-persona module. Does NOT feed into the fusion engine — fusion's |
| 'BBB is NOT a modality' rule is preserved. The only legitimate place where |
| BBB and MRI couple in this platform. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
| from typing import Any, Literal |
|
|
| import numpy as np |
|
|
| from src.core.logger import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
| PermeabilityMode = Literal["heuristic_proxy", "dce_onnx"] |
| DEFAULT_MODE: PermeabilityMode = "heuristic_proxy" |
|
|
|
|
| def _interpret(score: float) -> str: |
| if score < 0.2: |
| return "BBB intact" |
| if score < 0.4: |
| return "mild leakage" |
| if score < 0.7: |
| return "moderate leakage" |
| return "severe leakage" |
|
|
|
|
| def compute_from_classifier_probs(probabilities: list[dict[str, Any]]) -> float: |
| """Heuristic: 1 - P(NonDemented) from a 4-class probability list. |
| |
| Accepts the standard prediction-dict shape used across this repo |
| (`probabilities=[{"label_text", "probability"}, ...]`). |
| """ |
| p_nondemented = 0.0 |
| for entry in probabilities: |
| if str(entry.get("label_text", "")).lower() == "nondemented": |
| p_nondemented = float(entry.get("probability", 0.0)) |
| break |
| score = max(0.0, min(1.0, 1.0 - p_nondemented)) |
| return float(score) |
|
|
|
|
| def heuristic_proxy_score(input_path: Path, checkpoint_path: Path) -> dict[str, Any]: |
| """Run the 2D resnet18 classifier and derive the permeability score from |
| its NonDemented probability. |
| """ |
| from src.models import mri_dl_2d |
|
|
| model = mri_dl_2d.load(checkpoint_path) |
| pred = mri_dl_2d.predict_image(model, input_path) |
| score = compute_from_classifier_probs(pred["probabilities"]) |
| return { |
| "permeability_score": score, |
| "interpretation": _interpret(score), |
| "method": "heuristic_proxy", |
| "voxel_map_available": False, |
| "source_class_probabilities": pred["probabilities"], |
| } |
|
|
|
|
| def dce_onnx_score(input_path: Path, checkpoint_path: Path) -> dict[str, Any]: |
| """Load a DCE-MRI ONNX model and compute per-voxel Ktrans → scalar score. |
| |
| Contract for the future artifact: |
| - Input: 4D NIfTI `(X, Y, Z, T)` with at least one timepoint. |
| - Output: 3D Ktrans map (mL/min/100g). We normalise to `[0, 1]` by |
| dividing by a clinically-conservative cap (`_DCE_KTRANS_CAP_MAX`) |
| and clipping. Mean over the brain mask becomes the scalar score. |
| |
| No real artifact ships with the repo — this raises a clear error |
| explaining the contract until one lands at `data/processed/bbb_permeability_dce.onnx` |
| (override via `BBB_PERMEABILITY_DCE_PATH`). |
| """ |
| checkpoint_path = Path(checkpoint_path) |
| if not checkpoint_path.exists(): |
| raise FileNotFoundError( |
| f"DCE-MRI BBB permeability artifact not found at {checkpoint_path}. " |
| "Train and export an ONNX model that consumes a 4D DCE volume " |
| "(X, Y, Z, T) and emits a 3D Ktrans map; drop it at this path or " |
| "set BBB_PERMEABILITY_DCE_PATH." |
| ) |
| import nibabel as nib |
| import onnxruntime as ort |
|
|
| img = nib.load(str(input_path)) |
| arr = np.asarray(img.get_fdata(dtype=np.float32), dtype=np.float32) |
| if arr.ndim != 4: |
| raise ValueError( |
| f"DCE-MRI mode expects a 4D NIfTI (X,Y,Z,T); got shape {arr.shape}." |
| ) |
| session = ort.InferenceSession(str(checkpoint_path), providers=["CPUExecutionProvider"]) |
| input_name = session.get_inputs()[0].name |
| ktrans = session.run(None, {input_name: arr[np.newaxis, ...]})[0] |
| |
| _DCE_KTRANS_CAP = 0.5 |
| normalised = np.clip(ktrans / _DCE_KTRANS_CAP, 0.0, 1.0).astype(np.float32) |
| score = float(np.mean(normalised)) |
| return { |
| "permeability_score": score, |
| "interpretation": _interpret(score), |
| "method": "dce_onnx", |
| "voxel_map_available": True, |
| "voxel_map_shape": list(normalised.shape), |
| } |
|
|
|
|
| def compute_permeability( |
| input_path: Path, |
| mode: PermeabilityMode = DEFAULT_MODE, |
| checkpoint_path: Path | None = None, |
| ) -> dict[str, Any]: |
| """Dispatch to the requested mode and return a unified payload.""" |
| input_path = Path(input_path) |
| if not input_path.exists(): |
| raise FileNotFoundError(f"MRI input not found: {input_path}") |
|
|
| if mode == "heuristic_proxy": |
| ckpt = checkpoint_path or Path(os.environ.get( |
| "MRI_MODEL_PATH_2D", "data/processed/mri_dl_2d/best_model.pt", |
| )) |
| return heuristic_proxy_score(input_path, ckpt) |
|
|
| if mode == "dce_onnx": |
| ckpt = checkpoint_path or Path(os.environ.get( |
| "BBB_PERMEABILITY_DCE_PATH", |
| "data/processed/bbb_permeability_dce.onnx", |
| )) |
| return dce_onnx_score(input_path, ckpt) |
|
|
| raise ValueError( |
| f"unknown BBB permeability mode={mode!r}; expected one of " |
| "('heuristic_proxy', 'dce_onnx')" |
| ) |
|
|