"""Warm-start pair generator for both roles. Produces two parallel JSONL files: warmstart/data/repair_pairs.jsonl -- (prompt, completion) for Repair Agent SFT warmstart/data/drift_pairs.jsonl -- (prompt, completion) for Drift Generator SFT Each row has the canonical chat-template fields: {"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}], "task_id": ..., "primitive_type": ..., "category": ...} We generate at least 50 pairs (the hackathon brief requires this minimum) by combining each seed-corpus script with each applicable breakage primitive configuration. The Drift Generator SFT teaches it to emit clean JSON; the Repair Agent SFT teaches it to emit canonical unified diffs. Usage: python warmstart/generate_pairs.py [--target_count 64] [--out_dir warmstart/data] """ from __future__ import annotations import argparse import json from pathlib import Path from typing import Iterable, Optional from forgeenv.env.diff_utils import make_unified_diff from forgeenv.primitives.breakage_primitives import ( PRIMITIVE_REGISTRY, parse_breakage_spec, ) from forgeenv.roles.drift_generator import _DEFAULT_PARAMS_BY_TYPE from forgeenv.roles.prompts import ( DRIFT_GENERATOR_SYSTEM_PROMPT, REPAIR_AGENT_SYSTEM_PROMPT, render_drift_generator_prompt, render_repair_agent_prompt, ) from forgeenv.roles.repair_agent import BaselineRepairAgent from forgeenv.tasks.task_sampler import TaskSampler def _candidate_breakages(script: str) -> list[dict]: """Yield breakage specs whose default params we know will mutate `script`.""" out: list[dict] = [] for ptype, param_options in _DEFAULT_PARAMS_BY_TYPE.items(): for params in param_options: spec = {"primitive_type": ptype, "params": dict(params)} try: primitive = parse_breakage_spec(spec) except ValueError: continue mutated = primitive.apply(script) if mutated != script: out.append(spec) return out def _render_pairs_for_task( task_id: str, script: str, library_versions: dict, repair_agent: BaselineRepairAgent, ) -> list[dict]: pairs = [] for spec in _candidate_breakages(script): primitive = parse_breakage_spec(spec) broken = primitive.apply(script) if broken == script: continue diff = repair_agent.repair(broken, breakage_spec=spec, original_script=script) if not diff: continue # Repair Agent pair repair_user = render_repair_agent_prompt( broken_script=broken, error_trace=f"[simulated] {primitive.description}", library_versions=library_versions, target_category=primitive.category, ) pairs.append( { "role_target": "repair_agent", "task_id": task_id, "primitive_type": spec["primitive_type"], "category": primitive.category, "messages": [ {"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT}, {"role": "user", "content": repair_user}, {"role": "assistant", "content": diff}, ], } ) # Drift Generator pair (predict the breakage spec from the working script) drift_user = render_drift_generator_prompt( script=script, target_category=spec["primitive_type"], library_versions=library_versions, ) pairs.append( { "role_target": "drift_generator", "task_id": task_id, "primitive_type": spec["primitive_type"], "category": primitive.category, "messages": [ {"role": "system", "content": DRIFT_GENERATOR_SYSTEM_PROMPT}, {"role": "user", "content": drift_user}, {"role": "assistant", "content": json.dumps(spec, indent=2)}, ], } ) return pairs def generate_pairs( target_count: int = 64, out_dir: Optional[Path] = None ) -> dict[str, int]: out_dir = Path(out_dir) if out_dir is not None else Path("warmstart/data") out_dir.mkdir(parents=True, exist_ok=True) library_versions = {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.10.0"} sampler = TaskSampler() repair_agent = BaselineRepairAgent() repair_pairs: list[dict] = [] drift_pairs: list[dict] = [] for task in sampler.tasks: for pair in _render_pairs_for_task( task_id=task.task_id, script=task.script_content, library_versions=library_versions, repair_agent=repair_agent, ): if pair["role_target"] == "repair_agent": repair_pairs.append(pair) else: drift_pairs.append(pair) # If we don't have enough, duplicate-with-shuffle. (We never do this in # practice — the corpus produces > 64 pairs — but the safety net is cheap.) while len(repair_pairs) < target_count and repair_pairs: repair_pairs.append(repair_pairs[len(repair_pairs) % len(repair_pairs)]) while len(drift_pairs) < target_count and drift_pairs: drift_pairs.append(drift_pairs[len(drift_pairs) % len(drift_pairs)]) repair_path = out_dir / "repair_pairs.jsonl" drift_path = out_dir / "drift_pairs.jsonl" with repair_path.open("w", encoding="utf-8") as f: for row in repair_pairs: f.write(json.dumps(row) + "\n") with drift_path.open("w", encoding="utf-8") as f: for row in drift_pairs: f.write(json.dumps(row) + "\n") counts = {"repair_pairs": len(repair_pairs), "drift_pairs": len(drift_pairs)} summary_path = out_dir / "summary.json" summary_path.write_text( json.dumps( { **counts, "target_count": target_count, "primitives_covered": sorted( {p["primitive_type"] for p in repair_pairs} ), "tasks_covered": sorted({p["task_id"] for p in repair_pairs}), }, indent=2, ), encoding="utf-8", ) return counts def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--target_count", type=int, default=64) parser.add_argument("--out_dir", type=str, default="warmstart/data") return parser.parse_args() if __name__ == "__main__": args = _parse_args() counts = generate_pairs(target_count=args.target_count, out_dir=Path(args.out_dir)) print(json.dumps(counts, indent=2))