| """Tests for the AST-based forbidden-pattern validator.""" | |
| from forgeenv.sandbox.ast_validator import validate_script | |
| def test_clean_script_passes(): | |
| script = """ | |
| import torch | |
| from transformers import Trainer | |
| model = Trainer() | |
| """ | |
| result = validate_script(script) | |
| assert result.is_valid, f"Clean script should pass: {result.violations}" | |
| def test_os_import_fails(): | |
| script = "import os\nos.system('rm -rf /')" | |
| result = validate_script(script) | |
| assert not result.is_valid | |
| assert any("os" in v for v in result.violations) | |
| def test_subprocess_fails(): | |
| script = "import subprocess\nsubprocess.run(['ls'])" | |
| result = validate_script(script) | |
| assert not result.is_valid | |
| def test_eval_fails(): | |
| script = "result = eval('1+1')" | |
| result = validate_script(script) | |
| assert not result.is_valid | |
| assert any("eval" in v for v in result.violations) | |
| def test_syntax_error_fails(): | |
| script = "def foo(\n broken syntax" | |
| result = validate_script(script) | |
| assert not result.is_valid | |
| assert any("SyntaxError" in v for v in result.violations) | |
| def test_transformers_import_passes(): | |
| script = """ | |
| from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments | |
| from datasets import load_dataset | |
| import torch | |
| """ | |
| result = validate_script(script) | |
| assert result.is_valid | |
| def test_socket_import_fails(): | |
| script = "import socket\ns = socket.socket()" | |
| result = validate_script(script) | |
| assert not result.is_valid | |
| def test_builtins_assignment_fails(): | |
| script = "__builtins__ = {}" | |
| result = validate_script(script) | |
| assert not result.is_valid | |
| def test_attribute_eval_fails(): | |
| """eval accessed via attribute (e.g. ast.literal_eval is fine, but | |
| something.eval() of certain shape should be flagged when name is exec).""" | |
| script = "obj.exec('rm -rf /')" | |
| result = validate_script(script) | |
| assert not result.is_valid | |