ronitraj commited on
Commit
7dc2fe6
·
verified ·
1 Parent(s): be6d6a7

Upload scripts/make_comparison_plot.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/make_comparison_plot.py +355 -0
scripts/make_comparison_plot.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """scripts/make_comparison_plot.py - the "money plot" for Qubit-Medic.
2
+
3
+ Renders a side-by-side bar chart comparing four conditions on the two
4
+ headline metrics emitted by ``scripts.eval`` (and dumped to JSON in
5
+ ``data/eval/`` or ``data/`` by the training pipeline):
6
+
7
+ * Random baseline (uniform-random qubit picks)
8
+ * Base Qwen2.5-3B (un-fine-tuned model; usually format failures)
9
+ * SFT-only (Qwen2.5-3B after supervised fine-tuning)
10
+ * SFT + GRPO (the full Qubit-Medic checkpoint)
11
+
12
+ Two panels:
13
+
14
+ * Left: ``logical_correction_rate`` (y-axis 0-1, fraction of shots
15
+ where the predicted Pauli frame yields no logical-Z flip)
16
+ * Right: ``pymatching_beat_rate`` (y-axis 0-1, fraction of shots
17
+ where the model corrects but PyMatching does not)
18
+
19
+ JSON schema expected per condition file (mirrors
20
+ ``scripts/eval.py::_summary``)::
21
+
22
+ {
23
+ "name": str,
24
+ "episodes": int,
25
+ "logical_correction_rate": float,
26
+ "pymatching_beat_rate": float,
27
+ ... (other keys are ignored here)
28
+ }
29
+
30
+ The script never runs ``scripts.eval`` itself - it just reads JSON.
31
+
32
+ Usage::
33
+
34
+ python scripts/make_comparison_plot.py # uses defaults
35
+ python scripts/make_comparison_plot.py --eval-dir data/eval
36
+ python scripts/make_comparison_plot.py \
37
+ --random data/eval/random.json \
38
+ --base data/eval/base_qwen.json \
39
+ --sft data/eval/sft_only.json \
40
+ --grpo data/eval/sft_grpo.json \
41
+ --out figures/before_after_comparison.png
42
+ """
43
+ from __future__ import annotations
44
+
45
+ import argparse
46
+ import json
47
+ import sys
48
+ from dataclasses import dataclass
49
+ from pathlib import Path
50
+ from typing import Iterable, Optional
51
+
52
+
53
+ # --------------------------------------------------------------------------- #
54
+ # Plot configuration #
55
+ # --------------------------------------------------------------------------- #
56
+
57
+ CONDITION_LABELS: tuple[str, ...] = (
58
+ "Random baseline",
59
+ "Base Qwen2.5-3B",
60
+ "SFT-only",
61
+ "SFT + GRPO",
62
+ )
63
+
64
+ # Colour-blind safe-ish palette: greys for the baselines, accent for ours.
65
+ CONDITION_COLOURS: tuple[str, ...] = (
66
+ "#9aa0a6", # random - light grey
67
+ "#5f6368", # base - dark grey
68
+ "#7e57c2", # sft - purple
69
+ "#1e88e5", # sft+grpo - blue (the "after" colour)
70
+ )
71
+
72
+ # Default filenames the script will look for inside --eval-dir if explicit
73
+ # per-condition paths are not supplied. Order matches CONDITION_LABELS.
74
+ DEFAULT_FILENAMES: tuple[str, ...] = (
75
+ "random.json",
76
+ "base_qwen.json",
77
+ "sft_only.json",
78
+ "sft_grpo.json",
79
+ )
80
+
81
+
82
+ # --------------------------------------------------------------------------- #
83
+ # Data structures #
84
+ # --------------------------------------------------------------------------- #
85
+
86
+
87
+ @dataclass(frozen=True)
88
+ class Condition:
89
+ """One bar per panel - a single eval JSON read off disk."""
90
+
91
+ label: str
92
+ colour: str
93
+ path: Path
94
+ data: Optional[dict] # None if file missing
95
+
96
+ @property
97
+ def lcr(self) -> float:
98
+ if self.data is None:
99
+ return 0.0
100
+ return float(self.data.get("logical_correction_rate", 0.0))
101
+
102
+ @property
103
+ def beat(self) -> float:
104
+ if self.data is None:
105
+ return 0.0
106
+ return float(self.data.get("pymatching_beat_rate", 0.0))
107
+
108
+ @property
109
+ def episodes(self) -> int:
110
+ if self.data is None:
111
+ return 0
112
+ return int(self.data.get("episodes", 0))
113
+
114
+
115
+ # --------------------------------------------------------------------------- #
116
+ # I/O #
117
+ # --------------------------------------------------------------------------- #
118
+
119
+
120
+ def _load_json(path: Path) -> Optional[dict]:
121
+ """Read a JSON file, returning ``None`` if it does not exist."""
122
+ if not path.exists():
123
+ return None
124
+ with path.open("r") as f:
125
+ return json.load(f)
126
+
127
+
128
+ def _resolve_paths(
129
+ eval_dir: Path,
130
+ explicit: dict[str, Optional[str]],
131
+ ) -> list[Path]:
132
+ """Resolve a path per condition, preferring explicit overrides."""
133
+ paths: list[Path] = []
134
+ for label, default_name in zip(CONDITION_LABELS, DEFAULT_FILENAMES):
135
+ override = explicit.get(label)
136
+ if override:
137
+ paths.append(Path(override))
138
+ else:
139
+ paths.append(eval_dir / default_name)
140
+ return paths
141
+
142
+
143
+ def load_conditions(
144
+ eval_dir: Path,
145
+ explicit: dict[str, Optional[str]],
146
+ ) -> list[Condition]:
147
+ """Materialise four ``Condition`` rows in the canonical plot order."""
148
+ paths = _resolve_paths(eval_dir, explicit)
149
+ out: list[Condition] = []
150
+ for label, colour, path in zip(CONDITION_LABELS, CONDITION_COLOURS, paths):
151
+ out.append(
152
+ Condition(
153
+ label=label,
154
+ colour=colour,
155
+ path=path,
156
+ data=_load_json(path),
157
+ )
158
+ )
159
+ return out
160
+
161
+
162
+ # --------------------------------------------------------------------------- #
163
+ # Plot #
164
+ # --------------------------------------------------------------------------- #
165
+
166
+
167
+ def render_plot(
168
+ conditions: list[Condition],
169
+ out_path: Path,
170
+ title: str,
171
+ dpi: int = 150,
172
+ ) -> None:
173
+ """Render the two-panel money plot to ``out_path`` at ``dpi``."""
174
+ try:
175
+ import matplotlib.pyplot as plt # local import: graceful failure path
176
+ except ImportError as exc: # pragma: no cover - import-time only
177
+ raise SystemExit(
178
+ "matplotlib is required for make_comparison_plot.py. "
179
+ "Install with: pip install matplotlib"
180
+ ) from exc
181
+
182
+ labels = [c.label for c in conditions]
183
+ colours = [c.colour for c in conditions]
184
+ lcr_values = [c.lcr for c in conditions]
185
+ beat_values = [c.beat for c in conditions]
186
+
187
+ fig, (ax_left, ax_right) = plt.subplots(
188
+ nrows=1, ncols=2, figsize=(12, 5.2), sharey=False
189
+ )
190
+
191
+ x = list(range(len(labels)))
192
+
193
+ bars_left = ax_left.bar(x, lcr_values, color=colours, edgecolor="black",
194
+ linewidth=0.6)
195
+ ax_left.set_xticks(x)
196
+ ax_left.set_xticklabels(labels, rotation=20, ha="right")
197
+ ax_left.set_ylim(0.0, 1.0)
198
+ ax_left.set_ylabel("Logical correction rate (fraction of shots, 0-1)")
199
+ ax_left.set_xlabel("Decoder condition")
200
+ ax_left.set_title("Logical correction rate (per shot)")
201
+ ax_left.grid(axis="y", linestyle=":", alpha=0.5)
202
+ for bar, val in zip(bars_left, lcr_values):
203
+ ax_left.text(
204
+ bar.get_x() + bar.get_width() / 2,
205
+ min(val + 0.02, 0.98),
206
+ f"{val:.3f}",
207
+ ha="center", va="bottom", fontsize=9,
208
+ )
209
+
210
+ bars_right = ax_right.bar(x, beat_values, color=colours, edgecolor="black",
211
+ linewidth=0.6)
212
+ ax_right.set_xticks(x)
213
+ ax_right.set_xticklabels(labels, rotation=20, ha="right")
214
+ ax_right.set_ylim(0.0, 1.0)
215
+ ax_right.set_ylabel("PyMatching beat rate (fraction of shots, 0-1)")
216
+ ax_right.set_xlabel("Decoder condition")
217
+ ax_right.set_title("PyMatching beat rate (model corrects, PM does not)")
218
+ ax_right.grid(axis="y", linestyle=":", alpha=0.5)
219
+ for bar, val in zip(bars_right, beat_values):
220
+ ax_right.text(
221
+ bar.get_x() + bar.get_width() / 2,
222
+ min(val + 0.02, 0.98),
223
+ f"{val:.3f}",
224
+ ha="center", va="bottom", fontsize=9,
225
+ )
226
+
227
+ # One shared legend across both panels.
228
+ handles = [
229
+ plt.Rectangle((0, 0), 1, 1, color=c, ec="black", lw=0.6)
230
+ for c in colours
231
+ ]
232
+ fig.legend(
233
+ handles, labels,
234
+ loc="lower center", ncol=len(labels),
235
+ bbox_to_anchor=(0.5, -0.02), frameon=False,
236
+ )
237
+
238
+ fig.suptitle(title, fontsize=13, y=1.02)
239
+ fig.tight_layout()
240
+
241
+ out_path.parent.mkdir(parents=True, exist_ok=True)
242
+ fig.savefig(out_path, dpi=dpi, bbox_inches="tight")
243
+ plt.close(fig)
244
+
245
+
246
+ # --------------------------------------------------------------------------- #
247
+ # CLI #
248
+ # --------------------------------------------------------------------------- #
249
+
250
+
251
+ def _missing_files_message(conditions: list[Condition]) -> str:
252
+ """Build a helpful error when one or more eval JSONs are absent."""
253
+ missing = [(c.label, c.path) for c in conditions if c.data is None]
254
+ if not missing:
255
+ return ""
256
+ lines = [
257
+ "ERROR: cannot build comparison plot - one or more eval JSON files "
258
+ "were not found.",
259
+ "",
260
+ "Expected files (one per condition):",
261
+ ]
262
+ for label, path in missing:
263
+ lines.append(f" - {label}: {path}")
264
+ lines.extend([
265
+ "",
266
+ "Generate them with scripts/eval.py, for example:",
267
+ " python -m scripts.eval --policy random --episodes 1000 \\",
268
+ " --out data/eval/random.json",
269
+ " python -m scripts.eval --base-model Qwen/Qwen2.5-3B-Instruct \\",
270
+ " --adapter '' --episodes 1000 --out data/eval/base_qwen.json",
271
+ " python -m scripts.eval --adapter checkpoints/sft/best \\",
272
+ " --episodes 1000 --out data/eval/sft_only.json",
273
+ " python -m scripts.eval --adapter checkpoints/grpo/best \\",
274
+ " --episodes 1000 --out data/eval/sft_grpo.json",
275
+ "",
276
+ "Override individual paths with --random / --base / --sft / --grpo.",
277
+ ])
278
+ return "\n".join(lines)
279
+
280
+
281
+ def parse_args(argv: Iterable[str]) -> argparse.Namespace:
282
+ parser = argparse.ArgumentParser(description=__doc__)
283
+ parser.add_argument(
284
+ "--eval-dir", type=str, default="data/eval",
285
+ help="Directory holding one JSON per condition "
286
+ "(random.json, base_qwen.json, sft_only.json, sft_grpo.json).",
287
+ )
288
+ parser.add_argument(
289
+ "--random", type=str, default=None,
290
+ help="Override path to the random-baseline eval JSON.",
291
+ )
292
+ parser.add_argument(
293
+ "--base", type=str, default=None,
294
+ help="Override path to the base-Qwen eval JSON.",
295
+ )
296
+ parser.add_argument(
297
+ "--sft", type=str, default=None,
298
+ help="Override path to the SFT-only eval JSON.",
299
+ )
300
+ parser.add_argument(
301
+ "--grpo", type=str, default=None,
302
+ help="Override path to the SFT+GRPO eval JSON.",
303
+ )
304
+ parser.add_argument(
305
+ "--out", type=str, default="figures/before_after_comparison.png",
306
+ help="Where to write the PNG (created at 150 dpi by default).",
307
+ )
308
+ parser.add_argument(
309
+ "--dpi", type=int, default=150,
310
+ help="DPI for the saved PNG.",
311
+ )
312
+ parser.add_argument(
313
+ "--title", type=str,
314
+ default=(
315
+ "Qubit-Medic decoder accuracy: before vs after RLHF training "
316
+ "(distance-3 surface code, p=0.001)"
317
+ ),
318
+ help="Figure suptitle.",
319
+ )
320
+ return parser.parse_args(list(argv))
321
+
322
+
323
+ def main(argv: Iterable[str] = ()) -> int:
324
+ args = parse_args(argv)
325
+
326
+ explicit = {
327
+ "Random baseline": args.random,
328
+ "Base Qwen2.5-3B": args.base,
329
+ "SFT-only": args.sft,
330
+ "SFT + GRPO": args.grpo,
331
+ }
332
+ conditions = load_conditions(Path(args.eval_dir), explicit)
333
+
334
+ msg = _missing_files_message(conditions)
335
+ if msg:
336
+ print(msg, file=sys.stderr)
337
+ return 1
338
+
339
+ render_plot(
340
+ conditions=conditions,
341
+ out_path=Path(args.out),
342
+ title=args.title,
343
+ dpi=args.dpi,
344
+ )
345
+ print(f"Wrote comparison plot to {args.out}")
346
+ for c in conditions:
347
+ print(
348
+ f" {c.label:>18s}: LCR={c.lcr:.3f} "
349
+ f"PMbeat={c.beat:.3f} (n={c.episodes}, src={c.path})"
350
+ )
351
+ return 0
352
+
353
+
354
+ if __name__ == "__main__":
355
+ sys.exit(main(sys.argv[1:]))