| """BBB-permeability downstream classifier — train / save / load / predict. |
| |
| Built on top of `data/processed/bbbp_features.parquet` produced by |
| `src.pipelines.bbb_pipeline`. Uses scikit-learn's `RandomForestClassifier` |
| (no XGBoost — saves a heavy dep without losing accuracy at this scale). |
| |
| The model takes a 2,048-bit Morgan fingerprint as input. SHAP-based |
| explanation is added in Task 2 (`explain_prediction`). |
| """ |
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import joblib |
| import numpy as np |
| import pandas as pd |
| from sklearn.ensemble import RandomForestClassifier |
| from sklearn.model_selection import train_test_split |
|
|
| from src.core.logger import get_logger |
| from src.pipelines.bbb_pipeline import ( |
| compute_morgan_fingerprint, |
| is_valid_smiles, |
| ) |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| _FP_COL_PREFIX = "fp_" |
|
|
|
|
| def _split_features_and_label( |
| df: pd.DataFrame, label_col: str, |
| ) -> tuple[np.ndarray, np.ndarray, list[str]]: |
| """Pull out fp_* columns as X and `label_col` as y. Returns (X, y, fp_col_names).""" |
| if label_col not in df.columns: |
| raise KeyError(f"Label column {label_col!r} not in DataFrame") |
| fp_cols = [c for c in df.columns if c.startswith(_FP_COL_PREFIX)] |
| if not fp_cols: |
| raise KeyError( |
| f"No {_FP_COL_PREFIX}* columns found — was this DataFrame produced " |
| f"by bbb_pipeline.run_pipeline?" |
| ) |
| X = df[fp_cols].to_numpy() |
| y = df[label_col].to_numpy() |
| return X, y, fp_cols |
|
|
|
|
| _CALIBRATION_THRESHOLDS: tuple[float, ...] = (0.50, 0.60, 0.70, 0.75, 0.80, 0.90) |
|
|
|
|
| def _compute_calibration_bins( |
| model: RandomForestClassifier, |
| X_test: np.ndarray, |
| y_test: np.ndarray, |
| ) -> list[dict[str, float]]: |
| """Compute precision-at-confidence-threshold bins on a held-out test set. |
| |
| For each threshold T in `_CALIBRATION_THRESHOLDS`, picks the predictions |
| whose max class probability >= T, computes precision and support, and |
| returns one bin per threshold. Bins with zero support are still emitted |
| (precision = 0.0, support = 0) so the API can always find a match. |
| """ |
| if len(y_test) == 0: |
| return [ |
| {"threshold": float(t), "precision": 0.0, "support": 0} |
| for t in _CALIBRATION_THRESHOLDS |
| ] |
| proba = model.predict_proba(X_test) |
| pred = model.predict(X_test) |
| confidence = proba.max(axis=1) |
| correct = (pred == y_test).astype(int) |
| bins: list[dict[str, float]] = [] |
| for t in _CALIBRATION_THRESHOLDS: |
| mask = confidence >= t |
| support = int(mask.sum()) |
| if support == 0: |
| precision = 0.0 |
| else: |
| precision = float(correct[mask].mean()) |
| bins.append({ |
| "threshold": float(t), "precision": precision, "support": support, |
| }) |
| return bins |
|
|
|
|
| def _compute_train_stats( |
| model: RandomForestClassifier, |
| X_train: np.ndarray, |
| ) -> dict[str, float]: |
| """Compute median + std of the model's own confidence on the training set. |
| |
| Used as the reference distribution for runtime drift detection. All values |
| are floats so the dict is joblib-roundtrip-safe and JSON-serializable. |
| """ |
| if len(X_train) == 0: |
| return {"median": 0.0, "std": 0.0, "n_train": 0} |
| proba = model.predict_proba(X_train) |
| confidence = proba.max(axis=1) |
| return { |
| "median": float(np.median(confidence)), |
| "std": float(np.std(confidence)), |
| "n_train": int(len(X_train)), |
| } |
|
|
|
|
| def train( |
| df: pd.DataFrame, |
| label_col: str = "p_np", |
| n_estimators: int = 100, |
| random_state: int = 42, |
| ) -> RandomForestClassifier: |
| """Train a Random Forest classifier on Morgan fingerprints. |
| |
| Args: |
| df: Output of `bbb_pipeline.run_pipeline` — has `fp_0..fp_N-1` cols |
| plus a binary `label_col`. |
| label_col: Name of the binary target column. Defaults to "p_np". |
| n_estimators: Number of trees. 100 is the sklearn default. |
| random_state: Seed for split + tree construction (determinism). |
| |
| Returns: |
| Fitted `RandomForestClassifier` with `feature_names_in_` set so |
| downstream callers can map SHAP values back to fp_<bit> indices. |
| """ |
| X, y, fp_cols = _split_features_and_label(df, label_col) |
| |
| |
| |
| try: |
| X_train, X_test, y_train, y_test = train_test_split( |
| X, y, test_size=0.2, random_state=random_state, stratify=y, |
| ) |
| except ValueError as e: |
| logger.warning( |
| "Stratified split failed (%s); training on full data; " |
| "calibration bins will be zero-support.", |
| e, |
| ) |
| X_train, X_test = X, np.empty((0, X.shape[1])) |
| y_train, y_test = y, np.empty((0,)) |
|
|
| model = RandomForestClassifier( |
| n_estimators=n_estimators, |
| random_state=random_state, |
| n_jobs=1, |
| ) |
| model.fit(X_train, y_train) |
| |
| |
| |
| |
| model._neurobridge_fp_cols = list(fp_cols) |
| model._neurobridge_calibration = _compute_calibration_bins( |
| model, X_test, y_test, |
| ) |
| model._neurobridge_train_stats = _compute_train_stats(model, X_train) |
| logger.info( |
| "Trained BBB classifier: n=%d, n_features=%d, classes=%s, " |
| "calibration_bins=%d, train_confidence_median=%.3f", |
| len(y), X.shape[1], model.classes_.tolist(), |
| len(model._neurobridge_calibration), |
| model._neurobridge_train_stats["median"], |
| ) |
| return model |
|
|
|
|
| def save(model: RandomForestClassifier, path: Path) -> None: |
| """Persist a fitted model to `path` (parent dirs auto-created).""" |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| joblib.dump(model, path) |
| logger.info("Saved BBB model to %s", path) |
|
|
|
|
| def load(path: Path) -> RandomForestClassifier: |
| """Load a previously-saved model. Raises FileNotFoundError on missing artifact.""" |
| path = Path(path) |
| if not path.exists(): |
| raise FileNotFoundError(f"BBB model artifact not found: {path}") |
| return joblib.load(path) |
|
|
|
|
| def predict_with_proba( |
| model: RandomForestClassifier, |
| smiles: str, |
| n_bits: int = 2048, |
| radius: int = 2, |
| ) -> dict[str, object]: |
| """Predict BBB permeability for a single SMILES. |
| |
| Returns: |
| `{"label": int, "confidence": float}` where confidence is the |
| predicted class's probability (max class probability — model's |
| self-rated certainty). |
| |
| Raises: |
| ValueError: if `smiles` cannot be parsed by RDKit. |
| """ |
| if not is_valid_smiles(smiles): |
| raise ValueError(f"invalid SMILES: {smiles!r}") |
| fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius) |
| proba = model.predict_proba(fp.reshape(1, -1))[0] |
| label_idx = int(np.argmax(proba)) |
| label = int(model.classes_[label_idx]) |
| return { |
| "label": label, |
| "confidence": float(proba[label_idx]), |
| } |
|
|
|
|
| def explain_prediction( |
| model: RandomForestClassifier, |
| smiles: str, |
| top_k: int = 5, |
| n_bits: int = 2048, |
| radius: int = 2, |
| ) -> list[dict[str, object]]: |
| """Return the top-`top_k` feature attributions (SHAP values) for `smiles`. |
| |
| Uses `shap.TreeExplainer` (exact for tree ensembles, no sampling). The |
| explanation is for the *predicted* class — i.e. SHAP values that pushed |
| the model toward whichever label was returned by `predict_with_proba`. |
| |
| Reads fingerprint column names from `model._neurobridge_fp_cols` (set by |
| `train()`). Falls back to `fp_<index>` if the attribute is missing — useful |
| for models loaded from a joblib without the project-owned attribute. |
| |
| Args: |
| model: Fitted classifier from `train()` or `load()`. |
| smiles: A SMILES string (validated via `is_valid_smiles`). |
| top_k: How many top features to return. Default 5 — matches the |
| jury-demo budget (more bars = noisier waterfall chart). |
| n_bits / radius: Must match training-time fingerprint settings. |
| |
| Returns: |
| A list of `{"feature": "fp_<bit_idx>", "shap_value": float}` dicts, |
| sorted by `abs(shap_value)` descending. |
| |
| Raises: |
| ValueError: if `smiles` cannot be parsed by RDKit. |
| """ |
| import shap |
|
|
| if not is_valid_smiles(smiles): |
| raise ValueError(f"invalid SMILES: {smiles!r}") |
| fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius) |
| X = fp.reshape(1, -1) |
|
|
| explainer = shap.TreeExplainer(model) |
| |
| |
| |
| |
| shap_values = explainer.shap_values(X, check_additivity=False) |
| |
| |
| |
| |
| if isinstance(shap_values, list): |
| proba = model.predict_proba(X)[0] |
| label_idx = int(np.argmax(proba)) |
| per_feature = shap_values[label_idx][0] |
| else: |
| arr = np.asarray(shap_values) |
| if arr.ndim == 3: |
| |
| proba = model.predict_proba(X)[0] |
| label_idx = int(np.argmax(proba)) |
| per_feature = arr[0, :, label_idx] |
| else: |
| |
| per_feature = arr[0] |
|
|
| fp_cols = ( |
| list(model._neurobridge_fp_cols) |
| if hasattr(model, "_neurobridge_fp_cols") |
| else [f"fp_{i}" for i in range(len(per_feature))] |
| ) |
|
|
| pairs = sorted( |
| zip(fp_cols, per_feature, strict=True), |
| key=lambda p: abs(p[1]), |
| reverse=True, |
| ) |
| return [ |
| {"feature": str(name), "shap_value": float(value)} |
| for name, value in pairs[:top_k] |
| ] |
|
|
|
|
| DEFAULT_FEATURES_PATH = Path("data/processed/bbbp_features.parquet") |
| DEFAULT_MODEL_PATH = Path("data/processed/bbb_model.joblib") |
|
|
|
|
| def main() -> None: |
| """Train and persist the production BBB model from the Day-4 features Parquet. |
| |
| Reads from `DEFAULT_FEATURES_PATH`, trains with default hyperparameters, |
| and writes the artifact to `DEFAULT_MODEL_PATH`. Re-runs are idempotent |
| (same random_state). |
| """ |
| if not DEFAULT_FEATURES_PATH.exists(): |
| raise FileNotFoundError( |
| f"Features Parquet not found at {DEFAULT_FEATURES_PATH}. " |
| f"Run `python -m src.pipelines.bbb_pipeline` first." |
| ) |
| df = pd.read_parquet(DEFAULT_FEATURES_PATH) |
| model = train(df, label_col="p_np") |
| save(model, DEFAULT_MODEL_PATH) |
| logger.info("BBB model artifact ready at %s", DEFAULT_MODEL_PATH) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|