"""AST-based script validator. Catches forbidden imports and dangerous patterns BEFORE any execution happens. This is a critical defense against reward hacking via system calls, network access, or process manipulation. """ from __future__ import annotations import ast from forgeenv.tasks.models import ValidationResult FORBIDDEN_MODULES = { "os", "subprocess", "socket", "urllib", "requests", "ctypes", "shutil", "signal", "multiprocessing", "threading", } FORBIDDEN_FUNCTIONS = {"eval", "exec", "compile", "__import__"} def validate_script(script_content: str) -> ValidationResult: """Parse a script as AST and reject forbidden patterns. Returns a ValidationResult with `is_valid` and a list of `violations`. """ violations: list[str] = [] try: tree = ast.parse(script_content) except SyntaxError as e: return ValidationResult(is_valid=False, violations=[f"SyntaxError: {e}"]) for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: module_root = alias.name.split(".")[0] if module_root in FORBIDDEN_MODULES: violations.append(f"Forbidden import: {alias.name}") if isinstance(node, ast.ImportFrom): if node.module: module_root = node.module.split(".")[0] if module_root in FORBIDDEN_MODULES: violations.append(f"Forbidden import from: {node.module}") if isinstance(node, ast.Call): if isinstance(node.func, ast.Name): if node.func.id in FORBIDDEN_FUNCTIONS: violations.append(f"Forbidden call: {node.func.id}()") if isinstance(node.func, ast.Attribute): if node.func.attr in FORBIDDEN_FUNCTIONS: violations.append(f"Forbidden call: .{node.func.attr}()") if isinstance(node, ast.Assign): for target in node.targets: if isinstance(target, ast.Name) and target.id == "__builtins__": violations.append("Forbidden: __builtins__ assignment") return ValidationResult( is_valid=len(violations) == 0, violations=violations, )