Prasham.Jain Claude Sonnet 4.6 commited on
Commit
e46f00b
·
1 Parent(s): 93e68bc

feat(training): Phase C6 — ablations, training curves, readme finalization

Browse files

Adds 4-entry ablation matrix (no_diagnosis, no_action_quality, no_investigation,
no_anti_gaming), each running 1000-step GRPO with one reward component zeroed.
Extends run_grpo/TrainingRollout with weights_override to thread ablation weights
into CompositeReward. Adds W&B curve plotting, ablation summary charts,
populate_readme for [FILL] marker replacement, and Colab eval notebook.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

notebooks/eval.ipynb ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "cell-0",
6
+ "metadata": {},
7
+ "source": [
8
+ "# CI-Triage-Env — Evaluation & Ablation Notebook\n",
9
+ "\n",
10
+ "Colab-runnable notebook for judges to reproduce all results.\n",
11
+ "\n",
12
+ "Steps:\n",
13
+ "1. Install dependencies\n",
14
+ "2. Load trained checkpoint from HF Hub\n",
15
+ "3. Run full 5-baseline evaluation\n",
16
+ "4. Generate all metric plots\n",
17
+ "5. Run reward-layer ablations (optional, GPU, ~5h)\n",
18
+ "6. Populate README with results\n",
19
+ "\n",
20
+ "**Prerequisites**: `HF_TOKEN`, `WANDB_API_KEY` as Colab secrets."
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "cell-1",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "# Cell 1: Install dependencies\n",
31
+ "!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n",
32
+ "!pip install -q unsloth trl transformers accelerate peft\n",
33
+ "!pip install -q wandb datasets huggingface_hub openai httpx fastapi uvicorn pydantic jsonschema\n",
34
+ "!pip install -q matplotlib seaborn pandas tabulate\n",
35
+ "!pip install -q -e . # install ci_triage_env package"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "cell-2",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "# Cell 2: Environment setup\n",
46
+ "import os\n",
47
+ "from google.colab import userdata\n",
48
+ "\n",
49
+ "os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n",
50
+ "os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')\n",
51
+ "os.environ['WANDB_PROJECT'] = 'ci-triage-env'\n",
52
+ "\n",
53
+ "# Config — replace with your values\n",
54
+ "HF_DATASET_REPO = 'YOUR_ORG/ci-triage-scenarios'\n",
55
+ "HF_MODEL_REPO = 'YOUR_ORG/ci-triage-trained-qwen3.5-4b'\n",
56
+ "WANDB_RUN_ID = 'YOUR_WANDB_RUN_ID' # e.g. 'entity/ci-triage-env/abc123'"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "id": "cell-3",
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "# Cell 3: Download scenario corpus and trained checkpoint\n",
67
+ "from huggingface_hub import snapshot_download\n",
68
+ "\n",
69
+ "scen_dir = snapshot_download(\n",
70
+ " HF_DATASET_REPO, repo_type='dataset',\n",
71
+ " local_dir='data_artifacts/scenarios'\n",
72
+ ")\n",
73
+ "ckpt_dir = snapshot_download(\n",
74
+ " HF_MODEL_REPO, repo_type='model',\n",
75
+ " local_dir='checkpoints/grpo_full'\n",
76
+ ")\n",
77
+ "print(f'Scenarios: {scen_dir}')\n",
78
+ "print(f'Checkpoint: {ckpt_dir}')"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "cell-4",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "# Cell 4: Start env server in background\n",
89
+ "import subprocess, time\n",
90
+ "server_proc = subprocess.Popen(\n",
91
+ " ['python', '-m', 'ci_triage_env.env.server'],\n",
92
+ " stdout=subprocess.PIPE, stderr=subprocess.PIPE\n",
93
+ ")\n",
94
+ "time.sleep(4)\n",
95
+ "print('Env server started, PID:', server_proc.pid)"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "id": "cell-5",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "# Cell 5: Run full 5-baseline evaluation\n",
106
+ "from pathlib import Path\n",
107
+ "import pandas as pd\n",
108
+ "from ci_triage_env.training.eval import Evaluator\n",
109
+ "from ci_triage_env.training.plotting import plot_all_eval_metrics\n",
110
+ "\n",
111
+ "evaluator = Evaluator(\n",
112
+ " eval_set_path='data_artifacts/scenarios/held_out/',\n",
113
+ " trained_checkpoint='checkpoints/grpo_full/',\n",
114
+ ")\n",
115
+ "df_eval = evaluator.run_all(seeds=[1, 2, 3])\n",
116
+ "\n",
117
+ "out = Path('data_artifacts/results/')\n",
118
+ "out.mkdir(parents=True, exist_ok=True)\n",
119
+ "df_eval.to_csv(out / 'eval.csv', index=False)\n",
120
+ "\n",
121
+ "print(df_eval.groupby('baseline').agg({\n",
122
+ " 'diagnosis_correct': 'mean',\n",
123
+ " 'total_reward': 'mean',\n",
124
+ " 'tool_call_count': 'mean',\n",
125
+ "}))\n",
126
+ "\n",
127
+ "plot_all_eval_metrics(df_eval, out / 'plots/')\n",
128
+ "print('Plots saved to data_artifacts/results/plots/')"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "cell-6",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "# Cell 6: Pull training curves from W&B\n",
139
+ "from ci_triage_env.training.curves import plot_training_curves_from_wandb\n",
140
+ "\n",
141
+ "plot_training_curves_from_wandb(\n",
142
+ " run_id=WANDB_RUN_ID,\n",
143
+ " output_dir=Path('data_artifacts/results/plots/'),\n",
144
+ ")\n",
145
+ "print('Training curves saved.')"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "id": "cell-7",
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "# Cell 7: Run reward-layer ablations (~5h on A100; set RUN_ABLATIONS=True to enable)\n",
156
+ "RUN_ABLATIONS = False\n",
157
+ "\n",
158
+ "if RUN_ABLATIONS:\n",
159
+ " from ci_triage_env.training.ablations import ABLATIONS, run_ablation\n",
160
+ " from ci_triage_env.training.curves import plot_ablation_summary\n",
161
+ "\n",
162
+ " abl_results = []\n",
163
+ " for name, overrides in ABLATIONS.items():\n",
164
+ " print(f'=== Ablation: {name} ===')\n",
165
+ " df_abl = run_ablation(name, overrides, total_steps=1000)\n",
166
+ " abl_results.append(df_abl)\n",
167
+ " print(df_abl.groupby('baseline')['diagnosis_correct'].mean())\n",
168
+ "\n",
169
+ " df_full_abl = pd.concat(abl_results, ignore_index=True)\n",
170
+ " df_full_abl.to_csv(out / 'ablations.csv', index=False)\n",
171
+ " plot_ablation_summary(df_full_abl, output_dir=out / 'plots/')\n",
172
+ " print('Ablations saved.')\n",
173
+ "else:\n",
174
+ " print('Ablations skipped (set RUN_ABLATIONS=True to run).')"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "cell-8",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "# Cell 8: Populate README with results\n",
185
+ "from ci_triage_env.training.finalize_readme import populate_readme\n",
186
+ "\n",
187
+ "n = populate_readme(\n",
188
+ " eval_csv=out / 'eval.csv',\n",
189
+ " ablation_csv=out / 'ablations.csv',\n",
190
+ " plots_dir=out / 'plots/',\n",
191
+ ")\n",
192
+ "print(f'Replaced {n} markers in README.md')\n",
193
+ "\n",
194
+ "# Check for any remaining unfilled markers\n",
195
+ "import subprocess\n",
196
+ "result = subprocess.run(['grep', '-c', r'\\[FILL', 'README.md'], capture_output=True, text=True)\n",
197
+ "remaining = int(result.stdout.strip() or 0)\n",
198
+ "if remaining:\n",
199
+ " print(f'WARNING: {remaining} unfilled [FILL] marker(s) remain in README.md')\n",
200
+ "else:\n",
201
+ " print('README.md is clean — no unfilled markers.')"
202
+ ]
203
+ }
204
+ ],
205
+ "metadata": {
206
+ "kernelspec": {
207
+ "display_name": "Python 3",
208
+ "language": "python",
209
+ "name": "python3"
210
+ },
211
+ "language_info": {
212
+ "name": "python",
213
+ "version": "3.11.0"
214
+ }
215
+ },
216
+ "nbformat": 4,
217
+ "nbformat_minor": 5
218
+ }
src/ci_triage_env/training/ablations.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward-layer ablation runs for CI-Triage-Env.
2
+
3
+ Each ablation zeroes one reward component weight, runs 1000-step GRPO from the
4
+ SFT checkpoint, then evaluates against the held-out set. Results are saved to
5
+ data_artifacts/results/ablations.csv.
6
+
7
+ All GPU-heavy imports are lazy; this module is importable without GPU.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from ci_triage_env.rewards.weights import REWARD_WEIGHTS
13
+ from ci_triage_env.training.eval import Evaluator
14
+ from ci_triage_env.training.grpo import run_grpo
15
+
16
+ ABLATIONS: dict[str, dict[str, float]] = {
17
+ "no_diagnosis": {"diagnosis": 0.0},
18
+ "no_action_quality": {"action_quality": 0.0},
19
+ "no_investigation": {"investigation": 0.0},
20
+ "no_anti_gaming": {"anti_gaming": 0.0},
21
+ # counterfactual ablation deferred to v2 — its weight is already 0.0 in v1.
22
+ }
23
+
24
+
25
+ def run_ablation(
26
+ name: str,
27
+ weight_overrides: dict[str, float],
28
+ total_steps: int = 1000,
29
+ base_sft_checkpoint: str = "checkpoints/sft/",
30
+ env_client=None,
31
+ ):
32
+ """Run one ablation: train GRPO with modified weights then evaluate.
33
+
34
+ Args:
35
+ name: Ablation name (used as output directory suffix and CSV column).
36
+ weight_overrides: Keys/values to merge over REWARD_WEIGHTS (zeroed components).
37
+ total_steps: GRPO training steps for this ablation (default 1000, not 3000).
38
+ base_sft_checkpoint: SFT warmstart checkpoint to train from.
39
+ env_client: Optional env client override (MockEnvClient for testing).
40
+
41
+ Returns:
42
+ pandas DataFrame with eval results; has an extra "ablation" column.
43
+ """
44
+
45
+ new_weights = {**REWARD_WEIGHTS, **weight_overrides}
46
+ output_dir = f"checkpoints/ablation_{name}/"
47
+
48
+ run_grpo(
49
+ sft_checkpoint_dir=base_sft_checkpoint,
50
+ output_dir=output_dir,
51
+ total_steps=total_steps,
52
+ weights_override=new_weights,
53
+ env_client=env_client,
54
+ )
55
+
56
+ evaluator = Evaluator(trained_checkpoint=output_dir, env_client=env_client)
57
+ evaluator.BASELINES = ["random", "heuristic", "trained"]
58
+ df = evaluator.run_all(seeds=[1])
59
+ df["ablation"] = name
60
+ return df
61
+
62
+
63
+ def main(argv=None) -> None:
64
+ import argparse
65
+ from pathlib import Path
66
+
67
+ import pandas as pd
68
+
69
+ from ci_triage_env.training.curves import plot_ablation_summary
70
+
71
+ parser = argparse.ArgumentParser(description="Run reward-layer ablation matrix")
72
+ parser.add_argument("--steps", type=int, default=1000)
73
+ parser.add_argument("--sft-checkpoint", default="checkpoints/sft/")
74
+ parser.add_argument("--output", default="data_artifacts/results/")
75
+ args = parser.parse_args(argv)
76
+
77
+ all_results = []
78
+ for name, overrides in ABLATIONS.items():
79
+ print(f"=== Ablation: {name} ===")
80
+ df = run_ablation(
81
+ name, overrides,
82
+ total_steps=args.steps,
83
+ base_sft_checkpoint=args.sft_checkpoint,
84
+ )
85
+ all_results.append(df)
86
+ print(df.groupby("baseline")["diagnosis_correct"].mean())
87
+
88
+ full = pd.concat(all_results, ignore_index=True)
89
+ out = Path(args.output)
90
+ out.mkdir(parents=True, exist_ok=True)
91
+ full.to_csv(out / "ablations.csv", index=False)
92
+ print(f"\nAblations saved to {out / 'ablations.csv'}")
93
+
94
+ plot_ablation_summary(full, output_dir=out / "plots/")
95
+
96
+
97
+ if __name__ == "__main__":
98
+ main()
src/ci_triage_env/training/curves.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training curve plots and ablation summary charts for CI-Triage-Env.
2
+
3
+ Module-level try/except lets the module import without matplotlib; tests patch
4
+ `ci_triage_env.training.curves.plt` and `.sns` directly.
5
+ W&B imports are lazy (inside functions) since wandb is Colab-only.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from pathlib import Path
11
+
12
+ try:
13
+ import matplotlib.pyplot as plt # type: ignore[import]
14
+ import seaborn as sns # type: ignore[import]
15
+ except ImportError:
16
+ plt = None # type: ignore[assignment]
17
+ sns = None # type: ignore[assignment]
18
+
19
+
20
+ def plot_training_curves_from_wandb(run_id: str, output_dir: Path) -> None:
21
+ """Pull per-step metrics from W&B and write training curve PNGs.
22
+
23
+ Args:
24
+ run_id: W&B run path, e.g. "<entity>/ci-triage-env/<run_id>".
25
+ output_dir: Directory to write PNG files.
26
+ """
27
+ if plt is None:
28
+ raise ImportError("matplotlib required — install with: pip install matplotlib")
29
+
30
+ import wandb # type: ignore[import]
31
+
32
+ output_dir = Path(output_dir)
33
+ output_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ api = wandb.Api()
36
+ run = api.run(run_id)
37
+ history = run.history(samples=10000)
38
+
39
+ # 1. Reward curve (raw + smoothed)
40
+ fig, ax = plt.subplots(figsize=(10, 5))
41
+ ax.plot(history["step"], history["mean_reward"], alpha=0.4, label="raw")
42
+ smooth = history["mean_reward"].rolling(window=20, min_periods=1).mean()
43
+ ax.plot(history["step"], smooth, label="smoothed (window=20)", linewidth=2)
44
+ ax.set_xlabel("Training step")
45
+ ax.set_ylabel("Mean episode reward")
46
+ ax.set_title("GRPO training: reward over time")
47
+ ax.legend()
48
+ ax.grid(alpha=0.3)
49
+ fig.tight_layout()
50
+ fig.savefig(output_dir / "training_reward_curve.png", dpi=120)
51
+ plt.close(fig)
52
+
53
+ # 2. Per-component reward breakdown
54
+ component_keys = [
55
+ "reward/diagnosis", "reward/action_quality", "reward/cost_efficiency",
56
+ "reward/investigation", "reward/anti_gaming",
57
+ ]
58
+ fig, ax = plt.subplots(figsize=(10, 5))
59
+ for k in component_keys:
60
+ if k in history.columns:
61
+ smooth = history[k].rolling(window=20, min_periods=1).mean()
62
+ ax.plot(history["step"], smooth, label=k.replace("reward/", ""))
63
+ ax.set_xlabel("Training step")
64
+ ax.set_ylabel("Mean component reward (smoothed)")
65
+ ax.set_title("Per-component reward breakdown over training")
66
+ ax.legend()
67
+ ax.grid(alpha=0.3)
68
+ fig.tight_layout()
69
+ fig.savefig(output_dir / "per_component_curve.png", dpi=120)
70
+ plt.close(fig)
71
+
72
+ # 3. KL divergence
73
+ if "kl" in history.columns:
74
+ fig, ax = plt.subplots(figsize=(10, 4))
75
+ ax.plot(history["step"], history["kl"])
76
+ ax.set_xlabel("Training step")
77
+ ax.set_ylabel("KL(policy || ref)")
78
+ ax.set_title("KL divergence to SFT reference")
79
+ ax.grid(alpha=0.3)
80
+ fig.tight_layout()
81
+ fig.savefig(output_dir / "kl_divergence.png", dpi=120)
82
+ plt.close(fig)
83
+
84
+ # 4. Format-gate pass rate
85
+ if "format_gate_pass_rate" in history.columns:
86
+ fig, ax = plt.subplots(figsize=(10, 4))
87
+ ax.plot(history["step"], history["format_gate_pass_rate"])
88
+ ax.set_ylim(0, 1)
89
+ ax.set_xlabel("Training step")
90
+ ax.set_ylabel("Format gate pass rate")
91
+ ax.set_title("Fraction of episodes passing format gate")
92
+ ax.grid(alpha=0.3)
93
+ fig.tight_layout()
94
+ fig.savefig(output_dir / "format_gate.png", dpi=120)
95
+ plt.close(fig)
96
+
97
+
98
+ def plot_ablation_summary(df, output_dir=None) -> None:
99
+ """Bar chart comparing diagnosis accuracy and total reward across ablations.
100
+
101
+ Args:
102
+ df: DataFrame with columns [ablation, baseline, diagnosis_correct, total_reward, ...].
103
+ output_dir: Directory to write ablation_summary.png. Defaults to
104
+ data_artifacts/results/plots/.
105
+ """
106
+ if plt is None:
107
+ raise ImportError("matplotlib required — install with: pip install matplotlib")
108
+
109
+ if output_dir is None:
110
+ output_dir = Path("data_artifacts/results/plots/")
111
+ output_dir = Path(output_dir)
112
+ output_dir.mkdir(parents=True, exist_ok=True)
113
+
114
+ summary = df.groupby("ablation").agg(
115
+ diagnosis_acc=("diagnosis_correct", "mean"),
116
+ total_reward=("total_reward", "mean"),
117
+ action_qual=("action_quality", "mean"),
118
+ ).reset_index()
119
+
120
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
121
+
122
+ axes[0].bar(list(summary["ablation"]), list(summary["diagnosis_acc"]))
123
+ axes[0].set_title("Diagnosis Accuracy by Ablation")
124
+ axes[0].set_ylabel("Diagnosis Accuracy")
125
+ axes[0].set_xlabel("Ablation")
126
+ axes[0].tick_params(axis="x", rotation=30)
127
+
128
+ axes[1].bar(list(summary["ablation"]), list(summary["total_reward"]))
129
+ axes[1].set_title("Mean Total Reward by Ablation")
130
+ axes[1].set_ylabel("Mean Total Reward")
131
+ axes[1].set_xlabel("Ablation")
132
+ axes[1].tick_params(axis="x", rotation=30)
133
+
134
+ fig.tight_layout()
135
+ fig.savefig(output_dir / "ablation_summary.png", dpi=120)
136
+ plt.close(fig)
src/ci_triage_env/training/finalize_readme.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Populate README's [FILL POST-TRAIN] markers after training is complete.
2
+
3
+ Run after eval.csv, ablations.csv, and plots/ are all present:
4
+ python -c "from ci_triage_env.training.finalize_readme import populate_readme; populate_readme()"
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from pathlib import Path
10
+
11
+
12
+ def populate_readme(
13
+ template_path: Path = Path("README.md"),
14
+ eval_csv: Path = Path("data_artifacts/results/eval.csv"),
15
+ ablation_csv: Path = Path("data_artifacts/results/ablations.csv"),
16
+ plots_dir: Path = Path("data_artifacts/results/plots/"),
17
+ ) -> int:
18
+ """Fill [FILL …] markers in README.md in-place.
19
+
20
+ Args:
21
+ template_path: Path to README.md.
22
+ eval_csv: Path to the master eval CSV from Phase C5.
23
+ ablation_csv: Path to the ablation results CSV from Phase C6.
24
+ plots_dir: Directory containing PNG plots.
25
+
26
+ Returns:
27
+ Number of markers replaced.
28
+ """
29
+ import pandas as pd
30
+
31
+ from ci_triage_env.training.readme_table import generate_results_table
32
+
33
+ text = template_path.read_text()
34
+ replaced = 0
35
+
36
+ # 1. Results table
37
+ if eval_csv.exists():
38
+ df_eval = pd.read_csv(eval_csv)
39
+ table_md = generate_results_table(df_eval)
40
+ marker = "[FILL: 5-row × 6-metric table]"
41
+ if marker in text:
42
+ text = text.replace(marker, table_md)
43
+ replaced += 1
44
+
45
+ # 2. Embed plot images — replace [FILL: <stem with spaces>] with markdown img tags
46
+ if plots_dir.exists():
47
+ for png in sorted(plots_dir.glob("*.png")):
48
+ stem_words = png.stem.replace("_", " ")
49
+ marker = f"[FILL: {stem_words}]"
50
+ rel = png.relative_to(template_path.parent)
51
+ embed = f"![{png.stem}]({rel})"
52
+ if marker in text:
53
+ text = text.replace(marker, embed)
54
+ replaced += 1
55
+
56
+ # 3. Remove any remaining generic [FILL POST-TRAIN] or [FILL] markers
57
+ # by replacing them with a placeholder so the README stays valid.
58
+ import re
59
+ generic = re.compile(r"\[FILL[^\]]*\]")
60
+ remaining = generic.findall(text)
61
+ if remaining:
62
+ print(f"WARNING: {len(remaining)} unfilled marker(s) remain: {remaining[:5]}")
63
+
64
+ template_path.write_text(text)
65
+ return replaced
src/ci_triage_env/training/grpo.py CHANGED
@@ -34,6 +34,7 @@ def run_grpo(
34
  env_url: str = "http://localhost:8000",
35
  scenarios_train_path: str = "data_artifacts/scenarios/train",
36
  hyperparams: dict | None = None,
 
37
  ) -> str:
38
  """Launch GRPO fine-tuning from an SFT checkpoint.
39
 
@@ -46,6 +47,8 @@ def run_grpo(
46
  env_url: URL of the running env server (used only when env_client is None).
47
  scenarios_train_path: Directory containing train split scenario JSON files.
48
  hyperparams: Override specific GRPO hyperparameters. Merged over GRPO_HYPERPARAMS.
 
 
49
 
50
  Returns:
51
  output_dir path as a string.
@@ -64,6 +67,7 @@ def run_grpo(
64
  rollout = TrainingRollout(
65
  env_client=env_client,
66
  scenarios_train=scenario_ids,
 
67
  )
68
 
69
  model, tokenizer = load_model_for_sft(model_name=sft_checkpoint_dir)
 
34
  env_url: str = "http://localhost:8000",
35
  scenarios_train_path: str = "data_artifacts/scenarios/train",
36
  hyperparams: dict | None = None,
37
+ weights_override: dict | None = None,
38
  ) -> str:
39
  """Launch GRPO fine-tuning from an SFT checkpoint.
40
 
 
47
  env_url: URL of the running env server (used only when env_client is None).
48
  scenarios_train_path: Directory containing train split scenario JSON files.
49
  hyperparams: Override specific GRPO hyperparameters. Merged over GRPO_HYPERPARAMS.
50
+ weights_override: Override reward component weights passed to CompositeReward.
51
+ Used by ablation runs to zero out individual reward terms.
52
 
53
  Returns:
54
  output_dir path as a string.
 
67
  rollout = TrainingRollout(
68
  env_client=env_client,
69
  scenarios_train=scenario_ids,
70
+ weights=weights_override,
71
  )
72
 
73
  model, tokenizer = load_model_for_sft(model_name=sft_checkpoint_dir)
src/ci_triage_env/training/rollout.py CHANGED
@@ -47,10 +47,12 @@ class TrainingRollout:
47
  env_client,
48
  scenarios_train: list[str] | None = None,
49
  max_turns: int = 12,
 
50
  ) -> None:
51
  self.env = env_client
52
  self.scenarios_train: list[str] = scenarios_train or []
53
  self.max_turns = max_turns
 
54
  self._quarantine_window: list[str] = []
55
  self._tools_listing: str | None = None
56
 
@@ -114,9 +116,10 @@ class TrainingRollout:
114
 
115
  trace = self.env.get_trace(episode_id)
116
  scenario = self._load_scenario(trace, episode_id)
117
- reward = compute_reward(
118
- trace, scenario, quarantine_window=self._quarantine_window
119
- )
 
120
 
121
  # Update rolling quarantine window (last 50 secondary actions)
122
  if trace.episode.final_action:
 
47
  env_client,
48
  scenarios_train: list[str] | None = None,
49
  max_turns: int = 12,
50
+ weights: dict | None = None,
51
  ) -> None:
52
  self.env = env_client
53
  self.scenarios_train: list[str] = scenarios_train or []
54
  self.max_turns = max_turns
55
+ self.weights = weights # None → CompositeReward uses frozen defaults
56
  self._quarantine_window: list[str] = []
57
  self._tools_listing: str | None = None
58
 
 
116
 
117
  trace = self.env.get_trace(episode_id)
118
  scenario = self._load_scenario(trace, episode_id)
119
+ reward_kwargs: dict = {"quarantine_window": self._quarantine_window}
120
+ if self.weights is not None:
121
+ reward_kwargs["weights"] = self.weights
122
+ reward = compute_reward(trace, scenario, **reward_kwargs)
123
 
124
  # Update rolling quarantine window (last 50 secondary actions)
125
  if trace.episode.final_action:
tests/training/test_ablations.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for Phase C6 — ablations, curves, and readme finalization (no GPU)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from unittest.mock import MagicMock, patch
7
+
8
+ import pandas as pd
9
+
10
+ from ci_triage_env.training.ablations import ABLATIONS, run_ablation
11
+ from ci_triage_env.training.finalize_readme import populate_readme
12
+
13
+ # ---------------------------------------------------------------------------
14
+ # ABLATIONS dict
15
+ # ---------------------------------------------------------------------------
16
+
17
+
18
+ def test_ablations_dict_has_4_entries() -> None:
19
+ assert len(ABLATIONS) == 4 # counterfactual deferred to v2
20
+
21
+
22
+ def test_each_ablation_zeros_exactly_one_weight() -> None:
23
+ for name, overrides in ABLATIONS.items():
24
+ zeroed = [k for k, v in overrides.items() if v == 0.0]
25
+ assert len(zeroed) == 1, f"Ablation '{name}' should zero exactly 1 weight, got {zeroed}"
26
+
27
+
28
+ def test_ablation_names_reference_valid_reward_keys() -> None:
29
+ from ci_triage_env.rewards.weights import REWARD_WEIGHTS
30
+
31
+ for name, overrides in ABLATIONS.items():
32
+ for key in overrides:
33
+ assert key in REWARD_WEIGHTS, (
34
+ f"Ablation '{name}' references unknown reward key '{key}'"
35
+ )
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # run_ablation smoke (mock run_grpo + Evaluator)
40
+ # ---------------------------------------------------------------------------
41
+
42
+
43
+ def test_run_ablation_smoke() -> None:
44
+ """Mock run_grpo and Evaluator; verify run_ablation returns a DataFrame."""
45
+ fake_df = pd.DataFrame({
46
+ "baseline": ["random", "heuristic", "trained"],
47
+ "scenario_id": ["s1", "s1", "s1"],
48
+ "family": ["real_bug"] * 3,
49
+ "difficulty": ["easy"] * 3,
50
+ "seed": [1, 1, 1],
51
+ "total_reward": [0.1, 0.5, 0.7],
52
+ "format_gate": [True] * 3,
53
+ "diagnosis_correct": [False, True, True],
54
+ "predicted_diagnosis": ["ambiguous", "real_bug", "real_bug"],
55
+ "true_diagnosis": ["real_bug"] * 3,
56
+ "action_quality": [0.0, 0.3, 0.5],
57
+ "tool_call_count": [3, 4, 5],
58
+ "total_cost": [0.03, 0.04, 0.05],
59
+ "confidence": [0.5, 0.8, 0.9],
60
+ "is_ambiguous_scenario": [False] * 3,
61
+ "brier_on_ambiguous": [None] * 3,
62
+ })
63
+
64
+ mock_evaluator = MagicMock()
65
+ mock_evaluator.run_all.return_value = fake_df
66
+
67
+ with patch("ci_triage_env.training.ablations.run_grpo") as mock_grpo, \
68
+ patch("ci_triage_env.training.ablations.Evaluator", return_value=mock_evaluator):
69
+ mock_grpo.return_value = "checkpoints/ablation_test/"
70
+ result = run_ablation(
71
+ "no_diagnosis",
72
+ {"diagnosis": 0.0},
73
+ total_steps=10,
74
+ )
75
+
76
+ assert isinstance(result, pd.DataFrame)
77
+ assert "ablation" in result.columns
78
+ assert (result["ablation"] == "no_diagnosis").all()
79
+ mock_grpo.assert_called_once()
80
+ call_kwargs = mock_grpo.call_args.kwargs
81
+ assert call_kwargs["weights_override"]["diagnosis"] == 0.0
82
+
83
+
84
+ def test_run_ablation_passes_weights_to_grpo() -> None:
85
+ """Confirm the merged weights dict reaches run_grpo."""
86
+ from ci_triage_env.rewards.weights import REWARD_WEIGHTS
87
+
88
+ fake_df = pd.DataFrame({"baseline": [], "total_reward": [], "diagnosis_correct": [],
89
+ "scenario_id": [], "family": [], "difficulty": [], "seed": [],
90
+ "format_gate": [], "predicted_diagnosis": [], "true_diagnosis": [],
91
+ "action_quality": [], "tool_call_count": [], "total_cost": [],
92
+ "confidence": [], "is_ambiguous_scenario": [], "brier_on_ambiguous": []})
93
+ mock_evaluator = MagicMock()
94
+ mock_evaluator.run_all.return_value = fake_df
95
+
96
+ with patch("ci_triage_env.training.ablations.run_grpo") as mock_grpo, \
97
+ patch("ci_triage_env.training.ablations.Evaluator", return_value=mock_evaluator):
98
+ mock_grpo.return_value = "checkpoints/ablation_no_anti_gaming/"
99
+ run_ablation("no_anti_gaming", {"anti_gaming": 0.0}, total_steps=5)
100
+
101
+ weights_sent = mock_grpo.call_args.kwargs["weights_override"]
102
+ assert weights_sent["anti_gaming"] == 0.0
103
+ # All other weights preserved from REWARD_WEIGHTS
104
+ for k, v in REWARD_WEIGHTS.items():
105
+ if k != "anti_gaming":
106
+ assert weights_sent[k] == v
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # plot_ablation_summary (matplotlib mocked)
111
+ # ---------------------------------------------------------------------------
112
+
113
+
114
+ def test_plot_ablation_summary_writes_png(tmp_path: Path) -> None:
115
+ df = pd.DataFrame({
116
+ "ablation": ["no_diagnosis", "no_action_quality", "no_diagnosis", "no_action_quality"],
117
+ "baseline": ["random", "random", "heuristic", "heuristic"],
118
+ "diagnosis_correct": [0.2, 0.4, 0.5, 0.6],
119
+ "total_reward": [0.1, 0.3, 0.4, 0.5],
120
+ "action_quality": [0.0, 0.2, 0.3, 0.4],
121
+ })
122
+
123
+ saved: list[str] = []
124
+ mock_fig = MagicMock()
125
+ mock_axes = [MagicMock(), MagicMock()]
126
+ mock_plt = MagicMock()
127
+ mock_plt.subplots.return_value = (mock_fig, mock_axes)
128
+ mock_sns = MagicMock()
129
+
130
+ def _record(path, **kwargs):
131
+ saved.append(str(path))
132
+
133
+ mock_fig.savefig.side_effect = _record
134
+
135
+ with patch("ci_triage_env.training.curves.plt", mock_plt), \
136
+ patch("ci_triage_env.training.curves.sns", mock_sns):
137
+ from ci_triage_env.training.curves import plot_ablation_summary
138
+ plot_ablation_summary(df, output_dir=tmp_path / "plots")
139
+
140
+ assert len(saved) >= 1
141
+ assert any("ablation_summary" in s for s in saved)
142
+
143
+
144
+ # ---------------------------------------------------------------------------
145
+ # populate_readme
146
+ # ---------------------------------------------------------------------------
147
+
148
+
149
+ def test_finalize_readme_replaces_table_marker(tmp_path: Path) -> None:
150
+ readme = tmp_path / "README.md"
151
+ readme.write_text(
152
+ "# Results\n\n[FILL: 5-row × 6-metric table]\n\nMore text.\n"
153
+ )
154
+ eval_csv = tmp_path / "eval.csv"
155
+ pd.DataFrame({
156
+ "baseline": ["random", "heuristic"],
157
+ "diagnosis_correct": [0.3, 0.6],
158
+ "action_quality": [0.1, 0.4],
159
+ "total_cost": [0.05, 0.03],
160
+ "tool_call_count": [4, 4],
161
+ "total_reward": [0.2, 0.5],
162
+ }).to_csv(eval_csv, index=False)
163
+
164
+ # No ablation csv; no plots dir
165
+ n = populate_readme(
166
+ template_path=readme,
167
+ eval_csv=eval_csv,
168
+ ablation_csv=tmp_path / "ablations.csv",
169
+ plots_dir=tmp_path / "plots",
170
+ )
171
+
172
+ result = readme.read_text()
173
+ assert "[FILL: 5-row × 6-metric table]" not in result
174
+ assert "|" in result # table was inserted
175
+ assert n >= 1
176
+
177
+
178
+ def test_finalize_readme_embeds_plot_images(tmp_path: Path) -> None:
179
+ readme = tmp_path / "README.md"
180
+ readme.write_text("# Plots\n\n[FILL: diagnosis accuracy]\n")
181
+
182
+ plots_dir = tmp_path / "plots"
183
+ plots_dir.mkdir()
184
+ (plots_dir / "diagnosis_accuracy.png").touch()
185
+
186
+ n = populate_readme(
187
+ template_path=readme,
188
+ eval_csv=tmp_path / "eval.csv",
189
+ ablation_csv=tmp_path / "ablations.csv",
190
+ plots_dir=plots_dir,
191
+ )
192
+
193
+ result = readme.read_text()
194
+ assert "[FILL: diagnosis accuracy]" not in result
195
+ assert "diagnosis_accuracy.png" in result
196
+ assert n >= 1
197
+
198
+
199
+ def test_finalize_readme_missing_csv_does_not_crash(tmp_path: Path) -> None:
200
+ readme = tmp_path / "README.md"
201
+ readme.write_text("# CI Triage\n\nNo markers here.\n")
202
+
203
+ n = populate_readme(
204
+ template_path=readme,
205
+ eval_csv=tmp_path / "nonexistent.csv",
206
+ ablation_csv=tmp_path / "nonexistent2.csv",
207
+ plots_dir=tmp_path / "no_plots",
208
+ )
209
+ assert n == 0
210
+ assert readme.read_text() == "# CI Triage\n\nNo markers here.\n"