feat(api): dispatch /predict/mri via MRI_MODEL_KIND env var
Browse files- src/api/routes.py +31 -12
- src/api/schemas.py +8 -1
- tests/api/test_mri_2d_route.py +43 -0
src/api/routes.py
CHANGED
|
@@ -319,25 +319,44 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
|
|
| 319 |
|
| 320 |
@predict_router.post("/mri", response_model=MRIPredictResponse)
|
| 321 |
def predict_mri(req: MRIPredictRequest) -> MRIPredictResponse:
|
| 322 |
-
"""Predict from one MRI
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
if not artifact.exists():
|
| 325 |
raise HTTPException(
|
| 326 |
status_code=503,
|
| 327 |
detail=(
|
| 328 |
-
f"MRI model artifact not available at {artifact}. "
|
| 329 |
-
"
|
| 330 |
-
"or
|
| 331 |
),
|
| 332 |
)
|
| 333 |
try:
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
except FileNotFoundError as e:
|
| 342 |
raise HTTPException(status_code=404, detail=str(e))
|
| 343 |
except ValueError as e:
|
|
|
|
| 319 |
|
| 320 |
@predict_router.post("/mri", response_model=MRIPredictResponse)
|
| 321 |
def predict_mri(req: MRIPredictRequest) -> MRIPredictResponse:
|
| 322 |
+
"""Predict from one MRI image. Backend selected by MRI_MODEL_KIND env.
|
| 323 |
+
|
| 324 |
+
- `volumetric_onnx` (default): NIfTI volume + externally-trained ONNX.
|
| 325 |
+
- `resnet18_2d`: 2D image (.png/.jpg) + PyTorch state_dict, 4-class
|
| 326 |
+
Alzheimer's classifier (MildDemented/ModerateDemented/NonDemented/VeryMildDemented).
|
| 327 |
+
"""
|
| 328 |
+
from src.models import mri_selector
|
| 329 |
+
|
| 330 |
+
kind = mri_selector.current_kind()
|
| 331 |
+
if kind == "resnet18_2d":
|
| 332 |
+
artifact = Path(os.environ.get(
|
| 333 |
+
"MRI_MODEL_PATH_2D", "data/processed/mri_dl_2d/best_model.pt",
|
| 334 |
+
))
|
| 335 |
+
else:
|
| 336 |
+
artifact = _mri_model_path()
|
| 337 |
+
|
| 338 |
if not artifact.exists():
|
| 339 |
raise HTTPException(
|
| 340 |
status_code=503,
|
| 341 |
detail=(
|
| 342 |
+
f"MRI model artifact not available at {artifact} (kind={kind}). "
|
| 343 |
+
"Drop the trained checkpoint at this path, or override the path "
|
| 344 |
+
"via MRI_MODEL_PATH (3D ONNX) or MRI_MODEL_PATH_2D (2D resnet18)."
|
| 345 |
),
|
| 346 |
)
|
| 347 |
try:
|
| 348 |
+
if kind == "resnet18_2d":
|
| 349 |
+
pred = mri_selector.predict(
|
| 350 |
+
input_path=Path(req.input_path),
|
| 351 |
+
checkpoint_path=artifact,
|
| 352 |
+
)
|
| 353 |
+
else:
|
| 354 |
+
pred = mri_selector.predict(
|
| 355 |
+
input_path=Path(req.input_path),
|
| 356 |
+
checkpoint_path=artifact,
|
| 357 |
+
target_shape=tuple(req.target_shape),
|
| 358 |
+
label_names=tuple(req.label_names) if req.label_names else None,
|
| 359 |
+
)
|
| 360 |
except FileNotFoundError as e:
|
| 361 |
raise HTTPException(status_code=404, detail=str(e))
|
| 362 |
except ValueError as e:
|
src/api/schemas.py
CHANGED
|
@@ -115,7 +115,14 @@ class BBBPredictResponse(BaseModel):
|
|
| 115 |
|
| 116 |
class MRIPredictRequest(BaseModel):
|
| 117 |
"""Single-subject MRI image prediction request."""
|
| 118 |
-
input_path: str = Field(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
target_shape: tuple[int, int, int] = Field(
|
| 120 |
(64, 64, 64),
|
| 121 |
description="Model preprocessing resize target as (D, H, W)",
|
|
|
|
| 115 |
|
| 116 |
class MRIPredictRequest(BaseModel):
|
| 117 |
"""Single-subject MRI image prediction request."""
|
| 118 |
+
input_path: str = Field(
|
| 119 |
+
...,
|
| 120 |
+
description=(
|
| 121 |
+
"Path to MRI input. With MRI_MODEL_KIND=volumetric_onnx (default), "
|
| 122 |
+
"expects a .nii/.nii.gz volume. With MRI_MODEL_KIND=resnet18_2d, "
|
| 123 |
+
"expects a 2D image (.png/.jpg)."
|
| 124 |
+
),
|
| 125 |
+
)
|
| 126 |
target_shape: tuple[int, int, int] = Field(
|
| 127 |
(64, 64, 64),
|
| 128 |
description="Model preprocessing resize target as (D, H, W)",
|
tests/api/test_mri_2d_route.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration: POST /predict/mri with MRI_MODEL_KIND=resnet18_2d."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pytest
|
| 6 |
+
from fastapi.testclient import TestClient
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from src.api.main import app
|
| 10 |
+
from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@pytest.fixture()
|
| 14 |
+
def client_2d(monkeypatch, tmp_path):
|
| 15 |
+
monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
|
| 16 |
+
ckpt = build_dummy_2d(tmp_path / "best.pt")
|
| 17 |
+
monkeypatch.setenv("MRI_MODEL_PATH_2D", str(ckpt))
|
| 18 |
+
return TestClient(app)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_predict_mri_2d_happy_path(client_2d, tmp_path):
|
| 22 |
+
img_path = tmp_path / "scan.png"
|
| 23 |
+
Image.fromarray((np.random.RandomState(0).rand(170, 170, 3) * 255).astype("uint8")).save(str(img_path))
|
| 24 |
+
|
| 25 |
+
r = client_2d.post("/predict/mri", json={"input_path": str(img_path)})
|
| 26 |
+
assert r.status_code == 200, r.text
|
| 27 |
+
data = r.json()
|
| 28 |
+
assert data["label_text"] in {
|
| 29 |
+
"MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented",
|
| 30 |
+
}
|
| 31 |
+
assert 0.0 <= data["confidence"] <= 1.0
|
| 32 |
+
assert len(data["probabilities"]) == 4
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_predict_mri_2d_missing_artifact_returns_503(monkeypatch, tmp_path):
|
| 36 |
+
monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
|
| 37 |
+
monkeypatch.setenv("MRI_MODEL_PATH_2D", str(tmp_path / "missing.pt"))
|
| 38 |
+
img_path = tmp_path / "scan.png"
|
| 39 |
+
Image.fromarray((np.random.RandomState(0).rand(170, 170, 3) * 255).astype("uint8")).save(str(img_path))
|
| 40 |
+
client = TestClient(app)
|
| 41 |
+
r = client.post("/predict/mri", json={"input_path": str(img_path)})
|
| 42 |
+
assert r.status_code == 503
|
| 43 |
+
assert "kind=resnet18_2d" in r.text
|