ronitraj's picture
deploy via scripts/deploy_to_space.py
fa68719 verified
"""The five reward functions (Section 3 of the plan).
Design contract (from Section 3.6):
* Each reward is a pure function ``(action, state, layout) -> float in [0, 1]``.
* Rewards never observe each other - they're independent by construction so
the LLM can't satisfy one at the expense of another without genuine task
understanding.
* The combined reward is a weighted sum (weights in :mod:`qubit_medic.config`)
clamped to ``[0, 1]``.
* Every per-component score is reported in the ``info`` dict so logs can
surface reward-hacking early (Section 3.7).
A note on Reward 2 and Reward 3 ground truth - see ``physics.py``: the LLM
predicts a *terminal Pauli frame*, which fully determines the logical-Z
observable but only constrains the *final-round* detectors. Earlier rounds'
detectors are intentionally unscored. Reward 3 compares against PyMatching's
near-optimal Pauli-frame prediction (the canonical decoder reference used in
AlphaQubit's Nature paper).
"""
from __future__ import annotations
from dataclasses import dataclass
from qubit_medic.config import REWARD_WEIGHTS
from qubit_medic.prompts import ParseResult
from qubit_medic.server.physics import (
CircuitLayout,
SyndromeSample,
predicted_observable_flip,
)
# --------------------------------------------------------------------------- #
# Reward 1: logical correction success #
# --------------------------------------------------------------------------- #
def reward_logical_correction(
parsed: ParseResult,
sample: SyndromeSample,
layout: CircuitLayout,
) -> float:
"""Did the predicted correction preserve the logical state?
Apply the predicted X errors as a Pauli frame at end-of-circuit and
compute the implied observable flip. If this matches the actual
observable flip recorded by Stim, the logical state was preserved.
Outputs 1.0 if so, else 0.0.
This is the unfakeable reward - it depends only on Stim's ground truth.
"""
implied = predicted_observable_flip(parsed.x_errors, layout)
return 1.0 if implied == sample.actual_observable_flip else 0.0
# --------------------------------------------------------------------------- #
# Reward 2: syndrome consistency #
# --------------------------------------------------------------------------- #
def _syndrome_from_pauli_frame(
x_errors: list[int],
layout: CircuitLayout,
final_detector_supports: dict[int, frozenset[int]],
) -> dict[int, int]:
"""Compute the implied bits for FINAL-round detectors only.
A terminal X error on data qubit ``q`` flips a final-round Z-stabiliser
detector iff ``q`` is in that detector's support.
"""
out: dict[int, int] = {}
x_set = set(x_errors)
for det_idx, support in final_detector_supports.items():
out[det_idx] = 1 if len(x_set & support) % 2 == 1 else 0
return out
def reward_syndrome_consistency(
parsed: ParseResult,
sample: SyndromeSample,
layout: CircuitLayout,
final_detector_supports: dict[int, frozenset[int]],
) -> float:
"""How well does the predicted Pauli frame reproduce the FINAL detectors?
Computes Hamming similarity between ``predicted_final_bits`` (induced
by the predicted X errors) and ``observed_final_bits``. Returns
``1 - hamming_distance / num_final_detectors``.
Rationale (Section 3.2): without this term, an LLM that lucky-guesses
the right qubits could get Reward 1 occasionally; this signal forces
it to also explain the data the syndrome carries.
2026-04 anti-collapse cap (FIX 1, RL spec rewrite): if the prediction
is empty AND the observed syndrome is non-empty (at least one
detector fired), cap the score at 0.5. Without this cap, the
"always predict empty" policy can still pull a high syndrome-
consistency score on the prompts where the implied final-round bits
happen to coincide with zeros, which kept GRPO trapped in the
constant-empty mode.
"""
final_dets = layout.final_detectors
if not final_dets:
return 0.0
implied = _syndrome_from_pauli_frame(
parsed.x_errors, layout, final_detector_supports
)
distance = 0
for det_idx in final_dets:
observed = sample.syndrome_bits[det_idx]
predicted = implied.get(det_idx, 0)
if observed != predicted:
distance += 1
base = 1.0 - distance / len(final_dets)
# Anti-collapse cap: empty prediction + non-empty observed syndrome
# is a "did nothing while alarms were firing" failure mode. Cap at
# 0.5 so the empty policy can never approach the full 1.0 even when
# the implied final-round bits happen to coincide.
pred_is_empty = (not parsed.x_errors) and (not parsed.z_errors)
has_active_syndrome = any(int(b) != 0 for b in sample.syndrome_bits)
if pred_is_empty and has_active_syndrome:
return min(base, 0.5)
return base
def compute_final_detector_supports(
layout: CircuitLayout,
syndrome_bits_unused: list[int] | None = None, # API symmetry
*,
detector_to_data_qubits: dict[int, frozenset[int]] | None = None,
) -> dict[int, frozenset[int]]:
"""Map each final-round detector to the set of data qubits whose
terminal X error flips it.
For the rotated memory_z code, each Z-stabiliser final detector watches
the four (or two/one on the boundary) data qubits adjacent to it on the
grid. We compute adjacency by Euclidean distance; data qubits at
distance ``sqrt(2)`` from a Z-stabiliser ancilla coordinate are
incident.
"""
if detector_to_data_qubits is not None:
return detector_to_data_qubits
out: dict[int, frozenset[int]] = {}
for det_idx in layout.final_detectors:
dx, dy = layout.detector_coords[det_idx]
adj: set[int] = set()
for q, (qx, qy) in zip(layout.data_qubits, layout.data_qubit_coords):
if abs((qx - dx) ** 2 + (qy - dy) ** 2 - 2.0) < 1e-6:
adj.add(q)
out[det_idx] = frozenset(adj)
return out
# --------------------------------------------------------------------------- #
# Reward 3: Hamming overlap with reference Pauli frame #
# --------------------------------------------------------------------------- #
def _set_aware_jaccard(true_set: list[int], pred_set: list[int]) -> float:
"""Set-aware Jaccard: penalises BOTH false alarms and missed errors.
2026-04 spec rewrite (FIX 1). The four-case rule is what makes
"predict empty everywhere" stop being a near-optimal strategy:
+-------------+-----------+-----------------------------------------+
| true_set | pred_set | score |
+-------------+-----------+-----------------------------------------+
| empty | empty | 1.0 (perfect, "no errors -> no edit") |
| empty | non-empty | 0.0 false alarm |
| non-empty | empty | 0.0 missed errors <-- the key change |
| non-empty | non-empty | |inter| / |union| (standard Jaccard) |
+-------------+-----------+-----------------------------------------+
Critically the third case used to score 1.0 under the prior plain
Jaccard (because both sets were treated symmetrically; "everything
correct, just nothing predicted" was indistinguishable from "perfect
agreement"). Under this rule a missed-error answer scores 0.0,
which moves the GRPO reward landscape so a non-trivial prediction
can climb out of the empty-everywhere local optimum.
"""
sa, sp = set(true_set), set(pred_set)
if not sa and not sp:
return 1.0 # perfect agreement: no true errors AND no claimed errors
if not sa and sp:
return 0.0 # false alarm: claimed errors that were not there
if sa and not sp:
return 0.0 # missed errors: alarms fired but model said nothing
inter = len(sa & sp)
union = len(sa | sp)
return inter / union if union else 1.0
def reward_hamming_overlap(
parsed: ParseResult,
sample: SyndromeSample,
layout: CircuitLayout,
) -> float:
"""Average of set-aware Jaccard(X) and set-aware Jaccard(Z) against
the reference Pauli frame carried by ``SyndromeSample``.
The reference frame lives on
``sample.pymatching_x_errors`` / ``sample.pymatching_z_errors`` —
in this codebase that frame is treated as the ground-truth target
(the SFT/GRPO dataset builders fill it from the same source as the
JSONL ``true_x_errors`` / ``true_z_errors`` fields). Per-axis score
uses the set-aware rule (see :func:`_set_aware_jaccard`), so missed
errors no longer score 1.0 just because the prediction set is empty.
"""
jx = _set_aware_jaccard(sample.pymatching_x_errors, parsed.x_errors)
jz = _set_aware_jaccard(sample.pymatching_z_errors, parsed.z_errors)
return 0.5 * (jx + jz)
# --------------------------------------------------------------------------- #
# Reward 4: format compliance #
# --------------------------------------------------------------------------- #
def reward_format_compliance(parsed: ParseResult) -> float:
"""Binary {0.0, 1.0}: 1.0 iff the parser fully extracted both lists.
2026-04 spec rewrite (FIX 1): partial credit (0.5) is removed. With
partial credit on, the model could still earn ~half the format
weight on garbage outputs that resembled the canonical form, which
is part of what kept the reward landscape too flat for GRPO to
escape the empty-everywhere mode. The new rule rewards only a
cleanly-parsed answer.
"""
return 1.0 if parsed.parse_success else 0.0
# --------------------------------------------------------------------------- #
# Reward 5: PyMatching beat-rate bonus #
# --------------------------------------------------------------------------- #
def reward_pymatching_beat(
parsed: ParseResult,
sample: SyndromeSample,
layout: CircuitLayout,
) -> float:
"""1.0 iff PyMatching got this syndrome wrong AND the LLM got it right.
This is the headline metric (Section 3.5). Most of training it'll be
near zero; the trajectory of its mean over steps is the proof we've
moved past pure imitation.
"""
pm_correct = sample.pymatching_observable_pred == sample.actual_observable_flip
if pm_correct:
return 0.0
llm_implied = predicted_observable_flip(parsed.x_errors, layout)
return 1.0 if llm_implied == sample.actual_observable_flip else 0.0
# --------------------------------------------------------------------------- #
# Combined reward #
# --------------------------------------------------------------------------- #
@dataclass(frozen=True)
class RewardBreakdown:
"""Per-component scores plus the weighted total."""
logical_correction: float
syndrome_consistency: float
hamming_overlap: float
format_compliance: float
pymatching_beat: float
total: float
def as_dict(self) -> dict[str, float]:
return {
"logical_correction": self.logical_correction,
"syndrome_consistency": self.syndrome_consistency,
"hamming_overlap": self.hamming_overlap,
"format_compliance": self.format_compliance,
"pymatching_beat": self.pymatching_beat,
"total": self.total,
}
def compute_all_rewards(
parsed: ParseResult,
sample: SyndromeSample,
layout: CircuitLayout,
final_detector_supports: dict[int, frozenset[int]],
weights: dict[str, float] = REWARD_WEIGHTS,
) -> RewardBreakdown:
"""Compute all five rewards and the weighted total.
Returns a :class:`RewardBreakdown` whose ``as_dict`` is what the env's
``info`` payload contains. The trainer logs each component separately.
"""
r1 = reward_logical_correction(parsed, sample, layout)
r2 = reward_syndrome_consistency(parsed, sample, layout, final_detector_supports)
r3 = reward_hamming_overlap(parsed, sample, layout)
r4 = reward_format_compliance(parsed)
r5 = reward_pymatching_beat(parsed, sample, layout)
total = (
weights["logical_correction"] * r1
+ weights["syndrome_consistency"] * r2
+ weights["hamming_overlap"] * r3
+ weights["format_compliance"] * r4
+ weights["pymatching_beat"] * r5
)
total = max(0.0, min(1.0, total))
return RewardBreakdown(
logical_correction=r1,
syndrome_consistency=r2,
hamming_overlap=r3,
format_compliance=r4,
pymatching_beat=r5,
total=total,
)