File size: 2,342 Bytes
11f62d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Tests for src.models.mri_dl_2d — pretrained 4-class Alzheimer's resnet18."""
from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest
from PIL import Image

from src.models import mri_dl_2d
from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d


def _png(path: Path, size: tuple[int, int] = (200, 200)) -> Path:
    arr = (np.random.RandomState(0).rand(size[1], size[0], 3) * 255).astype(np.uint8)
    Image.fromarray(arr, mode="RGB").save(str(path))
    return path


class TestMRIDL2D:
    def test_class_to_idx_matches_trainer(self) -> None:
        assert mri_dl_2d.CLASS_TO_IDX == {
            "MildDemented": 0,
            "ModerateDemented": 1,
            "NonDemented": 2,
            "VeryMildDemented": 3,
        }

    def test_idx_to_class_is_consistent(self) -> None:
        for name, idx in mri_dl_2d.CLASS_TO_IDX.items():
            assert mri_dl_2d.IDX_TO_CLASS[idx] == name

    def test_load_missing_artifact_raises(self, tmp_path: Path) -> None:
        with pytest.raises(FileNotFoundError, match="MRI 2D checkpoint not found"):
            mri_dl_2d.load(tmp_path / "nope.pt")

    def test_predict_image_returns_full_probs(self, tmp_path: Path) -> None:
        ckpt = build_dummy_2d(tmp_path / "best.pt")
        model = mri_dl_2d.load(ckpt)
        img = _png(tmp_path / "scan.png")

        result = mri_dl_2d.predict_image(model, img)

        assert set(result) == {"label", "label_text", "confidence", "probabilities"}
        assert result["label"] in {0, 1, 2, 3}
        assert result["label_text"] in mri_dl_2d.CLASS_TO_IDX
        assert 0.0 <= result["confidence"] <= 1.0
        probs = result["probabilities"]
        assert len(probs) == 4
        assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5
        assert {p["label_text"] for p in probs} == set(mri_dl_2d.CLASS_TO_IDX)

    def test_predict_works_for_grayscale_input(self, tmp_path: Path) -> None:
        ckpt = build_dummy_2d(tmp_path / "best.pt")
        model = mri_dl_2d.load(ckpt)
        gray = (np.random.RandomState(1).rand(180, 180) * 255).astype(np.uint8)
        path = tmp_path / "gray.png"
        Image.fromarray(gray, mode="L").save(str(path))

        result = mri_dl_2d.predict_image(model, path)
        assert 0.0 <= result["confidence"] <= 1.0