| """TACK Demo β PROTAC/degrader activity prediction via ensemble models.""" |
| import os |
| import logging |
| import tempfile |
| import warnings |
| from typing import Dict, List, Optional, Tuple |
|
|
| import pandas as pd |
| import gradio as gr |
|
|
| warnings.filterwarnings("ignore") |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| TASK_LABELS = { |
| "bin": "Binary Activity (prob.)", |
| "dc50": "DC50 (nM)", |
| "dmax": "Dmax (%)", |
| } |
| RESULT_COLS = [ |
| "Task", "Prediction", "Uncertainty (Β±std)", |
| "CI 95% Low", "CI 95% High", "n_models", |
| ] |
| EXAMPLE_SMILES = ( |
| "Cn1c(=O)n(C2CCC(=O)NC2=O)c2cccc(C#CCCN3CCC4(CC3)CC(n3cc(NC(=O)c5cnn6ccc(N7C[C@H]8C[C@@H]7CO8)nc56)c(C(F)F)n3)C4)c21", |
| "Cc1ncsc1C1=CCC([C@H](C)NC(=O)[C@@H]2C[C@@H](O)CN2C(=O)C(c2cc(N3CCC(CN4CCC(n5nc(NC(C)C)c6nnc(-c7cccc(F)c7O)cc65)CC4)CC3)no2)C(C)C)C=C1", |
| "COc1ccccc1C(=O)NCc1ccc(-c2nn3c(c2C(N)=O)Nc2ccc(N4CCN(CC5CCN(c6ccc(N7CCC(=O)NC7=O)cc6)CC5)CC4)cc2CC3)c(OC)c1", |
| ) |
|
|
| |
| |
| |
| def _load_predictors() -> Dict: |
| """Download models from HF Hub and return EnsemblePredictor instances.""" |
| try: |
| from huggingface_hub import snapshot_download |
| from tackai.ensemble_predictor import EnsemblePredictor |
| except ImportError as exc: |
| logger.error("Missing dependency: %s", exc) |
| return {} |
|
|
| try: |
| cache_dir = snapshot_download(repo_id="ailab-bio/tack-cache") |
| os.environ.setdefault("TACKAI_CACHE", cache_dir) |
| logger.info("Cache downloaded to %s", cache_dir) |
| except Exception as exc: |
| logger.warning("Cache repo unavailable: %s", exc) |
|
|
| |
| repo_subfolders = { |
| "bin": ("ailab-bio/TACK-Model-Bin", "bin_best_arch_ensemble"), |
| "dc50": ("ailab-bio/TACK-Model-DC50", "dc50_best_arch_ensemble"), |
| "dmax": ("ailab-bio/TACK-Model-Dmax", "dmax_best_arch_ensemble"), |
| } |
| loaded: Dict = {} |
| for task, (repo_id, subfolder) in repo_subfolders.items(): |
| try: |
| repo_dir = snapshot_download(repo_id=repo_id) |
| model_dir = os.path.join(repo_dir, subfolder) |
| loaded[task] = EnsemblePredictor.from_directory( |
| model_dir, device="cpu" |
| ) |
| logger.info( |
| "Loaded '%s' predictor from %s (%d models).", |
| task, |
| model_dir, |
| len(loaded[task].models), |
| ) |
| except Exception as exc: |
| logger.warning("Could not load predictor '%s': %s", task, exc) |
| return loaded |
|
|
|
|
| PREDICTORS: Dict = _load_predictors() |
| AVAILABLE_TASKS: List[str] = list(PREDICTORS.keys()) |
|
|
|
|
| |
| |
| |
| def _make_sample( |
| smiles: str, |
| poi_name: str, |
| poi_sequence: str, |
| ligase_name: str, |
| ligase_sequence: str, |
| cell_line: str, |
| treatment_time: float, |
| assay_type: str, |
| ) -> "SampleInput": |
| from tackai.ensemble_predictor import SampleInput |
| return SampleInput( |
| smiles=smiles.strip() if smiles else None, |
| poi_name=poi_name.strip() if poi_name else None, |
| poi_sequence=poi_sequence.strip() if poi_sequence else None, |
| ligase_name=ligase_name.strip() if ligase_name else None, |
| ligase_sequence=ligase_sequence.strip() if ligase_sequence else None, |
| cell_line=(cell_line or "Unknown").strip(), |
| assay_type=(assay_type or "Unknown").strip(), |
| treatment_time=float(treatment_time) if treatment_time else 24.0, |
| ) |
|
|
|
|
| def _sample_from_row(row: pd.Series) -> "SampleInput": |
| from tackai.ensemble_predictor import SampleInput |
|
|
| def get_str(*keys: str) -> Optional[str]: |
| for k in keys: |
| v = row.get(k) |
| if v is not None and pd.notna(v) and str(v).strip(): |
| return str(v).strip() |
| return None |
|
|
| def get_float(*keys: str) -> Optional[float]: |
| for k in keys: |
| v = row.get(k) |
| if v is not None and pd.notna(v): |
| try: |
| return float(v) |
| except (ValueError, TypeError): |
| pass |
| return None |
|
|
| return SampleInput( |
| smiles=get_str("SMILES", "smiles"), |
| poi_name=get_str("POI_Name", "poi_name"), |
| poi_sequence=get_str("POI_Sequence", "poi_sequence"), |
| ligase_name=get_str("Ligase_Name", "ligase_name"), |
| ligase_sequence=get_str("Ligase_Sequence", "ligase_sequence"), |
| cell_line=get_str("Cell_Line", "cell_line") or "Unknown", |
| assay_type=get_str("Assay", "assay_type") or "Unknown", |
| treatment_time=get_float( |
| "Assay_Time", "assay_time", "treatment_time" |
| ) or 24.0, |
| ) |
|
|
|
|
| def _result_row(result: "EnsemblePrediction", task: str) -> Dict: |
| return { |
| "Task": TASK_LABELS.get(task, task.upper()), |
| "Prediction": round(float(result.weighted_mean[0]), 4), |
| "Uncertainty (Β±std)": round(float(result.uncertainty_std[0]), 4), |
| "CI 95% Low": round(float(result.ci_percentile_lower_95[0]), 4), |
| "CI 95% High": round(float(result.ci_percentile_upper_95[0]), 4), |
| "n_models": len(result.model_names), |
| } |
|
|
|
|
| |
| |
| |
| def run_single_prediction( |
| smiles: str, |
| poi_name: str, |
| poi_sequence: str, |
| ligase_name: str, |
| ligase_sequence: str, |
| cell_line: str, |
| treatment_time: float, |
| assay_type: str, |
| selected_tasks: List[str], |
| ) -> Tuple[pd.DataFrame, str]: |
| if not smiles or not smiles.strip(): |
| return pd.DataFrame(columns=RESULT_COLS), "Please enter a SMILES string." |
| if not selected_tasks: |
| return pd.DataFrame(columns=RESULT_COLS), "Please select at least one task." |
| if not PREDICTORS: |
| return pd.DataFrame(columns=RESULT_COLS), ( |
| "No models loaded β ensure the HF repositories are available." |
| ) |
|
|
| sample = _make_sample( |
| smiles, poi_name, poi_sequence, |
| ligase_name, ligase_sequence, |
| cell_line, treatment_time, assay_type, |
| ) |
| rows = [] |
| for task in selected_tasks: |
| if task not in PREDICTORS: |
| continue |
| try: |
| task_dict = PREDICTORS[task].predict_batch( |
| [sample], tasks=[task] |
| )[0] |
| if task_dict and task in task_dict: |
| rows.append(_result_row(task_dict[task], task)) |
| except Exception as exc: |
| logger.error("Single prediction error (task=%s): %s", task, exc) |
| rows.append({ |
| "Task": TASK_LABELS.get(task, task.upper()), |
| "Prediction": "ERROR", |
| "Uncertainty (Β±std)": "β", |
| "CI 95% Low": "β", |
| "CI 95% High": "β", |
| "n_models": 0, |
| }) |
|
|
| if not rows: |
| return pd.DataFrame(columns=RESULT_COLS), "No predictions returned." |
| return pd.DataFrame(rows), "" |
|
|
|
|
| def load_csv_preview(filepath: Optional[str]) -> pd.DataFrame: |
| if not filepath: |
| return pd.DataFrame() |
| try: |
| return pd.read_csv(filepath).head(5) |
| except Exception as exc: |
| logger.error("CSV preview error: %s", exc) |
| return pd.DataFrame() |
|
|
|
|
| def run_batch_prediction( |
| filepath: Optional[str], |
| selected_tasks: List[str], |
| ) -> Tuple[pd.DataFrame, Optional[str], str]: |
| """Run batch predictions; returns (results_df, download_path, message).""" |
| if not filepath: |
| return pd.DataFrame(), None, "Please upload a CSV file." |
| if not selected_tasks: |
| return pd.DataFrame(), None, "Please select at least one task." |
| if not PREDICTORS: |
| return pd.DataFrame(), None, ( |
| "No models loaded β ensure the HF repositories are available." |
| ) |
|
|
| try: |
| df = pd.read_csv(filepath) |
| except Exception as exc: |
| return pd.DataFrame(), None, f"Error reading CSV: {exc}" |
|
|
| if df.empty: |
| return pd.DataFrame(), None, "Uploaded CSV is empty." |
| if "SMILES" not in df.columns and "smiles" not in df.columns: |
| return pd.DataFrame(), None, "CSV must contain a 'SMILES' column." |
|
|
| samples = [_sample_from_row(row) for _, row in df.iterrows()] |
| result_rows = [] |
|
|
| for task in selected_tasks: |
| if task not in PREDICTORS: |
| continue |
| try: |
| batch = PREDICTORS[task].predict_batch(samples, tasks=[task]) |
| for i, task_dict in enumerate(batch): |
| base = {"#": i + 1, "SMILES": samples[i].smiles or ""} |
| if task_dict and task in task_dict: |
| result_rows.append({ |
| **base, |
| **_result_row(task_dict[task], task), |
| }) |
| else: |
| result_rows.append({ |
| **base, |
| "Task": TASK_LABELS.get(task, task.upper()), |
| "Prediction": "ERROR", |
| }) |
| except Exception as exc: |
| logger.error( |
| "Batch prediction error (task=%s): %s", task, exc |
| ) |
| return pd.DataFrame(), None, f"Prediction failed: {exc}" |
|
|
| if not result_rows: |
| return pd.DataFrame(), None, "No predictions returned." |
|
|
| results_df = pd.DataFrame(result_rows) |
| tmp = tempfile.NamedTemporaryFile( |
| delete=False, suffix=".csv", prefix="tack_results_", |
| ) |
| results_df.to_csv(tmp.name, index=False) |
| return results_df, tmp.name, "" |
|
|
|
|
| def get_template_csv() -> str: |
| """Write a CSV template to a temp file and return its path.""" |
| df = pd.DataFrame([{ |
| "SMILES": EXAMPLE_SMILES, |
| "POI_Name": "AR", |
| "POI_Sequence": "MEVQLGLGRVYPRPPSKTYRGAFQNLFQSVREVIQNPGPRHPEAASAAPPGASLLLLQQQQQQQQQQQQQQQQQQQQQQQETSPRQQQQQQGEDGSPQAHRRGPTGYLVLDEEQQPSQPQSALECHPERGCVPEPGAAVAASKGLPQQLPAPPDEDDSAAPSTLSLLGPTFPGLSSCSADLKDILSEASTMQLLQQQQQEAVSEGSSSGRAREASGAPTSSKDNYLGGTSTISDNAKELCKAVSVSMGLGVEALEHLSPGEQLRGDCMYAPLLGVPPAVRPTPCAPLAECKGSLLDDSAGKSTEDTAEYSPFKGGYTKGLEGESLGCSGSAAAGSSGTLELPSTLSLYKSGALDEAAAYQSRDYYNFPLALAGPPPPPPPPHPHARIKLENPLDYGSAWAAAAAQCRYGDLASLHGAGAAGPGSGSPSAAASSSWHTLFTAEEGQLYGPCGGGGGGGGGGGGGGGGGGGGGGGEAGAVAPYGYTRPPQGLAGQESDFTAPDVWYPGGMVSRVPYPSPTCVKSEMGPWMDSYSGPYGDMRLETARDHVLPIDYYFPPQKTCLICGDEASGCHYGALTCGSCKVFFKRAAEGKQKYLCASRNDCTIDKFRRKNCPSCRLRKCYEAGMTLGARKLKKLGNLKLQEEGEASSTTSPTEETTQKLTVSHIEGYECQPIFLNVLEAIEPGVVCAGHDNNQPDSFAALLSSLNELGERQLVHVVKWAKALPGFRNLHVDDQMAVIQYSWMGLMVFAMGWRSFTNVNSRMLYFAPDLVFNEYRMHKSRMYSQCVRMRHLSQEFGWLQITPQEFLCMKALLLFSIIPVDGLKNQKFFDELRMNYIKELDRIIACKRKNPTSCSRRFYQLTKLLDSVQPIARELHQFTFDLLIKSHMVSVDFPEMMAEIISVQVPKILSGKVKPIYFHTQ", |
| "Ligase_Name": "CRBN", |
| "Ligase_Sequence": "MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL", |
| "Cell_Line": "Unknown", |
| "Assay_Time": 24.0, |
| "Assay": "Unknown", |
| }]) |
| tmp = tempfile.NamedTemporaryFile( |
| delete=False, suffix=".csv", prefix="tack_template_", |
| ) |
| df.to_csv(tmp.name, index=False) |
| return tmp.name |
|
|
|
|
| |
| |
| |
| _TASK_CHOICES = [ |
| (TASK_LABELS["bin"], "bin"), |
| (TASK_LABELS["dc50"], "dc50"), |
| (TASK_LABELS["dmax"], "dmax"), |
| ] |
| _DEFAULT_TASKS = AVAILABLE_TASKS if AVAILABLE_TASKS else ["bin"] |
| _NO_MODELS_BANNER = ( |
| "> β οΈ **No models loaded.** Ensure the HF repositories " |
| "(`ailab-bio/tack-model-*`) are publicly available and try again." |
| if not PREDICTORS |
| else "" |
| ) |
|
|
| with gr.Blocks( |
| theme=gr.themes.Soft(primary_hue="blue"), |
| title="TACK β PROTAC Activity Predictor", |
| ) as demo: |
|
|
| gr.Markdown("# TACK β PROTAC Activity Predictor") |
| gr.Markdown( |
| "Predict PROTAC/degrader activity using an ensemble of XGBoost " |
| "and MLP models trained on PROTAC-DB, TPD-DB, and PROTAC-Pedia. " |
| "Outputs binary activity probability, DC50 (nM), and Dmax (%) " |
| "with full uncertainty quantification." |
| ) |
| if _NO_MODELS_BANNER: |
| gr.Markdown(_NO_MODELS_BANNER) |
|
|
| task_selector = gr.CheckboxGroup( |
| choices=_TASK_CHOICES, |
| value=_DEFAULT_TASKS, |
| label="Prediction Task(s)", |
| info="Select one or more properties to predict.", |
| ) |
|
|
| with gr.Tabs(): |
|
|
| |
| with gr.Tab("Single Compound"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| smiles_in = gr.Textbox( |
| label="SMILES *", |
| placeholder="Paste a SMILES stringβ¦", |
| lines=2, |
| value=EXAMPLE_SMILES, |
| ) |
| with gr.Accordion("POI (Protein of Interest)", open=False): |
| poi_name_in = gr.Textbox( |
| label="POI Name", |
| placeholder="e.g. AR, BRD4, SMARCA2", |
| ) |
| poi_seq_in = gr.Textbox( |
| label="Amino Acid Sequence", |
| placeholder=( |
| "Paste full sequence (no FASTA header)β¦" |
| ), |
| lines=5, |
| ) |
| with gr.Accordion("E3 Ligase", open=False): |
| ligase_name_in = gr.Textbox( |
| label="E3 Ligase Name", |
| placeholder="e.g. CRBN, VHL, MDM2", |
| value="CRBN", |
| ) |
| ligase_seq_in = gr.Textbox( |
| label="Amino Acid Sequence", |
| placeholder=( |
| "Paste full sequence (no FASTA header)β¦" |
| ), |
| lines=5, |
| ) |
| with gr.Row(): |
| cell_line_in = gr.Textbox( |
| label="Cell Line", |
| value="Unknown", |
| placeholder="e.g. HEK293, Jurkat", |
| ) |
| treatment_time_in = gr.Number( |
| label="Treatment Time (h)", |
| value=24.0, |
| minimum=0.0, |
| ) |
| assay_type_in = gr.Textbox( |
| label="Assay Type", |
| value="Unknown", |
| placeholder="e.g. Western, FACS", |
| ) |
| predict_single_btn = gr.Button( |
| "Predict", variant="primary", size="lg", |
| ) |
|
|
| with gr.Column(scale=1): |
| single_msg = gr.Markdown() |
| single_results = gr.Dataframe( |
| headers=RESULT_COLS, |
| label="Results", |
| interactive=False, |
| wrap=True, |
| ) |
|
|
| predict_single_btn.click( |
| fn=run_single_prediction, |
| inputs=[ |
| smiles_in, poi_name_in, poi_seq_in, |
| ligase_name_in, ligase_seq_in, |
| cell_line_in, treatment_time_in, assay_type_in, |
| task_selector, |
| ], |
| outputs=[single_results, single_msg], |
| ) |
|
|
| |
| with gr.Tab("Batch (CSV)"): |
| gr.Markdown( |
| "Upload a CSV with a **SMILES** column (required). " |
| "Optional columns: `POI_Name`, `POI_Sequence`, " |
| "`Ligase_Name`, `Ligase_Sequence`, `Cell_Line`, " |
| "`Assay_Time`, `Assay`." |
| ) |
| with gr.Row(): |
| csv_upload = gr.File( |
| label="Upload CSV", |
| file_types=[".csv"], |
| type="filepath", |
| ) |
| with gr.Column(): |
| template_btn = gr.Button( |
| "Get Template CSV", size="sm", |
| ) |
| template_out = gr.File( |
| label="Template", |
| interactive=False, |
| visible=True, |
| ) |
| template_btn.click(fn=get_template_csv, outputs=template_out) |
|
|
| csv_preview = gr.Dataframe( |
| label="CSV Preview (first 5 rows)", |
| interactive=False, |
| ) |
| csv_upload.change( |
| fn=load_csv_preview, |
| inputs=csv_upload, |
| outputs=csv_preview, |
| ) |
|
|
| batch_predict_btn = gr.Button( |
| "Run Batch Prediction", variant="primary", size="lg", |
| ) |
| batch_msg = gr.Markdown() |
| batch_results = gr.Dataframe( |
| label="Batch Results", |
| interactive=False, |
| wrap=True, |
| ) |
| batch_download = gr.File( |
| label="Download Results CSV", |
| interactive=False, |
| ) |
|
|
| batch_predict_btn.click( |
| fn=run_batch_prediction, |
| inputs=[csv_upload, task_selector], |
| outputs=[batch_results, batch_download, batch_msg], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|