File size: 5,419 Bytes
327b23d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""BBB Permeability Map / Score from MRI input.

Two modes:
- `heuristic_proxy` (default, demo-ready): reuses the 2D resnet18 4-class
  Alzheimer's classifier output. Score = 1 - P(NonDemented). Anchored in the
  published correlation between disease severity and BBB breakdown.
- `dce_onnx` (real-DCE artifact, future): loads an ONNX model trained on
  4D DCE-MRI inputs that emit per-voxel Ktrans maps. Stub for now; works
  when the artifact drops in.

Researcher-persona module. Does NOT feed into the fusion engine — fusion's
'BBB is NOT a modality' rule is preserved. The only legitimate place where
BBB and MRI couple in this platform.
"""
from __future__ import annotations

import os
from pathlib import Path
from typing import Any, Literal

import numpy as np

from src.core.logger import get_logger

logger = get_logger(__name__)

PermeabilityMode = Literal["heuristic_proxy", "dce_onnx"]
DEFAULT_MODE: PermeabilityMode = "heuristic_proxy"


def _interpret(score: float) -> str:
    if score < 0.2:
        return "BBB intact"
    if score < 0.4:
        return "mild leakage"
    if score < 0.7:
        return "moderate leakage"
    return "severe leakage"


def compute_from_classifier_probs(probabilities: list[dict[str, Any]]) -> float:
    """Heuristic: 1 - P(NonDemented) from a 4-class probability list.

    Accepts the standard prediction-dict shape used across this repo
    (`probabilities=[{"label_text", "probability"}, ...]`).
    """
    p_nondemented = 0.0
    for entry in probabilities:
        if str(entry.get("label_text", "")).lower() == "nondemented":
            p_nondemented = float(entry.get("probability", 0.0))
            break
    score = max(0.0, min(1.0, 1.0 - p_nondemented))
    return float(score)


def heuristic_proxy_score(input_path: Path, checkpoint_path: Path) -> dict[str, Any]:
    """Run the 2D resnet18 classifier and derive the permeability score from
    its NonDemented probability.
    """
    from src.models import mri_dl_2d

    model = mri_dl_2d.load(checkpoint_path)
    pred = mri_dl_2d.predict_image(model, input_path)
    score = compute_from_classifier_probs(pred["probabilities"])
    return {
        "permeability_score": score,
        "interpretation": _interpret(score),
        "method": "heuristic_proxy",
        "voxel_map_available": False,
        "source_class_probabilities": pred["probabilities"],
    }


def dce_onnx_score(input_path: Path, checkpoint_path: Path) -> dict[str, Any]:
    """Load a DCE-MRI ONNX model and compute per-voxel Ktrans → scalar score.

    Contract for the future artifact:
    - Input: 4D NIfTI `(X, Y, Z, T)` with at least one timepoint.
    - Output: 3D Ktrans map (mL/min/100g). We normalise to `[0, 1]` by
      dividing by a clinically-conservative cap (`_DCE_KTRANS_CAP_MAX`)
      and clipping. Mean over the brain mask becomes the scalar score.

    No real artifact ships with the repo — this raises a clear error
    explaining the contract until one lands at `data/processed/bbb_permeability_dce.onnx`
    (override via `BBB_PERMEABILITY_DCE_PATH`).
    """
    checkpoint_path = Path(checkpoint_path)
    if not checkpoint_path.exists():
        raise FileNotFoundError(
            f"DCE-MRI BBB permeability artifact not found at {checkpoint_path}. "
            "Train and export an ONNX model that consumes a 4D DCE volume "
            "(X, Y, Z, T) and emits a 3D Ktrans map; drop it at this path or "
            "set BBB_PERMEABILITY_DCE_PATH."
        )
    import nibabel as nib
    import onnxruntime as ort

    img = nib.load(str(input_path))
    arr = np.asarray(img.get_fdata(dtype=np.float32), dtype=np.float32)
    if arr.ndim != 4:
        raise ValueError(
            f"DCE-MRI mode expects a 4D NIfTI (X,Y,Z,T); got shape {arr.shape}."
        )
    session = ort.InferenceSession(str(checkpoint_path), providers=["CPUExecutionProvider"])
    input_name = session.get_inputs()[0].name
    ktrans = session.run(None, {input_name: arr[np.newaxis, ...]})[0]
    # Normalise to [0, 1]; clinically-conservative cap of 0.5 mL/min/100g.
    _DCE_KTRANS_CAP = 0.5
    normalised = np.clip(ktrans / _DCE_KTRANS_CAP, 0.0, 1.0).astype(np.float32)
    score = float(np.mean(normalised))
    return {
        "permeability_score": score,
        "interpretation": _interpret(score),
        "method": "dce_onnx",
        "voxel_map_available": True,
        "voxel_map_shape": list(normalised.shape),
    }


def compute_permeability(
    input_path: Path,
    mode: PermeabilityMode = DEFAULT_MODE,
    checkpoint_path: Path | None = None,
) -> dict[str, Any]:
    """Dispatch to the requested mode and return a unified payload."""
    input_path = Path(input_path)
    if not input_path.exists():
        raise FileNotFoundError(f"MRI input not found: {input_path}")

    if mode == "heuristic_proxy":
        ckpt = checkpoint_path or Path(os.environ.get(
            "MRI_MODEL_PATH_2D", "data/processed/mri_dl_2d/best_model.pt",
        ))
        return heuristic_proxy_score(input_path, ckpt)

    if mode == "dce_onnx":
        ckpt = checkpoint_path or Path(os.environ.get(
            "BBB_PERMEABILITY_DCE_PATH",
            "data/processed/bbb_permeability_dce.onnx",
        ))
        return dce_onnx_score(input_path, ckpt)

    raise ValueError(
        f"unknown BBB permeability mode={mode!r}; expected one of "
        "('heuristic_proxy', 'dce_onnx')"
    )