"""Multi-baseline evaluation pipeline for CI-Triage-Env. Run with: python -m ci_triage_env.training.eval --output data_artifacts/results/ """ from __future__ import annotations import argparse from pathlib import Path from ci_triage_env.rewards.composite import compute_reward from ci_triage_env.schemas.action import ToolCall from ci_triage_env.schemas.scenario import Scenario from ci_triage_env.training.baselines.heuristic_policy import HeuristicPolicy from ci_triage_env.training.baselines.random_policy import RandomPolicy from ci_triage_env.training.rollout import _SYSTEM_PROMPT_TEMPLATE _SYSTEM_PROMPT = _SYSTEM_PROMPT_TEMPLATE.format(tools="") def load_scenario( scenario_id: str, scenarios_dir: str = "data_artifacts/scenarios", ) -> Scenario: """Load a Scenario by ID, searching common split subdirectories.""" candidates = [ Path(scenarios_dir) / f"{scenario_id}.json", Path(scenarios_dir) / "train" / f"{scenario_id}.json", Path(scenarios_dir) / "val" / f"{scenario_id}.json", Path(scenarios_dir) / "held_out" / f"{scenario_id}.json", ] for p in candidates: if p.exists(): return Scenario.model_validate_json(p.read_text()) # Fallback: generate a plausible mock from the scenario_id prefix family = scenario_id.split("-")[0] _valid_families = { "real_bug", "race_flake", "timing_flake", "infra_network", "infra_resource", "dependency_drift", "ambiguous", } from ci_triage_env.mock.scenario import make_mock_scenario return make_mock_scenario(family=family if family in _valid_families else "real_bug") class Evaluator: """Run 5-baseline evaluation matrix over a held-out scenario set.""" BASELINES = [ "random", "heuristic", "qwen3.5_4b_zero_shot", "qwen3.5_9b_zero_shot", "trained", ] def __init__( self, eval_set_path: str = "data_artifacts/scenarios/held_out/", env_url: str = "http://localhost:8000", trained_checkpoint: str = "checkpoints/grpo_full/", env_client=None, ) -> None: if env_client is not None: self.env = env_client else: from ci_triage_env.training.env_client import EnvClient self.env = EnvClient(env_url) self.eval_scenarios = list(Path(eval_set_path).glob("*.json")) self.trained_checkpoint = trained_checkpoint def run_all(self, seeds: list[int] | None = None): import pandas as pd if seeds is None: seeds = [1, 2, 3] rows = [] for baseline_name in self.BASELINES: policy = self._build(baseline_name) for scenario_path in self.eval_scenarios: scenario_id = scenario_path.stem for seed in seeds: rows.append(self._run_one(policy, scenario_id, seed)) return pd.DataFrame(rows) def _build(self, name: str): if name == "random": return RandomPolicy() if name == "heuristic": return HeuristicPolicy() if name == "qwen3.5_4b_zero_shot": from ci_triage_env.training.baselines.zero_shot import ZeroShotPolicy return ZeroShotPolicy("Qwen/Qwen3.5-4B", _SYSTEM_PROMPT) if name == "qwen3.5_9b_zero_shot": from ci_triage_env.training.baselines.zero_shot import ZeroShotPolicy return ZeroShotPolicy( "Qwen/Qwen3.5-9B", _SYSTEM_PROMPT, name="zero_shot_Qwen3.5-9B" ) if name == "trained": from ci_triage_env.training.baselines.trained import TrainedPolicy return TrainedPolicy(self.trained_checkpoint, _SYSTEM_PROMPT) raise ValueError(f"Unknown baseline: {name}") def _run_one(self, policy, scenario_id: str, seed: int) -> dict: obs = self.env.reset(scenario_id=scenario_id, seed_override=seed) episode_id = obs.episode_id history: list = [] for _ in range(12): action = policy.act(obs, history) try: obs = self.env.step(episode_id, action) except Exception: break history.append(action) if obs.is_terminal: break trace = self.env.get_trace(episode_id) scenario = load_scenario(scenario_id) reward = compute_reward(trace, scenario) final = trace.episode.final_action return { "baseline": policy.name, "scenario_id": scenario_id, "family": scenario.family, "difficulty": scenario.metadata.difficulty, "seed": seed, "total_reward": reward.total, "format_gate": reward.format_gate, "diagnosis_correct": ( final.diagnosis == scenario.ground_truth.label if final else False ), "predicted_diagnosis": final.diagnosis if final else None, "true_diagnosis": scenario.ground_truth.label, "action_quality": reward.components["action_quality"].raw, "tool_call_count": sum( 1 for r in trace.episode.history if isinstance(r.action, ToolCall) ), "total_cost": sum(r.cost_charged for r in trace.episode.history), "confidence": final.confidence if final else 0.0, "is_ambiguous_scenario": scenario.ground_truth.is_ambiguous, "brier_on_ambiguous": ( (final.confidence - scenario.ground_truth.confidence_target) ** 2 if scenario.ground_truth.is_ambiguous and final else None ), } def main(argv=None) -> None: parser = argparse.ArgumentParser(description="Run multi-baseline CI-triage evaluation") parser.add_argument("--output", default="data_artifacts/results/") parser.add_argument("--eval-set", default="data_artifacts/scenarios/held_out/") parser.add_argument("--checkpoint", default="checkpoints/grpo_full/") parser.add_argument("--seeds", nargs="+", type=int, default=[1, 2, 3]) args = parser.parse_args(argv) evaluator = Evaluator( eval_set_path=args.eval_set, trained_checkpoint=args.checkpoint, ) df = evaluator.run_all(seeds=args.seeds) out = Path(args.output) out.mkdir(parents=True, exist_ok=True) df.to_csv(out / "eval.csv", index=False) print( df.groupby("baseline").agg({ "diagnosis_correct": "mean", "action_quality": "mean", "tool_call_count": "mean", "total_cost": "mean", "total_reward": "mean", }) ) from ci_triage_env.training.plotting import plot_all_eval_metrics plot_all_eval_metrics(df, out / "plots/") if __name__ == "__main__": main()