"""Smoke tests for the training pipeline (rollout, dry-run trainers).""" import json import tempfile from pathlib import Path from forgeenv.env.forge_environment import ForgeEnvironment from forgeenv.training.grpo_repair import run_grpo from forgeenv.training.grpo_drift import run_drift_grpo_dry_run from forgeenv.training.rollout import ( baseline_oracle_repair_generate, rollout_one_episode, ) def test_rollout_one_episode_baseline_no_op_repair(): env = ForgeEnvironment(seed=1) result = rollout_one_episode(env) assert result.task_id assert result.primitive_type assert isinstance(result.visible_reward, float) assert "executed_cleanly" in result.held_out_breakdown def test_rollout_one_episode_with_oracle_repair_succeeds(): env = ForgeEnvironment(seed=2) repair_gen = baseline_oracle_repair_generate(env) result = rollout_one_episode(env, repair_generate=repair_gen) # Oracle should usually score well on `intent_preserved` (script identical to original). assert result.held_out_breakdown.get("intent_preserved", 0.0) > 0.7 def test_grpo_repair_dry_run_smoke(): with tempfile.TemporaryDirectory() as tmp: run_grpo( base_model="(unused-in-dry-run)", adapter_path=None, output_dir=tmp, total_episodes=5, group_size=2, seed=0, use_unsloth=False, ) rewards_path = Path(tmp) / "dry_run_rewards.json" assert rewards_path.exists() rewards = json.loads(rewards_path.read_text()) assert len(rewards) == 5 assert all(isinstance(r, (int, float)) for r in rewards) def test_grpo_drift_dry_run_smoke(): with tempfile.TemporaryDirectory() as tmp: run_drift_grpo_dry_run( output_dir=tmp, total_episodes=3, group_size=2, seed=0 ) log_path = Path(tmp) / "drift_dry_run.json" assert log_path.exists() log = json.loads(log_path.read_text()) assert len(log) == 3 for entry in log: assert "rewards" in entry and "candidates" in entry assert all(0.0 <= r <= 2.0 for r in entry["rewards"])