TACK-Demo / app.py
ribesstefano's picture
Fixed cache name
f3b42a1
"""TACK Demo β€” PROTAC/degrader activity prediction via ensemble models."""
import os
import logging
import tempfile
import warnings
from typing import Dict, List, Optional, Tuple
import pandas as pd
import gradio as gr
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
TASK_LABELS = {
"bin": "Binary Activity (prob.)",
"dc50": "DC50 (nM)",
"dmax": "Dmax (%)",
}
RESULT_COLS = [
"Task", "Prediction", "Uncertainty (Β±std)",
"CI 95% Low", "CI 95% High", "n_models",
]
EXAMPLE_SMILES = (
"Cn1c(=O)n(C2CCC(=O)NC2=O)c2cccc(C#CCCN3CCC4(CC3)CC(n3cc(NC(=O)c5cnn6ccc(N7C[C@H]8C[C@@H]7CO8)nc56)c(C(F)F)n3)C4)c21",
"Cc1ncsc1C1=CCC([C@H](C)NC(=O)[C@@H]2C[C@@H](O)CN2C(=O)C(c2cc(N3CCC(CN4CCC(n5nc(NC(C)C)c6nnc(-c7cccc(F)c7O)cc65)CC4)CC3)no2)C(C)C)C=C1",
"COc1ccccc1C(=O)NCc1ccc(-c2nn3c(c2C(N)=O)Nc2ccc(N4CCN(CC5CCN(c6ccc(N7CCC(=O)NC7=O)cc6)CC5)CC4)cc2CC3)c(OC)c1",
)
# ---------------------------------------------------------------------------
# Model loading at startup
# ---------------------------------------------------------------------------
def _load_predictors() -> Dict:
"""Download models from HF Hub and return EnsemblePredictor instances."""
try:
from huggingface_hub import snapshot_download
from tackai.ensemble_predictor import EnsemblePredictor
except ImportError as exc:
logger.error("Missing dependency: %s", exc)
return {}
try:
cache_dir = snapshot_download(repo_id="ailab-bio/tack-cache")
os.environ.setdefault("TACKAI_CACHE", cache_dir)
logger.info("Cache downloaded to %s", cache_dir)
except Exception as exc:
logger.warning("Cache repo unavailable: %s", exc)
# Each repo stores the ensemble under a task-named subfolder.
repo_subfolders = {
"bin": ("ailab-bio/TACK-Model-Bin", "bin_best_arch_ensemble"),
"dc50": ("ailab-bio/TACK-Model-DC50", "dc50_best_arch_ensemble"),
"dmax": ("ailab-bio/TACK-Model-Dmax", "dmax_best_arch_ensemble"),
}
loaded: Dict = {}
for task, (repo_id, subfolder) in repo_subfolders.items():
try:
repo_dir = snapshot_download(repo_id=repo_id)
model_dir = os.path.join(repo_dir, subfolder)
loaded[task] = EnsemblePredictor.from_directory(
model_dir, device="cpu"
)
logger.info(
"Loaded '%s' predictor from %s (%d models).",
task,
model_dir,
len(loaded[task].models),
)
except Exception as exc:
logger.warning("Could not load predictor '%s': %s", task, exc)
return loaded
PREDICTORS: Dict = _load_predictors()
AVAILABLE_TASKS: List[str] = list(PREDICTORS.keys())
# ---------------------------------------------------------------------------
# Sample construction helpers
# ---------------------------------------------------------------------------
def _make_sample(
smiles: str,
poi_name: str,
poi_sequence: str,
ligase_name: str,
ligase_sequence: str,
cell_line: str,
treatment_time: float,
assay_type: str,
) -> "SampleInput":
from tackai.ensemble_predictor import SampleInput
return SampleInput(
smiles=smiles.strip() if smiles else None,
poi_name=poi_name.strip() if poi_name else None,
poi_sequence=poi_sequence.strip() if poi_sequence else None,
ligase_name=ligase_name.strip() if ligase_name else None,
ligase_sequence=ligase_sequence.strip() if ligase_sequence else None,
cell_line=(cell_line or "Unknown").strip(),
assay_type=(assay_type or "Unknown").strip(),
treatment_time=float(treatment_time) if treatment_time else 24.0,
)
def _sample_from_row(row: pd.Series) -> "SampleInput":
from tackai.ensemble_predictor import SampleInput
def get_str(*keys: str) -> Optional[str]:
for k in keys:
v = row.get(k)
if v is not None and pd.notna(v) and str(v).strip():
return str(v).strip()
return None
def get_float(*keys: str) -> Optional[float]:
for k in keys:
v = row.get(k)
if v is not None and pd.notna(v):
try:
return float(v)
except (ValueError, TypeError):
pass
return None
return SampleInput(
smiles=get_str("SMILES", "smiles"),
poi_name=get_str("POI_Name", "poi_name"),
poi_sequence=get_str("POI_Sequence", "poi_sequence"),
ligase_name=get_str("Ligase_Name", "ligase_name"),
ligase_sequence=get_str("Ligase_Sequence", "ligase_sequence"),
cell_line=get_str("Cell_Line", "cell_line") or "Unknown",
assay_type=get_str("Assay", "assay_type") or "Unknown",
treatment_time=get_float(
"Assay_Time", "assay_time", "treatment_time"
) or 24.0,
)
def _result_row(result: "EnsemblePrediction", task: str) -> Dict:
return {
"Task": TASK_LABELS.get(task, task.upper()),
"Prediction": round(float(result.weighted_mean[0]), 4),
"Uncertainty (Β±std)": round(float(result.uncertainty_std[0]), 4),
"CI 95% Low": round(float(result.ci_percentile_lower_95[0]), 4),
"CI 95% High": round(float(result.ci_percentile_upper_95[0]), 4),
"n_models": len(result.model_names),
}
# ---------------------------------------------------------------------------
# Prediction callbacks
# ---------------------------------------------------------------------------
def run_single_prediction(
smiles: str,
poi_name: str,
poi_sequence: str,
ligase_name: str,
ligase_sequence: str,
cell_line: str,
treatment_time: float,
assay_type: str,
selected_tasks: List[str],
) -> Tuple[pd.DataFrame, str]:
if not smiles or not smiles.strip():
return pd.DataFrame(columns=RESULT_COLS), "Please enter a SMILES string."
if not selected_tasks:
return pd.DataFrame(columns=RESULT_COLS), "Please select at least one task."
if not PREDICTORS:
return pd.DataFrame(columns=RESULT_COLS), (
"No models loaded β€” ensure the HF repositories are available."
)
sample = _make_sample(
smiles, poi_name, poi_sequence,
ligase_name, ligase_sequence,
cell_line, treatment_time, assay_type,
)
rows = []
for task in selected_tasks:
if task not in PREDICTORS:
continue
try:
task_dict = PREDICTORS[task].predict_batch(
[sample], tasks=[task]
)[0]
if task_dict and task in task_dict:
rows.append(_result_row(task_dict[task], task))
except Exception as exc:
logger.error("Single prediction error (task=%s): %s", task, exc)
rows.append({
"Task": TASK_LABELS.get(task, task.upper()),
"Prediction": "ERROR",
"Uncertainty (Β±std)": "β€”",
"CI 95% Low": "β€”",
"CI 95% High": "β€”",
"n_models": 0,
})
if not rows:
return pd.DataFrame(columns=RESULT_COLS), "No predictions returned."
return pd.DataFrame(rows), ""
def load_csv_preview(filepath: Optional[str]) -> pd.DataFrame:
if not filepath:
return pd.DataFrame()
try:
return pd.read_csv(filepath).head(5)
except Exception as exc:
logger.error("CSV preview error: %s", exc)
return pd.DataFrame()
def run_batch_prediction(
filepath: Optional[str],
selected_tasks: List[str],
) -> Tuple[pd.DataFrame, Optional[str], str]:
"""Run batch predictions; returns (results_df, download_path, message)."""
if not filepath:
return pd.DataFrame(), None, "Please upload a CSV file."
if not selected_tasks:
return pd.DataFrame(), None, "Please select at least one task."
if not PREDICTORS:
return pd.DataFrame(), None, (
"No models loaded β€” ensure the HF repositories are available."
)
try:
df = pd.read_csv(filepath)
except Exception as exc:
return pd.DataFrame(), None, f"Error reading CSV: {exc}"
if df.empty:
return pd.DataFrame(), None, "Uploaded CSV is empty."
if "SMILES" not in df.columns and "smiles" not in df.columns:
return pd.DataFrame(), None, "CSV must contain a 'SMILES' column."
samples = [_sample_from_row(row) for _, row in df.iterrows()]
result_rows = []
for task in selected_tasks:
if task not in PREDICTORS:
continue
try:
batch = PREDICTORS[task].predict_batch(samples, tasks=[task])
for i, task_dict in enumerate(batch):
base = {"#": i + 1, "SMILES": samples[i].smiles or ""}
if task_dict and task in task_dict:
result_rows.append({
**base,
**_result_row(task_dict[task], task),
})
else:
result_rows.append({
**base,
"Task": TASK_LABELS.get(task, task.upper()),
"Prediction": "ERROR",
})
except Exception as exc:
logger.error(
"Batch prediction error (task=%s): %s", task, exc
)
return pd.DataFrame(), None, f"Prediction failed: {exc}"
if not result_rows:
return pd.DataFrame(), None, "No predictions returned."
results_df = pd.DataFrame(result_rows)
tmp = tempfile.NamedTemporaryFile(
delete=False, suffix=".csv", prefix="tack_results_",
)
results_df.to_csv(tmp.name, index=False)
return results_df, tmp.name, ""
def get_template_csv() -> str:
"""Write a CSV template to a temp file and return its path."""
df = pd.DataFrame([{
"SMILES": EXAMPLE_SMILES,
"POI_Name": "AR",
"POI_Sequence": "MEVQLGLGRVYPRPPSKTYRGAFQNLFQSVREVIQNPGPRHPEAASAAPPGASLLLLQQQQQQQQQQQQQQQQQQQQQQQETSPRQQQQQQGEDGSPQAHRRGPTGYLVLDEEQQPSQPQSALECHPERGCVPEPGAAVAASKGLPQQLPAPPDEDDSAAPSTLSLLGPTFPGLSSCSADLKDILSEASTMQLLQQQQQEAVSEGSSSGRAREASGAPTSSKDNYLGGTSTISDNAKELCKAVSVSMGLGVEALEHLSPGEQLRGDCMYAPLLGVPPAVRPTPCAPLAECKGSLLDDSAGKSTEDTAEYSPFKGGYTKGLEGESLGCSGSAAAGSSGTLELPSTLSLYKSGALDEAAAYQSRDYYNFPLALAGPPPPPPPPHPHARIKLENPLDYGSAWAAAAAQCRYGDLASLHGAGAAGPGSGSPSAAASSSWHTLFTAEEGQLYGPCGGGGGGGGGGGGGGGGGGGGGGGEAGAVAPYGYTRPPQGLAGQESDFTAPDVWYPGGMVSRVPYPSPTCVKSEMGPWMDSYSGPYGDMRLETARDHVLPIDYYFPPQKTCLICGDEASGCHYGALTCGSCKVFFKRAAEGKQKYLCASRNDCTIDKFRRKNCPSCRLRKCYEAGMTLGARKLKKLGNLKLQEEGEASSTTSPTEETTQKLTVSHIEGYECQPIFLNVLEAIEPGVVCAGHDNNQPDSFAALLSSLNELGERQLVHVVKWAKALPGFRNLHVDDQMAVIQYSWMGLMVFAMGWRSFTNVNSRMLYFAPDLVFNEYRMHKSRMYSQCVRMRHLSQEFGWLQITPQEFLCMKALLLFSIIPVDGLKNQKFFDELRMNYIKELDRIIACKRKNPTSCSRRFYQLTKLLDSVQPIARELHQFTFDLLIKSHMVSVDFPEMMAEIISVQVPKILSGKVKPIYFHTQ",
"Ligase_Name": "CRBN",
"Ligase_Sequence": "MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL",
"Cell_Line": "Unknown",
"Assay_Time": 24.0,
"Assay": "Unknown",
}])
tmp = tempfile.NamedTemporaryFile(
delete=False, suffix=".csv", prefix="tack_template_",
)
df.to_csv(tmp.name, index=False)
return tmp.name
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
_TASK_CHOICES = [
(TASK_LABELS["bin"], "bin"),
(TASK_LABELS["dc50"], "dc50"),
(TASK_LABELS["dmax"], "dmax"),
]
_DEFAULT_TASKS = AVAILABLE_TASKS if AVAILABLE_TASKS else ["bin"]
_NO_MODELS_BANNER = (
"> ⚠️ **No models loaded.** Ensure the HF repositories "
"(`ailab-bio/tack-model-*`) are publicly available and try again."
if not PREDICTORS
else ""
)
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="blue"),
title="TACK β€” PROTAC Activity Predictor",
) as demo:
gr.Markdown("# TACK β€” PROTAC Activity Predictor")
gr.Markdown(
"Predict PROTAC/degrader activity using an ensemble of XGBoost "
"and MLP models trained on PROTAC-DB, TPD-DB, and PROTAC-Pedia. "
"Outputs binary activity probability, DC50 (nM), and Dmax (%) "
"with full uncertainty quantification."
)
if _NO_MODELS_BANNER:
gr.Markdown(_NO_MODELS_BANNER)
task_selector = gr.CheckboxGroup(
choices=_TASK_CHOICES,
value=_DEFAULT_TASKS,
label="Prediction Task(s)",
info="Select one or more properties to predict.",
)
with gr.Tabs():
# ── Tab 1: Single compound ─────────────────────────────────────────
with gr.Tab("Single Compound"):
with gr.Row():
with gr.Column(scale=1):
smiles_in = gr.Textbox(
label="SMILES *",
placeholder="Paste a SMILES string…",
lines=2,
value=EXAMPLE_SMILES,
)
with gr.Accordion("POI (Protein of Interest)", open=False):
poi_name_in = gr.Textbox(
label="POI Name",
placeholder="e.g. AR, BRD4, SMARCA2",
)
poi_seq_in = gr.Textbox(
label="Amino Acid Sequence",
placeholder=(
"Paste full sequence (no FASTA header)…"
),
lines=5,
)
with gr.Accordion("E3 Ligase", open=False):
ligase_name_in = gr.Textbox(
label="E3 Ligase Name",
placeholder="e.g. CRBN, VHL, MDM2",
value="CRBN",
)
ligase_seq_in = gr.Textbox(
label="Amino Acid Sequence",
placeholder=(
"Paste full sequence (no FASTA header)…"
),
lines=5,
)
with gr.Row():
cell_line_in = gr.Textbox(
label="Cell Line",
value="Unknown",
placeholder="e.g. HEK293, Jurkat",
)
treatment_time_in = gr.Number(
label="Treatment Time (h)",
value=24.0,
minimum=0.0,
)
assay_type_in = gr.Textbox(
label="Assay Type",
value="Unknown",
placeholder="e.g. Western, FACS",
)
predict_single_btn = gr.Button(
"Predict", variant="primary", size="lg",
)
with gr.Column(scale=1):
single_msg = gr.Markdown()
single_results = gr.Dataframe(
headers=RESULT_COLS,
label="Results",
interactive=False,
wrap=True,
)
predict_single_btn.click(
fn=run_single_prediction,
inputs=[
smiles_in, poi_name_in, poi_seq_in,
ligase_name_in, ligase_seq_in,
cell_line_in, treatment_time_in, assay_type_in,
task_selector,
],
outputs=[single_results, single_msg],
)
# ── Tab 2: Batch (CSV) ─────────────────────────────────────────────
with gr.Tab("Batch (CSV)"):
gr.Markdown(
"Upload a CSV with a **SMILES** column (required). "
"Optional columns: `POI_Name`, `POI_Sequence`, "
"`Ligase_Name`, `Ligase_Sequence`, `Cell_Line`, "
"`Assay_Time`, `Assay`."
)
with gr.Row():
csv_upload = gr.File(
label="Upload CSV",
file_types=[".csv"],
type="filepath",
)
with gr.Column():
template_btn = gr.Button(
"Get Template CSV", size="sm",
)
template_out = gr.File(
label="Template",
interactive=False,
visible=True,
)
template_btn.click(fn=get_template_csv, outputs=template_out)
csv_preview = gr.Dataframe(
label="CSV Preview (first 5 rows)",
interactive=False,
)
csv_upload.change(
fn=load_csv_preview,
inputs=csv_upload,
outputs=csv_preview,
)
batch_predict_btn = gr.Button(
"Run Batch Prediction", variant="primary", size="lg",
)
batch_msg = gr.Markdown()
batch_results = gr.Dataframe(
label="Batch Results",
interactive=False,
wrap=True,
)
batch_download = gr.File(
label="Download Results CSV",
interactive=False,
)
batch_predict_btn.click(
fn=run_batch_prediction,
inputs=[csv_upload, task_selector],
outputs=[batch_results, batch_download, batch_msg],
)
if __name__ == "__main__":
demo.launch()