Spaces:
Sleeping
Sleeping
| """Pydantic data classes shared by client and server (Section 2.2 of the plan). | |
| Three classes draw the trust boundary: | |
| * ``DecoderObservation`` - what the LLM sees on each step. | |
| * ``DecoderAction`` - what the LLM emits (after parsing). | |
| * ``DecoderState`` - server-side state, never serialised to the client. | |
| Keeping the wire schema explicit is what closes off reward-hacking attacks: | |
| the LLM literally cannot reach into the ``true_error_pattern`` because that | |
| field is not in any class it ever receives. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Optional | |
| from pydantic import BaseModel, ConfigDict, Field | |
| # --------------------------------------------------------------------------- # | |
| # Wire types - sent across the OpenEnv HTTP boundary # | |
| # --------------------------------------------------------------------------- # | |
| class DecoderObservation(BaseModel): | |
| """The view the LLM (and only the LLM) sees on each step.""" | |
| model_config = ConfigDict(frozen=True) | |
| prompt: str = Field( | |
| ..., | |
| description=( | |
| "Pre-formatted prompt string. This is exactly what the trainer " | |
| "passes to the policy - it appears verbatim in training logs." | |
| ), | |
| ) | |
| syndrome_bits: list[int] = Field( | |
| ..., | |
| description=( | |
| "Raw detector activations (0/1). Provided for debugging and " | |
| "reward-hacking audits; the LLM should be reading the prompt, not " | |
| "this array." | |
| ), | |
| ) | |
| distance: int = Field(..., description="Code distance for this episode.") | |
| rounds: int = Field(..., description="Number of stabiliser rounds.") | |
| p: float = Field(..., description="Physical error budget (SI1000 base).") | |
| curriculum_level: str = Field(..., description="Curriculum level name.") | |
| episode_id: int = Field(..., description="Monotonic episode counter.") | |
| dem_digest: str = Field( | |
| ..., | |
| description=( | |
| "Short hash of the detector error model used this episode. The " | |
| "trainer logs this so we can group rollouts by physics config." | |
| ), | |
| ) | |
| class DecoderAction(BaseModel): | |
| """Action emitted by the LLM after parsing. | |
| ``raw_response`` is preserved exactly so we can satisfy the participant | |
| guide's *inspect generations* mandate (Section 2.5 of the plan). | |
| """ | |
| model_config = ConfigDict(frozen=True) | |
| x_error_qubits: list[int] = Field(default_factory=list) | |
| z_error_qubits: list[int] = Field(default_factory=list) | |
| raw_response: str = "" | |
| parse_success: bool = True | |
| class StepResult(BaseModel): | |
| """Standard env step return (mirrors OpenEnv core/Gymnasium).""" | |
| observation: DecoderObservation | |
| reward: float | |
| done: bool | |
| truncated: bool = False | |
| info: dict[str, Any] = Field(default_factory=dict) | |
| class ResetRequest(BaseModel): | |
| """Optional knobs the trainer can pass to ``reset``.""" | |
| seed: Optional[int] = None | |
| forced_level: Optional[str] = Field( | |
| default=None, | |
| description=( | |
| "Override the curriculum scheduler. Used by eval scripts that " | |
| "want a specific (distance, rounds, p) configuration." | |
| ), | |
| ) | |
| class StepRequest(BaseModel): | |
| """The trainer sends the LLM's raw text; the server parses + scores.""" | |
| raw_response: str | |
| episode_id: int | |
| # --------------------------------------------------------------------------- # | |
| # Server-only state - intentionally NOT a wire type # | |
| # --------------------------------------------------------------------------- # | |
| class DecoderState(BaseModel): | |
| """Per-episode state kept on the server; never sent to the client. | |
| Pydantic ``arbitrary_types_allowed`` is on because we hold a reference to | |
| a ``stim.Circuit`` object. The state is not serialised over HTTP - it | |
| lives in the server's per-episode dict and is discarded on ``done``. | |
| """ | |
| model_config = ConfigDict(arbitrary_types_allowed=True, frozen=False) | |
| episode_id: int | |
| seed: int | |
| curriculum_level: str | |
| distance: int | |
| rounds: int | |
| p: float | |
| syndrome_bits: list[int] | |
| true_x_errors: list[int] | |
| true_z_errors: list[int] | |
| actual_observable_flip: int # 0 or 1; the unfakeable ground truth | |
| pymatching_observable_pred: int # 0 or 1; baseline's prediction | |
| # Pre-computed quantities the reward functions need. | |
| x_observable_support: list[int] # data qubits whose Z error flips X obs | |
| z_observable_support: list[int] # data qubits whose X error flips Z obs | |
| num_data_qubits: int | |
| num_stabilizers: int | |
| # Stim/PyMatching objects - kept opaque to satisfy Pydantic. | |
| circuit_text: str | |
| dem_text: str | |
| # Reward audit log. | |
| last_reward_breakdown: dict[str, float] = Field(default_factory=dict) | |