ronitraj commited on
Commit
7b91523
·
verified ·
1 Parent(s): 7dc2fe6

Upload scripts/validate_env.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/validate_env.py +167 -0
scripts/validate_env.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """scripts/validate_env.py - the Section 1.1 environment validation.
2
+
3
+ Five gates, in order:
4
+
5
+ 1. Imports succeed (catches install issues).
6
+ 2. Stim generates a tiny distance-3 surface code.
7
+ 3. PyMatching decodes 100 syndromes.
8
+ 4. Logical-error rate at p=0.001 is in the expected range.
9
+ 5. ``DecoderEnvironment`` reset+step works end-to-end (proves the wire
10
+ contract is intact).
11
+
12
+ Run with::
13
+
14
+ .venv/bin/python -m scripts.validate_env
15
+
16
+ Exit code is 0 iff every gate passes. The participant guide explicitly
17
+ warns: *"if any of these fail on any team member's machine, fix it now -
18
+ not at 11pm on Day 1."*
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import sys
23
+ import time
24
+ from typing import Iterable
25
+
26
+ GATES = []
27
+
28
+
29
+ def gate(name: str):
30
+ def deco(fn):
31
+ GATES.append((name, fn))
32
+ return fn
33
+ return deco
34
+
35
+
36
+ def _ok(name: str, msg: str = "") -> None:
37
+ extra = f" {msg}" if msg else ""
38
+ print(f" PASS {name}{extra}")
39
+
40
+
41
+ def _fail(name: str, msg: str) -> None:
42
+ print(f" FAIL {name} -- {msg}")
43
+
44
+
45
+ # --------------------------------------------------------------------------- #
46
+
47
+
48
+ @gate("imports")
49
+ def _imports() -> None:
50
+ import stim, pymatching, numpy, fastapi, pydantic # noqa: F401
51
+ import qubit_medic
52
+ import qubit_medic.config
53
+ import qubit_medic.models
54
+ import qubit_medic.prompts
55
+ import qubit_medic.server.physics
56
+ import qubit_medic.server.rewards
57
+ import qubit_medic.server.curriculum
58
+ import qubit_medic.server.environment
59
+ print(f" stim={stim.__version__} pymatching={pymatching.__version__} "
60
+ f"qubit_medic={qubit_medic.__version__}")
61
+
62
+
63
+ @gate("stim_circuit_generation")
64
+ def _stim_gen() -> None:
65
+ from qubit_medic.config import primary_level
66
+ from qubit_medic.server.physics import build_circuit, build_dem, extract_layout
67
+ c = build_circuit(primary_level())
68
+ dem = build_dem(c)
69
+ layout = extract_layout(c)
70
+ assert layout.num_data_qubits == 9, f"expected 9 data qubits, got {layout.num_data_qubits}"
71
+ assert layout.num_ancilla_qubits == 8
72
+ assert layout.z_observable_support == (1, 3, 5)
73
+ print(f" circuit={len(str(c))} chars, DEM={len(str(dem))} chars, "
74
+ f"obs_support={layout.z_observable_support}")
75
+
76
+
77
+ @gate("pymatching_decoding_100")
78
+ def _pm_decoding() -> None:
79
+ import pymatching, numpy as np
80
+ from qubit_medic.config import primary_level
81
+ from qubit_medic.server.physics import build_circuit, build_dem
82
+ c = build_circuit(primary_level())
83
+ dem = build_dem(c)
84
+ sampler = c.compile_detector_sampler(seed=42)
85
+ det, obs = sampler.sample(100, separate_observables=True)
86
+ m = pymatching.Matching.from_detector_error_model(dem)
87
+ pred = m.decode_batch(det)
88
+ err_rate = float(np.mean(np.any(pred != obs, axis=1)))
89
+ print(f" logical-error rate (100 shots): {err_rate:.4f}")
90
+
91
+
92
+ @gate("ler_in_expected_range")
93
+ def _ler_range() -> None:
94
+ """At distance 3, p=0.001, 5000 shots, PyMatching LER should be < 1%."""
95
+ import pymatching, numpy as np
96
+ from qubit_medic.config import primary_level
97
+ from qubit_medic.server.physics import build_circuit, build_dem
98
+ c = build_circuit(primary_level())
99
+ dem = build_dem(c)
100
+ sampler = c.compile_detector_sampler(seed=2024)
101
+ det, obs = sampler.sample(5000, separate_observables=True)
102
+ m = pymatching.Matching.from_detector_error_model(dem)
103
+ pred = m.decode_batch(det)
104
+ err = float(np.mean(np.any(pred != obs, axis=1)))
105
+ expected_lo, expected_hi = 0.0, 0.01
106
+ if not (expected_lo <= err <= expected_hi):
107
+ raise AssertionError(
108
+ f"PyMatching LER {err:.4f} outside [{expected_lo}, {expected_hi}]"
109
+ )
110
+ print(f" PyMatching LER on 5000 shots: {err:.4f} "
111
+ f"(expected ~0.001 - 0.01)")
112
+
113
+
114
+ @gate("decoder_environment_roundtrip")
115
+ def _env_roundtrip() -> None:
116
+ """Reset + step round-trip with three trivial policies."""
117
+ from qubit_medic.client.client import LocalDecoderClient
118
+ from qubit_medic.prompts import format_completion
119
+
120
+ client = LocalDecoderClient()
121
+ obs = client.reset(forced_level="L2_target", seed=1)
122
+ assert obs.distance == 3 and obs.rounds == 3
123
+ assert obs.curriculum_level == "L2_target"
124
+
125
+ # All-zeros policy: claim no errors.
126
+ result = client.step(
127
+ raw_response=format_completion([], []),
128
+ episode_id=obs.episode_id,
129
+ )
130
+ assert result.done is True
131
+ assert "rewards" in result.info
132
+ print(f" reset->step round-trip ok; "
133
+ f"all-zeros total reward={result.reward:.3f}, "
134
+ f"breakdown={result.info['rewards']}")
135
+
136
+ # Trivial second episode under forced L1.
137
+ obs2 = client.reset(forced_level="L1_warmup", seed=2)
138
+ assert obs2.distance == 3 and obs2.rounds == 1
139
+ print(f" L1 warmup reset OK; prompt is {len(obs2.prompt)} chars long")
140
+
141
+
142
+ # --------------------------------------------------------------------------- #
143
+
144
+
145
+ def main(argv: Iterable[str] = ()) -> int:
146
+ print("Qubit-Medic environment validation")
147
+ print("=" * 60)
148
+ failures = 0
149
+ started = time.monotonic()
150
+ for name, fn in GATES:
151
+ try:
152
+ fn()
153
+ _ok(name)
154
+ except Exception as exc: # noqa: BLE001 - we want to keep going
155
+ _fail(name, repr(exc))
156
+ failures += 1
157
+ elapsed = time.monotonic() - started
158
+ print("=" * 60)
159
+ if failures:
160
+ print(f"{failures} gate(s) failed in {elapsed:.2f}s")
161
+ return 1
162
+ print(f"all {len(GATES)} gates passed in {elapsed:.2f}s")
163
+ return 0
164
+
165
+
166
+ if __name__ == "__main__":
167
+ sys.exit(main(sys.argv[1:]))