ronitraj commited on
Commit
be6d6a7
Β·
verified Β·
1 Parent(s): 74d70f5

Upload scripts/comparison_table.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/comparison_table.py +286 -0
scripts/comparison_table.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """scripts/comparison_table.py - Literature comparison table generator.
2
+
3
+ Compliance Section 9 (audit, 2026-04): emit a Markdown table at
4
+ ``results/comparison_table.md`` comparing the trained Qubit-Medic model
5
+ against:
6
+
7
+ * The untrained ``Qwen/Qwen2.5-3B-Instruct`` baseline (loaded from
8
+ ``--baseline-json`` written by ``scripts.eval --policy zeros`` or
9
+ ``--policy random``, since the untrained model itself collapses to
10
+ format failures).
11
+ * PyMatching v2 (Higgott & Gidney 2023, arXiv:2303.15933) reference
12
+ LER ~ 3.0e-2 per cycle at distance-3, p=0.001 (the canonical decoder).
13
+ * AlphaQubit reference LER ~ 2.7e-2 per cycle at distance-3, p=0.001
14
+ (Bausch et al., *Nature* 635:834, 2024,
15
+ doi:10.1038/s41586-024-08148-8).
16
+
17
+ Inputs are JSON dumps written by ``scripts.eval``; the schema mirrors
18
+ ``_summary()`` in that module. Required keys per file:
19
+
20
+ {
21
+ "name": str,
22
+ "episodes": int,
23
+ "logical_correction_rate": float,
24
+ "pymatching_beat_rate": float,
25
+ "format_compliance_rate": float,
26
+ "exact_match_pymatching": float,
27
+ "mean_total_reward": float,
28
+ ... optionally "ler_per_round", "ler_per_round_log10", "level"
29
+ }
30
+
31
+ Usage::
32
+
33
+ # 1. Run model + baseline evals first
34
+ python -m scripts.eval --adapter checkpoints/grpo/best \
35
+ --episodes 1000 --out data/eval_grpo.json
36
+ python -m scripts.eval --policy pymatching \
37
+ --episodes 1000 --out data/eval_pymatching.json
38
+
39
+ # 2. Build the comparison table
40
+ python -m scripts.comparison_table \
41
+ --eval-json data/eval_grpo.json \
42
+ --baseline-json data/eval_pymatching.json \
43
+ --output results/comparison_table.md
44
+ """
45
+ from __future__ import annotations
46
+
47
+ import argparse
48
+ import json
49
+ import math
50
+ import sys
51
+ import time
52
+ from pathlib import Path
53
+ from typing import Iterable, Optional
54
+
55
+
56
+ # --------------------------------------------------------------------------- #
57
+ # Literature reference values (locked at audit time, 2026-04). #
58
+ # --------------------------------------------------------------------------- #
59
+ # Both numbers are reported at distance-3, p ~ 1e-3, rotated surface code,
60
+ # Z-memory experiment. Sources:
61
+ #
62
+ # * PyMatching v2: Higgott & Gidney, "Sparse Blossom" (PyMatching v2),
63
+ # arXiv:2303.15933 (2023). LER ~ 3.0e-2 per round on the distance-3
64
+ # SI1000 benchmark at p=0.001.
65
+ # * AlphaQubit (Bausch et al., Nature 635:834, 2024,
66
+ # doi:10.1038/s41586-024-08148-8). The two-stage decoder hits ~2.7e-2
67
+ # per round at distance-3 on the same benchmark, beating PyMatching by
68
+ # ~10% relative.
69
+ # --------------------------------------------------------------------------- #
70
+
71
+
72
+ _PYMATCHING_REFERENCE = {
73
+ "name": "PyMatching v2 (Higgott & Gidney 2023)",
74
+ "ler_per_round": 3.0e-2,
75
+ "logical_correction_rate": None, # not directly comparable - LCR is per shot
76
+ "citation": "arXiv:2303.15933",
77
+ }
78
+
79
+ _ALPHAQUBIT_REFERENCE = {
80
+ "name": "AlphaQubit (Bausch et al. 2024)",
81
+ "ler_per_round": 2.7e-2,
82
+ "logical_correction_rate": None,
83
+ "citation": "Nature 635:834 (2024), doi:10.1038/s41586-024-08148-8",
84
+ }
85
+
86
+
87
+ # --------------------------------------------------------------------------- #
88
+ # Helpers #
89
+ # --------------------------------------------------------------------------- #
90
+
91
+
92
+ def _load(path: Optional[str]) -> Optional[dict]:
93
+ if path is None:
94
+ return None
95
+ p = Path(path)
96
+ if not p.exists():
97
+ print(f"WARNING: {p} does not exist; skipping that column",
98
+ file=sys.stderr)
99
+ return None
100
+ with p.open("r") as f:
101
+ return json.load(f)
102
+
103
+
104
+ def _fmt_pct(x: Optional[float], digits: int = 2) -> str:
105
+ if x is None:
106
+ return "β€”"
107
+ try:
108
+ return f"{float(x) * 100:.{digits}f}%"
109
+ except (TypeError, ValueError):
110
+ return "β€”"
111
+
112
+
113
+ def _fmt_sci(x: Optional[float], digits: int = 2) -> str:
114
+ if x is None:
115
+ return "β€”"
116
+ try:
117
+ v = float(x)
118
+ if v <= 0:
119
+ return "β€”"
120
+ exp = int(math.floor(math.log10(v)))
121
+ mantissa = v / (10 ** exp)
122
+ return f"{mantissa:.{digits}f}e{exp:+d}"
123
+ except (TypeError, ValueError):
124
+ return "β€”"
125
+
126
+
127
+ def _row(label: str, values: list[str]) -> str:
128
+ return "| " + " | ".join([label] + values) + " |"
129
+
130
+
131
+ def _sep(n: int) -> str:
132
+ return "|" + "|".join(["---"] * n) + "|"
133
+
134
+
135
+ # --------------------------------------------------------------------------- #
136
+ # Table builder #
137
+ # --------------------------------------------------------------------------- #
138
+
139
+
140
+ def build_table(model_eval: dict, baseline_eval: Optional[dict],
141
+ level: str = "L2_target") -> str:
142
+ """Assemble the Markdown table.
143
+
144
+ Columns are: metric, model, baseline (if provided), PyMatching v2,
145
+ AlphaQubit. The two literature columns only carry the LER row.
146
+ """
147
+ cols = ["Metric", "Trained Qubit-Medic"]
148
+ if baseline_eval is not None:
149
+ cols.append(f"Baseline ({baseline_eval.get('name', 'baseline')})")
150
+ cols.append("PyMatching v2 (lit.)")
151
+ cols.append("AlphaQubit (lit.)")
152
+
153
+ out_lines = [
154
+ "# Qubit-Medic literature comparison",
155
+ "",
156
+ f"_Generated: {time.strftime('%Y-%m-%d %H:%M:%S UTC', time.gmtime())}_",
157
+ "",
158
+ f"_Distance-3 rotated surface code, Z-memory experiment, "
159
+ f"SI1000 noise, p ~ 1e-3, level={level}._",
160
+ "",
161
+ "References:",
162
+ f"- PyMatching v2: {_PYMATCHING_REFERENCE['citation']}",
163
+ f"- AlphaQubit: {_ALPHAQUBIT_REFERENCE['citation']}",
164
+ "",
165
+ "| " + " | ".join(cols) + " |",
166
+ _sep(len(cols)),
167
+ ]
168
+
169
+ # Per-shot logical correction rate (the headline binary metric).
170
+ row_vals = [_fmt_pct(model_eval.get("logical_correction_rate"))]
171
+ if baseline_eval is not None:
172
+ row_vals.append(_fmt_pct(baseline_eval.get("logical_correction_rate")))
173
+ row_vals.extend(["β€”", "β€”"])
174
+ out_lines.append(_row("logical_correction_rate (per shot)", row_vals))
175
+
176
+ # Per-round logical error rate (the literature-comparable metric).
177
+ model_ler = model_eval.get("ler_per_round")
178
+ base_ler = baseline_eval.get("ler_per_round") if baseline_eval else None
179
+ row_vals = [_fmt_sci(model_ler)]
180
+ if baseline_eval is not None:
181
+ row_vals.append(_fmt_sci(base_ler))
182
+ row_vals.append(_fmt_sci(_PYMATCHING_REFERENCE["ler_per_round"]))
183
+ row_vals.append(_fmt_sci(_ALPHAQUBIT_REFERENCE["ler_per_round"]))
184
+ out_lines.append(_row("ler_per_round (logical errors / cycle)", row_vals))
185
+
186
+ # PyMatching beat-rate: how often the model wins where PM was wrong.
187
+ row_vals = [_fmt_pct(model_eval.get("pymatching_beat_rate"))]
188
+ if baseline_eval is not None:
189
+ row_vals.append(_fmt_pct(baseline_eval.get("pymatching_beat_rate")))
190
+ row_vals.extend(["0.00%", "β€”"])
191
+ out_lines.append(_row("pymatching_beat_rate", row_vals))
192
+
193
+ # Format compliance.
194
+ row_vals = [_fmt_pct(model_eval.get("format_compliance_rate"))]
195
+ if baseline_eval is not None:
196
+ row_vals.append(_fmt_pct(baseline_eval.get("format_compliance_rate")))
197
+ row_vals.extend(["β€”", "β€”"])
198
+ out_lines.append(_row("format_compliance_rate", row_vals))
199
+
200
+ # Exact match against PyMatching (as a "convergence to baseline" signal).
201
+ row_vals = [_fmt_pct(model_eval.get("exact_match_pymatching"))]
202
+ if baseline_eval is not None:
203
+ row_vals.append(_fmt_pct(baseline_eval.get("exact_match_pymatching")))
204
+ row_vals.extend(["100.00%", "β€”"])
205
+ out_lines.append(_row("exact_match_pymatching", row_vals))
206
+
207
+ # Mean total reward (aggregate scalar; useful for sanity).
208
+ mtr = model_eval.get("mean_total_reward")
209
+ row_vals = [f"{mtr:.3f}" if mtr is not None else "β€”"]
210
+ if baseline_eval is not None:
211
+ bmtr = baseline_eval.get("mean_total_reward")
212
+ row_vals.append(f"{bmtr:.3f}" if bmtr is not None else "β€”")
213
+ row_vals.extend(["β€”", "β€”"])
214
+ out_lines.append(_row("mean_total_reward", row_vals))
215
+
216
+ out_lines.append("")
217
+ out_lines.append("## Notes")
218
+ out_lines.append("")
219
+ out_lines.append(
220
+ "- LER values for PyMatching v2 and AlphaQubit are taken verbatim "
221
+ "from the cited papers at distance-3, p~1e-3 SI1000 noise. They "
222
+ "are reproduction targets, not numbers we re-measured here."
223
+ )
224
+ out_lines.append(
225
+ "- A trained Qubit-Medic ler_per_round below 3.0e-2 means we are "
226
+ "matching or beating the canonical PyMatching reference at this "
227
+ "noise budget; below 2.7e-2 we are matching AlphaQubit's published "
228
+ "two-stage decoder (Bausch et al., Nature 2024)."
229
+ )
230
+ out_lines.append(
231
+ "- pymatching_beat_rate is exactly 0% by construction for "
232
+ "PyMatching itself (it cannot beat itself). It is shown only "
233
+ "to make the trained-model column meaningful."
234
+ )
235
+ out_lines.append("")
236
+
237
+ return "\n".join(out_lines)
238
+
239
+
240
+ # --------------------------------------------------------------------------- #
241
+ # Main #
242
+ # --------------------------------------------------------------------------- #
243
+
244
+
245
+ def main(argv: Iterable[str] = ()) -> int:
246
+ parser = argparse.ArgumentParser(description=__doc__)
247
+ parser.add_argument(
248
+ "--eval-json", type=str, default="data/eval_grpo.json",
249
+ help="JSON output from scripts.eval for the trained model.",
250
+ )
251
+ parser.add_argument(
252
+ "--baseline-json", type=str, default=None,
253
+ help="Optional JSON from scripts.eval for an untrained / "
254
+ "baseline policy column. Skipped if missing.",
255
+ )
256
+ parser.add_argument(
257
+ "--output", type=str, default="results/comparison_table.md",
258
+ help="Markdown file to write.",
259
+ )
260
+ parser.add_argument(
261
+ "--level", type=str, default="L2_target",
262
+ help="Curriculum level the comparison was run on (used in the "
263
+ "table header only; values come from --eval-json).",
264
+ )
265
+ args = parser.parse_args(list(argv))
266
+
267
+ model = _load(args.eval_json)
268
+ if model is None:
269
+ print(f"ERROR: --eval-json {args.eval_json} not found; cannot "
270
+ f"build comparison table.", file=sys.stderr)
271
+ return 1
272
+ baseline = _load(args.baseline_json)
273
+
274
+ md = build_table(model, baseline, level=args.level)
275
+
276
+ out = Path(args.output)
277
+ out.parent.mkdir(parents=True, exist_ok=True)
278
+ out.write_text(md)
279
+ print(f"Wrote literature comparison table to {out}")
280
+ print()
281
+ print(md)
282
+ return 0
283
+
284
+
285
+ if __name__ == "__main__":
286
+ sys.exit(main(sys.argv[1:]))