Spaces:
Sleeping
Sleeping
| """DecoderEnvironment: the OpenEnv-style env that the LLM trainer talks to. | |
| This is the heart of the server (Sections 2.4 + 2.5 of the plan): | |
| * ``reset()``: pick a curriculum level, build a circuit, sample a syndrome, | |
| return a :class:`DecoderObservation`. | |
| * ``step(raw_response)``: parse the LLM's text, score five rewards, return | |
| a :class:`StepResult` whose ``info`` dict carries the per-component | |
| breakdown. | |
| Episodes are single-step (Section 2.5): the LLM emits one prediction and | |
| the episode ends. | |
| """ | |
| from __future__ import annotations | |
| import threading | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| import pymatching | |
| from qubit_medic.config import ( | |
| EPISODE_TIMEOUT_SECONDS, | |
| PRIMARY_SEED, | |
| REWARD_WEIGHTS, | |
| ) | |
| from qubit_medic.models import ( | |
| DecoderAction, | |
| DecoderObservation, | |
| DecoderState, | |
| StepResult, | |
| ) | |
| from qubit_medic.prompts import build_prompt, parse_action | |
| from qubit_medic.server import physics | |
| from qubit_medic.server.curriculum import CurriculumScheduler | |
| from qubit_medic.server.physics import ( | |
| CircuitLayout, | |
| SyndromeSample, | |
| build_circuit, | |
| build_dem, | |
| dem_digest, | |
| extract_layout, | |
| per_round_x_z_counts, | |
| sample_episode, | |
| ) | |
| from qubit_medic.server.rewards import ( | |
| RewardBreakdown, | |
| compute_all_rewards, | |
| compute_final_detector_supports, | |
| ) | |
| # --------------------------------------------------------------------------- # | |
| # Per-level cached compilation - building Stim/PyMatching is the slow step # | |
| # --------------------------------------------------------------------------- # | |
| class _LevelCache: | |
| """Compiled Stim/PyMatching artefacts for one curriculum level.""" | |
| circuit: object | |
| dem: object | |
| matching: pymatching.Matching | |
| layout: CircuitLayout | |
| final_detector_supports: dict | |
| dem_digest: str | |
| def build(cls, level) -> "_LevelCache": | |
| c = build_circuit(level) | |
| d = build_dem(c) | |
| m = pymatching.Matching.from_detector_error_model(d) | |
| layout = extract_layout(c) | |
| supports = compute_final_detector_supports(layout) | |
| return cls( | |
| circuit=c, | |
| dem=d, | |
| matching=m, | |
| layout=layout, | |
| final_detector_supports=supports, | |
| dem_digest=dem_digest(d), | |
| ) | |
| # --------------------------------------------------------------------------- # | |
| # DecoderEnvironment # | |
| # --------------------------------------------------------------------------- # | |
| class _ActiveEpisode: | |
| """In-flight episode bookkeeping.""" | |
| state: DecoderState | |
| sample: SyndromeSample | |
| layout: CircuitLayout | |
| final_detector_supports: dict | |
| started_at: float | |
| class DecoderEnvironment: | |
| """OpenEnv-style env for surface-code decoding. | |
| Thread-safe by virtue of a single ``_lock``: the FastAPI server is | |
| expected to be I/O bound, and per-call latency is well under a | |
| millisecond, so a coarse lock is fine and dramatically simplifies the | |
| state machine. | |
| """ | |
| def __init__(self, *, base_seed: int = PRIMARY_SEED) -> None: | |
| self._lock = threading.Lock() | |
| self._scheduler = CurriculumScheduler(rng=__import__("random").Random(base_seed)) | |
| self._caches: dict[str, _LevelCache] = {} | |
| self._episode_counter = 0 | |
| self._base_seed = base_seed | |
| self._active: dict[int, _ActiveEpisode] = {} | |
| # ----- cache helpers -------------------------------------------------- | |
| def _cache_for(self, level_name: str): | |
| cache = self._caches.get(level_name) | |
| if cache is not None: | |
| return cache | |
| from qubit_medic.config import level_by_name | |
| cache = _LevelCache.build(level_by_name(level_name)) | |
| self._caches[level_name] = cache | |
| return cache | |
| # ----- public API ----------------------------------------------------- | |
| def reset( | |
| self, | |
| *, | |
| seed: Optional[int] = None, | |
| forced_level: Optional[str] = None, | |
| ) -> DecoderObservation: | |
| with self._lock: | |
| self._episode_counter += 1 | |
| ep_id = self._episode_counter | |
| shot_seed = seed if seed is not None else self._base_seed + ep_id | |
| level = self._scheduler.sample(forced_level=forced_level) | |
| cache = self._cache_for(level.name) | |
| sample = sample_episode( | |
| circuit=cache.circuit, | |
| matching=cache.matching, | |
| layout=cache.layout, | |
| seed=shot_seed, | |
| ) | |
| state = DecoderState( | |
| episode_id=ep_id, | |
| seed=shot_seed, | |
| curriculum_level=level.name, | |
| distance=level.distance, | |
| rounds=level.rounds, | |
| p=level.p, | |
| syndrome_bits=sample.syndrome_bits, | |
| true_x_errors=sample.pymatching_x_errors, | |
| true_z_errors=sample.pymatching_z_errors, | |
| actual_observable_flip=sample.actual_observable_flip, | |
| pymatching_observable_pred=sample.pymatching_observable_pred, | |
| x_observable_support=[], # memory_z task: no X observable | |
| z_observable_support=list(cache.layout.z_observable_support), | |
| num_data_qubits=cache.layout.num_data_qubits, | |
| num_stabilizers=cache.layout.num_ancilla_qubits, | |
| circuit_text=str(cache.circuit), | |
| dem_text=str(cache.dem), | |
| ) | |
| self._active[ep_id] = _ActiveEpisode( | |
| state=state, | |
| sample=sample, | |
| layout=cache.layout, | |
| final_detector_supports=cache.final_detector_supports, | |
| started_at=time.monotonic(), | |
| ) | |
| n_x, n_z = per_round_x_z_counts(cache.layout) | |
| prompt = build_prompt( | |
| distance=level.distance, | |
| rounds=level.rounds, | |
| p=level.p, | |
| syndrome_bits=sample.syndrome_bits, | |
| num_x_stabilizers=n_x, | |
| num_z_stabilizers=n_z, | |
| num_data_qubits=cache.layout.num_data_qubits, | |
| ) | |
| return DecoderObservation( | |
| prompt=prompt, | |
| syndrome_bits=sample.syndrome_bits, | |
| distance=level.distance, | |
| rounds=level.rounds, | |
| p=level.p, | |
| curriculum_level=level.name, | |
| episode_id=ep_id, | |
| dem_digest=cache.dem_digest, | |
| ) | |
| def step(self, raw_response: str, episode_id: int) -> StepResult: | |
| with self._lock: | |
| episode = self._active.pop(episode_id, None) | |
| if episode is None: | |
| # Calling step() on an unknown episode ID is a clean | |
| # ValueError (compliance Section 1 of the participant-guide | |
| # audit: the env must "raise a clean ValueError, not a | |
| # Python traceback"). The trainer didn't follow reset/step | |
| # pairing, or the episode already ended; either way we | |
| # surface a typed exception so the FastAPI layer can turn | |
| # it into a 400 response instead of a 500. | |
| raise ValueError( | |
| f"unknown or already-finished episode {episode_id}; " | |
| f"call reset() before step()." | |
| ) | |
| elapsed = time.monotonic() - episode.started_at | |
| timed_out = elapsed > EPISODE_TIMEOUT_SECONDS | |
| parsed = parse_action( | |
| raw_response=raw_response, | |
| num_data_qubits=episode.layout.num_data_qubits, | |
| ) | |
| if timed_out: | |
| # Hard timeout: zero reward, mark format compliance as zero, | |
| # close the episode cleanly (Section 2.6). | |
| breakdown = RewardBreakdown( | |
| logical_correction=0.0, | |
| syndrome_consistency=0.0, | |
| hamming_overlap=0.0, | |
| format_compliance=0.0, | |
| pymatching_beat=0.0, | |
| total=0.0, | |
| ) | |
| action = DecoderAction( | |
| raw_response=raw_response, | |
| parse_success=False, | |
| ) | |
| else: | |
| # Convert LLM-space qubit IDs (0..N-1) to Stim IDs before | |
| # scoring; rewards operate in the Stim coordinate system. | |
| from qubit_medic.prompts import ParseResult | |
| parsed_stim = ParseResult( | |
| x_errors=episode.layout.llm_to_stim(parsed.x_errors), | |
| z_errors=episode.layout.llm_to_stim(parsed.z_errors), | |
| parse_success=parsed.parse_success, | |
| parse_partial=parsed.parse_partial, | |
| raw_response=parsed.raw_response, | |
| ) | |
| breakdown = compute_all_rewards( | |
| parsed=parsed_stim, | |
| sample=episode.sample, | |
| layout=episode.layout, | |
| final_detector_supports=episode.final_detector_supports, | |
| weights=REWARD_WEIGHTS, | |
| ) | |
| action = DecoderAction( | |
| x_error_qubits=parsed.x_errors, | |
| z_error_qubits=parsed.z_errors, | |
| raw_response=raw_response, | |
| parse_success=parsed.parse_success, | |
| ) | |
| self._scheduler.update( | |
| episode.state.curriculum_level, | |
| logical_correction=breakdown.logical_correction, | |
| ) | |
| episode.state.last_reward_breakdown = breakdown.as_dict() | |
| n_x, n_z = per_round_x_z_counts(episode.layout) | |
| prompt = build_prompt( | |
| distance=episode.state.distance, | |
| rounds=episode.state.rounds, | |
| p=episode.state.p, | |
| syndrome_bits=episode.state.syndrome_bits, | |
| num_x_stabilizers=n_x, | |
| num_z_stabilizers=n_z, | |
| num_data_qubits=episode.layout.num_data_qubits, | |
| ) | |
| obs = DecoderObservation( | |
| prompt=prompt, | |
| syndrome_bits=episode.state.syndrome_bits, | |
| distance=episode.state.distance, | |
| rounds=episode.state.rounds, | |
| p=episode.state.p, | |
| curriculum_level=episode.state.curriculum_level, | |
| episode_id=episode.state.episode_id, | |
| dem_digest=episode.state.dem_text[:8], | |
| ) | |
| info = { | |
| "rewards": breakdown.as_dict(), | |
| "parsed_action": action.model_dump(), | |
| "actual_observable_flip": episode.sample.actual_observable_flip, | |
| "pymatching_observable_pred": episode.sample.pymatching_observable_pred, | |
| "pymatching_x_errors": episode.sample.pymatching_x_errors, | |
| "pymatching_z_errors": episode.sample.pymatching_z_errors, | |
| "elapsed_seconds": elapsed, | |
| "timed_out": timed_out, | |
| "curriculum_stats": self._scheduler.stats(), | |
| } | |
| return StepResult( | |
| observation=obs, | |
| reward=breakdown.total, | |
| done=True, # single-step episodes | |
| truncated=timed_out, | |
| info=info, | |
| ) | |
| # ----- introspection -------------------------------------------------- | |
| def health(self) -> dict: | |
| with self._lock: | |
| return { | |
| "ok": True, | |
| "episodes_started": self._episode_counter, | |
| "active_episodes": len(self._active), | |
| "curriculum": self._scheduler.stats(), | |
| "cached_levels": list(self._caches.keys()), | |
| } | |
| def state(self) -> dict: | |
| """Return a JSON-serialisable snapshot of the env's externally- | |
| visible state (compliance Section 1 of the participant-guide | |
| audit: ``state()`` returns a JSON-serialisable object, not a raw | |
| Python object). | |
| Crucially this never includes the ground-truth fields stored on | |
| the per-episode :class:`DecoderState` (true error patterns, | |
| actual_observable_flip, pymatching_observable_pred, circuit_text, | |
| dem_text). Those stay in ``self._active[ep].state`` and are only | |
| consumed by the reward functions. | |
| """ | |
| with self._lock: | |
| return { | |
| "episodes_started": int(self._episode_counter), | |
| "active_episodes": int(len(self._active)), | |
| "active_episode_ids": [int(ep) for ep in self._active.keys()], | |
| "cached_levels": list(self._caches.keys()), | |
| "curriculum": self._scheduler.stats(), | |
| "base_seed": int(self._base_seed), | |
| } | |
| def close(self) -> None: | |
| """Drop any in-flight episodes and clear caches. | |
| Compliance Section 1: the gym-style API requires ``close()``. | |
| After ``close()`` the env can still be re-used by calling | |
| ``reset()`` again - we don't tear down the curriculum scheduler | |
| or release the lock; we only release per-episode bookkeeping. | |
| """ | |
| with self._lock: | |
| self._active.clear() | |