mekosotto Claude Opus 4.7 (1M context) commited on
Commit
efb8713
·
1 Parent(s): c4a01f0

feat(models): train-time confidence stats stashed on _neurobridge_train_stats

Browse files

- _compute_train_stats() captures median, std, n_train of the model's
own predict_proba on X_train. Joblib-roundtrip-safe.
- train() persists stats alongside _neurobridge_fp_cols and
_neurobridge_calibration. INFO log line now surfaces the median.
- Foundation for Day-7 T1B drift z-score in /predict/bbb.
- 2 new tests (TestTrainStatsMetadata): attribute presence + roundtrip.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

src/models/bbb_model.py CHANGED
@@ -84,6 +84,26 @@ def _compute_calibration_bins(
84
  return bins
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def train(
88
  df: pd.DataFrame,
89
  label_col: str = "p_np",
@@ -134,11 +154,13 @@ def train(
134
  model._neurobridge_calibration = _compute_calibration_bins(
135
  model, X_test, y_test,
136
  )
 
137
  logger.info(
138
  "Trained BBB classifier: n=%d, n_features=%d, classes=%s, "
139
- "calibration_bins=%d",
140
  len(y), X.shape[1], model.classes_.tolist(),
141
  len(model._neurobridge_calibration),
 
142
  )
143
  return model
144
 
 
84
  return bins
85
 
86
 
87
+ def _compute_train_stats(
88
+ model: RandomForestClassifier,
89
+ X_train: np.ndarray,
90
+ ) -> dict[str, float]:
91
+ """Compute median + std of the model's own confidence on the training set.
92
+
93
+ Used as the reference distribution for runtime drift detection. All values
94
+ are floats so the dict is joblib-roundtrip-safe and JSON-serializable.
95
+ """
96
+ if len(X_train) == 0:
97
+ return {"median": 0.0, "std": 0.0, "n_train": 0}
98
+ proba = model.predict_proba(X_train)
99
+ confidence = proba.max(axis=1)
100
+ return {
101
+ "median": float(np.median(confidence)),
102
+ "std": float(np.std(confidence)),
103
+ "n_train": int(len(X_train)),
104
+ }
105
+
106
+
107
  def train(
108
  df: pd.DataFrame,
109
  label_col: str = "p_np",
 
154
  model._neurobridge_calibration = _compute_calibration_bins(
155
  model, X_test, y_test,
156
  )
157
+ model._neurobridge_train_stats = _compute_train_stats(model, X_train)
158
  logger.info(
159
  "Trained BBB classifier: n=%d, n_features=%d, classes=%s, "
160
+ "calibration_bins=%d, train_confidence_median=%.3f",
161
  len(y), X.shape[1], model.classes_.tolist(),
162
  len(model._neurobridge_calibration),
163
+ model._neurobridge_train_stats["median"],
164
  )
165
  return model
166
 
tests/models/test_bbb_model.py CHANGED
@@ -161,3 +161,29 @@ class TestCalibrationMetadata:
161
  reloaded = bbb_model.load(artifact)
162
  assert hasattr(reloaded, "_neurobridge_calibration")
163
  assert reloaded._neurobridge_calibration == model._neurobridge_calibration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  reloaded = bbb_model.load(artifact)
162
  assert hasattr(reloaded, "_neurobridge_calibration")
163
  assert reloaded._neurobridge_calibration == model._neurobridge_calibration
164
+
165
+
166
+ class TestTrainStatsMetadata:
167
+ """Day 7 — T1A: train()-time confidence distribution stash."""
168
+
169
+ def test_train_attaches_train_stats_attribute(self, trained_model_and_features):
170
+ model, _ = trained_model_and_features
171
+ assert hasattr(model, "_neurobridge_train_stats")
172
+ stats = model._neurobridge_train_stats
173
+ assert isinstance(stats, dict)
174
+ for key in ("median", "std", "n_train"):
175
+ assert key in stats, f"missing key {key!r} in train stats"
176
+ assert 0.0 <= stats["median"] <= 1.0
177
+ assert stats["std"] >= 0.0
178
+ assert stats["n_train"] >= 1
179
+
180
+ def test_train_stats_survives_save_load_roundtrip(
181
+ self, trained_model_and_features, tmp_path: Path,
182
+ ):
183
+ from src.models import bbb_model
184
+ model, _ = trained_model_and_features
185
+ path = tmp_path / "m.joblib"
186
+ bbb_model.save(model, path)
187
+ reloaded = bbb_model.load(path)
188
+ assert hasattr(reloaded, "_neurobridge_train_stats")
189
+ assert reloaded._neurobridge_train_stats == model._neurobridge_train_stats