feat(api): expose calibration bin in /predict/bbb response
Browse files- Adds CalibrationContext schema (threshold/precision/support).
- BBBPredictResponse gains optional calibration field; populated by
_matching_calibration_bin helper that picks the highest-threshold
bin whose threshold <= confidence.
- Returns None for legacy models without _neurobridge_calibration or
when confidence < lowest threshold (< 0.50).
- Extends existing 200-happy-path test with calibration assertions.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- src/api/routes.py +23 -0
- src/api/schemas.py +11 -0
- tests/api/test_routes.py +13 -0
src/api/routes.py
CHANGED
|
@@ -20,6 +20,7 @@ from src.api.schemas import (
|
|
| 20 |
BBBPredictRequest,
|
| 21 |
BBBPredictResponse,
|
| 22 |
BBBRequest,
|
|
|
|
| 23 |
EEGRequest,
|
| 24 |
FeatureAttribution,
|
| 25 |
MRIRequest,
|
|
@@ -126,6 +127,26 @@ def _bbb_model_path() -> Path:
|
|
| 126 |
return Path(os.environ.get("BBB_MODEL_PATH", str(_DEFAULT_BBB_MODEL_PATH)))
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
@predict_router.post("/bbb", response_model=BBBPredictResponse)
|
| 130 |
def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
|
| 131 |
"""Predict BBB permeability + return SHAP attributions for one SMILES.
|
|
@@ -155,9 +176,11 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
|
|
| 155 |
raise HTTPException(status_code=400, detail=str(e))
|
| 156 |
|
| 157 |
label_text = "permeable" if pred["label"] == 1 else "non-permeable"
|
|
|
|
| 158 |
return BBBPredictResponse(
|
| 159 |
label=pred["label"],
|
| 160 |
label_text=label_text,
|
| 161 |
confidence=pred["confidence"],
|
| 162 |
top_features=[FeatureAttribution(**a) for a in attributions],
|
|
|
|
| 163 |
)
|
|
|
|
| 20 |
BBBPredictRequest,
|
| 21 |
BBBPredictResponse,
|
| 22 |
BBBRequest,
|
| 23 |
+
CalibrationContext,
|
| 24 |
EEGRequest,
|
| 25 |
FeatureAttribution,
|
| 26 |
MRIRequest,
|
|
|
|
| 127 |
return Path(os.environ.get("BBB_MODEL_PATH", str(_DEFAULT_BBB_MODEL_PATH)))
|
| 128 |
|
| 129 |
|
| 130 |
+
def _matching_calibration_bin(model, confidence: float) -> CalibrationContext | None:
|
| 131 |
+
"""Pick the highest-threshold bin whose threshold <= confidence. None if no match or no metadata."""
|
| 132 |
+
bins = getattr(model, "_neurobridge_calibration", None)
|
| 133 |
+
if not bins:
|
| 134 |
+
return None
|
| 135 |
+
matched = None
|
| 136 |
+
for bin_ in bins:
|
| 137 |
+
if bin_["threshold"] <= confidence:
|
| 138 |
+
matched = bin_
|
| 139 |
+
else:
|
| 140 |
+
break
|
| 141 |
+
if matched is None:
|
| 142 |
+
return None
|
| 143 |
+
return CalibrationContext(
|
| 144 |
+
threshold=matched["threshold"],
|
| 145 |
+
precision=matched["precision"],
|
| 146 |
+
support=matched["support"],
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
@predict_router.post("/bbb", response_model=BBBPredictResponse)
|
| 151 |
def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
|
| 152 |
"""Predict BBB permeability + return SHAP attributions for one SMILES.
|
|
|
|
| 176 |
raise HTTPException(status_code=400, detail=str(e))
|
| 177 |
|
| 178 |
label_text = "permeable" if pred["label"] == 1 else "non-permeable"
|
| 179 |
+
calibration = _matching_calibration_bin(model, pred["confidence"])
|
| 180 |
return BBBPredictResponse(
|
| 181 |
label=pred["label"],
|
| 182 |
label_text=label_text,
|
| 183 |
confidence=pred["confidence"],
|
| 184 |
top_features=[FeatureAttribution(**a) for a in attributions],
|
| 185 |
+
calibration=calibration,
|
| 186 |
)
|
src/api/schemas.py
CHANGED
|
@@ -63,9 +63,20 @@ class FeatureAttribution(BaseModel):
|
|
| 63 |
)
|
| 64 |
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
class BBBPredictResponse(BaseModel):
|
| 67 |
"""Decision-system payload: prediction + uncertainty + explanation."""
|
| 68 |
label: int
|
| 69 |
label_text: str = Field(..., description="'permeable' or 'non-permeable'")
|
| 70 |
confidence: float
|
| 71 |
top_features: list[FeatureAttribution]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
|
| 66 |
+
class CalibrationContext(BaseModel):
|
| 67 |
+
"""Precision-at-confidence-threshold bin matched to a single prediction."""
|
| 68 |
+
threshold: float = Field(..., description="Lowest confidence threshold this bin covers (0.0-1.0)")
|
| 69 |
+
precision: float = Field(..., description="Precision on the held-out test set among predictions ≥ threshold")
|
| 70 |
+
support: int = Field(..., description="Number of held-out predictions falling in this bin")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
class BBBPredictResponse(BaseModel):
|
| 74 |
"""Decision-system payload: prediction + uncertainty + explanation."""
|
| 75 |
label: int
|
| 76 |
label_text: str = Field(..., description="'permeable' or 'non-permeable'")
|
| 77 |
confidence: float
|
| 78 |
top_features: list[FeatureAttribution]
|
| 79 |
+
calibration: CalibrationContext | None = Field(
|
| 80 |
+
None,
|
| 81 |
+
description="Statistical context: how often the model is right when this confident on held-out data.",
|
| 82 |
+
)
|
tests/api/test_routes.py
CHANGED
|
@@ -90,6 +90,7 @@ class TestBBBPredictRoute:
|
|
| 90 |
return artifact
|
| 91 |
|
| 92 |
def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch):
|
|
|
|
| 93 |
artifact = self._setup_model_artifact(tmp_path)
|
| 94 |
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
|
| 95 |
|
|
@@ -106,6 +107,18 @@ class TestBBBPredictRoute:
|
|
| 106 |
for f in body["top_features"]:
|
| 107 |
assert f["feature"].startswith("fp_")
|
| 108 |
assert isinstance(f["shap_value"], float)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
|
| 111 |
artifact = self._setup_model_artifact(tmp_path)
|
|
|
|
| 90 |
return artifact
|
| 91 |
|
| 92 |
def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch):
|
| 93 |
+
import pytest
|
| 94 |
artifact = self._setup_model_artifact(tmp_path)
|
| 95 |
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
|
| 96 |
|
|
|
|
| 107 |
for f in body["top_features"]:
|
| 108 |
assert f["feature"].startswith("fp_")
|
| 109 |
assert isinstance(f["shap_value"], float)
|
| 110 |
+
# Day-6 calibration assertions: trained test fixture model has
|
| 111 |
+
# _neurobridge_calibration metadata, so calibration must be populated.
|
| 112 |
+
assert body["calibration"] is not None
|
| 113 |
+
cal = body["calibration"]
|
| 114 |
+
valid_thresholds = [0.50, 0.60, 0.70, 0.75, 0.80, 0.90]
|
| 115 |
+
assert any(
|
| 116 |
+
cal["threshold"] == pytest.approx(t) for t in valid_thresholds
|
| 117 |
+
), f"threshold {cal['threshold']} not in {valid_thresholds}"
|
| 118 |
+
assert cal["threshold"] <= body["confidence"]
|
| 119 |
+
assert 0.0 <= cal["precision"] <= 1.0
|
| 120 |
+
assert isinstance(cal["support"], int)
|
| 121 |
+
assert cal["support"] >= 0
|
| 122 |
|
| 123 |
def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
|
| 124 |
artifact = self._setup_model_artifact(tmp_path)
|