mekosotto commited on
Commit
9ae5b40
·
1 Parent(s): c0a7163

fix(mri/model): warn when label_names length != model output dim (was silent override)

Browse files
src/models/mri_model.py CHANGED
@@ -95,6 +95,13 @@ def predict_with_proba(
95
  output = model.run(None, {input_name: model_input.astype(np.float32, copy=False)})[0]
96
  proba = _as_probabilities(np.asarray(output, dtype=np.float32))
97
  if len(labels) != proba.shape[0]:
 
 
 
 
 
 
 
98
  labels = tuple(f"class_{i}" for i in range(proba.shape[0]))
99
 
100
  label_idx = int(np.argmax(proba))
 
95
  output = model.run(None, {input_name: model_input.astype(np.float32, copy=False)})[0]
96
  proba = _as_probabilities(np.asarray(output, dtype=np.float32))
97
  if len(labels) != proba.shape[0]:
98
+ logger.warning(
99
+ "label_names length (%d) does not match model output dim (%d); "
100
+ "overriding with class_0..class_N. Provided labels: %r",
101
+ len(labels),
102
+ proba.shape[0],
103
+ list(labels),
104
+ )
105
  labels = tuple(f"class_{i}" for i in range(proba.shape[0]))
106
 
107
  label_idx = int(np.argmax(proba))
tests/models/test_mri_model.py CHANGED
@@ -52,3 +52,31 @@ class TestMRIDLModel:
52
  probs = result["probabilities"]
53
  assert len(probs) == 2
54
  assert sum(p["probability"] for p in probs) == pytest.approx(1.0, abs=1e-6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  probs = result["probabilities"]
53
  assert len(probs) == 2
54
  assert sum(p["probability"] for p in probs) == pytest.approx(1.0, abs=1e-6)
55
+
56
+ def test_predict_warns_on_label_count_mismatch(
57
+ self, tmp_path: Path, caplog: pytest.LogCaptureFixture
58
+ ) -> None:
59
+ import logging
60
+
61
+ artifact = build_dummy_mri_onnx(tmp_path / "mri_model.onnx")
62
+ model = mri_model.load(artifact)
63
+
64
+ # mri_model.logger has propagate=False (src/core/logger.py), so pytest's
65
+ # caplog root handler never sees its records. Attach caplog.handler directly.
66
+ mri_model.logger.addHandler(caplog.handler)
67
+ try:
68
+ with caplog.at_level(logging.WARNING, logger="src.models.mri_model"):
69
+ result = mri_model.predict_nifti(
70
+ model,
71
+ _FIXTURE_MRI,
72
+ target_shape=(8, 8, 8),
73
+ label_names=("control", "abnormal", "extra"),
74
+ )
75
+ finally:
76
+ mri_model.logger.removeHandler(caplog.handler)
77
+
78
+ assert result["label_text"] in {"class_0", "class_1"}
79
+ assert any(
80
+ "label_names length" in rec.message and "overriding" in rec.message
81
+ for rec in caplog.records
82
+ ), [rec.message for rec in caplog.records]