QuantumScribe / scripts /comparison_table.py
ronitraj's picture
Upload scripts/comparison_table.py with huggingface_hub
be6d6a7 verified
raw
history blame
10.7 kB
"""scripts/comparison_table.py - Literature comparison table generator.
Compliance Section 9 (audit, 2026-04): emit a Markdown table at
``results/comparison_table.md`` comparing the trained Qubit-Medic model
against:
* The untrained ``Qwen/Qwen2.5-3B-Instruct`` baseline (loaded from
``--baseline-json`` written by ``scripts.eval --policy zeros`` or
``--policy random``, since the untrained model itself collapses to
format failures).
* PyMatching v2 (Higgott & Gidney 2023, arXiv:2303.15933) reference
LER ~ 3.0e-2 per cycle at distance-3, p=0.001 (the canonical decoder).
* AlphaQubit reference LER ~ 2.7e-2 per cycle at distance-3, p=0.001
(Bausch et al., *Nature* 635:834, 2024,
doi:10.1038/s41586-024-08148-8).
Inputs are JSON dumps written by ``scripts.eval``; the schema mirrors
``_summary()`` in that module. Required keys per file:
{
"name": str,
"episodes": int,
"logical_correction_rate": float,
"pymatching_beat_rate": float,
"format_compliance_rate": float,
"exact_match_pymatching": float,
"mean_total_reward": float,
... optionally "ler_per_round", "ler_per_round_log10", "level"
}
Usage::
# 1. Run model + baseline evals first
python -m scripts.eval --adapter checkpoints/grpo/best \
--episodes 1000 --out data/eval_grpo.json
python -m scripts.eval --policy pymatching \
--episodes 1000 --out data/eval_pymatching.json
# 2. Build the comparison table
python -m scripts.comparison_table \
--eval-json data/eval_grpo.json \
--baseline-json data/eval_pymatching.json \
--output results/comparison_table.md
"""
from __future__ import annotations
import argparse
import json
import math
import sys
import time
from pathlib import Path
from typing import Iterable, Optional
# --------------------------------------------------------------------------- #
# Literature reference values (locked at audit time, 2026-04). #
# --------------------------------------------------------------------------- #
# Both numbers are reported at distance-3, p ~ 1e-3, rotated surface code,
# Z-memory experiment. Sources:
#
# * PyMatching v2: Higgott & Gidney, "Sparse Blossom" (PyMatching v2),
# arXiv:2303.15933 (2023). LER ~ 3.0e-2 per round on the distance-3
# SI1000 benchmark at p=0.001.
# * AlphaQubit (Bausch et al., Nature 635:834, 2024,
# doi:10.1038/s41586-024-08148-8). The two-stage decoder hits ~2.7e-2
# per round at distance-3 on the same benchmark, beating PyMatching by
# ~10% relative.
# --------------------------------------------------------------------------- #
_PYMATCHING_REFERENCE = {
"name": "PyMatching v2 (Higgott & Gidney 2023)",
"ler_per_round": 3.0e-2,
"logical_correction_rate": None, # not directly comparable - LCR is per shot
"citation": "arXiv:2303.15933",
}
_ALPHAQUBIT_REFERENCE = {
"name": "AlphaQubit (Bausch et al. 2024)",
"ler_per_round": 2.7e-2,
"logical_correction_rate": None,
"citation": "Nature 635:834 (2024), doi:10.1038/s41586-024-08148-8",
}
# --------------------------------------------------------------------------- #
# Helpers #
# --------------------------------------------------------------------------- #
def _load(path: Optional[str]) -> Optional[dict]:
if path is None:
return None
p = Path(path)
if not p.exists():
print(f"WARNING: {p} does not exist; skipping that column",
file=sys.stderr)
return None
with p.open("r") as f:
return json.load(f)
def _fmt_pct(x: Optional[float], digits: int = 2) -> str:
if x is None:
return "—"
try:
return f"{float(x) * 100:.{digits}f}%"
except (TypeError, ValueError):
return "—"
def _fmt_sci(x: Optional[float], digits: int = 2) -> str:
if x is None:
return "—"
try:
v = float(x)
if v <= 0:
return "—"
exp = int(math.floor(math.log10(v)))
mantissa = v / (10 ** exp)
return f"{mantissa:.{digits}f}e{exp:+d}"
except (TypeError, ValueError):
return "—"
def _row(label: str, values: list[str]) -> str:
return "| " + " | ".join([label] + values) + " |"
def _sep(n: int) -> str:
return "|" + "|".join(["---"] * n) + "|"
# --------------------------------------------------------------------------- #
# Table builder #
# --------------------------------------------------------------------------- #
def build_table(model_eval: dict, baseline_eval: Optional[dict],
level: str = "L2_target") -> str:
"""Assemble the Markdown table.
Columns are: metric, model, baseline (if provided), PyMatching v2,
AlphaQubit. The two literature columns only carry the LER row.
"""
cols = ["Metric", "Trained Qubit-Medic"]
if baseline_eval is not None:
cols.append(f"Baseline ({baseline_eval.get('name', 'baseline')})")
cols.append("PyMatching v2 (lit.)")
cols.append("AlphaQubit (lit.)")
out_lines = [
"# Qubit-Medic literature comparison",
"",
f"_Generated: {time.strftime('%Y-%m-%d %H:%M:%S UTC', time.gmtime())}_",
"",
f"_Distance-3 rotated surface code, Z-memory experiment, "
f"SI1000 noise, p ~ 1e-3, level={level}._",
"",
"References:",
f"- PyMatching v2: {_PYMATCHING_REFERENCE['citation']}",
f"- AlphaQubit: {_ALPHAQUBIT_REFERENCE['citation']}",
"",
"| " + " | ".join(cols) + " |",
_sep(len(cols)),
]
# Per-shot logical correction rate (the headline binary metric).
row_vals = [_fmt_pct(model_eval.get("logical_correction_rate"))]
if baseline_eval is not None:
row_vals.append(_fmt_pct(baseline_eval.get("logical_correction_rate")))
row_vals.extend(["—", "—"])
out_lines.append(_row("logical_correction_rate (per shot)", row_vals))
# Per-round logical error rate (the literature-comparable metric).
model_ler = model_eval.get("ler_per_round")
base_ler = baseline_eval.get("ler_per_round") if baseline_eval else None
row_vals = [_fmt_sci(model_ler)]
if baseline_eval is not None:
row_vals.append(_fmt_sci(base_ler))
row_vals.append(_fmt_sci(_PYMATCHING_REFERENCE["ler_per_round"]))
row_vals.append(_fmt_sci(_ALPHAQUBIT_REFERENCE["ler_per_round"]))
out_lines.append(_row("ler_per_round (logical errors / cycle)", row_vals))
# PyMatching beat-rate: how often the model wins where PM was wrong.
row_vals = [_fmt_pct(model_eval.get("pymatching_beat_rate"))]
if baseline_eval is not None:
row_vals.append(_fmt_pct(baseline_eval.get("pymatching_beat_rate")))
row_vals.extend(["0.00%", "—"])
out_lines.append(_row("pymatching_beat_rate", row_vals))
# Format compliance.
row_vals = [_fmt_pct(model_eval.get("format_compliance_rate"))]
if baseline_eval is not None:
row_vals.append(_fmt_pct(baseline_eval.get("format_compliance_rate")))
row_vals.extend(["—", "—"])
out_lines.append(_row("format_compliance_rate", row_vals))
# Exact match against PyMatching (as a "convergence to baseline" signal).
row_vals = [_fmt_pct(model_eval.get("exact_match_pymatching"))]
if baseline_eval is not None:
row_vals.append(_fmt_pct(baseline_eval.get("exact_match_pymatching")))
row_vals.extend(["100.00%", "—"])
out_lines.append(_row("exact_match_pymatching", row_vals))
# Mean total reward (aggregate scalar; useful for sanity).
mtr = model_eval.get("mean_total_reward")
row_vals = [f"{mtr:.3f}" if mtr is not None else "—"]
if baseline_eval is not None:
bmtr = baseline_eval.get("mean_total_reward")
row_vals.append(f"{bmtr:.3f}" if bmtr is not None else "—")
row_vals.extend(["—", "—"])
out_lines.append(_row("mean_total_reward", row_vals))
out_lines.append("")
out_lines.append("## Notes")
out_lines.append("")
out_lines.append(
"- LER values for PyMatching v2 and AlphaQubit are taken verbatim "
"from the cited papers at distance-3, p~1e-3 SI1000 noise. They "
"are reproduction targets, not numbers we re-measured here."
)
out_lines.append(
"- A trained Qubit-Medic ler_per_round below 3.0e-2 means we are "
"matching or beating the canonical PyMatching reference at this "
"noise budget; below 2.7e-2 we are matching AlphaQubit's published "
"two-stage decoder (Bausch et al., Nature 2024)."
)
out_lines.append(
"- pymatching_beat_rate is exactly 0% by construction for "
"PyMatching itself (it cannot beat itself). It is shown only "
"to make the trained-model column meaningful."
)
out_lines.append("")
return "\n".join(out_lines)
# --------------------------------------------------------------------------- #
# Main #
# --------------------------------------------------------------------------- #
def main(argv: Iterable[str] = ()) -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--eval-json", type=str, default="data/eval_grpo.json",
help="JSON output from scripts.eval for the trained model.",
)
parser.add_argument(
"--baseline-json", type=str, default=None,
help="Optional JSON from scripts.eval for an untrained / "
"baseline policy column. Skipped if missing.",
)
parser.add_argument(
"--output", type=str, default="results/comparison_table.md",
help="Markdown file to write.",
)
parser.add_argument(
"--level", type=str, default="L2_target",
help="Curriculum level the comparison was run on (used in the "
"table header only; values come from --eval-json).",
)
args = parser.parse_args(list(argv))
model = _load(args.eval_json)
if model is None:
print(f"ERROR: --eval-json {args.eval_json} not found; cannot "
f"build comparison table.", file=sys.stderr)
return 1
baseline = _load(args.baseline_json)
md = build_table(model, baseline, level=args.level)
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(md)
print(f"Wrote literature comparison table to {out}")
print()
print(md)
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))