feat(models): SHAP top-k explainer for BBB predictions
Browse files- src/models/bbb_model.py +80 -0
- tests/models/test_bbb_model.py +38 -0
src/models/bbb_model.py
CHANGED
|
@@ -125,3 +125,83 @@ def predict_with_proba(
|
|
| 125 |
"label": label,
|
| 126 |
"confidence": float(proba[label_idx]),
|
| 127 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
"label": label,
|
| 126 |
"confidence": float(proba[label_idx]),
|
| 127 |
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def explain_prediction(
|
| 131 |
+
model: RandomForestClassifier,
|
| 132 |
+
smiles: str,
|
| 133 |
+
top_k: int = 5,
|
| 134 |
+
n_bits: int = 2048,
|
| 135 |
+
radius: int = 2,
|
| 136 |
+
) -> list[dict[str, object]]:
|
| 137 |
+
"""Return the top-`top_k` feature attributions (SHAP values) for `smiles`.
|
| 138 |
+
|
| 139 |
+
Uses `shap.TreeExplainer` (exact for tree ensembles, no sampling). The
|
| 140 |
+
explanation is for the *predicted* class — i.e. SHAP values that pushed
|
| 141 |
+
the model toward whichever label was returned by `predict_with_proba`.
|
| 142 |
+
|
| 143 |
+
Reads fingerprint column names from `model._neurobridge_fp_cols` (set by
|
| 144 |
+
`train()`). Falls back to `fp_<index>` if the attribute is missing — useful
|
| 145 |
+
for models loaded from a joblib without the project-owned attribute.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
model: Fitted classifier from `train()` or `load()`.
|
| 149 |
+
smiles: A SMILES string (validated via `is_valid_smiles`).
|
| 150 |
+
top_k: How many top features to return. Default 5 — matches the
|
| 151 |
+
jury-demo budget (more bars = noisier waterfall chart).
|
| 152 |
+
n_bits / radius: Must match training-time fingerprint settings.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
A list of `{"feature": "fp_<bit_idx>", "shap_value": float}` dicts,
|
| 156 |
+
sorted by `abs(shap_value)` descending.
|
| 157 |
+
|
| 158 |
+
Raises:
|
| 159 |
+
ValueError: if `smiles` cannot be parsed by RDKit.
|
| 160 |
+
"""
|
| 161 |
+
import shap # local import — heavy module, only loaded when needed
|
| 162 |
+
|
| 163 |
+
if not is_valid_smiles(smiles):
|
| 164 |
+
raise ValueError(f"invalid SMILES: {smiles!r}")
|
| 165 |
+
fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius)
|
| 166 |
+
X = fp.reshape(1, -1)
|
| 167 |
+
|
| 168 |
+
explainer = shap.TreeExplainer(model)
|
| 169 |
+
# uint8 fingerprints cause benign additivity violations in SHAP's
|
| 170 |
+
# reconstruction (base + sum != model output within tolerance); the
|
| 171 |
+
# default check produces false-positive errors on tree ensembles
|
| 172 |
+
# over quantized inputs, so we skip it.
|
| 173 |
+
shap_values = explainer.shap_values(X, check_additivity=False)
|
| 174 |
+
# `shap_values` shape varies by sklearn / shap versions:
|
| 175 |
+
# - older: list of (1, n_features) arrays, one per class
|
| 176 |
+
# - newer: ndarray of shape (1, n_features, n_classes) for binary RF
|
| 177 |
+
# - or (1, n_features) when output already condensed
|
| 178 |
+
if isinstance(shap_values, list):
|
| 179 |
+
proba = model.predict_proba(X)[0]
|
| 180 |
+
label_idx = int(np.argmax(proba))
|
| 181 |
+
per_feature = shap_values[label_idx][0]
|
| 182 |
+
else:
|
| 183 |
+
arr = np.asarray(shap_values)
|
| 184 |
+
if arr.ndim == 3:
|
| 185 |
+
# (1, n_features, n_classes)
|
| 186 |
+
proba = model.predict_proba(X)[0]
|
| 187 |
+
label_idx = int(np.argmax(proba))
|
| 188 |
+
per_feature = arr[0, :, label_idx]
|
| 189 |
+
else:
|
| 190 |
+
# (1, n_features)
|
| 191 |
+
per_feature = arr[0]
|
| 192 |
+
|
| 193 |
+
fp_cols = (
|
| 194 |
+
list(model._neurobridge_fp_cols)
|
| 195 |
+
if hasattr(model, "_neurobridge_fp_cols")
|
| 196 |
+
else [f"fp_{i}" for i in range(len(per_feature))]
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
pairs = sorted(
|
| 200 |
+
zip(fp_cols, per_feature, strict=True),
|
| 201 |
+
key=lambda p: abs(p[1]),
|
| 202 |
+
reverse=True,
|
| 203 |
+
)
|
| 204 |
+
return [
|
| 205 |
+
{"feature": str(name), "shap_value": float(value)}
|
| 206 |
+
for name, value in pairs[:top_k]
|
| 207 |
+
]
|
tests/models/test_bbb_model.py
CHANGED
|
@@ -89,3 +89,41 @@ class TestPredictWithProba:
|
|
| 89 |
raw_proba = model.predict_proba(fp)[0]
|
| 90 |
result = bbb_model.predict_with_proba(model, "CCO")
|
| 91 |
assert abs(result["confidence"] - float(max(raw_proba))) < 1e-9
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
raw_proba = model.predict_proba(fp)[0]
|
| 90 |
result = bbb_model.predict_with_proba(model, "CCO")
|
| 91 |
assert abs(result["confidence"] - float(max(raw_proba))) < 1e-9
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class TestExplainPrediction:
|
| 95 |
+
def test_returns_top_k_features(self, trained_model_and_features):
|
| 96 |
+
model, _ = trained_model_and_features
|
| 97 |
+
attributions = bbb_model.explain_prediction(model, "CCO", top_k=5)
|
| 98 |
+
assert len(attributions) == 5
|
| 99 |
+
for a in attributions:
|
| 100 |
+
assert "feature" in a
|
| 101 |
+
assert "shap_value" in a
|
| 102 |
+
assert isinstance(a["shap_value"], float)
|
| 103 |
+
|
| 104 |
+
def test_features_sorted_by_absolute_shap_value_descending(
|
| 105 |
+
self, trained_model_and_features,
|
| 106 |
+
):
|
| 107 |
+
model, _ = trained_model_and_features
|
| 108 |
+
attributions = bbb_model.explain_prediction(model, "CCO", top_k=10)
|
| 109 |
+
abs_vals = [abs(a["shap_value"]) for a in attributions]
|
| 110 |
+
assert abs_vals == sorted(abs_vals, reverse=True)
|
| 111 |
+
|
| 112 |
+
def test_features_named_fp_INDEX(self, trained_model_and_features):
|
| 113 |
+
model, _ = trained_model_and_features
|
| 114 |
+
attributions = bbb_model.explain_prediction(model, "CCO", top_k=3)
|
| 115 |
+
for a in attributions:
|
| 116 |
+
assert a["feature"].startswith("fp_")
|
| 117 |
+
int(a["feature"].split("_")[1]) # parses cleanly
|
| 118 |
+
|
| 119 |
+
def test_raises_on_invalid_smiles(self, trained_model_and_features):
|
| 120 |
+
model, _ = trained_model_and_features
|
| 121 |
+
with pytest.raises(ValueError):
|
| 122 |
+
bbb_model.explain_prediction(model, "still_not_a_smiles", top_k=5)
|
| 123 |
+
|
| 124 |
+
def test_deterministic_output(self, trained_model_and_features):
|
| 125 |
+
"""AGENTS.md §4 rule 3: identical input → identical SHAP attributions."""
|
| 126 |
+
model, _ = trained_model_and_features
|
| 127 |
+
r1 = bbb_model.explain_prediction(model, "CCO", top_k=5)
|
| 128 |
+
r2 = bbb_model.explain_prediction(model, "CCO", top_k=5)
|
| 129 |
+
assert r1 == r2
|