ribesstefano commited on
Commit
dd628da
Β·
1 Parent(s): 3822c8b

Initial app version

Browse files
Files changed (2) hide show
  1. app.py +468 -0
  2. requirements.txt +13 -0
app.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TACK Demo β€” PROTAC/degrader activity prediction via ensemble models."""
2
+ import os
3
+ import logging
4
+ import tempfile
5
+ import warnings
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import pandas as pd
9
+ import gradio as gr
10
+
11
+ warnings.filterwarnings("ignore")
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Constants
17
+ # ---------------------------------------------------------------------------
18
+ TASK_LABELS = {
19
+ "bin": "Binary Activity (prob.)",
20
+ "dc50": "DC50 (nM)",
21
+ "dmax": "Dmax (%)",
22
+ }
23
+ RESULT_COLS = [
24
+ "Task", "Prediction", "Uncertainty (Β±std)",
25
+ "CI 95% Low", "CI 95% High", "n_models",
26
+ ]
27
+ EXAMPLE_SMILES = (
28
+ "CC1(C)[C@H](NC(=O)c2ccc(N3CCN(CCCOc4ccc(C(=O)NC5CCC(=O)NC5=O)"
29
+ "nc4)CC3)nc2)C(C)(C)[C@H]1Oc1ccc(C#N)c(Cl)c1"
30
+ )
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Model loading at startup
34
+ # ---------------------------------------------------------------------------
35
+ def _load_predictors() -> Dict:
36
+ """Download models from HF Hub and return EnsemblePredictor instances."""
37
+ try:
38
+ from huggingface_hub import snapshot_download
39
+ from tackai.ensemble_predictor import EnsemblePredictor
40
+ except ImportError as exc:
41
+ logger.error("Missing dependency: %s", exc)
42
+ return {}
43
+
44
+ try:
45
+ cache_dir = snapshot_download(repo_id="ailab-bio/tack-cache")
46
+ os.environ.setdefault("TACK_CACHE_DIR", cache_dir)
47
+ logger.info("Cache downloaded to %s", cache_dir)
48
+ except Exception as exc:
49
+ logger.warning("Cache repo unavailable: %s", exc)
50
+
51
+ # Each repo stores the ensemble under a task-named subfolder.
52
+ repo_subfolders = {
53
+ "bin": ("ailab-bio/TACK-Model-Bin", "bin_best_arch_ensemble"),
54
+ # "dc50": ("ailab-bio/TACK-Model-DC50", "dc50_best_arch_ensemble"),
55
+ # "dmax": ("ailab-bio/TACK-Model-Dmax", "dmax_best_arch_ensemble"),
56
+ }
57
+ loaded: Dict = {}
58
+ for task, (repo_id, subfolder) in repo_subfolders.items():
59
+ try:
60
+ repo_dir = snapshot_download(repo_id=repo_id)
61
+ model_dir = os.path.join(repo_dir, subfolder)
62
+ loaded[task] = EnsemblePredictor.from_directory(
63
+ model_dir, device="cpu"
64
+ )
65
+ logger.info(
66
+ "Loaded '%s' predictor from %s (%d models).",
67
+ task,
68
+ model_dir,
69
+ len(loaded[task].models),
70
+ )
71
+ except Exception as exc:
72
+ logger.warning("Could not load predictor '%s': %s", task, exc)
73
+ return loaded
74
+
75
+
76
+ PREDICTORS: Dict = _load_predictors()
77
+ AVAILABLE_TASKS: List[str] = list(PREDICTORS.keys())
78
+
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # Sample construction helpers
82
+ # ---------------------------------------------------------------------------
83
+ def _make_sample(
84
+ smiles: str,
85
+ poi_name: str,
86
+ poi_sequence: str,
87
+ ligase_name: str,
88
+ ligase_sequence: str,
89
+ cell_line: str,
90
+ treatment_time: float,
91
+ assay_type: str,
92
+ ) -> "SampleInput":
93
+ from tackai.ensemble_predictor import SampleInput
94
+ return SampleInput(
95
+ smiles=smiles.strip() if smiles else None,
96
+ poi_name=poi_name.strip() if poi_name else None,
97
+ poi_sequence=poi_sequence.strip() if poi_sequence else None,
98
+ ligase_name=ligase_name.strip() if ligase_name else None,
99
+ ligase_sequence=ligase_sequence.strip() if ligase_sequence else None,
100
+ cell_line=(cell_line or "Unknown").strip(),
101
+ assay_type=(assay_type or "Unknown").strip(),
102
+ treatment_time=float(treatment_time) if treatment_time else 24.0,
103
+ )
104
+
105
+
106
+ def _sample_from_row(row: pd.Series) -> "SampleInput":
107
+ from tackai.ensemble_predictor import SampleInput
108
+
109
+ def get_str(*keys: str) -> Optional[str]:
110
+ for k in keys:
111
+ v = row.get(k)
112
+ if v is not None and pd.notna(v) and str(v).strip():
113
+ return str(v).strip()
114
+ return None
115
+
116
+ def get_float(*keys: str) -> Optional[float]:
117
+ for k in keys:
118
+ v = row.get(k)
119
+ if v is not None and pd.notna(v):
120
+ try:
121
+ return float(v)
122
+ except (ValueError, TypeError):
123
+ pass
124
+ return None
125
+
126
+ return SampleInput(
127
+ smiles=get_str("SMILES", "smiles"),
128
+ poi_name=get_str("POI_Name", "poi_name"),
129
+ poi_sequence=get_str("POI_Sequence", "poi_sequence"),
130
+ ligase_name=get_str("Ligase_Name", "ligase_name"),
131
+ ligase_sequence=get_str("Ligase_Sequence", "ligase_sequence"),
132
+ cell_line=get_str("Cell_Line", "cell_line") or "Unknown",
133
+ assay_type=get_str("Assay", "assay_type") or "Unknown",
134
+ treatment_time=get_float(
135
+ "Assay_Time", "assay_time", "treatment_time"
136
+ ) or 24.0,
137
+ )
138
+
139
+
140
+ def _result_row(result: "EnsemblePrediction", task: str) -> Dict:
141
+ return {
142
+ "Task": TASK_LABELS.get(task, task.upper()),
143
+ "Prediction": round(float(result.weighted_mean[0]), 4),
144
+ "Uncertainty (Β±std)": round(float(result.uncertainty_std[0]), 4),
145
+ "CI 95% Low": round(float(result.ci_percentile_lower_95[0]), 4),
146
+ "CI 95% High": round(float(result.ci_percentile_upper_95[0]), 4),
147
+ "n_models": len(result.model_names),
148
+ }
149
+
150
+
151
+ # ---------------------------------------------------------------------------
152
+ # Prediction callbacks
153
+ # ---------------------------------------------------------------------------
154
+ def run_single_prediction(
155
+ smiles: str,
156
+ poi_name: str,
157
+ poi_sequence: str,
158
+ ligase_name: str,
159
+ ligase_sequence: str,
160
+ cell_line: str,
161
+ treatment_time: float,
162
+ assay_type: str,
163
+ selected_tasks: List[str],
164
+ ) -> Tuple[pd.DataFrame, str]:
165
+ if not smiles or not smiles.strip():
166
+ return pd.DataFrame(columns=RESULT_COLS), "Please enter a SMILES string."
167
+ if not selected_tasks:
168
+ return pd.DataFrame(columns=RESULT_COLS), "Please select at least one task."
169
+ if not PREDICTORS:
170
+ return pd.DataFrame(columns=RESULT_COLS), (
171
+ "No models loaded β€” ensure the HF repositories are available."
172
+ )
173
+
174
+ sample = _make_sample(
175
+ smiles, poi_name, poi_sequence,
176
+ ligase_name, ligase_sequence,
177
+ cell_line, treatment_time, assay_type,
178
+ )
179
+ rows = []
180
+ for task in selected_tasks:
181
+ if task not in PREDICTORS:
182
+ continue
183
+ try:
184
+ task_dict = PREDICTORS[task].predict_batch(
185
+ [sample], tasks=[task]
186
+ )[0]
187
+ if task_dict and task in task_dict:
188
+ rows.append(_result_row(task_dict[task], task))
189
+ except Exception as exc:
190
+ logger.error("Single prediction error (task=%s): %s", task, exc)
191
+ rows.append({
192
+ "Task": TASK_LABELS.get(task, task.upper()),
193
+ "Prediction": "ERROR",
194
+ "Uncertainty (Β±std)": "β€”",
195
+ "CI 95% Low": "β€”",
196
+ "CI 95% High": "β€”",
197
+ "n_models": 0,
198
+ })
199
+
200
+ if not rows:
201
+ return pd.DataFrame(columns=RESULT_COLS), "No predictions returned."
202
+ return pd.DataFrame(rows), ""
203
+
204
+
205
+ def load_csv_preview(filepath: Optional[str]) -> pd.DataFrame:
206
+ if not filepath:
207
+ return pd.DataFrame()
208
+ try:
209
+ return pd.read_csv(filepath).head(5)
210
+ except Exception as exc:
211
+ logger.error("CSV preview error: %s", exc)
212
+ return pd.DataFrame()
213
+
214
+
215
+ def run_batch_prediction(
216
+ filepath: Optional[str],
217
+ selected_tasks: List[str],
218
+ ) -> Tuple[pd.DataFrame, Optional[str], str]:
219
+ """Run batch predictions; returns (results_df, download_path, message)."""
220
+ if not filepath:
221
+ return pd.DataFrame(), None, "Please upload a CSV file."
222
+ if not selected_tasks:
223
+ return pd.DataFrame(), None, "Please select at least one task."
224
+ if not PREDICTORS:
225
+ return pd.DataFrame(), None, (
226
+ "No models loaded β€” ensure the HF repositories are available."
227
+ )
228
+
229
+ try:
230
+ df = pd.read_csv(filepath)
231
+ except Exception as exc:
232
+ return pd.DataFrame(), None, f"Error reading CSV: {exc}"
233
+
234
+ if df.empty:
235
+ return pd.DataFrame(), None, "Uploaded CSV is empty."
236
+ if "SMILES" not in df.columns and "smiles" not in df.columns:
237
+ return pd.DataFrame(), None, "CSV must contain a 'SMILES' column."
238
+
239
+ samples = [_sample_from_row(row) for _, row in df.iterrows()]
240
+ result_rows = []
241
+
242
+ for task in selected_tasks:
243
+ if task not in PREDICTORS:
244
+ continue
245
+ try:
246
+ batch = PREDICTORS[task].predict_batch(samples, tasks=[task])
247
+ for i, task_dict in enumerate(batch):
248
+ base = {"#": i + 1, "SMILES": samples[i].smiles or ""}
249
+ if task_dict and task in task_dict:
250
+ result_rows.append({
251
+ **base,
252
+ **_result_row(task_dict[task], task),
253
+ })
254
+ else:
255
+ result_rows.append({
256
+ **base,
257
+ "Task": TASK_LABELS.get(task, task.upper()),
258
+ "Prediction": "ERROR",
259
+ })
260
+ except Exception as exc:
261
+ logger.error(
262
+ "Batch prediction error (task=%s): %s", task, exc
263
+ )
264
+ return pd.DataFrame(), None, f"Prediction failed: {exc}"
265
+
266
+ if not result_rows:
267
+ return pd.DataFrame(), None, "No predictions returned."
268
+
269
+ results_df = pd.DataFrame(result_rows)
270
+ tmp = tempfile.NamedTemporaryFile(
271
+ delete=False, suffix=".csv", prefix="tack_results_",
272
+ )
273
+ results_df.to_csv(tmp.name, index=False)
274
+ return results_df, tmp.name, ""
275
+
276
+
277
+ def get_template_csv() -> str:
278
+ """Write a CSV template to a temp file and return its path."""
279
+ df = pd.DataFrame([{
280
+ "SMILES": EXAMPLE_SMILES,
281
+ "POI_Name": "AR",
282
+ "POI_Sequence": "",
283
+ "Ligase_Name": "CRBN",
284
+ "Ligase_Sequence": "",
285
+ "Cell_Line": "Unknown",
286
+ "Assay_Time": 24.0,
287
+ "Assay": "Unknown",
288
+ }])
289
+ tmp = tempfile.NamedTemporaryFile(
290
+ delete=False, suffix=".csv", prefix="tack_template_",
291
+ )
292
+ df.to_csv(tmp.name, index=False)
293
+ return tmp.name
294
+
295
+
296
+ # ---------------------------------------------------------------------------
297
+ # Gradio UI
298
+ # ---------------------------------------------------------------------------
299
+ _TASK_CHOICES = [
300
+ (TASK_LABELS["bin"], "bin"),
301
+ (TASK_LABELS["dc50"], "dc50"),
302
+ (TASK_LABELS["dmax"], "dmax"),
303
+ ]
304
+ _DEFAULT_TASKS = AVAILABLE_TASKS if AVAILABLE_TASKS else ["bin"]
305
+ _NO_MODELS_BANNER = (
306
+ "> ⚠️ **No models loaded.** Ensure the HF repositories "
307
+ "(`ailab-bio/tack-model-*`) are publicly available and try again."
308
+ if not PREDICTORS
309
+ else ""
310
+ )
311
+
312
+ with gr.Blocks(
313
+ theme=gr.themes.Soft(primary_hue="blue"),
314
+ title="TACK β€” PROTAC Activity Predictor",
315
+ ) as demo:
316
+
317
+ gr.Markdown("# TACK β€” PROTAC Activity Predictor")
318
+ gr.Markdown(
319
+ "Predict PROTAC/degrader activity using an ensemble of XGBoost "
320
+ "and MLP models trained on PROTAC-DB, TPD-DB, and PROTAC-Pedia. "
321
+ "Outputs binary activity probability, DC50 (nM), and Dmax (%) "
322
+ "with full uncertainty quantification."
323
+ )
324
+ if _NO_MODELS_BANNER:
325
+ gr.Markdown(_NO_MODELS_BANNER)
326
+
327
+ task_selector = gr.CheckboxGroup(
328
+ choices=_TASK_CHOICES,
329
+ value=_DEFAULT_TASKS,
330
+ label="Prediction Task(s)",
331
+ info="Select one or more properties to predict.",
332
+ )
333
+
334
+ with gr.Tabs():
335
+
336
+ # ── Tab 1: Single compound ─────────────────────────────────────────
337
+ with gr.Tab("Single Compound"):
338
+ with gr.Row():
339
+ with gr.Column(scale=1):
340
+ smiles_in = gr.Textbox(
341
+ label="SMILES *",
342
+ placeholder="Paste a SMILES string…",
343
+ lines=2,
344
+ value=EXAMPLE_SMILES,
345
+ )
346
+ with gr.Accordion("POI (Protein of Interest)", open=False):
347
+ poi_name_in = gr.Textbox(
348
+ label="POI Name",
349
+ placeholder="e.g. AR, BRD4, SMARCA2",
350
+ )
351
+ poi_seq_in = gr.Textbox(
352
+ label="Amino Acid Sequence",
353
+ placeholder=(
354
+ "Paste full sequence (no FASTA header)…"
355
+ ),
356
+ lines=5,
357
+ )
358
+ with gr.Accordion("E3 Ligase", open=False):
359
+ ligase_name_in = gr.Textbox(
360
+ label="E3 Ligase Name",
361
+ placeholder="e.g. CRBN, VHL, MDM2",
362
+ value="CRBN",
363
+ )
364
+ ligase_seq_in = gr.Textbox(
365
+ label="Amino Acid Sequence",
366
+ placeholder=(
367
+ "Paste full sequence (no FASTA header)…"
368
+ ),
369
+ lines=5,
370
+ )
371
+ with gr.Row():
372
+ cell_line_in = gr.Textbox(
373
+ label="Cell Line",
374
+ value="Unknown",
375
+ placeholder="e.g. HEK293, Jurkat",
376
+ )
377
+ treatment_time_in = gr.Number(
378
+ label="Treatment Time (h)",
379
+ value=24.0,
380
+ minimum=0.0,
381
+ )
382
+ assay_type_in = gr.Textbox(
383
+ label="Assay Type",
384
+ value="Unknown",
385
+ placeholder="e.g. Western, FACS",
386
+ )
387
+ predict_single_btn = gr.Button(
388
+ "Predict", variant="primary", size="lg",
389
+ )
390
+
391
+ with gr.Column(scale=1):
392
+ single_msg = gr.Markdown()
393
+ single_results = gr.Dataframe(
394
+ headers=RESULT_COLS,
395
+ label="Results",
396
+ interactive=False,
397
+ wrap=True,
398
+ )
399
+
400
+ predict_single_btn.click(
401
+ fn=run_single_prediction,
402
+ inputs=[
403
+ smiles_in, poi_name_in, poi_seq_in,
404
+ ligase_name_in, ligase_seq_in,
405
+ cell_line_in, treatment_time_in, assay_type_in,
406
+ task_selector,
407
+ ],
408
+ outputs=[single_results, single_msg],
409
+ )
410
+
411
+ # ── Tab 2: Batch (CSV) ─────────────────────────────────────────────
412
+ with gr.Tab("Batch (CSV)"):
413
+ gr.Markdown(
414
+ "Upload a CSV with a **SMILES** column (required). "
415
+ "Optional columns: `POI_Name`, `POI_Sequence`, "
416
+ "`Ligase_Name`, `Ligase_Sequence`, `Cell_Line`, "
417
+ "`Assay_Time`, `Assay`."
418
+ )
419
+ with gr.Row():
420
+ csv_upload = gr.File(
421
+ label="Upload CSV",
422
+ file_types=[".csv"],
423
+ type="filepath",
424
+ )
425
+ with gr.Column():
426
+ template_btn = gr.Button(
427
+ "Get Template CSV", size="sm",
428
+ )
429
+ template_out = gr.File(
430
+ label="Template",
431
+ interactive=False,
432
+ visible=True,
433
+ )
434
+ template_btn.click(fn=get_template_csv, outputs=template_out)
435
+
436
+ csv_preview = gr.Dataframe(
437
+ label="CSV Preview (first 5 rows)",
438
+ interactive=False,
439
+ )
440
+ csv_upload.change(
441
+ fn=load_csv_preview,
442
+ inputs=csv_upload,
443
+ outputs=csv_preview,
444
+ )
445
+
446
+ batch_predict_btn = gr.Button(
447
+ "Run Batch Prediction", variant="primary", size="lg",
448
+ )
449
+ batch_msg = gr.Markdown()
450
+ batch_results = gr.Dataframe(
451
+ label="Batch Results",
452
+ interactive=False,
453
+ wrap=True,
454
+ )
455
+ batch_download = gr.File(
456
+ label="Download Results CSV",
457
+ interactive=False,
458
+ )
459
+
460
+ batch_predict_btn.click(
461
+ fn=run_batch_prediction,
462
+ inputs=[csv_upload, task_selector],
463
+ outputs=[batch_results, batch_download, batch_msg],
464
+ )
465
+
466
+
467
+ if __name__ == "__main__":
468
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt for the HF Space only
2
+ gradio
3
+ huggingface_hub
4
+ datasets
5
+ rdkit
6
+ scikit-learn
7
+ xgboost
8
+ numpy
9
+ pandas
10
+ joblib
11
+ torch
12
+ lightning
13
+ tackai # install from the TACK package once published