Spaces:
Sleeping
Sleeping
Prasham.Jain Claude Sonnet 4.6 commited on
Commit ·
e46f00b
1
Parent(s): 93e68bc
feat(training): Phase C6 — ablations, training curves, readme finalization
Browse filesAdds 4-entry ablation matrix (no_diagnosis, no_action_quality, no_investigation,
no_anti_gaming), each running 1000-step GRPO with one reward component zeroed.
Extends run_grpo/TrainingRollout with weights_override to thread ablation weights
into CompositeReward. Adds W&B curve plotting, ablation summary charts,
populate_readme for [FILL] marker replacement, and Colab eval notebook.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
notebooks/eval.ipynb
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "cell-0",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# CI-Triage-Env — Evaluation & Ablation Notebook\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"Colab-runnable notebook for judges to reproduce all results.\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"Steps:\n",
|
| 13 |
+
"1. Install dependencies\n",
|
| 14 |
+
"2. Load trained checkpoint from HF Hub\n",
|
| 15 |
+
"3. Run full 5-baseline evaluation\n",
|
| 16 |
+
"4. Generate all metric plots\n",
|
| 17 |
+
"5. Run reward-layer ablations (optional, GPU, ~5h)\n",
|
| 18 |
+
"6. Populate README with results\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"**Prerequisites**: `HF_TOKEN`, `WANDB_API_KEY` as Colab secrets."
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": null,
|
| 26 |
+
"id": "cell-1",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"# Cell 1: Install dependencies\n",
|
| 31 |
+
"!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n",
|
| 32 |
+
"!pip install -q unsloth trl transformers accelerate peft\n",
|
| 33 |
+
"!pip install -q wandb datasets huggingface_hub openai httpx fastapi uvicorn pydantic jsonschema\n",
|
| 34 |
+
"!pip install -q matplotlib seaborn pandas tabulate\n",
|
| 35 |
+
"!pip install -q -e . # install ci_triage_env package"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
+
"id": "cell-2",
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [],
|
| 44 |
+
"source": [
|
| 45 |
+
"# Cell 2: Environment setup\n",
|
| 46 |
+
"import os\n",
|
| 47 |
+
"from google.colab import userdata\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n",
|
| 50 |
+
"os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')\n",
|
| 51 |
+
"os.environ['WANDB_PROJECT'] = 'ci-triage-env'\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"# Config — replace with your values\n",
|
| 54 |
+
"HF_DATASET_REPO = 'YOUR_ORG/ci-triage-scenarios'\n",
|
| 55 |
+
"HF_MODEL_REPO = 'YOUR_ORG/ci-triage-trained-qwen3.5-4b'\n",
|
| 56 |
+
"WANDB_RUN_ID = 'YOUR_WANDB_RUN_ID' # e.g. 'entity/ci-triage-env/abc123'"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": null,
|
| 62 |
+
"id": "cell-3",
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"outputs": [],
|
| 65 |
+
"source": [
|
| 66 |
+
"# Cell 3: Download scenario corpus and trained checkpoint\n",
|
| 67 |
+
"from huggingface_hub import snapshot_download\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"scen_dir = snapshot_download(\n",
|
| 70 |
+
" HF_DATASET_REPO, repo_type='dataset',\n",
|
| 71 |
+
" local_dir='data_artifacts/scenarios'\n",
|
| 72 |
+
")\n",
|
| 73 |
+
"ckpt_dir = snapshot_download(\n",
|
| 74 |
+
" HF_MODEL_REPO, repo_type='model',\n",
|
| 75 |
+
" local_dir='checkpoints/grpo_full'\n",
|
| 76 |
+
")\n",
|
| 77 |
+
"print(f'Scenarios: {scen_dir}')\n",
|
| 78 |
+
"print(f'Checkpoint: {ckpt_dir}')"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "code",
|
| 83 |
+
"execution_count": null,
|
| 84 |
+
"id": "cell-4",
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"outputs": [],
|
| 87 |
+
"source": [
|
| 88 |
+
"# Cell 4: Start env server in background\n",
|
| 89 |
+
"import subprocess, time\n",
|
| 90 |
+
"server_proc = subprocess.Popen(\n",
|
| 91 |
+
" ['python', '-m', 'ci_triage_env.env.server'],\n",
|
| 92 |
+
" stdout=subprocess.PIPE, stderr=subprocess.PIPE\n",
|
| 93 |
+
")\n",
|
| 94 |
+
"time.sleep(4)\n",
|
| 95 |
+
"print('Env server started, PID:', server_proc.pid)"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "code",
|
| 100 |
+
"execution_count": null,
|
| 101 |
+
"id": "cell-5",
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"outputs": [],
|
| 104 |
+
"source": [
|
| 105 |
+
"# Cell 5: Run full 5-baseline evaluation\n",
|
| 106 |
+
"from pathlib import Path\n",
|
| 107 |
+
"import pandas as pd\n",
|
| 108 |
+
"from ci_triage_env.training.eval import Evaluator\n",
|
| 109 |
+
"from ci_triage_env.training.plotting import plot_all_eval_metrics\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"evaluator = Evaluator(\n",
|
| 112 |
+
" eval_set_path='data_artifacts/scenarios/held_out/',\n",
|
| 113 |
+
" trained_checkpoint='checkpoints/grpo_full/',\n",
|
| 114 |
+
")\n",
|
| 115 |
+
"df_eval = evaluator.run_all(seeds=[1, 2, 3])\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"out = Path('data_artifacts/results/')\n",
|
| 118 |
+
"out.mkdir(parents=True, exist_ok=True)\n",
|
| 119 |
+
"df_eval.to_csv(out / 'eval.csv', index=False)\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"print(df_eval.groupby('baseline').agg({\n",
|
| 122 |
+
" 'diagnosis_correct': 'mean',\n",
|
| 123 |
+
" 'total_reward': 'mean',\n",
|
| 124 |
+
" 'tool_call_count': 'mean',\n",
|
| 125 |
+
"}))\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"plot_all_eval_metrics(df_eval, out / 'plots/')\n",
|
| 128 |
+
"print('Plots saved to data_artifacts/results/plots/')"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "code",
|
| 133 |
+
"execution_count": null,
|
| 134 |
+
"id": "cell-6",
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"outputs": [],
|
| 137 |
+
"source": [
|
| 138 |
+
"# Cell 6: Pull training curves from W&B\n",
|
| 139 |
+
"from ci_triage_env.training.curves import plot_training_curves_from_wandb\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"plot_training_curves_from_wandb(\n",
|
| 142 |
+
" run_id=WANDB_RUN_ID,\n",
|
| 143 |
+
" output_dir=Path('data_artifacts/results/plots/'),\n",
|
| 144 |
+
")\n",
|
| 145 |
+
"print('Training curves saved.')"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"cell_type": "code",
|
| 150 |
+
"execution_count": null,
|
| 151 |
+
"id": "cell-7",
|
| 152 |
+
"metadata": {},
|
| 153 |
+
"outputs": [],
|
| 154 |
+
"source": [
|
| 155 |
+
"# Cell 7: Run reward-layer ablations (~5h on A100; set RUN_ABLATIONS=True to enable)\n",
|
| 156 |
+
"RUN_ABLATIONS = False\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"if RUN_ABLATIONS:\n",
|
| 159 |
+
" from ci_triage_env.training.ablations import ABLATIONS, run_ablation\n",
|
| 160 |
+
" from ci_triage_env.training.curves import plot_ablation_summary\n",
|
| 161 |
+
"\n",
|
| 162 |
+
" abl_results = []\n",
|
| 163 |
+
" for name, overrides in ABLATIONS.items():\n",
|
| 164 |
+
" print(f'=== Ablation: {name} ===')\n",
|
| 165 |
+
" df_abl = run_ablation(name, overrides, total_steps=1000)\n",
|
| 166 |
+
" abl_results.append(df_abl)\n",
|
| 167 |
+
" print(df_abl.groupby('baseline')['diagnosis_correct'].mean())\n",
|
| 168 |
+
"\n",
|
| 169 |
+
" df_full_abl = pd.concat(abl_results, ignore_index=True)\n",
|
| 170 |
+
" df_full_abl.to_csv(out / 'ablations.csv', index=False)\n",
|
| 171 |
+
" plot_ablation_summary(df_full_abl, output_dir=out / 'plots/')\n",
|
| 172 |
+
" print('Ablations saved.')\n",
|
| 173 |
+
"else:\n",
|
| 174 |
+
" print('Ablations skipped (set RUN_ABLATIONS=True to run).')"
|
| 175 |
+
]
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"cell_type": "code",
|
| 179 |
+
"execution_count": null,
|
| 180 |
+
"id": "cell-8",
|
| 181 |
+
"metadata": {},
|
| 182 |
+
"outputs": [],
|
| 183 |
+
"source": [
|
| 184 |
+
"# Cell 8: Populate README with results\n",
|
| 185 |
+
"from ci_triage_env.training.finalize_readme import populate_readme\n",
|
| 186 |
+
"\n",
|
| 187 |
+
"n = populate_readme(\n",
|
| 188 |
+
" eval_csv=out / 'eval.csv',\n",
|
| 189 |
+
" ablation_csv=out / 'ablations.csv',\n",
|
| 190 |
+
" plots_dir=out / 'plots/',\n",
|
| 191 |
+
")\n",
|
| 192 |
+
"print(f'Replaced {n} markers in README.md')\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"# Check for any remaining unfilled markers\n",
|
| 195 |
+
"import subprocess\n",
|
| 196 |
+
"result = subprocess.run(['grep', '-c', r'\\[FILL', 'README.md'], capture_output=True, text=True)\n",
|
| 197 |
+
"remaining = int(result.stdout.strip() or 0)\n",
|
| 198 |
+
"if remaining:\n",
|
| 199 |
+
" print(f'WARNING: {remaining} unfilled [FILL] marker(s) remain in README.md')\n",
|
| 200 |
+
"else:\n",
|
| 201 |
+
" print('README.md is clean — no unfilled markers.')"
|
| 202 |
+
]
|
| 203 |
+
}
|
| 204 |
+
],
|
| 205 |
+
"metadata": {
|
| 206 |
+
"kernelspec": {
|
| 207 |
+
"display_name": "Python 3",
|
| 208 |
+
"language": "python",
|
| 209 |
+
"name": "python3"
|
| 210 |
+
},
|
| 211 |
+
"language_info": {
|
| 212 |
+
"name": "python",
|
| 213 |
+
"version": "3.11.0"
|
| 214 |
+
}
|
| 215 |
+
},
|
| 216 |
+
"nbformat": 4,
|
| 217 |
+
"nbformat_minor": 5
|
| 218 |
+
}
|
src/ci_triage_env/training/ablations.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward-layer ablation runs for CI-Triage-Env.
|
| 2 |
+
|
| 3 |
+
Each ablation zeroes one reward component weight, runs 1000-step GRPO from the
|
| 4 |
+
SFT checkpoint, then evaluates against the held-out set. Results are saved to
|
| 5 |
+
data_artifacts/results/ablations.csv.
|
| 6 |
+
|
| 7 |
+
All GPU-heavy imports are lazy; this module is importable without GPU.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from ci_triage_env.rewards.weights import REWARD_WEIGHTS
|
| 13 |
+
from ci_triage_env.training.eval import Evaluator
|
| 14 |
+
from ci_triage_env.training.grpo import run_grpo
|
| 15 |
+
|
| 16 |
+
ABLATIONS: dict[str, dict[str, float]] = {
|
| 17 |
+
"no_diagnosis": {"diagnosis": 0.0},
|
| 18 |
+
"no_action_quality": {"action_quality": 0.0},
|
| 19 |
+
"no_investigation": {"investigation": 0.0},
|
| 20 |
+
"no_anti_gaming": {"anti_gaming": 0.0},
|
| 21 |
+
# counterfactual ablation deferred to v2 — its weight is already 0.0 in v1.
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def run_ablation(
|
| 26 |
+
name: str,
|
| 27 |
+
weight_overrides: dict[str, float],
|
| 28 |
+
total_steps: int = 1000,
|
| 29 |
+
base_sft_checkpoint: str = "checkpoints/sft/",
|
| 30 |
+
env_client=None,
|
| 31 |
+
):
|
| 32 |
+
"""Run one ablation: train GRPO with modified weights then evaluate.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
name: Ablation name (used as output directory suffix and CSV column).
|
| 36 |
+
weight_overrides: Keys/values to merge over REWARD_WEIGHTS (zeroed components).
|
| 37 |
+
total_steps: GRPO training steps for this ablation (default 1000, not 3000).
|
| 38 |
+
base_sft_checkpoint: SFT warmstart checkpoint to train from.
|
| 39 |
+
env_client: Optional env client override (MockEnvClient for testing).
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
pandas DataFrame with eval results; has an extra "ablation" column.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
new_weights = {**REWARD_WEIGHTS, **weight_overrides}
|
| 46 |
+
output_dir = f"checkpoints/ablation_{name}/"
|
| 47 |
+
|
| 48 |
+
run_grpo(
|
| 49 |
+
sft_checkpoint_dir=base_sft_checkpoint,
|
| 50 |
+
output_dir=output_dir,
|
| 51 |
+
total_steps=total_steps,
|
| 52 |
+
weights_override=new_weights,
|
| 53 |
+
env_client=env_client,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
evaluator = Evaluator(trained_checkpoint=output_dir, env_client=env_client)
|
| 57 |
+
evaluator.BASELINES = ["random", "heuristic", "trained"]
|
| 58 |
+
df = evaluator.run_all(seeds=[1])
|
| 59 |
+
df["ablation"] = name
|
| 60 |
+
return df
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main(argv=None) -> None:
|
| 64 |
+
import argparse
|
| 65 |
+
from pathlib import Path
|
| 66 |
+
|
| 67 |
+
import pandas as pd
|
| 68 |
+
|
| 69 |
+
from ci_triage_env.training.curves import plot_ablation_summary
|
| 70 |
+
|
| 71 |
+
parser = argparse.ArgumentParser(description="Run reward-layer ablation matrix")
|
| 72 |
+
parser.add_argument("--steps", type=int, default=1000)
|
| 73 |
+
parser.add_argument("--sft-checkpoint", default="checkpoints/sft/")
|
| 74 |
+
parser.add_argument("--output", default="data_artifacts/results/")
|
| 75 |
+
args = parser.parse_args(argv)
|
| 76 |
+
|
| 77 |
+
all_results = []
|
| 78 |
+
for name, overrides in ABLATIONS.items():
|
| 79 |
+
print(f"=== Ablation: {name} ===")
|
| 80 |
+
df = run_ablation(
|
| 81 |
+
name, overrides,
|
| 82 |
+
total_steps=args.steps,
|
| 83 |
+
base_sft_checkpoint=args.sft_checkpoint,
|
| 84 |
+
)
|
| 85 |
+
all_results.append(df)
|
| 86 |
+
print(df.groupby("baseline")["diagnosis_correct"].mean())
|
| 87 |
+
|
| 88 |
+
full = pd.concat(all_results, ignore_index=True)
|
| 89 |
+
out = Path(args.output)
|
| 90 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 91 |
+
full.to_csv(out / "ablations.csv", index=False)
|
| 92 |
+
print(f"\nAblations saved to {out / 'ablations.csv'}")
|
| 93 |
+
|
| 94 |
+
plot_ablation_summary(full, output_dir=out / "plots/")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
main()
|
src/ci_triage_env/training/curves.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training curve plots and ablation summary charts for CI-Triage-Env.
|
| 2 |
+
|
| 3 |
+
Module-level try/except lets the module import without matplotlib; tests patch
|
| 4 |
+
`ci_triage_env.training.curves.plt` and `.sns` directly.
|
| 5 |
+
W&B imports are lazy (inside functions) since wandb is Colab-only.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import matplotlib.pyplot as plt # type: ignore[import]
|
| 14 |
+
import seaborn as sns # type: ignore[import]
|
| 15 |
+
except ImportError:
|
| 16 |
+
plt = None # type: ignore[assignment]
|
| 17 |
+
sns = None # type: ignore[assignment]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def plot_training_curves_from_wandb(run_id: str, output_dir: Path) -> None:
|
| 21 |
+
"""Pull per-step metrics from W&B and write training curve PNGs.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
run_id: W&B run path, e.g. "<entity>/ci-triage-env/<run_id>".
|
| 25 |
+
output_dir: Directory to write PNG files.
|
| 26 |
+
"""
|
| 27 |
+
if plt is None:
|
| 28 |
+
raise ImportError("matplotlib required — install with: pip install matplotlib")
|
| 29 |
+
|
| 30 |
+
import wandb # type: ignore[import]
|
| 31 |
+
|
| 32 |
+
output_dir = Path(output_dir)
|
| 33 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
api = wandb.Api()
|
| 36 |
+
run = api.run(run_id)
|
| 37 |
+
history = run.history(samples=10000)
|
| 38 |
+
|
| 39 |
+
# 1. Reward curve (raw + smoothed)
|
| 40 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 41 |
+
ax.plot(history["step"], history["mean_reward"], alpha=0.4, label="raw")
|
| 42 |
+
smooth = history["mean_reward"].rolling(window=20, min_periods=1).mean()
|
| 43 |
+
ax.plot(history["step"], smooth, label="smoothed (window=20)", linewidth=2)
|
| 44 |
+
ax.set_xlabel("Training step")
|
| 45 |
+
ax.set_ylabel("Mean episode reward")
|
| 46 |
+
ax.set_title("GRPO training: reward over time")
|
| 47 |
+
ax.legend()
|
| 48 |
+
ax.grid(alpha=0.3)
|
| 49 |
+
fig.tight_layout()
|
| 50 |
+
fig.savefig(output_dir / "training_reward_curve.png", dpi=120)
|
| 51 |
+
plt.close(fig)
|
| 52 |
+
|
| 53 |
+
# 2. Per-component reward breakdown
|
| 54 |
+
component_keys = [
|
| 55 |
+
"reward/diagnosis", "reward/action_quality", "reward/cost_efficiency",
|
| 56 |
+
"reward/investigation", "reward/anti_gaming",
|
| 57 |
+
]
|
| 58 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 59 |
+
for k in component_keys:
|
| 60 |
+
if k in history.columns:
|
| 61 |
+
smooth = history[k].rolling(window=20, min_periods=1).mean()
|
| 62 |
+
ax.plot(history["step"], smooth, label=k.replace("reward/", ""))
|
| 63 |
+
ax.set_xlabel("Training step")
|
| 64 |
+
ax.set_ylabel("Mean component reward (smoothed)")
|
| 65 |
+
ax.set_title("Per-component reward breakdown over training")
|
| 66 |
+
ax.legend()
|
| 67 |
+
ax.grid(alpha=0.3)
|
| 68 |
+
fig.tight_layout()
|
| 69 |
+
fig.savefig(output_dir / "per_component_curve.png", dpi=120)
|
| 70 |
+
plt.close(fig)
|
| 71 |
+
|
| 72 |
+
# 3. KL divergence
|
| 73 |
+
if "kl" in history.columns:
|
| 74 |
+
fig, ax = plt.subplots(figsize=(10, 4))
|
| 75 |
+
ax.plot(history["step"], history["kl"])
|
| 76 |
+
ax.set_xlabel("Training step")
|
| 77 |
+
ax.set_ylabel("KL(policy || ref)")
|
| 78 |
+
ax.set_title("KL divergence to SFT reference")
|
| 79 |
+
ax.grid(alpha=0.3)
|
| 80 |
+
fig.tight_layout()
|
| 81 |
+
fig.savefig(output_dir / "kl_divergence.png", dpi=120)
|
| 82 |
+
plt.close(fig)
|
| 83 |
+
|
| 84 |
+
# 4. Format-gate pass rate
|
| 85 |
+
if "format_gate_pass_rate" in history.columns:
|
| 86 |
+
fig, ax = plt.subplots(figsize=(10, 4))
|
| 87 |
+
ax.plot(history["step"], history["format_gate_pass_rate"])
|
| 88 |
+
ax.set_ylim(0, 1)
|
| 89 |
+
ax.set_xlabel("Training step")
|
| 90 |
+
ax.set_ylabel("Format gate pass rate")
|
| 91 |
+
ax.set_title("Fraction of episodes passing format gate")
|
| 92 |
+
ax.grid(alpha=0.3)
|
| 93 |
+
fig.tight_layout()
|
| 94 |
+
fig.savefig(output_dir / "format_gate.png", dpi=120)
|
| 95 |
+
plt.close(fig)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def plot_ablation_summary(df, output_dir=None) -> None:
|
| 99 |
+
"""Bar chart comparing diagnosis accuracy and total reward across ablations.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
df: DataFrame with columns [ablation, baseline, diagnosis_correct, total_reward, ...].
|
| 103 |
+
output_dir: Directory to write ablation_summary.png. Defaults to
|
| 104 |
+
data_artifacts/results/plots/.
|
| 105 |
+
"""
|
| 106 |
+
if plt is None:
|
| 107 |
+
raise ImportError("matplotlib required — install with: pip install matplotlib")
|
| 108 |
+
|
| 109 |
+
if output_dir is None:
|
| 110 |
+
output_dir = Path("data_artifacts/results/plots/")
|
| 111 |
+
output_dir = Path(output_dir)
|
| 112 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 113 |
+
|
| 114 |
+
summary = df.groupby("ablation").agg(
|
| 115 |
+
diagnosis_acc=("diagnosis_correct", "mean"),
|
| 116 |
+
total_reward=("total_reward", "mean"),
|
| 117 |
+
action_qual=("action_quality", "mean"),
|
| 118 |
+
).reset_index()
|
| 119 |
+
|
| 120 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 121 |
+
|
| 122 |
+
axes[0].bar(list(summary["ablation"]), list(summary["diagnosis_acc"]))
|
| 123 |
+
axes[0].set_title("Diagnosis Accuracy by Ablation")
|
| 124 |
+
axes[0].set_ylabel("Diagnosis Accuracy")
|
| 125 |
+
axes[0].set_xlabel("Ablation")
|
| 126 |
+
axes[0].tick_params(axis="x", rotation=30)
|
| 127 |
+
|
| 128 |
+
axes[1].bar(list(summary["ablation"]), list(summary["total_reward"]))
|
| 129 |
+
axes[1].set_title("Mean Total Reward by Ablation")
|
| 130 |
+
axes[1].set_ylabel("Mean Total Reward")
|
| 131 |
+
axes[1].set_xlabel("Ablation")
|
| 132 |
+
axes[1].tick_params(axis="x", rotation=30)
|
| 133 |
+
|
| 134 |
+
fig.tight_layout()
|
| 135 |
+
fig.savefig(output_dir / "ablation_summary.png", dpi=120)
|
| 136 |
+
plt.close(fig)
|
src/ci_triage_env/training/finalize_readme.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Populate README's [FILL POST-TRAIN] markers after training is complete.
|
| 2 |
+
|
| 3 |
+
Run after eval.csv, ablations.csv, and plots/ are all present:
|
| 4 |
+
python -c "from ci_triage_env.training.finalize_readme import populate_readme; populate_readme()"
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def populate_readme(
|
| 13 |
+
template_path: Path = Path("README.md"),
|
| 14 |
+
eval_csv: Path = Path("data_artifacts/results/eval.csv"),
|
| 15 |
+
ablation_csv: Path = Path("data_artifacts/results/ablations.csv"),
|
| 16 |
+
plots_dir: Path = Path("data_artifacts/results/plots/"),
|
| 17 |
+
) -> int:
|
| 18 |
+
"""Fill [FILL …] markers in README.md in-place.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
template_path: Path to README.md.
|
| 22 |
+
eval_csv: Path to the master eval CSV from Phase C5.
|
| 23 |
+
ablation_csv: Path to the ablation results CSV from Phase C6.
|
| 24 |
+
plots_dir: Directory containing PNG plots.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Number of markers replaced.
|
| 28 |
+
"""
|
| 29 |
+
import pandas as pd
|
| 30 |
+
|
| 31 |
+
from ci_triage_env.training.readme_table import generate_results_table
|
| 32 |
+
|
| 33 |
+
text = template_path.read_text()
|
| 34 |
+
replaced = 0
|
| 35 |
+
|
| 36 |
+
# 1. Results table
|
| 37 |
+
if eval_csv.exists():
|
| 38 |
+
df_eval = pd.read_csv(eval_csv)
|
| 39 |
+
table_md = generate_results_table(df_eval)
|
| 40 |
+
marker = "[FILL: 5-row × 6-metric table]"
|
| 41 |
+
if marker in text:
|
| 42 |
+
text = text.replace(marker, table_md)
|
| 43 |
+
replaced += 1
|
| 44 |
+
|
| 45 |
+
# 2. Embed plot images — replace [FILL: <stem with spaces>] with markdown img tags
|
| 46 |
+
if plots_dir.exists():
|
| 47 |
+
for png in sorted(plots_dir.glob("*.png")):
|
| 48 |
+
stem_words = png.stem.replace("_", " ")
|
| 49 |
+
marker = f"[FILL: {stem_words}]"
|
| 50 |
+
rel = png.relative_to(template_path.parent)
|
| 51 |
+
embed = f""
|
| 52 |
+
if marker in text:
|
| 53 |
+
text = text.replace(marker, embed)
|
| 54 |
+
replaced += 1
|
| 55 |
+
|
| 56 |
+
# 3. Remove any remaining generic [FILL POST-TRAIN] or [FILL] markers
|
| 57 |
+
# by replacing them with a placeholder so the README stays valid.
|
| 58 |
+
import re
|
| 59 |
+
generic = re.compile(r"\[FILL[^\]]*\]")
|
| 60 |
+
remaining = generic.findall(text)
|
| 61 |
+
if remaining:
|
| 62 |
+
print(f"WARNING: {len(remaining)} unfilled marker(s) remain: {remaining[:5]}")
|
| 63 |
+
|
| 64 |
+
template_path.write_text(text)
|
| 65 |
+
return replaced
|
src/ci_triage_env/training/grpo.py
CHANGED
|
@@ -34,6 +34,7 @@ def run_grpo(
|
|
| 34 |
env_url: str = "http://localhost:8000",
|
| 35 |
scenarios_train_path: str = "data_artifacts/scenarios/train",
|
| 36 |
hyperparams: dict | None = None,
|
|
|
|
| 37 |
) -> str:
|
| 38 |
"""Launch GRPO fine-tuning from an SFT checkpoint.
|
| 39 |
|
|
@@ -46,6 +47,8 @@ def run_grpo(
|
|
| 46 |
env_url: URL of the running env server (used only when env_client is None).
|
| 47 |
scenarios_train_path: Directory containing train split scenario JSON files.
|
| 48 |
hyperparams: Override specific GRPO hyperparameters. Merged over GRPO_HYPERPARAMS.
|
|
|
|
|
|
|
| 49 |
|
| 50 |
Returns:
|
| 51 |
output_dir path as a string.
|
|
@@ -64,6 +67,7 @@ def run_grpo(
|
|
| 64 |
rollout = TrainingRollout(
|
| 65 |
env_client=env_client,
|
| 66 |
scenarios_train=scenario_ids,
|
|
|
|
| 67 |
)
|
| 68 |
|
| 69 |
model, tokenizer = load_model_for_sft(model_name=sft_checkpoint_dir)
|
|
|
|
| 34 |
env_url: str = "http://localhost:8000",
|
| 35 |
scenarios_train_path: str = "data_artifacts/scenarios/train",
|
| 36 |
hyperparams: dict | None = None,
|
| 37 |
+
weights_override: dict | None = None,
|
| 38 |
) -> str:
|
| 39 |
"""Launch GRPO fine-tuning from an SFT checkpoint.
|
| 40 |
|
|
|
|
| 47 |
env_url: URL of the running env server (used only when env_client is None).
|
| 48 |
scenarios_train_path: Directory containing train split scenario JSON files.
|
| 49 |
hyperparams: Override specific GRPO hyperparameters. Merged over GRPO_HYPERPARAMS.
|
| 50 |
+
weights_override: Override reward component weights passed to CompositeReward.
|
| 51 |
+
Used by ablation runs to zero out individual reward terms.
|
| 52 |
|
| 53 |
Returns:
|
| 54 |
output_dir path as a string.
|
|
|
|
| 67 |
rollout = TrainingRollout(
|
| 68 |
env_client=env_client,
|
| 69 |
scenarios_train=scenario_ids,
|
| 70 |
+
weights=weights_override,
|
| 71 |
)
|
| 72 |
|
| 73 |
model, tokenizer = load_model_for_sft(model_name=sft_checkpoint_dir)
|
src/ci_triage_env/training/rollout.py
CHANGED
|
@@ -47,10 +47,12 @@ class TrainingRollout:
|
|
| 47 |
env_client,
|
| 48 |
scenarios_train: list[str] | None = None,
|
| 49 |
max_turns: int = 12,
|
|
|
|
| 50 |
) -> None:
|
| 51 |
self.env = env_client
|
| 52 |
self.scenarios_train: list[str] = scenarios_train or []
|
| 53 |
self.max_turns = max_turns
|
|
|
|
| 54 |
self._quarantine_window: list[str] = []
|
| 55 |
self._tools_listing: str | None = None
|
| 56 |
|
|
@@ -114,9 +116,10 @@ class TrainingRollout:
|
|
| 114 |
|
| 115 |
trace = self.env.get_trace(episode_id)
|
| 116 |
scenario = self._load_scenario(trace, episode_id)
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
| 120 |
|
| 121 |
# Update rolling quarantine window (last 50 secondary actions)
|
| 122 |
if trace.episode.final_action:
|
|
|
|
| 47 |
env_client,
|
| 48 |
scenarios_train: list[str] | None = None,
|
| 49 |
max_turns: int = 12,
|
| 50 |
+
weights: dict | None = None,
|
| 51 |
) -> None:
|
| 52 |
self.env = env_client
|
| 53 |
self.scenarios_train: list[str] = scenarios_train or []
|
| 54 |
self.max_turns = max_turns
|
| 55 |
+
self.weights = weights # None → CompositeReward uses frozen defaults
|
| 56 |
self._quarantine_window: list[str] = []
|
| 57 |
self._tools_listing: str | None = None
|
| 58 |
|
|
|
|
| 116 |
|
| 117 |
trace = self.env.get_trace(episode_id)
|
| 118 |
scenario = self._load_scenario(trace, episode_id)
|
| 119 |
+
reward_kwargs: dict = {"quarantine_window": self._quarantine_window}
|
| 120 |
+
if self.weights is not None:
|
| 121 |
+
reward_kwargs["weights"] = self.weights
|
| 122 |
+
reward = compute_reward(trace, scenario, **reward_kwargs)
|
| 123 |
|
| 124 |
# Update rolling quarantine window (last 50 secondary actions)
|
| 125 |
if trace.episode.final_action:
|
tests/training/test_ablations.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for Phase C6 — ablations, curves, and readme finalization (no GPU)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from unittest.mock import MagicMock, patch
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from ci_triage_env.training.ablations import ABLATIONS, run_ablation
|
| 11 |
+
from ci_triage_env.training.finalize_readme import populate_readme
|
| 12 |
+
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
# ABLATIONS dict
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_ablations_dict_has_4_entries() -> None:
|
| 19 |
+
assert len(ABLATIONS) == 4 # counterfactual deferred to v2
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_each_ablation_zeros_exactly_one_weight() -> None:
|
| 23 |
+
for name, overrides in ABLATIONS.items():
|
| 24 |
+
zeroed = [k for k, v in overrides.items() if v == 0.0]
|
| 25 |
+
assert len(zeroed) == 1, f"Ablation '{name}' should zero exactly 1 weight, got {zeroed}"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_ablation_names_reference_valid_reward_keys() -> None:
|
| 29 |
+
from ci_triage_env.rewards.weights import REWARD_WEIGHTS
|
| 30 |
+
|
| 31 |
+
for name, overrides in ABLATIONS.items():
|
| 32 |
+
for key in overrides:
|
| 33 |
+
assert key in REWARD_WEIGHTS, (
|
| 34 |
+
f"Ablation '{name}' references unknown reward key '{key}'"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# run_ablation smoke (mock run_grpo + Evaluator)
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_run_ablation_smoke() -> None:
|
| 44 |
+
"""Mock run_grpo and Evaluator; verify run_ablation returns a DataFrame."""
|
| 45 |
+
fake_df = pd.DataFrame({
|
| 46 |
+
"baseline": ["random", "heuristic", "trained"],
|
| 47 |
+
"scenario_id": ["s1", "s1", "s1"],
|
| 48 |
+
"family": ["real_bug"] * 3,
|
| 49 |
+
"difficulty": ["easy"] * 3,
|
| 50 |
+
"seed": [1, 1, 1],
|
| 51 |
+
"total_reward": [0.1, 0.5, 0.7],
|
| 52 |
+
"format_gate": [True] * 3,
|
| 53 |
+
"diagnosis_correct": [False, True, True],
|
| 54 |
+
"predicted_diagnosis": ["ambiguous", "real_bug", "real_bug"],
|
| 55 |
+
"true_diagnosis": ["real_bug"] * 3,
|
| 56 |
+
"action_quality": [0.0, 0.3, 0.5],
|
| 57 |
+
"tool_call_count": [3, 4, 5],
|
| 58 |
+
"total_cost": [0.03, 0.04, 0.05],
|
| 59 |
+
"confidence": [0.5, 0.8, 0.9],
|
| 60 |
+
"is_ambiguous_scenario": [False] * 3,
|
| 61 |
+
"brier_on_ambiguous": [None] * 3,
|
| 62 |
+
})
|
| 63 |
+
|
| 64 |
+
mock_evaluator = MagicMock()
|
| 65 |
+
mock_evaluator.run_all.return_value = fake_df
|
| 66 |
+
|
| 67 |
+
with patch("ci_triage_env.training.ablations.run_grpo") as mock_grpo, \
|
| 68 |
+
patch("ci_triage_env.training.ablations.Evaluator", return_value=mock_evaluator):
|
| 69 |
+
mock_grpo.return_value = "checkpoints/ablation_test/"
|
| 70 |
+
result = run_ablation(
|
| 71 |
+
"no_diagnosis",
|
| 72 |
+
{"diagnosis": 0.0},
|
| 73 |
+
total_steps=10,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
assert isinstance(result, pd.DataFrame)
|
| 77 |
+
assert "ablation" in result.columns
|
| 78 |
+
assert (result["ablation"] == "no_diagnosis").all()
|
| 79 |
+
mock_grpo.assert_called_once()
|
| 80 |
+
call_kwargs = mock_grpo.call_args.kwargs
|
| 81 |
+
assert call_kwargs["weights_override"]["diagnosis"] == 0.0
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def test_run_ablation_passes_weights_to_grpo() -> None:
|
| 85 |
+
"""Confirm the merged weights dict reaches run_grpo."""
|
| 86 |
+
from ci_triage_env.rewards.weights import REWARD_WEIGHTS
|
| 87 |
+
|
| 88 |
+
fake_df = pd.DataFrame({"baseline": [], "total_reward": [], "diagnosis_correct": [],
|
| 89 |
+
"scenario_id": [], "family": [], "difficulty": [], "seed": [],
|
| 90 |
+
"format_gate": [], "predicted_diagnosis": [], "true_diagnosis": [],
|
| 91 |
+
"action_quality": [], "tool_call_count": [], "total_cost": [],
|
| 92 |
+
"confidence": [], "is_ambiguous_scenario": [], "brier_on_ambiguous": []})
|
| 93 |
+
mock_evaluator = MagicMock()
|
| 94 |
+
mock_evaluator.run_all.return_value = fake_df
|
| 95 |
+
|
| 96 |
+
with patch("ci_triage_env.training.ablations.run_grpo") as mock_grpo, \
|
| 97 |
+
patch("ci_triage_env.training.ablations.Evaluator", return_value=mock_evaluator):
|
| 98 |
+
mock_grpo.return_value = "checkpoints/ablation_no_anti_gaming/"
|
| 99 |
+
run_ablation("no_anti_gaming", {"anti_gaming": 0.0}, total_steps=5)
|
| 100 |
+
|
| 101 |
+
weights_sent = mock_grpo.call_args.kwargs["weights_override"]
|
| 102 |
+
assert weights_sent["anti_gaming"] == 0.0
|
| 103 |
+
# All other weights preserved from REWARD_WEIGHTS
|
| 104 |
+
for k, v in REWARD_WEIGHTS.items():
|
| 105 |
+
if k != "anti_gaming":
|
| 106 |
+
assert weights_sent[k] == v
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
# plot_ablation_summary (matplotlib mocked)
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def test_plot_ablation_summary_writes_png(tmp_path: Path) -> None:
|
| 115 |
+
df = pd.DataFrame({
|
| 116 |
+
"ablation": ["no_diagnosis", "no_action_quality", "no_diagnosis", "no_action_quality"],
|
| 117 |
+
"baseline": ["random", "random", "heuristic", "heuristic"],
|
| 118 |
+
"diagnosis_correct": [0.2, 0.4, 0.5, 0.6],
|
| 119 |
+
"total_reward": [0.1, 0.3, 0.4, 0.5],
|
| 120 |
+
"action_quality": [0.0, 0.2, 0.3, 0.4],
|
| 121 |
+
})
|
| 122 |
+
|
| 123 |
+
saved: list[str] = []
|
| 124 |
+
mock_fig = MagicMock()
|
| 125 |
+
mock_axes = [MagicMock(), MagicMock()]
|
| 126 |
+
mock_plt = MagicMock()
|
| 127 |
+
mock_plt.subplots.return_value = (mock_fig, mock_axes)
|
| 128 |
+
mock_sns = MagicMock()
|
| 129 |
+
|
| 130 |
+
def _record(path, **kwargs):
|
| 131 |
+
saved.append(str(path))
|
| 132 |
+
|
| 133 |
+
mock_fig.savefig.side_effect = _record
|
| 134 |
+
|
| 135 |
+
with patch("ci_triage_env.training.curves.plt", mock_plt), \
|
| 136 |
+
patch("ci_triage_env.training.curves.sns", mock_sns):
|
| 137 |
+
from ci_triage_env.training.curves import plot_ablation_summary
|
| 138 |
+
plot_ablation_summary(df, output_dir=tmp_path / "plots")
|
| 139 |
+
|
| 140 |
+
assert len(saved) >= 1
|
| 141 |
+
assert any("ablation_summary" in s for s in saved)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ---------------------------------------------------------------------------
|
| 145 |
+
# populate_readme
|
| 146 |
+
# ---------------------------------------------------------------------------
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def test_finalize_readme_replaces_table_marker(tmp_path: Path) -> None:
|
| 150 |
+
readme = tmp_path / "README.md"
|
| 151 |
+
readme.write_text(
|
| 152 |
+
"# Results\n\n[FILL: 5-row × 6-metric table]\n\nMore text.\n"
|
| 153 |
+
)
|
| 154 |
+
eval_csv = tmp_path / "eval.csv"
|
| 155 |
+
pd.DataFrame({
|
| 156 |
+
"baseline": ["random", "heuristic"],
|
| 157 |
+
"diagnosis_correct": [0.3, 0.6],
|
| 158 |
+
"action_quality": [0.1, 0.4],
|
| 159 |
+
"total_cost": [0.05, 0.03],
|
| 160 |
+
"tool_call_count": [4, 4],
|
| 161 |
+
"total_reward": [0.2, 0.5],
|
| 162 |
+
}).to_csv(eval_csv, index=False)
|
| 163 |
+
|
| 164 |
+
# No ablation csv; no plots dir
|
| 165 |
+
n = populate_readme(
|
| 166 |
+
template_path=readme,
|
| 167 |
+
eval_csv=eval_csv,
|
| 168 |
+
ablation_csv=tmp_path / "ablations.csv",
|
| 169 |
+
plots_dir=tmp_path / "plots",
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
result = readme.read_text()
|
| 173 |
+
assert "[FILL: 5-row × 6-metric table]" not in result
|
| 174 |
+
assert "|" in result # table was inserted
|
| 175 |
+
assert n >= 1
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def test_finalize_readme_embeds_plot_images(tmp_path: Path) -> None:
|
| 179 |
+
readme = tmp_path / "README.md"
|
| 180 |
+
readme.write_text("# Plots\n\n[FILL: diagnosis accuracy]\n")
|
| 181 |
+
|
| 182 |
+
plots_dir = tmp_path / "plots"
|
| 183 |
+
plots_dir.mkdir()
|
| 184 |
+
(plots_dir / "diagnosis_accuracy.png").touch()
|
| 185 |
+
|
| 186 |
+
n = populate_readme(
|
| 187 |
+
template_path=readme,
|
| 188 |
+
eval_csv=tmp_path / "eval.csv",
|
| 189 |
+
ablation_csv=tmp_path / "ablations.csv",
|
| 190 |
+
plots_dir=plots_dir,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
result = readme.read_text()
|
| 194 |
+
assert "[FILL: diagnosis accuracy]" not in result
|
| 195 |
+
assert "diagnosis_accuracy.png" in result
|
| 196 |
+
assert n >= 1
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def test_finalize_readme_missing_csv_does_not_crash(tmp_path: Path) -> None:
|
| 200 |
+
readme = tmp_path / "README.md"
|
| 201 |
+
readme.write_text("# CI Triage\n\nNo markers here.\n")
|
| 202 |
+
|
| 203 |
+
n = populate_readme(
|
| 204 |
+
template_path=readme,
|
| 205 |
+
eval_csv=tmp_path / "nonexistent.csv",
|
| 206 |
+
ablation_csv=tmp_path / "nonexistent2.csv",
|
| 207 |
+
plots_dir=tmp_path / "no_plots",
|
| 208 |
+
)
|
| 209 |
+
assert n == 0
|
| 210 |
+
assert readme.read_text() == "# CI Triage\n\nNo markers here.\n"
|