| """Tests for src.models.mri_selector — env-var-driven 2D / 3D dispatch.""" |
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
| import pytest |
| from PIL import Image |
|
|
| from src.models import mri_selector |
| from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_3d |
| from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d |
|
|
|
|
| _FIXTURE_MRI = Path(__file__).resolve().parents[1] / "fixtures" / "mri_sample" / "subject_0.nii.gz" |
|
|
|
|
| class TestSelector: |
| def test_default_kind_is_volumetric(self, monkeypatch) -> None: |
| monkeypatch.delenv("MRI_MODEL_KIND", raising=False) |
| assert mri_selector.current_kind() == "volumetric_onnx" |
|
|
| def test_explicit_2d_selection(self, monkeypatch) -> None: |
| monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d") |
| assert mri_selector.current_kind() == "resnet18_2d" |
|
|
| def test_unknown_kind_raises(self, monkeypatch) -> None: |
| monkeypatch.setenv("MRI_MODEL_KIND", "neural_net_supreme") |
| with pytest.raises(ValueError, match="unknown MRI_MODEL_KIND"): |
| mri_selector.current_kind() |
|
|
| def test_predict_routes_to_volumetric(self, monkeypatch, tmp_path) -> None: |
| monkeypatch.setenv("MRI_MODEL_KIND", "volumetric_onnx") |
| artifact = build_dummy_3d(tmp_path / "vol.onnx") |
| result = mri_selector.predict( |
| input_path=_FIXTURE_MRI, |
| checkpoint_path=artifact, |
| target_shape=(8, 8, 8), |
| label_names=("control", "abnormal"), |
| ) |
| assert result["label_text"] in {"control", "abnormal"} |
|
|
| def test_predict_routes_to_2d(self, monkeypatch, tmp_path) -> None: |
| monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d") |
| artifact = build_dummy_2d(tmp_path / "best.pt") |
| img_path = tmp_path / "scan.png" |
| Image.fromarray((np.random.RandomState(0).rand(160, 160, 3) * 255).astype("uint8")).save(str(img_path)) |
| result = mri_selector.predict( |
| input_path=img_path, |
| checkpoint_path=artifact, |
| ) |
| assert result["label_text"] in mri_selector.label_names_for_kind("resnet18_2d") |
|
|