caliby / design.py
Justine Yuan
Caliby HuggingFace example
3beba17
"""Core design pipeline: context building, execution, output formatting."""
import re
import tempfile
from pathlib import Path
import gradio as gr
import pandas as pd
import spaces
import torch
from app_config import WEIGHTS_DIR
from constraints import _build_pos_constraint_df, _validate_design_inputs
from ensemble import _generate_protpardelle_ensemble, _setup_user_ensemble_dir
from file_utils import _copy_uploaded_files, _get_file_path, _sanitize_download_stem, _write_zip_from_paths
from models import get_model
from self_consistency import _run_self_consistency
# ZeroGPU quota-aware retry: request the max duration first, and if the
# scheduler returns a quota error (which is free — no GPU time consumed),
# parse the remaining seconds and retry with that exact amount.
_MAX_GPU_DURATION = 120 # Per-call max; daily quota is 210s but per-call cap is lower
_gpu_duration_override: int | None = None
def _dynamic_gpu_duration(*args, **kwargs) -> int:
"""Return the current GPU duration for @spaces.GPU scheduling."""
return _gpu_duration_override if _gpu_duration_override is not None else _MAX_GPU_DURATION
def _parse_quota_left(error: Exception) -> int | None:
"""Extract remaining GPU seconds from a ZeroGPU quota error message.
Returns the number of seconds left, or None if not a recoverable quota error.
"""
message = getattr(error, 'message', None)
if not isinstance(message, str):
return None
match = re.search(r'(\d+)s left\)', message)
return int(match.group(1)) if match else None
def _build_design_context(
pdb_paths: list[str],
ensemble_mode: str,
tmpdir: Path,
num_protpardelle_conformers: int,
fixed_pos_seq: str,
fixed_pos_scn: str,
fixed_pos_override_seq: str,
pos_restrict_aatype: str,
symmetry_pos: str,
) -> tuple[list[str] | dict[str, list[str]], pd.DataFrame | None]:
pdb_key = Path(pdb_paths[0]).stem
pos_constraint_df = _build_pos_constraint_df(
pdb_key=pdb_key,
fixed_pos_seq=fixed_pos_seq,
fixed_pos_scn=fixed_pos_scn,
fixed_pos_override_seq=fixed_pos_override_seq,
pos_restrict_aatype=pos_restrict_aatype,
symmetry_pos=symmetry_pos,
)
if ensemble_mode == "none":
return pdb_paths, pos_constraint_df
if ensemble_mode == "synthetic":
design_inputs = _generate_protpardelle_ensemble(
pdb_path=pdb_paths[0],
num_conformers=num_protpardelle_conformers,
out_dir=tmpdir,
weights_dir=WEIGHTS_DIR,
)
else:
design_inputs = _setup_user_ensemble_dir(pdb_paths=pdb_paths)
if pos_constraint_df is not None:
from caliby import make_ensemble_constraints
row = pos_constraint_df.iloc[0]
cols = {col: row[col] for col in pos_constraint_df.columns if col != "pdb_key"}
pos_constraint_df = make_ensemble_constraints({pdb_key: cols}, design_inputs)
return design_inputs, pos_constraint_df
def _format_outputs(outputs: dict) -> tuple[pd.DataFrame, str, list[str]]:
out_pdb_list = outputs["out_pdb"]
df = pd.DataFrame(
{
"Sample": [Path(out_pdb).stem for out_pdb in out_pdb_list],
"Sequence": outputs["seq"],
"Energy (U)": outputs["U"],
}
)
fasta_lines = []
for i, (eid, seq) in enumerate(zip(outputs["example_id"], outputs["seq"])):
fasta_lines.append(f">{eid}_sample{i}")
fasta_lines.append(seq)
fasta_text = "\n".join(fasta_lines)
return df, fasta_text, out_pdb_list
@spaces.GPU(duration=_dynamic_gpu_duration)
def _design_sequences_gpu(
pdb_files: list | None,
ensemble_mode: str,
model_variant: str,
num_seqs: int,
omit_aas: list[str] | None,
temperature: float,
fixed_pos_seq: str,
fixed_pos_scn: str,
fixed_pos_override_seq: str,
pos_restrict_aatype: str,
symmetry_pos: str,
num_protpardelle_conformers: int,
run_af2_eval: bool = False,
):
validation_error = _validate_design_inputs(pdb_files, ensemble_mode)
if validation_error:
return pd.DataFrame(), validation_error, None, None, {}, {}
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
download_stem = _sanitize_download_stem(_get_file_path(pdb_files[0]).stem)
gr.Info("Loading model...")
model = get_model(model_variant, device)
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
pdb_paths = _copy_uploaded_files(pdb_files, tmpdir)
input_pdb_data = {Path(p).stem: Path(p).read_text() for p in pdb_paths}
out_dir = tmpdir / "outputs"
out_dir.mkdir(parents=True, exist_ok=True)
if ensemble_mode == "synthetic":
gr.Info("Generating conformer ensemble...")
elif ensemble_mode == "user":
gr.Info("Preparing uploaded ensemble...")
design_inputs, pos_constraint_df = _build_design_context(
pdb_paths=pdb_paths,
ensemble_mode=ensemble_mode,
tmpdir=tmpdir,
num_protpardelle_conformers=num_protpardelle_conformers,
fixed_pos_seq=fixed_pos_seq,
fixed_pos_scn=fixed_pos_scn,
fixed_pos_override_seq=fixed_pos_override_seq,
pos_restrict_aatype=pos_restrict_aatype,
symmetry_pos=symmetry_pos,
)
gr.Info("Designing sequences...")
sample_kwargs = dict(
out_dir=str(out_dir),
num_seqs_per_pdb=num_seqs,
omit_aas=omit_aas if omit_aas else None,
temperature=temperature,
num_workers=0,
pos_constraint_df=pos_constraint_df,
)
if ensemble_mode == "none":
outputs = model.sample(design_inputs, **sample_kwargs)
else:
outputs = model.ensemble_sample(design_inputs, **sample_kwargs)
df, fasta_text, out_pdb_list = _format_outputs(outputs)
sc_zip_path = None
af2_pdb_data = {}
if run_af2_eval:
gr.Info("Running AF2 self-consistency evaluation...")
sc_zip_path, af2_pdb_data = _run_self_consistency(model, df, out_pdb_list, out_dir, download_stem)
out_zip_path = _write_zip_from_paths(out_pdb_list, download_stem, "_designs.zip")
return df, fasta_text, out_zip_path, sc_zip_path, af2_pdb_data, input_pdb_data
def design_sequences(
pdb_files: list | None,
ensemble_mode: str,
model_variant: str,
num_seqs: int,
omit_aas: list[str] | None,
temperature: float,
fixed_pos_seq: str,
fixed_pos_scn: str,
fixed_pos_override_seq: str,
pos_restrict_aatype: str,
symmetry_pos: str,
num_protpardelle_conformers: int,
run_af2_eval: bool = False,
):
"""Run sequence design with ZeroGPU quota-aware retry.
Requests the max GPU duration first. If the scheduler returns a quota
error (free — no GPU time consumed), parses the remaining seconds and
retries with that exact amount to maximize GPU utilization.
"""
global _gpu_duration_override
_gpu_duration_override = None
try:
return _design_sequences_gpu(
pdb_files=pdb_files,
ensemble_mode=ensemble_mode,
model_variant=model_variant,
num_seqs=num_seqs,
omit_aas=omit_aas,
temperature=temperature,
fixed_pos_seq=fixed_pos_seq,
fixed_pos_scn=fixed_pos_scn,
fixed_pos_override_seq=fixed_pos_override_seq,
pos_restrict_aatype=pos_restrict_aatype,
symmetry_pos=symmetry_pos,
num_protpardelle_conformers=num_protpardelle_conformers,
run_af2_eval=run_af2_eval,
)
except gr.Error as e:
remaining = _parse_quota_left(e)
print(f"[ZeroGPU retry] Caught gr.Error, parsed remaining={remaining}, message={getattr(e, 'message', str(e))}")
if remaining is None or remaining <= 0:
raise
gr.Info(f"GPU quota: {remaining}s remaining, retrying with exact quota")
_gpu_duration_override = remaining - 1
try:
return _design_sequences_gpu(
pdb_files=pdb_files,
ensemble_mode=ensemble_mode,
model_variant=model_variant,
num_seqs=num_seqs,
omit_aas=omit_aas,
temperature=temperature,
fixed_pos_seq=fixed_pos_seq,
fixed_pos_scn=fixed_pos_scn,
fixed_pos_override_seq=fixed_pos_override_seq,
pos_restrict_aatype=pos_restrict_aatype,
symmetry_pos=symmetry_pos,
num_protpardelle_conformers=num_protpardelle_conformers,
run_af2_eval=run_af2_eval,
)
finally:
_gpu_duration_override = None