File size: 6,153 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""Rollout function: connects an LLM to ForgeEnvironment for a full episode.

This is the function the GRPO trainer calls to convert a prompt into a
trajectory + reward. It runs both phases of an episode (drift + repair) by
asking the policy twice with role-switched prompts.
"""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable, Optional

from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.roles.drift_generator import (
    BaselineDriftGenerator,
    parse_drift_output,
)
from forgeenv.roles.prompts import (
    DRIFT_GENERATOR_SYSTEM_PROMPT,
    REPAIR_AGENT_SYSTEM_PROMPT,
    render_drift_generator_prompt,
    render_repair_agent_prompt,
)
from forgeenv.roles.repair_agent import BaselineRepairAgent, extract_diff


# Generation function signature: takes a (system, user) prompt pair and
# returns the assistant's completion. We keep this abstract so we can plug
# in TRL's batched generator, vLLM, or our deterministic baseline.
GenerateFn = Callable[[str, str], str]


@dataclass
class RolloutResult:
    task_id: str
    primitive_type: str
    drift_prompt: str
    drift_completion: str
    repair_prompt: str
    repair_completion: str
    visible_reward: float
    visible_breakdown: dict[str, float]
    held_out_breakdown: dict[str, float]
    success: bool
    error_trace: str = ""
    info: dict[str, Any] = field(default_factory=dict)


def _baseline_drift_generate(env: ForgeEnvironment) -> GenerateFn:
    """Wrap our deterministic Drift Generator into a GenerateFn."""

    gen = BaselineDriftGenerator(seed=0)

    def fn(system: str, user: str) -> str:
        target = "RenameApiCall"
        for line in user.splitlines():
            if line.lower().startswith("target category:"):
                target = line.split(":", 1)[1].strip()
                break
        # Try to extract the script body so we can pick a primitive that
        # actually mutates it.
        script_block = ""
        if "```python" in user:
            script_block = user.split("```python", 1)[1].split("```", 1)[0]
        spec = gen.propose(target_category=target, script=script_block)
        import json

        return json.dumps(spec)

    return fn


def _baseline_repair_generate() -> GenerateFn:
    """Wrap our deterministic Repair Agent into a GenerateFn.

    The baseline cheats by recovering the original script from the user
    prompt is impossible (we don't pass it). Instead, when called as a
    baseline it just returns an empty diff. Use BaselineDriftGenerator-paired
    tests (which read env.state) when you want the oracle path.
    """

    def fn(system: str, user: str) -> str:
        return ""  # baseline = no-op (intentional negative baseline)

    return fn


def rollout_one_episode(
    env: ForgeEnvironment,
    drift_generate: Optional[GenerateFn] = None,
    repair_generate: Optional[GenerateFn] = None,
    difficulty: str = "easy",
) -> RolloutResult:
    """Run a single 2-step episode end-to-end and capture all signals."""
    drift_generate = drift_generate or _baseline_drift_generate(env)
    repair_generate = repair_generate or _baseline_repair_generate()

    obs = env.reset(difficulty=difficulty)
    assert obs.current_phase == "drift_gen"

    # ---------- Phase 1: Drift Generator ----------
    drift_prompt = render_drift_generator_prompt(
        script=obs.script_content,
        target_category=obs.target_category,
        library_versions=obs.library_versions,
    )
    drift_raw = drift_generate(DRIFT_GENERATOR_SYSTEM_PROMPT, drift_prompt)
    spec = parse_drift_output(drift_raw)
    if not spec:
        spec = {"primitive_type": "RenameApiCall", "params": {}}

    breakage_action = ForgeAction(
        breakage=BreakageAction(
            primitive_type=spec.get("primitive_type", "RenameApiCall"),
            params=spec.get("params", {}) or {},
        )
    )
    obs2 = env.step(breakage_action)

    # ---------- Phase 2: Repair Agent ----------
    repair_prompt = render_repair_agent_prompt(
        broken_script=obs2.script_content,
        error_trace=obs2.error_trace or "",
        library_versions=obs2.library_versions,
        target_category=obs2.target_category,
    )
    repair_raw = repair_generate(REPAIR_AGENT_SYSTEM_PROMPT, repair_prompt)
    diff = extract_diff(repair_raw) if repair_raw else ""

    repair_action = ForgeAction(repair=RepairAction(unified_diff=diff))
    obs3 = env.step(repair_action)

    return RolloutResult(
        task_id=obs.task_id,
        primitive_type=spec.get("primitive_type", "RenameApiCall"),
        drift_prompt=drift_prompt,
        drift_completion=drift_raw,
        repair_prompt=repair_prompt,
        repair_completion=repair_raw,
        visible_reward=float(obs3.reward or 0.0),
        visible_breakdown=dict(obs3.reward_breakdown),
        held_out_breakdown=dict(obs3.held_out_breakdown),
        success=bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5),
        error_trace=obs3.error_trace or "",
        info=dict(obs3.info),
    )


def baseline_oracle_repair_generate(env: ForgeEnvironment) -> GenerateFn:
    """An "oracle" repair generator that reads the original script from
    `env.state` and emits a perfect diff. Useful for sanity-checking the
    end-to-end loop and as the upper-bound baseline in plots.
    """

    repair_agent = BaselineRepairAgent()

    def fn(system: str, user: str) -> str:
        # Pull the original script out of env state via the task sampler
        task_id = env.state.get("task_id")
        if task_id is None:
            return ""
        task = env.task_sampler.get_by_id(task_id)
        if task is None:
            return ""
        # The current script in env._broken_script is what the user sees.
        broken = env._broken_script  # noqa: SLF001 — internal but oracle-only
        return repair_agent.repair(
            broken,
            breakage_spec=env._breakage_spec,  # noqa: SLF001
            original_script=task.script_content,
        )

    return fn