"""StabilizerForge environment. Episode loop: reset(task_id?) -> sample/load a task; emit initial observation step(action) -> apply one Clifford gate, or FINALIZE returns dense shaping reward per step, full terminal reward at FINALIZE. """ from __future__ import annotations import json import os import random from pathlib import Path from typing import Any from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State try: from ..models import StabilizerAction, StabilizerObservation from .verifier import match_fraction except ImportError: # pragma: no cover (in-container imports without package context) from models import StabilizerAction, StabilizerObservation from server.verifier import match_fraction # Reward weights (keep aligned with README) W_MATCH = 0.4 W_GATE_EFF = 0.2 W_TWOQ_EFF = 0.2 W_CONN = 0.1 W_FORMAT = 0.1 SHAPING_COEF = 0.05 # per-step Δmatch_fraction shaping MAX_CONSECUTIVE_VIOLATIONS = 5 def _default_tasks_path() -> str: """Resolve the default tasks file. Looks for env var, then sibling tasks.jsonl.""" env_path = os.environ.get("STABILIZER_FORGE_TASKS") if env_path: return env_path here = Path(__file__).resolve().parent.parent # stabilizer_forge/ candidate = here / "tasks.jsonl" if candidate.exists(): return str(candidate) # Fallback: project root return str(here.parent / "tasks.jsonl") def _load_tasks(path: str) -> list[dict]: p = Path(path) if not p.exists(): return [] tasks: list[dict] = [] with p.open() as f: for line in f: line = line.strip() if not line: continue tasks.append(json.loads(line)) return tasks def _gate_to_stim(action: StabilizerAction) -> str: """Render a single action as one line of Stim text.""" if action.op in {"H", "S"}: return f"{action.op} {action.qubits[0]}" if action.op == "CX": return f"CX {action.qubits[0]} {action.qubits[1]}" raise ValueError(f"Cannot render gate: {action.op}") def _validate_action( action: StabilizerAction, n_qubits: int ) -> tuple[bool, str]: """Schema/range validation. Returns (is_valid, error_msg).""" if action.op == "FINALIZE": if action.qubits: return False, "FINALIZE takes no qubits." return True, "" if action.op in {"H", "S"}: if len(action.qubits) != 1: return False, f"{action.op} requires exactly 1 qubit, got {len(action.qubits)}." q = action.qubits[0] if not (0 <= q < n_qubits): return False, f"qubit {q} out of range [0, {n_qubits})." return True, "" if action.op == "CX": if len(action.qubits) != 2: return False, f"CX requires exactly 2 qubits, got {len(action.qubits)}." c, t = action.qubits if c == t: return False, "CX control and target must differ." for q in (c, t): if not (0 <= q < n_qubits): return False, f"qubit {q} out of range [0, {n_qubits})." return True, "" return False, f"unknown op {action.op}" class StabilizerForgeEnvironment(Environment): """Single-agent env: emit Clifford gates to encode a target stabilizer code.""" SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self, tasks_path: str | None = None): super().__init__() self._tasks_path = tasks_path or _default_tasks_path() self._tasks = _load_tasks(self._tasks_path) self._task: dict[str, Any] | None = None self._gates: list[str] = [] self._cnot_count = 0 self._nonadj_cnot_count = 0 self._format_violations = 0 self._consecutive_violations = 0 self._last_match_fraction = 0.0 self._last_match_results: list[bool] = [] self._finalized = False self._rng = random.Random() self._state = State(episode_id=str(uuid4()), step_count=0) # ---------- Helpers ---------- def _circuit_text(self) -> str: return "\n".join(self._gates) def _is_adjacent(self, a: int, b: int) -> bool: edges = self._task.get("connectivity_edges") if self._task else None if edges is None: return True # all-to-all edge_set = {tuple(sorted(e)) for e in edges} return tuple(sorted((a, b))) in edge_set def _compute_match(self) -> tuple[float, list[bool]]: if not self._task: return 0.0, [] text = self._circuit_text() n = self._task["n_qubits"] targets = self._task["target_stabilizers"] frac, results_dict = match_fraction(text, targets, n) ordered = [results_dict[s] for s in targets] return frac, ordered def _make_obs( self, *, reward: float, done: bool, last_action_valid: bool = True, last_action_error: str = "", ) -> StabilizerObservation: assert self._task is not None bench = int(self._task.get("benchmark_optimum", 0) or 0) return StabilizerObservation( task_id=self._task["task_id"], target_stabilizers=list(self._task["target_stabilizers"]), n_qubits=int(self._task["n_qubits"]), connectivity_edges=self._task.get("connectivity_edges"), gates_so_far=list(self._gates), current_circuit=self._circuit_text(), current_match=list(self._last_match_results), match_fraction=float(self._last_match_fraction), gates_emitted=len(self._gates), cnot_count=self._cnot_count, nonadj_cnot_count=self._nonadj_cnot_count, gate_budget=int(self._task.get("gate_budget", 2 * max(1, bench))), gate_budget_remaining=max( 0, int(self._task.get("gate_budget", 2 * max(1, bench))) - len(self._gates), ), benchmark_optimum=bench, benchmark_optimum_2q=int(self._task.get("benchmark_optimum_2q", 0) or 0), format_violations=self._format_violations, consecutive_violations=self._consecutive_violations, last_action_valid=last_action_valid, last_action_error=last_action_error, step_count=self._state.step_count, finalized=self._finalized, done=done, reward=reward, ) def _pick_task(self, task_id: str | None, seed: int | None) -> dict[str, Any]: if task_id is not None: for t in self._tasks: if t.get("task_id") == task_id: return t raise ValueError(f"task_id '{task_id}' not found in {self._tasks_path}") if not self._tasks: raise RuntimeError( f"No tasks loaded from {self._tasks_path}. " "Set STABILIZER_FORGE_TASKS or place tasks.jsonl next to the env." ) rng = random.Random(seed) if seed is not None else self._rng return rng.choice(self._tasks) # ---------- Gym API ---------- def reset( self, seed: int | None = None, episode_id: str | None = None, task_id: str | None = None, **kwargs: Any, ) -> StabilizerObservation: if seed is not None: self._rng = random.Random(seed) self._task = self._pick_task(task_id=task_id, seed=seed) self._gates = [] self._cnot_count = 0 self._nonadj_cnot_count = 0 self._format_violations = 0 self._consecutive_violations = 0 self._finalized = False self._state = State( episode_id=episode_id or str(uuid4()), step_count=0 ) # Initial match (empty circuit on |0...0>) self._last_match_fraction, self._last_match_results = self._compute_match() return self._make_obs(reward=0.0, done=False) def step(self, action: StabilizerAction, **kwargs: Any) -> StabilizerObservation: # type: ignore[override] if self._task is None: raise RuntimeError("step() called before reset().") self._state.step_count += 1 n_qubits = int(self._task["n_qubits"]) gate_budget = int( self._task.get( "gate_budget", 2 * max(1, int(self._task.get("benchmark_optimum", 1))) ) ) # --- Validate --- ok, err = _validate_action(action, n_qubits) if not ok: self._format_violations += 1 self._consecutive_violations += 1 done = ( self._consecutive_violations >= MAX_CONSECUTIVE_VIOLATIONS or len(self._gates) >= gate_budget ) return self._make_obs( reward=W_FORMAT * -1.0, done=done, last_action_valid=False, last_action_error=err, ) self._consecutive_violations = 0 # --- FINALIZE: terminal reward --- if action.op == "FINALIZE": self._finalized = True self._last_match_fraction, self._last_match_results = self._compute_match() bench = max(1, int(self._task.get("benchmark_optimum", 1))) bench_2q = max(1, int(self._task.get("benchmark_optimum_2q", bench))) gate_eff = max(0.0, 1.0 - len(self._gates) / (1.5 * bench)) twoq_eff = max(0.0, 1.0 - self._cnot_count / (1.5 * bench_2q)) terminal = ( W_MATCH * self._last_match_fraction + W_GATE_EFF * gate_eff + W_TWOQ_EFF * twoq_eff ) return self._make_obs(reward=terminal, done=True) # --- Apply gate --- gate_str = _gate_to_stim(action) self._gates.append(gate_str) conn_penalty = 0.0 if action.op == "CX": self._cnot_count += 1 if not self._is_adjacent(action.qubits[0], action.qubits[1]): self._nonadj_cnot_count += 1 conn_penalty = -1.0 prev_match = self._last_match_fraction self._last_match_fraction, self._last_match_results = self._compute_match() delta = self._last_match_fraction - prev_match # Termination if we exceed budget without finalizing done = len(self._gates) >= gate_budget step_reward = SHAPING_COEF * delta + W_CONN * conn_penalty return self._make_obs(reward=step_reward, done=done) @property def state(self) -> State: return self._state