File size: 1,626 Bytes
10ed38c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | """Integration: POST /predict/mri with MRI_MODEL_KIND=resnet18_2d."""
from __future__ import annotations
import numpy as np
import pytest
from fastapi.testclient import TestClient
from PIL import Image
from src.api.main import app
from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d
@pytest.fixture()
def client_2d(monkeypatch, tmp_path):
monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
ckpt = build_dummy_2d(tmp_path / "best.pt")
monkeypatch.setenv("MRI_MODEL_PATH_2D", str(ckpt))
return TestClient(app)
def test_predict_mri_2d_happy_path(client_2d, tmp_path):
img_path = tmp_path / "scan.png"
Image.fromarray((np.random.RandomState(0).rand(170, 170, 3) * 255).astype("uint8")).save(str(img_path))
r = client_2d.post("/predict/mri", json={"input_path": str(img_path)})
assert r.status_code == 200, r.text
data = r.json()
assert data["label_text"] in {
"MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented",
}
assert 0.0 <= data["confidence"] <= 1.0
assert len(data["probabilities"]) == 4
def test_predict_mri_2d_missing_artifact_returns_503(monkeypatch, tmp_path):
monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
monkeypatch.setenv("MRI_MODEL_PATH_2D", str(tmp_path / "missing.pt"))
img_path = tmp_path / "scan.png"
Image.fromarray((np.random.RandomState(0).rand(170, 170, 3) * 255).astype("uint8")).save(str(img_path))
client = TestClient(app)
r = client.post("/predict/mri", json={"input_path": str(img_path)})
assert r.status_code == 503
assert "kind=resnet18_2d" in r.text
|