mekosotto commited on
Commit
e2239d7
·
1 Parent(s): 29faa9d

feat(frontend): expose MRI_MODEL_KIND in MRI predict tab; 2D upload path

Browse files
Files changed (1) hide show
  1. src/frontend/app.py +90 -50
src/frontend/app.py CHANGED
@@ -1319,60 +1319,100 @@ def _render_mri_tab() -> None:
1319
  st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}")
1320
 
1321
  st.markdown("#### MRI Image Model")
1322
- mri_image = st.text_input(
1323
- "NIfTI image",
1324
- "tests/fixtures/mri_sample/subject_0.nii.gz",
1325
- key="mri_predict_image",
1326
- )
1327
- mri_labels = st.text_input(
1328
- "Class labels",
1329
- "control,abnormal",
1330
- key="mri_predict_labels",
1331
- )
1332
- shape_cols = st.columns(3)
1333
- target_d = shape_cols[0].number_input(
1334
- "Resize D", min_value=1, max_value=256, value=64, step=1, key="mri_predict_d"
1335
- )
1336
- target_h = shape_cols[1].number_input(
1337
- "Resize H", min_value=1, max_value=256, value=64, step=1, key="mri_predict_h"
1338
- )
1339
- target_w = shape_cols[2].number_input(
1340
- "Resize W", min_value=1, max_value=256, value=64, step=1, key="mri_predict_w"
1341
- )
1342
  st.caption(
1343
- "Defaults to 64³ for production exports. Use 8³ when testing with the "
1344
- "dummy ONNX fixture from `tests/fixtures/build_dummy_mri_onnx.py`."
1345
  )
1346
- if st.button("Predict MRI image", key="mri_predict"):
1347
- labels = [x.strip() for x in mri_labels.split(",") if x.strip()]
1348
- payload: dict = {
1349
- "input_path": mri_image,
1350
- "target_shape": [int(target_d), int(target_h), int(target_w)],
1351
- }
1352
- if labels:
1353
- payload["label_names"] = labels
1354
- with st.spinner("Running MRI image model..."):
1355
- try:
1356
- result = _post("/predict/mri", payload, timeout=120.0)
1357
- except httpx.HTTPStatusError as e:
1358
- detail = e.response.text
1359
- if e.response.status_code == 503:
1360
- st.warning(
1361
- "MRI model artifact is not available yet. Export the trained "
1362
- "ONNX model to `data/processed/mri_model.onnx` or set `MRI_MODEL_PATH`."
 
 
 
 
 
 
 
 
 
 
 
 
 
1363
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1364
  else:
1365
- st.error(f"MRI prediction failed (HTTP {e.response.status_code}): {detail}")
1366
- except httpx.RequestError as e:
1367
- st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}")
1368
- else:
1369
- st.metric(
1370
- label=result.get("label_text", "prediction"),
1371
- value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%",
1372
- )
1373
- probs = result.get("probabilities", [])
1374
- if probs:
1375
- st.dataframe(probs, use_container_width=True, hide_index=True)
1376
 
1377
 
1378
  def _render_prediction_card(result: dict) -> None:
 
1319
  st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}")
1320
 
1321
  st.markdown("#### MRI Image Model")
1322
+ mri_kind = os.environ.get("MRI_MODEL_KIND", "volumetric_onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1323
  st.caption(
1324
+ f"Active backend: `{mri_kind}` set `MRI_MODEL_KIND=resnet18_2d` "
1325
+ "to switch to the 2D 4-class Alzheimer's classifier."
1326
  )
1327
+
1328
+ if mri_kind == "resnet18_2d":
1329
+ mri_image = st.text_input(
1330
+ "2D MRI image (.png/.jpg)",
1331
+ "tests/fixtures/mri_sample/subject_0_axial.png",
1332
+ key="mri_predict_image",
1333
+ )
1334
+ st.caption(
1335
+ "Resnet18 4-class labels: MildDemented, ModerateDemented, "
1336
+ "NonDemented, VeryMildDemented. Resize/labels are baked into the model."
1337
+ )
1338
+ if st.button("Predict MRI image", key="mri_predict"):
1339
+ payload = {"input_path": mri_image}
1340
+ with st.spinner("Running 2D MRI model..."):
1341
+ try:
1342
+ result = _post("/predict/mri", payload, timeout=120.0)
1343
+ except httpx.HTTPStatusError as e:
1344
+ if e.response.status_code == 503:
1345
+ st.warning(
1346
+ "MRI 2D model artifact missing. Drop the trained checkpoint at "
1347
+ "`data/processed/mri_dl_2d/best_model.pt` or set `MRI_MODEL_PATH_2D`."
1348
+ )
1349
+ else:
1350
+ st.error(f"MRI prediction failed (HTTP {e.response.status_code}): {e.response.text}")
1351
+ except httpx.RequestError as e:
1352
+ st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}")
1353
+ else:
1354
+ st.metric(
1355
+ label=result.get("label_text", "prediction"),
1356
+ value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%",
1357
  )
1358
+ probs = result.get("probabilities", [])
1359
+ if probs:
1360
+ st.dataframe(probs, use_container_width=True, hide_index=True)
1361
+ else:
1362
+ mri_image = st.text_input(
1363
+ "NIfTI image",
1364
+ "tests/fixtures/mri_sample/subject_0.nii.gz",
1365
+ key="mri_predict_image",
1366
+ )
1367
+ mri_labels = st.text_input(
1368
+ "Class labels",
1369
+ "control,abnormal",
1370
+ key="mri_predict_labels",
1371
+ )
1372
+ shape_cols = st.columns(3)
1373
+ target_d = shape_cols[0].number_input(
1374
+ "Resize D", min_value=1, max_value=256, value=64, step=1, key="mri_predict_d"
1375
+ )
1376
+ target_h = shape_cols[1].number_input(
1377
+ "Resize H", min_value=1, max_value=256, value=64, step=1, key="mri_predict_h"
1378
+ )
1379
+ target_w = shape_cols[2].number_input(
1380
+ "Resize W", min_value=1, max_value=256, value=64, step=1, key="mri_predict_w"
1381
+ )
1382
+ st.caption(
1383
+ "Defaults to 64³ for production exports. Use 8³ when testing with the "
1384
+ "dummy ONNX fixture from `tests/fixtures/build_dummy_mri_onnx.py`."
1385
+ )
1386
+ if st.button("Predict MRI image", key="mri_predict"):
1387
+ labels = [x.strip() for x in mri_labels.split(",") if x.strip()]
1388
+ payload: dict = {
1389
+ "input_path": mri_image,
1390
+ "target_shape": [int(target_d), int(target_h), int(target_w)],
1391
+ }
1392
+ if labels:
1393
+ payload["label_names"] = labels
1394
+ with st.spinner("Running MRI image model..."):
1395
+ try:
1396
+ result = _post("/predict/mri", payload, timeout=120.0)
1397
+ except httpx.HTTPStatusError as e:
1398
+ detail = e.response.text
1399
+ if e.response.status_code == 503:
1400
+ st.warning(
1401
+ "MRI model artifact is not available yet. Export the trained "
1402
+ "ONNX model to `data/processed/mri_model.onnx` or set `MRI_MODEL_PATH`."
1403
+ )
1404
+ else:
1405
+ st.error(f"MRI prediction failed (HTTP {e.response.status_code}): {detail}")
1406
+ except httpx.RequestError as e:
1407
+ st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}")
1408
  else:
1409
+ st.metric(
1410
+ label=result.get("label_text", "prediction"),
1411
+ value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%",
1412
+ )
1413
+ probs = result.get("probabilities", [])
1414
+ if probs:
1415
+ st.dataframe(probs, use_container_width=True, hide_index=True)
 
 
 
 
1416
 
1417
 
1418
  def _render_prediction_card(result: dict) -> None: