File size: 2,362 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Tests for the simulation-mode executor."""
from forgeenv.sandbox.simulation_mode import SimulationExecutor
from forgeenv.tasks.models import Task

VALID_HF = """
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
import torch

dataset = load_dataset("glue", "sst2")
trainer = Trainer(model=None, args=None, train_dataset=dataset)
trainer.train()
trainer.save_model("/tmp/forge_output/checkpoint")
print("TRAINING_COMPLETE")
"""

SYNTAX_ERROR = "def foo(\n  broken"

OS_IMPORT = "import os\nos.listdir('.')"


def _task(content: str) -> Task:
    return Task(
        task_id="t",
        description="d",
        script_content=content,
        difficulty="easy",
    )


def test_valid_script_can_succeed():
    """With seed 0, the valid HF script eventually returns a positive case."""
    executor = SimulationExecutor(seed=0)
    result = executor.execute(VALID_HF, _task(VALID_HF))
    # Either succeeds (exit 0 with TRAINING_COMPLETE) or fails with realistic
    # HF error; never crashes or returns an empty result.
    assert result.exit_code in (0, 1)
    if result.exit_code == 0:
        assert "TRAINING_COMPLETE" in result.stdout


def test_syntax_error_fails():
    executor = SimulationExecutor(seed=0)
    result = executor.execute(SYNTAX_ERROR, _task(SYNTAX_ERROR))
    assert result.exit_code == 1
    assert "SyntaxError" in result.stderr


def test_forbidden_import_fails():
    executor = SimulationExecutor(seed=0)
    result = executor.execute(OS_IMPORT, _task(OS_IMPORT))
    assert result.exit_code == 1
    assert "Validation failed" in result.stderr


def test_simulation_is_fast():
    """Simulation mode must complete each call in <100ms wall_time.

    The reported wall_time_ms field includes a synthetic delay so we measure
    real elapsed time at this layer instead.
    """
    import time
    executor = SimulationExecutor(seed=0)
    t0 = time.time()
    executor.execute(VALID_HF, _task(VALID_HF))
    elapsed_ms = (time.time() - t0) * 1000
    assert elapsed_ms < 200, f"Simulation took {elapsed_ms:.1f}ms"


def test_seed_is_deterministic():
    e1 = SimulationExecutor(seed=42)
    e2 = SimulationExecutor(seed=42)
    r1 = e1.execute(VALID_HF, _task(VALID_HF))
    r2 = e2.execute(VALID_HF, _task(VALID_HF))
    assert r1.exit_code == r2.exit_code
    assert r1.stderr == r2.stderr