File size: 2,174 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | """Smoke tests for the training pipeline (rollout, dry-run trainers)."""
import json
import tempfile
from pathlib import Path
from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.training.grpo_repair import run_grpo
from forgeenv.training.grpo_drift import run_drift_grpo_dry_run
from forgeenv.training.rollout import (
baseline_oracle_repair_generate,
rollout_one_episode,
)
def test_rollout_one_episode_baseline_no_op_repair():
env = ForgeEnvironment(seed=1)
result = rollout_one_episode(env)
assert result.task_id
assert result.primitive_type
assert isinstance(result.visible_reward, float)
assert "executed_cleanly" in result.held_out_breakdown
def test_rollout_one_episode_with_oracle_repair_succeeds():
env = ForgeEnvironment(seed=2)
repair_gen = baseline_oracle_repair_generate(env)
result = rollout_one_episode(env, repair_generate=repair_gen)
# Oracle should usually score well on `intent_preserved` (script identical to original).
assert result.held_out_breakdown.get("intent_preserved", 0.0) > 0.7
def test_grpo_repair_dry_run_smoke():
with tempfile.TemporaryDirectory() as tmp:
run_grpo(
base_model="(unused-in-dry-run)",
adapter_path=None,
output_dir=tmp,
total_episodes=5,
group_size=2,
seed=0,
use_unsloth=False,
)
rewards_path = Path(tmp) / "dry_run_rewards.json"
assert rewards_path.exists()
rewards = json.loads(rewards_path.read_text())
assert len(rewards) == 5
assert all(isinstance(r, (int, float)) for r in rewards)
def test_grpo_drift_dry_run_smoke():
with tempfile.TemporaryDirectory() as tmp:
run_drift_grpo_dry_run(
output_dir=tmp, total_episodes=3, group_size=2, seed=0
)
log_path = Path(tmp) / "drift_dry_run.json"
assert log_path.exists()
log = json.loads(log_path.read_text())
assert len(log) == 3
for entry in log:
assert "rewards" in entry and "candidates" in entry
assert all(0.0 <= r <= 2.0 for r in entry["rewards"])
|