hackathon / tests /models /test_mri_dl_2d.py
mekosotto's picture
feat(models): add 2D resnet18 4-class Alzheimer's MRI inference module
11f62d8
"""Tests for src.models.mri_dl_2d — pretrained 4-class Alzheimer's resnet18."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pytest
from PIL import Image
from src.models import mri_dl_2d
from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d
def _png(path: Path, size: tuple[int, int] = (200, 200)) -> Path:
arr = (np.random.RandomState(0).rand(size[1], size[0], 3) * 255).astype(np.uint8)
Image.fromarray(arr, mode="RGB").save(str(path))
return path
class TestMRIDL2D:
def test_class_to_idx_matches_trainer(self) -> None:
assert mri_dl_2d.CLASS_TO_IDX == {
"MildDemented": 0,
"ModerateDemented": 1,
"NonDemented": 2,
"VeryMildDemented": 3,
}
def test_idx_to_class_is_consistent(self) -> None:
for name, idx in mri_dl_2d.CLASS_TO_IDX.items():
assert mri_dl_2d.IDX_TO_CLASS[idx] == name
def test_load_missing_artifact_raises(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="MRI 2D checkpoint not found"):
mri_dl_2d.load(tmp_path / "nope.pt")
def test_predict_image_returns_full_probs(self, tmp_path: Path) -> None:
ckpt = build_dummy_2d(tmp_path / "best.pt")
model = mri_dl_2d.load(ckpt)
img = _png(tmp_path / "scan.png")
result = mri_dl_2d.predict_image(model, img)
assert set(result) == {"label", "label_text", "confidence", "probabilities"}
assert result["label"] in {0, 1, 2, 3}
assert result["label_text"] in mri_dl_2d.CLASS_TO_IDX
assert 0.0 <= result["confidence"] <= 1.0
probs = result["probabilities"]
assert len(probs) == 4
assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5
assert {p["label_text"] for p in probs} == set(mri_dl_2d.CLASS_TO_IDX)
def test_predict_works_for_grayscale_input(self, tmp_path: Path) -> None:
ckpt = build_dummy_2d(tmp_path / "best.pt")
model = mri_dl_2d.load(ckpt)
gray = (np.random.RandomState(1).rand(180, 180) * 255).astype(np.uint8)
path = tmp_path / "gray.png"
Image.fromarray(gray, mode="L").save(str(path))
result = mri_dl_2d.predict_image(model, path)
assert 0.0 <= result["confidence"] <= 1.0