feat(api): POST /explain/bbb — natural-language rationale endpoint
Browse files- New explain_router with /explain prefix; symmetric with /predict/bbb
and reserves /explain/eeg, /explain/mri for future expansion.
- BBBExplainRequest carries the prediction snapshot + optional
user_question. top_features is required and must be non-empty
(Pydantic min_length=1 → 422 on empty).
- BBBExplainResponse: {rationale, source, model}. Always 200 because
the explainer's template fallback never raises.
- 1 new test: 200 + source='template' under NEUROBRIDGE_DISABLE_LLM=1
with full SHAP + calibration + drift payload.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- src/api/main.py +2 -1
- src/api/routes.py +42 -0
- src/api/schemas.py +28 -0
- tests/api/test_routes.py +31 -0
src/api/main.py
CHANGED
|
@@ -6,7 +6,7 @@ from __future__ import annotations
|
|
| 6 |
|
| 7 |
from fastapi import FastAPI
|
| 8 |
|
| 9 |
-
from src.api.routes import router as pipeline_router, predict_router
|
| 10 |
from src.api.schemas import HealthResponse
|
| 11 |
|
| 12 |
app = FastAPI(
|
|
@@ -17,6 +17,7 @@ app = FastAPI(
|
|
| 17 |
|
| 18 |
app.include_router(pipeline_router)
|
| 19 |
app.include_router(predict_router)
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
@app.get("/health", response_model=HealthResponse)
|
|
|
|
| 6 |
|
| 7 |
from fastapi import FastAPI
|
| 8 |
|
| 9 |
+
from src.api.routes import router as pipeline_router, predict_router, explain_router
|
| 10 |
from src.api.schemas import HealthResponse
|
| 11 |
|
| 12 |
app = FastAPI(
|
|
|
|
| 17 |
|
| 18 |
app.include_router(pipeline_router)
|
| 19 |
app.include_router(predict_router)
|
| 20 |
+
app.include_router(explain_router)
|
| 21 |
|
| 22 |
|
| 23 |
@app.get("/health", response_model=HealthResponse)
|
src/api/routes.py
CHANGED
|
@@ -18,6 +18,8 @@ import pandas as pd
|
|
| 18 |
from fastapi import APIRouter, HTTPException
|
| 19 |
|
| 20 |
from src.api.schemas import (
|
|
|
|
|
|
|
| 21 |
BBBPredictRequest,
|
| 22 |
BBBPredictResponse,
|
| 23 |
BBBRequest,
|
|
@@ -32,12 +34,14 @@ from src.api.schemas import (
|
|
| 32 |
PipelineResponse,
|
| 33 |
)
|
| 34 |
from src.core.logger import get_logger
|
|
|
|
| 35 |
from src.models import bbb_model
|
| 36 |
from src.pipelines import bbb_pipeline, eeg_pipeline, mri_pipeline
|
| 37 |
|
| 38 |
logger = get_logger(__name__)
|
| 39 |
router = APIRouter(prefix="/pipeline")
|
| 40 |
predict_router = APIRouter(prefix="/predict")
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
def _wrap(
|
|
@@ -320,3 +324,41 @@ def mri_diagnostics(req: MRIDiagnosticsRequest) -> MRIDiagnosticsResponse:
|
|
| 320 |
site_gap_post=site_gap_post,
|
| 321 |
reduction_factor=reduction_factor,
|
| 322 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from fastapi import APIRouter, HTTPException
|
| 19 |
|
| 20 |
from src.api.schemas import (
|
| 21 |
+
BBBExplainRequest,
|
| 22 |
+
BBBExplainResponse,
|
| 23 |
BBBPredictRequest,
|
| 24 |
BBBPredictResponse,
|
| 25 |
BBBRequest,
|
|
|
|
| 34 |
PipelineResponse,
|
| 35 |
)
|
| 36 |
from src.core.logger import get_logger
|
| 37 |
+
from src.llm import explainer as llm_explainer
|
| 38 |
from src.models import bbb_model
|
| 39 |
from src.pipelines import bbb_pipeline, eeg_pipeline, mri_pipeline
|
| 40 |
|
| 41 |
logger = get_logger(__name__)
|
| 42 |
router = APIRouter(prefix="/pipeline")
|
| 43 |
predict_router = APIRouter(prefix="/predict")
|
| 44 |
+
explain_router = APIRouter(prefix="/explain")
|
| 45 |
|
| 46 |
|
| 47 |
def _wrap(
|
|
|
|
| 324 |
site_gap_post=site_gap_post,
|
| 325 |
reduction_factor=reduction_factor,
|
| 326 |
)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@explain_router.post("/bbb", response_model=BBBExplainResponse)
|
| 330 |
+
def explain_bbb(req: BBBExplainRequest) -> BBBExplainResponse:
|
| 331 |
+
"""Natural-language rationale for a single BBB prediction.
|
| 332 |
+
|
| 333 |
+
Always returns 200 — the explainer is guaranteed to produce a
|
| 334 |
+
rationale via deterministic-template fallback. Pydantic enforces
|
| 335 |
+
a non-empty top_features list; an empty list returns 422 from
|
| 336 |
+
FastAPI before this handler runs.
|
| 337 |
+
"""
|
| 338 |
+
payload: llm_explainer.ExplainPayload = {
|
| 339 |
+
"smiles": req.smiles,
|
| 340 |
+
"label": req.label,
|
| 341 |
+
"label_text": req.label_text,
|
| 342 |
+
"confidence": req.confidence,
|
| 343 |
+
"top_features": [
|
| 344 |
+
{"feature": f.feature, "shap_value": f.shap_value}
|
| 345 |
+
for f in req.top_features
|
| 346 |
+
],
|
| 347 |
+
"calibration": (
|
| 348 |
+
None
|
| 349 |
+
if req.calibration is None
|
| 350 |
+
else {
|
| 351 |
+
"threshold": req.calibration.threshold,
|
| 352 |
+
"precision": req.calibration.precision,
|
| 353 |
+
"support": req.calibration.support,
|
| 354 |
+
}
|
| 355 |
+
),
|
| 356 |
+
"drift_z": req.drift_z,
|
| 357 |
+
"user_question": req.user_question or "",
|
| 358 |
+
}
|
| 359 |
+
result = llm_explainer.explain(payload)
|
| 360 |
+
return BBBExplainResponse(
|
| 361 |
+
rationale=result["rationale"],
|
| 362 |
+
source=result["source"],
|
| 363 |
+
model=result["model"],
|
| 364 |
+
)
|
src/api/schemas.py
CHANGED
|
@@ -133,3 +133,31 @@ class MRIDiagnosticsResponse(BaseModel):
|
|
| 133 |
site_gap_pre: float = Field(..., description="Range of per-site means before ComBat")
|
| 134 |
site_gap_post: float = Field(..., description="Range of per-site means after ComBat")
|
| 135 |
reduction_factor: float = Field(..., description="site_gap_pre / max(site_gap_post, eps)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
site_gap_pre: float = Field(..., description="Range of per-site means before ComBat")
|
| 134 |
site_gap_post: float = Field(..., description="Range of per-site means after ComBat")
|
| 135 |
reduction_factor: float = Field(..., description="site_gap_pre / max(site_gap_post, eps)")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class BBBExplainRequest(BaseModel):
|
| 139 |
+
"""Day-7 T3B: payload for POST /explain/bbb (chat-style explainer)."""
|
| 140 |
+
smiles: str = Field(..., description="SMILES string of the molecule")
|
| 141 |
+
label: int = Field(..., description="Predicted label (0 = non-permeable, 1 = permeable)")
|
| 142 |
+
label_text: str = Field(..., description="'permeable' or 'non-permeable'")
|
| 143 |
+
confidence: float = Field(..., ge=0.0, le=1.0)
|
| 144 |
+
top_features: list[FeatureAttribution] = Field(
|
| 145 |
+
..., min_length=1,
|
| 146 |
+
description="Non-empty list of SHAP attributions; an empty list returns 400.",
|
| 147 |
+
)
|
| 148 |
+
calibration: CalibrationContext | None = None
|
| 149 |
+
drift_z: float | None = None
|
| 150 |
+
user_question: str | None = Field(
|
| 151 |
+
None,
|
| 152 |
+
description="Optional question from the user; passed to the LLM prompt only.",
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class BBBExplainResponse(BaseModel):
|
| 157 |
+
"""Day-7 T3B: response from POST /explain/bbb."""
|
| 158 |
+
rationale: str = Field(..., description="2-4 sentence natural-language explanation")
|
| 159 |
+
source: str = Field(..., description="'llm' or 'template'")
|
| 160 |
+
model: str | None = Field(
|
| 161 |
+
None,
|
| 162 |
+
description="LLM model name when source='llm'; None when source='template'",
|
| 163 |
+
)
|
tests/api/test_routes.py
CHANGED
|
@@ -228,3 +228,34 @@ class TestMRIDiagnosticsRoute:
|
|
| 228 |
},
|
| 229 |
)
|
| 230 |
assert resp.status_code == 404
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
},
|
| 229 |
)
|
| 230 |
assert resp.status_code == 404
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class TestExplainBBBRoute:
|
| 234 |
+
"""Day-7 T3B: POST /explain/bbb."""
|
| 235 |
+
|
| 236 |
+
def test_returns_200_with_template_source(self, monkeypatch):
|
| 237 |
+
"""Kill-switch on → /explain/bbb returns rationale with source=template."""
|
| 238 |
+
monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
|
| 239 |
+
body = {
|
| 240 |
+
"smiles": "CCO",
|
| 241 |
+
"label": 1,
|
| 242 |
+
"label_text": "permeable",
|
| 243 |
+
"confidence": 0.82,
|
| 244 |
+
"top_features": [
|
| 245 |
+
{"feature": "fp_341", "shap_value": 0.045},
|
| 246 |
+
{"feature": "fp_902", "shap_value": -0.031},
|
| 247 |
+
{"feature": "fp_77", "shap_value": 0.022},
|
| 248 |
+
],
|
| 249 |
+
"calibration": {"threshold": 0.80, "precision": 0.92, "support": 18},
|
| 250 |
+
"drift_z": 0.42,
|
| 251 |
+
"user_question": "Why permeable?",
|
| 252 |
+
}
|
| 253 |
+
resp = client.post("/explain/bbb", json=body)
|
| 254 |
+
assert resp.status_code == 200, resp.text
|
| 255 |
+
out = resp.json()
|
| 256 |
+
assert out["source"] == "template"
|
| 257 |
+
assert out["model"] is None
|
| 258 |
+
# Template must mention all three features
|
| 259 |
+
for feat in ("fp_341", "fp_902", "fp_77"):
|
| 260 |
+
assert feat in out["rationale"]
|
| 261 |
+
assert "permeable" in out["rationale"]
|