fix(mri/model): warn when label_names length != model output dim (was silent override)
Browse files- src/models/mri_model.py +7 -0
- tests/models/test_mri_model.py +28 -0
src/models/mri_model.py
CHANGED
|
@@ -95,6 +95,13 @@ def predict_with_proba(
|
|
| 95 |
output = model.run(None, {input_name: model_input.astype(np.float32, copy=False)})[0]
|
| 96 |
proba = _as_probabilities(np.asarray(output, dtype=np.float32))
|
| 97 |
if len(labels) != proba.shape[0]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
labels = tuple(f"class_{i}" for i in range(proba.shape[0]))
|
| 99 |
|
| 100 |
label_idx = int(np.argmax(proba))
|
|
|
|
| 95 |
output = model.run(None, {input_name: model_input.astype(np.float32, copy=False)})[0]
|
| 96 |
proba = _as_probabilities(np.asarray(output, dtype=np.float32))
|
| 97 |
if len(labels) != proba.shape[0]:
|
| 98 |
+
logger.warning(
|
| 99 |
+
"label_names length (%d) does not match model output dim (%d); "
|
| 100 |
+
"overriding with class_0..class_N. Provided labels: %r",
|
| 101 |
+
len(labels),
|
| 102 |
+
proba.shape[0],
|
| 103 |
+
list(labels),
|
| 104 |
+
)
|
| 105 |
labels = tuple(f"class_{i}" for i in range(proba.shape[0]))
|
| 106 |
|
| 107 |
label_idx = int(np.argmax(proba))
|
tests/models/test_mri_model.py
CHANGED
|
@@ -52,3 +52,31 @@ class TestMRIDLModel:
|
|
| 52 |
probs = result["probabilities"]
|
| 53 |
assert len(probs) == 2
|
| 54 |
assert sum(p["probability"] for p in probs) == pytest.approx(1.0, abs=1e-6)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
probs = result["probabilities"]
|
| 53 |
assert len(probs) == 2
|
| 54 |
assert sum(p["probability"] for p in probs) == pytest.approx(1.0, abs=1e-6)
|
| 55 |
+
|
| 56 |
+
def test_predict_warns_on_label_count_mismatch(
|
| 57 |
+
self, tmp_path: Path, caplog: pytest.LogCaptureFixture
|
| 58 |
+
) -> None:
|
| 59 |
+
import logging
|
| 60 |
+
|
| 61 |
+
artifact = build_dummy_mri_onnx(tmp_path / "mri_model.onnx")
|
| 62 |
+
model = mri_model.load(artifact)
|
| 63 |
+
|
| 64 |
+
# mri_model.logger has propagate=False (src/core/logger.py), so pytest's
|
| 65 |
+
# caplog root handler never sees its records. Attach caplog.handler directly.
|
| 66 |
+
mri_model.logger.addHandler(caplog.handler)
|
| 67 |
+
try:
|
| 68 |
+
with caplog.at_level(logging.WARNING, logger="src.models.mri_model"):
|
| 69 |
+
result = mri_model.predict_nifti(
|
| 70 |
+
model,
|
| 71 |
+
_FIXTURE_MRI,
|
| 72 |
+
target_shape=(8, 8, 8),
|
| 73 |
+
label_names=("control", "abnormal", "extra"),
|
| 74 |
+
)
|
| 75 |
+
finally:
|
| 76 |
+
mri_model.logger.removeHandler(caplog.handler)
|
| 77 |
+
|
| 78 |
+
assert result["label_text"] in {"class_0", "class_1"}
|
| 79 |
+
assert any(
|
| 80 |
+
"label_names length" in rec.message and "overriding" in rec.message
|
| 81 |
+
for rec in caplog.records
|
| 82 |
+
), [rec.message for rec in caplog.records]
|