File size: 10,615 Bytes
b0fbec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""ForgeEnvironment: the OpenEnv Environment subclass for ForgeEnv.



Episode flow (exactly 2 steps per episode):

  reset()             -> sample task, ask Teacher for category

  step(BreakageAction) -> Drift Generator's proposal is applied; broken

                          script is run, error trace captured.

  step(RepairAction)   -> Repair diff is applied; script is re-executed;

                          visible + held-out rewards computed; episode ends.

"""
from __future__ import annotations

import time
import uuid
from typing import Any, Optional

from openenv.core import Environment

from forgeenv.drift.library_drift_engine import LibraryDriftEngine
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
from forgeenv.env.diff_utils import apply_unified_diff
from forgeenv.env.observations import ForgeObservation
from forgeenv.primitives.breakage_primitives import (
    PRIMITIVE_REGISTRY,
    parse_breakage_spec,
)
from forgeenv.roles.teacher import Teacher
from forgeenv.sandbox.simulation_mode import SimulationExecutor
from forgeenv.tasks.models import ExecutionResult, Task
from forgeenv.tasks.task_sampler import TaskSampler
from forgeenv.verifier.held_out_evaluator import compute_held_out_scores
from forgeenv.verifier.visible_verifier import compute_visible_reward

DEFAULT_CATEGORIES = sorted(PRIMITIVE_REGISTRY.keys())


class ForgeEnvironment(Environment[ForgeAction, ForgeObservation, dict]):
    """OpenEnv-compliant environment for HuggingFace ecosystem repair."""

    SUPPORTS_CONCURRENT_SESSIONS = False  # Teacher state is global per env

    def __init__(

        self,

        task_sampler: Optional[TaskSampler] = None,

        teacher: Optional[Teacher] = None,

        executor: Optional[SimulationExecutor] = None,

        drift_engine: Optional[LibraryDriftEngine] = None,

        seed: Optional[int] = None,

    ) -> None:
        super().__init__()
        self.task_sampler = task_sampler or TaskSampler()
        self.teacher = teacher or Teacher(
            categories=list(DEFAULT_CATEGORIES) or ["api_drift"]
        )
        self.executor = executor or SimulationExecutor(seed=seed)
        self.drift_engine = drift_engine or LibraryDriftEngine()

        self._episode_id: Optional[str] = None
        self._episode_count: int = 0
        self._current_task: Optional[Task] = None
        self._original_script: str = ""
        self._broken_script: str = ""
        self._error_trace: str = ""
        self._breakage_spec: Optional[dict[str, Any]] = None
        self._target_category: str = ""
        self._current_phase: str = "idle"
        self._last_obs: Optional[ForgeObservation] = None

    # ------------------------------------------------------------------ API
    def reset(

        self,

        seed: Optional[int] = None,

        episode_id: Optional[str] = None,

        difficulty: Optional[str] = "easy",

        **kwargs: Any,

    ) -> ForgeObservation:
        self._episode_id = episode_id or str(uuid.uuid4())
        self._episode_count += 1
        self._target_category = self.teacher.select_next_category()

        task = self.task_sampler.sample(difficulty=difficulty)
        if task is None:
            raise RuntimeError("Task sampler returned no tasks (empty seed corpus?)")
        self._current_task = task
        self._original_script = task.script_content
        self._broken_script = ""
        self._error_trace = ""
        self._breakage_spec = None
        self._current_phase = "drift_gen"

        # Library drift trigger every 50 episodes (configurable from outside).
        drifted = self.drift_engine.maybe_drift(self._episode_count, drift_every=50)

        obs = ForgeObservation(
            current_phase="drift_gen",
            task_id=task.task_id,
            task_description=task.description,
            target_category=self._target_category,
            script_content=self._original_script,
            error_trace=None,
            library_versions=self.drift_engine.current_versions(),
            episode_step=0,
            done=False,
            reward=0.0,
            info={
                "episode_id": self._episode_id,
                "episode_count": self._episode_count,
                "drift_triggered": drifted,
                "available_primitives": sorted(PRIMITIVE_REGISTRY),
            },
        )
        self._last_obs = obs
        return obs

    def step(

        self,

        action: ForgeAction,

        timeout_s: Optional[float] = None,

        **kwargs: Any,

    ) -> ForgeObservation:
        if self._current_phase == "drift_gen":
            if action.breakage is None:
                return self._error_obs("Expected BreakageAction in drift_gen phase")
            return self._handle_breakage(action.breakage)

        if self._current_phase == "repair":
            if action.repair is None:
                return self._error_obs("Expected RepairAction in repair phase")
            return self._handle_repair(action.repair)

        return self._error_obs(
            f"step() called in invalid phase {self._current_phase!r} — call reset() first"
        )

    @property
    def state(self) -> dict:
        return {
            "phase": self._current_phase,
            "episode_id": self._episode_id,
            "episode_count": self._episode_count,
            "task_id": self._current_task.task_id if self._current_task else None,
            "target_category": self._target_category,
            "library_versions": self.drift_engine.current_versions(),
            "teacher": self.teacher.get_state(),
            "drift_history": list(self.drift_engine.drift_history),
            "breakage_spec": dict(self._breakage_spec) if self._breakage_spec else None,
        }

    # ---------------------------------------------------------------- helpers
    def _handle_breakage(self, breakage: BreakageAction) -> ForgeObservation:
        spec = {"primitive_type": breakage.primitive_type, "params": dict(breakage.params)}
        try:
            primitive = parse_breakage_spec(spec)
        except ValueError as exc:
            return self._error_obs(f"Invalid breakage spec: {exc}")

        try:
            self._broken_script = primitive.apply(self._original_script)
        except Exception as exc:  # primitive bug — surface but don't crash server
            return self._error_obs(f"Primitive apply failed: {exc}")

        self._breakage_spec = spec

        result = self.executor.execute(self._broken_script, self._current_task)
        if result.exit_code != 0:
            self._error_trace = result.stderr or "non-zero exit code, no stderr"
        else:
            # The breakage didn't actually break it; still proceed to repair phase
            # (no-op repair is then a valid choice).
            self._error_trace = "Script ran without observable error"

        self._current_phase = "repair"

        obs = ForgeObservation(
            current_phase="repair",
            task_id=self._current_task.task_id,
            task_description=self._current_task.description,
            target_category=primitive.category,
            script_content=self._broken_script,
            error_trace=self._error_trace,
            library_versions=self.drift_engine.current_versions(),
            episode_step=1,
            done=False,
            reward=0.0,
            info={
                "episode_id": self._episode_id,
                "breakage_primitive": primitive.name,
                "breakage_description": primitive.description,
            },
        )
        self._last_obs = obs
        return obs

    def _handle_repair(self, repair: RepairAction) -> ForgeObservation:
        repaired = apply_unified_diff(self._broken_script, repair.unified_diff or "")

        t0 = time.time()
        result = self.executor.execute(repaired, self._current_task)
        result.script_content = repaired  # ensure verifier sees what we ran
        wall_ms = int((time.time() - t0) * 1000)

        visible_reward, visible_breakdown = compute_visible_reward(
            result, self._current_task
        )
        held_out = compute_held_out_scores(
            result, self._current_task, repair_diff=repair.unified_diff or ""
        )

        success = result.exit_code == 0
        category = (
            self._breakage_spec.get("primitive_type", "unknown")
            if self._breakage_spec
            else "unknown"
        )
        # Update Teacher's curriculum state
        self.teacher.update(category, success)

        self._current_phase = "done"

        obs = ForgeObservation(
            current_phase="done",
            task_id=self._current_task.task_id,
            task_description=self._current_task.description,
            target_category=category,
            script_content=repaired,
            error_trace=result.stderr or None,
            library_versions=self.drift_engine.current_versions(),
            episode_step=2,
            done=True,
            reward=visible_reward,
            reward_breakdown=visible_breakdown,
            held_out_breakdown=held_out,
            info={
                "episode_id": self._episode_id,
                "exit_code": result.exit_code,
                "wall_time_ms": wall_ms,
                "checkpoint_exists": result.checkpoint_exists,
                "stdout_tail": "\n".join(result.stdout.splitlines()[-5:]),
                "breakage_spec": self._breakage_spec,
                "teacher_state": self.teacher.get_state(),
            },
        )
        self._last_obs = obs
        return obs

    def _error_obs(self, message: str) -> ForgeObservation:
        """Return a `done=True` error observation rather than raising."""
        return ForgeObservation(
            current_phase="done",
            task_id=self._current_task.task_id if self._current_task else "",
            task_description=self._current_task.description if self._current_task else "",
            target_category=self._target_category,
            script_content=self._broken_script or self._original_script,
            error_trace=message,
            library_versions=self.drift_engine.current_versions(),
            episode_step=2,
            done=True,
            reward=0.0,
            info={"error": message, "episode_id": self._episode_id},
        )