feat(models): add 2D resnet18 4-class Alzheimer's MRI inference module
Browse files
src/models/mri_dl_2d.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pretrained 2D MRI Alzheimer's classifier (resnet18, 4 classes).
|
| 2 |
+
|
| 3 |
+
Decision-layer bridge for an externally-trained PyTorch checkpoint. Loads
|
| 4 |
+
either a state_dict (default) or a full pickled model, applies the trainer's
|
| 5 |
+
preprocessing (resize image_size=160, ImageNet normalisation), and emits the
|
| 6 |
+
same dict shape as src.models.mri_model.predict_with_proba so downstream
|
| 7 |
+
code paths don't care which backend produced the prediction.
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from torchvision import models, transforms
|
| 19 |
+
|
| 20 |
+
from src.core.logger import get_logger
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
CLASS_TO_IDX: dict[str, int] = {
|
| 25 |
+
"MildDemented": 0,
|
| 26 |
+
"ModerateDemented": 1,
|
| 27 |
+
"NonDemented": 2,
|
| 28 |
+
"VeryMildDemented": 3,
|
| 29 |
+
}
|
| 30 |
+
IDX_TO_CLASS: dict[int, str] = {v: k for k, v in CLASS_TO_IDX.items()}
|
| 31 |
+
|
| 32 |
+
DEFAULT_IMAGE_SIZE = 160
|
| 33 |
+
_IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 34 |
+
_IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 35 |
+
|
| 36 |
+
_TRANSFORM = transforms.Compose([
|
| 37 |
+
transforms.Resize((DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE)),
|
| 38 |
+
transforms.ToTensor(),
|
| 39 |
+
transforms.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
|
| 40 |
+
])
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _build_resnet18_4class() -> nn.Module:
|
| 44 |
+
model = models.resnet18(weights=None)
|
| 45 |
+
model.fc = nn.Linear(model.fc.in_features, len(CLASS_TO_IDX))
|
| 46 |
+
return model
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load(path: Path) -> nn.Module:
|
| 50 |
+
"""Load checkpoint. Supports state_dict (preferred) or full pickled model."""
|
| 51 |
+
path = Path(path)
|
| 52 |
+
if not path.exists():
|
| 53 |
+
raise FileNotFoundError(f"MRI 2D checkpoint not found: {path}")
|
| 54 |
+
obj = torch.load(str(path), map_location="cpu", weights_only=False)
|
| 55 |
+
if isinstance(obj, nn.Module):
|
| 56 |
+
model = obj
|
| 57 |
+
else:
|
| 58 |
+
model = _build_resnet18_4class()
|
| 59 |
+
clean = {k.removeprefix("module."): v for k, v in obj.items()}
|
| 60 |
+
model.load_state_dict(clean, strict=True)
|
| 61 |
+
model.eval()
|
| 62 |
+
return model
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def predict_image(model: nn.Module, image_path: Path) -> dict[str, Any]:
|
| 66 |
+
"""Run inference on one image. Output mirrors mri_model.predict_with_proba."""
|
| 67 |
+
image_path = Path(image_path)
|
| 68 |
+
if not image_path.exists():
|
| 69 |
+
raise FileNotFoundError(f"MRI image not found: {image_path}")
|
| 70 |
+
img = Image.open(str(image_path)).convert("RGB")
|
| 71 |
+
tensor = _TRANSFORM(img).unsqueeze(0)
|
| 72 |
+
|
| 73 |
+
with torch.inference_mode():
|
| 74 |
+
logits = model(tensor)
|
| 75 |
+
probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()
|
| 76 |
+
|
| 77 |
+
label_idx = int(np.argmax(probs))
|
| 78 |
+
return {
|
| 79 |
+
"label": label_idx,
|
| 80 |
+
"label_text": IDX_TO_CLASS[label_idx],
|
| 81 |
+
"confidence": float(probs[label_idx]),
|
| 82 |
+
"probabilities": [
|
| 83 |
+
{"label": i, "label_text": IDX_TO_CLASS[i], "probability": float(p)}
|
| 84 |
+
for i, p in enumerate(probs)
|
| 85 |
+
],
|
| 86 |
+
}
|
tests/fixtures/build_dummy_resnet18_2d.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build a randomly-initialised 4-class resnet18 state_dict for tests."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision import models
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build(path: Path) -> Path:
|
| 11 |
+
"""Save a state_dict at `path` and return the path. Idempotent."""
|
| 12 |
+
path = Path(path)
|
| 13 |
+
if path.exists():
|
| 14 |
+
return path
|
| 15 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 16 |
+
model = models.resnet18(weights=None)
|
| 17 |
+
model.fc = torch.nn.Linear(model.fc.in_features, 4)
|
| 18 |
+
torch.save(model.state_dict(), str(path))
|
| 19 |
+
return path
|
tests/models/test_mri_dl_2d.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.models.mri_dl_2d — pretrained 4-class Alzheimer's resnet18."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pytest
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from src.models import mri_dl_2d
|
| 11 |
+
from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _png(path: Path, size: tuple[int, int] = (200, 200)) -> Path:
|
| 15 |
+
arr = (np.random.RandomState(0).rand(size[1], size[0], 3) * 255).astype(np.uint8)
|
| 16 |
+
Image.fromarray(arr, mode="RGB").save(str(path))
|
| 17 |
+
return path
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestMRIDL2D:
|
| 21 |
+
def test_class_to_idx_matches_trainer(self) -> None:
|
| 22 |
+
assert mri_dl_2d.CLASS_TO_IDX == {
|
| 23 |
+
"MildDemented": 0,
|
| 24 |
+
"ModerateDemented": 1,
|
| 25 |
+
"NonDemented": 2,
|
| 26 |
+
"VeryMildDemented": 3,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
def test_idx_to_class_is_consistent(self) -> None:
|
| 30 |
+
for name, idx in mri_dl_2d.CLASS_TO_IDX.items():
|
| 31 |
+
assert mri_dl_2d.IDX_TO_CLASS[idx] == name
|
| 32 |
+
|
| 33 |
+
def test_load_missing_artifact_raises(self, tmp_path: Path) -> None:
|
| 34 |
+
with pytest.raises(FileNotFoundError, match="MRI 2D checkpoint not found"):
|
| 35 |
+
mri_dl_2d.load(tmp_path / "nope.pt")
|
| 36 |
+
|
| 37 |
+
def test_predict_image_returns_full_probs(self, tmp_path: Path) -> None:
|
| 38 |
+
ckpt = build_dummy_2d(tmp_path / "best.pt")
|
| 39 |
+
model = mri_dl_2d.load(ckpt)
|
| 40 |
+
img = _png(tmp_path / "scan.png")
|
| 41 |
+
|
| 42 |
+
result = mri_dl_2d.predict_image(model, img)
|
| 43 |
+
|
| 44 |
+
assert set(result) == {"label", "label_text", "confidence", "probabilities"}
|
| 45 |
+
assert result["label"] in {0, 1, 2, 3}
|
| 46 |
+
assert result["label_text"] in mri_dl_2d.CLASS_TO_IDX
|
| 47 |
+
assert 0.0 <= result["confidence"] <= 1.0
|
| 48 |
+
probs = result["probabilities"]
|
| 49 |
+
assert len(probs) == 4
|
| 50 |
+
assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5
|
| 51 |
+
assert {p["label_text"] for p in probs} == set(mri_dl_2d.CLASS_TO_IDX)
|
| 52 |
+
|
| 53 |
+
def test_predict_works_for_grayscale_input(self, tmp_path: Path) -> None:
|
| 54 |
+
ckpt = build_dummy_2d(tmp_path / "best.pt")
|
| 55 |
+
model = mri_dl_2d.load(ckpt)
|
| 56 |
+
gray = (np.random.RandomState(1).rand(180, 180) * 255).astype(np.uint8)
|
| 57 |
+
path = tmp_path / "gray.png"
|
| 58 |
+
Image.fromarray(gray, mode="L").save(str(path))
|
| 59 |
+
|
| 60 |
+
result = mri_dl_2d.predict_image(model, path)
|
| 61 |
+
assert 0.0 <= result["confidence"] <= 1.0
|