hackathon / src /llm /explainer.py
mekosotto's picture
fix(demo): wire every showcase path end-to-end with auto-resolved artifacts
cc1c9fc
"""Natural-language rationale for a single BBB prediction.
Public entry point: `explain(payload)`. Always returns a usable
ExplainResult — never raises. Tries OpenRouter first when a key is set
and the kill-switch is off; falls back to a deterministic template on
any failure (network, auth, rate limit, malformed response).
Test discipline: deterministic template path is the source of truth.
LLM path is env-gated and exercised by integration tests only.
"""
from __future__ import annotations
import os
from typing import Any, TypedDict
from src.core.logger import get_logger
logger = get_logger(__name__)
# Load .env (project root) so OPENROUTER_API_KEY etc. are available without
# the caller having to export them. Safe no-op if python-dotenv isn't
# installed or .env is missing. Existing env vars are NOT overridden.
try:
from dotenv import load_dotenv as _load_dotenv
_load_dotenv(override=False)
except ImportError:
pass
class FeatureRow(TypedDict):
feature: str
shap_value: float
class CalibrationDict(TypedDict):
threshold: float
precision: float
support: int
ExplainPayload = dict[str, Any] # Heterogeneous: BBB / EEG / MRI shapes differ.
class ExplainResult(TypedDict):
rationale: str
source: str # "llm" | "template"
model: str | None # llm model name when source="llm", else None
_OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
_LLM_TIMEOUT_SECONDS = 8.0
_LLM_MAX_TOKENS = 256
_LLM_TEMPERATURE = 0.3
# Free-tier fallback chain, smartest → smallest. When a model returns 429
# (rate-limit / daily-quota exhausted), 402 (credits), 404 (id retired) or
# 5xx (upstream), we advance to the next model. Network/timeout errors fall
# straight to the deterministic template — switching models won't help.
# Override at runtime via OPENROUTER_FREE_MODELS (comma-separated). Model
# availability on OpenRouter churns; verify with scripts/diagnose_openrouter.py.
# Last verified: 2026-05-02 via scripts/diagnose_openrouter.py.
# Entries marked "currently 429" have valid IDs but were quota-exhausted at
# probe time; kept because OpenRouter rate-limits are per-window and recover.
_DEFAULT_FREE_MODEL_CHAIN: tuple[str, ...] = (
"openai/gpt-oss-20b:free", # 20B — verified OK 2026-05-02
"inclusionai/ling-2.6-1t:free", # ~1T flagship — verified OK, returns content
"nvidia/nemotron-3-super-120b-a12b:free", # 120B — verified OK, returns content
"minimax/minimax-m2.5:free", # MoE — verified OK, returns content
"qwen/qwen3-next-80b-a3b-instruct:free", # 80B — currently 429 but valid id
"google/gemma-4-31b-it:free", # 31B — currently 429 but valid id
"google/gemma-4-26b-a4b-it:free", # 26B MoE — currently 429 but valid id
"tencent/hy3-preview:free", # MoE preview — verified OK
"nvidia/nemotron-3-nano-omni-30b-a3b-reasoning:free", # 30B reasoning — verified OK
"nvidia/nemotron-3-nano-30b-a3b:free", # 30B — verified OK
"poolside/laguna-xs.2:free", # smallest — verified OK
)
def _free_model_chain() -> tuple[str, ...]:
raw = os.environ.get("OPENROUTER_FREE_MODELS")
if raw:
ids = tuple(m.strip() for m in raw.split(",") if m.strip())
if ids:
return ids
return _DEFAULT_FREE_MODEL_CHAIN
def _should_use_llm() -> bool:
"""Gate: env kill-switch off AND key present."""
if os.environ.get("NEUROBRIDGE_DISABLE_LLM") == "1":
return False
if not os.environ.get("OPENROUTER_API_KEY"):
return False
return True
def _drift_interpretation(drift_z: float | None) -> str:
if drift_z is None:
return "drift unavailable"
mag = abs(drift_z)
if mag < 1.0:
return "within expected range"
if mag < 2.0:
return "mild distribution shift"
return "significant shift, retrain recommended"
def _template_explain_bbb(payload: ExplainPayload) -> str:
"""Deterministic, jury-friendly rationale for a single BBB prediction."""
label_text = payload.get("label_text", "unknown")
confidence = float(payload.get("confidence", 0.0))
top_features = payload.get("top_features") or []
# Sentence 1
sentences = [
f"Predicted **{label_text}** with {confidence * 100:.0f}% confidence."
]
# Sentence 2 (calibration, optional)
cal = payload.get("calibration")
if cal is not None:
thr_pct = float(cal["threshold"]) * 100
prec_pct = float(cal["precision"]) * 100
support = int(cal["support"])
if support > 0:
sentences.append(
f"Calibration: predictions in the ≥{thr_pct:.0f}% bin are "
f"correct {prec_pct:.0f}% of the time on held-out data "
f"(n={support})."
)
# Sentence 3 (top-3 SHAP features)
if top_features:
feat_strs = [
f"{row['feature']}{float(row['shap_value']):+.3f})"
for row in top_features[:3]
]
sentences.append(
f"Top SHAP attributions toward this label: {', '.join(feat_strs)}."
)
# Sentence 4 (drift, optional)
drift_z = payload.get("drift_z")
if drift_z is not None:
interp = _drift_interpretation(drift_z)
sentences.append(
f"Drift signal: trailing-100 confidence median is "
f"{float(drift_z):+.2f}σ from training distribution ({interp})."
)
return " ".join(sentences)
def _template_explain_eeg(payload: ExplainPayload) -> str:
"""Deterministic rationale for an EEG pipeline run."""
rows = payload.get("rows", 0)
columns = payload.get("columns", 0)
duration = float(payload.get("duration_sec", 0.0))
run_id = payload.get("mlflow_run_id") or "—"
sentences = [
f"EEG pipeline produced **{rows}** epochs × **{columns}** features "
f"in {duration:.1f}s.",
"ICA decomposed the signal and dropped components whose absolute "
"EOG correlation exceeded 0.5 (eye-blink artifacts).",
"Bandpass filter 0.5-40 Hz removed line noise and DC drift before ICA.",
f"Run id: `{run_id}` (use the Experiments tab to compare against "
"previous runs).",
]
return " ".join(sentences)
def _template_explain_mri(payload: ExplainPayload) -> str:
"""Deterministic rationale for an MRI ComBat-harmonization diagnostic."""
pre = float(payload.get("site_gap_pre", 0.0))
post = float(payload.get("site_gap_post", 0.0))
factor = float(payload.get("reduction_factor", 0.0))
n_subjects = int(payload.get("n_subjects", 0))
sentences = [
f"ComBat harmonization reduced the per-site mean gap from "
f"**{pre:.4f}** to **{post:.4f}** — a **{factor:.0f}×** collapse "
f"across **{n_subjects}** subjects on the first feature.",
"This is the quantified proof that scanner / acquisition-site bias "
"was removed: predictions trained on the harmonized features "
"generalize across hospitals instead of memorizing site identity.",
"The visual evidence is the per-site KDE convergence in the "
"Pre-ComBat → Post-ComBat panels (Streamlit MRI tab).",
]
return " ".join(sentences)
_TEMPLATE_DISPATCH = {
"bbb": _template_explain_bbb,
"eeg": _template_explain_eeg,
"mri": _template_explain_mri,
}
def _build_llm_prompt(payload: ExplainPayload, modality: str = "bbb") -> str:
"""Format the payload + user question into a single LLM prompt."""
headers = {
"bbb": (
"You are a clinical-ML explainer for a B2B blood-brain-barrier "
"permeability tool."
),
"eeg": (
"You are a clinical-ML explainer for an EEG signal-processing "
"pipeline (MNE-Python + ICA artifact removal)."
),
"mri": (
"You are a clinical-ML explainer for a multi-site MRI "
"harmonization pipeline (neuroHarmonize / ComBat)."
),
}
header = headers.get(modality, headers["bbb"])
raw_q = (payload.get("user_question") or "").strip()
# When the caller did not supply a question, default to the paper-style
# rationale prompt; this preserves the original behavior for /explain
# callers that just want a one-shot summary.
user_q = raw_q or "Explain the result in 2-4 sentences."
has_explicit_question = bool(raw_q)
body_lines: list[str] = []
if modality == "bbb":
top_features = payload.get("top_features") or []
top_lines = "\n".join(
f" - {row['feature']}: Δ{float(row['shap_value']):+.3f}"
for row in top_features[:5]
) or " - (none)"
drift_z = payload.get("drift_z")
drift_str = "n/a" if drift_z is None else f"{float(drift_z):+.2f}"
body_lines.append(
f"Prediction:\n"
f"- SMILES: {payload.get('smiles', '?')}\n"
f"- Verdict: {payload.get('label_text', '?')} "
f"({float(payload.get('confidence', 0.0)) * 100:.0f}% confident)\n"
f"- Top SHAP features (positive = pushed toward verdict):\n"
f"{top_lines}\n"
f"- Drift z-score: {drift_str}"
)
elif modality == "eeg":
body_lines.append(
f"EEG Pipeline Run:\n"
f"- Epochs produced: {payload.get('rows', 0)}\n"
f"- Features per epoch: {payload.get('columns', 0)}\n"
f"- Wall-clock: {float(payload.get('duration_sec', 0.0)):.2f}s\n"
f"- MLflow run id: {payload.get('mlflow_run_id') or 'n/a'}"
)
elif modality == "mri":
body_lines.append(
f"MRI ComBat Diagnostics:\n"
f"- Site-gap pre-ComBat: {float(payload.get('site_gap_pre', 0)):.4f}\n"
f"- Site-gap post-ComBat: {float(payload.get('site_gap_post', 0)):.4f}\n"
f"- Reduction factor: {float(payload.get('reduction_factor', 0)):.0f}×\n"
f"- Subjects: {int(payload.get('n_subjects', 0))}"
)
else:
# fallback uses BBB-shape prompt
body_lines.append(f"Payload: {payload!r}")
if has_explicit_question:
instructions = (
"Instructions:\n"
"- Respond in the SAME LANGUAGE as the user's question above "
"(Turkish question → Turkish answer, English → English, etc.).\n"
"- Directly answer the user's question using the data below; "
"do not default to a generic paper-style summary unless they "
"asked for one.\n"
"- If the question is conversational or off-topic (e.g. a "
"greeting), reply briefly and conversationally — do not force "
"a clinical rationale.\n"
"- Cite specific numbers from the data when relevant.\n"
"- No preamble, no apologies, just the answer."
)
else:
instructions = (
"Write a 2-4 sentence rationale a researcher could paste into "
"a paper. Avoid hedging; be specific about the numbers. "
"Respond with the rationale only, no preamble."
)
return (
f"{header}\n\n"
f"User question: {user_q}\n\n"
f"Data for this prediction:\n{body_lines[0]}\n\n"
f"{instructions}"
)
def _llm_explain(payload: ExplainPayload, modality: str = "bbb") -> tuple[str, str] | None:
"""Try the OpenRouter chat completion across the free-tier fallback chain.
Returns (rationale, model_id) on first success, or None if every model
is exhausted / unreachable (caller falls back to the template).
"""
try:
# Local imports — keeps this dep optional at module load time.
from openai import (
OpenAI,
APIConnectionError,
APIStatusError,
APITimeoutError,
RateLimitError,
)
except ImportError as e:
logger.warning("openai SDK not importable: %s", e)
return None
api_key = os.environ.get("OPENROUTER_API_KEY")
if not api_key:
return None
client = OpenAI(
base_url=_OPENROUTER_BASE_URL,
api_key=api_key,
timeout=_LLM_TIMEOUT_SECONDS,
)
prompt = _build_llm_prompt(payload, modality)
chain = _free_model_chain()
for model in chain:
try:
completion = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=_LLM_MAX_TOKENS,
temperature=_LLM_TEMPERATURE,
)
except RateLimitError:
logger.info("OpenRouter 429 on %s; advancing to next free model.", model)
continue
except APIStatusError as e:
status = getattr(e, "status_code", None)
# 401 = unauthorized — the key is bad, no model in this chain
# will succeed. Surface a loud, actionable hint and bail.
if status == 401:
logger.warning(
"OpenRouter 401 unauthorized on %s. The OPENROUTER_API_KEY "
"is rejected — verify it is current at "
"https://openrouter.ai/keys and that free-model data-sharing "
"is enabled at https://openrouter.ai/settings/privacy. "
"Falling back to deterministic template.",
model,
)
return None
# 400 = malformed prompt for this specific model (e.g. it
# rejected our system role). Skip this model, try the next.
if status == 400:
logger.info(
"OpenRouter 400 on %s (likely prompt-shape mismatch); "
"advancing to next free model.", model,
)
continue
# 402 credits / 403 access / 404 retired-id / 5xx upstream → next.
if status in (402, 403, 404) or (status is not None and 500 <= status < 600):
logger.info("OpenRouter %s on %s; advancing to next free model.", status, model)
continue
logger.warning("LLM call failed on %s (%s); falling back to template.", model, e)
return None
except (APIConnectionError, APITimeoutError) as e:
# Network is global — switching models won't help.
logger.warning("LLM connection error (%s); falling back to template.", type(e).__name__)
return None
except Exception as e:
logger.warning("LLM unexpected error on %s (%s); falling back to template.", model, type(e).__name__)
return None
try:
text = completion.choices[0].message.content
except (AttributeError, IndexError, TypeError) as e:
logger.info("LLM response malformed on %s (%s); advancing to next model.", model, e)
continue
if not text or not text.strip():
logger.info("LLM returned empty rationale on %s; advancing to next model.", model)
continue
return text.strip(), model
logger.warning("All free models exhausted; falling back to template.")
return None
def explain(
payload: ExplainPayload, modality: str = "bbb",
) -> ExplainResult:
"""Return a natural-language rationale for a prediction or pipeline run.
`modality` selects the template family ('bbb' | 'eeg' | 'mri'). Unknown
values degrade to the BBB template with a warning log; the function
never raises.
"""
if modality not in _TEMPLATE_DISPATCH:
logger.warning(
"Unknown explain modality %r; falling back to bbb template.",
modality,
)
modality = "bbb"
if _should_use_llm():
llm_out: Any = _llm_explain(payload, modality=modality)
if llm_out is not None:
rationale, model = llm_out
return ExplainResult(rationale=rationale, source="llm", model=model)
# else: fall through to template
template_fn = _TEMPLATE_DISPATCH[modality]
return ExplainResult(
rationale=template_fn(payload),
source="template",
model=None,
)