{ "cells": [ { "cell_type": "markdown", "id": "cell-0", "metadata": {}, "source": [ "# CI-Triage-Env — Evaluation & Ablation Notebook\n", "\n", "Colab-runnable notebook for judges to reproduce all results.\n", "\n", "Steps:\n", "1. Install dependencies\n", "2. Load trained checkpoint from HF Hub\n", "3. Run full 5-baseline evaluation\n", "4. Generate all metric plots\n", "5. Run reward-layer ablations (optional, GPU, ~5h)\n", "6. Populate README with results\n", "\n", "**Prerequisites**: `HF_TOKEN`, `WANDB_API_KEY` as Colab secrets." ] }, { "cell_type": "code", "execution_count": null, "id": "cell-1", "metadata": {}, "outputs": [], "source": [ "# Cell 1: Install dependencies\n", "!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n", "!pip install -q unsloth trl transformers accelerate peft\n", "!pip install -q wandb datasets huggingface_hub openai httpx fastapi uvicorn pydantic jsonschema\n", "!pip install -q matplotlib seaborn pandas tabulate\n", "!pip install -q -e . # install ci_triage_env package" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-2", "metadata": {}, "outputs": [], "source": [ "# Cell 2: Environment setup\n", "import os\n", "from google.colab import userdata\n", "\n", "os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n", "os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')\n", "os.environ['WANDB_PROJECT'] = 'ci-triage-env'\n", "\n", "# Config — replace with your values\n", "HF_DATASET_REPO = 'YOUR_ORG/ci-triage-scenarios'\n", "HF_MODEL_REPO = 'YOUR_ORG/ci-triage-trained-qwen3.5-4b'\n", "WANDB_RUN_ID = 'YOUR_WANDB_RUN_ID' # e.g. 'entity/ci-triage-env/abc123'" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-3", "metadata": {}, "outputs": [], "source": [ "# Cell 3: Download scenario corpus and trained checkpoint\n", "from huggingface_hub import snapshot_download\n", "\n", "scen_dir = snapshot_download(\n", " HF_DATASET_REPO, repo_type='dataset',\n", " local_dir='data_artifacts/scenarios'\n", ")\n", "ckpt_dir = snapshot_download(\n", " HF_MODEL_REPO, repo_type='model',\n", " local_dir='checkpoints/grpo_full'\n", ")\n", "print(f'Scenarios: {scen_dir}')\n", "print(f'Checkpoint: {ckpt_dir}')" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-4", "metadata": {}, "outputs": [], "source": [ "# Cell 4: Start env server in background\n", "import subprocess, time\n", "server_proc = subprocess.Popen(\n", " ['python', '-m', 'ci_triage_env.env.server'],\n", " stdout=subprocess.PIPE, stderr=subprocess.PIPE\n", ")\n", "time.sleep(4)\n", "print('Env server started, PID:', server_proc.pid)" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-5", "metadata": {}, "outputs": [], "source": [ "# Cell 5: Run full 5-baseline evaluation\n", "from pathlib import Path\n", "import pandas as pd\n", "from ci_triage_env.training.eval import Evaluator\n", "from ci_triage_env.training.plotting import plot_all_eval_metrics\n", "\n", "evaluator = Evaluator(\n", " eval_set_path='data_artifacts/scenarios/held_out/',\n", " trained_checkpoint='checkpoints/grpo_full/',\n", ")\n", "df_eval = evaluator.run_all(seeds=[1, 2, 3])\n", "\n", "out = Path('data_artifacts/results/')\n", "out.mkdir(parents=True, exist_ok=True)\n", "df_eval.to_csv(out / 'eval.csv', index=False)\n", "\n", "print(df_eval.groupby('baseline').agg({\n", " 'diagnosis_correct': 'mean',\n", " 'total_reward': 'mean',\n", " 'tool_call_count': 'mean',\n", "}))\n", "\n", "plot_all_eval_metrics(df_eval, out / 'plots/')\n", "print('Plots saved to data_artifacts/results/plots/')" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-6", "metadata": {}, "outputs": [], "source": [ "# Cell 6: Pull training curves from W&B\n", "from ci_triage_env.training.curves import plot_training_curves_from_wandb\n", "\n", "plot_training_curves_from_wandb(\n", " run_id=WANDB_RUN_ID,\n", " output_dir=Path('data_artifacts/results/plots/'),\n", ")\n", "print('Training curves saved.')" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-7", "metadata": {}, "outputs": [], "source": [ "# Cell 7: Run reward-layer ablations (~5h on A100; set RUN_ABLATIONS=True to enable)\n", "RUN_ABLATIONS = False\n", "\n", "if RUN_ABLATIONS:\n", " from ci_triage_env.training.ablations import ABLATIONS, run_ablation\n", " from ci_triage_env.training.curves import plot_ablation_summary\n", "\n", " abl_results = []\n", " for name, overrides in ABLATIONS.items():\n", " print(f'=== Ablation: {name} ===')\n", " df_abl = run_ablation(name, overrides, total_steps=1000)\n", " abl_results.append(df_abl)\n", " print(df_abl.groupby('baseline')['diagnosis_correct'].mean())\n", "\n", " df_full_abl = pd.concat(abl_results, ignore_index=True)\n", " df_full_abl.to_csv(out / 'ablations.csv', index=False)\n", " plot_ablation_summary(df_full_abl, output_dir=out / 'plots/')\n", " print('Ablations saved.')\n", "else:\n", " print('Ablations skipped (set RUN_ABLATIONS=True to run).')" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-8", "metadata": {}, "outputs": [], "source": [ "# Cell 8: Populate README with results\n", "from ci_triage_env.training.finalize_readme import populate_readme\n", "\n", "n = populate_readme(\n", " eval_csv=out / 'eval.csv',\n", " ablation_csv=out / 'ablations.csv',\n", " plots_dir=out / 'plots/',\n", ")\n", "print(f'Replaced {n} markers in README.md')\n", "\n", "# Check for any remaining unfilled markers\n", "import subprocess\n", "result = subprocess.run(['grep', '-c', r'\\[FILL', 'README.md'], capture_output=True, text=True)\n", "remaining = int(result.stdout.strip() or 0)\n", "if remaining:\n", " print(f'WARNING: {remaining} unfilled [FILL] marker(s) remain in README.md')\n", "else:\n", " print('README.md is clean — no unfilled markers.')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }