feat(frontend): expose MRI_MODEL_KIND in MRI predict tab; 2D upload path
Browse files- src/frontend/app.py +90 -50
src/frontend/app.py
CHANGED
|
@@ -1319,60 +1319,100 @@ def _render_mri_tab() -> None:
|
|
| 1319 |
st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}")
|
| 1320 |
|
| 1321 |
st.markdown("#### MRI Image Model")
|
| 1322 |
-
|
| 1323 |
-
"NIfTI image",
|
| 1324 |
-
"tests/fixtures/mri_sample/subject_0.nii.gz",
|
| 1325 |
-
key="mri_predict_image",
|
| 1326 |
-
)
|
| 1327 |
-
mri_labels = st.text_input(
|
| 1328 |
-
"Class labels",
|
| 1329 |
-
"control,abnormal",
|
| 1330 |
-
key="mri_predict_labels",
|
| 1331 |
-
)
|
| 1332 |
-
shape_cols = st.columns(3)
|
| 1333 |
-
target_d = shape_cols[0].number_input(
|
| 1334 |
-
"Resize D", min_value=1, max_value=256, value=64, step=1, key="mri_predict_d"
|
| 1335 |
-
)
|
| 1336 |
-
target_h = shape_cols[1].number_input(
|
| 1337 |
-
"Resize H", min_value=1, max_value=256, value=64, step=1, key="mri_predict_h"
|
| 1338 |
-
)
|
| 1339 |
-
target_w = shape_cols[2].number_input(
|
| 1340 |
-
"Resize W", min_value=1, max_value=256, value=64, step=1, key="mri_predict_w"
|
| 1341 |
-
)
|
| 1342 |
st.caption(
|
| 1343 |
-
"
|
| 1344 |
-
"
|
| 1345 |
)
|
| 1346 |
-
|
| 1347 |
-
|
| 1348 |
-
|
| 1349 |
-
"
|
| 1350 |
-
"
|
| 1351 |
-
|
| 1352 |
-
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
-
|
| 1361 |
-
|
| 1362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1363 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1364 |
else:
|
| 1365 |
-
st.
|
| 1366 |
-
|
| 1367 |
-
|
| 1368 |
-
|
| 1369 |
-
|
| 1370 |
-
|
| 1371 |
-
|
| 1372 |
-
)
|
| 1373 |
-
probs = result.get("probabilities", [])
|
| 1374 |
-
if probs:
|
| 1375 |
-
st.dataframe(probs, use_container_width=True, hide_index=True)
|
| 1376 |
|
| 1377 |
|
| 1378 |
def _render_prediction_card(result: dict) -> None:
|
|
|
|
| 1319 |
st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}")
|
| 1320 |
|
| 1321 |
st.markdown("#### MRI Image Model")
|
| 1322 |
+
mri_kind = os.environ.get("MRI_MODEL_KIND", "volumetric_onnx")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1323 |
st.caption(
|
| 1324 |
+
f"Active backend: `{mri_kind}` — set `MRI_MODEL_KIND=resnet18_2d` "
|
| 1325 |
+
"to switch to the 2D 4-class Alzheimer's classifier."
|
| 1326 |
)
|
| 1327 |
+
|
| 1328 |
+
if mri_kind == "resnet18_2d":
|
| 1329 |
+
mri_image = st.text_input(
|
| 1330 |
+
"2D MRI image (.png/.jpg)",
|
| 1331 |
+
"tests/fixtures/mri_sample/subject_0_axial.png",
|
| 1332 |
+
key="mri_predict_image",
|
| 1333 |
+
)
|
| 1334 |
+
st.caption(
|
| 1335 |
+
"Resnet18 4-class — labels: MildDemented, ModerateDemented, "
|
| 1336 |
+
"NonDemented, VeryMildDemented. Resize/labels are baked into the model."
|
| 1337 |
+
)
|
| 1338 |
+
if st.button("Predict MRI image", key="mri_predict"):
|
| 1339 |
+
payload = {"input_path": mri_image}
|
| 1340 |
+
with st.spinner("Running 2D MRI model..."):
|
| 1341 |
+
try:
|
| 1342 |
+
result = _post("/predict/mri", payload, timeout=120.0)
|
| 1343 |
+
except httpx.HTTPStatusError as e:
|
| 1344 |
+
if e.response.status_code == 503:
|
| 1345 |
+
st.warning(
|
| 1346 |
+
"MRI 2D model artifact missing. Drop the trained checkpoint at "
|
| 1347 |
+
"`data/processed/mri_dl_2d/best_model.pt` or set `MRI_MODEL_PATH_2D`."
|
| 1348 |
+
)
|
| 1349 |
+
else:
|
| 1350 |
+
st.error(f"MRI prediction failed (HTTP {e.response.status_code}): {e.response.text}")
|
| 1351 |
+
except httpx.RequestError as e:
|
| 1352 |
+
st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}")
|
| 1353 |
+
else:
|
| 1354 |
+
st.metric(
|
| 1355 |
+
label=result.get("label_text", "prediction"),
|
| 1356 |
+
value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%",
|
| 1357 |
)
|
| 1358 |
+
probs = result.get("probabilities", [])
|
| 1359 |
+
if probs:
|
| 1360 |
+
st.dataframe(probs, use_container_width=True, hide_index=True)
|
| 1361 |
+
else:
|
| 1362 |
+
mri_image = st.text_input(
|
| 1363 |
+
"NIfTI image",
|
| 1364 |
+
"tests/fixtures/mri_sample/subject_0.nii.gz",
|
| 1365 |
+
key="mri_predict_image",
|
| 1366 |
+
)
|
| 1367 |
+
mri_labels = st.text_input(
|
| 1368 |
+
"Class labels",
|
| 1369 |
+
"control,abnormal",
|
| 1370 |
+
key="mri_predict_labels",
|
| 1371 |
+
)
|
| 1372 |
+
shape_cols = st.columns(3)
|
| 1373 |
+
target_d = shape_cols[0].number_input(
|
| 1374 |
+
"Resize D", min_value=1, max_value=256, value=64, step=1, key="mri_predict_d"
|
| 1375 |
+
)
|
| 1376 |
+
target_h = shape_cols[1].number_input(
|
| 1377 |
+
"Resize H", min_value=1, max_value=256, value=64, step=1, key="mri_predict_h"
|
| 1378 |
+
)
|
| 1379 |
+
target_w = shape_cols[2].number_input(
|
| 1380 |
+
"Resize W", min_value=1, max_value=256, value=64, step=1, key="mri_predict_w"
|
| 1381 |
+
)
|
| 1382 |
+
st.caption(
|
| 1383 |
+
"Defaults to 64³ for production exports. Use 8³ when testing with the "
|
| 1384 |
+
"dummy ONNX fixture from `tests/fixtures/build_dummy_mri_onnx.py`."
|
| 1385 |
+
)
|
| 1386 |
+
if st.button("Predict MRI image", key="mri_predict"):
|
| 1387 |
+
labels = [x.strip() for x in mri_labels.split(",") if x.strip()]
|
| 1388 |
+
payload: dict = {
|
| 1389 |
+
"input_path": mri_image,
|
| 1390 |
+
"target_shape": [int(target_d), int(target_h), int(target_w)],
|
| 1391 |
+
}
|
| 1392 |
+
if labels:
|
| 1393 |
+
payload["label_names"] = labels
|
| 1394 |
+
with st.spinner("Running MRI image model..."):
|
| 1395 |
+
try:
|
| 1396 |
+
result = _post("/predict/mri", payload, timeout=120.0)
|
| 1397 |
+
except httpx.HTTPStatusError as e:
|
| 1398 |
+
detail = e.response.text
|
| 1399 |
+
if e.response.status_code == 503:
|
| 1400 |
+
st.warning(
|
| 1401 |
+
"MRI model artifact is not available yet. Export the trained "
|
| 1402 |
+
"ONNX model to `data/processed/mri_model.onnx` or set `MRI_MODEL_PATH`."
|
| 1403 |
+
)
|
| 1404 |
+
else:
|
| 1405 |
+
st.error(f"MRI prediction failed (HTTP {e.response.status_code}): {detail}")
|
| 1406 |
+
except httpx.RequestError as e:
|
| 1407 |
+
st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}")
|
| 1408 |
else:
|
| 1409 |
+
st.metric(
|
| 1410 |
+
label=result.get("label_text", "prediction"),
|
| 1411 |
+
value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%",
|
| 1412 |
+
)
|
| 1413 |
+
probs = result.get("probabilities", [])
|
| 1414 |
+
if probs:
|
| 1415 |
+
st.dataframe(probs, use_container_width=True, hide_index=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1416 |
|
| 1417 |
|
| 1418 |
def _render_prediction_card(result: dict) -> None:
|