forgeenv-source / warmstart /generate_pairs.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Warm-start pair generator for both roles.
Produces two parallel JSONL files:
warmstart/data/repair_pairs.jsonl -- (prompt, completion) for Repair Agent SFT
warmstart/data/drift_pairs.jsonl -- (prompt, completion) for Drift Generator SFT
Each row has the canonical chat-template fields:
{"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}],
"task_id": ..., "primitive_type": ..., "category": ...}
We generate at least 50 pairs (the hackathon brief requires this minimum) by
combining each seed-corpus script with each applicable breakage primitive
configuration. The Drift Generator SFT teaches it to emit clean JSON; the
Repair Agent SFT teaches it to emit canonical unified diffs.
Usage:
python warmstart/generate_pairs.py [--target_count 64] [--out_dir warmstart/data]
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Iterable, Optional
from forgeenv.env.diff_utils import make_unified_diff
from forgeenv.primitives.breakage_primitives import (
PRIMITIVE_REGISTRY,
parse_breakage_spec,
)
from forgeenv.roles.drift_generator import _DEFAULT_PARAMS_BY_TYPE
from forgeenv.roles.prompts import (
DRIFT_GENERATOR_SYSTEM_PROMPT,
REPAIR_AGENT_SYSTEM_PROMPT,
render_drift_generator_prompt,
render_repair_agent_prompt,
)
from forgeenv.roles.repair_agent import BaselineRepairAgent
from forgeenv.tasks.task_sampler import TaskSampler
def _candidate_breakages(script: str) -> list[dict]:
"""Yield breakage specs whose default params we know will mutate `script`."""
out: list[dict] = []
for ptype, param_options in _DEFAULT_PARAMS_BY_TYPE.items():
for params in param_options:
spec = {"primitive_type": ptype, "params": dict(params)}
try:
primitive = parse_breakage_spec(spec)
except ValueError:
continue
mutated = primitive.apply(script)
if mutated != script:
out.append(spec)
return out
def _render_pairs_for_task(
task_id: str,
script: str,
library_versions: dict,
repair_agent: BaselineRepairAgent,
) -> list[dict]:
pairs = []
for spec in _candidate_breakages(script):
primitive = parse_breakage_spec(spec)
broken = primitive.apply(script)
if broken == script:
continue
diff = repair_agent.repair(broken, breakage_spec=spec, original_script=script)
if not diff:
continue
# Repair Agent pair
repair_user = render_repair_agent_prompt(
broken_script=broken,
error_trace=f"[simulated] {primitive.description}",
library_versions=library_versions,
target_category=primitive.category,
)
pairs.append(
{
"role_target": "repair_agent",
"task_id": task_id,
"primitive_type": spec["primitive_type"],
"category": primitive.category,
"messages": [
{"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT},
{"role": "user", "content": repair_user},
{"role": "assistant", "content": diff},
],
}
)
# Drift Generator pair (predict the breakage spec from the working script)
drift_user = render_drift_generator_prompt(
script=script,
target_category=spec["primitive_type"],
library_versions=library_versions,
)
pairs.append(
{
"role_target": "drift_generator",
"task_id": task_id,
"primitive_type": spec["primitive_type"],
"category": primitive.category,
"messages": [
{"role": "system", "content": DRIFT_GENERATOR_SYSTEM_PROMPT},
{"role": "user", "content": drift_user},
{"role": "assistant", "content": json.dumps(spec, indent=2)},
],
}
)
return pairs
def generate_pairs(
target_count: int = 64, out_dir: Optional[Path] = None
) -> dict[str, int]:
out_dir = Path(out_dir) if out_dir is not None else Path("warmstart/data")
out_dir.mkdir(parents=True, exist_ok=True)
library_versions = {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.10.0"}
sampler = TaskSampler()
repair_agent = BaselineRepairAgent()
repair_pairs: list[dict] = []
drift_pairs: list[dict] = []
for task in sampler.tasks:
for pair in _render_pairs_for_task(
task_id=task.task_id,
script=task.script_content,
library_versions=library_versions,
repair_agent=repair_agent,
):
if pair["role_target"] == "repair_agent":
repair_pairs.append(pair)
else:
drift_pairs.append(pair)
# If we don't have enough, duplicate-with-shuffle. (We never do this in
# practice — the corpus produces > 64 pairs — but the safety net is cheap.)
while len(repair_pairs) < target_count and repair_pairs:
repair_pairs.append(repair_pairs[len(repair_pairs) % len(repair_pairs)])
while len(drift_pairs) < target_count and drift_pairs:
drift_pairs.append(drift_pairs[len(drift_pairs) % len(drift_pairs)])
repair_path = out_dir / "repair_pairs.jsonl"
drift_path = out_dir / "drift_pairs.jsonl"
with repair_path.open("w", encoding="utf-8") as f:
for row in repair_pairs:
f.write(json.dumps(row) + "\n")
with drift_path.open("w", encoding="utf-8") as f:
for row in drift_pairs:
f.write(json.dumps(row) + "\n")
counts = {"repair_pairs": len(repair_pairs), "drift_pairs": len(drift_pairs)}
summary_path = out_dir / "summary.json"
summary_path.write_text(
json.dumps(
{
**counts,
"target_count": target_count,
"primitives_covered": sorted(
{p["primitive_type"] for p in repair_pairs}
),
"tasks_covered": sorted({p["task_id"] for p in repair_pairs}),
},
indent=2,
),
encoding="utf-8",
)
return counts
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--target_count", type=int, default=64)
parser.add_argument("--out_dir", type=str, default="warmstart/data")
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
counts = generate_pairs(target_count=args.target_count, out_dir=Path(args.out_dir))
print(json.dumps(counts, indent=2))