| """Tests for the AST-based forbidden-pattern validator.""" |
| from forgeenv.sandbox.ast_validator import validate_script |
|
|
|
|
| def test_clean_script_passes(): |
| script = """ |
| import torch |
| from transformers import Trainer |
| model = Trainer() |
| """ |
| result = validate_script(script) |
| assert result.is_valid, f"Clean script should pass: {result.violations}" |
|
|
|
|
| def test_os_import_fails(): |
| script = "import os\nos.system('rm -rf /')" |
| result = validate_script(script) |
| assert not result.is_valid |
| assert any("os" in v for v in result.violations) |
|
|
|
|
| def test_subprocess_fails(): |
| script = "import subprocess\nsubprocess.run(['ls'])" |
| result = validate_script(script) |
| assert not result.is_valid |
|
|
|
|
| def test_eval_fails(): |
| script = "result = eval('1+1')" |
| result = validate_script(script) |
| assert not result.is_valid |
| assert any("eval" in v for v in result.violations) |
|
|
|
|
| def test_syntax_error_fails(): |
| script = "def foo(\n broken syntax" |
| result = validate_script(script) |
| assert not result.is_valid |
| assert any("SyntaxError" in v for v in result.violations) |
|
|
|
|
| def test_transformers_import_passes(): |
| script = """ |
| from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments |
| from datasets import load_dataset |
| import torch |
| """ |
| result = validate_script(script) |
| assert result.is_valid |
|
|
|
|
| def test_socket_import_fails(): |
| script = "import socket\ns = socket.socket()" |
| result = validate_script(script) |
| assert not result.is_valid |
|
|
|
|
| def test_builtins_assignment_fails(): |
| script = "__builtins__ = {}" |
| result = validate_script(script) |
| assert not result.is_valid |
|
|
|
|
| def test_attribute_eval_fails(): |
| """eval accessed via attribute (e.g. ast.literal_eval is fine, but |
| something.eval() of certain shape should be flagged when name is exec).""" |
| script = "obj.exec('rm -rf /')" |
| result = validate_script(script) |
| assert not result.is_valid |
|
|