feat(bbb): expand SMILES → Morgan FP into model-ready DataFrame with drift logging
Browse files
src/pipelines/bbb_pipeline.py
CHANGED
|
@@ -13,6 +13,7 @@ from __future__ import annotations
|
|
| 13 |
import math
|
| 14 |
|
| 15 |
import numpy as np
|
|
|
|
| 16 |
from rdkit import Chem, RDLogger
|
| 17 |
from rdkit.Chem import AllChem
|
| 18 |
from rdkit.DataStructs import ConvertToNumpyArray
|
|
@@ -78,3 +79,68 @@ def compute_morgan_fingerprint(
|
|
| 78 |
arr = np.zeros((n_bits,), dtype=np.uint8)
|
| 79 |
ConvertToNumpyArray(bit_vect, arr)
|
| 80 |
return arr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
import math
|
| 14 |
|
| 15 |
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
from rdkit import Chem, RDLogger
|
| 18 |
from rdkit.Chem import AllChem
|
| 19 |
from rdkit.DataStructs import ConvertToNumpyArray
|
|
|
|
| 79 |
arr = np.zeros((n_bits,), dtype=np.uint8)
|
| 80 |
ConvertToNumpyArray(bit_vect, arr)
|
| 81 |
return arr
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def extract_features_from_dataframe(
|
| 85 |
+
df: pd.DataFrame,
|
| 86 |
+
smiles_col: str = "smiles",
|
| 87 |
+
n_bits: int = 2048,
|
| 88 |
+
radius: int = 2,
|
| 89 |
+
) -> pd.DataFrame:
|
| 90 |
+
"""Convert a DataFrame of (SMILES + metadata) into model-ready features.
|
| 91 |
+
|
| 92 |
+
Steps:
|
| 93 |
+
1. Validate every SMILES with `is_valid_smiles`. Invalid rows are
|
| 94 |
+
logged at WARNING with their original index and dropped.
|
| 95 |
+
2. Compute the Morgan fingerprint for each remaining SMILES.
|
| 96 |
+
3. Expand the bit vector into `n_bits` integer columns named
|
| 97 |
+
`fp_0 ... fp_{n_bits - 1}` and concatenate with the surviving
|
| 98 |
+
non-SMILES metadata.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
df: Raw DataFrame; must contain `smiles_col`.
|
| 102 |
+
smiles_col: Name of the SMILES column (default `"smiles"`).
|
| 103 |
+
n_bits: Fingerprint length.
|
| 104 |
+
radius: Morgan radius.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
A new DataFrame with the SMILES column dropped and `n_bits` new
|
| 108 |
+
`fp_*` columns appended. Index is reset to 0..N-1.
|
| 109 |
+
|
| 110 |
+
Raises:
|
| 111 |
+
KeyError: if `smiles_col` is missing from `df`.
|
| 112 |
+
"""
|
| 113 |
+
if smiles_col not in df.columns:
|
| 114 |
+
raise KeyError(f"DataFrame is missing required column {smiles_col!r}")
|
| 115 |
+
|
| 116 |
+
n_total = len(df)
|
| 117 |
+
valid_mask = df[smiles_col].apply(is_valid_smiles)
|
| 118 |
+
n_invalid = int((~valid_mask).sum())
|
| 119 |
+
|
| 120 |
+
if n_invalid:
|
| 121 |
+
invalid_indices = df.index[~valid_mask].tolist()
|
| 122 |
+
logger.warning(
|
| 123 |
+
"Dropping %d/%d rows with invalid SMILES (indices=%s)",
|
| 124 |
+
n_invalid, n_total, invalid_indices,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
valid_df = df.loc[valid_mask].reset_index(drop=True)
|
| 128 |
+
|
| 129 |
+
fingerprints = np.stack(
|
| 130 |
+
[
|
| 131 |
+
compute_morgan_fingerprint(s, n_bits=n_bits, radius=radius)
|
| 132 |
+
for s in valid_df[smiles_col].tolist()
|
| 133 |
+
],
|
| 134 |
+
axis=0,
|
| 135 |
+
)
|
| 136 |
+
fp_columns = [f"fp_{i}" for i in range(n_bits)]
|
| 137 |
+
fp_df = pd.DataFrame(fingerprints, columns=fp_columns, dtype=np.uint8)
|
| 138 |
+
|
| 139 |
+
metadata = valid_df.drop(columns=[smiles_col]).reset_index(drop=True)
|
| 140 |
+
out = pd.concat([metadata, fp_df], axis=1)
|
| 141 |
+
|
| 142 |
+
logger.info(
|
| 143 |
+
"Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
|
| 144 |
+
n_total, len(out), n_invalid, 100.0 * n_invalid / max(n_total, 1),
|
| 145 |
+
)
|
| 146 |
+
return out
|
tests/pipelines/test_bbb_pipeline.py
CHANGED
|
@@ -9,6 +9,7 @@ import pytest
|
|
| 9 |
|
| 10 |
from src.pipelines.bbb_pipeline import (
|
| 11 |
compute_morgan_fingerprint,
|
|
|
|
| 12 |
is_valid_smiles,
|
| 13 |
)
|
| 14 |
|
|
@@ -56,3 +57,39 @@ class TestComputeMorganFingerprint:
|
|
| 56 |
def test_invalid_smiles_raises_value_error(self) -> None:
|
| 57 |
with pytest.raises(ValueError, match="invalid SMILES"):
|
| 58 |
compute_morgan_fingerprint("not_a_smiles", n_bits=2048, radius=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from src.pipelines.bbb_pipeline import (
|
| 11 |
compute_morgan_fingerprint,
|
| 12 |
+
extract_features_from_dataframe,
|
| 13 |
is_valid_smiles,
|
| 14 |
)
|
| 15 |
|
|
|
|
| 57 |
def test_invalid_smiles_raises_value_error(self) -> None:
|
| 58 |
with pytest.raises(ValueError, match="invalid SMILES"):
|
| 59 |
compute_morgan_fingerprint("not_a_smiles", n_bits=2048, radius=2)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TestExtractFeaturesFromDataFrame:
|
| 63 |
+
def test_filters_invalid_smiles(self) -> None:
|
| 64 |
+
raw = pd.read_csv(FIXTURE)
|
| 65 |
+
# Sanity: fixture contains 6 rows total, 2 are invalid by construction.
|
| 66 |
+
assert len(raw) == 6
|
| 67 |
+
|
| 68 |
+
features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
|
| 69 |
+
|
| 70 |
+
# Only the 4 chemically valid rows should remain.
|
| 71 |
+
assert len(features) == 4
|
| 72 |
+
|
| 73 |
+
def test_preserves_label_column(self) -> None:
|
| 74 |
+
raw = pd.read_csv(FIXTURE)
|
| 75 |
+
features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
|
| 76 |
+
assert "p_np" in features.columns
|
| 77 |
+
|
| 78 |
+
def test_expands_fingerprint_into_named_columns(self) -> None:
|
| 79 |
+
raw = pd.read_csv(FIXTURE)
|
| 80 |
+
features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
|
| 81 |
+
fp_cols = [c for c in features.columns if c.startswith("fp_")]
|
| 82 |
+
assert len(fp_cols) == 128
|
| 83 |
+
# All FP columns must be 0/1 integers.
|
| 84 |
+
assert features[fp_cols].isin([0, 1]).all().all()
|
| 85 |
+
|
| 86 |
+
def test_drops_smiles_string_after_expansion(self) -> None:
|
| 87 |
+
"""Once expanded to bits, the original SMILES string adds no signal."""
|
| 88 |
+
raw = pd.read_csv(FIXTURE)
|
| 89 |
+
features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
|
| 90 |
+
assert "smiles" not in features.columns
|
| 91 |
+
|
| 92 |
+
def test_resets_index(self) -> None:
|
| 93 |
+
raw = pd.read_csv(FIXTURE)
|
| 94 |
+
features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
|
| 95 |
+
assert list(features.index) == list(range(len(features)))
|