Prasham.Jain
feat(training): Phase C5 — evaluation harness, baselines, plots, readme table
93e68bc
"""Tests for Phase C5 — evaluation harness (no GPU required)."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch
import pandas as pd
import pytest
from ci_triage_env.schemas.observation import BudgetState, Observation
from ci_triage_env.training.baselines.heuristic_policy import HeuristicPolicy
from ci_triage_env.training.baselines.random_policy import ALL_TOOL_ARG_DEFAULTS, RandomPolicy
from ci_triage_env.training.eval import Evaluator
from ci_triage_env.training.readme_table import generate_results_table
def _mock_obs() -> Observation:
return Observation(
episode_id="ep-test",
step=0,
failure_summary=None,
budget_remaining=BudgetState(tool_calls_remaining=12, cost_remaining=1.0),
)
# ---------------------------------------------------------------------------
# RandomPolicy
# ---------------------------------------------------------------------------
def test_random_policy_emits_valid_action() -> None:
p = RandomPolicy(seed=42)
action = p.act(_mock_obs(), [])
assert "tool_name" in action or "action_type" in action
def test_random_policy_terminates_at_max_turns() -> None:
p = RandomPolicy(max_turns=2, seed=0)
action = p.act(_mock_obs(), [{}, {}]) # len == max_turns
assert action["action_type"] == "submit_diagnosis"
def test_random_policy_terminal_is_valid() -> None:
from ci_triage_env.schemas.diagnosis import DiagnosisLabel
p = RandomPolicy(seed=7)
action = p._random_terminal()
assert action["diagnosis"] in list(DiagnosisLabel)
assert 0.0 <= action["confidence"] <= 1.0
assert action["secondary_actions"] == []
def test_random_tool_call_uses_valid_defaults() -> None:
p = RandomPolicy(seed=0)
for _ in range(30):
action = p._random_tool_call()
tool_name = action["tool_name"]
assert tool_name in ALL_TOOL_ARG_DEFAULTS
assert action["args"] == ALL_TOOL_ARG_DEFAULTS[tool_name]
# ---------------------------------------------------------------------------
# HeuristicPolicy
# ---------------------------------------------------------------------------
def test_heuristic_policy_completes_investigation_then_diagnoses() -> None:
p = HeuristicPolicy()
history: list = []
for _ in range(4):
action = p.act(_mock_obs(), history)
assert "tool_name" in action
history.append({"output": "ok"})
final = p.act(_mock_obs(), history)
assert final["action_type"] == "submit_diagnosis"
def test_heuristic_policy_plan_length() -> None:
assert len(HeuristicPolicy.INVESTIGATION_PLAN) == 4
def test_heuristic_policy_defaults_to_ambiguous_on_no_match() -> None:
p = HeuristicPolicy()
# 4 "ok" outputs match no rules → ambiguous
history = [{"output": "ok"}] * 4
final = p._classify_from_history(history)
assert final["diagnosis"] == "ambiguous"
assert final["confidence"] == pytest.approx(0.4)
def test_heuristic_policy_secondary_for_real_bug() -> None:
p = HeuristicPolicy()
sa = p._secondary_for("real_bug")
assert len(sa) == 1
assert sa[0]["name"] == "file_bug"
def test_heuristic_policy_secondary_for_flake() -> None:
p = HeuristicPolicy()
for family in ("race_flake", "timing_flake"):
sa = p._secondary_for(family)
assert sa[0]["name"] == "quarantine_test"
# ---------------------------------------------------------------------------
# Evaluator._run_one
# ---------------------------------------------------------------------------
def test_evaluator_run_one_returns_row() -> None:
from ci_triage_env.training.mock_env_client import MockEnvClient
env = MockEnvClient(seed=0)
evaluator = Evaluator(env_client=env)
policy = RandomPolicy(seed=0)
# "real_bug-00" prefix causes load_scenario to fall back to make_mock_scenario("real_bug")
row = evaluator._run_one(policy, "real_bug-00", seed=1)
expected_keys = {
"baseline", "scenario_id", "family", "difficulty", "seed",
"total_reward", "format_gate", "diagnosis_correct",
"predicted_diagnosis", "true_diagnosis", "action_quality",
"tool_call_count", "total_cost", "confidence",
"is_ambiguous_scenario", "brier_on_ambiguous",
}
assert expected_keys <= set(row.keys())
assert isinstance(row["total_reward"], float)
assert isinstance(row["tool_call_count"], int)
assert row["scenario_id"] == "real_bug-00"
assert row["seed"] == 1
def test_evaluator_run_one_heuristic() -> None:
from ci_triage_env.training.mock_env_client import MockEnvClient
env = MockEnvClient(seed=1)
evaluator = Evaluator(env_client=env)
row = evaluator._run_one(HeuristicPolicy(), "race_flake-00", seed=2)
assert 0.0 <= row["total_cost"]
assert isinstance(row["diagnosis_correct"], bool)
# ---------------------------------------------------------------------------
# generate_results_table
# ---------------------------------------------------------------------------
def test_results_table_markdown() -> None:
df = pd.DataFrame({
"baseline": ["random", "heuristic", "random", "heuristic"],
"diagnosis_correct": [0.2, 0.6, 0.3, 0.7],
"action_quality": [0.1, 0.5, 0.2, 0.55],
"total_cost": [0.05, 0.04, 0.06, 0.03],
"tool_call_count": [4, 4, 5, 4],
"total_reward": [0.1, 0.5, 0.2, 0.6],
})
table = generate_results_table(df)
assert isinstance(table, str)
assert "|" in table
assert "baseline" in table
assert "diagnosis_acc" in table
assert "heuristic" in table
assert "random" in table
def test_results_table_float_format() -> None:
df = pd.DataFrame({
"baseline": ["a"],
"diagnosis_correct": [0.123456],
"action_quality": [0.5],
"total_cost": [0.01],
"tool_call_count": [3],
"total_reward": [0.25],
})
table = generate_results_table(df)
assert "0.123" in table
assert "0.124" not in table # 3-decimal truncation, not rounding beyond
# ---------------------------------------------------------------------------
# plot_all_eval_metrics (matplotlib mocked)
# ---------------------------------------------------------------------------
def test_plotting_writes_pngs(tmp_path: Path) -> None:
"""plot_all_eval_metrics records ≥ 5 savefig calls (matplotlib mocked)."""
df = pd.DataFrame({
"baseline": ["random", "heuristic", "random", "heuristic"],
"diagnosis_correct": [True, False, True, True],
"total_reward": [0.1, 0.5, 0.2, 0.6],
"family": ["real_bug", "race_flake", "real_bug", "race_flake"],
"total_cost": [0.05, 0.04, 0.06, 0.03],
"tool_call_count": [4, 4, 5, 3],
"confidence": [0.9, 0.6, 0.8, 0.7],
"is_ambiguous_scenario": [False, False, False, False],
"brier_on_ambiguous": [None, None, None, None],
})
saved: list[str] = []
mock_fig = MagicMock()
mock_ax = MagicMock()
mock_plt = MagicMock()
mock_plt.subplots.return_value = (mock_fig, mock_ax)
mock_sns = MagicMock()
def _record_savefig(path, **kwargs):
saved.append(str(path))
mock_fig.savefig.side_effect = _record_savefig
# Patch the module-level plt/sns so the function uses our mocks without
# needing sys.modules tricks (plotting.py uses try/except module-level imports).
with patch("ci_triage_env.training.plotting.plt", mock_plt), \
patch("ci_triage_env.training.plotting.sns", mock_sns):
from ci_triage_env.training.plotting import plot_all_eval_metrics
plot_all_eval_metrics(df, tmp_path / "plots")
assert len(saved) >= 5, f"Expected ≥ 5 savefig calls, got {len(saved)}: {saved}"