ronitraj commited on
Commit
ff28459
·
verified ·
1 Parent(s): bd1a695

Upload scripts/eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/eval.py +294 -0
scripts/eval.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """scripts/eval.py - held-out evaluation harness (Sections 6.2 + 7.3).
2
+
3
+ Runs a model (or one of the deterministic baselines) over a held-out set
4
+ of syndromes and reports:
5
+
6
+ * format compliance rate
7
+ * logical correction rate
8
+ * mean Hamming-overlap with PyMatching
9
+ * PyMatching beat-rate
10
+ * mean total reward
11
+
12
+ Usage::
13
+
14
+ # Baseline run (no model; uses PyMatching-imitator):
15
+ python -m scripts.eval --policy pymatching --episodes 200
16
+
17
+ # Trained model (loads adapters via Unsloth):
18
+ python -m scripts.eval --adapter checkpoints/grpo --episodes 500
19
+
20
+ # With W&B logging (summary + per-episode table):
21
+ python -m scripts.eval --adapter checkpoints/grpo --episodes 500 \
22
+ --report-to wandb --wandb-group my-experiment
23
+ """
24
+ from __future__ import annotations
25
+
26
+ import argparse
27
+ import json
28
+ import sys
29
+ from typing import Iterable
30
+
31
+ from qubit_medic.client.client import LocalDecoderClient
32
+ from qubit_medic.config import primary_level
33
+
34
+
35
+ def _summary(name: str, results: list[dict]) -> dict:
36
+ """Aggregate per-episode reward dicts into the metrics the master spec
37
+ benchmarks against (sections 6 + 7 of the locked spec).
38
+
39
+ Each entry in ``results`` is the env's per-step ``info["rewards"]``
40
+ dict, optionally with extra fields the eval loop decorated:
41
+ * ``exact_match_pymatching`` (model-eval only)
42
+ * ``output_length`` (model-eval only)
43
+ * ``n_true_errors`` (any caller; enables hard-syndrome subset)
44
+ """
45
+ n = max(1, len(results))
46
+ # Hard-syndrome subset = episodes where the simulated truth contains
47
+ # at least 2 X|Z errors. This is the cohort where MWPM ambiguity
48
+ # matters and trained-model contributions are most visible.
49
+ hard = [r for r in results if int(r.get("n_true_errors", 0)) >= 2]
50
+ n_hard = len(hard)
51
+ out = {
52
+ "name": name,
53
+ "episodes": len(results),
54
+ # Headline metrics (master spec, section 6).
55
+ "logical_correction_rate":
56
+ sum(r["logical_correction"] >= 0.5 for r in results) / n,
57
+ "pymatching_beat_rate":
58
+ sum(r["pymatching_beat"] >= 0.5 for r in results) / n,
59
+ "format_compliance_rate":
60
+ sum(r["format_compliance"] >= 0.999 for r in results) / n,
61
+ "format_partial_rate":
62
+ sum((r["format_compliance"] >= 0.5
63
+ and r["format_compliance"] < 0.999) for r in results) / n,
64
+ # Continuous progress metrics.
65
+ "syndrome_consistency_rate":
66
+ sum(r["syndrome_consistency"] >= 0.999 for r in results) / n,
67
+ "mean_syndrome_consistency":
68
+ sum(r["syndrome_consistency"] for r in results) / n,
69
+ "mean_hamming_overlap":
70
+ sum(r["hamming_overlap"] for r in results) / n,
71
+ "mean_total_reward":
72
+ sum(r["total"] for r in results) / n,
73
+ # Model-eval extras (present iff the model loop populated them).
74
+ "exact_match_pymatching":
75
+ sum(int(r.get("exact_match_pymatching", 0)) for r in results) / n,
76
+ "mean_output_length":
77
+ sum(int(r.get("output_length", 0)) for r in results) / n,
78
+ # Hard-syndrome subset (FIX 5, 2026-04 eval spec). Easy syndromes
79
+ # are where every baseline already hits ~95%+; the hard subset is
80
+ # where differentiation actually shows up.
81
+ "hard_syndrome_count": n_hard,
82
+ "hard_syndrome_lcr":
83
+ (sum(r["logical_correction"] >= 0.5 for r in hard) / n_hard
84
+ if n_hard else 0.0),
85
+ "hard_syndrome_beat_rate":
86
+ (sum(r["pymatching_beat"] >= 0.5 for r in hard) / n_hard
87
+ if n_hard else 0.0),
88
+ }
89
+ return out
90
+
91
+
92
+ def _eval_baseline(name: str, episodes: int, level: str,
93
+ collect_rows: bool = False):
94
+ from scripts.baseline_policies import (
95
+ policy_pymatching, policy_zeros, policy_random,
96
+ )
97
+ import random as _r
98
+ rng = _r.Random(0)
99
+ pol_map = {
100
+ "pymatching": lambda obs: policy_pymatching(obs, env_client=None),
101
+ "zeros": policy_zeros,
102
+ "random": lambda obs: policy_random(obs, rng=rng),
103
+ }
104
+ if name not in pol_map:
105
+ raise ValueError(f"unknown baseline {name}; choose from {sorted(pol_map)}")
106
+ pol = pol_map[name]
107
+ client = LocalDecoderClient()
108
+ rewards = []
109
+ rows = []
110
+ for ep in range(episodes):
111
+ obs = client.reset(forced_level=level, seed=10_000 + ep)
112
+ completion = pol(obs)
113
+ result = client.step(raw_response=completion, episode_id=obs.episode_id)
114
+ rwd = dict(result.info["rewards"]) # copy so we can decorate
115
+ # Tag with true-error count so _summary can filter the hard subset.
116
+ rwd["n_true_errors"] = (
117
+ len(result.info.get("pymatching_x_errors", []) or [])
118
+ + len(result.info.get("pymatching_z_errors", []) or [])
119
+ )
120
+ rewards.append(rwd)
121
+ if collect_rows and ep < 50: # cap table size
122
+ rows.append({
123
+ "episode": ep,
124
+ "completion": completion,
125
+ "logical_correction": rwd["logical_correction"],
126
+ "syndrome_consistency": rwd["syndrome_consistency"],
127
+ "hamming_overlap": rwd["hamming_overlap"],
128
+ "format_compliance": rwd["format_compliance"],
129
+ "pymatching_beat": rwd["pymatching_beat"],
130
+ "total": rwd["total"],
131
+ "actual_obs_flip": result.info["actual_observable_flip"],
132
+ "pm_obs_flip": result.info["pymatching_observable_pred"],
133
+ })
134
+ return _summary(name, rewards), rows
135
+
136
+
137
+ def _eval_model(adapter: str, episodes: int, level: str,
138
+ base_model: str, max_new_tokens: int,
139
+ collect_rows: bool = False):
140
+ """Use Unsloth to load the adapter and generate completions.
141
+
142
+ Populates ``exact_match_pymatching`` and ``output_length`` on each
143
+ per-episode reward dict so :func:`_summary` can report the master
144
+ spec's full benchmark suite (section 6 + section 7).
145
+ """
146
+ from unsloth import FastLanguageModel
147
+ model, tokenizer = FastLanguageModel.from_pretrained(
148
+ model_name=adapter if adapter else base_model,
149
+ max_seq_length=2048,
150
+ load_in_4bit=True,
151
+ dtype=None,
152
+ )
153
+ FastLanguageModel.for_inference(model)
154
+
155
+ client = LocalDecoderClient()
156
+ rewards = []
157
+ rows = []
158
+ for ep in range(episodes):
159
+ obs = client.reset(forced_level=level, seed=10_000 + ep)
160
+ chat = [{"role": "user", "content": obs.prompt}]
161
+ text = tokenizer.apply_chat_template(chat, tokenize=False,
162
+ add_generation_prompt=True)
163
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
164
+ out = model.generate(
165
+ **inputs, max_new_tokens=max_new_tokens,
166
+ do_sample=False, # deterministic / greedy eval
167
+ eos_token_id=tokenizer.eos_token_id,
168
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
169
+ )
170
+ gen_ids = out[0][inputs["input_ids"].shape[1]:]
171
+ completion = tokenizer.decode(gen_ids, skip_special_tokens=True)
172
+ n_tokens = int(gen_ids.shape[0])
173
+ result = client.step(raw_response=completion, episode_id=obs.episode_id)
174
+ rwd = dict(result.info["rewards"]) # copy so we can decorate
175
+
176
+ # Decorate with the master-spec extras.
177
+ action = result.info.get("parsed_action", {}) or {}
178
+ pm_x = sorted(set(map(int, result.info.get("pymatching_x_errors", []) or [])))
179
+ pm_z = sorted(set(map(int, result.info.get("pymatching_z_errors", []) or [])))
180
+ our_x = sorted(set(map(int, action.get("x_error_qubits", []) or [])))
181
+ our_z = sorted(set(map(int, action.get("z_error_qubits", []) or [])))
182
+ rwd["exact_match_pymatching"] = int(
183
+ bool(action.get("parse_success", False))
184
+ and our_x == pm_x and our_z == pm_z
185
+ )
186
+ rwd["output_length"] = n_tokens
187
+ rwd["n_true_errors"] = len(pm_x) + len(pm_z)
188
+ rewards.append(rwd)
189
+
190
+ if collect_rows and ep < 50:
191
+ rows.append({
192
+ "episode": ep,
193
+ "completion": completion[:300],
194
+ "logical_correction": rwd["logical_correction"],
195
+ "syndrome_consistency": rwd["syndrome_consistency"],
196
+ "hamming_overlap": rwd["hamming_overlap"],
197
+ "format_compliance": rwd["format_compliance"],
198
+ "pymatching_beat": rwd["pymatching_beat"],
199
+ "exact_match_pymatching": rwd["exact_match_pymatching"],
200
+ "output_length": rwd["output_length"],
201
+ "total": rwd["total"],
202
+ "actual_obs_flip": result.info["actual_observable_flip"],
203
+ "pm_obs_flip": result.info["pymatching_observable_pred"],
204
+ })
205
+ return _summary(f"model[{adapter}]", rewards), rows
206
+
207
+
208
+ def main(argv: Iterable[str] = ()) -> int:
209
+ parser = argparse.ArgumentParser(description=__doc__)
210
+ parser.add_argument("--policy", choices=["random", "zeros", "pymatching"],
211
+ default=None,
212
+ help="evaluate a deterministic baseline instead of a model")
213
+ parser.add_argument("--adapter", type=str, default=None,
214
+ help="path to LoRA adapter dir; mutually exclusive with --policy")
215
+ parser.add_argument("--base-model", type=str,
216
+ default="Qwen/Qwen2.5-3B-Instruct")
217
+ parser.add_argument("--episodes", type=int, default=200)
218
+ parser.add_argument("--level", type=str, default=primary_level().name)
219
+ parser.add_argument("--max-new-tokens", type=int, default=160)
220
+ parser.add_argument("--out", type=str, default=None)
221
+ parser.add_argument("--report-to", type=str, default="none",
222
+ choices=["wandb", "none"],
223
+ help="If 'wandb', log summary + per-episode table.")
224
+ parser.add_argument("--wandb-run-name", type=str, default=None)
225
+ parser.add_argument("--wandb-group", type=str, default=None)
226
+ parser.add_argument("--wandb-tags", type=str, nargs="*", default=("eval",))
227
+ parser.add_argument("--wandb-notes", type=str, default=None)
228
+ args = parser.parse_args(list(argv))
229
+
230
+ if (args.policy is None) == (args.adapter is None):
231
+ print("ERROR: exactly one of --policy and --adapter is required",
232
+ file=sys.stderr)
233
+ return 1
234
+
235
+ from qubit_medic import wandb_utils
236
+
237
+ report_to = wandb_utils.derive_report_to(args.report_to)
238
+ use_wandb = report_to == "wandb"
239
+ if use_wandb:
240
+ slug = args.policy or (args.adapter or "model").replace("/", "_")
241
+ run_name = args.wandb_run_name or wandb_utils.make_run_name(
242
+ "eval", suffix=slug)
243
+ wandb_utils.init_run(
244
+ run_name=run_name,
245
+ job_type="eval",
246
+ tags=tuple(list(args.wandb_tags) + [args.level]),
247
+ notes=args.wandb_notes,
248
+ group=args.wandb_group,
249
+ extra_config={
250
+ "cli": {
251
+ "policy": args.policy,
252
+ "adapter": args.adapter,
253
+ "episodes": args.episodes,
254
+ "level": args.level,
255
+ "max_new_tokens": args.max_new_tokens,
256
+ "base_model": args.base_model,
257
+ },
258
+ },
259
+ )
260
+
261
+ if args.policy is not None:
262
+ result, rows = _eval_baseline(args.policy, args.episodes, args.level,
263
+ collect_rows=use_wandb)
264
+ else:
265
+ result, rows = _eval_model(args.adapter, args.episodes, args.level,
266
+ args.base_model, args.max_new_tokens,
267
+ collect_rows=use_wandb)
268
+ result["level"] = args.level
269
+ print(json.dumps(result, indent=2))
270
+
271
+ if args.out:
272
+ from pathlib import Path
273
+ Path(args.out).parent.mkdir(parents=True, exist_ok=True)
274
+ with open(args.out, "w") as f:
275
+ json.dump(result, f, indent=2)
276
+
277
+ if use_wandb:
278
+ wandb_utils.log_eval_summary(result, prefix="eval")
279
+ if rows:
280
+ wandb_utils.log_generation_table(
281
+ rows, step=None, table_name="eval/episode_breakdown",
282
+ )
283
+ wandb_utils.update_summary({
284
+ "eval/policy_or_adapter": args.policy or args.adapter,
285
+ "eval/episodes": args.episodes,
286
+ "eval/level": args.level,
287
+ })
288
+ wandb_utils.finish_run()
289
+
290
+ return 0
291
+
292
+
293
+ if __name__ == "__main__":
294
+ sys.exit(main(sys.argv[1:]))