feat(models): calibration metadata — precision-at-threshold bins on _neurobridge_calibration
Browse files- train() does 80/20 stratified split, fits on train, computes 6
precision-at-confidence-threshold bins (0.50/0.60/0.70/0.75/0.80/0.90)
on the held-out test set, stashes the result on
model._neurobridge_calibration for the API to surface.
- Tiny-fixture fallback: ValueError on stratified split → train on full
data + emit zero-support bins.
- 3 new tests covering attribute presence, sorted thresholds, and
joblib roundtrip preservation.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- src/models/bbb_model.py +62 -2
- tests/models/test_bbb_model.py +34 -0
src/models/bbb_model.py
CHANGED
|
@@ -15,6 +15,7 @@ import joblib
|
|
| 15 |
import numpy as np
|
| 16 |
import pandas as pd
|
| 17 |
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
| 18 |
|
| 19 |
from src.core.logger import get_logger
|
| 20 |
from src.pipelines.bbb_pipeline import (
|
|
@@ -45,6 +46,44 @@ def _split_features_and_label(
|
|
| 45 |
return X, y, fp_cols
|
| 46 |
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
def train(
|
| 49 |
df: pd.DataFrame,
|
| 50 |
label_col: str = "p_np",
|
|
@@ -65,20 +104,41 @@ def train(
|
|
| 65 |
downstream callers can map SHAP values back to fp_<bit> indices.
|
| 66 |
"""
|
| 67 |
X, y, fp_cols = _split_features_and_label(df, label_col)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
model = RandomForestClassifier(
|
| 69 |
n_estimators=n_estimators,
|
| 70 |
random_state=random_state,
|
| 71 |
n_jobs=1,
|
| 72 |
)
|
| 73 |
-
model.fit(
|
| 74 |
# Stash the column names under a project-owned attribute so SHAP (Task 2)
|
| 75 |
# can map values back to fp_<bit> indices. Sklearn's own feature_names_in_
|
| 76 |
# is only set automatically when fit receives a DataFrame; setting it
|
| 77 |
# manually fires UserWarning on every predict call.
|
| 78 |
model._neurobridge_fp_cols = list(fp_cols)
|
|
|
|
|
|
|
|
|
|
| 79 |
logger.info(
|
| 80 |
-
"Trained BBB classifier: n=%d, n_features=%d, classes=%s
|
|
|
|
| 81 |
len(y), X.shape[1], model.classes_.tolist(),
|
|
|
|
| 82 |
)
|
| 83 |
return model
|
| 84 |
|
|
|
|
| 15 |
import numpy as np
|
| 16 |
import pandas as pd
|
| 17 |
from sklearn.ensemble import RandomForestClassifier
|
| 18 |
+
from sklearn.model_selection import train_test_split
|
| 19 |
|
| 20 |
from src.core.logger import get_logger
|
| 21 |
from src.pipelines.bbb_pipeline import (
|
|
|
|
| 46 |
return X, y, fp_cols
|
| 47 |
|
| 48 |
|
| 49 |
+
_CALIBRATION_THRESHOLDS: tuple[float, ...] = (0.50, 0.60, 0.70, 0.75, 0.80, 0.90)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _compute_calibration_bins(
|
| 53 |
+
model: RandomForestClassifier,
|
| 54 |
+
X_test: np.ndarray,
|
| 55 |
+
y_test: np.ndarray,
|
| 56 |
+
) -> list[dict[str, float]]:
|
| 57 |
+
"""Compute precision-at-confidence-threshold bins on a held-out test set.
|
| 58 |
+
|
| 59 |
+
For each threshold T in `_CALIBRATION_THRESHOLDS`, picks the predictions
|
| 60 |
+
whose max class probability >= T, computes precision and support, and
|
| 61 |
+
returns one bin per threshold. Bins with zero support are still emitted
|
| 62 |
+
(precision = 0.0, support = 0) so the API can always find a match.
|
| 63 |
+
"""
|
| 64 |
+
if len(y_test) == 0:
|
| 65 |
+
return [
|
| 66 |
+
{"threshold": float(t), "precision": 0.0, "support": 0}
|
| 67 |
+
for t in _CALIBRATION_THRESHOLDS
|
| 68 |
+
]
|
| 69 |
+
proba = model.predict_proba(X_test)
|
| 70 |
+
pred = model.predict(X_test)
|
| 71 |
+
confidence = proba.max(axis=1)
|
| 72 |
+
correct = (pred == y_test).astype(int)
|
| 73 |
+
bins: list[dict[str, float]] = []
|
| 74 |
+
for t in _CALIBRATION_THRESHOLDS:
|
| 75 |
+
mask = confidence >= t
|
| 76 |
+
support = int(mask.sum())
|
| 77 |
+
if support == 0:
|
| 78 |
+
precision = 0.0
|
| 79 |
+
else:
|
| 80 |
+
precision = float(correct[mask].mean())
|
| 81 |
+
bins.append({
|
| 82 |
+
"threshold": float(t), "precision": precision, "support": support,
|
| 83 |
+
})
|
| 84 |
+
return bins
|
| 85 |
+
|
| 86 |
+
|
| 87 |
def train(
|
| 88 |
df: pd.DataFrame,
|
| 89 |
label_col: str = "p_np",
|
|
|
|
| 104 |
downstream callers can map SHAP values back to fp_<bit> indices.
|
| 105 |
"""
|
| 106 |
X, y, fp_cols = _split_features_and_label(df, label_col)
|
| 107 |
+
# Stratified 80/20 split for honest calibration metrics. Falls back to
|
| 108 |
+
# train-on-all if the dataset is too tiny for a stratified split (test
|
| 109 |
+
# fixtures with 3-4 rows hit this branch).
|
| 110 |
+
try:
|
| 111 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 112 |
+
X, y, test_size=0.2, random_state=random_state, stratify=y,
|
| 113 |
+
)
|
| 114 |
+
except ValueError as e:
|
| 115 |
+
logger.warning(
|
| 116 |
+
"Stratified split failed (%s); training on full data; "
|
| 117 |
+
"calibration bins will be zero-support.",
|
| 118 |
+
e,
|
| 119 |
+
)
|
| 120 |
+
X_train, X_test = X, np.empty((0, X.shape[1]))
|
| 121 |
+
y_train, y_test = y, np.empty((0,))
|
| 122 |
+
|
| 123 |
model = RandomForestClassifier(
|
| 124 |
n_estimators=n_estimators,
|
| 125 |
random_state=random_state,
|
| 126 |
n_jobs=1,
|
| 127 |
)
|
| 128 |
+
model.fit(X_train, y_train)
|
| 129 |
# Stash the column names under a project-owned attribute so SHAP (Task 2)
|
| 130 |
# can map values back to fp_<bit> indices. Sklearn's own feature_names_in_
|
| 131 |
# is only set automatically when fit receives a DataFrame; setting it
|
| 132 |
# manually fires UserWarning on every predict call.
|
| 133 |
model._neurobridge_fp_cols = list(fp_cols)
|
| 134 |
+
model._neurobridge_calibration = _compute_calibration_bins(
|
| 135 |
+
model, X_test, y_test,
|
| 136 |
+
)
|
| 137 |
logger.info(
|
| 138 |
+
"Trained BBB classifier: n=%d, n_features=%d, classes=%s, "
|
| 139 |
+
"calibration_bins=%d",
|
| 140 |
len(y), X.shape[1], model.classes_.tolist(),
|
| 141 |
+
len(model._neurobridge_calibration),
|
| 142 |
)
|
| 143 |
return model
|
| 144 |
|
tests/models/test_bbb_model.py
CHANGED
|
@@ -127,3 +127,37 @@ class TestExplainPrediction:
|
|
| 127 |
r1 = bbb_model.explain_prediction(model, "CCO", top_k=5)
|
| 128 |
r2 = bbb_model.explain_prediction(model, "CCO", top_k=5)
|
| 129 |
assert r1 == r2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
r1 = bbb_model.explain_prediction(model, "CCO", top_k=5)
|
| 128 |
r2 = bbb_model.explain_prediction(model, "CCO", top_k=5)
|
| 129 |
assert r1 == r2
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class TestCalibrationMetadata:
|
| 133 |
+
def test_train_attaches_calibration_attribute(self, trained_model_and_features):
|
| 134 |
+
model, _ = trained_model_and_features
|
| 135 |
+
assert hasattr(model, "_neurobridge_calibration")
|
| 136 |
+
bins = model._neurobridge_calibration
|
| 137 |
+
assert isinstance(bins, list)
|
| 138 |
+
# Always at least one bin (the lowest-threshold one)
|
| 139 |
+
assert len(bins) >= 1
|
| 140 |
+
for b in bins:
|
| 141 |
+
assert "threshold" in b
|
| 142 |
+
assert "precision" in b
|
| 143 |
+
assert "support" in b
|
| 144 |
+
assert 0.0 <= b["threshold"] <= 1.0
|
| 145 |
+
assert 0.0 <= b["precision"] <= 1.0
|
| 146 |
+
assert b["support"] >= 0
|
| 147 |
+
|
| 148 |
+
def test_calibration_thresholds_are_sorted_ascending(
|
| 149 |
+
self, trained_model_and_features,
|
| 150 |
+
):
|
| 151 |
+
model, _ = trained_model_and_features
|
| 152 |
+
thresholds = [b["threshold"] for b in model._neurobridge_calibration]
|
| 153 |
+
assert thresholds == sorted(thresholds)
|
| 154 |
+
|
| 155 |
+
def test_calibration_survives_save_load_roundtrip(
|
| 156 |
+
self, trained_model_and_features, tmp_path: Path,
|
| 157 |
+
):
|
| 158 |
+
model, _ = trained_model_and_features
|
| 159 |
+
artifact = tmp_path / "calibrated.joblib"
|
| 160 |
+
bbb_model.save(model, artifact)
|
| 161 |
+
reloaded = bbb_model.load(artifact)
|
| 162 |
+
assert hasattr(reloaded, "_neurobridge_calibration")
|
| 163 |
+
assert reloaded._neurobridge_calibration == model._neurobridge_calibration
|