mekosotto Claude Opus 4.7 (1M context) commited on
Commit
e5c1c61
·
1 Parent(s): d69f171

feat(llm): explainer with deterministic template + OpenRouter fallback

Browse files

- New module src/llm/explainer.py — single public entry point
explain(payload). Returns {rationale, source, model}. Never raises.
- Deterministic template (4 sentences: verdict, calibration if any,
top-3 SHAP, drift) is the source of truth for tests.
- LLM path: OpenRouter chat completions via openai==1.51.0 SDK,
model meta-llama/llama-3.2-3b-instruct:free, 8s timeout, 256 max
tokens, temperature 0.3. Gated by OPENROUTER_API_KEY presence and
NEUROBRIDGE_DISABLE_LLM=1 kill-switch.
- Fallback chain: env-disabled → no key → SDK ImportError → API error
→ empty/malformed response → all degrade to template, log WARNING,
source="template".
- 4 new tests: deterministic, top features included, label text
included, kill-switch overrides key.
- New pip dep: openai==1.51.0 (~600KB, transitive deps already present).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

requirements.txt CHANGED
@@ -39,3 +39,6 @@ httpx==0.27.2 # FastAPI test client
39
 
40
  # --- Frontend (B2B dashboard) ---
41
  streamlit==1.39.0
 
 
 
 
39
 
40
  # --- Frontend (B2B dashboard) ---
41
  streamlit==1.39.0
42
+
43
+ # --- LLM provider (Day 7 explainer) ---
44
+ openai==1.51.0 # OpenRouter SDK (Day-7 LLM explainer; deterministic-template fallback always available)
src/llm/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """LLM-backed natural-language explainers (Day 7).
2
+
3
+ `explain()` is the ONLY public entry point. It guarantees a non-empty
4
+ rationale every call: tries OpenRouter when available, falls back to a
5
+ deterministic template otherwise. The deterministic path is the source
6
+ of truth for tests; the LLM path is gated behind env config.
7
+ """
8
+ from src.llm.explainer import ExplainPayload, ExplainResult, explain # noqa: F401
src/llm/explainer.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Natural-language rationale for a single BBB prediction.
2
+
3
+ Public entry point: `explain(payload)`. Always returns a usable
4
+ ExplainResult — never raises. Tries OpenRouter first when a key is set
5
+ and the kill-switch is off; falls back to a deterministic template on
6
+ any failure (network, auth, rate limit, malformed response).
7
+
8
+ Test discipline: deterministic template path is the source of truth.
9
+ LLM path is env-gated and exercised by integration tests only.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ from typing import Any, TypedDict
15
+
16
+ from src.core.logger import get_logger
17
+
18
+ logger = get_logger(__name__)
19
+
20
+
21
+ class FeatureRow(TypedDict):
22
+ feature: str
23
+ shap_value: float
24
+
25
+
26
+ class CalibrationDict(TypedDict):
27
+ threshold: float
28
+ precision: float
29
+ support: int
30
+
31
+
32
+ class ExplainPayload(TypedDict, total=False):
33
+ smiles: str
34
+ label: int
35
+ label_text: str
36
+ confidence: float
37
+ top_features: list[FeatureRow]
38
+ calibration: CalibrationDict | None
39
+ drift_z: float | None
40
+ user_question: str
41
+
42
+
43
+ class ExplainResult(TypedDict):
44
+ rationale: str
45
+ source: str # "llm" | "template"
46
+ model: str | None # llm model name when source="llm", else None
47
+
48
+
49
+ _OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
50
+ _DEFAULT_MODEL = "meta-llama/llama-3.2-3b-instruct:free"
51
+ _LLM_TIMEOUT_SECONDS = 8.0
52
+ _LLM_MAX_TOKENS = 256
53
+ _LLM_TEMPERATURE = 0.3
54
+
55
+
56
+ def _should_use_llm() -> bool:
57
+ """Gate: env kill-switch off AND key present."""
58
+ if os.environ.get("NEUROBRIDGE_DISABLE_LLM") == "1":
59
+ return False
60
+ if not os.environ.get("OPENROUTER_API_KEY"):
61
+ return False
62
+ return True
63
+
64
+
65
+ def _drift_interpretation(drift_z: float | None) -> str:
66
+ if drift_z is None:
67
+ return "drift unavailable"
68
+ mag = abs(drift_z)
69
+ if mag < 1.0:
70
+ return "within expected range"
71
+ if mag < 2.0:
72
+ return "mild distribution shift"
73
+ return "significant shift, retrain recommended"
74
+
75
+
76
+ def _template_explain(payload: ExplainPayload) -> str:
77
+ """Deterministic, jury-friendly rationale. Never raises."""
78
+ label_text = payload.get("label_text", "unknown")
79
+ confidence = float(payload.get("confidence", 0.0))
80
+ top_features = payload.get("top_features") or []
81
+
82
+ # Sentence 1
83
+ sentences = [
84
+ f"Predicted **{label_text}** with {confidence * 100:.0f}% confidence."
85
+ ]
86
+
87
+ # Sentence 2 (calibration, optional)
88
+ cal = payload.get("calibration")
89
+ if cal is not None:
90
+ thr_pct = float(cal["threshold"]) * 100
91
+ prec_pct = float(cal["precision"]) * 100
92
+ support = int(cal["support"])
93
+ if support > 0:
94
+ sentences.append(
95
+ f"Calibration: predictions in the ≥{thr_pct:.0f}% bin are "
96
+ f"correct {prec_pct:.0f}% of the time on held-out data "
97
+ f"(n={support})."
98
+ )
99
+
100
+ # Sentence 3 (top-3 SHAP features)
101
+ if top_features:
102
+ feat_strs = [
103
+ f"{row['feature']} (Δ{float(row['shap_value']):+.3f})"
104
+ for row in top_features[:3]
105
+ ]
106
+ sentences.append(
107
+ f"Top SHAP attributions toward this label: {', '.join(feat_strs)}."
108
+ )
109
+
110
+ # Sentence 4 (drift, optional)
111
+ drift_z = payload.get("drift_z")
112
+ if drift_z is not None:
113
+ interp = _drift_interpretation(drift_z)
114
+ sentences.append(
115
+ f"Drift signal: trailing-100 confidence median is "
116
+ f"{float(drift_z):+.2f}σ from training distribution ({interp})."
117
+ )
118
+
119
+ return " ".join(sentences)
120
+
121
+
122
+ def _build_llm_prompt(payload: ExplainPayload) -> str:
123
+ """Format the payload + user question into a single LLM prompt."""
124
+ top_features = payload.get("top_features") or []
125
+ top_lines = "\n".join(
126
+ f" - {row['feature']}: Δ{float(row['shap_value']):+.3f}"
127
+ for row in top_features[:5]
128
+ ) or " - (none)"
129
+ drift_z = payload.get("drift_z")
130
+ drift_str = "n/a" if drift_z is None else f"{float(drift_z):+.2f}"
131
+ user_q = payload.get("user_question") or (
132
+ "Explain the prediction in 2-4 sentences."
133
+ )
134
+ return (
135
+ "You are a clinical-ML explainer for a B2B blood-brain-barrier "
136
+ "permeability tool. Given the prediction details below, write a "
137
+ "2-4 sentence rationale a researcher could paste into a paper. "
138
+ "Use the SHAP attributions to justify the verdict. Mention drift "
139
+ "if abnormal. Avoid hedging; be specific about the numbers.\n\n"
140
+ f"Prediction:\n"
141
+ f"- SMILES: {payload.get('smiles', '?')}\n"
142
+ f"- Verdict: {payload.get('label_text', '?')} "
143
+ f"({float(payload.get('confidence', 0.0)) * 100:.0f}% confident)\n"
144
+ f"- Top SHAP features (positive = pushed toward verdict):\n"
145
+ f"{top_lines}\n"
146
+ f"- Drift z-score: {drift_str}\n"
147
+ f"\nUser question: {user_q}\n"
148
+ f"\nRespond with the rationale only, no preamble."
149
+ )
150
+
151
+
152
+ def _llm_explain(payload: ExplainPayload) -> tuple[str, str] | None:
153
+ """Try the OpenRouter chat completion. Return (rationale, model) or None."""
154
+ try:
155
+ # Local import — keeps this dep optional at module load time.
156
+ from openai import OpenAI
157
+ except ImportError as e:
158
+ logger.warning("openai SDK not importable: %s", e)
159
+ return None
160
+
161
+ api_key = os.environ.get("OPENROUTER_API_KEY")
162
+ if not api_key:
163
+ return None
164
+
165
+ client = OpenAI(
166
+ base_url=_OPENROUTER_BASE_URL,
167
+ api_key=api_key,
168
+ timeout=_LLM_TIMEOUT_SECONDS,
169
+ )
170
+ prompt = _build_llm_prompt(payload)
171
+ try:
172
+ completion = client.chat.completions.create(
173
+ model=_DEFAULT_MODEL,
174
+ messages=[{"role": "user", "content": prompt}],
175
+ max_tokens=_LLM_MAX_TOKENS,
176
+ temperature=_LLM_TEMPERATURE,
177
+ )
178
+ except Exception as e: # broad: APITimeoutError, APIConnectionError, RateLimitError, ...
179
+ logger.warning("LLM call failed (%s); falling back to template.", type(e).__name__)
180
+ return None
181
+
182
+ try:
183
+ text = completion.choices[0].message.content
184
+ except (AttributeError, IndexError, TypeError) as e:
185
+ logger.warning("LLM response malformed (%s); falling back to template.", e)
186
+ return None
187
+
188
+ if not text or not text.strip():
189
+ logger.warning("LLM returned empty rationale; falling back to template.")
190
+ return None
191
+
192
+ return text.strip(), _DEFAULT_MODEL
193
+
194
+
195
+ def explain(payload: ExplainPayload) -> ExplainResult:
196
+ """Return a natural-language rationale for a BBB prediction.
197
+
198
+ Tries the LLM first when env-permitted; falls back to a deterministic
199
+ template on any failure. Never raises.
200
+ """
201
+ if _should_use_llm():
202
+ llm_out: Any = _llm_explain(payload)
203
+ if llm_out is not None:
204
+ rationale, model = llm_out
205
+ return ExplainResult(rationale=rationale, source="llm", model=model)
206
+ # else: fall through to template
207
+ return ExplainResult(
208
+ rationale=_template_explain(payload),
209
+ source="template",
210
+ model=None,
211
+ )
tests/llm/__init__.py ADDED
File without changes
tests/llm/test_explainer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.llm.explainer.
2
+
3
+ The deterministic template path is exhaustively tested here. The LLM
4
+ path is exercised only by env-gated integration tests in
5
+ test_explainer_integration.py (NOT run in CI by default).
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import os
10
+
11
+ import pytest
12
+
13
+ from src.llm.explainer import ExplainPayload, explain
14
+
15
+
16
+ def _payload(**overrides) -> ExplainPayload:
17
+ """Build a representative ExplainPayload; overrides win."""
18
+ base: ExplainPayload = {
19
+ "smiles": "CCO",
20
+ "label": 1,
21
+ "label_text": "permeable",
22
+ "confidence": 0.82,
23
+ "top_features": [
24
+ {"feature": "fp_341", "shap_value": 0.045},
25
+ {"feature": "fp_902", "shap_value": -0.031},
26
+ {"feature": "fp_77", "shap_value": 0.022},
27
+ ],
28
+ "calibration": {"threshold": 0.80, "precision": 0.92, "support": 18},
29
+ "drift_z": 0.42,
30
+ "user_question": "Why was this molecule predicted as permeable?",
31
+ }
32
+ base.update(overrides)
33
+ return base
34
+
35
+
36
+ class TestTemplateExplain:
37
+ """Day-7 T3A: deterministic-template path of the explainer."""
38
+
39
+ def test_template_path_is_deterministic(self, monkeypatch):
40
+ """Same input → byte-identical rationale string. No randomness."""
41
+ monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
42
+ out_a = explain(_payload())
43
+ out_b = explain(_payload())
44
+ assert out_a["rationale"] == out_b["rationale"]
45
+ assert out_a["source"] == "template"
46
+ assert out_b["source"] == "template"
47
+ assert out_a["model"] is None
48
+
49
+ def test_template_includes_top_feature_names(self, monkeypatch):
50
+ """Rationale must mention the SHAP features so jurors see attribution."""
51
+ monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
52
+ result = explain(_payload())
53
+ for feat in ("fp_341", "fp_902", "fp_77"):
54
+ assert feat in result["rationale"], (
55
+ f"expected feature {feat!r} in rationale, got {result['rationale']!r}"
56
+ )
57
+
58
+ def test_template_includes_label_text(self, monkeypatch):
59
+ """The verdict word ('permeable' / 'non-permeable') must appear."""
60
+ monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
61
+ result = explain(_payload(label=0, label_text="non-permeable"))
62
+ assert "non-permeable" in result["rationale"]
63
+
64
+ def test_disable_flag_forces_template_even_with_key_set(self, monkeypatch):
65
+ """NEUROBRIDGE_DISABLE_LLM=1 wins over OPENROUTER_API_KEY presence."""
66
+ monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
67
+ monkeypatch.setenv("OPENROUTER_API_KEY", "sk-fake-not-used")
68
+ result = explain(_payload())
69
+ assert result["source"] == "template"
70
+ assert result["model"] is None