"""ForgeEnvironment: the OpenEnv Environment subclass for ForgeEnv. Episode flow (exactly 2 steps per episode): reset() -> sample task, ask Teacher for category step(BreakageAction) -> Drift Generator's proposal is applied; broken script is run, error trace captured. step(RepairAction) -> Repair diff is applied; script is re-executed; visible + held-out rewards computed; episode ends. """ from __future__ import annotations import time import uuid from typing import Any, Optional from openenv.core import Environment from forgeenv.drift.library_drift_engine import LibraryDriftEngine from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction from forgeenv.env.diff_utils import apply_unified_diff from forgeenv.env.observations import ForgeObservation from forgeenv.primitives.breakage_primitives import ( PRIMITIVE_REGISTRY, parse_breakage_spec, ) from forgeenv.roles.teacher import Teacher from forgeenv.sandbox.simulation_mode import SimulationExecutor from forgeenv.tasks.models import ExecutionResult, Task from forgeenv.tasks.task_sampler import TaskSampler from forgeenv.verifier.held_out_evaluator import compute_held_out_scores from forgeenv.verifier.visible_verifier import compute_visible_reward DEFAULT_CATEGORIES = sorted(PRIMITIVE_REGISTRY.keys()) class ForgeEnvironment(Environment[ForgeAction, ForgeObservation, dict]): """OpenEnv-compliant environment for HuggingFace ecosystem repair.""" SUPPORTS_CONCURRENT_SESSIONS = False # Teacher state is global per env def __init__( self, task_sampler: Optional[TaskSampler] = None, teacher: Optional[Teacher] = None, executor: Optional[SimulationExecutor] = None, drift_engine: Optional[LibraryDriftEngine] = None, seed: Optional[int] = None, ) -> None: super().__init__() self.task_sampler = task_sampler or TaskSampler() self.teacher = teacher or Teacher( categories=list(DEFAULT_CATEGORIES) or ["api_drift"] ) self.executor = executor or SimulationExecutor(seed=seed) self.drift_engine = drift_engine or LibraryDriftEngine() self._episode_id: Optional[str] = None self._episode_count: int = 0 self._current_task: Optional[Task] = None self._original_script: str = "" self._broken_script: str = "" self._error_trace: str = "" self._breakage_spec: Optional[dict[str, Any]] = None self._target_category: str = "" self._current_phase: str = "idle" self._last_obs: Optional[ForgeObservation] = None # ------------------------------------------------------------------ API def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, difficulty: Optional[str] = "easy", **kwargs: Any, ) -> ForgeObservation: self._episode_id = episode_id or str(uuid.uuid4()) self._episode_count += 1 self._target_category = self.teacher.select_next_category() task = self.task_sampler.sample(difficulty=difficulty) if task is None: raise RuntimeError("Task sampler returned no tasks (empty seed corpus?)") self._current_task = task self._original_script = task.script_content self._broken_script = "" self._error_trace = "" self._breakage_spec = None self._current_phase = "drift_gen" # Library drift trigger every 50 episodes (configurable from outside). drifted = self.drift_engine.maybe_drift(self._episode_count, drift_every=50) obs = ForgeObservation( current_phase="drift_gen", task_id=task.task_id, task_description=task.description, target_category=self._target_category, script_content=self._original_script, error_trace=None, library_versions=self.drift_engine.current_versions(), episode_step=0, done=False, reward=0.0, info={ "episode_id": self._episode_id, "episode_count": self._episode_count, "drift_triggered": drifted, "available_primitives": sorted(PRIMITIVE_REGISTRY), }, ) self._last_obs = obs return obs def step( self, action: ForgeAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> ForgeObservation: if self._current_phase == "drift_gen": if action.breakage is None: return self._error_obs("Expected BreakageAction in drift_gen phase") return self._handle_breakage(action.breakage) if self._current_phase == "repair": if action.repair is None: return self._error_obs("Expected RepairAction in repair phase") return self._handle_repair(action.repair) return self._error_obs( f"step() called in invalid phase {self._current_phase!r} — call reset() first" ) @property def state(self) -> dict: return { "phase": self._current_phase, "episode_id": self._episode_id, "episode_count": self._episode_count, "task_id": self._current_task.task_id if self._current_task else None, "target_category": self._target_category, "library_versions": self.drift_engine.current_versions(), "teacher": self.teacher.get_state(), "drift_history": list(self.drift_engine.drift_history), "breakage_spec": dict(self._breakage_spec) if self._breakage_spec else None, } # ---------------------------------------------------------------- helpers def _handle_breakage(self, breakage: BreakageAction) -> ForgeObservation: spec = {"primitive_type": breakage.primitive_type, "params": dict(breakage.params)} try: primitive = parse_breakage_spec(spec) except ValueError as exc: return self._error_obs(f"Invalid breakage spec: {exc}") try: self._broken_script = primitive.apply(self._original_script) except Exception as exc: # primitive bug — surface but don't crash server return self._error_obs(f"Primitive apply failed: {exc}") self._breakage_spec = spec result = self.executor.execute(self._broken_script, self._current_task) if result.exit_code != 0: self._error_trace = result.stderr or "non-zero exit code, no stderr" else: # The breakage didn't actually break it; still proceed to repair phase # (no-op repair is then a valid choice). self._error_trace = "Script ran without observable error" self._current_phase = "repair" obs = ForgeObservation( current_phase="repair", task_id=self._current_task.task_id, task_description=self._current_task.description, target_category=primitive.category, script_content=self._broken_script, error_trace=self._error_trace, library_versions=self.drift_engine.current_versions(), episode_step=1, done=False, reward=0.0, info={ "episode_id": self._episode_id, "breakage_primitive": primitive.name, "breakage_description": primitive.description, }, ) self._last_obs = obs return obs def _handle_repair(self, repair: RepairAction) -> ForgeObservation: repaired = apply_unified_diff(self._broken_script, repair.unified_diff or "") t0 = time.time() result = self.executor.execute(repaired, self._current_task) result.script_content = repaired # ensure verifier sees what we ran wall_ms = int((time.time() - t0) * 1000) visible_reward, visible_breakdown = compute_visible_reward( result, self._current_task ) held_out = compute_held_out_scores( result, self._current_task, repair_diff=repair.unified_diff or "" ) success = result.exit_code == 0 category = ( self._breakage_spec.get("primitive_type", "unknown") if self._breakage_spec else "unknown" ) # Update Teacher's curriculum state self.teacher.update(category, success) self._current_phase = "done" obs = ForgeObservation( current_phase="done", task_id=self._current_task.task_id, task_description=self._current_task.description, target_category=category, script_content=repaired, error_trace=result.stderr or None, library_versions=self.drift_engine.current_versions(), episode_step=2, done=True, reward=visible_reward, reward_breakdown=visible_breakdown, held_out_breakdown=held_out, info={ "episode_id": self._episode_id, "exit_code": result.exit_code, "wall_time_ms": wall_ms, "checkpoint_exists": result.checkpoint_exists, "stdout_tail": "\n".join(result.stdout.splitlines()[-5:]), "breakage_spec": self._breakage_spec, "teacher_state": self.teacher.get_state(), }, ) self._last_obs = obs return obs def _error_obs(self, message: str) -> ForgeObservation: """Return a `done=True` error observation rather than raising.""" return ForgeObservation( current_phase="done", task_id=self._current_task.task_id if self._current_task else "", task_description=self._current_task.description if self._current_task else "", target_category=self._target_category, script_content=self._broken_script or self._original_script, error_trace=message, library_versions=self.drift_engine.current_versions(), episode_step=2, done=True, reward=0.0, info={"error": message, "episode_id": self._episode_id}, )