mekosotto commited on
Commit
7dad1a9
·
1 Parent(s): 3cc6a7d

feat(models): SHAP top-k explainer for BBB predictions

Browse files
src/models/bbb_model.py CHANGED
@@ -125,3 +125,83 @@ def predict_with_proba(
125
  "label": label,
126
  "confidence": float(proba[label_idx]),
127
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  "label": label,
126
  "confidence": float(proba[label_idx]),
127
  }
128
+
129
+
130
+ def explain_prediction(
131
+ model: RandomForestClassifier,
132
+ smiles: str,
133
+ top_k: int = 5,
134
+ n_bits: int = 2048,
135
+ radius: int = 2,
136
+ ) -> list[dict[str, object]]:
137
+ """Return the top-`top_k` feature attributions (SHAP values) for `smiles`.
138
+
139
+ Uses `shap.TreeExplainer` (exact for tree ensembles, no sampling). The
140
+ explanation is for the *predicted* class — i.e. SHAP values that pushed
141
+ the model toward whichever label was returned by `predict_with_proba`.
142
+
143
+ Reads fingerprint column names from `model._neurobridge_fp_cols` (set by
144
+ `train()`). Falls back to `fp_<index>` if the attribute is missing — useful
145
+ for models loaded from a joblib without the project-owned attribute.
146
+
147
+ Args:
148
+ model: Fitted classifier from `train()` or `load()`.
149
+ smiles: A SMILES string (validated via `is_valid_smiles`).
150
+ top_k: How many top features to return. Default 5 — matches the
151
+ jury-demo budget (more bars = noisier waterfall chart).
152
+ n_bits / radius: Must match training-time fingerprint settings.
153
+
154
+ Returns:
155
+ A list of `{"feature": "fp_<bit_idx>", "shap_value": float}` dicts,
156
+ sorted by `abs(shap_value)` descending.
157
+
158
+ Raises:
159
+ ValueError: if `smiles` cannot be parsed by RDKit.
160
+ """
161
+ import shap # local import — heavy module, only loaded when needed
162
+
163
+ if not is_valid_smiles(smiles):
164
+ raise ValueError(f"invalid SMILES: {smiles!r}")
165
+ fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius)
166
+ X = fp.reshape(1, -1)
167
+
168
+ explainer = shap.TreeExplainer(model)
169
+ # uint8 fingerprints cause benign additivity violations in SHAP's
170
+ # reconstruction (base + sum != model output within tolerance); the
171
+ # default check produces false-positive errors on tree ensembles
172
+ # over quantized inputs, so we skip it.
173
+ shap_values = explainer.shap_values(X, check_additivity=False)
174
+ # `shap_values` shape varies by sklearn / shap versions:
175
+ # - older: list of (1, n_features) arrays, one per class
176
+ # - newer: ndarray of shape (1, n_features, n_classes) for binary RF
177
+ # - or (1, n_features) when output already condensed
178
+ if isinstance(shap_values, list):
179
+ proba = model.predict_proba(X)[0]
180
+ label_idx = int(np.argmax(proba))
181
+ per_feature = shap_values[label_idx][0]
182
+ else:
183
+ arr = np.asarray(shap_values)
184
+ if arr.ndim == 3:
185
+ # (1, n_features, n_classes)
186
+ proba = model.predict_proba(X)[0]
187
+ label_idx = int(np.argmax(proba))
188
+ per_feature = arr[0, :, label_idx]
189
+ else:
190
+ # (1, n_features)
191
+ per_feature = arr[0]
192
+
193
+ fp_cols = (
194
+ list(model._neurobridge_fp_cols)
195
+ if hasattr(model, "_neurobridge_fp_cols")
196
+ else [f"fp_{i}" for i in range(len(per_feature))]
197
+ )
198
+
199
+ pairs = sorted(
200
+ zip(fp_cols, per_feature, strict=True),
201
+ key=lambda p: abs(p[1]),
202
+ reverse=True,
203
+ )
204
+ return [
205
+ {"feature": str(name), "shap_value": float(value)}
206
+ for name, value in pairs[:top_k]
207
+ ]
tests/models/test_bbb_model.py CHANGED
@@ -89,3 +89,41 @@ class TestPredictWithProba:
89
  raw_proba = model.predict_proba(fp)[0]
90
  result = bbb_model.predict_with_proba(model, "CCO")
91
  assert abs(result["confidence"] - float(max(raw_proba))) < 1e-9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  raw_proba = model.predict_proba(fp)[0]
90
  result = bbb_model.predict_with_proba(model, "CCO")
91
  assert abs(result["confidence"] - float(max(raw_proba))) < 1e-9
92
+
93
+
94
+ class TestExplainPrediction:
95
+ def test_returns_top_k_features(self, trained_model_and_features):
96
+ model, _ = trained_model_and_features
97
+ attributions = bbb_model.explain_prediction(model, "CCO", top_k=5)
98
+ assert len(attributions) == 5
99
+ for a in attributions:
100
+ assert "feature" in a
101
+ assert "shap_value" in a
102
+ assert isinstance(a["shap_value"], float)
103
+
104
+ def test_features_sorted_by_absolute_shap_value_descending(
105
+ self, trained_model_and_features,
106
+ ):
107
+ model, _ = trained_model_and_features
108
+ attributions = bbb_model.explain_prediction(model, "CCO", top_k=10)
109
+ abs_vals = [abs(a["shap_value"]) for a in attributions]
110
+ assert abs_vals == sorted(abs_vals, reverse=True)
111
+
112
+ def test_features_named_fp_INDEX(self, trained_model_and_features):
113
+ model, _ = trained_model_and_features
114
+ attributions = bbb_model.explain_prediction(model, "CCO", top_k=3)
115
+ for a in attributions:
116
+ assert a["feature"].startswith("fp_")
117
+ int(a["feature"].split("_")[1]) # parses cleanly
118
+
119
+ def test_raises_on_invalid_smiles(self, trained_model_and_features):
120
+ model, _ = trained_model_and_features
121
+ with pytest.raises(ValueError):
122
+ bbb_model.explain_prediction(model, "still_not_a_smiles", top_k=5)
123
+
124
+ def test_deterministic_output(self, trained_model_and_features):
125
+ """AGENTS.md §4 rule 3: identical input → identical SHAP attributions."""
126
+ model, _ = trained_model_and_features
127
+ r1 = bbb_model.explain_prediction(model, "CCO", top_k=5)
128
+ r2 = bbb_model.explain_prediction(model, "CCO", top_k=5)
129
+ assert r1 == r2