File size: 2,345 Bytes
b0fbec3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | """AST-based script validator.
Catches forbidden imports and dangerous patterns BEFORE any execution
happens. This is a critical defense against reward hacking via system
calls, network access, or process manipulation.
"""
from __future__ import annotations
import ast
from forgeenv.tasks.models import ValidationResult
FORBIDDEN_MODULES = {
"os",
"subprocess",
"socket",
"urllib",
"requests",
"ctypes",
"shutil",
"signal",
"multiprocessing",
"threading",
}
FORBIDDEN_FUNCTIONS = {"eval", "exec", "compile", "__import__"}
def validate_script(script_content: str) -> ValidationResult:
"""Parse a script as AST and reject forbidden patterns.
Returns a ValidationResult with `is_valid` and a list of `violations`.
"""
violations: list[str] = []
try:
tree = ast.parse(script_content)
except SyntaxError as e:
return ValidationResult(is_valid=False, violations=[f"SyntaxError: {e}"])
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module_root = alias.name.split(".")[0]
if module_root in FORBIDDEN_MODULES:
violations.append(f"Forbidden import: {alias.name}")
if isinstance(node, ast.ImportFrom):
if node.module:
module_root = node.module.split(".")[0]
if module_root in FORBIDDEN_MODULES:
violations.append(f"Forbidden import from: {node.module}")
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in FORBIDDEN_FUNCTIONS:
violations.append(f"Forbidden call: {node.func.id}()")
if isinstance(node.func, ast.Attribute):
if node.func.attr in FORBIDDEN_FUNCTIONS:
violations.append(f"Forbidden call: .{node.func.attr}()")
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == "__builtins__":
violations.append("Forbidden: __builtins__ assignment")
return ValidationResult(
is_valid=len(violations) == 0,
violations=violations,
)
|