hackathon / tests /models /test_mri_selector.py
mekosotto's picture
feat(models): selector dispatch for volumetric vs 2D MRI models
621cb25
"""Tests for src.models.mri_selector — env-var-driven 2D / 3D dispatch."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pytest
from PIL import Image
from src.models import mri_selector
from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_3d
from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d
_FIXTURE_MRI = Path(__file__).resolve().parents[1] / "fixtures" / "mri_sample" / "subject_0.nii.gz"
class TestSelector:
def test_default_kind_is_volumetric(self, monkeypatch) -> None:
monkeypatch.delenv("MRI_MODEL_KIND", raising=False)
assert mri_selector.current_kind() == "volumetric_onnx"
def test_explicit_2d_selection(self, monkeypatch) -> None:
monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
assert mri_selector.current_kind() == "resnet18_2d"
def test_unknown_kind_raises(self, monkeypatch) -> None:
monkeypatch.setenv("MRI_MODEL_KIND", "neural_net_supreme")
with pytest.raises(ValueError, match="unknown MRI_MODEL_KIND"):
mri_selector.current_kind()
def test_predict_routes_to_volumetric(self, monkeypatch, tmp_path) -> None:
monkeypatch.setenv("MRI_MODEL_KIND", "volumetric_onnx")
artifact = build_dummy_3d(tmp_path / "vol.onnx")
result = mri_selector.predict(
input_path=_FIXTURE_MRI,
checkpoint_path=artifact,
target_shape=(8, 8, 8),
label_names=("control", "abnormal"),
)
assert result["label_text"] in {"control", "abnormal"}
def test_predict_routes_to_2d(self, monkeypatch, tmp_path) -> None:
monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
artifact = build_dummy_2d(tmp_path / "best.pt")
img_path = tmp_path / "scan.png"
Image.fromarray((np.random.RandomState(0).rand(160, 160, 3) * 255).astype("uint8")).save(str(img_path))
result = mri_selector.predict(
input_path=img_path,
checkpoint_path=artifact,
)
assert result["label_text"] in mri_selector.label_names_for_kind("resnet18_2d")