QuantumScribe / scripts /make_comparison_plot.py
ronitraj's picture
Upload scripts/make_comparison_plot.py with huggingface_hub
7dc2fe6 verified
raw
history blame
11.8 kB
"""scripts/make_comparison_plot.py - the "money plot" for Qubit-Medic.
Renders a side-by-side bar chart comparing four conditions on the two
headline metrics emitted by ``scripts.eval`` (and dumped to JSON in
``data/eval/`` or ``data/`` by the training pipeline):
* Random baseline (uniform-random qubit picks)
* Base Qwen2.5-3B (un-fine-tuned model; usually format failures)
* SFT-only (Qwen2.5-3B after supervised fine-tuning)
* SFT + GRPO (the full Qubit-Medic checkpoint)
Two panels:
* Left: ``logical_correction_rate`` (y-axis 0-1, fraction of shots
where the predicted Pauli frame yields no logical-Z flip)
* Right: ``pymatching_beat_rate`` (y-axis 0-1, fraction of shots
where the model corrects but PyMatching does not)
JSON schema expected per condition file (mirrors
``scripts/eval.py::_summary``)::
{
"name": str,
"episodes": int,
"logical_correction_rate": float,
"pymatching_beat_rate": float,
... (other keys are ignored here)
}
The script never runs ``scripts.eval`` itself - it just reads JSON.
Usage::
python scripts/make_comparison_plot.py # uses defaults
python scripts/make_comparison_plot.py --eval-dir data/eval
python scripts/make_comparison_plot.py \
--random data/eval/random.json \
--base data/eval/base_qwen.json \
--sft data/eval/sft_only.json \
--grpo data/eval/sft_grpo.json \
--out figures/before_after_comparison.png
"""
from __future__ import annotations
import argparse
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional
# --------------------------------------------------------------------------- #
# Plot configuration #
# --------------------------------------------------------------------------- #
CONDITION_LABELS: tuple[str, ...] = (
"Random baseline",
"Base Qwen2.5-3B",
"SFT-only",
"SFT + GRPO",
)
# Colour-blind safe-ish palette: greys for the baselines, accent for ours.
CONDITION_COLOURS: tuple[str, ...] = (
"#9aa0a6", # random - light grey
"#5f6368", # base - dark grey
"#7e57c2", # sft - purple
"#1e88e5", # sft+grpo - blue (the "after" colour)
)
# Default filenames the script will look for inside --eval-dir if explicit
# per-condition paths are not supplied. Order matches CONDITION_LABELS.
DEFAULT_FILENAMES: tuple[str, ...] = (
"random.json",
"base_qwen.json",
"sft_only.json",
"sft_grpo.json",
)
# --------------------------------------------------------------------------- #
# Data structures #
# --------------------------------------------------------------------------- #
@dataclass(frozen=True)
class Condition:
"""One bar per panel - a single eval JSON read off disk."""
label: str
colour: str
path: Path
data: Optional[dict] # None if file missing
@property
def lcr(self) -> float:
if self.data is None:
return 0.0
return float(self.data.get("logical_correction_rate", 0.0))
@property
def beat(self) -> float:
if self.data is None:
return 0.0
return float(self.data.get("pymatching_beat_rate", 0.0))
@property
def episodes(self) -> int:
if self.data is None:
return 0
return int(self.data.get("episodes", 0))
# --------------------------------------------------------------------------- #
# I/O #
# --------------------------------------------------------------------------- #
def _load_json(path: Path) -> Optional[dict]:
"""Read a JSON file, returning ``None`` if it does not exist."""
if not path.exists():
return None
with path.open("r") as f:
return json.load(f)
def _resolve_paths(
eval_dir: Path,
explicit: dict[str, Optional[str]],
) -> list[Path]:
"""Resolve a path per condition, preferring explicit overrides."""
paths: list[Path] = []
for label, default_name in zip(CONDITION_LABELS, DEFAULT_FILENAMES):
override = explicit.get(label)
if override:
paths.append(Path(override))
else:
paths.append(eval_dir / default_name)
return paths
def load_conditions(
eval_dir: Path,
explicit: dict[str, Optional[str]],
) -> list[Condition]:
"""Materialise four ``Condition`` rows in the canonical plot order."""
paths = _resolve_paths(eval_dir, explicit)
out: list[Condition] = []
for label, colour, path in zip(CONDITION_LABELS, CONDITION_COLOURS, paths):
out.append(
Condition(
label=label,
colour=colour,
path=path,
data=_load_json(path),
)
)
return out
# --------------------------------------------------------------------------- #
# Plot #
# --------------------------------------------------------------------------- #
def render_plot(
conditions: list[Condition],
out_path: Path,
title: str,
dpi: int = 150,
) -> None:
"""Render the two-panel money plot to ``out_path`` at ``dpi``."""
try:
import matplotlib.pyplot as plt # local import: graceful failure path
except ImportError as exc: # pragma: no cover - import-time only
raise SystemExit(
"matplotlib is required for make_comparison_plot.py. "
"Install with: pip install matplotlib"
) from exc
labels = [c.label for c in conditions]
colours = [c.colour for c in conditions]
lcr_values = [c.lcr for c in conditions]
beat_values = [c.beat for c in conditions]
fig, (ax_left, ax_right) = plt.subplots(
nrows=1, ncols=2, figsize=(12, 5.2), sharey=False
)
x = list(range(len(labels)))
bars_left = ax_left.bar(x, lcr_values, color=colours, edgecolor="black",
linewidth=0.6)
ax_left.set_xticks(x)
ax_left.set_xticklabels(labels, rotation=20, ha="right")
ax_left.set_ylim(0.0, 1.0)
ax_left.set_ylabel("Logical correction rate (fraction of shots, 0-1)")
ax_left.set_xlabel("Decoder condition")
ax_left.set_title("Logical correction rate (per shot)")
ax_left.grid(axis="y", linestyle=":", alpha=0.5)
for bar, val in zip(bars_left, lcr_values):
ax_left.text(
bar.get_x() + bar.get_width() / 2,
min(val + 0.02, 0.98),
f"{val:.3f}",
ha="center", va="bottom", fontsize=9,
)
bars_right = ax_right.bar(x, beat_values, color=colours, edgecolor="black",
linewidth=0.6)
ax_right.set_xticks(x)
ax_right.set_xticklabels(labels, rotation=20, ha="right")
ax_right.set_ylim(0.0, 1.0)
ax_right.set_ylabel("PyMatching beat rate (fraction of shots, 0-1)")
ax_right.set_xlabel("Decoder condition")
ax_right.set_title("PyMatching beat rate (model corrects, PM does not)")
ax_right.grid(axis="y", linestyle=":", alpha=0.5)
for bar, val in zip(bars_right, beat_values):
ax_right.text(
bar.get_x() + bar.get_width() / 2,
min(val + 0.02, 0.98),
f"{val:.3f}",
ha="center", va="bottom", fontsize=9,
)
# One shared legend across both panels.
handles = [
plt.Rectangle((0, 0), 1, 1, color=c, ec="black", lw=0.6)
for c in colours
]
fig.legend(
handles, labels,
loc="lower center", ncol=len(labels),
bbox_to_anchor=(0.5, -0.02), frameon=False,
)
fig.suptitle(title, fontsize=13, y=1.02)
fig.tight_layout()
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=dpi, bbox_inches="tight")
plt.close(fig)
# --------------------------------------------------------------------------- #
# CLI #
# --------------------------------------------------------------------------- #
def _missing_files_message(conditions: list[Condition]) -> str:
"""Build a helpful error when one or more eval JSONs are absent."""
missing = [(c.label, c.path) for c in conditions if c.data is None]
if not missing:
return ""
lines = [
"ERROR: cannot build comparison plot - one or more eval JSON files "
"were not found.",
"",
"Expected files (one per condition):",
]
for label, path in missing:
lines.append(f" - {label}: {path}")
lines.extend([
"",
"Generate them with scripts/eval.py, for example:",
" python -m scripts.eval --policy random --episodes 1000 \\",
" --out data/eval/random.json",
" python -m scripts.eval --base-model Qwen/Qwen2.5-3B-Instruct \\",
" --adapter '' --episodes 1000 --out data/eval/base_qwen.json",
" python -m scripts.eval --adapter checkpoints/sft/best \\",
" --episodes 1000 --out data/eval/sft_only.json",
" python -m scripts.eval --adapter checkpoints/grpo/best \\",
" --episodes 1000 --out data/eval/sft_grpo.json",
"",
"Override individual paths with --random / --base / --sft / --grpo.",
])
return "\n".join(lines)
def parse_args(argv: Iterable[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--eval-dir", type=str, default="data/eval",
help="Directory holding one JSON per condition "
"(random.json, base_qwen.json, sft_only.json, sft_grpo.json).",
)
parser.add_argument(
"--random", type=str, default=None,
help="Override path to the random-baseline eval JSON.",
)
parser.add_argument(
"--base", type=str, default=None,
help="Override path to the base-Qwen eval JSON.",
)
parser.add_argument(
"--sft", type=str, default=None,
help="Override path to the SFT-only eval JSON.",
)
parser.add_argument(
"--grpo", type=str, default=None,
help="Override path to the SFT+GRPO eval JSON.",
)
parser.add_argument(
"--out", type=str, default="figures/before_after_comparison.png",
help="Where to write the PNG (created at 150 dpi by default).",
)
parser.add_argument(
"--dpi", type=int, default=150,
help="DPI for the saved PNG.",
)
parser.add_argument(
"--title", type=str,
default=(
"Qubit-Medic decoder accuracy: before vs after RLHF training "
"(distance-3 surface code, p=0.001)"
),
help="Figure suptitle.",
)
return parser.parse_args(list(argv))
def main(argv: Iterable[str] = ()) -> int:
args = parse_args(argv)
explicit = {
"Random baseline": args.random,
"Base Qwen2.5-3B": args.base,
"SFT-only": args.sft,
"SFT + GRPO": args.grpo,
}
conditions = load_conditions(Path(args.eval_dir), explicit)
msg = _missing_files_message(conditions)
if msg:
print(msg, file=sys.stderr)
return 1
render_plot(
conditions=conditions,
out_path=Path(args.out),
title=args.title,
dpi=args.dpi,
)
print(f"Wrote comparison plot to {args.out}")
for c in conditions:
print(
f" {c.label:>18s}: LCR={c.lcr:.3f} "
f"PMbeat={c.beat:.3f} (n={c.episodes}, src={c.path})"
)
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))