ronitraj commited on
Commit
74d70f5
·
verified ·
1 Parent(s): 16c627e

Upload scripts/baseline_policies.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/baseline_policies.py +196 -0
scripts/baseline_policies.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Three baseline policies (Section 2.7 of the plan).
2
+
3
+ Run::
4
+
5
+ .venv/bin/python -m scripts.baseline_policies --episodes 500
6
+
7
+ Expected ranges (Section 2.7):
8
+
9
+ * Random policy: ~10% logical correction
10
+ * All-zeros policy: ~99% on L1 (warmup, p=0.0001), ~99% on L2 (still small)
11
+ * PyMatching imitator: ~99-100% logical correction
12
+
13
+ The plan's quoted numbers ("~10%", "~40%", "~97%") refer to a different
14
+ counting (per-shot accuracy on a *high-noise* level). At p=0.001 the
15
+ syndromes are mostly all-zero, so the all-zeros baseline will look very
16
+ strong. We report both the headline level (L2) and a high-noise level
17
+ (p=0.01) for an honest comparison.
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
+ import json
23
+ import random
24
+ from dataclasses import dataclass
25
+ from typing import Callable, Iterable
26
+
27
+ from qubit_medic.client.client import LocalDecoderClient
28
+ from qubit_medic.config import CURRICULUM, primary_level
29
+ from qubit_medic.models import DecoderObservation
30
+ from qubit_medic.prompts import format_completion
31
+
32
+
33
+ Policy = Callable[[DecoderObservation], str]
34
+
35
+
36
+ # --------------------------------------------------------------------------- #
37
+ # Three policies #
38
+ # --------------------------------------------------------------------------- #
39
+
40
+
41
+ def policy_random(obs: DecoderObservation, *, rng: random.Random) -> str:
42
+ """Random qubit IDs - the noise floor."""
43
+ n = max(1, obs.distance ** 2) # number of data qubits
44
+ k = rng.randint(0, max(1, n // 2))
45
+ xs = sorted(rng.sample(range(n), k=min(k, n)))
46
+ k = rng.randint(0, max(1, n // 2))
47
+ zs = sorted(rng.sample(range(n), k=min(k, n)))
48
+ return format_completion(xs, zs)
49
+
50
+
51
+ def policy_zeros(obs: DecoderObservation) -> str:
52
+ """Always predict 'no errors'."""
53
+ return format_completion([], [])
54
+
55
+
56
+ _PM_CACHE: dict[str, tuple] = {}
57
+
58
+
59
+ def policy_pymatching(obs: DecoderObservation, *, env_client: LocalDecoderClient) -> str:
60
+ """Use PyMatching's prediction as the LLM imitator's response.
61
+
62
+ This is a 'cheating' policy in the sense that it consults the same
63
+ baseline used by Reward 5, so beat-rate is 0 by definition. Per-level
64
+ Stim/PyMatching artefacts are cached so the policy stays fast.
65
+ """
66
+ import pymatching, numpy as np
67
+ from qubit_medic.config import level_by_name
68
+ from qubit_medic.server.physics import (
69
+ build_circuit, build_dem, extract_layout,
70
+ pymatching_predicted_pauli_frame, rectify_pauli_frame_to_observable,
71
+ )
72
+ cached = _PM_CACHE.get(obs.curriculum_level)
73
+ if cached is None:
74
+ lvl = level_by_name(obs.curriculum_level)
75
+ c = build_circuit(lvl)
76
+ dem = build_dem(c)
77
+ m = pymatching.Matching.from_detector_error_model(dem)
78
+ layout = extract_layout(c)
79
+ cached = (m, layout)
80
+ _PM_CACHE[obs.curriculum_level] = cached
81
+ m, layout = cached
82
+ syndrome = np.asarray(obs.syndrome_bits, dtype=np.uint8)
83
+ px_stim, pz_stim = pymatching_predicted_pauli_frame(m, syndrome, layout)
84
+ pm_obs = int(m.decode(syndrome)[0])
85
+ px_stim, pz_stim = rectify_pauli_frame_to_observable(
86
+ px_stim, pz_stim, pm_obs, layout,
87
+ )
88
+ return format_completion(layout.stim_to_llm(px_stim),
89
+ layout.stim_to_llm(pz_stim))
90
+
91
+
92
+ # --------------------------------------------------------------------------- #
93
+ # Evaluation harness #
94
+ # --------------------------------------------------------------------------- #
95
+
96
+
97
+ @dataclass
98
+ class PolicyStats:
99
+ name: str
100
+ episodes: int = 0
101
+ logical_correct: int = 0
102
+ format_ok: int = 0
103
+ beat_pm: int = 0
104
+ sum_total: float = 0.0
105
+
106
+ def update(self, info: dict, total: float) -> None:
107
+ self.episodes += 1
108
+ rewards = info["rewards"]
109
+ if rewards["logical_correction"] >= 0.5:
110
+ self.logical_correct += 1
111
+ if rewards["format_compliance"] >= 0.5:
112
+ self.format_ok += 1
113
+ if rewards["pymatching_beat"] >= 0.5:
114
+ self.beat_pm += 1
115
+ self.sum_total += total
116
+
117
+ def as_dict(self) -> dict:
118
+ n = max(1, self.episodes)
119
+ return {
120
+ "name": self.name,
121
+ "episodes": self.episodes,
122
+ "logical_correction_rate": self.logical_correct / n,
123
+ "format_compliance_rate": self.format_ok / n,
124
+ "pymatching_beat_rate": self.beat_pm / n,
125
+ "mean_total_reward": self.sum_total / n,
126
+ }
127
+
128
+
129
+ def evaluate_policy(
130
+ *,
131
+ name: str,
132
+ policy: Policy,
133
+ episodes: int,
134
+ forced_level: str,
135
+ seed: int = 0,
136
+ ) -> dict:
137
+ """Run a policy for ``episodes`` shots at one curriculum level."""
138
+ client = LocalDecoderClient()
139
+ stats = PolicyStats(name=name)
140
+ for ep in range(episodes):
141
+ obs = client.reset(forced_level=forced_level, seed=seed + ep)
142
+ raw = policy(obs)
143
+ result = client.step(raw_response=raw, episode_id=obs.episode_id)
144
+ stats.update(info=result.info, total=result.reward)
145
+ return stats.as_dict()
146
+
147
+
148
+ # --------------------------------------------------------------------------- #
149
+ # CLI #
150
+ # --------------------------------------------------------------------------- #
151
+
152
+
153
+ def main(argv: Iterable[str] = ()) -> int:
154
+ parser = argparse.ArgumentParser(description=__doc__)
155
+ parser.add_argument("--episodes", type=int, default=200,
156
+ help="episodes per (policy, level) pair")
157
+ parser.add_argument("--levels", nargs="*", default=["L1_warmup", "L2_target"])
158
+ parser.add_argument("--out", type=str, default=None,
159
+ help="optional path to dump JSON results")
160
+ args = parser.parse_args(list(argv))
161
+
162
+ rng = random.Random(42)
163
+ random_policy = lambda obs: policy_random(obs, rng=rng) # noqa: E731
164
+ pm_policy_client = LocalDecoderClient()
165
+ pm_policy = lambda obs: policy_pymatching(obs, env_client=pm_policy_client) # noqa: E731
166
+
167
+ results = []
168
+ for level in args.levels:
169
+ for name, policy in (
170
+ ("random", random_policy),
171
+ ("zeros", policy_zeros),
172
+ ("pymatching", pm_policy),
173
+ ):
174
+ r = evaluate_policy(
175
+ name=name, policy=policy, episodes=args.episodes,
176
+ forced_level=level,
177
+ )
178
+ r["level"] = level
179
+ results.append(r)
180
+ print(
181
+ f"{level:<12} {name:<12} "
182
+ f"LER={1 - r['logical_correction_rate']:.3f} "
183
+ f"correct={r['logical_correction_rate']:.3f} "
184
+ f"format={r['format_compliance_rate']:.3f} "
185
+ f"beat={r['pymatching_beat_rate']:.3f} "
186
+ f"mean_R={r['mean_total_reward']:.3f}"
187
+ )
188
+ if args.out:
189
+ with open(args.out, "w") as f:
190
+ json.dump(results, f, indent=2)
191
+ return 0
192
+
193
+
194
+ if __name__ == "__main__":
195
+ import sys
196
+ sys.exit(main(sys.argv[1:]))