ci-triage-env / tests /training /test_ablations.py
Prasham.Jain
feat(training): Phase C6 — ablations, training curves, readme finalization
e46f00b
"""Tests for Phase C6 — ablations, curves, and readme finalization (no GPU)."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch
import pandas as pd
from ci_triage_env.training.ablations import ABLATIONS, run_ablation
from ci_triage_env.training.finalize_readme import populate_readme
# ---------------------------------------------------------------------------
# ABLATIONS dict
# ---------------------------------------------------------------------------
def test_ablations_dict_has_4_entries() -> None:
assert len(ABLATIONS) == 4 # counterfactual deferred to v2
def test_each_ablation_zeros_exactly_one_weight() -> None:
for name, overrides in ABLATIONS.items():
zeroed = [k for k, v in overrides.items() if v == 0.0]
assert len(zeroed) == 1, f"Ablation '{name}' should zero exactly 1 weight, got {zeroed}"
def test_ablation_names_reference_valid_reward_keys() -> None:
from ci_triage_env.rewards.weights import REWARD_WEIGHTS
for name, overrides in ABLATIONS.items():
for key in overrides:
assert key in REWARD_WEIGHTS, (
f"Ablation '{name}' references unknown reward key '{key}'"
)
# ---------------------------------------------------------------------------
# run_ablation smoke (mock run_grpo + Evaluator)
# ---------------------------------------------------------------------------
def test_run_ablation_smoke() -> None:
"""Mock run_grpo and Evaluator; verify run_ablation returns a DataFrame."""
fake_df = pd.DataFrame({
"baseline": ["random", "heuristic", "trained"],
"scenario_id": ["s1", "s1", "s1"],
"family": ["real_bug"] * 3,
"difficulty": ["easy"] * 3,
"seed": [1, 1, 1],
"total_reward": [0.1, 0.5, 0.7],
"format_gate": [True] * 3,
"diagnosis_correct": [False, True, True],
"predicted_diagnosis": ["ambiguous", "real_bug", "real_bug"],
"true_diagnosis": ["real_bug"] * 3,
"action_quality": [0.0, 0.3, 0.5],
"tool_call_count": [3, 4, 5],
"total_cost": [0.03, 0.04, 0.05],
"confidence": [0.5, 0.8, 0.9],
"is_ambiguous_scenario": [False] * 3,
"brier_on_ambiguous": [None] * 3,
})
mock_evaluator = MagicMock()
mock_evaluator.run_all.return_value = fake_df
with patch("ci_triage_env.training.ablations.run_grpo") as mock_grpo, \
patch("ci_triage_env.training.ablations.Evaluator", return_value=mock_evaluator):
mock_grpo.return_value = "checkpoints/ablation_test/"
result = run_ablation(
"no_diagnosis",
{"diagnosis": 0.0},
total_steps=10,
)
assert isinstance(result, pd.DataFrame)
assert "ablation" in result.columns
assert (result["ablation"] == "no_diagnosis").all()
mock_grpo.assert_called_once()
call_kwargs = mock_grpo.call_args.kwargs
assert call_kwargs["weights_override"]["diagnosis"] == 0.0
def test_run_ablation_passes_weights_to_grpo() -> None:
"""Confirm the merged weights dict reaches run_grpo."""
from ci_triage_env.rewards.weights import REWARD_WEIGHTS
fake_df = pd.DataFrame({"baseline": [], "total_reward": [], "diagnosis_correct": [],
"scenario_id": [], "family": [], "difficulty": [], "seed": [],
"format_gate": [], "predicted_diagnosis": [], "true_diagnosis": [],
"action_quality": [], "tool_call_count": [], "total_cost": [],
"confidence": [], "is_ambiguous_scenario": [], "brier_on_ambiguous": []})
mock_evaluator = MagicMock()
mock_evaluator.run_all.return_value = fake_df
with patch("ci_triage_env.training.ablations.run_grpo") as mock_grpo, \
patch("ci_triage_env.training.ablations.Evaluator", return_value=mock_evaluator):
mock_grpo.return_value = "checkpoints/ablation_no_anti_gaming/"
run_ablation("no_anti_gaming", {"anti_gaming": 0.0}, total_steps=5)
weights_sent = mock_grpo.call_args.kwargs["weights_override"]
assert weights_sent["anti_gaming"] == 0.0
# All other weights preserved from REWARD_WEIGHTS
for k, v in REWARD_WEIGHTS.items():
if k != "anti_gaming":
assert weights_sent[k] == v
# ---------------------------------------------------------------------------
# plot_ablation_summary (matplotlib mocked)
# ---------------------------------------------------------------------------
def test_plot_ablation_summary_writes_png(tmp_path: Path) -> None:
df = pd.DataFrame({
"ablation": ["no_diagnosis", "no_action_quality", "no_diagnosis", "no_action_quality"],
"baseline": ["random", "random", "heuristic", "heuristic"],
"diagnosis_correct": [0.2, 0.4, 0.5, 0.6],
"total_reward": [0.1, 0.3, 0.4, 0.5],
"action_quality": [0.0, 0.2, 0.3, 0.4],
})
saved: list[str] = []
mock_fig = MagicMock()
mock_axes = [MagicMock(), MagicMock()]
mock_plt = MagicMock()
mock_plt.subplots.return_value = (mock_fig, mock_axes)
mock_sns = MagicMock()
def _record(path, **kwargs):
saved.append(str(path))
mock_fig.savefig.side_effect = _record
with patch("ci_triage_env.training.curves.plt", mock_plt), \
patch("ci_triage_env.training.curves.sns", mock_sns):
from ci_triage_env.training.curves import plot_ablation_summary
plot_ablation_summary(df, output_dir=tmp_path / "plots")
assert len(saved) >= 1
assert any("ablation_summary" in s for s in saved)
# ---------------------------------------------------------------------------
# populate_readme
# ---------------------------------------------------------------------------
def test_finalize_readme_replaces_table_marker(tmp_path: Path) -> None:
readme = tmp_path / "README.md"
readme.write_text(
"# Results\n\n[FILL: 5-row × 6-metric table]\n\nMore text.\n"
)
eval_csv = tmp_path / "eval.csv"
pd.DataFrame({
"baseline": ["random", "heuristic"],
"diagnosis_correct": [0.3, 0.6],
"action_quality": [0.1, 0.4],
"total_cost": [0.05, 0.03],
"tool_call_count": [4, 4],
"total_reward": [0.2, 0.5],
}).to_csv(eval_csv, index=False)
# No ablation csv; no plots dir
n = populate_readme(
template_path=readme,
eval_csv=eval_csv,
ablation_csv=tmp_path / "ablations.csv",
plots_dir=tmp_path / "plots",
)
result = readme.read_text()
assert "[FILL: 5-row × 6-metric table]" not in result
assert "|" in result # table was inserted
assert n >= 1
def test_finalize_readme_embeds_plot_images(tmp_path: Path) -> None:
readme = tmp_path / "README.md"
readme.write_text("# Plots\n\n[FILL: diagnosis accuracy]\n")
plots_dir = tmp_path / "plots"
plots_dir.mkdir()
(plots_dir / "diagnosis_accuracy.png").touch()
n = populate_readme(
template_path=readme,
eval_csv=tmp_path / "eval.csv",
ablation_csv=tmp_path / "ablations.csv",
plots_dir=plots_dir,
)
result = readme.read_text()
assert "[FILL: diagnosis accuracy]" not in result
assert "diagnosis_accuracy.png" in result
assert n >= 1
def test_finalize_readme_missing_csv_does_not_crash(tmp_path: Path) -> None:
readme = tmp_path / "README.md"
readme.write_text("# CI Triage\n\nNo markers here.\n")
n = populate_readme(
template_path=readme,
eval_csv=tmp_path / "nonexistent.csv",
ablation_csv=tmp_path / "nonexistent2.csv",
plots_dir=tmp_path / "no_plots",
)
assert n == 0
assert readme.read_text() == "# CI Triage\n\nNo markers here.\n"