"""Core design pipeline: context building, execution, output formatting.""" import re import tempfile from pathlib import Path import gradio as gr import pandas as pd import spaces import torch from app_config import WEIGHTS_DIR from constraints import _build_pos_constraint_df, _validate_design_inputs from ensemble import _generate_protpardelle_ensemble, _setup_user_ensemble_dir from file_utils import _copy_uploaded_files, _get_file_path, _sanitize_download_stem, _write_zip_from_paths from models import get_model from self_consistency import _run_self_consistency # ZeroGPU quota-aware retry: request the max duration first, and if the # scheduler returns a quota error (which is free — no GPU time consumed), # parse the remaining seconds and retry with that exact amount. _MAX_GPU_DURATION = 120 # Per-call max; daily quota is 210s but per-call cap is lower _gpu_duration_override: int | None = None def _dynamic_gpu_duration(*args, **kwargs) -> int: """Return the current GPU duration for @spaces.GPU scheduling.""" return _gpu_duration_override if _gpu_duration_override is not None else _MAX_GPU_DURATION def _parse_quota_left(error: Exception) -> int | None: """Extract remaining GPU seconds from a ZeroGPU quota error message. Returns the number of seconds left, or None if not a recoverable quota error. """ message = getattr(error, 'message', None) if not isinstance(message, str): return None match = re.search(r'(\d+)s left\)', message) return int(match.group(1)) if match else None def _build_design_context( pdb_paths: list[str], ensemble_mode: str, tmpdir: Path, num_protpardelle_conformers: int, fixed_pos_seq: str, fixed_pos_scn: str, fixed_pos_override_seq: str, pos_restrict_aatype: str, symmetry_pos: str, ) -> tuple[list[str] | dict[str, list[str]], pd.DataFrame | None]: pdb_key = Path(pdb_paths[0]).stem pos_constraint_df = _build_pos_constraint_df( pdb_key=pdb_key, fixed_pos_seq=fixed_pos_seq, fixed_pos_scn=fixed_pos_scn, fixed_pos_override_seq=fixed_pos_override_seq, pos_restrict_aatype=pos_restrict_aatype, symmetry_pos=symmetry_pos, ) if ensemble_mode == "none": return pdb_paths, pos_constraint_df if ensemble_mode == "synthetic": design_inputs = _generate_protpardelle_ensemble( pdb_path=pdb_paths[0], num_conformers=num_protpardelle_conformers, out_dir=tmpdir, weights_dir=WEIGHTS_DIR, ) else: design_inputs = _setup_user_ensemble_dir(pdb_paths=pdb_paths) if pos_constraint_df is not None: from caliby import make_ensemble_constraints row = pos_constraint_df.iloc[0] cols = {col: row[col] for col in pos_constraint_df.columns if col != "pdb_key"} pos_constraint_df = make_ensemble_constraints({pdb_key: cols}, design_inputs) return design_inputs, pos_constraint_df def _format_outputs(outputs: dict) -> tuple[pd.DataFrame, str, list[str]]: out_pdb_list = outputs["out_pdb"] df = pd.DataFrame( { "Sample": [Path(out_pdb).stem for out_pdb in out_pdb_list], "Sequence": outputs["seq"], "Energy (U)": outputs["U"], } ) fasta_lines = [] for i, (eid, seq) in enumerate(zip(outputs["example_id"], outputs["seq"])): fasta_lines.append(f">{eid}_sample{i}") fasta_lines.append(seq) fasta_text = "\n".join(fasta_lines) return df, fasta_text, out_pdb_list @spaces.GPU(duration=_dynamic_gpu_duration) def _design_sequences_gpu( pdb_files: list | None, ensemble_mode: str, model_variant: str, num_seqs: int, omit_aas: list[str] | None, temperature: float, fixed_pos_seq: str, fixed_pos_scn: str, fixed_pos_override_seq: str, pos_restrict_aatype: str, symmetry_pos: str, num_protpardelle_conformers: int, run_af2_eval: bool = False, ): validation_error = _validate_design_inputs(pdb_files, ensemble_mode) if validation_error: return pd.DataFrame(), validation_error, None, None, {}, {} device = "cuda" if torch.cuda.is_available() else "cpu" torch.set_grad_enabled(False) download_stem = _sanitize_download_stem(_get_file_path(pdb_files[0]).stem) gr.Info("Loading model...") model = get_model(model_variant, device) with tempfile.TemporaryDirectory() as tmpdir: tmpdir = Path(tmpdir) pdb_paths = _copy_uploaded_files(pdb_files, tmpdir) input_pdb_data = {Path(p).stem: Path(p).read_text() for p in pdb_paths} out_dir = tmpdir / "outputs" out_dir.mkdir(parents=True, exist_ok=True) if ensemble_mode == "synthetic": gr.Info("Generating conformer ensemble...") elif ensemble_mode == "user": gr.Info("Preparing uploaded ensemble...") design_inputs, pos_constraint_df = _build_design_context( pdb_paths=pdb_paths, ensemble_mode=ensemble_mode, tmpdir=tmpdir, num_protpardelle_conformers=num_protpardelle_conformers, fixed_pos_seq=fixed_pos_seq, fixed_pos_scn=fixed_pos_scn, fixed_pos_override_seq=fixed_pos_override_seq, pos_restrict_aatype=pos_restrict_aatype, symmetry_pos=symmetry_pos, ) gr.Info("Designing sequences...") sample_kwargs = dict( out_dir=str(out_dir), num_seqs_per_pdb=num_seqs, omit_aas=omit_aas if omit_aas else None, temperature=temperature, num_workers=0, pos_constraint_df=pos_constraint_df, ) if ensemble_mode == "none": outputs = model.sample(design_inputs, **sample_kwargs) else: outputs = model.ensemble_sample(design_inputs, **sample_kwargs) df, fasta_text, out_pdb_list = _format_outputs(outputs) sc_zip_path = None af2_pdb_data = {} if run_af2_eval: gr.Info("Running AF2 self-consistency evaluation...") sc_zip_path, af2_pdb_data = _run_self_consistency(model, df, out_pdb_list, out_dir, download_stem) out_zip_path = _write_zip_from_paths(out_pdb_list, download_stem, "_designs.zip") return df, fasta_text, out_zip_path, sc_zip_path, af2_pdb_data, input_pdb_data def design_sequences( pdb_files: list | None, ensemble_mode: str, model_variant: str, num_seqs: int, omit_aas: list[str] | None, temperature: float, fixed_pos_seq: str, fixed_pos_scn: str, fixed_pos_override_seq: str, pos_restrict_aatype: str, symmetry_pos: str, num_protpardelle_conformers: int, run_af2_eval: bool = False, ): """Run sequence design with ZeroGPU quota-aware retry. Requests the max GPU duration first. If the scheduler returns a quota error (free — no GPU time consumed), parses the remaining seconds and retries with that exact amount to maximize GPU utilization. """ global _gpu_duration_override _gpu_duration_override = None try: return _design_sequences_gpu( pdb_files=pdb_files, ensemble_mode=ensemble_mode, model_variant=model_variant, num_seqs=num_seqs, omit_aas=omit_aas, temperature=temperature, fixed_pos_seq=fixed_pos_seq, fixed_pos_scn=fixed_pos_scn, fixed_pos_override_seq=fixed_pos_override_seq, pos_restrict_aatype=pos_restrict_aatype, symmetry_pos=symmetry_pos, num_protpardelle_conformers=num_protpardelle_conformers, run_af2_eval=run_af2_eval, ) except gr.Error as e: remaining = _parse_quota_left(e) print(f"[ZeroGPU retry] Caught gr.Error, parsed remaining={remaining}, message={getattr(e, 'message', str(e))}") if remaining is None or remaining <= 0: raise gr.Info(f"GPU quota: {remaining}s remaining, retrying with exact quota") _gpu_duration_override = remaining - 1 try: return _design_sequences_gpu( pdb_files=pdb_files, ensemble_mode=ensemble_mode, model_variant=model_variant, num_seqs=num_seqs, omit_aas=omit_aas, temperature=temperature, fixed_pos_seq=fixed_pos_seq, fixed_pos_scn=fixed_pos_scn, fixed_pos_override_seq=fixed_pos_override_seq, pos_restrict_aatype=pos_restrict_aatype, symmetry_pos=symmetry_pos, num_protpardelle_conformers=num_protpardelle_conformers, run_af2_eval=run_af2_eval, ) finally: _gpu_duration_override = None