forgeenv-source / tests /test_warmstart.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Smoke tests for warm-start pair generation."""
import json
import tempfile
from pathlib import Path
from warmstart.generate_pairs import generate_pairs
def test_generate_pairs_produces_minimum_count():
with tempfile.TemporaryDirectory() as tmp:
counts = generate_pairs(target_count=50, out_dir=Path(tmp))
assert counts["repair_pairs"] >= 50
assert counts["drift_pairs"] >= 50
repair_jsonl = Path(tmp) / "repair_pairs.jsonl"
drift_jsonl = Path(tmp) / "drift_pairs.jsonl"
assert repair_jsonl.exists()
assert drift_jsonl.exists()
first = json.loads(repair_jsonl.read_text(encoding="utf-8").splitlines()[0])
assert first["role_target"] == "repair_agent"
assert "messages" in first and len(first["messages"]) == 3
assert first["messages"][-1]["content"] # non-empty assistant content
first_drift = json.loads(drift_jsonl.read_text(encoding="utf-8").splitlines()[0])
assert first_drift["role_target"] == "drift_generator"
body = first_drift["messages"][-1]["content"]
parsed = json.loads(body)
assert "primitive_type" in parsed and "params" in parsed
def test_generate_pairs_covers_multiple_primitive_types():
with tempfile.TemporaryDirectory() as tmp:
generate_pairs(target_count=50, out_dir=Path(tmp))
summary = json.loads((Path(tmp) / "summary.json").read_text(encoding="utf-8"))
assert len(summary["primitives_covered"]) >= 5
assert len(summary["tasks_covered"]) >= 5