| """Smoke tests for the training pipeline (rollout, dry-run trainers)."""
|
| import json
|
| import tempfile
|
| from pathlib import Path
|
|
|
| from forgeenv.env.forge_environment import ForgeEnvironment
|
| from forgeenv.training.grpo_repair import run_grpo
|
| from forgeenv.training.grpo_drift import run_drift_grpo_dry_run
|
| from forgeenv.training.rollout import (
|
| baseline_oracle_repair_generate,
|
| rollout_one_episode,
|
| )
|
|
|
|
|
| def test_rollout_one_episode_baseline_no_op_repair():
|
| env = ForgeEnvironment(seed=1)
|
| result = rollout_one_episode(env)
|
| assert result.task_id
|
| assert result.primitive_type
|
| assert isinstance(result.visible_reward, float)
|
| assert "executed_cleanly" in result.held_out_breakdown
|
|
|
|
|
| def test_rollout_one_episode_with_oracle_repair_succeeds():
|
| env = ForgeEnvironment(seed=2)
|
| repair_gen = baseline_oracle_repair_generate(env)
|
| result = rollout_one_episode(env, repair_generate=repair_gen)
|
|
|
| assert result.held_out_breakdown.get("intent_preserved", 0.0) > 0.7
|
|
|
|
|
| def test_grpo_repair_dry_run_smoke():
|
| with tempfile.TemporaryDirectory() as tmp:
|
| run_grpo(
|
| base_model="(unused-in-dry-run)",
|
| adapter_path=None,
|
| output_dir=tmp,
|
| total_episodes=5,
|
| group_size=2,
|
| seed=0,
|
| use_unsloth=False,
|
| )
|
| rewards_path = Path(tmp) / "dry_run_rewards.json"
|
| assert rewards_path.exists()
|
| rewards = json.loads(rewards_path.read_text())
|
| assert len(rewards) == 5
|
| assert all(isinstance(r, (int, float)) for r in rewards)
|
|
|
|
|
| def test_grpo_drift_dry_run_smoke():
|
| with tempfile.TemporaryDirectory() as tmp:
|
| run_drift_grpo_dry_run(
|
| output_dir=tmp, total_episodes=3, group_size=2, seed=0
|
| )
|
| log_path = Path(tmp) / "drift_dry_run.json"
|
| assert log_path.exists()
|
| log = json.loads(log_path.read_text())
|
| assert len(log) == 3
|
| for entry in log:
|
| assert "rewards" in entry and "candidates" in entry
|
| assert all(0.0 <= r <= 2.0 for r in entry["rewards"])
|
|
|