"""Tests for src.models.mri_dl_2d — pretrained 4-class Alzheimer's resnet18.""" from __future__ import annotations from pathlib import Path import numpy as np import pytest from PIL import Image from src.models import mri_dl_2d from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d def _png(path: Path, size: tuple[int, int] = (200, 200)) -> Path: arr = (np.random.RandomState(0).rand(size[1], size[0], 3) * 255).astype(np.uint8) Image.fromarray(arr, mode="RGB").save(str(path)) return path class TestMRIDL2D: def test_class_to_idx_matches_trainer(self) -> None: assert mri_dl_2d.CLASS_TO_IDX == { "MildDemented": 0, "ModerateDemented": 1, "NonDemented": 2, "VeryMildDemented": 3, } def test_idx_to_class_is_consistent(self) -> None: for name, idx in mri_dl_2d.CLASS_TO_IDX.items(): assert mri_dl_2d.IDX_TO_CLASS[idx] == name def test_load_missing_artifact_raises(self, tmp_path: Path) -> None: with pytest.raises(FileNotFoundError, match="MRI 2D checkpoint not found"): mri_dl_2d.load(tmp_path / "nope.pt") def test_predict_image_returns_full_probs(self, tmp_path: Path) -> None: ckpt = build_dummy_2d(tmp_path / "best.pt") model = mri_dl_2d.load(ckpt) img = _png(tmp_path / "scan.png") result = mri_dl_2d.predict_image(model, img) assert set(result) == {"label", "label_text", "confidence", "probabilities"} assert result["label"] in {0, 1, 2, 3} assert result["label_text"] in mri_dl_2d.CLASS_TO_IDX assert 0.0 <= result["confidence"] <= 1.0 probs = result["probabilities"] assert len(probs) == 4 assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5 assert {p["label_text"] for p in probs} == set(mri_dl_2d.CLASS_TO_IDX) def test_predict_works_for_grayscale_input(self, tmp_path: Path) -> None: ckpt = build_dummy_2d(tmp_path / "best.pt") model = mri_dl_2d.load(ckpt) gray = (np.random.RandomState(1).rand(180, 180) * 255).astype(np.uint8) path = tmp_path / "gray.png" Image.fromarray(gray, mode="L").save(str(path)) result = mri_dl_2d.predict_image(model, path) assert 0.0 <= result["confidence"] <= 1.0