| """Tests for the simulation-mode executor."""
|
| from forgeenv.sandbox.simulation_mode import SimulationExecutor
|
| from forgeenv.tasks.models import Task
|
|
|
| VALID_HF = """
|
| from transformers import Trainer, TrainingArguments
|
| from datasets import load_dataset
|
| import torch
|
|
|
| dataset = load_dataset("glue", "sst2")
|
| trainer = Trainer(model=None, args=None, train_dataset=dataset)
|
| trainer.train()
|
| trainer.save_model("/tmp/forge_output/checkpoint")
|
| print("TRAINING_COMPLETE")
|
| """
|
|
|
| SYNTAX_ERROR = "def foo(\n broken"
|
|
|
| OS_IMPORT = "import os\nos.listdir('.')"
|
|
|
|
|
| def _task(content: str) -> Task:
|
| return Task(
|
| task_id="t",
|
| description="d",
|
| script_content=content,
|
| difficulty="easy",
|
| )
|
|
|
|
|
| def test_valid_script_can_succeed():
|
| """With seed 0, the valid HF script eventually returns a positive case."""
|
| executor = SimulationExecutor(seed=0)
|
| result = executor.execute(VALID_HF, _task(VALID_HF))
|
|
|
|
|
| assert result.exit_code in (0, 1)
|
| if result.exit_code == 0:
|
| assert "TRAINING_COMPLETE" in result.stdout
|
|
|
|
|
| def test_syntax_error_fails():
|
| executor = SimulationExecutor(seed=0)
|
| result = executor.execute(SYNTAX_ERROR, _task(SYNTAX_ERROR))
|
| assert result.exit_code == 1
|
| assert "SyntaxError" in result.stderr
|
|
|
|
|
| def test_forbidden_import_fails():
|
| executor = SimulationExecutor(seed=0)
|
| result = executor.execute(OS_IMPORT, _task(OS_IMPORT))
|
| assert result.exit_code == 1
|
| assert "Validation failed" in result.stderr
|
|
|
|
|
| def test_simulation_is_fast():
|
| """Simulation mode must complete each call in <100ms wall_time.
|
|
|
| The reported wall_time_ms field includes a synthetic delay so we measure
|
| real elapsed time at this layer instead.
|
| """
|
| import time
|
| executor = SimulationExecutor(seed=0)
|
| t0 = time.time()
|
| executor.execute(VALID_HF, _task(VALID_HF))
|
| elapsed_ms = (time.time() - t0) * 1000
|
| assert elapsed_ms < 200, f"Simulation took {elapsed_ms:.1f}ms"
|
|
|
|
|
| def test_seed_is_deterministic():
|
| e1 = SimulationExecutor(seed=42)
|
| e2 = SimulationExecutor(seed=42)
|
| r1 = e1.execute(VALID_HF, _task(VALID_HF))
|
| r2 = e2.execute(VALID_HF, _task(VALID_HF))
|
| assert r1.exit_code == r2.exit_code
|
| assert r1.stderr == r2.stderr
|
|
|