Spaces:
Sleeping
Sleeping
| """The five reward functions (Section 3 of the plan). | |
| Design contract (from Section 3.6): | |
| * Each reward is a pure function ``(action, state, layout) -> float in [0, 1]``. | |
| * Rewards never observe each other - they're independent by construction so | |
| the LLM can't satisfy one at the expense of another without genuine task | |
| understanding. | |
| * The combined reward is a weighted sum (weights in :mod:`qubit_medic.config`) | |
| clamped to ``[0, 1]``. | |
| * Every per-component score is reported in the ``info`` dict so logs can | |
| surface reward-hacking early (Section 3.7). | |
| A note on Reward 2 and Reward 3 ground truth - see ``physics.py``: the LLM | |
| predicts a *terminal Pauli frame*, which fully determines the logical-Z | |
| observable but only constrains the *final-round* detectors. Earlier rounds' | |
| detectors are intentionally unscored. Reward 3 compares against PyMatching's | |
| near-optimal Pauli-frame prediction (the canonical decoder reference used in | |
| AlphaQubit's Nature paper). | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from qubit_medic.config import REWARD_WEIGHTS | |
| from qubit_medic.prompts import ParseResult | |
| from qubit_medic.server.physics import ( | |
| CircuitLayout, | |
| SyndromeSample, | |
| predicted_observable_flip, | |
| ) | |
| # --------------------------------------------------------------------------- # | |
| # Reward 1: logical correction success # | |
| # --------------------------------------------------------------------------- # | |
| def reward_logical_correction( | |
| parsed: ParseResult, | |
| sample: SyndromeSample, | |
| layout: CircuitLayout, | |
| ) -> float: | |
| """Did the predicted correction preserve the logical state? | |
| Apply the predicted X errors as a Pauli frame at end-of-circuit and | |
| compute the implied observable flip. If this matches the actual | |
| observable flip recorded by Stim, the logical state was preserved. | |
| Outputs 1.0 if so, else 0.0. | |
| This is the unfakeable reward - it depends only on Stim's ground truth. | |
| """ | |
| implied = predicted_observable_flip(parsed.x_errors, layout) | |
| return 1.0 if implied == sample.actual_observable_flip else 0.0 | |
| # --------------------------------------------------------------------------- # | |
| # Reward 2: syndrome consistency # | |
| # --------------------------------------------------------------------------- # | |
| def _syndrome_from_pauli_frame( | |
| x_errors: list[int], | |
| layout: CircuitLayout, | |
| final_detector_supports: dict[int, frozenset[int]], | |
| ) -> dict[int, int]: | |
| """Compute the implied bits for FINAL-round detectors only. | |
| A terminal X error on data qubit ``q`` flips a final-round Z-stabiliser | |
| detector iff ``q`` is in that detector's support. | |
| """ | |
| out: dict[int, int] = {} | |
| x_set = set(x_errors) | |
| for det_idx, support in final_detector_supports.items(): | |
| out[det_idx] = 1 if len(x_set & support) % 2 == 1 else 0 | |
| return out | |
| def reward_syndrome_consistency( | |
| parsed: ParseResult, | |
| sample: SyndromeSample, | |
| layout: CircuitLayout, | |
| final_detector_supports: dict[int, frozenset[int]], | |
| ) -> float: | |
| """How well does the predicted Pauli frame reproduce the FINAL detectors? | |
| Computes Hamming similarity between ``predicted_final_bits`` (induced | |
| by the predicted X errors) and ``observed_final_bits``. Returns | |
| ``1 - hamming_distance / num_final_detectors``. | |
| Rationale (Section 3.2): without this term, an LLM that lucky-guesses | |
| the right qubits could get Reward 1 occasionally; this signal forces | |
| it to also explain the data the syndrome carries. | |
| 2026-04 anti-collapse cap (FIX 1, RL spec rewrite): if the prediction | |
| is empty AND the observed syndrome is non-empty (at least one | |
| detector fired), cap the score at 0.5. Without this cap, the | |
| "always predict empty" policy can still pull a high syndrome- | |
| consistency score on the prompts where the implied final-round bits | |
| happen to coincide with zeros, which kept GRPO trapped in the | |
| constant-empty mode. | |
| """ | |
| final_dets = layout.final_detectors | |
| if not final_dets: | |
| return 0.0 | |
| implied = _syndrome_from_pauli_frame( | |
| parsed.x_errors, layout, final_detector_supports | |
| ) | |
| distance = 0 | |
| for det_idx in final_dets: | |
| observed = sample.syndrome_bits[det_idx] | |
| predicted = implied.get(det_idx, 0) | |
| if observed != predicted: | |
| distance += 1 | |
| base = 1.0 - distance / len(final_dets) | |
| # Anti-collapse cap: empty prediction + non-empty observed syndrome | |
| # is a "did nothing while alarms were firing" failure mode. Cap at | |
| # 0.5 so the empty policy can never approach the full 1.0 even when | |
| # the implied final-round bits happen to coincide. | |
| pred_is_empty = (not parsed.x_errors) and (not parsed.z_errors) | |
| has_active_syndrome = any(int(b) != 0 for b in sample.syndrome_bits) | |
| if pred_is_empty and has_active_syndrome: | |
| return min(base, 0.5) | |
| return base | |
| def compute_final_detector_supports( | |
| layout: CircuitLayout, | |
| syndrome_bits_unused: list[int] | None = None, # API symmetry | |
| *, | |
| detector_to_data_qubits: dict[int, frozenset[int]] | None = None, | |
| ) -> dict[int, frozenset[int]]: | |
| """Map each final-round detector to the set of data qubits whose | |
| terminal X error flips it. | |
| For the rotated memory_z code, each Z-stabiliser final detector watches | |
| the four (or two/one on the boundary) data qubits adjacent to it on the | |
| grid. We compute adjacency by Euclidean distance; data qubits at | |
| distance ``sqrt(2)`` from a Z-stabiliser ancilla coordinate are | |
| incident. | |
| """ | |
| if detector_to_data_qubits is not None: | |
| return detector_to_data_qubits | |
| out: dict[int, frozenset[int]] = {} | |
| for det_idx in layout.final_detectors: | |
| dx, dy = layout.detector_coords[det_idx] | |
| adj: set[int] = set() | |
| for q, (qx, qy) in zip(layout.data_qubits, layout.data_qubit_coords): | |
| if abs((qx - dx) ** 2 + (qy - dy) ** 2 - 2.0) < 1e-6: | |
| adj.add(q) | |
| out[det_idx] = frozenset(adj) | |
| return out | |
| # --------------------------------------------------------------------------- # | |
| # Reward 3: Hamming overlap with reference Pauli frame # | |
| # --------------------------------------------------------------------------- # | |
| def _set_aware_jaccard(true_set: list[int], pred_set: list[int]) -> float: | |
| """Set-aware Jaccard: penalises BOTH false alarms and missed errors. | |
| 2026-04 spec rewrite (FIX 1). The four-case rule is what makes | |
| "predict empty everywhere" stop being a near-optimal strategy: | |
| +-------------+-----------+-----------------------------------------+ | |
| | true_set | pred_set | score | | |
| +-------------+-----------+-----------------------------------------+ | |
| | empty | empty | 1.0 (perfect, "no errors -> no edit") | | |
| | empty | non-empty | 0.0 false alarm | | |
| | non-empty | empty | 0.0 missed errors <-- the key change | | |
| | non-empty | non-empty | |inter| / |union| (standard Jaccard) | | |
| +-------------+-----------+-----------------------------------------+ | |
| Critically the third case used to score 1.0 under the prior plain | |
| Jaccard (because both sets were treated symmetrically; "everything | |
| correct, just nothing predicted" was indistinguishable from "perfect | |
| agreement"). Under this rule a missed-error answer scores 0.0, | |
| which moves the GRPO reward landscape so a non-trivial prediction | |
| can climb out of the empty-everywhere local optimum. | |
| """ | |
| sa, sp = set(true_set), set(pred_set) | |
| if not sa and not sp: | |
| return 1.0 # perfect agreement: no true errors AND no claimed errors | |
| if not sa and sp: | |
| return 0.0 # false alarm: claimed errors that were not there | |
| if sa and not sp: | |
| return 0.0 # missed errors: alarms fired but model said nothing | |
| inter = len(sa & sp) | |
| union = len(sa | sp) | |
| return inter / union if union else 1.0 | |
| def reward_hamming_overlap( | |
| parsed: ParseResult, | |
| sample: SyndromeSample, | |
| layout: CircuitLayout, | |
| ) -> float: | |
| """Average of set-aware Jaccard(X) and set-aware Jaccard(Z) against | |
| the reference Pauli frame carried by ``SyndromeSample``. | |
| The reference frame lives on | |
| ``sample.pymatching_x_errors`` / ``sample.pymatching_z_errors`` — | |
| in this codebase that frame is treated as the ground-truth target | |
| (the SFT/GRPO dataset builders fill it from the same source as the | |
| JSONL ``true_x_errors`` / ``true_z_errors`` fields). Per-axis score | |
| uses the set-aware rule (see :func:`_set_aware_jaccard`), so missed | |
| errors no longer score 1.0 just because the prediction set is empty. | |
| """ | |
| jx = _set_aware_jaccard(sample.pymatching_x_errors, parsed.x_errors) | |
| jz = _set_aware_jaccard(sample.pymatching_z_errors, parsed.z_errors) | |
| return 0.5 * (jx + jz) | |
| # --------------------------------------------------------------------------- # | |
| # Reward 4: format compliance # | |
| # --------------------------------------------------------------------------- # | |
| def reward_format_compliance(parsed: ParseResult) -> float: | |
| """Binary {0.0, 1.0}: 1.0 iff the parser fully extracted both lists. | |
| 2026-04 spec rewrite (FIX 1): partial credit (0.5) is removed. With | |
| partial credit on, the model could still earn ~half the format | |
| weight on garbage outputs that resembled the canonical form, which | |
| is part of what kept the reward landscape too flat for GRPO to | |
| escape the empty-everywhere mode. The new rule rewards only a | |
| cleanly-parsed answer. | |
| """ | |
| return 1.0 if parsed.parse_success else 0.0 | |
| # --------------------------------------------------------------------------- # | |
| # Reward 5: PyMatching beat-rate bonus # | |
| # --------------------------------------------------------------------------- # | |
| def reward_pymatching_beat( | |
| parsed: ParseResult, | |
| sample: SyndromeSample, | |
| layout: CircuitLayout, | |
| ) -> float: | |
| """1.0 iff PyMatching got this syndrome wrong AND the LLM got it right. | |
| This is the headline metric (Section 3.5). Most of training it'll be | |
| near zero; the trajectory of its mean over steps is the proof we've | |
| moved past pure imitation. | |
| """ | |
| pm_correct = sample.pymatching_observable_pred == sample.actual_observable_flip | |
| if pm_correct: | |
| return 0.0 | |
| llm_implied = predicted_observable_flip(parsed.x_errors, layout) | |
| return 1.0 if llm_implied == sample.actual_observable_flip else 0.0 | |
| # --------------------------------------------------------------------------- # | |
| # Combined reward # | |
| # --------------------------------------------------------------------------- # | |
| class RewardBreakdown: | |
| """Per-component scores plus the weighted total.""" | |
| logical_correction: float | |
| syndrome_consistency: float | |
| hamming_overlap: float | |
| format_compliance: float | |
| pymatching_beat: float | |
| total: float | |
| def as_dict(self) -> dict[str, float]: | |
| return { | |
| "logical_correction": self.logical_correction, | |
| "syndrome_consistency": self.syndrome_consistency, | |
| "hamming_overlap": self.hamming_overlap, | |
| "format_compliance": self.format_compliance, | |
| "pymatching_beat": self.pymatching_beat, | |
| "total": self.total, | |
| } | |
| def compute_all_rewards( | |
| parsed: ParseResult, | |
| sample: SyndromeSample, | |
| layout: CircuitLayout, | |
| final_detector_supports: dict[int, frozenset[int]], | |
| weights: dict[str, float] = REWARD_WEIGHTS, | |
| ) -> RewardBreakdown: | |
| """Compute all five rewards and the weighted total. | |
| Returns a :class:`RewardBreakdown` whose ``as_dict`` is what the env's | |
| ``info`` payload contains. The trainer logs each component separately. | |
| """ | |
| r1 = reward_logical_correction(parsed, sample, layout) | |
| r2 = reward_syndrome_consistency(parsed, sample, layout, final_detector_supports) | |
| r3 = reward_hamming_overlap(parsed, sample, layout) | |
| r4 = reward_format_compliance(parsed) | |
| r5 = reward_pymatching_beat(parsed, sample, layout) | |
| total = ( | |
| weights["logical_correction"] * r1 | |
| + weights["syndrome_consistency"] * r2 | |
| + weights["hamming_overlap"] * r3 | |
| + weights["format_compliance"] * r4 | |
| + weights["pymatching_beat"] * r5 | |
| ) | |
| total = max(0.0, min(1.0, total)) | |
| return RewardBreakdown( | |
| logical_correction=r1, | |
| syndrome_consistency=r2, | |
| hamming_overlap=r3, | |
| format_compliance=r4, | |
| pymatching_beat=r5, | |
| total=total, | |
| ) | |