"""Tests for the simulation-mode executor.""" from forgeenv.sandbox.simulation_mode import SimulationExecutor from forgeenv.tasks.models import Task VALID_HF = """ from transformers import Trainer, TrainingArguments from datasets import load_dataset import torch dataset = load_dataset("glue", "sst2") trainer = Trainer(model=None, args=None, train_dataset=dataset) trainer.train() trainer.save_model("/tmp/forge_output/checkpoint") print("TRAINING_COMPLETE") """ SYNTAX_ERROR = "def foo(\n broken" OS_IMPORT = "import os\nos.listdir('.')" def _task(content: str) -> Task: return Task( task_id="t", description="d", script_content=content, difficulty="easy", ) def test_valid_script_can_succeed(): """With seed 0, the valid HF script eventually returns a positive case.""" executor = SimulationExecutor(seed=0) result = executor.execute(VALID_HF, _task(VALID_HF)) # Either succeeds (exit 0 with TRAINING_COMPLETE) or fails with realistic # HF error; never crashes or returns an empty result. assert result.exit_code in (0, 1) if result.exit_code == 0: assert "TRAINING_COMPLETE" in result.stdout def test_syntax_error_fails(): executor = SimulationExecutor(seed=0) result = executor.execute(SYNTAX_ERROR, _task(SYNTAX_ERROR)) assert result.exit_code == 1 assert "SyntaxError" in result.stderr def test_forbidden_import_fails(): executor = SimulationExecutor(seed=0) result = executor.execute(OS_IMPORT, _task(OS_IMPORT)) assert result.exit_code == 1 assert "Validation failed" in result.stderr def test_simulation_is_fast(): """Simulation mode must complete each call in <100ms wall_time. The reported wall_time_ms field includes a synthetic delay so we measure real elapsed time at this layer instead. """ import time executor = SimulationExecutor(seed=0) t0 = time.time() executor.execute(VALID_HF, _task(VALID_HF)) elapsed_ms = (time.time() - t0) * 1000 assert elapsed_ms < 200, f"Simulation took {elapsed_ms:.1f}ms" def test_seed_is_deterministic(): e1 = SimulationExecutor(seed=42) e2 = SimulationExecutor(seed=42) r1 = e1.execute(VALID_HF, _task(VALID_HF)) r2 = e2.execute(VALID_HF, _task(VALID_HF)) assert r1.exit_code == r2.exit_code assert r1.stderr == r2.stderr