Spaces:
Sleeping
Sleeping
| """scripts/make_comparison_plot.py - the "money plot" for Qubit-Medic. | |
| Renders a side-by-side bar chart comparing four conditions on the two | |
| headline metrics emitted by ``scripts.eval`` (and dumped to JSON in | |
| ``data/eval/`` or ``data/`` by the training pipeline): | |
| * Random baseline (uniform-random qubit picks) | |
| * Base Qwen2.5-3B (un-fine-tuned model; usually format failures) | |
| * SFT-only (Qwen2.5-3B after supervised fine-tuning) | |
| * SFT + GRPO (the full Qubit-Medic checkpoint) | |
| Two panels: | |
| * Left: ``logical_correction_rate`` (y-axis 0-1, fraction of shots | |
| where the predicted Pauli frame yields no logical-Z flip) | |
| * Right: ``pymatching_beat_rate`` (y-axis 0-1, fraction of shots | |
| where the model corrects but PyMatching does not) | |
| JSON schema expected per condition file (mirrors | |
| ``scripts/eval.py::_summary``):: | |
| { | |
| "name": str, | |
| "episodes": int, | |
| "logical_correction_rate": float, | |
| "pymatching_beat_rate": float, | |
| ... (other keys are ignored here) | |
| } | |
| The script never runs ``scripts.eval`` itself - it just reads JSON. | |
| Usage:: | |
| python scripts/make_comparison_plot.py # uses defaults | |
| python scripts/make_comparison_plot.py --eval-dir data/eval | |
| python scripts/make_comparison_plot.py \ | |
| --random data/eval/random.json \ | |
| --base data/eval/base_qwen.json \ | |
| --sft data/eval/sft_only.json \ | |
| --grpo data/eval/sft_grpo.json \ | |
| --out figures/before_after_comparison.png | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterable, Optional | |
| # --------------------------------------------------------------------------- # | |
| # Plot configuration # | |
| # --------------------------------------------------------------------------- # | |
| CONDITION_LABELS: tuple[str, ...] = ( | |
| "Random baseline", | |
| "Base Qwen2.5-3B", | |
| "SFT-only", | |
| "SFT + GRPO", | |
| ) | |
| # Colour-blind safe-ish palette: greys for the baselines, accent for ours. | |
| CONDITION_COLOURS: tuple[str, ...] = ( | |
| "#9aa0a6", # random - light grey | |
| "#5f6368", # base - dark grey | |
| "#7e57c2", # sft - purple | |
| "#1e88e5", # sft+grpo - blue (the "after" colour) | |
| ) | |
| # Default filenames the script will look for inside --eval-dir if explicit | |
| # per-condition paths are not supplied. Order matches CONDITION_LABELS. | |
| DEFAULT_FILENAMES: tuple[str, ...] = ( | |
| "random.json", | |
| "base_qwen.json", | |
| "sft_only.json", | |
| "sft_grpo.json", | |
| ) | |
| # --------------------------------------------------------------------------- # | |
| # Data structures # | |
| # --------------------------------------------------------------------------- # | |
| class Condition: | |
| """One bar per panel - a single eval JSON read off disk.""" | |
| label: str | |
| colour: str | |
| path: Path | |
| data: Optional[dict] # None if file missing | |
| def lcr(self) -> float: | |
| if self.data is None: | |
| return 0.0 | |
| return float(self.data.get("logical_correction_rate", 0.0)) | |
| def beat(self) -> float: | |
| if self.data is None: | |
| return 0.0 | |
| return float(self.data.get("pymatching_beat_rate", 0.0)) | |
| def episodes(self) -> int: | |
| if self.data is None: | |
| return 0 | |
| return int(self.data.get("episodes", 0)) | |
| # --------------------------------------------------------------------------- # | |
| # I/O # | |
| # --------------------------------------------------------------------------- # | |
| def _load_json(path: Path) -> Optional[dict]: | |
| """Read a JSON file, returning ``None`` if it does not exist.""" | |
| if not path.exists(): | |
| return None | |
| with path.open("r") as f: | |
| return json.load(f) | |
| def _resolve_paths( | |
| eval_dir: Path, | |
| explicit: dict[str, Optional[str]], | |
| ) -> list[Path]: | |
| """Resolve a path per condition, preferring explicit overrides.""" | |
| paths: list[Path] = [] | |
| for label, default_name in zip(CONDITION_LABELS, DEFAULT_FILENAMES): | |
| override = explicit.get(label) | |
| if override: | |
| paths.append(Path(override)) | |
| else: | |
| paths.append(eval_dir / default_name) | |
| return paths | |
| def load_conditions( | |
| eval_dir: Path, | |
| explicit: dict[str, Optional[str]], | |
| ) -> list[Condition]: | |
| """Materialise four ``Condition`` rows in the canonical plot order.""" | |
| paths = _resolve_paths(eval_dir, explicit) | |
| out: list[Condition] = [] | |
| for label, colour, path in zip(CONDITION_LABELS, CONDITION_COLOURS, paths): | |
| out.append( | |
| Condition( | |
| label=label, | |
| colour=colour, | |
| path=path, | |
| data=_load_json(path), | |
| ) | |
| ) | |
| return out | |
| # --------------------------------------------------------------------------- # | |
| # Plot # | |
| # --------------------------------------------------------------------------- # | |
| def render_plot( | |
| conditions: list[Condition], | |
| out_path: Path, | |
| title: str, | |
| dpi: int = 150, | |
| ) -> None: | |
| """Render the two-panel money plot to ``out_path`` at ``dpi``.""" | |
| try: | |
| import matplotlib.pyplot as plt # local import: graceful failure path | |
| except ImportError as exc: # pragma: no cover - import-time only | |
| raise SystemExit( | |
| "matplotlib is required for make_comparison_plot.py. " | |
| "Install with: pip install matplotlib" | |
| ) from exc | |
| labels = [c.label for c in conditions] | |
| colours = [c.colour for c in conditions] | |
| lcr_values = [c.lcr for c in conditions] | |
| beat_values = [c.beat for c in conditions] | |
| fig, (ax_left, ax_right) = plt.subplots( | |
| nrows=1, ncols=2, figsize=(12, 5.2), sharey=False | |
| ) | |
| x = list(range(len(labels))) | |
| bars_left = ax_left.bar(x, lcr_values, color=colours, edgecolor="black", | |
| linewidth=0.6) | |
| ax_left.set_xticks(x) | |
| ax_left.set_xticklabels(labels, rotation=20, ha="right") | |
| ax_left.set_ylim(0.0, 1.0) | |
| ax_left.set_ylabel("Logical correction rate (fraction of shots, 0-1)") | |
| ax_left.set_xlabel("Decoder condition") | |
| ax_left.set_title("Logical correction rate (per shot)") | |
| ax_left.grid(axis="y", linestyle=":", alpha=0.5) | |
| for bar, val in zip(bars_left, lcr_values): | |
| ax_left.text( | |
| bar.get_x() + bar.get_width() / 2, | |
| min(val + 0.02, 0.98), | |
| f"{val:.3f}", | |
| ha="center", va="bottom", fontsize=9, | |
| ) | |
| bars_right = ax_right.bar(x, beat_values, color=colours, edgecolor="black", | |
| linewidth=0.6) | |
| ax_right.set_xticks(x) | |
| ax_right.set_xticklabels(labels, rotation=20, ha="right") | |
| ax_right.set_ylim(0.0, 1.0) | |
| ax_right.set_ylabel("PyMatching beat rate (fraction of shots, 0-1)") | |
| ax_right.set_xlabel("Decoder condition") | |
| ax_right.set_title("PyMatching beat rate (model corrects, PM does not)") | |
| ax_right.grid(axis="y", linestyle=":", alpha=0.5) | |
| for bar, val in zip(bars_right, beat_values): | |
| ax_right.text( | |
| bar.get_x() + bar.get_width() / 2, | |
| min(val + 0.02, 0.98), | |
| f"{val:.3f}", | |
| ha="center", va="bottom", fontsize=9, | |
| ) | |
| # One shared legend across both panels. | |
| handles = [ | |
| plt.Rectangle((0, 0), 1, 1, color=c, ec="black", lw=0.6) | |
| for c in colours | |
| ] | |
| fig.legend( | |
| handles, labels, | |
| loc="lower center", ncol=len(labels), | |
| bbox_to_anchor=(0.5, -0.02), frameon=False, | |
| ) | |
| fig.suptitle(title, fontsize=13, y=1.02) | |
| fig.tight_layout() | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(out_path, dpi=dpi, bbox_inches="tight") | |
| plt.close(fig) | |
| # --------------------------------------------------------------------------- # | |
| # CLI # | |
| # --------------------------------------------------------------------------- # | |
| def _missing_files_message(conditions: list[Condition]) -> str: | |
| """Build a helpful error when one or more eval JSONs are absent.""" | |
| missing = [(c.label, c.path) for c in conditions if c.data is None] | |
| if not missing: | |
| return "" | |
| lines = [ | |
| "ERROR: cannot build comparison plot - one or more eval JSON files " | |
| "were not found.", | |
| "", | |
| "Expected files (one per condition):", | |
| ] | |
| for label, path in missing: | |
| lines.append(f" - {label}: {path}") | |
| lines.extend([ | |
| "", | |
| "Generate them with scripts/eval.py, for example:", | |
| " python -m scripts.eval --policy random --episodes 1000 \\", | |
| " --out data/eval/random.json", | |
| " python -m scripts.eval --base-model Qwen/Qwen2.5-3B-Instruct \\", | |
| " --adapter '' --episodes 1000 --out data/eval/base_qwen.json", | |
| " python -m scripts.eval --adapter checkpoints/sft/best \\", | |
| " --episodes 1000 --out data/eval/sft_only.json", | |
| " python -m scripts.eval --adapter checkpoints/grpo/best \\", | |
| " --episodes 1000 --out data/eval/sft_grpo.json", | |
| "", | |
| "Override individual paths with --random / --base / --sft / --grpo.", | |
| ]) | |
| return "\n".join(lines) | |
| def parse_args(argv: Iterable[str]) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument( | |
| "--eval-dir", type=str, default="data/eval", | |
| help="Directory holding one JSON per condition " | |
| "(random.json, base_qwen.json, sft_only.json, sft_grpo.json).", | |
| ) | |
| parser.add_argument( | |
| "--random", type=str, default=None, | |
| help="Override path to the random-baseline eval JSON.", | |
| ) | |
| parser.add_argument( | |
| "--base", type=str, default=None, | |
| help="Override path to the base-Qwen eval JSON.", | |
| ) | |
| parser.add_argument( | |
| "--sft", type=str, default=None, | |
| help="Override path to the SFT-only eval JSON.", | |
| ) | |
| parser.add_argument( | |
| "--grpo", type=str, default=None, | |
| help="Override path to the SFT+GRPO eval JSON.", | |
| ) | |
| parser.add_argument( | |
| "--out", type=str, default="figures/before_after_comparison.png", | |
| help="Where to write the PNG (created at 150 dpi by default).", | |
| ) | |
| parser.add_argument( | |
| "--dpi", type=int, default=150, | |
| help="DPI for the saved PNG.", | |
| ) | |
| parser.add_argument( | |
| "--title", type=str, | |
| default=( | |
| "Qubit-Medic decoder accuracy: before vs after RLHF training " | |
| "(distance-3 surface code, p=0.001)" | |
| ), | |
| help="Figure suptitle.", | |
| ) | |
| return parser.parse_args(list(argv)) | |
| def main(argv: Iterable[str] = ()) -> int: | |
| args = parse_args(argv) | |
| explicit = { | |
| "Random baseline": args.random, | |
| "Base Qwen2.5-3B": args.base, | |
| "SFT-only": args.sft, | |
| "SFT + GRPO": args.grpo, | |
| } | |
| conditions = load_conditions(Path(args.eval_dir), explicit) | |
| msg = _missing_files_message(conditions) | |
| if msg: | |
| print(msg, file=sys.stderr) | |
| return 1 | |
| render_plot( | |
| conditions=conditions, | |
| out_path=Path(args.out), | |
| title=args.title, | |
| dpi=args.dpi, | |
| ) | |
| print(f"Wrote comparison plot to {args.out}") | |
| for c in conditions: | |
| print( | |
| f" {c.label:>18s}: LCR={c.lcr:.3f} " | |
| f"PMbeat={c.beat:.3f} (n={c.episodes}, src={c.path})" | |
| ) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main(sys.argv[1:])) | |