caliby / tests /test_design_sequences.py
Justine Yuan
Caliby HuggingFace example
3beba17
"""Integration tests for get_model, _setup_user_ensemble_dir, and design_sequences."""
from pathlib import Path
from unittest.mock import MagicMock, patch
import gradio as gr
import pandas as pd
import pytest
import design
import ensemble
import models
# ---------------------------------------------------------------------------
# get_model
# ---------------------------------------------------------------------------
class TestGetModel:
"""Lazy-loads and caches CalibyModel instances via caliby.load_model."""
@pytest.fixture(autouse=True)
def _clear_model_cache(self):
models.MODELS.clear()
yield
models.MODELS.clear()
def test_calls_load_model_with_variant_and_device(self):
mock_caliby_model = MagicMock()
with patch("caliby.load_model", return_value=mock_caliby_model) as mock_load:
result = models.get_model("caliby", "cpu")
mock_load.assert_called_once_with("caliby", device="cpu")
assert result is mock_caliby_model
def test_caches_model_on_repeat_call(self):
mock_caliby_model = MagicMock()
with patch("caliby.load_model", return_value=mock_caliby_model) as mock_load:
first = models.get_model("caliby", "cpu")
second = models.get_model("caliby", "cpu")
mock_load.assert_called_once()
assert first is second
def test_different_variants_cached_separately(self):
mock_a = MagicMock()
mock_b = MagicMock()
with patch("caliby.load_model", side_effect=[mock_a, mock_b]):
a = models.get_model("caliby", "cpu")
b = models.get_model("soluble_caliby_v1", "cpu")
assert a is mock_a
assert b is mock_b
# ---------------------------------------------------------------------------
# _setup_user_ensemble_dir
# ---------------------------------------------------------------------------
class TestSetupUserEnsembleDir:
"""Builds pdb_to_conformers dict from user-uploaded files."""
def test_returns_dict_with_primary_key(self):
result = ensemble._setup_user_ensemble_dir(["/tmp/primary.pdb", "/tmp/conf1.pdb", "/tmp/conf2.pdb"])
assert "primary" in result
assert result["primary"] == ["/tmp/primary.pdb", "/tmp/conf1.pdb", "/tmp/conf2.pdb"]
def test_first_file_is_primary(self):
result = ensemble._setup_user_ensemble_dir(["/tmp/myprotein.cif", "/tmp/alt.pdb"])
assert result["myprotein"][0] == "/tmp/myprotein.cif"
def test_uses_stem_as_key(self):
result = ensemble._setup_user_ensemble_dir(["/path/to/foo.pdb"])
assert "foo" in result
# ---------------------------------------------------------------------------
# design_sequences — validation
# ---------------------------------------------------------------------------
class TestDesignSequencesValidation:
"""Input validation before any model calls."""
def test_no_files(self):
df, msg, _, _, _, _ = design.design_sequences(None, "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
assert df.empty
assert "Upload at least one" in msg
def test_empty_file_list(self):
df, msg, _, _, _, _ = design.design_sequences([], "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
assert df.empty
assert "Upload at least one" in msg
def test_single_mode_multiple_files(self):
df, msg, _, _, _, _ = design.design_sequences(
["a.pdb", "b.pdb"], "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31
)
assert "exactly one file" in msg
def test_synthetic_mode_multiple_files(self):
df, msg, _, _, _, _ = design.design_sequences(
["a.pdb", "b.pdb"], "synthetic", "caliby", 4, None, 0.1, "", "", "", "", "", 31
)
assert "exactly one file" in msg
def test_user_mode_too_few_files(self):
df, msg, _, _, _, _ = design.design_sequences(["a.pdb"], "user", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
assert "at least two" in msg
# ---------------------------------------------------------------------------
# design_sequences — single structure mode
# ---------------------------------------------------------------------------
class TestDesignSequencesSingleMode:
"""Tests ensemble_mode='none' — verifies correct args to CalibyModel.sample()."""
def _make_mock_outputs(self):
return {
"example_id": ["test"],
"out_pdb": ["/tmp/test_sample0.cif"],
"U": [-100.0],
"input_seq": ["NATIVE"],
"seq": ["ACDEF"],
}
def test_sample_called_with_correct_args(self, tmp_path):
pdb_file = tmp_path / "test.pdb"
pdb_file.write_text("FAKE PDB")
mock_model = MagicMock()
mock_model.sample.return_value = self._make_mock_outputs()
with (
patch.object(design, "get_model", return_value=mock_model),
patch.object(design, "_write_zip_from_paths", return_value=None),
patch("torch.cuda.is_available", return_value=False),
):
design.design_sequences(
[str(pdb_file)],
"none",
"caliby",
4,
["C"],
0.5,
"A1-100",
"A1-10",
"A26:A",
"A26:AVG",
"A10,B10",
31,
)
mock_model.sample.assert_called_once()
args, kwargs = mock_model.sample.call_args
# First positional arg is pdb_paths
assert isinstance(args[0], list)
assert len(args[0]) == 1
assert args[0][0].endswith("test.pdb")
assert kwargs["num_seqs_per_pdb"] == 4
assert kwargs["omit_aas"] == ["C"]
assert kwargs["temperature"] == 0.5
assert kwargs["num_workers"] == 0
assert isinstance(kwargs["out_dir"], str)
assert isinstance(kwargs["pos_constraint_df"], pd.DataFrame)
assert kwargs["pos_constraint_df"].iloc[0]["pdb_key"] == "test"
def test_no_constraints_passes_none(self, tmp_path):
pdb_file = tmp_path / "test.pdb"
pdb_file.write_text("FAKE")
mock_model = MagicMock()
mock_model.sample.return_value = self._make_mock_outputs()
with (
patch.object(design, "get_model", return_value=mock_model),
patch.object(design, "_write_zip_from_paths", return_value=None),
patch("torch.cuda.is_available", return_value=False),
):
design.design_sequences(
[str(pdb_file)],
"none",
"caliby",
1,
None,
0.1,
"",
"",
"",
"",
"",
31,
)
assert mock_model.sample.call_args[1]["pos_constraint_df"] is None
def test_empty_omit_aas_becomes_none(self, tmp_path):
pdb_file = tmp_path / "test.pdb"
pdb_file.write_text("FAKE")
mock_model = MagicMock()
mock_model.sample.return_value = self._make_mock_outputs()
with (
patch.object(design, "get_model", return_value=mock_model),
patch.object(design, "_write_zip_from_paths", return_value=None),
patch("torch.cuda.is_available", return_value=False),
):
design.design_sequences(
[str(pdb_file)],
"none",
"caliby",
1,
[],
0.1,
"",
"",
"",
"",
"",
31,
)
assert mock_model.sample.call_args[1]["omit_aas"] is None
# ---------------------------------------------------------------------------
# design_sequences — user ensemble mode
# ---------------------------------------------------------------------------
class TestDesignSequencesUserEnsembleMode:
"""Tests ensemble_mode='user' — verifies correct args to CalibyModel.ensemble_sample()."""
def _make_mock_outputs(self):
return {
"example_id": ["primary"],
"out_pdb": ["/tmp/primary_sample0.cif"],
"U": [-100.0],
"input_seq": ["NATIVE"],
"seq": ["AAA"],
}
def test_calls_ensemble_sample(self, tmp_path):
pdb1 = tmp_path / "primary.pdb"
pdb2 = tmp_path / "conf1.pdb"
pdb1.write_text("PDB1")
pdb2.write_text("PDB2")
mock_model = MagicMock()
mock_model.ensemble_sample.return_value = self._make_mock_outputs()
mock_pdb_to_conf = {"primary": ["/some/primary.pdb", "/some/conf1.pdb"]}
with (
patch.object(design, "get_model", return_value=mock_model),
patch.object(design, "_setup_user_ensemble_dir", return_value=mock_pdb_to_conf),
patch.object(design, "_write_zip_from_paths", return_value=None),
patch("torch.cuda.is_available", return_value=False),
):
design.design_sequences([str(pdb1), str(pdb2)], "user", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
mock_model.ensemble_sample.assert_called_once()
args, kwargs = mock_model.ensemble_sample.call_args
assert args[0] is mock_pdb_to_conf
assert kwargs["pos_constraint_df"] is None
def test_constraints_expand_via_make_ensemble_constraints(self, tmp_path):
pdb1 = tmp_path / "primary.pdb"
pdb2 = tmp_path / "conf1.pdb"
pdb1.write_text("PDB1")
pdb2.write_text("PDB2")
mock_model = MagicMock()
mock_model.ensemble_sample.return_value = self._make_mock_outputs()
mock_pdb_to_conf = {"primary": ["a.pdb", "b.pdb"]}
expanded_df = pd.DataFrame({"pdb_key": ["a", "b"], "fixed_pos_seq": ["A1-10", "A1-10"]})
with (
patch.object(design, "get_model", return_value=mock_model),
patch.object(design, "_setup_user_ensemble_dir", return_value=mock_pdb_to_conf),
patch("caliby.make_ensemble_constraints", return_value=expanded_df) as mock_expand,
patch.object(design, "_write_zip_from_paths", return_value=None),
patch("torch.cuda.is_available", return_value=False),
):
design.design_sequences([str(pdb1), str(pdb2)], "user", "caliby", 1, None, 0.1, "A1-10", "", "", "", "", 31)
mock_expand.assert_called_once()
constraints_dict, pdb_to_conf_arg = mock_expand.call_args[0]
assert isinstance(constraints_dict, dict)
assert "primary" in constraints_dict
assert constraints_dict["primary"]["fixed_pos_seq"] == "A1-10"
assert pdb_to_conf_arg is mock_pdb_to_conf
# ---------------------------------------------------------------------------
# design_sequences — error handling
# ---------------------------------------------------------------------------
class TestDesignSequencesErrorHandling:
"""Verifies non-validation failures now raise naturally."""
def test_value_error(self, tmp_path):
pdb_file = tmp_path / "test.pdb"
pdb_file.write_text("PDB")
with (
patch.object(design, "get_model", side_effect=ValueError("bad config")),
patch("torch.cuda.is_available", return_value=False),
):
with pytest.raises(ValueError, match="bad config"):
design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
def test_file_not_found(self, tmp_path):
with (
patch.object(design, "get_model", side_effect=FileNotFoundError("missing.pdb")),
patch("torch.cuda.is_available", return_value=False),
):
with pytest.raises(FileNotFoundError, match="missing.pdb"):
design.design_sequences(
[str(tmp_path / "ghost.pdb")], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31
)
def test_unexpected_runtime_error(self, tmp_path):
pdb_file = tmp_path / "test.pdb"
pdb_file.write_text("PDB")
with (
patch.object(design, "get_model", side_effect=RuntimeError("GPU OOM")),
patch("torch.cuda.is_available", return_value=False),
):
with pytest.raises(RuntimeError, match="GPU OOM"):
design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
# ---------------------------------------------------------------------------
# design_sequences — zip output
# ---------------------------------------------------------------------------
class TestDesignSequencesZipOutput:
"""Tests ZIP file creation from output CIF files."""
def test_creates_zip_when_out_pdb_present(self, tmp_path):
pdb_file = tmp_path / "test.pdb"
pdb_file.write_text("PDB")
out_cif = tmp_path / "test_sample0.cif"
out_cif.write_text("CIF CONTENT")
mock_model = MagicMock()
mock_model.sample.return_value = {
"example_id": ["test"],
"out_pdb": [str(out_cif)],
"U": [-100.0],
"input_seq": ["NATIVE"],
"seq": ["AAA"],
}
with (
patch.object(design, "get_model", return_value=mock_model),
patch("torch.cuda.is_available", return_value=False),
):
_, _, zip_path, _, _, _ = design.design_sequences(
[str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31
)
assert zip_path is not None
assert Path(zip_path).name == "test_designs.zip"
assert Path(zip_path).exists()
def test_empty_out_pdb_raises_for_invalid_caliby_output(self, tmp_path):
pdb_file = tmp_path / "test.pdb"
pdb_file.write_text("PDB")
mock_model = MagicMock()
mock_model.sample.return_value = {
"example_id": ["test"],
"out_pdb": [],
"U": [-100.0],
"input_seq": ["NATIVE"],
"seq": ["AAA"],
}
with (
patch.object(design, "get_model", return_value=mock_model),
patch("torch.cuda.is_available", return_value=False),
):
with pytest.raises(ValueError, match="All arrays must be of the same length"):
design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
# ---------------------------------------------------------------------------
# design_sequences — ZeroGPU quota-aware retry
# ---------------------------------------------------------------------------
class TestParseQuotaLeft:
"""Tests _parse_quota_left regex parsing of ZeroGPU error messages."""
def test_extracts_remaining_seconds(self):
e = gr.Error("You have exceeded your free GPU quota (210s requested vs. 45s left). Try again in 0:02:45")
assert design._parse_quota_left(e) == 45
def test_extracts_zero_remaining(self):
e = gr.Error("(210s requested vs. 0s left). Try again in 0:03:30")
assert design._parse_quota_left(e) == 0
def test_returns_none_for_non_quota_error(self):
e = gr.Error("Some other error")
assert design._parse_quota_left(e) is None
def test_returns_none_for_no_message_attr(self):
e = RuntimeError("no message attribute")
assert design._parse_quota_left(e) is None
class TestDesignSequencesQuotaRetry:
"""Tests ZeroGPU quota-aware retry logic in design_sequences wrapper."""
_DESIGN_ARGS = (None, "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31)
def test_retry_on_quota_exceeded(self, tmp_path):
pdb_file = tmp_path / "test.pdb"
pdb_file.write_text("PDB")
mock_model = MagicMock()
mock_model.sample.return_value = {
"example_id": ["test"],
"out_pdb": ["/tmp/t.cif"],
"U": [-100.0],
"input_seq": ["N"],
"seq": ["A"],
}
quota_error = gr.Error("(210s requested vs. 45s left). Try again in 0:02:45")
call_count = 0
original_fn = design._design_sequences_gpu
def side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise quota_error
return original_fn(*args, **kwargs)
with (
patch.object(design, "_design_sequences_gpu", side_effect=side_effect),
patch.object(design, "get_model", return_value=mock_model),
patch.object(design, "_write_zip_from_paths", return_value=None),
patch("torch.cuda.is_available", return_value=False),
):
design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31)
assert call_count == 2
assert design._gpu_duration_override is None # Reset after retry
def test_no_retry_when_remaining_zero(self):
quota_error = gr.Error("(210s requested vs. 0s left). Try again in 0:03:30")
with patch.object(design, "_design_sequences_gpu", side_effect=quota_error):
with pytest.raises(gr.Error):
design.design_sequences(*self._DESIGN_ARGS)
def test_no_retry_for_non_quota_gr_error(self):
other_error = gr.Error("The requested GPU duration (210s) is larger than the maximum allowed")
with patch.object(design, "_design_sequences_gpu", side_effect=other_error):
with pytest.raises(gr.Error, match="larger than the maximum allowed"):
design.design_sequences(*self._DESIGN_ARGS)
def test_non_gradio_errors_propagate(self):
"""ValueError, RuntimeError etc. are not caught by the retry logic."""
with patch.object(design, "_design_sequences_gpu", side_effect=ValueError("bad")):
with pytest.raises(ValueError, match="bad"):
design.design_sequences(*self._DESIGN_ARGS)