File size: 11,290 Bytes
3cc6a7d 90167c7 3cc6a7d 90167c7 efb8713 3cc6a7d 90167c7 3cc6a7d 90167c7 3cc6a7d 90167c7 efb8713 3cc6a7d 90167c7 efb8713 3cc6a7d 90167c7 efb8713 3cc6a7d 7dad1a9 53256ed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 | """BBB-permeability downstream classifier — train / save / load / predict.
Built on top of `data/processed/bbbp_features.parquet` produced by
`src.pipelines.bbb_pipeline`. Uses scikit-learn's `RandomForestClassifier`
(no XGBoost — saves a heavy dep without losing accuracy at this scale).
The model takes a 2,048-bit Morgan fingerprint as input. SHAP-based
explanation is added in Task 2 (`explain_prediction`).
"""
from __future__ import annotations
from pathlib import Path
import joblib
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from src.core.logger import get_logger
from src.pipelines.bbb_pipeline import (
compute_morgan_fingerprint,
is_valid_smiles,
)
logger = get_logger(__name__)
_FP_COL_PREFIX = "fp_"
def _split_features_and_label(
df: pd.DataFrame, label_col: str,
) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""Pull out fp_* columns as X and `label_col` as y. Returns (X, y, fp_col_names)."""
if label_col not in df.columns:
raise KeyError(f"Label column {label_col!r} not in DataFrame")
fp_cols = [c for c in df.columns if c.startswith(_FP_COL_PREFIX)]
if not fp_cols:
raise KeyError(
f"No {_FP_COL_PREFIX}* columns found — was this DataFrame produced "
f"by bbb_pipeline.run_pipeline?"
)
X = df[fp_cols].to_numpy()
y = df[label_col].to_numpy()
return X, y, fp_cols
_CALIBRATION_THRESHOLDS: tuple[float, ...] = (0.50, 0.60, 0.70, 0.75, 0.80, 0.90)
def _compute_calibration_bins(
model: RandomForestClassifier,
X_test: np.ndarray,
y_test: np.ndarray,
) -> list[dict[str, float]]:
"""Compute precision-at-confidence-threshold bins on a held-out test set.
For each threshold T in `_CALIBRATION_THRESHOLDS`, picks the predictions
whose max class probability >= T, computes precision and support, and
returns one bin per threshold. Bins with zero support are still emitted
(precision = 0.0, support = 0) so the API can always find a match.
"""
if len(y_test) == 0:
return [
{"threshold": float(t), "precision": 0.0, "support": 0}
for t in _CALIBRATION_THRESHOLDS
]
proba = model.predict_proba(X_test)
pred = model.predict(X_test)
confidence = proba.max(axis=1)
correct = (pred == y_test).astype(int)
bins: list[dict[str, float]] = []
for t in _CALIBRATION_THRESHOLDS:
mask = confidence >= t
support = int(mask.sum())
if support == 0:
precision = 0.0
else:
precision = float(correct[mask].mean())
bins.append({
"threshold": float(t), "precision": precision, "support": support,
})
return bins
def _compute_train_stats(
model: RandomForestClassifier,
X_train: np.ndarray,
) -> dict[str, float]:
"""Compute median + std of the model's own confidence on the training set.
Used as the reference distribution for runtime drift detection. All values
are floats so the dict is joblib-roundtrip-safe and JSON-serializable.
"""
if len(X_train) == 0:
return {"median": 0.0, "std": 0.0, "n_train": 0}
proba = model.predict_proba(X_train)
confidence = proba.max(axis=1)
return {
"median": float(np.median(confidence)),
"std": float(np.std(confidence)),
"n_train": int(len(X_train)),
}
def train(
df: pd.DataFrame,
label_col: str = "p_np",
n_estimators: int = 100,
random_state: int = 42,
) -> RandomForestClassifier:
"""Train a Random Forest classifier on Morgan fingerprints.
Args:
df: Output of `bbb_pipeline.run_pipeline` — has `fp_0..fp_N-1` cols
plus a binary `label_col`.
label_col: Name of the binary target column. Defaults to "p_np".
n_estimators: Number of trees. 100 is the sklearn default.
random_state: Seed for split + tree construction (determinism).
Returns:
Fitted `RandomForestClassifier` with `feature_names_in_` set so
downstream callers can map SHAP values back to fp_<bit> indices.
"""
X, y, fp_cols = _split_features_and_label(df, label_col)
# Stratified 80/20 split for honest calibration metrics. Falls back to
# train-on-all if the dataset is too tiny for a stratified split (test
# fixtures with 3-4 rows hit this branch).
try:
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=random_state, stratify=y,
)
except ValueError as e:
logger.warning(
"Stratified split failed (%s); training on full data; "
"calibration bins will be zero-support.",
e,
)
X_train, X_test = X, np.empty((0, X.shape[1]))
y_train, y_test = y, np.empty((0,))
model = RandomForestClassifier(
n_estimators=n_estimators,
random_state=random_state,
n_jobs=1,
)
model.fit(X_train, y_train)
# Stash the column names under a project-owned attribute so SHAP (Task 2)
# can map values back to fp_<bit> indices. Sklearn's own feature_names_in_
# is only set automatically when fit receives a DataFrame; setting it
# manually fires UserWarning on every predict call.
model._neurobridge_fp_cols = list(fp_cols)
model._neurobridge_calibration = _compute_calibration_bins(
model, X_test, y_test,
)
model._neurobridge_train_stats = _compute_train_stats(model, X_train)
logger.info(
"Trained BBB classifier: n=%d, n_features=%d, classes=%s, "
"calibration_bins=%d, train_confidence_median=%.3f",
len(y), X.shape[1], model.classes_.tolist(),
len(model._neurobridge_calibration),
model._neurobridge_train_stats["median"],
)
return model
def save(model: RandomForestClassifier, path: Path) -> None:
"""Persist a fitted model to `path` (parent dirs auto-created)."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(model, path)
logger.info("Saved BBB model to %s", path)
def load(path: Path) -> RandomForestClassifier:
"""Load a previously-saved model. Raises FileNotFoundError on missing artifact."""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"BBB model artifact not found: {path}")
return joblib.load(path)
def predict_with_proba(
model: RandomForestClassifier,
smiles: str,
n_bits: int = 2048,
radius: int = 2,
) -> dict[str, object]:
"""Predict BBB permeability for a single SMILES.
Returns:
`{"label": int, "confidence": float}` where confidence is the
predicted class's probability (max class probability — model's
self-rated certainty).
Raises:
ValueError: if `smiles` cannot be parsed by RDKit.
"""
if not is_valid_smiles(smiles):
raise ValueError(f"invalid SMILES: {smiles!r}")
fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius)
proba = model.predict_proba(fp.reshape(1, -1))[0]
label_idx = int(np.argmax(proba))
label = int(model.classes_[label_idx])
return {
"label": label,
"confidence": float(proba[label_idx]),
}
def explain_prediction(
model: RandomForestClassifier,
smiles: str,
top_k: int = 5,
n_bits: int = 2048,
radius: int = 2,
) -> list[dict[str, object]]:
"""Return the top-`top_k` feature attributions (SHAP values) for `smiles`.
Uses `shap.TreeExplainer` (exact for tree ensembles, no sampling). The
explanation is for the *predicted* class — i.e. SHAP values that pushed
the model toward whichever label was returned by `predict_with_proba`.
Reads fingerprint column names from `model._neurobridge_fp_cols` (set by
`train()`). Falls back to `fp_<index>` if the attribute is missing — useful
for models loaded from a joblib without the project-owned attribute.
Args:
model: Fitted classifier from `train()` or `load()`.
smiles: A SMILES string (validated via `is_valid_smiles`).
top_k: How many top features to return. Default 5 — matches the
jury-demo budget (more bars = noisier waterfall chart).
n_bits / radius: Must match training-time fingerprint settings.
Returns:
A list of `{"feature": "fp_<bit_idx>", "shap_value": float}` dicts,
sorted by `abs(shap_value)` descending.
Raises:
ValueError: if `smiles` cannot be parsed by RDKit.
"""
import shap # local import — heavy module, only loaded when needed
if not is_valid_smiles(smiles):
raise ValueError(f"invalid SMILES: {smiles!r}")
fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius)
X = fp.reshape(1, -1)
explainer = shap.TreeExplainer(model)
# uint8 fingerprints cause benign additivity violations in SHAP's
# reconstruction (base + sum != model output within tolerance); the
# default check produces false-positive errors on tree ensembles
# over quantized inputs, so we skip it.
shap_values = explainer.shap_values(X, check_additivity=False)
# `shap_values` shape varies by sklearn / shap versions:
# - older: list of (1, n_features) arrays, one per class
# - newer: ndarray of shape (1, n_features, n_classes) for binary RF
# - or (1, n_features) when output already condensed
if isinstance(shap_values, list):
proba = model.predict_proba(X)[0]
label_idx = int(np.argmax(proba))
per_feature = shap_values[label_idx][0]
else:
arr = np.asarray(shap_values)
if arr.ndim == 3:
# (1, n_features, n_classes)
proba = model.predict_proba(X)[0]
label_idx = int(np.argmax(proba))
per_feature = arr[0, :, label_idx]
else:
# (1, n_features)
per_feature = arr[0]
fp_cols = (
list(model._neurobridge_fp_cols)
if hasattr(model, "_neurobridge_fp_cols")
else [f"fp_{i}" for i in range(len(per_feature))]
)
pairs = sorted(
zip(fp_cols, per_feature, strict=True),
key=lambda p: abs(p[1]),
reverse=True,
)
return [
{"feature": str(name), "shap_value": float(value)}
for name, value in pairs[:top_k]
]
DEFAULT_FEATURES_PATH = Path("data/processed/bbbp_features.parquet")
DEFAULT_MODEL_PATH = Path("data/processed/bbb_model.joblib")
def main() -> None:
"""Train and persist the production BBB model from the Day-4 features Parquet.
Reads from `DEFAULT_FEATURES_PATH`, trains with default hyperparameters,
and writes the artifact to `DEFAULT_MODEL_PATH`. Re-runs are idempotent
(same random_state).
"""
if not DEFAULT_FEATURES_PATH.exists():
raise FileNotFoundError(
f"Features Parquet not found at {DEFAULT_FEATURES_PATH}. "
f"Run `python -m src.pipelines.bbb_pipeline` first."
)
df = pd.read_parquet(DEFAULT_FEATURES_PATH)
model = train(df, label_col="p_np")
save(model, DEFAULT_MODEL_PATH)
logger.info("BBB model artifact ready at %s", DEFAULT_MODEL_PATH)
if __name__ == "__main__":
main()
|