File size: 5,575 Bytes
7b91523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""scripts/validate_env.py - the Section 1.1 environment validation.

Five gates, in order:

    1.  Imports succeed (catches install issues).
    2.  Stim generates a tiny distance-3 surface code.
    3.  PyMatching decodes 100 syndromes.
    4.  Logical-error rate at p=0.001 is in the expected range.
    5.  ``DecoderEnvironment`` reset+step works end-to-end (proves the wire
        contract is intact).

Run with::

    .venv/bin/python -m scripts.validate_env

Exit code is 0 iff every gate passes. The participant guide explicitly
warns: *"if any of these fail on any team member's machine, fix it now -
not at 11pm on Day 1."*
"""
from __future__ import annotations

import sys
import time
from typing import Iterable

GATES = []


def gate(name: str):
    def deco(fn):
        GATES.append((name, fn))
        return fn
    return deco


def _ok(name: str, msg: str = "") -> None:
    extra = f" {msg}" if msg else ""
    print(f"  PASS  {name}{extra}")


def _fail(name: str, msg: str) -> None:
    print(f"  FAIL  {name} -- {msg}")


# --------------------------------------------------------------------------- #


@gate("imports")
def _imports() -> None:
    import stim, pymatching, numpy, fastapi, pydantic  # noqa: F401
    import qubit_medic
    import qubit_medic.config
    import qubit_medic.models
    import qubit_medic.prompts
    import qubit_medic.server.physics
    import qubit_medic.server.rewards
    import qubit_medic.server.curriculum
    import qubit_medic.server.environment
    print(f"      stim={stim.__version__}  pymatching={pymatching.__version__}  "
          f"qubit_medic={qubit_medic.__version__}")


@gate("stim_circuit_generation")
def _stim_gen() -> None:
    from qubit_medic.config import primary_level
    from qubit_medic.server.physics import build_circuit, build_dem, extract_layout
    c = build_circuit(primary_level())
    dem = build_dem(c)
    layout = extract_layout(c)
    assert layout.num_data_qubits == 9, f"expected 9 data qubits, got {layout.num_data_qubits}"
    assert layout.num_ancilla_qubits == 8
    assert layout.z_observable_support == (1, 3, 5)
    print(f"      circuit={len(str(c))} chars, DEM={len(str(dem))} chars, "
          f"obs_support={layout.z_observable_support}")


@gate("pymatching_decoding_100")
def _pm_decoding() -> None:
    import pymatching, numpy as np
    from qubit_medic.config import primary_level
    from qubit_medic.server.physics import build_circuit, build_dem
    c = build_circuit(primary_level())
    dem = build_dem(c)
    sampler = c.compile_detector_sampler(seed=42)
    det, obs = sampler.sample(100, separate_observables=True)
    m = pymatching.Matching.from_detector_error_model(dem)
    pred = m.decode_batch(det)
    err_rate = float(np.mean(np.any(pred != obs, axis=1)))
    print(f"      logical-error rate (100 shots): {err_rate:.4f}")


@gate("ler_in_expected_range")
def _ler_range() -> None:
    """At distance 3, p=0.001, 5000 shots, PyMatching LER should be < 1%."""
    import pymatching, numpy as np
    from qubit_medic.config import primary_level
    from qubit_medic.server.physics import build_circuit, build_dem
    c = build_circuit(primary_level())
    dem = build_dem(c)
    sampler = c.compile_detector_sampler(seed=2024)
    det, obs = sampler.sample(5000, separate_observables=True)
    m = pymatching.Matching.from_detector_error_model(dem)
    pred = m.decode_batch(det)
    err = float(np.mean(np.any(pred != obs, axis=1)))
    expected_lo, expected_hi = 0.0, 0.01
    if not (expected_lo <= err <= expected_hi):
        raise AssertionError(
            f"PyMatching LER {err:.4f} outside [{expected_lo}, {expected_hi}]"
        )
    print(f"      PyMatching LER on 5000 shots: {err:.4f} "
          f"(expected ~0.001 - 0.01)")


@gate("decoder_environment_roundtrip")
def _env_roundtrip() -> None:
    """Reset + step round-trip with three trivial policies."""
    from qubit_medic.client.client import LocalDecoderClient
    from qubit_medic.prompts import format_completion

    client = LocalDecoderClient()
    obs = client.reset(forced_level="L2_target", seed=1)
    assert obs.distance == 3 and obs.rounds == 3
    assert obs.curriculum_level == "L2_target"

    # All-zeros policy: claim no errors.
    result = client.step(
        raw_response=format_completion([], []),
        episode_id=obs.episode_id,
    )
    assert result.done is True
    assert "rewards" in result.info
    print(f"      reset->step round-trip ok; "
          f"all-zeros total reward={result.reward:.3f}, "
          f"breakdown={result.info['rewards']}")

    # Trivial second episode under forced L1.
    obs2 = client.reset(forced_level="L1_warmup", seed=2)
    assert obs2.distance == 3 and obs2.rounds == 1
    print(f"      L1 warmup reset OK; prompt is {len(obs2.prompt)} chars long")


# --------------------------------------------------------------------------- #


def main(argv: Iterable[str] = ()) -> int:
    print("Qubit-Medic environment validation")
    print("=" * 60)
    failures = 0
    started = time.monotonic()
    for name, fn in GATES:
        try:
            fn()
            _ok(name)
        except Exception as exc:  # noqa: BLE001 - we want to keep going
            _fail(name, repr(exc))
            failures += 1
    elapsed = time.monotonic() - started
    print("=" * 60)
    if failures:
        print(f"{failures} gate(s) failed in {elapsed:.2f}s")
        return 1
    print(f"all {len(GATES)} gates passed in {elapsed:.2f}s")
    return 0


if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))