| """Integration: POST /predict/mri with MRI_MODEL_KIND=resnet18_2d.""" |
| from __future__ import annotations |
|
|
| import numpy as np |
| import pytest |
| from fastapi.testclient import TestClient |
| from PIL import Image |
|
|
| from src.api.main import app |
| from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d |
|
|
|
|
| @pytest.fixture() |
| def client_2d(monkeypatch, tmp_path): |
| monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d") |
| ckpt = build_dummy_2d(tmp_path / "best.pt") |
| monkeypatch.setenv("MRI_MODEL_PATH_2D", str(ckpt)) |
| return TestClient(app) |
|
|
|
|
| def test_predict_mri_2d_happy_path(client_2d, tmp_path): |
| img_path = tmp_path / "scan.png" |
| Image.fromarray((np.random.RandomState(0).rand(170, 170, 3) * 255).astype("uint8")).save(str(img_path)) |
|
|
| r = client_2d.post("/predict/mri", json={"input_path": str(img_path)}) |
| assert r.status_code == 200, r.text |
| data = r.json() |
| assert data["label_text"] in { |
| "MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented", |
| } |
| assert 0.0 <= data["confidence"] <= 1.0 |
| assert len(data["probabilities"]) == 4 |
|
|
|
|
| def test_predict_mri_2d_missing_artifact_returns_503(monkeypatch, tmp_path): |
| monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d") |
| monkeypatch.setenv("MRI_MODEL_PATH_2D", str(tmp_path / "missing.pt")) |
| img_path = tmp_path / "scan.png" |
| Image.fromarray((np.random.RandomState(0).rand(170, 170, 3) * 255).astype("uint8")).save(str(img_path)) |
| client = TestClient(app) |
| r = client.post("/predict/mri", json={"input_path": str(img_path)}) |
| assert r.status_code == 503 |
| assert "kind=resnet18_2d" in r.text |
|
|