mekosotto commited on
Commit
10ed38c
·
1 Parent(s): 621cb25

feat(api): dispatch /predict/mri via MRI_MODEL_KIND env var

Browse files
src/api/routes.py CHANGED
@@ -319,25 +319,44 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
319
 
320
  @predict_router.post("/mri", response_model=MRIPredictResponse)
321
  def predict_mri(req: MRIPredictRequest) -> MRIPredictResponse:
322
- """Predict from one MRI NIfTI image using an externally-trained ONNX model."""
323
- artifact = _mri_model_path()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  if not artifact.exists():
325
  raise HTTPException(
326
  status_code=503,
327
  detail=(
328
- f"MRI model artifact not available at {artifact}. "
329
- "Export the trained volumetric model to ONNX and place it there, "
330
- "or set MRI_MODEL_PATH."
331
  ),
332
  )
333
  try:
334
- model = mri_model.load(artifact)
335
- pred = mri_model.predict_nifti(
336
- model,
337
- Path(req.input_path),
338
- target_shape=req.target_shape,
339
- label_names=req.label_names,
340
- )
 
 
 
 
 
341
  except FileNotFoundError as e:
342
  raise HTTPException(status_code=404, detail=str(e))
343
  except ValueError as e:
 
319
 
320
  @predict_router.post("/mri", response_model=MRIPredictResponse)
321
  def predict_mri(req: MRIPredictRequest) -> MRIPredictResponse:
322
+ """Predict from one MRI image. Backend selected by MRI_MODEL_KIND env.
323
+
324
+ - `volumetric_onnx` (default): NIfTI volume + externally-trained ONNX.
325
+ - `resnet18_2d`: 2D image (.png/.jpg) + PyTorch state_dict, 4-class
326
+ Alzheimer's classifier (MildDemented/ModerateDemented/NonDemented/VeryMildDemented).
327
+ """
328
+ from src.models import mri_selector
329
+
330
+ kind = mri_selector.current_kind()
331
+ if kind == "resnet18_2d":
332
+ artifact = Path(os.environ.get(
333
+ "MRI_MODEL_PATH_2D", "data/processed/mri_dl_2d/best_model.pt",
334
+ ))
335
+ else:
336
+ artifact = _mri_model_path()
337
+
338
  if not artifact.exists():
339
  raise HTTPException(
340
  status_code=503,
341
  detail=(
342
+ f"MRI model artifact not available at {artifact} (kind={kind}). "
343
+ "Drop the trained checkpoint at this path, or override the path "
344
+ "via MRI_MODEL_PATH (3D ONNX) or MRI_MODEL_PATH_2D (2D resnet18)."
345
  ),
346
  )
347
  try:
348
+ if kind == "resnet18_2d":
349
+ pred = mri_selector.predict(
350
+ input_path=Path(req.input_path),
351
+ checkpoint_path=artifact,
352
+ )
353
+ else:
354
+ pred = mri_selector.predict(
355
+ input_path=Path(req.input_path),
356
+ checkpoint_path=artifact,
357
+ target_shape=tuple(req.target_shape),
358
+ label_names=tuple(req.label_names) if req.label_names else None,
359
+ )
360
  except FileNotFoundError as e:
361
  raise HTTPException(status_code=404, detail=str(e))
362
  except ValueError as e:
src/api/schemas.py CHANGED
@@ -115,7 +115,14 @@ class BBBPredictResponse(BaseModel):
115
 
116
  class MRIPredictRequest(BaseModel):
117
  """Single-subject MRI image prediction request."""
118
- input_path: str = Field(..., description="Path to one .nii or .nii.gz MRI volume")
 
 
 
 
 
 
 
119
  target_shape: tuple[int, int, int] = Field(
120
  (64, 64, 64),
121
  description="Model preprocessing resize target as (D, H, W)",
 
115
 
116
  class MRIPredictRequest(BaseModel):
117
  """Single-subject MRI image prediction request."""
118
+ input_path: str = Field(
119
+ ...,
120
+ description=(
121
+ "Path to MRI input. With MRI_MODEL_KIND=volumetric_onnx (default), "
122
+ "expects a .nii/.nii.gz volume. With MRI_MODEL_KIND=resnet18_2d, "
123
+ "expects a 2D image (.png/.jpg)."
124
+ ),
125
+ )
126
  target_shape: tuple[int, int, int] = Field(
127
  (64, 64, 64),
128
  description="Model preprocessing resize target as (D, H, W)",
tests/api/test_mri_2d_route.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration: POST /predict/mri with MRI_MODEL_KIND=resnet18_2d."""
2
+ from __future__ import annotations
3
+
4
+ import numpy as np
5
+ import pytest
6
+ from fastapi.testclient import TestClient
7
+ from PIL import Image
8
+
9
+ from src.api.main import app
10
+ from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d
11
+
12
+
13
+ @pytest.fixture()
14
+ def client_2d(monkeypatch, tmp_path):
15
+ monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
16
+ ckpt = build_dummy_2d(tmp_path / "best.pt")
17
+ monkeypatch.setenv("MRI_MODEL_PATH_2D", str(ckpt))
18
+ return TestClient(app)
19
+
20
+
21
+ def test_predict_mri_2d_happy_path(client_2d, tmp_path):
22
+ img_path = tmp_path / "scan.png"
23
+ Image.fromarray((np.random.RandomState(0).rand(170, 170, 3) * 255).astype("uint8")).save(str(img_path))
24
+
25
+ r = client_2d.post("/predict/mri", json={"input_path": str(img_path)})
26
+ assert r.status_code == 200, r.text
27
+ data = r.json()
28
+ assert data["label_text"] in {
29
+ "MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented",
30
+ }
31
+ assert 0.0 <= data["confidence"] <= 1.0
32
+ assert len(data["probabilities"]) == 4
33
+
34
+
35
+ def test_predict_mri_2d_missing_artifact_returns_503(monkeypatch, tmp_path):
36
+ monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
37
+ monkeypatch.setenv("MRI_MODEL_PATH_2D", str(tmp_path / "missing.pt"))
38
+ img_path = tmp_path / "scan.png"
39
+ Image.fromarray((np.random.RandomState(0).rand(170, 170, 3) * 255).astype("uint8")).save(str(img_path))
40
+ client = TestClient(app)
41
+ r = client.post("/predict/mri", json={"input_path": str(img_path)})
42
+ assert r.status_code == 503
43
+ assert "kind=resnet18_2d" in r.text