"""Integration tests for get_model, _setup_user_ensemble_dir, and design_sequences.""" from pathlib import Path from unittest.mock import MagicMock, patch import gradio as gr import pandas as pd import pytest import design import ensemble import models # --------------------------------------------------------------------------- # get_model # --------------------------------------------------------------------------- class TestGetModel: """Lazy-loads and caches CalibyModel instances via caliby.load_model.""" @pytest.fixture(autouse=True) def _clear_model_cache(self): models.MODELS.clear() yield models.MODELS.clear() def test_calls_load_model_with_variant_and_device(self): mock_caliby_model = MagicMock() with patch("caliby.load_model", return_value=mock_caliby_model) as mock_load: result = models.get_model("caliby", "cpu") mock_load.assert_called_once_with("caliby", device="cpu") assert result is mock_caliby_model def test_caches_model_on_repeat_call(self): mock_caliby_model = MagicMock() with patch("caliby.load_model", return_value=mock_caliby_model) as mock_load: first = models.get_model("caliby", "cpu") second = models.get_model("caliby", "cpu") mock_load.assert_called_once() assert first is second def test_different_variants_cached_separately(self): mock_a = MagicMock() mock_b = MagicMock() with patch("caliby.load_model", side_effect=[mock_a, mock_b]): a = models.get_model("caliby", "cpu") b = models.get_model("soluble_caliby_v1", "cpu") assert a is mock_a assert b is mock_b # --------------------------------------------------------------------------- # _setup_user_ensemble_dir # --------------------------------------------------------------------------- class TestSetupUserEnsembleDir: """Builds pdb_to_conformers dict from user-uploaded files.""" def test_returns_dict_with_primary_key(self): result = ensemble._setup_user_ensemble_dir(["/tmp/primary.pdb", "/tmp/conf1.pdb", "/tmp/conf2.pdb"]) assert "primary" in result assert result["primary"] == ["/tmp/primary.pdb", "/tmp/conf1.pdb", "/tmp/conf2.pdb"] def test_first_file_is_primary(self): result = ensemble._setup_user_ensemble_dir(["/tmp/myprotein.cif", "/tmp/alt.pdb"]) assert result["myprotein"][0] == "/tmp/myprotein.cif" def test_uses_stem_as_key(self): result = ensemble._setup_user_ensemble_dir(["/path/to/foo.pdb"]) assert "foo" in result # --------------------------------------------------------------------------- # design_sequences — validation # --------------------------------------------------------------------------- class TestDesignSequencesValidation: """Input validation before any model calls.""" def test_no_files(self): df, msg, _, _, _, _ = design.design_sequences(None, "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31) assert df.empty assert "Upload at least one" in msg def test_empty_file_list(self): df, msg, _, _, _, _ = design.design_sequences([], "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31) assert df.empty assert "Upload at least one" in msg def test_single_mode_multiple_files(self): df, msg, _, _, _, _ = design.design_sequences( ["a.pdb", "b.pdb"], "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31 ) assert "exactly one file" in msg def test_synthetic_mode_multiple_files(self): df, msg, _, _, _, _ = design.design_sequences( ["a.pdb", "b.pdb"], "synthetic", "caliby", 4, None, 0.1, "", "", "", "", "", 31 ) assert "exactly one file" in msg def test_user_mode_too_few_files(self): df, msg, _, _, _, _ = design.design_sequences(["a.pdb"], "user", "caliby", 4, None, 0.1, "", "", "", "", "", 31) assert "at least two" in msg # --------------------------------------------------------------------------- # design_sequences — single structure mode # --------------------------------------------------------------------------- class TestDesignSequencesSingleMode: """Tests ensemble_mode='none' — verifies correct args to CalibyModel.sample().""" def _make_mock_outputs(self): return { "example_id": ["test"], "out_pdb": ["/tmp/test_sample0.cif"], "U": [-100.0], "input_seq": ["NATIVE"], "seq": ["ACDEF"], } def test_sample_called_with_correct_args(self, tmp_path): pdb_file = tmp_path / "test.pdb" pdb_file.write_text("FAKE PDB") mock_model = MagicMock() mock_model.sample.return_value = self._make_mock_outputs() with ( patch.object(design, "get_model", return_value=mock_model), patch.object(design, "_write_zip_from_paths", return_value=None), patch("torch.cuda.is_available", return_value=False), ): design.design_sequences( [str(pdb_file)], "none", "caliby", 4, ["C"], 0.5, "A1-100", "A1-10", "A26:A", "A26:AVG", "A10,B10", 31, ) mock_model.sample.assert_called_once() args, kwargs = mock_model.sample.call_args # First positional arg is pdb_paths assert isinstance(args[0], list) assert len(args[0]) == 1 assert args[0][0].endswith("test.pdb") assert kwargs["num_seqs_per_pdb"] == 4 assert kwargs["omit_aas"] == ["C"] assert kwargs["temperature"] == 0.5 assert kwargs["num_workers"] == 0 assert isinstance(kwargs["out_dir"], str) assert isinstance(kwargs["pos_constraint_df"], pd.DataFrame) assert kwargs["pos_constraint_df"].iloc[0]["pdb_key"] == "test" def test_no_constraints_passes_none(self, tmp_path): pdb_file = tmp_path / "test.pdb" pdb_file.write_text("FAKE") mock_model = MagicMock() mock_model.sample.return_value = self._make_mock_outputs() with ( patch.object(design, "get_model", return_value=mock_model), patch.object(design, "_write_zip_from_paths", return_value=None), patch("torch.cuda.is_available", return_value=False), ): design.design_sequences( [str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31, ) assert mock_model.sample.call_args[1]["pos_constraint_df"] is None def test_empty_omit_aas_becomes_none(self, tmp_path): pdb_file = tmp_path / "test.pdb" pdb_file.write_text("FAKE") mock_model = MagicMock() mock_model.sample.return_value = self._make_mock_outputs() with ( patch.object(design, "get_model", return_value=mock_model), patch.object(design, "_write_zip_from_paths", return_value=None), patch("torch.cuda.is_available", return_value=False), ): design.design_sequences( [str(pdb_file)], "none", "caliby", 1, [], 0.1, "", "", "", "", "", 31, ) assert mock_model.sample.call_args[1]["omit_aas"] is None # --------------------------------------------------------------------------- # design_sequences — user ensemble mode # --------------------------------------------------------------------------- class TestDesignSequencesUserEnsembleMode: """Tests ensemble_mode='user' — verifies correct args to CalibyModel.ensemble_sample().""" def _make_mock_outputs(self): return { "example_id": ["primary"], "out_pdb": ["/tmp/primary_sample0.cif"], "U": [-100.0], "input_seq": ["NATIVE"], "seq": ["AAA"], } def test_calls_ensemble_sample(self, tmp_path): pdb1 = tmp_path / "primary.pdb" pdb2 = tmp_path / "conf1.pdb" pdb1.write_text("PDB1") pdb2.write_text("PDB2") mock_model = MagicMock() mock_model.ensemble_sample.return_value = self._make_mock_outputs() mock_pdb_to_conf = {"primary": ["/some/primary.pdb", "/some/conf1.pdb"]} with ( patch.object(design, "get_model", return_value=mock_model), patch.object(design, "_setup_user_ensemble_dir", return_value=mock_pdb_to_conf), patch.object(design, "_write_zip_from_paths", return_value=None), patch("torch.cuda.is_available", return_value=False), ): design.design_sequences([str(pdb1), str(pdb2)], "user", "caliby", 4, None, 0.1, "", "", "", "", "", 31) mock_model.ensemble_sample.assert_called_once() args, kwargs = mock_model.ensemble_sample.call_args assert args[0] is mock_pdb_to_conf assert kwargs["pos_constraint_df"] is None def test_constraints_expand_via_make_ensemble_constraints(self, tmp_path): pdb1 = tmp_path / "primary.pdb" pdb2 = tmp_path / "conf1.pdb" pdb1.write_text("PDB1") pdb2.write_text("PDB2") mock_model = MagicMock() mock_model.ensemble_sample.return_value = self._make_mock_outputs() mock_pdb_to_conf = {"primary": ["a.pdb", "b.pdb"]} expanded_df = pd.DataFrame({"pdb_key": ["a", "b"], "fixed_pos_seq": ["A1-10", "A1-10"]}) with ( patch.object(design, "get_model", return_value=mock_model), patch.object(design, "_setup_user_ensemble_dir", return_value=mock_pdb_to_conf), patch("caliby.make_ensemble_constraints", return_value=expanded_df) as mock_expand, patch.object(design, "_write_zip_from_paths", return_value=None), patch("torch.cuda.is_available", return_value=False), ): design.design_sequences([str(pdb1), str(pdb2)], "user", "caliby", 1, None, 0.1, "A1-10", "", "", "", "", 31) mock_expand.assert_called_once() constraints_dict, pdb_to_conf_arg = mock_expand.call_args[0] assert isinstance(constraints_dict, dict) assert "primary" in constraints_dict assert constraints_dict["primary"]["fixed_pos_seq"] == "A1-10" assert pdb_to_conf_arg is mock_pdb_to_conf # --------------------------------------------------------------------------- # design_sequences — error handling # --------------------------------------------------------------------------- class TestDesignSequencesErrorHandling: """Verifies non-validation failures now raise naturally.""" def test_value_error(self, tmp_path): pdb_file = tmp_path / "test.pdb" pdb_file.write_text("PDB") with ( patch.object(design, "get_model", side_effect=ValueError("bad config")), patch("torch.cuda.is_available", return_value=False), ): with pytest.raises(ValueError, match="bad config"): design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31) def test_file_not_found(self, tmp_path): with ( patch.object(design, "get_model", side_effect=FileNotFoundError("missing.pdb")), patch("torch.cuda.is_available", return_value=False), ): with pytest.raises(FileNotFoundError, match="missing.pdb"): design.design_sequences( [str(tmp_path / "ghost.pdb")], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31 ) def test_unexpected_runtime_error(self, tmp_path): pdb_file = tmp_path / "test.pdb" pdb_file.write_text("PDB") with ( patch.object(design, "get_model", side_effect=RuntimeError("GPU OOM")), patch("torch.cuda.is_available", return_value=False), ): with pytest.raises(RuntimeError, match="GPU OOM"): design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31) # --------------------------------------------------------------------------- # design_sequences — zip output # --------------------------------------------------------------------------- class TestDesignSequencesZipOutput: """Tests ZIP file creation from output CIF files.""" def test_creates_zip_when_out_pdb_present(self, tmp_path): pdb_file = tmp_path / "test.pdb" pdb_file.write_text("PDB") out_cif = tmp_path / "test_sample0.cif" out_cif.write_text("CIF CONTENT") mock_model = MagicMock() mock_model.sample.return_value = { "example_id": ["test"], "out_pdb": [str(out_cif)], "U": [-100.0], "input_seq": ["NATIVE"], "seq": ["AAA"], } with ( patch.object(design, "get_model", return_value=mock_model), patch("torch.cuda.is_available", return_value=False), ): _, _, zip_path, _, _, _ = design.design_sequences( [str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31 ) assert zip_path is not None assert Path(zip_path).name == "test_designs.zip" assert Path(zip_path).exists() def test_empty_out_pdb_raises_for_invalid_caliby_output(self, tmp_path): pdb_file = tmp_path / "test.pdb" pdb_file.write_text("PDB") mock_model = MagicMock() mock_model.sample.return_value = { "example_id": ["test"], "out_pdb": [], "U": [-100.0], "input_seq": ["NATIVE"], "seq": ["AAA"], } with ( patch.object(design, "get_model", return_value=mock_model), patch("torch.cuda.is_available", return_value=False), ): with pytest.raises(ValueError, match="All arrays must be of the same length"): design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31) # --------------------------------------------------------------------------- # design_sequences — ZeroGPU quota-aware retry # --------------------------------------------------------------------------- class TestParseQuotaLeft: """Tests _parse_quota_left regex parsing of ZeroGPU error messages.""" def test_extracts_remaining_seconds(self): e = gr.Error("You have exceeded your free GPU quota (210s requested vs. 45s left). Try again in 0:02:45") assert design._parse_quota_left(e) == 45 def test_extracts_zero_remaining(self): e = gr.Error("(210s requested vs. 0s left). Try again in 0:03:30") assert design._parse_quota_left(e) == 0 def test_returns_none_for_non_quota_error(self): e = gr.Error("Some other error") assert design._parse_quota_left(e) is None def test_returns_none_for_no_message_attr(self): e = RuntimeError("no message attribute") assert design._parse_quota_left(e) is None class TestDesignSequencesQuotaRetry: """Tests ZeroGPU quota-aware retry logic in design_sequences wrapper.""" _DESIGN_ARGS = (None, "none", "caliby", 4, None, 0.1, "", "", "", "", "", 31) def test_retry_on_quota_exceeded(self, tmp_path): pdb_file = tmp_path / "test.pdb" pdb_file.write_text("PDB") mock_model = MagicMock() mock_model.sample.return_value = { "example_id": ["test"], "out_pdb": ["/tmp/t.cif"], "U": [-100.0], "input_seq": ["N"], "seq": ["A"], } quota_error = gr.Error("(210s requested vs. 45s left). Try again in 0:02:45") call_count = 0 original_fn = design._design_sequences_gpu def side_effect(*args, **kwargs): nonlocal call_count call_count += 1 if call_count == 1: raise quota_error return original_fn(*args, **kwargs) with ( patch.object(design, "_design_sequences_gpu", side_effect=side_effect), patch.object(design, "get_model", return_value=mock_model), patch.object(design, "_write_zip_from_paths", return_value=None), patch("torch.cuda.is_available", return_value=False), ): design.design_sequences([str(pdb_file)], "none", "caliby", 1, None, 0.1, "", "", "", "", "", 31) assert call_count == 2 assert design._gpu_duration_override is None # Reset after retry def test_no_retry_when_remaining_zero(self): quota_error = gr.Error("(210s requested vs. 0s left). Try again in 0:03:30") with patch.object(design, "_design_sequences_gpu", side_effect=quota_error): with pytest.raises(gr.Error): design.design_sequences(*self._DESIGN_ARGS) def test_no_retry_for_non_quota_gr_error(self): other_error = gr.Error("The requested GPU duration (210s) is larger than the maximum allowed") with patch.object(design, "_design_sequences_gpu", side_effect=other_error): with pytest.raises(gr.Error, match="larger than the maximum allowed"): design.design_sequences(*self._DESIGN_ARGS) def test_non_gradio_errors_propagate(self): """ValueError, RuntimeError etc. are not caught by the retry logic.""" with patch.object(design, "_design_sequences_gpu", side_effect=ValueError("bad")): with pytest.raises(ValueError, match="bad"): design.design_sequences(*self._DESIGN_ARGS)