mekosotto Claude Opus 4.7 (1M context) commited on
Commit
42366a8
·
1 Parent(s): 90167c7

feat(api): expose calibration bin in /predict/bbb response

Browse files

- Adds CalibrationContext schema (threshold/precision/support).
- BBBPredictResponse gains optional calibration field; populated by
_matching_calibration_bin helper that picks the highest-threshold
bin whose threshold <= confidence.
- Returns None for legacy models without _neurobridge_calibration or
when confidence < lowest threshold (< 0.50).
- Extends existing 200-happy-path test with calibration assertions.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (3) hide show
  1. src/api/routes.py +23 -0
  2. src/api/schemas.py +11 -0
  3. tests/api/test_routes.py +13 -0
src/api/routes.py CHANGED
@@ -20,6 +20,7 @@ from src.api.schemas import (
20
  BBBPredictRequest,
21
  BBBPredictResponse,
22
  BBBRequest,
 
23
  EEGRequest,
24
  FeatureAttribution,
25
  MRIRequest,
@@ -126,6 +127,26 @@ def _bbb_model_path() -> Path:
126
  return Path(os.environ.get("BBB_MODEL_PATH", str(_DEFAULT_BBB_MODEL_PATH)))
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @predict_router.post("/bbb", response_model=BBBPredictResponse)
130
  def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
131
  """Predict BBB permeability + return SHAP attributions for one SMILES.
@@ -155,9 +176,11 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
155
  raise HTTPException(status_code=400, detail=str(e))
156
 
157
  label_text = "permeable" if pred["label"] == 1 else "non-permeable"
 
158
  return BBBPredictResponse(
159
  label=pred["label"],
160
  label_text=label_text,
161
  confidence=pred["confidence"],
162
  top_features=[FeatureAttribution(**a) for a in attributions],
 
163
  )
 
20
  BBBPredictRequest,
21
  BBBPredictResponse,
22
  BBBRequest,
23
+ CalibrationContext,
24
  EEGRequest,
25
  FeatureAttribution,
26
  MRIRequest,
 
127
  return Path(os.environ.get("BBB_MODEL_PATH", str(_DEFAULT_BBB_MODEL_PATH)))
128
 
129
 
130
+ def _matching_calibration_bin(model, confidence: float) -> CalibrationContext | None:
131
+ """Pick the highest-threshold bin whose threshold <= confidence. None if no match or no metadata."""
132
+ bins = getattr(model, "_neurobridge_calibration", None)
133
+ if not bins:
134
+ return None
135
+ matched = None
136
+ for bin_ in bins:
137
+ if bin_["threshold"] <= confidence:
138
+ matched = bin_
139
+ else:
140
+ break
141
+ if matched is None:
142
+ return None
143
+ return CalibrationContext(
144
+ threshold=matched["threshold"],
145
+ precision=matched["precision"],
146
+ support=matched["support"],
147
+ )
148
+
149
+
150
  @predict_router.post("/bbb", response_model=BBBPredictResponse)
151
  def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
152
  """Predict BBB permeability + return SHAP attributions for one SMILES.
 
176
  raise HTTPException(status_code=400, detail=str(e))
177
 
178
  label_text = "permeable" if pred["label"] == 1 else "non-permeable"
179
+ calibration = _matching_calibration_bin(model, pred["confidence"])
180
  return BBBPredictResponse(
181
  label=pred["label"],
182
  label_text=label_text,
183
  confidence=pred["confidence"],
184
  top_features=[FeatureAttribution(**a) for a in attributions],
185
+ calibration=calibration,
186
  )
src/api/schemas.py CHANGED
@@ -63,9 +63,20 @@ class FeatureAttribution(BaseModel):
63
  )
64
 
65
 
 
 
 
 
 
 
 
66
  class BBBPredictResponse(BaseModel):
67
  """Decision-system payload: prediction + uncertainty + explanation."""
68
  label: int
69
  label_text: str = Field(..., description="'permeable' or 'non-permeable'")
70
  confidence: float
71
  top_features: list[FeatureAttribution]
 
 
 
 
 
63
  )
64
 
65
 
66
+ class CalibrationContext(BaseModel):
67
+ """Precision-at-confidence-threshold bin matched to a single prediction."""
68
+ threshold: float = Field(..., description="Lowest confidence threshold this bin covers (0.0-1.0)")
69
+ precision: float = Field(..., description="Precision on the held-out test set among predictions ≥ threshold")
70
+ support: int = Field(..., description="Number of held-out predictions falling in this bin")
71
+
72
+
73
  class BBBPredictResponse(BaseModel):
74
  """Decision-system payload: prediction + uncertainty + explanation."""
75
  label: int
76
  label_text: str = Field(..., description="'permeable' or 'non-permeable'")
77
  confidence: float
78
  top_features: list[FeatureAttribution]
79
+ calibration: CalibrationContext | None = Field(
80
+ None,
81
+ description="Statistical context: how often the model is right when this confident on held-out data.",
82
+ )
tests/api/test_routes.py CHANGED
@@ -90,6 +90,7 @@ class TestBBBPredictRoute:
90
  return artifact
91
 
92
  def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch):
 
93
  artifact = self._setup_model_artifact(tmp_path)
94
  monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
95
 
@@ -106,6 +107,18 @@ class TestBBBPredictRoute:
106
  for f in body["top_features"]:
107
  assert f["feature"].startswith("fp_")
108
  assert isinstance(f["shap_value"], float)
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
111
  artifact = self._setup_model_artifact(tmp_path)
 
90
  return artifact
91
 
92
  def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch):
93
+ import pytest
94
  artifact = self._setup_model_artifact(tmp_path)
95
  monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
96
 
 
107
  for f in body["top_features"]:
108
  assert f["feature"].startswith("fp_")
109
  assert isinstance(f["shap_value"], float)
110
+ # Day-6 calibration assertions: trained test fixture model has
111
+ # _neurobridge_calibration metadata, so calibration must be populated.
112
+ assert body["calibration"] is not None
113
+ cal = body["calibration"]
114
+ valid_thresholds = [0.50, 0.60, 0.70, 0.75, 0.80, 0.90]
115
+ assert any(
116
+ cal["threshold"] == pytest.approx(t) for t in valid_thresholds
117
+ ), f"threshold {cal['threshold']} not in {valid_thresholds}"
118
+ assert cal["threshold"] <= body["confidence"]
119
+ assert 0.0 <= cal["precision"] <= 1.0
120
+ assert isinstance(cal["support"], int)
121
+ assert cal["support"] >= 0
122
 
123
  def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
124
  artifact = self._setup_model_artifact(tmp_path)