| """Warm-start pair generator for both roles. |
| |
| Produces two parallel JSONL files: |
| |
| warmstart/data/repair_pairs.jsonl -- (prompt, completion) for Repair Agent SFT |
| warmstart/data/drift_pairs.jsonl -- (prompt, completion) for Drift Generator SFT |
| |
| Each row has the canonical chat-template fields: |
| {"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}], |
| "task_id": ..., "primitive_type": ..., "category": ...} |
| |
| We generate at least 50 pairs (the hackathon brief requires this minimum) by |
| combining each seed-corpus script with each applicable breakage primitive |
| configuration. The Drift Generator SFT teaches it to emit clean JSON; the |
| Repair Agent SFT teaches it to emit canonical unified diffs. |
| |
| Usage: |
| python warmstart/generate_pairs.py [--target_count 64] [--out_dir warmstart/data] |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Iterable, Optional |
|
|
| from forgeenv.env.diff_utils import make_unified_diff |
| from forgeenv.primitives.breakage_primitives import ( |
| PRIMITIVE_REGISTRY, |
| parse_breakage_spec, |
| ) |
| from forgeenv.roles.drift_generator import _DEFAULT_PARAMS_BY_TYPE |
| from forgeenv.roles.prompts import ( |
| DRIFT_GENERATOR_SYSTEM_PROMPT, |
| REPAIR_AGENT_SYSTEM_PROMPT, |
| render_drift_generator_prompt, |
| render_repair_agent_prompt, |
| ) |
| from forgeenv.roles.repair_agent import BaselineRepairAgent |
| from forgeenv.tasks.task_sampler import TaskSampler |
|
|
|
|
| def _candidate_breakages(script: str) -> list[dict]: |
| """Yield breakage specs whose default params we know will mutate `script`.""" |
| out: list[dict] = [] |
| for ptype, param_options in _DEFAULT_PARAMS_BY_TYPE.items(): |
| for params in param_options: |
| spec = {"primitive_type": ptype, "params": dict(params)} |
| try: |
| primitive = parse_breakage_spec(spec) |
| except ValueError: |
| continue |
| mutated = primitive.apply(script) |
| if mutated != script: |
| out.append(spec) |
| return out |
|
|
|
|
| def _render_pairs_for_task( |
| task_id: str, |
| script: str, |
| library_versions: dict, |
| repair_agent: BaselineRepairAgent, |
| ) -> list[dict]: |
| pairs = [] |
| for spec in _candidate_breakages(script): |
| primitive = parse_breakage_spec(spec) |
| broken = primitive.apply(script) |
| if broken == script: |
| continue |
|
|
| diff = repair_agent.repair(broken, breakage_spec=spec, original_script=script) |
| if not diff: |
| continue |
|
|
| |
| repair_user = render_repair_agent_prompt( |
| broken_script=broken, |
| error_trace=f"[simulated] {primitive.description}", |
| library_versions=library_versions, |
| target_category=primitive.category, |
| ) |
| pairs.append( |
| { |
| "role_target": "repair_agent", |
| "task_id": task_id, |
| "primitive_type": spec["primitive_type"], |
| "category": primitive.category, |
| "messages": [ |
| {"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT}, |
| {"role": "user", "content": repair_user}, |
| {"role": "assistant", "content": diff}, |
| ], |
| } |
| ) |
|
|
| |
| drift_user = render_drift_generator_prompt( |
| script=script, |
| target_category=spec["primitive_type"], |
| library_versions=library_versions, |
| ) |
| pairs.append( |
| { |
| "role_target": "drift_generator", |
| "task_id": task_id, |
| "primitive_type": spec["primitive_type"], |
| "category": primitive.category, |
| "messages": [ |
| {"role": "system", "content": DRIFT_GENERATOR_SYSTEM_PROMPT}, |
| {"role": "user", "content": drift_user}, |
| {"role": "assistant", "content": json.dumps(spec, indent=2)}, |
| ], |
| } |
| ) |
| return pairs |
|
|
|
|
| def generate_pairs( |
| target_count: int = 64, out_dir: Optional[Path] = None |
| ) -> dict[str, int]: |
| out_dir = Path(out_dir) if out_dir is not None else Path("warmstart/data") |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| library_versions = {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.10.0"} |
| sampler = TaskSampler() |
| repair_agent = BaselineRepairAgent() |
|
|
| repair_pairs: list[dict] = [] |
| drift_pairs: list[dict] = [] |
| for task in sampler.tasks: |
| for pair in _render_pairs_for_task( |
| task_id=task.task_id, |
| script=task.script_content, |
| library_versions=library_versions, |
| repair_agent=repair_agent, |
| ): |
| if pair["role_target"] == "repair_agent": |
| repair_pairs.append(pair) |
| else: |
| drift_pairs.append(pair) |
|
|
| |
| |
| while len(repair_pairs) < target_count and repair_pairs: |
| repair_pairs.append(repair_pairs[len(repair_pairs) % len(repair_pairs)]) |
| while len(drift_pairs) < target_count and drift_pairs: |
| drift_pairs.append(drift_pairs[len(drift_pairs) % len(drift_pairs)]) |
|
|
| repair_path = out_dir / "repair_pairs.jsonl" |
| drift_path = out_dir / "drift_pairs.jsonl" |
| with repair_path.open("w", encoding="utf-8") as f: |
| for row in repair_pairs: |
| f.write(json.dumps(row) + "\n") |
| with drift_path.open("w", encoding="utf-8") as f: |
| for row in drift_pairs: |
| f.write(json.dumps(row) + "\n") |
|
|
| counts = {"repair_pairs": len(repair_pairs), "drift_pairs": len(drift_pairs)} |
| summary_path = out_dir / "summary.json" |
| summary_path.write_text( |
| json.dumps( |
| { |
| **counts, |
| "target_count": target_count, |
| "primitives_covered": sorted( |
| {p["primitive_type"] for p in repair_pairs} |
| ), |
| "tasks_covered": sorted({p["task_id"] for p in repair_pairs}), |
| }, |
| indent=2, |
| ), |
| encoding="utf-8", |
| ) |
| return counts |
|
|
|
|
| def _parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("--target_count", type=int, default=64) |
| parser.add_argument("--out_dir", type=str, default="warmstart/data") |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = _parse_args() |
| counts = generate_pairs(target_count=args.target_count, out_dir=Path(args.out_dir)) |
| print(json.dumps(counts, indent=2)) |
|
|