"""Stim + PyMatching wrapper - the physics engine (Section 2.4 of the plan). This module never makes decoding decisions: it builds circuits, samples syndromes, computes baselines, and exposes the observable's support on the data qubits so the reward functions can score predictions deterministically. Two design choices worth flagging up-front: * The LLM's action is a **terminal Pauli frame** on data qubits (the X and Z errors on each data qubit at the moment of final measurement). This representation is exact for the rotated memory_z task and lets us reuse Stim/PyMatching ground-truth machinery. The trade-off is documented in ``rewards.py``: the syndrome-consistency reward (Reward 2) only constrains the *final-round* detectors. Earlier rounds are silent w.r.t. an end-of-circuit Pauli frame; that is intentional and made explicit in the reward's docstring. * "Ground-truth error pattern" for Reward 3 is taken to be the **PyMatching-most-probable error pattern** explaining the syndrome (extracted via ``Matching.decode_to_edges_array``). This is the near-optimal canonical choice and matches what the AlphaQubit baseline comparison uses. The README's *honesty note* repeats this. """ from __future__ import annotations import hashlib from dataclasses import dataclass from typing import Optional import numpy as np import pymatching import stim from qubit_medic.config import ( CODE_TASK, CurriculumLevel, SI1000Rates, ) # --------------------------------------------------------------------------- # # Circuit + DEM construction # # --------------------------------------------------------------------------- # def build_circuit(level: CurriculumLevel) -> stim.Circuit: """Generate a Stim ``rotated_memory_z`` circuit at the given level.""" rates = SI1000Rates.from_p(level.p) return stim.Circuit.generated( CODE_TASK, distance=level.distance, rounds=level.rounds, **rates.as_stim_kwargs(), ) def build_dem(circuit: stim.Circuit) -> stim.DetectorErrorModel: """Decompose-errors=True is mandatory for PyMatching.""" return circuit.detector_error_model(decompose_errors=True) def dem_digest(dem: stim.DetectorErrorModel) -> str: """8-char digest of the DEM, useful for grouping training logs.""" return hashlib.sha256(str(dem).encode("utf-8")).hexdigest()[:8] # --------------------------------------------------------------------------- # # Layout introspection - figure out data qubits, ancillas, observable support # # --------------------------------------------------------------------------- # @dataclass(frozen=True) class CircuitLayout: """Static facts about a circuit, computed once per episode. Two indexings coexist: * **Stim IDs** (``data_qubits``) are the physical qubit IDs Stim emits (e.g. ``(1, 3, 5, 8, 10, 12, 15, 17, 19)`` for distance-3). These are what Stim/PyMatching speak. * **LLM IDs** are consecutive ``0..num_data_qubits-1``. These are what the prompt advertises and what the LLM emits, because consecutive small ints are dramatically easier for a language model to handle. :meth:`llm_to_stim` and :meth:`stim_to_llm` perform the remap. *All* server-internal scoring uses Stim IDs; the boundary at the prompt formatter / parser converts. """ data_qubits: tuple[int, ...] """Stim IDs of data qubits (measured by terminal ``M``), sorted.""" data_qubit_coords: tuple[tuple[float, float], ...] """(x, y) coordinate of each data qubit, in the same order as ``data_qubits``. Used by Reward 3 to snap PyMatching edges to qubits.""" ancilla_qubits: tuple[int, ...] """Physical qubit IDs that hold stabiliser measurements (``MR``).""" z_observable_support: tuple[int, ...] """Data qubits whose Z value is XOR'd into the logical Z observable. An X error on any of these flips the observable.""" detector_round: tuple[int, ...] """For each detector index, the round it nominally belongs to (0-based, extracted from the ``DETECTOR(x, y, t)`` coordinate).""" detector_coords: tuple[tuple[float, float], ...] """(x, y) coordinate of each detector, used by Reward 3.""" detector_is_x_type: tuple[bool, ...] """Whether the detector watches an X-stabiliser. For the rotated surface code Stim places X-stabilisers at coordinates with ``(x + y) mod 4 == 2`` and Z-stabilisers at ``(x + y) mod 4 == 0`` (verified empirically against Stim 1.15's ``surface_code:rotated_memory_z``).""" final_detectors: tuple[int, ...] """Indices of detectors that correspond to the *last* timeslice - those are the only detectors a terminal Pauli frame can affect (Reward 2).""" num_data_qubits: int num_ancilla_qubits: int num_detectors: int num_observables: int # ----- LLM <-> Stim qubit-ID remapping --------------------------------- def llm_to_stim(self, llm_ids: list[int]) -> list[int]: """Convert consecutive LLM IDs to physical Stim IDs. Out-of-range IDs are silently dropped (the parser already enforces the upper bound, but we double-check here as a defence-in-depth). """ out: list[int] = [] n = len(self.data_qubits) for i in llm_ids: if 0 <= i < n: out.append(self.data_qubits[i]) return out def stim_to_llm(self, stim_ids: list[int]) -> list[int]: """Inverse of :meth:`llm_to_stim` - used to render targets in the SFT data and the imitator policy.""" lookup = {q: i for i, q in enumerate(self.data_qubits)} return [lookup[q] for q in stim_ids if q in lookup] def _walk_measurement_records( circuit: stim.Circuit, ) -> tuple[list[int], list[Optional[str]]]: """Replay the circuit (no sampling) to map each measurement record to a qubit. Returns parallel lists: qubits[i] = qubit id, instr[i] = gate.""" qubits: list[int] = [] instrs: list[Optional[str]] = [] def _walk(c: stim.Circuit, repeats: int = 1) -> None: for _ in range(repeats): for inst in c: if isinstance(inst, stim.CircuitRepeatBlock): _walk(inst.body_copy(), inst.repeat_count) continue name = inst.name if name in { "M", "MX", "MY", "MZ", "MR", "MRX", "MRY", "MRZ", "MPP", }: for t in inst.targets_copy(): if t.is_qubit_target: qubits.append(t.qubit_value) instrs.append(name) _walk(circuit) return qubits, instrs def extract_layout(circuit: stim.Circuit) -> CircuitLayout: """Walk the circuit once to build a full :class:`CircuitLayout`.""" flat = circuit.flattened() measurement_qubits, measurement_instrs = _walk_measurement_records(circuit) # Data qubits = those measured by terminal ``M`` (destructive, no reset). data_qubits_in_order: list[int] = [] seen_data = set() for q, instr in zip(measurement_qubits, measurement_instrs): if instr == "M" and q not in seen_data: data_qubits_in_order.append(q) seen_data.add(q) data_qubits = tuple(sorted(seen_data)) # Ancilla qubits = everything measured by MR (reset after measurement). ancilla_qubits = tuple( sorted({q for q, instr in zip(measurement_qubits, measurement_instrs) if instr == "MR"}) ) # Observable support: walk OBSERVABLE_INCLUDE entries and resolve their # rec[-k] back to qubit IDs via the measurement record table. obs_support: dict[int, set[int]] = {} for inst in flat: if inst.name == "OBSERVABLE_INCLUDE": args = inst.gate_args_copy() obs_idx = int(args[0]) if args else 0 for t in inst.targets_copy(): if t.is_measurement_record_target: actual = len(measurement_qubits) + t.value # value is negative if 0 <= actual < len(measurement_qubits): obs_support.setdefault(obs_idx, set()).add( measurement_qubits[actual] ) z_obs = tuple(sorted(obs_support.get(0, set()))) # Qubit coordinates from QUBIT_COORDS instructions. qubit_coords: dict[int, tuple[float, float]] = {} for inst in flat: if inst.name == "QUBIT_COORDS": args = inst.gate_args_copy() x = float(args[0]) if len(args) >= 1 else 0.0 y = float(args[1]) if len(args) >= 2 else 0.0 for t in inst.targets_copy(): if t.is_qubit_target: qubit_coords[t.qubit_value] = (x, y) data_qubit_coords = tuple(qubit_coords.get(q, (0.0, 0.0)) for q in data_qubits) # Detector coordinates - last value of the tuple is the round index. det_coords_raw = circuit.get_detector_coordinates() num_dets = circuit.num_detectors rounds_per_det: list[int] = [] is_x_type: list[bool] = [] detector_coords: list[tuple[float, float]] = [] for i in range(num_dets): c = det_coords_raw.get(i, ()) if not c: rounds_per_det.append(0) is_x_type.append(False) detector_coords.append((0.0, 0.0)) continue round_idx = int(c[-1]) if len(c) >= 3 else 0 rounds_per_det.append(round_idx) x = float(c[0]) if len(c) >= 1 else 0.0 y = float(c[1]) if len(c) >= 2 else 0.0 detector_coords.append((x, y)) # X-stabilisers sit at (x + y) % 4 == 2 in Stim's generator. is_x_type.append((int(x + y) % 4) == 2) final_round = max(rounds_per_det) if rounds_per_det else 0 final_dets = tuple(i for i, r in enumerate(rounds_per_det) if r == final_round) return CircuitLayout( data_qubits=data_qubits, data_qubit_coords=data_qubit_coords, ancilla_qubits=ancilla_qubits, z_observable_support=z_obs, detector_round=tuple(rounds_per_det), detector_coords=tuple(detector_coords), detector_is_x_type=tuple(is_x_type), final_detectors=final_dets, num_data_qubits=len(data_qubits), num_ancilla_qubits=len(ancilla_qubits), num_detectors=num_dets, num_observables=circuit.num_observables, ) # --------------------------------------------------------------------------- # # Sampling and decoding # # --------------------------------------------------------------------------- # @dataclass(frozen=True) class SyndromeSample: """One noisy episode: detector activations, ground-truth observable flip, and PyMatching's prediction (used by Rewards 1 and 5).""" syndrome_bits: list[int] actual_observable_flip: int pymatching_observable_pred: int pymatching_x_errors: list[int] # Pauli frame at end of circuit (X part) pymatching_z_errors: list[int] # Pauli frame at end of circuit (Z part) def sample_episode( circuit: stim.Circuit, matching: pymatching.Matching, layout: CircuitLayout, seed: Optional[int] = None, ) -> SyndromeSample: """Sample one shot, decode it with PyMatching, and bundle the result.""" sampler = circuit.compile_detector_sampler(seed=seed) detection, observables = sampler.sample(1, separate_observables=True) detection_row = detection[0].astype(np.uint8) observable_flip = int(observables[0, 0]) if observables.shape[1] else 0 # PyMatching's prediction (observable level). pred_obs = int(matching.decode(detection_row)[0]) # PyMatching's predicted physical Pauli frame on data qubits. pred_x, pred_z = pymatching_predicted_pauli_frame( matching=matching, syndrome=detection_row, layout=layout, ) return SyndromeSample( syndrome_bits=detection_row.tolist(), actual_observable_flip=observable_flip, pymatching_observable_pred=pred_obs, pymatching_x_errors=pred_x, pymatching_z_errors=pred_z, ) def pymatching_predicted_pauli_frame( matching: pymatching.Matching, syndrome: np.ndarray, layout: CircuitLayout, ) -> tuple[list[int], list[int]]: """Convert PyMatching's per-edge prediction into a data-qubit Pauli frame. The matching graph's edges correspond to error mechanisms in the DEM. Each edge connects two detectors (or a detector and a boundary). The data qubit responsible for the edge sits geometrically between the two detectors on the surface-code grid - we recover it by snapping the midpoint of the detector coordinates to the nearest data qubit. This frame is used as ground-truth for Reward 3 (Hamming overlap). Z-stabiliser endpoints (``(x+y) mod 4 == 0``) catch X errors on data qubits; X-stabiliser endpoints catch Z errors. Boundary edges are snapped to the unique data qubit adjacent to that boundary. """ try: edges = matching.decode_to_edges_array(syndrome) except Exception: return [], [] if edges is None or len(edges) == 0: return [], [] data_qubits = layout.data_qubits data_coords = layout.data_qubit_coords det_coords = layout.detector_coords det_is_x = layout.detector_is_x_type n_dets = len(det_coords) def _snap(x: float, y: float) -> int: best_q = data_qubits[0] best_d = float("inf") for q, (qx, qy) in zip(data_qubits, data_coords): d = (qx - x) ** 2 + (qy - y) ** 2 if d < best_d: best_d = d best_q = q return best_q x_errs: set[int] = set() z_errs: set[int] = set() for edge in edges: a, b = int(edge[0]), int(edge[1]) ca = det_coords[a] if 0 <= a < n_dets else None cb = det_coords[b] if 0 <= b < n_dets else None if ca is None and cb is None: continue if cb is None: mid_x, mid_y = ca ref_is_x = det_is_x[a] elif ca is None: mid_x, mid_y = cb ref_is_x = det_is_x[b] else: mid_x = (ca[0] + cb[0]) / 2.0 mid_y = (ca[1] + cb[1]) / 2.0 ref_is_x = det_is_x[a] if 0 <= a < n_dets else det_is_x[b] snap = _snap(mid_x, mid_y) if ref_is_x: z_errs.add(snap) else: x_errs.add(snap) return sorted(x_errs), sorted(z_errs) # --------------------------------------------------------------------------- # # Predicted-observable computation (used by Reward 1) # # --------------------------------------------------------------------------- # def predicted_observable_flip( predicted_x_qubits: list[int], layout: CircuitLayout, ) -> int: """Compute the implied logical-Z flip from a predicted Pauli frame. Only X errors on data qubits in ``z_observable_support`` matter for the Z observable - Z errors on data qubits commute with the destructive Z measurement and so cannot flip the observable. """ support = set(layout.z_observable_support) parity = 0 for q in predicted_x_qubits: if q in support: parity ^= 1 return parity def rectify_pauli_frame_to_observable( x_errors: list[int], z_errors: list[int], target_observable_flip: int, layout: CircuitLayout, ) -> tuple[list[int], list[int]]: """Adjust a predicted X-error frame so its implied observable matches. Used by the SFT data generator and the PyMatching imitator policy: the snap-to-data-qubit edge mapping (:func:`pymatching_predicted_pauli_frame`) is only ~95% faithful, but PyMatching's *observable* prediction is exact. We patch the X frame by toggling the smallest-degree data qubit on the observable support whenever the implied parity disagrees with the target. Z errors are untouched because they don't affect a Z observable. """ implied = predicted_observable_flip(x_errors, layout) if implied == target_observable_flip: return list(x_errors), list(z_errors) support = list(layout.z_observable_support) if not support: return list(x_errors), list(z_errors) x_set = set(x_errors) intersect = sorted(x_set & set(support)) if intersect: # Remove the smallest one currently flipping the observable. x_set.discard(intersect[0]) else: # Add the smallest support qubit to introduce a flip. x_set.add(support[0]) return sorted(x_set), list(z_errors) # --------------------------------------------------------------------------- # # Stabiliser counts - derived from layout # # --------------------------------------------------------------------------- # def detector_round_split(layout: CircuitLayout, syndrome_bits: list[int]) -> dict[int, list[int]]: """Group detector bits by their nominal round (used for prompt formatting).""" out: dict[int, list[int]] = {} for idx, bit in enumerate(syndrome_bits): r = layout.detector_round[idx] if idx < len(layout.detector_round) else 0 out.setdefault(r, []).append(bit) return out def per_round_x_z_counts(layout: CircuitLayout) -> tuple[int, int]: """Best-effort count of X-type and Z-type stabiliser detectors per round. For a rotated surface code at distance d there are (d^2-1)/2 of each type. We compute that from the layout to be robust. """ # Take one fully-populated round (the one with the most detectors). round_counts: dict[int, list[bool]] = {} for idx, r in enumerate(layout.detector_round): round_counts.setdefault(r, []).append(layout.detector_is_x_type[idx]) if not round_counts: return 0, 0 full_round = max(round_counts.values(), key=len) n_x = sum(1 for v in full_round if v) n_z = sum(1 for v in full_round if not v) return n_x, n_z