Spaces:
Sleeping
Sleeping
| """StabilizerForge environment. | |
| Episode loop: | |
| reset(task_id?) -> sample/load a task; emit initial observation | |
| step(action) -> apply one Clifford gate, or FINALIZE | |
| returns dense shaping reward per step, | |
| full terminal reward at FINALIZE. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import random | |
| from pathlib import Path | |
| from typing import Any | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..models import StabilizerAction, StabilizerObservation | |
| from .verifier import match_fraction | |
| except ImportError: # pragma: no cover (in-container imports without package context) | |
| from models import StabilizerAction, StabilizerObservation | |
| from server.verifier import match_fraction | |
| # Reward weights (keep aligned with README) | |
| W_MATCH = 0.4 | |
| W_GATE_EFF = 0.2 | |
| W_TWOQ_EFF = 0.2 | |
| W_CONN = 0.1 | |
| W_FORMAT = 0.1 | |
| SHAPING_COEF = 0.05 # per-step Δmatch_fraction shaping | |
| MAX_CONSECUTIVE_VIOLATIONS = 5 | |
| def _default_tasks_path() -> str: | |
| """Resolve the default tasks file. Looks for env var, then sibling tasks.jsonl.""" | |
| env_path = os.environ.get("STABILIZER_FORGE_TASKS") | |
| if env_path: | |
| return env_path | |
| here = Path(__file__).resolve().parent.parent # stabilizer_forge/ | |
| candidate = here / "tasks.jsonl" | |
| if candidate.exists(): | |
| return str(candidate) | |
| # Fallback: project root | |
| return str(here.parent / "tasks.jsonl") | |
| def _load_tasks(path: str) -> list[dict]: | |
| p = Path(path) | |
| if not p.exists(): | |
| return [] | |
| tasks: list[dict] = [] | |
| with p.open() as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| tasks.append(json.loads(line)) | |
| return tasks | |
| def _gate_to_stim(action: StabilizerAction) -> str: | |
| """Render a single action as one line of Stim text.""" | |
| if action.op in {"H", "S"}: | |
| return f"{action.op} {action.qubits[0]}" | |
| if action.op == "CX": | |
| return f"CX {action.qubits[0]} {action.qubits[1]}" | |
| raise ValueError(f"Cannot render gate: {action.op}") | |
| def _validate_action( | |
| action: StabilizerAction, n_qubits: int | |
| ) -> tuple[bool, str]: | |
| """Schema/range validation. Returns (is_valid, error_msg).""" | |
| if action.op == "FINALIZE": | |
| if action.qubits: | |
| return False, "FINALIZE takes no qubits." | |
| return True, "" | |
| if action.op in {"H", "S"}: | |
| if len(action.qubits) != 1: | |
| return False, f"{action.op} requires exactly 1 qubit, got {len(action.qubits)}." | |
| q = action.qubits[0] | |
| if not (0 <= q < n_qubits): | |
| return False, f"qubit {q} out of range [0, {n_qubits})." | |
| return True, "" | |
| if action.op == "CX": | |
| if len(action.qubits) != 2: | |
| return False, f"CX requires exactly 2 qubits, got {len(action.qubits)}." | |
| c, t = action.qubits | |
| if c == t: | |
| return False, "CX control and target must differ." | |
| for q in (c, t): | |
| if not (0 <= q < n_qubits): | |
| return False, f"qubit {q} out of range [0, {n_qubits})." | |
| return True, "" | |
| return False, f"unknown op {action.op}" | |
| class StabilizerForgeEnvironment(Environment): | |
| """Single-agent env: emit Clifford gates to encode a target stabilizer code.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self, tasks_path: str | None = None): | |
| super().__init__() | |
| self._tasks_path = tasks_path or _default_tasks_path() | |
| self._tasks = _load_tasks(self._tasks_path) | |
| self._task: dict[str, Any] | None = None | |
| self._gates: list[str] = [] | |
| self._cnot_count = 0 | |
| self._nonadj_cnot_count = 0 | |
| self._format_violations = 0 | |
| self._consecutive_violations = 0 | |
| self._last_match_fraction = 0.0 | |
| self._last_match_results: list[bool] = [] | |
| self._finalized = False | |
| self._rng = random.Random() | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| # ---------- Helpers ---------- | |
| def _circuit_text(self) -> str: | |
| return "\n".join(self._gates) | |
| def _is_adjacent(self, a: int, b: int) -> bool: | |
| edges = self._task.get("connectivity_edges") if self._task else None | |
| if edges is None: | |
| return True # all-to-all | |
| edge_set = {tuple(sorted(e)) for e in edges} | |
| return tuple(sorted((a, b))) in edge_set | |
| def _compute_match(self) -> tuple[float, list[bool]]: | |
| if not self._task: | |
| return 0.0, [] | |
| text = self._circuit_text() | |
| n = self._task["n_qubits"] | |
| targets = self._task["target_stabilizers"] | |
| frac, results_dict = match_fraction(text, targets, n) | |
| ordered = [results_dict[s] for s in targets] | |
| return frac, ordered | |
| def _make_obs( | |
| self, | |
| *, | |
| reward: float, | |
| done: bool, | |
| last_action_valid: bool = True, | |
| last_action_error: str = "", | |
| ) -> StabilizerObservation: | |
| assert self._task is not None | |
| bench = int(self._task.get("benchmark_optimum", 0) or 0) | |
| return StabilizerObservation( | |
| task_id=self._task["task_id"], | |
| target_stabilizers=list(self._task["target_stabilizers"]), | |
| n_qubits=int(self._task["n_qubits"]), | |
| connectivity_edges=self._task.get("connectivity_edges"), | |
| gates_so_far=list(self._gates), | |
| current_circuit=self._circuit_text(), | |
| current_match=list(self._last_match_results), | |
| match_fraction=float(self._last_match_fraction), | |
| gates_emitted=len(self._gates), | |
| cnot_count=self._cnot_count, | |
| nonadj_cnot_count=self._nonadj_cnot_count, | |
| gate_budget=int(self._task.get("gate_budget", 2 * max(1, bench))), | |
| gate_budget_remaining=max( | |
| 0, | |
| int(self._task.get("gate_budget", 2 * max(1, bench))) - len(self._gates), | |
| ), | |
| benchmark_optimum=bench, | |
| benchmark_optimum_2q=int(self._task.get("benchmark_optimum_2q", 0) or 0), | |
| format_violations=self._format_violations, | |
| consecutive_violations=self._consecutive_violations, | |
| last_action_valid=last_action_valid, | |
| last_action_error=last_action_error, | |
| step_count=self._state.step_count, | |
| finalized=self._finalized, | |
| done=done, | |
| reward=reward, | |
| ) | |
| def _pick_task(self, task_id: str | None, seed: int | None) -> dict[str, Any]: | |
| if task_id is not None: | |
| for t in self._tasks: | |
| if t.get("task_id") == task_id: | |
| return t | |
| raise ValueError(f"task_id '{task_id}' not found in {self._tasks_path}") | |
| if not self._tasks: | |
| raise RuntimeError( | |
| f"No tasks loaded from {self._tasks_path}. " | |
| "Set STABILIZER_FORGE_TASKS or place tasks.jsonl next to the env." | |
| ) | |
| rng = random.Random(seed) if seed is not None else self._rng | |
| return rng.choice(self._tasks) | |
| # ---------- Gym API ---------- | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| task_id: str | None = None, | |
| **kwargs: Any, | |
| ) -> StabilizerObservation: | |
| if seed is not None: | |
| self._rng = random.Random(seed) | |
| self._task = self._pick_task(task_id=task_id, seed=seed) | |
| self._gates = [] | |
| self._cnot_count = 0 | |
| self._nonadj_cnot_count = 0 | |
| self._format_violations = 0 | |
| self._consecutive_violations = 0 | |
| self._finalized = False | |
| self._state = State( | |
| episode_id=episode_id or str(uuid4()), step_count=0 | |
| ) | |
| # Initial match (empty circuit on |0...0>) | |
| self._last_match_fraction, self._last_match_results = self._compute_match() | |
| return self._make_obs(reward=0.0, done=False) | |
| def step(self, action: StabilizerAction, **kwargs: Any) -> StabilizerObservation: # type: ignore[override] | |
| if self._task is None: | |
| raise RuntimeError("step() called before reset().") | |
| self._state.step_count += 1 | |
| n_qubits = int(self._task["n_qubits"]) | |
| gate_budget = int( | |
| self._task.get( | |
| "gate_budget", 2 * max(1, int(self._task.get("benchmark_optimum", 1))) | |
| ) | |
| ) | |
| # --- Validate --- | |
| ok, err = _validate_action(action, n_qubits) | |
| if not ok: | |
| self._format_violations += 1 | |
| self._consecutive_violations += 1 | |
| done = ( | |
| self._consecutive_violations >= MAX_CONSECUTIVE_VIOLATIONS | |
| or len(self._gates) >= gate_budget | |
| ) | |
| return self._make_obs( | |
| reward=W_FORMAT * -1.0, | |
| done=done, | |
| last_action_valid=False, | |
| last_action_error=err, | |
| ) | |
| self._consecutive_violations = 0 | |
| # --- FINALIZE: terminal reward --- | |
| if action.op == "FINALIZE": | |
| self._finalized = True | |
| self._last_match_fraction, self._last_match_results = self._compute_match() | |
| bench = max(1, int(self._task.get("benchmark_optimum", 1))) | |
| bench_2q = max(1, int(self._task.get("benchmark_optimum_2q", bench))) | |
| gate_eff = max(0.0, 1.0 - len(self._gates) / (1.5 * bench)) | |
| twoq_eff = max(0.0, 1.0 - self._cnot_count / (1.5 * bench_2q)) | |
| terminal = ( | |
| W_MATCH * self._last_match_fraction | |
| + W_GATE_EFF * gate_eff | |
| + W_TWOQ_EFF * twoq_eff | |
| ) | |
| return self._make_obs(reward=terminal, done=True) | |
| # --- Apply gate --- | |
| gate_str = _gate_to_stim(action) | |
| self._gates.append(gate_str) | |
| conn_penalty = 0.0 | |
| if action.op == "CX": | |
| self._cnot_count += 1 | |
| if not self._is_adjacent(action.qubits[0], action.qubits[1]): | |
| self._nonadj_cnot_count += 1 | |
| conn_penalty = -1.0 | |
| prev_match = self._last_match_fraction | |
| self._last_match_fraction, self._last_match_results = self._compute_match() | |
| delta = self._last_match_fraction - prev_match | |
| # Termination if we exceed budget without finalizing | |
| done = len(self._gates) >= gate_budget | |
| step_reward = SHAPING_COEF * delta + W_CONN * conn_penalty | |
| return self._make_obs(reward=step_reward, done=done) | |
| def state(self) -> State: | |
| return self._state | |