File size: 6,811 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """Warm-start pair generator for both roles.
Produces two parallel JSONL files:
warmstart/data/repair_pairs.jsonl -- (prompt, completion) for Repair Agent SFT
warmstart/data/drift_pairs.jsonl -- (prompt, completion) for Drift Generator SFT
Each row has the canonical chat-template fields:
{"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}],
"task_id": ..., "primitive_type": ..., "category": ...}
We generate at least 50 pairs (the hackathon brief requires this minimum) by
combining each seed-corpus script with each applicable breakage primitive
configuration. The Drift Generator SFT teaches it to emit clean JSON; the
Repair Agent SFT teaches it to emit canonical unified diffs.
Usage:
python warmstart/generate_pairs.py [--target_count 64] [--out_dir warmstart/data]
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Iterable, Optional
from forgeenv.env.diff_utils import make_unified_diff
from forgeenv.primitives.breakage_primitives import (
PRIMITIVE_REGISTRY,
parse_breakage_spec,
)
from forgeenv.roles.drift_generator import _DEFAULT_PARAMS_BY_TYPE
from forgeenv.roles.prompts import (
DRIFT_GENERATOR_SYSTEM_PROMPT,
REPAIR_AGENT_SYSTEM_PROMPT,
render_drift_generator_prompt,
render_repair_agent_prompt,
)
from forgeenv.roles.repair_agent import BaselineRepairAgent
from forgeenv.tasks.task_sampler import TaskSampler
def _candidate_breakages(script: str) -> list[dict]:
"""Yield breakage specs whose default params we know will mutate `script`."""
out: list[dict] = []
for ptype, param_options in _DEFAULT_PARAMS_BY_TYPE.items():
for params in param_options:
spec = {"primitive_type": ptype, "params": dict(params)}
try:
primitive = parse_breakage_spec(spec)
except ValueError:
continue
mutated = primitive.apply(script)
if mutated != script:
out.append(spec)
return out
def _render_pairs_for_task(
task_id: str,
script: str,
library_versions: dict,
repair_agent: BaselineRepairAgent,
) -> list[dict]:
pairs = []
for spec in _candidate_breakages(script):
primitive = parse_breakage_spec(spec)
broken = primitive.apply(script)
if broken == script:
continue
diff = repair_agent.repair(broken, breakage_spec=spec, original_script=script)
if not diff:
continue
# Repair Agent pair
repair_user = render_repair_agent_prompt(
broken_script=broken,
error_trace=f"[simulated] {primitive.description}",
library_versions=library_versions,
target_category=primitive.category,
)
pairs.append(
{
"role_target": "repair_agent",
"task_id": task_id,
"primitive_type": spec["primitive_type"],
"category": primitive.category,
"messages": [
{"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT},
{"role": "user", "content": repair_user},
{"role": "assistant", "content": diff},
],
}
)
# Drift Generator pair (predict the breakage spec from the working script)
drift_user = render_drift_generator_prompt(
script=script,
target_category=spec["primitive_type"],
library_versions=library_versions,
)
pairs.append(
{
"role_target": "drift_generator",
"task_id": task_id,
"primitive_type": spec["primitive_type"],
"category": primitive.category,
"messages": [
{"role": "system", "content": DRIFT_GENERATOR_SYSTEM_PROMPT},
{"role": "user", "content": drift_user},
{"role": "assistant", "content": json.dumps(spec, indent=2)},
],
}
)
return pairs
def generate_pairs(
target_count: int = 64, out_dir: Optional[Path] = None
) -> dict[str, int]:
out_dir = Path(out_dir) if out_dir is not None else Path("warmstart/data")
out_dir.mkdir(parents=True, exist_ok=True)
library_versions = {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.10.0"}
sampler = TaskSampler()
repair_agent = BaselineRepairAgent()
repair_pairs: list[dict] = []
drift_pairs: list[dict] = []
for task in sampler.tasks:
for pair in _render_pairs_for_task(
task_id=task.task_id,
script=task.script_content,
library_versions=library_versions,
repair_agent=repair_agent,
):
if pair["role_target"] == "repair_agent":
repair_pairs.append(pair)
else:
drift_pairs.append(pair)
# If we don't have enough, duplicate-with-shuffle. (We never do this in
# practice — the corpus produces > 64 pairs — but the safety net is cheap.)
while len(repair_pairs) < target_count and repair_pairs:
repair_pairs.append(repair_pairs[len(repair_pairs) % len(repair_pairs)])
while len(drift_pairs) < target_count and drift_pairs:
drift_pairs.append(drift_pairs[len(drift_pairs) % len(drift_pairs)])
repair_path = out_dir / "repair_pairs.jsonl"
drift_path = out_dir / "drift_pairs.jsonl"
with repair_path.open("w", encoding="utf-8") as f:
for row in repair_pairs:
f.write(json.dumps(row) + "\n")
with drift_path.open("w", encoding="utf-8") as f:
for row in drift_pairs:
f.write(json.dumps(row) + "\n")
counts = {"repair_pairs": len(repair_pairs), "drift_pairs": len(drift_pairs)}
summary_path = out_dir / "summary.json"
summary_path.write_text(
json.dumps(
{
**counts,
"target_count": target_count,
"primitives_covered": sorted(
{p["primitive_type"] for p in repair_pairs}
),
"tasks_covered": sorted({p["task_id"] for p in repair_pairs}),
},
indent=2,
),
encoding="utf-8",
)
return counts
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--target_count", type=int, default=64)
parser.add_argument("--out_dir", type=str, default="warmstart/data")
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
counts = generate_pairs(target_count=args.target_count, out_dir=Path(args.out_dir))
print(json.dumps(counts, indent=2))
|