mekosotto commited on
Commit
3cc6a7d
·
1 Parent(s): 297ad76

feat(models): BBB classifier with predict_with_proba uncertainty

Browse files
requirements.txt CHANGED
@@ -28,6 +28,10 @@ statsmodels==0.14.6 # transitive dep of neuroharmonize; pinned for reproducibil
28
  # --- Experiment tracking ---
29
  mlflow==2.16.0
30
 
 
 
 
 
31
  # --- Tooling / tests ---
32
  pytest==8.3.3
33
  pytest-cov==5.0.0
 
28
  # --- Experiment tracking ---
29
  mlflow==2.16.0
30
 
31
+ # --- Downstream ML / XAI (Day 5 decision layer) ---
32
+ shap==0.46.0
33
+ joblib==1.4.2
34
+
35
  # --- Tooling / tests ---
36
  pytest==8.3.3
37
  pytest-cov==5.0.0
src/models/__init__.py ADDED
File without changes
src/models/bbb_model.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BBB-permeability downstream classifier — train / save / load / predict.
2
+
3
+ Built on top of `data/processed/bbbp_features.parquet` produced by
4
+ `src.pipelines.bbb_pipeline`. Uses scikit-learn's `RandomForestClassifier`
5
+ (no XGBoost — saves a heavy dep without losing accuracy at this scale).
6
+
7
+ The model takes a 2,048-bit Morgan fingerprint as input. SHAP-based
8
+ explanation is added in Task 2 (`explain_prediction`).
9
+ """
10
+ from __future__ import annotations
11
+
12
+ from pathlib import Path
13
+
14
+ import joblib
15
+ import numpy as np
16
+ import pandas as pd
17
+ from sklearn.ensemble import RandomForestClassifier
18
+
19
+ from src.core.logger import get_logger
20
+ from src.pipelines.bbb_pipeline import (
21
+ compute_morgan_fingerprint,
22
+ is_valid_smiles,
23
+ )
24
+
25
+ logger = get_logger(__name__)
26
+
27
+
28
+ _FP_COL_PREFIX = "fp_"
29
+
30
+
31
+ def _split_features_and_label(
32
+ df: pd.DataFrame, label_col: str,
33
+ ) -> tuple[np.ndarray, np.ndarray, list[str]]:
34
+ """Pull out fp_* columns as X and `label_col` as y. Returns (X, y, fp_col_names)."""
35
+ if label_col not in df.columns:
36
+ raise KeyError(f"Label column {label_col!r} not in DataFrame")
37
+ fp_cols = [c for c in df.columns if c.startswith(_FP_COL_PREFIX)]
38
+ if not fp_cols:
39
+ raise KeyError(
40
+ f"No {_FP_COL_PREFIX}* columns found — was this DataFrame produced "
41
+ f"by bbb_pipeline.run_pipeline?"
42
+ )
43
+ X = df[fp_cols].to_numpy()
44
+ y = df[label_col].to_numpy()
45
+ return X, y, fp_cols
46
+
47
+
48
+ def train(
49
+ df: pd.DataFrame,
50
+ label_col: str = "p_np",
51
+ n_estimators: int = 100,
52
+ random_state: int = 42,
53
+ ) -> RandomForestClassifier:
54
+ """Train a Random Forest classifier on Morgan fingerprints.
55
+
56
+ Args:
57
+ df: Output of `bbb_pipeline.run_pipeline` — has `fp_0..fp_N-1` cols
58
+ plus a binary `label_col`.
59
+ label_col: Name of the binary target column. Defaults to "p_np".
60
+ n_estimators: Number of trees. 100 is the sklearn default.
61
+ random_state: Seed for split + tree construction (determinism).
62
+
63
+ Returns:
64
+ Fitted `RandomForestClassifier` with `feature_names_in_` set so
65
+ downstream callers can map SHAP values back to fp_<bit> indices.
66
+ """
67
+ X, y, fp_cols = _split_features_and_label(df, label_col)
68
+ model = RandomForestClassifier(
69
+ n_estimators=n_estimators,
70
+ random_state=random_state,
71
+ n_jobs=1,
72
+ )
73
+ model.fit(X, y)
74
+ # Stash the column names under a project-owned attribute so SHAP (Task 2)
75
+ # can map values back to fp_<bit> indices. Sklearn's own feature_names_in_
76
+ # is only set automatically when fit receives a DataFrame; setting it
77
+ # manually fires UserWarning on every predict call.
78
+ model._neurobridge_fp_cols = list(fp_cols)
79
+ logger.info(
80
+ "Trained BBB classifier: n=%d, n_features=%d, classes=%s",
81
+ len(y), X.shape[1], model.classes_.tolist(),
82
+ )
83
+ return model
84
+
85
+
86
+ def save(model: RandomForestClassifier, path: Path) -> None:
87
+ """Persist a fitted model to `path` (parent dirs auto-created)."""
88
+ path = Path(path)
89
+ path.parent.mkdir(parents=True, exist_ok=True)
90
+ joblib.dump(model, path)
91
+ logger.info("Saved BBB model to %s", path)
92
+
93
+
94
+ def load(path: Path) -> RandomForestClassifier:
95
+ """Load a previously-saved model. Raises FileNotFoundError on missing artifact."""
96
+ path = Path(path)
97
+ if not path.exists():
98
+ raise FileNotFoundError(f"BBB model artifact not found: {path}")
99
+ return joblib.load(path)
100
+
101
+
102
+ def predict_with_proba(
103
+ model: RandomForestClassifier,
104
+ smiles: str,
105
+ n_bits: int = 2048,
106
+ radius: int = 2,
107
+ ) -> dict[str, object]:
108
+ """Predict BBB permeability for a single SMILES.
109
+
110
+ Returns:
111
+ `{"label": int, "confidence": float}` where confidence is the
112
+ predicted class's probability (max class probability — model's
113
+ self-rated certainty).
114
+
115
+ Raises:
116
+ ValueError: if `smiles` cannot be parsed by RDKit.
117
+ """
118
+ if not is_valid_smiles(smiles):
119
+ raise ValueError(f"invalid SMILES: {smiles!r}")
120
+ fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius)
121
+ proba = model.predict_proba(fp.reshape(1, -1))[0]
122
+ label_idx = int(np.argmax(proba))
123
+ label = int(model.classes_[label_idx])
124
+ return {
125
+ "label": label,
126
+ "confidence": float(proba[label_idx]),
127
+ }
tests/fixtures/bbbp_sample.csv CHANGED
@@ -3,5 +3,5 @@ num,name,p_np,smiles
3
  2,Benzene,1,c1ccccc1
4
  3,Aspirin,1,CC(=O)OC1=CC=CC=C1C(=O)O
5
  4,InvalidMol,0,this_is_not_a_smiles
6
- 5,Caffeine,1,CN1C=NC2=C1C(=O)N(C(=O)N2C)C
7
  6,EmptyMol,0,
 
3
  2,Benzene,1,c1ccccc1
4
  3,Aspirin,1,CC(=O)OC1=CC=CC=C1C(=O)O
5
  4,InvalidMol,0,this_is_not_a_smiles
6
+ 5,Caffeine,0,CN1C=NC2=C1C(=O)N(C(=O)N2C)C
7
  6,EmptyMol,0,
tests/models/__init__.py ADDED
File without changes
tests/models/test_bbb_model.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.models.bbb_model — train, save/load, predict, uncertainty."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import pytest
9
+
10
+ from src.models import bbb_model
11
+
12
+
13
+ _FIXTURES = Path(__file__).resolve().parents[1] / "fixtures"
14
+
15
+
16
+ @pytest.fixture(scope="module")
17
+ def trained_model_and_features():
18
+ """Train one tiny model from the committed BBBP fixture; cache for the module."""
19
+ from src.pipelines import bbb_pipeline
20
+ import tempfile
21
+ tmp = Path(tempfile.mkdtemp(prefix="bbb_model_test_"))
22
+ out = tmp / "features.parquet"
23
+ bbb_pipeline.run_pipeline(
24
+ input_path=_FIXTURES / "bbbp_sample.csv",
25
+ output_path=out,
26
+ )
27
+ df = pd.read_parquet(out)
28
+ # Tiny n_estimators for test speed; real training uses default 100.
29
+ model = bbb_model.train(df, label_col="p_np", n_estimators=10, random_state=42)
30
+ return model, df
31
+
32
+
33
+ class TestTrain:
34
+ def test_returns_fitted_classifier(self, trained_model_and_features):
35
+ model, _ = trained_model_and_features
36
+ assert hasattr(model, "classes_")
37
+ assert len(model.classes_) == 2
38
+
39
+ def test_raises_on_missing_label_column(self, trained_model_and_features):
40
+ _, df = trained_model_and_features
41
+ with pytest.raises(KeyError):
42
+ bbb_model.train(df.drop(columns=["p_np"]), label_col="p_np")
43
+
44
+ def test_deterministic_with_random_state(self, trained_model_and_features):
45
+ _, df = trained_model_and_features
46
+ m1 = bbb_model.train(df, label_col="p_np", n_estimators=10, random_state=42)
47
+ m2 = bbb_model.train(df, label_col="p_np", n_estimators=10, random_state=42)
48
+ fp_cols = [c for c in df.columns if c.startswith("fp_")]
49
+ X = df[fp_cols].to_numpy()
50
+ np.testing.assert_array_equal(m1.predict_proba(X), m2.predict_proba(X))
51
+
52
+
53
+ class TestSaveLoad:
54
+ def test_save_then_load_roundtrip(self, trained_model_and_features, tmp_path: Path):
55
+ model, df = trained_model_and_features
56
+ artifact = tmp_path / "bbb_model.joblib"
57
+ bbb_model.save(model, artifact)
58
+ assert artifact.exists()
59
+
60
+ reloaded = bbb_model.load(artifact)
61
+ fp_cols = [c for c in df.columns if c.startswith("fp_")]
62
+ X = df[fp_cols].to_numpy()
63
+ np.testing.assert_array_equal(model.predict(X), reloaded.predict(X))
64
+
65
+ def test_load_raises_on_missing_path(self, tmp_path: Path):
66
+ with pytest.raises(FileNotFoundError):
67
+ bbb_model.load(tmp_path / "does_not_exist.joblib")
68
+
69
+
70
+ class TestPredictWithProba:
71
+ def test_returns_label_and_confidence(self, trained_model_and_features):
72
+ model, _ = trained_model_and_features
73
+ result = bbb_model.predict_with_proba(model, "CCO")
74
+ assert "label" in result
75
+ assert "confidence" in result
76
+ assert result["label"] in (0, 1)
77
+ assert 0.0 <= result["confidence"] <= 1.0
78
+
79
+ def test_raises_on_invalid_smiles(self, trained_model_and_features):
80
+ model, _ = trained_model_and_features
81
+ with pytest.raises(ValueError):
82
+ bbb_model.predict_with_proba(model, "this_is_not_a_smiles_AT_ALL")
83
+
84
+ def test_confidence_equals_max_class_probability(self, trained_model_and_features):
85
+ """confidence is the max class probability — verifies against raw predict_proba."""
86
+ model, _ = trained_model_and_features
87
+ from src.pipelines.bbb_pipeline import compute_morgan_fingerprint
88
+ fp = compute_morgan_fingerprint("CCO").reshape(1, -1)
89
+ raw_proba = model.predict_proba(fp)[0]
90
+ result = bbb_model.predict_with_proba(model, "CCO")
91
+ assert abs(result["confidence"] - float(max(raw_proba))) < 1e-9