feat(models): BBB classifier with predict_with_proba uncertainty
Browse files- requirements.txt +4 -0
- src/models/__init__.py +0 -0
- src/models/bbb_model.py +127 -0
- tests/fixtures/bbbp_sample.csv +1 -1
- tests/models/__init__.py +0 -0
- tests/models/test_bbb_model.py +91 -0
requirements.txt
CHANGED
|
@@ -28,6 +28,10 @@ statsmodels==0.14.6 # transitive dep of neuroharmonize; pinned for reproducibil
|
|
| 28 |
# --- Experiment tracking ---
|
| 29 |
mlflow==2.16.0
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
# --- Tooling / tests ---
|
| 32 |
pytest==8.3.3
|
| 33 |
pytest-cov==5.0.0
|
|
|
|
| 28 |
# --- Experiment tracking ---
|
| 29 |
mlflow==2.16.0
|
| 30 |
|
| 31 |
+
# --- Downstream ML / XAI (Day 5 decision layer) ---
|
| 32 |
+
shap==0.46.0
|
| 33 |
+
joblib==1.4.2
|
| 34 |
+
|
| 35 |
# --- Tooling / tests ---
|
| 36 |
pytest==8.3.3
|
| 37 |
pytest-cov==5.0.0
|
src/models/__init__.py
ADDED
|
File without changes
|
src/models/bbb_model.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BBB-permeability downstream classifier — train / save / load / predict.
|
| 2 |
+
|
| 3 |
+
Built on top of `data/processed/bbbp_features.parquet` produced by
|
| 4 |
+
`src.pipelines.bbb_pipeline`. Uses scikit-learn's `RandomForestClassifier`
|
| 5 |
+
(no XGBoost — saves a heavy dep without losing accuracy at this scale).
|
| 6 |
+
|
| 7 |
+
The model takes a 2,048-bit Morgan fingerprint as input. SHAP-based
|
| 8 |
+
explanation is added in Task 2 (`explain_prediction`).
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import joblib
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 18 |
+
|
| 19 |
+
from src.core.logger import get_logger
|
| 20 |
+
from src.pipelines.bbb_pipeline import (
|
| 21 |
+
compute_morgan_fingerprint,
|
| 22 |
+
is_valid_smiles,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
logger = get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_FP_COL_PREFIX = "fp_"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _split_features_and_label(
|
| 32 |
+
df: pd.DataFrame, label_col: str,
|
| 33 |
+
) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
| 34 |
+
"""Pull out fp_* columns as X and `label_col` as y. Returns (X, y, fp_col_names)."""
|
| 35 |
+
if label_col not in df.columns:
|
| 36 |
+
raise KeyError(f"Label column {label_col!r} not in DataFrame")
|
| 37 |
+
fp_cols = [c for c in df.columns if c.startswith(_FP_COL_PREFIX)]
|
| 38 |
+
if not fp_cols:
|
| 39 |
+
raise KeyError(
|
| 40 |
+
f"No {_FP_COL_PREFIX}* columns found — was this DataFrame produced "
|
| 41 |
+
f"by bbb_pipeline.run_pipeline?"
|
| 42 |
+
)
|
| 43 |
+
X = df[fp_cols].to_numpy()
|
| 44 |
+
y = df[label_col].to_numpy()
|
| 45 |
+
return X, y, fp_cols
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def train(
|
| 49 |
+
df: pd.DataFrame,
|
| 50 |
+
label_col: str = "p_np",
|
| 51 |
+
n_estimators: int = 100,
|
| 52 |
+
random_state: int = 42,
|
| 53 |
+
) -> RandomForestClassifier:
|
| 54 |
+
"""Train a Random Forest classifier on Morgan fingerprints.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
df: Output of `bbb_pipeline.run_pipeline` — has `fp_0..fp_N-1` cols
|
| 58 |
+
plus a binary `label_col`.
|
| 59 |
+
label_col: Name of the binary target column. Defaults to "p_np".
|
| 60 |
+
n_estimators: Number of trees. 100 is the sklearn default.
|
| 61 |
+
random_state: Seed for split + tree construction (determinism).
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Fitted `RandomForestClassifier` with `feature_names_in_` set so
|
| 65 |
+
downstream callers can map SHAP values back to fp_<bit> indices.
|
| 66 |
+
"""
|
| 67 |
+
X, y, fp_cols = _split_features_and_label(df, label_col)
|
| 68 |
+
model = RandomForestClassifier(
|
| 69 |
+
n_estimators=n_estimators,
|
| 70 |
+
random_state=random_state,
|
| 71 |
+
n_jobs=1,
|
| 72 |
+
)
|
| 73 |
+
model.fit(X, y)
|
| 74 |
+
# Stash the column names under a project-owned attribute so SHAP (Task 2)
|
| 75 |
+
# can map values back to fp_<bit> indices. Sklearn's own feature_names_in_
|
| 76 |
+
# is only set automatically when fit receives a DataFrame; setting it
|
| 77 |
+
# manually fires UserWarning on every predict call.
|
| 78 |
+
model._neurobridge_fp_cols = list(fp_cols)
|
| 79 |
+
logger.info(
|
| 80 |
+
"Trained BBB classifier: n=%d, n_features=%d, classes=%s",
|
| 81 |
+
len(y), X.shape[1], model.classes_.tolist(),
|
| 82 |
+
)
|
| 83 |
+
return model
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def save(model: RandomForestClassifier, path: Path) -> None:
|
| 87 |
+
"""Persist a fitted model to `path` (parent dirs auto-created)."""
|
| 88 |
+
path = Path(path)
|
| 89 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 90 |
+
joblib.dump(model, path)
|
| 91 |
+
logger.info("Saved BBB model to %s", path)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load(path: Path) -> RandomForestClassifier:
|
| 95 |
+
"""Load a previously-saved model. Raises FileNotFoundError on missing artifact."""
|
| 96 |
+
path = Path(path)
|
| 97 |
+
if not path.exists():
|
| 98 |
+
raise FileNotFoundError(f"BBB model artifact not found: {path}")
|
| 99 |
+
return joblib.load(path)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def predict_with_proba(
|
| 103 |
+
model: RandomForestClassifier,
|
| 104 |
+
smiles: str,
|
| 105 |
+
n_bits: int = 2048,
|
| 106 |
+
radius: int = 2,
|
| 107 |
+
) -> dict[str, object]:
|
| 108 |
+
"""Predict BBB permeability for a single SMILES.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
`{"label": int, "confidence": float}` where confidence is the
|
| 112 |
+
predicted class's probability (max class probability — model's
|
| 113 |
+
self-rated certainty).
|
| 114 |
+
|
| 115 |
+
Raises:
|
| 116 |
+
ValueError: if `smiles` cannot be parsed by RDKit.
|
| 117 |
+
"""
|
| 118 |
+
if not is_valid_smiles(smiles):
|
| 119 |
+
raise ValueError(f"invalid SMILES: {smiles!r}")
|
| 120 |
+
fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius)
|
| 121 |
+
proba = model.predict_proba(fp.reshape(1, -1))[0]
|
| 122 |
+
label_idx = int(np.argmax(proba))
|
| 123 |
+
label = int(model.classes_[label_idx])
|
| 124 |
+
return {
|
| 125 |
+
"label": label,
|
| 126 |
+
"confidence": float(proba[label_idx]),
|
| 127 |
+
}
|
tests/fixtures/bbbp_sample.csv
CHANGED
|
@@ -3,5 +3,5 @@ num,name,p_np,smiles
|
|
| 3 |
2,Benzene,1,c1ccccc1
|
| 4 |
3,Aspirin,1,CC(=O)OC1=CC=CC=C1C(=O)O
|
| 5 |
4,InvalidMol,0,this_is_not_a_smiles
|
| 6 |
-
5,Caffeine,
|
| 7 |
6,EmptyMol,0,
|
|
|
|
| 3 |
2,Benzene,1,c1ccccc1
|
| 4 |
3,Aspirin,1,CC(=O)OC1=CC=CC=C1C(=O)O
|
| 5 |
4,InvalidMol,0,this_is_not_a_smiles
|
| 6 |
+
5,Caffeine,0,CN1C=NC2=C1C(=O)N(C(=O)N2C)C
|
| 7 |
6,EmptyMol,0,
|
tests/models/__init__.py
ADDED
|
File without changes
|
tests/models/test_bbb_model.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.models.bbb_model — train, save/load, predict, uncertainty."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from src.models import bbb_model
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_FIXTURES = Path(__file__).resolve().parents[1] / "fixtures"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@pytest.fixture(scope="module")
|
| 17 |
+
def trained_model_and_features():
|
| 18 |
+
"""Train one tiny model from the committed BBBP fixture; cache for the module."""
|
| 19 |
+
from src.pipelines import bbb_pipeline
|
| 20 |
+
import tempfile
|
| 21 |
+
tmp = Path(tempfile.mkdtemp(prefix="bbb_model_test_"))
|
| 22 |
+
out = tmp / "features.parquet"
|
| 23 |
+
bbb_pipeline.run_pipeline(
|
| 24 |
+
input_path=_FIXTURES / "bbbp_sample.csv",
|
| 25 |
+
output_path=out,
|
| 26 |
+
)
|
| 27 |
+
df = pd.read_parquet(out)
|
| 28 |
+
# Tiny n_estimators for test speed; real training uses default 100.
|
| 29 |
+
model = bbb_model.train(df, label_col="p_np", n_estimators=10, random_state=42)
|
| 30 |
+
return model, df
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class TestTrain:
|
| 34 |
+
def test_returns_fitted_classifier(self, trained_model_and_features):
|
| 35 |
+
model, _ = trained_model_and_features
|
| 36 |
+
assert hasattr(model, "classes_")
|
| 37 |
+
assert len(model.classes_) == 2
|
| 38 |
+
|
| 39 |
+
def test_raises_on_missing_label_column(self, trained_model_and_features):
|
| 40 |
+
_, df = trained_model_and_features
|
| 41 |
+
with pytest.raises(KeyError):
|
| 42 |
+
bbb_model.train(df.drop(columns=["p_np"]), label_col="p_np")
|
| 43 |
+
|
| 44 |
+
def test_deterministic_with_random_state(self, trained_model_and_features):
|
| 45 |
+
_, df = trained_model_and_features
|
| 46 |
+
m1 = bbb_model.train(df, label_col="p_np", n_estimators=10, random_state=42)
|
| 47 |
+
m2 = bbb_model.train(df, label_col="p_np", n_estimators=10, random_state=42)
|
| 48 |
+
fp_cols = [c for c in df.columns if c.startswith("fp_")]
|
| 49 |
+
X = df[fp_cols].to_numpy()
|
| 50 |
+
np.testing.assert_array_equal(m1.predict_proba(X), m2.predict_proba(X))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TestSaveLoad:
|
| 54 |
+
def test_save_then_load_roundtrip(self, trained_model_and_features, tmp_path: Path):
|
| 55 |
+
model, df = trained_model_and_features
|
| 56 |
+
artifact = tmp_path / "bbb_model.joblib"
|
| 57 |
+
bbb_model.save(model, artifact)
|
| 58 |
+
assert artifact.exists()
|
| 59 |
+
|
| 60 |
+
reloaded = bbb_model.load(artifact)
|
| 61 |
+
fp_cols = [c for c in df.columns if c.startswith("fp_")]
|
| 62 |
+
X = df[fp_cols].to_numpy()
|
| 63 |
+
np.testing.assert_array_equal(model.predict(X), reloaded.predict(X))
|
| 64 |
+
|
| 65 |
+
def test_load_raises_on_missing_path(self, tmp_path: Path):
|
| 66 |
+
with pytest.raises(FileNotFoundError):
|
| 67 |
+
bbb_model.load(tmp_path / "does_not_exist.joblib")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class TestPredictWithProba:
|
| 71 |
+
def test_returns_label_and_confidence(self, trained_model_and_features):
|
| 72 |
+
model, _ = trained_model_and_features
|
| 73 |
+
result = bbb_model.predict_with_proba(model, "CCO")
|
| 74 |
+
assert "label" in result
|
| 75 |
+
assert "confidence" in result
|
| 76 |
+
assert result["label"] in (0, 1)
|
| 77 |
+
assert 0.0 <= result["confidence"] <= 1.0
|
| 78 |
+
|
| 79 |
+
def test_raises_on_invalid_smiles(self, trained_model_and_features):
|
| 80 |
+
model, _ = trained_model_and_features
|
| 81 |
+
with pytest.raises(ValueError):
|
| 82 |
+
bbb_model.predict_with_proba(model, "this_is_not_a_smiles_AT_ALL")
|
| 83 |
+
|
| 84 |
+
def test_confidence_equals_max_class_probability(self, trained_model_and_features):
|
| 85 |
+
"""confidence is the max class probability — verifies against raw predict_proba."""
|
| 86 |
+
model, _ = trained_model_and_features
|
| 87 |
+
from src.pipelines.bbb_pipeline import compute_morgan_fingerprint
|
| 88 |
+
fp = compute_morgan_fingerprint("CCO").reshape(1, -1)
|
| 89 |
+
raw_proba = model.predict_proba(fp)[0]
|
| 90 |
+
result = bbb_model.predict_with_proba(model, "CCO")
|
| 91 |
+
assert abs(result["confidence"] - float(max(raw_proba))) < 1e-9
|