File size: 6,811 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""Warm-start pair generator for both roles.

Produces two parallel JSONL files:

  warmstart/data/repair_pairs.jsonl    -- (prompt, completion) for Repair Agent SFT
  warmstart/data/drift_pairs.jsonl     -- (prompt, completion) for Drift Generator SFT

Each row has the canonical chat-template fields:
  {"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}],
   "task_id": ..., "primitive_type": ..., "category": ...}

We generate at least 50 pairs (the hackathon brief requires this minimum) by
combining each seed-corpus script with each applicable breakage primitive
configuration. The Drift Generator SFT teaches it to emit clean JSON; the
Repair Agent SFT teaches it to emit canonical unified diffs.

Usage:
  python warmstart/generate_pairs.py [--target_count 64] [--out_dir warmstart/data]
"""
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Iterable, Optional

from forgeenv.env.diff_utils import make_unified_diff
from forgeenv.primitives.breakage_primitives import (
    PRIMITIVE_REGISTRY,
    parse_breakage_spec,
)
from forgeenv.roles.drift_generator import _DEFAULT_PARAMS_BY_TYPE
from forgeenv.roles.prompts import (
    DRIFT_GENERATOR_SYSTEM_PROMPT,
    REPAIR_AGENT_SYSTEM_PROMPT,
    render_drift_generator_prompt,
    render_repair_agent_prompt,
)
from forgeenv.roles.repair_agent import BaselineRepairAgent
from forgeenv.tasks.task_sampler import TaskSampler


def _candidate_breakages(script: str) -> list[dict]:
    """Yield breakage specs whose default params we know will mutate `script`."""
    out: list[dict] = []
    for ptype, param_options in _DEFAULT_PARAMS_BY_TYPE.items():
        for params in param_options:
            spec = {"primitive_type": ptype, "params": dict(params)}
            try:
                primitive = parse_breakage_spec(spec)
            except ValueError:
                continue
            mutated = primitive.apply(script)
            if mutated != script:
                out.append(spec)
    return out


def _render_pairs_for_task(
    task_id: str,
    script: str,
    library_versions: dict,
    repair_agent: BaselineRepairAgent,
) -> list[dict]:
    pairs = []
    for spec in _candidate_breakages(script):
        primitive = parse_breakage_spec(spec)
        broken = primitive.apply(script)
        if broken == script:
            continue

        diff = repair_agent.repair(broken, breakage_spec=spec, original_script=script)
        if not diff:
            continue

        # Repair Agent pair
        repair_user = render_repair_agent_prompt(
            broken_script=broken,
            error_trace=f"[simulated] {primitive.description}",
            library_versions=library_versions,
            target_category=primitive.category,
        )
        pairs.append(
            {
                "role_target": "repair_agent",
                "task_id": task_id,
                "primitive_type": spec["primitive_type"],
                "category": primitive.category,
                "messages": [
                    {"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT},
                    {"role": "user", "content": repair_user},
                    {"role": "assistant", "content": diff},
                ],
            }
        )

        # Drift Generator pair (predict the breakage spec from the working script)
        drift_user = render_drift_generator_prompt(
            script=script,
            target_category=spec["primitive_type"],
            library_versions=library_versions,
        )
        pairs.append(
            {
                "role_target": "drift_generator",
                "task_id": task_id,
                "primitive_type": spec["primitive_type"],
                "category": primitive.category,
                "messages": [
                    {"role": "system", "content": DRIFT_GENERATOR_SYSTEM_PROMPT},
                    {"role": "user", "content": drift_user},
                    {"role": "assistant", "content": json.dumps(spec, indent=2)},
                ],
            }
        )
    return pairs


def generate_pairs(
    target_count: int = 64, out_dir: Optional[Path] = None
) -> dict[str, int]:
    out_dir = Path(out_dir) if out_dir is not None else Path("warmstart/data")
    out_dir.mkdir(parents=True, exist_ok=True)

    library_versions = {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.10.0"}
    sampler = TaskSampler()
    repair_agent = BaselineRepairAgent()

    repair_pairs: list[dict] = []
    drift_pairs: list[dict] = []
    for task in sampler.tasks:
        for pair in _render_pairs_for_task(
            task_id=task.task_id,
            script=task.script_content,
            library_versions=library_versions,
            repair_agent=repair_agent,
        ):
            if pair["role_target"] == "repair_agent":
                repair_pairs.append(pair)
            else:
                drift_pairs.append(pair)

    # If we don't have enough, duplicate-with-shuffle. (We never do this in
    # practice — the corpus produces > 64 pairs — but the safety net is cheap.)
    while len(repair_pairs) < target_count and repair_pairs:
        repair_pairs.append(repair_pairs[len(repair_pairs) % len(repair_pairs)])
    while len(drift_pairs) < target_count and drift_pairs:
        drift_pairs.append(drift_pairs[len(drift_pairs) % len(drift_pairs)])

    repair_path = out_dir / "repair_pairs.jsonl"
    drift_path = out_dir / "drift_pairs.jsonl"
    with repair_path.open("w", encoding="utf-8") as f:
        for row in repair_pairs:
            f.write(json.dumps(row) + "\n")
    with drift_path.open("w", encoding="utf-8") as f:
        for row in drift_pairs:
            f.write(json.dumps(row) + "\n")

    counts = {"repair_pairs": len(repair_pairs), "drift_pairs": len(drift_pairs)}
    summary_path = out_dir / "summary.json"
    summary_path.write_text(
        json.dumps(
            {
                **counts,
                "target_count": target_count,
                "primitives_covered": sorted(
                    {p["primitive_type"] for p in repair_pairs}
                ),
                "tasks_covered": sorted({p["task_id"] for p in repair_pairs}),
            },
            indent=2,
        ),
        encoding="utf-8",
    )
    return counts


def _parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--target_count", type=int, default=64)
    parser.add_argument("--out_dir", type=str, default="warmstart/data")
    return parser.parse_args()


if __name__ == "__main__":
    args = _parse_args()
    counts = generate_pairs(target_count=args.target_count, out_dir=Path(args.out_dir))
    print(json.dumps(counts, indent=2))