mekosotto Claude Opus 4.7 (1M context) commited on
Commit
c26a55c
·
1 Parent(s): efb8713

feat(api): drift z-score in /predict/bbb response

Browse files

- WORKER_CONFIDENCE_DEQUE: collections.deque(maxlen=100), per-worker
rolling window of confidences; drift_z computed against train-time
median when ≥10 samples buffered AND model has _neurobridge_train_stats.
- BBBPredictResponse gains drift_z (float | None) and rolling_n (int).
- 2 new tests: drift_z/rolling_n always present in body; deque rolls
at 100 after 105 predictions.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (3) hide show
  1. src/api/routes.py +33 -0
  2. src/api/schemas.py +16 -1
  3. tests/api/test_routes.py +41 -1
src/api/routes.py CHANGED
@@ -9,6 +9,7 @@ from __future__ import annotations
9
 
10
  import os
11
  import time
 
12
  from pathlib import Path
13
  from typing import Callable
14
 
@@ -130,6 +131,35 @@ def _bbb_model_path() -> Path:
130
  return Path(os.environ.get("BBB_MODEL_PATH", str(_DEFAULT_BBB_MODEL_PATH)))
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def _matching_calibration_bin(model, confidence: float) -> CalibrationContext | None:
134
  """Pick the highest-threshold bin whose threshold <= confidence. None if no match or no metadata."""
135
  bins = getattr(model, "_neurobridge_calibration", None)
@@ -180,12 +210,15 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
180
 
181
  label_text = "permeable" if pred["label"] == 1 else "non-permeable"
182
  calibration = _matching_calibration_bin(model, pred["confidence"])
 
183
  return BBBPredictResponse(
184
  label=pred["label"],
185
  label_text=label_text,
186
  confidence=pred["confidence"],
187
  top_features=[FeatureAttribution(**a) for a in attributions],
188
  calibration=calibration,
 
 
189
  )
190
 
191
 
 
9
 
10
  import os
11
  import time
12
+ from collections import deque
13
  from pathlib import Path
14
  from typing import Callable
15
 
 
131
  return Path(os.environ.get("BBB_MODEL_PATH", str(_DEFAULT_BBB_MODEL_PATH)))
132
 
133
 
134
+ # Per-worker rolling window of recent prediction confidences.
135
+ # Cleared on worker restart; multi-worker setups have independent windows.
136
+ WORKER_CONFIDENCE_DEQUE: deque[float] = deque(maxlen=100)
137
+ _DRIFT_MIN_SAMPLES = 10
138
+
139
+
140
+ def _compute_drift_z(model, confidence: float) -> tuple[float | None, int]:
141
+ """Append `confidence` to the worker deque and compute the drift z-score.
142
+
143
+ Returns (drift_z, rolling_n). drift_z is None until both:
144
+ (1) the deque has at least `_DRIFT_MIN_SAMPLES` samples, AND
145
+ (2) the model has `_neurobridge_train_stats` attached.
146
+
147
+ z = (rolling_median - train_median) / max(train_std, 1e-9)
148
+ """
149
+ import statistics
150
+
151
+ WORKER_CONFIDENCE_DEQUE.append(float(confidence))
152
+ rolling_n = len(WORKER_CONFIDENCE_DEQUE)
153
+ stats = getattr(model, "_neurobridge_train_stats", None)
154
+ if rolling_n < _DRIFT_MIN_SAMPLES or stats is None:
155
+ return None, rolling_n
156
+ rolling_median = statistics.median(WORKER_CONFIDENCE_DEQUE)
157
+ train_median = float(stats["median"])
158
+ train_std = max(float(stats["std"]), 1e-9)
159
+ drift_z = (rolling_median - train_median) / train_std
160
+ return float(drift_z), rolling_n
161
+
162
+
163
  def _matching_calibration_bin(model, confidence: float) -> CalibrationContext | None:
164
  """Pick the highest-threshold bin whose threshold <= confidence. None if no match or no metadata."""
165
  bins = getattr(model, "_neurobridge_calibration", None)
 
210
 
211
  label_text = "permeable" if pred["label"] == 1 else "non-permeable"
212
  calibration = _matching_calibration_bin(model, pred["confidence"])
213
+ drift_z, rolling_n = _compute_drift_z(model, pred["confidence"])
214
  return BBBPredictResponse(
215
  label=pred["label"],
216
  label_text=label_text,
217
  confidence=pred["confidence"],
218
  top_features=[FeatureAttribution(**a) for a in attributions],
219
  calibration=calibration,
220
+ drift_z=drift_z,
221
+ rolling_n=rolling_n,
222
  )
223
 
224
 
src/api/schemas.py CHANGED
@@ -71,7 +71,7 @@ class CalibrationContext(BaseModel):
71
 
72
 
73
  class BBBPredictResponse(BaseModel):
74
- """Decision-system payload: prediction + uncertainty + explanation."""
75
  label: int
76
  label_text: str = Field(..., description="'permeable' or 'non-permeable'")
77
  confidence: float
@@ -80,6 +80,21 @@ class BBBPredictResponse(BaseModel):
80
  None,
81
  description="Statistical context: how often the model is right when this confident on held-out data.",
82
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
 
85
  class MRIDiagnosticsRequest(BaseModel):
 
71
 
72
 
73
  class BBBPredictResponse(BaseModel):
74
+ """Decision-system payload: prediction + uncertainty + explanation + drift."""
75
  label: int
76
  label_text: str = Field(..., description="'permeable' or 'non-permeable'")
77
  confidence: float
 
80
  None,
81
  description="Statistical context: how often the model is right when this confident on held-out data.",
82
  )
83
+ drift_z: float | None = Field(
84
+ None,
85
+ description=(
86
+ "Z-score of the trailing-100 confidence median against the "
87
+ "train-time median; None when warming up (<10 samples) or "
88
+ "when the model lacks _neurobridge_train_stats."
89
+ ),
90
+ )
91
+ rolling_n: int = Field(
92
+ 0,
93
+ description=(
94
+ "Number of confidence samples currently buffered in the worker's "
95
+ "rolling window (max 100). Zero on a fresh worker."
96
+ ),
97
+ )
98
 
99
 
100
  class MRIDiagnosticsRequest(BaseModel):
tests/api/test_routes.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
 
4
  from pathlib import Path
5
 
 
6
  from fastapi.testclient import TestClient
7
 
8
  from src.api.main import app
@@ -89,8 +90,14 @@ class TestBBBPredictRoute:
89
  bbb_model.save(model, artifact)
90
  return artifact
91
 
 
 
 
 
 
 
 
92
  def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch):
93
- import pytest
94
  artifact = self._setup_model_artifact(tmp_path)
95
  monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
96
 
@@ -120,6 +127,39 @@ class TestBBBPredictRoute:
120
  assert isinstance(cal["support"], int)
121
  assert cal["support"] >= 0
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
124
  artifact = self._setup_model_artifact(tmp_path)
125
  monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
 
3
 
4
  from pathlib import Path
5
 
6
+ import pytest
7
  from fastapi.testclient import TestClient
8
 
9
  from src.api.main import app
 
90
  bbb_model.save(model, artifact)
91
  return artifact
92
 
93
+ @pytest.fixture
94
+ def _set_bbb_model_path(self, tmp_path: Path, monkeypatch):
95
+ """Build a model artifact and point BBB_MODEL_PATH at it for the test."""
96
+ artifact = self._setup_model_artifact(tmp_path)
97
+ monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
98
+ return artifact
99
+
100
  def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch):
 
101
  artifact = self._setup_model_artifact(tmp_path)
102
  monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
103
 
 
127
  assert isinstance(cal["support"], int)
128
  assert cal["support"] >= 0
129
 
130
+ def test_predict_response_includes_drift_z_and_rolling_n(
131
+ self, _set_bbb_model_path,
132
+ ):
133
+ """T1B: drift_z and rolling_n keys must always appear in the body."""
134
+ # Reset deque before this test so rolling_n starts deterministic.
135
+ from src.api import routes
136
+ routes.WORKER_CONFIDENCE_DEQUE.clear()
137
+
138
+ resp = client.post("/predict/bbb", json={"smiles": "CCO", "top_k": 5})
139
+ assert resp.status_code == 200, resp.text
140
+ body = resp.json()
141
+ assert "drift_z" in body
142
+ assert "rolling_n" in body
143
+ # First request: buffer has 1 sample (just appended), so warming up.
144
+ assert body["rolling_n"] == 1
145
+ assert body["drift_z"] is None # <10 samples = warming up
146
+
147
+ def test_predict_deque_rolls_at_100(self, _set_bbb_model_path):
148
+ """T1B: after 100 predictions, deque caps at maxlen=100 (rolls)."""
149
+ from src.api import routes
150
+ routes.WORKER_CONFIDENCE_DEQUE.clear()
151
+ # Fire 105 calls; final rolling_n must be 100, not 105.
152
+ last_body = None
153
+ for _ in range(105):
154
+ resp = client.post(
155
+ "/predict/bbb", json={"smiles": "CCO", "top_k": 3},
156
+ )
157
+ assert resp.status_code == 200
158
+ last_body = resp.json()
159
+ assert last_body["rolling_n"] == 100
160
+ # By call 105, drift_z is computable (≥10 samples) — assert numeric.
161
+ assert isinstance(last_body["drift_z"], float)
162
+
163
  def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
164
  artifact = self._setup_model_artifact(tmp_path)
165
  monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))