Spaces:
Running on Zero
Running on Zero
| """Integration tests for get_model, _setup_user_ensemble_dir, and design_sequences.""" | |
| from pathlib import Path | |
| from unittest.mock import MagicMock, patch | |
| import gradio as gr | |
| import pandas as pd | |
| import pytest | |
| import design | |
| import ensemble | |
| import models | |
| # --------------------------------------------------------------------------- | |
| # get_model | |
| # --------------------------------------------------------------------------- | |
| class TestGetModel: | |
| """Lazy-loads and caches CalibyModel instances via caliby.load_model.""" | |
| def _clear_model_cache(self): | |
| models.MODELS.clear() | |
| yield | |
| models.MODELS.clear() | |
| def test_calls_load_model_with_variant_and_device(self): | |
| mock_caliby_model = MagicMock() | |
| with patch("caliby.load_model", return_value=mock_caliby_model) as mock_load: | |
| result = models.get_model("caliby", "cpu") | |
| mock_load.assert_called_once_with("caliby", device="cpu") | |
| assert result is mock_caliby_model | |
| def test_caches_model_on_repeat_call(self): | |
| mock_caliby_model = MagicMock() | |
| with patch("caliby.load_model", return_value=mock_caliby_model) as mock_load: | |
| first = models.get_model("caliby", "cpu") | |
| second = models.get_model("caliby", "cpu") | |
| mock_load.assert_called_once() | |
| assert first is second | |
| def test_different_variants_cached_separately(self): | |
| mock_a = MagicMock() | |
| mock_b = MagicMock() | |
| with patch("caliby.load_model", side_effect=[mock_a, mock_b]): | |
| a = models.get_model("caliby", "cpu") | |
| b = models.get_model("soluble_caliby_v1", "cpu") | |
| assert a is mock_a | |
| assert b is mock_b | |
| # --------------------------------------------------------------------------- | |
| # _setup_user_ensemble_dir | |
| # --------------------------------------------------------------------------- | |
| class TestSetupUserEnsembleDir: | |
| """Builds pdb_to_conformers dict from user-uploaded files.""" | |
| def test_returns_dict_with_primary_key(self): | |
| result = ensemble._setup_user_ensemble_dir(["/tmp/primary.pdb", "/tmp/conf1.pdb", "/tmp/conf2.pdb"]) | |
| assert "primary" in result | |
| assert result["primary"] == ["/tmp/primary.pdb", "/tmp/conf1.pdb", "/tmp/conf2.pdb"] | |
| def test_first_file_is_primary(self): | |
| result = ensemble._setup_user_ensemble_dir(["/tmp/myprotein.cif", "/tmp/alt.pdb"]) | |
| assert result["myprotein"][0] == "/tmp/myprotein.cif" | |
| def test_uses_stem_as_key(self): | |
| result = ensemble._setup_user_ensemble_dir(["/path/to/foo.pdb"]) | |
| assert "foo" in result | |
| # --------------------------------------------------------------------------- | |
| # design_sequences — validation | |
| # --------------------------------------------------------------------------- | |
| class TestDesignSequencesValidation: | |
| """Input validation before any model calls.""" | |
| def test_no_files(self): | |
| df, msg, _, _, _, _ = design.design_sequences(None, "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31) | |
| assert df.empty | |
| assert "Upload at least one" in msg | |
| def test_empty_file_list(self): | |
| df, msg, _, _, _, _ = design.design_sequences([], "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31) | |
| assert df.empty | |
| assert "Upload at least one" in msg | |
| def test_single_mode_multiple_files(self): | |
| df, msg, _, _, _, _ = design.design_sequences( | |
| ["a.pdb", "b.pdb"], "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31 | |
| ) | |
| assert "exactly one file" in msg | |
| def test_synthetic_mode_multiple_files(self): | |
| df, msg, _, _, _, _ = design.design_sequences( | |
| ["a.pdb", "b.pdb"], "synthetic", "caliby", 4, None, 0.1, "", "", "", "", "", 31 | |
| ) | |
| assert "exactly one file" in msg | |
| def test_user_mode_too_few_files(self): | |
| df, msg, _, _, _, _ = design.design_sequences(["a.pdb"], "user", "caliby", 4, None, 0.1, "", "", "", "", "", 31) | |
| assert "at least two" in msg | |
| # --------------------------------------------------------------------------- | |
| # design_sequences — single structure mode | |
| # --------------------------------------------------------------------------- | |
| class TestDesignSequencesSingleMode: | |
| """Tests ensemble_mode='none' — verifies correct args to CalibyModel.sample().""" | |
| def _make_mock_outputs(self): | |
| return { | |
| "example_id": ["test"], | |
| "out_pdb": ["/tmp/test_sample0.cif"], | |
| "U": [-100.0], | |
| "input_seq": ["NATIVE"], | |
| "seq": ["ACDEF"], | |
| } | |
| def test_sample_called_with_correct_args(self, tmp_path): | |
| pdb_file = tmp_path / "test.pdb" | |
| pdb_file.write_text("FAKE PDB") | |
| mock_model = MagicMock() | |
| mock_model.sample.return_value = self._make_mock_outputs() | |
| with ( | |
| patch.object(design, "get_model", return_value=mock_model), | |
| patch.object(design, "_write_zip_from_paths", return_value=None), | |
| patch("torch.cuda.is_available", return_value=False), | |
| ): | |
| design.design_sequences( | |
| [str(pdb_file)], | |
| "none", | |
| "caliby", | |
| 4, | |
| ["C"], | |
| 0.5, | |
| "A1-100", | |
| "A1-10", | |
| "A26:A", | |
| "A26:AVG", | |
| "A10,B10", | |
| 31, | |
| ) | |
| mock_model.sample.assert_called_once() | |
| args, kwargs = mock_model.sample.call_args | |
| # First positional arg is pdb_paths | |
| assert isinstance(args[0], list) | |
| assert len(args[0]) == 1 | |
| assert args[0][0].endswith("test.pdb") | |
| assert kwargs["num_seqs_per_pdb"] == 4 | |
| assert kwargs["omit_aas"] == ["C"] | |
| assert kwargs["temperature"] == 0.5 | |
| assert kwargs["num_workers"] == 0 | |
| assert isinstance(kwargs["out_dir"], str) | |
| assert isinstance(kwargs["pos_constraint_df"], pd.DataFrame) | |
| assert kwargs["pos_constraint_df"].iloc[0]["pdb_key"] == "test" | |
| def test_no_constraints_passes_none(self, tmp_path): | |
| pdb_file = tmp_path / "test.pdb" | |
| pdb_file.write_text("FAKE") | |
| mock_model = MagicMock() | |
| mock_model.sample.return_value = self._make_mock_outputs() | |
| with ( | |
| patch.object(design, "get_model", return_value=mock_model), | |
| patch.object(design, "_write_zip_from_paths", return_value=None), | |
| patch("torch.cuda.is_available", return_value=False), | |
| ): | |
| design.design_sequences( | |
| [str(pdb_file)], | |
| "none", | |
| "caliby", | |
| 1, | |
| None, | |
| 0.1, | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| 31, | |
| ) | |
| assert mock_model.sample.call_args[1]["pos_constraint_df"] is None | |
| def test_empty_omit_aas_becomes_none(self, tmp_path): | |
| pdb_file = tmp_path / "test.pdb" | |
| pdb_file.write_text("FAKE") | |
| mock_model = MagicMock() | |
| mock_model.sample.return_value = self._make_mock_outputs() | |
| with ( | |
| patch.object(design, "get_model", return_value=mock_model), | |
| patch.object(design, "_write_zip_from_paths", return_value=None), | |
| patch("torch.cuda.is_available", return_value=False), | |
| ): | |
| design.design_sequences( | |
| [str(pdb_file)], | |
| "none", | |
| "caliby", | |
| 1, | |
| [], | |
| 0.1, | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| 31, | |
| ) | |
| assert mock_model.sample.call_args[1]["omit_aas"] is None | |
| # --------------------------------------------------------------------------- | |
| # design_sequences — user ensemble mode | |
| # --------------------------------------------------------------------------- | |
| class TestDesignSequencesUserEnsembleMode: | |
| """Tests ensemble_mode='user' — verifies correct args to CalibyModel.ensemble_sample().""" | |
| def _make_mock_outputs(self): | |
| return { | |
| "example_id": ["primary"], | |
| "out_pdb": ["/tmp/primary_sample0.cif"], | |
| "U": [-100.0], | |
| "input_seq": ["NATIVE"], | |
| "seq": ["AAA"], | |
| } | |
| def test_calls_ensemble_sample(self, tmp_path): | |
| pdb1 = tmp_path / "primary.pdb" | |
| pdb2 = tmp_path / "conf1.pdb" | |
| pdb1.write_text("PDB1") | |
| pdb2.write_text("PDB2") | |
| mock_model = MagicMock() | |
| mock_model.ensemble_sample.return_value = self._make_mock_outputs() | |
| mock_pdb_to_conf = {"primary": ["/some/primary.pdb", "/some/conf1.pdb"]} | |
| with ( | |
| patch.object(design, "get_model", return_value=mock_model), | |
| patch.object(design, "_setup_user_ensemble_dir", return_value=mock_pdb_to_conf), | |
| patch.object(design, "_write_zip_from_paths", return_value=None), | |
| patch("torch.cuda.is_available", return_value=False), | |
| ): | |
| design.design_sequences([str(pdb1), str(pdb2)], "user", "caliby", 4, None, 0.1, "", "", "", "", "", 31) | |
| mock_model.ensemble_sample.assert_called_once() | |
| args, kwargs = mock_model.ensemble_sample.call_args | |
| assert args[0] is mock_pdb_to_conf | |
| assert kwargs["pos_constraint_df"] is None | |
| def test_constraints_expand_via_make_ensemble_constraints(self, tmp_path): | |
| pdb1 = tmp_path / "primary.pdb" | |
| pdb2 = tmp_path / "conf1.pdb" | |
| pdb1.write_text("PDB1") | |
| pdb2.write_text("PDB2") | |
| mock_model = MagicMock() | |
| mock_model.ensemble_sample.return_value = self._make_mock_outputs() | |
| mock_pdb_to_conf = {"primary": ["a.pdb", "b.pdb"]} | |
| expanded_df = pd.DataFrame({"pdb_key": ["a", "b"], "fixed_pos_seq": ["A1-10", "A1-10"]}) | |
| with ( | |
| patch.object(design, "get_model", return_value=mock_model), | |
| patch.object(design, "_setup_user_ensemble_dir", return_value=mock_pdb_to_conf), | |
| patch("caliby.make_ensemble_constraints", return_value=expanded_df) as mock_expand, | |
| patch.object(design, "_write_zip_from_paths", return_value=None), | |
| patch("torch.cuda.is_available", return_value=False), | |
| ): | |
| design.design_sequences([str(pdb1), str(pdb2)], "user", "caliby", 1, None, 0.1, "A1-10", "", "", "", "", 31) | |
| mock_expand.assert_called_once() | |
| constraints_dict, pdb_to_conf_arg = mock_expand.call_args[0] | |
| assert isinstance(constraints_dict, dict) | |
| assert "primary" in constraints_dict | |
| assert constraints_dict["primary"]["fixed_pos_seq"] == "A1-10" | |
| assert pdb_to_conf_arg is mock_pdb_to_conf | |
| # --------------------------------------------------------------------------- | |
| # design_sequences — error handling | |
| # --------------------------------------------------------------------------- | |
| class TestDesignSequencesErrorHandling: | |
| """Verifies non-validation failures now raise naturally.""" | |
| def test_value_error(self, tmp_path): | |
| pdb_file = tmp_path / "test.pdb" | |
| pdb_file.write_text("PDB") | |
| with ( | |
| patch.object(design, "get_model", side_effect=ValueError("bad config")), | |
| patch("torch.cuda.is_available", return_value=False), | |
| ): | |
| with pytest.raises(ValueError, match="bad config"): | |
| design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31) | |
| def test_file_not_found(self, tmp_path): | |
| with ( | |
| patch.object(design, "get_model", side_effect=FileNotFoundError("missing.pdb")), | |
| patch("torch.cuda.is_available", return_value=False), | |
| ): | |
| with pytest.raises(FileNotFoundError, match="missing.pdb"): | |
| design.design_sequences( | |
| [str(tmp_path / "ghost.pdb")], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31 | |
| ) | |
| def test_unexpected_runtime_error(self, tmp_path): | |
| pdb_file = tmp_path / "test.pdb" | |
| pdb_file.write_text("PDB") | |
| with ( | |
| patch.object(design, "get_model", side_effect=RuntimeError("GPU OOM")), | |
| patch("torch.cuda.is_available", return_value=False), | |
| ): | |
| with pytest.raises(RuntimeError, match="GPU OOM"): | |
| design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31) | |
| # --------------------------------------------------------------------------- | |
| # design_sequences — zip output | |
| # --------------------------------------------------------------------------- | |
| class TestDesignSequencesZipOutput: | |
| """Tests ZIP file creation from output CIF files.""" | |
| def test_creates_zip_when_out_pdb_present(self, tmp_path): | |
| pdb_file = tmp_path / "test.pdb" | |
| pdb_file.write_text("PDB") | |
| out_cif = tmp_path / "test_sample0.cif" | |
| out_cif.write_text("CIF CONTENT") | |
| mock_model = MagicMock() | |
| mock_model.sample.return_value = { | |
| "example_id": ["test"], | |
| "out_pdb": [str(out_cif)], | |
| "U": [-100.0], | |
| "input_seq": ["NATIVE"], | |
| "seq": ["AAA"], | |
| } | |
| with ( | |
| patch.object(design, "get_model", return_value=mock_model), | |
| patch("torch.cuda.is_available", return_value=False), | |
| ): | |
| _, _, zip_path, _, _, _ = design.design_sequences( | |
| [str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31 | |
| ) | |
| assert zip_path is not None | |
| assert Path(zip_path).name == "test_designs.zip" | |
| assert Path(zip_path).exists() | |
| def test_empty_out_pdb_raises_for_invalid_caliby_output(self, tmp_path): | |
| pdb_file = tmp_path / "test.pdb" | |
| pdb_file.write_text("PDB") | |
| mock_model = MagicMock() | |
| mock_model.sample.return_value = { | |
| "example_id": ["test"], | |
| "out_pdb": [], | |
| "U": [-100.0], | |
| "input_seq": ["NATIVE"], | |
| "seq": ["AAA"], | |
| } | |
| with ( | |
| patch.object(design, "get_model", return_value=mock_model), | |
| patch("torch.cuda.is_available", return_value=False), | |
| ): | |
| with pytest.raises(ValueError, match="All arrays must be of the same length"): | |
| design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31) | |
| # --------------------------------------------------------------------------- | |
| # design_sequences — ZeroGPU quota-aware retry | |
| # --------------------------------------------------------------------------- | |
| class TestParseQuotaLeft: | |
| """Tests _parse_quota_left regex parsing of ZeroGPU error messages.""" | |
| def test_extracts_remaining_seconds(self): | |
| e = gr.Error("You have exceeded your free GPU quota (210s requested vs. 45s left). Try again in 0:02:45") | |
| assert design._parse_quota_left(e) == 45 | |
| def test_extracts_zero_remaining(self): | |
| e = gr.Error("(210s requested vs. 0s left). Try again in 0:03:30") | |
| assert design._parse_quota_left(e) == 0 | |
| def test_returns_none_for_non_quota_error(self): | |
| e = gr.Error("Some other error") | |
| assert design._parse_quota_left(e) is None | |
| def test_returns_none_for_no_message_attr(self): | |
| e = RuntimeError("no message attribute") | |
| assert design._parse_quota_left(e) is None | |
| class TestDesignSequencesQuotaRetry: | |
| """Tests ZeroGPU quota-aware retry logic in design_sequences wrapper.""" | |
| _DESIGN_ARGS = (None, "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31) | |
| def test_retry_on_quota_exceeded(self, tmp_path): | |
| pdb_file = tmp_path / "test.pdb" | |
| pdb_file.write_text("PDB") | |
| mock_model = MagicMock() | |
| mock_model.sample.return_value = { | |
| "example_id": ["test"], | |
| "out_pdb": ["/tmp/t.cif"], | |
| "U": [-100.0], | |
| "input_seq": ["N"], | |
| "seq": ["A"], | |
| } | |
| quota_error = gr.Error("(210s requested vs. 45s left). Try again in 0:02:45") | |
| call_count = 0 | |
| original_fn = design._design_sequences_gpu | |
| def side_effect(*args, **kwargs): | |
| nonlocal call_count | |
| call_count += 1 | |
| if call_count == 1: | |
| raise quota_error | |
| return original_fn(*args, **kwargs) | |
| with ( | |
| patch.object(design, "_design_sequences_gpu", side_effect=side_effect), | |
| patch.object(design, "get_model", return_value=mock_model), | |
| patch.object(design, "_write_zip_from_paths", return_value=None), | |
| patch("torch.cuda.is_available", return_value=False), | |
| ): | |
| design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31) | |
| assert call_count == 2 | |
| assert design._gpu_duration_override is None # Reset after retry | |
| def test_no_retry_when_remaining_zero(self): | |
| quota_error = gr.Error("(210s requested vs. 0s left). Try again in 0:03:30") | |
| with patch.object(design, "_design_sequences_gpu", side_effect=quota_error): | |
| with pytest.raises(gr.Error): | |
| design.design_sequences(*self._DESIGN_ARGS) | |
| def test_no_retry_for_non_quota_gr_error(self): | |
| other_error = gr.Error("The requested GPU duration (210s) is larger than the maximum allowed") | |
| with patch.object(design, "_design_sequences_gpu", side_effect=other_error): | |
| with pytest.raises(gr.Error, match="larger than the maximum allowed"): | |
| design.design_sequences(*self._DESIGN_ARGS) | |
| def test_non_gradio_errors_propagate(self): | |
| """ValueError, RuntimeError etc. are not caught by the retry logic.""" | |
| with patch.object(design, "_design_sequences_gpu", side_effect=ValueError("bad")): | |
| with pytest.raises(ValueError, match="bad"): | |
| design.design_sequences(*self._DESIGN_ARGS) | |