feat(api): drift z-score in /predict/bbb response
Browse files- WORKER_CONFIDENCE_DEQUE: collections.deque(maxlen=100), per-worker
rolling window of confidences; drift_z computed against train-time
median when ≥10 samples buffered AND model has _neurobridge_train_stats.
- BBBPredictResponse gains drift_z (float | None) and rolling_n (int).
- 2 new tests: drift_z/rolling_n always present in body; deque rolls
at 100 after 105 predictions.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- src/api/routes.py +33 -0
- src/api/schemas.py +16 -1
- tests/api/test_routes.py +41 -1
src/api/routes.py
CHANGED
|
@@ -9,6 +9,7 @@ from __future__ import annotations
|
|
| 9 |
|
| 10 |
import os
|
| 11 |
import time
|
|
|
|
| 12 |
from pathlib import Path
|
| 13 |
from typing import Callable
|
| 14 |
|
|
@@ -130,6 +131,35 @@ def _bbb_model_path() -> Path:
|
|
| 130 |
return Path(os.environ.get("BBB_MODEL_PATH", str(_DEFAULT_BBB_MODEL_PATH)))
|
| 131 |
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
def _matching_calibration_bin(model, confidence: float) -> CalibrationContext | None:
|
| 134 |
"""Pick the highest-threshold bin whose threshold <= confidence. None if no match or no metadata."""
|
| 135 |
bins = getattr(model, "_neurobridge_calibration", None)
|
|
@@ -180,12 +210,15 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
|
|
| 180 |
|
| 181 |
label_text = "permeable" if pred["label"] == 1 else "non-permeable"
|
| 182 |
calibration = _matching_calibration_bin(model, pred["confidence"])
|
|
|
|
| 183 |
return BBBPredictResponse(
|
| 184 |
label=pred["label"],
|
| 185 |
label_text=label_text,
|
| 186 |
confidence=pred["confidence"],
|
| 187 |
top_features=[FeatureAttribution(**a) for a in attributions],
|
| 188 |
calibration=calibration,
|
|
|
|
|
|
|
| 189 |
)
|
| 190 |
|
| 191 |
|
|
|
|
| 9 |
|
| 10 |
import os
|
| 11 |
import time
|
| 12 |
+
from collections import deque
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Callable
|
| 15 |
|
|
|
|
| 131 |
return Path(os.environ.get("BBB_MODEL_PATH", str(_DEFAULT_BBB_MODEL_PATH)))
|
| 132 |
|
| 133 |
|
| 134 |
+
# Per-worker rolling window of recent prediction confidences.
|
| 135 |
+
# Cleared on worker restart; multi-worker setups have independent windows.
|
| 136 |
+
WORKER_CONFIDENCE_DEQUE: deque[float] = deque(maxlen=100)
|
| 137 |
+
_DRIFT_MIN_SAMPLES = 10
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _compute_drift_z(model, confidence: float) -> tuple[float | None, int]:
|
| 141 |
+
"""Append `confidence` to the worker deque and compute the drift z-score.
|
| 142 |
+
|
| 143 |
+
Returns (drift_z, rolling_n). drift_z is None until both:
|
| 144 |
+
(1) the deque has at least `_DRIFT_MIN_SAMPLES` samples, AND
|
| 145 |
+
(2) the model has `_neurobridge_train_stats` attached.
|
| 146 |
+
|
| 147 |
+
z = (rolling_median - train_median) / max(train_std, 1e-9)
|
| 148 |
+
"""
|
| 149 |
+
import statistics
|
| 150 |
+
|
| 151 |
+
WORKER_CONFIDENCE_DEQUE.append(float(confidence))
|
| 152 |
+
rolling_n = len(WORKER_CONFIDENCE_DEQUE)
|
| 153 |
+
stats = getattr(model, "_neurobridge_train_stats", None)
|
| 154 |
+
if rolling_n < _DRIFT_MIN_SAMPLES or stats is None:
|
| 155 |
+
return None, rolling_n
|
| 156 |
+
rolling_median = statistics.median(WORKER_CONFIDENCE_DEQUE)
|
| 157 |
+
train_median = float(stats["median"])
|
| 158 |
+
train_std = max(float(stats["std"]), 1e-9)
|
| 159 |
+
drift_z = (rolling_median - train_median) / train_std
|
| 160 |
+
return float(drift_z), rolling_n
|
| 161 |
+
|
| 162 |
+
|
| 163 |
def _matching_calibration_bin(model, confidence: float) -> CalibrationContext | None:
|
| 164 |
"""Pick the highest-threshold bin whose threshold <= confidence. None if no match or no metadata."""
|
| 165 |
bins = getattr(model, "_neurobridge_calibration", None)
|
|
|
|
| 210 |
|
| 211 |
label_text = "permeable" if pred["label"] == 1 else "non-permeable"
|
| 212 |
calibration = _matching_calibration_bin(model, pred["confidence"])
|
| 213 |
+
drift_z, rolling_n = _compute_drift_z(model, pred["confidence"])
|
| 214 |
return BBBPredictResponse(
|
| 215 |
label=pred["label"],
|
| 216 |
label_text=label_text,
|
| 217 |
confidence=pred["confidence"],
|
| 218 |
top_features=[FeatureAttribution(**a) for a in attributions],
|
| 219 |
calibration=calibration,
|
| 220 |
+
drift_z=drift_z,
|
| 221 |
+
rolling_n=rolling_n,
|
| 222 |
)
|
| 223 |
|
| 224 |
|
src/api/schemas.py
CHANGED
|
@@ -71,7 +71,7 @@ class CalibrationContext(BaseModel):
|
|
| 71 |
|
| 72 |
|
| 73 |
class BBBPredictResponse(BaseModel):
|
| 74 |
-
"""Decision-system payload: prediction + uncertainty + explanation."""
|
| 75 |
label: int
|
| 76 |
label_text: str = Field(..., description="'permeable' or 'non-permeable'")
|
| 77 |
confidence: float
|
|
@@ -80,6 +80,21 @@ class BBBPredictResponse(BaseModel):
|
|
| 80 |
None,
|
| 81 |
description="Statistical context: how often the model is right when this confident on held-out data.",
|
| 82 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
class MRIDiagnosticsRequest(BaseModel):
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
class BBBPredictResponse(BaseModel):
|
| 74 |
+
"""Decision-system payload: prediction + uncertainty + explanation + drift."""
|
| 75 |
label: int
|
| 76 |
label_text: str = Field(..., description="'permeable' or 'non-permeable'")
|
| 77 |
confidence: float
|
|
|
|
| 80 |
None,
|
| 81 |
description="Statistical context: how often the model is right when this confident on held-out data.",
|
| 82 |
)
|
| 83 |
+
drift_z: float | None = Field(
|
| 84 |
+
None,
|
| 85 |
+
description=(
|
| 86 |
+
"Z-score of the trailing-100 confidence median against the "
|
| 87 |
+
"train-time median; None when warming up (<10 samples) or "
|
| 88 |
+
"when the model lacks _neurobridge_train_stats."
|
| 89 |
+
),
|
| 90 |
+
)
|
| 91 |
+
rolling_n: int = Field(
|
| 92 |
+
0,
|
| 93 |
+
description=(
|
| 94 |
+
"Number of confidence samples currently buffered in the worker's "
|
| 95 |
+
"rolling window (max 100). Zero on a fresh worker."
|
| 96 |
+
),
|
| 97 |
+
)
|
| 98 |
|
| 99 |
|
| 100 |
class MRIDiagnosticsRequest(BaseModel):
|
tests/api/test_routes.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
| 3 |
|
| 4 |
from pathlib import Path
|
| 5 |
|
|
|
|
| 6 |
from fastapi.testclient import TestClient
|
| 7 |
|
| 8 |
from src.api.main import app
|
|
@@ -89,8 +90,14 @@ class TestBBBPredictRoute:
|
|
| 89 |
bbb_model.save(model, artifact)
|
| 90 |
return artifact
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch):
|
| 93 |
-
import pytest
|
| 94 |
artifact = self._setup_model_artifact(tmp_path)
|
| 95 |
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
|
| 96 |
|
|
@@ -120,6 +127,39 @@ class TestBBBPredictRoute:
|
|
| 120 |
assert isinstance(cal["support"], int)
|
| 121 |
assert cal["support"] >= 0
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
|
| 124 |
artifact = self._setup_model_artifact(tmp_path)
|
| 125 |
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
|
|
|
|
| 3 |
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
+
import pytest
|
| 7 |
from fastapi.testclient import TestClient
|
| 8 |
|
| 9 |
from src.api.main import app
|
|
|
|
| 90 |
bbb_model.save(model, artifact)
|
| 91 |
return artifact
|
| 92 |
|
| 93 |
+
@pytest.fixture
|
| 94 |
+
def _set_bbb_model_path(self, tmp_path: Path, monkeypatch):
|
| 95 |
+
"""Build a model artifact and point BBB_MODEL_PATH at it for the test."""
|
| 96 |
+
artifact = self._setup_model_artifact(tmp_path)
|
| 97 |
+
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
|
| 98 |
+
return artifact
|
| 99 |
+
|
| 100 |
def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch):
|
|
|
|
| 101 |
artifact = self._setup_model_artifact(tmp_path)
|
| 102 |
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
|
| 103 |
|
|
|
|
| 127 |
assert isinstance(cal["support"], int)
|
| 128 |
assert cal["support"] >= 0
|
| 129 |
|
| 130 |
+
def test_predict_response_includes_drift_z_and_rolling_n(
|
| 131 |
+
self, _set_bbb_model_path,
|
| 132 |
+
):
|
| 133 |
+
"""T1B: drift_z and rolling_n keys must always appear in the body."""
|
| 134 |
+
# Reset deque before this test so rolling_n starts deterministic.
|
| 135 |
+
from src.api import routes
|
| 136 |
+
routes.WORKER_CONFIDENCE_DEQUE.clear()
|
| 137 |
+
|
| 138 |
+
resp = client.post("/predict/bbb", json={"smiles": "CCO", "top_k": 5})
|
| 139 |
+
assert resp.status_code == 200, resp.text
|
| 140 |
+
body = resp.json()
|
| 141 |
+
assert "drift_z" in body
|
| 142 |
+
assert "rolling_n" in body
|
| 143 |
+
# First request: buffer has 1 sample (just appended), so warming up.
|
| 144 |
+
assert body["rolling_n"] == 1
|
| 145 |
+
assert body["drift_z"] is None # <10 samples = warming up
|
| 146 |
+
|
| 147 |
+
def test_predict_deque_rolls_at_100(self, _set_bbb_model_path):
|
| 148 |
+
"""T1B: after 100 predictions, deque caps at maxlen=100 (rolls)."""
|
| 149 |
+
from src.api import routes
|
| 150 |
+
routes.WORKER_CONFIDENCE_DEQUE.clear()
|
| 151 |
+
# Fire 105 calls; final rolling_n must be 100, not 105.
|
| 152 |
+
last_body = None
|
| 153 |
+
for _ in range(105):
|
| 154 |
+
resp = client.post(
|
| 155 |
+
"/predict/bbb", json={"smiles": "CCO", "top_k": 3},
|
| 156 |
+
)
|
| 157 |
+
assert resp.status_code == 200
|
| 158 |
+
last_body = resp.json()
|
| 159 |
+
assert last_body["rolling_n"] == 100
|
| 160 |
+
# By call 105, drift_z is computable (≥10 samples) — assert numeric.
|
| 161 |
+
assert isinstance(last_body["drift_z"], float)
|
| 162 |
+
|
| 163 |
def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
|
| 164 |
artifact = self._setup_model_artifact(tmp_path)
|
| 165 |
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
|