Spaces:
Sleeping
Sleeping
File size: 5,575 Bytes
7b91523 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """scripts/validate_env.py - the Section 1.1 environment validation.
Five gates, in order:
1. Imports succeed (catches install issues).
2. Stim generates a tiny distance-3 surface code.
3. PyMatching decodes 100 syndromes.
4. Logical-error rate at p=0.001 is in the expected range.
5. ``DecoderEnvironment`` reset+step works end-to-end (proves the wire
contract is intact).
Run with::
.venv/bin/python -m scripts.validate_env
Exit code is 0 iff every gate passes. The participant guide explicitly
warns: *"if any of these fail on any team member's machine, fix it now -
not at 11pm on Day 1."*
"""
from __future__ import annotations
import sys
import time
from typing import Iterable
GATES = []
def gate(name: str):
def deco(fn):
GATES.append((name, fn))
return fn
return deco
def _ok(name: str, msg: str = "") -> None:
extra = f" {msg}" if msg else ""
print(f" PASS {name}{extra}")
def _fail(name: str, msg: str) -> None:
print(f" FAIL {name} -- {msg}")
# --------------------------------------------------------------------------- #
@gate("imports")
def _imports() -> None:
import stim, pymatching, numpy, fastapi, pydantic # noqa: F401
import qubit_medic
import qubit_medic.config
import qubit_medic.models
import qubit_medic.prompts
import qubit_medic.server.physics
import qubit_medic.server.rewards
import qubit_medic.server.curriculum
import qubit_medic.server.environment
print(f" stim={stim.__version__} pymatching={pymatching.__version__} "
f"qubit_medic={qubit_medic.__version__}")
@gate("stim_circuit_generation")
def _stim_gen() -> None:
from qubit_medic.config import primary_level
from qubit_medic.server.physics import build_circuit, build_dem, extract_layout
c = build_circuit(primary_level())
dem = build_dem(c)
layout = extract_layout(c)
assert layout.num_data_qubits == 9, f"expected 9 data qubits, got {layout.num_data_qubits}"
assert layout.num_ancilla_qubits == 8
assert layout.z_observable_support == (1, 3, 5)
print(f" circuit={len(str(c))} chars, DEM={len(str(dem))} chars, "
f"obs_support={layout.z_observable_support}")
@gate("pymatching_decoding_100")
def _pm_decoding() -> None:
import pymatching, numpy as np
from qubit_medic.config import primary_level
from qubit_medic.server.physics import build_circuit, build_dem
c = build_circuit(primary_level())
dem = build_dem(c)
sampler = c.compile_detector_sampler(seed=42)
det, obs = sampler.sample(100, separate_observables=True)
m = pymatching.Matching.from_detector_error_model(dem)
pred = m.decode_batch(det)
err_rate = float(np.mean(np.any(pred != obs, axis=1)))
print(f" logical-error rate (100 shots): {err_rate:.4f}")
@gate("ler_in_expected_range")
def _ler_range() -> None:
"""At distance 3, p=0.001, 5000 shots, PyMatching LER should be < 1%."""
import pymatching, numpy as np
from qubit_medic.config import primary_level
from qubit_medic.server.physics import build_circuit, build_dem
c = build_circuit(primary_level())
dem = build_dem(c)
sampler = c.compile_detector_sampler(seed=2024)
det, obs = sampler.sample(5000, separate_observables=True)
m = pymatching.Matching.from_detector_error_model(dem)
pred = m.decode_batch(det)
err = float(np.mean(np.any(pred != obs, axis=1)))
expected_lo, expected_hi = 0.0, 0.01
if not (expected_lo <= err <= expected_hi):
raise AssertionError(
f"PyMatching LER {err:.4f} outside [{expected_lo}, {expected_hi}]"
)
print(f" PyMatching LER on 5000 shots: {err:.4f} "
f"(expected ~0.001 - 0.01)")
@gate("decoder_environment_roundtrip")
def _env_roundtrip() -> None:
"""Reset + step round-trip with three trivial policies."""
from qubit_medic.client.client import LocalDecoderClient
from qubit_medic.prompts import format_completion
client = LocalDecoderClient()
obs = client.reset(forced_level="L2_target", seed=1)
assert obs.distance == 3 and obs.rounds == 3
assert obs.curriculum_level == "L2_target"
# All-zeros policy: claim no errors.
result = client.step(
raw_response=format_completion([], []),
episode_id=obs.episode_id,
)
assert result.done is True
assert "rewards" in result.info
print(f" reset->step round-trip ok; "
f"all-zeros total reward={result.reward:.3f}, "
f"breakdown={result.info['rewards']}")
# Trivial second episode under forced L1.
obs2 = client.reset(forced_level="L1_warmup", seed=2)
assert obs2.distance == 3 and obs2.rounds == 1
print(f" L1 warmup reset OK; prompt is {len(obs2.prompt)} chars long")
# --------------------------------------------------------------------------- #
def main(argv: Iterable[str] = ()) -> int:
print("Qubit-Medic environment validation")
print("=" * 60)
failures = 0
started = time.monotonic()
for name, fn in GATES:
try:
fn()
_ok(name)
except Exception as exc: # noqa: BLE001 - we want to keep going
_fail(name, repr(exc))
failures += 1
elapsed = time.monotonic() - started
print("=" * 60)
if failures:
print(f"{failures} gate(s) failed in {elapsed:.2f}s")
return 1
print(f"all {len(GATES)} gates passed in {elapsed:.2f}s")
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
|