| """ForgeEnvironment: the OpenEnv Environment subclass for ForgeEnv.
|
|
|
| Episode flow (exactly 2 steps per episode):
|
| reset() -> sample task, ask Teacher for category
|
| step(BreakageAction) -> Drift Generator's proposal is applied; broken
|
| script is run, error trace captured.
|
| step(RepairAction) -> Repair diff is applied; script is re-executed;
|
| visible + held-out rewards computed; episode ends.
|
| """
|
| from __future__ import annotations
|
|
|
| import time
|
| import uuid
|
| from typing import Any, Optional
|
|
|
| from openenv.core import Environment
|
|
|
| from forgeenv.drift.library_drift_engine import LibraryDriftEngine
|
| from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
|
| from forgeenv.env.diff_utils import apply_unified_diff
|
| from forgeenv.env.observations import ForgeObservation
|
| from forgeenv.primitives.breakage_primitives import (
|
| PRIMITIVE_REGISTRY,
|
| parse_breakage_spec,
|
| )
|
| from forgeenv.roles.teacher import Teacher
|
| from forgeenv.sandbox.simulation_mode import SimulationExecutor
|
| from forgeenv.tasks.models import ExecutionResult, Task
|
| from forgeenv.tasks.task_sampler import TaskSampler
|
| from forgeenv.verifier.held_out_evaluator import compute_held_out_scores
|
| from forgeenv.verifier.visible_verifier import compute_visible_reward
|
|
|
| DEFAULT_CATEGORIES = sorted(PRIMITIVE_REGISTRY.keys())
|
|
|
|
|
| class ForgeEnvironment(Environment[ForgeAction, ForgeObservation, dict]):
|
| """OpenEnv-compliant environment for HuggingFace ecosystem repair."""
|
|
|
| SUPPORTS_CONCURRENT_SESSIONS = False
|
|
|
| def __init__(
|
| self,
|
| task_sampler: Optional[TaskSampler] = None,
|
| teacher: Optional[Teacher] = None,
|
| executor: Optional[SimulationExecutor] = None,
|
| drift_engine: Optional[LibraryDriftEngine] = None,
|
| seed: Optional[int] = None,
|
| ) -> None:
|
| super().__init__()
|
| self.task_sampler = task_sampler or TaskSampler()
|
| self.teacher = teacher or Teacher(
|
| categories=list(DEFAULT_CATEGORIES) or ["api_drift"]
|
| )
|
| self.executor = executor or SimulationExecutor(seed=seed)
|
| self.drift_engine = drift_engine or LibraryDriftEngine()
|
|
|
| self._episode_id: Optional[str] = None
|
| self._episode_count: int = 0
|
| self._current_task: Optional[Task] = None
|
| self._original_script: str = ""
|
| self._broken_script: str = ""
|
| self._error_trace: str = ""
|
| self._breakage_spec: Optional[dict[str, Any]] = None
|
| self._target_category: str = ""
|
| self._current_phase: str = "idle"
|
| self._last_obs: Optional[ForgeObservation] = None
|
|
|
|
|
| def reset(
|
| self,
|
| seed: Optional[int] = None,
|
| episode_id: Optional[str] = None,
|
| difficulty: Optional[str] = "easy",
|
| **kwargs: Any,
|
| ) -> ForgeObservation:
|
| self._episode_id = episode_id or str(uuid.uuid4())
|
| self._episode_count += 1
|
| self._target_category = self.teacher.select_next_category()
|
|
|
| task = self.task_sampler.sample(difficulty=difficulty)
|
| if task is None:
|
| raise RuntimeError("Task sampler returned no tasks (empty seed corpus?)")
|
| self._current_task = task
|
| self._original_script = task.script_content
|
| self._broken_script = ""
|
| self._error_trace = ""
|
| self._breakage_spec = None
|
| self._current_phase = "drift_gen"
|
|
|
|
|
| drifted = self.drift_engine.maybe_drift(self._episode_count, drift_every=50)
|
|
|
| obs = ForgeObservation(
|
| current_phase="drift_gen",
|
| task_id=task.task_id,
|
| task_description=task.description,
|
| target_category=self._target_category,
|
| script_content=self._original_script,
|
| error_trace=None,
|
| library_versions=self.drift_engine.current_versions(),
|
| episode_step=0,
|
| done=False,
|
| reward=0.0,
|
| info={
|
| "episode_id": self._episode_id,
|
| "episode_count": self._episode_count,
|
| "drift_triggered": drifted,
|
| "available_primitives": sorted(PRIMITIVE_REGISTRY),
|
| },
|
| )
|
| self._last_obs = obs
|
| return obs
|
|
|
| def step(
|
| self,
|
| action: ForgeAction,
|
| timeout_s: Optional[float] = None,
|
| **kwargs: Any,
|
| ) -> ForgeObservation:
|
| if self._current_phase == "drift_gen":
|
| if action.breakage is None:
|
| return self._error_obs("Expected BreakageAction in drift_gen phase")
|
| return self._handle_breakage(action.breakage)
|
|
|
| if self._current_phase == "repair":
|
| if action.repair is None:
|
| return self._error_obs("Expected RepairAction in repair phase")
|
| return self._handle_repair(action.repair)
|
|
|
| return self._error_obs(
|
| f"step() called in invalid phase {self._current_phase!r} — call reset() first"
|
| )
|
|
|
| @property
|
| def state(self) -> dict:
|
| return {
|
| "phase": self._current_phase,
|
| "episode_id": self._episode_id,
|
| "episode_count": self._episode_count,
|
| "task_id": self._current_task.task_id if self._current_task else None,
|
| "target_category": self._target_category,
|
| "library_versions": self.drift_engine.current_versions(),
|
| "teacher": self.teacher.get_state(),
|
| "drift_history": list(self.drift_engine.drift_history),
|
| "breakage_spec": dict(self._breakage_spec) if self._breakage_spec else None,
|
| }
|
|
|
|
|
| def _handle_breakage(self, breakage: BreakageAction) -> ForgeObservation:
|
| spec = {"primitive_type": breakage.primitive_type, "params": dict(breakage.params)}
|
| try:
|
| primitive = parse_breakage_spec(spec)
|
| except ValueError as exc:
|
| return self._error_obs(f"Invalid breakage spec: {exc}")
|
|
|
| try:
|
| self._broken_script = primitive.apply(self._original_script)
|
| except Exception as exc:
|
| return self._error_obs(f"Primitive apply failed: {exc}")
|
|
|
| self._breakage_spec = spec
|
|
|
| result = self.executor.execute(self._broken_script, self._current_task)
|
| if result.exit_code != 0:
|
| self._error_trace = result.stderr or "non-zero exit code, no stderr"
|
| else:
|
|
|
|
|
| self._error_trace = "Script ran without observable error"
|
|
|
| self._current_phase = "repair"
|
|
|
| obs = ForgeObservation(
|
| current_phase="repair",
|
| task_id=self._current_task.task_id,
|
| task_description=self._current_task.description,
|
| target_category=primitive.category,
|
| script_content=self._broken_script,
|
| error_trace=self._error_trace,
|
| library_versions=self.drift_engine.current_versions(),
|
| episode_step=1,
|
| done=False,
|
| reward=0.0,
|
| info={
|
| "episode_id": self._episode_id,
|
| "breakage_primitive": primitive.name,
|
| "breakage_description": primitive.description,
|
| },
|
| )
|
| self._last_obs = obs
|
| return obs
|
|
|
| def _handle_repair(self, repair: RepairAction) -> ForgeObservation:
|
| repaired = apply_unified_diff(self._broken_script, repair.unified_diff or "")
|
|
|
| t0 = time.time()
|
| result = self.executor.execute(repaired, self._current_task)
|
| result.script_content = repaired
|
| wall_ms = int((time.time() - t0) * 1000)
|
|
|
| visible_reward, visible_breakdown = compute_visible_reward(
|
| result, self._current_task
|
| )
|
| held_out = compute_held_out_scores(
|
| result, self._current_task, repair_diff=repair.unified_diff or ""
|
| )
|
|
|
| success = result.exit_code == 0
|
| category = (
|
| self._breakage_spec.get("primitive_type", "unknown")
|
| if self._breakage_spec
|
| else "unknown"
|
| )
|
|
|
| self.teacher.update(category, success)
|
|
|
| self._current_phase = "done"
|
|
|
| obs = ForgeObservation(
|
| current_phase="done",
|
| task_id=self._current_task.task_id,
|
| task_description=self._current_task.description,
|
| target_category=category,
|
| script_content=repaired,
|
| error_trace=result.stderr or None,
|
| library_versions=self.drift_engine.current_versions(),
|
| episode_step=2,
|
| done=True,
|
| reward=visible_reward,
|
| reward_breakdown=visible_breakdown,
|
| held_out_breakdown=held_out,
|
| info={
|
| "episode_id": self._episode_id,
|
| "exit_code": result.exit_code,
|
| "wall_time_ms": wall_ms,
|
| "checkpoint_exists": result.checkpoint_exists,
|
| "stdout_tail": "\n".join(result.stdout.splitlines()[-5:]),
|
| "breakage_spec": self._breakage_spec,
|
| "teacher_state": self.teacher.get_state(),
|
| },
|
| )
|
| self._last_obs = obs
|
| return obs
|
|
|
| def _error_obs(self, message: str) -> ForgeObservation:
|
| """Return a `done=True` error observation rather than raising."""
|
| return ForgeObservation(
|
| current_phase="done",
|
| task_id=self._current_task.task_id if self._current_task else "",
|
| task_description=self._current_task.description if self._current_task else "",
|
| target_category=self._target_category,
|
| script_content=self._broken_script or self._original_script,
|
| error_trace=message,
|
| library_versions=self.drift_engine.current_versions(),
|
| episode_step=2,
|
| done=True,
|
| reward=0.0,
|
| info={"error": message, "episode_id": self._episode_id},
|
| )
|
|
|