mekosotto Claude Sonnet 4.6 commited on
Commit
049a352
·
1 Parent(s): b08a67c

fix(bbb): guard empty input; truncate index log; add KeyError + log tests

Browse files
src/pipelines/bbb_pipeline.py CHANGED
@@ -81,6 +81,25 @@ def compute_morgan_fingerprint(
81
  return arr
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def extract_features_from_dataframe(
85
  df: pd.DataFrame,
86
  smiles_col: str = "smiles",
@@ -97,6 +116,10 @@ def extract_features_from_dataframe(
97
  `fp_0 ... fp_{n_bits - 1}` and concatenate with the surviving
98
  non-SMILES metadata.
99
 
 
 
 
 
100
  Args:
101
  df: Raw DataFrame; must contain `smiles_col`.
102
  smiles_col: Name of the SMILES column (default `"smiles"`).
@@ -113,30 +136,40 @@ def extract_features_from_dataframe(
113
  if smiles_col not in df.columns:
114
  raise KeyError(f"DataFrame is missing required column {smiles_col!r}")
115
 
 
 
 
116
  n_total = len(df)
117
  valid_mask = df[smiles_col].apply(is_valid_smiles)
118
  n_invalid = int((~valid_mask).sum())
119
 
120
  if n_invalid:
121
  invalid_indices = df.index[~valid_mask].tolist()
 
 
 
 
 
 
122
  logger.warning(
123
- "Dropping %d/%d rows with invalid SMILES (indices=%s)",
124
- n_invalid, n_total, invalid_indices,
125
  )
126
 
127
  valid_df = df.loc[valid_mask].reset_index(drop=True)
128
 
129
- fingerprints = np.stack(
130
- [
131
- compute_morgan_fingerprint(s, n_bits=n_bits, radius=radius)
132
- for s in valid_df[smiles_col].tolist()
133
- ],
134
- axis=0,
 
 
 
135
  )
136
- fp_columns = [f"fp_{i}" for i in range(n_bits)]
137
  fp_df = pd.DataFrame(fingerprints, columns=fp_columns, dtype=np.uint8)
138
-
139
- metadata = valid_df.drop(columns=[smiles_col]).reset_index(drop=True)
140
  out = pd.concat([metadata, fp_df], axis=1)
141
 
142
  logger.info(
 
81
  return arr
82
 
83
 
84
+ def _compute_fingerprint_matrix(
85
+ valid_smiles: list[str],
86
+ n_bits: int,
87
+ radius: int,
88
+ ) -> np.ndarray:
89
+ """Stack Morgan fingerprints into a (N, n_bits) uint8 matrix.
90
+
91
+ Caller must guarantee `valid_smiles` is non-empty and every entry has
92
+ already passed `is_valid_smiles`.
93
+ """
94
+ return np.stack(
95
+ [
96
+ compute_morgan_fingerprint(s, n_bits=n_bits, radius=radius)
97
+ for s in valid_smiles
98
+ ],
99
+ axis=0,
100
+ )
101
+
102
+
103
  def extract_features_from_dataframe(
104
  df: pd.DataFrame,
105
  smiles_col: str = "smiles",
 
116
  `fp_0 ... fp_{n_bits - 1}` and concatenate with the surviving
117
  non-SMILES metadata.
118
 
119
+ On empty input or when every row is invalid, returns a DataFrame with
120
+ the expected columns and zero rows (rather than raising), so callers
121
+ downstream see a well-typed result instead of an exception.
122
+
123
  Args:
124
  df: Raw DataFrame; must contain `smiles_col`.
125
  smiles_col: Name of the SMILES column (default `"smiles"`).
 
136
  if smiles_col not in df.columns:
137
  raise KeyError(f"DataFrame is missing required column {smiles_col!r}")
138
 
139
+ fp_columns = [f"fp_{i}" for i in range(n_bits)]
140
+ metadata_columns = [c for c in df.columns if c != smiles_col]
141
+
142
  n_total = len(df)
143
  valid_mask = df[smiles_col].apply(is_valid_smiles)
144
  n_invalid = int((~valid_mask).sum())
145
 
146
  if n_invalid:
147
  invalid_indices = df.index[~valid_mask].tolist()
148
+ display = invalid_indices[:10]
149
+ suffix = (
150
+ f"... (+{len(invalid_indices) - 10} more)"
151
+ if len(invalid_indices) > 10
152
+ else ""
153
+ )
154
  logger.warning(
155
+ "Dropping %d/%d rows with invalid SMILES (indices=%s%s)",
156
+ n_invalid, n_total, display, suffix,
157
  )
158
 
159
  valid_df = df.loc[valid_mask].reset_index(drop=True)
160
 
161
+ if len(valid_df) == 0:
162
+ logger.info(
163
+ "Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
164
+ n_total, n_invalid, 100.0 * n_invalid / max(n_total, 1),
165
+ )
166
+ return pd.DataFrame(columns=metadata_columns + fp_columns)
167
+
168
+ fingerprints = _compute_fingerprint_matrix(
169
+ valid_df[smiles_col].tolist(), n_bits=n_bits, radius=radius,
170
  )
 
171
  fp_df = pd.DataFrame(fingerprints, columns=fp_columns, dtype=np.uint8)
172
+ metadata = valid_df.drop(columns=[smiles_col])
 
173
  out = pd.concat([metadata, fp_df], axis=1)
174
 
175
  logger.info(
tests/pipelines/test_bbb_pipeline.py CHANGED
@@ -93,3 +93,46 @@ class TestExtractFeaturesFromDataFrame:
93
  raw = pd.read_csv(FIXTURE)
94
  features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
95
  assert list(features.index) == list(range(len(features)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  raw = pd.read_csv(FIXTURE)
94
  features = extract_features_from_dataframe(raw, smiles_col="smiles", n_bits=128, radius=2)
95
  assert list(features.index) == list(range(len(features)))
96
+
97
+ def test_raises_key_error_on_missing_smiles_col(self) -> None:
98
+ df = pd.DataFrame({"foo": [1, 2, 3]})
99
+ with pytest.raises(KeyError, match="missing required column 'smiles'"):
100
+ extract_features_from_dataframe(df, smiles_col="smiles", n_bits=64)
101
+
102
+ def test_returns_empty_dataframe_when_all_invalid(self) -> None:
103
+ """All-invalid input must produce a typed empty result, not crash."""
104
+ df = pd.DataFrame(
105
+ {
106
+ "p_np": [0, 0],
107
+ "smiles": ["", "still_garbage"],
108
+ }
109
+ )
110
+ out = extract_features_from_dataframe(df, smiles_col="smiles", n_bits=32)
111
+ assert len(out) == 0
112
+ assert "p_np" in out.columns
113
+ assert sum(c.startswith("fp_") for c in out.columns) == 32
114
+ assert "smiles" not in out.columns
115
+
116
+ def test_emits_warning_and_info_logs(self) -> None:
117
+ """AGENTS.md §4 traceability: log invalid drops + in/out/dropped counts."""
118
+ import io
119
+ import logging
120
+
121
+ from src.core.logger import get_logger
122
+ from src.pipelines import bbb_pipeline as mod
123
+
124
+ # Swap the module logger's stream so we can capture output.
125
+ logger = get_logger(mod.__name__, level=logging.INFO)
126
+ handler = logger.handlers[0]
127
+ buf = io.StringIO()
128
+ original_stream = handler.stream
129
+ handler.stream = buf
130
+ try:
131
+ df = pd.read_csv(FIXTURE)
132
+ extract_features_from_dataframe(df, smiles_col="smiles", n_bits=32)
133
+ finally:
134
+ handler.stream = original_stream
135
+
136
+ output = buf.getvalue()
137
+ assert "Dropping 2/6 rows with invalid SMILES" in output
138
+ assert "Feature extraction complete: in=6, out=4, dropped=2" in output