| """ForgeEnvironment: the OpenEnv Environment subclass for ForgeEnv. |
| |
| Episode flow (exactly 2 steps per episode): |
| reset() -> sample task, ask Teacher for category |
| step(BreakageAction) -> Drift Generator's proposal is applied; broken |
| script is run, error trace captured. |
| step(RepairAction) -> Repair diff is applied; script is re-executed; |
| visible + held-out rewards computed; episode ends. |
| """ |
| from __future__ import annotations |
|
|
| import time |
| import uuid |
| from typing import Any, Optional |
|
|
| from openenv.core import Environment |
|
|
| from forgeenv.drift.library_drift_engine import LibraryDriftEngine |
| from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction |
| from forgeenv.env.diff_utils import apply_unified_diff |
| from forgeenv.env.observations import ForgeObservation |
| from forgeenv.primitives.breakage_primitives import ( |
| PRIMITIVE_REGISTRY, |
| parse_breakage_spec, |
| ) |
| from forgeenv.roles.teacher import Teacher |
| from forgeenv.sandbox.simulation_mode import SimulationExecutor |
| from forgeenv.tasks.models import ExecutionResult, Task |
| from forgeenv.tasks.task_sampler import TaskSampler |
| from forgeenv.verifier.held_out_evaluator import compute_held_out_scores |
| from forgeenv.verifier.visible_verifier import compute_visible_reward |
|
|
| DEFAULT_CATEGORIES = sorted(PRIMITIVE_REGISTRY.keys()) |
|
|
|
|
| class ForgeEnvironment(Environment[ForgeAction, ForgeObservation, dict]): |
| """OpenEnv-compliant environment for HuggingFace ecosystem repair.""" |
|
|
| SUPPORTS_CONCURRENT_SESSIONS = False |
|
|
| def __init__( |
| self, |
| task_sampler: Optional[TaskSampler] = None, |
| teacher: Optional[Teacher] = None, |
| executor: Optional[SimulationExecutor] = None, |
| drift_engine: Optional[LibraryDriftEngine] = None, |
| seed: Optional[int] = None, |
| ) -> None: |
| super().__init__() |
| self.task_sampler = task_sampler or TaskSampler() |
| self.teacher = teacher or Teacher( |
| categories=list(DEFAULT_CATEGORIES) or ["api_drift"] |
| ) |
| self.executor = executor or SimulationExecutor(seed=seed) |
| self.drift_engine = drift_engine or LibraryDriftEngine() |
|
|
| self._episode_id: Optional[str] = None |
| self._episode_count: int = 0 |
| self._current_task: Optional[Task] = None |
| self._original_script: str = "" |
| self._broken_script: str = "" |
| self._error_trace: str = "" |
| self._breakage_spec: Optional[dict[str, Any]] = None |
| self._target_category: str = "" |
| self._current_phase: str = "idle" |
| self._last_obs: Optional[ForgeObservation] = None |
|
|
| |
| def reset( |
| self, |
| seed: Optional[int] = None, |
| episode_id: Optional[str] = None, |
| difficulty: Optional[str] = "easy", |
| **kwargs: Any, |
| ) -> ForgeObservation: |
| self._episode_id = episode_id or str(uuid.uuid4()) |
| self._episode_count += 1 |
| self._target_category = self.teacher.select_next_category() |
|
|
| task = self.task_sampler.sample(difficulty=difficulty) |
| if task is None: |
| raise RuntimeError("Task sampler returned no tasks (empty seed corpus?)") |
| self._current_task = task |
| self._original_script = task.script_content |
| self._broken_script = "" |
| self._error_trace = "" |
| self._breakage_spec = None |
| self._current_phase = "drift_gen" |
|
|
| |
| drifted = self.drift_engine.maybe_drift(self._episode_count, drift_every=50) |
|
|
| obs = ForgeObservation( |
| current_phase="drift_gen", |
| task_id=task.task_id, |
| task_description=task.description, |
| target_category=self._target_category, |
| script_content=self._original_script, |
| error_trace=None, |
| library_versions=self.drift_engine.current_versions(), |
| episode_step=0, |
| done=False, |
| reward=0.0, |
| info={ |
| "episode_id": self._episode_id, |
| "episode_count": self._episode_count, |
| "drift_triggered": drifted, |
| "available_primitives": sorted(PRIMITIVE_REGISTRY), |
| }, |
| ) |
| self._last_obs = obs |
| return obs |
|
|
| def step( |
| self, |
| action: ForgeAction, |
| timeout_s: Optional[float] = None, |
| **kwargs: Any, |
| ) -> ForgeObservation: |
| if self._current_phase == "drift_gen": |
| if action.breakage is None: |
| return self._error_obs("Expected BreakageAction in drift_gen phase") |
| return self._handle_breakage(action.breakage) |
|
|
| if self._current_phase == "repair": |
| if action.repair is None: |
| return self._error_obs("Expected RepairAction in repair phase") |
| return self._handle_repair(action.repair) |
|
|
| return self._error_obs( |
| f"step() called in invalid phase {self._current_phase!r} — call reset() first" |
| ) |
|
|
| @property |
| def state(self) -> dict: |
| return { |
| "phase": self._current_phase, |
| "episode_id": self._episode_id, |
| "episode_count": self._episode_count, |
| "task_id": self._current_task.task_id if self._current_task else None, |
| "target_category": self._target_category, |
| "library_versions": self.drift_engine.current_versions(), |
| "teacher": self.teacher.get_state(), |
| "drift_history": list(self.drift_engine.drift_history), |
| "breakage_spec": dict(self._breakage_spec) if self._breakage_spec else None, |
| } |
|
|
| |
| def _handle_breakage(self, breakage: BreakageAction) -> ForgeObservation: |
| spec = {"primitive_type": breakage.primitive_type, "params": dict(breakage.params)} |
| try: |
| primitive = parse_breakage_spec(spec) |
| except ValueError as exc: |
| return self._error_obs(f"Invalid breakage spec: {exc}") |
|
|
| try: |
| self._broken_script = primitive.apply(self._original_script) |
| except Exception as exc: |
| return self._error_obs(f"Primitive apply failed: {exc}") |
|
|
| self._breakage_spec = spec |
|
|
| result = self.executor.execute(self._broken_script, self._current_task) |
| if result.exit_code != 0: |
| self._error_trace = result.stderr or "non-zero exit code, no stderr" |
| else: |
| |
| |
| self._error_trace = "Script ran without observable error" |
|
|
| self._current_phase = "repair" |
|
|
| obs = ForgeObservation( |
| current_phase="repair", |
| task_id=self._current_task.task_id, |
| task_description=self._current_task.description, |
| target_category=primitive.category, |
| script_content=self._broken_script, |
| error_trace=self._error_trace, |
| library_versions=self.drift_engine.current_versions(), |
| episode_step=1, |
| done=False, |
| reward=0.0, |
| info={ |
| "episode_id": self._episode_id, |
| "breakage_primitive": primitive.name, |
| "breakage_description": primitive.description, |
| }, |
| ) |
| self._last_obs = obs |
| return obs |
|
|
| def _handle_repair(self, repair: RepairAction) -> ForgeObservation: |
| repaired = apply_unified_diff(self._broken_script, repair.unified_diff or "") |
|
|
| t0 = time.time() |
| result = self.executor.execute(repaired, self._current_task) |
| result.script_content = repaired |
| wall_ms = int((time.time() - t0) * 1000) |
|
|
| visible_reward, visible_breakdown = compute_visible_reward( |
| result, self._current_task |
| ) |
| held_out = compute_held_out_scores( |
| result, self._current_task, repair_diff=repair.unified_diff or "" |
| ) |
|
|
| success = result.exit_code == 0 |
| category = ( |
| self._breakage_spec.get("primitive_type", "unknown") |
| if self._breakage_spec |
| else "unknown" |
| ) |
| |
| self.teacher.update(category, success) |
|
|
| self._current_phase = "done" |
|
|
| obs = ForgeObservation( |
| current_phase="done", |
| task_id=self._current_task.task_id, |
| task_description=self._current_task.description, |
| target_category=category, |
| script_content=repaired, |
| error_trace=result.stderr or None, |
| library_versions=self.drift_engine.current_versions(), |
| episode_step=2, |
| done=True, |
| reward=visible_reward, |
| reward_breakdown=visible_breakdown, |
| held_out_breakdown=held_out, |
| info={ |
| "episode_id": self._episode_id, |
| "exit_code": result.exit_code, |
| "wall_time_ms": wall_ms, |
| "checkpoint_exists": result.checkpoint_exists, |
| "stdout_tail": "\n".join(result.stdout.splitlines()[-5:]), |
| "breakage_spec": self._breakage_spec, |
| "teacher_state": self.teacher.get_state(), |
| }, |
| ) |
| self._last_obs = obs |
| return obs |
|
|
| def _error_obs(self, message: str) -> ForgeObservation: |
| """Return a `done=True` error observation rather than raising.""" |
| return ForgeObservation( |
| current_phase="done", |
| task_id=self._current_task.task_id if self._current_task else "", |
| task_description=self._current_task.description if self._current_task else "", |
| target_category=self._target_category, |
| script_content=self._broken_script or self._original_script, |
| error_trace=message, |
| library_versions=self.drift_engine.current_versions(), |
| episode_step=2, |
| done=True, |
| reward=0.0, |
| info={"error": message, "episode_id": self._episode_id}, |
| ) |
|
|