mekosotto commited on
Commit
11f62d8
·
1 Parent(s): e82971e

feat(models): add 2D resnet18 4-class Alzheimer's MRI inference module

Browse files
src/models/mri_dl_2d.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pretrained 2D MRI Alzheimer's classifier (resnet18, 4 classes).
2
+
3
+ Decision-layer bridge for an externally-trained PyTorch checkpoint. Loads
4
+ either a state_dict (default) or a full pickled model, applies the trainer's
5
+ preprocessing (resize image_size=160, ImageNet normalisation), and emits the
6
+ same dict shape as src.models.mri_model.predict_with_proba so downstream
7
+ code paths don't care which backend produced the prediction.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from PIL import Image
18
+ from torchvision import models, transforms
19
+
20
+ from src.core.logger import get_logger
21
+
22
+ logger = get_logger(__name__)
23
+
24
+ CLASS_TO_IDX: dict[str, int] = {
25
+ "MildDemented": 0,
26
+ "ModerateDemented": 1,
27
+ "NonDemented": 2,
28
+ "VeryMildDemented": 3,
29
+ }
30
+ IDX_TO_CLASS: dict[int, str] = {v: k for k, v in CLASS_TO_IDX.items()}
31
+
32
+ DEFAULT_IMAGE_SIZE = 160
33
+ _IMAGENET_MEAN = (0.485, 0.456, 0.406)
34
+ _IMAGENET_STD = (0.229, 0.224, 0.225)
35
+
36
+ _TRANSFORM = transforms.Compose([
37
+ transforms.Resize((DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE)),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
40
+ ])
41
+
42
+
43
+ def _build_resnet18_4class() -> nn.Module:
44
+ model = models.resnet18(weights=None)
45
+ model.fc = nn.Linear(model.fc.in_features, len(CLASS_TO_IDX))
46
+ return model
47
+
48
+
49
+ def load(path: Path) -> nn.Module:
50
+ """Load checkpoint. Supports state_dict (preferred) or full pickled model."""
51
+ path = Path(path)
52
+ if not path.exists():
53
+ raise FileNotFoundError(f"MRI 2D checkpoint not found: {path}")
54
+ obj = torch.load(str(path), map_location="cpu", weights_only=False)
55
+ if isinstance(obj, nn.Module):
56
+ model = obj
57
+ else:
58
+ model = _build_resnet18_4class()
59
+ clean = {k.removeprefix("module."): v for k, v in obj.items()}
60
+ model.load_state_dict(clean, strict=True)
61
+ model.eval()
62
+ return model
63
+
64
+
65
+ def predict_image(model: nn.Module, image_path: Path) -> dict[str, Any]:
66
+ """Run inference on one image. Output mirrors mri_model.predict_with_proba."""
67
+ image_path = Path(image_path)
68
+ if not image_path.exists():
69
+ raise FileNotFoundError(f"MRI image not found: {image_path}")
70
+ img = Image.open(str(image_path)).convert("RGB")
71
+ tensor = _TRANSFORM(img).unsqueeze(0)
72
+
73
+ with torch.inference_mode():
74
+ logits = model(tensor)
75
+ probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()
76
+
77
+ label_idx = int(np.argmax(probs))
78
+ return {
79
+ "label": label_idx,
80
+ "label_text": IDX_TO_CLASS[label_idx],
81
+ "confidence": float(probs[label_idx]),
82
+ "probabilities": [
83
+ {"label": i, "label_text": IDX_TO_CLASS[i], "probability": float(p)}
84
+ for i, p in enumerate(probs)
85
+ ],
86
+ }
tests/fixtures/build_dummy_resnet18_2d.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build a randomly-initialised 4-class resnet18 state_dict for tests."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from torchvision import models
8
+
9
+
10
+ def build(path: Path) -> Path:
11
+ """Save a state_dict at `path` and return the path. Idempotent."""
12
+ path = Path(path)
13
+ if path.exists():
14
+ return path
15
+ path.parent.mkdir(parents=True, exist_ok=True)
16
+ model = models.resnet18(weights=None)
17
+ model.fc = torch.nn.Linear(model.fc.in_features, 4)
18
+ torch.save(model.state_dict(), str(path))
19
+ return path
tests/models/test_mri_dl_2d.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.models.mri_dl_2d — pretrained 4-class Alzheimer's resnet18."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pytest
8
+ from PIL import Image
9
+
10
+ from src.models import mri_dl_2d
11
+ from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d
12
+
13
+
14
+ def _png(path: Path, size: tuple[int, int] = (200, 200)) -> Path:
15
+ arr = (np.random.RandomState(0).rand(size[1], size[0], 3) * 255).astype(np.uint8)
16
+ Image.fromarray(arr, mode="RGB").save(str(path))
17
+ return path
18
+
19
+
20
+ class TestMRIDL2D:
21
+ def test_class_to_idx_matches_trainer(self) -> None:
22
+ assert mri_dl_2d.CLASS_TO_IDX == {
23
+ "MildDemented": 0,
24
+ "ModerateDemented": 1,
25
+ "NonDemented": 2,
26
+ "VeryMildDemented": 3,
27
+ }
28
+
29
+ def test_idx_to_class_is_consistent(self) -> None:
30
+ for name, idx in mri_dl_2d.CLASS_TO_IDX.items():
31
+ assert mri_dl_2d.IDX_TO_CLASS[idx] == name
32
+
33
+ def test_load_missing_artifact_raises(self, tmp_path: Path) -> None:
34
+ with pytest.raises(FileNotFoundError, match="MRI 2D checkpoint not found"):
35
+ mri_dl_2d.load(tmp_path / "nope.pt")
36
+
37
+ def test_predict_image_returns_full_probs(self, tmp_path: Path) -> None:
38
+ ckpt = build_dummy_2d(tmp_path / "best.pt")
39
+ model = mri_dl_2d.load(ckpt)
40
+ img = _png(tmp_path / "scan.png")
41
+
42
+ result = mri_dl_2d.predict_image(model, img)
43
+
44
+ assert set(result) == {"label", "label_text", "confidence", "probabilities"}
45
+ assert result["label"] in {0, 1, 2, 3}
46
+ assert result["label_text"] in mri_dl_2d.CLASS_TO_IDX
47
+ assert 0.0 <= result["confidence"] <= 1.0
48
+ probs = result["probabilities"]
49
+ assert len(probs) == 4
50
+ assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5
51
+ assert {p["label_text"] for p in probs} == set(mri_dl_2d.CLASS_TO_IDX)
52
+
53
+ def test_predict_works_for_grayscale_input(self, tmp_path: Path) -> None:
54
+ ckpt = build_dummy_2d(tmp_path / "best.pt")
55
+ model = mri_dl_2d.load(ckpt)
56
+ gray = (np.random.RandomState(1).rand(180, 180) * 255).astype(np.uint8)
57
+ path = tmp_path / "gray.png"
58
+ Image.fromarray(gray, mode="L").save(str(path))
59
+
60
+ result = mri_dl_2d.predict_image(model, path)
61
+ assert 0.0 <= result["confidence"] <= 1.0