File size: 2,174 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Smoke tests for the training pipeline (rollout, dry-run trainers)."""
import json
import tempfile
from pathlib import Path

from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.training.grpo_repair import run_grpo
from forgeenv.training.grpo_drift import run_drift_grpo_dry_run
from forgeenv.training.rollout import (
    baseline_oracle_repair_generate,
    rollout_one_episode,
)


def test_rollout_one_episode_baseline_no_op_repair():
    env = ForgeEnvironment(seed=1)
    result = rollout_one_episode(env)
    assert result.task_id
    assert result.primitive_type
    assert isinstance(result.visible_reward, float)
    assert "executed_cleanly" in result.held_out_breakdown


def test_rollout_one_episode_with_oracle_repair_succeeds():
    env = ForgeEnvironment(seed=2)
    repair_gen = baseline_oracle_repair_generate(env)
    result = rollout_one_episode(env, repair_generate=repair_gen)
    # Oracle should usually score well on `intent_preserved` (script identical to original).
    assert result.held_out_breakdown.get("intent_preserved", 0.0) > 0.7


def test_grpo_repair_dry_run_smoke():
    with tempfile.TemporaryDirectory() as tmp:
        run_grpo(
            base_model="(unused-in-dry-run)",
            adapter_path=None,
            output_dir=tmp,
            total_episodes=5,
            group_size=2,
            seed=0,
            use_unsloth=False,
        )
        rewards_path = Path(tmp) / "dry_run_rewards.json"
        assert rewards_path.exists()
        rewards = json.loads(rewards_path.read_text())
        assert len(rewards) == 5
        assert all(isinstance(r, (int, float)) for r in rewards)


def test_grpo_drift_dry_run_smoke():
    with tempfile.TemporaryDirectory() as tmp:
        run_drift_grpo_dry_run(
            output_dir=tmp, total_episodes=3, group_size=2, seed=0
        )
        log_path = Path(tmp) / "drift_dry_run.json"
        assert log_path.exists()
        log = json.loads(log_path.read_text())
        assert len(log) == 3
        for entry in log:
            assert "rewards" in entry and "candidates" in entry
            assert all(0.0 <= r <= 2.0 for r in entry["rewards"])