forgeenv-source / forgeenv /env /forge_environment.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""ForgeEnvironment: the OpenEnv Environment subclass for ForgeEnv.
Episode flow (exactly 2 steps per episode):
reset() -> sample task, ask Teacher for category
step(BreakageAction) -> Drift Generator's proposal is applied; broken
script is run, error trace captured.
step(RepairAction) -> Repair diff is applied; script is re-executed;
visible + held-out rewards computed; episode ends.
"""
from __future__ import annotations
import time
import uuid
from typing import Any, Optional
from openenv.core import Environment
from forgeenv.drift.library_drift_engine import LibraryDriftEngine
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
from forgeenv.env.diff_utils import apply_unified_diff
from forgeenv.env.observations import ForgeObservation
from forgeenv.primitives.breakage_primitives import (
PRIMITIVE_REGISTRY,
parse_breakage_spec,
)
from forgeenv.roles.teacher import Teacher
from forgeenv.sandbox.simulation_mode import SimulationExecutor
from forgeenv.tasks.models import ExecutionResult, Task
from forgeenv.tasks.task_sampler import TaskSampler
from forgeenv.verifier.held_out_evaluator import compute_held_out_scores
from forgeenv.verifier.visible_verifier import compute_visible_reward
DEFAULT_CATEGORIES = sorted(PRIMITIVE_REGISTRY.keys())
class ForgeEnvironment(Environment[ForgeAction, ForgeObservation, dict]):
"""OpenEnv-compliant environment for HuggingFace ecosystem repair."""
SUPPORTS_CONCURRENT_SESSIONS = False # Teacher state is global per env
def __init__(
self,
task_sampler: Optional[TaskSampler] = None,
teacher: Optional[Teacher] = None,
executor: Optional[SimulationExecutor] = None,
drift_engine: Optional[LibraryDriftEngine] = None,
seed: Optional[int] = None,
) -> None:
super().__init__()
self.task_sampler = task_sampler or TaskSampler()
self.teacher = teacher or Teacher(
categories=list(DEFAULT_CATEGORIES) or ["api_drift"]
)
self.executor = executor or SimulationExecutor(seed=seed)
self.drift_engine = drift_engine or LibraryDriftEngine()
self._episode_id: Optional[str] = None
self._episode_count: int = 0
self._current_task: Optional[Task] = None
self._original_script: str = ""
self._broken_script: str = ""
self._error_trace: str = ""
self._breakage_spec: Optional[dict[str, Any]] = None
self._target_category: str = ""
self._current_phase: str = "idle"
self._last_obs: Optional[ForgeObservation] = None
# ------------------------------------------------------------------ API
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
difficulty: Optional[str] = "easy",
**kwargs: Any,
) -> ForgeObservation:
self._episode_id = episode_id or str(uuid.uuid4())
self._episode_count += 1
self._target_category = self.teacher.select_next_category()
task = self.task_sampler.sample(difficulty=difficulty)
if task is None:
raise RuntimeError("Task sampler returned no tasks (empty seed corpus?)")
self._current_task = task
self._original_script = task.script_content
self._broken_script = ""
self._error_trace = ""
self._breakage_spec = None
self._current_phase = "drift_gen"
# Library drift trigger every 50 episodes (configurable from outside).
drifted = self.drift_engine.maybe_drift(self._episode_count, drift_every=50)
obs = ForgeObservation(
current_phase="drift_gen",
task_id=task.task_id,
task_description=task.description,
target_category=self._target_category,
script_content=self._original_script,
error_trace=None,
library_versions=self.drift_engine.current_versions(),
episode_step=0,
done=False,
reward=0.0,
info={
"episode_id": self._episode_id,
"episode_count": self._episode_count,
"drift_triggered": drifted,
"available_primitives": sorted(PRIMITIVE_REGISTRY),
},
)
self._last_obs = obs
return obs
def step(
self,
action: ForgeAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> ForgeObservation:
if self._current_phase == "drift_gen":
if action.breakage is None:
return self._error_obs("Expected BreakageAction in drift_gen phase")
return self._handle_breakage(action.breakage)
if self._current_phase == "repair":
if action.repair is None:
return self._error_obs("Expected RepairAction in repair phase")
return self._handle_repair(action.repair)
return self._error_obs(
f"step() called in invalid phase {self._current_phase!r} — call reset() first"
)
@property
def state(self) -> dict:
return {
"phase": self._current_phase,
"episode_id": self._episode_id,
"episode_count": self._episode_count,
"task_id": self._current_task.task_id if self._current_task else None,
"target_category": self._target_category,
"library_versions": self.drift_engine.current_versions(),
"teacher": self.teacher.get_state(),
"drift_history": list(self.drift_engine.drift_history),
"breakage_spec": dict(self._breakage_spec) if self._breakage_spec else None,
}
# ---------------------------------------------------------------- helpers
def _handle_breakage(self, breakage: BreakageAction) -> ForgeObservation:
spec = {"primitive_type": breakage.primitive_type, "params": dict(breakage.params)}
try:
primitive = parse_breakage_spec(spec)
except ValueError as exc:
return self._error_obs(f"Invalid breakage spec: {exc}")
try:
self._broken_script = primitive.apply(self._original_script)
except Exception as exc: # primitive bug — surface but don't crash server
return self._error_obs(f"Primitive apply failed: {exc}")
self._breakage_spec = spec
result = self.executor.execute(self._broken_script, self._current_task)
if result.exit_code != 0:
self._error_trace = result.stderr or "non-zero exit code, no stderr"
else:
# The breakage didn't actually break it; still proceed to repair phase
# (no-op repair is then a valid choice).
self._error_trace = "Script ran without observable error"
self._current_phase = "repair"
obs = ForgeObservation(
current_phase="repair",
task_id=self._current_task.task_id,
task_description=self._current_task.description,
target_category=primitive.category,
script_content=self._broken_script,
error_trace=self._error_trace,
library_versions=self.drift_engine.current_versions(),
episode_step=1,
done=False,
reward=0.0,
info={
"episode_id": self._episode_id,
"breakage_primitive": primitive.name,
"breakage_description": primitive.description,
},
)
self._last_obs = obs
return obs
def _handle_repair(self, repair: RepairAction) -> ForgeObservation:
repaired = apply_unified_diff(self._broken_script, repair.unified_diff or "")
t0 = time.time()
result = self.executor.execute(repaired, self._current_task)
result.script_content = repaired # ensure verifier sees what we ran
wall_ms = int((time.time() - t0) * 1000)
visible_reward, visible_breakdown = compute_visible_reward(
result, self._current_task
)
held_out = compute_held_out_scores(
result, self._current_task, repair_diff=repair.unified_diff or ""
)
success = result.exit_code == 0
category = (
self._breakage_spec.get("primitive_type", "unknown")
if self._breakage_spec
else "unknown"
)
# Update Teacher's curriculum state
self.teacher.update(category, success)
self._current_phase = "done"
obs = ForgeObservation(
current_phase="done",
task_id=self._current_task.task_id,
task_description=self._current_task.description,
target_category=category,
script_content=repaired,
error_trace=result.stderr or None,
library_versions=self.drift_engine.current_versions(),
episode_step=2,
done=True,
reward=visible_reward,
reward_breakdown=visible_breakdown,
held_out_breakdown=held_out,
info={
"episode_id": self._episode_id,
"exit_code": result.exit_code,
"wall_time_ms": wall_ms,
"checkpoint_exists": result.checkpoint_exists,
"stdout_tail": "\n".join(result.stdout.splitlines()[-5:]),
"breakage_spec": self._breakage_spec,
"teacher_state": self.teacher.get_state(),
},
)
self._last_obs = obs
return obs
def _error_obs(self, message: str) -> ForgeObservation:
"""Return a `done=True` error observation rather than raising."""
return ForgeObservation(
current_phase="done",
task_id=self._current_task.task_id if self._current_task else "",
task_description=self._current_task.description if self._current_task else "",
target_category=self._target_category,
script_content=self._broken_script or self._original_script,
error_trace=message,
library_versions=self.drift_engine.current_versions(),
episode_step=2,
done=True,
reward=0.0,
info={"error": message, "episode_id": self._episode_id},
)