File size: 11,290 Bytes
3cc6a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90167c7
3cc6a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90167c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efb8713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc6a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90167c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc6a7d
 
 
 
 
90167c7
3cc6a7d
 
 
 
 
90167c7
 
 
efb8713
3cc6a7d
90167c7
efb8713
3cc6a7d
90167c7
efb8713
3cc6a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dad1a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53256ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""BBB-permeability downstream classifier — train / save / load / predict.

Built on top of `data/processed/bbbp_features.parquet` produced by
`src.pipelines.bbb_pipeline`. Uses scikit-learn's `RandomForestClassifier`
(no XGBoost — saves a heavy dep without losing accuracy at this scale).

The model takes a 2,048-bit Morgan fingerprint as input. SHAP-based
explanation is added in Task 2 (`explain_prediction`).
"""
from __future__ import annotations

from pathlib import Path

import joblib
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

from src.core.logger import get_logger
from src.pipelines.bbb_pipeline import (
    compute_morgan_fingerprint,
    is_valid_smiles,
)

logger = get_logger(__name__)


_FP_COL_PREFIX = "fp_"


def _split_features_and_label(
    df: pd.DataFrame, label_col: str,
) -> tuple[np.ndarray, np.ndarray, list[str]]:
    """Pull out fp_* columns as X and `label_col` as y. Returns (X, y, fp_col_names)."""
    if label_col not in df.columns:
        raise KeyError(f"Label column {label_col!r} not in DataFrame")
    fp_cols = [c for c in df.columns if c.startswith(_FP_COL_PREFIX)]
    if not fp_cols:
        raise KeyError(
            f"No {_FP_COL_PREFIX}* columns found — was this DataFrame produced "
            f"by bbb_pipeline.run_pipeline?"
        )
    X = df[fp_cols].to_numpy()
    y = df[label_col].to_numpy()
    return X, y, fp_cols


_CALIBRATION_THRESHOLDS: tuple[float, ...] = (0.50, 0.60, 0.70, 0.75, 0.80, 0.90)


def _compute_calibration_bins(
    model: RandomForestClassifier,
    X_test: np.ndarray,
    y_test: np.ndarray,
) -> list[dict[str, float]]:
    """Compute precision-at-confidence-threshold bins on a held-out test set.

    For each threshold T in `_CALIBRATION_THRESHOLDS`, picks the predictions
    whose max class probability >= T, computes precision and support, and
    returns one bin per threshold. Bins with zero support are still emitted
    (precision = 0.0, support = 0) so the API can always find a match.
    """
    if len(y_test) == 0:
        return [
            {"threshold": float(t), "precision": 0.0, "support": 0}
            for t in _CALIBRATION_THRESHOLDS
        ]
    proba = model.predict_proba(X_test)
    pred = model.predict(X_test)
    confidence = proba.max(axis=1)
    correct = (pred == y_test).astype(int)
    bins: list[dict[str, float]] = []
    for t in _CALIBRATION_THRESHOLDS:
        mask = confidence >= t
        support = int(mask.sum())
        if support == 0:
            precision = 0.0
        else:
            precision = float(correct[mask].mean())
        bins.append({
            "threshold": float(t), "precision": precision, "support": support,
        })
    return bins


def _compute_train_stats(
    model: RandomForestClassifier,
    X_train: np.ndarray,
) -> dict[str, float]:
    """Compute median + std of the model's own confidence on the training set.

    Used as the reference distribution for runtime drift detection. All values
    are floats so the dict is joblib-roundtrip-safe and JSON-serializable.
    """
    if len(X_train) == 0:
        return {"median": 0.0, "std": 0.0, "n_train": 0}
    proba = model.predict_proba(X_train)
    confidence = proba.max(axis=1)
    return {
        "median": float(np.median(confidence)),
        "std": float(np.std(confidence)),
        "n_train": int(len(X_train)),
    }


def train(
    df: pd.DataFrame,
    label_col: str = "p_np",
    n_estimators: int = 100,
    random_state: int = 42,
) -> RandomForestClassifier:
    """Train a Random Forest classifier on Morgan fingerprints.

    Args:
        df: Output of `bbb_pipeline.run_pipeline` — has `fp_0..fp_N-1` cols
            plus a binary `label_col`.
        label_col: Name of the binary target column. Defaults to "p_np".
        n_estimators: Number of trees. 100 is the sklearn default.
        random_state: Seed for split + tree construction (determinism).

    Returns:
        Fitted `RandomForestClassifier` with `feature_names_in_` set so
        downstream callers can map SHAP values back to fp_<bit> indices.
    """
    X, y, fp_cols = _split_features_and_label(df, label_col)
    # Stratified 80/20 split for honest calibration metrics. Falls back to
    # train-on-all if the dataset is too tiny for a stratified split (test
    # fixtures with 3-4 rows hit this branch).
    try:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=random_state, stratify=y,
        )
    except ValueError as e:
        logger.warning(
            "Stratified split failed (%s); training on full data; "
            "calibration bins will be zero-support.",
            e,
        )
        X_train, X_test = X, np.empty((0, X.shape[1]))
        y_train, y_test = y, np.empty((0,))

    model = RandomForestClassifier(
        n_estimators=n_estimators,
        random_state=random_state,
        n_jobs=1,
    )
    model.fit(X_train, y_train)
    # Stash the column names under a project-owned attribute so SHAP (Task 2)
    # can map values back to fp_<bit> indices. Sklearn's own feature_names_in_
    # is only set automatically when fit receives a DataFrame; setting it
    # manually fires UserWarning on every predict call.
    model._neurobridge_fp_cols = list(fp_cols)
    model._neurobridge_calibration = _compute_calibration_bins(
        model, X_test, y_test,
    )
    model._neurobridge_train_stats = _compute_train_stats(model, X_train)
    logger.info(
        "Trained BBB classifier: n=%d, n_features=%d, classes=%s, "
        "calibration_bins=%d, train_confidence_median=%.3f",
        len(y), X.shape[1], model.classes_.tolist(),
        len(model._neurobridge_calibration),
        model._neurobridge_train_stats["median"],
    )
    return model


def save(model: RandomForestClassifier, path: Path) -> None:
    """Persist a fitted model to `path` (parent dirs auto-created)."""
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    joblib.dump(model, path)
    logger.info("Saved BBB model to %s", path)


def load(path: Path) -> RandomForestClassifier:
    """Load a previously-saved model. Raises FileNotFoundError on missing artifact."""
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"BBB model artifact not found: {path}")
    return joblib.load(path)


def predict_with_proba(
    model: RandomForestClassifier,
    smiles: str,
    n_bits: int = 2048,
    radius: int = 2,
) -> dict[str, object]:
    """Predict BBB permeability for a single SMILES.

    Returns:
        `{"label": int, "confidence": float}` where confidence is the
        predicted class's probability (max class probability — model's
        self-rated certainty).

    Raises:
        ValueError: if `smiles` cannot be parsed by RDKit.
    """
    if not is_valid_smiles(smiles):
        raise ValueError(f"invalid SMILES: {smiles!r}")
    fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius)
    proba = model.predict_proba(fp.reshape(1, -1))[0]
    label_idx = int(np.argmax(proba))
    label = int(model.classes_[label_idx])
    return {
        "label": label,
        "confidence": float(proba[label_idx]),
    }


def explain_prediction(
    model: RandomForestClassifier,
    smiles: str,
    top_k: int = 5,
    n_bits: int = 2048,
    radius: int = 2,
) -> list[dict[str, object]]:
    """Return the top-`top_k` feature attributions (SHAP values) for `smiles`.

    Uses `shap.TreeExplainer` (exact for tree ensembles, no sampling). The
    explanation is for the *predicted* class — i.e. SHAP values that pushed
    the model toward whichever label was returned by `predict_with_proba`.

    Reads fingerprint column names from `model._neurobridge_fp_cols` (set by
    `train()`). Falls back to `fp_<index>` if the attribute is missing — useful
    for models loaded from a joblib without the project-owned attribute.

    Args:
        model: Fitted classifier from `train()` or `load()`.
        smiles: A SMILES string (validated via `is_valid_smiles`).
        top_k: How many top features to return. Default 5 — matches the
            jury-demo budget (more bars = noisier waterfall chart).
        n_bits / radius: Must match training-time fingerprint settings.

    Returns:
        A list of `{"feature": "fp_<bit_idx>", "shap_value": float}` dicts,
        sorted by `abs(shap_value)` descending.

    Raises:
        ValueError: if `smiles` cannot be parsed by RDKit.
    """
    import shap  # local import — heavy module, only loaded when needed

    if not is_valid_smiles(smiles):
        raise ValueError(f"invalid SMILES: {smiles!r}")
    fp = compute_morgan_fingerprint(smiles, n_bits=n_bits, radius=radius)
    X = fp.reshape(1, -1)

    explainer = shap.TreeExplainer(model)
    # uint8 fingerprints cause benign additivity violations in SHAP's
    # reconstruction (base + sum != model output within tolerance); the
    # default check produces false-positive errors on tree ensembles
    # over quantized inputs, so we skip it.
    shap_values = explainer.shap_values(X, check_additivity=False)
    # `shap_values` shape varies by sklearn / shap versions:
    #   - older: list of (1, n_features) arrays, one per class
    #   - newer: ndarray of shape (1, n_features, n_classes) for binary RF
    #   - or (1, n_features) when output already condensed
    if isinstance(shap_values, list):
        proba = model.predict_proba(X)[0]
        label_idx = int(np.argmax(proba))
        per_feature = shap_values[label_idx][0]
    else:
        arr = np.asarray(shap_values)
        if arr.ndim == 3:
            # (1, n_features, n_classes)
            proba = model.predict_proba(X)[0]
            label_idx = int(np.argmax(proba))
            per_feature = arr[0, :, label_idx]
        else:
            # (1, n_features)
            per_feature = arr[0]

    fp_cols = (
        list(model._neurobridge_fp_cols)
        if hasattr(model, "_neurobridge_fp_cols")
        else [f"fp_{i}" for i in range(len(per_feature))]
    )

    pairs = sorted(
        zip(fp_cols, per_feature, strict=True),
        key=lambda p: abs(p[1]),
        reverse=True,
    )
    return [
        {"feature": str(name), "shap_value": float(value)}
        for name, value in pairs[:top_k]
    ]


DEFAULT_FEATURES_PATH = Path("data/processed/bbbp_features.parquet")
DEFAULT_MODEL_PATH = Path("data/processed/bbb_model.joblib")


def main() -> None:
    """Train and persist the production BBB model from the Day-4 features Parquet.

    Reads from `DEFAULT_FEATURES_PATH`, trains with default hyperparameters,
    and writes the artifact to `DEFAULT_MODEL_PATH`. Re-runs are idempotent
    (same random_state).
    """
    if not DEFAULT_FEATURES_PATH.exists():
        raise FileNotFoundError(
            f"Features Parquet not found at {DEFAULT_FEATURES_PATH}. "
            f"Run `python -m src.pipelines.bbb_pipeline` first."
        )
    df = pd.read_parquet(DEFAULT_FEATURES_PATH)
    model = train(df, label_col="p_np")
    save(model, DEFAULT_MODEL_PATH)
    logger.info("BBB model artifact ready at %s", DEFAULT_MODEL_PATH)


if __name__ == "__main__":
    main()