mekosotto Claude Sonnet 4.6 commited on
Commit
b08a67c
·
1 Parent(s): 80528e7

feat(bbb): expand SMILES → Morgan FP into model-ready DataFrame with drift logging

Browse files
src/pipelines/bbb_pipeline.py CHANGED
@@ -13,6 +13,7 @@ from __future__ import annotations
13
  import math
14
 
15
  import numpy as np
 
16
  from rdkit import Chem, RDLogger
17
  from rdkit.Chem import AllChem
18
  from rdkit.DataStructs import ConvertToNumpyArray
@@ -78,3 +79,68 @@ def compute_morgan_fingerprint(
78
  arr = np.zeros((n_bits,), dtype=np.uint8)
79
  ConvertToNumpyArray(bit_vect, arr)
80
  return arr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  import math
14
 
15
  import numpy as np
16
+ import pandas as pd
17
  from rdkit import Chem, RDLogger
18
  from rdkit.Chem import AllChem
19
  from rdkit.DataStructs import ConvertToNumpyArray
 
79
  arr = np.zeros((n_bits,), dtype=np.uint8)
80
  ConvertToNumpyArray(bit_vect, arr)
81
  return arr
82
+
83
+
84
+ def extract_features_from_dataframe(
85
+ df: pd.DataFrame,
86
+ smiles_col: str = "smiles",
87
+ n_bits: int = 2048,
88
+ radius: int = 2,
89
+ ) -> pd.DataFrame:
90
+ """Convert a DataFrame of (SMILES + metadata) into model-ready features.
91
+
92
+ Steps:
93
+ 1. Validate every SMILES with `is_valid_smiles`. Invalid rows are
94
+ logged at WARNING with their original index and dropped.
95
+ 2. Compute the Morgan fingerprint for each remaining SMILES.
96
+ 3. Expand the bit vector into `n_bits` integer columns named
97
+ `fp_0 ... fp_{n_bits - 1}` and concatenate with the surviving
98
+ non-SMILES metadata.
99
+
100
+ Args:
101
+ df: Raw DataFrame; must contain `smiles_col`.
102
+ smiles_col: Name of the SMILES column (default `"smiles"`).
103
+ n_bits: Fingerprint length.
104
+ radius: Morgan radius.
105
+
106
+ Returns:
107
+ A new DataFrame with the SMILES column dropped and `n_bits` new
108
+ `fp_*` columns appended. Index is reset to 0..N-1.
109
+
110
+ Raises:
111
+ KeyError: if `smiles_col` is missing from `df`.
112
+ """
113
+ if smiles_col not in df.columns:
114
+ raise KeyError(f"DataFrame is missing required column {smiles_col!r}")
115
+
116
+ n_total = len(df)
117
+ valid_mask = df[smiles_col].apply(is_valid_smiles)
118
+ n_invalid = int((~valid_mask).sum())
119
+
120
+ if n_invalid:
121
+ invalid_indices = df.index[~valid_mask].tolist()
122
+ logger.warning(
123
+ "Dropping %d/%d rows with invalid SMILES (indices=%s)",
124
+ n_invalid, n_total, invalid_indices,
125
+ )
126
+
127
+ valid_df = df.loc[valid_mask].reset_index(drop=True)
128
+
129
+ fingerprints = np.stack(
130
+ [
131
+ compute_morgan_fingerprint(s, n_bits=n_bits, radius=radius)
132
+ for s in valid_df[smiles_col].tolist()
133
+ ],
134
+ axis=0,
135
+ )
136
+ fp_columns = [f"fp_{i}" for i in range(n_bits)]
137
+ fp_df = pd.DataFrame(fingerprints, columns=fp_columns, dtype=np.uint8)
138
+
139
+ metadata = valid_df.drop(columns=[smiles_col]).reset_index(drop=True)
140
+ out = pd.concat([metadata, fp_df], axis=1)
141
+
142
+ logger.info(
143
+ "Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
144
+ n_total, len(out), n_invalid, 100.0 * n_invalid / max(n_total, 1),
145
+ )
146
+ return out
tests/pipelines/test_bbb_pipeline.py CHANGED
@@ -9,6 +9,7 @@ import pytest
9
 
10
  from src.pipelines.bbb_pipeline import (
11
  compute_morgan_fingerprint,
 
12
  is_valid_smiles,
13
  )
14
 
@@ -56,3 +57,39 @@ class TestComputeMorganFingerprint:
56
  def test_invalid_smiles_raises_value_error(self) -> None:
57
  with pytest.raises(ValueError, match="invalid SMILES"):
58
  compute_morgan_fingerprint("not_a_smiles", n_bits=2048, radius=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  from src.pipelines.bbb_pipeline import (
11
  compute_morgan_fingerprint,
12
+ extract_features_from_dataframe,
13
  is_valid_smiles,
14
  )
15
 
 
57
  def test_invalid_smiles_raises_value_error(self) -> None:
58
  with pytest.raises(ValueError, match="invalid SMILES"):
59
  compute_morgan_fingerprint("not_a_smiles", n_bits=2048, radius=2)
60
+
61
+
62
+ class TestExtractFeaturesFromDataFrame:
63
+ def test_filters_invalid_smiles(self) -> None:
64
+ raw = pd.read_csv(FIXTURE)
65
+ # Sanity: fixture contains 6 rows total, 2 are invalid by construction.
66
+ assert len(raw) == 6
67
+
68
+ features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
69
+
70
+ # Only the 4 chemically valid rows should remain.
71
+ assert len(features) == 4
72
+
73
+ def test_preserves_label_column(self) -> None:
74
+ raw = pd.read_csv(FIXTURE)
75
+ features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
76
+ assert "p_np" in features.columns
77
+
78
+ def test_expands_fingerprint_into_named_columns(self) -> None:
79
+ raw = pd.read_csv(FIXTURE)
80
+ features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
81
+ fp_cols = [c for c in features.columns if c.startswith("fp_")]
82
+ assert len(fp_cols) == 128
83
+ # All FP columns must be 0/1 integers.
84
+ assert features[fp_cols].isin([0, 1]).all().all()
85
+
86
+ def test_drops_smiles_string_after_expansion(self) -> None:
87
+ """Once expanded to bits, the original SMILES string adds no signal."""
88
+ raw = pd.read_csv(FIXTURE)
89
+ features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
90
+ assert "smiles" not in features.columns
91
+
92
+ def test_resets_index(self) -> None:
93
+ raw = pd.read_csv(FIXTURE)
94
+ features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
95
+ assert list(features.index) == list(range(len(features)))