feat(api+frontend): MLflow provenance badge in decision card
Browse files- ModelProvenance schema (mlflow_run_id, model_version, train_date,
n_examples). BBBPredictResponse.provenance is always populated; failed
MLflow lookup degrades to None fields without breaking the response.
- _build_provenance() module-level cache: one MLflow query per worker.
NEUROBRIDGE_DISABLE_MLFLOW=1 short-circuits to None fields. n_examples
pulled per-request from model._neurobridge_train_stats.
- Streamlit decision card renders a one-line audit badge above the
label: run id (first 8 chars), model version, train date, n_examples.
- 1 new test: provenance field present in /predict/bbb body with the
fixture model (n_examples ≥ 1 from train stats).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- src/api/routes.py +56 -0
- src/api/schemas.py +12 -0
- src/frontend/app.py +12 -0
- tests/api/test_routes.py +19 -0
src/api/routes.py
CHANGED
|
@@ -25,6 +25,7 @@ from src.api.schemas import (
|
|
| 25 |
EEGRequest,
|
| 26 |
FeatureAttribution,
|
| 27 |
HarmonizationRow,
|
|
|
|
| 28 |
MRIDiagnosticsRequest,
|
| 29 |
MRIDiagnosticsResponse,
|
| 30 |
MRIRequest,
|
|
@@ -160,6 +161,59 @@ def _compute_drift_z(model, confidence: float) -> tuple[float | None, int]:
|
|
| 160 |
return float(drift_z), rolling_n
|
| 161 |
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
def _matching_calibration_bin(model, confidence: float) -> CalibrationContext | None:
|
| 164 |
"""Pick the highest-threshold bin whose threshold <= confidence. None if no match or no metadata."""
|
| 165 |
bins = getattr(model, "_neurobridge_calibration", None)
|
|
@@ -211,6 +265,7 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
|
|
| 211 |
label_text = "permeable" if pred["label"] == 1 else "non-permeable"
|
| 212 |
calibration = _matching_calibration_bin(model, pred["confidence"])
|
| 213 |
drift_z, rolling_n = _compute_drift_z(model, pred["confidence"])
|
|
|
|
| 214 |
return BBBPredictResponse(
|
| 215 |
label=pred["label"],
|
| 216 |
label_text=label_text,
|
|
@@ -219,6 +274,7 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
|
|
| 219 |
calibration=calibration,
|
| 220 |
drift_z=drift_z,
|
| 221 |
rolling_n=rolling_n,
|
|
|
|
| 222 |
)
|
| 223 |
|
| 224 |
|
|
|
|
| 25 |
EEGRequest,
|
| 26 |
FeatureAttribution,
|
| 27 |
HarmonizationRow,
|
| 28 |
+
ModelProvenance,
|
| 29 |
MRIDiagnosticsRequest,
|
| 30 |
MRIDiagnosticsResponse,
|
| 31 |
MRIRequest,
|
|
|
|
| 161 |
return float(drift_z), rolling_n
|
| 162 |
|
| 163 |
|
| 164 |
+
_PROVENANCE_CACHE: ModelProvenance | None = None
|
| 165 |
+
_MODEL_VERSION = "v1" # bump manually per train cycle
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _build_provenance(model) -> ModelProvenance:
|
| 169 |
+
"""Look up the most recent BBB MLflow run; build a ModelProvenance.
|
| 170 |
+
|
| 171 |
+
Cached at module level so we hit MLflow once per worker. Failures (no
|
| 172 |
+
runs found, MLflow unreachable, NEUROBRIDGE_DISABLE_MLFLOW=1) all
|
| 173 |
+
degrade to a partial ModelProvenance with mlflow_run_id=None — the
|
| 174 |
+
badge still renders, just without a run id.
|
| 175 |
+
"""
|
| 176 |
+
global _PROVENANCE_CACHE
|
| 177 |
+
if _PROVENANCE_CACHE is not None:
|
| 178 |
+
# Refresh n_examples each call from the model (cheap lookup).
|
| 179 |
+
n_train = None
|
| 180 |
+
stats = getattr(model, "_neurobridge_train_stats", None)
|
| 181 |
+
if stats is not None:
|
| 182 |
+
n_train = int(stats.get("n_train", 0)) or None
|
| 183 |
+
return _PROVENANCE_CACHE.model_copy(update={"n_examples": n_train})
|
| 184 |
+
|
| 185 |
+
run_id: str | None = None
|
| 186 |
+
train_date: str | None = None
|
| 187 |
+
if os.environ.get("NEUROBRIDGE_DISABLE_MLFLOW") != "1":
|
| 188 |
+
try:
|
| 189 |
+
runs = mlflow.search_runs(
|
| 190 |
+
experiment_names=["bbb_pipeline"],
|
| 191 |
+
max_results=1,
|
| 192 |
+
order_by=["start_time DESC"],
|
| 193 |
+
)
|
| 194 |
+
if len(runs):
|
| 195 |
+
row = runs.iloc[0]
|
| 196 |
+
run_id = str(row["run_id"])
|
| 197 |
+
ts = row.get("start_time")
|
| 198 |
+
if ts is not None:
|
| 199 |
+
train_date = str(pd.Timestamp(ts).isoformat())
|
| 200 |
+
except Exception as e: # broad: MLflow store unreachable, schema mismatch, etc.
|
| 201 |
+
logger.warning("MLflow provenance lookup failed: %s", e)
|
| 202 |
+
|
| 203 |
+
n_train = None
|
| 204 |
+
stats = getattr(model, "_neurobridge_train_stats", None)
|
| 205 |
+
if stats is not None:
|
| 206 |
+
n_train = int(stats.get("n_train", 0)) or None
|
| 207 |
+
|
| 208 |
+
_PROVENANCE_CACHE = ModelProvenance(
|
| 209 |
+
mlflow_run_id=run_id,
|
| 210 |
+
model_version=_MODEL_VERSION,
|
| 211 |
+
train_date=train_date,
|
| 212 |
+
n_examples=n_train,
|
| 213 |
+
)
|
| 214 |
+
return _PROVENANCE_CACHE
|
| 215 |
+
|
| 216 |
+
|
| 217 |
def _matching_calibration_bin(model, confidence: float) -> CalibrationContext | None:
|
| 218 |
"""Pick the highest-threshold bin whose threshold <= confidence. None if no match or no metadata."""
|
| 219 |
bins = getattr(model, "_neurobridge_calibration", None)
|
|
|
|
| 265 |
label_text = "permeable" if pred["label"] == 1 else "non-permeable"
|
| 266 |
calibration = _matching_calibration_bin(model, pred["confidence"])
|
| 267 |
drift_z, rolling_n = _compute_drift_z(model, pred["confidence"])
|
| 268 |
+
provenance = _build_provenance(model)
|
| 269 |
return BBBPredictResponse(
|
| 270 |
label=pred["label"],
|
| 271 |
label_text=label_text,
|
|
|
|
| 274 |
calibration=calibration,
|
| 275 |
drift_z=drift_z,
|
| 276 |
rolling_n=rolling_n,
|
| 277 |
+
provenance=provenance,
|
| 278 |
)
|
| 279 |
|
| 280 |
|
src/api/schemas.py
CHANGED
|
@@ -70,6 +70,14 @@ class CalibrationContext(BaseModel):
|
|
| 70 |
support: int = Field(..., description="Number of held-out predictions falling in this bin")
|
| 71 |
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
class BBBPredictResponse(BaseModel):
|
| 74 |
"""Decision-system payload: prediction + uncertainty + explanation + drift."""
|
| 75 |
label: int
|
|
@@ -95,6 +103,10 @@ class BBBPredictResponse(BaseModel):
|
|
| 95 |
"rolling window (max 100). Zero on a fresh worker."
|
| 96 |
),
|
| 97 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
class MRIDiagnosticsRequest(BaseModel):
|
|
|
|
| 70 |
support: int = Field(..., description="Number of held-out predictions falling in this bin")
|
| 71 |
|
| 72 |
|
| 73 |
+
class ModelProvenance(BaseModel):
|
| 74 |
+
"""Auditable provenance of the BBB model that produced a prediction."""
|
| 75 |
+
mlflow_run_id: str | None = Field(None, description="MLflow run id of the most recent training run, if any")
|
| 76 |
+
model_version: str = Field("v1", description="Manually-bumped model version label")
|
| 77 |
+
train_date: str | None = Field(None, description="ISO 8601 train timestamp from MLflow run start_time")
|
| 78 |
+
n_examples: int | None = Field(None, description="Training set size (from model._neurobridge_train_stats[\"n_train\"])")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
class BBBPredictResponse(BaseModel):
|
| 82 |
"""Decision-system payload: prediction + uncertainty + explanation + drift."""
|
| 83 |
label: int
|
|
|
|
| 103 |
"rolling window (max 100). Zero on a fresh worker."
|
| 104 |
),
|
| 105 |
)
|
| 106 |
+
provenance: ModelProvenance | None = Field(
|
| 107 |
+
None,
|
| 108 |
+
description="Auditing metadata (MLflow run id, train date, n_examples).",
|
| 109 |
+
)
|
| 110 |
|
| 111 |
|
| 112 |
class MRIDiagnosticsRequest(BaseModel):
|
src/frontend/app.py
CHANGED
|
@@ -457,6 +457,18 @@ def _render_mri_tab() -> None:
|
|
| 457 |
def _render_prediction_card(result: dict) -> None:
|
| 458 |
"""Render a B2B-styled decision card: label badge + confidence + SHAP bars."""
|
| 459 |
st.session_state["last_bbb_prediction"] = result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
label_text = _html.escape(str(result["label_text"]))
|
| 461 |
badge_color = "#166534" if result["label"] == 1 else "#991B1B"
|
| 462 |
badge_bg = "#DCFCE7" if result["label"] == 1 else "#FEE2E2"
|
|
|
|
| 457 |
def _render_prediction_card(result: dict) -> None:
|
| 458 |
"""Render a B2B-styled decision card: label badge + confidence + SHAP bars."""
|
| 459 |
st.session_state["last_bbb_prediction"] = result
|
| 460 |
+
provenance = result.get("provenance")
|
| 461 |
+
if provenance is not None:
|
| 462 |
+
run_id = provenance.get("mlflow_run_id")
|
| 463 |
+
run_label = run_id[:8] if run_id else "—"
|
| 464 |
+
train_date = provenance.get("train_date") or "—"
|
| 465 |
+
n_examples = provenance.get("n_examples")
|
| 466 |
+
n_label = f"n={n_examples}" if n_examples else "n=—"
|
| 467 |
+
st.caption(
|
| 468 |
+
f"🔎 MLflow run **{run_label}** · "
|
| 469 |
+
f"Model **{provenance.get('model_version', 'v1')}** · "
|
| 470 |
+
f"trained {train_date} · {n_label}"
|
| 471 |
+
)
|
| 472 |
label_text = _html.escape(str(result["label_text"]))
|
| 473 |
badge_color = "#166534" if result["label"] == 1 else "#991B1B"
|
| 474 |
badge_bg = "#DCFCE7" if result["label"] == 1 else "#FEE2E2"
|
tests/api/test_routes.py
CHANGED
|
@@ -160,6 +160,25 @@ class TestBBBPredictRoute:
|
|
| 160 |
# By call 105, drift_z is computable (≥10 samples) — assert numeric.
|
| 161 |
assert isinstance(last_body["drift_z"], float)
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
|
| 164 |
artifact = self._setup_model_artifact(tmp_path)
|
| 165 |
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
|
|
|
|
| 160 |
# By call 105, drift_z is computable (≥10 samples) — assert numeric.
|
| 161 |
assert isinstance(last_body["drift_z"], float)
|
| 162 |
|
| 163 |
+
def test_predict_response_includes_provenance(self, _set_bbb_model_path):
|
| 164 |
+
"""T2: provenance field is present in body (fields may be None)."""
|
| 165 |
+
from src.api import routes
|
| 166 |
+
routes.WORKER_CONFIDENCE_DEQUE.clear()
|
| 167 |
+
|
| 168 |
+
resp = client.post("/predict/bbb", json={"smiles": "CCO", "top_k": 3})
|
| 169 |
+
assert resp.status_code == 200, resp.text
|
| 170 |
+
body = resp.json()
|
| 171 |
+
assert "provenance" in body
|
| 172 |
+
assert body["provenance"] is not None, "provenance should be populated even when MLflow is empty"
|
| 173 |
+
prov = body["provenance"]
|
| 174 |
+
assert "mlflow_run_id" in prov
|
| 175 |
+
assert "model_version" in prov
|
| 176 |
+
assert prov["model_version"] == "v1" # default until bumped manually
|
| 177 |
+
assert "train_date" in prov
|
| 178 |
+
assert "n_examples" in prov
|
| 179 |
+
# n_examples comes from train_stats — must be a positive int for the test fixture
|
| 180 |
+
assert isinstance(prov["n_examples"], int) and prov["n_examples"] >= 1
|
| 181 |
+
|
| 182 |
def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
|
| 183 |
artifact = self._setup_model_artifact(tmp_path)
|
| 184 |
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
|