hackathon / src /models /bbb_model.py
mekosotto's picture
feat(models): train-time confidence stats stashed on _neurobridge_train_stats
efb8713
"""BBB-permeability downstream classifier — train / save / load / predict.
Built on top of `data/processed/bbbp_features.parquet` produced by
`src.pipelines.bbb_pipeline`. Uses scikit-learn's `RandomForestClassifier`
(no XGBoost — saves a heavy dep without losing accuracy at this scale).
The model takes a 2,048-bit Morgan fingerprint as input. SHAP-based
explanation is added in Task 2 (`explain_prediction`).
"""
from __future__ import annotations
from pathlib import Path
import joblib
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from src.core.logger import get_logger
from src.pipelines.bbb_pipeline import (
compute_morgan_fingerprint,
is_valid_smiles,
)
logger = get_logger(__name__)
_FP_COL_PREFIX = "fp_"
def _split_features_and_label(
df: pd.DataFrame, label_col: str,
) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""Pull out fp_* columns as X and `label_col` as y. Returns (X, y, fp_col_names)."""
if label_col not in df.columns:
raise KeyError(f"Label column {label_col!r} not in DataFrame")
fp_cols = [c for c in df.columns if c.startswith(_FP_COL_PREFIX)]
if not fp_cols:
raise KeyError(
f"No {_FP_COL_PREFIX}* columns found — was this DataFrame produced "
f"by bbb_pipeline.run_pipeline?"
)
X = df[fp_cols].to_numpy()
y = df[label_col].to_numpy()
return X, y, fp_cols
_CALIBRATION_THRESHOLDS: tuple[float, ...] = (0.50, 0.60, 0.70, 0.75, 0.80, 0.90)
def _compute_calibration_bins(
model: RandomForestClassifier,
X_test: np.ndarray,
y_test: np.ndarray,
) -> list[dict[str, float]]:
"""Compute precision-at-confidence-threshold bins on a held-out test set.
For each threshold T in `_CALIBRATION_THRESHOLDS`, picks the predictions
whose max class probability >= T, computes precision and support, and
returns one bin per threshold. Bins with zero support are still emitted
(precision = 0.0, support = 0) so the API can always find a match.
"""
if len(y_test) == 0:
return [
{"threshold": float(t), "precision": 0.0, "support": 0}
for t in _CALIBRATION_THRESHOLDS
]
proba = model.predict_proba(X_test)
pred = model.predict(X_test)
confidence = proba.max(axis=1)
correct = (pred == y_test).astype(int)
bins: list[dict[str, float]] = []
for t in _CALIBRATION_THRESHOLDS:
mask = confidence >= t
support = int(mask.sum())
if support == 0:
precision = 0.0
else:
precision = float(correct[mask].mean())
bins.append({
"threshold": float(t), "precision": precision, "support": support,
})
return bins
def _compute_train_stats(
model: RandomForestClassifier,
X_train: np.ndarray,
) -> dict[str, float]:
"""Compute median + std of the model's own confidence on the training set.
Used as the reference distribution for runtime drift detection. All values
are floats so the dict is joblib-roundtrip-safe and JSON-serializable.
"""
if len(X_train) == 0:
return {"median": 0.0, "std": 0.0, "n_train": 0}
proba = model.predict_proba(X_train)
confidence = proba.max(axis=1)
return {
"median": float(np.median(confidence)),
"std": float(np.std(confidence)),
"n_train": int(len(X_train)),
}
def train(
df: pd.DataFrame,
label_col: str = "p_np",
n_estimators: int = 100,
random_state: int = 42,
) -> RandomForestClassifier:
"""Train a Random Forest classifier on Morgan fingerprints.
Args:
df: Output of `bbb_pipeline.run_pipeline` — has `fp_0..fp_N-1` cols
plus a binary `label_col`.
label_col: Name of the binary target column. Defaults to "p_np".
n_estimators: Number of trees. 100 is the sklearn default.
random_state: Seed for split + tree construction (determinism).
Returns:
Fitted `RandomForestClassifier` with `feature_names_in_` set so
downstream callers can map SHAP values back to fp_<bit> indices.
"""
X, y, fp_cols = _split_features_and_label(df, label_col)
# Stratified 80/20 split for honest calibration metrics. Falls back to
# train-on-all if the dataset is too tiny for a stratified split (test
# fixtures with 3-4 rows hit this branch).
try:
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=random_state, stratify=y,
)
except ValueError as e:
logger.warning(
"Stratified split failed (%s); training on full data; "
"calibration bins will be zero-support.",
e,
)
X_train, X_test = X, np.empty((0, X.shape[1]))
y_train, y_test = y, np.empty((0,))
model = RandomForestClassifier(
n_estimators=n_estimators,
random_state=random_state,
n_jobs=1,
)
model.fit(X_train, y_train)
# Stash the column names under a project-owned attribute so SHAP (Task 2)
# can map values back to fp_<bit> indices. Sklearn's own feature_names_in_
# is only set automatically when fit receives a DataFrame; setting it
# manually fires UserWarning on every predict call.
model._neurobridge_fp_cols = list(fp_cols)
model._neurobridge_calibration = _compute_calibration_bins(
model, X_test, y_test,
)
model._neurobridge_train_stats = _compute_train_stats(model, X_train)
logger.info(
"Trained BBB classifier: n=%d, n_features=%d, classes=%s, "
"calibration_bins=%d, train_confidence_median=%.3f",
len(y), X.shape[1], model.classes_.tolist(),
len(model._neurobridge_calibration),
model._neurobridge_train_stats["median"],
)
return model
def save(model: RandomForestClassifier, path: Path) -> None:
"""Persist a fitted model to `path` (parent dirs auto-created)."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(model, path)
logger.info("Saved BBB model to %s", path)
def load(path: Path) -> RandomForestClassifier:
"""Load a previously-saved model. Raises FileNotFoundError on missing artifact."""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"BBB model artifact not found: {path}")
return joblib.load(path)
def predict_with_proba(
model: RandomForestClassifier,
smiles: str,
n_bits: int = 2048,
radius: int = 2,
) -> dict[str, object]:
"""Predict BBB permeability for a single SMILES.
Returns:
`{"label": int, "confidence": float}` where confidence is the
predicted class's probability (max class probability — model's
self-rated certainty).
Raises:
ValueError: if `smiles` cannot be parsed by RDKit.
"""
if not is_valid_smiles(smiles):
raise ValueError(f"invalid SMILES: {smiles!r}")
fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius)
proba = model.predict_proba(fp.reshape(1, -1))[0]
label_idx = int(np.argmax(proba))
label = int(model.classes_[label_idx])
return {
"label": label,
"confidence": float(proba[label_idx]),
}
def explain_prediction(
model: RandomForestClassifier,
smiles: str,
top_k: int = 5,
n_bits: int = 2048,
radius: int = 2,
) -> list[dict[str, object]]:
"""Return the top-`top_k` feature attributions (SHAP values) for `smiles`.
Uses `shap.TreeExplainer` (exact for tree ensembles, no sampling). The
explanation is for the *predicted* class — i.e. SHAP values that pushed
the model toward whichever label was returned by `predict_with_proba`.
Reads fingerprint column names from `model._neurobridge_fp_cols` (set by
`train()`). Falls back to `fp_<index>` if the attribute is missing — useful
for models loaded from a joblib without the project-owned attribute.
Args:
model: Fitted classifier from `train()` or `load()`.
smiles: A SMILES string (validated via `is_valid_smiles`).
top_k: How many top features to return. Default 5 — matches the
jury-demo budget (more bars = noisier waterfall chart).
n_bits / radius: Must match training-time fingerprint settings.
Returns:
A list of `{"feature": "fp_<bit_idx>", "shap_value": float}` dicts,
sorted by `abs(shap_value)` descending.
Raises:
ValueError: if `smiles` cannot be parsed by RDKit.
"""
import shap # local import — heavy module, only loaded when needed
if not is_valid_smiles(smiles):
raise ValueError(f"invalid SMILES: {smiles!r}")
fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius)
X = fp.reshape(1, -1)
explainer = shap.TreeExplainer(model)
# uint8 fingerprints cause benign additivity violations in SHAP's
# reconstruction (base + sum != model output within tolerance); the
# default check produces false-positive errors on tree ensembles
# over quantized inputs, so we skip it.
shap_values = explainer.shap_values(X, check_additivity=False)
# `shap_values` shape varies by sklearn / shap versions:
# - older: list of (1, n_features) arrays, one per class
# - newer: ndarray of shape (1, n_features, n_classes) for binary RF
# - or (1, n_features) when output already condensed
if isinstance(shap_values, list):
proba = model.predict_proba(X)[0]
label_idx = int(np.argmax(proba))
per_feature = shap_values[label_idx][0]
else:
arr = np.asarray(shap_values)
if arr.ndim == 3:
# (1, n_features, n_classes)
proba = model.predict_proba(X)[0]
label_idx = int(np.argmax(proba))
per_feature = arr[0, :, label_idx]
else:
# (1, n_features)
per_feature = arr[0]
fp_cols = (
list(model._neurobridge_fp_cols)
if hasattr(model, "_neurobridge_fp_cols")
else [f"fp_{i}" for i in range(len(per_feature))]
)
pairs = sorted(
zip(fp_cols, per_feature, strict=True),
key=lambda p: abs(p[1]),
reverse=True,
)
return [
{"feature": str(name), "shap_value": float(value)}
for name, value in pairs[:top_k]
]
DEFAULT_FEATURES_PATH = Path("data/processed/bbbp_features.parquet")
DEFAULT_MODEL_PATH = Path("data/processed/bbb_model.joblib")
def main() -> None:
"""Train and persist the production BBB model from the Day-4 features Parquet.
Reads from `DEFAULT_FEATURES_PATH`, trains with default hyperparameters,
and writes the artifact to `DEFAULT_MODEL_PATH`. Re-runs are idempotent
(same random_state).
"""
if not DEFAULT_FEATURES_PATH.exists():
raise FileNotFoundError(
f"Features Parquet not found at {DEFAULT_FEATURES_PATH}. "
f"Run `python -m src.pipelines.bbb_pipeline` first."
)
df = pd.read_parquet(DEFAULT_FEATURES_PATH)
model = train(df, label_col="p_np")
save(model, DEFAULT_MODEL_PATH)
logger.info("BBB model artifact ready at %s", DEFAULT_MODEL_PATH)
if __name__ == "__main__":
main()