File size: 2,143 Bytes
621cb25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Tests for src.models.mri_selector — env-var-driven 2D / 3D dispatch."""
from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest
from PIL import Image

from src.models import mri_selector
from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_3d
from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d


_FIXTURE_MRI = Path(__file__).resolve().parents[1] / "fixtures" / "mri_sample" / "subject_0.nii.gz"


class TestSelector:
    def test_default_kind_is_volumetric(self, monkeypatch) -> None:
        monkeypatch.delenv("MRI_MODEL_KIND", raising=False)
        assert mri_selector.current_kind() == "volumetric_onnx"

    def test_explicit_2d_selection(self, monkeypatch) -> None:
        monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
        assert mri_selector.current_kind() == "resnet18_2d"

    def test_unknown_kind_raises(self, monkeypatch) -> None:
        monkeypatch.setenv("MRI_MODEL_KIND", "neural_net_supreme")
        with pytest.raises(ValueError, match="unknown MRI_MODEL_KIND"):
            mri_selector.current_kind()

    def test_predict_routes_to_volumetric(self, monkeypatch, tmp_path) -> None:
        monkeypatch.setenv("MRI_MODEL_KIND", "volumetric_onnx")
        artifact = build_dummy_3d(tmp_path / "vol.onnx")
        result = mri_selector.predict(
            input_path=_FIXTURE_MRI,
            checkpoint_path=artifact,
            target_shape=(8, 8, 8),
            label_names=("control", "abnormal"),
        )
        assert result["label_text"] in {"control", "abnormal"}

    def test_predict_routes_to_2d(self, monkeypatch, tmp_path) -> None:
        monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
        artifact = build_dummy_2d(tmp_path / "best.pt")
        img_path = tmp_path / "scan.png"
        Image.fromarray((np.random.RandomState(0).rand(160, 160, 3) * 255).astype("uint8")).save(str(img_path))
        result = mri_selector.predict(
            input_path=img_path,
            checkpoint_path=artifact,
        )
        assert result["label_text"] in mri_selector.label_names_for_kind("resnet18_2d")