File size: 2,946 Bytes
c0a7163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ae5b40
 
 
 
 
 
 
 
 
 
 
a3b6bb6
 
 
 
 
 
9ae5b40
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Tests for src.models.mri_model — image-based MRI DL inference surface."""
from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest

from src.models import mri_model
from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_mri_onnx


_FIXTURE_MRI = Path(__file__).resolve().parents[1] / "fixtures" / "mri_sample" / "subject_0.nii.gz"


class TestMRIDLModel:
    def test_preprocess_volume_returns_batch_channel_tensor(self) -> None:
        volume = np.ones((4, 5, 6), dtype=np.float32)
        volume[1:3, 1:4, 2:5] = 5.0

        out = mri_model.preprocess_volume(volume, target_shape=(8, 8, 8))

        assert out.shape == (1, 1, 8, 8, 8)
        assert out.dtype == np.float32
        assert np.all(np.isfinite(out))

    def test_preprocess_rejects_nan_volume(self) -> None:
        volume = np.zeros((4, 4, 4), dtype=np.float32)
        volume[0, 0, 0] = np.nan

        with pytest.raises(ValueError, match="finite numeric 3-D"):
            mri_model.preprocess_volume(volume, target_shape=(8, 8, 8))

    def test_load_missing_artifact_raises(self, tmp_path: Path) -> None:
        with pytest.raises(FileNotFoundError, match="MRI model artifact not found"):
            mri_model.load(tmp_path / "missing.onnx")

    def test_predict_nifti_with_dummy_onnx(self, tmp_path: Path) -> None:
        artifact = build_dummy_mri_onnx(tmp_path / "mri_model.onnx")
        model = mri_model.load(artifact)

        result = mri_model.predict_nifti(
            model,
            _FIXTURE_MRI,
            target_shape=(8, 8, 8),
            label_names=("control", "abnormal"),
        )

        assert result["label"] == 1
        assert result["label_text"] == "abnormal"
        assert result["confidence"] > 0.5
        probs = result["probabilities"]
        assert len(probs) == 2
        assert sum(p["probability"] for p in probs) == pytest.approx(1.0, abs=1e-6)

    def test_predict_warns_on_label_count_mismatch(
        self, tmp_path: Path, caplog: pytest.LogCaptureFixture
    ) -> None:
        artifact = build_dummy_mri_onnx(tmp_path / "mri_model.onnx")
        model = mri_model.load(artifact)

        # mri_model.logger has propagate=False (src/core/logger.py), so pytest's
        # caplog root handler never sees its records. Attach caplog.handler directly.
        mri_model.logger.addHandler(caplog.handler)
        try:
            result = mri_model.predict_nifti(
                model,
                _FIXTURE_MRI,
                target_shape=(8, 8, 8),
                label_names=("control", "abnormal", "extra"),
            )
        finally:
            mri_model.logger.removeHandler(caplog.handler)

        assert result["label_text"] in {"class_0", "class_1"}
        assert any(
            "label_names length" in rec.message and "overriding" in rec.message
            for rec in caplog.records
        ), [rec.message for rec in caplog.records]