Spaces:
Running on Zero
Running on Zero
| """Core design pipeline: context building, execution, output formatting.""" | |
| import re | |
| import tempfile | |
| from pathlib import Path | |
| import gradio as gr | |
| import pandas as pd | |
| import spaces | |
| import torch | |
| from app_config import WEIGHTS_DIR | |
| from constraints import _build_pos_constraint_df, _validate_design_inputs | |
| from ensemble import _generate_protpardelle_ensemble, _setup_user_ensemble_dir | |
| from file_utils import _copy_uploaded_files, _get_file_path, _sanitize_download_stem, _write_zip_from_paths | |
| from models import get_model | |
| from self_consistency import _run_self_consistency | |
| # ZeroGPU quota-aware retry: request the max duration first, and if the | |
| # scheduler returns a quota error (which is free — no GPU time consumed), | |
| # parse the remaining seconds and retry with that exact amount. | |
| _MAX_GPU_DURATION = 120 # Per-call max; daily quota is 210s but per-call cap is lower | |
| _gpu_duration_override: int | None = None | |
| def _dynamic_gpu_duration(*args, **kwargs) -> int: | |
| """Return the current GPU duration for @spaces.GPU scheduling.""" | |
| return _gpu_duration_override if _gpu_duration_override is not None else _MAX_GPU_DURATION | |
| def _parse_quota_left(error: Exception) -> int | None: | |
| """Extract remaining GPU seconds from a ZeroGPU quota error message. | |
| Returns the number of seconds left, or None if not a recoverable quota error. | |
| """ | |
| message = getattr(error, 'message', None) | |
| if not isinstance(message, str): | |
| return None | |
| match = re.search(r'(\d+)s left\)', message) | |
| return int(match.group(1)) if match else None | |
| def _build_design_context( | |
| pdb_paths: list[str], | |
| ensemble_mode: str, | |
| tmpdir: Path, | |
| num_protpardelle_conformers: int, | |
| fixed_pos_seq: str, | |
| fixed_pos_scn: str, | |
| fixed_pos_override_seq: str, | |
| pos_restrict_aatype: str, | |
| symmetry_pos: str, | |
| ) -> tuple[list[str] | dict[str, list[str]], pd.DataFrame | None]: | |
| pdb_key = Path(pdb_paths[0]).stem | |
| pos_constraint_df = _build_pos_constraint_df( | |
| pdb_key=pdb_key, | |
| fixed_pos_seq=fixed_pos_seq, | |
| fixed_pos_scn=fixed_pos_scn, | |
| fixed_pos_override_seq=fixed_pos_override_seq, | |
| pos_restrict_aatype=pos_restrict_aatype, | |
| symmetry_pos=symmetry_pos, | |
| ) | |
| if ensemble_mode == "none": | |
| return pdb_paths, pos_constraint_df | |
| if ensemble_mode == "synthetic": | |
| design_inputs = _generate_protpardelle_ensemble( | |
| pdb_path=pdb_paths[0], | |
| num_conformers=num_protpardelle_conformers, | |
| out_dir=tmpdir, | |
| weights_dir=WEIGHTS_DIR, | |
| ) | |
| else: | |
| design_inputs = _setup_user_ensemble_dir(pdb_paths=pdb_paths) | |
| if pos_constraint_df is not None: | |
| from caliby import make_ensemble_constraints | |
| row = pos_constraint_df.iloc[0] | |
| cols = {col: row[col] for col in pos_constraint_df.columns if col != "pdb_key"} | |
| pos_constraint_df = make_ensemble_constraints({pdb_key: cols}, design_inputs) | |
| return design_inputs, pos_constraint_df | |
| def _format_outputs(outputs: dict) -> tuple[pd.DataFrame, str, list[str]]: | |
| out_pdb_list = outputs["out_pdb"] | |
| df = pd.DataFrame( | |
| { | |
| "Sample": [Path(out_pdb).stem for out_pdb in out_pdb_list], | |
| "Sequence": outputs["seq"], | |
| "Energy (U)": outputs["U"], | |
| } | |
| ) | |
| fasta_lines = [] | |
| for i, (eid, seq) in enumerate(zip(outputs["example_id"], outputs["seq"])): | |
| fasta_lines.append(f">{eid}_sample{i}") | |
| fasta_lines.append(seq) | |
| fasta_text = "\n".join(fasta_lines) | |
| return df, fasta_text, out_pdb_list | |
| def _design_sequences_gpu( | |
| pdb_files: list | None, | |
| ensemble_mode: str, | |
| model_variant: str, | |
| num_seqs: int, | |
| omit_aas: list[str] | None, | |
| temperature: float, | |
| fixed_pos_seq: str, | |
| fixed_pos_scn: str, | |
| fixed_pos_override_seq: str, | |
| pos_restrict_aatype: str, | |
| symmetry_pos: str, | |
| num_protpardelle_conformers: int, | |
| run_af2_eval: bool = False, | |
| ): | |
| validation_error = _validate_design_inputs(pdb_files, ensemble_mode) | |
| if validation_error: | |
| return pd.DataFrame(), validation_error, None, None, {}, {} | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.set_grad_enabled(False) | |
| download_stem = _sanitize_download_stem(_get_file_path(pdb_files[0]).stem) | |
| gr.Info("Loading model...") | |
| model = get_model(model_variant, device) | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| tmpdir = Path(tmpdir) | |
| pdb_paths = _copy_uploaded_files(pdb_files, tmpdir) | |
| input_pdb_data = {Path(p).stem: Path(p).read_text() for p in pdb_paths} | |
| out_dir = tmpdir / "outputs" | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| if ensemble_mode == "synthetic": | |
| gr.Info("Generating conformer ensemble...") | |
| elif ensemble_mode == "user": | |
| gr.Info("Preparing uploaded ensemble...") | |
| design_inputs, pos_constraint_df = _build_design_context( | |
| pdb_paths=pdb_paths, | |
| ensemble_mode=ensemble_mode, | |
| tmpdir=tmpdir, | |
| num_protpardelle_conformers=num_protpardelle_conformers, | |
| fixed_pos_seq=fixed_pos_seq, | |
| fixed_pos_scn=fixed_pos_scn, | |
| fixed_pos_override_seq=fixed_pos_override_seq, | |
| pos_restrict_aatype=pos_restrict_aatype, | |
| symmetry_pos=symmetry_pos, | |
| ) | |
| gr.Info("Designing sequences...") | |
| sample_kwargs = dict( | |
| out_dir=str(out_dir), | |
| num_seqs_per_pdb=num_seqs, | |
| omit_aas=omit_aas if omit_aas else None, | |
| temperature=temperature, | |
| num_workers=0, | |
| pos_constraint_df=pos_constraint_df, | |
| ) | |
| if ensemble_mode == "none": | |
| outputs = model.sample(design_inputs, **sample_kwargs) | |
| else: | |
| outputs = model.ensemble_sample(design_inputs, **sample_kwargs) | |
| df, fasta_text, out_pdb_list = _format_outputs(outputs) | |
| sc_zip_path = None | |
| af2_pdb_data = {} | |
| if run_af2_eval: | |
| gr.Info("Running AF2 self-consistency evaluation...") | |
| sc_zip_path, af2_pdb_data = _run_self_consistency(model, df, out_pdb_list, out_dir, download_stem) | |
| out_zip_path = _write_zip_from_paths(out_pdb_list, download_stem, "_designs.zip") | |
| return df, fasta_text, out_zip_path, sc_zip_path, af2_pdb_data, input_pdb_data | |
| def design_sequences( | |
| pdb_files: list | None, | |
| ensemble_mode: str, | |
| model_variant: str, | |
| num_seqs: int, | |
| omit_aas: list[str] | None, | |
| temperature: float, | |
| fixed_pos_seq: str, | |
| fixed_pos_scn: str, | |
| fixed_pos_override_seq: str, | |
| pos_restrict_aatype: str, | |
| symmetry_pos: str, | |
| num_protpardelle_conformers: int, | |
| run_af2_eval: bool = False, | |
| ): | |
| """Run sequence design with ZeroGPU quota-aware retry. | |
| Requests the max GPU duration first. If the scheduler returns a quota | |
| error (free — no GPU time consumed), parses the remaining seconds and | |
| retries with that exact amount to maximize GPU utilization. | |
| """ | |
| global _gpu_duration_override | |
| _gpu_duration_override = None | |
| try: | |
| return _design_sequences_gpu( | |
| pdb_files=pdb_files, | |
| ensemble_mode=ensemble_mode, | |
| model_variant=model_variant, | |
| num_seqs=num_seqs, | |
| omit_aas=omit_aas, | |
| temperature=temperature, | |
| fixed_pos_seq=fixed_pos_seq, | |
| fixed_pos_scn=fixed_pos_scn, | |
| fixed_pos_override_seq=fixed_pos_override_seq, | |
| pos_restrict_aatype=pos_restrict_aatype, | |
| symmetry_pos=symmetry_pos, | |
| num_protpardelle_conformers=num_protpardelle_conformers, | |
| run_af2_eval=run_af2_eval, | |
| ) | |
| except gr.Error as e: | |
| remaining = _parse_quota_left(e) | |
| print(f"[ZeroGPU retry] Caught gr.Error, parsed remaining={remaining}, message={getattr(e, 'message', str(e))}") | |
| if remaining is None or remaining <= 0: | |
| raise | |
| gr.Info(f"GPU quota: {remaining}s remaining, retrying with exact quota") | |
| _gpu_duration_override = remaining - 1 | |
| try: | |
| return _design_sequences_gpu( | |
| pdb_files=pdb_files, | |
| ensemble_mode=ensemble_mode, | |
| model_variant=model_variant, | |
| num_seqs=num_seqs, | |
| omit_aas=omit_aas, | |
| temperature=temperature, | |
| fixed_pos_seq=fixed_pos_seq, | |
| fixed_pos_scn=fixed_pos_scn, | |
| fixed_pos_override_seq=fixed_pos_override_seq, | |
| pos_restrict_aatype=pos_restrict_aatype, | |
| symmetry_pos=symmetry_pos, | |
| num_protpardelle_conformers=num_protpardelle_conformers, | |
| run_af2_eval=run_af2_eval, | |
| ) | |
| finally: | |
| _gpu_duration_override = None | |