caliby / tests /test_helpers.py
Justine Yuan
Caliby HuggingFace example
3beba17
"""Unit tests for helper functions."""
import types
from pathlib import Path
import pandas as pd
import constraints
import design
import file_utils
import viewers
# ---------------------------------------------------------------------------
# _get_file_path
# ---------------------------------------------------------------------------
class TestGetFilePath:
"""Normalizes Gradio's various file input formats to a Path."""
def test_string_input(self):
assert file_utils._get_file_path("/some/path.pdb") == Path("/some/path.pdb")
def test_object_with_path_attr(self):
obj = types.SimpleNamespace(path="/uploads/file.pdb")
assert file_utils._get_file_path(obj) == Path("/uploads/file.pdb")
def test_dict_with_path_key(self):
result = file_utils._get_file_path({"path": "/uploads/file.pdb", "name": "file.pdb"})
assert result == Path("/uploads/file.pdb")
def test_fallback_to_str(self):
assert file_utils._get_file_path(42) == Path("42")
# ---------------------------------------------------------------------------
# _build_pos_constraint_df
# ---------------------------------------------------------------------------
class TestBuildPosConstraintDf:
"""Builds a positional constraint DataFrame for caliby."""
def test_all_empty_returns_none(self):
assert constraints._build_pos_constraint_df("1YCR", "", "", "", "", "") is None
def test_all_whitespace_returns_none(self):
assert constraints._build_pos_constraint_df("1YCR", " ", " ", " ", " ", " ") is None
def test_single_field_populated(self):
df = constraints._build_pos_constraint_df("1YCR", "A1-100", "", "", "", "")
assert df is not None
assert len(df) == 1
assert df.iloc[0]["pdb_key"] == "1YCR"
assert df.iloc[0]["fixed_pos_seq"] == "A1-100"
# Only populated columns + pdb_key should be present
assert "fixed_pos_scn" not in df.columns
def test_all_fields_populated(self):
df = constraints._build_pos_constraint_df("X", "A1", "B2", "A3:G", "A4:V", "A5,B5")
assert set(df.columns) == {
"pdb_key",
"fixed_pos_seq",
"fixed_pos_scn",
"fixed_pos_override_seq",
"pos_restrict_aatype",
"symmetry_pos",
}
def test_columns_match_caliby_valid_columns(self):
"""All columns must be in caliby's _VALID_POS_CONSTRAINT_COLUMNS."""
valid = {
"pdb_key",
"fixed_pos_seq",
"fixed_pos_scn",
"fixed_pos_override_seq",
"pos_restrict_aatype",
"symmetry_pos",
}
df = constraints._build_pos_constraint_df("X", "A1", "B2", "A3:G", "A4:V", "A5,B5")
assert set(df.columns).issubset(valid)
# ---------------------------------------------------------------------------
# _df_to_csv
# ---------------------------------------------------------------------------
class TestDfToCsv:
"""Writes a DataFrame to a temp CSV file."""
def test_none_returns_none(self):
assert file_utils._df_to_csv(None) is None
def test_empty_dataframe_returns_none(self):
assert file_utils._df_to_csv(pd.DataFrame()) is None
def test_valid_dataframe_roundtrips(self):
df = pd.DataFrame({"pdb_key": ["1YCR"], "fixed_pos_seq": ["A1-100"]})
path = file_utils._df_to_csv(df)
assert path is not None
assert Path(path).exists()
assert path.endswith(".csv")
loaded = pd.read_csv(path)
pd.testing.assert_frame_equal(df, loaded)
def test_uses_sample_name_for_csv_basename(self):
df = pd.DataFrame(
{
"Sample": ["1YCR_sample0"],
"Sequence": ["ACDE"],
"Energy (U)": [-1.0],
}
)
path = file_utils._df_to_csv(df)
assert path is not None
assert Path(path).name == "1YCR_results.csv"
class TestCsvDownloadOutput:
"""Formats CSV downloads for the Gradio file component."""
def test_hides_component_for_empty_dataframe(self):
update = viewers._csv_download_output(pd.DataFrame())
assert update["visible"] is False
assert update["value"] is None
def test_shows_named_csv_for_results_dataframe(self):
df = pd.DataFrame(
{
"Sample": ["1YCR_sample0"],
"Sequence": ["ACDE"],
"Energy (U)": [-1.0],
}
)
update = viewers._csv_download_output(df)
assert update["visible"] is True
assert Path(update["value"]).name == "1YCR_results.csv"
class TestFormatResultsDisplay:
"""Formats the on-screen results table without changing the raw dataframe."""
def test_formats_last_four_numeric_columns(self):
df = pd.DataFrame(
{
"Sample": ["1YCR_sample0"],
"Sequence": ["ACDE"],
"Energy (U)": [-1.2345],
"sc_ca_rmsd": [1.0],
"avg_ca_plddt": [88.888],
"tmalign_score": [0.12345],
}
)
styler = viewers._format_results_display(df)
html = styler.to_html()
assert "-1.23" in html
assert ">1<" in html
assert "88.89" in html
assert "0.12" in html
# ---------------------------------------------------------------------------
# _format_outputs
# ---------------------------------------------------------------------------
class TestFormatOutputs:
"""Formats caliby output dict into (DataFrame, FASTA, out_pdb_list)."""
def test_dataframe_structure(self, sample_outputs_with_out_pdbs):
df, _, _ = design._format_outputs(sample_outputs_with_out_pdbs)
assert list(df.columns) == ["Sample", "Sequence", "Energy (U)"]
assert len(df) == 2
def test_sample_names_from_path_stems(self, sample_outputs_with_out_pdbs):
df, _, _ = design._format_outputs(sample_outputs_with_out_pdbs)
assert list(df["Sample"]) == ["1YCR_sample0", "1YCR_sample1"]
def test_fasta_format(self, sample_outputs_with_out_pdbs):
_, fasta, _ = design._format_outputs(sample_outputs_with_out_pdbs)
lines = fasta.strip().split("\n")
assert lines[0] == ">1YCR_sample0"
assert lines[1] == "MTEEQWAQ"
assert lines[2] == ">1YCR_sample1"
assert lines[3] == "VSEQQWAQ"
def test_uses_caliby_out_pdb_key(self, sample_outputs):
assert "out_pdbs" not in sample_outputs
df, fasta, out_pdb_list = design._format_outputs(sample_outputs)
assert list(df["Sample"]) == ["1YCR_sample0", "1YCR_sample1"]
assert ">1YCR_sample0" in fasta
assert out_pdb_list == sample_outputs["out_pdb"]
# ---------------------------------------------------------------------------
# _get_best_sc_sample
# ---------------------------------------------------------------------------
class TestGetBestScSample:
"""Picks the sample with the highest tmalign_score."""
def test_picks_highest_tmalign_score(self):
df = pd.DataFrame(
{
"Sample": ["1YCR_sample0", "1YCR_sample1", "1YCR_sample2"],
"tmalign_score": [0.5, 0.9, 0.7],
}
)
assert viewers._get_best_sc_sample(df) == "1YCR_sample1"
def test_falls_back_to_first_when_no_tmalign(self):
df = pd.DataFrame({"Sample": ["1YCR_sample0", "1YCR_sample1"]})
assert viewers._get_best_sc_sample(df) == "1YCR_sample0"
def test_falls_back_to_first_when_all_nan(self):
df = pd.DataFrame(
{
"Sample": ["A_sample0", "A_sample1"],
"tmalign_score": [float("nan"), float("nan")],
}
)
assert viewers._get_best_sc_sample(df) == "A_sample0"
def test_returns_none_for_empty_df(self):
assert viewers._get_best_sc_sample(pd.DataFrame()) is None
# ---------------------------------------------------------------------------
# _render_af2_viewer / _render_reference_viewer
# ---------------------------------------------------------------------------
_MINIMAL_PDB = "ATOM 1 CA ALA A 1 0.000 0.000 0.000 1.00 90.00 C\nEND\n"
class TestRenderAf2Viewer:
"""Renders AF2 prediction with pLDDT coloring via molview."""
def test_returns_html_with_valid_data(self):
html = viewers._render_af2_viewer("test_sample0", {"test_sample0": _MINIMAL_PDB})
assert "iframe" in html
def test_returns_empty_for_missing_sample(self):
assert viewers._render_af2_viewer("missing", {"other": _MINIMAL_PDB}) == ""
def test_returns_empty_for_none_sample(self):
assert viewers._render_af2_viewer(None, {"test": _MINIMAL_PDB}) == ""
def test_returns_empty_for_empty_data(self):
assert viewers._render_af2_viewer("test", {}) == ""
class TestRenderReferenceViewer:
"""Renders original input PDB with chain coloring via molview."""
def test_maps_sample_to_input_key(self):
html = viewers._render_reference_viewer("1YCR_sample0", {"1YCR": _MINIMAL_PDB})
assert "iframe" in html
def test_returns_empty_when_input_key_missing(self):
assert viewers._render_reference_viewer("1YCR_sample0", {"OTHER": _MINIMAL_PDB}) == ""
def test_returns_empty_for_none_sample(self):
assert viewers._render_reference_viewer(None, {"1YCR": _MINIMAL_PDB}) == ""
# ---------------------------------------------------------------------------
# _update_viewers
# ---------------------------------------------------------------------------
class TestUpdateViewers:
"""Combined handler for overlay toggle."""
def test_overlay_off_hides_reference(self):
af2_html, ref_update = viewers._update_viewers("s0", {"s0": _MINIMAL_PDB}, {"s": _MINIMAL_PDB}, False)
assert "iframe" in af2_html
assert ref_update["visible"] is False
def test_overlay_on_shows_reference(self):
af2_html, ref_update = viewers._update_viewers(
"s_sample0", {"s_sample0": _MINIMAL_PDB}, {"s": _MINIMAL_PDB}, True
)
assert "iframe" in af2_html
assert ref_update["visible"] is True
assert "iframe" in ref_update["value"]