feat(models): train-time confidence stats stashed on _neurobridge_train_stats
Browse files- _compute_train_stats() captures median, std, n_train of the model's
own predict_proba on X_train. Joblib-roundtrip-safe.
- train() persists stats alongside _neurobridge_fp_cols and
_neurobridge_calibration. INFO log line now surfaces the median.
- Foundation for Day-7 T1B drift z-score in /predict/bbb.
- 2 new tests (TestTrainStatsMetadata): attribute presence + roundtrip.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- src/models/bbb_model.py +23 -1
- tests/models/test_bbb_model.py +26 -0
src/models/bbb_model.py
CHANGED
|
@@ -84,6 +84,26 @@ def _compute_calibration_bins(
|
|
| 84 |
return bins
|
| 85 |
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
def train(
|
| 88 |
df: pd.DataFrame,
|
| 89 |
label_col: str = "p_np",
|
|
@@ -134,11 +154,13 @@ def train(
|
|
| 134 |
model._neurobridge_calibration = _compute_calibration_bins(
|
| 135 |
model, X_test, y_test,
|
| 136 |
)
|
|
|
|
| 137 |
logger.info(
|
| 138 |
"Trained BBB classifier: n=%d, n_features=%d, classes=%s, "
|
| 139 |
-
"calibration_bins=%d",
|
| 140 |
len(y), X.shape[1], model.classes_.tolist(),
|
| 141 |
len(model._neurobridge_calibration),
|
|
|
|
| 142 |
)
|
| 143 |
return model
|
| 144 |
|
|
|
|
| 84 |
return bins
|
| 85 |
|
| 86 |
|
| 87 |
+
def _compute_train_stats(
|
| 88 |
+
model: RandomForestClassifier,
|
| 89 |
+
X_train: np.ndarray,
|
| 90 |
+
) -> dict[str, float]:
|
| 91 |
+
"""Compute median + std of the model's own confidence on the training set.
|
| 92 |
+
|
| 93 |
+
Used as the reference distribution for runtime drift detection. All values
|
| 94 |
+
are floats so the dict is joblib-roundtrip-safe and JSON-serializable.
|
| 95 |
+
"""
|
| 96 |
+
if len(X_train) == 0:
|
| 97 |
+
return {"median": 0.0, "std": 0.0, "n_train": 0}
|
| 98 |
+
proba = model.predict_proba(X_train)
|
| 99 |
+
confidence = proba.max(axis=1)
|
| 100 |
+
return {
|
| 101 |
+
"median": float(np.median(confidence)),
|
| 102 |
+
"std": float(np.std(confidence)),
|
| 103 |
+
"n_train": int(len(X_train)),
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
def train(
|
| 108 |
df: pd.DataFrame,
|
| 109 |
label_col: str = "p_np",
|
|
|
|
| 154 |
model._neurobridge_calibration = _compute_calibration_bins(
|
| 155 |
model, X_test, y_test,
|
| 156 |
)
|
| 157 |
+
model._neurobridge_train_stats = _compute_train_stats(model, X_train)
|
| 158 |
logger.info(
|
| 159 |
"Trained BBB classifier: n=%d, n_features=%d, classes=%s, "
|
| 160 |
+
"calibration_bins=%d, train_confidence_median=%.3f",
|
| 161 |
len(y), X.shape[1], model.classes_.tolist(),
|
| 162 |
len(model._neurobridge_calibration),
|
| 163 |
+
model._neurobridge_train_stats["median"],
|
| 164 |
)
|
| 165 |
return model
|
| 166 |
|
tests/models/test_bbb_model.py
CHANGED
|
@@ -161,3 +161,29 @@ class TestCalibrationMetadata:
|
|
| 161 |
reloaded = bbb_model.load(artifact)
|
| 162 |
assert hasattr(reloaded, "_neurobridge_calibration")
|
| 163 |
assert reloaded._neurobridge_calibration == model._neurobridge_calibration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
reloaded = bbb_model.load(artifact)
|
| 162 |
assert hasattr(reloaded, "_neurobridge_calibration")
|
| 163 |
assert reloaded._neurobridge_calibration == model._neurobridge_calibration
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class TestTrainStatsMetadata:
|
| 167 |
+
"""Day 7 — T1A: train()-time confidence distribution stash."""
|
| 168 |
+
|
| 169 |
+
def test_train_attaches_train_stats_attribute(self, trained_model_and_features):
|
| 170 |
+
model, _ = trained_model_and_features
|
| 171 |
+
assert hasattr(model, "_neurobridge_train_stats")
|
| 172 |
+
stats = model._neurobridge_train_stats
|
| 173 |
+
assert isinstance(stats, dict)
|
| 174 |
+
for key in ("median", "std", "n_train"):
|
| 175 |
+
assert key in stats, f"missing key {key!r} in train stats"
|
| 176 |
+
assert 0.0 <= stats["median"] <= 1.0
|
| 177 |
+
assert stats["std"] >= 0.0
|
| 178 |
+
assert stats["n_train"] >= 1
|
| 179 |
+
|
| 180 |
+
def test_train_stats_survives_save_load_roundtrip(
|
| 181 |
+
self, trained_model_and_features, tmp_path: Path,
|
| 182 |
+
):
|
| 183 |
+
from src.models import bbb_model
|
| 184 |
+
model, _ = trained_model_and_features
|
| 185 |
+
path = tmp_path / "m.joblib"
|
| 186 |
+
bbb_model.save(model, path)
|
| 187 |
+
reloaded = bbb_model.load(path)
|
| 188 |
+
assert hasattr(reloaded, "_neurobridge_train_stats")
|
| 189 |
+
assert reloaded._neurobridge_train_stats == model._neurobridge_train_stats
|