mekosotto commited on
Commit
621cb25
·
1 Parent(s): 11f62d8

feat(models): selector dispatch for volumetric vs 2D MRI models

Browse files
src/models/mri_selector.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Env-var-driven dispatch between volumetric ONNX and 2D resnet18 MRI models."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from src.core.logger import get_logger
9
+ from src.models import mri_dl_2d, mri_model
10
+
11
+ logger = get_logger(__name__)
12
+
13
+ VALID_KINDS = ("volumetric_onnx", "resnet18_2d")
14
+ _DEFAULT_KIND = "volumetric_onnx"
15
+
16
+
17
+ def current_kind() -> str:
18
+ kind = os.environ.get("MRI_MODEL_KIND", _DEFAULT_KIND)
19
+ if kind not in VALID_KINDS:
20
+ raise ValueError(f"unknown MRI_MODEL_KIND={kind!r}; expected one of {VALID_KINDS}")
21
+ return kind
22
+
23
+
24
+ def label_names_for_kind(kind: str) -> tuple[str, ...]:
25
+ if kind == "resnet18_2d":
26
+ return tuple(mri_dl_2d.IDX_TO_CLASS[i] for i in range(len(mri_dl_2d.CLASS_TO_IDX)))
27
+ return mri_model.DEFAULT_LABEL_NAMES
28
+
29
+
30
+ def predict(
31
+ input_path: Path,
32
+ checkpoint_path: Path,
33
+ target_shape: tuple[int, int, int] | None = None,
34
+ label_names: tuple[str, ...] | None = None,
35
+ ) -> dict[str, Any]:
36
+ """Run the active MRI model on one input. Returns the unified prediction dict."""
37
+ kind = current_kind()
38
+ logger.info("dispatching MRI prediction kind=%s input=%s", kind, input_path)
39
+ if kind == "resnet18_2d":
40
+ model = mri_dl_2d.load(checkpoint_path)
41
+ return mri_dl_2d.predict_image(model, input_path)
42
+ model = mri_model.load(checkpoint_path)
43
+ return mri_model.predict_nifti(
44
+ model,
45
+ input_path,
46
+ target_shape=target_shape or mri_model.DEFAULT_TARGET_SHAPE,
47
+ label_names=label_names,
48
+ )
tests/models/test_mri_selector.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.models.mri_selector — env-var-driven 2D / 3D dispatch."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pytest
8
+ from PIL import Image
9
+
10
+ from src.models import mri_selector
11
+ from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_3d
12
+ from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d
13
+
14
+
15
+ _FIXTURE_MRI = Path(__file__).resolve().parents[1] / "fixtures" / "mri_sample" / "subject_0.nii.gz"
16
+
17
+
18
+ class TestSelector:
19
+ def test_default_kind_is_volumetric(self, monkeypatch) -> None:
20
+ monkeypatch.delenv("MRI_MODEL_KIND", raising=False)
21
+ assert mri_selector.current_kind() == "volumetric_onnx"
22
+
23
+ def test_explicit_2d_selection(self, monkeypatch) -> None:
24
+ monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
25
+ assert mri_selector.current_kind() == "resnet18_2d"
26
+
27
+ def test_unknown_kind_raises(self, monkeypatch) -> None:
28
+ monkeypatch.setenv("MRI_MODEL_KIND", "neural_net_supreme")
29
+ with pytest.raises(ValueError, match="unknown MRI_MODEL_KIND"):
30
+ mri_selector.current_kind()
31
+
32
+ def test_predict_routes_to_volumetric(self, monkeypatch, tmp_path) -> None:
33
+ monkeypatch.setenv("MRI_MODEL_KIND", "volumetric_onnx")
34
+ artifact = build_dummy_3d(tmp_path / "vol.onnx")
35
+ result = mri_selector.predict(
36
+ input_path=_FIXTURE_MRI,
37
+ checkpoint_path=artifact,
38
+ target_shape=(8, 8, 8),
39
+ label_names=("control", "abnormal"),
40
+ )
41
+ assert result["label_text"] in {"control", "abnormal"}
42
+
43
+ def test_predict_routes_to_2d(self, monkeypatch, tmp_path) -> None:
44
+ monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
45
+ artifact = build_dummy_2d(tmp_path / "best.pt")
46
+ img_path = tmp_path / "scan.png"
47
+ Image.fromarray((np.random.RandomState(0).rand(160, 160, 3) * 255).astype("uint8")).save(str(img_path))
48
+ result = mri_selector.predict(
49
+ input_path=img_path,
50
+ checkpoint_path=artifact,
51
+ )
52
+ assert result["label_text"] in mri_selector.label_names_for_kind("resnet18_2d")