stabilizer-forge / server /stabilizer_forge_environment.py
ronitraj's picture
Upload folder using huggingface_hub
b1100bc verified
"""StabilizerForge environment.
Episode loop:
reset(task_id?) -> sample/load a task; emit initial observation
step(action) -> apply one Clifford gate, or FINALIZE
returns dense shaping reward per step,
full terminal reward at FINALIZE.
"""
from __future__ import annotations
import json
import os
import random
from pathlib import Path
from typing import Any
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import StabilizerAction, StabilizerObservation
from .verifier import match_fraction
except ImportError: # pragma: no cover (in-container imports without package context)
from models import StabilizerAction, StabilizerObservation
from server.verifier import match_fraction
# Reward weights (keep aligned with README)
W_MATCH = 0.4
W_GATE_EFF = 0.2
W_TWOQ_EFF = 0.2
W_CONN = 0.1
W_FORMAT = 0.1
SHAPING_COEF = 0.05 # per-step Δmatch_fraction shaping
MAX_CONSECUTIVE_VIOLATIONS = 5
def _default_tasks_path() -> str:
"""Resolve the default tasks file. Looks for env var, then sibling tasks.jsonl."""
env_path = os.environ.get("STABILIZER_FORGE_TASKS")
if env_path:
return env_path
here = Path(__file__).resolve().parent.parent # stabilizer_forge/
candidate = here / "tasks.jsonl"
if candidate.exists():
return str(candidate)
# Fallback: project root
return str(here.parent / "tasks.jsonl")
def _load_tasks(path: str) -> list[dict]:
p = Path(path)
if not p.exists():
return []
tasks: list[dict] = []
with p.open() as f:
for line in f:
line = line.strip()
if not line:
continue
tasks.append(json.loads(line))
return tasks
def _gate_to_stim(action: StabilizerAction) -> str:
"""Render a single action as one line of Stim text."""
if action.op in {"H", "S"}:
return f"{action.op} {action.qubits[0]}"
if action.op == "CX":
return f"CX {action.qubits[0]} {action.qubits[1]}"
raise ValueError(f"Cannot render gate: {action.op}")
def _validate_action(
action: StabilizerAction, n_qubits: int
) -> tuple[bool, str]:
"""Schema/range validation. Returns (is_valid, error_msg)."""
if action.op == "FINALIZE":
if action.qubits:
return False, "FINALIZE takes no qubits."
return True, ""
if action.op in {"H", "S"}:
if len(action.qubits) != 1:
return False, f"{action.op} requires exactly 1 qubit, got {len(action.qubits)}."
q = action.qubits[0]
if not (0 <= q < n_qubits):
return False, f"qubit {q} out of range [0, {n_qubits})."
return True, ""
if action.op == "CX":
if len(action.qubits) != 2:
return False, f"CX requires exactly 2 qubits, got {len(action.qubits)}."
c, t = action.qubits
if c == t:
return False, "CX control and target must differ."
for q in (c, t):
if not (0 <= q < n_qubits):
return False, f"qubit {q} out of range [0, {n_qubits})."
return True, ""
return False, f"unknown op {action.op}"
class StabilizerForgeEnvironment(Environment):
"""Single-agent env: emit Clifford gates to encode a target stabilizer code."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self, tasks_path: str | None = None):
super().__init__()
self._tasks_path = tasks_path or _default_tasks_path()
self._tasks = _load_tasks(self._tasks_path)
self._task: dict[str, Any] | None = None
self._gates: list[str] = []
self._cnot_count = 0
self._nonadj_cnot_count = 0
self._format_violations = 0
self._consecutive_violations = 0
self._last_match_fraction = 0.0
self._last_match_results: list[bool] = []
self._finalized = False
self._rng = random.Random()
self._state = State(episode_id=str(uuid4()), step_count=0)
# ---------- Helpers ----------
def _circuit_text(self) -> str:
return "\n".join(self._gates)
def _is_adjacent(self, a: int, b: int) -> bool:
edges = self._task.get("connectivity_edges") if self._task else None
if edges is None:
return True # all-to-all
edge_set = {tuple(sorted(e)) for e in edges}
return tuple(sorted((a, b))) in edge_set
def _compute_match(self) -> tuple[float, list[bool]]:
if not self._task:
return 0.0, []
text = self._circuit_text()
n = self._task["n_qubits"]
targets = self._task["target_stabilizers"]
frac, results_dict = match_fraction(text, targets, n)
ordered = [results_dict[s] for s in targets]
return frac, ordered
def _make_obs(
self,
*,
reward: float,
done: bool,
last_action_valid: bool = True,
last_action_error: str = "",
) -> StabilizerObservation:
assert self._task is not None
bench = int(self._task.get("benchmark_optimum", 0) or 0)
return StabilizerObservation(
task_id=self._task["task_id"],
target_stabilizers=list(self._task["target_stabilizers"]),
n_qubits=int(self._task["n_qubits"]),
connectivity_edges=self._task.get("connectivity_edges"),
gates_so_far=list(self._gates),
current_circuit=self._circuit_text(),
current_match=list(self._last_match_results),
match_fraction=float(self._last_match_fraction),
gates_emitted=len(self._gates),
cnot_count=self._cnot_count,
nonadj_cnot_count=self._nonadj_cnot_count,
gate_budget=int(self._task.get("gate_budget", 2 * max(1, bench))),
gate_budget_remaining=max(
0,
int(self._task.get("gate_budget", 2 * max(1, bench))) - len(self._gates),
),
benchmark_optimum=bench,
benchmark_optimum_2q=int(self._task.get("benchmark_optimum_2q", 0) or 0),
format_violations=self._format_violations,
consecutive_violations=self._consecutive_violations,
last_action_valid=last_action_valid,
last_action_error=last_action_error,
step_count=self._state.step_count,
finalized=self._finalized,
done=done,
reward=reward,
)
def _pick_task(self, task_id: str | None, seed: int | None) -> dict[str, Any]:
if task_id is not None:
for t in self._tasks:
if t.get("task_id") == task_id:
return t
raise ValueError(f"task_id '{task_id}' not found in {self._tasks_path}")
if not self._tasks:
raise RuntimeError(
f"No tasks loaded from {self._tasks_path}. "
"Set STABILIZER_FORGE_TASKS or place tasks.jsonl next to the env."
)
rng = random.Random(seed) if seed is not None else self._rng
return rng.choice(self._tasks)
# ---------- Gym API ----------
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
task_id: str | None = None,
**kwargs: Any,
) -> StabilizerObservation:
if seed is not None:
self._rng = random.Random(seed)
self._task = self._pick_task(task_id=task_id, seed=seed)
self._gates = []
self._cnot_count = 0
self._nonadj_cnot_count = 0
self._format_violations = 0
self._consecutive_violations = 0
self._finalized = False
self._state = State(
episode_id=episode_id or str(uuid4()), step_count=0
)
# Initial match (empty circuit on |0...0>)
self._last_match_fraction, self._last_match_results = self._compute_match()
return self._make_obs(reward=0.0, done=False)
def step(self, action: StabilizerAction, **kwargs: Any) -> StabilizerObservation: # type: ignore[override]
if self._task is None:
raise RuntimeError("step() called before reset().")
self._state.step_count += 1
n_qubits = int(self._task["n_qubits"])
gate_budget = int(
self._task.get(
"gate_budget", 2 * max(1, int(self._task.get("benchmark_optimum", 1)))
)
)
# --- Validate ---
ok, err = _validate_action(action, n_qubits)
if not ok:
self._format_violations += 1
self._consecutive_violations += 1
done = (
self._consecutive_violations >= MAX_CONSECUTIVE_VIOLATIONS
or len(self._gates) >= gate_budget
)
return self._make_obs(
reward=W_FORMAT * -1.0,
done=done,
last_action_valid=False,
last_action_error=err,
)
self._consecutive_violations = 0
# --- FINALIZE: terminal reward ---
if action.op == "FINALIZE":
self._finalized = True
self._last_match_fraction, self._last_match_results = self._compute_match()
bench = max(1, int(self._task.get("benchmark_optimum", 1)))
bench_2q = max(1, int(self._task.get("benchmark_optimum_2q", bench)))
gate_eff = max(0.0, 1.0 - len(self._gates) / (1.5 * bench))
twoq_eff = max(0.0, 1.0 - self._cnot_count / (1.5 * bench_2q))
terminal = (
W_MATCH * self._last_match_fraction
+ W_GATE_EFF * gate_eff
+ W_TWOQ_EFF * twoq_eff
)
return self._make_obs(reward=terminal, done=True)
# --- Apply gate ---
gate_str = _gate_to_stim(action)
self._gates.append(gate_str)
conn_penalty = 0.0
if action.op == "CX":
self._cnot_count += 1
if not self._is_adjacent(action.qubits[0], action.qubits[1]):
self._nonadj_cnot_count += 1
conn_penalty = -1.0
prev_match = self._last_match_fraction
self._last_match_fraction, self._last_match_results = self._compute_match()
delta = self._last_match_fraction - prev_match
# Termination if we exceed budget without finalizing
done = len(self._gates) >= gate_budget
step_reward = SHAPING_COEF * delta + W_CONN * conn_penalty
return self._make_obs(reward=step_reward, done=done)
@property
def state(self) -> State:
return self._state