feat(llm): modality dispatch — explain(payload, modality) for BBB/EEG/MRI
Browse files- explain() gains modality kwarg ('bbb' | 'eeg' | 'mri'), default 'bbb'
for backward compat with Day-7 callers.
- _template_explain renamed to _template_explain_bbb; added
_template_explain_eeg (epochs, features, ICA story) and
_template_explain_mri (site-gap pre/post, reduction factor).
- _build_llm_prompt branches on modality with a domain-specific header
+ body. Unknown modality logs warning and falls back to BBB template.
- ExplainPayload loosened from strict TypedDict to dict[str, Any] since
shapes differ across modalities.
- 3 new tests (TestEEGTemplate, TestMRITemplate, TestModalityDispatch).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- src/llm/explainer.py +128 -44
- tests/llm/test_explainer.py +60 -0
src/llm/explainer.py
CHANGED
|
@@ -29,15 +29,7 @@ class CalibrationDict(TypedDict):
|
|
| 29 |
support: int
|
| 30 |
|
| 31 |
|
| 32 |
-
|
| 33 |
-
smiles: str
|
| 34 |
-
label: int
|
| 35 |
-
label_text: str
|
| 36 |
-
confidence: float
|
| 37 |
-
top_features: list[FeatureRow]
|
| 38 |
-
calibration: CalibrationDict | None
|
| 39 |
-
drift_z: float | None
|
| 40 |
-
user_question: str
|
| 41 |
|
| 42 |
|
| 43 |
class ExplainResult(TypedDict):
|
|
@@ -73,8 +65,8 @@ def _drift_interpretation(drift_z: float | None) -> str:
|
|
| 73 |
return "significant shift, retrain recommended"
|
| 74 |
|
| 75 |
|
| 76 |
-
def
|
| 77 |
-
"""Deterministic, jury-friendly rationale
|
| 78 |
label_text = payload.get("label_text", "unknown")
|
| 79 |
confidence = float(payload.get("confidence", 0.0))
|
| 80 |
top_features = payload.get("top_features") or []
|
|
@@ -119,37 +111,117 @@ def _template_explain(payload: ExplainPayload) -> str:
|
|
| 119 |
return " ".join(sentences)
|
| 120 |
|
| 121 |
|
| 122 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
"""Format the payload + user question into a single LLM prompt."""
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
return (
|
| 135 |
-
"
|
| 136 |
-
"
|
| 137 |
-
"
|
| 138 |
-
"
|
| 139 |
-
"
|
| 140 |
-
f"
|
| 141 |
-
f"- SMILES: {payload.get('smiles', '?')}\n"
|
| 142 |
-
f"- Verdict: {payload.get('label_text', '?')} "
|
| 143 |
-
f"({float(payload.get('confidence', 0.0)) * 100:.0f}% confident)\n"
|
| 144 |
-
f"- Top SHAP features (positive = pushed toward verdict):\n"
|
| 145 |
-
f"{top_lines}\n"
|
| 146 |
-
f"- Drift z-score: {drift_str}\n"
|
| 147 |
-
f"\nUser question: {user_q}\n"
|
| 148 |
-
f"\nRespond with the rationale only, no preamble."
|
| 149 |
)
|
| 150 |
|
| 151 |
|
| 152 |
-
def _llm_explain(payload: ExplainPayload) -> tuple[str, str] | None:
|
| 153 |
"""Try the OpenRouter chat completion. Return (rationale, model) or None."""
|
| 154 |
try:
|
| 155 |
# Local import — keeps this dep optional at module load time.
|
|
@@ -167,7 +239,7 @@ def _llm_explain(payload: ExplainPayload) -> tuple[str, str] | None:
|
|
| 167 |
api_key=api_key,
|
| 168 |
timeout=_LLM_TIMEOUT_SECONDS,
|
| 169 |
)
|
| 170 |
-
prompt = _build_llm_prompt(payload)
|
| 171 |
try:
|
| 172 |
completion = client.chat.completions.create(
|
| 173 |
model=_DEFAULT_MODEL,
|
|
@@ -192,20 +264,32 @@ def _llm_explain(payload: ExplainPayload) -> tuple[str, str] | None:
|
|
| 192 |
return text.strip(), _DEFAULT_MODEL
|
| 193 |
|
| 194 |
|
| 195 |
-
def explain(
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
|
| 198 |
-
|
| 199 |
-
template
|
|
|
|
| 200 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
if _should_use_llm():
|
| 202 |
-
llm_out: Any = _llm_explain(payload)
|
| 203 |
if llm_out is not None:
|
| 204 |
rationale, model = llm_out
|
| 205 |
return ExplainResult(rationale=rationale, source="llm", model=model)
|
| 206 |
# else: fall through to template
|
|
|
|
|
|
|
| 207 |
return ExplainResult(
|
| 208 |
-
rationale=
|
| 209 |
source="template",
|
| 210 |
model=None,
|
| 211 |
)
|
|
|
|
| 29 |
support: int
|
| 30 |
|
| 31 |
|
| 32 |
+
ExplainPayload = dict[str, Any] # Heterogeneous: BBB / EEG / MRI shapes differ.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
class ExplainResult(TypedDict):
|
|
|
|
| 65 |
return "significant shift, retrain recommended"
|
| 66 |
|
| 67 |
|
| 68 |
+
def _template_explain_bbb(payload: ExplainPayload) -> str:
|
| 69 |
+
"""Deterministic, jury-friendly rationale for a single BBB prediction."""
|
| 70 |
label_text = payload.get("label_text", "unknown")
|
| 71 |
confidence = float(payload.get("confidence", 0.0))
|
| 72 |
top_features = payload.get("top_features") or []
|
|
|
|
| 111 |
return " ".join(sentences)
|
| 112 |
|
| 113 |
|
| 114 |
+
def _template_explain_eeg(payload: ExplainPayload) -> str:
|
| 115 |
+
"""Deterministic rationale for an EEG pipeline run."""
|
| 116 |
+
rows = payload.get("rows", 0)
|
| 117 |
+
columns = payload.get("columns", 0)
|
| 118 |
+
duration = float(payload.get("duration_sec", 0.0))
|
| 119 |
+
run_id = payload.get("mlflow_run_id") or "—"
|
| 120 |
+
sentences = [
|
| 121 |
+
f"EEG pipeline produced **{rows}** epochs × **{columns}** features "
|
| 122 |
+
f"in {duration:.1f}s.",
|
| 123 |
+
"ICA decomposed the signal and dropped components whose absolute "
|
| 124 |
+
"EOG correlation exceeded 0.5 (eye-blink artifacts).",
|
| 125 |
+
"Bandpass filter 0.5-40 Hz removed line noise and DC drift before ICA.",
|
| 126 |
+
f"Run id: `{run_id}` (use the Experiments tab to compare against "
|
| 127 |
+
"previous runs).",
|
| 128 |
+
]
|
| 129 |
+
return " ".join(sentences)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _template_explain_mri(payload: ExplainPayload) -> str:
|
| 133 |
+
"""Deterministic rationale for an MRI ComBat-harmonization diagnostic."""
|
| 134 |
+
pre = float(payload.get("site_gap_pre", 0.0))
|
| 135 |
+
post = float(payload.get("site_gap_post", 0.0))
|
| 136 |
+
factor = float(payload.get("reduction_factor", 0.0))
|
| 137 |
+
n_subjects = int(payload.get("n_subjects", 0))
|
| 138 |
+
sentences = [
|
| 139 |
+
f"ComBat harmonization reduced the per-site mean gap from "
|
| 140 |
+
f"**{pre:.4f}** to **{post:.4f}** — a **{factor:.0f}×** collapse "
|
| 141 |
+
f"across **{n_subjects}** subjects on the first feature.",
|
| 142 |
+
"This is the quantified proof that scanner / acquisition-site bias "
|
| 143 |
+
"was removed: predictions trained on the harmonized features "
|
| 144 |
+
"generalize across hospitals instead of memorizing site identity.",
|
| 145 |
+
"The visual evidence is the per-site KDE convergence in the "
|
| 146 |
+
"Pre-ComBat → Post-ComBat panels (Streamlit MRI tab).",
|
| 147 |
+
]
|
| 148 |
+
return " ".join(sentences)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
_TEMPLATE_DISPATCH = {
|
| 152 |
+
"bbb": _template_explain_bbb,
|
| 153 |
+
"eeg": _template_explain_eeg,
|
| 154 |
+
"mri": _template_explain_mri,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _build_llm_prompt(payload: ExplainPayload, modality: str = "bbb") -> str:
|
| 159 |
"""Format the payload + user question into a single LLM prompt."""
|
| 160 |
+
headers = {
|
| 161 |
+
"bbb": (
|
| 162 |
+
"You are a clinical-ML explainer for a B2B blood-brain-barrier "
|
| 163 |
+
"permeability tool."
|
| 164 |
+
),
|
| 165 |
+
"eeg": (
|
| 166 |
+
"You are a clinical-ML explainer for an EEG signal-processing "
|
| 167 |
+
"pipeline (MNE-Python + ICA artifact removal)."
|
| 168 |
+
),
|
| 169 |
+
"mri": (
|
| 170 |
+
"You are a clinical-ML explainer for a multi-site MRI "
|
| 171 |
+
"harmonization pipeline (neuroHarmonize / ComBat)."
|
| 172 |
+
),
|
| 173 |
+
}
|
| 174 |
+
header = headers.get(modality, headers["bbb"])
|
| 175 |
+
user_q = payload.get("user_question") or "Explain the result in 2-4 sentences."
|
| 176 |
+
body_lines: list[str] = []
|
| 177 |
+
if modality == "bbb":
|
| 178 |
+
top_features = payload.get("top_features") or []
|
| 179 |
+
top_lines = "\n".join(
|
| 180 |
+
f" - {row['feature']}: Δ{float(row['shap_value']):+.3f}"
|
| 181 |
+
for row in top_features[:5]
|
| 182 |
+
) or " - (none)"
|
| 183 |
+
drift_z = payload.get("drift_z")
|
| 184 |
+
drift_str = "n/a" if drift_z is None else f"{float(drift_z):+.2f}"
|
| 185 |
+
body_lines.append(
|
| 186 |
+
f"Prediction:\n"
|
| 187 |
+
f"- SMILES: {payload.get('smiles', '?')}\n"
|
| 188 |
+
f"- Verdict: {payload.get('label_text', '?')} "
|
| 189 |
+
f"({float(payload.get('confidence', 0.0)) * 100:.0f}% confident)\n"
|
| 190 |
+
f"- Top SHAP features (positive = pushed toward verdict):\n"
|
| 191 |
+
f"{top_lines}\n"
|
| 192 |
+
f"- Drift z-score: {drift_str}"
|
| 193 |
+
)
|
| 194 |
+
elif modality == "eeg":
|
| 195 |
+
body_lines.append(
|
| 196 |
+
f"EEG Pipeline Run:\n"
|
| 197 |
+
f"- Epochs produced: {payload.get('rows', 0)}\n"
|
| 198 |
+
f"- Features per epoch: {payload.get('columns', 0)}\n"
|
| 199 |
+
f"- Wall-clock: {float(payload.get('duration_sec', 0.0)):.2f}s\n"
|
| 200 |
+
f"- MLflow run id: {payload.get('mlflow_run_id') or 'n/a'}"
|
| 201 |
+
)
|
| 202 |
+
elif modality == "mri":
|
| 203 |
+
body_lines.append(
|
| 204 |
+
f"MRI ComBat Diagnostics:\n"
|
| 205 |
+
f"- Site-gap pre-ComBat: {float(payload.get('site_gap_pre', 0)):.4f}\n"
|
| 206 |
+
f"- Site-gap post-ComBat: {float(payload.get('site_gap_post', 0)):.4f}\n"
|
| 207 |
+
f"- Reduction factor: {float(payload.get('reduction_factor', 0)):.0f}×\n"
|
| 208 |
+
f"- Subjects: {int(payload.get('n_subjects', 0))}"
|
| 209 |
+
)
|
| 210 |
+
else:
|
| 211 |
+
# fallback uses BBB-shape prompt
|
| 212 |
+
body_lines.append(f"Payload: {payload!r}")
|
| 213 |
+
|
| 214 |
return (
|
| 215 |
+
f"{header} Given the details below, write a 2-4 sentence rationale a "
|
| 216 |
+
f"researcher could paste into a paper. Avoid hedging; be specific "
|
| 217 |
+
f"about the numbers.\n\n"
|
| 218 |
+
f"{body_lines[0]}\n\n"
|
| 219 |
+
f"User question: {user_q}\n\n"
|
| 220 |
+
f"Respond with the rationale only, no preamble."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
)
|
| 222 |
|
| 223 |
|
| 224 |
+
def _llm_explain(payload: ExplainPayload, modality: str = "bbb") -> tuple[str, str] | None:
|
| 225 |
"""Try the OpenRouter chat completion. Return (rationale, model) or None."""
|
| 226 |
try:
|
| 227 |
# Local import — keeps this dep optional at module load time.
|
|
|
|
| 239 |
api_key=api_key,
|
| 240 |
timeout=_LLM_TIMEOUT_SECONDS,
|
| 241 |
)
|
| 242 |
+
prompt = _build_llm_prompt(payload, modality)
|
| 243 |
try:
|
| 244 |
completion = client.chat.completions.create(
|
| 245 |
model=_DEFAULT_MODEL,
|
|
|
|
| 264 |
return text.strip(), _DEFAULT_MODEL
|
| 265 |
|
| 266 |
|
| 267 |
+
def explain(
|
| 268 |
+
payload: ExplainPayload, modality: str = "bbb",
|
| 269 |
+
) -> ExplainResult:
|
| 270 |
+
"""Return a natural-language rationale for a prediction or pipeline run.
|
| 271 |
|
| 272 |
+
`modality` selects the template family ('bbb' | 'eeg' | 'mri'). Unknown
|
| 273 |
+
values degrade to the BBB template with a warning log; the function
|
| 274 |
+
never raises.
|
| 275 |
"""
|
| 276 |
+
if modality not in _TEMPLATE_DISPATCH:
|
| 277 |
+
logger.warning(
|
| 278 |
+
"Unknown explain modality %r; falling back to bbb template.",
|
| 279 |
+
modality,
|
| 280 |
+
)
|
| 281 |
+
modality = "bbb"
|
| 282 |
+
|
| 283 |
if _should_use_llm():
|
| 284 |
+
llm_out: Any = _llm_explain(payload, modality=modality)
|
| 285 |
if llm_out is not None:
|
| 286 |
rationale, model = llm_out
|
| 287 |
return ExplainResult(rationale=rationale, source="llm", model=model)
|
| 288 |
# else: fall through to template
|
| 289 |
+
|
| 290 |
+
template_fn = _TEMPLATE_DISPATCH[modality]
|
| 291 |
return ExplainResult(
|
| 292 |
+
rationale=template_fn(payload),
|
| 293 |
source="template",
|
| 294 |
model=None,
|
| 295 |
)
|
tests/llm/test_explainer.py
CHANGED
|
@@ -68,3 +68,63 @@ class TestTemplateExplain:
|
|
| 68 |
result = explain(_payload())
|
| 69 |
assert result["source"] == "template"
|
| 70 |
assert result["model"] is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
result = explain(_payload())
|
| 69 |
assert result["source"] == "template"
|
| 70 |
assert result["model"] is None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TestEEGTemplate:
|
| 74 |
+
"""Day-8 T1A: deterministic EEG template path."""
|
| 75 |
+
|
| 76 |
+
def test_eeg_template_uses_pipeline_metrics(self, monkeypatch):
|
| 77 |
+
monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
|
| 78 |
+
payload = {
|
| 79 |
+
"rows": 30,
|
| 80 |
+
"columns": 95,
|
| 81 |
+
"duration_sec": 4.32,
|
| 82 |
+
"mlflow_run_id": "abc12345",
|
| 83 |
+
"user_question": "Why were epochs dropped?",
|
| 84 |
+
}
|
| 85 |
+
result = explain(payload, modality="eeg")
|
| 86 |
+
assert result["source"] == "template"
|
| 87 |
+
assert result["model"] is None
|
| 88 |
+
rationale = result["rationale"]
|
| 89 |
+
assert "30" in rationale, "epoch count must appear"
|
| 90 |
+
assert "95" in rationale, "feature count must appear"
|
| 91 |
+
assert "4.3" in rationale, "duration must appear (1-decimal)"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class TestMRITemplate:
|
| 95 |
+
"""Day-8 T1A: deterministic MRI template path."""
|
| 96 |
+
|
| 97 |
+
def test_mri_template_uses_combat_metrics(self, monkeypatch):
|
| 98 |
+
monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
|
| 99 |
+
payload = {
|
| 100 |
+
"site_gap_pre": 5.0004,
|
| 101 |
+
"site_gap_post": 0.0015,
|
| 102 |
+
"reduction_factor": 3290.0,
|
| 103 |
+
"n_subjects": 6,
|
| 104 |
+
"user_question": "Why does ComBat matter?",
|
| 105 |
+
}
|
| 106 |
+
result = explain(payload, modality="mri")
|
| 107 |
+
assert result["source"] == "template"
|
| 108 |
+
rationale = result["rationale"]
|
| 109 |
+
assert "5.00" in rationale or "5.0" in rationale, "pre-gap must appear"
|
| 110 |
+
assert "3290" in rationale or "3290×" in rationale, "reduction factor must appear"
|
| 111 |
+
assert "6" in rationale, "n_subjects must appear"
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class TestModalityDispatch:
|
| 115 |
+
"""Day-8 T1A: explain(modality=…) routes to the right template."""
|
| 116 |
+
|
| 117 |
+
def test_unknown_modality_falls_back_to_bbb_template(self, monkeypatch):
|
| 118 |
+
"""Defensive: an unknown modality string degrades gracefully (warn + bbb-style template)."""
|
| 119 |
+
monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
|
| 120 |
+
payload = {
|
| 121 |
+
"smiles": "CCO",
|
| 122 |
+
"label": 1,
|
| 123 |
+
"label_text": "permeable",
|
| 124 |
+
"confidence": 0.82,
|
| 125 |
+
"top_features": [{"feature": "fp_1", "shap_value": 0.05}],
|
| 126 |
+
}
|
| 127 |
+
result = explain(payload, modality="unknown_xyz")
|
| 128 |
+
# Should not raise; should produce a non-empty rationale
|
| 129 |
+
assert result["source"] == "template"
|
| 130 |
+
assert result["rationale"], "rationale must be non-empty"
|